From 82ce28c7846c6afdb8ab22e4571cec67f36b9d4f Mon Sep 17 00:00:00 2001 From: Brigham Lysenko Date: Thu, 15 Aug 2019 14:16:56 -0600 Subject: [PATCH] fixed bug fix --- models/ModelBase.py | 4 ++-- models/Model_SAE/Model.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/models/ModelBase.py b/models/ModelBase.py index 21905da..a96d229 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -145,8 +145,8 @@ class ModelBase(object): " Memory error. Tune this value for your" " videocard manually.")) self.options['ping_pong'] = io.input_bool( - "Enable ping-pong? (y/n ?:help skip:%s) : " % self.options.get('batch_cap', False), - self.options.get('batch_cap', False), + "Enable ping-pong? (y/n ?:help skip:%s) : " % self.options.get('ping_pong', False), + self.options.get('ping_pong', False), help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence") self.options['paddle'] = self.options.get('paddle','ping') if self.options.get('ping_pong',False): diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 5a3983d..59b0cc0 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -539,12 +539,13 @@ class SAEModel(ModelBase): # override def set_batch_size(self, batch_size): + self.batch_size = batch_size self.set_training_data_generators(None) self.set_training_data_generators([ SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None, random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None, - debug=self.is_debug(), batch_size=batch_size, + debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) + self.src_scale_mod / 100.0), @@ -559,7 +560,7 @@ class SAEModel(ModelBase): range(ms_count)] ), - SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=batch_size, + SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ), output_sample_types=[{'types': ( t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),