mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 13:32:09 -07:00
+WScaleConv2DLayer
This commit is contained in:
parent
efd8c16daa
commit
5efb430f2a
1 changed files with 23 additions and 9 deletions
|
@ -59,6 +59,7 @@ Input = KL.Input
|
||||||
|
|
||||||
Dense = KL.Dense
|
Dense = KL.Dense
|
||||||
Conv2D = KL.Conv2D
|
Conv2D = KL.Conv2D
|
||||||
|
WScaleConv2DLayer = nnlib.WScaleConv2DLayer
|
||||||
Conv2DTranspose = KL.Conv2DTranspose
|
Conv2DTranspose = KL.Conv2DTranspose
|
||||||
EqualConv2D = nnlib.EqualConv2D
|
EqualConv2D = nnlib.EqualConv2D
|
||||||
SeparableConv2D = KL.SeparableConv2D
|
SeparableConv2D = KL.SeparableConv2D
|
||||||
|
@ -788,12 +789,9 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
base = K.tile(pixels_batch, (1, output_h * output_w ) )
|
base = K.tile(pixels_batch, (1, output_h * output_w ) )
|
||||||
base = K.flatten(base)
|
base = K.flatten(base)
|
||||||
|
|
||||||
# base_y0 = base + (y0 * width)
|
base_y0 = base + y0 * width
|
||||||
base_y0 = y0 * width
|
|
||||||
base_y0 = base + base_y0
|
base_y1 = base + y1 * width
|
||||||
# base_y1 = base + (y1 * width)
|
|
||||||
base_y1 = y1 * width
|
|
||||||
base_y1 = base_y1 + base
|
|
||||||
|
|
||||||
indices_a = base_y0 + x0
|
indices_a = base_y0 + x0
|
||||||
indices_b = base_y1 + x0
|
indices_b = base_y1 + x0
|
||||||
|
@ -854,10 +852,26 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
|
|
||||||
nnlib.BilinearInterpolation = BilinearInterpolation
|
nnlib.BilinearInterpolation = BilinearInterpolation
|
||||||
|
|
||||||
|
class WScaleConv2DLayer(KL.Conv2D):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs['kernel_initializer'] = keras.initializers.random_normal()
|
||||||
|
|
||||||
|
super(WScaleConv2DLayer,self).__init__(*args,**kwargs)
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
super().build(input_shape)
|
||||||
|
kernel_shape = K.int_shape(self.kernel)
|
||||||
|
std = np.sqrt(2) / np.sqrt( np.prod(kernel_shape[:-1]) )
|
||||||
|
self.wscale = K.constant(std, dtype=K.floatx() )
|
||||||
|
|
||||||
|
def call(self, input, **kwargs):
|
||||||
|
k = self.kernel
|
||||||
|
self.kernel = self.kernel*self.wscale
|
||||||
|
x = super().call(input,**kwargs)
|
||||||
|
self.kernel = k
|
||||||
|
return x
|
||||||
|
nnlib.WScaleConv2DLayer = WScaleConv2DLayer
|
||||||
|
|
||||||
class SelfAttention(KL.Layer):
|
class SelfAttention(KL.Layer):
|
||||||
def __init__(self, nc, squeeze_factor=8, **kwargs):
|
def __init__(self, nc, squeeze_factor=8, **kwargs):
|
||||||
assert nc//squeeze_factor > 0, f"Input channels must be >= {squeeze_factor}, recieved nc={nc}"
|
assert nc//squeeze_factor > 0, f"Input channels must be >= {squeeze_factor}, recieved nc={nc}"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue