diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 16630da..33eaf1d 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -580,7 +580,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... x = tf.cast(x, tf.float32) x = tf.math.scalar_mul(1-smoothing, x) # x = x + (smoothing/num_labels) - x = tf.reshape(x, (self.batch_size,) + tensor.shape[1:]) + x = tf.reshape(x, (self.batch_size,) + tensor.shape.as_list()[1:]) return x smoothing = self.options['gan_smoothing']