added option Eyes priority (y/n)

	fix eye problems during training  ( especially on HD architectures )
	by forcing the neural network to train eyes with higher priority
	before/after https://i.imgur.com/YQHOuSR.jpg

	It does not guarantee the right eye direction.
This commit is contained in:
Colombo 2020-02-18 14:30:07 +04:00
parent 4f928074b9
commit 9598ba0141
5 changed files with 105 additions and 60 deletions

View file

@ -36,6 +36,7 @@ class SAEHDModel(ModelBase):
self.options['d_dims'] = None
self.options['d_mask_dims'] = None
default_learn_mask = self.options['learn_mask'] = self.load_or_def_option('learn_mask', True)
default_eyes_prio = self.options['eyes_prio'] = self.load_or_def_option('eyes_prio', False)
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False)
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
@ -83,6 +84,7 @@ class SAEHDModel(ModelBase):
if self.is_first_run() or ask_override:
self.options['learn_mask'] = io.input_bool ("Learn mask", default_learn_mask, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case merger forced to use 'not predicted mask' that is not smooth as predicted.")
self.options['eyes_prio'] = io.input_bool ("Eyes priority", default_eyes_prio, help_message="fix eye problems during training ( especially on HD architectures ) by forcing the neural network to train eyes with higher priority. before/after https://i.imgur.com/YQHOuSR.jpg . It does not guarantee the right eye direction.")
if self.is_first_run() or ask_override:
if len(device_config.devices) == 1:
@ -333,6 +335,7 @@ class SAEHDModel(ModelBase):
self.resolution = resolution = self.options['resolution']
learn_mask = self.options['learn_mask']
eyes_prio = self.options['eyes_prio']
archi = self.options['archi']
ae_dims = self.options['ae_dims']
e_dims = self.options['e_dims']
@ -367,9 +370,9 @@ class SAEHDModel(ModelBase):
self.target_src = tf.placeholder (nn.tf_floatx, bgr_shape)
self.target_dst = tf.placeholder (nn.tf_floatx, bgr_shape)
self.target_srcm = tf.placeholder (nn.tf_floatx, mask_shape)
self.target_dstm = tf.placeholder (nn.tf_floatx, mask_shape)
self.target_srcm_all = tf.placeholder (nn.tf_floatx, mask_shape)
self.target_dstm_all = tf.placeholder (nn.tf_floatx, mask_shape)
# Initializing model classes
with tf.device (models_opt_device):
if 'df' in archi:
@ -468,13 +471,13 @@ class SAEHDModel(ModelBase):
with tf.device(f'/CPU:0'):
# slice on CPU, otherwise all batch data will be transfered to GPU first
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
gpu_warped_src = self.warped_src [batch_slice,:,:,:]
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:]
gpu_target_src = self.target_src [batch_slice,:,:,:]
gpu_target_dst = self.target_dst [batch_slice,:,:,:]
gpu_target_srcm = self.target_srcm[batch_slice,:,:,:]
gpu_target_dstm = self.target_dstm[batch_slice,:,:,:]
gpu_warped_src = self.warped_src [batch_slice,:,:,:]
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:]
gpu_target_src = self.target_src [batch_slice,:,:,:]
gpu_target_dst = self.target_dst [batch_slice,:,:,:]
gpu_target_srcm_all = self.target_srcm_all[batch_slice,:,:,:]
gpu_target_dstm_all = self.target_dstm_all[batch_slice,:,:,:]
# process model tensors
if 'df' in archi:
gpu_src_code = self.inter(self.encoder(gpu_warped_src))
@ -504,7 +507,13 @@ class SAEHDModel(ModelBase):
gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
# unpack masks from one combined mask
gpu_target_srcm = tf.clip_by_value (gpu_target_srcm_all, 0, 1)
gpu_target_dstm = tf.clip_by_value (gpu_target_dstm_all, 0, 1)
gpu_target_srcm_eyes = tf.clip_by_value (gpu_target_srcm_all-1, 0, 1)
gpu_target_dstm_eyes = tf.clip_by_value (gpu_target_dstm_all-1, 0, 1)
gpu_target_srcm_blur = nn.tf_gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
gpu_target_dstm_blur = nn.tf_gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
@ -513,7 +522,7 @@ class SAEHDModel(ModelBase):
gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst
gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
@ -522,6 +531,10 @@ class SAEHDModel(ModelBase):
gpu_src_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
if eyes_prio:
gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_eyes - gpu_pred_src_src*gpu_target_srcm_eyes ), axis=[1,2,3])
if learn_mask:
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
@ -534,8 +547,12 @@ class SAEHDModel(ModelBase):
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.tf_dssim(gpu_psd_target_dst_anti_masked, gpu_target_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square( gpu_psd_target_dst_anti_masked - gpu_target_dst_anti_masked), axis=[1,2,3] )
gpu_dst_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_dst_loss = tf.reduce_mean ( 10*nn.tf_dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
if eyes_prio:
gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_eyes - gpu_pred_dst_dst*gpu_target_dstm_eyes ), axis=[1,2,3])
if learn_mask:
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
@ -606,15 +623,15 @@ class SAEHDModel(ModelBase):
# Initializing training and view functions
def src_dst_train(warped_src, target_src, target_srcm, \
warped_dst, target_dst, target_dstm):
def src_dst_train(warped_src, target_src, target_srcm_all, \
warped_dst, target_dst, target_dstm_all):
s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
feed_dict={self.warped_src :warped_src,
self.target_src :target_src,
self.target_srcm:target_srcm,
self.target_srcm_all:target_srcm_all,
self.warped_dst :warped_dst,
self.target_dst :target_dst,
self.target_dstm:target_dstm,
self.target_dstm_all:target_dstm_all,
})
s = np.mean(s)
d = np.mean(d)
@ -722,14 +739,16 @@ class SAEHDModel(ModelBase):
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_HULL), 'data_format':nn.data_format, 'resolution': resolution } ],
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_EYES_HULL), 'data_format':nn.data_format, 'resolution': resolution },
],
generators_count=src_generators_count ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_HULL), 'data_format':nn.data_format, 'resolution': resolution} ],
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_EYES_HULL), 'data_format':nn.data_format, 'resolution': resolution},
],
generators_count=dst_generators_count )
])
@ -748,23 +767,23 @@ class SAEHDModel(ModelBase):
#override
def onTrainOneIter(self):
( (warped_src, target_src, target_srcm), \
(warped_dst, target_dst, target_dstm) ) = self.generate_next_samples()
( (warped_src, target_src, target_srcm_all), \
(warped_dst, target_dst, target_dstm_all) ) = self.generate_next_samples()
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm)
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
if self.options['true_face_power'] != 0 and not self.pretrain:
self.D_train (warped_src, warped_dst)
if self.gan_power != 0:
self.D_src_dst_train (warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm)
self.D_src_dst_train (warped_src, target_src, target_srcm_all, warped_dst, target_dst, target_dstm_all)
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
#override
def onGetPreview(self, samples):
( (warped_src, target_src, target_srcm),
(warped_dst, target_dst, target_dstm) ) = samples
( (warped_src, target_src, target_srcm_all,),
(warped_dst, target_dst, target_dstm_all,) ) = samples
if self.options['learn_mask']:
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
@ -772,8 +791,11 @@ class SAEHDModel(ModelBase):
else:
S, D, SS, DD, SD, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format) , 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
target_srcm_all, target_dstm_all = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm_all, target_dstm_all] )]
target_srcm = np.clip(target_srcm_all, 0, 1)
target_dstm = np.clip(target_dstm_all, 0, 1)
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
result = []