From a3c5d957d3df9710e4045b58d29225ef2dad6a51 Mon Sep 17 00:00:00 2001 From: Brigham Lysenko Date: Wed, 14 Aug 2019 18:55:45 -0600 Subject: [PATCH] now accessing current generators instead of assigning new ones --- models/Model_SAE/Model.py | 43 +++++--------------------------- samplelib/SampleGeneratorFace.py | 4 +++ 2 files changed, 10 insertions(+), 37 deletions(-) diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index ead0282..8f492d2 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -169,7 +169,7 @@ class SAEModel(ModelBase): global ms_count self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1 - global apply_random_ct + apply_random_ct = self.options.get('apply_random_ct', False) masked_training = True @@ -448,11 +448,11 @@ class SAEModel(ModelBase): global t_mode_bgr t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE - global training_data_src_path + training_data_src_path = self.training_data_src_path - global training_data_dst_path + training_data_dst_path= self.training_data_dst_path - global sort_by_yaw + sort_by_yaw = self.sort_by_yaw if self.pretrain and self.pretraining_data_path is not None: @@ -529,39 +529,8 @@ class SAEModel(ModelBase): # override def set_batch_size(self, batch_size): self.batch_size = batch_size - self.set_training_data_generators([ - SampleGeneratorFace(training_data_src_path, - sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None, - random_ct_samples_path=training_data_dst_path if apply_random_ct else None, - debug=self.is_debug(), batch_size=self.batch_size, - sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, - scale_range=np.array([-0.05, - 0.05]) + self.src_scale_mod / 100.0), - output_sample_types=[{'types': ( - t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), - 'resolution': resolution, 'apply_ct': apply_random_ct}] + \ - [{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr), - 'resolution': resolution // (2 ** i), - 'apply_ct': apply_random_ct} for i in range(ms_count)] + \ - [{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M), - 'resolution': resolution // (2 ** i)} for i in - range(ms_count)] - ), - - SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, - sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ), - output_sample_types=[{'types': ( - t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), - 'resolution': resolution}] + \ - [{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr), - 'resolution': resolution // (2 ** i)} for i in - range(ms_count)] + \ - [{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M), - 'resolution': resolution // (2 ** i)} for i in - range(ms_count)]) - ]) - - + for i, generator in enumerate(self.generator_list): + generator.update_batch(batch_size) # override def onTrainOneIter(self, generators_samples, generators_list): diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index 593f8e2..87a7e5b 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -140,3 +140,7 @@ class SampleGeneratorFace(SampleGeneratorBase): break yield [ np.array(batch) for batch in batches] + + def update_batch(self, batch_size): + self.batch_size = batch_size +