diff --git a/core/imagelib/warp.py b/core/imagelib/warp.py index 857b13d..f8be299 100644 --- a/core/imagelib/warp.py +++ b/core/imagelib/warp.py @@ -2,14 +2,12 @@ import numpy as np import cv2 from core import randomex -def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_seed=None ): +def gen_warp_params (source, flip, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None ): h,w,c = source.shape if (h != w): raise ValueError ('gen_warp_params accepts only square images.') - if rnd_seed != None: - rnd_state = np.random.RandomState (rnd_seed) - else: + if rnd_state is None: rnd_state = np.random rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] ) diff --git a/core/mplib/__init__.py b/core/mplib/__init__.py index 63a14ea..ee4c40b 100644 --- a/core/mplib/__init__.py +++ b/core/mplib/__init__.py @@ -99,15 +99,17 @@ class IndexHost(): """ Provides random shuffled indexes for multiprocesses """ - def __init__(self, indexes_count): + def __init__(self, indexes_count, rnd_seed=None): self.sq = multiprocessing.Queue() self.cqs = [] self.clis = [] - self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,) ) + self.thread = threading.Thread(target=self.host_thread, args=(indexes_count,rnd_seed) ) self.thread.daemon = True self.thread.start() - def host_thread(self, indexes_count): + def host_thread(self, indexes_count, rnd_seed): + rnd_state = np.random.RandomState(rnd_seed) if rnd_seed is not None else np.random + idxs = [*range(indexes_count)] shuffle_idxs = [] sq = self.sq @@ -121,7 +123,7 @@ class IndexHost(): for i in range(count): if len(shuffle_idxs) == 0: shuffle_idxs = idxs.copy() - np.random.shuffle(shuffle_idxs) + rnd_state.shuffle(shuffle_idxs) result.append(shuffle_idxs.pop()) self.cqs[cq_id].put (result) diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index a7e622f..36ebee0 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -27,12 +27,16 @@ class SampleGeneratorFace(SampleGeneratorBase): output_sample_types=[], add_sample_idx=False, generators_count=4, + rnd_seed=None, **kwargs): super().__init__(samples_path, debug, batch_size) self.sample_process_options = sample_process_options self.output_sample_types = output_sample_types self.add_sample_idx = add_sample_idx + + if rnd_seed is None: + rnd_seed = np.random.randint(0x80000000) if self.debug: self.generators_count = 1 @@ -45,11 +49,11 @@ class SampleGeneratorFace(SampleGeneratorBase): if self.samples_len == 0: raise ValueError('No training data provided.') - index_host = mplib.IndexHost(self.samples_len) + index_host = mplib.IndexHost(self.samples_len, rnd_seed=rnd_seed) if random_ct_samples_path is not None: ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) - ct_index_host = mplib.IndexHost( len(ct_samples) ) + ct_index_host = mplib.IndexHost( len(ct_samples), rnd_seed=rnd_seed ) else: ct_samples = None ct_index_host = None @@ -58,9 +62,9 @@ class SampleGeneratorFace(SampleGeneratorBase): ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None if self.debug: - self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )] + self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None, rnd_seed) )] else: - self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \ + self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None, rnd_seed), start_now=False ) \ for i in range(self.generators_count) ] SubprocessGenerator.start_in_parallel( self.generators ) @@ -76,7 +80,9 @@ class SampleGeneratorFace(SampleGeneratorBase): return next(generator) def batch_func(self, param ): - pickled_samples, index_host, ct_pickled_samples, ct_index_host = param + pickled_samples, index_host, ct_pickled_samples, ct_index_host, rnd_seed = param + + rnd_state = np.random.RandomState(rnd_seed) samples = pickle.loads(pickled_samples) ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None @@ -98,7 +104,7 @@ class SampleGeneratorFace(SampleGeneratorBase): ct_sample = ct_samples[ct_indexes[n_batch]] try: - x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample) + x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, rnd_state=rnd_state) except: raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) diff --git a/samplelib/SampleProcessor.py b/samplelib/SampleProcessor.py index 929933c..c30e5f0 100644 --- a/samplelib/SampleProcessor.py +++ b/samplelib/SampleProcessor.py @@ -63,11 +63,14 @@ class SampleProcessor(object): } @staticmethod - def process (samples, sample_process_options, output_sample_types, debug, ct_sample=None): + def process (samples, sample_process_options, output_sample_types, debug, ct_sample=None, rnd_state=None): SPTF = SampleProcessor.Types - sample_rnd_seed = np.random.randint(0x80000000) - + if rnd_state is None: + rnd_state = np.random.RandomState( np.random.randint(0x80000000) ) + + sample_rnd_seed = rnd_state.randint(0x80000000) + outputs = [] for sample in samples: sample_bgr = sample.load_bgr() @@ -79,7 +82,7 @@ class SampleProcessor(object): if debug and is_face_sample: LandmarksProcessor.draw_landmarks (sample_bgr, sample.landmarks, (0, 1, 0)) - params = imagelib.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range, rnd_seed=sample_rnd_seed ) + params = imagelib.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range, rnd_state=rnd_state ) outputs_sample = [] for opts in output_sample_types: @@ -186,10 +189,10 @@ class SampleProcessor(object): chance, mb_max_size = motion_blur chance = np.clip(chance, 0, 100) - rnd_state = np.random.RandomState (sample_rnd_seed) - mblur_rnd_chance = rnd_state.randint(100) - mblur_rnd_kernel = rnd_state.randint(mb_max_size)+1 - mblur_rnd_deg = rnd_state.randint(360) + l_rnd_state = np.random.RandomState (sample_rnd_seed) + mblur_rnd_chance = l_rnd_state.randint(100) + mblur_rnd_kernel = l_rnd_state.randint(mb_max_size)+1 + mblur_rnd_deg = l_rnd_state.randint(360) if mblur_rnd_chance < chance: img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg ) @@ -198,9 +201,9 @@ class SampleProcessor(object): chance, kernel_max_size = gaussian_blur chance = np.clip(chance, 0, 100) - rnd_state = np.random.RandomState (sample_rnd_seed+1) - gblur_rnd_chance = rnd_state.randint(100) - gblur_rnd_kernel = rnd_state.randint(kernel_max_size)*2+1 + l_rnd_state = np.random.RandomState (sample_rnd_seed+1) + gblur_rnd_chance = l_rnd_state.randint(100) + gblur_rnd_kernel = l_rnd_state.randint(kernel_max_size)*2+1 if gblur_rnd_chance < chance: img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0) @@ -260,22 +263,22 @@ class SampleProcessor(object): if mode_type == SPTF.MODE_BGR: out_sample = img elif mode_type == SPTF.MODE_BGR_SHUFFLE: - rnd_state = np.random.RandomState (sample_rnd_seed) - out_sample = np.take (img, rnd_state.permutation(img.shape[-1]), axis=-1) + l_rnd_state = np.random.RandomState (sample_rnd_seed) + out_sample = np.take (img, l_rnd_state.permutation(img.shape[-1]), axis=-1) elif mode_type == SPTF.MODE_BGR_RANDOM_HSV_SHIFT: - rnd_state = np.random.RandomState (sample_rnd_seed) + l_rnd_state = np.random.RandomState (sample_rnd_seed) hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) h, s, v = cv2.split(hsv) - h = (h + rnd_state.randint(360) ) % 360 - s = np.clip ( s + rnd_state.random()-0.5, 0, 1 ) - v = np.clip ( v + rnd_state.random()-0.5, 0, 1 ) + h = (h + l_rnd_state.randint(360) ) % 360 + s = np.clip ( s + l_rnd_state.random()-0.5, 0, 1 ) + v = np.clip ( v + l_rnd_state.random()-0.5, 0, 1 ) hsv = cv2.merge([h, s, v]) out_sample = np.clip( cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) , 0, 1 ) elif mode_type == SPTF.MODE_BGR_RANDOM_RGB_LEVELS: - rnd_state = np.random.RandomState (sample_rnd_seed) - np_rnd = rnd_state.rand + l_rnd_state = np.random.RandomState (sample_rnd_seed) + np_rnd = l_rnd_state.rand inBlack = np.array([np_rnd()*0.25 , np_rnd()*0.25 , np_rnd()*0.25], dtype=np.float32) inWhite = np.array([1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25], dtype=np.float32)