diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index bc2e3d1..067a3d8 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -58,23 +58,23 @@ class SAEModel(ModelBase): default_ae_dims = 256 if 'liae' in self.options['archi'] else 512 default_e_ch_dims = 42 default_d_ch_dims = default_e_ch_dims // 2 - + def_ca_weights = False + if is_first_run: self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) self.options['e_ch_dims'] = np.clip ( io.input_int("Encoder dims per channel (21-85 ?:help skip:%d) : " % (default_e_ch_dims) , default_e_ch_dims, help_message="More encoder dims help to recognize more facial features, but require more VRAM. You can fine-tune model size to fit your GPU." ), 21, 85 ) default_d_ch_dims = self.options['e_ch_dims'] // 2 self.options['d_ch_dims'] = np.clip ( io.input_int("Decoder dims per channel (10-85 ?:help skip:%d) : " % (default_d_ch_dims) , default_d_ch_dims, help_message="More decoder dims help to get better details, but require more VRAM. You can fine-tune model size to fit your GPU." ), 10, 85 ) #self.options['remove_gray_border'] = io.input_bool ("Remove gray border? (y/n, ?:help skip:n) : ", False, help_message="Removes gray border of predicted face, but requires more computing resources.") + self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.") + self.options['ca_weights'] = io.input_bool ("Use CA weights? (y/n, ?:help skip: %s ) : " % (yn_str[def_ca_weights]), def_ca_weights, help_message="Initialize network with 'Convolution Aware' weights. This may help to achieve a higher accuracy model, but consumes time at first run.") else: self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims) self.options['e_ch_dims'] = self.options.get('e_ch_dims', default_e_ch_dims) self.options['d_ch_dims'] = self.options.get('d_ch_dims', default_d_ch_dims) self.options['remove_gray_border'] = self.options.get('remove_gray_border', False) - - if is_first_run: - self.options['multiscale_decoder'] = io.input_bool ("Use multiscale decoder? (y/n, ?:help skip:n) : ", False, help_message="Multiscale decoder helps to get better details.") - else: self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', False) + self.options['ca_weights'] = self.options.get('ca_weights', def_ca_weights) default_face_style_power = 0.0 default_bg_style_power = 0.0 @@ -134,7 +134,8 @@ class SAEModel(ModelBase): common_flow_kwargs = { 'padding': padding, 'norm': norm, 'act':'' } - + + models_list = [] weights_to_load = [] if 'liae' in self.options['archi']: self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape)) @@ -147,10 +148,12 @@ class SAEModel(ModelBase): inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ] self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs) - + models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder] + if self.options['learn_mask']: self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs) - + models_list += [self.decoderm] + if not self.is_first_run(): weights_to_load += [ [self.encoder , 'encoder.h5'], [self.inter_B , 'inter_B.h5'], @@ -187,11 +190,13 @@ class SAEModel(ModelBase): self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs) self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs) - + models_list += [self.encoder, self.decoder_src, self.decoder_dst] + if self.options['learn_mask']: self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs) self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs) - + models_list += [self.decoder_srcm, self.decoder_dstm] + if not self.is_first_run(): weights_to_load += [ [self.encoder , 'encoder.h5'], [self.decoder_src, 'decoder_src.h5'], @@ -212,7 +217,15 @@ class SAEModel(ModelBase): pred_src_srcm = self.decoder_srcm(warped_src_code) pred_dst_dstm = self.decoder_dstm(warped_dst_code) pred_src_dstm = self.decoder_srcm(warped_dst_code) - + + if self.is_first_run() and self.options.get('ca_weights',False): + conv_weights_list = [] + for model in models_list: + for layer in model.layers: + if type(layer) == keras.layers.Conv2D: + conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights + CAInitializerMP ( conv_weights_list ) + pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ] if self.options['learn_mask']: