mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
refactorings
This commit is contained in:
parent
e0e8970ab9
commit
754d6c385c
13 changed files with 243 additions and 104 deletions
|
@ -31,7 +31,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
self.add_sample_idx = add_sample_idx
|
||||
|
||||
samples_host = SampleHost.mp_host (SampleType.FACE, self.samples_path)
|
||||
self.samples_len = len(samples_host)
|
||||
self.samples_len = len(samples_host.get_list())
|
||||
|
||||
if self.samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
@ -40,7 +40,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
|
||||
if random_ct_samples_path is not None:
|
||||
ct_samples_host = SampleHost.mp_host (SampleType.FACE, random_ct_samples_path)
|
||||
ct_index_host = mp_utils.IndexHost( len(ct_samples_host) )
|
||||
ct_index_host = mp_utils.IndexHost( len(ct_samples_host.get_list()) )
|
||||
else:
|
||||
ct_samples_host = None
|
||||
ct_index_host = None
|
||||
|
@ -76,7 +76,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
ct_indexes = ct_index_host.get(bs) if ct_samples is not None else None
|
||||
|
||||
for n_batch in range(bs):
|
||||
sample = samples[ indexes[n_batch] ]
|
||||
sample_idx = indexes[n_batch]
|
||||
sample = samples[ sample_idx ]
|
||||
ct_sample = ct_samples[ ct_indexes[n_batch] ] if ct_samples is not None else None
|
||||
|
||||
try:
|
||||
|
@ -94,9 +95,5 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
batches[i].append ( x[i] )
|
||||
|
||||
if self.add_sample_idx:
|
||||
batches[i_sample_idx].append (idx)
|
||||
batches[i_sample_idx].append (sample_idx)
|
||||
yield [ np.array(batch) for batch in batches]
|
||||
|
||||
@staticmethod
|
||||
def get_person_id_max_count(samples_path):
|
||||
return SampleHost.get_person_id_max_count(samples_path)
|
Loading…
Add table
Add a link
Reference in a new issue