diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index b434f71..4044ee9 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -436,7 +436,8 @@ class SAEModel(ModelBase): else: if self.options['learn_mask']: - self.AE_convert = K.function([warped_dst], [pred_src_dst[-1], pred_dst_dstm[-1], pred_src_dstm[-1]]) + from keras import K + self.AE_convert = K.function([warped_dst], [pred_src_dst[-1][:4], pred_dst_dstm[-1], pred_src_dstm[-1]]) else: self.AE_convert = K.function([warped_dst], [pred_src_dst[-1]])