diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index dacc66e..c86ba5b 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -166,6 +166,9 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1': suppressor = std_utils.suppress_stdout_stderr().__enter__() + #if "tensorflow" in device_config.backend: + # nnlib.keras = nnlib.tf.keras + #else: import keras as keras_ nnlib.keras = keras_ @@ -318,13 +321,12 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator nnlib.dssim = dssim class PixelShuffler(keras.layers.Layer): - def __init__(self, size=(2, 2), data_format=None, **kwargs): + def __init__(self, size=(2, 2), data_format='channels_last', **kwargs): super(PixelShuffler, self).__init__(**kwargs) - self.data_format = K.normalize_data_format(data_format) - self.size = keras.utils.conv_utils.normalize_tuple(size, 2, 'size') + self.data_format = data_format + self.size = size def call(self, inputs): - input_shape = K.int_shape(inputs) if len(input_shape) != 4: raise ValueError('Inputs should have rank ' + @@ -395,7 +397,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator 'data_format': self.data_format} base_config = super(PixelShuffler, self).get_config() - return dict(list(base_config.items()) + list(config.items())) + return dict(list(base_config.items()) + list(config.items())) nnlib.PixelShuffler = PixelShuffler nnlib.SubpixelUpscaler = PixelShuffler