mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 13:32:09 -07:00
+WScaleConv2DLayer
This commit is contained in:
parent
efd8c16daa
commit
5efb430f2a
1 changed files with 23 additions and 9 deletions
|
@ -59,6 +59,7 @@ Input = KL.Input
|
|||
|
||||
Dense = KL.Dense
|
||||
Conv2D = KL.Conv2D
|
||||
WScaleConv2DLayer = nnlib.WScaleConv2DLayer
|
||||
Conv2DTranspose = KL.Conv2DTranspose
|
||||
EqualConv2D = nnlib.EqualConv2D
|
||||
SeparableConv2D = KL.SeparableConv2D
|
||||
|
@ -788,12 +789,9 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
base = K.tile(pixels_batch, (1, output_h * output_w ) )
|
||||
base = K.flatten(base)
|
||||
|
||||
# base_y0 = base + (y0 * width)
|
||||
base_y0 = y0 * width
|
||||
base_y0 = base + base_y0
|
||||
# base_y1 = base + (y1 * width)
|
||||
base_y1 = y1 * width
|
||||
base_y1 = base_y1 + base
|
||||
base_y0 = base + y0 * width
|
||||
|
||||
base_y1 = base + y1 * width
|
||||
|
||||
indices_a = base_y0 + x0
|
||||
indices_b = base_y1 + x0
|
||||
|
@ -854,10 +852,26 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
|
||||
nnlib.BilinearInterpolation = BilinearInterpolation
|
||||
|
||||
|
||||
|
||||
|
||||
class WScaleConv2DLayer(KL.Conv2D):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['kernel_initializer'] = keras.initializers.random_normal()
|
||||
|
||||
super(WScaleConv2DLayer,self).__init__(*args,**kwargs)
|
||||
|
||||
def build(self, input_shape):
|
||||
super().build(input_shape)
|
||||
kernel_shape = K.int_shape(self.kernel)
|
||||
std = np.sqrt(2) / np.sqrt( np.prod(kernel_shape[:-1]) )
|
||||
self.wscale = K.constant(std, dtype=K.floatx() )
|
||||
|
||||
def call(self, input, **kwargs):
|
||||
k = self.kernel
|
||||
self.kernel = self.kernel*self.wscale
|
||||
x = super().call(input,**kwargs)
|
||||
self.kernel = k
|
||||
return x
|
||||
nnlib.WScaleConv2DLayer = WScaleConv2DLayer
|
||||
|
||||
class SelfAttention(KL.Layer):
|
||||
def __init__(self, nc, squeeze_factor=8, **kwargs):
|
||||
assert nc//squeeze_factor > 0, f"Input channels must be >= {squeeze_factor}, recieved nc={nc}"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue