Fixed batchnorm2d

This commit is contained in:
Jose 2023-02-07 11:17:04 +01:00 committed by GitHub
parent 9ef04b2207
commit fcd398707f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -34,6 +34,12 @@ class BatchNorm2D(nn.LayerBase):
running_mean = tf.reshape ( self.running_mean, shape )
running_var = tf.reshape ( self.running_var , shape )
x_mean = tf.math.reduce_mean(x, axis=[0] + nn.conv2d_spatial_axes, keepdims=True )
running_mean = running_mean * self.momentum + x_mean * (1 - self.momentum)
x_var = tf.math.reduce_variance(x, axis=[0] + nn.conv2d_spatial_axes, keepdims=True )
running_var = running_var * self.momentum + x_var * (1 - self.momentum)
x = (x - running_mean) / tf.sqrt( running_var + self.eps )
x *= weight
x += bias