From 248ab19e4be780ab71c6b1d3ea14635528c4c76e Mon Sep 17 00:00:00 2001 From: Cioscos Date: Sun, 3 Oct 2021 18:13:26 +0200 Subject: [PATCH] Fixed retraining_samples --- models/Model_SAEHD/Model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index a1b71f7..b2703af 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -915,13 +915,13 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) - for i in range(bs): - self.last_src_samples_loss.append ( (target_src[i], target_srcm[i], target_srcm_em[i], src_loss[i] ) ) - self.last_dst_samples_loss.append ( (target_dst[i], target_dstm[i], target_dstm_em[i], dst_loss[i] ) ) - if self.options['retraining_samples']: bs = self.get_batch_size() + for i in range(bs): + self.last_src_samples_loss.append ( (target_src[i], target_srcm[i], target_srcm_em[i], src_loss[i] ) ) + self.last_dst_samples_loss.append ( (target_dst[i], target_dstm[i], target_dstm_em[i], dst_loss[i] ) ) + if len(self.last_src_samples_loss) >= bs*16: src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True) dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True)