mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 04:59:27 -07:00
Merge pull request #9 from MachineEditor/preview_filenames
Preview filenames
This commit is contained in:
commit
bdd783b888
6 changed files with 54 additions and 7 deletions
|
@ -415,7 +415,7 @@ class ModelBase(object):
|
|||
return ( ('loss_src', 0), ('loss_dst', 0) )
|
||||
|
||||
#overridable
|
||||
def onGetPreview(self, sample, for_history=False):
|
||||
def onGetPreview(self, sample, for_history=False, filenames=None):
|
||||
#you can return multiple previews
|
||||
#return [ ('preview_name',preview_rgb), ... ]
|
||||
return []
|
||||
|
@ -447,7 +447,7 @@ class ModelBase(object):
|
|||
return self.target_iter != 0 and self.iter >= self.target_iter
|
||||
|
||||
def get_previews(self):
|
||||
return self.onGetPreview ( self.last_sample )
|
||||
return self.onGetPreview ( self.last_sample, filenames=self.last_sample_filenames )
|
||||
|
||||
def get_static_previews(self):
|
||||
return self.onGetPreview (self.sample_for_preview)
|
||||
|
@ -585,12 +585,19 @@ class ModelBase(object):
|
|||
|
||||
def generate_next_samples(self):
|
||||
sample = []
|
||||
sample_filenames = []
|
||||
for generator in self.generator_list:
|
||||
if generator.is_initialized():
|
||||
sample.append ( generator.generate_next() )
|
||||
batch = generator.generate_next()
|
||||
if type(batch) is tuple:
|
||||
sample.append ( batch[0] )
|
||||
sample_filenames.append( batch[1] )
|
||||
else:
|
||||
sample.append ( batch )
|
||||
else:
|
||||
sample.append ( [] )
|
||||
self.last_sample = sample
|
||||
self.last_sample_filenames = sample_filenames
|
||||
return sample
|
||||
|
||||
#overridable
|
||||
|
|
|
@ -13,6 +13,8 @@ from core.cv2ex import *
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from utils.label_face import label_face_filename
|
||||
|
||||
class AMPModel(ModelBase):
|
||||
|
||||
#override
|
||||
|
@ -888,7 +890,7 @@ class AMPModel(ModelBase):
|
|||
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
|
||||
|
||||
#override
|
||||
def onGetPreview(self, samples, for_history=False):
|
||||
def onGetPreview(self, samples, for_history=False, filenames=None):
|
||||
( (warped_src, target_src, target_srcm, target_srcm_em),
|
||||
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
|
||||
|
||||
|
@ -920,6 +922,10 @@ class AMPModel(ModelBase):
|
|||
|
||||
i = np.random.randint(n_samples) if not for_history else 0
|
||||
|
||||
if filenames is not None and len(filenames) > 0:
|
||||
S[i] = label_face_filename(S[i], filenames[0][i])
|
||||
D[i] = label_face_filename(D[i], filenames[1][i])
|
||||
|
||||
st = [ np.concatenate ((S[i], D[i], DD[i]*DDM_000[i]), axis=1) ]
|
||||
st += [ np.concatenate ((SS[i], DD[i], SD_100[i] ), axis=1) ]
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@ from samplelib import *
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from utils.label_face import label_face_filename
|
||||
|
||||
class QModel(ModelBase):
|
||||
#override
|
||||
def on_initialize_options(self):
|
||||
|
@ -287,7 +289,7 @@ class QModel(ModelBase):
|
|||
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
|
||||
|
||||
#override
|
||||
def onGetPreview(self, samples, for_history=False):
|
||||
def onGetPreview(self, samples, for_history=False, filenames=None):
|
||||
( (warped_src, target_src, target_srcm),
|
||||
(warped_dst, target_dst, target_dstm) ) = samples
|
||||
|
||||
|
@ -297,6 +299,12 @@ class QModel(ModelBase):
|
|||
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
|
||||
|
||||
n_samples = min(4, self.get_batch_size() )
|
||||
|
||||
if filenames is not None and len(filenames) > 0:
|
||||
for i in range(n_samples):
|
||||
S[i] = label_face_filename(S[i], filenames[0][i])
|
||||
D[i] = label_face_filename(D[i], filenames[1][i])
|
||||
|
||||
result = []
|
||||
st = []
|
||||
for i in range(n_samples):
|
||||
|
|
|
@ -12,6 +12,8 @@ from samplelib import *
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from utils.label_face import label_face_filename
|
||||
|
||||
class SAEHDModel(ModelBase):
|
||||
|
||||
#override
|
||||
|
@ -953,7 +955,7 @@ class SAEHDModel(ModelBase):
|
|||
|
||||
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
|
||||
#override
|
||||
def onGetPreview(self, samples, for_history=False):
|
||||
def onGetPreview(self, samples, for_history=False, filenames=None):
|
||||
( (warped_src, target_src, target_srcm, target_srcm_em),
|
||||
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
|
||||
|
||||
|
@ -965,6 +967,11 @@ class SAEHDModel(ModelBase):
|
|||
|
||||
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
|
||||
|
||||
if filenames is not None and len(filenames) > 0:
|
||||
for i in range(n_samples):
|
||||
S[i] = label_face_filename(S[i], filenames[0][i])
|
||||
D[i] = label_face_filename(D[i], filenames[1][i])
|
||||
|
||||
if self.resolution <= 256:
|
||||
result = []
|
||||
|
||||
|
|
|
@ -115,6 +115,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
samples, index_host, ct_samples, ct_index_host = param
|
||||
|
||||
bs = self.batch_size
|
||||
filenames = []
|
||||
while True:
|
||||
batches = None
|
||||
|
||||
|
@ -141,4 +142,6 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
for i in range(len(x)):
|
||||
batches[i].append ( x[i] )
|
||||
|
||||
yield [ np.array(batch) for batch in batches]
|
||||
filenames.append(sample.filename)
|
||||
|
||||
yield ([ np.array(batch) for batch in batches], filenames)
|
||||
|
|
16
utils/label_face.py
Normal file
16
utils/label_face.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
import cv2
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
def label_face_filename(face, filename):
|
||||
text = os.path.splitext(os.path.basename(filename))[0]
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
org = (5, face.shape[0] - 10)
|
||||
thickness = 1
|
||||
fontScale = 0.5
|
||||
color = (255, 255, 255)
|
||||
face = face.copy() # numpy array issue
|
||||
cv2.putText(face, text, org, font, fontScale, color, thickness, cv2.LINE_AA)
|
||||
|
||||
return face
|
Loading…
Add table
Add a link
Reference in a new issue