From ba586213611b57a47da5d90c4da817e3064f08b8 Mon Sep 17 00:00:00 2001 From: Colombo Date: Tue, 10 Mar 2020 10:36:08 +0400 Subject: [PATCH] SAEHD: removed option learn_mask, it is now enabled by default --- models/Model_SAEHD/Model.py | 105 ++++++++++++++---------------------- 1 file changed, 41 insertions(+), 64 deletions(-) diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index f1ddbd9..8c9d638 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -37,7 +37,6 @@ class SAEHDModel(ModelBase): default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True) - default_learn_mask = self.options['learn_mask'] = self.load_or_def_option('learn_mask', True) default_eyes_prio = self.options['eyes_prio'] = self.load_or_def_option('eyes_prio', False) default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False) default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) @@ -88,7 +87,6 @@ class SAEHDModel(ModelBase): if self.options['face_type'] == 'wf': self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' type. Masked training clips training area to full_face mask, thus network will train the faces properly. When the face is trained enough, disable this option to train all area of the frame. Merge with 'raw-rgb' mode, then use Adobe After Effects to manually mask and compose whole face include forehead.") - self.options['learn_mask'] = io.input_bool ("Learn mask", default_learn_mask, help_message="Learning mask will produce a smooth mask in the merger. Also it works as guide for neural network to recognize face directions.") self.options['eyes_prio'] = io.input_bool ("Eyes priority", default_eyes_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction ( especially on HD architectures ) by forcing the neural network to train eyes with higher priority. before/after https://i.imgur.com/YQHOuSR.jpg ') if self.is_first_run() or ask_override: @@ -134,7 +132,6 @@ class SAEHDModel(ModelBase): 'f' : FaceType.FULL, 'wf' : FaceType.WHOLE_FACE}[ self.options['face_type'] ] - learn_mask = self.options['learn_mask'] eyes_prio = self.options['eyes_prio'] archi = self.options['archi'] is_hd = 'hd' in archi @@ -234,14 +231,11 @@ class SAEHDModel(ModelBase): self.src_dst_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if 'df' in archi: - self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() - self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights_ex(learn_mask) + self.decoder_dst.get_weights_ex(learn_mask) - + self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() elif 'liae' in archi: - self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() - self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights_ex(learn_mask) + self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() - self.src_dst_opt.initialize_variables (self.src_dst_all_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) + self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu) if self.options['true_face_power'] != 0: self.D_code_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt') @@ -343,8 +337,7 @@ class SAEHDModel(ModelBase): if eyes_prio: gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_eyes - gpu_pred_src_src*gpu_target_srcm_eyes ), axis=[1,2,3]) - if learn_mask: - gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) + gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) face_style_power = self.options['face_style_power'] / 100.0 if face_style_power != 0 and not self.pretrain: @@ -361,8 +354,7 @@ class SAEHDModel(ModelBase): if eyes_prio: gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_eyes - gpu_pred_dst_dst*gpu_target_dstm_eyes ), axis=[1,2,3]) - if learn_mask: - gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] @@ -461,16 +453,11 @@ class SAEHDModel(ModelBase): self.target_dstm_all:target_dstm_all}) self.D_src_dst_train = D_src_dst_train - if learn_mask: - def AE_view(warped_src, warped_dst): - return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], - feed_dict={self.warped_src:warped_src, - self.warped_dst:warped_dst}) - else: - def AE_view(warped_src, warped_dst): - return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_src_dst], - feed_dict={self.warped_src:warped_src, - self.warped_dst:warped_dst}) + + def AE_view(warped_src, warped_dst): + return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], + feed_dict={self.warped_src:warped_src, + self.warped_dst:warped_dst}) self.AE_view = AE_view else: # Initializing merge function @@ -490,12 +477,9 @@ class SAEHDModel(ModelBase): gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) - if learn_mask: - def AE_merge( warped_dst): - return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) - else: - def AE_merge( warped_dst): - return nn.tf_sess.run ( [gpu_pred_src_dst], feed_dict={self.warped_dst:warped_dst}) + + def AE_merge( warped_dst): + return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) self.AE_merge = AE_merge @@ -605,11 +589,8 @@ class SAEHDModel(ModelBase): ( (warped_src, target_src, target_srcm_all,), (warped_dst, target_dst, target_dstm_all,) ) = samples - if self.options['learn_mask']: - S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] - DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] - else: - S, D, SS, DD, SD, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format) , 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] + S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] + DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] target_srcm_all, target_dstm_all = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm_all, target_dstm_all] )] @@ -627,13 +608,13 @@ class SAEHDModel(ModelBase): st.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD', np.concatenate (st, axis=0 )), ] - if self.options['learn_mask']: - st_m = [] - for i in range(n_samples): - ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) - st_m.append ( np.concatenate ( ar, axis=1) ) + + st_m = [] + for i in range(n_samples): + ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) + st_m.append ( np.concatenate ( ar, axis=1) ) - result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ] + result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ] else: result = [] @@ -655,38 +636,34 @@ class SAEHDModel(ModelBase): st.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD pred', np.concatenate (st, axis=0 )), ] - if self.options['learn_mask']: - st_m = [] - for i in range(n_samples): - ar = S[i]*target_srcm[i], SS[i] - st_m.append ( np.concatenate ( ar, axis=1) ) - result += [ ('SAEHD masked src-src', np.concatenate (st_m, axis=0 )), ] - st_m = [] - for i in range(n_samples): - ar = D[i]*target_dstm[i], DD[i]*DDM[i] - st_m.append ( np.concatenate ( ar, axis=1) ) - result += [ ('SAEHD masked dst-dst', np.concatenate (st_m, axis=0 )), ] - - st_m = [] - for i in range(n_samples): - ar = D[i]*target_dstm[i], SD[i]*(DDM[i]*SDM[i]) - st_m.append ( np.concatenate ( ar, axis=1) ) - result += [ ('SAEHD masked pred', np.concatenate (st_m, axis=0 )), ] + st_m = [] + for i in range(n_samples): + ar = S[i]*target_srcm[i], SS[i] + st_m.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD masked src-src', np.concatenate (st_m, axis=0 )), ] + + st_m = [] + for i in range(n_samples): + ar = D[i]*target_dstm[i], DD[i]*DDM[i] + st_m.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD masked dst-dst', np.concatenate (st_m, axis=0 )), ] + + st_m = [] + for i in range(n_samples): + ar = D[i]*target_dstm[i], SD[i]*(DDM[i]*SDM[i]) + st_m.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD masked pred', np.concatenate (st_m, axis=0 )), ] return result def predictor_func (self, face=None): face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") - if self.options['learn_mask']: - bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ] + bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ] - mask = mask_dst_dstm[0] * mask_src_dstm[0] - return bgr[0], mask[...,0] - else: - bgr, = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ] - return bgr[0] + mask = mask_dst_dstm[0] * mask_src_dstm[0] + return bgr[0], mask[...,0] #override def get_MergerConfig(self):