diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index f2b3d64..af6f22e 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -653,7 +653,7 @@ class AMPModel(ModelBase): for i in range(bs): self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) ) - self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dst[i], target_dstm[i], target_dstm_em[i]) ) + self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) ) if len(self.last_src_samples_loss) >= bs*16: src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)