SAE: added options archi, learn_face_style , learn_bg_style, style_power

This commit is contained in:
iperov 2019-01-09 10:51:31 +04:00
commit e300662f89

View file

@ -31,6 +31,8 @@ class SAEModel(ModelBase):
#first run
self.options['resolution'] = input_int("Resolution (64,128, ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
self.options['learn_face_style'] = input_bool("Learn face style? (y/n skip:y) : ", True)
self.options['learn_bg_style'] = input_bool("Learn background style? (y/n skip:y) : ", True)
self.options['style_power'] = np.clip ( input_int("Style power (1..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst style during generalization of src and dst faces."), 1, 100 )
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
self.options['ae_dims'] = input_int("AutoEncoder dims (128,256,512 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, [128,256,512], help_message="More dims are better, but requires more VRAM." )
@ -38,12 +40,14 @@ class SAEModel(ModelBase):
else:
#not first run
self.options['resolution'] = self.options.get('resolution', default_resolution)
self.options['learn_face_style'] = self.options.get('learn_face_style', True)
self.options['learn_bg_style'] = self.options.get('learn_bg_style', True)
self.options['archi'] = self.options.get('archi', default_archi)
self.options['style_power'] = self.options.get('style_power', default_style_power)
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
self.options['face_type'] = self.options.get('face_type', default_face_type)
#override
def onInitialize(self, **in_options):
exec(nnlib.import_all(), locals(), globals())
@ -163,8 +167,12 @@ class SAEModel(ModelBase):
style_power = self.options['style_power'] / 100.0
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*style_power)(psd_target_dst_masked, target_dst_masked)
src_loss += K.mean( (100*style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
if self.options['learn_face_style']:
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*style_power)(psd_target_dst_masked, target_dst_masked)
if self.options['learn_bg_style']:
src_loss += K.mean( (100*style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
if self.options['archi'] == 'liae':
src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights