mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
refactoring
This commit is contained in:
parent
cbff72f597
commit
757ec77e44
4 changed files with 154 additions and 228 deletions
|
@ -344,6 +344,11 @@ class SAEHDModel(ModelBase):
|
|||
devices = device_config.devices
|
||||
|
||||
self.resolution = resolution = self.options['resolution']
|
||||
self.face_type = {'h' : FaceType.HALF,
|
||||
'mf' : FaceType.MID_FULL,
|
||||
'f' : FaceType.FULL,
|
||||
'wf' : FaceType.WHOLE_FACE}[ self.options['face_type'] ]
|
||||
|
||||
learn_mask = self.options['learn_mask']
|
||||
eyes_prio = self.options['eyes_prio']
|
||||
archi = self.options['archi']
|
||||
|
@ -722,23 +727,11 @@ class SAEHDModel(ModelBase):
|
|||
|
||||
# initializing sample generators
|
||||
if self.is_training:
|
||||
t = SampleProcessor.Types
|
||||
if self.options['face_type'] == 'h':
|
||||
face_type = t.FACE_TYPE_HALF
|
||||
elif self.options['face_type'] == 'mf':
|
||||
face_type = t.FACE_TYPE_MID_FULL
|
||||
elif self.options['face_type'] == 'f':
|
||||
face_type = t.FACE_TYPE_FULL
|
||||
elif self.options['face_type'] == 'wf':
|
||||
face_type = t.FACE_TYPE_WHOLE_FACE
|
||||
|
||||
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
|
||||
training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path()
|
||||
|
||||
random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' and not self.pretrain else None
|
||||
|
||||
t_img_warped = t.IMG_WARPED_TRANSFORMED if self.options['random_warp'] else t.IMG_TRANSFORMED
|
||||
|
||||
cpu_count = min(multiprocessing.cpu_count(), 8)
|
||||
src_generators_count = cpu_count // 2
|
||||
dst_generators_count = cpu_count // 2
|
||||
|
@ -748,17 +741,17 @@ class SAEHDModel(ModelBase):
|
|||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
||||
output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
|
||||
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
|
||||
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_EYES_HULL), 'data_format':nn.data_format, 'resolution': resolution },
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': self.options['ct_mode'], 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': self.options['ct_mode'], 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.ALL_EYES_HULL, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
generators_count=src_generators_count ),
|
||||
|
||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
||||
output_sample_types = [ {'types' : (t_img_warped, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_FACE_MASK_ALL_EYES_HULL), 'data_format':nn.data_format, 'resolution': resolution},
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':self.options['random_warp'], 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.ALL_EYES_HULL, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
generators_count=dst_generators_count )
|
||||
])
|
||||
|
@ -904,17 +897,8 @@ class SAEHDModel(ModelBase):
|
|||
|
||||
#override
|
||||
def get_MergerConfig(self):
|
||||
if self.options['face_type'] == 'h':
|
||||
face_type = FaceType.HALF
|
||||
elif self.options['face_type'] == 'mf':
|
||||
face_type = FaceType.MID_FULL
|
||||
elif self.options['face_type'] == 'f':
|
||||
face_type = FaceType.FULL
|
||||
elif self.options['face_type'] == 'wf':
|
||||
face_type = FaceType.WHOLE_FACE
|
||||
|
||||
import merger
|
||||
return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=face_type,
|
||||
return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type,
|
||||
default_mode = 'overlay' if self.options['ct_mode'] != 'none' or self.options['face_style_power'] or self.options['bg_style_power'] else 'seamless',
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue