diff --git a/core/leras/layers.py b/core/leras/layers.py index a7f2bd0..827f28c 100644 --- a/core/leras/layers.py +++ b/core/leras/layers.py @@ -78,6 +78,7 @@ def initialize_layers(nn): def init_weights(self): nn.tf_init_weights(self.get_weights()) + nn.Saveable = Saveable class LayerBase(): @@ -318,7 +319,7 @@ def initialize_layers(nn): dim_size = dim_size * stride_size return dim_size nn.Conv2DTranspose = Conv2DTranspose - + class BlurPool(LayerBase): def __init__(self, filt_size=3, stride=2, **kwargs ): @@ -439,6 +440,44 @@ def initialize_layers(nn): return x nn.Dense = Dense + class InstanceNorm2D(LayerBase): + def __init__(self, in_ch, dtype=None, **kwargs): + self.in_ch = in_ch + + if dtype is None: + dtype = nn.tf_floatx + self.dtype = dtype + + super().__init__(**kwargs) + + def build_weights(self): + kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype) + self.weight = tf.get_variable("weight", (self.in_ch,), dtype=self.dtype, initializer=kernel_initializer ) + self.bias = tf.get_variable("bias", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros() ) + + def get_weights(self): + return [self.weight, self.bias] + + def __call__(self, x): + if nn.data_format == "NHWC": + shape = (1,1,1,self.in_ch) + else: + shape = (1,self.in_ch,1,1) + + weight = tf.reshape ( self.weight , shape ) + bias = tf.reshape ( self.bias , shape ) + + x_mean = tf.reduce_mean(x, axis=nn.conv2d_spatial_axes, keepdims=True ) + x_std = tf.math.reduce_std(x, axis=nn.conv2d_spatial_axes, keepdims=True ) + 1e-5 + + x = (x - x_mean) / x_std + x *= weight + x += bias + + return x + + nn.InstanceNorm2D = InstanceNorm2D + class BatchNorm2D(LayerBase): """ currently not for training @@ -477,4 +516,58 @@ def initialize_layers(nn): x += bias return x - nn.BatchNorm2D = BatchNorm2D \ No newline at end of file + nn.BatchNorm2D = BatchNorm2D + + class AdaIN(LayerBase): + """ + """ + def __init__(self, in_ch, mlp_ch, kernel_initializer=None, dtype=None, **kwargs): + self.in_ch = in_ch + self.mlp_ch = mlp_ch + self.kernel_initializer = kernel_initializer + + if dtype is None: + dtype = nn.tf_floatx + self.dtype = dtype + + super().__init__(**kwargs) + + def build_weights(self): + kernel_initializer = self.kernel_initializer + if kernel_initializer is None: + kernel_initializer = tf.initializers.he_normal()#(dtype=self.dtype) + + self.weight1 = tf.get_variable("weight1", (self.mlp_ch, self.in_ch), dtype=self.dtype, initializer=kernel_initializer) + self.bias1 = tf.get_variable("bias1", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros()) + self.weight2 = tf.get_variable("weight2", (self.mlp_ch, self.in_ch), dtype=self.dtype, initializer=kernel_initializer) + self.bias2 = tf.get_variable("bias2", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros()) + + def get_weights(self): + return [self.weight1, self.bias1, self.weight2, self.bias2] + + def __call__(self, inputs): + x, mlp = inputs + + gamma = tf.matmul(mlp, self.weight1) + gamma = tf.add(gamma, tf.reshape(self.bias1, (1,self.in_ch) ) ) + + beta = tf.matmul(mlp, self.weight2) + beta = tf.add(beta, tf.reshape(self.bias2, (1,self.in_ch) ) ) + + + if nn.data_format == "NHWC": + shape = (-1,1,1,self.in_ch) + else: + shape = (-1,self.in_ch,1,1) + + x_mean = tf.reduce_mean(x, axis=nn.conv2d_spatial_axes, keepdims=True ) + x_std = tf.math.reduce_std(x, axis=nn.conv2d_spatial_axes, keepdims=True ) + 1e-5 + + x = (x - x_mean) / x_std + x *= tf.reshape(gamma, shape) + + x += tf.reshape(beta, shape) + + return x + + nn.AdaIN = AdaIN \ No newline at end of file diff --git a/core/leras/nn.py b/core/leras/nn.py index 7cb862f..6abe6a0 100644 --- a/core/leras/nn.py +++ b/core/leras/nn.py @@ -72,7 +72,9 @@ class nn(): Conv2DTranspose = None BlurPool = None Dense = None + InstanceNorm2D = None BatchNorm2D = None + AdaIN = None # Initializers initializers = None