upd liae loss

This commit is contained in:
Colombo 2019-10-06 10:06:22 +04:00
parent e4a360e5ff
commit 4c2cb44643

View file

@ -390,23 +390,23 @@ class SAEv2Model(ModelBase):
self.target_srcm, self.target_dstm = Input(mask_shape), Input(mask_shape)
warped_src_code = self.encoder (self.warped_src)
self.src_code = warped_src_inter_AB_code = self.inter_AB (warped_src_code)
src_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
self.src_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
warped_dst_code = self.encoder (self.warped_dst)
self.dst_code = warped_dst_inter_B_code = self.inter_B (warped_dst_code)
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
dst_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
self.dst_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
src_dst_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
self.pred_src_src = self.decoder(src_code)
self.pred_dst_dst = self.decoder(dst_code)
self.pred_src_src = self.decoder(self.src_code)
self.pred_dst_dst = self.decoder(self.dst_code)
self.pred_src_dst = self.decoder(src_dst_code)
if learn_mask:
self.pred_src_srcm = self.decoderm(src_code)
self.pred_dst_dstm = self.decoderm(dst_code)
self.pred_src_srcm = self.decoderm(self.src_code)
self.pred_dst_dstm = self.decoderm(self.dst_code)
self.pred_src_dstm = self.decoderm(src_dst_code)
def get_model_filename_list(self, exclude_for_pretrain=False):