Made separable optional

This commit is contained in:
TalosOfCrete 2020-06-08 12:53:29 -05:00
commit 00a7a2357a
2 changed files with 187 additions and 100 deletions

View file

@ -14,17 +14,25 @@ class DeepFakeArchi(nn.ArchiBase):
if mod is None:
class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, separable=False, *kwargs ):
self.in_ch = in_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
self.dilations = dilations
self.subpixel = subpixel
self.use_activator = use_activator
self.separable = separable
super().__init__(*kwargs)
def on_build(self, *args, **kwargs ):
self.conv1 = nn.SeparableConv2D( self.in_ch,
if self.separable:
self.conv1 = nn.SeparableConv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations)
else:
self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
@ -42,13 +50,13 @@ class DeepFakeArchi(nn.ArchiBase):
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True, separable=False):
self.downs = []
last_ch = in_ch
for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel, separable=separable) )
last_ch = self.downs[-1].get_out_ch()
def forward(self, inp):
@ -58,8 +66,11 @@ class DeepFakeArchi(nn.ArchiBase):
return x
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def on_build(self, in_ch, out_ch, kernel_size=3, separable=False ):
if separable:
self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
else:
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def forward(self, x):
x = self.conv1(x)
@ -68,9 +79,13 @@ class DeepFakeArchi(nn.ArchiBase):
return x
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def on_build(self, ch, kernel_size=3, separable=False ):
if separable:
self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
else:
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def forward(self, inp):
x = self.conv1(inp)
@ -80,10 +95,10 @@ class DeepFakeArchi(nn.ArchiBase):
return x
class UpdownResidualBlock(nn.ModelBase):
def on_build(self, ch, inner_ch, kernel_size=3 ):
self.up = Upscale (ch, inner_ch, kernel_size=kernel_size)
self.res = ResidualBlock (inner_ch, kernel_size=kernel_size)
self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False)
def on_build(self, ch, inner_ch, kernel_size=3, separable=False ):
self.up = Upscale (ch, inner_ch, kernel_size=kernel_size, separable=separable)
self.res = ResidualBlock (inner_ch, kernel_size=kernel_size, separable=separable)
self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False, separable=separable)
def forward(self, inp):
x = self.up(inp)
@ -94,15 +109,16 @@ class DeepFakeArchi(nn.ArchiBase):
return x, upx
class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch, is_hd):
def on_build(self, in_ch, e_ch, is_hd, separable=False):
self.is_hd=is_hd
self.separable=separable
if self.is_hd:
self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1)
self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1)
self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2)
self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2)
self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1, separable=self.separable)
self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1, separable=self.separable)
self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2, separable=self.separable)
self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2, separable=self.separable)
else:
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False, separable=self.separable)
def forward(self, inp):
if self.is_hd:
@ -117,16 +133,16 @@ class DeepFakeArchi(nn.ArchiBase):
lowest_dense_res = resolution // 16
class Inter(nn.ModelBase):
def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, **kwargs):
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, separable=False **kwargs):
self.in_ch, self.ae_ch, self.ae_out_ch, self.separable = in_ch, ae_ch, ae_out_ch, separable
super().__init__(**kwargs)
def on_build(self):
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
in_ch, ae_ch, ae_out_ch, separable = self.in_ch, self.ae_ch, self.ae_out_ch, self.separable
self.dense1 = nn.Dense( in_ch, ae_ch )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
self.upscale1 = Upscale(ae_out_ch, ae_out_ch, separable)
def forward(self, inp):
x = self.dense1(inp)
@ -146,26 +162,34 @@ class DeepFakeArchi(nn.ArchiBase):
def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ):
self.is_hd = is_hd
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3, separable=separable)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3, separable=separable)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3, separable=separable)
if is_hd:
self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3)
self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3)
self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3)
self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3)
self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3, separable=separable)
self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3, separable=separable)
self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3, separable=separable)
self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3, separable=separable)
else:
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3, separable=separable)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3, separable=separable)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3, separable=separable)
self.out_conv = nn.SeparableConv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
if separable:
self.out_conv = nn.SeparableConv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
else:
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_convm = nn.SeparableConv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3, separable=separable)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3, separable=separable)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3, separable=separable)
if separable:
self.out_convm = nn.SeparableConv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
else:
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
def forward(self, inp):
z = inp
@ -200,17 +224,25 @@ class DeepFakeArchi(nn.ArchiBase):
elif mod == 'quick':
class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, separable=False, *kwargs ):
self.in_ch = in_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
self.dilations = dilations
self.subpixel = subpixel
self.use_activator = use_activator
self.separable = separable
super().__init__(*kwargs)
def on_build(self, *args, **kwargs ):
self.conv1 = nn.SeparableConv2D( self.in_ch,
if self.separable:
self.conv1 = nn.SeparableConv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations )
else:
self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
@ -230,13 +262,13 @@ class DeepFakeArchi(nn.ArchiBase):
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True, separable=False):
self.downs = []
last_ch = in_ch
for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel, separable=separable) )
last_ch = self.downs[-1].get_out_ch()
def forward(self, inp):
@ -246,8 +278,11 @@ class DeepFakeArchi(nn.ArchiBase):
return x
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def on_build(self, in_ch, out_ch, kernel_size=3, separable=False ):
if separable:
self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
else:
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def forward(self, x):
x = self.conv1(x)
@ -256,9 +291,13 @@ class DeepFakeArchi(nn.ArchiBase):
return x
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def on_build(self, ch, kernel_size=3, separable=False ):
if separable:
self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
else:
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def forward(self, inp):
x = self.conv1(inp)
@ -269,25 +308,25 @@ class DeepFakeArchi(nn.ArchiBase):
return x
class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch):
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5)
def on_build(self, in_ch, e_ch, separable=False):
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, separable=separable)
def forward(self, inp):
return nn.flatten(self.down1(inp))
lowest_dense_res = resolution // 16
class Inter(nn.ModelBase):
def __init__(self, in_ch, ae_ch, ae_out_ch, d_ch, **kwargs):
self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, ae_ch, ae_out_ch, d_ch
def __init__(self, in_ch, ae_ch, ae_out_ch, d_ch, separable=False**kwargs):
self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch, self.separable = in_ch, ae_ch, ae_out_ch, d_ch, separable
super().__init__(**kwargs)
def on_build(self):
in_ch, ae_ch, ae_out_ch, d_ch = self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch
in_ch, ae_ch, ae_out_ch, d_ch, separable = self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch, self.separable
self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal )
self.upscale1 = Upscale(ae_out_ch, d_ch*8)
self.res1 = ResidualBlock(d_ch*8)
self.upscale1 = Upscale(ae_out_ch, d_ch*8, separable=separable)
self.res1 = ResidualBlock(d_ch*8, separable=separable)
def forward(self, inp):
x = self.dense1(inp)
@ -301,20 +340,25 @@ class DeepFakeArchi(nn.ArchiBase):
return self.ae_out_ch
class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch):
self.upscale1 = Upscale(in_ch, d_ch*4)
self.res1 = ResidualBlock(d_ch*4)
self.upscale2 = Upscale(d_ch*4, d_ch*2)
self.res2 = ResidualBlock(d_ch*2)
self.upscale3 = Upscale(d_ch*2, d_ch*1)
self.res3 = ResidualBlock(d_ch*1)
def on_build(self, in_ch, d_ch, separable=False):
self.upscale1 = Upscale(in_ch, d_ch*4, separable=separable)
self.res1 = ResidualBlock(d_ch*4, separable=separable)
self.upscale2 = Upscale(d_ch*4, d_ch*2, separable=separable)
self.res2 = ResidualBlock(d_ch*2, separable=separable)
self.upscale3 = Upscale(d_ch*2, d_ch*1, separable=separable)
self.res3 = ResidualBlock(d_ch*1, separable=separable)
self.upscalem1 = Upscale(in_ch, d_ch)
self.upscalem2 = Upscale(d_ch, d_ch//2)
self.upscalem3 = Upscale(d_ch//2, d_ch//2)
self.upscalem1 = Upscale(in_ch, d_ch, separable=separable)
self.upscalem2 = Upscale(d_ch, d_ch//2, separable=separable)
self.upscalem3 = Upscale(d_ch//2, d_ch//2, separable=separable)
self.out_conv = nn.SeparableConv2D( d_ch*1, 3, kernel_size=1, padding='SAME')
self.out_convm = nn.SeparableConv2D( d_ch//2, 1, kernel_size=1, padding='SAME')
if separable:
self.out_conv = nn.SeparableConv2D( d_ch*1, 3, kernel_size=1, padding='SAME')
self.out_convm = nn.SeparableConv2D( d_ch//2, 1, kernel_size=1, padding='SAME')
else:
self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME')
self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME')
def forward(self, inp):
z = inp
@ -334,17 +378,25 @@ class DeepFakeArchi(nn.ArchiBase):
elif mod == 'uhd':
class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, separable=False, *kwargs ):
self.in_ch = in_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
self.dilations = dilations
self.subpixel = subpixel
self.use_activator = use_activator
self.separable = separable
super().__init__(*kwargs)
def on_build(self, *args, **kwargs ):
self.conv1 = nn.SeparableConv2D( self.in_ch,
if self.separable:
self.conv1 = nn.SeparableConv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations)
else:
self.conv1 = nn.Conv2D( self.in_ch,
self.out_ch // (4 if self.subpixel else 1),
kernel_size=self.kernel_size,
strides=1 if self.subpixel else 2,
@ -362,13 +414,13 @@ class DeepFakeArchi(nn.ArchiBase):
return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch
class DownscaleBlock(nn.ModelBase):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True):
def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True, separable=False):
self.downs = []
last_ch = in_ch
for i in range(n_downscales):
cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel, separable=separable) )
last_ch = self.downs[-1].get_out_ch()
def forward(self, inp):
@ -378,8 +430,11 @@ class DeepFakeArchi(nn.ArchiBase):
return x
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def on_build(self, in_ch, out_ch, kernel_size=3, separable=False ):
if separable:
self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
else:
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def forward(self, x):
x = self.conv1(x)
@ -388,9 +443,13 @@ class DeepFakeArchi(nn.ArchiBase):
return x
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def on_build(self, ch, kernel_size=3, separable=False ):
if separable:
self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
else:
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def forward(self, inp):
x = self.conv1(inp)
@ -400,8 +459,8 @@ class DeepFakeArchi(nn.ArchiBase):
return x
class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch, **kwargs):
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False)
def on_build(self, in_ch, e_ch, separable=False, **kwargs):
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False, separable=separable)
def forward(self, inp):
x = nn.flatten(self.down1(inp))
@ -410,12 +469,12 @@ class DeepFakeArchi(nn.ArchiBase):
lowest_dense_res = resolution // 16
class Inter(nn.ModelBase):
def on_build(self, in_ch, ae_ch, ae_out_ch, **kwargs):
def on_build(self, in_ch, ae_ch, ae_out_ch, separable=False, **kwargs):
self.ae_out_ch = ae_out_ch
self.dense_norm = nn.DenseNorm()
self.dense1 = nn.Dense( in_ch, ae_ch )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
self.upscale1 = Upscale(ae_out_ch, ae_out_ch, separable=separable)
def forward(self, inp):
x = self.dense_norm(inp)
@ -433,22 +492,30 @@ class DeepFakeArchi(nn.ArchiBase):
return self.ae_out_ch
class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch, d_mask_ch, **kwargs ):
def on_build(self, in_ch, d_ch, d_mask_ch, separable=False, **kwargs ):
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3, separable=separable)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3, separable=separable)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3, separable=separable)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3, separable=separable)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3, separable=separable)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3, separable=separable)
self.out_conv = nn.SeparableConv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
if separable:
self.out_conv = nn.SeparableConv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
else:
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_convm = nn.SeparableConv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3, separable=separable)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3, separable=separable)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3, separable=separable)
if separable:
self.out_convm = nn.SeparableConv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
else:
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME')
def forward(self, inp):
z = inp

View file

@ -32,6 +32,10 @@ class SAEHDModel(ModelBase):
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
default_archi = self.options['archi'] = self.load_or_def_option('archi', 'df')
default_separable = self.options['separable'] = self.load_or_def_option('separable', False)
default_separable_enc = self.options['separable_enc'] = self.load_or_def_option('separable_enc', False)
default_separable_inter = self.options['separable_inter'] = self.load_or_def_option('separable_inter', False)
default_separable_dec = self.options['separable_dec'] = self.load_or_def_option('separable_dec', False)
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None)
@ -60,8 +64,18 @@ class SAEHDModel(ModelBase):
resolution = io.input_int("Resolution", default_resolution, add_info="64-512", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
resolution = np.clip ( (resolution // 16) * 16, 64, 512)
self.options['resolution'] = resolution
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower()
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face including forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower()
self.options['archi'] = io.input_str ("AE architecture", default_archi, ['df','liae','dfhd','liaehd','dfuhd','liaeuhd'], help_message="'df' keeps faces more natural.\n'liae' can fix overly different face shapes.\n'hd' are experimental versions.").lower()
self.options['separable'] = io.input_bool ("Use depthwise separable convolutions throughout model", default_separable, help_message="Use lighter and more effecient layers in the encoder, bottleneck, and decoder. This speeds up iterations (~15-200%) and reduces memory usage considerably in the encoder and decoder (~60%) - allowing for increased settings or faster training - but it may take more iterations for the model to become good and possibly hurt quality. Set this to false if you wish to only have these layers in certain parts of the model (e.g. only in the encoder) or if you do not want them at all.").lower()
if not self.options['separable']:
self.options['separable_enc'] = io.input_bool ("Use depthwise separable convolutions in the encoder", default_separable_enc, help_message="This is the part of the model second most impacted by using these more efficient layers in terms of better iteration speed, memory savings, possible quality loss, and possible increase in required iterations as it may have anywhere from one to four downscaling operations for UHD and HD architectures respectively.").lower()
self.options['separable_inter'] = io.input_bool ("Use depthwise separable convolutions in the bottleneck", default_separable_inter, help_message="This is the part of the model least impacted by using these more efficient layers in terms of better iteration speed, memory savings, possible quality loss, and possible increase in required iterations as it only has a single upscaling operation regardless of architecture.").lower()
self.options['separable_dec'] = io.input_bool ("Use depthwise separable convolutions in the decoder", default_separable_dec, help_message="This is the part of the model most impacted by using these more efficient layers in terms of better iteration speed, memory savings, possible quality loss, and possible increase in required iterations as it has by far the most operations that these layers use out of all parts of the model.").lower()
else:
self.options['separable_enc'] = True
self.options['separable_inter'] = True
self.options['separable_dec'] = True
default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims)
@ -133,6 +147,12 @@ class SAEHDModel(ModelBase):
eyes_prio = self.options['eyes_prio']
archi = self.options['archi']
is_hd = 'hd' in archi
separable_enc = self.options['separable_enc']
separable_inter = self.options['separable_inter']
separable_dec = self.options['separable_dec']
ae_dims = self.options['ae_dims']
e_dims = self.options['e_dims']
d_dims = self.options['d_dims']
@ -173,14 +193,14 @@ class SAEHDModel(ModelBase):
with tf.device (models_opt_device):
if 'df' in archi:
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder')
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, separable=separable_enc, name='encoder')
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, name='inter')
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, is_hd=is_hd, separable=separable_inter, name='inter')
inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_src')
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder_dst')
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, separable=separable_dec, name='decoder_src')
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, separable=separable_dec, name='decoder_dst')
self.model_filename_list += [ [self.encoder, 'encoder.npy' ],
[self.inter, 'inter.npy' ],
@ -193,16 +213,16 @@ class SAEHDModel(ModelBase):
self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ]
elif 'liae' in archi:
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, name='encoder')
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, is_hd=is_hd, separable=separable_enc, name='encoder')
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_AB')
self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, name='inter_B')
self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, separable=separable_inter, name='inter_AB')
self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, is_hd=is_hd, separable=separable_inter, name='inter_B')
inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
inters_out_ch = inter_AB_out_ch+inter_B_out_ch
self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, name='decoder')
self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, is_hd=is_hd, separable=separable_dec, name='decoder')
self.model_filename_list += [ [self.encoder, 'encoder.npy'],
[self.inter_AB, 'inter_AB.npy'],