This commit is contained in:
Colombo 2020-01-22 10:29:17 +04:00
parent 9797a70fd3
commit beed145d29
2 changed files with 17 additions and 43 deletions

View file

@ -42,7 +42,7 @@ class QModel(ModelBase):
x = tf.nn.space_to_depth(x, 2)
if self.use_activator:
x = tf.nn.leaky_relu(x, 0.2)
x = nn.tf_gelu(x)
return x
def get_out_ch(self):
@ -70,24 +70,10 @@ class QModel(ModelBase):
def forward(self, x):
x = self.conv1(x)
x = tf.nn.leaky_relu(x, 0.2)
x = nn.tf_gelu(x)
x = tf.nn.depth_to_space(x, 2)
return x
class UpdownResidualBlock(nn.ModelBase):
def on_build(self, ch, inner_ch, kernel_size=3 ):
self.up = Upscale (ch, inner_ch, kernel_size=kernel_size)
self.res = ResidualBlock (inner_ch, kernel_size=kernel_size)
self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False)
def forward(self, inp):
x = self.up(inp)
x = upx = self.res(x)
x = self.down(x)
x = x + inp
x = tf.nn.leaky_relu(x, 0.2)
return x, upx
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
@ -95,10 +81,10 @@ class QModel(ModelBase):
def forward(self, inp):
x = self.conv1(inp)
x = tf.nn.leaky_relu(x, 0.2)
x = nn.tf_gelu(x)
x = self.conv2(x)
x = inp + x
x = tf.nn.leaky_relu(x, 0.2)
x = nn.tf_gelu(x)
return x
class Encoder(nn.ModelBase):
@ -116,7 +102,7 @@ class QModel(ModelBase):
in_ch, lowest_dense_res, ae_ch, ae_out_ch, d_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch, self.d_ch
self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, maxout_features=2, kernel_initializer=tf.initializers.orthogonal )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, maxout_features=4, kernel_initializer=tf.initializers.orthogonal )
self.upscale1 = Upscale(ae_out_ch, d_ch*8)
self.res1 = ResidualBlock(d_ch*8)
@ -133,13 +119,12 @@ class QModel(ModelBase):
class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch):
self.upscale1 = Upscale(in_ch, d_ch*4)
self.res1 = UpdownResidualBlock(d_ch*4, d_ch*2)
self.upscale1 = Upscale(in_ch, d_ch*4)
self.res1 = ResidualBlock(d_ch*4)
self.upscale2 = Upscale(d_ch*4, d_ch*2)
self.res2 = UpdownResidualBlock(d_ch*2, d_ch)
self.res2 = ResidualBlock(d_ch*2)
self.upscale3 = Upscale(d_ch*2, d_ch*1)
self.res3 = UpdownResidualBlock(d_ch, d_ch//2)
self.res3 = ResidualBlock(d_ch*1)
self.upscalem1 = Upscale(in_ch, d_ch)
self.upscalem2 = Upscale(d_ch, d_ch//2)
@ -149,27 +134,13 @@ class QModel(ModelBase):
self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
def forward(self, inp):
z = inp
x = self.upscale1(z)
x, upx = self.res1(x)
x = self.upscale2(x)
x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res2(x)
x = self.upscale3(x)
x = tf.nn.leaky_relu(x + upx, 0.2)
x, upx = self.res3(x)
"""
x = self.upscale1 (z)
z = inp
x = self.upscale1 (z)
x = self.res1 (x)
x = self.upscale2 (x)
x = self.res2 (x)
x = self.upscale3 (x)
x = self.res3 (x)
"""
y = self.upscalem1 (z)
y = self.upscalem2 (y)
@ -185,7 +156,7 @@ class QModel(ModelBase):
ae_dims = 128
e_dims = 128
d_dims = 64
self.pretrain = True
self.pretrain = False
self.pretrain_just_disabled = False
masked_training = True