mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
fix
This commit is contained in:
parent
754d6c385c
commit
9ed0111824
1 changed files with 6 additions and 11 deletions
|
@ -37,19 +37,14 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
|||
|
||||
if self.samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
||||
persons_name_idxs = {}
|
||||
|
||||
for i,sample in enumerate(samples):
|
||||
person_name = sample.person_name
|
||||
if person_name not in persons_name_idxs:
|
||||
persons_name_idxs[person_name] = []
|
||||
persons_name_idxs[person_name].append (i)
|
||||
|
||||
indexes2D = [ persons_name_idxs[person_name] for person_name in sorted(list(persons_name_idxs.keys())) ]
|
||||
|
||||
unique_person_names = { sample.person_name for sample in samples }
|
||||
persons_name_idxs = { person_name : [] for person_name in unique_person_names }
|
||||
for i,sample in enumerate(samples):
|
||||
persons_name_idxs[sample.person_name].append (i)
|
||||
indexes2D = [ persons_name_idxs[person_name] for person_name in unique_person_names ]
|
||||
index2d_host = mp_utils.Index2DHost(indexes2D)
|
||||
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) )]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue