diff --git a/models/Model_SAEv2/Model.py b/models/Model_SAEv2/Model.py index 78d311c..ceed8f7 100644 --- a/models/Model_SAEv2/Model.py +++ b/models/Model_SAEv2/Model.py @@ -390,23 +390,23 @@ class SAEv2Model(ModelBase): self.target_srcm, self.target_dstm = Input(mask_shape), Input(mask_shape) warped_src_code = self.encoder (self.warped_src) - self.src_code = warped_src_inter_AB_code = self.inter_AB (warped_src_code) - src_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code]) + warped_src_inter_AB_code = self.inter_AB (warped_src_code) + self.src_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code]) warped_dst_code = self.encoder (self.warped_dst) - self.dst_code = warped_dst_inter_B_code = self.inter_B (warped_dst_code) + warped_dst_inter_B_code = self.inter_B (warped_dst_code) warped_dst_inter_AB_code = self.inter_AB (warped_dst_code) - dst_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code]) + self.dst_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code]) src_dst_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code]) - self.pred_src_src = self.decoder(src_code) - self.pred_dst_dst = self.decoder(dst_code) + self.pred_src_src = self.decoder(self.src_code) + self.pred_dst_dst = self.decoder(self.dst_code) self.pred_src_dst = self.decoder(src_dst_code) if learn_mask: - self.pred_src_srcm = self.decoderm(src_code) - self.pred_dst_dstm = self.decoderm(dst_code) + self.pred_src_srcm = self.decoderm(self.src_code) + self.pred_dst_dstm = self.decoderm(self.dst_code) self.pred_src_dstm = self.decoderm(src_dst_code) def get_model_filename_list(self, exclude_for_pretrain=False):