mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
Fixed batchnorm2d
This commit is contained in:
parent
9ef04b2207
commit
fcd398707f
1 changed files with 6 additions and 0 deletions
|
@ -34,6 +34,12 @@ class BatchNorm2D(nn.LayerBase):
|
||||||
running_mean = tf.reshape ( self.running_mean, shape )
|
running_mean = tf.reshape ( self.running_mean, shape )
|
||||||
running_var = tf.reshape ( self.running_var , shape )
|
running_var = tf.reshape ( self.running_var , shape )
|
||||||
|
|
||||||
|
x_mean = tf.math.reduce_mean(x, axis=[0] + nn.conv2d_spatial_axes, keepdims=True )
|
||||||
|
running_mean = running_mean * self.momentum + x_mean * (1 - self.momentum)
|
||||||
|
x_var = tf.math.reduce_variance(x, axis=[0] + nn.conv2d_spatial_axes, keepdims=True )
|
||||||
|
running_var = running_var * self.momentum + x_var * (1 - self.momentum)
|
||||||
|
|
||||||
|
|
||||||
x = (x - running_mean) / tf.sqrt( running_var + self.eps )
|
x = (x - running_mean) / tf.sqrt( running_var + self.eps )
|
||||||
x *= weight
|
x *= weight
|
||||||
x += bias
|
x += bias
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue