mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-11 07:37:03 -07:00
optimized sample generator
This commit is contained in:
parent
b5c234dac3
commit
21b25038ac
6 changed files with 201 additions and 160 deletions
|
@ -7,8 +7,8 @@ import numpy as np
|
||||||
from facelib import LandmarksProcessor
|
from facelib import LandmarksProcessor
|
||||||
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
|
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
|
||||||
SampleType)
|
SampleType)
|
||||||
from utils import iter_utils
|
from utils import iter_utils, mp_utils
|
||||||
from utils import mp_utils
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
arg
|
arg
|
||||||
|
@ -30,8 +30,13 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
self.output_sample_types = output_sample_types
|
self.output_sample_types = output_sample_types
|
||||||
self.add_sample_idx = add_sample_idx
|
self.add_sample_idx = add_sample_idx
|
||||||
|
|
||||||
samples_host = SampleHost.mp_host (SampleType.FACE, self.samples_path)
|
if self.debug:
|
||||||
self.samples_len = len(samples_host.get_list())
|
self.generators_count = 1
|
||||||
|
else:
|
||||||
|
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 6)
|
||||||
|
|
||||||
|
samples_clis = SampleHost.host (SampleType.FACE, self.samples_path, number_of_clis=self.generators_count)
|
||||||
|
self.samples_len = len(samples_clis[0])
|
||||||
|
|
||||||
if self.samples_len == 0:
|
if self.samples_len == 0:
|
||||||
raise ValueError('No training data provided.')
|
raise ValueError('No training data provided.')
|
||||||
|
@ -39,18 +44,16 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
index_host = mp_utils.IndexHost(self.samples_len)
|
index_host = mp_utils.IndexHost(self.samples_len)
|
||||||
|
|
||||||
if random_ct_samples_path is not None:
|
if random_ct_samples_path is not None:
|
||||||
ct_samples_host = SampleHost.mp_host (SampleType.FACE, random_ct_samples_path)
|
ct_samples_clis = SampleHost.host (SampleType.FACE, random_ct_samples_path, number_of_clis=self.generators_count)
|
||||||
ct_index_host = mp_utils.IndexHost( len(ct_samples_host.get_list()) )
|
ct_index_host = mp_utils.IndexHost( len(ct_samples_clis[0]) )
|
||||||
else:
|
else:
|
||||||
ct_samples_host = None
|
ct_samples_clis = None
|
||||||
ct_index_host = None
|
ct_index_host = None
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.generators_count = 1
|
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_clis[0], index_host.create_cli(), ct_samples_clis[0] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
||||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index_host.create_cli(), ct_samples_host.create_cli() if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
|
||||||
else:
|
else:
|
||||||
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 4)
|
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_clis[i], index_host.create_cli(), ct_samples_clis[i] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
|
||||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_host.create_cli(), index_host.create_cli(), ct_samples_host.create_cli() if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
|
|
||||||
|
|
||||||
self.generator_counter = -1
|
self.generator_counter = -1
|
||||||
|
|
||||||
|
@ -72,13 +75,16 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
while True:
|
while True:
|
||||||
batches = None
|
batches = None
|
||||||
|
|
||||||
indexes = index_host.get(bs)
|
indexes = index_host.multi_get(bs)
|
||||||
ct_indexes = ct_index_host.get(bs) if ct_samples is not None else None
|
ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None
|
||||||
|
|
||||||
|
batch_samples = samples.multi_get (indexes)
|
||||||
|
batch_ct_samples = ct_samples.multi_get (ct_indexes) if ct_samples is not None else None
|
||||||
|
|
||||||
for n_batch in range(bs):
|
for n_batch in range(bs):
|
||||||
sample_idx = indexes[n_batch]
|
sample_idx = indexes[n_batch]
|
||||||
sample = samples[ sample_idx ]
|
sample = batch_samples[n_batch]
|
||||||
ct_sample = ct_samples[ ct_indexes[n_batch] ] if ct_samples is not None else None
|
ct_sample = batch_ct_samples[n_batch] if ct_samples is not None else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
|
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
|
||||||
|
|
|
@ -30,6 +30,7 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
||||||
self.output_sample_types = output_sample_types
|
self.output_sample_types = output_sample_types
|
||||||
self.person_id_mode = person_id_mode
|
self.person_id_mode = person_id_mode
|
||||||
|
|
||||||
|
raise NotImplementedError("Currently SampleGeneratorFacePerson is not implemented.")
|
||||||
|
|
||||||
samples_host = SampleHost.mp_host (SampleType.FACE, self.samples_path)
|
samples_host = SampleHost.mp_host (SampleType.FACE, self.samples_path)
|
||||||
samples = samples_host.get_list()
|
samples = samples_host.get_list()
|
||||||
|
|
|
@ -20,14 +20,17 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
||||||
self.sample_process_options = sample_process_options
|
self.sample_process_options = sample_process_options
|
||||||
self.output_sample_types = output_sample_types
|
self.output_sample_types = output_sample_types
|
||||||
|
|
||||||
self.samples = SampleHost.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
|
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.generators_count = 1
|
self.generators_count = 1
|
||||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )]
|
|
||||||
else:
|
else:
|
||||||
self.generators_count = min ( generators_count, len(self.samples) )
|
self.generators_count = generators_count
|
||||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, i ) for i in range(self.generators_count) ]
|
|
||||||
|
samples_clis = SampleHost.host (SampleType.FACE_TEMPORAL_SORTED, self.samples_path, number_of_clis=self.generators_count)
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples_clis[0]) )]
|
||||||
|
else:
|
||||||
|
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples_clis[i]) ) for i in range(self.generators_count) ]
|
||||||
|
|
||||||
self.generator_counter = -1
|
self.generator_counter = -1
|
||||||
|
|
||||||
|
@ -39,8 +42,9 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
||||||
generator = self.generators[self.generator_counter % len(self.generators) ]
|
generator = self.generators[self.generator_counter % len(self.generators) ]
|
||||||
return next(generator)
|
return next(generator)
|
||||||
|
|
||||||
def batch_func(self, generator_id):
|
def batch_func(self, param):
|
||||||
samples = self.samples
|
generator_id, samples = param
|
||||||
|
|
||||||
samples_len = len(samples)
|
samples_len = len(samples)
|
||||||
if samples_len == 0:
|
if samples_len == 0:
|
||||||
raise ValueError('No training data provided.')
|
raise ValueError('No training data provided.')
|
||||||
|
@ -56,10 +60,8 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
||||||
shuffle_idxs = []
|
shuffle_idxs = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
batches = None
|
batches = None
|
||||||
for n_batch in range(self.batch_size):
|
for n_batch in range(self.batch_size):
|
||||||
|
|
||||||
if len(shuffle_idxs) == 0:
|
if len(shuffle_idxs) == 0:
|
||||||
shuffle_idxs = samples_idxs.copy()
|
shuffle_idxs = samples_idxs.copy()
|
||||||
np.random.shuffle (shuffle_idxs)
|
np.random.shuffle (shuffle_idxs)
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
import gc
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import operator
|
import operator
|
||||||
import pickle
|
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -16,9 +14,11 @@ from .Sample import Sample, SampleType
|
||||||
|
|
||||||
|
|
||||||
class SampleHost:
|
class SampleHost:
|
||||||
samples_cache = dict()
|
|
||||||
host_cache = dict()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
samples_cache = dict()
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_person_id_max_count(samples_path):
|
def get_person_id_max_count(samples_path):
|
||||||
samples = None
|
samples = None
|
||||||
|
@ -35,7 +35,7 @@ class SampleHost:
|
||||||
return len(list(persons_name_idxs.keys()))
|
return len(list(persons_name_idxs.keys()))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(sample_type, samples_path):
|
def host(sample_type, samples_path, number_of_clis):
|
||||||
samples_cache = SampleHost.samples_cache
|
samples_cache = SampleHost.samples_cache
|
||||||
|
|
||||||
if str(samples_path) not in samples_cache.keys():
|
if str(samples_path) not in samples_cache.keys():
|
||||||
|
@ -46,9 +46,11 @@ class SampleHost:
|
||||||
if sample_type == SampleType.IMAGE:
|
if sample_type == SampleType.IMAGE:
|
||||||
if samples[sample_type] is None:
|
if samples[sample_type] is None:
|
||||||
samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ]
|
samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ]
|
||||||
elif sample_type == SampleType.FACE:
|
elif sample_type == SampleType.FACE or \
|
||||||
|
sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||||
|
result = None
|
||||||
|
|
||||||
if samples[sample_type] is None:
|
if samples[sample_type] is None:
|
||||||
result = None
|
|
||||||
try:
|
try:
|
||||||
result = samplelib.PackedFaceset.load(samples_path)
|
result = samplelib.PackedFaceset.load(samples_path)
|
||||||
except:
|
except:
|
||||||
|
@ -60,33 +62,26 @@ class SampleHost:
|
||||||
if result is None:
|
if result is None:
|
||||||
result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) )
|
result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) )
|
||||||
|
|
||||||
result_dmp = pickle.dumps(result)
|
samples[sample_type] = mp_utils.ListHost()
|
||||||
del result
|
|
||||||
gc.collect()
|
|
||||||
result = pickle.loads(result_dmp)
|
|
||||||
|
|
||||||
samples[sample_type] = result
|
if sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||||
|
result = SampleHost.upgradeToFaceTemporalSortedSamples(result)
|
||||||
|
|
||||||
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
list_host = samples[sample_type]
|
||||||
if samples[sample_type] is None:
|
|
||||||
samples[sample_type] = SampleHost.upgradeToFaceTemporalSortedSamples( SampleHost.load(SampleType.FACE, samples_path) )
|
clis = [ list_host.create_cli() for _ in range(number_of_clis) ]
|
||||||
|
|
||||||
|
if result is not None:
|
||||||
|
while True:
|
||||||
|
if len(result) == 0:
|
||||||
|
break
|
||||||
|
items = result[0:10000]
|
||||||
|
del result[0:10000]
|
||||||
|
clis[0].extend(items)
|
||||||
|
return clis
|
||||||
|
|
||||||
return samples[sample_type]
|
return samples[sample_type]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mp_host(sample_type, samples_path):
|
|
||||||
result = SampleHost.load (sample_type, samples_path)
|
|
||||||
|
|
||||||
host_cache = SampleHost.host_cache
|
|
||||||
if str(samples_path) not in host_cache.keys():
|
|
||||||
host_cache[str(samples_path)] = [None]*SampleType.QTY
|
|
||||||
hosts = host_cache[str(samples_path)]
|
|
||||||
|
|
||||||
if hosts[sample_type] is None:
|
|
||||||
hosts[sample_type] = mp_utils.ListHost(result)
|
|
||||||
|
|
||||||
return hosts[sample_type]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_face_samples ( image_paths):
|
def load_face_samples ( image_paths):
|
||||||
result = FaceSamplesLoaderSubprocessor(image_paths).run()
|
result = FaceSamplesLoaderSubprocessor(image_paths).run()
|
||||||
|
|
|
@ -22,7 +22,7 @@ class ThisThreadGenerator(object):
|
||||||
return next(self.generator_func)
|
return next(self.generator_func)
|
||||||
|
|
||||||
class SubprocessGenerator(object):
|
class SubprocessGenerator(object):
|
||||||
def __init__(self, generator_func, user_param=None, prefetch=2, start_now=False):
|
def __init__(self, generator_func, user_param=None, prefetch=3, start_now=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.prefetch = prefetch
|
self.prefetch = prefetch
|
||||||
self.generator_func = generator_func
|
self.generator_func = generator_func
|
||||||
|
|
|
@ -1,22 +1,25 @@
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class Index2DHost():
|
class Index2DHost():
|
||||||
"""
|
"""
|
||||||
Provides random shuffled 2D indexes for multiprocesses
|
Provides random shuffled 2D indexes for multiprocesses
|
||||||
"""
|
"""
|
||||||
def __init__(self, indexes2D):
|
def __init__(self, indexes2D, max_number_of_clis=128):
|
||||||
self.sq = multiprocessing.Queue()
|
self.sq = multiprocessing.Queue()
|
||||||
self.cqs = []
|
self.cqs = [ multiprocessing.Queue() for _ in range(max_number_of_clis) ]
|
||||||
self.clis = []
|
self.n_clis = 0
|
||||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
|
self.max_number_of_clis = max_number_of_clis
|
||||||
self.thread.daemon = True
|
|
||||||
self.thread.start()
|
|
||||||
|
|
||||||
def host_thread(self, indexes2D):
|
self.p = multiprocessing.Process(target=self.host_proc, args=(indexes2D, self.sq, self.cqs) )
|
||||||
|
self.p.daemon = True
|
||||||
|
self.p.start()
|
||||||
|
|
||||||
|
def host_proc(self, indexes2D, sq, cqs):
|
||||||
indexes_counts_len = len(indexes2D)
|
indexes_counts_len = len(indexes2D)
|
||||||
|
|
||||||
idxs = [*range(indexes_counts_len)]
|
idxs = [*range(indexes_counts_len)]
|
||||||
|
@ -27,8 +30,6 @@ class Index2DHost():
|
||||||
idxs_2D[i] = indexes2D[i]
|
idxs_2D[i] = indexes2D[i]
|
||||||
shuffle_idxs_2D[i] = []
|
shuffle_idxs_2D[i] = []
|
||||||
|
|
||||||
sq = self.sq
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
while not sq.empty():
|
while not sq.empty():
|
||||||
obj = sq.get()
|
obj = sq.get()
|
||||||
|
@ -43,7 +44,7 @@ class Index2DHost():
|
||||||
shuffle_idxs = idxs.copy()
|
shuffle_idxs = idxs.copy()
|
||||||
np.random.shuffle(shuffle_idxs)
|
np.random.shuffle(shuffle_idxs)
|
||||||
result.append(shuffle_idxs.pop())
|
result.append(shuffle_idxs.pop())
|
||||||
self.cqs[cq_id].put (result)
|
cqs[cq_id].put (result)
|
||||||
elif cmd == 1: #get_2D
|
elif cmd == 1: #get_2D
|
||||||
targ_idxs,count = obj[2], obj[3]
|
targ_idxs,count = obj[2], obj[3]
|
||||||
result = []
|
result = []
|
||||||
|
@ -57,7 +58,7 @@ class Index2DHost():
|
||||||
np.random.shuffle(ar)
|
np.random.shuffle(ar)
|
||||||
sub_idxs.append(ar.pop())
|
sub_idxs.append(ar.pop())
|
||||||
result.append (sub_idxs)
|
result.append (sub_idxs)
|
||||||
self.cqs[cq_id].put (result)
|
cqs[cq_id].put (result)
|
||||||
|
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
|
|
||||||
|
@ -99,18 +100,19 @@ class IndexHost():
|
||||||
"""
|
"""
|
||||||
Provides random shuffled indexes for multiprocesses
|
Provides random shuffled indexes for multiprocesses
|
||||||
"""
|
"""
|
||||||
def __init__(self, indexes_count):
|
def __init__(self, indexes_count, max_number_of_clis=128):
|
||||||
self.sq = multiprocessing.Queue()
|
self.sq = multiprocessing.Queue()
|
||||||
self.cqs = []
|
self.cqs = [ multiprocessing.Queue() for _ in range(max_number_of_clis) ]
|
||||||
self.clis = []
|
self.n_clis = 0
|
||||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,) )
|
self.max_number_of_clis = max_number_of_clis
|
||||||
self.thread.daemon = True
|
|
||||||
self.thread.start()
|
|
||||||
|
|
||||||
def host_thread(self, indexes_count):
|
self.p = multiprocessing.Process(target=self.host_proc, args=(indexes_count, self.sq, self.cqs) )
|
||||||
|
self.p.daemon = True
|
||||||
|
self.p.start()
|
||||||
|
|
||||||
|
def host_proc(self, indexes_count, sq, cqs):
|
||||||
idxs = [*range(indexes_count)]
|
idxs = [*range(indexes_count)]
|
||||||
shuffle_idxs = []
|
shuffle_idxs = []
|
||||||
sq = self.sq
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
while not sq.empty():
|
while not sq.empty():
|
||||||
|
@ -123,15 +125,18 @@ class IndexHost():
|
||||||
shuffle_idxs = idxs.copy()
|
shuffle_idxs = idxs.copy()
|
||||||
np.random.shuffle(shuffle_idxs)
|
np.random.shuffle(shuffle_idxs)
|
||||||
result.append(shuffle_idxs.pop())
|
result.append(shuffle_idxs.pop())
|
||||||
self.cqs[cq_id].put (result)
|
cqs[cq_id].put (result)
|
||||||
|
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
|
|
||||||
def create_cli(self):
|
def create_cli(self):
|
||||||
cq = multiprocessing.Queue()
|
if self.n_clis == self.max_number_of_clis:
|
||||||
self.cqs.append ( cq )
|
raise Exception("")
|
||||||
cq_id = len(self.cqs)-1
|
|
||||||
return IndexHost.Cli(self.sq, cq, cq_id)
|
cq_id = self.n_clis
|
||||||
|
self.n_clis += 1
|
||||||
|
|
||||||
|
return IndexHost.Cli(self.sq, self.cqs[cq_id], cq_id)
|
||||||
|
|
||||||
# disable pickling
|
# disable pickling
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
|
@ -145,7 +150,7 @@ class IndexHost():
|
||||||
self.cq = cq
|
self.cq = cq
|
||||||
self.cq_id = cq_id
|
self.cq_id = cq_id
|
||||||
|
|
||||||
def get(self, count):
|
def multi_get(self, count):
|
||||||
self.sq.put ( (self.cq_id,count) )
|
self.sq.put ( (self.cq_id,count) )
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
@ -154,37 +159,50 @@ class IndexHost():
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
class ListHost():
|
class ListHost():
|
||||||
def __init__(self, list_):
|
def __init__(self, list_=None, max_number_of_clis=128):
|
||||||
self.sq = multiprocessing.Queue()
|
self.sq = multiprocessing.Queue()
|
||||||
self.cqs = []
|
self.cqs = [ multiprocessing.Queue() for _ in range(max_number_of_clis) ]
|
||||||
self.clis = []
|
self.n_clis = 0
|
||||||
self.list_ = list_
|
self.max_number_of_clis = max_number_of_clis
|
||||||
self.thread = threading.Thread(target=self.host_thread)
|
|
||||||
self.thread.daemon = True
|
self.p = multiprocessing.Process(target=self.host_proc, args=(self.sq, self.cqs) )
|
||||||
self.thread.start()
|
self.p.daemon = True
|
||||||
|
self.p.start()
|
||||||
|
|
||||||
|
def host_proc(self, sq, cqs):
|
||||||
|
m_list = list()
|
||||||
|
|
||||||
def host_thread(self):
|
|
||||||
sq = self.sq
|
|
||||||
while True:
|
while True:
|
||||||
while not sq.empty():
|
while not sq.empty():
|
||||||
obj = sq.get()
|
obj = sq.get()
|
||||||
cq_id, cmd = obj[0], obj[1]
|
cq_id, cmd = obj[0], obj[1]
|
||||||
if cmd == 0:
|
if cmd == 0:
|
||||||
item = self.list_[ obj[2] ]
|
cqs[cq_id].put ( len(m_list) )
|
||||||
self.cqs[cq_id].put ( item )
|
|
||||||
|
|
||||||
elif cmd == 1:
|
elif cmd == 1:
|
||||||
self.cqs[cq_id].put ( len(self.list_) )
|
idx = obj[2]
|
||||||
|
item = m_list[idx ]
|
||||||
|
cqs[cq_id].put ( item )
|
||||||
|
elif cmd == 2:
|
||||||
|
result = []
|
||||||
|
for item in obj[2]:
|
||||||
|
result.append ( m_list[item] )
|
||||||
|
cqs[cq_id].put ( result )
|
||||||
|
elif cmd == 3:
|
||||||
|
m_list.insert(obj[2], obj[3])
|
||||||
|
elif cmd == 4:
|
||||||
|
m_list.append(obj[2])
|
||||||
|
elif cmd == 5:
|
||||||
|
m_list.extend(obj[2])
|
||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
|
|
||||||
def create_cli(self):
|
def create_cli(self):
|
||||||
cq = multiprocessing.Queue()
|
if self.n_clis == self.max_number_of_clis:
|
||||||
self.cqs.append ( cq )
|
raise Exception("")
|
||||||
cq_id = len(self.cqs)-1
|
|
||||||
return ListHost.Cli(self.sq, cq, cq_id)
|
|
||||||
|
|
||||||
def get_list(self):
|
cq_id = self.n_clis
|
||||||
return self.list_
|
self.n_clis += 1
|
||||||
|
|
||||||
|
return ListHost.Cli(self.sq, self.cqs[cq_id], cq_id)
|
||||||
|
|
||||||
# disable pickling
|
# disable pickling
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
|
@ -198,22 +216,41 @@ class ListHost():
|
||||||
self.cq = cq
|
self.cq = cq
|
||||||
self.cq_id = cq_id
|
self.cq_id = cq_id
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
self.sq.put ( (self.cq_id,0,key) )
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if not self.cq.empty():
|
|
||||||
return self.cq.get()
|
|
||||||
time.sleep(0.001)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
self.sq.put ( (self.cq_id,1) )
|
self.sq.put ( (self.cq_id,0) )
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if not self.cq.empty():
|
if not self.cq.empty():
|
||||||
return self.cq.get()
|
return self.cq.get()
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
self.sq.put ( (self.cq_id,1,key) )
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if not self.cq.empty():
|
||||||
|
return self.cq.get()
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
def multi_get(self, keys):
|
||||||
|
self.sq.put ( (self.cq_id,2,keys) )
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if not self.cq.empty():
|
||||||
|
return self.cq.get()
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
def insert(self, index, item):
|
||||||
|
self.sq.put ( (self.cq_id,3,index,item) )
|
||||||
|
|
||||||
|
def append(self, item):
|
||||||
|
self.sq.put ( (self.cq_id,4,item) )
|
||||||
|
|
||||||
|
def extend(self, items):
|
||||||
|
self.sq.put ( (self.cq_id,5,items) )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DictHost():
|
class DictHost():
|
||||||
def __init__(self, d, num_users):
|
def __init__(self, d, num_users):
|
||||||
self.sqs = [ multiprocessing.Queue() for _ in range(num_users) ]
|
self.sqs = [ multiprocessing.Queue() for _ in range(num_users) ]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue