SAE: dssim kernel size now depends on resolution

This commit is contained in:
iperov 2019-03-12 09:49:40 +04:00
parent fd3b9add2f
commit 46ff33bf89
3 changed files with 12 additions and 8 deletions

View file

@ -22,6 +22,8 @@ class SAEModel(ModelBase):
#override
def onInitializeOptions(self, is_first_run, ask_override):
yn_str = {True:'y',False:'n'}
default_resolution = 128
default_archi = 'df'
default_face_type = 'f'
@ -42,7 +44,7 @@ class SAEModel(ModelBase):
if is_first_run or ask_override:
def_simple_optimizer = self.options.get('simple_optimizer', False)
self.options['simple_optimizer'] = io.input_bool ("Use simple optimizer? (y/n, ?:help skip:%s) : " % ( {True:'y',False:'n'}[def_simple_optimizer] ), def_simple_optimizer, help_message="Simple optimizer allows you to train bigger network or more batch size, sacrificing training accuracy.")
self.options['simple_optimizer'] = io.input_bool ("Use simple optimizer? (y/n, ?:help skip:%s) : " % ( yn_str[def_simple_optimizer] ), def_simple_optimizer, help_message="Simple optimizer allows you to train bigger network or more batch size, sacrificing training accuracy.")
else:
self.options['simple_optimizer'] = self.options.get('simple_optimizer', False)
@ -75,7 +77,7 @@ class SAEModel(ModelBase):
default_bg_style_power = 0.0
if is_first_run or ask_override:
def_pixel_loss = self.options.get('pixel_loss', False)
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 15-25k epochs to enhance fine details and decrease face jitter.")
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: %s ) : " % (yn_str[def_pixel_loss]), def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 15-25k epochs to enhance fine details and decrease face jitter.")
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
self.options['face_style_power'] = np.clip ( io.input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_face_style_power), default_face_style_power,
@ -286,7 +288,7 @@ class SAEModel(ModelBase):
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
if not self.options['pixel_loss']:
src_loss_batch = sum([ ( 100*K.square( dssim(max_value=2.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ])
src_loss_batch = sum([ ( 100*K.square( dssim(kernel_size=int(resolution/11.6),max_value=2.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ])
else:
src_loss_batch = sum([ K.mean ( 100*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ])
@ -300,13 +302,13 @@ class SAEModel(ModelBase):
bg_style_power = self.options['bg_style_power'] / 100.0
if bg_style_power != 0:
if not self.options['pixel_loss']:
bg_loss = K.mean( (100*bg_style_power)*K.square(dssim(max_value=2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
bg_loss = K.mean( (100*bg_style_power)*K.square(dssim(kernel_size=int(resolution/11.6),max_value=2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
else:
bg_loss = K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
src_loss += bg_loss
if not self.options['pixel_loss']:
dst_loss_batch = sum([ ( 100*K.square(dssim(max_value=2.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ])
dst_loss_batch = sum([ ( 100*K.square(dssim(kernel_size=int(resolution/11.6),max_value=2.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ])
else:
dst_loss_batch = sum([ K.mean ( 100*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])