diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index e275f9c..c26e539 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -35,17 +35,19 @@ class SAEModel(ModelBase): self.options['archi'] = self.options.get('archi', default_archi) self.options['lighter_encoder'] = self.options.get('lighter_encoder', False) + default_face_style_power = 10.0 if is_first_run or ask_override: - default_style_power = 100 if is_first_run else self.options.get('face_style_power', 100) - self.options['face_style_power'] = np.clip ( input_int("Face style power (0..100 ?:help skip:%d) : " % (default_style_power), default_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0, 100 ) + default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power) + self.options['face_style_power'] = np.clip ( input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_face_style_power), default_face_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0.0, 100.0 ) else: - self.options['face_style_power'] = self.options.get('face_style_power', 100) - + self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power) + + default_bg_style_power = 10.0 if is_first_run or ask_override: - default_style_power = 100 if is_first_run else self.options.get('bg_style_power', 100) - self.options['bg_style_power'] = np.clip ( input_int("Background style power (0..100 ?:help skip:%d) : " % (default_style_power), default_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0, 100 ) + default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power) + self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0.0, 100.0 ) else: - self.options['bg_style_power'] = self.options.get('bg_style_power', 100) + self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power) default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 default_ed_ch_dims = 42 @@ -187,6 +189,9 @@ class SAEModel(ModelBase): src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) ) + def optimizer(): + return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999) + if self.options['face_style_power'] != 0: face_style_power = self.options['face_style_power'] / 100.0 src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)(psd_target_dst_masked, target_dst_masked) @@ -200,7 +205,7 @@ class SAEModel(ModelBase): else: src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], - Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_loss, src_train_weights) ) + optimizer().get_updates(src_loss, src_train_weights) ) dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) ) @@ -209,7 +214,7 @@ class SAEModel(ModelBase): else: dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss], - Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_loss, dst_train_weights) ) + optimizer().get_updates(dst_loss, dst_train_weights) ) src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm)) @@ -219,7 +224,7 @@ class SAEModel(ModelBase): src_mask_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], - Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_mask_loss, src_mask_train_weights ) ) + optimizer().get_updates(src_mask_loss, src_mask_train_weights ) ) dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm)) @@ -229,7 +234,7 @@ class SAEModel(ModelBase): dst_mask_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], - Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_mask_loss, dst_mask_train_weights) ) + optimizer().get_updates(dst_mask_loss, dst_mask_train_weights) ) self.AE_view = K.function ([warped_src, warped_dst],[pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm]) self.AE_convert = K.function ([warped_dst],[pred_src_dst, pred_src_dstm]) diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index 3ce3fd8..2d465e5 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -33,7 +33,7 @@ class nnlib(object): tf_adain = None tf_gaussian_blur = None tf_style_loss = None - + modelify = None ReflectionPadding2D = None DSSIMLoss = None @@ -109,6 +109,7 @@ AddUniformNoise = nnlib.AddUniformNoise keras_contrib = nnlib.keras_contrib GroupNormalization = keras_contrib.layers.GroupNormalization InstanceNormalization = keras_contrib.layers.InstanceNormalization +Padam = keras_contrib.optimizers.Padam """ code_import_dlib_string = \ """ @@ -378,7 +379,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator return func nnlib.tf_style_loss = tf_style_loss - + @staticmethod def import_keras(device_config = None): if nnlib.keras is not None: diff --git a/utils/console_utils.py b/utils/console_utils.py index 1b2638a..197e5e9 100644 --- a/utils/console_utils.py +++ b/utils/console_utils.py @@ -3,6 +3,25 @@ import sys import time import multiprocessing +def input_number(s, default_value, valid_list=None, help_message=None): + while True: + try: + inp = input(s) + if len(inp) == 0: + raise ValueError("") + + if help_message is not None and inp == '?': + print (help_message) + continue + + i = float(inp) + if (valid_list is not None) and (i not in valid_list): + return default_value + return i + except: + print (default_value) + return default_value + def input_int(s, default_value, valid_list=None, help_message=None): while True: try: