Decreased amount of RAM used by Sample Generator.

This commit is contained in:
Colombo 2020-04-05 13:52:32 +04:00
parent 8e9e346c9d
commit 33b0aadb4e
5 changed files with 221 additions and 32 deletions

View file

@ -1,5 +1,4 @@
import multiprocessing
import pickle
import time
import traceback
@ -12,7 +11,6 @@ from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
SampleType)
'''
arg
output_sample_types = [
@ -59,13 +57,10 @@ class SampleGeneratorFace(SampleGeneratorBase):
ct_samples = None
ct_index_host = None
pickled_samples = pickle.dumps(samples, 4)
ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None
if self.debug:
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
self.generators = [ThisThreadGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
else:
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \
self.generators = [SubprocessGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \
for i in range(self.generators_count) ]
SubprocessGenerator.start_in_parallel( self.generators )
@ -90,11 +85,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
return next(generator)
def batch_func(self, param ):
pickled_samples, index_host, ct_pickled_samples, ct_index_host = param
samples = pickle.loads(pickled_samples)
ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None
samples, index_host, ct_samples, ct_index_host = param
bs = self.batch_size
while True:
batches = None

View file

@ -11,7 +11,7 @@ from core import imagelib, mplib, pathex
from core.imagelib import sd
from core.cv2ex import *
from core.interact import interact as io
from core.joblib import SubprocessGenerator, ThisThreadGenerator
from core.joblib import Subprocessor, SubprocessGenerator, ThisThreadGenerator
from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType)
@ -23,28 +23,24 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
super().__init__(debug, batch_size)
self.initialized = False
samples = []
for path in paths:
samples += SampleLoader.load (SampleType.FACE, path)
seg_samples = [ sample for sample in samples if sample.seg_ie_polys.get_pts_count() != 0]
seg_samples_len = len(seg_samples)
samples = sum([ SampleLoader.load (SampleType.FACE, path) for path in paths ] )
seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples).run()
seg_samples_len = len(seg_sample_idxs)
if seg_samples_len == 0:
raise Exception(f"No segmented faces found.")
else:
io.log_info(f"Using {seg_samples_len} segmented samples.")
pickled_samples = pickle.dumps(seg_samples, 4)
if self.debug:
self.generators_count = 1
else:
self.generators_count = max(1, generators_count)
if self.debug:
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, resolution, face_type, data_format) )]
self.generators = [ThisThreadGenerator ( self.batch_func, (samples, seg_sample_idxs, resolution, face_type, data_format) )]
else:
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, resolution, face_type, data_format), start_now=False ) \
self.generators = [SubprocessGenerator ( self.batch_func, (samples, seg_sample_idxs, resolution, face_type, data_format), start_now=False ) \
for i in range(self.generators_count) ]
SubprocessGenerator.start_in_parallel( self.generators )
@ -66,12 +62,9 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
return next(generator)
def batch_func(self, param ):
pickled_samples, resolution, face_type, data_format = param
samples = pickle.loads(pickled_samples)
samples, seg_sample_idxs, resolution, face_type, data_format = param
shuffle_idxs = []
idxs = [*range(len(samples))]
random_flip = True
rotation_range=[-10,10]
@ -91,7 +84,7 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
while n_batch < bs:
try:
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
shuffle_idxs = seg_sample_idxs.copy()
np.random.shuffle(shuffle_idxs)
idx = shuffle_idxs.pop()
@ -146,3 +139,56 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
io.log_err ( traceback.format_exc() )
yield [ np.array(batch) for batch in batches]
class SegmentedSampleFilterSubprocessor(Subprocessor):
#override
def __init__(self, samples ):
self.samples = samples
self.samples_len = len(self.samples)
self.idxs = [*range(self.samples_len)]
self.result = []
super().__init__('SegmentedSampleFilterSubprocessor', SegmentedSampleFilterSubprocessor.Cli, 60)
#override
def process_info_generator(self):
for i in range(multiprocessing.cpu_count()):
yield 'CPU%d' % (i), {}, {'samples':self.samples}
#override
def on_clients_initialized(self):
io.progress_bar ("Filtering", self.samples_len)
#override
def on_clients_finalized(self):
io.progress_bar_close()
#override
def get_data(self, host_dict):
if len (self.idxs) > 0:
return self.idxs.pop(0)
return None
#override
def on_data_return (self, host_dict, data):
self.idxs.insert(0, data)
#override
def on_result (self, host_dict, data, result):
idx, is_ok = result
if is_ok:
self.result.append(idx)
io.progress_bar_inc(1)
def get_result(self):
return self.result
class Cli(Subprocessor.Cli):
#overridable optional
def on_initialize(self, client_dict):
self.samples = client_dict['samples']
def process_data(self, idx):
return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0

View file

@ -6,6 +6,7 @@ from pathlib import Path
import samplelib.PackedFaceset
from core import pathex
from core.mplib import MPSharedList
from core.interact import interact as io
from core.joblib import Subprocessor
from DFLIMG import *
@ -33,6 +34,9 @@ class SampleLoader:
@staticmethod
def load(sample_type, samples_path, subdirs=False):
"""
Return MPSharedList of samples
"""
samples_cache = SampleLoader.samples_cache
if str(samples_path) not in samples_cache.keys():
@ -56,12 +60,12 @@ class SampleLoader:
if result is None:
result = SampleLoader.load_face_samples( pathex.get_image_paths(samples_path, subdirs=subdirs) )
samples[sample_type] = result
samples[sample_type] = MPSharedList(result)
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
result = SampleLoader.load (SampleType.FACE, samples_path)
result = SampleLoader.upgradeToFaceTemporalSortedSamples(result)
samples[sample_type] = result
samples[sample_type] = MPSharedList(result)
return samples[sample_type]