mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
SAE: added test option: 'Apply random color transfer to src faceset'
This commit is contained in:
parent
bde700243c
commit
a805f81142
9 changed files with 152 additions and 129 deletions
|
@ -59,7 +59,7 @@ class SAEModel(ModelBase):
|
|||
default_e_ch_dims = 42
|
||||
default_d_ch_dims = default_e_ch_dims // 2
|
||||
def_ca_weights = False
|
||||
|
||||
|
||||
if is_first_run:
|
||||
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
|
||||
self.options['e_ch_dims'] = np.clip ( io.input_int("Encoder dims per channel (21-85 ?:help skip:%d) : " % (default_e_ch_dims) , default_e_ch_dims, help_message="More encoder dims help to recognize more facial features, but require more VRAM. You can fine-tune model size to fit your GPU." ), 21, 85 )
|
||||
|
@ -87,16 +87,21 @@ class SAEModel(ModelBase):
|
|||
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
|
||||
self.options['bg_style_power'] = np.clip ( io.input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power,
|
||||
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
|
||||
|
||||
default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False)
|
||||
self.options['apply_random_ct'] = io.input_bool ("Apply random color transfer to src faceset? (y/n, ?:help skip:%s) : " % (yn_str[default_apply_random_ct]), default_apply_random_ct, help_message="Increase variativity of src samples by apply RCT color transfer from random dst samples.")
|
||||
|
||||
else:
|
||||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
||||
|
||||
self.options['apply_random_ct'] = self.options.get('apply_random_ct', False)
|
||||
|
||||
if is_first_run:
|
||||
self.options['pretrain'] = io.input_bool ("Pretrain the model? (y/n, ?:help skip:n) : ", False, help_message="Pretrain the model with large amount of various faces. This technique may help to train the fake with overly different face shapes and light conditions of src/dst data. Face will be look more like a morphed. To reduce the morph effect, some model files will be initialized but not be updated after pretrain: LIAE: inter_AB.h5 DF: encoder.h5. The longer you pretrain the model the more morphed face will look. After that, save and run the training again.")
|
||||
else:
|
||||
self.options['pretrain'] = False
|
||||
|
||||
self.options['pretrain'] = False
|
||||
|
||||
#override
|
||||
def onInitialize(self):
|
||||
exec(nnlib.import_all(), locals(), globals())
|
||||
|
@ -110,13 +115,14 @@ class SAEModel(ModelBase):
|
|||
self.pretrain = self.options['pretrain'] = self.options.get('pretrain', False)
|
||||
if not self.pretrain:
|
||||
self.options.pop('pretrain')
|
||||
|
||||
|
||||
d_residual_blocks = True
|
||||
bgr_shape = (resolution, resolution, 3)
|
||||
mask_shape = (resolution, resolution, 1)
|
||||
|
||||
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
|
||||
|
||||
apply_random_ct = self.options.get('apply_random_ct', False)
|
||||
masked_training = True
|
||||
|
||||
warped_src = Input(bgr_shape)
|
||||
|
@ -133,8 +139,8 @@ class SAEModel(ModelBase):
|
|||
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
||||
|
||||
common_flow_kwargs = { 'padding': 'zero',
|
||||
'norm': 'norm',
|
||||
'act':'' }
|
||||
'norm': '',
|
||||
'act':'' }
|
||||
models_list = []
|
||||
weights_to_load = []
|
||||
if 'liae' in self.options['archi']:
|
||||
|
@ -149,11 +155,11 @@ class SAEModel(ModelBase):
|
|||
|
||||
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs)
|
||||
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder]
|
||||
|
||||
|
||||
if self.options['learn_mask']:
|
||||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs)
|
||||
models_list += [self.decoderm]
|
||||
|
||||
|
||||
if not self.is_first_run():
|
||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||
[self.inter_B , 'inter_B.h5'],
|
||||
|
@ -191,12 +197,12 @@ class SAEModel(ModelBase):
|
|||
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
||||
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
||||
models_list += [self.encoder, self.decoder_src, self.decoder_dst]
|
||||
|
||||
|
||||
if self.options['learn_mask']:
|
||||
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
||||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
||||
models_list += [self.decoder_srcm, self.decoder_dstm]
|
||||
|
||||
|
||||
if not self.is_first_run():
|
||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
|
@ -217,18 +223,18 @@ class SAEModel(ModelBase):
|
|||
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||
|
||||
|
||||
if self.is_first_run():
|
||||
if self.options.get('ca_weights',False):
|
||||
conv_weights_list = []
|
||||
for model in models_list:
|
||||
for layer in model.layers:
|
||||
if type(layer) == keras.layers.Conv2D:
|
||||
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
||||
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
||||
CAInitializerMP ( conv_weights_list )
|
||||
else:
|
||||
self.load_weights_safe(weights_to_load)
|
||||
|
||||
|
||||
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
||||
|
||||
if self.options['learn_mask']:
|
||||
|
@ -264,7 +270,7 @@ class SAEModel(ModelBase):
|
|||
|
||||
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||
|
||||
|
||||
if self.is_training_mode:
|
||||
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
|
@ -328,7 +334,7 @@ class SAEModel(ModelBase):
|
|||
else:
|
||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] )
|
||||
|
||||
|
||||
|
||||
else:
|
||||
if self.options['learn_mask']:
|
||||
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_dst_dstm[-1], pred_src_dstm[-1] ])
|
||||
|
@ -345,29 +351,31 @@ class SAEModel(ModelBase):
|
|||
|
||||
t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
|
||||
|
||||
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution} ]
|
||||
output_sample_types += [ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i) } for i in range(ms_count)]
|
||||
output_sample_types += [ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M, t.FACE_MASK_FULL), 'resolution': resolution // (2**i) } for i in range(ms_count)]
|
||||
|
||||
training_data_src_path = self.training_data_src_path
|
||||
training_data_dst_path = self.training_data_dst_path
|
||||
sort_by_yaw = self.sort_by_yaw
|
||||
|
||||
|
||||
if self.pretrain and self.pretraining_data_path is not None:
|
||||
training_data_src_path = self.pretraining_data_path
|
||||
training_data_dst_path = self.pretraining_data_path
|
||||
sort_by_yaw = False
|
||||
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||
output_sample_types=output_sample_types ),
|
||||
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution, 'apply_ct': apply_random_ct} ] + \
|
||||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i), 'apply_ct': apply_random_ct } for i in range(ms_count)] + \
|
||||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution // (2**i) } for i in range(ms_count)]
|
||||
),
|
||||
|
||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
||||
output_sample_types=output_sample_types )
|
||||
])
|
||||
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution} ] + \
|
||||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i)} for i in range(ms_count)] + \
|
||||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution // (2**i) } for i in range(ms_count)])
|
||||
])
|
||||
|
||||
#override
|
||||
def onSave(self):
|
||||
|
@ -380,23 +388,23 @@ class SAEModel(ModelBase):
|
|||
[self.inter_B, 'inter_B.h5'],
|
||||
[self.decoder, 'decoder.h5']
|
||||
]
|
||||
|
||||
|
||||
if not self.pretrain or self.iter == 0:
|
||||
ar += [ [self.inter_AB, 'inter_AB.h5'],
|
||||
]
|
||||
|
||||
|
||||
if self.options['learn_mask']:
|
||||
ar += [ [self.decoderm, 'decoderm.h5'] ]
|
||||
|
||||
|
||||
elif 'df' in self.options['archi']:
|
||||
if not self.pretrain or self.iter == 0:
|
||||
ar += [ [self.encoder, 'encoder.h5'],
|
||||
]
|
||||
|
||||
|
||||
ar += [ [self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']
|
||||
]
|
||||
|
||||
]
|
||||
|
||||
if self.options['learn_mask']:
|
||||
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
||||
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
||||
|
@ -442,15 +450,15 @@ class SAEModel(ModelBase):
|
|||
for i in range(0, len(test_S)):
|
||||
ar = S[i], SS[i], D[i], DD[i], SD[i]
|
||||
st.append ( np.concatenate ( ar, axis=1) )
|
||||
|
||||
|
||||
result += [ ('SAE', np.concatenate (st, axis=0 )), ]
|
||||
|
||||
|
||||
if self.options['learn_mask']:
|
||||
st_m = []
|
||||
for i in range(0, len(test_S)):
|
||||
ar = S[i]*test_S_m[i], SS[i], D[i]*test_D_m[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i])
|
||||
st_m.append ( np.concatenate ( ar, axis=1) )
|
||||
|
||||
|
||||
result += [ ('SAE masked', np.concatenate (st_m, axis=0 )), ]
|
||||
|
||||
return result
|
||||
|
@ -458,7 +466,7 @@ class SAEModel(ModelBase):
|
|||
def predictor_func (self, face):
|
||||
if self.options['learn_mask']:
|
||||
bgr, mask_dst_dstm, mask_src_dstm = self.AE_convert ([face[np.newaxis,...]])
|
||||
mask = mask_dst_dstm[0] * mask_src_dstm[0]
|
||||
mask = mask_dst_dstm[0] * mask_src_dstm[0]
|
||||
return bgr[0], mask[...,0]
|
||||
else:
|
||||
bgr, = self.AE_convert ([face[np.newaxis,...]])
|
||||
|
@ -493,13 +501,13 @@ class SAEModel(ModelBase):
|
|||
|
||||
def NormPass(x):
|
||||
return x
|
||||
|
||||
|
||||
def Norm(norm=''):
|
||||
if norm == 'bn':
|
||||
return BatchNormalization(axis=-1)
|
||||
else:
|
||||
return NormPass
|
||||
|
||||
|
||||
def Act(act='', lrelu_alpha=0.1):
|
||||
if act == 'prelu':
|
||||
return PReLU()
|
||||
|
@ -549,7 +557,7 @@ class SAEModel(ModelBase):
|
|||
exec (nnlib.import_all(), locals(), globals())
|
||||
upscale = partial(SAEModel.upscale, **kwargs)
|
||||
downscale = partial(SAEModel.downscale, **kwargs)
|
||||
|
||||
|
||||
def func(input):
|
||||
dims = K.int_shape(input)[-1]*ch_dims
|
||||
|
||||
|
@ -571,7 +579,7 @@ class SAEModel(ModelBase):
|
|||
|
||||
def func(input):
|
||||
x = input[0]
|
||||
x = Dense(ae_dims)(x)
|
||||
x = Dense(ae_dims)(x)
|
||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
|
||||
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
|
||||
x = upscale(ae_dims*2)(x)
|
||||
|
@ -635,8 +643,8 @@ class SAEModel(ModelBase):
|
|||
x = downscale(dims*4)(x)
|
||||
x = downscale(dims*8)(x)
|
||||
|
||||
x = Dense(ae_dims)(Flatten()(x))
|
||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
||||
x = Dense(ae_dims)(Flatten()(x))
|
||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
||||
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
|
||||
x = upscale(ae_dims)(x)
|
||||
return x
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue