From 18d93376fcf7075f47dc93705b6ef0be5c9e6441 Mon Sep 17 00:00:00 2001 From: Colombo Date: Sun, 8 Mar 2020 10:34:48 +0400 Subject: [PATCH] update FANSeg --- models/ModelBase.py | 8 ++++- models/Model_FANSeg/Model.py | 51 +++++++++++++++++++++++--------- samplelib/SampleGeneratorBase.py | 4 +++ samplelib/SampleGeneratorFace.py | 15 ++++++++-- 4 files changed, 61 insertions(+), 17 deletions(-) diff --git a/models/ModelBase.py b/models/ModelBase.py index ec48f83..ea26d4a 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -412,7 +412,13 @@ class ModelBase(object): return imagelib.equalize_and_stack_square (images) def generate_next_samples(self): - self.last_sample = sample = [ generator.generate_next() for generator in self.generator_list] + sample = [] + for generator in self.generator_list: + if generator.is_initialized(): + sample.append ( generator.generate_next() ) + else: + sample.append ( [] ) + self.last_sample = sample return sample def train_one_iter(self): diff --git a/models/Model_FANSeg/Model.py b/models/Model_FANSeg/Model.py index d8babde..3cb3990 100644 --- a/models/Model_FANSeg/Model.py +++ b/models/Model_FANSeg/Model.py @@ -24,7 +24,6 @@ class FANSegModel(ModelBase): ask_override = self.ask_override() if self.is_first_run() or ask_override: self.ask_autobackup_hour() - self.ask_write_preview_history() self.ask_target_iter() self.ask_batch_size(4) @@ -117,21 +116,30 @@ class FANSegModel(ModelBase): # initializing sample generators training_data_src_path = self.training_data_src_path - #training_data_dst_path = self.training_data_dst_path + training_data_dst_path = self.training_data_dst_path cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 src_generators_count = int(src_generators_count * 1.5) - self.set_training_data_generators ([ - SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=True), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'ct_mode':'idt', 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.NONE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - ], - generators_count=src_generators_count ), - ]) + src_generator = SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=True), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'ct_mode':'idt', 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.NONE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + generators_count=src_generators_count ) + + dst_generator = SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=True), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'data_format':nn.data_format, 'resolution': resolution}, + ], + generators_count=dst_generators_count, + raise_on_no_data=False ) + if not dst_generator.is_initialized(): + io.log_info(f"\nTo view the model on unseen faces, place any aligned faces in {training_data_dst_path}.\n") + + self.set_training_data_generators ([src_generator, dst_generator]) #override def get_model_filename_list(self): @@ -143,7 +151,7 @@ class FANSegModel(ModelBase): #override def onTrainOneIter(self): - ( (source_np, target_np), ) = self.generate_next_samples() + source_np, target_np = self.generate_next_samples()[0] loss = self.train (source_np, target_np) return ( ('loss', loss ), ) @@ -152,7 +160,8 @@ class FANSegModel(ModelBase): def onGetPreview(self, samples): n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) - ( (source_np, target_np), ) = samples + src_samples, dst_samples = samples + source_np, target_np = src_samples S, T, SM, = [ np.clip(x, 0.0, 1.0) for x in ([source_np,target_np] + self.view (source_np) ) ] T, SM, = [ np.repeat (x, (3,), -1) for x in [T, SM] ] @@ -164,8 +173,22 @@ class FANSegModel(ModelBase): ar = S[i], T[i], SM[i], S[i]*SM[i] #todo green bg st.append ( np.concatenate ( ar, axis=1) ) - result += [ ('FANSeg', np.concatenate (st, axis=0 )), ] - + result += [ ('FANSeg training faces', np.concatenate (st, axis=0 )), ] + + if len(dst_samples) != 0: + dst_np, = dst_samples + + D, DM, = [ np.clip(x, 0.0, 1.0) for x in ([dst_np] + self.view (dst_np) ) ] + DM, = [ np.repeat (x, (3,), -1) for x in [DM] ] + + st = [] + for i in range(n_samples): + ar = D[i], DM[i], D[i]*DM[i] + #todo green bg + st.append ( np.concatenate ( ar, axis=1) ) + + result += [ ('FANSeg unseen faces', np.concatenate (st, axis=0 )), ] + return result Model = FANSegModel diff --git a/samplelib/SampleGeneratorBase.py b/samplelib/SampleGeneratorBase.py index ef98974..19d02b3 100644 --- a/samplelib/SampleGeneratorBase.py +++ b/samplelib/SampleGeneratorBase.py @@ -33,3 +33,7 @@ class SampleGeneratorBase(object): def __next__(self): #implement your own iterator return None + + #overridable + def is_initialized(self): + return True \ No newline at end of file diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index 344c792..f386797 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -27,6 +27,7 @@ class SampleGeneratorFace(SampleGeneratorBase): output_sample_types=[], add_sample_idx=False, generators_count=4, + raise_on_no_data=True, **kwargs): super().__init__(samples_path, debug, batch_size) @@ -42,8 +43,12 @@ class SampleGeneratorFace(SampleGeneratorBase): samples = SampleLoader.load (SampleType.FACE, self.samples_path) self.samples_len = len(samples) + self.initialized = False if self.samples_len == 0: - raise ValueError('No training data provided.') + if raise_on_no_data: + raise ValueError('No training data provided.') + else: + return index_host = mplib.IndexHost(self.samples_len) @@ -66,7 +71,13 @@ class SampleGeneratorFace(SampleGeneratorBase): SubprocessGenerator.start_in_parallel( self.generators ) self.generator_counter = -1 - + + self.initialized = True + + #overridable + def is_initialized(self): + return self.initialized + def __iter__(self): return self