refactoring

This commit is contained in:
iperov 2019-03-21 18:58:38 +04:00
commit 565af4d1da

View file

@ -44,34 +44,35 @@ class nnlib(object):
""" """
keras = nnlib.keras keras = nnlib.keras
K = keras.backend K = keras.backend
KL = keras.layers
Input = keras.layers.Input Input = KL.Input
Dense = keras.layers.Dense Dense = KL.Dense
Conv2D = keras.layers.Conv2D Conv2D = KL.Conv2D
Conv2DTranspose = keras.layers.Conv2DTranspose Conv2DTranspose = KL.Conv2DTranspose
SeparableConv2D = keras.layers.SeparableConv2D SeparableConv2D = KL.SeparableConv2D
MaxPooling2D = keras.layers.MaxPooling2D MaxPooling2D = KL.MaxPooling2D
UpSampling2D = keras.layers.UpSampling2D UpSampling2D = KL.UpSampling2D
BatchNormalization = keras.layers.BatchNormalization BatchNormalization = KL.BatchNormalization
LeakyReLU = keras.layers.LeakyReLU LeakyReLU = KL.LeakyReLU
ReLU = keras.layers.ReLU ReLU = KL.ReLU
PReLU = keras.layers.PReLU PReLU = KL.PReLU
tanh = keras.layers.Activation('tanh') tanh = KL.Activation('tanh')
sigmoid = keras.layers.Activation('sigmoid') sigmoid = KL.Activation('sigmoid')
Dropout = keras.layers.Dropout Dropout = KL.Dropout
Softmax = keras.layers.Softmax Softmax = KL.Softmax
Lambda = keras.layers.Lambda Lambda = KL.Lambda
Add = keras.layers.Add Add = KL.Add
Concatenate = keras.layers.Concatenate Concatenate = KL.Concatenate
Flatten = keras.layers.Flatten Flatten = KL.Flatten
Reshape = keras.layers.Reshape Reshape = KL.Reshape
ZeroPadding2D = keras.layers.ZeroPadding2D ZeroPadding2D = KL.ZeroPadding2D
RandomNormal = keras.initializers.RandomNormal RandomNormal = keras.initializers.RandomNormal
Model = keras.models.Model Model = keras.models.Model
@ -86,6 +87,8 @@ dssim = nnlib.dssim
PixelShuffler = nnlib.PixelShuffler PixelShuffler = nnlib.PixelShuffler
SubpixelUpscaler = nnlib.SubpixelUpscaler SubpixelUpscaler = nnlib.SubpixelUpscaler
Scale = nnlib.Scale Scale = nnlib.Scale
Capsule = nnlib.Capsule
CAInitializerMP = nnlib.CAInitializerMP CAInitializerMP = nnlib.CAInitializerMP
#ReflectionPadding2D = nnlib.ReflectionPadding2D #ReflectionPadding2D = nnlib.ReflectionPadding2D
@ -198,6 +201,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
def __initialize_keras_functions(): def __initialize_keras_functions():
keras = nnlib.keras keras = nnlib.keras
K = keras.backend K = keras.backend
KL = keras.layers
def modelify(model_functor): def modelify(model_functor):
def func(tensor): def func(tensor):
@ -319,7 +323,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
nnlib.dssim = dssim nnlib.dssim = dssim
class PixelShuffler(keras.layers.Layer): class PixelShuffler(KL.Layer):
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs): def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
super(PixelShuffler, self).__init__(**kwargs) super(PixelShuffler, self).__init__(**kwargs)
self.data_format = data_format self.data_format = data_format
@ -401,7 +405,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
nnlib.PixelShuffler = PixelShuffler nnlib.PixelShuffler = PixelShuffler
nnlib.SubpixelUpscaler = PixelShuffler nnlib.SubpixelUpscaler = PixelShuffler
class Scale(keras.layers.Layer): class Scale(KL.Layer):
""" """
GAN Custom Scal Layer GAN Custom Scal Layer
Code borrows from https://github.com/flyyufelix/cnn_finetune Code borrows from https://github.com/flyyufelix/cnn_finetune
@ -891,4 +895,3 @@ class CAInitializerMPSubprocessor(Subprocessor):
#override #override
def get_result(self): def get_result(self):
return self.result return self.result