Made separable optional

This commit is contained in:
TalosOfCrete 2020-06-08 12:53:29 -05:00
commit 00a7a2357a
2 changed files with 187 additions and 100 deletions

View file

@ -14,17 +14,25 @@ class DeepFakeArchi(nn.ArchiBase):
if mod is None: if mod is None:
class Downscale(nn.ModelBase): class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, separable=False, *kwargs ):
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dilations = dilations self.dilations = dilations
self.subpixel = subpixel self.subpixel = subpixel
self.use_activator = use_activator self.use_activator = use_activator
self.separable = separable
super().__init__(*kwargs) super().__init__(*kwargs)
def on_build(self, *args, **kwargs ): def on_build(self, *args, **kwargs ):
self.conv1 = nn.SeparableConv2D( self.in_ch, if self.separable:
self.conv1 = nn.SeparableConv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations)
else:
self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1), self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2, strides=1 if self.subpixel else 2,
@ -42,13 +50,13 @@ class DeepFakeArchi(nn.ArchiBase):
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
class DownscaleBlock(nn.ModelBase): class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True, separable=False):
self.downs = [] self.downs = []
last_ch = in_ch last_ch = in_ch
for i in range(n_downscales): for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) ) cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel, separable=separable) )
last_ch = self.downs[-1].get_out_ch() last_ch = self.downs[-1].get_out_ch()
def forward(self, inp): def forward(self, inp):
@ -58,8 +66,11 @@ class DeepFakeArchi(nn.ArchiBase):
return x return x
class Upscale(nn.ModelBase): class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ): def on_build(self, in_ch, out_ch, kernel_size=3, separable=False ):
self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') if separable:
self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
else:
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
@ -68,9 +79,13 @@ class DeepFakeArchi(nn.ArchiBase):
return x return x
class ResidualBlock(nn.ModelBase): class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ): def on_build(self, ch, kernel_size=3, separable=False ):
self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME') if separable:
self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME') self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
else:
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def forward(self, inp): def forward(self, inp):
x = self.conv1(inp) x = self.conv1(inp)
@ -80,10 +95,10 @@ class DeepFakeArchi(nn.ArchiBase):
return x return x
class UpdownResidualBlock(nn.ModelBase): class UpdownResidualBlock(nn.ModelBase):
def on_build(self, ch, inner_ch, kernel_size=3 ): def on_build(self, ch, inner_ch, kernel_size=3, separable=False ):
self.up = Upscale (ch, inner_ch, kernel_size=kernel_size) self.up = Upscale (ch, inner_ch, kernel_size=kernel_size, separable=separable)
self.res = ResidualBlock (inner_ch, kernel_size=kernel_size) self.res = ResidualBlock (inner_ch, kernel_size=kernel_size, separable=separable)
self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False) self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False, separable=separable)
def forward(self, inp): def forward(self, inp):
x = self.up(inp) x = self.up(inp)
@ -94,15 +109,16 @@ class DeepFakeArchi(nn.ArchiBase):
return x, upx return x, upx
class Encoder(nn.ModelBase): class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch, is_hd): def on_build(self, in_ch, e_ch, is_hd, separable=False):
self.is_hd=is_hd self.is_hd=is_hd
self.separable=separable
if self.is_hd: if self.is_hd:
self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1) self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1, separable=self.separable)
self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1) self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1, separable=self.separable)
self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2) self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2, separable=self.separable)
self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2) self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2, separable=self.separable)
else: else:
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False) self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False, separable=self.separable)
def forward(self, inp): def forward(self, inp):
if self.is_hd: if self.is_hd:
@ -117,16 +133,16 @@ class DeepFakeArchi(nn.ArchiBase):
lowest_dense_res = resolution // 16 lowest_dense_res = resolution // 16
class Inter(nn.ModelBase): class Inter(nn.ModelBase):
def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, **kwargs): def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, separable=False **kwargs):
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch self.in_ch, self.ae_ch, self.ae_out_ch, self.separable = in_ch, ae_ch, ae_out_ch, separable
super().__init__(**kwargs) super().__init__(**kwargs)
def on_build(self): def on_build(self):
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch in_ch, ae_ch, ae_out_ch, separable = self.in_ch, self.ae_ch, self.ae_out_ch, self.separable
self.dense1 = nn.Dense( in_ch, ae_ch ) self.dense1 = nn.Dense( in_ch, ae_ch )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
self.upscale1 = Upscale(ae_out_ch, ae_out_ch) self.upscale1 = Upscale(ae_out_ch, ae_out_ch, separable)
def forward(self, inp): def forward(self, inp):
x = self.dense1(inp) x = self.dense1(inp)
@ -146,26 +162,34 @@ class DeepFakeArchi(nn.ArchiBase):
def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ): def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ):
self.is_hd = is_hd self.is_hd = is_hd
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3, separable=separable)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3, separable=separable)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3, separable=separable)
if is_hd: if is_hd:
self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3) self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3, separable=separable)
self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3) self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3, separable=separable)
self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3) self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3, separable=separable)
self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3) self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3, separable=separable)
else: else:
self.res0 = ResidualBlock(d_ch*8, kernel_size=3) self.res0 = ResidualBlock(d_ch*8, kernel_size=3, separable=separable)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3) self.res1 = ResidualBlock(d_ch*4, kernel_size=3, separable=separable)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3) self.res2 = ResidualBlock(d_ch*2, kernel_size=3, separable=separable)
self.out_conv = nn.SeparableConv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) if separable:
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) self.out_conv = nn.SeparableConv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) else:
self.out_convm = nn.SeparableConv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3, separable=separable)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3, separable=separable)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3, separable=separable)
if separable:
self.out_convm = nn.SeparableConv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
else:
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
def forward(self, inp): def forward(self, inp):
z = inp z = inp
@ -200,17 +224,25 @@ class DeepFakeArchi(nn.ArchiBase):
elif mod == 'quick': elif mod == 'quick':
class Downscale(nn.ModelBase): class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, separable=False, *kwargs ):
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dilations = dilations self.dilations = dilations
self.subpixel = subpixel self.subpixel = subpixel
self.use_activator = use_activator self.use_activator = use_activator
self.separable = separable
super().__init__(*kwargs) super().__init__(*kwargs)
def on_build(self, *args, **kwargs ): def on_build(self, *args, **kwargs ):
self.conv1 = nn.SeparableConv2D( self.in_ch, if self.separable:
self.conv1 = nn.SeparableConv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations )
else:
self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1), self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2, strides=1 if self.subpixel else 2,
@ -230,13 +262,13 @@ class DeepFakeArchi(nn.ArchiBase):
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
class DownscaleBlock(nn.ModelBase): class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True, separable=False):
self.downs = [] self.downs = []
last_ch = in_ch last_ch = in_ch
for i in range(n_downscales): for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) ) cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel, separable=separable) )
last_ch = self.downs[-1].get_out_ch() last_ch = self.downs[-1].get_out_ch()
def forward(self, inp): def forward(self, inp):
@ -246,8 +278,11 @@ class DeepFakeArchi(nn.ArchiBase):
return x return x
class Upscale(nn.ModelBase): class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ): def on_build(self, in_ch, out_ch, kernel_size=3, separable=False ):
self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') if separable:
self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
else:
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
@ -256,9 +291,13 @@ class DeepFakeArchi(nn.ArchiBase):
return x return x
class ResidualBlock(nn.ModelBase): class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ): def on_build(self, ch, kernel_size=3, separable=False ):
self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME') if separable:
self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME') self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
else:
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def forward(self, inp): def forward(self, inp):
x = self.conv1(inp) x = self.conv1(inp)
@ -269,25 +308,25 @@ class DeepFakeArchi(nn.ArchiBase):
return x return x
class Encoder(nn.ModelBase): class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch): def on_build(self, in_ch, e_ch, separable=False):
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5) self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, separable=separable)
def forward(self, inp): def forward(self, inp):
return nn.flatten(self.down1(inp)) return nn.flatten(self.down1(inp))
lowest_dense_res = resolution // 16 lowest_dense_res = resolution // 16
class Inter(nn.ModelBase): class Inter(nn.ModelBase):
def __init__(self, in_ch, ae_ch, ae_out_ch, d_ch, **kwargs): def __init__(self, in_ch, ae_ch, ae_out_ch, d_ch, separable=False**kwargs):
self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, ae_ch, ae_out_ch, d_ch self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch, self.separable = in_ch, ae_ch, ae_out_ch, d_ch, separable
super().__init__(**kwargs) super().__init__(**kwargs)
def on_build(self): def on_build(self):
in_ch, ae_ch, ae_out_ch, d_ch = self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch in_ch, ae_ch, ae_out_ch, d_ch, separable = self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch, self.separable
self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal ) self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal ) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal )
self.upscale1 = Upscale(ae_out_ch, d_ch*8) self.upscale1 = Upscale(ae_out_ch, d_ch*8, separable=separable)
self.res1 = ResidualBlock(d_ch*8) self.res1 = ResidualBlock(d_ch*8, separable=separable)
def forward(self, inp): def forward(self, inp):
x = self.dense1(inp) x = self.dense1(inp)
@ -301,20 +340,25 @@ class DeepFakeArchi(nn.ArchiBase):
return self.ae_out_ch return self.ae_out_ch
class Decoder(nn.ModelBase): class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch): def on_build(self, in_ch, d_ch, separable=False):
self.upscale1 = Upscale(in_ch, d_ch*4) self.upscale1 = Upscale(in_ch, d_ch*4, separable=separable)
self.res1 = ResidualBlock(d_ch*4) self.res1 = ResidualBlock(d_ch*4, separable=separable)
self.upscale2 = Upscale(d_ch*4, d_ch*2) self.upscale2 = Upscale(d_ch*4, d_ch*2, separable=separable)
self.res2 = ResidualBlock(d_ch*2) self.res2 = ResidualBlock(d_ch*2, separable=separable)
self.upscale3 = Upscale(d_ch*2, d_ch*1) self.upscale3 = Upscale(d_ch*2, d_ch*1, separable=separable)
self.res3 = ResidualBlock(d_ch*1) self.res3 = ResidualBlock(d_ch*1, separable=separable)
self.upscalem1 = Upscale(in_ch, d_ch) self.upscalem1 = Upscale(in_ch, d_ch, separable=separable)
self.upscalem2 = Upscale(d_ch, d_ch//2) self.upscalem2 = Upscale(d_ch, d_ch//2, separable=separable)
self.upscalem3 = Upscale(d_ch//2, d_ch//2) self.upscalem3 = Upscale(d_ch//2, d_ch//2, separable=separable)
self.out_conv = nn.SeparableConv2D( d_ch*1, 3, kernel_size=1, padding='SAME')
self.out_convm = nn.SeparableConv2D( d_ch//2, 1, kernel_size=1, padding='SAME') if separable:
self.out_conv = nn.SeparableConv2D( d_ch*1, 3, kernel_size=1, padding='SAME')
self.out_convm = nn.SeparableConv2D( d_ch//2, 1, kernel_size=1, padding='SAME')
else:
self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME')
self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME')
def forward(self, inp): def forward(self, inp):
z = inp z = inp
@ -334,17 +378,25 @@ class DeepFakeArchi(nn.ArchiBase):
elif mod == 'uhd': elif mod == 'uhd':
class Downscale(nn.ModelBase): class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, separable=False, *kwargs ):
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dilations = dilations self.dilations = dilations
self.subpixel = subpixel self.subpixel = subpixel
self.use_activator = use_activator self.use_activator = use_activator
self.separable = separable
super().__init__(*kwargs) super().__init__(*kwargs)
def on_build(self, *args, **kwargs ): def on_build(self, *args, **kwargs ):
self.conv1 = nn.SeparableConv2D( self.in_ch, if self.separable:
self.conv1 = nn.SeparableConv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations)
else:
self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1), self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2, strides=1 if self.subpixel else 2,
@ -362,13 +414,13 @@ class DeepFakeArchi(nn.ArchiBase):
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
class DownscaleBlock(nn.ModelBase): class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True, separable=False):
self.downs = [] self.downs = []
last_ch = in_ch last_ch = in_ch
for i in range(n_downscales): for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) ) cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel, separable=separable) )
last_ch = self.downs[-1].get_out_ch() last_ch = self.downs[-1].get_out_ch()
def forward(self, inp): def forward(self, inp):
@ -378,8 +430,11 @@ class DeepFakeArchi(nn.ArchiBase):
return x return x
class Upscale(nn.ModelBase): class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ): def on_build(self, in_ch, out_ch, kernel_size=3, separable=False ):
self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') if separable:
self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
else:
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
@ -388,9 +443,13 @@ class DeepFakeArchi(nn.ArchiBase):
return x return x
class ResidualBlock(nn.ModelBase): class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ): def on_build(self, ch, kernel_size=3, separable=False ):
self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME') if separable:
self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME') self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
else:
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def forward(self, inp): def forward(self, inp):
x = self.conv1(inp) x = self.conv1(inp)
@ -400,8 +459,8 @@ class DeepFakeArchi(nn.ArchiBase):
return x return x
class Encoder(nn.ModelBase): class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch, **kwargs): def on_build(self, in_ch, e_ch, separable=False, **kwargs):
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False) self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False, separable=separable)
def forward(self, inp): def forward(self, inp):
x = nn.flatten(self.down1(inp)) x = nn.flatten(self.down1(inp))
@ -410,12 +469,12 @@ class DeepFakeArchi(nn.ArchiBase):
lowest_dense_res = resolution // 16 lowest_dense_res = resolution // 16
class Inter(nn.ModelBase): class Inter(nn.ModelBase):
def on_build(self, in_ch, ae_ch, ae_out_ch, **kwargs): def on_build(self, in_ch, ae_ch, ae_out_ch, separable=False, **kwargs):
self.ae_out_ch = ae_out_ch self.ae_out_ch = ae_out_ch
self.dense_norm = nn.DenseNorm() self.dense_norm = nn.DenseNorm()
self.dense1 = nn.Dense( in_ch, ae_ch ) self.dense1 = nn.Dense( in_ch, ae_ch )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
self.upscale1 = Upscale(ae_out_ch, ae_out_ch) self.upscale1 = Upscale(ae_out_ch, ae_out_ch, separable=separable)
def forward(self, inp): def forward(self, inp):
x = self.dense_norm(inp) x = self.dense_norm(inp)
@ -433,22 +492,30 @@ class DeepFakeArchi(nn.ArchiBase):
return self.ae_out_ch return self.ae_out_ch
class Decoder(nn.ModelBase): class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch, d_mask_ch, **kwargs ): def on_build(self, in_ch, d_ch, d_mask_ch, separable=False, **kwargs ):
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3, separable=separable)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3, separable=separable)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3, separable=separable)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3) self.res0 = ResidualBlock(d_ch*8, kernel_size=3, separable=separable)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3) self.res1 = ResidualBlock(d_ch*4, kernel_size=3, separable=separable)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3) self.res2 = ResidualBlock(d_ch*2, kernel_size=3, separable=separable)
self.out_conv = nn.SeparableConv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) if separable:
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) self.out_conv = nn.SeparableConv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) else:
self.out_convm = nn.SeparableConv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3, separable=separable)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3, separable=separable)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3, separable=separable)
if separable:
self.out_convm = nn.SeparableConv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
else:
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
def forward(self, inp): def forward(self, inp):
z = inp z = inp

View file

@ -32,6 +32,10 @@ class SAEHDModel(ModelBase):
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f') default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
default_archi = self.options['archi'] = self.load_or_def_option('archi', 'df') default_archi = self.options['archi'] = self.load_or_def_option('archi', 'df')
default_separable = self.options['separable'] = self.load_or_def_option('separable', False)
default_separable_enc = self.options['separable_enc'] = self.load_or_def_option('separable_enc', False)
default_separable_inter = self.options['separable_inter'] = self.load_or_def_option('separable_inter', False)
default_separable_dec = self.options['separable_dec'] = self.load_or_def_option('separable_dec', False)
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None)
@ -60,8 +64,18 @@ class SAEHDModel(ModelBase):
resolution = io.input_int("Resolution", default_resolution, add_info="64-512", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.") resolution = io.input_int("Resolution", default_resolution, add_info="64-512", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
resolution = np.clip ( (resolution // 16) * 16, 64, 512) resolution = np.clip ( (resolution // 16) * 16, 64, 512)
self.options['resolution'] = resolution self.options['resolution'] = resolution
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower() self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face including forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower()
self.options['archi'] = io.input_str ("AE architecture", default_archi, ['df','liae','dfhd','liaehd','dfuhd','liaeuhd'], help_message="'df' keeps faces more natural.\n'liae' can fix overly different face shapes.\n'hd' are experimental versions.").lower() self.options['archi'] = io.input_str ("AE architecture", default_archi, ['df','liae','dfhd','liaehd','dfuhd','liaeuhd'], help_message="'df' keeps faces more natural.\n'liae' can fix overly different face shapes.\n'hd' are experimental versions.").lower()
self.options['separable'] = io.input_bool ("Use depthwise separable convolutions throughout model", default_separable, help_message="Use lighter and more effecient layers in the encoder, bottleneck, and decoder. This speeds up iterations (~15-200%) and reduces memory usage considerably in the encoder and decoder (~60%) - allowing for increased settings or faster training - but it may take more iterations for the model to become good and possibly hurt quality. Set this to false if you wish to only have these layers in certain parts of the model (e.g. only in the encoder) or if you do not want them at all.").lower()
if not self.options['separable']:
self.options['separable_enc'] = io.input_bool ("Use depthwise separable convolutions in the encoder", default_separable_enc, help_message="This is the part of the model second most impacted by using these more efficient layers in terms of better iteration speed, memory savings, possible quality loss, and possible increase in required iterations as it may have anywhere from one to four downscaling operations for UHD and HD architectures respectively.").lower()
self.options['separable_inter'] = io.input_bool ("Use depthwise separable convolutions in the bottleneck", default_separable_inter, help_message="This is the part of the model least impacted by using these more efficient layers in terms of better iteration speed, memory savings, possible quality loss, and possible increase in required iterations as it only has a single upscaling operation regardless of architecture.").lower()
self.options['separable_dec'] = io.input_bool ("Use depthwise separable convolutions in the decoder", default_separable_dec, help_message="This is the part of the model most impacted by using these more efficient layers in terms of better iteration speed, memory savings, possible quality loss, and possible increase in required iterations as it has by far the most operations that these layers use out of all parts of the model.").lower()
else:
self.options['separable_enc'] = True
self.options['separable_inter'] = True
self.options['separable_dec'] = True
default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64 default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims) default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims)
@ -133,6 +147,12 @@ class SAEHDModel(ModelBase):
eyes_prio = self.options['eyes_prio'] eyes_prio = self.options['eyes_prio']
archi = self.options['archi'] archi = self.options['archi']
is_hd = 'hd' in archi is_hd = 'hd' in archi
separable_enc = self.options['separable_enc']
separable_inter = self.options['separable_inter']
separable_dec = self.options['separable_dec']
ae_dims = self.options['ae_dims'] ae_dims = self.options['ae_dims']
e_dims = self.options['e_dims'] e_dims = self.options['e_dims']
d_dims = self.options['d_dims'] d_dims = self.options['d_dims']
@ -173,14 +193,14 @@ class SAEHDModel(ModelBase):
with tf.device (models_opt_device): with tf.device (models_opt_device):
if 'df' in archi: if 'df' in archi:
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder') self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, separable=separable_enc, name='encoder')
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, name='inter') self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, separable=separable_inter, name='inter')
inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_src') self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, separable=separable_dec, name='decoder_src')
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_dst') self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, separable=separable_dec, name='decoder_dst')
self.model_filename_list += [ [self.encoder, 'encoder.npy' ], self.model_filename_list += [ [self.encoder, 'encoder.npy' ],
[self.inter, 'inter.npy' ], [self.inter, 'inter.npy' ],
@ -193,16 +213,16 @@ class SAEHDModel(ModelBase):
self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ] self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ]
elif 'liae' in archi: elif 'liae' in archi:
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder') self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, separable=separable_enc, name='encoder')
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_AB') self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, separable=separable_inter, name='inter_AB')
self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_B') self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, separable=separable_inter, name='inter_B')
inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
inters_out_ch = inter_AB_out_ch+inter_B_out_ch inters_out_ch = inter_AB_out_ch+inter_B_out_ch
self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder') self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, separable=separable_dec, name='decoder')
self.model_filename_list += [ [self.encoder, 'encoder.npy'], self.model_filename_list += [ [self.encoder, 'encoder.npy'],
[self.inter_AB, 'inter_AB.npy'], [self.inter_AB, 'inter_AB.npy'],