mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
Decreased amount of RAM used by Sample Generator.
This commit is contained in:
parent
8e9e346c9d
commit
33b0aadb4e
5 changed files with 221 additions and 32 deletions
|
@ -11,7 +11,7 @@ from core import imagelib, mplib, pathex
|
|||
from core.imagelib import sd
|
||||
from core.cv2ex import *
|
||||
from core.interact import interact as io
|
||||
from core.joblib import SubprocessGenerator, ThisThreadGenerator
|
||||
from core.joblib import Subprocessor, SubprocessGenerator, ThisThreadGenerator
|
||||
from facelib import LandmarksProcessor
|
||||
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType)
|
||||
|
||||
|
@ -23,28 +23,24 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
|||
super().__init__(debug, batch_size)
|
||||
self.initialized = False
|
||||
|
||||
samples = []
|
||||
for path in paths:
|
||||
samples += SampleLoader.load (SampleType.FACE, path)
|
||||
|
||||
seg_samples = [ sample for sample in samples if sample.seg_ie_polys.get_pts_count() != 0]
|
||||
seg_samples_len = len(seg_samples)
|
||||
samples = sum([ SampleLoader.load (SampleType.FACE, path) for path in paths ] )
|
||||
seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples).run()
|
||||
|
||||
seg_samples_len = len(seg_sample_idxs)
|
||||
if seg_samples_len == 0:
|
||||
raise Exception(f"No segmented faces found.")
|
||||
else:
|
||||
io.log_info(f"Using {seg_samples_len} segmented samples.")
|
||||
|
||||
pickled_samples = pickle.dumps(seg_samples, 4)
|
||||
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
else:
|
||||
self.generators_count = max(1, generators_count)
|
||||
|
||||
if self.debug:
|
||||
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, resolution, face_type, data_format) )]
|
||||
self.generators = [ThisThreadGenerator ( self.batch_func, (samples, seg_sample_idxs, resolution, face_type, data_format) )]
|
||||
else:
|
||||
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, resolution, face_type, data_format), start_now=False ) \
|
||||
self.generators = [SubprocessGenerator ( self.batch_func, (samples, seg_sample_idxs, resolution, face_type, data_format), start_now=False ) \
|
||||
for i in range(self.generators_count) ]
|
||||
|
||||
SubprocessGenerator.start_in_parallel( self.generators )
|
||||
|
@ -66,12 +62,9 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
|||
return next(generator)
|
||||
|
||||
def batch_func(self, param ):
|
||||
pickled_samples, resolution, face_type, data_format = param
|
||||
|
||||
samples = pickle.loads(pickled_samples)
|
||||
|
||||
samples, seg_sample_idxs, resolution, face_type, data_format = param
|
||||
|
||||
shuffle_idxs = []
|
||||
idxs = [*range(len(samples))]
|
||||
|
||||
random_flip = True
|
||||
rotation_range=[-10,10]
|
||||
|
@ -91,7 +84,7 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
|||
while n_batch < bs:
|
||||
try:
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = idxs.copy()
|
||||
shuffle_idxs = seg_sample_idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
idx = shuffle_idxs.pop()
|
||||
|
||||
|
@ -146,3 +139,56 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
|||
io.log_err ( traceback.format_exc() )
|
||||
|
||||
yield [ np.array(batch) for batch in batches]
|
||||
|
||||
|
||||
|
||||
class SegmentedSampleFilterSubprocessor(Subprocessor):
|
||||
#override
|
||||
def __init__(self, samples ):
|
||||
self.samples = samples
|
||||
self.samples_len = len(self.samples)
|
||||
|
||||
self.idxs = [*range(self.samples_len)]
|
||||
self.result = []
|
||||
super().__init__('SegmentedSampleFilterSubprocessor', SegmentedSampleFilterSubprocessor.Cli, 60)
|
||||
|
||||
#override
|
||||
def process_info_generator(self):
|
||||
for i in range(multiprocessing.cpu_count()):
|
||||
yield 'CPU%d' % (i), {}, {'samples':self.samples}
|
||||
|
||||
#override
|
||||
def on_clients_initialized(self):
|
||||
io.progress_bar ("Filtering", self.samples_len)
|
||||
|
||||
#override
|
||||
def on_clients_finalized(self):
|
||||
io.progress_bar_close()
|
||||
|
||||
#override
|
||||
def get_data(self, host_dict):
|
||||
if len (self.idxs) > 0:
|
||||
return self.idxs.pop(0)
|
||||
|
||||
return None
|
||||
|
||||
#override
|
||||
def on_data_return (self, host_dict, data):
|
||||
self.idxs.insert(0, data)
|
||||
|
||||
#override
|
||||
def on_result (self, host_dict, data, result):
|
||||
idx, is_ok = result
|
||||
if is_ok:
|
||||
self.result.append(idx)
|
||||
io.progress_bar_inc(1)
|
||||
def get_result(self):
|
||||
return self.result
|
||||
|
||||
class Cli(Subprocessor.Cli):
|
||||
#overridable optional
|
||||
def on_initialize(self, client_dict):
|
||||
self.samples = client_dict['samples']
|
||||
|
||||
def process_data(self, idx):
|
||||
return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0
|
Loading…
Add table
Add a link
Reference in a new issue