diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index 711edb8..61c7aef 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -59,6 +59,7 @@ Input = KL.Input Dense = KL.Dense Conv2D = KL.Conv2D +WScaleConv2DLayer = nnlib.WScaleConv2DLayer Conv2DTranspose = KL.Conv2DTranspose EqualConv2D = nnlib.EqualConv2D SeparableConv2D = KL.SeparableConv2D @@ -788,12 +789,9 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator base = K.tile(pixels_batch, (1, output_h * output_w ) ) base = K.flatten(base) - # base_y0 = base + (y0 * width) - base_y0 = y0 * width - base_y0 = base + base_y0 - # base_y1 = base + (y1 * width) - base_y1 = y1 * width - base_y1 = base_y1 + base + base_y0 = base + y0 * width + + base_y1 = base + y1 * width indices_a = base_y0 + x0 indices_b = base_y1 + x0 @@ -854,10 +852,26 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator nnlib.BilinearInterpolation = BilinearInterpolation - - - + class WScaleConv2DLayer(KL.Conv2D): + def __init__(self, *args, **kwargs): + kwargs['kernel_initializer'] = keras.initializers.random_normal() + + super(WScaleConv2DLayer,self).__init__(*args,**kwargs) + def build(self, input_shape): + super().build(input_shape) + kernel_shape = K.int_shape(self.kernel) + std = np.sqrt(2) / np.sqrt( np.prod(kernel_shape[:-1]) ) + self.wscale = K.constant(std, dtype=K.floatx() ) + + def call(self, input, **kwargs): + k = self.kernel + self.kernel = self.kernel*self.wscale + x = super().call(input,**kwargs) + self.kernel = k + return x + nnlib.WScaleConv2DLayer = WScaleConv2DLayer + class SelfAttention(KL.Layer): def __init__(self, nc, squeeze_factor=8, **kwargs): assert nc//squeeze_factor > 0, f"Input channels must be >= {squeeze_factor}, recieved nc={nc}"