diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 30d85d5..ff9a922 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -31,6 +31,8 @@ class SAEModel(ModelBase): #first run self.options['resolution'] = input_int("Resolution (64,128, ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.") self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower() + self.options['learn_face_style'] = input_bool("Learn face style? (y/n skip:y) : ", True) + self.options['learn_bg_style'] = input_bool("Learn background style? (y/n skip:y) : ", True) self.options['style_power'] = np.clip ( input_int("Style power (1..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst style during generalization of src and dst faces."), 1, 100 ) default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 self.options['ae_dims'] = input_int("AutoEncoder dims (128,256,512 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, [128,256,512], help_message="More dims are better, but requires more VRAM." ) @@ -38,12 +40,14 @@ class SAEModel(ModelBase): else: #not first run self.options['resolution'] = self.options.get('resolution', default_resolution) + self.options['learn_face_style'] = self.options.get('learn_face_style', True) + self.options['learn_bg_style'] = self.options.get('learn_bg_style', True) self.options['archi'] = self.options.get('archi', default_archi) self.options['style_power'] = self.options.get('style_power', default_style_power) default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims) self.options['face_type'] = self.options.get('face_type', default_face_type) - + #override def onInitialize(self, **in_options): exec(nnlib.import_all(), locals(), globals()) @@ -163,8 +167,12 @@ class SAEModel(ModelBase): style_power = self.options['style_power'] / 100.0 src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) ) - src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*style_power)(psd_target_dst_masked, target_dst_masked) - src_loss += K.mean( (100*style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))) + + if self.options['learn_face_style']: + src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*style_power)(psd_target_dst_masked, target_dst_masked) + + if self.options['learn_bg_style']: + src_loss += K.mean( (100*style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))) if self.options['archi'] == 'liae': src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights