fixed bug fix

This commit is contained in:
Brigham Lysenko 2019-08-15 14:16:56 -06:00
commit 82ce28c784
2 changed files with 5 additions and 4 deletions

View file

@ -145,8 +145,8 @@ class ModelBase(object):
" Memory error. Tune this value for your" " Memory error. Tune this value for your"
" videocard manually.")) " videocard manually."))
self.options['ping_pong'] = io.input_bool( self.options['ping_pong'] = io.input_bool(
"Enable ping-pong? (y/n ?:help skip:%s) : " % self.options.get('batch_cap', False), "Enable ping-pong? (y/n ?:help skip:%s) : " % self.options.get('ping_pong', False),
self.options.get('batch_cap', False), self.options.get('ping_pong', False),
help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence") help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence")
self.options['paddle'] = self.options.get('paddle','ping') self.options['paddle'] = self.options.get('paddle','ping')
if self.options.get('ping_pong',False): if self.options.get('ping_pong',False):

View file

@ -539,12 +539,13 @@ class SAEModel(ModelBase):
# override # override
def set_batch_size(self, batch_size): def set_batch_size(self, batch_size):
self.batch_size = batch_size
self.set_training_data_generators(None) self.set_training_data_generators(None)
self.set_training_data_generators([ self.set_training_data_generators([
SampleGeneratorFace(training_data_src_path, SampleGeneratorFace(training_data_src_path,
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None, random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None,
debug=self.is_debug(), batch_size=batch_size, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
scale_range=np.array([-0.05, scale_range=np.array([-0.05,
0.05]) + self.src_scale_mod / 100.0), 0.05]) + self.src_scale_mod / 100.0),
@ -559,7 +560,7 @@ class SAEModel(ModelBase):
range(ms_count)] range(ms_count)]
), ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=batch_size, SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
output_sample_types=[{'types': ( output_sample_types=[{'types': (
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),