SAEHD: random flip replaced with random SRC flip(default False) and random DST flip(default True)

This commit is contained in:
iperov 2021-03-23 15:13:10 +04:00
parent b333fcea4b
commit e47b602ec8
2 changed files with 15 additions and 4 deletions

View file

@ -185,7 +185,9 @@ class ModelBase(object):
self.write_preview_history = self.options.get('write_preview_history', False)
self.target_iter = self.options.get('target_iter',0)
self.random_flip = self.options.get('random_flip',True)
self.random_src_flip = self.options.get('random_src_flip', False)
self.random_dst_flip = self.options.get('random_dst_flip', True)
self.on_initialize()
self.options['batch_size'] = self.batch_size
@ -297,6 +299,14 @@ class ModelBase(object):
def ask_random_flip(self):
default_random_flip = self.load_or_def_option('random_flip', True)
self.options['random_flip'] = io.input_bool("Flip faces randomly", default_random_flip, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
def ask_random_src_flip(self):
default_random_src_flip = self.load_or_def_option('random_src_flip', False)
self.options['random_src_flip'] = io.input_bool("Flip SRC faces randomly", default_random_src_flip, help_message="Random horizontal flip SRC faceset. Covers more angles, but the face may look less naturally.")
def ask_random_dst_flip(self):
default_random_dst_flip = self.load_or_def_option('random_dst_flip', True)
self.options['random_dst_flip'] = io.input_bool("Flip DST faces randomly", default_random_dst_flip, help_message="Random horizontal flip DST faceset. Makes generalization of src->dst better, if src random flip is not enabled.")
def ask_batch_size(self, suggest_batch_size=None, range=None):
default_batch_size = self.load_or_def_option('batch_size', suggest_batch_size or self.batch_size)

View file

@ -65,7 +65,8 @@ class SAEHDModel(ModelBase):
self.ask_autobackup_hour()
self.ask_write_preview_history()
self.ask_target_iter()
self.ask_random_flip()
self.ask_random_src_flip()
self.ask_random_dst_flip()
self.ask_batch_size(suggest_batch_size)
if self.is_first_run():
@ -630,7 +631,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
sample_process_options=SampleProcessor.Options(random_flip=self.random_src_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
@ -640,7 +641,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
generators_count=src_generators_count ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
sample_process_options=SampleProcessor.Options(random_flip=self.random_dst_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},