mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
DFL-2.0 initial branch commit
This commit is contained in:
parent
52a67a61b3
commit
38b85108b3
154 changed files with 5251 additions and 9414 deletions
|
@ -1,13 +1,16 @@
|
|||
import multiprocessing
|
||||
import traceback
|
||||
import pickle
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
from core import mplib
|
||||
from core.joblib import SubprocessGenerator, ThisThreadGenerator
|
||||
from facelib import LandmarksProcessor
|
||||
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
|
||||
SampleType)
|
||||
from utils import iter_utils, mp_utils
|
||||
|
||||
|
||||
'''
|
||||
|
@ -34,37 +37,33 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
if self.debug:
|
||||
self.generators_count = 1
|
||||
else:
|
||||
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, generators_count)
|
||||
|
||||
self.generators_count = max(1, generators_count)
|
||||
|
||||
samples = SampleHost.load (SampleType.FACE, self.samples_path)
|
||||
self.samples_len = len(samples)
|
||||
|
||||
if self.samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
||||
index_host = mp_utils.IndexHost(self.samples_len)
|
||||
index_host = mplib.IndexHost(self.samples_len)
|
||||
|
||||
if random_ct_samples_path is not None:
|
||||
ct_samples = SampleHost.load (SampleType.FACE, random_ct_samples_path)
|
||||
ct_index_host = mp_utils.IndexHost( len(ct_samples) )
|
||||
ct_index_host = mplib.IndexHost( len(ct_samples) )
|
||||
else:
|
||||
ct_samples = None
|
||||
ct_index_host = None
|
||||
|
||||
pickled_samples = pickle.dumps(samples, 4)
|
||||
ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None
|
||||
|
||||
|
||||
if self.debug:
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
||||
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
||||
else:
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
|
||||
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
#overridable
|
||||
def get_total_sample_count(self):
|
||||
return self.samples_len
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
@ -75,8 +74,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
|
||||
def batch_func(self, param ):
|
||||
pickled_samples, index_host, ct_pickled_samples, ct_index_host = param
|
||||
|
||||
samples = pickle.loads(pickled_samples)
|
||||
|
||||
samples = pickle.loads(pickled_samples)
|
||||
ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None
|
||||
|
||||
bs = self.batch_size
|
||||
|
@ -89,9 +88,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
t = time.time()
|
||||
for n_batch in range(bs):
|
||||
sample_idx = indexes[n_batch]
|
||||
sample = samples[sample_idx]
|
||||
|
||||
ct_sample = None
|
||||
sample = samples[sample_idx]
|
||||
|
||||
ct_sample = None
|
||||
if ct_samples is not None:
|
||||
ct_sample = ct_samples[ct_indexes[n_batch]]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue