mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
optimized sample generator
This commit is contained in:
parent
b5c234dac3
commit
21b25038ac
6 changed files with 201 additions and 160 deletions
|
@ -1,7 +1,5 @@
|
|||
import gc
|
||||
import multiprocessing
|
||||
import operator
|
||||
import pickle
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -16,9 +14,11 @@ from .Sample import Sample, SampleType
|
|||
|
||||
|
||||
class SampleHost:
|
||||
|
||||
|
||||
|
||||
|
||||
samples_cache = dict()
|
||||
host_cache = dict()
|
||||
|
||||
@staticmethod
|
||||
def get_person_id_max_count(samples_path):
|
||||
samples = None
|
||||
|
@ -35,7 +35,7 @@ class SampleHost:
|
|||
return len(list(persons_name_idxs.keys()))
|
||||
|
||||
@staticmethod
|
||||
def load(sample_type, samples_path):
|
||||
def host(sample_type, samples_path, number_of_clis):
|
||||
samples_cache = SampleHost.samples_cache
|
||||
|
||||
if str(samples_path) not in samples_cache.keys():
|
||||
|
@ -46,9 +46,11 @@ class SampleHost:
|
|||
if sample_type == SampleType.IMAGE:
|
||||
if samples[sample_type] is None:
|
||||
samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ]
|
||||
elif sample_type == SampleType.FACE:
|
||||
if samples[sample_type] is None:
|
||||
result = None
|
||||
elif sample_type == SampleType.FACE or \
|
||||
sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||
result = None
|
||||
|
||||
if samples[sample_type] is None:
|
||||
try:
|
||||
result = samplelib.PackedFaceset.load(samples_path)
|
||||
except:
|
||||
|
@ -60,33 +62,26 @@ class SampleHost:
|
|||
if result is None:
|
||||
result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) )
|
||||
|
||||
result_dmp = pickle.dumps(result)
|
||||
del result
|
||||
gc.collect()
|
||||
result = pickle.loads(result_dmp)
|
||||
|
||||
samples[sample_type] = result
|
||||
|
||||
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||
if samples[sample_type] is None:
|
||||
samples[sample_type] = SampleHost.upgradeToFaceTemporalSortedSamples( SampleHost.load(SampleType.FACE, samples_path) )
|
||||
samples[sample_type] = mp_utils.ListHost()
|
||||
|
||||
if sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||
result = SampleHost.upgradeToFaceTemporalSortedSamples(result)
|
||||
|
||||
list_host = samples[sample_type]
|
||||
|
||||
clis = [ list_host.create_cli() for _ in range(number_of_clis) ]
|
||||
|
||||
if result is not None:
|
||||
while True:
|
||||
if len(result) == 0:
|
||||
break
|
||||
items = result[0:10000]
|
||||
del result[0:10000]
|
||||
clis[0].extend(items)
|
||||
return clis
|
||||
|
||||
return samples[sample_type]
|
||||
|
||||
@staticmethod
|
||||
def mp_host(sample_type, samples_path):
|
||||
result = SampleHost.load (sample_type, samples_path)
|
||||
|
||||
host_cache = SampleHost.host_cache
|
||||
if str(samples_path) not in host_cache.keys():
|
||||
host_cache[str(samples_path)] = [None]*SampleType.QTY
|
||||
hosts = host_cache[str(samples_path)]
|
||||
|
||||
if hosts[sample_type] is None:
|
||||
hosts[sample_type] = mp_utils.ListHost(result)
|
||||
|
||||
return hosts[sample_type]
|
||||
|
||||
@staticmethod
|
||||
def load_face_samples ( image_paths):
|
||||
result = FaceSamplesLoaderSubprocessor(image_paths).run()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue