mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
fix
This commit is contained in:
parent
21b25038ac
commit
ea33541177
2 changed files with 92 additions and 100 deletions
|
@ -1,25 +1,22 @@
|
|||
import multiprocessing
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import numpy as np
|
||||
|
||||
import numpy as np
|
||||
|
||||
class Index2DHost():
|
||||
"""
|
||||
Provides random shuffled 2D indexes for multiprocesses
|
||||
"""
|
||||
def __init__(self, indexes2D, max_number_of_clis=128):
|
||||
def __init__(self, indexes2D):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.cqs = [ multiprocessing.Queue() for _ in range(max_number_of_clis) ]
|
||||
self.n_clis = 0
|
||||
self.max_number_of_clis = max_number_of_clis
|
||||
self.cqs = []
|
||||
self.clis = []
|
||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) )
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
self.p = multiprocessing.Process(target=self.host_proc, args=(indexes2D, self.sq, self.cqs) )
|
||||
self.p.daemon = True
|
||||
self.p.start()
|
||||
|
||||
def host_proc(self, indexes2D, sq, cqs):
|
||||
def host_thread(self, indexes2D):
|
||||
indexes_counts_len = len(indexes2D)
|
||||
|
||||
idxs = [*range(indexes_counts_len)]
|
||||
|
@ -30,6 +27,8 @@ class Index2DHost():
|
|||
idxs_2D[i] = indexes2D[i]
|
||||
shuffle_idxs_2D[i] = []
|
||||
|
||||
sq = self.sq
|
||||
|
||||
while True:
|
||||
while not sq.empty():
|
||||
obj = sq.get()
|
||||
|
@ -44,7 +43,7 @@ class Index2DHost():
|
|||
shuffle_idxs = idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
result.append(shuffle_idxs.pop())
|
||||
cqs[cq_id].put (result)
|
||||
self.cqs[cq_id].put (result)
|
||||
elif cmd == 1: #get_2D
|
||||
targ_idxs,count = obj[2], obj[3]
|
||||
result = []
|
||||
|
@ -58,7 +57,7 @@ class Index2DHost():
|
|||
np.random.shuffle(ar)
|
||||
sub_idxs.append(ar.pop())
|
||||
result.append (sub_idxs)
|
||||
cqs[cq_id].put (result)
|
||||
self.cqs[cq_id].put (result)
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
|
@ -100,19 +99,18 @@ class IndexHost():
|
|||
"""
|
||||
Provides random shuffled indexes for multiprocesses
|
||||
"""
|
||||
def __init__(self, indexes_count, max_number_of_clis=128):
|
||||
def __init__(self, indexes_count):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.cqs = [ multiprocessing.Queue() for _ in range(max_number_of_clis) ]
|
||||
self.n_clis = 0
|
||||
self.max_number_of_clis = max_number_of_clis
|
||||
self.cqs = []
|
||||
self.clis = []
|
||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,) )
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
self.p = multiprocessing.Process(target=self.host_proc, args=(indexes_count, self.sq, self.cqs) )
|
||||
self.p.daemon = True
|
||||
self.p.start()
|
||||
|
||||
def host_proc(self, indexes_count, sq, cqs):
|
||||
def host_thread(self, indexes_count):
|
||||
idxs = [*range(indexes_count)]
|
||||
shuffle_idxs = []
|
||||
sq = self.sq
|
||||
|
||||
while True:
|
||||
while not sq.empty():
|
||||
|
@ -125,18 +123,15 @@ class IndexHost():
|
|||
shuffle_idxs = idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
result.append(shuffle_idxs.pop())
|
||||
cqs[cq_id].put (result)
|
||||
self.cqs[cq_id].put (result)
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
def create_cli(self):
|
||||
if self.n_clis == self.max_number_of_clis:
|
||||
raise Exception("")
|
||||
|
||||
cq_id = self.n_clis
|
||||
self.n_clis += 1
|
||||
|
||||
return IndexHost.Cli(self.sq, self.cqs[cq_id], cq_id)
|
||||
cq = multiprocessing.Queue()
|
||||
self.cqs.append ( cq )
|
||||
cq_id = len(self.cqs)-1
|
||||
return IndexHost.Cli(self.sq, cq, cq_id)
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
|
@ -159,50 +154,50 @@ class IndexHost():
|
|||
time.sleep(0.001)
|
||||
|
||||
class ListHost():
|
||||
def __init__(self, list_=None, max_number_of_clis=128):
|
||||
def __init__(self, list_):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.cqs = [ multiprocessing.Queue() for _ in range(max_number_of_clis) ]
|
||||
self.n_clis = 0
|
||||
self.max_number_of_clis = max_number_of_clis
|
||||
|
||||
self.p = multiprocessing.Process(target=self.host_proc, args=(self.sq, self.cqs) )
|
||||
self.p.daemon = True
|
||||
self.p.start()
|
||||
self.cqs = []
|
||||
self.clis = []
|
||||
self.m_list = list_
|
||||
self.thread = threading.Thread(target=self.host_thread)
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def host_proc(self, sq, cqs):
|
||||
m_list = list()
|
||||
|
||||
def host_thread(self):
|
||||
sq = self.sq
|
||||
while True:
|
||||
while not sq.empty():
|
||||
obj = sq.get()
|
||||
cq_id, cmd = obj[0], obj[1]
|
||||
|
||||
if cmd == 0:
|
||||
cqs[cq_id].put ( len(m_list) )
|
||||
self.cqs[cq_id].put ( len(self.m_list) )
|
||||
elif cmd == 1:
|
||||
idx = obj[2]
|
||||
item = m_list[idx ]
|
||||
cqs[cq_id].put ( item )
|
||||
idx = obj[2]
|
||||
item = self.m_list[idx ]
|
||||
self.cqs[cq_id].put ( item )
|
||||
elif cmd == 2:
|
||||
result = []
|
||||
for item in obj[2]:
|
||||
result.append ( m_list[item] )
|
||||
cqs[cq_id].put ( result )
|
||||
result.append ( self.m_list[item] )
|
||||
self.cqs[cq_id].put ( result )
|
||||
elif cmd == 3:
|
||||
m_list.insert(obj[2], obj[3])
|
||||
self.m_list.insert(obj[2], obj[3])
|
||||
elif cmd == 4:
|
||||
m_list.append(obj[2])
|
||||
self.m_list.append(obj[2])
|
||||
elif cmd == 5:
|
||||
m_list.extend(obj[2])
|
||||
self.m_list.extend(obj[2])
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
def create_cli(self):
|
||||
if self.n_clis == self.max_number_of_clis:
|
||||
raise Exception("")
|
||||
|
||||
cq_id = self.n_clis
|
||||
self.n_clis += 1
|
||||
|
||||
return ListHost.Cli(self.sq, self.cqs[cq_id], cq_id)
|
||||
def create_cli(self):
|
||||
cq = multiprocessing.Queue()
|
||||
self.cqs.append ( cq )
|
||||
cq_id = len(self.cqs)-1
|
||||
return ListHost.Cli(self.sq, cq, cq_id)
|
||||
|
||||
def get_list(self):
|
||||
return self.list_
|
||||
|
||||
# disable pickling
|
||||
def __getstate__(self):
|
||||
|
@ -223,7 +218,7 @@ class ListHost():
|
|||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
self.sq.put ( (self.cq_id,1,key) )
|
||||
|
||||
|
@ -231,7 +226,7 @@ class ListHost():
|
|||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
|
||||
def multi_get(self, keys):
|
||||
self.sq.put ( (self.cq_id,2,keys) )
|
||||
|
||||
|
@ -239,17 +234,17 @@ class ListHost():
|
|||
if not self.cq.empty():
|
||||
return self.cq.get()
|
||||
time.sleep(0.001)
|
||||
|
||||
|
||||
def insert(self, index, item):
|
||||
self.sq.put ( (self.cq_id,3,index,item) )
|
||||
|
||||
|
||||
def append(self, item):
|
||||
self.sq.put ( (self.cq_id,4,item) )
|
||||
|
||||
|
||||
def extend(self, items):
|
||||
self.sq.put ( (self.cq_id,5,items) )
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class DictHost():
|
||||
def __init__(self, d, num_users):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue