SAEHD: speed up for nvidia, duplicate code clean up

This commit is contained in:
Colombo 2019-10-08 21:02:20 +04:00
parent 627df082d7
commit 3f23135982
2 changed files with 187 additions and 171 deletions

View file

@ -112,7 +112,31 @@ class SAEv2Model(ModelBase):
self.true_face_training = self.options.get('true_face_training', False)
masked_training = True
class SAEDFModel(object):
class CommonModel(object):
def downscale (self, dim, kernel_size=5, dilation_rate=1, use_activator=True):
def func(x):
if not use_activator:
return SubpixelDownscaler()(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))
else:
return SubpixelDownscaler()(LeakyReLU(0.1)(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x)))
return func
def upscale (self, dim, size=(2,2)):
def func(x):
return SubpixelUpscaler(size=size)(LeakyReLU(0.1)(Conv2D(dim * np.prod(size) , kernel_size=3, strides=1, padding='same')(x)))
return func
def ResidualBlock(self, dim):
def func(inp):
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
x = LeakyReLU(0.2)(x)
x = Conv2D(dim, kernel_size=3, padding='same')(x)
x = Add()([x, inp])
x = LeakyReLU(0.2)(x)
return x
return func
class SAEDFModel(CommonModel):
def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
super().__init__()
self.learn_mask = learn_mask
@ -123,18 +147,7 @@ class SAEv2Model(ModelBase):
lowest_dense_res = resolution // 16
e_dims = output_nc*e_ch_dims
def downscale (dim, kernel_size=5, dilation_rate=1, use_activator=True):
def func(x):
if not use_activator:
return SubpixelDownscaler()(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))
else:
return SubpixelDownscaler()(LeakyReLU(0.1)(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x)))
return func
def upscale (dim, size=(2,2)):
def func(x):
return SubpixelUpscaler(size=size)(LeakyReLU(0.1)(Conv2D(dim * np.prod(size) , kernel_size=3, strides=1, padding='same')(x)))
return func
def enc_flow(e_ch_dims, ae_dims, lowest_dense_res):
dims = output_nc * e_ch_dims
@ -142,32 +155,32 @@ class SAEv2Model(ModelBase):
dims += 1
def func(inp):
x = downscale(dims , 3, 1 )(inp)
x = downscale(dims*2, 3, 1 )(x)
x = downscale(dims*4, 3, 1 )(x)
x0 = downscale(dims*8, 3, 1 )(x)
x = self.downscale(dims , 3, 1 )(inp)
x = self.downscale(dims*2, 3, 1 )(x)
x = self.downscale(dims*4, 3, 1 )(x)
x0 = self.downscale(dims*8, 3, 1 )(x)
x = downscale(dims , 5, 1 )(inp)
x = downscale(dims*2, 5, 1 )(x)
x = downscale(dims*4, 5, 1 )(x)
x1 = downscale(dims*8, 5, 1 )(x)
x = self.downscale(dims , 5, 1 )(inp)
x = self.downscale(dims*2, 5, 1 )(x)
x = self.downscale(dims*4, 5, 1 )(x)
x1 = self.downscale(dims*8, 5, 1 )(x)
x = downscale(dims , 5, 2 )(inp)
x = downscale(dims*2, 5, 2 )(x)
x = downscale(dims*4, 5, 2 )(x)
x2 = downscale(dims*8, 5, 2 )(x)
x = self.downscale(dims , 5, 2 )(inp)
x = self.downscale(dims*2, 5, 2 )(x)
x = self.downscale(dims*4, 5, 2 )(x)
x2 = self.downscale(dims*8, 5, 2 )(x)
x = downscale(dims , 7, 2 )(inp)
x = downscale(dims*2, 7, 2 )(x)
x = downscale(dims*4, 7, 2 )(x)
x3 = downscale(dims*8, 7, 2 )(x)
x = self.downscale(dims , 7, 2 )(inp)
x = self.downscale(dims*2, 7, 2 )(x)
x = self.downscale(dims*4, 7, 2 )(x)
x3 = self.downscale(dims*8, 7, 2 )(x)
x = Concatenate()([x0,x1,x2,x3])
x = Dense(ae_dims)(Flatten()(x))
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
x = upscale(ae_dims)(x)
x = self.upscale(ae_dims)(x)
return x
return func
@ -176,26 +189,18 @@ class SAEv2Model(ModelBase):
if dims % 2 != 0:
dims += 1
def ResidualBlock(dim):
def func(inp):
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
x = LeakyReLU(0.2)(x)
x = Conv2D(dim, kernel_size=3, padding='same')(x)
x = Add()([x, inp])
x = LeakyReLU(0.2)(x)
return x
return func
def func(x):
for i in [8,4,2]:
x = upscale(dims*i)(x)
x = self.upscale(dims*i)(x)
if not is_mask:
x0 = x
x = upscale( (dims*i)//2 )(x)
x = ResidualBlock( (dims*i)//2 )(x)
x = downscale( dims*i, use_activator=False ) (x)
x = self.upscale( (dims*i)//2 )(x)
x = self.ResidualBlock( (dims*i)//2 )(x)
x = self.downscale( dims*i, use_activator=False ) (x)
x = Add()([x, x0])
x = LeakyReLU(0.2)(x)
@ -243,7 +248,7 @@ class SAEv2Model(ModelBase):
[self.decoder_dstm, 'decoder_dstm.h5'] ]
return ar
class SAELIAEModel(object):
class SAELIAEModel(CommonModel):
def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
super().__init__()
self.learn_mask = learn_mask
@ -254,44 +259,31 @@ class SAEv2Model(ModelBase):
lowest_dense_res = resolution // 16
def downscale (dim, kernel_size=5, dilation_rate=1, use_activator=True):
def func(x):
if not use_activator:
return SubpixelDownscaler()(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))
else:
return SubpixelDownscaler()(LeakyReLU(0.1)(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x)))
return func
def upscale (dim):
def func(x):
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same')(x)))
return func
def enc_flow(e_ch_dims):
dims = output_nc*e_ch_dims
if dims % 2 != 0:
dims += 1
def func(inp):
x = downscale(dims , 3, 1 )(inp)
x = downscale(dims*2, 3, 1 )(x)
x = downscale(dims*4, 3, 1 )(x)
x0 = downscale(dims*8, 3, 1 )(x)
x = self.downscale(dims , 3, 1 )(inp)
x = self.downscale(dims*2, 3, 1 )(x)
x = self.downscale(dims*4, 3, 1 )(x)
x0 = self.downscale(dims*8, 3, 1 )(x)
x = downscale(dims , 5, 1 )(inp)
x = downscale(dims*2, 5, 1 )(x)
x = downscale(dims*4, 5, 1 )(x)
x1 = downscale(dims*8, 5, 1 )(x)
x = self.downscale(dims , 5, 1 )(inp)
x = self.downscale(dims*2, 5, 1 )(x)
x = self.downscale(dims*4, 5, 1 )(x)
x1 = self.downscale(dims*8, 5, 1 )(x)
x = downscale(dims , 5, 2 )(inp)
x = downscale(dims*2, 5, 2 )(x)
x = downscale(dims*4, 5, 2 )(x)
x2 = downscale(dims*8, 5, 2 )(x)
x = self.downscale(dims , 5, 2 )(inp)
x = self.downscale(dims*2, 5, 2 )(x)
x = self.downscale(dims*4, 5, 2 )(x)
x2 = self.downscale(dims*8, 5, 2 )(x)
x = downscale(dims , 7, 2 )(inp)
x = downscale(dims*2, 7, 2 )(x)
x = downscale(dims*4, 7, 2 )(x)
x3 = downscale(dims*8, 7, 2 )(x)
x = self.downscale(dims , 7, 2 )(inp)
x = self.downscale(dims*2, 7, 2 )(x)
x = self.downscale(dims*4, 7, 2 )(x)
x3 = self.downscale(dims*8, 7, 2 )(x)
x = Concatenate()([x0,x1,x2,x3])
@ -304,7 +296,7 @@ class SAEv2Model(ModelBase):
x = Dense(ae_dims)(x)
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
x = upscale(ae_dims*2)(x)
x = self.upscale(ae_dims*2)(x)
return x
return func
@ -313,26 +305,16 @@ class SAEv2Model(ModelBase):
if dims % 2 != 0:
dims += 1
def ResidualBlock(dim):
def func(inp):
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
x = LeakyReLU(0.2)(x)
x = Conv2D(dim, kernel_size=3, padding='same')(x)
x = Add()([x, inp])
x = LeakyReLU(0.2)(x)
return x
return func
def func(x):
for i in [8,4,2]:
x = upscale(dims*i)(x)
x = self.upscale(dims*i)(x)
if not is_mask:
x0 = x
x = upscale( (dims*i)//2 )(x)
x = ResidualBlock( (dims*i)//2 )(x)
x = downscale( dims*i, use_activator=False ) (x)
x = self.upscale( (dims*i)//2 )(x)
x = self.ResidualBlock( (dims*i)//2 )(x)
x = self.downscale( dims*i, use_activator=False ) (x)
x = Add()([x, x0])
x = LeakyReLU(0.2)(x)

View file

@ -474,46 +474,80 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
nnlib.PixelShuffler = PixelShuffler
nnlib.SubpixelUpscaler = PixelShuffler
class SubpixelDownscaler(KL.Layer):
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
super(SubpixelDownscaler, self).__init__(**kwargs)
self.data_format = data_format
self.size = size
if 'tensorflow' in backend:
class SubpixelDownscaler(KL.Layer):
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
super(SubpixelDownscaler, self).__init__(**kwargs)
self.data_format = data_format
self.size = size
def call(self, inputs):
def call(self, inputs):
input_shape = K.shape(inputs)
if K.int_shape(input_shape)[0] != 4:
raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs)))
input_shape = K.shape(inputs)
if K.int_shape(input_shape)[0] != 4:
raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs)))
batch_size, h, w, c = input_shape[0], input_shape[1], input_shape[2], K.int_shape(inputs)[-1]
rh, rw = self.size
oh, ow = h // rh, w // rw
oc = c * (rh * rw)
return K.tf.space_to_depth(inputs, self.size[0], 'NHWC')
out = K.reshape(inputs, (batch_size, oh, rh, ow, rw, c))
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
out = K.reshape(out, (batch_size, oh, ow, oc))
return out
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
height = input_shape[1] // self.size[0] if input_shape[1] is not None else None
width = input_shape[2] // self.size[1] if input_shape[2] is not None else None
channels = input_shape[3] * self.size[0] * self.size[1]
height = input_shape[1] // self.size[0] if input_shape[1] is not None else None
width = input_shape[2] // self.size[1] if input_shape[2] is not None else None
channels = input_shape[3] * self.size[0] * self.size[1]
return (input_shape[0], height, width, channels)
return (input_shape[0], height, width, channels)
def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(SubpixelDownscaler, self).get_config()
def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(SubpixelDownscaler, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
else:
class SubpixelDownscaler(KL.Layer):
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
super(SubpixelDownscaler, self).__init__(**kwargs)
self.data_format = data_format
self.size = size
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
input_shape = K.shape(inputs)
if K.int_shape(input_shape)[0] != 4:
raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs)))
batch_size, h, w, c = input_shape[0], input_shape[1], input_shape[2], K.int_shape(inputs)[-1]
rh, rw = self.size
oh, ow = h // rh, w // rw
oc = c * (rh * rw)
out = K.reshape(inputs, (batch_size, oh, rh, ow, rw, c))
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
out = K.reshape(out, (batch_size, oh, ow, oc))
return out
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
height = input_shape[1] // self.size[0] if input_shape[1] is not None else None
width = input_shape[2] // self.size[1] if input_shape[2] is not None else None
channels = input_shape[3] * self.size[0] * self.size[1]
return (input_shape[0], height, width, channels)
def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(SubpixelDownscaler, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
nnlib.SubpixelDownscaler = SubpixelDownscaler