diff --git a/core/leras/archis/DeepFakeArchi.py b/core/leras/archis/DeepFakeArchi.py index 324d028..482b290 100644 --- a/core/leras/archis/DeepFakeArchi.py +++ b/core/leras/archis/DeepFakeArchi.py @@ -352,14 +352,20 @@ class DeepFakeArchi(nn.ArchiBase): depth_multiplier=self.depth_multiplier, strides=1 if self.subpixel else 2, padding='SAME', dilations=self.dilations) - #self.frn1 = nn.FRNorm2D(self.out_ch//(4 if self.subpixel else 1)) - #self.tlu1 = nn.TLU(self.out_ch//(4 if self.subpixel else 1)) + self.frn1 = nn.FRNorm2D(self.out_ch//(4 if self.subpixel else 1)) + self.tlu1 = nn.TLU(self.out_ch//(4 if self.subpixel else 1)) def forward(self, x): x = self.conv1(x) - #x = self.frn1(x) - #x = self.tlu1(x) + x = self.frn1(x) + x = self.tlu1(x) if self.subpixel: + if (x.get_shape().as_list()[-2] % 2 != 0): #or (x.get_shape().as_list()[-1] % 2 != 0): + #padding = self.kernel_size//2 + if nn.data_format == "NHWC": + x = tf.pad(x, [ [0,0], [1,0], [1,0], [0,0] ]) + else: + x = tf.pad(x, [ [0,0], [0,0], [1,0], [1,0] ]) x = nn.space_to_depth(x, 2) if self.use_activator: x = tf.nn.leaky_relu(x, 0.1) @@ -386,7 +392,7 @@ class DeepFakeArchi(nn.ArchiBase): return x class DecayingDownscaleBlock(nn.ModelBase): - def on_build(self, in_ch, ch, n_downscales=4, init_kernel_size=5, kernel_floor=3, dilations=1, subpixel=True, use_activator=True): + def on_build(self, in_ch, ch, n_downscales=4, init_kernel_size=5, kernel_floor=3, alternating_dilations=True, dilations=1, subpixel=True, use_activator=True): self.downs = [] @@ -399,6 +405,8 @@ class DeepFakeArchi(nn.ArchiBase): if not init_kernel_size > kernel_floor: raise ValueError("The initial kernel size must be larger than the kernel floor.") + + d = 0 cur_kernel_size = init_kernel_size @@ -406,50 +414,65 @@ class DeepFakeArchi(nn.ArchiBase): for i in range(n_downscales): cur_ch = ch*( min(2**i, 8) ) + + if (d % 2 == 0) and alternating_dilations: + dil = True + d += 1 + else: + dil = False + + if dil: + dilations=dilations + else: + dilations=1 + self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=cur_kernel_size, dilations=dilations, subpixel=subpixel, use_activator=use_activator) ) last_ch = self.downs[-1].get_out_ch() + if cur_kernel_size != kernel_floor and cur_kernel_size-2 != 1: cur_kernel_size -= 2 + self.bp1 = nn.BlurPool() + def forward(self, inp): x = inp for down in self.downs: x = down(x) + x = self.bp1(x) return x class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3, depth_multiplier=1 ): self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME') - #self.frn1 = nn.FRNorm2D(out_ch*4) - #self.tlu1 = nn.TLU(out_ch*4) + self.frn1 = nn.FRNorm2D(out_ch*4) + self.tlu1 = nn.TLU(out_ch*4) def forward(self, x): x = self.conv1(x) - #x = self.frn1(x) - #x = self.tlu1(x) - x = tf.nn.leaky_relu(x, 0.1)# (TLU replaces relu) + x = self.frn1(x) + x = self.tlu1(x) + #x = tf.nn.leaky_relu(x, 0.1) (TLU replaces relu) x = nn.depth_to_space(x, 2) return x class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3, depth_multiplier=1 ): self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME') - #self.frn1 = nn.FRNorm2D(ch) - #self.tlu1 = nn.TLU(ch) + self.frn1 = nn.FRNorm2D(ch) + self.tlu1 = nn.TLU(ch) self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME') - #self.frn2 = nn.FRNorm2D(ch) - #self.tlu2 = nn.TLU(ch) + self.frn2 = nn.FRNorm2D(ch) + self.tlu2 = nn.TLU(ch) def forward(self, inp): x = self.conv1(inp) - x = tf.nn.leaky_relu(x, 0.2) + #x = tf.nn.leaky_relu(x, 0.2) x = self.frn1(x) x = self.tlu1(x) x = self.conv2(x) x = self.frn2(x) - x = tf.nn.leaky_relu(inp + x, 0.2) - #x = self.tlu2(inp + x) + x = self.tlu2(inp + x) return x """ class UpdownResidualBlock(nn.ModelBase): @@ -472,7 +495,7 @@ class DeepFakeArchi(nn.ArchiBase): """ class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch, **kwargs): - self.down1 = DownscaleBlock(in_ch, e_ch, kernel_size=5, n_downscales=4, dilations=1) + self.down1 = DecayingDownscaleBlock(in_ch, e_ch, init_kernel_size=7, n_downscales=8, dilations=2, use_activator=False) #self.down2 = DecayingDownscaleBlock(in_ch, e_ch//2, n_downscales=6, dilations=2, use_activator=False)