mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
SAE: added pixel loss option.
This commit is contained in:
parent
72646becd1
commit
ef3cf392c5
2 changed files with 23 additions and 13 deletions
|
@ -106,14 +106,8 @@ class ModelBase(object):
|
||||||
self.options.pop('target_epoch')
|
self.options.pop('target_epoch')
|
||||||
|
|
||||||
self.batch_size = self.options['batch_size']
|
self.batch_size = self.options['batch_size']
|
||||||
|
|
||||||
self.sort_by_yaw = self.options['sort_by_yaw']
|
self.sort_by_yaw = self.options['sort_by_yaw']
|
||||||
if not self.sort_by_yaw:
|
|
||||||
self.options.pop('sort_by_yaw')
|
|
||||||
|
|
||||||
self.random_flip = self.options['random_flip']
|
self.random_flip = self.options['random_flip']
|
||||||
if self.random_flip:
|
|
||||||
self.options.pop('random_flip')
|
|
||||||
|
|
||||||
self.src_scale_mod = self.options['src_scale_mod']
|
self.src_scale_mod = self.options['src_scale_mod']
|
||||||
if self.src_scale_mod == 0:
|
if self.src_scale_mod == 0:
|
||||||
|
|
|
@ -51,6 +51,12 @@ class SAEModel(ModelBase):
|
||||||
else:
|
else:
|
||||||
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
||||||
|
|
||||||
|
if is_first_run or ask_override:
|
||||||
|
default_pixel_loss = False if is_first_run else self.options.get('pixel_loss', False)
|
||||||
|
self.options['pixel_loss'] = input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", default_pixel_loss, help_message="Default perceptual (DSSIM) loss good for initial understanding structure of faces. Use pixel loss after 20-30k epochs to enhance fine details.")
|
||||||
|
else:
|
||||||
|
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||||
|
|
||||||
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
||||||
default_ed_ch_dims = 42
|
default_ed_ch_dims = 42
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
|
@ -208,6 +214,9 @@ class SAEModel(ModelBase):
|
||||||
src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
||||||
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
||||||
|
|
||||||
|
if self.options['pixel_loss']:
|
||||||
|
src_loss = sum([ K.mean( 100*K.square( target_src_masked_ar[i] - pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] )) for i in range(len(target_src_masked_ar)) ])
|
||||||
|
else:
|
||||||
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
|
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
|
||||||
|
|
||||||
if self.options['face_style_power'] != 0:
|
if self.options['face_style_power'] != 0:
|
||||||
|
@ -216,11 +225,18 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
if self.options['bg_style_power'] != 0:
|
if self.options['bg_style_power'] != 0:
|
||||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||||
|
if self.options['pixel_loss']:
|
||||||
|
src_loss += K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
|
||||||
|
else:
|
||||||
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
|
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
|
||||||
|
|
||||||
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) )
|
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) )
|
||||||
|
|
||||||
|
if self.options['pixel_loss']:
|
||||||
|
dst_loss = sum([ K.mean( 100*K.square( target_dst_masked_ar[i] - pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] )) for i in range(len(target_dst_masked_ar)) ])
|
||||||
|
else:
|
||||||
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
||||||
|
|
||||||
self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) )
|
self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) )
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue