refactoring

This commit is contained in:
Colombo 2020-03-01 19:09:50 +04:00
parent cbff72f597
commit 757ec77e44
4 changed files with 154 additions and 228 deletions

View file

@ -344,6 +344,11 @@ class SAEHDModel(ModelBase):
devices = device_config.devices
self.resolution = resolution = self.options['resolution']
self.face_type = {'h' : FaceType.HALF,
'mf' : FaceType.MID_FULL,
'f' : FaceType.FULL,
'wf' : FaceType.WHOLE_FACE}[ self.options['face_type'] ]
learn_mask = self.options['learn_mask']
eyes_prio = self.options['eyes_prio']
archi = self.options['archi']
@ -722,23 +727,11 @@ class SAEHDModel(ModelBase):
# initializing sample generators
if self.is_training:
t = SampleProcessor.Types
if self.options['face_type'] == 'h':
face_type = t.FACE_TYPE_HALF
elif self.options['face_type'] == 'mf':
face_type = t.FACE_TYPE_MID_FULL
elif self.options['face_type'] == 'f':
face_type = t.FACE_TYPE_FULL
elif self.options['face_type'] == 'wf':
face_type = t.FACE_TYPE_WHOLE_FACE
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()
random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' and not self.pretrain else None
t_img_warped = t.IMG_WARPED_TRANSFORMED if self.options['random_warp'] else t.IMG_TRANSFORMED
cpu_count = min(multiprocessing.cpu_count(), 8)
src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2
@ -748,17 +741,17 @@ class SAEHDModel(ModelBase):
self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_EYES_HULL), 'data_format':nn.data_format, 'resolution': resolution },
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': self.options['ct_mode'], 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': self.options['ct_mode'], 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.ALL_EYES_HULL, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
generators_count=src_generators_count ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_EYES_HULL), 'data_format':nn.data_format, 'resolution': resolution},
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.ALL_EYES_HULL, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
generators_count=dst_generators_count )
])
@ -904,17 +897,8 @@ class SAEHDModel(ModelBase):
#override
def get_MergerConfig(self):
if self.options['face_type'] == 'h':
face_type = FaceType.HALF
elif self.options['face_type'] == 'mf':
face_type = FaceType.MID_FULL
elif self.options['face_type'] == 'f':
face_type = FaceType.FULL
elif self.options['face_type'] == 'wf':
face_type = FaceType.WHOLE_FACE
import merger
return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=face_type,
return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type,
default_mode = 'overlay' if self.options['ct_mode'] != 'none' or self.options['face_style_power'] or self.options['bg_style_power'] else 'seamless',
)