mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
SAE added CA weights
This commit is contained in:
parent
60ed56a801
commit
b6711b97a3
1 changed files with 24 additions and 11 deletions
|
@ -58,6 +58,7 @@ class SAEModel(ModelBase):
|
||||||
default_ae_dims = 256 if 'liae' in self.options['archi'] else 512
|
default_ae_dims = 256 if 'liae' in self.options['archi'] else 512
|
||||||
default_e_ch_dims = 42
|
default_e_ch_dims = 42
|
||||||
default_d_ch_dims = default_e_ch_dims // 2
|
default_d_ch_dims = default_e_ch_dims // 2
|
||||||
|
def_ca_weights = False
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
|
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
|
||||||
|
@ -65,16 +66,15 @@ class SAEModel(ModelBase):
|
||||||
default_d_ch_dims = self.options['e_ch_dims'] // 2
|
default_d_ch_dims = self.options['e_ch_dims'] // 2
|
||||||
self.options['d_ch_dims'] = np.clip ( io.input_int("Decoder dims per channel (10-85 ?:help skip:%d) : " % (default_d_ch_dims) , default_d_ch_dims, help_message="More decoder dims help to get better details, but require more VRAM. You can fine-tune model size to fit your GPU." ), 10, 85 )
|
self.options['d_ch_dims'] = np.clip ( io.input_int("Decoder dims per channel (10-85 ?:help skip:%d) : " % (default_d_ch_dims) , default_d_ch_dims, help_message="More decoder dims help to get better details, but require more VRAM. You can fine-tune model size to fit your GPU." ), 10, 85 )
|
||||||
#self.options['remove_gray_border'] = io.input_bool ("Remove gray border? (y/n, ?:help skip:n) : ", False, help_message="Removes gray border of predicted face, but requires more computing resources.")
|
#self.options['remove_gray_border'] = io.input_bool ("Remove gray border? (y/n, ?:help skip:n) : ", False, help_message="Removes gray border of predicted face, but requires more computing resources.")
|
||||||
|
self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.")
|
||||||
|
self.options['ca_weights'] = io.input_bool ("Use CA weights? (y/n, ?:help skip: %s ) : " % (yn_str[def_ca_weights]), def_ca_weights, help_message="Initialize network with 'Convolution Aware' weights. This may help to achieve a higher accuracy model, but consumes time at first run.")
|
||||||
else:
|
else:
|
||||||
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
|
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
|
||||||
self.options['e_ch_dims'] = self.options.get('e_ch_dims', default_e_ch_dims)
|
self.options['e_ch_dims'] = self.options.get('e_ch_dims', default_e_ch_dims)
|
||||||
self.options['d_ch_dims'] = self.options.get('d_ch_dims', default_d_ch_dims)
|
self.options['d_ch_dims'] = self.options.get('d_ch_dims', default_d_ch_dims)
|
||||||
self.options['remove_gray_border'] = self.options.get('remove_gray_border', False)
|
self.options['remove_gray_border'] = self.options.get('remove_gray_border', False)
|
||||||
|
|
||||||
if is_first_run:
|
|
||||||
self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.")
|
|
||||||
else:
|
|
||||||
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False)
|
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False)
|
||||||
|
self.options['ca_weights'] = self.options.get('ca_weights', def_ca_weights)
|
||||||
|
|
||||||
default_face_style_power = 0.0
|
default_face_style_power = 0.0
|
||||||
default_bg_style_power = 0.0
|
default_bg_style_power = 0.0
|
||||||
|
@ -135,6 +135,7 @@ class SAEModel(ModelBase):
|
||||||
'norm': norm,
|
'norm': norm,
|
||||||
'act':'' }
|
'act':'' }
|
||||||
|
|
||||||
|
models_list = []
|
||||||
weights_to_load = []
|
weights_to_load = []
|
||||||
if 'liae' in self.options['archi']:
|
if 'liae' in self.options['archi']:
|
||||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
|
||||||
|
@ -147,9 +148,11 @@ class SAEModel(ModelBase):
|
||||||
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
|
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
|
||||||
|
|
||||||
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs)
|
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs)
|
||||||
|
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs)
|
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs)
|
||||||
|
models_list += [self.decoderm]
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||||
|
@ -187,10 +190,12 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
||||||
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
||||||
|
models_list += [self.encoder, self.decoder_src, self.decoder_dst]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
||||||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
||||||
|
models_list += [self.decoder_srcm, self.decoder_dstm]
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||||
|
@ -213,6 +218,14 @@ class SAEModel(ModelBase):
|
||||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||||
|
|
||||||
|
if self.is_first_run() and self.options.get('ca_weights',False):
|
||||||
|
conv_weights_list = []
|
||||||
|
for model in models_list:
|
||||||
|
for layer in model.layers:
|
||||||
|
if type(layer) == keras.layers.Conv2D:
|
||||||
|
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
||||||
|
CAInitializerMP ( conv_weights_list )
|
||||||
|
|
||||||
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue