diff --git a/models/Model_XSeg/Model.py b/models/Model_XSeg/Model.py index 53ae22b..b0addfd 100644 --- a/models/Model_XSeg/Model.py +++ b/models/Model_XSeg/Model.py @@ -34,7 +34,7 @@ class XSegModel(ModelBase): self.ask_batch_size(4, range=[2,16]) self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain) - if self.options['pretrain'] and self.get_pretraining_data_path() is None: + if not self.is_exporting and (self.options['pretrain'] and self.get_pretraining_data_path() is None): raise Exception("pretraining_data_path is not defined") self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) @@ -42,7 +42,7 @@ class XSegModel(ModelBase): #override def on_initialize(self): device_config = nn.getCurrentDeviceConfig() - self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC" + self.model_data_format = "NCHW" if self.is_exporting or (len(device_config.devices) != 0 and not self.is_debug()) else "NHWC" nn.initialize(data_format=self.model_data_format) tf = nn.tf @@ -85,8 +85,6 @@ class XSegModel(ModelBase): bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) - targetm_t = tf.placeholder (nn.floatx, mask_shape) - # Compute losses per GPU gpu_pred_list = [] @@ -100,7 +98,6 @@ class XSegModel(ModelBase): batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_input_t = self.model.input_t [batch_slice,:,:,:] gpu_target_t = self.model.target_t [batch_slice,:,:,:] - gpu_targetm_t = targetm_t [batch_slice,:,:,:] # process model tensors gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t, pretrain=self.pretrain) @@ -108,17 +105,11 @@ class XSegModel(ModelBase): if self.pretrain: - gpu_targetm_blur = nn.gaussian_blur(gpu_targetm_t, max(1, resolution // 32) ) - gpu_targetm_blur = tf.clip_by_value(gpu_targetm_blur, 0, 0.5) * 2 - - gpu_target_t_blur = gpu_target_t*gpu_targetm_blur - gpu_pred_t_blur = gpu_pred_t*gpu_targetm_t - # Structural loss - gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t_blur, gpu_pred_t_blur, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) - gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t_blur, gpu_pred_t_blur, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) # Pixel loss - gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t_blur-gpu_pred_t_blur), axis=[1,2,3]) + gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t-gpu_pred_t), axis=[1,2,3]) else: gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3]) @@ -137,8 +128,8 @@ class XSegModel(ModelBase): # Initializing training and view functions if self.pretrain: - def train(input_np, target_np, targetm_np): - l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np, targetm_t :targetm_np }) + def train(input_np, target_np): + l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np}) return l else: def train(input_np, target_np): @@ -160,8 +151,7 @@ class XSegModel(ModelBase): pretrain_gen = SampleGeneratorFace(self.get_pretraining_data_path(), debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=True), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=False, generators_count=cpu_count ) @@ -200,13 +190,9 @@ class XSegModel(ModelBase): #override def onTrainOneIter(self): - if self.pretrain: - image_np, target_np, targetm_np = self.generate_next_samples()[0] - loss = self.train (image_np, target_np, targetm_np) - else: - image_np, mask_np = self.generate_next_samples()[0] - loss = self.train (image_np, mask_np) - + image_np, target_np = self.generate_next_samples()[0] + loss = self.train (image_np, target_np) + return ( ('loss', np.mean(loss) ), ) #override @@ -215,7 +201,7 @@ class XSegModel(ModelBase): if self.pretrain: srcdst_samples, = samples - image_np, mask_np, _ = srcdst_samples + image_np, mask_np = srcdst_samples else: srcdst_samples, src_samples, dst_samples = samples image_np, mask_np = srcdst_samples @@ -264,5 +250,34 @@ class XSegModel(ModelBase): result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ] return result - + + def export_dfm (self): + output_path = self.get_strpath_storage_for_file(f'model.onnx') + io.log_info(f'Dumping .onnx to {output_path}') + tf = nn.tf + + with tf.device (nn.tf_default_device_name): + input_t = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') + input_t = tf.transpose(input_t, (0,3,1,2)) + _, pred_t = self.model.flow(input_t) + pred_t = tf.transpose(pred_t, (0,2,3,1)) + + tf.identity(pred_t, name='out_mask') + + output_graph_def = tf.graph_util.convert_variables_to_constants( + nn.tf_sess, + tf.get_default_graph().as_graph_def(), + ['out_mask'] + ) + + import tf2onnx + with tf.device("/CPU:0"): + model_proto, _ = tf2onnx.convert._convert_common( + output_graph_def, + name='XSeg', + input_names=['in_face:0'], + output_names=['out_mask:0'], + opset=13, + output_path=output_path) + Model = XSegModel \ No newline at end of file