now accessing current generators instead of assigning new ones

This commit is contained in:
Brigham Lysenko 2019-08-14 18:55:45 -06:00
commit a3c5d957d3
2 changed files with 10 additions and 37 deletions

View file

@ -169,7 +169,7 @@ class SAEModel(ModelBase):
global ms_count
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
global apply_random_ct
apply_random_ct = self.options.get('apply_random_ct', False)
masked_training = True
@ -448,11 +448,11 @@ class SAEModel(ModelBase):
global t_mode_bgr
t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
global training_data_src_path
training_data_src_path = self.training_data_src_path
global training_data_dst_path
training_data_dst_path= self.training_data_dst_path
global sort_by_yaw
sort_by_yaw = self.sort_by_yaw
if self.pretrain and self.pretraining_data_path is not None:
@ -529,39 +529,8 @@ class SAEModel(ModelBase):
# override
def set_batch_size(self, batch_size):
self.batch_size = batch_size
self.set_training_data_generators([
SampleGeneratorFace(training_data_src_path,
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
scale_range=np.array([-0.05,
0.05]) + self.src_scale_mod / 100.0),
output_sample_types=[{'types': (
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution, 'apply_ct': apply_random_ct}] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution // (2 ** i),
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)]
),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
output_sample_types=[{'types': (
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution}] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)])
])
for i, generator in enumerate(self.generator_list):
generator.update_batch(batch_size)
# override
def onTrainOneIter(self, generators_samples, generators_list):

View file

@ -140,3 +140,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
break
yield [ np.array(batch) for batch in batches]
def update_batch(self, batch_size):
self.batch_size = batch_size