diff --git a/samplelib/SampleGeneratorBase.py b/samplelib/SampleGeneratorBase.py index 19d02b3..6e60160 100644 --- a/samplelib/SampleGeneratorBase.py +++ b/samplelib/SampleGeneratorBase.py @@ -6,11 +6,7 @@ You can implement your own SampleGenerator class SampleGeneratorBase(object): - def __init__ (self, samples_path, debug=False, batch_size=1): - if samples_path is None: - raise Exception('samples_path is None') - - self.samples_path = Path(samples_path) + def __init__ (self, debug=False, batch_size=1): self.debug = debug self.batch_size = 1 if self.debug else batch_size self.last_generation = None diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index f386797..c91337e 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -30,7 +30,7 @@ class SampleGeneratorFace(SampleGeneratorBase): raise_on_no_data=True, **kwargs): - super().__init__(samples_path, debug, batch_size) + super().__init__(debug, batch_size) self.sample_process_options = sample_process_options self.output_sample_types = output_sample_types self.add_sample_idx = add_sample_idx @@ -40,7 +40,7 @@ class SampleGeneratorFace(SampleGeneratorBase): else: self.generators_count = max(1, generators_count) - samples = SampleLoader.load (SampleType.FACE, self.samples_path) + samples = SampleLoader.load (SampleType.FACE, samples_path) self.samples_len = len(samples) self.initialized = False diff --git a/samplelib/SampleGeneratorFacePerson.py b/samplelib/SampleGeneratorFacePerson.py index ba69f44..a72cf59 100644 --- a/samplelib/SampleGeneratorFacePerson.py +++ b/samplelib/SampleGeneratorFacePerson.py @@ -26,14 +26,14 @@ class SampleGeneratorFacePerson(SampleGeneratorBase): person_id_mode=1, **kwargs): - super().__init__(samples_path, debug, batch_size) + super().__init__(debug, batch_size) self.sample_process_options = sample_process_options self.output_sample_types = output_sample_types self.person_id_mode = person_id_mode raise NotImplementedError("Currently SampleGeneratorFacePerson is not implemented.") - samples_host = SampleLoader.mp_host (SampleType.FACE, self.samples_path) + samples_host = SampleLoader.mp_host (SampleType.FACE, samples_path) samples = samples_host.get_list() self.samples_len = len(samples) diff --git a/samplelib/SampleGeneratorFaceTemporal.py b/samplelib/SampleGeneratorFaceTemporal.py index 5ffb8e4..2137476 100644 --- a/samplelib/SampleGeneratorFaceTemporal.py +++ b/samplelib/SampleGeneratorFaceTemporal.py @@ -20,7 +20,7 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase): output_sample_types=[], generators_count=2, **kwargs): - super().__init__(samples_path, debug, batch_size) + super().__init__(debug, batch_size) self.temporal_image_count = temporal_image_count self.sample_process_options = sample_process_options @@ -31,7 +31,7 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase): else: self.generators_count = generators_count - samples = SampleLoader.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path) + samples = SampleLoader.load (SampleType.FACE_TEMPORAL_SORTED, samples_path) samples_len = len(samples) if samples_len == 0: raise ValueError('No training data provided.') diff --git a/samplelib/SampleGeneratorImageTemporal.py b/samplelib/SampleGeneratorImageTemporal.py index 4f86e43..ed3ab4a 100644 --- a/samplelib/SampleGeneratorImageTemporal.py +++ b/samplelib/SampleGeneratorImageTemporal.py @@ -16,13 +16,13 @@ output_sample_types = [ ''' class SampleGeneratorImageTemporal(SampleGeneratorBase): def __init__ (self, samples_path, debug, batch_size, temporal_image_count, sample_process_options=SampleProcessor.Options(), output_sample_types=[], **kwargs): - super().__init__(samples_path, debug, batch_size) + super().__init__(debug, batch_size) self.temporal_image_count = temporal_image_count self.sample_process_options = sample_process_options self.output_sample_types = output_sample_types - self.samples = SampleLoader.load (SampleType.IMAGE, self.samples_path) + self.samples = SampleLoader.load (SampleType.IMAGE, samples_path) self.generator_samples = [ self.samples ] self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] if self.debug else \