mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
SAE: df-s liae-s archis
This commit is contained in:
parent
8969e80275
commit
47f9bad42b
1 changed files with 14 additions and 9 deletions
|
@ -51,11 +51,11 @@ class SAEModel(ModelBase):
|
||||||
self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1)
|
self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1)
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
self.options['archi'] = io.input_str ("AE architecture (df, liae ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes.").lower()
|
self.options['archi'] = io.input_str ("AE architecture (df, liae, df-s, liae-s ?:help skip:%s) : " % (default_archi) , default_archi, ['df','df-s','liae','liae-s'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. -s version is slower, but has decreased change to collapse.").lower()
|
||||||
else:
|
else:
|
||||||
self.options['archi'] = self.options.get('archi', default_archi)
|
self.options['archi'] = self.options.get('archi', default_archi)
|
||||||
|
|
||||||
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
default_ae_dims = 256 if 'liae' in self.options['archi'] else 512
|
||||||
default_e_ch_dims = 42
|
default_e_ch_dims = 42
|
||||||
default_d_ch_dims = default_e_ch_dims // 2
|
default_d_ch_dims = default_e_ch_dims // 2
|
||||||
|
|
||||||
|
@ -125,13 +125,18 @@ class SAEModel(ModelBase):
|
||||||
target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
||||||
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
||||||
|
|
||||||
padding = 'reflect' if self.options['remove_gray_border'] else 'zero'
|
padding = 'zero'
|
||||||
|
norm = ''
|
||||||
|
|
||||||
|
if '-s' in self.options['archi']:
|
||||||
|
norm = 'bn'
|
||||||
|
|
||||||
common_flow_kwargs = { 'padding': padding,
|
common_flow_kwargs = { 'padding': padding,
|
||||||
'norm': 'bn',
|
'norm': norm,
|
||||||
'act':'' }
|
'act':'' }
|
||||||
|
|
||||||
weights_to_load = []
|
weights_to_load = []
|
||||||
if self.options['archi'] == 'liae':
|
if 'liae' in self.options['archi']:
|
||||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
||||||
|
|
||||||
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||||
|
@ -175,7 +180,7 @@ class SAEModel(ModelBase):
|
||||||
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
||||||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||||
|
|
||||||
elif self.options['archi'] == 'df':
|
elif 'df' in self.options['archi']:
|
||||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
self.encoder = modelify(SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
||||||
|
|
||||||
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||||
|
@ -248,7 +253,7 @@ class SAEModel(ModelBase):
|
||||||
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
|
|
||||||
if self.options['archi'] == 'liae':
|
if 'liae' in self.options['archi']:
|
||||||
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||||
|
@ -344,7 +349,7 @@ class SAEModel(ModelBase):
|
||||||
[self.src_dst_mask_opt, 'src_dst_mask_opt']
|
[self.src_dst_mask_opt, 'src_dst_mask_opt']
|
||||||
]
|
]
|
||||||
ar = []
|
ar = []
|
||||||
if self.options['archi'] == 'liae':
|
if 'liae' in self.options['archi']:
|
||||||
ar += [[self.encoder, 'encoder.h5'],
|
ar += [[self.encoder, 'encoder.h5'],
|
||||||
[self.inter_B, 'inter_B.h5'],
|
[self.inter_B, 'inter_B.h5'],
|
||||||
[self.inter_AB, 'inter_AB.h5'],
|
[self.inter_AB, 'inter_AB.h5'],
|
||||||
|
@ -352,7 +357,7 @@ class SAEModel(ModelBase):
|
||||||
]
|
]
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
ar += [ [self.decoderm, 'decoderm.h5'] ]
|
ar += [ [self.decoderm, 'decoderm.h5'] ]
|
||||||
elif self.options['archi'] == 'df':
|
elif 'df' in self.options['archi']:
|
||||||
ar += [[self.encoder, 'encoder.h5'],
|
ar += [[self.encoder, 'encoder.h5'],
|
||||||
[self.decoder_src, 'decoder_src.h5'],
|
[self.decoder_src, 'decoder_src.h5'],
|
||||||
[self.decoder_dst, 'decoder_dst.h5']
|
[self.decoder_dst, 'decoder_dst.h5']
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue