Temporary fix

This commit is contained in:
TalosOfCrete 2020-05-21 21:00:09 -05:00
commit 30b4c4c679

View file

@ -352,13 +352,13 @@ class DeepFakeArchi(nn.ArchiBase):
depth_multiplier=self.depth_multiplier, depth_multiplier=self.depth_multiplier,
strides=1 if self.subpixel else 2, strides=1 if self.subpixel else 2,
padding='SAME', dilations=self.dilations) padding='SAME', dilations=self.dilations)
self.frn1 = nn.FRNorm2D(self.out_ch//(4 if self.subpixel else 1)) #self.frn1 = nn.FRNorm2D(self.out_ch//(4 if self.subpixel else 1))
self.tlu1 = nn.TLU(self.out_ch//(4 if self.subpixel else 1)) #self.tlu1 = nn.TLU(self.out_ch//(4 if self.subpixel else 1))
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = self.frn1(x) #x = self.frn1(x)
x = self.tlu1(x) #x = self.tlu1(x)
if self.subpixel: if self.subpixel:
x = nn.space_to_depth(x, 2) x = nn.space_to_depth(x, 2)
if self.use_activator: if self.use_activator:
@ -377,12 +377,12 @@ class DeepFakeArchi(nn.ArchiBase):
cur_ch = ch*( min(2**i, 8) ) cur_ch = ch*( min(2**i, 8) )
self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel, use_activator=use_activator) ) self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel, use_activator=use_activator) )
last_ch = self.downs[-1].get_out_ch() last_ch = self.downs[-1].get_out_ch()
self.bp1 = nn.BlurPool(kernel_size) #self.bp1 = nn.BlurPool(kernel_size)
def forward(self, inp): def forward(self, inp):
x = inp x = inp
for down in self.downs: for down in self.downs:
x = down(x) x = down(x)
x = self.bp1(x) #x = self.bp1(x)
return x return x
class DecayingDownscaleBlock(nn.ModelBase): class DecayingDownscaleBlock(nn.ModelBase):
@ -420,35 +420,36 @@ class DeepFakeArchi(nn.ArchiBase):
class Upscale(nn.ModelBase): class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3, depth_multiplier=1 ): def on_build(self, in_ch, out_ch, kernel_size=3, depth_multiplier=1 ):
self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME') self.conv1 = nn.SeparableConv2D( in_ch, out_ch*4, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME')
self.frn1 = nn.FRNorm2D(out_ch*4) #self.frn1 = nn.FRNorm2D(out_ch*4)
self.tlu1 = nn.TLU(out_ch*4) #self.tlu1 = nn.TLU(out_ch*4)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = self.frn1(x) #x = self.frn1(x)
x = self.tlu1(x) #x = self.tlu1(x)
#x = tf.nn.leaky_relu(x, 0.1) (TLU replaces relu) x = tf.nn.leaky_relu(x, 0.1)# (TLU replaces relu)
x = nn.depth_to_space(x, 2) x = nn.depth_to_space(x, 2)
return x return x
class ResidualBlock(nn.ModelBase): class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3, depth_multiplier=1 ): def on_build(self, ch, kernel_size=3, depth_multiplier=1 ):
self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME') self.conv1 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME')
self.frn1 = nn.FRNorm2D(ch) #self.frn1 = nn.FRNorm2D(ch)
self.tlu1 = nn.TLU(ch) #self.tlu1 = nn.TLU(ch)
self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME') self.conv2 = nn.SeparableConv2D( ch, ch, kernel_size=kernel_size, depth_multiplier=depth_multiplier, padding='SAME')
self.frn2 = nn.FRNorm2D(ch) #self.frn2 = nn.FRNorm2D(ch)
self.tlu2 = nn.TLU(ch) #self.tlu2 = nn.TLU(ch)
def forward(self, inp): def forward(self, inp):
x = self.conv1(inp) x = self.conv1(inp)
#x = tf.nn.leaky_relu(x, 0.2) x = tf.nn.leaky_relu(x, 0.2)
x = self.frn1(x) x = self.frn1(x)
x = self.tlu1(x) x = self.tlu1(x)
x = self.conv2(x) x = self.conv2(x)
x = self.frn2(x) x = self.frn2(x)
x = self.tlu2(inp + x) x = tf.nn.leaky_relu(inp + x, 0.2)
#x = self.tlu2(inp + x)
return x return x
""" """
class UpdownResidualBlock(nn.ModelBase): class UpdownResidualBlock(nn.ModelBase):
@ -471,14 +472,14 @@ class DeepFakeArchi(nn.ArchiBase):
""" """
class Encoder(nn.ModelBase): class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch, **kwargs): def on_build(self, in_ch, e_ch, **kwargs):
self.down1 = DecayingDownscaleBlock(in_ch, e_ch*2, n_downscales=6, dilations=1, use_activator=False) self.down1 = DownscaleBlock(in_ch, e_ch, kernel_size=5, n_downscales=4, dilations=1)
self.down2 = DecayingDownscaleBlock(in_ch, e_ch//2, n_downscales=6, dilations=2, use_activator=False) #self.down2 = DecayingDownscaleBlock(in_ch, e_ch//2, n_downscales=6, dilations=2, use_activator=False)
def forward(self, inp): def forward(self, inp):
#x = nn.flatten(self.down1(inp)) x = nn.flatten(self.down1(inp))
x = tf.concat([ nn.flatten(self.down1(inp)), #x = tf.concat([ nn.flatten(self.down1(inp)),
nn.flatten(self.down2(inp)) ], -1 ) #nn.flatten(self.down2(inp)) ], -1 )
return x return x
lowest_dense_res = resolution // 16 lowest_dense_res = resolution // 16