SAE: added 'learn mask' option, default:Y,

SAE: added clipping mask at borders to remove artifact lines.
This commit is contained in:
iperov 2019-01-19 16:15:49 +04:00
parent dd10d963d1
commit fe66b6b2a1
2 changed files with 98 additions and 56 deletions

View file

@ -15,13 +15,14 @@ class ConverterMasked(ConverterBase):
face_type=FaceType.FULL,
base_erode_mask_modifier = 0,
base_blur_mask_modifier = 0,
clip_border_mask_per = 0,
**in_options):
super().__init__(predictor)
self.predictor_input_size = predictor_input_size
self.output_size = output_size
self.face_type = face_type
self.clip_border_mask_per = clip_border_mask_per
self.TFLabConverter = None
mode = input_int ("Choose mode: (1) overlay, (2) hist match, (3) hist match bw, (4) seamless (default), (5) seamless hist match, (6) raw : ", 4)
@ -182,6 +183,16 @@ class ConverterMasked(ConverterBase):
img_mask_blurry_aaa = np.clip( img_mask_blurry_aaa, 0, 1.0 )
if self.clip_border_mask_per > 0:
prd_border_rect_mask_a = np.ones ( prd_face_mask_a.shape, dtype=prd_face_mask_a.dtype)
prd_border_size = int ( prd_border_rect_mask_a.shape[1] * self.clip_border_mask_per )
prd_border_rect_mask_a[0:prd_border_size,:,:] = 0
prd_border_rect_mask_a[-prd_border_size:,:,:] = 0
prd_border_rect_mask_a[:,0:prd_border_size,:] = 0
prd_border_rect_mask_a[:,-prd_border_size:,:] = 0
prd_border_rect_mask_a = np.expand_dims(cv2.blur(prd_border_rect_mask_a, (prd_border_size, prd_border_size) ),-1)
if self.mode == 'hist-match-bw':
prd_face_bgr = cv2.cvtColor(prd_face_bgr, cv2.COLOR_BGR2GRAY)
prd_face_bgr = np.repeat( np.expand_dims (prd_face_bgr, -1), (3,), -1 )
@ -226,6 +237,11 @@ class ConverterMasked(ConverterBase):
if debug:
debugs += [out_img.copy()]
if self.clip_border_mask_per > 0:
img_prd_border_rect_mask_a = cv2.warpAffine( prd_border_rect_mask_a, face_output_mat, img_size, np.zeros(img_bgr.shape, dtype=np.float32), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT )
img_prd_border_rect_mask_a = np.expand_dims (img_prd_border_rect_mask_a, -1)
img_mask_blurry_aaa *= img_prd_border_rect_mask_a
out_img = np.clip( img_bgr*(1-img_mask_blurry_aaa) + (out_img*img_mask_blurry_aaa) , 0, 1.0 )
if self.mode == 'seamless-hist-match':

View file

@ -30,10 +30,12 @@ class SAEModel(ModelBase):
self.options['resolution'] = input_int("Resolution (64,128 ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, sacrificing overall quality.")
self.options['learn_mask'] = input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Choose NO to reduce model size. In this case converter forced to use 'not predicted mask' that is not smooth as predicted. Styled SAE can learn without mask and produce same quality fake if you choose high blur value in converter.")
else:
self.options['resolution'] = self.options.get('resolution', default_resolution)
self.options['archi'] = self.options.get('archi', default_archi)
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
self.options['learn_mask'] = self.options.get('learn_mask', True)
default_face_style_power = 10.0
if is_first_run or ask_override:
@ -101,14 +103,17 @@ class SAEModel(ModelBase):
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (inter_output_Inputs)
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
if self.options['learn_mask']:
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
self.inter_AB.load_weights (self.get_strpath_storage_for_file(self.inter_ABH5))
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5))
if self.options['learn_mask']:
self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5))
warped_src_code = self.encoder (warped_src)
@ -116,19 +121,23 @@ class SAEModel(ModelBase):
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
pred_src_src = self.decoder(warped_src_inter_code)
pred_src_srcm = self.decoderm(warped_src_inter_code)
if self.options['learn_mask']:
pred_src_srcm = self.decoderm(warped_src_inter_code)
warped_dst_code = self.encoder (warped_dst)
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
pred_dst_dst = self.decoder(warped_dst_inter_code)
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
if self.options['learn_mask']:
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
pred_src_dst = self.decoder(warped_src_dst_inter_code)
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
if self.options['learn_mask']:
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
else:
self.encoder = modelify(SAEModel.DFEncFlow(resolution, adapt_k_size, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
@ -136,28 +145,28 @@ class SAEModel(ModelBase):
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
if self.options['learn_mask']:
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5))
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
if self.options['learn_mask']:
self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5))
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
warped_src_code = self.encoder (warped_src)
pred_src_src = self.decoder_src(warped_src_code)
pred_src_srcm = self.decoder_srcm(warped_src_code)
warped_dst_code = self.encoder (warped_dst)
pred_src_src = self.decoder_src(warped_src_code)
pred_dst_dst = self.decoder_dst(warped_dst_code)
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
pred_src_dst = self.decoder_src(warped_dst_code)
pred_src_dstm = self.decoder_srcm(warped_dst_code)
if self.options['learn_mask']:
pred_src_srcm = self.decoder_srcm(warped_src_code)
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
pred_src_dstm = self.decoder_srcm(warped_dst_code)
ms_count = len(pred_src_src)
@ -195,14 +204,16 @@ class SAEModel(ModelBase):
if self.options['archi'] == 'liae':
src_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
src_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
if self.options['learn_mask']:
src_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
else:
src_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
if self.options['learn_mask']:
src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
@ -216,18 +227,23 @@ class SAEModel(ModelBase):
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) )
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ])
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], optimizer().get_updates(src_mask_loss, src_mask_loss_train_weights) )
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) )
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ])
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], optimizer().get_updates(dst_mask_loss, dst_mask_loss_train_weights) )
if self.options['learn_mask']:
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ])
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], optimizer().get_updates(src_mask_loss, src_mask_loss_train_weights) )
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ])
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], optimizer().get_updates(dst_mask_loss, dst_mask_loss_train_weights) )
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] )
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_src_srcm[-1], pred_dst_dst[-1], pred_dst_dstm[-1], pred_src_dst[-1], pred_src_dstm[-1]] )
else:
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_src_dstm[-1] ])
if self.options['learn_mask']:
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_src_dstm[-1] ])
else:
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1] ])
if self.is_training_mode:
f = SampleProcessor.TypeFlags
@ -251,19 +267,23 @@ class SAEModel(ModelBase):
#override
def onSave(self):
if self.options['archi'] == 'liae':
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
[self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)],
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)],
[self.decoder, self.get_strpath_storage_for_file(self.decoderH5)],
[self.decoderm, self.get_strpath_storage_for_file(self.decodermH5)],
] )
ar = [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
[self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)],
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)],
[self.decoder, self.get_strpath_storage_for_file(self.decoderH5)]
]
if self.options['learn_mask']:
ar += [ [self.decoderm, self.get_strpath_storage_for_file(self.decodermH5)] ]
else:
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
[self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)],
[self.decoder_srcm, self.get_strpath_storage_for_file(self.decoder_srcmH5)],
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)],
[self.decoder_dstm, self.get_strpath_storage_for_file(self.decoder_dstmH5)],
] )
ar = [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
[self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)],
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]
]
if self.options['learn_mask']:
ar += [ [self.decoder_srcm, self.get_strpath_storage_for_file(self.decoder_srcmH5)],
[self.decoder_dstm, self.get_strpath_storage_for_file(self.decoder_dstmH5)] ]
self.save_weights_safe(ar)
#override
def onTrainOneEpoch(self, sample):
@ -273,8 +293,9 @@ class SAEModel(ModelBase):
src_loss, = self.src_train ([warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
dst_loss, = self.dst_train ([warped_dst, target_dst, target_dst_mask])
src_mask_loss, = self.src_mask_train ([warped_src, target_src_mask])
dst_mask_loss, = self.dst_mask_train ([warped_dst, target_dst_mask])
if self.options['learn_mask']:
src_mask_loss, = self.src_mask_train ([warped_src, target_src_mask])
dst_mask_loss, = self.dst_mask_train ([warped_dst, target_dst_mask])
return ( ('src_loss', src_loss), ('dst_loss', dst_loss) )
@ -286,7 +307,7 @@ class SAEModel(ModelBase):
test_B = sample[1][1][0:4]
test_B_m = sample[1][2][0:4]
S, D, SS, SM, DD, DM, SD, SDM, = [ x / 2 + 0.5 for x in ([test_A,test_B] + self.AE_view ([test_A, test_B]) ) ]
S, D, SS, DD, SD, = [ x / 2 + 0.5 for x in ([test_A,test_B] + self.AE_view ([test_A, test_B]) ) ]
#SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ]
st_x3 = []
@ -300,10 +321,14 @@ class SAEModel(ModelBase):
return [ ('SAE', np.concatenate (st_x3, axis=0 )), ]
def predictor_func (self, face):
face = face * 2.0 - 1.0
face_128_bgr = face[...,0:3]
x, mx = [ (x[0] + 1.0) / 2.0 for x in self.AE_convert ( [ np.expand_dims(face_128_bgr,0) ] ) ]
return np.concatenate ( (x,mx), -1 )
face_tanh = face * 2.0 - 1.0
face_bgr = face_tanh[...,0:3]
prd = [ (x[0] + 1.0) / 2.0 for x in self.AE_convert ( [ np.expand_dims(face_bgr,0) ] ) ]
if not self.options['learn_mask']:
prd += [ np.expand_dims(face[...,3],-1) ]
return np.concatenate ( [prd[0], prd[1]], -1 )
#override
def get_converter(self, **in_options):
@ -320,6 +345,7 @@ class SAEModel(ModelBase):
face_type=face_type,
base_erode_mask_modifier=base_erode_mask_modifier,
base_blur_mask_modifier=base_blur_mask_modifier,
clip_border_mask_per=0.03125,
**in_options)
@staticmethod