mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
SAEHD: new GAN
This commit is contained in:
parent
241d1a9c35
commit
ae9e16b4a5
2 changed files with 41 additions and 36 deletions
|
@ -130,26 +130,13 @@ class UNetPatchDiscriminator(nn.ModelBase):
|
|||
q=x[np.abs(np.array(x)-target_patch_size).argmin()]
|
||||
return s[q][2]
|
||||
|
||||
def on_build(self, patch_size, in_ch):
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
x = tf.nn.leaky_relu(x, 0.2)
|
||||
x = self.conv2(x)
|
||||
x = tf.nn.leaky_relu(inp + x, 0.2)
|
||||
return x
|
||||
|
||||
def on_build(self, patch_size, in_ch, base_ch = 16):
|
||||
prev_ch = in_ch
|
||||
self.convs = []
|
||||
self.res = []
|
||||
self.upconvs = []
|
||||
self.upres = []
|
||||
layers = self.find_archi(patch_size)
|
||||
base_ch = 16
|
||||
|
||||
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
|
||||
|
||||
|
@ -157,34 +144,26 @@ class UNetPatchDiscriminator(nn.ModelBase):
|
|||
|
||||
for i, (kernel_size, strides) in enumerate(layers):
|
||||
self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME') )
|
||||
|
||||
self.res.append ( ResidualBlock(level_chs[i]) )
|
||||
|
||||
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME') )
|
||||
|
||||
self.upres.insert (0, ResidualBlock(level_chs[i-1]*2) )
|
||||
|
||||
|
||||
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID')
|
||||
|
||||
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID')
|
||||
self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID')
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = tf.nn.leaky_relu( self.in_conv(x), 0.1 )
|
||||
|
||||
encs = []
|
||||
for conv, res in zip(self.convs, self.res):
|
||||
for conv in self.convs:#, self.res):
|
||||
encs.insert(0, x)
|
||||
x = tf.nn.leaky_relu( conv(x), 0.1 )
|
||||
x = res(x)
|
||||
|
||||
center_out, x = self.center_out(x), self.center_conv(x)
|
||||
|
||||
for i, (upconv, enc, upres) in enumerate(zip(self.upconvs, encs, self.upres)):
|
||||
for i, (upconv, enc,) in enumerate(zip(self.upconvs, encs)):# self.upres
|
||||
x = tf.nn.leaky_relu( upconv(x), 0.1 )
|
||||
x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)
|
||||
x = upres(x)
|
||||
|
||||
return center_out, self.out_conv(x)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue