mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
SAE: added pixel loss option.
This commit is contained in:
parent
72646becd1
commit
ef3cf392c5
2 changed files with 23 additions and 13 deletions
|
@ -50,6 +50,12 @@ class SAEModel(ModelBase):
|
|||
self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces. If style is learned good enough, set this value to 0.1-0.3 to prevent artifacts appearing."), 0.0, 100.0 )
|
||||
else:
|
||||
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
||||
|
||||
if is_first_run or ask_override:
|
||||
default_pixel_loss = False if is_first_run else self.options.get('pixel_loss', False)
|
||||
self.options['pixel_loss'] = input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", default_pixel_loss, help_message="Default perceptual (DSSIM) loss good for initial understanding structure of faces. Use pixel loss after 20-30k epochs to enhance fine details.")
|
||||
else:
|
||||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||
|
||||
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
||||
default_ed_ch_dims = 42
|
||||
|
@ -207,20 +213,30 @@ class SAEModel(ModelBase):
|
|||
if self.options['learn_mask']:
|
||||
src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
||||
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
||||
|
||||
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
|
||||
|
||||
|
||||
if self.options['pixel_loss']:
|
||||
src_loss = sum([ K.mean( 100*K.square( target_src_masked_ar[i] - pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] )) for i in range(len(target_src_masked_ar)) ])
|
||||
else:
|
||||
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
|
||||
|
||||
if self.options['face_style_power'] != 0:
|
||||
face_style_power = self.options['face_style_power'] / 100.0
|
||||
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] )
|
||||
|
||||
if self.options['bg_style_power'] != 0:
|
||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
|
||||
|
||||
if self.options['pixel_loss']:
|
||||
src_loss += K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
|
||||
else:
|
||||
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
|
||||
|
||||
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) )
|
||||
|
||||
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
||||
if self.options['pixel_loss']:
|
||||
dst_loss = sum([ K.mean( 100*K.square( target_dst_masked_ar[i] - pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] )) for i in range(len(target_dst_masked_ar)) ])
|
||||
else:
|
||||
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
||||
|
||||
self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) )
|
||||
|
||||
if self.options['learn_mask']:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue