From d6a50a0cc606048006a32c8767fb41a67239cb06 Mon Sep 17 00:00:00 2001 From: Jan Date: Tue, 23 Nov 2021 17:25:13 +0100 Subject: [PATCH] added fp16 to V3 gan, copied latest iperov (==v2) gan --- core/leras/models/PatchDiscriminator.py | 264 ++++++++++++------------ 1 file changed, 128 insertions(+), 136 deletions(-) diff --git a/core/leras/models/PatchDiscriminator.py b/core/leras/models/PatchDiscriminator.py index 63dd2b5..66ab02e 100644 --- a/core/leras/models/PatchDiscriminator.py +++ b/core/leras/models/PatchDiscriminator.py @@ -76,9 +76,11 @@ class PatchDiscriminator(nn.ModelBase): nn.PatchDiscriminator = PatchDiscriminator + class UNetPatchDiscriminator(nn.ModelBase): """ Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks" + Based on iperov commit 11add4cd4f5a61df26a8659f4cc5c8d9467bf5f8 from Jan 3, 2021 with added fp16 option """ def calc_receptive_field_size(self, layers): """ @@ -133,6 +135,130 @@ class UNetPatchDiscriminator(nn.ModelBase): def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False): self.use_fp16 = use_fp16 conv_dtype = tf.float16 if use_fp16 else tf.float32 + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, inp): + x = self.conv1(inp) + x = tf.nn.leaky_relu(x, 0.2) + x = self.conv2(x) + x = tf.nn.leaky_relu(inp + x, 0.2) + return x + + prev_ch = in_ch + self.convs = [] + self.res1 = [] + self.res2 = [] + self.upconvs = [] + self.upres1 = [] + self.upres2 = [] + layers = self.find_archi(patch_size) + + level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) } + + self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype) + + for i, (kernel_size, strides) in enumerate(layers): + self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) + + self.res1.append ( ResidualBlock(level_chs[i]) ) + self.res2.append ( ResidualBlock(level_chs[i]) ) + + self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) + + self.upres1.insert (0, ResidualBlock(level_chs[i-1]*2) ) + self.upres2.insert (0, ResidualBlock(level_chs[i-1]*2) ) + + self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype) + + self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype) + self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype) + + + def forward(self, x): + x = tf.nn.leaky_relu( self.in_conv(x), 0.2 ) + + encs = [] + for conv, res1,res2 in zip(self.convs, self.res1, self.res2): + encs.insert(0, x) + x = tf.nn.leaky_relu( conv(x), 0.2 ) + x = res1(x) + x = res2(x) + + center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 ) + + for i, (upconv, enc, upres1, upres2 ) in enumerate(zip(self.upconvs, encs, self.upres1, self.upres2)): + x = tf.nn.leaky_relu( upconv(x), 0.2 ) + x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis) + x = upres1(x) + x = upres2(x) + + return center_out, self.out_conv(x) + +nn.UNetPatchDiscriminator = UNetPatchDiscriminator + +class UNetPatchDiscriminatorV2(nn.ModelBase): + """ + Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks" + Based on iperov commit 35877dbfd724c22040f421e93c1adbb7142e5b5d from Jul 14, 2021 + """ + def calc_receptive_field_size(self, layers): + """ + result the same as https://fomoro.com/research/article/receptive-field-calculatorindex.html + """ + rf = 0 + ts = 1 + for i, (k, s) in enumerate(layers): + if i == 0: + rf = k + else: + rf += (k-1)*ts + ts *= s + return rf + + def find_archi(self, target_patch_size, max_layers=9): + """ + Find the best configuration of layers using only 3x3 convs for target patch size + """ + s = {} + for layers_count in range(1,max_layers+1): + val = 1 << (layers_count-1) + while True: + val -= 1 + + layers = [] + sum_st = 0 + layers.append ( [3, 2]) + sum_st += 2 + for i in range(layers_count-1): + st = 1 + (1 if val & (1 << i) !=0 else 0 ) + layers.append ( [3, st ]) + sum_st += st + + rf = self.calc_receptive_field_size(layers) + + s_rf = s.get(rf, None) + if s_rf is None: + s[rf] = (layers_count, sum_st, layers) + else: + if layers_count < s_rf[0] or \ + ( layers_count == s_rf[0] and sum_st > s_rf[1] ): + s[rf] = (layers_count, sum_st, layers) + + if val == 0: + break + + x = sorted(list(s.keys())) + q=x[np.abs(np.array(x)-target_patch_size).argmin()] + return s[q][2] + + def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False): + self.use_fp16 = use_fp16 + conv_dtype = tf.float16 if use_fp16 else tf.float32 + class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) @@ -149,7 +275,7 @@ class UNetPatchDiscriminator(nn.ModelBase): self.convs = [] self.upconvs = [] layers = self.find_archi(patch_size) - + level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) } self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype) @@ -157,14 +283,8 @@ class UNetPatchDiscriminator(nn.ModelBase): for i, (kernel_size, strides) in enumerate(layers): self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) - self.res1.append ( ResidualBlock(level_chs[i]) ) - self.res2.append ( ResidualBlock(level_chs[i]) ) - self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) - self.upres1.insert (0, ResidualBlock(level_chs[i-1]*2) ) - self.upres2.insert (0, ResidualBlock(level_chs[i-1]*2) ) - self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype) self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype) @@ -181,9 +301,7 @@ class UNetPatchDiscriminator(nn.ModelBase): for conv in self.convs: encs.insert(0, x) x = tf.nn.leaky_relu( conv(x), 0.2 ) - x = res1(x) - x = res2(x) - + center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 ) for i, (upconv, enc) in enumerate(zip(self.upconvs, encs)): @@ -197,131 +315,5 @@ class UNetPatchDiscriminator(nn.ModelBase): x = tf.cast(x, tf.float32) return center_out, x - -nn.UNetPatchDiscriminator = UNetPatchDiscriminator - -class UNetPatchDiscriminatorV2(nn.ModelBase): - """ - Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks" - """ - def calc_receptive_field_size(self, layers): - """ - result the same as https://fomoro.com/research/article/receptive-field-calculatorindex.html - """ - rf = 0 - ts = 1 - for i, (k, s) in enumerate(layers): - if i == 0: - rf = k - else: - rf += (k-1)*ts - ts *= s - return rf - - def find_archi(self, target_patch_size, max_layers=6): - """ - Find the best configuration of layers using only 3x3 convs for target patch size - """ - s = {} - for layers_count in range(1,max_layers+1): - val = 1 << (layers_count-1) - while True: - val -= 1 - - layers = [] - sum_st = 0 - for i in range(layers_count-1): - st = 1 + (1 if val & (1 << i) !=0 else 0 ) - layers.append ( [3, st ]) - sum_st += st - layers.append ( [3, 2]) - sum_st += 2 - - rf = self.calc_receptive_field_size(layers) - - s_rf = s.get(rf, None) - if s_rf is None: - s[rf] = (layers_count, sum_st, layers) - else: - if layers_count < s_rf[0] or \ - ( layers_count == s_rf[0] and sum_st > s_rf[1] ): - s[rf] = (layers_count, sum_st, layers) - - if val == 0: - break - - x = sorted(list(s.keys())) - q=x[np.abs(np.array(x)-target_patch_size).argmin()] - return s[q][2] - - def on_build(self, patch_size, in_ch, use_fp16 = False): - self.use_fp16 = use_fp16 - conv_dtype = tf.float16 if use_fp16 else tf.float32 - - class ResidualBlock(nn.ModelBase): - def on_build(self, ch, kernel_size=3 ): - self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) - self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) - - def forward(self, inp): - x = self.conv1(inp) - x = tf.nn.leaky_relu(x, 0.2) - x = self.conv2(x) - x = tf.nn.leaky_relu(inp + x, 0.2) - return x - - prev_ch = in_ch - self.convs = [] - self.res = [] - self.upconvs = [] - self.upres = [] - layers = self.find_archi(patch_size) - base_ch = 16 - - level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) } - - self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype) - - for i, (kernel_size, strides) in enumerate(layers): - self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) - - self.res.append ( ResidualBlock(level_chs[i]) ) - - self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) - - self.upres.insert (0, ResidualBlock(level_chs[i-1]*2) ) - - self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype) - - self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype) - self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype) - - - def forward(self, x): - if self.use_fp16: - x = tf.cast(x, tf.float16) - - x = tf.nn.leaky_relu( self.in_conv(x), 0.1 ) - - encs = [] - for conv, res in zip(self.convs, self.res): - encs.insert(0, x) - x = tf.nn.leaky_relu( conv(x), 0.1 ) - x = res(x) - - center_out, x = self.center_out(x), self.center_conv(x) - - for i, (upconv, enc, upres) in enumerate(zip(self.upconvs, encs, self.upres)): - x = tf.nn.leaky_relu( upconv(x), 0.1 ) - x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis) - x = upres(x) - - x = self.out_conv(x) - - if self.use_fp16: - center_out = tf.cast(center_out, tf.float32) - x = tf.cast(x, tf.float32) - - return center_out, x nn.UNetPatchDiscriminatorV2 = UNetPatchDiscriminatorV2