upd PatchDiscriminator

This commit is contained in:
iperov 2020-12-31 17:38:41 +04:00
commit ad5733c5bb
2 changed files with 5 additions and 5 deletions

View file

@ -152,17 +152,17 @@ class UNetPatchDiscriminator(nn.ModelBase):
self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID')
def forward(self, x):
x = tf.nn.leaky_relu( self.in_conv(x), 0.1 )
x = tf.nn.leaky_relu( self.in_conv(x), 0.2 )
encs = []
for conv in self.convs:#, self.res):
encs.insert(0, x)
x = tf.nn.leaky_relu( conv(x), 0.1 )
x = tf.nn.leaky_relu( conv(x), 0.2 )
center_out, x = self.center_out(x), self.center_conv(x)
center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )
for i, (upconv, enc,) in enumerate(zip(self.upconvs, encs)):# self.upres
x = tf.nn.leaky_relu( upconv(x), 0.1 )
x = tf.nn.leaky_relu( upconv(x), 0.2 )
x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)
return center_out, self.out_conv(x)