refactoring

This commit is contained in:
Colombo 2020-03-03 22:20:15 +04:00
parent f56d583cb5
commit 302d23a612
2 changed files with 9 additions and 8 deletions

View file

@ -363,6 +363,9 @@ class SAEHDModel(ModelBase):
self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0 self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0
masked_training = self.options['masked_training'] masked_training = self.options['masked_training']
ct_mode = self.options['ct_mode']
if ct_mode == 'none':
ct_mode = None
models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0' models_opt_device = '/GPU:0' if models_opt_on_gpu and self.is_training else '/CPU:0'
@ -730,19 +733,19 @@ class SAEHDModel(ModelBase):
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()
random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' and not self.pretrain else None random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None
cpu_count = min(multiprocessing.cpu_count(), 8) cpu_count = min(multiprocessing.cpu_count(), 8)
src_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2
if self.options['ct_mode'] != 'none': if ct_mode is not None:
src_generators_count = int(src_generators_count * 1.5) src_generators_count = int(src_generators_count * 1.5)
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': self.options['ct_mode'], 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': self.options['ct_mode'], 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.ALL_EYES_HULL, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.ALL_EYES_HULL, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
], ],
generators_count=src_generators_count ), generators_count=src_generators_count ),
@ -898,8 +901,6 @@ class SAEHDModel(ModelBase):
#override #override
def get_MergerConfig(self): def get_MergerConfig(self):
import merger import merger
return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
default_mode = 'overlay' if self.options['ct_mode'] != 'none' or self.options['face_style_power'] or self.options['bg_style_power'] else 'seamless',
)
Model = SAEHDModel Model = SAEHDModel

View file

@ -72,7 +72,7 @@ class SampleProcessor(object):
motion_blur = opts.get('motion_blur', None) motion_blur = opts.get('motion_blur', None)
gaussian_blur = opts.get('gaussian_blur', None) gaussian_blur = opts.get('gaussian_blur', None)
normalize_tanh = opts.get('normalize_tanh', False) normalize_tanh = opts.get('normalize_tanh', False)
ct_mode = opts.get('ct_mode', 'None') ct_mode = opts.get('ct_mode', None)
data_format = opts.get('data_format', 'NHWC') data_format = opts.get('data_format', 'NHWC')
if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK: if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK: