mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
Decreased amount of RAM used by Sample Generator.
This commit is contained in:
parent
8e9e346c9d
commit
33b0aadb4e
5 changed files with 221 additions and 32 deletions
146
core/mplib/MPSharedList.py
Normal file
146
core/mplib/MPSharedList.py
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
import multiprocessing
|
||||||
|
import pickle
|
||||||
|
import struct
|
||||||
|
from core.joblib import Subprocessor
|
||||||
|
|
||||||
|
class MPSharedList():
|
||||||
|
"""
|
||||||
|
Provides read-only pickled list of constant objects via shared memory aka 'multiprocessing.Array'
|
||||||
|
Thus no 4GB limit for subprocesses.
|
||||||
|
|
||||||
|
supports list concat via + or sum()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, obj_list):
|
||||||
|
if obj_list is None:
|
||||||
|
self.obj_counts = None
|
||||||
|
self.table_offsets = None
|
||||||
|
self.data_offsets = None
|
||||||
|
self.sh_bs = None
|
||||||
|
else:
|
||||||
|
obj_count, table_offset, data_offset, sh_b = MPSharedList.bake_data(obj_list)
|
||||||
|
|
||||||
|
self.obj_counts = [obj_count]
|
||||||
|
self.table_offsets = [table_offset]
|
||||||
|
self.data_offsets = [data_offset]
|
||||||
|
self.sh_bs = [sh_b]
|
||||||
|
|
||||||
|
def __add__(self, o):
|
||||||
|
if isinstance(o, MPSharedList):
|
||||||
|
m = MPSharedList(None)
|
||||||
|
m.obj_counts = self.obj_counts + o.obj_counts
|
||||||
|
m.table_offsets = self.table_offsets + o.table_offsets
|
||||||
|
m.data_offsets = self.data_offsets + o.data_offsets
|
||||||
|
m.sh_bs = self.sh_bs + o.sh_bs
|
||||||
|
return m
|
||||||
|
elif isinstance(o, int):
|
||||||
|
return self
|
||||||
|
else:
|
||||||
|
raise ValueError(f"MPSharedList object of class {o.__class__} is not supported for __add__ operator.")
|
||||||
|
|
||||||
|
def __radd__(self, o):
|
||||||
|
return self+o
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return sum(self.obj_counts)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
obj_count = sum(self.obj_counts)
|
||||||
|
if key < 0:
|
||||||
|
key = obj_count+key
|
||||||
|
if key < 0 or key >= obj_count:
|
||||||
|
raise ValueError("out of range")
|
||||||
|
|
||||||
|
for i in range(len(self.obj_counts)):
|
||||||
|
|
||||||
|
if key < self.obj_counts[i]:
|
||||||
|
table_offset = self.table_offsets[i]
|
||||||
|
data_offset = self.data_offsets[i]
|
||||||
|
sh_b = self.sh_bs[i]
|
||||||
|
break
|
||||||
|
key -= self.obj_counts[i]
|
||||||
|
|
||||||
|
offset_start, offset_end = struct.unpack('<QQ', bytes(sh_b[ table_offset + key*8 : table_offset + (key+2)*8]) )
|
||||||
|
|
||||||
|
return pickle.loads( bytes(sh_b[ data_offset + offset_start : data_offset + offset_end ]) )
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for i in range(self.__len__()):
|
||||||
|
yield self.__getitem__(i)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def bake_data(obj_list):
|
||||||
|
if not isinstance(obj_list, list):
|
||||||
|
raise ValueError("MPSharedList: obj_list should be list type.")
|
||||||
|
|
||||||
|
obj_count = len(obj_list)
|
||||||
|
|
||||||
|
if obj_count != 0:
|
||||||
|
obj_pickled_ar = [pickle.dumps(o, 4) for o in obj_list]
|
||||||
|
|
||||||
|
table_offset = 0
|
||||||
|
table_size = (obj_count+1)*8
|
||||||
|
data_offset = table_offset + table_size
|
||||||
|
data_size = sum([len(x) for x in obj_pickled_ar])
|
||||||
|
|
||||||
|
sh_b = multiprocessing.RawArray('B', table_size + data_size)
|
||||||
|
sh_b[0:8] = struct.pack('<Q', obj_count)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
sh_b_table = bytes()
|
||||||
|
offsets = []
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for i in range(obj_count):
|
||||||
|
offsets.append(offset)
|
||||||
|
offset += len(obj_pickled_ar[i])
|
||||||
|
offsets.append(offset)
|
||||||
|
|
||||||
|
sh_b[table_offset:table_offset+table_size] = struct.pack( '<'+'Q'*len(offsets), *offsets )
|
||||||
|
|
||||||
|
ArrayFillerSubprocessor(sh_b, [ (data_offset+offsets[i], obj_pickled_ar[i] ) for i in range(obj_count) ] ).run()
|
||||||
|
|
||||||
|
return obj_count, table_offset, data_offset, sh_b
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayFillerSubprocessor(Subprocessor):
|
||||||
|
"""
|
||||||
|
Much faster to fill shared memory via subprocesses rather than direct whole bytes fill.
|
||||||
|
"""
|
||||||
|
#override
|
||||||
|
def __init__(self, sh_b, data_list ):
|
||||||
|
self.sh_b = sh_b
|
||||||
|
self.data_list = data_list
|
||||||
|
super().__init__('ArrayFillerSubprocessor', ArrayFillerSubprocessor.Cli, 60)
|
||||||
|
|
||||||
|
#override
|
||||||
|
def process_info_generator(self):
|
||||||
|
for i in range(multiprocessing.cpu_count()):
|
||||||
|
yield 'CPU%d' % (i), {}, {'sh_b':self.sh_b}
|
||||||
|
|
||||||
|
#override
|
||||||
|
def get_data(self, host_dict):
|
||||||
|
if len(self.data_list) > 0:
|
||||||
|
return self.data_list.pop(0)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
#override
|
||||||
|
def on_data_return (self, host_dict, data):
|
||||||
|
self.data_list.insert(0, data)
|
||||||
|
|
||||||
|
#override
|
||||||
|
def on_result (self, host_dict, data, result):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Cli(Subprocessor.Cli):
|
||||||
|
#overridable optional
|
||||||
|
def on_initialize(self, client_dict):
|
||||||
|
self.sh_b = client_dict['sh_b']
|
||||||
|
|
||||||
|
def process_data(self, data):
|
||||||
|
offset, b = data
|
||||||
|
self.sh_b[offset:offset+len(b)]=b
|
||||||
|
return 0
|
|
@ -1,3 +1,4 @@
|
||||||
|
from .MPSharedList import MPSharedList
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import pickle
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -12,7 +11,6 @@ from facelib import LandmarksProcessor
|
||||||
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
||||||
SampleType)
|
SampleType)
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
arg
|
arg
|
||||||
output_sample_types = [
|
output_sample_types = [
|
||||||
|
@ -59,13 +57,10 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
ct_samples = None
|
ct_samples = None
|
||||||
ct_index_host = None
|
ct_index_host = None
|
||||||
|
|
||||||
pickled_samples = pickle.dumps(samples, 4)
|
|
||||||
ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None
|
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
self.generators = [ThisThreadGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
||||||
else:
|
else:
|
||||||
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \
|
self.generators = [SubprocessGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \
|
||||||
for i in range(self.generators_count) ]
|
for i in range(self.generators_count) ]
|
||||||
|
|
||||||
SubprocessGenerator.start_in_parallel( self.generators )
|
SubprocessGenerator.start_in_parallel( self.generators )
|
||||||
|
@ -90,11 +85,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
return next(generator)
|
return next(generator)
|
||||||
|
|
||||||
def batch_func(self, param ):
|
def batch_func(self, param ):
|
||||||
pickled_samples, index_host, ct_pickled_samples, ct_index_host = param
|
samples, index_host, ct_samples, ct_index_host = param
|
||||||
|
|
||||||
samples = pickle.loads(pickled_samples)
|
|
||||||
ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None
|
|
||||||
|
|
||||||
bs = self.batch_size
|
bs = self.batch_size
|
||||||
while True:
|
while True:
|
||||||
batches = None
|
batches = None
|
||||||
|
|
|
@ -11,7 +11,7 @@ from core import imagelib, mplib, pathex
|
||||||
from core.imagelib import sd
|
from core.imagelib import sd
|
||||||
from core.cv2ex import *
|
from core.cv2ex import *
|
||||||
from core.interact import interact as io
|
from core.interact import interact as io
|
||||||
from core.joblib import SubprocessGenerator, ThisThreadGenerator
|
from core.joblib import Subprocessor, SubprocessGenerator, ThisThreadGenerator
|
||||||
from facelib import LandmarksProcessor
|
from facelib import LandmarksProcessor
|
||||||
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType)
|
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType)
|
||||||
|
|
||||||
|
@ -23,28 +23,24 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
||||||
super().__init__(debug, batch_size)
|
super().__init__(debug, batch_size)
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
|
|
||||||
samples = []
|
samples = sum([ SampleLoader.load (SampleType.FACE, path) for path in paths ] )
|
||||||
for path in paths:
|
seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples).run()
|
||||||
samples += SampleLoader.load (SampleType.FACE, path)
|
|
||||||
|
seg_samples_len = len(seg_sample_idxs)
|
||||||
seg_samples = [ sample for sample in samples if sample.seg_ie_polys.get_pts_count() != 0]
|
|
||||||
seg_samples_len = len(seg_samples)
|
|
||||||
if seg_samples_len == 0:
|
if seg_samples_len == 0:
|
||||||
raise Exception(f"No segmented faces found.")
|
raise Exception(f"No segmented faces found.")
|
||||||
else:
|
else:
|
||||||
io.log_info(f"Using {seg_samples_len} segmented samples.")
|
io.log_info(f"Using {seg_samples_len} segmented samples.")
|
||||||
|
|
||||||
pickled_samples = pickle.dumps(seg_samples, 4)
|
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.generators_count = 1
|
self.generators_count = 1
|
||||||
else:
|
else:
|
||||||
self.generators_count = max(1, generators_count)
|
self.generators_count = max(1, generators_count)
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, resolution, face_type, data_format) )]
|
self.generators = [ThisThreadGenerator ( self.batch_func, (samples, seg_sample_idxs, resolution, face_type, data_format) )]
|
||||||
else:
|
else:
|
||||||
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, resolution, face_type, data_format), start_now=False ) \
|
self.generators = [SubprocessGenerator ( self.batch_func, (samples, seg_sample_idxs, resolution, face_type, data_format), start_now=False ) \
|
||||||
for i in range(self.generators_count) ]
|
for i in range(self.generators_count) ]
|
||||||
|
|
||||||
SubprocessGenerator.start_in_parallel( self.generators )
|
SubprocessGenerator.start_in_parallel( self.generators )
|
||||||
|
@ -66,12 +62,9 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
||||||
return next(generator)
|
return next(generator)
|
||||||
|
|
||||||
def batch_func(self, param ):
|
def batch_func(self, param ):
|
||||||
pickled_samples, resolution, face_type, data_format = param
|
samples, seg_sample_idxs, resolution, face_type, data_format = param
|
||||||
|
|
||||||
samples = pickle.loads(pickled_samples)
|
|
||||||
|
|
||||||
shuffle_idxs = []
|
shuffle_idxs = []
|
||||||
idxs = [*range(len(samples))]
|
|
||||||
|
|
||||||
random_flip = True
|
random_flip = True
|
||||||
rotation_range=[-10,10]
|
rotation_range=[-10,10]
|
||||||
|
@ -91,7 +84,7 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
||||||
while n_batch < bs:
|
while n_batch < bs:
|
||||||
try:
|
try:
|
||||||
if len(shuffle_idxs) == 0:
|
if len(shuffle_idxs) == 0:
|
||||||
shuffle_idxs = idxs.copy()
|
shuffle_idxs = seg_sample_idxs.copy()
|
||||||
np.random.shuffle(shuffle_idxs)
|
np.random.shuffle(shuffle_idxs)
|
||||||
idx = shuffle_idxs.pop()
|
idx = shuffle_idxs.pop()
|
||||||
|
|
||||||
|
@ -146,3 +139,56 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
||||||
io.log_err ( traceback.format_exc() )
|
io.log_err ( traceback.format_exc() )
|
||||||
|
|
||||||
yield [ np.array(batch) for batch in batches]
|
yield [ np.array(batch) for batch in batches]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentedSampleFilterSubprocessor(Subprocessor):
|
||||||
|
#override
|
||||||
|
def __init__(self, samples ):
|
||||||
|
self.samples = samples
|
||||||
|
self.samples_len = len(self.samples)
|
||||||
|
|
||||||
|
self.idxs = [*range(self.samples_len)]
|
||||||
|
self.result = []
|
||||||
|
super().__init__('SegmentedSampleFilterSubprocessor', SegmentedSampleFilterSubprocessor.Cli, 60)
|
||||||
|
|
||||||
|
#override
|
||||||
|
def process_info_generator(self):
|
||||||
|
for i in range(multiprocessing.cpu_count()):
|
||||||
|
yield 'CPU%d' % (i), {}, {'samples':self.samples}
|
||||||
|
|
||||||
|
#override
|
||||||
|
def on_clients_initialized(self):
|
||||||
|
io.progress_bar ("Filtering", self.samples_len)
|
||||||
|
|
||||||
|
#override
|
||||||
|
def on_clients_finalized(self):
|
||||||
|
io.progress_bar_close()
|
||||||
|
|
||||||
|
#override
|
||||||
|
def get_data(self, host_dict):
|
||||||
|
if len (self.idxs) > 0:
|
||||||
|
return self.idxs.pop(0)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
#override
|
||||||
|
def on_data_return (self, host_dict, data):
|
||||||
|
self.idxs.insert(0, data)
|
||||||
|
|
||||||
|
#override
|
||||||
|
def on_result (self, host_dict, data, result):
|
||||||
|
idx, is_ok = result
|
||||||
|
if is_ok:
|
||||||
|
self.result.append(idx)
|
||||||
|
io.progress_bar_inc(1)
|
||||||
|
def get_result(self):
|
||||||
|
return self.result
|
||||||
|
|
||||||
|
class Cli(Subprocessor.Cli):
|
||||||
|
#overridable optional
|
||||||
|
def on_initialize(self, client_dict):
|
||||||
|
self.samples = client_dict['samples']
|
||||||
|
|
||||||
|
def process_data(self, idx):
|
||||||
|
return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import samplelib.PackedFaceset
|
import samplelib.PackedFaceset
|
||||||
from core import pathex
|
from core import pathex
|
||||||
|
from core.mplib import MPSharedList
|
||||||
from core.interact import interact as io
|
from core.interact import interact as io
|
||||||
from core.joblib import Subprocessor
|
from core.joblib import Subprocessor
|
||||||
from DFLIMG import *
|
from DFLIMG import *
|
||||||
|
@ -33,6 +34,9 @@ class SampleLoader:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(sample_type, samples_path, subdirs=False):
|
def load(sample_type, samples_path, subdirs=False):
|
||||||
|
"""
|
||||||
|
Return MPSharedList of samples
|
||||||
|
"""
|
||||||
samples_cache = SampleLoader.samples_cache
|
samples_cache = SampleLoader.samples_cache
|
||||||
|
|
||||||
if str(samples_path) not in samples_cache.keys():
|
if str(samples_path) not in samples_cache.keys():
|
||||||
|
@ -56,12 +60,12 @@ class SampleLoader:
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
result = SampleLoader.load_face_samples( pathex.get_image_paths(samples_path, subdirs=subdirs) )
|
result = SampleLoader.load_face_samples( pathex.get_image_paths(samples_path, subdirs=subdirs) )
|
||||||
samples[sample_type] = result
|
|
||||||
|
|
||||||
|
samples[sample_type] = MPSharedList(result)
|
||||||
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||||
result = SampleLoader.load (SampleType.FACE, samples_path)
|
result = SampleLoader.load (SampleType.FACE, samples_path)
|
||||||
result = SampleLoader.upgradeToFaceTemporalSortedSamples(result)
|
result = SampleLoader.upgradeToFaceTemporalSortedSamples(result)
|
||||||
samples[sample_type] = result
|
samples[sample_type] = MPSharedList(result)
|
||||||
|
|
||||||
return samples[sample_type]
|
return samples[sample_type]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue