mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 13:09:56 -07:00
added fp16 to V3 gan, copied latest iperov (==v2) gan
This commit is contained in:
parent
8df69209c7
commit
d6a50a0cc6
1 changed files with 128 additions and 136 deletions
|
@ -76,9 +76,11 @@ class PatchDiscriminator(nn.ModelBase):
|
||||||
|
|
||||||
nn.PatchDiscriminator = PatchDiscriminator
|
nn.PatchDiscriminator = PatchDiscriminator
|
||||||
|
|
||||||
|
|
||||||
class UNetPatchDiscriminator(nn.ModelBase):
|
class UNetPatchDiscriminator(nn.ModelBase):
|
||||||
"""
|
"""
|
||||||
Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks"
|
Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks"
|
||||||
|
Based on iperov commit 11add4cd4f5a61df26a8659f4cc5c8d9467bf5f8 from Jan 3, 2021 with added fp16 option
|
||||||
"""
|
"""
|
||||||
def calc_receptive_field_size(self, layers):
|
def calc_receptive_field_size(self, layers):
|
||||||
"""
|
"""
|
||||||
|
@ -133,6 +135,130 @@ class UNetPatchDiscriminator(nn.ModelBase):
|
||||||
def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False):
|
def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False):
|
||||||
self.use_fp16 = use_fp16
|
self.use_fp16 = use_fp16
|
||||||
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
||||||
|
|
||||||
|
class ResidualBlock(nn.ModelBase):
|
||||||
|
def on_build(self, ch, kernel_size=3 ):
|
||||||
|
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||||
|
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||||
|
|
||||||
|
def forward(self, inp):
|
||||||
|
x = self.conv1(inp)
|
||||||
|
x = tf.nn.leaky_relu(x, 0.2)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = tf.nn.leaky_relu(inp + x, 0.2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
prev_ch = in_ch
|
||||||
|
self.convs = []
|
||||||
|
self.res1 = []
|
||||||
|
self.res2 = []
|
||||||
|
self.upconvs = []
|
||||||
|
self.upres1 = []
|
||||||
|
self.upres2 = []
|
||||||
|
layers = self.find_archi(patch_size)
|
||||||
|
|
||||||
|
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
|
||||||
|
|
||||||
|
self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||||
|
|
||||||
|
for i, (kernel_size, strides) in enumerate(layers):
|
||||||
|
self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
|
||||||
|
|
||||||
|
self.res1.append ( ResidualBlock(level_chs[i]) )
|
||||||
|
self.res2.append ( ResidualBlock(level_chs[i]) )
|
||||||
|
|
||||||
|
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
|
||||||
|
|
||||||
|
self.upres1.insert (0, ResidualBlock(level_chs[i-1]*2) )
|
||||||
|
self.upres2.insert (0, ResidualBlock(level_chs[i-1]*2) )
|
||||||
|
|
||||||
|
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||||
|
|
||||||
|
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||||
|
self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = tf.nn.leaky_relu( self.in_conv(x), 0.2 )
|
||||||
|
|
||||||
|
encs = []
|
||||||
|
for conv, res1,res2 in zip(self.convs, self.res1, self.res2):
|
||||||
|
encs.insert(0, x)
|
||||||
|
x = tf.nn.leaky_relu( conv(x), 0.2 )
|
||||||
|
x = res1(x)
|
||||||
|
x = res2(x)
|
||||||
|
|
||||||
|
center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )
|
||||||
|
|
||||||
|
for i, (upconv, enc, upres1, upres2 ) in enumerate(zip(self.upconvs, encs, self.upres1, self.upres2)):
|
||||||
|
x = tf.nn.leaky_relu( upconv(x), 0.2 )
|
||||||
|
x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)
|
||||||
|
x = upres1(x)
|
||||||
|
x = upres2(x)
|
||||||
|
|
||||||
|
return center_out, self.out_conv(x)
|
||||||
|
|
||||||
|
nn.UNetPatchDiscriminator = UNetPatchDiscriminator
|
||||||
|
|
||||||
|
class UNetPatchDiscriminatorV2(nn.ModelBase):
|
||||||
|
"""
|
||||||
|
Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks"
|
||||||
|
Based on iperov commit 35877dbfd724c22040f421e93c1adbb7142e5b5d from Jul 14, 2021
|
||||||
|
"""
|
||||||
|
def calc_receptive_field_size(self, layers):
|
||||||
|
"""
|
||||||
|
result the same as https://fomoro.com/research/article/receptive-field-calculatorindex.html
|
||||||
|
"""
|
||||||
|
rf = 0
|
||||||
|
ts = 1
|
||||||
|
for i, (k, s) in enumerate(layers):
|
||||||
|
if i == 0:
|
||||||
|
rf = k
|
||||||
|
else:
|
||||||
|
rf += (k-1)*ts
|
||||||
|
ts *= s
|
||||||
|
return rf
|
||||||
|
|
||||||
|
def find_archi(self, target_patch_size, max_layers=9):
|
||||||
|
"""
|
||||||
|
Find the best configuration of layers using only 3x3 convs for target patch size
|
||||||
|
"""
|
||||||
|
s = {}
|
||||||
|
for layers_count in range(1,max_layers+1):
|
||||||
|
val = 1 << (layers_count-1)
|
||||||
|
while True:
|
||||||
|
val -= 1
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
sum_st = 0
|
||||||
|
layers.append ( [3, 2])
|
||||||
|
sum_st += 2
|
||||||
|
for i in range(layers_count-1):
|
||||||
|
st = 1 + (1 if val & (1 << i) !=0 else 0 )
|
||||||
|
layers.append ( [3, st ])
|
||||||
|
sum_st += st
|
||||||
|
|
||||||
|
rf = self.calc_receptive_field_size(layers)
|
||||||
|
|
||||||
|
s_rf = s.get(rf, None)
|
||||||
|
if s_rf is None:
|
||||||
|
s[rf] = (layers_count, sum_st, layers)
|
||||||
|
else:
|
||||||
|
if layers_count < s_rf[0] or \
|
||||||
|
( layers_count == s_rf[0] and sum_st > s_rf[1] ):
|
||||||
|
s[rf] = (layers_count, sum_st, layers)
|
||||||
|
|
||||||
|
if val == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
x = sorted(list(s.keys()))
|
||||||
|
q=x[np.abs(np.array(x)-target_patch_size).argmin()]
|
||||||
|
return s[q][2]
|
||||||
|
|
||||||
|
def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False):
|
||||||
|
self.use_fp16 = use_fp16
|
||||||
|
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
||||||
|
|
||||||
class ResidualBlock(nn.ModelBase):
|
class ResidualBlock(nn.ModelBase):
|
||||||
def on_build(self, ch, kernel_size=3 ):
|
def on_build(self, ch, kernel_size=3 ):
|
||||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||||
|
@ -149,7 +275,7 @@ class UNetPatchDiscriminator(nn.ModelBase):
|
||||||
self.convs = []
|
self.convs = []
|
||||||
self.upconvs = []
|
self.upconvs = []
|
||||||
layers = self.find_archi(patch_size)
|
layers = self.find_archi(patch_size)
|
||||||
|
|
||||||
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
|
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
|
||||||
|
|
||||||
self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
|
self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||||
|
@ -157,14 +283,8 @@ class UNetPatchDiscriminator(nn.ModelBase):
|
||||||
for i, (kernel_size, strides) in enumerate(layers):
|
for i, (kernel_size, strides) in enumerate(layers):
|
||||||
self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
|
self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
|
||||||
|
|
||||||
self.res1.append ( ResidualBlock(level_chs[i]) )
|
|
||||||
self.res2.append ( ResidualBlock(level_chs[i]) )
|
|
||||||
|
|
||||||
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
|
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
|
||||||
|
|
||||||
self.upres1.insert (0, ResidualBlock(level_chs[i-1]*2) )
|
|
||||||
self.upres2.insert (0, ResidualBlock(level_chs[i-1]*2) )
|
|
||||||
|
|
||||||
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
|
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||||
|
|
||||||
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
|
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
|
||||||
|
@ -181,9 +301,7 @@ class UNetPatchDiscriminator(nn.ModelBase):
|
||||||
for conv in self.convs:
|
for conv in self.convs:
|
||||||
encs.insert(0, x)
|
encs.insert(0, x)
|
||||||
x = tf.nn.leaky_relu( conv(x), 0.2 )
|
x = tf.nn.leaky_relu( conv(x), 0.2 )
|
||||||
x = res1(x)
|
|
||||||
x = res2(x)
|
|
||||||
|
|
||||||
center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )
|
center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )
|
||||||
|
|
||||||
for i, (upconv, enc) in enumerate(zip(self.upconvs, encs)):
|
for i, (upconv, enc) in enumerate(zip(self.upconvs, encs)):
|
||||||
|
@ -197,131 +315,5 @@ class UNetPatchDiscriminator(nn.ModelBase):
|
||||||
x = tf.cast(x, tf.float32)
|
x = tf.cast(x, tf.float32)
|
||||||
|
|
||||||
return center_out, x
|
return center_out, x
|
||||||
|
|
||||||
nn.UNetPatchDiscriminator = UNetPatchDiscriminator
|
|
||||||
|
|
||||||
class UNetPatchDiscriminatorV2(nn.ModelBase):
|
|
||||||
"""
|
|
||||||
Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks"
|
|
||||||
"""
|
|
||||||
def calc_receptive_field_size(self, layers):
|
|
||||||
"""
|
|
||||||
result the same as https://fomoro.com/research/article/receptive-field-calculatorindex.html
|
|
||||||
"""
|
|
||||||
rf = 0
|
|
||||||
ts = 1
|
|
||||||
for i, (k, s) in enumerate(layers):
|
|
||||||
if i == 0:
|
|
||||||
rf = k
|
|
||||||
else:
|
|
||||||
rf += (k-1)*ts
|
|
||||||
ts *= s
|
|
||||||
return rf
|
|
||||||
|
|
||||||
def find_archi(self, target_patch_size, max_layers=6):
|
|
||||||
"""
|
|
||||||
Find the best configuration of layers using only 3x3 convs for target patch size
|
|
||||||
"""
|
|
||||||
s = {}
|
|
||||||
for layers_count in range(1,max_layers+1):
|
|
||||||
val = 1 << (layers_count-1)
|
|
||||||
while True:
|
|
||||||
val -= 1
|
|
||||||
|
|
||||||
layers = []
|
|
||||||
sum_st = 0
|
|
||||||
for i in range(layers_count-1):
|
|
||||||
st = 1 + (1 if val & (1 << i) !=0 else 0 )
|
|
||||||
layers.append ( [3, st ])
|
|
||||||
sum_st += st
|
|
||||||
layers.append ( [3, 2])
|
|
||||||
sum_st += 2
|
|
||||||
|
|
||||||
rf = self.calc_receptive_field_size(layers)
|
|
||||||
|
|
||||||
s_rf = s.get(rf, None)
|
|
||||||
if s_rf is None:
|
|
||||||
s[rf] = (layers_count, sum_st, layers)
|
|
||||||
else:
|
|
||||||
if layers_count < s_rf[0] or \
|
|
||||||
( layers_count == s_rf[0] and sum_st > s_rf[1] ):
|
|
||||||
s[rf] = (layers_count, sum_st, layers)
|
|
||||||
|
|
||||||
if val == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
x = sorted(list(s.keys()))
|
|
||||||
q=x[np.abs(np.array(x)-target_patch_size).argmin()]
|
|
||||||
return s[q][2]
|
|
||||||
|
|
||||||
def on_build(self, patch_size, in_ch, use_fp16 = False):
|
|
||||||
self.use_fp16 = use_fp16
|
|
||||||
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
|
||||||
|
|
||||||
class ResidualBlock(nn.ModelBase):
|
|
||||||
def on_build(self, ch, kernel_size=3 ):
|
|
||||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
|
||||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
|
||||||
|
|
||||||
def forward(self, inp):
|
|
||||||
x = self.conv1(inp)
|
|
||||||
x = tf.nn.leaky_relu(x, 0.2)
|
|
||||||
x = self.conv2(x)
|
|
||||||
x = tf.nn.leaky_relu(inp + x, 0.2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
prev_ch = in_ch
|
|
||||||
self.convs = []
|
|
||||||
self.res = []
|
|
||||||
self.upconvs = []
|
|
||||||
self.upres = []
|
|
||||||
layers = self.find_archi(patch_size)
|
|
||||||
base_ch = 16
|
|
||||||
|
|
||||||
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
|
|
||||||
|
|
||||||
self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
|
|
||||||
|
|
||||||
for i, (kernel_size, strides) in enumerate(layers):
|
|
||||||
self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
|
|
||||||
|
|
||||||
self.res.append ( ResidualBlock(level_chs[i]) )
|
|
||||||
|
|
||||||
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
|
|
||||||
|
|
||||||
self.upres.insert (0, ResidualBlock(level_chs[i-1]*2) )
|
|
||||||
|
|
||||||
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
|
|
||||||
|
|
||||||
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
|
|
||||||
self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.use_fp16:
|
|
||||||
x = tf.cast(x, tf.float16)
|
|
||||||
|
|
||||||
x = tf.nn.leaky_relu( self.in_conv(x), 0.1 )
|
|
||||||
|
|
||||||
encs = []
|
|
||||||
for conv, res in zip(self.convs, self.res):
|
|
||||||
encs.insert(0, x)
|
|
||||||
x = tf.nn.leaky_relu( conv(x), 0.1 )
|
|
||||||
x = res(x)
|
|
||||||
|
|
||||||
center_out, x = self.center_out(x), self.center_conv(x)
|
|
||||||
|
|
||||||
for i, (upconv, enc, upres) in enumerate(zip(self.upconvs, encs, self.upres)):
|
|
||||||
x = tf.nn.leaky_relu( upconv(x), 0.1 )
|
|
||||||
x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)
|
|
||||||
x = upres(x)
|
|
||||||
|
|
||||||
x = self.out_conv(x)
|
|
||||||
|
|
||||||
if self.use_fp16:
|
|
||||||
center_out = tf.cast(center_out, tf.float32)
|
|
||||||
x = tf.cast(x, tf.float32)
|
|
||||||
|
|
||||||
return center_out, x
|
|
||||||
|
|
||||||
nn.UNetPatchDiscriminatorV2 = UNetPatchDiscriminatorV2
|
nn.UNetPatchDiscriminatorV2 = UNetPatchDiscriminatorV2
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue