diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index a755d3d..9b647af 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -62,7 +62,7 @@ class SAEHDModel(ModelBase): resolution = np.clip ( (resolution // 16) * 16, 64, 512) self.options['resolution'] = resolution self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf'], help_message="Half / mid face / full face / whole face. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead, but requires manual merge in Adobe After Effects.").lower() - self.options['archi'] = io.input_str ("AE architecture", default_archi, ['dfhd','liaehd','df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'hd' is heavyweight version for the best quality.").lower() + self.options['archi'] = io.input_str ("AE architecture", default_archi, ['df','liae','dfhd','liaehd','liaech'], help_message="'df' keeps faces more natural.\n'liae' can fix overly different face shapes.\n'hd' are experimental versions.\n'liaech' - new experimental model by @chervoniy. Based on liae, but produces more src-like face.").lower() default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64 default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims) @@ -125,8 +125,6 @@ class SAEHDModel(ModelBase): nn.initialize(data_format=self.model_data_format) tf = nn.tf - Encoder, Inter, Decoder = nn.get_ae_models() - device_config = nn.getCurrentDeviceConfig() devices = device_config.devices @@ -163,8 +161,6 @@ class SAEHDModel(ModelBase): output_ch = 3 bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) - lowest_dense_res = resolution // 16 - self.model_filename_list = [] @@ -180,12 +176,18 @@ class SAEHDModel(ModelBase): self.target_dstm_all = tf.placeholder (nn.tf_floatx, mask_shape) # Initializing model classes + if archi == 'liaech': + lowest_dense_res, Encoder, Inter, Decoder = nn.get_ae_models_chervoniy(resolution) + else: + lowest_dense_res, Encoder, Inter, Decoder = nn.get_ae_models(resolution) + + with tf.device (models_opt_device): if 'df' in archi: self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder') encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape)) - self.inter = Inter (in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, name='inter') + self.inter = Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, name='inter') inter_out_ch = self.inter.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch))) self.decoder_src = Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_src') @@ -205,8 +207,8 @@ class SAEHDModel(ModelBase): self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder') encoder_out_ch = self.encoder.compute_output_channels ( (nn.tf_floatx, bgr_shape)) - self.inter_AB = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_AB') - self.inter_B = Inter(in_ch=encoder_out_ch, lowest_dense_res=lowest_dense_res, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_B') + self.inter_AB = Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_AB') + self.inter_B = Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_B') inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch))) inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.tf_floatx, (None,encoder_out_ch)))