diff --git a/core/leras/archis/DeepFakeArchi.py b/core/leras/archis/DeepFakeArchi.py index 5dfd293..e630624 100644 --- a/core/leras/archis/DeepFakeArchi.py +++ b/core/leras/archis/DeepFakeArchi.py @@ -21,6 +21,13 @@ class DeepFakeArchi(nn.ArchiBase): conv_dtype = tf.float16 if use_fp16 else tf.float32 + if 'c' in opts: + def act(x, alpha=0.1): + return tf.nn.relu(x) + else: + def act(x, alpha=0.1): + return tf.nn.leaky_relu(x, alpha) + if mod is None: class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ): @@ -34,7 +41,7 @@ class DeepFakeArchi(nn.ArchiBase): def forward(self, x): x = self.conv1(x) - x = tf.nn.leaky_relu(x, 0.1) + x = act(x, 0.1) return x def get_out_ch(self): @@ -62,7 +69,7 @@ class DeepFakeArchi(nn.ArchiBase): def forward(self, x): x = self.conv1(x) - x = tf.nn.leaky_relu(x, 0.1) + x = act(x, 0.1) x = nn.depth_to_space(x, 2) return x @@ -73,9 +80,9 @@ class DeepFakeArchi(nn.ArchiBase): def forward(self, inp): x = self.conv1(inp) - x = tf.nn.leaky_relu(x, 0.2) + x = act(x, 0.2) x = self.conv2(x) - x = tf.nn.leaky_relu(inp + x, 0.2) + x = act(inp + x, 0.2) return x class Encoder(nn.ModelBase): diff --git a/core/leras/layers/__init__.py b/core/leras/layers/__init__.py index d8f1c9d..5c9d43b 100644 --- a/core/leras/layers/__init__.py +++ b/core/leras/layers/__init__.py @@ -8,6 +8,7 @@ from .Dense import * from .BlurPool import * from .BatchNorm2D import * +from .InstanceNorm2D import * from .FRNorm2D import * from .TLU import * diff --git a/core/leras/ops/__init__.py b/core/leras/ops/__init__.py index e0c4f5b..3bb8688 100644 --- a/core/leras/ops/__init__.py +++ b/core/leras/ops/__init__.py @@ -108,10 +108,13 @@ nn.gelu = gelu def upsample2d(x, size=2): if nn.data_format == "NCHW": - b,c,h,w = x.shape.as_list() - x = tf.reshape (x, (-1,c,h,1,w,1) ) - x = tf.tile(x, (1,1,1,size,1,size) ) - x = tf.reshape (x, (-1,c,h*size,w*size) ) + x = tf.transpose(x, (0,2,3,1)) + x = tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) ) + x = tf.transpose(x, (0,3,1,2)) + #b,c,h,w = x.shape.as_list() + #x = tf.reshape (x, (-1,c,h,1,w,1) ) + #x = tf.tile(x, (1,1,1,size,1,size) ) + #x = tf.reshape (x, (-1,c,h*size,w*size) ) return x else: return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )