SAE added CA weights

This commit is contained in:
iperov 2019-04-24 11:38:45 +04:00
parent 60ed56a801
commit b6711b97a3

View file

@ -58,6 +58,7 @@ class SAEModel(ModelBase):
default_ae_dims = 256 if 'liae' in self.options['archi'] else 512 default_ae_dims = 256 if 'liae' in self.options['archi'] else 512
default_e_ch_dims = 42 default_e_ch_dims = 42
default_d_ch_dims = default_e_ch_dims // 2 default_d_ch_dims = default_e_ch_dims // 2
def_ca_weights = False
if is_first_run: if is_first_run:
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
@ -65,16 +66,15 @@ class SAEModel(ModelBase):
default_d_ch_dims = self.options['e_ch_dims'] // 2 default_d_ch_dims = self.options['e_ch_dims'] // 2
self.options['d_ch_dims'] = np.clip ( io.input_int("Decoder dims per channel (10-85 ?:help skip:%d) : " % (default_d_ch_dims) , default_d_ch_dims, help_message="More decoder dims help to get better details, but require more VRAM. You can fine-tune model size to fit your GPU." ), 10, 85 ) self.options['d_ch_dims'] = np.clip ( io.input_int("Decoder dims per channel (10-85 ?:help skip:%d) : " % (default_d_ch_dims) , default_d_ch_dims, help_message="More decoder dims help to get better details, but require more VRAM. You can fine-tune model size to fit your GPU." ), 10, 85 )
#self.options['remove_gray_border'] = io.input_bool ("Remove gray border? (y/n, ?:help skip:n) : ", False, help_message="Removes gray border of predicted face, but requires more computing resources.") #self.options['remove_gray_border'] = io.input_bool ("Remove gray border? (y/n, ?:help skip:n) : ", False, help_message="Removes gray border of predicted face, but requires more computing resources.")
self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.")
self.options['ca_weights'] = io.input_bool ("Use CA weights? (y/n, ?:help skip: %s ) : " % (yn_str[def_ca_weights]), def_ca_weights, help_message="Initialize network with 'Convolution Aware' weights. This may help to achieve a higher accuracy model, but consumes time at first run.")
else: else:
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims) self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
self.options['e_ch_dims'] = self.options.get('e_ch_dims', default_e_ch_dims) self.options['e_ch_dims'] = self.options.get('e_ch_dims', default_e_ch_dims)
self.options['d_ch_dims'] = self.options.get('d_ch_dims', default_d_ch_dims) self.options['d_ch_dims'] = self.options.get('d_ch_dims', default_d_ch_dims)
self.options['remove_gray_border'] = self.options.get('remove_gray_border', False) self.options['remove_gray_border'] = self.options.get('remove_gray_border', False)
if is_first_run:
self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.")
else:
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False) self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False)
self.options['ca_weights'] = self.options.get('ca_weights', def_ca_weights)
default_face_style_power = 0.0 default_face_style_power = 0.0
default_bg_style_power = 0.0 default_bg_style_power = 0.0
@ -135,6 +135,7 @@ class SAEModel(ModelBase):
'norm': norm, 'norm': norm,
'act':'' } 'act':'' }
models_list = []
weights_to_load = [] weights_to_load = []
if 'liae' in self.options['archi']: if 'liae' in self.options['archi']:
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape)) self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape))
@ -147,9 +148,11 @@ class SAEModel(ModelBase):
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ] inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs) self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs)
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder]
if self.options['learn_mask']: if self.options['learn_mask']:
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs) self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs)
models_list += [self.decoderm]
if not self.is_first_run(): if not self.is_first_run():
weights_to_load += [ [self.encoder , 'encoder.h5'], weights_to_load += [ [self.encoder , 'encoder.h5'],
@ -187,10 +190,12 @@ class SAEModel(ModelBase):
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs) self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs) self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
models_list += [self.encoder, self.decoder_src, self.decoder_dst]
if self.options['learn_mask']: if self.options['learn_mask']:
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs) self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs) self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
models_list += [self.decoder_srcm, self.decoder_dstm]
if not self.is_first_run(): if not self.is_first_run():
weights_to_load += [ [self.encoder , 'encoder.h5'], weights_to_load += [ [self.encoder , 'encoder.h5'],
@ -213,6 +218,14 @@ class SAEModel(ModelBase):
pred_dst_dstm = self.decoder_dstm(warped_dst_code) pred_dst_dstm = self.decoder_dstm(warped_dst_code)
pred_src_dstm = self.decoder_srcm(warped_dst_code) pred_src_dstm = self.decoder_srcm(warped_dst_code)
if self.is_first_run() and self.options.get('ca_weights',False):
conv_weights_list = []
for model in models_list:
for layer in model.layers:
if type(layer) == keras.layers.Conv2D:
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
CAInitializerMP ( conv_weights_list )
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ] pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
if self.options['learn_mask']: if self.options['learn_mask']: