SAE: added options archi, learn_face_style , learn_bg_style, style_power

This commit is contained in:
iperov 2019-01-09 10:51:31 +04:00
commit e300662f89

View file

@ -31,6 +31,8 @@ class SAEModel(ModelBase):
#first run #first run
self.options['resolution'] = input_int("Resolution (64,128, ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.") self.options['resolution'] = input_int("Resolution (64,128, ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower() self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
self.options['learn_face_style'] = input_bool("Learn face style? (y/n skip:y) : ", True)
self.options['learn_bg_style'] = input_bool("Learn background style? (y/n skip:y) : ", True)
self.options['style_power'] = np.clip ( input_int("Style power (1..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst style during generalization of src and dst faces."), 1, 100 ) self.options['style_power'] = np.clip ( input_int("Style power (1..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst style during generalization of src and dst faces."), 1, 100 )
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
self.options['ae_dims'] = input_int("AutoEncoder dims (128,256,512 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, [128,256,512], help_message="More dims are better, but requires more VRAM." ) self.options['ae_dims'] = input_int("AutoEncoder dims (128,256,512 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, [128,256,512], help_message="More dims are better, but requires more VRAM." )
@ -38,6 +40,8 @@ class SAEModel(ModelBase):
else: else:
#not first run #not first run
self.options['resolution'] = self.options.get('resolution', default_resolution) self.options['resolution'] = self.options.get('resolution', default_resolution)
self.options['learn_face_style'] = self.options.get('learn_face_style', True)
self.options['learn_bg_style'] = self.options.get('learn_bg_style', True)
self.options['archi'] = self.options.get('archi', default_archi) self.options['archi'] = self.options.get('archi', default_archi)
self.options['style_power'] = self.options.get('style_power', default_style_power) self.options['style_power'] = self.options.get('style_power', default_style_power)
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
@ -163,7 +167,11 @@ class SAEModel(ModelBase):
style_power = self.options['style_power'] / 100.0 style_power = self.options['style_power'] / 100.0
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) ) src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
if self.options['learn_face_style']:
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*style_power)(psd_target_dst_masked, target_dst_masked) src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*style_power)(psd_target_dst_masked, target_dst_masked)
if self.options['learn_bg_style']:
src_loss += K.mean( (100*style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))) src_loss += K.mean( (100*style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
if self.options['archi'] == 'liae': if self.options['archi'] == 'liae':