mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
SAE: changed default style power to 10.0 . Now style power is floating number in valid range 0.0 to 100.0
This commit is contained in:
parent
7233b24d2c
commit
ba06a71fff
3 changed files with 38 additions and 13 deletions
|
@ -35,17 +35,19 @@ class SAEModel(ModelBase):
|
|||
self.options['archi'] = self.options.get('archi', default_archi)
|
||||
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
||||
|
||||
default_face_style_power = 10.0
|
||||
if is_first_run or ask_override:
|
||||
default_style_power = 100 if is_first_run else self.options.get('face_style_power', 100)
|
||||
self.options['face_style_power'] = np.clip ( input_int("Face style power (0..100 ?:help skip:%d) : " % (default_style_power), default_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0, 100 )
|
||||
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
|
||||
self.options['face_style_power'] = np.clip ( input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_face_style_power), default_face_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0.0, 100.0 )
|
||||
else:
|
||||
self.options['face_style_power'] = self.options.get('face_style_power', 100)
|
||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||
|
||||
default_bg_style_power = 10.0
|
||||
if is_first_run or ask_override:
|
||||
default_style_power = 100 if is_first_run else self.options.get('bg_style_power', 100)
|
||||
self.options['bg_style_power'] = np.clip ( input_int("Background style power (0..100 ?:help skip:%d) : " % (default_style_power), default_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0, 100 )
|
||||
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
|
||||
self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0.0, 100.0 )
|
||||
else:
|
||||
self.options['bg_style_power'] = self.options.get('bg_style_power', 100)
|
||||
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
||||
|
||||
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
||||
default_ed_ch_dims = 42
|
||||
|
@ -187,6 +189,9 @@ class SAEModel(ModelBase):
|
|||
|
||||
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
|
||||
|
||||
def optimizer():
|
||||
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
||||
|
||||
if self.options['face_style_power'] != 0:
|
||||
face_style_power = self.options['face_style_power'] / 100.0
|
||||
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)(psd_target_dst_masked, target_dst_masked)
|
||||
|
@ -200,7 +205,7 @@ class SAEModel(ModelBase):
|
|||
else:
|
||||
src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
|
||||
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss],
|
||||
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_loss, src_train_weights) )
|
||||
optimizer().get_updates(src_loss, src_train_weights) )
|
||||
|
||||
dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) )
|
||||
|
||||
|
@ -209,7 +214,7 @@ class SAEModel(ModelBase):
|
|||
else:
|
||||
dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
|
||||
self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss],
|
||||
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_loss, dst_train_weights) )
|
||||
optimizer().get_updates(dst_loss, dst_train_weights) )
|
||||
|
||||
src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
|
||||
|
||||
|
@ -219,7 +224,7 @@ class SAEModel(ModelBase):
|
|||
src_mask_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
||||
|
||||
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss],
|
||||
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_mask_loss, src_mask_train_weights ) )
|
||||
optimizer().get_updates(src_mask_loss, src_mask_train_weights ) )
|
||||
|
||||
dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm))
|
||||
|
||||
|
@ -229,7 +234,7 @@ class SAEModel(ModelBase):
|
|||
dst_mask_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
||||
|
||||
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss],
|
||||
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_mask_loss, dst_mask_train_weights) )
|
||||
optimizer().get_updates(dst_mask_loss, dst_mask_train_weights) )
|
||||
|
||||
self.AE_view = K.function ([warped_src, warped_dst],[pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm])
|
||||
self.AE_convert = K.function ([warped_dst],[pred_src_dst, pred_src_dstm])
|
||||
|
|
|
@ -109,6 +109,7 @@ AddUniformNoise = nnlib.AddUniformNoise
|
|||
keras_contrib = nnlib.keras_contrib
|
||||
GroupNormalization = keras_contrib.layers.GroupNormalization
|
||||
InstanceNormalization = keras_contrib.layers.InstanceNormalization
|
||||
Padam = keras_contrib.optimizers.Padam
|
||||
"""
|
||||
code_import_dlib_string = \
|
||||
"""
|
||||
|
|
|
@ -3,6 +3,25 @@ import sys
|
|||
import time
|
||||
import multiprocessing
|
||||
|
||||
def input_number(s, default_value, valid_list=None, help_message=None):
|
||||
while True:
|
||||
try:
|
||||
inp = input(s)
|
||||
if len(inp) == 0:
|
||||
raise ValueError("")
|
||||
|
||||
if help_message is not None and inp == '?':
|
||||
print (help_message)
|
||||
continue
|
||||
|
||||
i = float(inp)
|
||||
if (valid_list is not None) and (i not in valid_list):
|
||||
return default_value
|
||||
return i
|
||||
except:
|
||||
print (default_value)
|
||||
return default_value
|
||||
|
||||
def input_int(s, default_value, valid_list=None, help_message=None):
|
||||
while True:
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue