update leras

This commit is contained in:
Colombo 2020-03-07 16:05:53 +04:00
parent 123bccf01a
commit 0bac399841
2 changed files with 173 additions and 165 deletions

View file

@ -2,7 +2,8 @@
def initialize_archis(nn):
tf = nn.tf
def get_ae_models():
def get_ae_models(resolution):
lowest_dense_res = resolution // 16
conv_kernel_initializer = nn.initializers.ca()
class Downscale(nn.ModelBase):
@ -107,12 +108,12 @@ def initialize_archis(nn):
return x
class Inter(nn.ModelBase):
def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, is_hd=False, **kwargs):
self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch
def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, **kwargs):
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
super().__init__(**kwargs)
def on_build(self):
in_ch, lowest_dense_res, ae_ch, ae_out_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
self.dense1 = nn.Dense( in_ch, ae_ch )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
@ -121,7 +122,7 @@ def initialize_archis(nn):
def forward(self, inp):
x = self.dense1(inp)
x = self.dense2(x)
x = nn.tf_reshape_4D (x, self.lowest_dense_res, self.lowest_dense_res, self.ae_out_ch)
x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
x = self.upscale1(x)
return x
@ -202,11 +203,171 @@ def initialize_archis(nn):
return Encoder, Inter, Decoder
return lowest_dense_res, Encoder, Inter, Decoder
nn.get_ae_models = get_ae_models
def get_ae_models_chervoniy(resolution):
lowest_dense_res = resolution // 32
"""
by @chervoniy
"""
conv_kernel_initializer = nn.initializers.ca()
class Downscale(nn.ModelBase):
def __init__(self, in_ch, kernel_size=3, dilations=1, *kwargs ):
self.in_ch = in_ch
self.kernel_size = kernel_size
self.dilations = dilations
super().__init__(*kwargs)
def on_build(self, *args, **kwargs ):
self.conv_base1 = nn.Conv2D( self.in_ch, self.in_ch//2, kernel_size=1, strides=1, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
self.conv_l1 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=1, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
self.conv_l2 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
self.conv_base2 = nn.Conv2D( self.in_ch, self.in_ch//2, kernel_size=1, strides=1, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
self.conv_r1 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
self.pool_size = [1,1,2,2] if nn.data_format == 'NCHW' else [1,2,2,1]
def forward(self, x):
x_l = self.conv_base1(x)
x_l = self.conv_l1(x_l)
x_l = self.conv_l2(x_l)
x_r = self.conv_base2(x)
x_r = self.conv_r1(x_r)
x_pool = tf.nn.max_pool(x, ksize=self.pool_size, strides=self.pool_size, padding='SAME', data_format=nn.data_format)
x = tf.concat([x_l, x_r, x_pool], axis=nn.conv2d_ch_axis)
x = nn.tf_gelu(x)
return x
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.conv2 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.conv3 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.conv4 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
def forward(self, x):
x0 = self.conv1(x)
x1 = self.conv2(x0)
x2 = self.conv3(x1)
x3 = self.conv4(x2)
x = tf.concat([x0, x1, x2, x3], axis=nn.conv2d_ch_axis)
x = nn.tf_gelu(x)
x = nn.tf_depth_to_space(x, 2)
return x
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
def forward(self, inp):
x = self.conv1(inp)
x = tf.nn.leaky_relu(x, 0.2)
x = self.conv2(x)
x = tf.nn.leaky_relu(inp + x, 0.2)
return x
class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch, **kwargs):
self.conv0 = nn.Conv2D(in_ch, e_ch, kernel_size=3, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.down0 = Downscale(e_ch)
self.down1 = Downscale(e_ch*2)
self.down2 = Downscale(e_ch*4)
self.down3 = Downscale(e_ch*8)
self.down4 = Downscale(e_ch*16)
def forward(self, inp):
x = self.conv0(inp)
x = self.down0(x)
x = self.down1(x)
x = self.down2(x)
x = self.down3(x)
x = self.down4(x)
x = nn.tf_flatten(x)
return x
class Inter(nn.ModelBase):
def __init__(self, in_ch, ae_ch, ae_out_ch, **kwargs):
self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch
super().__init__(**kwargs)
def on_build(self, **kwargs):
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
self.dense_l = nn.Dense( in_ch, ae_ch//2, kernel_initializer=tf.initializers.orthogonal)
self.dense_r = nn.Dense( in_ch, ae_ch//2, maxout_features=4, kernel_initializer=tf.initializers.orthogonal)
self.dense = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * (ae_out_ch//2), kernel_initializer=tf.initializers.orthogonal)
self.upscale1 = Upscale(ae_out_ch//2, ae_out_ch//2)
def forward(self, inp):
x0 = self.dense_l(inp)
x1 = self.dense_r(inp)
x = tf.concat([x0, x1], axis=-1)
x = self.dense(x)
x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch//2)
x = self.upscale1(x)
return x
def get_out_ch(self):
return self.ae_out_ch//2
class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch, d_mask_ch, **kwargs):
self.upscale0 = Upscale(in_ch, d_ch*8)
self.upscale1 = Upscale(d_ch*8, d_ch*4)
self.upscale2 = Upscale(d_ch*4, d_ch*2)
self.upscale3 = Upscale(d_ch*2, d_ch)
self.out_conv = nn.Conv2D( d_ch, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
def get_weights_ex(self, include_mask):
# Call internal get_weights in order to initialize inner logic
self.get_weights()
weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() + self.upscale3.get_weights() + self.out_conv.get_weights()
if include_mask:
weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() + self.upscale3.get_weights() + self.out_convm.get_weights()
return weights
def forward(self, inp):
z = inp
x = self.upscale0(inp)
x = self.upscale1(x)
x = self.upscale2(x)
x = self.upscale3(x)
m = self.upscalem0(z)
m = self.upscalem1(m)
m = self.upscalem2(m)
m = self.upscalem3(m)
return tf.nn.sigmoid(self.out_conv(x)), \
tf.nn.sigmoid(self.out_convm(m))
return lowest_dense_res, Encoder, Inter, Decoder
def get_ae_models2():
nn.get_ae_models_chervoniy = get_ae_models_chervoniy
"""
def get_ae_models2():
conv_kernel_initializer = nn.initializers.ca()
class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, is_hd=False, *kwargs ):
@ -405,161 +566,4 @@ def initialize_archis(nn):
return Encoder, Inter, Decoder
nn.get_ae_models2 = get_ae_models2
def get_ae_models_chervoniy():
"""
by @chervoniy
"""
conv_kernel_initializer = nn.initializers.ca()
class Downscale(nn.ModelBase):
def __init__(self, in_ch, kernel_size=3, dilations=1, *kwargs ):
self.in_ch = in_ch
self.kernel_size = kernel_size
self.dilations = dilations
super().__init__(*kwargs)
def on_build(self, *args, **kwargs ):
self.conv_base1 = nn.Conv2D( self.in_ch, self.in_ch//2, kernel_size=1, strides=1, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
self.conv_l1 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=1, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
self.conv_l2 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
self.conv_base2 = nn.Conv2D( self.in_ch, self.in_ch//2, kernel_size=1, strides=1, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
self.conv_r1 = nn.Conv2D( self.in_ch//2, self.in_ch//2, kernel_size=self.kernel_size, strides=2, padding='SAME', dilations=self.dilations, kernel_initializer=conv_kernel_initializer)
self.pool_size = [1,1,2,2] if nn.data_format == 'NCHW' else [1,2,2,1]
def forward(self, x):
x_l = self.conv_base1(x)
x_l = self.conv_l1(x_l)
x_l = self.conv_l2(x_l)
x_r = self.conv_base2(x)
x_r = self.conv_r1(x_r)
x_pool = tf.nn.max_pool(x, ksize=self.pool_size, strides=self.pool_size, padding='SAME', data_format=nn.data_format)
x = tf.concat([x_l, x_r, x_pool], axis=nn.conv2d_ch_axis)
x = nn.tf_gelu(x)
return x
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.conv2 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.conv3 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.conv4 = nn.Conv2D( out_ch, out_ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
def forward(self, x):
x0 = self.conv1(x)
x1 = self.conv2(x0)
x2 = self.conv3(x1)
x3 = self.conv4(x2)
x = tf.concat([x0, x1, x2, x3], axis=nn.conv2d_ch_axis)
x = nn.tf_gelu(x)
x = nn.tf_depth_to_space(x, 2)
return x
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', kernel_initializer=conv_kernel_initializer)
def forward(self, inp):
x = self.conv1(inp)
x = tf.nn.leaky_relu(x, 0.2)
x = self.conv2(x)
x = tf.nn.leaky_relu(inp + x, 0.2)
return x
class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch):
self.conv0 = nn.Conv2D(in_ch, e_ch, kernel_size=3, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.down0 = Downscale(e_ch)
self.down1 = Downscale(e_ch*2)
self.down2 = Downscale(e_ch*4)
self.down3 = Downscale(e_ch*8)
self.down4 = Downscale(e_ch*16)
def forward(self, inp):
x = self.conv0(inp)
x = self.down0(x)
x = self.down1(x)
x = self.down2(x)
x = self.down3(x)
x = self.down4(x)
x = nn.tf_flatten(x)
return x
class Inter(nn.ModelBase):
def __init__(self, in_ch, lowest_dense_res, ae_ch, ae_out_ch, **kwargs):
self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch = in_ch, lowest_dense_res, ae_ch, ae_out_ch
super().__init__(**kwargs)
def on_build(self):
in_ch, lowest_dense_res, ae_ch, ae_out_ch = self.in_ch, self.lowest_dense_res, self.ae_ch, self.ae_out_ch
self.dense_l = nn.Dense( in_ch, ae_ch//2, kernel_initializer=tf.initializers.orthogonal)
self.dense_r = nn.Dense( in_ch, ae_ch//2, maxout_features=4, kernel_initializer=tf.initializers.orthogonal)
self.dense = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * (ae_out_ch//2), kernel_initializer=tf.initializers.orthogonal)
self.upscale1 = Upscale(ae_out_ch//2, ae_out_ch//2)
def forward(self, inp):
x0 = self.dense_l(inp)
x1 = self.dense_r(inp)
x = tf.concat([x0, x1], axis=-1)
x = self.dense(x)
x = nn.tf_reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch//2)
x = self.upscale1(x)
return x
def get_out_ch(self):
return self.ae_out_ch//2
class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch, d_mask_ch ):
self.upscale0 = Upscale(in_ch, d_ch*8)
self.upscale1 = Upscale(d_ch*8, d_ch*4)
self.upscale2 = Upscale(d_ch*4, d_ch*2)
self.upscale3 = Upscale(d_ch*2, d_ch)
self.out_conv = nn.Conv2D( d_ch, 3, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch, 1, kernel_size=1, padding='SAME', kernel_initializer=conv_kernel_initializer)
def get_weights_ex(self, include_mask):
# Call internal get_weights in order to initialize inner logic
self.get_weights()
weights = self.upscale0.get_weights() + self.upscale1.get_weights() + self.upscale2.get_weights() + self.upscale3.get_weights() + self.out_conv.get_weights()
if include_mask:
weights += self.upscalem0.get_weights() + self.upscalem1.get_weights() + self.upscalem2.get_weights() + self.upscale3.get_weights() + self.out_convm.get_weights()
return weights
def forward(self, inp):
z = inp
x = self.upscale0(inp)
x = self.upscale1(x)
x = self.upscale2(x)
x = self.upscale3(x)
m = self.upscalem0(z)
m = self.upscalem1(m)
m = self.upscalem2(m)
m = self.upscalem3(m)
return tf.nn.sigmoid(self.out_conv(x)), \
tf.nn.sigmoid(self.out_convm(m))
return Encoder, Inter, Decoder
return get_ae_models_chervoniy
"""

View file

@ -88,6 +88,10 @@ class nn():
IllumDiscriminator = None
CodeDiscriminator = None
# Arhis
get_ae_models = None
get_ae_models_chervoniy = None
@staticmethod
def initialize(device_config=None, floatx="float32", data_format="NHWC"):