This commit is contained in:
iperov 2021-01-03 01:08:20 +04:00
commit 54fc3162ed
2 changed files with 36 additions and 9 deletions

View file

@ -131,11 +131,26 @@ class UNetPatchDiscriminator(nn.ModelBase):
return s[q][2]
def on_build(self, patch_size, in_ch, base_ch = 16):
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def forward(self, inp):
x = self.conv1(inp)
x = tf.nn.leaky_relu(x, 0.2)
x = self.conv2(x)
x = tf.nn.leaky_relu(inp + x, 0.2)
return x
prev_ch = in_ch
self.convs = []
self.res = []
self.res1 = []
self.res2 = []
self.upconvs = []
self.upres = []
self.upres1 = []
self.upres2 = []
layers = self.find_archi(patch_size)
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
@ -144,26 +159,38 @@ class UNetPatchDiscriminator(nn.ModelBase):
for i, (kernel_size, strides) in enumerate(layers):
self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME') )
self.res1.append ( ResidualBlock(level_chs[i]) )
self.res2.append ( ResidualBlock(level_chs[i]) )
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME') )
self.upres1.insert (0, ResidualBlock(level_chs[i-1]*2) )
self.upres2.insert (0, ResidualBlock(level_chs[i-1]*2) )
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID')
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID')
self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID')
def forward(self, x):
x = tf.nn.leaky_relu( self.in_conv(x), 0.2 )
encs = []
for conv in self.convs:#, self.res):
for conv, res1,res2 in zip(self.convs, self.res1, self.res2):
encs.insert(0, x)
x = tf.nn.leaky_relu( conv(x), 0.2 )
x = res1(x)
x = res2(x)
center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )
for i, (upconv, enc,) in enumerate(zip(self.upconvs, encs)):# self.upres
for i, (upconv, enc, upres1, upres2 ) in enumerate(zip(self.upconvs, encs, self.upres1, self.upres2)):
x = tf.nn.leaky_relu( upconv(x), 0.2 )
x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)
x = upres1(x)
x = upres2(x)
return center_out, self.out_conv(x)