mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-14 10:46:59 -07:00
SAE: added options archi, learn_face_style , learn_bg_style, style_power
This commit is contained in:
parent
0c3559d2dd
commit
e300662f89
1 changed files with 11 additions and 3 deletions
|
@ -31,6 +31,8 @@ class SAEModel(ModelBase):
|
|||
#first run
|
||||
self.options['resolution'] = input_int("Resolution (64,128, ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
|
||||
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
||||
self.options['learn_face_style'] = input_bool("Learn face style? (y/n skip:y) : ", True)
|
||||
self.options['learn_bg_style'] = input_bool("Learn background style? (y/n skip:y) : ", True)
|
||||
self.options['style_power'] = np.clip ( input_int("Style power (1..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst style during generalization of src and dst faces."), 1, 100 )
|
||||
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
||||
self.options['ae_dims'] = input_int("AutoEncoder dims (128,256,512 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, [128,256,512], help_message="More dims are better, but requires more VRAM." )
|
||||
|
@ -38,12 +40,14 @@ class SAEModel(ModelBase):
|
|||
else:
|
||||
#not first run
|
||||
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
||||
self.options['learn_face_style'] = self.options.get('learn_face_style', True)
|
||||
self.options['learn_bg_style'] = self.options.get('learn_bg_style', True)
|
||||
self.options['archi'] = self.options.get('archi', default_archi)
|
||||
self.options['style_power'] = self.options.get('style_power', default_style_power)
|
||||
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
||||
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
|
||||
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
||||
|
||||
|
||||
#override
|
||||
def onInitialize(self, **in_options):
|
||||
exec(nnlib.import_all(), locals(), globals())
|
||||
|
@ -163,8 +167,12 @@ class SAEModel(ModelBase):
|
|||
style_power = self.options['style_power'] / 100.0
|
||||
|
||||
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
|
||||
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*style_power)(psd_target_dst_masked, target_dst_masked)
|
||||
src_loss += K.mean( (100*style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
|
||||
|
||||
if self.options['learn_face_style']:
|
||||
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*style_power)(psd_target_dst_masked, target_dst_masked)
|
||||
|
||||
if self.options['learn_bg_style']:
|
||||
src_loss += K.mean( (100*style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue