diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 478fad4..78221c0 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -387,9 +387,9 @@ class SAEModel(ModelBase): #override def onGetPreview(self, sample): test_S = sample[0][1][0:4] #first 4 samples - test_S_m = sample[0][2][0:4] #first 4 samples + test_S_m = sample[0][1+self.ms_count][0:4] #first 4 samples test_D = sample[1][1][0:4] - test_D_m = sample[1][2][0:4] + test_D_m = sample[1][1+self.ms_count][0:4] if self.options['learn_mask']: S, D, SS, DD, DDM, SD, SDM = [ np.clip(x, 0.0, 1.0) for x in ([test_S,test_D] + self.AE_view ([test_S, test_D]) ) ]