mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-30 19:50:08 -07:00
update leras
This commit is contained in:
parent
e1635ff760
commit
e0a55ff1c3
3 changed files with 39 additions and 4 deletions
|
@ -243,15 +243,15 @@ def initialize_models(nn):
|
|||
nn.ModelBase = ModelBase
|
||||
|
||||
class PatchDiscriminator(nn.ModelBase):
|
||||
def on_build(self, patch_size, in_ch, base_ch=256, kernel_initializer=None):
|
||||
def on_build(self, patch_size, in_ch, base_ch=256, conv_kernel_initializer=None):
|
||||
prev_ch = in_ch
|
||||
self.convs = []
|
||||
for i, (kernel_size, strides) in enumerate(patch_discriminator_kernels[patch_size]):
|
||||
cur_ch = base_ch * min( (2**i), 8 )
|
||||
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=kernel_initializer) )
|
||||
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) )
|
||||
prev_ch = cur_ch
|
||||
|
||||
self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=kernel_initializer)
|
||||
self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, x):
|
||||
for conv in self.convs:
|
||||
|
@ -260,7 +260,32 @@ def initialize_models(nn):
|
|||
|
||||
nn.PatchDiscriminator = PatchDiscriminator
|
||||
|
||||
class IllumDiscriminator(nn.ModelBase):
|
||||
def on_build(self, patch_size, in_ch, base_ch=256, conv_kernel_initializer=None):
|
||||
prev_ch = in_ch
|
||||
self.convs = []
|
||||
for i, (kernel_size, strides) in enumerate(patch_discriminator_kernels[patch_size]):
|
||||
cur_ch = base_ch * min( (2**i), 8 )
|
||||
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) )
|
||||
prev_ch = cur_ch
|
||||
|
||||
self.out1 = nn.Conv2D( 1, 1024, kernel_size=1, strides=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
self.out2 = nn.Conv2D( 1024, 1, kernel_size=1, strides=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
|
||||
|
||||
def forward(self, x):
|
||||
for conv in self.convs:
|
||||
x = tf.nn.leaky_relu( conv(x), 0.1 )
|
||||
|
||||
x = tf.reduce_mean(x, axis=nn.conv2d_ch_axis, keep_dims=True)
|
||||
|
||||
x = self.out1(x)
|
||||
x = tf.nn.leaky_relu(x, 0.1 )
|
||||
x = self.out2(x)
|
||||
|
||||
return x
|
||||
|
||||
nn.IllumDiscriminator = IllumDiscriminator
|
||||
|
||||
patch_discriminator_kernels = \
|
||||
{ 1 : [ [1,1] ],
|
||||
2 : [ [2,1] ],
|
||||
|
|
|
@ -55,6 +55,7 @@ class nn():
|
|||
tf_upsample2d = None
|
||||
tf_upsample2d_bilinear = None
|
||||
tf_flatten = None
|
||||
tf_max_pool = None
|
||||
tf_reshape_4D = None
|
||||
tf_random_binomial = None
|
||||
tf_gaussian_blur = None
|
||||
|
@ -82,7 +83,8 @@ class nn():
|
|||
|
||||
# Models
|
||||
PatchDiscriminator = None
|
||||
|
||||
IllumDiscriminator = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(device_config=None, floatx="float32", data_format="NHWC"):
|
||||
|
||||
|
|
|
@ -129,6 +129,14 @@ def initialize_tensor_ops(nn):
|
|||
|
||||
nn.tf_flatten = tf_flatten
|
||||
|
||||
def tf_max_pool(x, kernel_size, strides):
|
||||
if nn.data_format == "NHWC":
|
||||
return tf.nn.max_pool(x, [1,kernel_size,kernel_size,1], [1,strides,strides,1], "VALID", data_format=nn.data_format)
|
||||
else:
|
||||
return tf.nn.max_pool(x, [1,1,kernel_size,kernel_size], [1,1,strides,strides], "VALID", data_format=nn.data_format)
|
||||
|
||||
nn.tf_max_pool = tf_max_pool
|
||||
|
||||
def tf_reshape_4D(x, w,h,c):
|
||||
if nn.data_format == "NHWC":
|
||||
# match NCHW version in order to switch data_format without problems
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue