This commit is contained in:
Colombo 2020-02-27 12:03:01 +04:00
parent 30c93a9bdb
commit acb0b34811
2 changed files with 9 additions and 18 deletions

View file

@ -27,7 +27,6 @@ class SampleGeneratorFace(SampleGeneratorBase):
output_sample_types=[], output_sample_types=[],
add_sample_idx=False, add_sample_idx=False,
generators_count=4, generators_count=4,
rnd_seed=None,
**kwargs): **kwargs):
super().__init__(samples_path, debug, batch_size) super().__init__(samples_path, debug, batch_size)
@ -35,9 +34,6 @@ class SampleGeneratorFace(SampleGeneratorBase):
self.output_sample_types = output_sample_types self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx self.add_sample_idx = add_sample_idx
if rnd_seed is None:
rnd_seed = np.random.randint(0x80000000)
if self.debug: if self.debug:
self.generators_count = 1 self.generators_count = 1
else: else:
@ -49,11 +45,11 @@ class SampleGeneratorFace(SampleGeneratorBase):
if self.samples_len == 0: if self.samples_len == 0:
raise ValueError('No training data provided.') raise ValueError('No training data provided.')
index_host = mplib.IndexHost(self.samples_len, rnd_seed=rnd_seed) index_host = mplib.IndexHost(self.samples_len)
if random_ct_samples_path is not None: if random_ct_samples_path is not None:
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path)
ct_index_host = mplib.IndexHost( len(ct_samples), rnd_seed=rnd_seed ) ct_index_host = mplib.IndexHost( len(ct_samples) )
else: else:
ct_samples = None ct_samples = None
ct_index_host = None ct_index_host = None
@ -62,9 +58,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None
if self.debug: if self.debug:
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None, rnd_seed) )] self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
else: else:
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None, rnd_seed+i), start_now=False ) \ self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \
for i in range(self.generators_count) ] for i in range(self.generators_count) ]
SubprocessGenerator.start_in_parallel( self.generators ) SubprocessGenerator.start_in_parallel( self.generators )
@ -80,10 +76,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
return next(generator) return next(generator)
def batch_func(self, param ): def batch_func(self, param ):
pickled_samples, index_host, ct_pickled_samples, ct_index_host, rnd_seed = param pickled_samples, index_host, ct_pickled_samples, ct_index_host = param
rnd_state = np.random.RandomState(rnd_seed)
samples = pickle.loads(pickled_samples) samples = pickle.loads(pickled_samples)
ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None
@ -104,7 +98,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
ct_sample = ct_samples[ct_indexes[n_batch]] ct_sample = ct_samples[ct_indexes[n_batch]]
try: try:
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, rnd_state=rnd_state) x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
except: except:
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )

View file

@ -63,13 +63,10 @@ class SampleProcessor(object):
} }
@staticmethod @staticmethod
def process (samples, sample_process_options, output_sample_types, debug, ct_sample=None, rnd_state=None): def process (samples, sample_process_options, output_sample_types, debug, ct_sample=None):
SPTF = SampleProcessor.Types SPTF = SampleProcessor.Types
if rnd_state is None: sample_rnd_seed = np.random.randint(0x80000000)
rnd_state = np.random.RandomState( np.random.randint(0x80000000) )
sample_rnd_seed = rnd_state.randint(0x80000000)
outputs = [] outputs = []
for sample in samples: for sample in samples:
@ -82,7 +79,7 @@ class SampleProcessor(object):
if debug and is_face_sample: if debug and is_face_sample:
LandmarksProcessor.draw_landmarks (sample_bgr, sample.landmarks, (0, 1, 0)) LandmarksProcessor.draw_landmarks (sample_bgr, sample.landmarks, (0, 1, 0))
params = imagelib.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range, rnd_state=rnd_state ) params = imagelib.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range )
outputs_sample = [] outputs_sample = []
for opts in output_sample_types: for opts in output_sample_types: