diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index 61c7aef..975cf94 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -853,15 +853,20 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator nnlib.BilinearInterpolation = BilinearInterpolation class WScaleConv2DLayer(KL.Conv2D): - def __init__(self, *args, **kwargs): + def __init__(self, *args, gain=None, **kwargs): kwargs['kernel_initializer'] = keras.initializers.random_normal() + if gain is None: + gain = np.sqrt(2) + + self.gain = gain + super(WScaleConv2DLayer,self).__init__(*args,**kwargs) def build(self, input_shape): super().build(input_shape) kernel_shape = K.int_shape(self.kernel) - std = np.sqrt(2) / np.sqrt( np.prod(kernel_shape[:-1]) ) + std = np.sqrt(self.gain) / np.sqrt( np.prod(kernel_shape[:-1]) ) self.wscale = K.constant(std, dtype=K.floatx() ) def call(self, input, **kwargs): @@ -870,6 +875,12 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator x = super().call(input,**kwargs) self.kernel = k return x + + def get_config(self): + config = {"gain": self.gain} + base_config = super(WScaleConv2DLayer, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + nnlib.WScaleConv2DLayer = WScaleConv2DLayer class SelfAttention(KL.Layer):