upd leras

This commit is contained in:
Colombo 2020-02-27 16:20:14 +04:00
parent acb0b34811
commit 59be114485

View file

@ -243,11 +243,17 @@ def initialize_models(nn):
nn.ModelBase = ModelBase nn.ModelBase = ModelBase
class PatchDiscriminator(nn.ModelBase): class PatchDiscriminator(nn.ModelBase):
def on_build(self, patch_size, in_ch, base_ch=256, conv_kernel_initializer=None): def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None):
suggested_base_ch, kernels_strides = patch_discriminator_kernels[patch_size]
if base_ch is None:
base_ch = suggested_base_ch
prev_ch = in_ch prev_ch = in_ch
self.convs = [] self.convs = []
for i, (kernel_size, strides) in enumerate(patch_discriminator_kernels[patch_size]): for i, (kernel_size, strides) in enumerate(kernels_strides):
cur_ch = base_ch * min( (2**i), 8 ) cur_ch = base_ch * min( (2**i), 8 )
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) ) self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) )
prev_ch = cur_ch prev_ch = cur_ch
@ -261,10 +267,14 @@ def initialize_models(nn):
nn.PatchDiscriminator = PatchDiscriminator nn.PatchDiscriminator = PatchDiscriminator
class IllumDiscriminator(nn.ModelBase): class IllumDiscriminator(nn.ModelBase):
def on_build(self, patch_size, in_ch, base_ch=256, conv_kernel_initializer=None): def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None):
suggested_base_ch, kernels_strides = patch_discriminator_kernels[patch_size]
if base_ch is None:
base_ch = suggested_base_ch
prev_ch = in_ch prev_ch = in_ch
self.convs = [] self.convs = []
for i, (kernel_size, strides) in enumerate(patch_discriminator_kernels[patch_size]): for i, (kernel_size, strides) in enumerate(kernels_strides):
cur_ch = base_ch * min( (2**i), 8 ) cur_ch = base_ch * min( (2**i), 8 )
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) ) self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) )
prev_ch = cur_ch prev_ch = cur_ch
@ -287,23 +297,42 @@ def initialize_models(nn):
nn.IllumDiscriminator = IllumDiscriminator nn.IllumDiscriminator = IllumDiscriminator
patch_discriminator_kernels = \ patch_discriminator_kernels = \
{ 1 : [ [1,1] ], { 1 : (512, [ [1,1] ]),
2 : [ [2,1] ], 2 : (512, [ [2,1] ]),
3 : [ [2,1], [2,1] ], 3 : (512, [ [2,1], [2,1] ]),
4 : [ [2,2], [2,2] ], 4 : (512, [ [2,2], [2,2] ]),
5 : [ [3,2], [2,2] ], 5 : (512, [ [3,2], [2,2] ]),
6 : [ [4,2], [2,2] ], 6 : (512, [ [4,2], [2,2] ]),
7 : [ [3,2], [3,2] ], 7 : (512, [ [3,2], [3,2] ]),
8 : [ [4,2], [3,2] ], 8 : (512, [ [4,2], [3,2] ]),
9 : [ [3,2], [4,2] ], 9 : (512, [ [3,2], [4,2] ]),
10 : [ [4,2], [4,2] ], 10 : (512, [ [4,2], [4,2] ]),
11 : [ [3,2], [3,2], [2,1] ], 11 : (512, [ [3,2], [3,2], [2,1] ]),
12 : [ [4,2], [3,2], [2,1] ], 12 : (512, [ [4,2], [3,2], [2,1] ]),
13 : [ [3,2], [4,2], [2,1] ], 13 : (512, [ [3,2], [4,2], [2,1] ]),
14 : [ [4,2], [4,2], [2,1] ], 14 : (512, [ [4,2], [4,2], [2,1] ]),
15 : [ [3,2], [3,2], [3,1] ], 15 : (512, [ [3,2], [3,2], [3,1] ]),
16 : [ [4,2], [3,2], [3,1] ], 16 : (512, [ [4,2], [3,2], [3,1] ]),
17 : (512, [ [3,2], [4,2], [3,1] ]),
18 : (512, [ [4,2], [4,2], [3,1] ]),
28 : [ [4,2], [3,2], [4,2], [2,1] ] 19 : (512, [ [3,2], [3,2], [4,1] ]),
20 : (512, [ [4,2], [3,2], [4,1] ]),
21 : (512, [ [3,2], [4,2], [4,1] ]),
22 : (512, [ [4,2], [4,2], [4,1] ]),
23 : (256, [ [3,2], [3,2], [3,2], [2,1] ]),
24 : (256, [ [4,2], [3,2], [3,2], [2,1] ]),
25 : (256, [ [3,2], [4,2], [3,2], [2,1] ]),
26 : (256, [ [4,2], [4,2], [3,2], [2,1] ]),
27 : (256, [ [3,2], [4,2], [4,2], [2,1] ]),
28 : (256, [ [4,2], [3,2], [4,2], [2,1] ]),
29 : (256, [ [3,2], [4,2], [4,2], [2,1] ]),
30 : (256, [ [4,2], [4,2], [4,2], [2,1] ]),
31 : (256, [ [3,2], [3,2], [3,2], [3,1] ]),
32 : (256, [ [4,2], [3,2], [3,2], [3,1] ]),
33 : (256, [ [3,2], [4,2], [3,2], [3,1] ]),
34 : (256, [ [4,2], [4,2], [3,2], [3,1] ]),
35 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]),
37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]),
} }