diff --git a/models/ModelBase.py b/models/ModelBase.py index f39c0e8..59cfd57 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -85,12 +85,13 @@ class ModelBase(object): else: self.options['write_preview_history'] = self.options.get('write_preview_history', False) - if ask_target_iter and (self.iter == 0 or ask_override): - self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0)) - else: - self.options['target_iter'] = max(model_data.get('target_iter',0), self.options.get('target_epoch',0)) - if 'target_epoch' in self.options: - self.options.pop('target_epoch') + if ask_target_iter: + if (self.iter == 0 or ask_override): + self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0)) + else: + self.options['target_iter'] = max(model_data.get('target_iter',0), self.options.get('target_epoch',0)) + if 'target_epoch' in self.options: + self.options.pop('target_epoch') if ask_batch_size and (self.iter == 0 or ask_override): default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0) @@ -98,20 +99,23 @@ class ModelBase(object): else: self.options['batch_size'] = self.options.get('batch_size', 0) - if ask_sort_by_yaw and (self.iter == 0): - self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." ) - else: - self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False) + if ask_sort_by_yaw: + if (self.iter == 0): + self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." ) + else: + self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False) - if ask_random_flip and (self.iter == 0): - self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.") - else: - self.options['random_flip'] = self.options.get('random_flip', True) + if ask_random_flip: + if (self.iter == 0): + self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.") + else: + self.options['random_flip'] = self.options.get('random_flip', True) - if ask_src_scale_mod and (self.iter == 0): - self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30) - else: - self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0) + if ask_src_scale_mod: + if (self.iter == 0): + self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30) + else: + self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0) self.write_preview_history = self.options['write_preview_history'] if not self.options['write_preview_history']: diff --git a/samplelib/SampleGeneratorImageTemporal.py b/samplelib/SampleGeneratorImageTemporal.py index 2f3457c..d370bda 100644 --- a/samplelib/SampleGeneratorImageTemporal.py +++ b/samplelib/SampleGeneratorImageTemporal.py @@ -42,11 +42,13 @@ class SampleGeneratorImageTemporal(SampleGeneratorBase): if samples_len == 0: raise ValueError('No training data provided.') - if samples_len - self.temporal_image_count < 0: + mult_max = 4 + l = samples_len - (self.temporal_image_count-1)*mult_max + 1 + if l < 0: raise ValueError('Not enough samples to fit temporal line.') shuffle_idxs = [] - samples_sub_len = samples_len - self.temporal_image_count + 1 + samples_sub_len = samples_len - l + 1 while True: @@ -60,9 +62,9 @@ class SampleGeneratorImageTemporal(SampleGeneratorBase): idx = shuffle_idxs.pop() temporal_samples = [] - + mult = np.random.randint(mult_max) for i in range( self.temporal_image_count ): - sample = samples[ idx+i ] + sample = samples[ idx+i*mult ] try: temporal_samples += SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug) except: