SAEHD: speed up for nvidia, duplicate code clean up

This commit is contained in:
Colombo 2019-10-08 21:02:20 +04:00
parent 627df082d7
commit 3f23135982
2 changed files with 187 additions and 171 deletions

View file

@ -140,7 +140,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
os.environ.pop('CUDA_VISIBLE_DEVICES')
os.environ['CUDA_CACHE_MAXSIZE'] = '536870912' #512Mb (32mb default)
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
@ -151,7 +151,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
import tensorflow as tf
nnlib.tf = tf
if device_config.cpu_only:
config = tf.ConfigProto(device_count={'GPU': 0})
else:
@ -473,50 +473,84 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
nnlib.PixelShuffler = PixelShuffler
nnlib.SubpixelUpscaler = PixelShuffler
class SubpixelDownscaler(KL.Layer):
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
super(SubpixelDownscaler, self).__init__(**kwargs)
self.data_format = data_format
self.size = size
def call(self, inputs):
if 'tensorflow' in backend:
class SubpixelDownscaler(KL.Layer):
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
super(SubpixelDownscaler, self).__init__(**kwargs)
self.data_format = data_format
self.size = size
input_shape = K.shape(inputs)
if K.int_shape(input_shape)[0] != 4:
raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs)))
def call(self, inputs):
batch_size, h, w, c = input_shape[0], input_shape[1], input_shape[2], K.int_shape(inputs)[-1]
rh, rw = self.size
oh, ow = h // rh, w // rw
oc = c * (rh * rw)
input_shape = K.shape(inputs)
if K.int_shape(input_shape)[0] != 4:
raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs)))
out = K.reshape(inputs, (batch_size, oh, rh, ow, rw, c))
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
out = K.reshape(out, (batch_size, oh, ow, oc))
return out
return K.tf.space_to_depth(inputs, self.size[0], 'NHWC')
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
height = input_shape[1] // self.size[0] if input_shape[1] is not None else None
width = input_shape[2] // self.size[1] if input_shape[2] is not None else None
channels = input_shape[3] * self.size[0] * self.size[1]
height = input_shape[1] // self.size[0] if input_shape[1] is not None else None
width = input_shape[2] // self.size[1] if input_shape[2] is not None else None
channels = input_shape[3] * self.size[0] * self.size[1]
return (input_shape[0], height, width, channels)
return (input_shape[0], height, width, channels)
def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(SubpixelDownscaler, self).get_config()
def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(SubpixelDownscaler, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
else:
class SubpixelDownscaler(KL.Layer):
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
super(SubpixelDownscaler, self).__init__(**kwargs)
self.data_format = data_format
self.size = size
def call(self, inputs):
input_shape = K.shape(inputs)
if K.int_shape(input_shape)[0] != 4:
raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs)))
batch_size, h, w, c = input_shape[0], input_shape[1], input_shape[2], K.int_shape(inputs)[-1]
rh, rw = self.size
oh, ow = h // rh, w // rw
oc = c * (rh * rw)
out = K.reshape(inputs, (batch_size, oh, rh, ow, rw, c))
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
out = K.reshape(out, (batch_size, oh, ow, oc))
return out
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
height = input_shape[1] // self.size[0] if input_shape[1] is not None else None
width = input_shape[2] // self.size[1] if input_shape[2] is not None else None
channels = input_shape[3] * self.size[0] * self.size[1]
return (input_shape[0], height, width, channels)
def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(SubpixelDownscaler, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
return dict(list(base_config.items()) + list(config.items()))
nnlib.SubpixelDownscaler = SubpixelDownscaler
class BlurPool(KL.Layer):
"""
https://arxiv.org/abs/1904.11486 https://github.com/adobe/antialiased-cnns