mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 21:13:20 -07:00
AMP, SAEHD: reverted GAN to december version.
This commit is contained in:
parent
b4b72d056f
commit
ee1bc83a14
3 changed files with 5 additions and 19 deletions
|
@ -146,11 +146,7 @@ class UNetPatchDiscriminator(nn.ModelBase):
|
|||
|
||||
prev_ch = in_ch
|
||||
self.convs = []
|
||||
self.res1 = []
|
||||
self.res2 = []
|
||||
self.upconvs = []
|
||||
self.upres1 = []
|
||||
self.upres2 = []
|
||||
layers = self.find_archi(patch_size)
|
||||
|
||||
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
|
||||
|
@ -160,14 +156,8 @@ class UNetPatchDiscriminator(nn.ModelBase):
|
|||
for i, (kernel_size, strides) in enumerate(layers):
|
||||
self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME') )
|
||||
|
||||
self.res1.append ( ResidualBlock(level_chs[i]) )
|
||||
self.res2.append ( ResidualBlock(level_chs[i]) )
|
||||
|
||||
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME') )
|
||||
|
||||
self.upres1.insert (0, ResidualBlock(level_chs[i-1]*2) )
|
||||
self.upres2.insert (0, ResidualBlock(level_chs[i-1]*2) )
|
||||
|
||||
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID')
|
||||
|
||||
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID')
|
||||
|
@ -178,19 +168,15 @@ class UNetPatchDiscriminator(nn.ModelBase):
|
|||
x = tf.nn.leaky_relu( self.in_conv(x), 0.2 )
|
||||
|
||||
encs = []
|
||||
for conv, res1,res2 in zip(self.convs, self.res1, self.res2):
|
||||
for conv in self.convs:
|
||||
encs.insert(0, x)
|
||||
x = tf.nn.leaky_relu( conv(x), 0.2 )
|
||||
x = res1(x)
|
||||
x = res2(x)
|
||||
|
||||
center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )
|
||||
|
||||
for i, (upconv, enc, upres1, upres2 ) in enumerate(zip(self.upconvs, encs, self.upres1, self.upres2)):
|
||||
for i, (upconv, enc) in enumerate(zip(self.upconvs, encs)):
|
||||
x = tf.nn.leaky_relu( upconv(x), 0.2 )
|
||||
x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)
|
||||
x = upres1(x)
|
||||
x = upres2(x)
|
||||
|
||||
return center_out, self.out_conv(x)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue