update GAN model

This commit is contained in:
iperov 2021-01-03 19:25:39 +04:00
commit 11add4cd4f
2 changed files with 7 additions and 7 deletions

View file

@ -94,7 +94,7 @@ class UNetPatchDiscriminator(nn.ModelBase):
ts *= s
return rf
def find_archi(self, target_patch_size, max_layers=6):
def find_archi(self, target_patch_size, max_layers=9):
"""
Find the best configuration of layers using only 3x3 convs for target patch size
"""
@ -106,12 +106,12 @@ class UNetPatchDiscriminator(nn.ModelBase):
layers = []
sum_st = 0
layers.append ( [3, 2])
sum_st += 2
for i in range(layers_count-1):
st = 1 + (1 if val & (1 << i) !=0 else 0 )
layers.append ( [3, st ])
sum_st += st
layers.append ( [3, 2])
sum_st += 2
sum_st += st
rf = self.calc_receptive_field_size(layers)
@ -152,7 +152,7 @@ class UNetPatchDiscriminator(nn.ModelBase):
self.upres1 = []
self.upres2 = []
layers = self.find_archi(patch_size)
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID')