diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index 5f64852..344c792 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -27,7 +27,6 @@ class SampleGeneratorFace(SampleGeneratorBase): output_sample_types=[], add_sample_idx=False, generators_count=4, - rnd_seed=None, **kwargs): super().__init__(samples_path, debug, batch_size) @@ -35,9 +34,6 @@ class SampleGeneratorFace(SampleGeneratorBase): self.output_sample_types = output_sample_types self.add_sample_idx = add_sample_idx - if rnd_seed is None: - rnd_seed = np.random.randint(0x80000000) - if self.debug: self.generators_count = 1 else: @@ -49,11 +45,11 @@ class SampleGeneratorFace(SampleGeneratorBase): if self.samples_len == 0: raise ValueError('No training data provided.') - index_host = mplib.IndexHost(self.samples_len, rnd_seed=rnd_seed) + index_host = mplib.IndexHost(self.samples_len) if random_ct_samples_path is not None: ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) - ct_index_host = mplib.IndexHost( len(ct_samples), rnd_seed=rnd_seed ) + ct_index_host = mplib.IndexHost( len(ct_samples) ) else: ct_samples = None ct_index_host = None @@ -62,9 +58,9 @@ class SampleGeneratorFace(SampleGeneratorBase): ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None if self.debug: - self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None, rnd_seed) )] + self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )] else: - self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None, rnd_seed+i), start_now=False ) \ + self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \ for i in range(self.generators_count) ] SubprocessGenerator.start_in_parallel( self.generators ) @@ -80,10 +76,8 @@ class SampleGeneratorFace(SampleGeneratorBase): return next(generator) def batch_func(self, param ): - pickled_samples, index_host, ct_pickled_samples, ct_index_host, rnd_seed = param + pickled_samples, index_host, ct_pickled_samples, ct_index_host = param - rnd_state = np.random.RandomState(rnd_seed) - samples = pickle.loads(pickled_samples) ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None @@ -104,7 +98,7 @@ class SampleGeneratorFace(SampleGeneratorBase): ct_sample = ct_samples[ct_indexes[n_batch]] try: - x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, rnd_state=rnd_state) + x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample) except: raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) diff --git a/samplelib/SampleProcessor.py b/samplelib/SampleProcessor.py index c30e5f0..8a64ed3 100644 --- a/samplelib/SampleProcessor.py +++ b/samplelib/SampleProcessor.py @@ -63,13 +63,10 @@ class SampleProcessor(object): } @staticmethod - def process (samples, sample_process_options, output_sample_types, debug, ct_sample=None, rnd_state=None): + def process (samples, sample_process_options, output_sample_types, debug, ct_sample=None): SPTF = SampleProcessor.Types - if rnd_state is None: - rnd_state = np.random.RandomState( np.random.randint(0x80000000) ) - - sample_rnd_seed = rnd_state.randint(0x80000000) + sample_rnd_seed = np.random.randint(0x80000000) outputs = [] for sample in samples: @@ -82,7 +79,7 @@ class SampleProcessor(object): if debug and is_face_sample: LandmarksProcessor.draw_landmarks (sample_bgr, sample.landmarks, (0, 1, 0)) - params = imagelib.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range, rnd_state=rnd_state ) + params = imagelib.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range ) outputs_sample = [] for opts in output_sample_types: