mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-14 00:53:48 -07:00
upd liae loss
This commit is contained in:
parent
e4a360e5ff
commit
4c2cb44643
1 changed files with 8 additions and 8 deletions
|
@ -390,23 +390,23 @@ class SAEv2Model(ModelBase):
|
||||||
self.target_srcm, self.target_dstm = Input(mask_shape), Input(mask_shape)
|
self.target_srcm, self.target_dstm = Input(mask_shape), Input(mask_shape)
|
||||||
|
|
||||||
warped_src_code = self.encoder (self.warped_src)
|
warped_src_code = self.encoder (self.warped_src)
|
||||||
self.src_code = warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
||||||
src_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
self.src_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
||||||
|
|
||||||
warped_dst_code = self.encoder (self.warped_dst)
|
warped_dst_code = self.encoder (self.warped_dst)
|
||||||
self.dst_code = warped_dst_inter_B_code = self.inter_B (warped_dst_code)
|
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
|
||||||
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
|
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
|
||||||
dst_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
|
self.dst_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
|
||||||
|
|
||||||
src_dst_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
|
src_dst_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
|
||||||
|
|
||||||
self.pred_src_src = self.decoder(src_code)
|
self.pred_src_src = self.decoder(self.src_code)
|
||||||
self.pred_dst_dst = self.decoder(dst_code)
|
self.pred_dst_dst = self.decoder(self.dst_code)
|
||||||
self.pred_src_dst = self.decoder(src_dst_code)
|
self.pred_src_dst = self.decoder(src_dst_code)
|
||||||
|
|
||||||
if learn_mask:
|
if learn_mask:
|
||||||
self.pred_src_srcm = self.decoderm(src_code)
|
self.pred_src_srcm = self.decoderm(self.src_code)
|
||||||
self.pred_dst_dstm = self.decoderm(dst_code)
|
self.pred_dst_dstm = self.decoderm(self.dst_code)
|
||||||
self.pred_src_dstm = self.decoderm(src_dst_code)
|
self.pred_src_dstm = self.decoderm(src_dst_code)
|
||||||
|
|
||||||
def get_model_filename_list(self, exclude_for_pretrain=False):
|
def get_model_filename_list(self, exclude_for_pretrain=False):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue