all models: removed options 'src_scale_mod', and 'sort samples by yaw as target'

If you want, you can manually remove unnecessary angles from src faceset after sort by yaw.

Optimized sample generators (CPU workers). Now they consume less amount of RAM and work faster.

added
4.2.other) data_src/dst util faceset pack.bat
	Packs /aligned/ samples into one /aligned/samples.pak file.
	After that, all faces will be deleted.

4.2.other) data_src/dst util faceset unpack.bat
	unpacks faces from /aligned/samples.pak to /aligned/ dir.
	After that, samples.pak will be deleted.

Packed faceset load and work faster.
This commit is contained in:
Colombo 2019-12-21 23:16:55 +04:00
parent 8866dce22e
commit 50f892d57d
26 changed files with 577 additions and 433 deletions

10
main.py
View file

@ -140,6 +140,14 @@ if __name__ == "__main__":
if arguments.restore_faceset_metadata:
Util.restore_faceset_metadata_folder (input_path=arguments.input_dir)
if arguments.pack_faceset:
from samplelib import PackedFaceset
PackedFaceset.pack( Path(arguments.input_dir) )
if arguments.unpack_faceset:
from samplelib import PackedFaceset
PackedFaceset.unpack( Path(arguments.input_dir) )
p = subparsers.add_parser( "util", help="Utilities.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
p.add_argument('--convert-png-to-jpg', action="store_true", dest="convert_png_to_jpg", default=False, help="Convert DeepFaceLAB PNG files to JPEG.")
@ -149,6 +157,8 @@ if __name__ == "__main__":
p.add_argument('--remove-ie-polys', action="store_true", dest="remove_ie_polys", default=False, help="Remove ie_polys from aligned faces.")
p.add_argument('--save-faceset-metadata', action="store_true", dest="save_faceset_metadata", default=False, help="Save faceset metadata to file.")
p.add_argument('--restore-faceset-metadata', action="store_true", dest="restore_faceset_metadata", default=False, help="Restore faceset metadata to file. Image filenames must be the same as used with save.")
p.add_argument('--pack-faceset', action="store_true", dest="pack_faceset", default=False, help="")
p.add_argument('--unpack-faceset', action="store_true", dest="unpack_faceset", default=False, help="")
p.set_defaults (func=process_util)

View file

@ -28,12 +28,10 @@ class ModelBase(object):
ask_write_preview_history=True,
ask_target_iter=True,
ask_batch_size=True,
ask_sort_by_yaw=True,
ask_random_flip=True,
ask_src_scale_mod=True):
ask_random_flip=True, **kwargs):
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
device_args['cpu_only'] = device_args.get('cpu_only',False)
device_args['cpu_only'] = True if debug else device_args.get('cpu_only',False)
if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
@ -115,13 +113,6 @@ class ModelBase(object):
else:
self.batch_size = self.options.get('batch_size', 0)
if ask_sort_by_yaw:
if (self.iter == 0 or ask_override):
default_sort_by_yaw = self.options.get('sort_by_yaw', False)
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s) : " % (yn_str[default_sort_by_yaw]), default_sort_by_yaw, help_message="NN will not learn src face directions that don't match dst face directions. Do not use if the dst face has hair that covers the jaw." )
else:
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
if ask_random_flip:
default_random_flip = self.options.get('random_flip', True)
if (self.iter == 0 or ask_override):
@ -129,12 +120,6 @@ class ModelBase(object):
else:
self.options['random_flip'] = self.options.get('random_flip', default_random_flip)
if ask_src_scale_mod:
if (self.iter == 0):
self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
else:
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
self.autobackup = self.options.get('autobackup', False)
if not self.autobackup and 'autobackup' in self.options:
self.options.pop('autobackup')
@ -151,10 +136,6 @@ class ModelBase(object):
self.sort_by_yaw = self.options.get('sort_by_yaw',False)
self.random_flip = self.options.get('random_flip',True)
self.src_scale_mod = self.options.get('src_scale_mod',0)
if self.src_scale_mod == 0 and 'src_scale_mod' in self.options:
self.options.pop('src_scale_mod')
self.onInitializeOptions(self.iter == 0, ask_override)
nnlib.import_all(self.device_config)

View file

@ -16,9 +16,7 @@ class AVATARModel(ModelBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs,
ask_sort_by_yaw=False,
ask_random_flip=False,
ask_src_scale_mod=False)
ask_random_flip=False)
#override
def onInitializeOptions(self, is_first_run, ask_override):

View file

@ -13,9 +13,7 @@ class Model(ModelBase):
ask_enable_autobackup=False,
ask_write_preview_history=False,
ask_target_iter=False,
ask_sort_by_yaw=False,
ask_random_flip=False,
ask_src_scale_mod=False)
ask_random_flip=False)
#override
def onInitializeOptions(self, is_first_run, ask_override):

View file

@ -16,9 +16,7 @@ class FUNITModel(ModelBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs,
ask_sort_by_yaw=False,
ask_random_flip=False,
ask_src_scale_mod=False)
ask_random_flip=False)
#override
def onInitializeOptions(self, is_first_run, ask_override):
@ -87,19 +85,19 @@ class FUNITModel(ModelBase):
self.set_training_data_generators ([
SampleGeneratorFacePerson(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=True, rotation_range=[0,0] ),
output_sample_types=output_sample_types, person_id_mode=1, use_caching=True, generators_count=1 ),
output_sample_types=output_sample_types, person_id_mode=1, ),
SampleGeneratorFacePerson(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=True, rotation_range=[0,0] ),
output_sample_types=output_sample_types, person_id_mode=1, use_caching=True, generators_count=1 ),
output_sample_types=output_sample_types, person_id_mode=1, ),
SampleGeneratorFacePerson(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=True, rotation_range=[0,0]),
output_sample_types=output_sample_types1, person_id_mode=1, use_caching=True, generators_count=1 ),
output_sample_types=output_sample_types1, person_id_mode=1, ),
SampleGeneratorFacePerson(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=True, rotation_range=[0,0]),
output_sample_types=output_sample_types1, person_id_mode=1, use_caching=True, generators_count=1 ),
output_sample_types=output_sample_types1, person_id_mode=1, ),
])
#override

View file

@ -15,9 +15,7 @@ class Model(ModelBase):
ask_enable_autobackup=False,
ask_write_preview_history=False,
ask_target_iter=False,
ask_sort_by_yaw=False,
ask_random_flip=False,
ask_src_scale_mod=False)
ask_random_flip=False)
#override
def onInitializeOptions(self, is_first_run, ask_override):

View file

@ -50,9 +50,8 @@ class Model(ModelBase):
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_M), 'resolution':128} ]
self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) ),
output_sample_types=output_sample_types),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,

View file

@ -60,9 +60,8 @@ class Model(ModelBase):
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_M), 'resolution':128} ]
self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) ),
output_sample_types=output_sample_types ),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,

View file

@ -61,9 +61,8 @@ class Model(ModelBase):
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_M), 'resolution':64} ]
self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) ),
output_sample_types=output_sample_types),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,

View file

@ -55,9 +55,8 @@ class Model(ModelBase):
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_M), 'resolution':128} ]
self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) ),
output_sample_types=output_sample_types),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,

View file

@ -18,14 +18,12 @@ class Quick96Model(ModelBase):
ask_write_preview_history=False,
ask_target_iter=False,
ask_batch_size=False,
ask_sort_by_yaw=False,
ask_random_flip=False,
ask_src_scale_mod=False)
ask_random_flip=False)
#override
def onInitialize(self):
exec(nnlib.import_all(), locals(), globals())
self.set_vram_batch_requirements({1.5:2,2:4})#,3:4,4:8})
self.set_vram_batch_requirements({1.5:2,2:4})
resolution = self.resolution = 96
@ -171,7 +169,7 @@ class Quick96Model(ModelBase):
self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=False, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
sample_process_options=SampleProcessor.Options(random_flip=False, scale_range=np.array([-0.05, 0.05]) ),
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution': resolution, 'normalize_tanh':True },
{'types' : (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution': resolution, 'normalize_tanh':True },
{'types' : (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_M), 'resolution': resolution } ]

View file

@ -466,18 +466,15 @@ class SAEModel(ModelBase):
training_data_src_path = self.training_data_src_path
training_data_dst_path = self.training_data_dst_path
sort_by_yaw = self.sort_by_yaw
if self.pretrain and self.pretraining_data_path is not None:
training_data_src_path = self.pretraining_data_path
training_data_dst_path = self.pretraining_data_path
sort_by_yaw = False
self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' else None,
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) ),
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution } ]

View file

@ -529,26 +529,23 @@ class SAEHDModel(ModelBase):
training_data_src_path = self.training_data_src_path
training_data_dst_path = self.training_data_dst_path
sort_by_yaw = self.sort_by_yaw
if self.pretrain and self.pretraining_data_path is not None:
training_data_src_path = self.pretraining_data_path
training_data_dst_path = self.pretraining_data_path
sort_by_yaw = False
t_img_warped = t.IMG_WARPED_TRANSFORMED if self.options['random_warp'] else t.IMG_TRANSFORMED
self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' else None,
debug=self.is_debug(), batch_size=self.batch_size, use_caching=False,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) ),
output_sample_types = [ {'types' : (t_img_warped, face_type, t_mode_bgr), 'resolution':resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution, 'ct_mode': self.options['ct_mode'] },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution } ]
),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, use_caching=False,
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
output_sample_types = [ {'types' : (t_img_warped, face_type, t_mode_bgr), 'resolution':resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution},

107
samplelib/PackedFaceset.py Normal file
View file

@ -0,0 +1,107 @@
import pickle
import struct
from pathlib import Path
from interact import interact as io
from utils import Path_utils
import samplelib.SampleHost
packed_faceset_filename = 'samples.pak'
class PackedFaceset():
VERSION = 1
@staticmethod
def pack(samples_path):
samples_dat_path = samples_path / packed_faceset_filename
if samples_dat_path.exists():
io.log_info(f"{samples_dat_path} : file already exists !")
io.input_bool("Press enter to continue and overwrite.", False)
of = open(samples_dat_path, "wb")
image_paths = Path_utils.get_image_paths(samples_path)
samples = samplelib.SampleHost.load_face_samples(image_paths)
for sample in samples:
sample.filename = str(Path(sample.filename).relative_to(samples_path))
samples_len = len(samples)
samples_bytes = pickle.dumps(samples, 4)
of.write ( struct.pack ("Q", PackedFaceset.VERSION ) )
of.write ( struct.pack ("Q", len(samples_bytes) ) )
of.write ( samples_bytes )
sample_data_table_offset = of.tell()
of.write ( bytes( 8*(samples_len+1) ) ) #sample data offset table
data_start_offset = of.tell()
offsets = []
for sample in io.progress_bar_generator(samples, "Packing"):
try:
with open( samples_path / sample.filename, "rb") as f:
b = f.read()
offsets.append ( of.tell() - data_start_offset )
of.write(b)
except:
raise Exception(f"error while processing sample {sample.filename}")
offsets.append ( of.tell() )
of.seek(sample_data_table_offset, 0)
for offset in offsets:
of.write ( struct.pack("Q", offset) )
of.seek(0,2)
of.close()
for filename in io.progress_bar_generator(image_paths,"Deleting"):
Path(filename).unlink()
@staticmethod
def unpack(samples_path):
samples_dat_path = samples_path / packed_faceset_filename
if not samples_dat_path.exists():
io.log_info(f"{samples_dat_path} : file not found.")
return
samples = PackedFaceset.load(samples_path)
for sample in io.progress_bar_generator(samples, "Unpacking"):
with open(samples_path / sample.filename, "wb") as f:
f.write( sample.read_raw_file() )
samples_dat_path.unlink()
@staticmethod
def load(samples_path):
samples_dat_path = samples_path / packed_faceset_filename
if not samples_dat_path.exists():
return None
f = open(samples_dat_path, "rb")
version, = struct.unpack("Q", f.read(8) )
if version != PackedFaceset.VERSION:
raise NotImplementedError
sizeof_samples_bytes, = struct.unpack("Q", f.read(8) )
samples = pickle.loads ( f.read(sizeof_samples_bytes) )
offsets = [ struct.unpack("Q", f.read(8) )[0] for _ in range(len(samples)+1) ]
data_start_offset = f.tell()
f.close()
for i, sample in enumerate(samples):
start_offset, end_offset = offsets[i], offsets[i+1]
sample.set_filename_offset_size( str(samples_dat_path), data_start_offset+start_offset, end_offset-start_offset )
return samples

View file

@ -14,15 +14,28 @@ class SampleType(IntEnum):
FACE_BEGIN = 1
FACE = 1 #aligned face unsorted
FACE_YAW_SORTED = 2 #sorted by yaw
FACE_YAW_SORTED_AS_TARGET = 3 #sorted by yaw and included only yaws which exist in TARGET also automatic mirrored
FACE_TEMPORAL_SORTED = 4
FACE_END = 4
FACE_TEMPORAL_SORTED = 2 #sorted by source filename
FACE_END = 2
QTY = 5
class Sample(object):
def __init__(self, sample_type=None, filename=None, person_id=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch_yaw_roll=None, eyebrows_expand_mod=None, source_filename=None, mirror=None, close_target_list=None, fanseg_mask_exist=False):
__slots__ = ['sample_type',
'filename',
'person_id',
'face_type',
'shape',
'landmarks',
'ie_polys',
'pitch_yaw_roll',
'eyebrows_expand_mod',
'source_filename',
'mirror',
'fanseg_mask_exist',
'_filename_offset_size',
]
def __init__(self, sample_type=None, filename=None, person_id=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch_yaw_roll=None, eyebrows_expand_mod=None, source_filename=None, mirror=None, fanseg_mask_exist=False):
self.sample_type = sample_type if sample_type is not None else SampleType.IMAGE
self.filename = filename
self.person_id = person_id
@ -34,10 +47,14 @@ class Sample(object):
self.eyebrows_expand_mod = eyebrows_expand_mod
self.source_filename = source_filename
self.mirror = mirror
self.close_target_list = close_target_list
self.fanseg_mask_exist = fanseg_mask_exist
def copy_and_set(self, sample_type=None, filename=None, person_id=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch_yaw_roll=None, eyebrows_expand_mod=None, source_filename=None, mirror=None, close_target_list=None, fanseg_mask=None, fanseg_mask_exist=None):
self._filename_offset_size = None
def set_filename_offset_size(self, filename, offset, size):
self._filename_offset_size = (filename, offset, size)
def copy_and_set(self, sample_type=None, filename=None, person_id=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch_yaw_roll=None, eyebrows_expand_mod=None, source_filename=None, mirror=None, fanseg_mask=None, fanseg_mask_exist=None):
return Sample(
sample_type=sample_type if sample_type is not None else self.sample_type,
filename=filename if filename is not None else self.filename,
@ -50,11 +67,20 @@ class Sample(object):
eyebrows_expand_mod=eyebrows_expand_mod if eyebrows_expand_mod is not None else self.eyebrows_expand_mod,
source_filename=source_filename if source_filename is not None else self.source_filename,
mirror=mirror if mirror is not None else self.mirror,
close_target_list=close_target_list if close_target_list is not None else self.close_target_list,
fanseg_mask_exist=fanseg_mask_exist if fanseg_mask_exist is not None else self.fanseg_mask_exist)
def read_raw_file(self, filename=None):
if self._filename_offset_size is not None:
filename, offset, size = self._filename_offset_size
with open(filename, "rb") as f:
f.seek( offset, 0)
return f.read (size)
else:
with open(filename, "rb") as f:
return f.read()
def load_bgr(self):
img = cv2_imread (self.filename).astype(np.float32) / 255.0
img = cv2_imread (self.filename, loader_func=self.read_raw_file).astype(np.float32) / 255.0
if self.mirror:
img = img[:,::-1].copy()
return img
@ -63,16 +89,12 @@ class Sample(object):
if self.fanseg_mask_exist:
filepath = Path(self.filename)
if filepath.suffix == '.png':
dflimg = DFLPNG.load ( str(filepath) )
dflimg = DFLPNG.load ( str(filepath), loader_func=self.read_raw_file )
elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) )
dflimg = DFLJPG.load ( str(filepath), loader_func=self.read_raw_file )
else:
dflimg = None
return dflimg.get_fanseg_mask()
return None
def get_random_close_target_sample(self):
if self.close_target_list is None:
return None
return self.close_target_list[randint (0, len(self.close_target_list)-1)]

View file

@ -5,10 +5,10 @@ import cv2
import numpy as np
from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
SampleType)
from utils import iter_utils
from utils import mp_utils
'''
arg
@ -19,15 +19,10 @@ output_sample_types = [
'''
class SampleGeneratorFace(SampleGeneratorBase):
def __init__ (self, samples_path, debug=False, batch_size=1,
sort_by_yaw=False,
sort_by_yaw_target_samples_path=None,
random_ct_samples_path=None,
sample_process_options=SampleProcessor.Options(),
output_sample_types=[],
add_sample_idx=False,
use_caching=False,
generators_count=2,
generators_random_seed=None,
**kwargs):
super().__init__(samples_path, debug, batch_size)
@ -35,33 +30,27 @@ class SampleGeneratorFace(SampleGeneratorBase):
self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx
if sort_by_yaw_target_samples_path is not None:
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
elif sort_by_yaw:
self.sample_type = SampleType.FACE_YAW_SORTED
else:
self.sample_type = SampleType.FACE
if generators_random_seed is not None and len(generators_random_seed) != generators_count:
raise ValueError("len(generators_random_seed) != generators_count")
self.generators_random_seed = generators_random_seed
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path, use_caching=use_caching)
np.random.shuffle(samples)
self.samples_len = len(samples)
samples_host = SampleHost.mp_host (SampleType.FACE, self.samples_path)
self.samples_len = len(samples_host)
if self.samples_len == 0:
raise ValueError('No training data provided.')
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path, use_caching=use_caching) if random_ct_samples_path is not None else None
index_host = mp_utils.IndexHost(self.samples_len)
if random_ct_samples_path is not None:
ct_samples_host = SampleHost.mp_host (SampleType.FACE, random_ct_samples_path)
ct_index_host = mp_utils.IndexHost( len(ct_samples_host) )
else:
ct_samples_host = None
ct_index_host = None
if self.debug:
self.generators_count = 1
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples, ct_samples) )]
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index_host.create_cli(), ct_samples_host.create_cli() if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None) )]
else:
self.generators_count = min ( generators_count, self.samples_len )
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ]
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 4)
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_host.create_cli(), index_host.create_cli(), ct_samples_host.create_cli() if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
self.generator_counter = -1
@ -78,62 +67,19 @@ class SampleGeneratorFace(SampleGeneratorBase):
return next(generator)
def batch_func(self, param ):
generator_id, samples, ct_samples = param
if self.generators_random_seed is not None:
np.random.seed ( self.generators_random_seed[generator_id] )
samples_len = len(samples)
samples_idxs = [*range(samples_len)]
ct_samples_len = len(ct_samples) if ct_samples is not None else 0
if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
if all ( [ samples[idx] == None for idx in samples_idxs] ):
raise ValueError('Not enough training data. Gather more faces!')
if self.sample_type == SampleType.FACE:
shuffle_idxs = []
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
shuffle_idxs = []
shuffle_idxs_2D = [[]]*samples_len
samples, index_host, ct_samples, ct_index_host = param
bs = self.batch_size
while True:
batches = None
for n_batch in range(self.batch_size):
while True:
sample = None
if self.sample_type == SampleType.FACE:
if len(shuffle_idxs) == 0:
shuffle_idxs = samples_idxs.copy()
np.random.shuffle(shuffle_idxs)
indexes = index_host.get(bs)
ct_indexes = ct_index_host.get(bs) if ct_samples is not None else None
idx = shuffle_idxs.pop()
sample = samples[ idx ]
for n_batch in range(bs):
sample = samples[ indexes[n_batch] ]
ct_sample = ct_samples[ ct_indexes[n_batch] ] if ct_samples is not None else None
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
if len(shuffle_idxs) == 0:
shuffle_idxs = samples_idxs.copy()
np.random.shuffle(shuffle_idxs)
idx = shuffle_idxs.pop()
if samples[idx] != None:
if len(shuffle_idxs_2D[idx]) == 0:
a = shuffle_idxs_2D[idx] = [ *range(len(samples[idx])) ]
np.random.shuffle (a)
idx2 = shuffle_idxs_2D[idx].pop()
sample = samples[idx][idx2]
idx = (idx << 16) | (idx2 & 0xFFFF)
if sample is not None:
try:
ct_sample=None
if ct_samples is not None:
ct_sample=ct_samples[np.random.randint(ct_samples_len)]
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
except:
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
@ -149,11 +95,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
if self.add_sample_idx:
batches[i_sample_idx].append (idx)
break
yield [ np.array(batch) for batch in batches]
@staticmethod
def get_person_id_max_count(samples_path):
return SampleLoader.get_person_id_max_count(samples_path)
return SampleHost.get_person_id_max_count(samples_path)

View file

@ -6,7 +6,7 @@ import cv2
import numpy as np
from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
SampleType)
from utils import iter_utils
@ -37,7 +37,7 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
raise ValueError("len(generators_random_seed) != generators_count")
self.generators_random_seed = generators_random_seed
samples = SampleLoader.load (SampleType.FACE, self.samples_path, person_id_mode=True, use_caching=use_caching)
samples = SampleHost.load (SampleType.FACE, self.samples_path, person_id_mode=True, use_caching=use_caching)
samples = copy.copy(samples)
for i in range(len(samples)):
samples[i] = copy.copy(samples[i])
@ -275,4 +275,4 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
@staticmethod
def get_person_id_max_count(samples_path):
return SampleLoader.get_person_id_max_count(samples_path)
return SampleHost.get_person_id_max_count(samples_path)

View file

@ -4,7 +4,7 @@ import cv2
from utils import iter_utils
from samplelib import SampleType, SampleProcessor, SampleLoader, SampleGeneratorBase
from samplelib import SampleType, SampleProcessor, SampleHost, SampleGeneratorBase
'''
output_sample_types = [
@ -20,7 +20,7 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
self.samples = SampleLoader.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
self.samples = SampleHost.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
if self.debug:
self.generators_count = 1

View file

@ -4,7 +4,7 @@ import cv2
from utils import iter_utils
from samplelib import SampleType, SampleProcessor, SampleLoader, SampleGeneratorBase
from samplelib import SampleType, SampleProcessor, SampleHost, SampleGeneratorBase
'''
output_sample_types = [
@ -20,7 +20,7 @@ class SampleGeneratorImageTemporal(SampleGeneratorBase):
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
self.samples = SampleLoader.load (SampleType.IMAGE, self.samples_path)
self.samples = SampleHost.load (SampleType.IMAGE, self.samples_path)
self.generator_samples = [ self.samples ]
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] if self.debug else \

116
samplelib/SampleHost.py Normal file
View file

@ -0,0 +1,116 @@
import operator
import traceback
from pathlib import Path
from facelib import FaceType, LandmarksProcessor
from interact import interact as io
from utils import Path_utils, mp_utils
from utils.DFLJPG import DFLJPG
from utils.DFLPNG import DFLPNG
from .Sample import Sample, SampleType
import samplelib.PackedFaceset
class SampleHost:
samples_cache = dict()
host_cache = dict()
@staticmethod
def get_person_id_max_count(samples_path):
return len ( Path_utils.get_all_dir_names(samples_path) )
@staticmethod
def load(sample_type, samples_path):
samples_cache = SampleHost.samples_cache
if str(samples_path) not in samples_cache.keys():
samples_cache[str(samples_path)] = [None]*SampleType.QTY
samples = samples_cache[str(samples_path)]
if sample_type == SampleType.IMAGE:
if samples[sample_type] is None:
samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ]
elif sample_type == SampleType.FACE:
if samples[sample_type] is None:
result = None
try:
result = samplelib.PackedFaceset.load(samples_path)
except:
io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}")
if result is not None:
io.log_info (f"Loaded packed samples from {samples_path}")
if result is None:
result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) )
samples[sample_type] = result
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
if samples[sample_type] is None:
samples[sample_type] = SampleHost.upgradeToFaceTemporalSortedSamples( SampleHost.load(SampleType.FACE, samples_path) )
return samples[sample_type]
@staticmethod
def mp_host(sample_type, samples_path):
result = SampleHost.load (sample_type, samples_path)
host_cache = SampleHost.host_cache
if str(samples_path) not in host_cache.keys():
host_cache[str(samples_path)] = [None]*SampleType.QTY
hosts = host_cache[str(samples_path)]
if hosts[sample_type] is None:
hosts[sample_type] = mp_utils.ListHost(result)
return hosts[sample_type]
@staticmethod
def load_face_samples ( image_paths, silent=False):
sample_list = []
for filename in (image_paths if silent else io.progress_bar_generator( image_paths, "Loading")):
filename_path = Path(filename)
try:
if filename_path.suffix == '.png':
dflimg = DFLPNG.load ( str(filename_path) )
elif filename_path.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filename_path) )
else:
dflimg = None
if dflimg is None:
io.log_err ("load_face_samples: %s is not a dfl image file required for training" % (filename_path.name) )
continue
landmarks = dflimg.get_landmarks()
pitch_yaw_roll = dflimg.get_pitch_yaw_roll()
eyebrows_expand_mod = dflimg.get_eyebrows_expand_mod()
if pitch_yaw_roll is None:
pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(landmarks)
sample_list.append( Sample(filename=filename,
sample_type=SampleType.FACE,
face_type=FaceType.fromString (dflimg.get_face_type()),
shape=dflimg.get_shape(),
landmarks=landmarks,
ie_polys=dflimg.get_ie_polys(),
pitch_yaw_roll=pitch_yaw_roll,
eyebrows_expand_mod=eyebrows_expand_mod,
source_filename=dflimg.get_source_filename(),
fanseg_mask_exist=dflimg.get_fanseg_mask() is not None, ) )
except:
io.log_err ("Unable to load %s , error: %s" % (filename, traceback.format_exc() ) )
return sample_list
@staticmethod
def upgradeToFaceTemporalSortedSamples( samples ):
new_s = [ (s, s.source_filename) for s in samples]
new_s = sorted(new_s, key=operator.itemgetter(1))
return [ s[0] for s in new_s]

View file

@ -1,204 +0,0 @@
import operator
import pickle
import traceback
from enum import IntEnum
from pathlib import Path
import cv2
import numpy as np
from facelib import FaceType, LandmarksProcessor
from interact import interact as io
from utils import Path_utils
from utils.DFLJPG import DFLJPG
from utils.DFLPNG import DFLPNG
from .Sample import Sample, SampleType
class SampleLoader:
cache = dict()
@staticmethod
def get_person_id_max_count(samples_path):
return len ( Path_utils.get_all_dir_names(samples_path) )
@staticmethod
def load(sample_type, samples_path, target_samples_path=None, person_id_mode=False, use_caching=False):
cache = SampleLoader.cache
if str(samples_path) not in cache.keys():
cache[str(samples_path)] = [None]*SampleType.QTY
datas = cache[str(samples_path)]
if sample_type == SampleType.IMAGE:
if datas[sample_type] is None:
datas[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ]
elif sample_type == SampleType.FACE:
if datas[sample_type] is None:
if not use_caching:
datas[sample_type] = SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename) for filename in Path_utils.get_image_paths(samples_path) ] )
else:
samples_dat = samples_path / 'samples.dat'
if samples_dat.exists():
io.log_info (f"Using saved samples info from '{samples_dat}' ")
all_samples = pickle.loads(samples_dat.read_bytes())
if person_id_mode:
for samples in all_samples:
for sample in samples:
sample.filename = str( samples_path / Path(sample.filename) )
else:
for sample in all_samples:
sample.filename = str( samples_path / Path(sample.filename) )
datas[sample_type] = all_samples
else:
if person_id_mode:
dir_names = Path_utils.get_all_dir_names(samples_path)
all_samples = []
for i, dir_name in io.progress_bar_generator( [*enumerate(dir_names)] , "Loading"):
all_samples += [ SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename, person_id=i) for filename in Path_utils.get_image_paths( samples_path / dir_name ) ], silent=True ) ]
datas[sample_type] = all_samples
else:
datas[sample_type] = all_samples = SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename) for filename in Path_utils.get_image_paths(samples_path) ] )
if person_id_mode:
for samples in all_samples:
for sample in samples:
sample.filename = str(Path(sample.filename).relative_to(samples_path))
else:
for sample in all_samples:
sample.filename = str(Path(sample.filename).relative_to(samples_path))
samples_dat.write_bytes (pickle.dumps(all_samples))
if person_id_mode:
for samples in all_samples:
for sample in samples:
sample.filename = str( samples_path / Path(sample.filename) )
else:
for sample in all_samples:
sample.filename = str( samples_path / Path(sample.filename) )
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
if datas[sample_type] is None:
datas[sample_type] = SampleLoader.upgradeToFaceTemporalSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) )
elif sample_type == SampleType.FACE_YAW_SORTED:
if datas[sample_type] is None:
datas[sample_type] = SampleLoader.upgradeToFaceYawSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) )
elif sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
if datas[sample_type] is None:
if target_samples_path is None:
raise Exception('target_samples_path is None for FACE_YAW_SORTED_AS_TARGET')
datas[sample_type] = SampleLoader.upgradeToFaceYawSortedAsTargetSamples( SampleLoader.load(SampleType.FACE_YAW_SORTED, samples_path), SampleLoader.load(SampleType.FACE_YAW_SORTED, target_samples_path) )
return datas[sample_type]
@staticmethod
def upgradeToFaceSamples ( samples, silent=False ):
sample_list = []
for s in (samples if silent else io.progress_bar_generator(samples, "Loading")):
s_filename_path = Path(s.filename)
try:
if s_filename_path.suffix == '.png':
dflimg = DFLPNG.load ( str(s_filename_path) )
elif s_filename_path.suffix == '.jpg':
dflimg = DFLJPG.load ( str(s_filename_path) )
else:
dflimg = None
if dflimg is None:
print ("%s is not a dfl image file required for training" % (s_filename_path.name) )
continue
landmarks = dflimg.get_landmarks()
pitch_yaw_roll = dflimg.get_pitch_yaw_roll()
eyebrows_expand_mod = dflimg.get_eyebrows_expand_mod()
if pitch_yaw_roll is None:
pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(landmarks)
sample_list.append( s.copy_and_set(sample_type=SampleType.FACE,
face_type=FaceType.fromString (dflimg.get_face_type()),
shape=dflimg.get_shape(),
landmarks=landmarks,
ie_polys=dflimg.get_ie_polys(),
pitch_yaw_roll=pitch_yaw_roll,
eyebrows_expand_mod=eyebrows_expand_mod,
source_filename=dflimg.get_source_filename(),
fanseg_mask_exist=dflimg.get_fanseg_mask() is not None, ) )
except:
print ("Unable to load %s , error: %s" % (str(s_filename_path), traceback.format_exc() ) )
return sample_list
@staticmethod
def upgradeToFaceTemporalSortedSamples( samples ):
new_s = [ (s, s.source_filename) for s in samples]
new_s = sorted(new_s, key=operator.itemgetter(1))
return [ s[0] for s in new_s]
@staticmethod
def upgradeToFaceYawSortedSamples( samples ):
lowest_yaw, highest_yaw = -1.0, 1.0
gradations = 64
diff_rot_per_grad = abs(highest_yaw-lowest_yaw) / gradations
yaws_sample_list = [None]*gradations
for i in io.progress_bar_generator(range(gradations), "Sorting"):
yaw = lowest_yaw + i*diff_rot_per_grad
next_yaw = lowest_yaw + (i+1)*diff_rot_per_grad
yaw_samples = []
for s in samples:
s_yaw = s.pitch_yaw_roll[1]
if (i == 0 and s_yaw < next_yaw) or \
(i < gradations-1 and s_yaw >= yaw and s_yaw < next_yaw) or \
(i == gradations-1 and s_yaw >= yaw):
yaw_samples.append ( s.copy_and_set(sample_type=SampleType.FACE_YAW_SORTED) )
if len(yaw_samples) > 0:
yaws_sample_list[i] = yaw_samples
return yaws_sample_list
@staticmethod
def upgradeToFaceYawSortedAsTargetSamples (s, t):
l = len(s)
if l != len(t):
raise Exception('upgradeToFaceYawSortedAsTargetSamples() s_len != t_len')
b = l // 2
s_idxs = np.argwhere ( np.array ( [ 1 if x != None else 0 for x in s] ) == 1 )[:,0]
t_idxs = np.argwhere ( np.array ( [ 1 if x != None else 0 for x in t] ) == 1 )[:,0]
new_s = [None]*l
for t_idx in t_idxs:
search_idxs = []
for i in range(0,l):
search_idxs += [t_idx - i, (l-t_idx-1) - i, t_idx + i, (l-t_idx-1) + i]
for search_idx in search_idxs:
if search_idx in s_idxs:
mirrored = ( t_idx != search_idx and ((t_idx < b and search_idx >= b) or (search_idx < b and t_idx >= b)) )
new_s[t_idx] = [ sample.copy_and_set(sample_type=SampleType.FACE_YAW_SORTED_AS_TARGET,
mirror=True,
pitch_yaw_roll=(sample.pitch_yaw_roll[0],-sample.pitch_yaw_roll[1],sample.pitch_yaw_roll[2]),
landmarks=LandmarksProcessor.mirror_landmarks (sample.landmarks, sample.shape[1] ))
for sample in s[search_idx]
] if mirrored else s[search_idx]
break
return new_s

View file

@ -1,9 +1,10 @@
from .Sample import Sample
from .Sample import SampleType
from .SampleLoader import SampleLoader
from .SampleHost import SampleHost
from .SampleProcessor import SampleProcessor
from .SampleGeneratorBase import SampleGeneratorBase
from .SampleGeneratorFace import SampleGeneratorFace
from .SampleGeneratorFacePerson import SampleGeneratorFacePerson
from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal
from .PackedFaceset import PackedFaceset

View file

@ -18,8 +18,11 @@ class DFLJPG(object):
self.shape = (0,0,0)
@staticmethod
def load_raw(filename):
def load_raw(filename, loader_func=None):
try:
if loader_func is not None:
data = loader_func(filename)
else:
with open(filename, "rb") as f:
data = f.read()
except:
@ -116,9 +119,9 @@ class DFLJPG(object):
raise Exception ("Corrupted JPG file: %s" % (str(e)))
@staticmethod
def load(filename):
def load(filename, loader_func=None):
try:
inst = DFLJPG.load_raw (filename)
inst = DFLJPG.load_raw (filename, loader_func=loader_func)
inst.dfl_dict = None
for chunk in inst.chunks:

View file

@ -225,8 +225,11 @@ class DFLPNG(object):
self.dfl_dict = None
@staticmethod
def load_raw(filename):
def load_raw(filename, loader_func=None):
try:
if loader_func is not None:
data = loader_func(filename)
else:
with open(filename, "rb") as f:
data = f.read()
except:
@ -252,9 +255,9 @@ class DFLPNG(object):
return inst
@staticmethod
def load(filename):
def load(filename, loader_func=None):
try:
inst = DFLPNG.load_raw (filename)
inst = DFLPNG.load_raw (filename, loader_func=loader_func)
inst.dfl_dict = inst.getDFLDictData()
if inst.dfl_dict is not None:

View file

@ -3,8 +3,11 @@ import numpy as np
from pathlib import Path
#allows to open non-english characters path
def cv2_imread(filename, flags=cv2.IMREAD_UNCHANGED):
def cv2_imread(filename, flags=cv2.IMREAD_UNCHANGED, loader_func=None):
try:
if loader_func is not None:
bytes = bytearray(loader_func(filename))
else:
with open(filename, "rb") as stream:
bytes = bytearray(stream.read())
numpyarray = np.asarray(bytes, dtype=np.uint8)

179
utils/mp_utils.py Normal file
View file

@ -0,0 +1,179 @@
import multiprocessing
import threading
import time
import numpy as np
class IndexHost():
"""
Provides random shuffled indexes for multiprocesses
"""
def __init__(self, indexes_count):
self.sq = multiprocessing.Queue()
self.cqs = []
self.clis = []
self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,) )
self.thread.daemon = True
self.thread.start()
def host_thread(self, indexes_count):
idxs = [*range(indexes_count)]
shuffle_idxs = []
sq = self.sq
while True:
while not sq.empty():
obj = sq.get()
cq_id, count = obj[0], obj[1]
result = []
for i in range(count):
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle(shuffle_idxs)
result.append(shuffle_idxs.pop())
self.cqs[cq_id].put (result)
time.sleep(0.005)
def create_cli(self):
cq = multiprocessing.Queue()
self.cqs.append ( cq )
cq_id = len(self.cqs)-1
return IndexHost.Cli(self.sq, cq, cq_id)
# disable pickling
def __getstate__(self):
return dict()
def __setstate__(self, d):
self.__dict__.update(d)
class Cli():
def __init__(self, sq, cq, cq_id):
self.sq = sq
self.cq = cq
self.cq_id = cq_id
def get(self, count):
self.sq.put ( (self.cq_id,count) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
class ListHost():
def __init__(self, list_):
self.sq = multiprocessing.Queue()
self.cqs = []
self.clis = []
self.list_ = list_
self.thread = threading.Thread(target=self.host_thread)
self.thread.daemon = True
self.thread.start()
def host_thread(self):
sq = self.sq
while True:
while not sq.empty():
obj = sq.get()
cq_id, cmd = obj[0], obj[1]
if cmd == 0:
item = self.list_[ obj[2] ]
self.cqs[cq_id].put ( item )
elif cmd == 1:
self.cqs[cq_id].put ( len(self.list_) )
time.sleep(0.005)
def create_cli(self):
cq = multiprocessing.Queue()
self.cqs.append ( cq )
cq_id = len(self.cqs)-1
return ListHost.Cli(self.sq, cq, cq_id)
def __len__(self):
return len(self.list_)
# disable pickling
def __getstate__(self):
return dict()
def __setstate__(self, d):
self.__dict__.update(d)
class Cli():
def __init__(self, sq, cq, cq_id):
self.sq = sq
self.cq = cq
self.cq_id = cq_id
def __getitem__(self, key):
self.sq.put ( (self.cq_id,0,key) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
def __len__(self):
self.sq.put ( (self.cq_id,1) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
class DictHost():
def __init__(self, d, num_users):
self.sqs = [ multiprocessing.Queue() for _ in range(num_users) ]
self.cqs = [ multiprocessing.Queue() for _ in range(num_users) ]
self.thread = threading.Thread(target=self.host_thread, args=(d,) )
self.thread.daemon = True
self.thread.start()
self.clis = [ DictHostCli(sq,cq) for sq, cq in zip(self.sqs, self.cqs) ]
def host_thread(self, d):
while True:
for sq, cq in zip(self.sqs, self.cqs):
if not sq.empty():
obj = sq.get()
cmd = obj[0]
if cmd == 0:
cq.put (d[ obj[1] ])
elif cmd == 1:
cq.put ( list(d.keys()) )
time.sleep(0.005)
def get_cli(self, n_user):
return self.clis[n_user]
# disable pickling
def __getstate__(self):
return dict()
def __setstate__(self, d):
self.__dict__.update(d)
class DictHostCli():
def __init__(self, sq, cq):
self.sq = sq
self.cq = cq
def __getitem__(self, key):
self.sq.put ( (0,key) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)
def keys(self):
self.sq.put ( (1,) )
while True:
if not self.cq.empty():
return self.cq.get()
time.sleep(0.001)