This commit is contained in:
Colombo 2020-01-08 15:12:43 +04:00
parent b8182ae42b
commit bbe81b20af

View file

@ -853,15 +853,20 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
nnlib.BilinearInterpolation = BilinearInterpolation nnlib.BilinearInterpolation = BilinearInterpolation
class WScaleConv2DLayer(KL.Conv2D): class WScaleConv2DLayer(KL.Conv2D):
def __init__(self, *args, **kwargs): def __init__(self, *args, gain=None, **kwargs):
kwargs['kernel_initializer'] = keras.initializers.random_normal() kwargs['kernel_initializer'] = keras.initializers.random_normal()
if gain is None:
gain = np.sqrt(2)
self.gain = gain
super(WScaleConv2DLayer,self).__init__(*args,**kwargs) super(WScaleConv2DLayer,self).__init__(*args,**kwargs)
def build(self, input_shape): def build(self, input_shape):
super().build(input_shape) super().build(input_shape)
kernel_shape = K.int_shape(self.kernel) kernel_shape = K.int_shape(self.kernel)
std = np.sqrt(2) / np.sqrt( np.prod(kernel_shape[:-1]) ) std = np.sqrt(self.gain) / np.sqrt( np.prod(kernel_shape[:-1]) )
self.wscale = K.constant(std, dtype=K.floatx() ) self.wscale = K.constant(std, dtype=K.floatx() )
def call(self, input, **kwargs): def call(self, input, **kwargs):
@ -870,6 +875,12 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
x = super().call(input,**kwargs) x = super().call(input,**kwargs)
self.kernel = k self.kernel = k
return x return x
def get_config(self):
config = {"gain": self.gain}
base_config = super(WScaleConv2DLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
nnlib.WScaleConv2DLayer = WScaleConv2DLayer nnlib.WScaleConv2DLayer = WScaleConv2DLayer
class SelfAttention(KL.Layer): class SelfAttention(KL.Layer):