From d71a310fd7e5d7d067209e778a6bd03e370e96ab Mon Sep 17 00:00:00 2001 From: iperov Date: Sun, 17 Mar 2019 13:03:49 +0400 Subject: [PATCH] SAE: forgot to remove normalizing from tanh --- models/Model_SAE/Model.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 2616461..7807e4b 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -267,19 +267,19 @@ class SAEModel(ModelBase): pred_src_srcm, pred_dst_dstm, pred_src_dstm = [ [x] if type(x) != list else x for x in [pred_src_srcm, pred_dst_dstm, pred_src_dstm] ] target_srcm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_srcm_ar] - target_srcm_sigm_ar = [ x / 2.0 + 0.5 for x in target_srcm_blurred_ar] + target_srcm_sigm_ar = target_srcm_blurred_ar #[ x / 2.0 + 0.5 for x in target_srcm_blurred_ar] target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar] target_dstm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_dstm_ar] - target_dstm_sigm_ar = [ x / 2.0 + 0.5 for x in target_dstm_blurred_ar] + target_dstm_sigm_ar = target_dstm_blurred_ar#[ x / 2.0 + 0.5 for x in target_dstm_blurred_ar] target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar] - target_src_sigm_ar = [ x + 1 for x in target_src_ar] - target_dst_sigm_ar = [ x + 1 for x in target_dst_ar] + target_src_sigm_ar = target_src_ar#[ x + 1 for x in target_src_ar] + target_dst_sigm_ar = target_dst_ar#[ x + 1 for x in target_dst_ar] - pred_src_src_sigm_ar = [ x + 1 for x in pred_src_src] - pred_dst_dst_sigm_ar = [ x + 1 for x in pred_dst_dst] - pred_src_dst_sigm_ar = [ x + 1 for x in pred_src_dst] + pred_src_src_sigm_ar = pred_src_src#[ x + 1 for x in pred_src_src] + pred_dst_dst_sigm_ar = pred_dst_dst#[ x + 1 for x in pred_dst_dst] + pred_src_dst_sigm_ar = pred_src_dst#[ x + 1 for x in pred_src_dst] target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))] target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))] @@ -311,7 +311,7 @@ class SAEModel(ModelBase): src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights if not self.options['pixel_loss']: - src_loss_batch = sum([ ( 100*K.square( dssim(kernel_size=int(resolution/11.6),max_value=2.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ]) + src_loss_batch = sum([ ( 100*K.square( dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ]) else: src_loss_batch = sum([ K.mean ( 100*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ]) @@ -325,13 +325,13 @@ class SAEModel(ModelBase): bg_style_power = self.options['bg_style_power'] / 100.0 if bg_style_power != 0: if not self.options['pixel_loss']: - bg_loss = K.mean( (100*bg_style_power)*K.square(dssim(kernel_size=int(resolution/11.6),max_value=2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] ))) + bg_loss = K.mean( (100*bg_style_power)*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] ))) else: bg_loss = K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] )) src_loss += bg_loss if not self.options['pixel_loss']: - dst_loss_batch = sum([ ( 100*K.square(dssim(kernel_size=int(resolution/11.6),max_value=2.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ]) + dst_loss_batch = sum([ ( 100*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ]) else: dst_loss_batch = sum([ K.mean ( 100*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])