mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-14 02:37:00 -07:00
upd SampleGenerator
This commit is contained in:
parent
1898bd6881
commit
9860a38907
4 changed files with 42 additions and 33 deletions
|
@ -99,15 +99,17 @@ class IndexHost():
|
|||
"""
|
||||
Provides random shuffled indexes for multiprocesses
|
||||
"""
|
||||
def __init__(self, indexes_count):
|
||||
def __init__(self, indexes_count, rnd_seed=None):
|
||||
self.sq = multiprocessing.Queue()
|
||||
self.cqs = []
|
||||
self.clis = []
|
||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,) )
|
||||
self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,rnd_seed) )
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def host_thread(self, indexes_count):
|
||||
def host_thread(self, indexes_count, rnd_seed):
|
||||
rnd_state = np.random.RandomState(rnd_seed) if rnd_seed is not None else np.random
|
||||
|
||||
idxs = [*range(indexes_count)]
|
||||
shuffle_idxs = []
|
||||
sq = self.sq
|
||||
|
@ -121,7 +123,7 @@ class IndexHost():
|
|||
for i in range(count):
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
rnd_state.shuffle(shuffle_idxs)
|
||||
result.append(shuffle_idxs.pop())
|
||||
self.cqs[cq_id].put (result)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue