diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 5a1e0ca..ce2086d 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -363,6 +363,9 @@ class SAEHDModel(ModelBase): self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0 masked_training = self.options['masked_training'] + ct_mode = self.options['ct_mode'] + if ct_mode == 'none': + ct_mode = None models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' @@ -730,19 +733,19 @@ class SAEHDModel(ModelBase): training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() - random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' and not self.pretrain else None + random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 - if self.options['ct_mode'] != 'none': + if ct_mode is not None: src_generators_count = int(src_generators_count * 1.5) self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': self.options['ct_mode'], 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': self.options['ct_mode'], 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.ALL_EYES_HULL, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], generators_count=src_generators_count ), @@ -898,8 +901,6 @@ class SAEHDModel(ModelBase): #override def get_MergerConfig(self): import merger - return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, - default_mode = 'overlay' if self.options['ct_mode'] != 'none' or self.options['face_style_power'] or self.options['bg_style_power'] else 'seamless', - ) + return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') Model = SAEHDModel diff --git a/samplelib/SampleProcessor.py b/samplelib/SampleProcessor.py index 1b4fc89..f735925 100644 --- a/samplelib/SampleProcessor.py +++ b/samplelib/SampleProcessor.py @@ -72,7 +72,7 @@ class SampleProcessor(object): motion_blur = opts.get('motion_blur', None) gaussian_blur = opts.get('gaussian_blur', None) normalize_tanh = opts.get('normalize_tanh', False) - ct_mode = opts.get('ct_mode', 'None') + ct_mode = opts.get('ct_mode', None) data_format = opts.get('data_format', 'NHWC') if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK: