update leras

This commit is contained in:
Colombo 2020-02-17 18:26:19 +04:00
parent e1635ff760
commit e0a55ff1c3
3 changed files with 39 additions and 4 deletions

View file

@ -243,15 +243,15 @@ def initialize_models(nn):
nn.ModelBase = ModelBase
class PatchDiscriminator(nn.ModelBase):
def on_build(self, patch_size, in_ch, base_ch=256, kernel_initializer=None):
def on_build(self, patch_size, in_ch, base_ch=256, conv_kernel_initializer=None):
prev_ch = in_ch
self.convs = []
for i, (kernel_size, strides) in enumerate(patch_discriminator_kernels[patch_size]):
cur_ch = base_ch * min( (2**i), 8 )
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=kernel_initializer) )
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) )
prev_ch = cur_ch
self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=kernel_initializer)
self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer)
def forward(self, x):
for conv in self.convs:
@ -260,7 +260,32 @@ def initialize_models(nn):
nn.PatchDiscriminator = PatchDiscriminator
class IllumDiscriminator(nn.ModelBase):
def on_build(self, patch_size, in_ch, base_ch=256, conv_kernel_initializer=None):
prev_ch = in_ch
self.convs = []
for i, (kernel_size, strides) in enumerate(patch_discriminator_kernels[patch_size]):
cur_ch = base_ch * min( (2**i), 8 )
self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) )
prev_ch = cur_ch
self.out1 = nn.Conv2D( 1, 1024, kernel_size=1, strides=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.out2 = nn.Conv2D( 1024, 1, kernel_size=1, strides=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
def forward(self, x):
for conv in self.convs:
x = tf.nn.leaky_relu( conv(x), 0.1 )
x = tf.reduce_mean(x, axis=nn.conv2d_ch_axis, keep_dims=True)
x = self.out1(x)
x = tf.nn.leaky_relu(x, 0.1 )
x = self.out2(x)
return x
nn.IllumDiscriminator = IllumDiscriminator
patch_discriminator_kernels = \
{ 1 : [ [1,1] ],
2 : [ [2,1] ],

View file

@ -55,6 +55,7 @@ class nn():
tf_upsample2d = None
tf_upsample2d_bilinear = None
tf_flatten = None
tf_max_pool = None
tf_reshape_4D = None
tf_random_binomial = None
tf_gaussian_blur = None
@ -82,7 +83,8 @@ class nn():
# Models
PatchDiscriminator = None
IllumDiscriminator = None
@staticmethod
def initialize(device_config=None, floatx="float32", data_format="NHWC"):

View file

@ -129,6 +129,14 @@ def initialize_tensor_ops(nn):
nn.tf_flatten = tf_flatten
def tf_max_pool(x, kernel_size, strides):
if nn.data_format == "NHWC":
return tf.nn.max_pool(x, [1,kernel_size,kernel_size,1], [1,strides,strides,1], "VALID", data_format=nn.data_format)
else:
return tf.nn.max_pool(x, [1,1,kernel_size,kernel_size], [1,1,strides,strides], "VALID", data_format=nn.data_format)
nn.tf_max_pool = tf_max_pool
def tf_reshape_4D(x, w,h,c):
if nn.data_format == "NHWC":
# match NCHW version in order to switch data_format without problems