code refactoring,

lr_dropout is now disabled in pretraining mode
changed help message for lr_dropout and random_warp
This commit is contained in:
Colombo 2020-03-07 13:59:47 +04:00
parent 9ccdd271a4
commit ada60ccefe
4 changed files with 612 additions and 232 deletions

View file

@ -1,4 +1,6 @@
import types
import numpy as np
from core.interact import interact as io
def initialize_models(nn):
tf = nn.tf
@ -296,6 +298,28 @@ def initialize_models(nn):
nn.IllumDiscriminator = IllumDiscriminator
class CodeDiscriminator(nn.ModelBase):
def on_build(self, in_ch, code_res, ch=256, conv_kernel_initializer=None):
if conv_kernel_initializer is None:
conv_kernel_initializer = nn.initializers.ca()
n_downscales = 1 + code_res // 8
self.convs = []
prev_ch = in_ch
for i in range(n_downscales):
cur_ch = ch * min( (2**i), 8 )
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=4 if i == 0 else 3, strides=2, padding='SAME', kernel_initializer=conv_kernel_initializer) )
prev_ch = cur_ch
self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer)
def forward(self, x):
for conv in self.convs:
x = tf.nn.leaky_relu( conv(x), 0.1 )
return self.out_conv(x)
nn.CodeDiscriminator = CodeDiscriminator
patch_discriminator_kernels = \
{ 1 : (512, [ [1,1] ]),
2 : (512, [ [2,1] ]),
@ -335,4 +359,4 @@ patch_discriminator_kernels = \
36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]),
37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]),
38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]),
}
}