From 5073507b10bb251b7feb83b81e4f9589b0cf9c8e Mon Sep 17 00:00:00 2001 From: jh Date: Thu, 12 Sep 2019 19:40:17 -0700 Subject: [PATCH] Add mask to 2nd column of previews --- models/Model_SAE/Model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 5db9c89..e806782 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -440,8 +440,9 @@ class SAEModel(ModelBase): if self.options['learn_mask']: self.AE_view = K.function([warped_src, warped_dst], - [pred_src_src[-1], pred_dst_dst[-1], pred_dst_dstm[-1], pred_src_dst[-1], - pred_src_dstm[-1]]) + [pred_src_src[-1], pred_src_srcm[-1], + pred_dst_dst[-1], pred_dst_dstm[-1], + pred_src_dst[-1], pred_src_dstm[-1]]) else: self.AE_view = K.function([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1]]) @@ -579,9 +580,9 @@ class SAEModel(ModelBase): test_D_m = sample[1][1+self.ms_count][0:4] if self.options['learn_mask']: - S, D, SS, DD, DDM, SD, SDM = [np.clip(x, 0.0, 1.0) for x in + S, D, SS, SSM, DD, DDM, SD, SDM = [np.clip(x, 0.0, 1.0) for x in ([test_S, test_D] + self.AE_view([test_S, test_D]))] - DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] + SSM, DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [SSM, DDM, SDM] ] else: S, D, SS, DD, SD, = [ np.clip(x, 0.0, 1.0) for x in ([test_S,test_D] + self.AE_view ([test_S, test_D]) ) ] @@ -596,7 +597,7 @@ class SAEModel(ModelBase): if self.options['learn_mask']: st_m = [] for i in range(0, len(test_S)): - ar = S[i]*test_S_m[i], SS[i], D[i]*test_D_m[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) + ar = S[i]*test_S_m[i], SS[i]*SSM[i], D[i]*test_D_m[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) st_m.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAE masked', np.concatenate (st_m, axis=0 )), ]