mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
SAE: increased speed of training by 10-18%,
increased clipping border mask in full face mode results better transition of cheeks, default archi now 'df'
This commit is contained in:
parent
41abda42d2
commit
e226ab5385
2 changed files with 65 additions and 70 deletions
|
@ -23,7 +23,7 @@ class SAEModel(ModelBase):
|
|||
#override
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
default_resolution = 128
|
||||
default_archi = 'liae'
|
||||
default_archi = 'df'
|
||||
default_face_type = 'f'
|
||||
|
||||
if is_first_run:
|
||||
|
@ -114,29 +114,26 @@ class SAEModel(ModelBase):
|
|||
if self.options['learn_mask']:
|
||||
self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5))
|
||||
|
||||
warped_src_code = self.encoder (warped_src)
|
||||
|
||||
warped_src_code = self.encoder (warped_src)
|
||||
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
||||
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
||||
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
||||
|
||||
pred_src_src = self.decoder(warped_src_inter_code)
|
||||
if self.options['learn_mask']:
|
||||
pred_src_srcm = self.decoderm(warped_src_inter_code)
|
||||
|
||||
warped_dst_code = self.encoder (warped_dst)
|
||||
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
|
||||
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
|
||||
warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
|
||||
pred_dst_dst = self.decoder(warped_dst_inter_code)
|
||||
|
||||
if self.options['learn_mask']:
|
||||
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
||||
|
||||
|
||||
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
|
||||
|
||||
pred_src_src = self.decoder(warped_src_inter_code)
|
||||
pred_dst_dst = self.decoder(warped_dst_inter_code)
|
||||
pred_src_dst = self.decoder(warped_src_dst_inter_code)
|
||||
|
||||
if self.options['learn_mask']:
|
||||
pred_src_srcm = self.decoderm(warped_src_inter_code)
|
||||
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
||||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||
|
||||
else:
|
||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, adapt_k_size, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||
|
||||
|
@ -162,11 +159,13 @@ class SAEModel(ModelBase):
|
|||
pred_src_src = self.decoder_src(warped_src_code)
|
||||
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
||||
pred_src_dst = self.decoder_src(warped_dst_code)
|
||||
|
||||
if self.options['learn_mask']:
|
||||
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||
|
||||
|
||||
|
||||
ms_count = len(pred_src_src)
|
||||
|
||||
target_src_ar = [ target_src if i == 0 else tf.image.resize_bicubic( target_src, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||
|
@ -201,19 +200,15 @@ class SAEModel(ModelBase):
|
|||
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
||||
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
src_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
if self.options['archi'] == 'liae':
|
||||
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
if self.options['learn_mask']:
|
||||
src_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||
else:
|
||||
src_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
|
||||
dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
|
||||
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||
else:
|
||||
src_dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.decoder_dst.trainable_weights
|
||||
if self.options['learn_mask']:
|
||||
src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
||||
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
||||
|
||||
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
|
||||
|
||||
if self.options['pixel_loss']:
|
||||
src_loss = sum([ K.mean( 100*K.square( target_src_masked_ar[i] - pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] )) for i in range(len(target_src_masked_ar)) ])
|
||||
else:
|
||||
|
@ -229,24 +224,24 @@ class SAEModel(ModelBase):
|
|||
src_loss += K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
|
||||
else:
|
||||
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
|
||||
|
||||
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) )
|
||||
|
||||
|
||||
if self.options['pixel_loss']:
|
||||
dst_loss = sum([ K.mean( 100*K.square( target_dst_masked_ar[i] - pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] )) for i in range(len(target_dst_masked_ar)) ])
|
||||
else:
|
||||
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
||||
|
||||
self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) )
|
||||
self.src_dst_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss,dst_loss], optimizer().get_updates(src_loss+dst_loss, src_dst_loss_train_weights) )
|
||||
|
||||
|
||||
if self.options['learn_mask']:
|
||||
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ])
|
||||
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], optimizer().get_updates(src_mask_loss, src_mask_loss_train_weights) )
|
||||
|
||||
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ])
|
||||
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], optimizer().get_updates(dst_mask_loss, dst_mask_loss_train_weights) )
|
||||
|
||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] )
|
||||
self.src_dst_mask_train = K.function ([warped_src, target_srcm, warped_dst, target_dstm],[src_mask_loss, dst_mask_loss], optimizer().get_updates(src_mask_loss+dst_mask_loss, src_dst_mask_loss_train_weights) )
|
||||
|
||||
if self.options['learn_mask']:
|
||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1], pred_src_dstm[-1]])
|
||||
else:
|
||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] )
|
||||
|
||||
else:
|
||||
if self.options['learn_mask']:
|
||||
|
@ -299,12 +294,10 @@ class SAEModel(ModelBase):
|
|||
warped_src, target_src, target_src_mask = sample[0]
|
||||
warped_dst, target_dst, target_dst_mask = sample[1]
|
||||
|
||||
src_loss, = self.src_train ([warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
|
||||
dst_loss, = self.dst_train ([warped_dst, target_dst, target_dst_mask])
|
||||
src_loss, dst_loss = self.src_dst_train ([warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
|
||||
|
||||
if self.options['learn_mask']:
|
||||
src_mask_loss, = self.src_mask_train ([warped_src, target_src_mask])
|
||||
dst_mask_loss, = self.dst_mask_train ([warped_dst, target_dst_mask])
|
||||
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train ([warped_src, target_src_mask, warped_dst, target_dst_mask])
|
||||
|
||||
return ( ('src_loss', src_loss), ('dst_loss', dst_loss) )
|
||||
|
||||
|
@ -316,18 +309,20 @@ class SAEModel(ModelBase):
|
|||
test_B = sample[1][1][0:4]
|
||||
test_B_m = sample[1][2][0:4]
|
||||
|
||||
S, D, SS, DD, SD, = [ x / 2 + 0.5 for x in ([test_A,test_B] + self.AE_view ([test_A, test_B]) ) ]
|
||||
#SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ]
|
||||
if self.options['learn_mask']:
|
||||
S, D, SS, DD, SD, SDM = [ x / 2 + 0.5 for x in ([test_A,test_B] + self.AE_view ([test_A, test_B]) ) ]
|
||||
SDM, = [ np.repeat (x, (3,), -1) for x in [SDM] ]
|
||||
else:
|
||||
S, D, SS, DD, SD, = [ x / 2 + 0.5 for x in ([test_A,test_B] + self.AE_view ([test_A, test_B]) ) ]
|
||||
|
||||
st_x3 = []
|
||||
st = []
|
||||
for i in range(0, len(test_A)):
|
||||
st_x3.append ( np.concatenate ( (
|
||||
S[i], SS[i], #SM[i],
|
||||
D[i], DD[i], #DM[i],
|
||||
SD[i], #SDM[i]
|
||||
), axis=1) )
|
||||
ar = S[i], SS[i], D[i], DD[i], SD[i]
|
||||
if self.options['learn_mask']:
|
||||
ar += (SDM[i],)
|
||||
st.append ( np.concatenate ( ar, axis=1) )
|
||||
|
||||
return [ ('SAE', np.concatenate (st_x3, axis=0 )), ]
|
||||
return [ ('SAE', np.concatenate (st, axis=0 )), ]
|
||||
|
||||
def predictor_func (self, face):
|
||||
face_tanh = face * 2.0 - 1.0
|
||||
|
@ -354,7 +349,7 @@ class SAEModel(ModelBase):
|
|||
face_type=face_type,
|
||||
base_erode_mask_modifier=base_erode_mask_modifier,
|
||||
base_blur_mask_modifier=base_blur_mask_modifier,
|
||||
clip_border_mask_per=0.03125,
|
||||
clip_hborder_mask_per=0.0625 if self.options['face_type'] == 'f' else 0,
|
||||
**in_options)
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue