diff --git a/core/leras/models.py b/core/leras/models.py index 231dc85..1962c08 100644 --- a/core/leras/models.py +++ b/core/leras/models.py @@ -243,11 +243,17 @@ def initialize_models(nn): nn.ModelBase = ModelBase class PatchDiscriminator(nn.ModelBase): - def on_build(self, patch_size, in_ch, base_ch=256, conv_kernel_initializer=None): + def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None): + suggested_base_ch, kernels_strides = patch_discriminator_kernels[patch_size] + + if base_ch is None: + base_ch = suggested_base_ch + prev_ch = in_ch self.convs = [] - for i, (kernel_size, strides) in enumerate(patch_discriminator_kernels[patch_size]): + for i, (kernel_size, strides) in enumerate(kernels_strides): cur_ch = base_ch * min( (2**i), 8 ) + self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) ) prev_ch = cur_ch @@ -261,10 +267,14 @@ def initialize_models(nn): nn.PatchDiscriminator = PatchDiscriminator class IllumDiscriminator(nn.ModelBase): - def on_build(self, patch_size, in_ch, base_ch=256, conv_kernel_initializer=None): + def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None): + suggested_base_ch, kernels_strides = patch_discriminator_kernels[patch_size] + if base_ch is None: + base_ch = suggested_base_ch + prev_ch = in_ch self.convs = [] - for i, (kernel_size, strides) in enumerate(patch_discriminator_kernels[patch_size]): + for i, (kernel_size, strides) in enumerate(kernels_strides): cur_ch = base_ch * min( (2**i), 8 ) self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) ) prev_ch = cur_ch @@ -287,23 +297,42 @@ def initialize_models(nn): nn.IllumDiscriminator = IllumDiscriminator patch_discriminator_kernels = \ - { 1 : [ [1,1] ], - 2 : [ [2,1] ], - 3 : [ [2,1], [2,1] ], - 4 : [ [2,2], [2,2] ], - 5 : [ [3,2], [2,2] ], - 6 : [ [4,2], [2,2] ], - 7 : [ [3,2], [3,2] ], - 8 : [ [4,2], [3,2] ], - 9 : [ [3,2], [4,2] ], - 10 : [ [4,2], [4,2] ], - 11 : [ [3,2], [3,2], [2,1] ], - 12 : [ [4,2], [3,2], [2,1] ], - 13 : [ [3,2], [4,2], [2,1] ], - 14 : [ [4,2], [4,2], [2,1] ], - 15 : [ [3,2], [3,2], [3,1] ], - 16 : [ [4,2], [3,2], [3,1] ], - - - 28 : [ [4,2], [3,2], [4,2], [2,1] ] - } \ No newline at end of file + { 1 : (512, [ [1,1] ]), + 2 : (512, [ [2,1] ]), + 3 : (512, [ [2,1], [2,1] ]), + 4 : (512, [ [2,2], [2,2] ]), + 5 : (512, [ [3,2], [2,2] ]), + 6 : (512, [ [4,2], [2,2] ]), + 7 : (512, [ [3,2], [3,2] ]), + 8 : (512, [ [4,2], [3,2] ]), + 9 : (512, [ [3,2], [4,2] ]), + 10 : (512, [ [4,2], [4,2] ]), + 11 : (512, [ [3,2], [3,2], [2,1] ]), + 12 : (512, [ [4,2], [3,2], [2,1] ]), + 13 : (512, [ [3,2], [4,2], [2,1] ]), + 14 : (512, [ [4,2], [4,2], [2,1] ]), + 15 : (512, [ [3,2], [3,2], [3,1] ]), + 16 : (512, [ [4,2], [3,2], [3,1] ]), + 17 : (512, [ [3,2], [4,2], [3,1] ]), + 18 : (512, [ [4,2], [4,2], [3,1] ]), + 19 : (512, [ [3,2], [3,2], [4,1] ]), + 20 : (512, [ [4,2], [3,2], [4,1] ]), + 21 : (512, [ [3,2], [4,2], [4,1] ]), + 22 : (512, [ [4,2], [4,2], [4,1] ]), + 23 : (256, [ [3,2], [3,2], [3,2], [2,1] ]), + 24 : (256, [ [4,2], [3,2], [3,2], [2,1] ]), + 25 : (256, [ [3,2], [4,2], [3,2], [2,1] ]), + 26 : (256, [ [4,2], [4,2], [3,2], [2,1] ]), + 27 : (256, [ [3,2], [4,2], [4,2], [2,1] ]), + 28 : (256, [ [4,2], [3,2], [4,2], [2,1] ]), + 29 : (256, [ [3,2], [4,2], [4,2], [2,1] ]), + 30 : (256, [ [4,2], [4,2], [4,2], [2,1] ]), + 31 : (256, [ [3,2], [3,2], [3,2], [3,1] ]), + 32 : (256, [ [4,2], [3,2], [3,2], [3,1] ]), + 33 : (256, [ [3,2], [4,2], [3,2], [3,1] ]), + 34 : (256, [ [4,2], [4,2], [3,2], [3,1] ]), + 35 : (256, [ [3,2], [4,2], [4,2], [3,1] ]), + 36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]), + 37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]), + 38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]), + } \ No newline at end of file