diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index 0a5bc18..e3ed27d 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -622,24 +622,22 @@ class AMPModel(ModelBase): src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) for i in range(bs): - self.last_src_samples_loss.append ( (src_loss[i], warped_src[i], target_src[i], target_srcm[i], target_srcm_em[i]) ) - self.last_dst_samples_loss.append ( (dst_loss[i], warped_dst[i], target_dst[i], target_dstm[i], target_dstm_em[i]) ) + self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) ) + self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) ) if len(self.last_src_samples_loss) >= bs*16: src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True) dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True) - warped_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) - target_src = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) - target_srcm = np.stack( [ x[3] for x in src_samples_loss[:bs] ] ) - target_srcm_em = np.stack( [ x[4] for x in src_samples_loss[:bs] ] ) + target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) + target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) + target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] ) - warped_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) - target_dst = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) - target_dstm = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) - target_dstm_em = np.stack( [ x[4] for x in dst_samples_loss[:bs] ] ) + target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) + target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) + target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) - src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em) self.last_src_samples_loss = [] self.last_dst_samples_loss = []