diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 0040cf7..02c0581 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -282,6 +282,9 @@ class SAEHDModel(ModelBase): weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() \ + self.res0.get_weights() + self.res1.get_weights() + self.res2.get_weights() + self.out_conv.get_weights() + + if self.is_hd: + weights += self.res3.get_weights() if include_mask: weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() \