+WScaleConv2DLayer

This commit is contained in:
Colombo 2020-01-01 11:15:55 +04:00
parent efd8c16daa
commit 5efb430f2a

View file

@ -59,6 +59,7 @@ Input = KL.Input
Dense = KL.Dense
Conv2D = KL.Conv2D
WScaleConv2DLayer = nnlib.WScaleConv2DLayer
Conv2DTranspose = KL.Conv2DTranspose
EqualConv2D = nnlib.EqualConv2D
SeparableConv2D = KL.SeparableConv2D
@ -788,12 +789,9 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
base = K.tile(pixels_batch, (1, output_h * output_w ) )
base = K.flatten(base)
# base_y0 = base + (y0 * width)
base_y0 = y0 * width
base_y0 = base + base_y0
# base_y1 = base + (y1 * width)
base_y1 = y1 * width
base_y1 = base_y1 + base
base_y0 = base + y0 * width
base_y1 = base + y1 * width
indices_a = base_y0 + x0
indices_b = base_y1 + x0
@ -854,10 +852,26 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
nnlib.BilinearInterpolation = BilinearInterpolation
class WScaleConv2DLayer(KL.Conv2D):
def __init__(self, *args, **kwargs):
kwargs['kernel_initializer'] = keras.initializers.random_normal()
super(WScaleConv2DLayer,self).__init__(*args,**kwargs)
def build(self, input_shape):
super().build(input_shape)
kernel_shape = K.int_shape(self.kernel)
std = np.sqrt(2) / np.sqrt( np.prod(kernel_shape[:-1]) )
self.wscale = K.constant(std, dtype=K.floatx() )
def call(self, input, **kwargs):
k = self.kernel
self.kernel = self.kernel*self.wscale
x = super().call(input,**kwargs)
self.kernel = k
return x
nnlib.WScaleConv2DLayer = WScaleConv2DLayer
class SelfAttention(KL.Layer):
def __init__(self, nc, squeeze_factor=8, **kwargs):
assert nc//squeeze_factor > 0, f"Input channels must be >= {squeeze_factor}, recieved nc={nc}"