From 1a5484df8f19094f0ac6ae069907c9fa4c4f9d88 Mon Sep 17 00:00:00 2001 From: Brigham Lysenko Date: Fri, 23 Aug 2019 19:37:04 -0600 Subject: [PATCH] fixed small things dont even worry about it --- models/Model_SAE/Model.py | 36 ++------------------------------ samplelib/SampleGeneratorFace.py | 4 ++-- samplelib/__init__.py | 1 + 3 files changed, 5 insertions(+), 36 deletions(-) diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index d12dabb..a24e4df 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -542,40 +542,8 @@ class SAEModel(ModelBase): # override def set_batch_size(self, batch_size): self.batch_size = batch_size - self.set_training_data_generators(None) - self.set_training_data_generators([ - SampleGeneratorFace(training_data_src_path, - sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None, - random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None, - debug=self.is_debug(), batch_size=self.batch_size, - sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, - scale_range=np.array([-0.05, - 0.05]) + self.src_scale_mod / 100.0), - output_sample_types=[{'types': ( - t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), - 'resolution': resolution, 'apply_ct': apply_random_ct}] + \ - [{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr), - 'resolution': resolution // (2 ** i), - 'apply_ct': apply_random_ct} for i in range(ms_count)] + \ - [{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M), - 'resolution': resolution // (2 ** i)} for i in - range(ms_count)], - ping_pong=self.ping_pong_options, - ), - - SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, - sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ), - output_sample_types=[{'types': ( - t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), - 'resolution': resolution}] + \ - [{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr), - 'resolution': resolution // (2 ** i)} for i in - range(ms_count)] + \ - [{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M), - 'resolution': resolution // (2 ** i)} for i in - range(ms_count)], - ping_pong=self.ping_pong_options,) - ]) + for generators in self.get_training_data_generators(): + generators.update_batch(batch_size) # override def onTrainOneIter(self, generators_samples, generators_list): diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index 654ce1b..231756f 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -7,7 +7,7 @@ import numpy as np from facelib import LandmarksProcessor from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleGeneratorPingPong, SampleType) -from samplelib.SampleGeneratorPingPong import PingPongOptions +from samplelib.SampleGeneratorPingPong import PingPongOptions, SampleGeneratorPingPong from utils import iter_utils @@ -25,7 +25,7 @@ class SampleGeneratorFace(SampleGeneratorPingPong): random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, ping_pong=PingPongOptions(), **kwargs): - super().__init__(samples_path, debug, batch_size, ping_pong) + super().__init__(samples_path, debug, batch_size=batch_size, ping_pong=ping_pong) self.sample_process_options = sample_process_options self.output_sample_types = output_sample_types self.add_sample_idx = add_sample_idx diff --git a/samplelib/__init__.py b/samplelib/__init__.py index d865394..8214c98 100644 --- a/samplelib/__init__.py +++ b/samplelib/__init__.py @@ -4,5 +4,6 @@ from .SampleLoader import SampleLoader from .SampleProcessor import SampleProcessor from .SampleGeneratorBase import SampleGeneratorBase from .SampleGeneratorFace import SampleGeneratorFace +from .SampleGeneratorPingPong import SampleGeneratorPingPong from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal