SAE: forgot to remove normalizing from tanh

This commit is contained in:
iperov 2019-03-17 13:03:49 +04:00
parent 4ff88865ce
commit d71a310fd7

View file

@ -267,19 +267,19 @@ class SAEModel(ModelBase):
pred_src_srcm, pred_dst_dstm, pred_src_dstm = [ [x] if type(x) != list else x for x in [pred_src_srcm, pred_dst_dstm, pred_src_dstm] ] pred_src_srcm, pred_dst_dstm, pred_src_dstm = [ [x] if type(x) != list else x for x in [pred_src_srcm, pred_dst_dstm, pred_src_dstm] ]
target_srcm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_srcm_ar] target_srcm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_srcm_ar]
target_srcm_sigm_ar = [ x / 2.0 + 0.5 for x in target_srcm_blurred_ar] target_srcm_sigm_ar = target_srcm_blurred_ar #[ x / 2.0 + 0.5 for x in target_srcm_blurred_ar]
target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar] target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar]
target_dstm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_dstm_ar] target_dstm_blurred_ar = [ gaussian_blur( max(1, K.int_shape(x)[1] // 32) )(x) for x in target_dstm_ar]
target_dstm_sigm_ar = [ x / 2.0 + 0.5 for x in target_dstm_blurred_ar] target_dstm_sigm_ar = target_dstm_blurred_ar#[ x / 2.0 + 0.5 for x in target_dstm_blurred_ar]
target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar] target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar]
target_src_sigm_ar = [ x + 1 for x in target_src_ar] target_src_sigm_ar = target_src_ar#[ x + 1 for x in target_src_ar]
target_dst_sigm_ar = [ x + 1 for x in target_dst_ar] target_dst_sigm_ar = target_dst_ar#[ x + 1 for x in target_dst_ar]
pred_src_src_sigm_ar = [ x + 1 for x in pred_src_src] pred_src_src_sigm_ar = pred_src_src#[ x + 1 for x in pred_src_src]
pred_dst_dst_sigm_ar = [ x + 1 for x in pred_dst_dst] pred_dst_dst_sigm_ar = pred_dst_dst#[ x + 1 for x in pred_dst_dst]
pred_src_dst_sigm_ar = [ x + 1 for x in pred_src_dst] pred_src_dst_sigm_ar = pred_src_dst#[ x + 1 for x in pred_src_dst]
target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))] target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))] target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
@ -311,7 +311,7 @@ class SAEModel(ModelBase):
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
src_loss_batch = sum([ ( 100*K.square( dssim(kernel_size=int(resolution/11.6),max_value=2.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ]) src_loss_batch = sum([ ( 100*K.square( dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ])
else: else:
src_loss_batch = sum([ K.mean ( 100*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ]) src_loss_batch = sum([ K.mean ( 100*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ])
@ -325,13 +325,13 @@ class SAEModel(ModelBase):
bg_style_power = self.options['bg_style_power'] / 100.0 bg_style_power = self.options['bg_style_power'] / 100.0
if bg_style_power != 0: if bg_style_power != 0:
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
bg_loss = K.mean( (100*bg_style_power)*K.square(dssim(kernel_size=int(resolution/11.6),max_value=2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] ))) bg_loss = K.mean( (100*bg_style_power)*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
else: else:
bg_loss = K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] )) bg_loss = K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
src_loss += bg_loss src_loss += bg_loss
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
dst_loss_batch = sum([ ( 100*K.square(dssim(kernel_size=int(resolution/11.6),max_value=2.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ]) dst_loss_batch = sum([ ( 100*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ])
else: else:
dst_loss_batch = sum([ K.mean ( 100*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ]) dst_loss_batch = sum([ K.mean ( 100*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])