SAE: changed default style power to 10.0 . Now style power is floating number in valid range 0.0 to 100.0

This commit is contained in:
iperov 2019-01-14 21:24:33 +04:00
parent 7233b24d2c
commit ba06a71fff
3 changed files with 38 additions and 13 deletions

View file

@ -35,17 +35,19 @@ class SAEModel(ModelBase):
self.options['archi'] = self.options.get('archi', default_archi) self.options['archi'] = self.options.get('archi', default_archi)
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False) self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
default_face_style_power = 10.0
if is_first_run or ask_override: if is_first_run or ask_override:
default_style_power = 100 if is_first_run else self.options.get('face_style_power', 100) default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
self.options['face_style_power'] = np.clip ( input_int("Face style power (0..100 ?:help skip:%d) : " % (default_style_power), default_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0, 100 ) self.options['face_style_power'] = np.clip ( input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_face_style_power), default_face_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0.0, 100.0 )
else: else:
self.options['face_style_power'] = self.options.get('face_style_power', 100) self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
default_bg_style_power = 10.0
if is_first_run or ask_override: if is_first_run or ask_override:
default_style_power = 100 if is_first_run else self.options.get('bg_style_power', 100) default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
self.options['bg_style_power'] = np.clip ( input_int("Background style power (0..100 ?:help skip:%d) : " % (default_style_power), default_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0, 100 ) self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0.0, 100.0 )
else: else:
self.options['bg_style_power'] = self.options.get('bg_style_power', 100) self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
default_ed_ch_dims = 42 default_ed_ch_dims = 42
@ -187,6 +189,9 @@ class SAEModel(ModelBase):
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) ) src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
def optimizer():
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
if self.options['face_style_power'] != 0: if self.options['face_style_power'] != 0:
face_style_power = self.options['face_style_power'] / 100.0 face_style_power = self.options['face_style_power'] / 100.0
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)(psd_target_dst_masked, target_dst_masked) src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)(psd_target_dst_masked, target_dst_masked)
@ -200,7 +205,7 @@ class SAEModel(ModelBase):
else: else:
src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss],
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_loss, src_train_weights) ) optimizer().get_updates(src_loss, src_train_weights) )
dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) ) dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) )
@ -209,7 +214,7 @@ class SAEModel(ModelBase):
else: else:
dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss], self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss],
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_loss, dst_train_weights) ) optimizer().get_updates(dst_loss, dst_train_weights) )
src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm)) src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
@ -219,7 +224,7 @@ class SAEModel(ModelBase):
src_mask_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights src_mask_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss],
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_mask_loss, src_mask_train_weights ) ) optimizer().get_updates(src_mask_loss, src_mask_train_weights ) )
dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm)) dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm))
@ -229,7 +234,7 @@ class SAEModel(ModelBase):
dst_mask_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights dst_mask_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss],
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_mask_loss, dst_mask_train_weights) ) optimizer().get_updates(dst_mask_loss, dst_mask_train_weights) )
self.AE_view = K.function ([warped_src, warped_dst],[pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm]) self.AE_view = K.function ([warped_src, warped_dst],[pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm])
self.AE_convert = K.function ([warped_dst],[pred_src_dst, pred_src_dstm]) self.AE_convert = K.function ([warped_dst],[pred_src_dst, pred_src_dstm])

View file

@ -109,6 +109,7 @@ AddUniformNoise = nnlib.AddUniformNoise
keras_contrib = nnlib.keras_contrib keras_contrib = nnlib.keras_contrib
GroupNormalization = keras_contrib.layers.GroupNormalization GroupNormalization = keras_contrib.layers.GroupNormalization
InstanceNormalization = keras_contrib.layers.InstanceNormalization InstanceNormalization = keras_contrib.layers.InstanceNormalization
Padam = keras_contrib.optimizers.Padam
""" """
code_import_dlib_string = \ code_import_dlib_string = \
""" """

View file

@ -3,6 +3,25 @@ import sys
import time import time
import multiprocessing import multiprocessing
def input_number(s, default_value, valid_list=None, help_message=None):
while True:
try:
inp = input(s)
if len(inp) == 0:
raise ValueError("")
if help_message is not None and inp == '?':
print (help_message)
continue
i = float(inp)
if (valid_list is not None) and (i not in valid_list):
return default_value
return i
except:
print (default_value)
return default_value
def input_int(s, default_value, valid_list=None, help_message=None): def input_int(s, default_value, valid_list=None, help_message=None):
while True: while True:
try: try: