mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
upd fan segmentator
This commit is contained in:
parent
1f569117c8
commit
034ad3cce5
3 changed files with 27 additions and 37 deletions
|
@ -33,7 +33,6 @@ class FANSegmentator(object):
|
||||||
return self.model.train_on_batch(inp, outp)
|
return self.model.train_on_batch(inp, outp)
|
||||||
|
|
||||||
def extract_from_bgr (self, input_image):
|
def extract_from_bgr (self, input_image):
|
||||||
#return np.clip ( self.model.predict(input_image), 0, 1.0 )
|
|
||||||
return np.clip ( (self.model.predict(input_image) + 1) / 2.0, 0, 1.0 )
|
return np.clip ( (self.model.predict(input_image) + 1) / 2.0, 0, 1.0 )
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -44,8 +43,7 @@ class FANSegmentator(object):
|
||||||
x = FANSegmentator.EncFlow(ngf=ngf)(x)
|
x = FANSegmentator.EncFlow(ngf=ngf)(x)
|
||||||
x = FANSegmentator.DecFlow(ngf=ngf)(x)
|
x = FANSegmentator.DecFlow(ngf=ngf)(x)
|
||||||
model = Model(inp,x)
|
model = Model(inp,x)
|
||||||
model.compile (loss='mse', optimizer=Padam(tf_cpu_mode=2) )
|
model.compile (loss='mse', optimizer=Adam(tf_cpu_mode=2) )
|
||||||
#model.compile (loss='mse', optimizer=Adam(tf_cpu_mode=2) )
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -19,8 +19,10 @@ You can implement your own model. Check examples.
|
||||||
'''
|
'''
|
||||||
class ModelBase(object):
|
class ModelBase(object):
|
||||||
|
|
||||||
#DONT OVERRIDE
|
|
||||||
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, device_args = None):
|
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, device_args = None,
|
||||||
|
ask_write_preview_history=True, ask_target_iter=True, ask_batch_size=True, ask_sort_by_yaw=True,
|
||||||
|
ask_random_flip=True, ask_src_scale_mod=True):
|
||||||
|
|
||||||
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
|
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
|
||||||
device_args['cpu_only'] = device_args.get('cpu_only',False)
|
device_args['cpu_only'] = device_args.get('cpu_only',False)
|
||||||
|
@ -58,6 +60,8 @@ class ModelBase(object):
|
||||||
self.options = {}
|
self.options = {}
|
||||||
self.loss_history = []
|
self.loss_history = []
|
||||||
self.sample_for_preview = None
|
self.sample_for_preview = None
|
||||||
|
|
||||||
|
model_data = {}
|
||||||
if self.model_data_path.exists():
|
if self.model_data_path.exists():
|
||||||
model_data = pickle.loads ( self.model_data_path.read_bytes() )
|
model_data = pickle.loads ( self.model_data_path.read_bytes() )
|
||||||
self.iter = max( model_data.get('iter',0), model_data.get('epoch',0) )
|
self.iter = max( model_data.get('iter',0), model_data.get('epoch',0) )
|
||||||
|
@ -75,36 +79,36 @@ class ModelBase(object):
|
||||||
if self.iter == 0:
|
if self.iter == 0:
|
||||||
io.log_info ("\nModel first run. Enter model options as default for each run.")
|
io.log_info ("\nModel first run. Enter model options as default for each run.")
|
||||||
|
|
||||||
if self.iter == 0 or ask_override:
|
if ask_write_preview_history and (self.iter == 0 or ask_override):
|
||||||
default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',False)
|
default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',False)
|
||||||
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
|
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
|
||||||
else:
|
else:
|
||||||
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
|
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
|
||||||
|
|
||||||
if self.iter == 0 or ask_override:
|
if ask_target_iter and (self.iter == 0 or ask_override):
|
||||||
self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0))
|
self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0))
|
||||||
else:
|
else:
|
||||||
self.options['target_iter'] = max(model_data.get('target_iter',0), self.options.get('target_epoch',0))
|
self.options['target_iter'] = max(model_data.get('target_iter',0), self.options.get('target_epoch',0))
|
||||||
if 'target_epoch' in self.options:
|
if 'target_epoch' in self.options:
|
||||||
self.options.pop('target_epoch')
|
self.options.pop('target_epoch')
|
||||||
|
|
||||||
if self.iter == 0 or ask_override:
|
if ask_batch_size and (self.iter == 0 or ask_override):
|
||||||
default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0)
|
default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0)
|
||||||
self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % (default_batch_size), default_batch_size, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
|
self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % (default_batch_size), default_batch_size, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
|
||||||
else:
|
else:
|
||||||
self.options['batch_size'] = self.options.get('batch_size', 0)
|
self.options['batch_size'] = self.options.get('batch_size', 0)
|
||||||
|
|
||||||
if self.iter == 0:
|
if ask_sort_by_yaw and (self.iter == 0):
|
||||||
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." )
|
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." )
|
||||||
else:
|
else:
|
||||||
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
|
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
|
||||||
|
|
||||||
if self.iter == 0:
|
if ask_random_flip and (self.iter == 0):
|
||||||
self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
|
self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
|
||||||
else:
|
else:
|
||||||
self.options['random_flip'] = self.options.get('random_flip', True)
|
self.options['random_flip'] = self.options.get('random_flip', True)
|
||||||
|
|
||||||
if self.iter == 0:
|
if ask_src_scale_mod and (self.iter == 0):
|
||||||
self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
|
self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
|
||||||
else:
|
else:
|
||||||
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
||||||
|
|
|
@ -9,6 +9,13 @@ from interact import interact as io
|
||||||
|
|
||||||
class Model(ModelBase):
|
class Model(ModelBase):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs,
|
||||||
|
ask_write_preview_history=False,
|
||||||
|
ask_target_iter=False,
|
||||||
|
ask_sort_by_yaw=False,
|
||||||
|
ask_random_flip=False,
|
||||||
|
ask_src_scale_mod=False)
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onInitialize(self):
|
def onInitialize(self):
|
||||||
|
@ -25,17 +32,17 @@ class Model(ModelBase):
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
f = SampleProcessor.TypeFlags
|
f = SampleProcessor.TypeFlags
|
||||||
f_type = f.FACE_ALIGN_FULL #if self.face_type == FaceType.FULL else f.FACE_ALIGN_HALF
|
f_type = f.FACE_ALIGN_FULL
|
||||||
|
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
|
||||||
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR, self.resolution],
|
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR, self.resolution],
|
||||||
[f.TRANSFORMED | f_type | f.MODE_M | f.FACE_MASK_FULL, self.resolution]
|
[f.TRANSFORMED | f_type | f.MODE_M | f.FACE_MASK_FULL, self.resolution]
|
||||||
]),
|
]),
|
||||||
|
|
||||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
|
||||||
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR, self.resolution]
|
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR, self.resolution]
|
||||||
])
|
])
|
||||||
])
|
])
|
||||||
|
@ -59,6 +66,9 @@ class Model(ModelBase):
|
||||||
|
|
||||||
mAA = self.fan_seg.extract_from_bgr([test_A])
|
mAA = self.fan_seg.extract_from_bgr([test_A])
|
||||||
mBB = self.fan_seg.extract_from_bgr([test_B])
|
mBB = self.fan_seg.extract_from_bgr([test_B])
|
||||||
|
|
||||||
|
test_A, test_B, = [ np.clip( (x + 1.0)/2.0, 0.0, 1.0) for x in [test_A, test_B] ]
|
||||||
|
|
||||||
mAA = np.repeat ( mAA, (3,), -1)
|
mAA = np.repeat ( mAA, (3,), -1)
|
||||||
mBB = np.repeat ( mBB, (3,), -1)
|
mBB = np.repeat ( mBB, (3,), -1)
|
||||||
|
|
||||||
|
@ -81,25 +91,3 @@ class Model(ModelBase):
|
||||||
return [ ('FANSegmentator', np.concatenate ( st, axis=0 ) ),
|
return [ ('FANSegmentator', np.concatenate ( st, axis=0 ) ),
|
||||||
('never seen', np.concatenate ( st2, axis=0 ) ),
|
('never seen', np.concatenate ( st2, axis=0 ) ),
|
||||||
]
|
]
|
||||||
|
|
||||||
def predictor_func (self, face):
|
|
||||||
|
|
||||||
face_64_bgr = face[...,0:3]
|
|
||||||
face_64_mask = np.expand_dims(face[...,3],-1)
|
|
||||||
|
|
||||||
x, mx = self.src_view ( [ np.expand_dims(face_64_bgr,0) ] )
|
|
||||||
x, mx = x[0], mx[0]
|
|
||||||
|
|
||||||
return np.concatenate ( (x,mx), -1 )
|
|
||||||
|
|
||||||
#override
|
|
||||||
def get_converter(self):
|
|
||||||
from converters import ConverterMasked
|
|
||||||
return ConverterMasked(self.predictor_func,
|
|
||||||
predictor_input_size=64,
|
|
||||||
output_size=64,
|
|
||||||
face_type=FaceType.HALF,
|
|
||||||
base_erode_mask_modifier=100,
|
|
||||||
base_blur_mask_modifier=100)
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue