diff --git a/core/leras/models.py b/core/leras/models.py index e15d40d..46b7217 100644 --- a/core/leras/models.py +++ b/core/leras/models.py @@ -243,15 +243,15 @@ def initialize_models(nn): nn.ModelBase = ModelBase class PatchDiscriminator(nn.ModelBase): - def on_build(self, patch_size, in_ch, base_ch=256, kernel_initializer=None): + def on_build(self, patch_size, in_ch, base_ch=256, conv_kernel_initializer=None): prev_ch = in_ch self.convs = [] for i, (kernel_size, strides) in enumerate(patch_discriminator_kernels[patch_size]): cur_ch = base_ch * min( (2**i), 8 ) - self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=kernel_initializer) ) + self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) ) prev_ch = cur_ch - self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=kernel_initializer) + self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer) def forward(self, x): for conv in self.convs: @@ -260,7 +260,32 @@ def initialize_models(nn): nn.PatchDiscriminator = PatchDiscriminator + class IllumDiscriminator(nn.ModelBase): + def on_build(self, patch_size, in_ch, base_ch=256, conv_kernel_initializer=None): + prev_ch = in_ch + self.convs = [] + for i, (kernel_size, strides) in enumerate(patch_discriminator_kernels[patch_size]): + cur_ch = base_ch * min( (2**i), 8 ) + self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) ) + prev_ch = cur_ch + + self.out1 = nn.Conv2D( 1, 1024, kernel_size=1, strides=1, padding='SAME', kernel_initializer=conv_kernel_initializer) + self.out2 = nn.Conv2D( 1024, 1, kernel_size=1, strides=1, padding='SAME', kernel_initializer=conv_kernel_initializer) + def forward(self, x): + for conv in self.convs: + x = tf.nn.leaky_relu( conv(x), 0.1 ) + + x = tf.reduce_mean(x, axis=nn.conv2d_ch_axis, keep_dims=True) + + x = self.out1(x) + x = tf.nn.leaky_relu(x, 0.1 ) + x = self.out2(x) + + return x + + nn.IllumDiscriminator = IllumDiscriminator + patch_discriminator_kernels = \ { 1 : [ [1,1] ], 2 : [ [2,1] ], diff --git a/core/leras/nn.py b/core/leras/nn.py index 9199b6d..c5ba9cf 100644 --- a/core/leras/nn.py +++ b/core/leras/nn.py @@ -55,6 +55,7 @@ class nn(): tf_upsample2d = None tf_upsample2d_bilinear = None tf_flatten = None + tf_max_pool = None tf_reshape_4D = None tf_random_binomial = None tf_gaussian_blur = None @@ -82,7 +83,8 @@ class nn(): # Models PatchDiscriminator = None - + IllumDiscriminator = None + @staticmethod def initialize(device_config=None, floatx="float32", data_format="NHWC"): diff --git a/core/leras/tensor_ops.py b/core/leras/tensor_ops.py index db99f0e..9b0f177 100644 --- a/core/leras/tensor_ops.py +++ b/core/leras/tensor_ops.py @@ -129,6 +129,14 @@ def initialize_tensor_ops(nn): nn.tf_flatten = tf_flatten + def tf_max_pool(x, kernel_size, strides): + if nn.data_format == "NHWC": + return tf.nn.max_pool(x, [1,kernel_size,kernel_size,1], [1,strides,strides,1], "VALID", data_format=nn.data_format) + else: + return tf.nn.max_pool(x, [1,1,kernel_size,kernel_size], [1,1,strides,strides], "VALID", data_format=nn.data_format) + + nn.tf_max_pool = tf_max_pool + def tf_reshape_4D(x, w,h,c): if nn.data_format == "NHWC": # match NCHW version in order to switch data_format without problems