diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index 1588f56..2c60741 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -207,7 +207,7 @@ class AMPModel(ModelBase): class Inter(nn.ModelBase): def on_build(self): - self.dense2 = nn.Dense( ae_dims, lowest_dense_res * lowest_dense_res * ae_dims ) + self.dense2 = nn.Dense( ae_dims, lowest_dense_res * lowest_dense_res * inter_dims ) def forward(self, inp): x = inp