mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 21:42:08 -07:00
changing SubpixelUpscaler to variable H,W dims,
tensorflow backend : using depth_to_space in SubpixelUpscaler, so training speed increased by 4%
This commit is contained in:
parent
689aefeb2f
commit
4683c362ac
1 changed files with 114 additions and 63 deletions
177
nnlib/nnlib.py
177
nnlib/nnlib.py
|
@ -321,84 +321,135 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
|
|
||||||
nnlib.dssim = dssim
|
nnlib.dssim = dssim
|
||||||
|
|
||||||
class PixelShuffler(KL.Layer):
|
if 'tensorflow' in backend:
|
||||||
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
|
class PixelShuffler(keras.layers.Layer):
|
||||||
super(PixelShuffler, self).__init__(**kwargs)
|
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
|
||||||
self.data_format = data_format
|
super(PixelShuffler, self).__init__(**kwargs)
|
||||||
self.size = size
|
self.data_format = data_format
|
||||||
|
self.size = size
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
input_shape = K.int_shape(inputs)
|
input_shape = K.shape(inputs)
|
||||||
if len(input_shape) != 4:
|
if K.int_shape(input_shape)[0] != 4:
|
||||||
raise ValueError('Inputs should have rank ' +
|
raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs)))
|
||||||
str(4) +
|
|
||||||
'; Received input shape:', str(input_shape))
|
|
||||||
|
|
||||||
if self.data_format == 'channels_first':
|
if self.data_format == 'channels_first':
|
||||||
batch_size, c, h, w = input_shape
|
return K.tf.depth_to_space(inputs, self.size[0], 'NCHW')
|
||||||
if batch_size is None:
|
|
||||||
batch_size = -1
|
|
||||||
rh, rw = self.size
|
|
||||||
oh, ow = h * rh, w * rw
|
|
||||||
oc = c // (rh * rw)
|
|
||||||
|
|
||||||
out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w))
|
elif self.data_format == 'channels_last':
|
||||||
out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2))
|
return K.tf.depth_to_space(inputs, self.size[0], 'NHWC')
|
||||||
out = K.reshape(out, (batch_size, oc, oh, ow))
|
|
||||||
return out
|
|
||||||
|
|
||||||
elif self.data_format == 'channels_last':
|
def compute_output_shape(self, input_shape):
|
||||||
batch_size, h, w, c = input_shape
|
if len(input_shape) != 4:
|
||||||
if batch_size is None:
|
raise ValueError('Inputs should have rank ' +
|
||||||
batch_size = -1
|
str(4) +
|
||||||
rh, rw = self.size
|
'; Received input shape:', str(input_shape))
|
||||||
oh, ow = h * rh, w * rw
|
|
||||||
oc = c // (rh * rw)
|
|
||||||
|
|
||||||
out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc))
|
if self.data_format == 'channels_first':
|
||||||
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
|
height = input_shape[2] * self.size[0] if input_shape[2] is not None else None
|
||||||
out = K.reshape(out, (batch_size, oh, ow, oc))
|
width = input_shape[3] * self.size[1] if input_shape[3] is not None else None
|
||||||
return out
|
channels = input_shape[1] // self.size[0] // self.size[1]
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
if channels * self.size[0] * self.size[1] != input_shape[1]:
|
||||||
|
raise ValueError('channels of input and size are incompatible')
|
||||||
|
|
||||||
if len(input_shape) != 4:
|
return (input_shape[0],
|
||||||
raise ValueError('Inputs should have rank ' +
|
channels,
|
||||||
str(4) +
|
height,
|
||||||
'; Received input shape:', str(input_shape))
|
width)
|
||||||
|
|
||||||
if self.data_format == 'channels_first':
|
elif self.data_format == 'channels_last':
|
||||||
height = input_shape[2] * self.size[0] if input_shape[2] is not None else None
|
height = input_shape[1] * self.size[0] if input_shape[1] is not None else None
|
||||||
width = input_shape[3] * self.size[1] if input_shape[3] is not None else None
|
width = input_shape[2] * self.size[1] if input_shape[2] is not None else None
|
||||||
channels = input_shape[1] // self.size[0] // self.size[1]
|
channels = input_shape[3] // self.size[0] // self.size[1]
|
||||||
|
|
||||||
if channels * self.size[0] * self.size[1] != input_shape[1]:
|
if channels * self.size[0] * self.size[1] != input_shape[3]:
|
||||||
raise ValueError('channels of input and size are incompatible')
|
raise ValueError('channels of input and size are incompatible')
|
||||||
|
|
||||||
return (input_shape[0],
|
return (input_shape[0],
|
||||||
channels,
|
height,
|
||||||
height,
|
width,
|
||||||
width)
|
channels)
|
||||||
|
|
||||||
elif self.data_format == 'channels_last':
|
def get_config(self):
|
||||||
height = input_shape[1] * self.size[0] if input_shape[1] is not None else None
|
config = {'size': self.size,
|
||||||
width = input_shape[2] * self.size[1] if input_shape[2] is not None else None
|
'data_format': self.data_format}
|
||||||
channels = input_shape[3] // self.size[0] // self.size[1]
|
base_config = super(PixelShuffler, self).get_config()
|
||||||
|
|
||||||
if channels * self.size[0] * self.size[1] != input_shape[3]:
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
raise ValueError('channels of input and size are incompatible')
|
else:
|
||||||
|
class PixelShuffler(KL.Layer):
|
||||||
|
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
|
||||||
|
super(PixelShuffler, self).__init__(**kwargs)
|
||||||
|
self.data_format = data_format
|
||||||
|
self.size = size
|
||||||
|
|
||||||
return (input_shape[0],
|
def call(self, inputs):
|
||||||
height,
|
|
||||||
width,
|
|
||||||
channels)
|
|
||||||
|
|
||||||
def get_config(self):
|
input_shape = K.shape(inputs)
|
||||||
config = {'size': self.size,
|
if K.int_shape(input_shape)[0] != 4:
|
||||||
'data_format': self.data_format}
|
raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs)))
|
||||||
base_config = super(PixelShuffler, self).get_config()
|
|
||||||
|
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
if self.data_format == 'channels_first':
|
||||||
|
batch_size, c, h, w = input_shape[0], K.int_shape(inputs)[1], input_shape[2], input_shape[3]
|
||||||
|
rh, rw = self.size
|
||||||
|
oh, ow = h * rh, w * rw
|
||||||
|
oc = c // (rh * rw)
|
||||||
|
|
||||||
|
out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w))
|
||||||
|
out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2))
|
||||||
|
out = K.reshape(out, (batch_size, oc, oh, ow))
|
||||||
|
return out
|
||||||
|
|
||||||
|
elif self.data_format == 'channels_last':
|
||||||
|
batch_size, h, w, c = input_shape[0], input_shape[1], input_shape[2], K.int_shape(inputs)[-1]
|
||||||
|
rh, rw = self.size
|
||||||
|
oh, ow = h * rh, w * rw
|
||||||
|
oc = c // (rh * rw)
|
||||||
|
|
||||||
|
out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc))
|
||||||
|
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
|
||||||
|
out = K.reshape(out, (batch_size, oh, ow, oc))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
if len(input_shape) != 4:
|
||||||
|
raise ValueError('Inputs should have rank ' +
|
||||||
|
str(4) +
|
||||||
|
'; Received input shape:', str(input_shape))
|
||||||
|
|
||||||
|
if self.data_format == 'channels_first':
|
||||||
|
height = input_shape[2] * self.size[0] if input_shape[2] is not None else None
|
||||||
|
width = input_shape[3] * self.size[1] if input_shape[3] is not None else None
|
||||||
|
channels = input_shape[1] // self.size[0] // self.size[1]
|
||||||
|
|
||||||
|
if channels * self.size[0] * self.size[1] != input_shape[1]:
|
||||||
|
raise ValueError('channels of input and size are incompatible')
|
||||||
|
|
||||||
|
return (input_shape[0],
|
||||||
|
channels,
|
||||||
|
height,
|
||||||
|
width)
|
||||||
|
|
||||||
|
elif self.data_format == 'channels_last':
|
||||||
|
height = input_shape[1] * self.size[0] if input_shape[1] is not None else None
|
||||||
|
width = input_shape[2] * self.size[1] if input_shape[2] is not None else None
|
||||||
|
channels = input_shape[3] // self.size[0] // self.size[1]
|
||||||
|
|
||||||
|
if channels * self.size[0] * self.size[1] != input_shape[3]:
|
||||||
|
raise ValueError('channels of input and size are incompatible')
|
||||||
|
|
||||||
|
return (input_shape[0],
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
channels)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {'size': self.size,
|
||||||
|
'data_format': self.data_format}
|
||||||
|
base_config = super(PixelShuffler, self).get_config()
|
||||||
|
|
||||||
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
nnlib.PixelShuffler = PixelShuffler
|
nnlib.PixelShuffler = PixelShuffler
|
||||||
nnlib.SubpixelUpscaler = PixelShuffler
|
nnlib.SubpixelUpscaler = PixelShuffler
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue