diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 392002c..ee45c55 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -51,11 +51,11 @@ class SAEModel(ModelBase): self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1) if is_first_run: - self.options['archi'] = io.input_str ("AE architecture (df, liae ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes.").lower() + self.options['archi'] = io.input_str ("AE architecture (df, liae, df-s, liae-s ?:help skip:%s) : " % (default_archi) , default_archi, ['df','df-s','liae','liae-s'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. -s version is slower, but has decreased change to collapse.").lower() else: self.options['archi'] = self.options.get('archi', default_archi) - default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 + default_ae_dims = 256 if 'liae' in self.options['archi'] else 512 default_e_ch_dims = 42 default_d_ch_dims = default_e_ch_dims // 2 @@ -125,13 +125,18 @@ class SAEModel(ModelBase): target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)] target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)] - padding = 'reflect' if self.options['remove_gray_border'] else 'zero' + padding = 'zero' + norm = '' + + if '-s' in self.options['archi']: + norm = 'bn' + common_flow_kwargs = { 'padding': padding, - 'norm': 'bn', + 'norm': norm, 'act':'' } weights_to_load = [] - if self.options['archi'] == 'liae': + if 'liae' in self.options['archi']: self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape)) enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] @@ -175,7 +180,7 @@ class SAEModel(ModelBase): pred_dst_dstm = self.decoderm(warped_dst_inter_code) pred_src_dstm = self.decoderm(warped_src_dst_inter_code) - elif self.options['archi'] == 'df': + elif 'df' in self.options['archi']: self.encoder = modelify(SAEModel.DFEncFlow(resolution, ae_dims=ae_dims, ch_dims=e_ch_dims, **common_flow_kwargs) ) (Input(bgr_shape)) dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] @@ -248,7 +253,7 @@ class SAEModel(ModelBase): self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1) - if self.options['archi'] == 'liae': + if 'liae' in self.options['archi']: src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights if self.options['learn_mask']: src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights @@ -344,7 +349,7 @@ class SAEModel(ModelBase): [self.src_dst_mask_opt, 'src_dst_mask_opt'] ] ar = [] - if self.options['archi'] == 'liae': + if 'liae' in self.options['archi']: ar += [[self.encoder, 'encoder.h5'], [self.inter_B, 'inter_B.h5'], [self.inter_AB, 'inter_AB.h5'], @@ -352,7 +357,7 @@ class SAEModel(ModelBase): ] if self.options['learn_mask']: ar += [ [self.decoderm, 'decoderm.h5'] ] - elif self.options['archi'] == 'df': + elif 'df' in self.options['archi']: ar += [[self.encoder, 'encoder.h5'], [self.decoder_src, 'decoder_src.h5'], [self.decoder_dst, 'decoder_dst.h5']