mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-08 05:51:40 -07:00
-
This commit is contained in:
parent
b8182ae42b
commit
bbe81b20af
1 changed files with 13 additions and 2 deletions
|
@ -853,15 +853,20 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
nnlib.BilinearInterpolation = BilinearInterpolation
|
||||
|
||||
class WScaleConv2DLayer(KL.Conv2D):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, gain=None, **kwargs):
|
||||
kwargs['kernel_initializer'] = keras.initializers.random_normal()
|
||||
|
||||
if gain is None:
|
||||
gain = np.sqrt(2)
|
||||
|
||||
self.gain = gain
|
||||
|
||||
super(WScaleConv2DLayer,self).__init__(*args,**kwargs)
|
||||
|
||||
def build(self, input_shape):
|
||||
super().build(input_shape)
|
||||
kernel_shape = K.int_shape(self.kernel)
|
||||
std = np.sqrt(2) / np.sqrt( np.prod(kernel_shape[:-1]) )
|
||||
std = np.sqrt(self.gain) / np.sqrt( np.prod(kernel_shape[:-1]) )
|
||||
self.wscale = K.constant(std, dtype=K.floatx() )
|
||||
|
||||
def call(self, input, **kwargs):
|
||||
|
@ -870,6 +875,12 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
x = super().call(input,**kwargs)
|
||||
self.kernel = k
|
||||
return x
|
||||
|
||||
def get_config(self):
|
||||
config = {"gain": self.gain}
|
||||
base_config = super(WScaleConv2DLayer, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
nnlib.WScaleConv2DLayer = WScaleConv2DLayer
|
||||
|
||||
class SelfAttention(KL.Layer):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue