diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index c52829d..50003a6 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -321,84 +321,135 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator nnlib.dssim = dssim - class PixelShuffler(KL.Layer): - def __init__(self, size=(2, 2), data_format='channels_last', **kwargs): - super(PixelShuffler, self).__init__(**kwargs) - self.data_format = data_format - self.size = size + if 'tensorflow' in backend: + class PixelShuffler(keras.layers.Layer): + def __init__(self, size=(2, 2), data_format='channels_last', **kwargs): + super(PixelShuffler, self).__init__(**kwargs) + self.data_format = data_format + self.size = size - def call(self, inputs): - input_shape = K.int_shape(inputs) - if len(input_shape) != 4: - raise ValueError('Inputs should have rank ' + - str(4) + - '; Received input shape:', str(input_shape)) + def call(self, inputs): + input_shape = K.shape(inputs) + if K.int_shape(input_shape)[0] != 4: + raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs))) - if self.data_format == 'channels_first': - batch_size, c, h, w = input_shape - if batch_size is None: - batch_size = -1 - rh, rw = self.size - oh, ow = h * rh, w * rw - oc = c // (rh * rw) + if self.data_format == 'channels_first': + return K.tf.depth_to_space(inputs, self.size[0], 'NCHW') - out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w)) - out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2)) - out = K.reshape(out, (batch_size, oc, oh, ow)) - return out + elif self.data_format == 'channels_last': + return K.tf.depth_to_space(inputs, self.size[0], 'NHWC') - elif self.data_format == 'channels_last': - batch_size, h, w, c = input_shape - if batch_size is None: - batch_size = -1 - rh, rw = self.size - oh, ow = h * rh, w * rw - oc = c // (rh * rw) + def compute_output_shape(self, input_shape): + if len(input_shape) != 4: + raise ValueError('Inputs should have rank ' + + str(4) + + '; Received input shape:', str(input_shape)) - out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc)) - out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5)) - out = K.reshape(out, (batch_size, oh, ow, oc)) - return out + if self.data_format == 'channels_first': + height = input_shape[2] * self.size[0] if input_shape[2] is not None else None + width = input_shape[3] * self.size[1] if input_shape[3] is not None else None + channels = input_shape[1] // self.size[0] // self.size[1] - def compute_output_shape(self, input_shape): + if channels * self.size[0] * self.size[1] != input_shape[1]: + raise ValueError('channels of input and size are incompatible') - if len(input_shape) != 4: - raise ValueError('Inputs should have rank ' + - str(4) + - '; Received input shape:', str(input_shape)) + return (input_shape[0], + channels, + height, + width) - if self.data_format == 'channels_first': - height = input_shape[2] * self.size[0] if input_shape[2] is not None else None - width = input_shape[3] * self.size[1] if input_shape[3] is not None else None - channels = input_shape[1] // self.size[0] // self.size[1] + elif self.data_format == 'channels_last': + height = input_shape[1] * self.size[0] if input_shape[1] is not None else None + width = input_shape[2] * self.size[1] if input_shape[2] is not None else None + channels = input_shape[3] // self.size[0] // self.size[1] - if channels * self.size[0] * self.size[1] != input_shape[1]: - raise ValueError('channels of input and size are incompatible') + if channels * self.size[0] * self.size[1] != input_shape[3]: + raise ValueError('channels of input and size are incompatible') - return (input_shape[0], - channels, - height, - width) + return (input_shape[0], + height, + width, + channels) - elif self.data_format == 'channels_last': - height = input_shape[1] * self.size[0] if input_shape[1] is not None else None - width = input_shape[2] * self.size[1] if input_shape[2] is not None else None - channels = input_shape[3] // self.size[0] // self.size[1] + def get_config(self): + config = {'size': self.size, + 'data_format': self.data_format} + base_config = super(PixelShuffler, self).get_config() - if channels * self.size[0] * self.size[1] != input_shape[3]: - raise ValueError('channels of input and size are incompatible') + return dict(list(base_config.items()) + list(config.items())) + else: + class PixelShuffler(KL.Layer): + def __init__(self, size=(2, 2), data_format='channels_last', **kwargs): + super(PixelShuffler, self).__init__(**kwargs) + self.data_format = data_format + self.size = size - return (input_shape[0], - height, - width, - channels) + def call(self, inputs): - def get_config(self): - config = {'size': self.size, - 'data_format': self.data_format} - base_config = super(PixelShuffler, self).get_config() + input_shape = K.shape(inputs) + if K.int_shape(input_shape)[0] != 4: + raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs))) - return dict(list(base_config.items()) + list(config.items())) + if self.data_format == 'channels_first': + batch_size, c, h, w = input_shape[0], K.int_shape(inputs)[1], input_shape[2], input_shape[3] + rh, rw = self.size + oh, ow = h * rh, w * rw + oc = c // (rh * rw) + + out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w)) + out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2)) + out = K.reshape(out, (batch_size, oc, oh, ow)) + return out + + elif self.data_format == 'channels_last': + batch_size, h, w, c = input_shape[0], input_shape[1], input_shape[2], K.int_shape(inputs)[-1] + rh, rw = self.size + oh, ow = h * rh, w * rw + oc = c // (rh * rw) + + out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc)) + out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5)) + out = K.reshape(out, (batch_size, oh, ow, oc)) + return out + + def compute_output_shape(self, input_shape): + if len(input_shape) != 4: + raise ValueError('Inputs should have rank ' + + str(4) + + '; Received input shape:', str(input_shape)) + + if self.data_format == 'channels_first': + height = input_shape[2] * self.size[0] if input_shape[2] is not None else None + width = input_shape[3] * self.size[1] if input_shape[3] is not None else None + channels = input_shape[1] // self.size[0] // self.size[1] + + if channels * self.size[0] * self.size[1] != input_shape[1]: + raise ValueError('channels of input and size are incompatible') + + return (input_shape[0], + channels, + height, + width) + + elif self.data_format == 'channels_last': + height = input_shape[1] * self.size[0] if input_shape[1] is not None else None + width = input_shape[2] * self.size[1] if input_shape[2] is not None else None + channels = input_shape[3] // self.size[0] // self.size[1] + + if channels * self.size[0] * self.size[1] != input_shape[3]: + raise ValueError('channels of input and size are incompatible') + + return (input_shape[0], + height, + width, + channels) + + def get_config(self): + config = {'size': self.size, + 'data_format': self.data_format} + base_config = super(PixelShuffler, self).get_config() + + return dict(list(base_config.items()) + list(config.items())) nnlib.PixelShuffler = PixelShuffler nnlib.SubpixelUpscaler = PixelShuffler