Merge pull request #144 from faceshiftlabs/feat/image-degradation-random-downsample

Feat/image degradation random downsample
This commit is contained in:
Jeremy Hummel 2021-05-22 23:47:27 -07:00 committed by GitHub
commit 75bad1ba0c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 2 deletions

View file

@ -56,6 +56,7 @@ class SAEHDModel(ModelBase):
default_loss_function = self.options['loss_function'] = self.load_or_def_option('loss_function', 'SSIM')
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
default_random_downsample = self.options['random_downsample'] = self.load_or_def_option('random_downsample', False)
default_background_power = self.options['background_power'] = self.load_or_def_option('background_power', 0.0)
default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0)
default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0)
@ -159,6 +160,11 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")
self.options['random_downsample'] = io.input_bool("Enable random downsample of samples", default_random_downsample, help_message="")
# self.options['random_noise'] = io.input_bool("Enable random noise added to samples", False, help_message="")
# self.options['random_blur'] = io.input_bool("Enable random blur of samples", False, help_message="")
# self.options['random_jpeg'] = io.input_bool("Enable random jpeg compression of samples", False, help_message="")
self.options['gan_version'] = np.clip (io.input_int("GAN version", default_gan_version, add_info="2 or 3", help_message="Choose GAN version (v2: 7/16/2020, v3: 1/3/2021):"), 2, 3)
if self.options['gan_version'] == 2:
@ -745,7 +751,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : channel_type, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'random_downsample': self.options['random_downsample'], 'transform':True, 'channel_type' : channel_type, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : channel_type, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
@ -755,7 +761,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : channel_type, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'random_downsample': self.options['random_downsample'], 'transform':True, 'channel_type' : channel_type, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : channel_type, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
@ -830,6 +836,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
S, D, SS, SSM, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
SW, DW = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([warped_src,warped_dst]) ]
SSM, DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [SSM, DDM, SDM] ]
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
@ -845,6 +852,11 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
st.append ( np.concatenate ( ar, axis=1) )
result += [ ('SAEHD', np.concatenate (st, axis=0 )), ]
wt = []
for i in range(n_samples):
ar = SW[i], SS[i], DW[i], DD[i], SD[i]
wt.append ( np.concatenate ( ar, axis=1) )
result += [ ('SAEHD warped', np.concatenate (wt, axis=0 )), ]
st_m = []
for i in range(n_samples):
@ -875,6 +887,23 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
st.append ( np.concatenate ( ar, axis=1) )
result += [ ('SAEHD pred', np.concatenate (st, axis=0 )), ]
wt = []
for i in range(n_samples):
ar = SW[i], SS[i]
wt.append ( np.concatenate ( ar, axis=1) )
result += [ ('SAEHD warped src-src', np.concatenate (wt, axis=0 )), ]
wt = []
for i in range(n_samples):
ar = DW[i], DD[i]
wt.append ( np.concatenate ( ar, axis=1) )
result += [ ('SAEHD warped dst-dst', np.concatenate (wt, axis=0 )), ]
wt = []
for i in range(n_samples):
ar = DW[i], SD[i]
wt.append ( np.concatenate ( ar, axis=1) )
result += [ ('SAEHD warped pred', np.concatenate (wt, axis=0 )), ]
st_m = []
for i in range(n_samples):

View file

@ -112,6 +112,7 @@ class SampleProcessor(object):
nearest_resize_to = opts.get('nearest_resize_to', None)
warp = opts.get('warp', False)
transform = opts.get('transform', False)
random_downsample = opts.get('random_downsample', False)
motion_blur = opts.get('motion_blur', None)
gaussian_blur = opts.get('gaussian_blur', None)
random_bilinear_resize = opts.get('random_bilinear_resize', None)
@ -213,6 +214,11 @@ class SampleProcessor(object):
ct_sample_bgr = ct_sample.load_bgr()
img = imagelib.color_transfer (ct_mode, img, cv2.resize( ct_sample_bgr, (resolution,resolution), interpolation=cv2.INTER_LINEAR ) )
# Apply random downsampling
if random_downsample:
down_res = np.random.randint(int(0.125*resolution), int(0.25*resolution))
img = cv2.resize(img, (down_res, down_res), interpolation=cv2.INTER_CUBIC)
img = cv2.resize(img, (resolution, resolution), interpolation=cv2.INTER_CUBIC)
img = imagelib.warp_by_params (params_per_resolution[resolution], img, warp, transform, can_flip=True, border_replicate=border_replicate)
img = np.clip(img.astype(np.float32), 0, 1)