upd SampleGenerator

This commit is contained in:
Colombo 2020-02-27 09:58:46 +04:00
commit 9860a38907
4 changed files with 42 additions and 33 deletions

View file

@ -2,14 +2,12 @@ import numpy as np
import cv2
from core import randomex
def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_seed=None ):
def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None ):
h,w,c = source.shape
if (h != w):
raise ValueError ('gen_warp_params accepts only square images.')
if rnd_seed != None:
rnd_state = np.random.RandomState (rnd_seed)
else:
if rnd_state is None:
rnd_state = np.random
rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] )

View file

@ -99,15 +99,17 @@ class IndexHost():
"""
Provides random shuffled indexes for multiprocesses
"""
def __init__(self, indexes_count):
def __init__(self, indexes_count, rnd_seed=None):
self.sq = multiprocessing.Queue()
self.cqs = []
self.clis = []
self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,) )
self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,rnd_seed) )
self.thread.daemon = True
self.thread.start()
def host_thread(self, indexes_count):
def host_thread(self, indexes_count, rnd_seed):
rnd_state = np.random.RandomState(rnd_seed) if rnd_seed is not None else np.random
idxs = [*range(indexes_count)]
shuffle_idxs = []
sq = self.sq
@ -121,7 +123,7 @@ class IndexHost():
for i in range(count):
if len(shuffle_idxs) == 0:
shuffle_idxs = idxs.copy()
np.random.shuffle(shuffle_idxs)
rnd_state.shuffle(shuffle_idxs)
result.append(shuffle_idxs.pop())
self.cqs[cq_id].put (result)