mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
added '?' help for model options. Added 'Src face scale modifier' to model options.
This commit is contained in:
parent
e8620919a7
commit
1a2555e160
8 changed files with 254 additions and 104 deletions
|
@ -92,7 +92,9 @@ DF - good for side faces, but results in a lower resolution and details. Covers
|
||||||
LIAE - can partially fix dissimilar face shapes, but results in a less recognizable face.
|
LIAE - can partially fix dissimilar face shapes, but results in a less recognizable face.
|
||||||
SAE - no matter how similar faces, src face will be morphed onto dst face, which can make face absolutely unrecognizable. Model can collapse on some scenes. Easy to overlay final face because dst background is also predicted.
|
SAE - no matter how similar faces, src face will be morphed onto dst face, which can make face absolutely unrecognizable. Model can collapse on some scenes. Easy to overlay final face because dst background is also predicted.
|
||||||
|
|
||||||
also quality of src faceset significantly affects the final face.
|
Quality of src faceset significantly affects the final face.
|
||||||
|
|
||||||
|
Narrow src face is better fakeable than wide. This is why Cage is so popular in deepfakes.
|
||||||
|
|
||||||
### **Sort tool**:
|
### **Sort tool**:
|
||||||
|
|
||||||
|
|
|
@ -58,11 +58,12 @@ class ModelBase(object):
|
||||||
|
|
||||||
if self.epoch == 0:
|
if self.epoch == 0:
|
||||||
print ("\nModel first run. Enter model options as default for each run.")
|
print ("\nModel first run. Enter model options as default for each run.")
|
||||||
self.options['write_preview_history'] = input_bool("Write preview history? (y/n skip:n) : ", False)
|
self.options['write_preview_history'] = input_bool("Write preview history? (y/n ?:help skip:n) : ", False, help_message="Preview history will be writed to <ModelName>_history folder.")
|
||||||
self.options['target_epoch'] = max(0, input_int("Target epoch (skip:unlimited) : ", 0))
|
self.options['target_epoch'] = max(0, input_int("Target epoch (skip:unlimited) : ", 0))
|
||||||
self.options['batch_size'] = max(0, input_int("Batch_size (skip:model choice) : ", 0))
|
self.options['batch_size'] = max(0, input_int("Batch_size (?:help skip:model choice) : ", 0, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
|
||||||
self.options['sort_by_yaw'] = input_bool("Feed faces to network sorted by yaw? (y/n skip:n) : ", False)
|
self.options['sort_by_yaw'] = input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." )
|
||||||
self.options['random_flip'] = input_bool("Flip faces randomly? (y/n skip:y) : ", True)
|
self.options['random_flip'] = input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
|
||||||
|
self.options['src_scale_mod'] = np.clip( input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
|
||||||
#self.options['use_fp16'] = use_fp16 = input_bool("Use float16? (y/n skip:n) : ", False)
|
#self.options['use_fp16'] = use_fp16 = input_bool("Use float16? (y/n skip:n) : ", False)
|
||||||
else:
|
else:
|
||||||
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
|
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
|
||||||
|
@ -70,6 +71,7 @@ class ModelBase(object):
|
||||||
self.options['batch_size'] = self.options.get('batch_size', 0)
|
self.options['batch_size'] = self.options.get('batch_size', 0)
|
||||||
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
|
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
|
||||||
self.options['random_flip'] = self.options.get('random_flip', True)
|
self.options['random_flip'] = self.options.get('random_flip', True)
|
||||||
|
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
||||||
#self.options['use_fp16'] = use_fp16 = self.options['use_fp16'] if 'use_fp16' in self.options.keys() else False
|
#self.options['use_fp16'] = use_fp16 = self.options['use_fp16'] if 'use_fp16' in self.options.keys() else False
|
||||||
|
|
||||||
use_fp16 = False #currently models fails with fp16
|
use_fp16 = False #currently models fails with fp16
|
||||||
|
@ -106,6 +108,10 @@ class ModelBase(object):
|
||||||
if self.random_flip:
|
if self.random_flip:
|
||||||
self.options.pop('random_flip')
|
self.options.pop('random_flip')
|
||||||
|
|
||||||
|
self.src_scale_mod = self.options['src_scale_mod']
|
||||||
|
if self.src_scale_mod == 0:
|
||||||
|
self.options.pop('src_scale_mod')
|
||||||
|
|
||||||
self.write_preview_history = session_write_preview_history
|
self.write_preview_history = session_write_preview_history
|
||||||
self.target_epoch = session_target_epoch
|
self.target_epoch = session_target_epoch
|
||||||
self.batch_size = session_batch_size
|
self.batch_size = session_batch_size
|
||||||
|
|
|
@ -40,7 +40,7 @@ class Model(ModelBase):
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||||
debug=self.is_debug(), batch_size=self.batch_size,
|
debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||||
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
|
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
|
||||||
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
|
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
|
||||||
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ),
|
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ),
|
||||||
|
|
|
@ -46,7 +46,7 @@ class Model(ModelBase):
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||||
debug=self.is_debug(), batch_size=self.batch_size,
|
debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||||
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128],
|
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128],
|
||||||
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128],
|
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128],
|
||||||
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 128] ] ),
|
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 128] ] ),
|
||||||
|
|
|
@ -47,7 +47,7 @@ class Model(ModelBase):
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||||
debug=self.is_debug(), batch_size=self.batch_size,
|
debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||||
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
|
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
|
||||||
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
|
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
|
||||||
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 64] ] ),
|
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 64] ] ),
|
||||||
|
|
|
@ -47,7 +47,7 @@ class Model(ModelBase):
|
||||||
|
|
||||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||||
debug=self.is_debug(), batch_size=self.batch_size,
|
debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||||
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
|
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
|
||||||
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
|
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
|
||||||
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ),
|
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ),
|
||||||
|
|
|
@ -15,21 +15,32 @@ class SAEModel(ModelBase):
|
||||||
decoderH5 = 'decoder.h5'
|
decoderH5 = 'decoder.h5'
|
||||||
decodermH5 = 'decoderm.h5'
|
decodermH5 = 'decoderm.h5'
|
||||||
|
|
||||||
|
decoder_srcH5 = 'decoder_src.h5'
|
||||||
|
decoder_srcmH5 = 'decoder_srcm.h5'
|
||||||
|
decoder_dstH5 = 'decoder_dst.h5'
|
||||||
|
decoder_dstmH5 = 'decoder_dstm.h5'
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onInitializeOptions(self, is_first_run, ask_for_session_options):
|
def onInitializeOptions(self, is_first_run, ask_for_session_options):
|
||||||
default_resolution = 128
|
default_resolution = 128
|
||||||
default_ae_dims = 256
|
default_archi = 'liae'
|
||||||
|
default_style_power = 100
|
||||||
default_face_type = 'f'
|
default_face_type = 'f'
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
#first run
|
#first run
|
||||||
self.options['resolution'] = input_int("Resolution (valid: 64,128, skip:128) : ", default_resolution, [64,128])
|
self.options['resolution'] = input_int("Resolution (64,128, ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
|
||||||
self.options['ae_dims'] = input_int("AutoEncoder dims (valid: 128,256,512 skip:256) : ", default_ae_dims, [128,256,512])
|
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
||||||
self.options['face_type'] = input_str ("Half or Full face? (h/f, skip:f) : ", default_face_type, ['h','f'])
|
self.options['style_power'] = np.clip ( input_int("Style power (1..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst style during generalization of src and dst faces."), 1, 100 )
|
||||||
|
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
||||||
|
self.options['ae_dims'] = input_int("AutoEncoder dims (128,256,512 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, [128,256,512], help_message="More dims are better, but requires more VRAM." )
|
||||||
|
self.options['face_type'] = input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
|
||||||
else:
|
else:
|
||||||
#not first run
|
#not first run
|
||||||
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
||||||
|
self.options['archi'] = self.options.get('archi', default_archi)
|
||||||
|
self.options['style_power'] = self.options.get('style_power', default_style_power)
|
||||||
|
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
||||||
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
|
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
|
||||||
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
||||||
|
|
||||||
|
@ -44,6 +55,15 @@ class SAEModel(ModelBase):
|
||||||
bgr_shape = (resolution, resolution, 3)
|
bgr_shape = (resolution, resolution, 3)
|
||||||
mask_shape = (resolution, resolution, 1)
|
mask_shape = (resolution, resolution, 1)
|
||||||
|
|
||||||
|
warped_src = Input(bgr_shape)
|
||||||
|
target_src = Input(bgr_shape)
|
||||||
|
target_srcm = Input(mask_shape)
|
||||||
|
|
||||||
|
warped_dst = Input(bgr_shape)
|
||||||
|
target_dst = Input(bgr_shape)
|
||||||
|
target_dstm = Input(mask_shape)
|
||||||
|
|
||||||
|
if self.options['archi'] == 'liae':
|
||||||
self.encoder = modelify(SAEModel.EncFlow() ) (Input(bgr_shape))
|
self.encoder = modelify(SAEModel.EncFlow() ) (Input(bgr_shape))
|
||||||
|
|
||||||
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||||
|
@ -63,10 +83,6 @@ class SAEModel(ModelBase):
|
||||||
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
|
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
|
||||||
self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5))
|
self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5))
|
||||||
|
|
||||||
warped_src = Input(bgr_shape)
|
|
||||||
target_src = Input(bgr_shape)
|
|
||||||
target_srcm = Input(mask_shape)
|
|
||||||
|
|
||||||
warped_src_code = self.encoder (warped_src)
|
warped_src_code = self.encoder (warped_src)
|
||||||
|
|
||||||
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
||||||
|
@ -75,9 +91,6 @@ class SAEModel(ModelBase):
|
||||||
pred_src_src = self.decoder(warped_src_inter_code)
|
pred_src_src = self.decoder(warped_src_inter_code)
|
||||||
pred_src_srcm = self.decoderm(warped_src_inter_code)
|
pred_src_srcm = self.decoderm(warped_src_inter_code)
|
||||||
|
|
||||||
warped_dst = Input(bgr_shape)
|
|
||||||
target_dst = Input(bgr_shape)
|
|
||||||
target_dstm = Input(mask_shape)
|
|
||||||
|
|
||||||
warped_dst_code = self.encoder (warped_dst)
|
warped_dst_code = self.encoder (warped_dst)
|
||||||
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
|
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
|
||||||
|
@ -89,6 +102,36 @@ class SAEModel(ModelBase):
|
||||||
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
|
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
|
||||||
pred_src_dst = self.decoder(warped_src_dst_inter_code)
|
pred_src_dst = self.decoder(warped_src_dst_inter_code)
|
||||||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||||
|
else:
|
||||||
|
self.encoder = modelify(SAEModel.DFEncFlow(dims=ae_dims,lowest_dense_res=resolution // 16) ) (Input(bgr_shape))
|
||||||
|
|
||||||
|
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||||
|
|
||||||
|
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],dims=ae_dims)) (dec_Inputs)
|
||||||
|
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],dims=ae_dims)) (dec_Inputs)
|
||||||
|
|
||||||
|
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],dims=ae_dims//2)) (dec_Inputs)
|
||||||
|
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],dims=ae_dims//2)) (dec_Inputs)
|
||||||
|
|
||||||
|
|
||||||
|
if not self.is_first_run():
|
||||||
|
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||||
|
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
||||||
|
self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5))
|
||||||
|
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
|
||||||
|
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
|
||||||
|
|
||||||
|
warped_src_code = self.encoder (warped_src)
|
||||||
|
pred_src_src = self.decoder_src(warped_src_code)
|
||||||
|
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||||
|
|
||||||
|
warped_dst_code = self.encoder (warped_dst)
|
||||||
|
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
||||||
|
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||||
|
|
||||||
|
pred_src_dst = self.decoder_src(warped_dst_code)
|
||||||
|
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||||
|
|
||||||
|
|
||||||
target_srcm_blurred = tf_gaussian_blur(resolution // 32)(target_srcm)
|
target_srcm_blurred = tf_gaussian_blur(resolution // 32)(target_srcm)
|
||||||
target_srcm_sigm = target_srcm_blurred / 2.0 + 0.5
|
target_srcm_sigm = target_srcm_blurred / 2.0 + 0.5
|
||||||
|
@ -105,48 +148,61 @@ class SAEModel(ModelBase):
|
||||||
pred_dst_dst_sigm = pred_dst_dst+1
|
pred_dst_dst_sigm = pred_dst_dst+1
|
||||||
pred_src_dst_sigm = pred_src_dst+1
|
pred_src_dst_sigm = pred_src_dst+1
|
||||||
|
|
||||||
pred_src_dstm_blurred = tf_gaussian_blur(resolution // 32)(pred_src_dstm)
|
|
||||||
pred_src_dstm_sigm = pred_src_dstm_blurred / 2.0 + 0.5
|
|
||||||
pred_src_dstm_anti_sigm = 1.0 - pred_src_dstm_sigm
|
|
||||||
|
|
||||||
target_src_masked = target_src_sigm*target_srcm_sigm
|
target_src_masked = target_src_sigm*target_srcm_sigm
|
||||||
|
|
||||||
target_dst_masked = target_dst_sigm * target_dstm_sigm
|
target_dst_masked = target_dst_sigm * target_dstm_sigm
|
||||||
target_dst_anti_masked = target_dst_sigm * target_dstm_anti_sigm
|
target_dst_anti_masked = target_dst_sigm * target_dstm_anti_sigm
|
||||||
|
|
||||||
target_dst_psd_masked = target_dst_sigm * pred_src_dstm_sigm
|
|
||||||
target_dst_psd_anti_masked = target_dst_sigm * pred_src_dstm_anti_sigm
|
|
||||||
|
|
||||||
pred_src_src_masked = pred_src_src_sigm * target_srcm_sigm
|
pred_src_src_masked = pred_src_src_sigm * target_srcm_sigm
|
||||||
pred_dst_dst_masked = pred_dst_dst_sigm * target_dstm_sigm
|
pred_dst_dst_masked = pred_dst_dst_sigm * target_dstm_sigm
|
||||||
|
|
||||||
psd_target_dst_masked = pred_src_dst_sigm * target_dstm_sigm
|
psd_target_dst_masked = pred_src_dst_sigm * target_dstm_sigm
|
||||||
psd_target_dst_anti_masked = pred_src_dst_sigm * target_dstm_anti_sigm
|
psd_target_dst_anti_masked = pred_src_dst_sigm * target_dstm_anti_sigm
|
||||||
|
|
||||||
psd_psd_masked = pred_src_dst_sigm * pred_src_dstm_sigm
|
|
||||||
psd_psd_anti_masked = pred_src_dst_sigm * pred_src_dstm_anti_sigm
|
style_power = self.options['style_power'] / 100.0
|
||||||
|
|
||||||
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
|
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
|
||||||
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2)(psd_target_dst_masked, target_dst_masked)
|
#src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*style_power)(psd_target_dst_masked, target_dst_masked)
|
||||||
src_loss += K.mean( 100*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
|
src_loss += K.mean( (100*style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
|
||||||
#src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2)(psd_psd_masked, target_dst_psd_masked)
|
|
||||||
#src_loss += K.mean( 100*K.square(tf_dssim(2.0)( psd_psd_anti_masked, target_dst_psd_anti_masked )))
|
|
||||||
|
|
||||||
|
if self.options['archi'] == 'liae':
|
||||||
|
src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||||
|
else:
|
||||||
|
src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
|
||||||
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss],
|
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss],
|
||||||
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_loss, self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights) )
|
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_loss, src_train_weights) )
|
||||||
|
|
||||||
|
|
||||||
dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) )
|
dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) )
|
||||||
|
|
||||||
|
if self.options['archi'] == 'liae':
|
||||||
|
dst_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||||
|
else:
|
||||||
|
dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
|
||||||
self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss],
|
self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss],
|
||||||
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_loss, self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights) )
|
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_loss, dst_train_weights) )
|
||||||
|
|
||||||
|
|
||||||
src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
|
src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
|
||||||
|
|
||||||
|
if self.options['archi'] == 'liae':
|
||||||
|
src_mask_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||||
|
else:
|
||||||
|
src_mask_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
||||||
|
|
||||||
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss],
|
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss],
|
||||||
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_mask_loss, self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights) )
|
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_mask_loss, src_mask_train_weights ) )
|
||||||
|
|
||||||
dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm))
|
dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm))
|
||||||
|
|
||||||
|
if self.options['archi'] == 'liae':
|
||||||
|
dst_mask_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||||
|
else:
|
||||||
|
dst_mask_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
||||||
|
|
||||||
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss],
|
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss],
|
||||||
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_mask_loss, self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights) )
|
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_mask_loss, dst_mask_train_weights) )
|
||||||
|
|
||||||
self.AE_view = K.function ([warped_src, warped_dst],[pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm])
|
self.AE_view = K.function ([warped_src, warped_dst],[pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm])
|
||||||
self.AE_convert = K.function ([warped_dst],[pred_src_dst, pred_src_dstm])
|
self.AE_convert = K.function ([warped_dst],[pred_src_dst, pred_src_dstm])
|
||||||
|
@ -159,7 +215,7 @@ class SAEModel(ModelBase):
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||||
debug=self.is_debug(), batch_size=self.batch_size,
|
debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||||
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||||
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||||
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution],
|
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution],
|
||||||
|
@ -175,12 +231,20 @@ class SAEModel(ModelBase):
|
||||||
])
|
])
|
||||||
#override
|
#override
|
||||||
def onSave(self):
|
def onSave(self):
|
||||||
|
if self.options['archi'] == 'liae':
|
||||||
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
||||||
[self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)],
|
[self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)],
|
||||||
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)],
|
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)],
|
||||||
[self.decoder, self.get_strpath_storage_for_file(self.decoderH5)],
|
[self.decoder, self.get_strpath_storage_for_file(self.decoderH5)],
|
||||||
[self.decoderm, self.get_strpath_storage_for_file(self.decodermH5)],
|
[self.decoderm, self.get_strpath_storage_for_file(self.decodermH5)],
|
||||||
] )
|
] )
|
||||||
|
else:
|
||||||
|
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
||||||
|
[self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)],
|
||||||
|
[self.decoder_srcm, self.get_strpath_storage_for_file(self.decoder_srcmH5)],
|
||||||
|
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)],
|
||||||
|
[self.decoder_dstm, self.get_strpath_storage_for_file(self.decoder_dstmH5)],
|
||||||
|
] )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneEpoch(self, sample):
|
def onTrainOneEpoch(self, sample):
|
||||||
|
@ -216,7 +280,7 @@ class SAEModel(ModelBase):
|
||||||
st.append ( np.concatenate ( (
|
st.append ( np.concatenate ( (
|
||||||
S[i], SS[i], #SM[i],
|
S[i], SS[i], #SM[i],
|
||||||
D[i], DD[i], #DM[i],
|
D[i], DD[i], #DM[i],
|
||||||
SD[i], #SDM[i]
|
SD[i], SDM[i]
|
||||||
), axis=1) )
|
), axis=1) )
|
||||||
|
|
||||||
return [ ('SAE', np.concatenate ( st, axis=0 ) ) ]
|
return [ ('SAE', np.concatenate ( st, axis=0 ) ) ]
|
||||||
|
@ -306,4 +370,54 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def DFEncFlow(dims=512, lowest_dense_res=8):
|
||||||
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
|
||||||
|
def downscale (dim):
|
||||||
|
def func(x):
|
||||||
|
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
|
||||||
|
return func
|
||||||
|
|
||||||
|
def upscale (dim):
|
||||||
|
def func(x):
|
||||||
|
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||||
|
return func
|
||||||
|
|
||||||
|
def func(input):
|
||||||
|
x = input
|
||||||
|
|
||||||
|
x = downscale(128)(x)
|
||||||
|
x = downscale(256)(x)
|
||||||
|
x = downscale(512)(x)
|
||||||
|
x = downscale(1024)(x)
|
||||||
|
|
||||||
|
x = Dense(dims)(Flatten()(x))
|
||||||
|
x = Dense(lowest_dense_res * lowest_dense_res * dims)(x)
|
||||||
|
x = Reshape((lowest_dense_res, lowest_dense_res, dims))(x)
|
||||||
|
x = upscale(dims)(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
return func
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def DFDecFlow(output_nc,dims,activation='tanh'):
|
||||||
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
|
||||||
|
def upscale (dim):
|
||||||
|
def func(x):
|
||||||
|
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||||
|
return func
|
||||||
|
def func(input):
|
||||||
|
x = input[0]
|
||||||
|
x = upscale(dims)(x)
|
||||||
|
x = upscale(dims//2)(x)
|
||||||
|
x = upscale(dims//4)(x)
|
||||||
|
|
||||||
|
x = Conv2D(output_nc, kernel_size=5, padding='same', activation=activation)(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
Model = SAEModel
|
Model = SAEModel
|
|
@ -1,26 +1,54 @@
|
||||||
|
|
||||||
|
|
||||||
def input_int(s, default_value, valid_list=None):
|
def input_int(s, default_value, valid_list=None, help_message=None):
|
||||||
|
while True:
|
||||||
try:
|
try:
|
||||||
inp = input(s)
|
inp = input(s)
|
||||||
|
if len(inp) == 0:
|
||||||
|
raise ValueError("")
|
||||||
|
|
||||||
|
if help_message is not None and inp == '?':
|
||||||
|
print (help_message)
|
||||||
|
continue
|
||||||
|
|
||||||
i = int(inp)
|
i = int(inp)
|
||||||
if (valid_list is not None) and (i not in valid_list):
|
if (valid_list is not None) and (i not in valid_list):
|
||||||
return default_value
|
return default_value
|
||||||
return i
|
return i
|
||||||
except:
|
except:
|
||||||
|
print (default_value)
|
||||||
return default_value
|
return default_value
|
||||||
|
|
||||||
def input_bool(s, default_value):
|
def input_bool(s, default_value, help_message=None):
|
||||||
try:
|
while True:
|
||||||
return bool ( {"y":True,"n":False,"1":True,"0":False}.get(input(s).lower(), default_value) )
|
|
||||||
except:
|
|
||||||
return default_value
|
|
||||||
|
|
||||||
def input_str(s, default_value, valid_list=None):
|
|
||||||
try:
|
try:
|
||||||
inp = input(s)
|
inp = input(s)
|
||||||
|
if len(inp) == 0:
|
||||||
|
raise ValueError("")
|
||||||
|
|
||||||
|
if help_message is not None and inp == '?':
|
||||||
|
print (help_message)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return bool ( {"y":True,"n":False,"1":True,"0":False}.get(inp.lower(), default_value) )
|
||||||
|
except:
|
||||||
|
print ( "y" if default_value else "n" )
|
||||||
|
return default_value
|
||||||
|
|
||||||
|
def input_str(s, default_value, valid_list=None, help_message=None):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
inp = input(s)
|
||||||
|
if len(inp) == 0:
|
||||||
|
raise ValueError("")
|
||||||
|
|
||||||
|
if help_message is not None and inp == '?':
|
||||||
|
print (help_message)
|
||||||
|
continue
|
||||||
|
|
||||||
if (valid_list is not None) and (inp.lower() not in valid_list):
|
if (valid_list is not None) and (inp.lower() not in valid_list):
|
||||||
return default_value
|
return default_value
|
||||||
return inp
|
return inp
|
||||||
except:
|
except:
|
||||||
|
print (default_value)
|
||||||
return default_value
|
return default_value
|
Loading…
Add table
Add a link
Reference in a new issue