added fp16 to V3 gan, copied latest iperov (==v2) gan

This commit is contained in:
Jan 2021-11-23 17:25:13 +01:00
commit d6a50a0cc6

View file

@ -76,9 +76,11 @@ class PatchDiscriminator(nn.ModelBase):
nn.PatchDiscriminator = PatchDiscriminator nn.PatchDiscriminator = PatchDiscriminator
class UNetPatchDiscriminator(nn.ModelBase): class UNetPatchDiscriminator(nn.ModelBase):
""" """
Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks" Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks"
Based on iperov commit 11add4cd4f5a61df26a8659f4cc5c8d9467bf5f8 from Jan 3, 2021 with added fp16 option
""" """
def calc_receptive_field_size(self, layers): def calc_receptive_field_size(self, layers):
""" """
@ -133,6 +135,7 @@ class UNetPatchDiscriminator(nn.ModelBase):
def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False): def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False):
self.use_fp16 = use_fp16 self.use_fp16 = use_fp16
conv_dtype = tf.float16 if use_fp16 else tf.float32 conv_dtype = tf.float16 if use_fp16 else tf.float32
class ResidualBlock(nn.ModelBase): class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ): def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
@ -147,7 +150,11 @@ class UNetPatchDiscriminator(nn.ModelBase):
prev_ch = in_ch prev_ch = in_ch
self.convs = [] self.convs = []
self.res1 = []
self.res2 = []
self.upconvs = [] self.upconvs = []
self.upres1 = []
self.upres2 = []
layers = self.find_archi(patch_size) layers = self.find_archi(patch_size)
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) } level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
@ -172,13 +179,10 @@ class UNetPatchDiscriminator(nn.ModelBase):
def forward(self, x): def forward(self, x):
if self.use_fp16:
x = tf.cast(x, tf.float16)
x = tf.nn.leaky_relu( self.in_conv(x), 0.2 ) x = tf.nn.leaky_relu( self.in_conv(x), 0.2 )
encs = [] encs = []
for conv in self.convs: for conv, res1,res2 in zip(self.convs, self.res1, self.res2):
encs.insert(0, x) encs.insert(0, x)
x = tf.nn.leaky_relu( conv(x), 0.2 ) x = tf.nn.leaky_relu( conv(x), 0.2 )
x = res1(x) x = res1(x)
@ -186,23 +190,20 @@ class UNetPatchDiscriminator(nn.ModelBase):
center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 ) center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )
for i, (upconv, enc) in enumerate(zip(self.upconvs, encs)): for i, (upconv, enc, upres1, upres2 ) in enumerate(zip(self.upconvs, encs, self.upres1, self.upres2)):
x = tf.nn.leaky_relu( upconv(x), 0.2 ) x = tf.nn.leaky_relu( upconv(x), 0.2 )
x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis) x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)
x = upres1(x)
x = upres2(x)
x = self.out_conv(x) return center_out, self.out_conv(x)
if self.use_fp16:
center_out = tf.cast(center_out, tf.float32)
x = tf.cast(x, tf.float32)
return center_out, x
nn.UNetPatchDiscriminator = UNetPatchDiscriminator nn.UNetPatchDiscriminator = UNetPatchDiscriminator
class UNetPatchDiscriminatorV2(nn.ModelBase): class UNetPatchDiscriminatorV2(nn.ModelBase):
""" """
Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks" Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks"
Based on iperov commit 35877dbfd724c22040f421e93c1adbb7142e5b5d from Jul 14, 2021
""" """
def calc_receptive_field_size(self, layers): def calc_receptive_field_size(self, layers):
""" """
@ -218,7 +219,7 @@ class UNetPatchDiscriminatorV2(nn.ModelBase):
ts *= s ts *= s
return rf return rf
def find_archi(self, target_patch_size, max_layers=6): def find_archi(self, target_patch_size, max_layers=9):
""" """
Find the best configuration of layers using only 3x3 convs for target patch size Find the best configuration of layers using only 3x3 convs for target patch size
""" """
@ -230,12 +231,12 @@ class UNetPatchDiscriminatorV2(nn.ModelBase):
layers = [] layers = []
sum_st = 0 sum_st = 0
layers.append ( [3, 2])
sum_st += 2
for i in range(layers_count-1): for i in range(layers_count-1):
st = 1 + (1 if val & (1 << i) !=0 else 0 ) st = 1 + (1 if val & (1 << i) !=0 else 0 )
layers.append ( [3, st ]) layers.append ( [3, st ])
sum_st += st sum_st += st
layers.append ( [3, 2])
sum_st += 2
rf = self.calc_receptive_field_size(layers) rf = self.calc_receptive_field_size(layers)
@ -244,7 +245,7 @@ class UNetPatchDiscriminatorV2(nn.ModelBase):
s[rf] = (layers_count, sum_st, layers) s[rf] = (layers_count, sum_st, layers)
else: else:
if layers_count < s_rf[0] or \ if layers_count < s_rf[0] or \
( layers_count == s_rf[0] and sum_st > s_rf[1] ): ( layers_count == s_rf[0] and sum_st > s_rf[1] ):
s[rf] = (layers_count, sum_st, layers) s[rf] = (layers_count, sum_st, layers)
if val == 0: if val == 0:
@ -254,7 +255,7 @@ class UNetPatchDiscriminatorV2(nn.ModelBase):
q=x[np.abs(np.array(x)-target_patch_size).argmin()] q=x[np.abs(np.array(x)-target_patch_size).argmin()]
return s[q][2] return s[q][2]
def on_build(self, patch_size, in_ch, use_fp16 = False): def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False):
self.use_fp16 = use_fp16 self.use_fp16 = use_fp16
conv_dtype = tf.float16 if use_fp16 else tf.float32 conv_dtype = tf.float16 if use_fp16 else tf.float32
@ -272,11 +273,8 @@ class UNetPatchDiscriminatorV2(nn.ModelBase):
prev_ch = in_ch prev_ch = in_ch
self.convs = [] self.convs = []
self.res = []
self.upconvs = [] self.upconvs = []
self.upres = []
layers = self.find_archi(patch_size) layers = self.find_archi(patch_size)
base_ch = 16
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) } level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
@ -285,12 +283,8 @@ class UNetPatchDiscriminatorV2(nn.ModelBase):
for i, (kernel_size, strides) in enumerate(layers): for i, (kernel_size, strides) in enumerate(layers):
self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
self.res.append ( ResidualBlock(level_chs[i]) )
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) )
self.upres.insert (0, ResidualBlock(level_chs[i-1]*2) )
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype) self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype) self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype)
@ -301,20 +295,18 @@ class UNetPatchDiscriminatorV2(nn.ModelBase):
if self.use_fp16: if self.use_fp16:
x = tf.cast(x, tf.float16) x = tf.cast(x, tf.float16)
x = tf.nn.leaky_relu( self.in_conv(x), 0.1 ) x = tf.nn.leaky_relu( self.in_conv(x), 0.2 )
encs = [] encs = []
for conv, res in zip(self.convs, self.res): for conv in self.convs:
encs.insert(0, x) encs.insert(0, x)
x = tf.nn.leaky_relu( conv(x), 0.1 ) x = tf.nn.leaky_relu( conv(x), 0.2 )
x = res(x)
center_out, x = self.center_out(x), self.center_conv(x) center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 )
for i, (upconv, enc, upres) in enumerate(zip(self.upconvs, encs, self.upres)): for i, (upconv, enc) in enumerate(zip(self.upconvs, encs)):
x = tf.nn.leaky_relu( upconv(x), 0.1 ) x = tf.nn.leaky_relu( upconv(x), 0.2 )
x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis) x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)
x = upres(x)
x = self.out_conv(x) x = self.out_conv(x)