From 55b947eab5dfb0ccc7b245a40aeaa9fb08806521 Mon Sep 17 00:00:00 2001 From: iperov Date: Fri, 30 Jul 2021 17:24:21 +0400 Subject: [PATCH] XSeg: added pretrain option. --- core/leras/models/XSeg.py | 16 ++++- facelib/XSegNet.py | 4 +- models/Model_XSeg/Model.py | 137 +++++++++++++++++++++++++------------ 3 files changed, 111 insertions(+), 46 deletions(-) diff --git a/core/leras/models/XSeg.py b/core/leras/models/XSeg.py index e6bde65..f59eb8c 100644 --- a/core/leras/models/XSeg.py +++ b/core/leras/models/XSeg.py @@ -88,9 +88,9 @@ class XSeg(nn.ModelBase): self.uconv02 = ConvBlock(base_ch*2, base_ch) self.uconv01 = ConvBlock(base_ch, base_ch) self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME') + - - def forward(self, inp): + def forward(self, inp, pretrain=False): x = inp x = self.conv01(x) @@ -126,29 +126,41 @@ class XSeg(nn.ModelBase): x = nn.reshape_4D (x, 4, 4, self.base_ch*8 ) x = self.up5(x) + if pretrain: + x5 = tf.zeros_like(x5) x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis)) x = self.uconv52(x) x = self.uconv51(x) x = self.up4(x) + if pretrain: + x4 = tf.zeros_like(x4) x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis)) x = self.uconv42(x) x = self.uconv41(x) x = self.up3(x) + if pretrain: + x3 = tf.zeros_like(x3) x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis)) x = self.uconv32(x) x = self.uconv31(x) x = self.up2(x) + if pretrain: + x2 = tf.zeros_like(x2) x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis)) x = self.uconv21(x) x = self.up1(x) + if pretrain: + x1 = tf.zeros_like(x1) x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis)) x = self.uconv11(x) x = self.up0(x) + if pretrain: + x0 = tf.zeros_like(x0) x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis)) x = self.uconv01(x) diff --git a/facelib/XSegNet.py b/facelib/XSegNet.py index 5621a65..ff2bd08 100644 --- a/facelib/XSegNet.py +++ b/facelib/XSegNet.py @@ -81,8 +81,8 @@ class XSegNet(object): def get_resolution(self): return self.resolution - def flow(self, x): - return self.model(x) + def flow(self, x, pretrain=False): + return self.model(x, pretrain=pretrain) def get_weights(self): return self.model_weights diff --git a/models/Model_XSeg/Model.py b/models/Model_XSeg/Model.py index a27d5ad..53ae22b 100644 --- a/models/Model_XSeg/Model.py +++ b/models/Model_XSeg/Model.py @@ -25,13 +25,20 @@ class XSegModel(ModelBase): self.set_iter(0) default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') + default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False) if self.is_first_run(): self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower() if self.is_first_run() or ask_override: self.ask_batch_size(4, range=[2,16]) - + self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain) + + if self.options['pretrain'] and self.get_pretraining_data_path() is None: + raise Exception("pretraining_data_path is not defined") + + self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) + #override def on_initialize(self): device_config = nn.getCurrentDeviceConfig() @@ -50,7 +57,8 @@ class XSegModel(ModelBase): 'f' : FaceType.FULL, 'wf' : FaceType.WHOLE_FACE, 'head' : FaceType.HEAD}[ self.options['face_type'] ] - + + place_model_on_cpu = len(devices) == 0 models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name @@ -66,14 +74,19 @@ class XSegModel(ModelBase): place_model_on_cpu=place_model_on_cpu, optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'), data_format=nn.data_format) - + + self.pretrain = self.options['pretrain'] + if self.pretrain_just_disabled: + self.set_iter(0) + if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) - + targetm_t = tf.placeholder (nn.floatx, mask_shape) + # Compute losses per GPU gpu_pred_list = [] @@ -81,20 +94,33 @@ class XSegModel(ModelBase): gpu_loss_gvs = [] for gpu_id in range(gpu_count): - - with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_input_t = self.model.input_t [batch_slice,:,:,:] gpu_target_t = self.model.target_t [batch_slice,:,:,:] + gpu_targetm_t = targetm_t [batch_slice,:,:,:] # process model tensors - gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t) + gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t, pretrain=self.pretrain) gpu_pred_list.append(gpu_pred_t) - - gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3]) + + + if self.pretrain: + gpu_targetm_blur = nn.gaussian_blur(gpu_targetm_t, max(1, resolution // 32) ) + gpu_targetm_blur = tf.clip_by_value(gpu_targetm_blur, 0, 0.5) * 2 + + gpu_target_t_blur = gpu_target_t*gpu_targetm_blur + gpu_pred_t_blur = gpu_pred_t*gpu_targetm_t + + # Structural loss + gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t_blur, gpu_pred_t_blur, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t_blur, gpu_pred_t_blur, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + # Pixel loss + gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t_blur-gpu_pred_t_blur), axis=[1,2,3]) + else: + gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3]) gpu_losses += [gpu_loss] @@ -110,9 +136,14 @@ class XSegModel(ModelBase): # Initializing training and view functions - def train(input_np, target_np): - l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np }) - return l + if self.pretrain: + def train(input_np, target_np, targetm_np): + l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np, targetm_t :targetm_np }) + return l + else: + def train(input_np, target_np): + l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np }) + return l self.train = train def view(input_np): @@ -124,30 +155,40 @@ class XSegModel(ModelBase): src_dst_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 + + if self.pretrain: + pretrain_gen = SampleGeneratorFace(self.get_pretraining_data_path(), debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=True), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=False, + generators_count=cpu_count ) + self.set_training_data_generators ([pretrain_gen]) + else: + srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path], + debug=self.is_debug(), + batch_size=self.get_batch_size(), + resolution=resolution, + face_type=self.face_type, + generators_count=src_dst_generators_count, + data_format=nn.data_format) + src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=False), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + generators_count=src_generators_count, + raise_on_no_data=False ) + dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=False), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + generators_count=dst_generators_count, + raise_on_no_data=False ) - srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path], - debug=self.is_debug(), - batch_size=self.get_batch_size(), - resolution=resolution, - face_type=self.face_type, - generators_count=src_dst_generators_count, - data_format=nn.data_format) - - src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=False), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - ], - generators_count=src_generators_count, - raise_on_no_data=False ) - dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=False), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - ], - generators_count=dst_generators_count, - raise_on_no_data=False ) - - self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator]) + self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator]) #override def get_model_filename_list(self): @@ -159,16 +200,25 @@ class XSegModel(ModelBase): #override def onTrainOneIter(self): - image_np, mask_np = self.generate_next_samples()[0] - loss = self.train (image_np, mask_np) + if self.pretrain: + image_np, target_np, targetm_np = self.generate_next_samples()[0] + loss = self.train (image_np, target_np, targetm_np) + else: + image_np, mask_np = self.generate_next_samples()[0] + loss = self.train (image_np, mask_np) + return ( ('loss', np.mean(loss) ), ) #override def onGetPreview(self, samples, for_history=False): n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) - - srcdst_samples, src_samples, dst_samples = samples - image_np, mask_np = srcdst_samples + + if self.pretrain: + srcdst_samples, = samples + image_np, mask_np, _ = srcdst_samples + else: + srcdst_samples, src_samples, dst_samples = samples + image_np, mask_np = srcdst_samples I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ] M, IM, = [ np.repeat (x, (3,), -1) for x in [M, IM] ] @@ -178,11 +228,14 @@ class XSegModel(ModelBase): result = [] st = [] for i in range(n_samples): - ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i]) + if self.pretrain: + ar = I[i], IM[i] + else: + ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i]) st.append ( np.concatenate ( ar, axis=1) ) result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ] - if len(src_samples) != 0: + if not self.pretrain and len(src_samples) != 0: src_np, = src_samples @@ -196,7 +249,7 @@ class XSegModel(ModelBase): result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ] - if len(dst_samples) != 0: + if not self.pretrain and len(dst_samples) != 0: dst_np, = dst_samples