SAE fix model crashes

This commit is contained in:
iperov 2019-04-22 00:40:23 +04:00
parent 5e3bd75007
commit f54cc744bd

View file

@ -127,7 +127,7 @@ class SAEModel(ModelBase):
padding = 'reflect' if self.options['remove_gray_border'] else 'zero' padding = 'reflect' if self.options['remove_gray_border'] else 'zero'
common_flow_kwargs = { 'padding': padding, common_flow_kwargs = { 'padding': padding,
'norm': 'bn', 'norm': '',
'act':'' } 'act':'' }
weights_to_load = [] weights_to_load = []
@ -258,9 +258,9 @@ class SAEModel(ModelBase):
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
src_loss_batch = sum([ ( 100*K.square( dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ]) src_loss_batch = sum([ ( 10*K.square( dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ])
else: else:
src_loss_batch = sum([ K.mean ( 100*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ]) src_loss_batch = sum([ K.mean ( 10*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ])
src_loss = K.mean(src_loss_batch) src_loss = K.mean(src_loss_batch)
@ -272,15 +272,15 @@ class SAEModel(ModelBase):
bg_style_power = self.options['bg_style_power'] / 100.0 bg_style_power = self.options['bg_style_power'] / 100.0
if bg_style_power != 0: if bg_style_power != 0:
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
bg_loss = K.mean( (100*bg_style_power)*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] ))) bg_loss = K.mean( (10*bg_style_power)*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
else: else:
bg_loss = K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] )) bg_loss = K.mean( (10*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
src_loss += bg_loss src_loss += bg_loss
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
dst_loss_batch = sum([ ( 100*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ]) dst_loss_batch = sum([ ( 10*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ])
else: else:
dst_loss_batch = sum([ K.mean ( 100*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ]) dst_loss_batch = sum([ K.mean ( 10*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])
dst_loss = K.mean(dst_loss_batch) dst_loss = K.mean(dst_loss_batch)