diff --git a/core/leras/layers/BatchNorm2D.py b/core/leras/layers/BatchNorm2D.py index 62de521..8ba1030 100644 --- a/core/leras/layers/BatchNorm2D.py +++ b/core/leras/layers/BatchNorm2D.py @@ -34,6 +34,12 @@ class BatchNorm2D(nn.LayerBase): running_mean = tf.reshape ( self.running_mean, shape ) running_var = tf.reshape ( self.running_var , shape ) + x_mean = tf.math.reduce_mean(x, axis=[0] + nn.conv2d_spatial_axes, keepdims=True ) + running_mean = running_mean * self.momentum + x_mean * (1 - self.momentum) + x_var = tf.math.reduce_variance(x, axis=[0] + nn.conv2d_spatial_axes, keepdims=True ) + running_var = running_var * self.momentum + x_var * (1 - self.momentum) + + x = (x - running_mean) / tf.sqrt( running_var + self.eps ) x *= weight x += bias