mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
refactoring
This commit is contained in:
parent
121c0cfc0f
commit
eda6433936
5 changed files with 9 additions and 13 deletions
|
@ -6,11 +6,7 @@ You can implement your own SampleGenerator
|
||||||
class SampleGeneratorBase(object):
|
class SampleGeneratorBase(object):
|
||||||
|
|
||||||
|
|
||||||
def __init__ (self, samples_path, debug=False, batch_size=1):
|
def __init__ (self, debug=False, batch_size=1):
|
||||||
if samples_path is None:
|
|
||||||
raise Exception('samples_path is None')
|
|
||||||
|
|
||||||
self.samples_path = Path(samples_path)
|
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.batch_size = 1 if self.debug else batch_size
|
self.batch_size = 1 if self.debug else batch_size
|
||||||
self.last_generation = None
|
self.last_generation = None
|
||||||
|
|
|
@ -30,7 +30,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
raise_on_no_data=True,
|
raise_on_no_data=True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
super().__init__(samples_path, debug, batch_size)
|
super().__init__(debug, batch_size)
|
||||||
self.sample_process_options = sample_process_options
|
self.sample_process_options = sample_process_options
|
||||||
self.output_sample_types = output_sample_types
|
self.output_sample_types = output_sample_types
|
||||||
self.add_sample_idx = add_sample_idx
|
self.add_sample_idx = add_sample_idx
|
||||||
|
@ -40,7 +40,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
else:
|
else:
|
||||||
self.generators_count = max(1, generators_count)
|
self.generators_count = max(1, generators_count)
|
||||||
|
|
||||||
samples = SampleLoader.load (SampleType.FACE, self.samples_path)
|
samples = SampleLoader.load (SampleType.FACE, samples_path)
|
||||||
self.samples_len = len(samples)
|
self.samples_len = len(samples)
|
||||||
|
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
|
|
|
@ -26,14 +26,14 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
||||||
person_id_mode=1,
|
person_id_mode=1,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
super().__init__(samples_path, debug, batch_size)
|
super().__init__(debug, batch_size)
|
||||||
self.sample_process_options = sample_process_options
|
self.sample_process_options = sample_process_options
|
||||||
self.output_sample_types = output_sample_types
|
self.output_sample_types = output_sample_types
|
||||||
self.person_id_mode = person_id_mode
|
self.person_id_mode = person_id_mode
|
||||||
|
|
||||||
raise NotImplementedError("Currently SampleGeneratorFacePerson is not implemented.")
|
raise NotImplementedError("Currently SampleGeneratorFacePerson is not implemented.")
|
||||||
|
|
||||||
samples_host = SampleLoader.mp_host (SampleType.FACE, self.samples_path)
|
samples_host = SampleLoader.mp_host (SampleType.FACE, samples_path)
|
||||||
samples = samples_host.get_list()
|
samples = samples_host.get_list()
|
||||||
self.samples_len = len(samples)
|
self.samples_len = len(samples)
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
||||||
output_sample_types=[],
|
output_sample_types=[],
|
||||||
generators_count=2,
|
generators_count=2,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(samples_path, debug, batch_size)
|
super().__init__(debug, batch_size)
|
||||||
|
|
||||||
self.temporal_image_count = temporal_image_count
|
self.temporal_image_count = temporal_image_count
|
||||||
self.sample_process_options = sample_process_options
|
self.sample_process_options = sample_process_options
|
||||||
|
@ -31,7 +31,7 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
||||||
else:
|
else:
|
||||||
self.generators_count = generators_count
|
self.generators_count = generators_count
|
||||||
|
|
||||||
samples = SampleLoader.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
|
samples = SampleLoader.load (SampleType.FACE_TEMPORAL_SORTED, samples_path)
|
||||||
samples_len = len(samples)
|
samples_len = len(samples)
|
||||||
if samples_len == 0:
|
if samples_len == 0:
|
||||||
raise ValueError('No training data provided.')
|
raise ValueError('No training data provided.')
|
||||||
|
|
|
@ -16,13 +16,13 @@ output_sample_types = [
|
||||||
'''
|
'''
|
||||||
class SampleGeneratorImageTemporal(SampleGeneratorBase):
|
class SampleGeneratorImageTemporal(SampleGeneratorBase):
|
||||||
def __init__ (self, samples_path, debug, batch_size, temporal_image_count, sample_process_options=SampleProcessor.Options(), output_sample_types=[], **kwargs):
|
def __init__ (self, samples_path, debug, batch_size, temporal_image_count, sample_process_options=SampleProcessor.Options(), output_sample_types=[], **kwargs):
|
||||||
super().__init__(samples_path, debug, batch_size)
|
super().__init__(debug, batch_size)
|
||||||
|
|
||||||
self.temporal_image_count = temporal_image_count
|
self.temporal_image_count = temporal_image_count
|
||||||
self.sample_process_options = sample_process_options
|
self.sample_process_options = sample_process_options
|
||||||
self.output_sample_types = output_sample_types
|
self.output_sample_types = output_sample_types
|
||||||
|
|
||||||
self.samples = SampleLoader.load (SampleType.IMAGE, self.samples_path)
|
self.samples = SampleLoader.load (SampleType.IMAGE, samples_path)
|
||||||
|
|
||||||
self.generator_samples = [ self.samples ]
|
self.generator_samples = [ self.samples ]
|
||||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] if self.debug else \
|
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] if self.debug else \
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue