mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
SAE: added 'learn mask' option, default:Y,
SAE: added clipping mask at borders to remove artifact lines.
This commit is contained in:
parent
dd10d963d1
commit
fe66b6b2a1
2 changed files with 98 additions and 56 deletions
|
@ -15,13 +15,14 @@ class ConverterMasked(ConverterBase):
|
||||||
face_type=FaceType.FULL,
|
face_type=FaceType.FULL,
|
||||||
base_erode_mask_modifier = 0,
|
base_erode_mask_modifier = 0,
|
||||||
base_blur_mask_modifier = 0,
|
base_blur_mask_modifier = 0,
|
||||||
|
clip_border_mask_per = 0,
|
||||||
**in_options):
|
**in_options):
|
||||||
|
|
||||||
super().__init__(predictor)
|
super().__init__(predictor)
|
||||||
self.predictor_input_size = predictor_input_size
|
self.predictor_input_size = predictor_input_size
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
self.face_type = face_type
|
self.face_type = face_type
|
||||||
|
self.clip_border_mask_per = clip_border_mask_per
|
||||||
self.TFLabConverter = None
|
self.TFLabConverter = None
|
||||||
|
|
||||||
mode = input_int ("Choose mode: (1) overlay, (2) hist match, (3) hist match bw, (4) seamless (default), (5) seamless hist match, (6) raw : ", 4)
|
mode = input_int ("Choose mode: (1) overlay, (2) hist match, (3) hist match bw, (4) seamless (default), (5) seamless hist match, (6) raw : ", 4)
|
||||||
|
@ -182,6 +183,16 @@ class ConverterMasked(ConverterBase):
|
||||||
|
|
||||||
img_mask_blurry_aaa = np.clip( img_mask_blurry_aaa, 0, 1.0 )
|
img_mask_blurry_aaa = np.clip( img_mask_blurry_aaa, 0, 1.0 )
|
||||||
|
|
||||||
|
if self.clip_border_mask_per > 0:
|
||||||
|
prd_border_rect_mask_a = np.ones ( prd_face_mask_a.shape, dtype=prd_face_mask_a.dtype)
|
||||||
|
prd_border_size = int ( prd_border_rect_mask_a.shape[1] * self.clip_border_mask_per )
|
||||||
|
|
||||||
|
prd_border_rect_mask_a[0:prd_border_size,:,:] = 0
|
||||||
|
prd_border_rect_mask_a[-prd_border_size:,:,:] = 0
|
||||||
|
prd_border_rect_mask_a[:,0:prd_border_size,:] = 0
|
||||||
|
prd_border_rect_mask_a[:,-prd_border_size:,:] = 0
|
||||||
|
prd_border_rect_mask_a = np.expand_dims(cv2.blur(prd_border_rect_mask_a, (prd_border_size, prd_border_size) ),-1)
|
||||||
|
|
||||||
if self.mode == 'hist-match-bw':
|
if self.mode == 'hist-match-bw':
|
||||||
prd_face_bgr = cv2.cvtColor(prd_face_bgr, cv2.COLOR_BGR2GRAY)
|
prd_face_bgr = cv2.cvtColor(prd_face_bgr, cv2.COLOR_BGR2GRAY)
|
||||||
prd_face_bgr = np.repeat( np.expand_dims (prd_face_bgr, -1), (3,), -1 )
|
prd_face_bgr = np.repeat( np.expand_dims (prd_face_bgr, -1), (3,), -1 )
|
||||||
|
@ -226,6 +237,11 @@ class ConverterMasked(ConverterBase):
|
||||||
if debug:
|
if debug:
|
||||||
debugs += [out_img.copy()]
|
debugs += [out_img.copy()]
|
||||||
|
|
||||||
|
if self.clip_border_mask_per > 0:
|
||||||
|
img_prd_border_rect_mask_a = cv2.warpAffine( prd_border_rect_mask_a, face_output_mat, img_size, np.zeros(img_bgr.shape, dtype=np.float32), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT )
|
||||||
|
img_prd_border_rect_mask_a = np.expand_dims (img_prd_border_rect_mask_a, -1)
|
||||||
|
img_mask_blurry_aaa *= img_prd_border_rect_mask_a
|
||||||
|
|
||||||
out_img = np.clip( img_bgr*(1-img_mask_blurry_aaa) + (out_img*img_mask_blurry_aaa) , 0, 1.0 )
|
out_img = np.clip( img_bgr*(1-img_mask_blurry_aaa) + (out_img*img_mask_blurry_aaa) , 0, 1.0 )
|
||||||
|
|
||||||
if self.mode == 'seamless-hist-match':
|
if self.mode == 'seamless-hist-match':
|
||||||
|
|
|
@ -30,10 +30,12 @@ class SAEModel(ModelBase):
|
||||||
self.options['resolution'] = input_int("Resolution (64,128 ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
|
self.options['resolution'] = input_int("Resolution (64,128 ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
|
||||||
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
||||||
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, sacrificing overall quality.")
|
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, sacrificing overall quality.")
|
||||||
|
self.options['learn_mask'] = input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Choose NO to reduce model size. In this case converter forced to use 'not predicted mask' that is not smooth as predicted. Styled SAE can learn without mask and produce same quality fake if you choose high blur value in converter.")
|
||||||
else:
|
else:
|
||||||
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
||||||
self.options['archi'] = self.options.get('archi', default_archi)
|
self.options['archi'] = self.options.get('archi', default_archi)
|
||||||
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
||||||
|
self.options['learn_mask'] = self.options.get('learn_mask', True)
|
||||||
|
|
||||||
default_face_style_power = 10.0
|
default_face_style_power = 10.0
|
||||||
if is_first_run or ask_override:
|
if is_first_run or ask_override:
|
||||||
|
@ -101,14 +103,17 @@ class SAEModel(ModelBase):
|
||||||
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
|
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
|
||||||
|
|
||||||
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (inter_output_Inputs)
|
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (inter_output_Inputs)
|
||||||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
|
|
||||||
|
if self.options['learn_mask']:
|
||||||
|
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||||
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
|
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
|
||||||
self.inter_AB.load_weights (self.get_strpath_storage_for_file(self.inter_ABH5))
|
self.inter_AB.load_weights (self.get_strpath_storage_for_file(self.inter_ABH5))
|
||||||
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
|
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
|
||||||
self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5))
|
if self.options['learn_mask']:
|
||||||
|
self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5))
|
||||||
|
|
||||||
warped_src_code = self.encoder (warped_src)
|
warped_src_code = self.encoder (warped_src)
|
||||||
|
|
||||||
|
@ -116,19 +121,23 @@ class SAEModel(ModelBase):
|
||||||
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
||||||
|
|
||||||
pred_src_src = self.decoder(warped_src_inter_code)
|
pred_src_src = self.decoder(warped_src_inter_code)
|
||||||
pred_src_srcm = self.decoderm(warped_src_inter_code)
|
if self.options['learn_mask']:
|
||||||
|
pred_src_srcm = self.decoderm(warped_src_inter_code)
|
||||||
|
|
||||||
warped_dst_code = self.encoder (warped_dst)
|
warped_dst_code = self.encoder (warped_dst)
|
||||||
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
|
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
|
||||||
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
|
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
|
||||||
warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
|
warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
|
||||||
pred_dst_dst = self.decoder(warped_dst_inter_code)
|
pred_dst_dst = self.decoder(warped_dst_inter_code)
|
||||||
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
|
||||||
|
if self.options['learn_mask']:
|
||||||
|
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
||||||
|
|
||||||
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
|
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
|
||||||
pred_src_dst = self.decoder(warped_src_dst_inter_code)
|
pred_src_dst = self.decoder(warped_src_dst_inter_code)
|
||||||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
|
||||||
|
if self.options['learn_mask']:
|
||||||
|
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||||
else:
|
else:
|
||||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, adapt_k_size, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
self.encoder = modelify(SAEModel.DFEncFlow(resolution, adapt_k_size, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||||
|
|
||||||
|
@ -136,28 +145,28 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
||||||
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
||||||
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
|
||||||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
if self.options['learn_mask']:
|
||||||
|
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
||||||
|
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||||
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
||||||
self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5))
|
|
||||||
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
|
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
|
||||||
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
|
if self.options['learn_mask']:
|
||||||
|
self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5))
|
||||||
|
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
|
||||||
|
|
||||||
warped_src_code = self.encoder (warped_src)
|
warped_src_code = self.encoder (warped_src)
|
||||||
|
|
||||||
pred_src_src = self.decoder_src(warped_src_code)
|
|
||||||
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
|
||||||
|
|
||||||
warped_dst_code = self.encoder (warped_dst)
|
warped_dst_code = self.encoder (warped_dst)
|
||||||
|
pred_src_src = self.decoder_src(warped_src_code)
|
||||||
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
||||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
|
||||||
|
|
||||||
pred_src_dst = self.decoder_src(warped_dst_code)
|
pred_src_dst = self.decoder_src(warped_dst_code)
|
||||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
if self.options['learn_mask']:
|
||||||
|
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||||
|
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||||
|
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||||
|
|
||||||
ms_count = len(pred_src_src)
|
ms_count = len(pred_src_src)
|
||||||
|
|
||||||
|
@ -195,14 +204,16 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
if self.options['archi'] == 'liae':
|
if self.options['archi'] == 'liae':
|
||||||
src_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
src_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||||
src_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
|
||||||
dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||||
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
if self.options['learn_mask']:
|
||||||
|
src_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||||
|
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||||
else:
|
else:
|
||||||
src_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
|
src_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
|
||||||
src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
|
||||||
dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
|
dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
|
||||||
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
if self.options['learn_mask']:
|
||||||
|
src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
||||||
|
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
||||||
|
|
||||||
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
|
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
|
||||||
|
|
||||||
|
@ -216,18 +227,23 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) )
|
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) )
|
||||||
|
|
||||||
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ])
|
|
||||||
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], optimizer().get_updates(src_mask_loss, src_mask_loss_train_weights) )
|
|
||||||
|
|
||||||
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
||||||
self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) )
|
self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) )
|
||||||
|
|
||||||
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ])
|
if self.options['learn_mask']:
|
||||||
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], optimizer().get_updates(dst_mask_loss, dst_mask_loss_train_weights) )
|
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ])
|
||||||
|
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], optimizer().get_updates(src_mask_loss, src_mask_loss_train_weights) )
|
||||||
|
|
||||||
|
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ])
|
||||||
|
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], optimizer().get_updates(dst_mask_loss, dst_mask_loss_train_weights) )
|
||||||
|
|
||||||
|
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] )
|
||||||
|
|
||||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_src_srcm[-1], pred_dst_dst[-1], pred_dst_dstm[-1], pred_src_dst[-1], pred_src_dstm[-1]] )
|
|
||||||
else:
|
else:
|
||||||
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_src_dstm[-1] ])
|
if self.options['learn_mask']:
|
||||||
|
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_src_dstm[-1] ])
|
||||||
|
else:
|
||||||
|
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1] ])
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
f = SampleProcessor.TypeFlags
|
f = SampleProcessor.TypeFlags
|
||||||
|
@ -251,19 +267,23 @@ class SAEModel(ModelBase):
|
||||||
#override
|
#override
|
||||||
def onSave(self):
|
def onSave(self):
|
||||||
if self.options['archi'] == 'liae':
|
if self.options['archi'] == 'liae':
|
||||||
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
ar = [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
||||||
[self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)],
|
[self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)],
|
||||||
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)],
|
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)],
|
||||||
[self.decoder, self.get_strpath_storage_for_file(self.decoderH5)],
|
[self.decoder, self.get_strpath_storage_for_file(self.decoderH5)]
|
||||||
[self.decoderm, self.get_strpath_storage_for_file(self.decodermH5)],
|
]
|
||||||
] )
|
if self.options['learn_mask']:
|
||||||
|
ar += [ [self.decoderm, self.get_strpath_storage_for_file(self.decodermH5)] ]
|
||||||
else:
|
else:
|
||||||
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
ar = [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
|
||||||
[self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)],
|
[self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)],
|
||||||
[self.decoder_srcm, self.get_strpath_storage_for_file(self.decoder_srcmH5)],
|
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]
|
||||||
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)],
|
]
|
||||||
[self.decoder_dstm, self.get_strpath_storage_for_file(self.decoder_dstmH5)],
|
if self.options['learn_mask']:
|
||||||
] )
|
ar += [ [self.decoder_srcm, self.get_strpath_storage_for_file(self.decoder_srcmH5)],
|
||||||
|
[self.decoder_dstm, self.get_strpath_storage_for_file(self.decoder_dstmH5)] ]
|
||||||
|
|
||||||
|
self.save_weights_safe(ar)
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneEpoch(self, sample):
|
def onTrainOneEpoch(self, sample):
|
||||||
|
@ -273,8 +293,9 @@ class SAEModel(ModelBase):
|
||||||
src_loss, = self.src_train ([warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
|
src_loss, = self.src_train ([warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
|
||||||
dst_loss, = self.dst_train ([warped_dst, target_dst, target_dst_mask])
|
dst_loss, = self.dst_train ([warped_dst, target_dst, target_dst_mask])
|
||||||
|
|
||||||
src_mask_loss, = self.src_mask_train ([warped_src, target_src_mask])
|
if self.options['learn_mask']:
|
||||||
dst_mask_loss, = self.dst_mask_train ([warped_dst, target_dst_mask])
|
src_mask_loss, = self.src_mask_train ([warped_src, target_src_mask])
|
||||||
|
dst_mask_loss, = self.dst_mask_train ([warped_dst, target_dst_mask])
|
||||||
|
|
||||||
return ( ('src_loss', src_loss), ('dst_loss', dst_loss) )
|
return ( ('src_loss', src_loss), ('dst_loss', dst_loss) )
|
||||||
|
|
||||||
|
@ -286,7 +307,7 @@ class SAEModel(ModelBase):
|
||||||
test_B = sample[1][1][0:4]
|
test_B = sample[1][1][0:4]
|
||||||
test_B_m = sample[1][2][0:4]
|
test_B_m = sample[1][2][0:4]
|
||||||
|
|
||||||
S, D, SS, SM, DD, DM, SD, SDM, = [ x / 2 + 0.5 for x in ([test_A,test_B] + self.AE_view ([test_A, test_B]) ) ]
|
S, D, SS, DD, SD, = [ x / 2 + 0.5 for x in ([test_A,test_B] + self.AE_view ([test_A, test_B]) ) ]
|
||||||
#SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ]
|
#SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ]
|
||||||
|
|
||||||
st_x3 = []
|
st_x3 = []
|
||||||
|
@ -300,10 +321,14 @@ class SAEModel(ModelBase):
|
||||||
return [ ('SAE', np.concatenate (st_x3, axis=0 )), ]
|
return [ ('SAE', np.concatenate (st_x3, axis=0 )), ]
|
||||||
|
|
||||||
def predictor_func (self, face):
|
def predictor_func (self, face):
|
||||||
face = face * 2.0 - 1.0
|
face_tanh = face * 2.0 - 1.0
|
||||||
face_128_bgr = face[...,0:3]
|
face_bgr = face_tanh[...,0:3]
|
||||||
x, mx = [ (x[0] + 1.0) / 2.0 for x in self.AE_convert ( [ np.expand_dims(face_128_bgr,0) ] ) ]
|
prd = [ (x[0] + 1.0) / 2.0 for x in self.AE_convert ( [ np.expand_dims(face_bgr,0) ] ) ]
|
||||||
return np.concatenate ( (x,mx), -1 )
|
|
||||||
|
if not self.options['learn_mask']:
|
||||||
|
prd += [ np.expand_dims(face[...,3],-1) ]
|
||||||
|
|
||||||
|
return np.concatenate ( [prd[0], prd[1]], -1 )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def get_converter(self, **in_options):
|
def get_converter(self, **in_options):
|
||||||
|
@ -320,6 +345,7 @@ class SAEModel(ModelBase):
|
||||||
face_type=face_type,
|
face_type=face_type,
|
||||||
base_erode_mask_modifier=base_erode_mask_modifier,
|
base_erode_mask_modifier=base_erode_mask_modifier,
|
||||||
base_blur_mask_modifier=base_blur_mask_modifier,
|
base_blur_mask_modifier=base_blur_mask_modifier,
|
||||||
|
clip_border_mask_per=0.03125,
|
||||||
**in_options)
|
**in_options)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue