refactoring

This commit is contained in:
Colombo 2020-03-08 23:19:04 +04:00
parent 121c0cfc0f
commit eda6433936
5 changed files with 9 additions and 13 deletions

View file

@ -6,11 +6,7 @@ You can implement your own SampleGenerator
class SampleGeneratorBase(object):
def __init__ (self, samples_path, debug=False, batch_size=1):
if samples_path is None:
raise Exception('samples_path is None')
self.samples_path = Path(samples_path)
def __init__ (self, debug=False, batch_size=1):
self.debug = debug
self.batch_size = 1 if self.debug else batch_size
self.last_generation = None

View file

@ -30,7 +30,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
raise_on_no_data=True,
**kwargs):
super().__init__(samples_path, debug, batch_size)
super().__init__(debug, batch_size)
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx
@ -40,7 +40,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
else:
self.generators_count = max(1, generators_count)
samples = SampleLoader.load (SampleType.FACE, self.samples_path)
samples = SampleLoader.load (SampleType.FACE, samples_path)
self.samples_len = len(samples)
self.initialized = False

View file

@ -26,14 +26,14 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
person_id_mode=1,
**kwargs):
super().__init__(samples_path, debug, batch_size)
super().__init__(debug, batch_size)
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
self.person_id_mode = person_id_mode
raise NotImplementedError("Currently SampleGeneratorFacePerson is not implemented.")
samples_host = SampleLoader.mp_host (SampleType.FACE, self.samples_path)
samples_host = SampleLoader.mp_host (SampleType.FACE, samples_path)
samples = samples_host.get_list()
self.samples_len = len(samples)

View file

@ -20,7 +20,7 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
output_sample_types=[],
generators_count=2,
**kwargs):
super().__init__(samples_path, debug, batch_size)
super().__init__(debug, batch_size)
self.temporal_image_count = temporal_image_count
self.sample_process_options = sample_process_options
@ -31,7 +31,7 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
else:
self.generators_count = generators_count
samples = SampleLoader.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
samples = SampleLoader.load (SampleType.FACE_TEMPORAL_SORTED, samples_path)
samples_len = len(samples)
if samples_len == 0:
raise ValueError('No training data provided.')

View file

@ -16,13 +16,13 @@ output_sample_types = [
'''
class SampleGeneratorImageTemporal(SampleGeneratorBase):
def __init__ (self, samples_path, debug, batch_size, temporal_image_count, sample_process_options=SampleProcessor.Options(), output_sample_types=[], **kwargs):
super().__init__(samples_path, debug, batch_size)
super().__init__(debug, batch_size)
self.temporal_image_count = temporal_image_count
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
self.samples = SampleLoader.load (SampleType.IMAGE, self.samples_path)
self.samples = SampleLoader.load (SampleType.IMAGE, samples_path)
self.generator_samples = [ self.samples ]
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] if self.debug else \