From 4adeb0aa79aead408ce800f9dff16ed4038b0327 Mon Sep 17 00:00:00 2001 From: Brigham Lysenko Date: Thu, 15 Aug 2019 14:01:16 -0600 Subject: [PATCH] fixed bug fix --- models/ModelBase.py | 14 ++++++------- models/Model_SAE/Model.py | 41 ++++++++++++++++++++++++++++++++++----- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/models/ModelBase.py b/models/ModelBase.py index 114b3aa..21905da 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -145,16 +145,16 @@ class ModelBase(object): " Memory error. Tune this value for your" " videocard manually.")) self.options['ping_pong'] = io.input_bool( - "Enable ping-pong? (y/n ?:help skip:%s) : " % (yn_str[True]), - True, + "Enable ping-pong? (y/n ?:help skip:%s) : " % self.options.get('batch_cap', False), + self.options.get('batch_cap', False), help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence") self.options['paddle'] = self.options.get('paddle','ping') - if self.options.get('ping_pong',True): + if self.options.get('ping_pong',False): self.options['ping_pong_iter'] = max(0, io.input_int("Ping-pong iteration (skip:1000/default) : ", 1000)) else: self.options['batch_cap'] = self.options.get('batch_cap', 16) - self.options['ping_pong'] = self.options.get('ping_pong', True) + self.options['ping_pong'] = self.options.get('ping_pong', False) self.options['ping_pong_iter'] = self.options.get('ping_pong_iter',1000) if ask_sort_by_yaw: @@ -525,10 +525,10 @@ class ModelBase(object): def train_one_iter(self): - if self.iter == 1 and self.options.get('ping_pong', True): + if self.iter == 1 and self.options.get('ping_pong', False): self.set_batch_size(1) self.paddle = 'ping' - elif not self.options.get('ping_pong', True) and self.batch_cap != self.batch_size: + elif not self.options.get('ping_pong', False) and self.batch_cap != self.batch_size: self.set_batch_size(self.batch_cap) sample = self.generate_next_sample() iter_time = time.time() @@ -556,7 +556,7 @@ class ModelBase(object): img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8) cv2_imwrite(filepath, img) - if self.iter % self.ping_pong_iter == 0 and self.iter != 0 and self.options.get('ping_pong', True): + if self.iter % self.ping_pong_iter == 0 and self.iter != 0 and self.options.get('ping_pong', False): if self.batch_size == self.batch_cap: self.paddle = 'pong' if self.batch_size > self.batch_cap: diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 3c0d375..5a3983d 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -179,6 +179,7 @@ class SAEModel(ModelBase): + global apply_random_ct apply_random_ct = self.options.get('apply_random_ct', ColorTransferMode.NONE) masked_training = True @@ -457,11 +458,12 @@ class SAEModel(ModelBase): global t_mode_bgr t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE - + global training_data_src_path training_data_src_path = self.training_data_src_path - + global training_data_dst_path training_data_dst_path= self.training_data_dst_path + global sort_by_yaw sort_by_yaw = self.sort_by_yaw if self.pretrain and self.pretraining_data_path is not None: @@ -537,9 +539,38 @@ class SAEModel(ModelBase): # override def set_batch_size(self, batch_size): - self.batch_size = batch_size - for i, generator in enumerate(self.generator_list): - generator.update_batch(batch_size) + self.set_training_data_generators(None) + self.set_training_data_generators([ + SampleGeneratorFace(training_data_src_path, + sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None, + random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None, + debug=self.is_debug(), batch_size=batch_size, + sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, + scale_range=np.array([-0.05, + 0.05]) + self.src_scale_mod / 100.0), + output_sample_types=[{'types': ( + t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), + 'resolution': resolution, 'apply_ct': apply_random_ct}] + \ + [{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr), + 'resolution': resolution // (2 ** i), + 'apply_ct': apply_random_ct} for i in range(ms_count)] + \ + [{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M), + 'resolution': resolution // (2 ** i)} for i in + range(ms_count)] + ), + + SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=batch_size, + sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ), + output_sample_types=[{'types': ( + t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), + 'resolution': resolution}] + \ + [{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr), + 'resolution': resolution // (2 ** i)} for i in + range(ms_count)] + \ + [{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M), + 'resolution': resolution // (2 ** i)} for i in + range(ms_count)]) + ]) # override def onTrainOneIter(self, generators_samples, generators_list):