mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
enhanced SAE model. You should to restart training.
new default style power = 2.0 fix DF default batch sizes. upd readme
This commit is contained in:
parent
22401cecfc
commit
946688567d
5 changed files with 137 additions and 134 deletions
11
README.md
11
README.md
|
@ -82,7 +82,9 @@ LIAEF128 Cage video:
|
||||||
|
|
||||||
[](https://www.youtube.com/watch?v=mRsexePEVco)
|
[](https://www.youtube.com/watch?v=mRsexePEVco)
|
||||||
|
|
||||||
- **SAE (2GB+)** - Styled AutoEncoder - new superior model based on style loss. Morphing/stylizing done directly by neural network. Face obstructions also reconstructed without any masks. Converter mode 'overlay' should be used. Model has several options on start for fine tuning to fit your GPU.
|
- **SAE ( minimum 2GB+, recommended 11GB+ )** - Styled AutoEncoder - new superior model based on style loss. Morphing/stylizing done directly by neural network. Face obstructions also reconstructed without any masks. Converter mode 'overlay' should be used. Model has several options on start for fine tuning to fit your GPU.
|
||||||
|
|
||||||
|
SAE actually contains all other models. Just set style powers to 0.0 to get default models.
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
@ -96,7 +98,6 @@ SAE model Putin-Navalny video: https://www.youtube.com/watch?v=Jj7b3mqx-Mw
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|
||||||
### **Tips and tricks**:
|
### **Tips and tricks**:
|
||||||
|
|
||||||
unfortunately deepfaking is time/eletricpower consuming topic and has a lot of nuances.
|
unfortunately deepfaking is time/eletricpower consuming topic and has a lot of nuances.
|
||||||
|
@ -119,13 +120,17 @@ Narrow src face is better fakeable than wide. This is why Cage is so popular in
|
||||||
|
|
||||||
SAE tips:
|
SAE tips:
|
||||||
|
|
||||||
|
- SAE actually contains all other models. Just set style power options to 0.0 to get default models.
|
||||||
|
|
||||||
- if src faceset has number of faces more than dst faceset, model can be not converged. In this case try 'Feed faces to network sorted by yaw' option.
|
- if src faceset has number of faces more than dst faceset, model can be not converged. In this case try 'Feed faces to network sorted by yaw' option.
|
||||||
|
|
||||||
- if src face wider than dst, model can be not converged. In this case try to decrease 'Src face scale modifier' to -5.
|
- if src face wider than dst, model can be not converged. In this case try to decrease 'Src face scale modifier' to -5.
|
||||||
|
|
||||||
- architecture 'df' make predicted face looking more like src, but if model not converges try default 'liae'.
|
- architecture 'df' make predicted face looking more like src, but if model not converges try default 'liae'.
|
||||||
|
|
||||||
- most scenes converge fine with batch size = 8. In this case better to increase 'encoder/decoder dims per channel' to get more sharp result.
|
- if you have a lot of VRAM, you can choose between batch size that affects quality of generalization and enc/dec dims that affects image quality.
|
||||||
|
|
||||||
|
- if style speed is too fast, you will get the artifacts before the face becomes sharp.
|
||||||
|
|
||||||
### **Sort tool**:
|
### **Sort tool**:
|
||||||
|
|
||||||
|
|
|
@ -163,10 +163,10 @@ def previewThread (input_queue, output_queue):
|
||||||
|
|
||||||
if update_preview:
|
if update_preview:
|
||||||
update_preview = False
|
update_preview = False
|
||||||
(h,w,c) = previews[0][1].shape
|
|
||||||
|
|
||||||
selected_preview_name = previews[selected_preview][0]
|
selected_preview_name = previews[selected_preview][0]
|
||||||
selected_preview_rgb = previews[selected_preview][1]
|
selected_preview_rgb = previews[selected_preview][1]
|
||||||
|
(h,w,c) = selected_preview_rgb.shape
|
||||||
|
|
||||||
# HEAD
|
# HEAD
|
||||||
head_text_color = [0.8]*c
|
head_text_color = [0.8]*c
|
||||||
|
|
|
@ -14,7 +14,7 @@ class Model(ModelBase):
|
||||||
#override
|
#override
|
||||||
def onInitialize(self, **in_options):
|
def onInitialize(self, **in_options):
|
||||||
exec(nnlib.import_all(), locals(), globals())
|
exec(nnlib.import_all(), locals(), globals())
|
||||||
self.set_vram_batch_requirements( {4.5:16,5:16,6:16,7:16,8:24,9:24,10:32,11:32,12:32,13:48} )
|
self.set_vram_batch_requirements( {4.5:4,5:6,6:8,7:16,8:24,9:24,10:32,11:32,12:32,13:48} )
|
||||||
|
|
||||||
ae_input_layer = Input(shape=(128, 128, 3))
|
ae_input_layer = Input(shape=(128, 128, 3))
|
||||||
mask_layer = Input(shape=(128, 128, 1)) #same as output
|
mask_layer = Input(shape=(128, 128, 1)) #same as output
|
||||||
|
|
|
@ -28,21 +28,21 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
self.options['resolution'] = input_int("Resolution (64,128 ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
|
self.options['resolution'] = input_int("Resolution (64,128 ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
|
||||||
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
||||||
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, sacrificing overall quality.")
|
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, sacrificing overall quality.")
|
||||||
else:
|
else:
|
||||||
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
||||||
self.options['archi'] = self.options.get('archi', default_archi)
|
self.options['archi'] = self.options.get('archi', default_archi)
|
||||||
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
||||||
|
|
||||||
default_face_style_power = 10.0
|
default_face_style_power = 2.0
|
||||||
if is_first_run or ask_override:
|
if is_first_run or ask_override:
|
||||||
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
|
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
|
||||||
self.options['face_style_power'] = np.clip ( input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_face_style_power), default_face_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0.0, 100.0 )
|
self.options['face_style_power'] = np.clip ( input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_face_style_power), default_face_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0.0, 100.0 )
|
||||||
else:
|
else:
|
||||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||||
|
|
||||||
default_bg_style_power = 10.0
|
default_bg_style_power = 2.0
|
||||||
if is_first_run or ask_override:
|
if is_first_run or ask_override:
|
||||||
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
|
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
|
||||||
self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0.0, 100.0 )
|
self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0.0, 100.0 )
|
||||||
|
@ -103,7 +103,6 @@ class SAEModel(ModelBase):
|
||||||
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (inter_output_Inputs)
|
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (inter_output_Inputs)
|
||||||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
|
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
|
||||||
|
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||||
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
|
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
|
||||||
|
@ -137,11 +136,9 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
||||||
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
||||||
|
|
||||||
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
||||||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
||||||
|
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||||
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
||||||
|
@ -150,100 +147,91 @@ class SAEModel(ModelBase):
|
||||||
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
|
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
|
||||||
|
|
||||||
warped_src_code = self.encoder (warped_src)
|
warped_src_code = self.encoder (warped_src)
|
||||||
|
|
||||||
pred_src_src = self.decoder_src(warped_src_code)
|
pred_src_src = self.decoder_src(warped_src_code)
|
||||||
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||||
|
|
||||||
warped_dst_code = self.encoder (warped_dst)
|
warped_dst_code = self.encoder (warped_dst)
|
||||||
|
|
||||||
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
||||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||||
|
|
||||||
pred_src_dst = self.decoder_src(warped_dst_code)
|
pred_src_dst = self.decoder_src(warped_dst_code)
|
||||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||||
|
|
||||||
|
ms_count = len(pred_src_src)
|
||||||
|
|
||||||
target_srcm_blurred = tf_gaussian_blur(resolution // 32)(target_srcm)
|
target_src_ar = [ target_src if i == 0 else tf.image.resize_bicubic( target_src, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||||
target_srcm_sigm = target_srcm_blurred / 2.0 + 0.5
|
target_srcm_ar = [ target_srcm if i == 0 else tf.image.resize_bicubic( target_srcm, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||||
target_srcm_anti_sigm = 1.0 - target_srcm_sigm
|
target_dst_ar = [ target_dst if i == 0 else tf.image.resize_bicubic( target_dst, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||||
|
target_dstm_ar = [ target_dstm if i == 0 else tf.image.resize_bicubic( target_dstm, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||||
|
|
||||||
target_dstm_blurred = tf_gaussian_blur(resolution // 32)(target_dstm)
|
target_srcm_blurred_ar = [ tf_gaussian_blur( max(1, x.get_shape().as_list()[1] // 32) )(x) for x in target_srcm_ar]
|
||||||
target_dstm_sigm = target_dstm_blurred / 2.0 + 0.5
|
target_srcm_sigm_ar = [ x / 2.0 + 0.5 for x in target_srcm_blurred_ar]
|
||||||
target_dstm_anti_sigm = 1.0 - target_dstm_sigm
|
target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar]
|
||||||
|
|
||||||
target_src_sigm = target_src+1
|
target_dstm_blurred_ar = [ tf_gaussian_blur( max(1, x.get_shape().as_list()[1] // 32) )(x) for x in target_dstm_ar]
|
||||||
target_dst_sigm = target_dst+1
|
target_dstm_sigm_ar = [ x / 2.0 + 0.5 for x in target_dstm_blurred_ar]
|
||||||
|
target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar]
|
||||||
|
|
||||||
pred_src_src_sigm = pred_src_src+1
|
target_src_sigm_ar = [ x + 1 for x in target_src_ar]
|
||||||
pred_dst_dst_sigm = pred_dst_dst+1
|
target_dst_sigm_ar = [ x + 1 for x in target_dst_ar]
|
||||||
pred_src_dst_sigm = pred_src_dst+1
|
|
||||||
|
|
||||||
target_src_masked = target_src_sigm*target_srcm_sigm
|
pred_src_src_sigm_ar = [ x + 1 for x in pred_src_src]
|
||||||
|
pred_dst_dst_sigm_ar = [ x + 1 for x in pred_dst_dst]
|
||||||
|
pred_src_dst_sigm_ar = [ x + 1 for x in pred_src_dst]
|
||||||
|
|
||||||
target_dst_masked = target_dst_sigm * target_dstm_sigm
|
target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
|
||||||
target_dst_anti_masked = target_dst_sigm * target_dstm_anti_sigm
|
target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
||||||
|
target_dst_anti_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
||||||
|
|
||||||
pred_src_src_masked = pred_src_src_sigm * target_srcm_sigm
|
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||||
pred_dst_dst_masked = pred_dst_dst_sigm * target_dstm_sigm
|
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||||
|
|
||||||
psd_target_dst_masked = pred_src_dst_sigm * target_dstm_sigm
|
if self.is_training_mode:
|
||||||
psd_target_dst_anti_masked = pred_src_dst_sigm * target_dstm_anti_sigm
|
def optimizer():
|
||||||
|
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
||||||
|
|
||||||
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
|
|
||||||
|
|
||||||
def optimizer():
|
if self.options['archi'] == 'liae':
|
||||||
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
src_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||||
|
src_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||||
|
dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||||
|
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||||
|
else:
|
||||||
|
src_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
|
||||||
|
src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
||||||
|
dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
|
||||||
|
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
||||||
|
|
||||||
if self.options['face_style_power'] != 0:
|
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
|
||||||
face_style_power = self.options['face_style_power'] / 100.0
|
|
||||||
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)(psd_target_dst_masked, target_dst_masked)
|
|
||||||
|
|
||||||
if self.options['bg_style_power'] != 0:
|
if self.options['face_style_power'] != 0:
|
||||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
face_style_power = self.options['face_style_power'] / 100.0
|
||||||
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
|
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] )
|
||||||
|
|
||||||
if self.options['archi'] == 'liae':
|
if self.options['bg_style_power'] != 0:
|
||||||
src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||||
|
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
|
||||||
|
|
||||||
|
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) )
|
||||||
|
|
||||||
|
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ])
|
||||||
|
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], optimizer().get_updates(src_mask_loss, src_mask_loss_train_weights) )
|
||||||
|
|
||||||
|
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
||||||
|
self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) )
|
||||||
|
|
||||||
|
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ])
|
||||||
|
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], optimizer().get_updates(dst_mask_loss, dst_mask_loss_train_weights) )
|
||||||
|
|
||||||
|
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_src_srcm[-1], pred_dst_dst[-1], pred_dst_dstm[-1], pred_src_dst[-1], pred_src_dstm[-1]] )
|
||||||
else:
|
else:
|
||||||
src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
|
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_src_dstm[-1] ])
|
||||||
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss],
|
|
||||||
optimizer().get_updates(src_loss, src_train_weights) )
|
|
||||||
|
|
||||||
dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) )
|
|
||||||
|
|
||||||
if self.options['archi'] == 'liae':
|
|
||||||
dst_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
|
||||||
else:
|
|
||||||
dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
|
|
||||||
self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss],
|
|
||||||
optimizer().get_updates(dst_loss, dst_train_weights) )
|
|
||||||
|
|
||||||
src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
|
|
||||||
|
|
||||||
if self.options['archi'] == 'liae':
|
|
||||||
src_mask_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
|
||||||
else:
|
|
||||||
src_mask_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
|
||||||
|
|
||||||
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss],
|
|
||||||
optimizer().get_updates(src_mask_loss, src_mask_train_weights ) )
|
|
||||||
|
|
||||||
dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm))
|
|
||||||
|
|
||||||
if self.options['archi'] == 'liae':
|
|
||||||
dst_mask_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
|
||||||
else:
|
|
||||||
dst_mask_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
|
||||||
|
|
||||||
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss],
|
|
||||||
optimizer().get_updates(dst_mask_loss, dst_mask_train_weights) )
|
|
||||||
|
|
||||||
self.AE_view = K.function ([warped_src, warped_dst],[pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm])
|
|
||||||
self.AE_convert = K.function ([warped_dst],[pred_src_dst, pred_src_dstm])
|
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
f = SampleProcessor.TypeFlags
|
f = SampleProcessor.TypeFlags
|
||||||
|
|
||||||
face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF
|
face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF
|
||||||
|
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||||
debug=self.is_debug(), batch_size=self.batch_size,
|
debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
|
@ -257,7 +245,8 @@ class SAEModel(ModelBase):
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True),
|
||||||
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||||
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||||
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution] ] )
|
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution]
|
||||||
|
] )
|
||||||
])
|
])
|
||||||
#override
|
#override
|
||||||
def onSave(self):
|
def onSave(self):
|
||||||
|
@ -297,23 +286,18 @@ class SAEModel(ModelBase):
|
||||||
test_B = sample[1][1][0:4]
|
test_B = sample[1][1][0:4]
|
||||||
test_B_m = sample[1][2][0:4]
|
test_B_m = sample[1][2][0:4]
|
||||||
|
|
||||||
S = test_A
|
S, D, SS, SM, DD, DM, SD, SDM, = [ x / 2 + 0.5 for x in ([test_A,test_B] + self.AE_view ([test_A, test_B]) ) ]
|
||||||
D = test_B
|
#SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ]
|
||||||
|
|
||||||
SS, SM, DD, DM, SD, SDM = self.AE_view ([test_A, test_B])
|
st_x3 = []
|
||||||
S, D, SS, SM, DD, DM, SD, SDM = [ x / 2 + 0.5 for x in [S, D, SS, SM, DD, DM, SD, SDM] ]
|
|
||||||
|
|
||||||
SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ]
|
|
||||||
|
|
||||||
st = []
|
|
||||||
for i in range(0, len(test_A)):
|
for i in range(0, len(test_A)):
|
||||||
st.append ( np.concatenate ( (
|
st_x3.append ( np.concatenate ( (
|
||||||
S[i], SS[i], #SM[i],
|
S[i], SS[i], #SM[i],
|
||||||
D[i], DD[i], #DM[i],
|
D[i], DD[i], #DM[i],
|
||||||
SD[i], #SDM[i]
|
SD[i], #SDM[i]
|
||||||
), axis=1) )
|
), axis=1) )
|
||||||
|
|
||||||
return [ ('SAE', np.concatenate ( st, axis=0 ) ) ]
|
return [ ('SAE', np.concatenate (st_x3, axis=0 )), ]
|
||||||
|
|
||||||
def predictor_func (self, face):
|
def predictor_func (self, face):
|
||||||
face = face * 2.0 - 1.0
|
face = face * 2.0 - 1.0
|
||||||
|
@ -371,7 +355,7 @@ class SAEModel(ModelBase):
|
||||||
x = downscale(ed_dims*8)(x)
|
x = downscale(ed_dims*8)(x)
|
||||||
else:
|
else:
|
||||||
x = downscale_sep(ed_dims*2)(x)
|
x = downscale_sep(ed_dims*2)(x)
|
||||||
x = downscale_sep(ed_dims*4)(x)
|
x = downscale(ed_dims*4)(x)
|
||||||
x = downscale_sep(ed_dims*8)(x)
|
x = downscale_sep(ed_dims*8)(x)
|
||||||
|
|
||||||
x = Flatten()(x)
|
x = Flatten()(x)
|
||||||
|
@ -400,26 +384,31 @@ class SAEModel(ModelBase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def LIAEDecFlow(output_nc,ed_ch_dims=21,activation='tanh'):
|
def LIAEDecFlow(output_nc,ed_ch_dims=21,activation='tanh'):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
ed_dims = output_nc * ed_ch_dims
|
||||||
|
|
||||||
def upscale (dim):
|
def upscale (dim):
|
||||||
def func(x):
|
def func(x):
|
||||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
def to_bgr ():
|
||||||
|
def func(x):
|
||||||
|
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x)
|
||||||
|
return func
|
||||||
def func(input):
|
def func(input):
|
||||||
ed_dims = output_nc * ed_ch_dims
|
|
||||||
|
|
||||||
x = input[0]
|
x = input[0]
|
||||||
x = upscale(ed_dims*8)(x)
|
x1 = upscale(ed_dims*8)( x )
|
||||||
x = upscale(ed_dims*4)(x)
|
x1_bgr = to_bgr() ( x1 )
|
||||||
x = upscale(ed_dims*2)(x)
|
|
||||||
|
|
||||||
x = Conv2D(output_nc, kernel_size=5, padding='same', activation=activation)(x)
|
x2 = upscale(ed_dims*4)( x1 )
|
||||||
return x
|
x2_bgr = to_bgr() ( x2 )
|
||||||
|
|
||||||
|
x3 = upscale(ed_dims*2)( x2 )
|
||||||
|
x3_bgr = to_bgr() ( x3 )
|
||||||
|
|
||||||
|
return [ x1_bgr, x2_bgr, x3_bgr ]
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def DFEncFlow(resolution, adapt_k_size, light_enc, ae_dims=512, ed_ch_dims=42):
|
def DFEncFlow(resolution, adapt_k_size, light_enc, ae_dims=512, ed_ch_dims=42):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
@ -466,7 +455,7 @@ class SAEModel(ModelBase):
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def DFDecFlow(output_nc, ed_ch_dims=21, activation='tanh'):
|
def DFDecFlow(output_nc, ed_ch_dims=21):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
ed_dims = output_nc * ed_ch_dims
|
ed_dims = output_nc * ed_ch_dims
|
||||||
|
|
||||||
|
@ -474,15 +463,23 @@ class SAEModel(ModelBase):
|
||||||
def func(x):
|
def func(x):
|
||||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
def to_bgr ():
|
||||||
|
def func(x):
|
||||||
|
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x)
|
||||||
|
return func
|
||||||
def func(input):
|
def func(input):
|
||||||
x = input[0]
|
x = input[0]
|
||||||
x = upscale(ed_dims*8)(x)
|
x1 = upscale(ed_dims*8)( x )
|
||||||
x = upscale(ed_dims*4)(x)
|
x1_bgr = to_bgr() ( x1 )
|
||||||
x = upscale(ed_dims*2)(x)
|
|
||||||
|
|
||||||
x = Conv2D(output_nc, kernel_size=5, padding='same', activation=activation)(x)
|
x2 = upscale(ed_dims*4)( x1 )
|
||||||
return x
|
x2_bgr = to_bgr() ( x2 )
|
||||||
|
|
||||||
|
x3 = upscale(ed_dims*2)( x2 )
|
||||||
|
x3_bgr = to_bgr() ( x3 )
|
||||||
|
|
||||||
|
return [ x1_bgr, x2_bgr, x3_bgr ]
|
||||||
return func
|
return func
|
||||||
|
|
||||||
Model = SAEModel
|
Model = SAEModel
|
|
@ -75,6 +75,7 @@ Conv2D = keras.layers.Conv2D
|
||||||
Conv2DTranspose = keras.layers.Conv2DTranspose
|
Conv2DTranspose = keras.layers.Conv2DTranspose
|
||||||
SeparableConv2D = keras.layers.SeparableConv2D
|
SeparableConv2D = keras.layers.SeparableConv2D
|
||||||
MaxPooling2D = keras.layers.MaxPooling2D
|
MaxPooling2D = keras.layers.MaxPooling2D
|
||||||
|
UpSampling2D = keras.layers.UpSampling2D
|
||||||
BatchNormalization = keras.layers.BatchNormalization
|
BatchNormalization = keras.layers.BatchNormalization
|
||||||
|
|
||||||
LeakyReLU = keras.layers.LeakyReLU
|
LeakyReLU = keras.layers.LeakyReLU
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue