mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 13:32:09 -07:00
SAE: df-s liae-s archis
This commit is contained in:
parent
8969e80275
commit
47f9bad42b
1 changed files with 14 additions and 9 deletions
|
@ -51,11 +51,11 @@ class SAEModel(ModelBase):
|
|||
self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1)
|
||||
|
||||
if is_first_run:
|
||||
self.options['archi'] = io.input_str ("AE architecture (df, liae ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes.").lower()
|
||||
self.options['archi'] = io.input_str ("AE architecture (df, liae, df-s, liae-s ?:help skip:%s) : " % (default_archi) , default_archi, ['df','df-s','liae','liae-s'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. -s version is slower, but has decreased change to collapse.").lower()
|
||||
else:
|
||||
self.options['archi'] = self.options.get('archi', default_archi)
|
||||
|
||||
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
||||
default_ae_dims = 256 if 'liae' in self.options['archi'] else 512
|
||||
default_e_ch_dims = 42
|
||||
default_d_ch_dims = default_e_ch_dims // 2
|
||||
|
||||
|
@ -125,13 +125,18 @@ class SAEModel(ModelBase):
|
|||
target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
||||
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
||||
|
||||
padding = 'reflect' if self.options['remove_gray_border'] else 'zero'
|
||||
padding = 'zero'
|
||||
norm = ''
|
||||
|
||||
if '-s' in self.options['archi']:
|
||||
norm = 'bn'
|
||||
|
||||
common_flow_kwargs = { 'padding': padding,
|
||||
'norm': 'bn',
|
||||
'norm': norm,
|
||||
'act':'' }
|
||||
|
||||
weights_to_load = []
|
||||
if self.options['archi'] == 'liae':
|
||||
if 'liae' in self.options['archi']:
|
||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
||||
|
||||
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||
|
@ -175,7 +180,7 @@ class SAEModel(ModelBase):
|
|||
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
||||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||
|
||||
elif self.options['archi'] == 'df':
|
||||
elif 'df' in self.options['archi']:
|
||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
||||
|
||||
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||
|
@ -248,7 +253,7 @@ class SAEModel(ModelBase):
|
|||
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
if 'liae' in self.options['archi']:
|
||||
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
if self.options['learn_mask']:
|
||||
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||
|
@ -344,7 +349,7 @@ class SAEModel(ModelBase):
|
|||
[self.src_dst_mask_opt, 'src_dst_mask_opt']
|
||||
]
|
||||
ar = []
|
||||
if self.options['archi'] == 'liae':
|
||||
if 'liae' in self.options['archi']:
|
||||
ar += [[self.encoder, 'encoder.h5'],
|
||||
[self.inter_B, 'inter_B.h5'],
|
||||
[self.inter_AB, 'inter_AB.h5'],
|
||||
|
@ -352,7 +357,7 @@ class SAEModel(ModelBase):
|
|||
]
|
||||
if self.options['learn_mask']:
|
||||
ar += [ [self.decoderm, 'decoderm.h5'] ]
|
||||
elif self.options['archi'] == 'df':
|
||||
elif 'df' in self.options['archi']:
|
||||
ar += [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue