changing SubpixelUpscaler to variable H,W dims,

tensorflow backend : using depth_to_space in SubpixelUpscaler, so training speed increased by 4%
This commit is contained in:
iperov 2019-03-28 17:55:42 +04:00
parent 689aefeb2f
commit 4683c362ac

View file

@ -321,84 +321,135 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
nnlib.dssim = dssim nnlib.dssim = dssim
class PixelShuffler(KL.Layer): if 'tensorflow' in backend:
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs): class PixelShuffler(keras.layers.Layer):
super(PixelShuffler, self).__init__(**kwargs) def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
self.data_format = data_format super(PixelShuffler, self).__init__(**kwargs)
self.size = size self.data_format = data_format
self.size = size
def call(self, inputs): def call(self, inputs):
input_shape = K.int_shape(inputs) input_shape = K.shape(inputs)
if len(input_shape) != 4: if K.int_shape(input_shape)[0] != 4:
raise ValueError('Inputs should have rank ' + raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs)))
str(4) +
'; Received input shape:', str(input_shape))
if self.data_format == 'channels_first': if self.data_format == 'channels_first':
batch_size, c, h, w = input_shape return K.tf.depth_to_space(inputs, self.size[0], 'NCHW')
if batch_size is None:
batch_size = -1
rh, rw = self.size
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w)) elif self.data_format == 'channels_last':
out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2)) return K.tf.depth_to_space(inputs, self.size[0], 'NHWC')
out = K.reshape(out, (batch_size, oc, oh, ow))
return out
elif self.data_format == 'channels_last': def compute_output_shape(self, input_shape):
batch_size, h, w, c = input_shape if len(input_shape) != 4:
if batch_size is None: raise ValueError('Inputs should have rank ' +
batch_size = -1 str(4) +
rh, rw = self.size '; Received input shape:', str(input_shape))
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc)) if self.data_format == 'channels_first':
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5)) height = input_shape[2] * self.size[0] if input_shape[2] is not None else None
out = K.reshape(out, (batch_size, oh, ow, oc)) width = input_shape[3] * self.size[1] if input_shape[3] is not None else None
return out channels = input_shape[1] // self.size[0] // self.size[1]
def compute_output_shape(self, input_shape): if channels * self.size[0] * self.size[1] != input_shape[1]:
raise ValueError('channels of input and size are incompatible')
if len(input_shape) != 4: return (input_shape[0],
raise ValueError('Inputs should have rank ' + channels,
str(4) + height,
'; Received input shape:', str(input_shape)) width)
if self.data_format == 'channels_first': elif self.data_format == 'channels_last':
height = input_shape[2] * self.size[0] if input_shape[2] is not None else None height = input_shape[1] * self.size[0] if input_shape[1] is not None else None
width = input_shape[3] * self.size[1] if input_shape[3] is not None else None width = input_shape[2] * self.size[1] if input_shape[2] is not None else None
channels = input_shape[1] // self.size[0] // self.size[1] channels = input_shape[3] // self.size[0] // self.size[1]
if channels * self.size[0] * self.size[1] != input_shape[1]: if channels * self.size[0] * self.size[1] != input_shape[3]:
raise ValueError('channels of input and size are incompatible') raise ValueError('channels of input and size are incompatible')
return (input_shape[0], return (input_shape[0],
channels, height,
height, width,
width) channels)
elif self.data_format == 'channels_last': def get_config(self):
height = input_shape[1] * self.size[0] if input_shape[1] is not None else None config = {'size': self.size,
width = input_shape[2] * self.size[1] if input_shape[2] is not None else None 'data_format': self.data_format}
channels = input_shape[3] // self.size[0] // self.size[1] base_config = super(PixelShuffler, self).get_config()
if channels * self.size[0] * self.size[1] != input_shape[3]: return dict(list(base_config.items()) + list(config.items()))
raise ValueError('channels of input and size are incompatible') else:
class PixelShuffler(KL.Layer):
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
super(PixelShuffler, self).__init__(**kwargs)
self.data_format = data_format
self.size = size
return (input_shape[0], def call(self, inputs):
height,
width,
channels)
def get_config(self): input_shape = K.shape(inputs)
config = {'size': self.size, if K.int_shape(input_shape)[0] != 4:
'data_format': self.data_format} raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs)))
base_config = super(PixelShuffler, self).get_config()
return dict(list(base_config.items()) + list(config.items())) if self.data_format == 'channels_first':
batch_size, c, h, w = input_shape[0], K.int_shape(inputs)[1], input_shape[2], input_shape[3]
rh, rw = self.size
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w))
out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2))
out = K.reshape(out, (batch_size, oc, oh, ow))
return out
elif self.data_format == 'channels_last':
batch_size, h, w, c = input_shape[0], input_shape[1], input_shape[2], K.int_shape(inputs)[-1]
rh, rw = self.size
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc))
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
out = K.reshape(out, (batch_size, oh, ow, oc))
return out
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
if self.data_format == 'channels_first':
height = input_shape[2] * self.size[0] if input_shape[2] is not None else None
width = input_shape[3] * self.size[1] if input_shape[3] is not None else None
channels = input_shape[1] // self.size[0] // self.size[1]
if channels * self.size[0] * self.size[1] != input_shape[1]:
raise ValueError('channels of input and size are incompatible')
return (input_shape[0],
channels,
height,
width)
elif self.data_format == 'channels_last':
height = input_shape[1] * self.size[0] if input_shape[1] is not None else None
width = input_shape[2] * self.size[1] if input_shape[2] is not None else None
channels = input_shape[3] // self.size[0] // self.size[1]
if channels * self.size[0] * self.size[1] != input_shape[3]:
raise ValueError('channels of input and size are incompatible')
return (input_shape[0],
height,
width,
channels)
def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(PixelShuffler, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
nnlib.PixelShuffler = PixelShuffler nnlib.PixelShuffler = PixelShuffler
nnlib.SubpixelUpscaler = PixelShuffler nnlib.SubpixelUpscaler = PixelShuffler