diff --git a/models/ModelBase.py b/models/ModelBase.py index 7fc166f..8b66a5f 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -106,14 +106,8 @@ class ModelBase(object): self.options.pop('target_epoch') self.batch_size = self.options['batch_size'] - - self.sort_by_yaw = self.options['sort_by_yaw'] - if not self.sort_by_yaw: - self.options.pop('sort_by_yaw') - + self.sort_by_yaw = self.options['sort_by_yaw'] self.random_flip = self.options['random_flip'] - if self.random_flip: - self.options.pop('random_flip') self.src_scale_mod = self.options['src_scale_mod'] if self.src_scale_mod == 0: diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index cc46e91..c3d445c 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -50,6 +50,12 @@ class SAEModel(ModelBase): self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces. If style is learned good enough, set this value to 0.1-0.3 to prevent artifacts appearing."), 0.0, 100.0 ) else: self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power) + + if is_first_run or ask_override: + default_pixel_loss = False if is_first_run else self.options.get('pixel_loss', False) + self.options['pixel_loss'] = input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", default_pixel_loss, help_message="Default perceptual (DSSIM) loss good for initial understanding structure of faces. Use pixel loss after 20-30k epochs to enhance fine details.") + else: + self.options['pixel_loss'] = self.options.get('pixel_loss', False) default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 default_ed_ch_dims = 42 @@ -207,20 +213,30 @@ class SAEModel(ModelBase): if self.options['learn_mask']: src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights - - src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ]) - + + if self.options['pixel_loss']: + src_loss = sum([ K.mean( 100*K.square( target_src_masked_ar[i] - pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] )) for i in range(len(target_src_masked_ar)) ]) + else: + src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ]) + if self.options['face_style_power'] != 0: face_style_power = self.options['face_style_power'] / 100.0 src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] ) if self.options['bg_style_power'] != 0: bg_style_power = self.options['bg_style_power'] / 100.0 - src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] ))) - + if self.options['pixel_loss']: + src_loss += K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] )) + else: + src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] ))) + self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) ) - dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ]) + if self.options['pixel_loss']: + dst_loss = sum([ K.mean( 100*K.square( target_dst_masked_ar[i] - pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] )) for i in range(len(target_dst_masked_ar)) ]) + else: + dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ]) + self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) ) if self.options['learn_mask']: