mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-05 20:42:11 -07:00
Fixed batchnorm2d
This commit is contained in:
parent
9ef04b2207
commit
fcd398707f
1 changed files with 6 additions and 0 deletions
|
@ -34,6 +34,12 @@ class BatchNorm2D(nn.LayerBase):
|
|||
running_mean = tf.reshape ( self.running_mean, shape )
|
||||
running_var = tf.reshape ( self.running_var , shape )
|
||||
|
||||
x_mean = tf.math.reduce_mean(x, axis=[0] + nn.conv2d_spatial_axes, keepdims=True )
|
||||
running_mean = running_mean * self.momentum + x_mean * (1 - self.momentum)
|
||||
x_var = tf.math.reduce_variance(x, axis=[0] + nn.conv2d_spatial_axes, keepdims=True )
|
||||
running_var = running_var * self.momentum + x_var * (1 - self.momentum)
|
||||
|
||||
|
||||
x = (x - running_mean) / tf.sqrt( running_var + self.eps )
|
||||
x *= weight
|
||||
x += bias
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue