Decreased amount of RAM used by Sample Generator.

This commit is contained in:
Colombo 2020-04-05 13:52:32 +04:00
parent 8e9e346c9d
commit 33b0aadb4e
5 changed files with 221 additions and 32 deletions

146
core/mplib/MPSharedList.py Normal file
View file

@ -0,0 +1,146 @@
import multiprocessing
import pickle
import struct
from core.joblib import Subprocessor
class MPSharedList():
"""
Provides read-only pickled list of constant objects via shared memory aka 'multiprocessing.Array'
Thus no 4GB limit for subprocesses.
supports list concat via + or sum()
"""
def __init__(self, obj_list):
if obj_list is None:
self.obj_counts = None
self.table_offsets = None
self.data_offsets = None
self.sh_bs = None
else:
obj_count, table_offset, data_offset, sh_b = MPSharedList.bake_data(obj_list)
self.obj_counts = [obj_count]
self.table_offsets = [table_offset]
self.data_offsets = [data_offset]
self.sh_bs = [sh_b]
def __add__(self, o):
if isinstance(o, MPSharedList):
m = MPSharedList(None)
m.obj_counts = self.obj_counts + o.obj_counts
m.table_offsets = self.table_offsets + o.table_offsets
m.data_offsets = self.data_offsets + o.data_offsets
m.sh_bs = self.sh_bs + o.sh_bs
return m
elif isinstance(o, int):
return self
else:
raise ValueError(f"MPSharedList object of class {o.__class__} is not supported for __add__ operator.")
def __radd__(self, o):
return self+o
def __len__(self):
return sum(self.obj_counts)
def __getitem__(self, key):
obj_count = sum(self.obj_counts)
if key < 0:
key = obj_count+key
if key < 0 or key >= obj_count:
raise ValueError("out of range")
for i in range(len(self.obj_counts)):
if key < self.obj_counts[i]:
table_offset = self.table_offsets[i]
data_offset = self.data_offsets[i]
sh_b = self.sh_bs[i]
break
key -= self.obj_counts[i]
offset_start, offset_end = struct.unpack('<QQ', bytes(sh_b[ table_offset + key*8 : table_offset + (key+2)*8]) )
return pickle.loads( bytes(sh_b[ data_offset + offset_start : data_offset + offset_end ]) )
def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)
@staticmethod
def bake_data(obj_list):
if not isinstance(obj_list, list):
raise ValueError("MPSharedList: obj_list should be list type.")
obj_count = len(obj_list)
if obj_count != 0:
obj_pickled_ar = [pickle.dumps(o, 4) for o in obj_list]
table_offset = 0
table_size = (obj_count+1)*8
data_offset = table_offset + table_size
data_size = sum([len(x) for x in obj_pickled_ar])
sh_b = multiprocessing.RawArray('B', table_size + data_size)
sh_b[0:8] = struct.pack('<Q', obj_count)
offset = 0
sh_b_table = bytes()
offsets = []
offset = 0
for i in range(obj_count):
offsets.append(offset)
offset += len(obj_pickled_ar[i])
offsets.append(offset)
sh_b[table_offset:table_offset+table_size] = struct.pack( '<'+'Q'*len(offsets), *offsets )
ArrayFillerSubprocessor(sh_b, [ (data_offset+offsets[i], obj_pickled_ar[i] ) for i in range(obj_count) ] ).run()
return obj_count, table_offset, data_offset, sh_b
class ArrayFillerSubprocessor(Subprocessor):
"""
Much faster to fill shared memory via subprocesses rather than direct whole bytes fill.
"""
#override
def __init__(self, sh_b, data_list ):
self.sh_b = sh_b
self.data_list = data_list
super().__init__('ArrayFillerSubprocessor', ArrayFillerSubprocessor.Cli, 60)
#override
def process_info_generator(self):
for i in range(multiprocessing.cpu_count()):
yield 'CPU%d' % (i), {}, {'sh_b':self.sh_b}
#override
def get_data(self, host_dict):
if len(self.data_list) > 0:
return self.data_list.pop(0)
return None
#override
def on_data_return (self, host_dict, data):
self.data_list.insert(0, data)
#override
def on_result (self, host_dict, data, result):
pass
class Cli(Subprocessor.Cli):
#overridable optional
def on_initialize(self, client_dict):
self.sh_b = client_dict['sh_b']
def process_data(self, data):
offset, b = data
self.sh_b[offset:offset+len(b)]=b
return 0

View file

@ -1,3 +1,4 @@
from .MPSharedList import MPSharedList
import multiprocessing import multiprocessing
import threading import threading
import time import time

View file

@ -1,5 +1,4 @@
import multiprocessing import multiprocessing
import pickle
import time import time
import traceback import traceback
@ -12,7 +11,6 @@ from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
SampleType) SampleType)
''' '''
arg arg
output_sample_types = [ output_sample_types = [
@ -59,13 +57,10 @@ class SampleGeneratorFace(SampleGeneratorBase):
ct_samples = None ct_samples = None
ct_index_host = None ct_index_host = None
pickled_samples = pickle.dumps(samples, 4)
ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None
if self.debug: if self.debug:
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )] self.generators = [ThisThreadGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
else: else:
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \ self.generators = [SubprocessGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \
for i in range(self.generators_count) ] for i in range(self.generators_count) ]
SubprocessGenerator.start_in_parallel( self.generators ) SubprocessGenerator.start_in_parallel( self.generators )
@ -90,11 +85,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
return next(generator) return next(generator)
def batch_func(self, param ): def batch_func(self, param ):
pickled_samples, index_host, ct_pickled_samples, ct_index_host = param samples, index_host, ct_samples, ct_index_host = param
samples = pickle.loads(pickled_samples)
ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None
bs = self.batch_size bs = self.batch_size
while True: while True:
batches = None batches = None

View file

@ -11,7 +11,7 @@ from core import imagelib, mplib, pathex
from core.imagelib import sd from core.imagelib import sd
from core.cv2ex import * from core.cv2ex import *
from core.interact import interact as io from core.interact import interact as io
from core.joblib import SubprocessGenerator, ThisThreadGenerator from core.joblib import Subprocessor, SubprocessGenerator, ThisThreadGenerator
from facelib import LandmarksProcessor from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType) from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType)
@ -23,28 +23,24 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
super().__init__(debug, batch_size) super().__init__(debug, batch_size)
self.initialized = False self.initialized = False
samples = [] samples = sum([ SampleLoader.load (SampleType.FACE, path) for path in paths ] )
for path in paths: seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples).run()
samples += SampleLoader.load (SampleType.FACE, path)
seg_samples_len = len(seg_sample_idxs)
seg_samples = [ sample for sample in samples if sample.seg_ie_polys.get_pts_count() != 0]
seg_samples_len = len(seg_samples)
if seg_samples_len == 0: if seg_samples_len == 0:
raise Exception(f"No segmented faces found.") raise Exception(f"No segmented faces found.")
else: else:
io.log_info(f"Using {seg_samples_len} segmented samples.") io.log_info(f"Using {seg_samples_len} segmented samples.")
pickled_samples = pickle.dumps(seg_samples, 4)
if self.debug: if self.debug:
self.generators_count = 1 self.generators_count = 1
else: else:
self.generators_count = max(1, generators_count) self.generators_count = max(1, generators_count)
if self.debug: if self.debug:
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, resolution, face_type, data_format) )] self.generators = [ThisThreadGenerator ( self.batch_func, (samples, seg_sample_idxs, resolution, face_type, data_format) )]
else: else:
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, resolution, face_type, data_format), start_now=False ) \ self.generators = [SubprocessGenerator ( self.batch_func, (samples, seg_sample_idxs, resolution, face_type, data_format), start_now=False ) \
for i in range(self.generators_count) ] for i in range(self.generators_count) ]
SubprocessGenerator.start_in_parallel( self.generators ) SubprocessGenerator.start_in_parallel( self.generators )
@ -66,12 +62,9 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
return next(generator) return next(generator)
def batch_func(self, param ): def batch_func(self, param ):
pickled_samples, resolution, face_type, data_format = param samples, seg_sample_idxs, resolution, face_type, data_format = param
samples = pickle.loads(pickled_samples)
shuffle_idxs = [] shuffle_idxs = []
idxs = [*range(len(samples))]
random_flip = True random_flip = True
rotation_range=[-10,10] rotation_range=[-10,10]
@ -91,7 +84,7 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
while n_batch < bs: while n_batch < bs:
try: try:
if len(shuffle_idxs) == 0: if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy() shuffle_idxs = seg_sample_idxs.copy()
np.random.shuffle(shuffle_idxs) np.random.shuffle(shuffle_idxs)
idx = shuffle_idxs.pop() idx = shuffle_idxs.pop()
@ -146,3 +139,56 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
io.log_err ( traceback.format_exc() ) io.log_err ( traceback.format_exc() )
yield [ np.array(batch) for batch in batches] yield [ np.array(batch) for batch in batches]
class SegmentedSampleFilterSubprocessor(Subprocessor):
#override
def __init__(self, samples ):
self.samples = samples
self.samples_len = len(self.samples)
self.idxs = [*range(self.samples_len)]
self.result = []
super().__init__('SegmentedSampleFilterSubprocessor', SegmentedSampleFilterSubprocessor.Cli, 60)
#override
def process_info_generator(self):
for i in range(multiprocessing.cpu_count()):
yield 'CPU%d' % (i), {}, {'samples':self.samples}
#override
def on_clients_initialized(self):
io.progress_bar ("Filtering", self.samples_len)
#override
def on_clients_finalized(self):
io.progress_bar_close()
#override
def get_data(self, host_dict):
if len (self.idxs) > 0:
return self.idxs.pop(0)
return None
#override
def on_data_return (self, host_dict, data):
self.idxs.insert(0, data)
#override
def on_result (self, host_dict, data, result):
idx, is_ok = result
if is_ok:
self.result.append(idx)
io.progress_bar_inc(1)
def get_result(self):
return self.result
class Cli(Subprocessor.Cli):
#overridable optional
def on_initialize(self, client_dict):
self.samples = client_dict['samples']
def process_data(self, idx):
return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0

View file

@ -6,6 +6,7 @@ from pathlib import Path
import samplelib.PackedFaceset import samplelib.PackedFaceset
from core import pathex from core import pathex
from core.mplib import MPSharedList
from core.interact import interact as io from core.interact import interact as io
from core.joblib import Subprocessor from core.joblib import Subprocessor
from DFLIMG import * from DFLIMG import *
@ -33,6 +34,9 @@ class SampleLoader:
@staticmethod @staticmethod
def load(sample_type, samples_path, subdirs=False): def load(sample_type, samples_path, subdirs=False):
"""
Return MPSharedList of samples
"""
samples_cache = SampleLoader.samples_cache samples_cache = SampleLoader.samples_cache
if str(samples_path) not in samples_cache.keys(): if str(samples_path) not in samples_cache.keys():
@ -56,12 +60,12 @@ class SampleLoader:
if result is None: if result is None:
result = SampleLoader.load_face_samples( pathex.get_image_paths(samples_path, subdirs=subdirs) ) result = SampleLoader.load_face_samples( pathex.get_image_paths(samples_path, subdirs=subdirs) )
samples[sample_type] = result
samples[sample_type] = MPSharedList(result)
elif sample_type == SampleType.FACE_TEMPORAL_SORTED: elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
result = SampleLoader.load (SampleType.FACE, samples_path) result = SampleLoader.load (SampleType.FACE, samples_path)
result = SampleLoader.upgradeToFaceTemporalSortedSamples(result) result = SampleLoader.upgradeToFaceTemporalSortedSamples(result)
samples[sample_type] = result samples[sample_type] = MPSharedList(result)
return samples[sample_type] return samples[sample_type]