refactoring

This commit is contained in:
iperov 2019-03-21 18:58:38 +04:00
commit 565af4d1da

View file

@ -44,34 +44,35 @@ class nnlib(object):
"""
keras = nnlib.keras
K = keras.backend
KL = keras.layers
Input = keras.layers.Input
Input = KL.Input
Dense = keras.layers.Dense
Conv2D = keras.layers.Conv2D
Conv2DTranspose = keras.layers.Conv2DTranspose
SeparableConv2D = keras.layers.SeparableConv2D
MaxPooling2D = keras.layers.MaxPooling2D
UpSampling2D = keras.layers.UpSampling2D
BatchNormalization = keras.layers.BatchNormalization
Dense = KL.Dense
Conv2D = KL.Conv2D
Conv2DTranspose = KL.Conv2DTranspose
SeparableConv2D = KL.SeparableConv2D
MaxPooling2D = KL.MaxPooling2D
UpSampling2D = KL.UpSampling2D
BatchNormalization = KL.BatchNormalization
LeakyReLU = keras.layers.LeakyReLU
ReLU = keras.layers.ReLU
PReLU = keras.layers.PReLU
tanh = keras.layers.Activation('tanh')
sigmoid = keras.layers.Activation('sigmoid')
Dropout = keras.layers.Dropout
Softmax = keras.layers.Softmax
LeakyReLU = KL.LeakyReLU
ReLU = KL.ReLU
PReLU = KL.PReLU
tanh = KL.Activation('tanh')
sigmoid = KL.Activation('sigmoid')
Dropout = KL.Dropout
Softmax = KL.Softmax
Lambda = keras.layers.Lambda
Add = keras.layers.Add
Concatenate = keras.layers.Concatenate
Lambda = KL.Lambda
Add = KL.Add
Concatenate = KL.Concatenate
Flatten = keras.layers.Flatten
Reshape = keras.layers.Reshape
Flatten = KL.Flatten
Reshape = KL.Reshape
ZeroPadding2D = keras.layers.ZeroPadding2D
ZeroPadding2D = KL.ZeroPadding2D
RandomNormal = keras.initializers.RandomNormal
Model = keras.models.Model
@ -86,6 +87,8 @@ dssim = nnlib.dssim
PixelShuffler = nnlib.PixelShuffler
SubpixelUpscaler = nnlib.SubpixelUpscaler
Scale = nnlib.Scale
Capsule = nnlib.Capsule
CAInitializerMP = nnlib.CAInitializerMP
#ReflectionPadding2D = nnlib.ReflectionPadding2D
@ -198,6 +201,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
def __initialize_keras_functions():
keras = nnlib.keras
K = keras.backend
KL = keras.layers
def modelify(model_functor):
def func(tensor):
@ -319,7 +323,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
nnlib.dssim = dssim
class PixelShuffler(keras.layers.Layer):
class PixelShuffler(KL.Layer):
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
super(PixelShuffler, self).__init__(**kwargs)
self.data_format = data_format
@ -401,7 +405,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
nnlib.PixelShuffler = PixelShuffler
nnlib.SubpixelUpscaler = PixelShuffler
class Scale(keras.layers.Layer):
class Scale(KL.Layer):
"""
GAN Custom Scal Layer
Code borrows from https://github.com/flyyufelix/cnn_finetune
@ -891,4 +895,3 @@ class CAInitializerMPSubprocessor(Subprocessor):
#override
def get_result(self):
return self.result