mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 21:13:20 -07:00
Added new features in leras
This commit is contained in:
parent
0e2cacce89
commit
c0ddc6fc5d
3 changed files with 19 additions and 8 deletions
|
@ -21,6 +21,13 @@ class DeepFakeArchi(nn.ArchiBase):
|
||||||
|
|
||||||
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
||||||
|
|
||||||
|
if 'c' in opts:
|
||||||
|
def act(x, alpha=0.1):
|
||||||
|
return tf.nn.relu(x)
|
||||||
|
else:
|
||||||
|
def act(x, alpha=0.1):
|
||||||
|
return tf.nn.leaky_relu(x, alpha)
|
||||||
|
|
||||||
if mod is None:
|
if mod is None:
|
||||||
class Downscale(nn.ModelBase):
|
class Downscale(nn.ModelBase):
|
||||||
def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ):
|
def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ):
|
||||||
|
@ -34,7 +41,7 @@ class DeepFakeArchi(nn.ArchiBase):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = tf.nn.leaky_relu(x, 0.1)
|
x = act(x, 0.1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_out_ch(self):
|
def get_out_ch(self):
|
||||||
|
@ -62,7 +69,7 @@ class DeepFakeArchi(nn.ArchiBase):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = tf.nn.leaky_relu(x, 0.1)
|
x = act(x, 0.1)
|
||||||
x = nn.depth_to_space(x, 2)
|
x = nn.depth_to_space(x, 2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -73,9 +80,9 @@ class DeepFakeArchi(nn.ArchiBase):
|
||||||
|
|
||||||
def forward(self, inp):
|
def forward(self, inp):
|
||||||
x = self.conv1(inp)
|
x = self.conv1(inp)
|
||||||
x = tf.nn.leaky_relu(x, 0.2)
|
x = act(x, 0.2)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = tf.nn.leaky_relu(inp + x, 0.2)
|
x = act(inp + x, 0.2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class Encoder(nn.ModelBase):
|
class Encoder(nn.ModelBase):
|
||||||
|
|
|
@ -8,6 +8,7 @@ from .Dense import *
|
||||||
from .BlurPool import *
|
from .BlurPool import *
|
||||||
|
|
||||||
from .BatchNorm2D import *
|
from .BatchNorm2D import *
|
||||||
|
from .InstanceNorm2D import *
|
||||||
from .FRNorm2D import *
|
from .FRNorm2D import *
|
||||||
|
|
||||||
from .TLU import *
|
from .TLU import *
|
||||||
|
|
|
@ -108,10 +108,13 @@ nn.gelu = gelu
|
||||||
|
|
||||||
def upsample2d(x, size=2):
|
def upsample2d(x, size=2):
|
||||||
if nn.data_format == "NCHW":
|
if nn.data_format == "NCHW":
|
||||||
b,c,h,w = x.shape.as_list()
|
x = tf.transpose(x, (0,2,3,1))
|
||||||
x = tf.reshape (x, (-1,c,h,1,w,1) )
|
x = tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
|
||||||
x = tf.tile(x, (1,1,1,size,1,size) )
|
x = tf.transpose(x, (0,3,1,2))
|
||||||
x = tf.reshape (x, (-1,c,h*size,w*size) )
|
#b,c,h,w = x.shape.as_list()
|
||||||
|
#x = tf.reshape (x, (-1,c,h,1,w,1) )
|
||||||
|
#x = tf.tile(x, (1,1,1,size,1,size) )
|
||||||
|
#x = tf.reshape (x, (-1,c,h*size,w*size) )
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
|
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue