diff --git a/core/leras/archis/DeepFakeArchi.py b/core/leras/archis/DeepFakeArchi.py index 720b6d3..cd30dc0 100644 --- a/core/leras/archis/DeepFakeArchi.py +++ b/core/leras/archis/DeepFakeArchi.py @@ -335,7 +335,7 @@ class DeepFakeArchi(nn.ArchiBase): elif mod == 'uhd': class Downscale(nn.ModelBase): - def __init__(self, in_ch, out_ch, kernel_size=5, depth_multiplier=1, dilations=1, subpixel=True, use_activator=True, *kwargs ): + def __init__(self, in_ch, out_ch, kernel_size=3, depth_multiplier=1, dilations=1, subpixel=True, use_activator=True, *kwargs ): self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size @@ -367,9 +367,9 @@ class DeepFakeArchi(nn.ArchiBase): def get_out_ch(self): return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch - + class DownscaleBlock(nn.ModelBase): - def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True, use_activator=True): + def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True, use_activator=False): self.downs = [] last_ch = in_ch @@ -377,9 +377,7 @@ class DeepFakeArchi(nn.ArchiBase): cur_ch = ch*( min(2**i, 8) ) self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel, use_activator=use_activator) ) last_ch = self.downs[-1].get_out_ch() - self.bp1 = nn.BlurPool(kernel_size) - def forward(self, inp): x = inp for down in self.downs: @@ -387,6 +385,38 @@ class DeepFakeArchi(nn.ArchiBase): x = self.bp1(x) return x + class DecayingDownscaleBlock(nn.ModelBase): + def on_build(self, in_ch, ch, n_downscales=4, init_kernel_size=5, kernel_floor=3, dilations=1, subpixel=True, use_activator=True): + + self.downs = [] + + if init_kernel_size % 2 == 0: + init_kernel_size += 1 + print("Initial kernel size has been adjusted up by 1 as it was an even number.") + if kernel_floor % 2 == 0: + kernel_floor += 1 + print("Kernel floor has been adjusted up by 1 as it was an even number.") + + if not init_kernel_size > kernel_floor: + raise ValueError("The initial kernel size must be larger than the kernel floor.") + + cur_kernel_size = init_kernel_size + + last_ch = in_ch + + for i in range(n_downscales): + cur_ch = ch*( min(2**i, 8) ) + self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=cur_kernel_size, dilations=dilations, subpixel=subpixel, use_activator=use_activator) ) + last_ch = self.downs[-1].get_out_ch() + if cur_kernel_size != kernel_floor and cur_kernel_size-2 != 1: + cur_kernel_size -= 2 + + def forward(self, inp): + x = inp + for down in self.downs: + x = down(x) + return x + class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3, depth_multiplier=1 ): self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME') @@ -441,17 +471,14 @@ class DeepFakeArchi(nn.ArchiBase): """ class Encoder(nn.ModelBase): def on_build(self, in_ch, e_ch, **kwargs): - self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, use_activator=False, subpixel=False) - #self.down2 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=1, use_activator=False) - #self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2, use_activator=False) - #self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2, use_activator=False) + self.down1 = DecayingDownscaleBlock(in_ch, e_ch*2, n_downscales=6, dilations=1, use_activator=False) + self.down2 = DecayingDownscaleBlock(in_ch, e_ch//2, n_downscales=6, dilations=2, use_activator=False) + def forward(self, inp): - x = nn.flatten(self.down1(inp)) - #x = tf.concat([ nn.flatten(self.down1(inp)), - #nn.flatten(self.down2(inp)) ], -1), - #nn.flatten(self.down3(inp)), - #nn.flatten(self.down4(inp)) ], -1 ) + #x = nn.flatten(self.down1(inp)) + x = tf.concat([ nn.flatten(self.down1(inp)), + nn.flatten(self.down2(inp)) ], -1 ) return x lowest_dense_res = resolution // 16