fix error

This commit is contained in:
jh 2021-03-17 08:34:00 -07:00
commit faf2a4a885

View file

@ -508,8 +508,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
probs = tf.math.log([[noise, 1-noise]]) if label == 1 else tf.math.log([[1-noise, noise]])
x = tf.random.categorical(probs, num_labels)
x = tf.cast(x, tf.float32)
x = x * (1-smoothing) + (smoothing/num_labels)
# x = tf.math.scalar_mul(1-smoothing, x) + (smoothing/x.shape[1])
# x = x * (1-smoothing) + (smoothing/num_labels)
x = tf.math.scalar_mul(1-smoothing, x)
x = x + (smoothing/num_labels)
x = tf.reshape(x, (self.batch_size,) + tensor.shape[1:])
return x