mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-05 12:36:42 -07:00
refactoring
This commit is contained in:
parent
121c0cfc0f
commit
eda6433936
5 changed files with 9 additions and 13 deletions
|
@ -6,11 +6,7 @@ You can implement your own SampleGenerator
|
|||
class SampleGeneratorBase(object):
|
||||
|
||||
|
||||
def __init__ (self, samples_path, debug=False, batch_size=1):
|
||||
if samples_path is None:
|
||||
raise Exception('samples_path is None')
|
||||
|
||||
self.samples_path = Path(samples_path)
|
||||
def __init__ (self, debug=False, batch_size=1):
|
||||
self.debug = debug
|
||||
self.batch_size = 1 if self.debug else batch_size
|
||||
self.last_generation = None
|
||||
|
|
|
@ -30,7 +30,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
raise_on_no_data=True,
|
||||
**kwargs):
|
||||
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
super().__init__(debug, batch_size)
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
self.add_sample_idx = add_sample_idx
|
||||
|
@ -40,7 +40,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
else:
|
||||
self.generators_count = max(1, generators_count)
|
||||
|
||||
samples = SampleLoader.load (SampleType.FACE, self.samples_path)
|
||||
samples = SampleLoader.load (SampleType.FACE, samples_path)
|
||||
self.samples_len = len(samples)
|
||||
|
||||
self.initialized = False
|
||||
|
|
|
@ -26,14 +26,14 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
|||
person_id_mode=1,
|
||||
**kwargs):
|
||||
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
super().__init__(debug, batch_size)
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
self.person_id_mode = person_id_mode
|
||||
|
||||
raise NotImplementedError("Currently SampleGeneratorFacePerson is not implemented.")
|
||||
|
||||
samples_host = SampleLoader.mp_host (SampleType.FACE, self.samples_path)
|
||||
samples_host = SampleLoader.mp_host (SampleType.FACE, samples_path)
|
||||
samples = samples_host.get_list()
|
||||
self.samples_len = len(samples)
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
|||
output_sample_types=[],
|
||||
generators_count=2,
|
||||
**kwargs):
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
super().__init__(debug, batch_size)
|
||||
|
||||
self.temporal_image_count = temporal_image_count
|
||||
self.sample_process_options = sample_process_options
|
||||
|
@ -31,7 +31,7 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
|||
else:
|
||||
self.generators_count = generators_count
|
||||
|
||||
samples = SampleLoader.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
|
||||
samples = SampleLoader.load (SampleType.FACE_TEMPORAL_SORTED, samples_path)
|
||||
samples_len = len(samples)
|
||||
if samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
|
|
@ -16,13 +16,13 @@ output_sample_types = [
|
|||
'''
|
||||
class SampleGeneratorImageTemporal(SampleGeneratorBase):
|
||||
def __init__ (self, samples_path, debug, batch_size, temporal_image_count, sample_process_options=SampleProcessor.Options(), output_sample_types=[], **kwargs):
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
super().__init__(debug, batch_size)
|
||||
|
||||
self.temporal_image_count = temporal_image_count
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
|
||||
self.samples = SampleLoader.load (SampleType.IMAGE, self.samples_path)
|
||||
self.samples = SampleLoader.load (SampleType.IMAGE, samples_path)
|
||||
|
||||
self.generator_samples = [ self.samples ]
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] if self.debug else \
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue