This commit is contained in:
Colombo 2020-01-05 13:58:25 +04:00
parent 21b25038ac
commit ea33541177
2 changed files with 92 additions and 100 deletions

View file

@ -1,5 +1,6 @@
import multiprocessing import multiprocessing
import operator import operator
import pickle
import traceback import traceback
from pathlib import Path from pathlib import Path
@ -14,23 +15,23 @@ from .Sample import Sample, SampleType
class SampleHost: class SampleHost:
samples_cache = dict() samples_cache = dict()
@staticmethod @staticmethod
def get_person_id_max_count(samples_path): def get_person_id_max_count(samples_path):
samples = None samples = None
try: try:
samples = samplelib.PackedFaceset.load(samples_path) samples = samplelib.PackedFaceset.load(samples_path)
except: except:
io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}") io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}")
if samples is None: if samples is None:
raise ValueError("packed faceset not found.") raise ValueError("packed faceset not found.")
persons_name_idxs = {} persons_name_idxs = {}
for sample in samples: for sample in samples:
persons_name_idxs[sample.person_name] = 0 persons_name_idxs[sample.person_name] = 0
return len(list(persons_name_idxs.keys())) return len(list(persons_name_idxs.keys()))
@ -49,41 +50,37 @@ class SampleHost:
elif sample_type == SampleType.FACE or \ elif sample_type == SampleType.FACE or \
sample_type == SampleType.FACE_TEMPORAL_SORTED: sample_type == SampleType.FACE_TEMPORAL_SORTED:
result = None result = None
if samples[sample_type] is None: if samples[sample_type] is None:
try: try:
result = samplelib.PackedFaceset.load(samples_path) result = samplelib.PackedFaceset.load(samples_path)
except: except:
io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}") io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}")
if result is not None: if result is not None:
io.log_info (f"Loaded {len(result)} packed faces from {samples_path}") io.log_info (f"Loaded {len(result)} packed faces from {samples_path}")
if result is None: if result is None:
result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) ) result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) )
samples[sample_type] = mp_utils.ListHost()
if sample_type == SampleType.FACE_TEMPORAL_SORTED: if sample_type == SampleType.FACE_TEMPORAL_SORTED:
result = SampleHost.upgradeToFaceTemporalSortedSamples(result) result = SampleHost.upgradeToFaceTemporalSortedSamples(result)
result_dumped = pickle.dumps(result)
del result
result = pickle.loads(result_dumped)
samples[sample_type] = mp_utils.ListHost(result)
list_host = samples[sample_type] list_host = samples[sample_type]
clis = [ list_host.create_cli() for _ in range(number_of_clis) ] clis = [ list_host.create_cli() for _ in range(number_of_clis) ]
if result is not None:
while True:
if len(result) == 0:
break
items = result[0:10000]
del result[0:10000]
clis[0].extend(items)
return clis return clis
return samples[sample_type] return samples[sample_type]
@staticmethod @staticmethod
def load_face_samples ( image_paths): def load_face_samples ( image_paths):
result = FaceSamplesLoaderSubprocessor(image_paths).run() result = FaceSamplesLoaderSubprocessor(image_paths).run()
sample_list = [] sample_list = []
@ -93,7 +90,7 @@ class SampleHost:
landmarks, landmarks,
ie_polys, ie_polys,
eyebrows_expand_mod, eyebrows_expand_mod,
source_filename, source_filename,
) in result: ) in result:
sample_list.append( Sample(filename=filename, sample_list.append( Sample(filename=filename,
sample_type=SampleType.FACE, sample_type=SampleType.FACE,
@ -102,7 +99,7 @@ class SampleHost:
landmarks=landmarks, landmarks=landmarks,
ie_polys=ie_polys, ie_polys=ie_polys,
eyebrows_expand_mod=eyebrows_expand_mod, eyebrows_expand_mod=eyebrows_expand_mod,
source_filename=source_filename, source_filename=source_filename,
)) ))
return sample_list return sample_list
@ -112,14 +109,14 @@ class SampleHost:
new_s = sorted(new_s, key=operator.itemgetter(1)) new_s = sorted(new_s, key=operator.itemgetter(1))
return [ s[0] for s in new_s] return [ s[0] for s in new_s]
class FaceSamplesLoaderSubprocessor(Subprocessor): class FaceSamplesLoaderSubprocessor(Subprocessor):
#override #override
def __init__(self, image_paths ): def __init__(self, image_paths ):
self.image_paths = image_paths self.image_paths = image_paths
self.image_paths_len = len(image_paths) self.image_paths_len = len(image_paths)
self.idxs = [*range(self.image_paths_len)] self.idxs = [*range(self.image_paths_len)]
self.result = [None]*self.image_paths_len self.result = [None]*self.image_paths_len
super().__init__('FaceSamplesLoader', FaceSamplesLoaderSubprocessor.Cli, 60, initialize_subprocesses_in_serial=False) super().__init__('FaceSamplesLoader', FaceSamplesLoaderSubprocessor.Cli, 60, initialize_subprocesses_in_serial=False)
@ -169,7 +166,7 @@ class FaceSamplesLoaderSubprocessor(Subprocessor):
def process_data(self, data): def process_data(self, data):
idx, filename = data idx, filename = data
dflimg = DFLIMG.load (Path(filename)) dflimg = DFLIMG.load (Path(filename))
if dflimg is None: if dflimg is None:
self.log_err (f"FaceSamplesLoader: {filename} is not a dfl image file.") self.log_err (f"FaceSamplesLoader: {filename} is not a dfl image file.")
data = None data = None
@ -179,8 +176,8 @@ class FaceSamplesLoaderSubprocessor(Subprocessor):
dflimg.get_landmarks(), dflimg.get_landmarks(),
dflimg.get_ie_polys(), dflimg.get_ie_polys(),
dflimg.get_eyebrows_expand_mod(), dflimg.get_eyebrows_expand_mod(),
dflimg.get_source_filename() ) dflimg.get_source_filename() )
return idx, data return idx, data
#override #override

View file

@ -1,25 +1,22 @@
import multiprocessing import multiprocessing
import threading import threading
import time import time
import traceback
import numpy as np
import numpy as np
class Index2DHost(): class Index2DHost():
""" """
Provides random shuffled 2D indexes for multiprocesses Provides random shuffled 2D indexes for multiprocesses
""" """
def __init__(self, indexes2D, max_number_of_clis=128): def __init__(self, indexes2D):
self.sq = multiprocessing.Queue() self.sq = multiprocessing.Queue()
self.cqs = [ multiprocessing.Queue() for _ in range(max_number_of_clis) ] self.cqs = []
self.n_clis = 0 self.clis = []
self.max_number_of_clis = max_number_of_clis self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
self.thread.daemon = True
self.thread.start()
self.p = multiprocessing.Process(target=self.host_proc, args=(indexes2D, self.sq, self.cqs) ) def host_thread(self, indexes2D):
self.p.daemon = True
self.p.start()
def host_proc(self, indexes2D, sq, cqs):
indexes_counts_len = len(indexes2D) indexes_counts_len = len(indexes2D)
idxs = [*range(indexes_counts_len)] idxs = [*range(indexes_counts_len)]
@ -30,6 +27,8 @@ class Index2DHost():
idxs_2D[i] = indexes2D[i] idxs_2D[i] = indexes2D[i]
shuffle_idxs_2D[i] = [] shuffle_idxs_2D[i] = []
sq = self.sq
while True: while True:
while not sq.empty(): while not sq.empty():
obj = sq.get() obj = sq.get()
@ -44,7 +43,7 @@ class Index2DHost():
shuffle_idxs = idxs.copy() shuffle_idxs = idxs.copy()
np.random.shuffle(shuffle_idxs) np.random.shuffle(shuffle_idxs)
result.append(shuffle_idxs.pop()) result.append(shuffle_idxs.pop())
cqs[cq_id].put (result) self.cqs[cq_id].put (result)
elif cmd == 1: #get_2D elif cmd == 1: #get_2D
targ_idxs,count = obj[2], obj[3] targ_idxs,count = obj[2], obj[3]
result = [] result = []
@ -58,7 +57,7 @@ class Index2DHost():
np.random.shuffle(ar) np.random.shuffle(ar)
sub_idxs.append(ar.pop()) sub_idxs.append(ar.pop())
result.append (sub_idxs) result.append (sub_idxs)
cqs[cq_id].put (result) self.cqs[cq_id].put (result)
time.sleep(0.005) time.sleep(0.005)
@ -100,19 +99,18 @@ class IndexHost():
""" """
Provides random shuffled indexes for multiprocesses Provides random shuffled indexes for multiprocesses
""" """
def __init__(self, indexes_count, max_number_of_clis=128): def __init__(self, indexes_count):
self.sq = multiprocessing.Queue() self.sq = multiprocessing.Queue()
self.cqs = [ multiprocessing.Queue() for _ in range(max_number_of_clis) ] self.cqs = []
self.n_clis = 0 self.clis = []
self.max_number_of_clis = max_number_of_clis self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,) )
self.thread.daemon = True
self.thread.start()
self.p = multiprocessing.Process(target=self.host_proc, args=(indexes_count, self.sq, self.cqs) ) def host_thread(self, indexes_count):
self.p.daemon = True
self.p.start()
def host_proc(self, indexes_count, sq, cqs):
idxs = [*range(indexes_count)] idxs = [*range(indexes_count)]
shuffle_idxs = [] shuffle_idxs = []
sq = self.sq
while True: while True:
while not sq.empty(): while not sq.empty():
@ -125,18 +123,15 @@ class IndexHost():
shuffle_idxs = idxs.copy() shuffle_idxs = idxs.copy()
np.random.shuffle(shuffle_idxs) np.random.shuffle(shuffle_idxs)
result.append(shuffle_idxs.pop()) result.append(shuffle_idxs.pop())
cqs[cq_id].put (result) self.cqs[cq_id].put (result)
time.sleep(0.005) time.sleep(0.005)
def create_cli(self): def create_cli(self):
if self.n_clis == self.max_number_of_clis: cq = multiprocessing.Queue()
raise Exception("") self.cqs.append ( cq )
cq_id = len(self.cqs)-1
cq_id = self.n_clis return IndexHost.Cli(self.sq, cq, cq_id)
self.n_clis += 1
return IndexHost.Cli(self.sq, self.cqs[cq_id], cq_id)
# disable pickling # disable pickling
def __getstate__(self): def __getstate__(self):
@ -159,50 +154,50 @@ class IndexHost():
time.sleep(0.001) time.sleep(0.001)
class ListHost(): class ListHost():
def __init__(self, list_=None, max_number_of_clis=128): def __init__(self, list_):
self.sq = multiprocessing.Queue() self.sq = multiprocessing.Queue()
self.cqs = [ multiprocessing.Queue() for _ in range(max_number_of_clis) ] self.cqs = []
self.n_clis = 0 self.clis = []
self.max_number_of_clis = max_number_of_clis self.m_list = list_
self.thread = threading.Thread(target=self.host_thread)
self.p = multiprocessing.Process(target=self.host_proc, args=(self.sq, self.cqs) ) self.thread.daemon = True
self.p.daemon = True self.thread.start()
self.p.start()
def host_proc(self, sq, cqs): def host_thread(self):
m_list = list() sq = self.sq
while True: while True:
while not sq.empty(): while not sq.empty():
obj = sq.get() obj = sq.get()
cq_id, cmd = obj[0], obj[1] cq_id, cmd = obj[0], obj[1]
if cmd == 0: if cmd == 0:
cqs[cq_id].put ( len(m_list) ) self.cqs[cq_id].put ( len(self.m_list) )
elif cmd == 1: elif cmd == 1:
idx = obj[2] idx = obj[2]
item = m_list[idx ] item = self.m_list[idx ]
cqs[cq_id].put ( item ) self.cqs[cq_id].put ( item )
elif cmd == 2: elif cmd == 2:
result = [] result = []
for item in obj[2]: for item in obj[2]:
result.append ( m_list[item] ) result.append ( self.m_list[item] )
cqs[cq_id].put ( result ) self.cqs[cq_id].put ( result )
elif cmd == 3: elif cmd == 3:
m_list.insert(obj[2], obj[3]) self.m_list.insert(obj[2], obj[3])
elif cmd == 4: elif cmd == 4:
m_list.append(obj[2]) self.m_list.append(obj[2])
elif cmd == 5: elif cmd == 5:
m_list.extend(obj[2]) self.m_list.extend(obj[2])
time.sleep(0.005) time.sleep(0.005)
def create_cli(self): def create_cli(self):
if self.n_clis == self.max_number_of_clis: cq = multiprocessing.Queue()
raise Exception("") self.cqs.append ( cq )
cq_id = len(self.cqs)-1
cq_id = self.n_clis return ListHost.Cli(self.sq, cq, cq_id)
self.n_clis += 1
def get_list(self):
return ListHost.Cli(self.sq, self.cqs[cq_id], cq_id) return self.list_
# disable pickling # disable pickling
def __getstate__(self): def __getstate__(self):
@ -223,7 +218,7 @@ class ListHost():
if not self.cq.empty(): if not self.cq.empty():
return self.cq.get() return self.cq.get()
time.sleep(0.001) time.sleep(0.001)
def __getitem__(self, key): def __getitem__(self, key):
self.sq.put ( (self.cq_id,1,key) ) self.sq.put ( (self.cq_id,1,key) )
@ -231,7 +226,7 @@ class ListHost():
if not self.cq.empty(): if not self.cq.empty():
return self.cq.get() return self.cq.get()
time.sleep(0.001) time.sleep(0.001)
def multi_get(self, keys): def multi_get(self, keys):
self.sq.put ( (self.cq_id,2,keys) ) self.sq.put ( (self.cq_id,2,keys) )
@ -239,17 +234,17 @@ class ListHost():
if not self.cq.empty(): if not self.cq.empty():
return self.cq.get() return self.cq.get()
time.sleep(0.001) time.sleep(0.001)
def insert(self, index, item): def insert(self, index, item):
self.sq.put ( (self.cq_id,3,index,item) ) self.sq.put ( (self.cq_id,3,index,item) )
def append(self, item): def append(self, item):
self.sq.put ( (self.cq_id,4,item) ) self.sq.put ( (self.cq_id,4,item) )
def extend(self, items): def extend(self, items):
self.sq.put ( (self.cq_id,5,items) ) self.sq.put ( (self.cq_id,5,items) )
class DictHost(): class DictHost():
def __init__(self, d, num_users): def __init__(self, d, num_users):