SAE: added pixel loss option.

This commit is contained in:
iperov 2019-01-25 09:56:40 +04:00
parent 72646becd1
commit ef3cf392c5
2 changed files with 23 additions and 13 deletions

View file

@ -106,14 +106,8 @@ class ModelBase(object):
self.options.pop('target_epoch')
self.batch_size = self.options['batch_size']
self.sort_by_yaw = self.options['sort_by_yaw']
if not self.sort_by_yaw:
self.options.pop('sort_by_yaw')
self.sort_by_yaw = self.options['sort_by_yaw']
self.random_flip = self.options['random_flip']
if self.random_flip:
self.options.pop('random_flip')
self.src_scale_mod = self.options['src_scale_mod']
if self.src_scale_mod == 0:

View file

@ -50,6 +50,12 @@ class SAEModel(ModelBase):
self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces. If style is learned good enough, set this value to 0.1-0.3 to prevent artifacts appearing."), 0.0, 100.0 )
else:
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
if is_first_run or ask_override:
default_pixel_loss = False if is_first_run else self.options.get('pixel_loss', False)
self.options['pixel_loss'] = input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", default_pixel_loss, help_message="Default perceptual (DSSIM) loss good for initial understanding structure of faces. Use pixel loss after 20-30k epochs to enhance fine details.")
else:
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
default_ed_ch_dims = 42
@ -207,20 +213,30 @@ class SAEModel(ModelBase):
if self.options['learn_mask']:
src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
if self.options['pixel_loss']:
src_loss = sum([ K.mean( 100*K.square( target_src_masked_ar[i] - pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] )) for i in range(len(target_src_masked_ar)) ])
else:
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
if self.options['face_style_power'] != 0:
face_style_power = self.options['face_style_power'] / 100.0
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] )
if self.options['bg_style_power'] != 0:
bg_style_power = self.options['bg_style_power'] / 100.0
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
if self.options['pixel_loss']:
src_loss += K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
else:
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) )
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
if self.options['pixel_loss']:
dst_loss = sum([ K.mean( 100*K.square( target_dst_masked_ar[i] - pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] )) for i in range(len(target_dst_masked_ar)) ])
else:
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) )
if self.options['learn_mask']: