diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 9b9cd2d..79ffa30 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -56,6 +56,7 @@ class SAEHDModel(ModelBase): default_loss_function = self.options['loss_function'] = self.load_or_def_option('loss_function', 'SSIM') default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) + default_random_downsample = self.options['random_downsample'] = self.load_or_def_option('random_downsample', False) default_background_power = self.options['background_power'] = self.load_or_def_option('background_power', 0.0) default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0) default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0) @@ -159,6 +160,11 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") + self.options['random_downsample'] = io.input_bool("Enable random downsample of samples", default_random_downsample, help_message="") + # self.options['random_noise'] = io.input_bool("Enable random noise added to samples", False, help_message="") + # self.options['random_blur'] = io.input_bool("Enable random blur of samples", False, help_message="") + # self.options['random_jpeg'] = io.input_bool("Enable random jpeg compression of samples", False, help_message="") + self.options['gan_version'] = np.clip (io.input_int("GAN version", default_gan_version, add_info="2 or 3", help_message="Choose GAN version (v2: 7/16/2020, v3: 1/3/2021):"), 2, 3) if self.options['gan_version'] == 2: @@ -745,7 +751,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : channel_type, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'random_downsample': self.options['random_downsample'], 'transform':True, 'channel_type' : channel_type, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : channel_type, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, @@ -755,7 +761,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : channel_type, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'random_downsample': self.options['random_downsample'], 'transform':True, 'channel_type' : channel_type, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : channel_type, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, @@ -830,6 +836,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples S, D, SS, SSM, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] + SW, DW = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([warped_src,warped_dst]) ] SSM, DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [SSM, DDM, SDM] ] target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] @@ -845,6 +852,11 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... st.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD', np.concatenate (st, axis=0 )), ] + wt = [] + for i in range(n_samples): + ar = SW[i], SS[i], DW[i], DD[i], SD[i] + wt.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD warped', np.concatenate (wt, axis=0 )), ] st_m = [] for i in range(n_samples): @@ -875,6 +887,23 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... st.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD pred', np.concatenate (st, axis=0 )), ] + wt = [] + for i in range(n_samples): + ar = SW[i], SS[i] + wt.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD warped src-src', np.concatenate (wt, axis=0 )), ] + + wt = [] + for i in range(n_samples): + ar = DW[i], DD[i] + wt.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD warped dst-dst', np.concatenate (wt, axis=0 )), ] + + wt = [] + for i in range(n_samples): + ar = DW[i], SD[i] + wt.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD warped pred', np.concatenate (wt, axis=0 )), ] st_m = [] for i in range(n_samples): diff --git a/samplelib/SampleProcessor.py b/samplelib/SampleProcessor.py index 4df44a3..1f1b96f 100644 --- a/samplelib/SampleProcessor.py +++ b/samplelib/SampleProcessor.py @@ -112,6 +112,7 @@ class SampleProcessor(object): nearest_resize_to = opts.get('nearest_resize_to', None) warp = opts.get('warp', False) transform = opts.get('transform', False) + random_downsample = opts.get('random_downsample', False) motion_blur = opts.get('motion_blur', None) gaussian_blur = opts.get('gaussian_blur', None) random_bilinear_resize = opts.get('random_bilinear_resize', None) @@ -213,6 +214,11 @@ class SampleProcessor(object): ct_sample_bgr = ct_sample.load_bgr() img = imagelib.color_transfer (ct_mode, img, cv2.resize( ct_sample_bgr, (resolution,resolution), interpolation=cv2.INTER_LINEAR ) ) + # Apply random downsampling + if random_downsample: + down_res = np.random.randint(int(0.125*resolution), int(0.25*resolution)) + img = cv2.resize(img, (down_res, down_res), interpolation=cv2.INTER_CUBIC) + img = cv2.resize(img, (resolution, resolution), interpolation=cv2.INTER_CUBIC) img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate) img = np.clip(img.astype(np.float32), 0, 1)