from core.leras import nn tf = nn.tf class DenseNorm(nn.LayerBase): def __init__(self, dense=False, eps=1e-06, dtype=None, **kwargs): self.dense = dense if dtype is None: dtype = nn.floatx self.eps = tf.constant(eps, dtype=dtype, name="epsilon") super().__init__(**kwargs) def __call__(self, x): return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + self.eps) nn.DenseNorm = DenseNorm