mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
nnlib: implemented ReflectionPadding2D for plaidML,
wrapping Conv2D with new padding param
This commit is contained in:
parent
673d4915d9
commit
4f4447d719
1 changed files with 80 additions and 16 deletions
|
@ -16,6 +16,8 @@ class nnlib(object):
|
||||||
DeviceConfig = device.Config
|
DeviceConfig = device.Config
|
||||||
active_DeviceConfig = DeviceConfig() #default is one best GPU
|
active_DeviceConfig = DeviceConfig() #default is one best GPU
|
||||||
|
|
||||||
|
backend = ""
|
||||||
|
|
||||||
dlib = None
|
dlib = None
|
||||||
|
|
||||||
keras = None
|
keras = None
|
||||||
|
@ -49,7 +51,7 @@ KL = keras.layers
|
||||||
Input = KL.Input
|
Input = KL.Input
|
||||||
|
|
||||||
Dense = KL.Dense
|
Dense = KL.Dense
|
||||||
Conv2D = KL.Conv2D
|
Conv2D = nnlib.Conv2D
|
||||||
Conv2DTranspose = KL.Conv2DTranspose
|
Conv2DTranspose = KL.Conv2DTranspose
|
||||||
SeparableConv2D = KL.SeparableConv2D
|
SeparableConv2D = KL.SeparableConv2D
|
||||||
MaxPooling2D = KL.MaxPooling2D
|
MaxPooling2D = KL.MaxPooling2D
|
||||||
|
@ -158,22 +160,21 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
if nnlib.keras is not None:
|
if nnlib.keras is not None:
|
||||||
return nnlib.code_import_keras
|
return nnlib.code_import_keras
|
||||||
|
|
||||||
if "tensorflow" in device_config.backend:
|
nnlib.backend = device_config.backend
|
||||||
|
|
||||||
|
if "tensorflow" in nnlib.backend:
|
||||||
nnlib._import_tf(device_config)
|
nnlib._import_tf(device_config)
|
||||||
elif device_config.backend == "plaidML":
|
elif nnlib.backend == "plaidML":
|
||||||
os.environ["KERAS_BACKEND"] = "plaidml.keras.backend"
|
os.environ["KERAS_BACKEND"] = "plaidml.keras.backend"
|
||||||
os.environ["PLAIDML_DEVICE_IDS"] = ",".join ( [ nnlib.device.getDeviceID(idx) for idx in device_config.gpu_idxs] )
|
os.environ["PLAIDML_DEVICE_IDS"] = ",".join ( [ nnlib.device.getDeviceID(idx) for idx in device_config.gpu_idxs] )
|
||||||
|
|
||||||
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
|
#if "tensorflow" in nnlib.backend:
|
||||||
suppressor = std_utils.suppress_stdout_stderr().__enter__()
|
|
||||||
|
|
||||||
#if "tensorflow" in device_config.backend:
|
|
||||||
# nnlib.keras = nnlib.tf.keras
|
# nnlib.keras = nnlib.tf.keras
|
||||||
#else:
|
#else:
|
||||||
import keras as keras_
|
import keras as keras_
|
||||||
nnlib.keras = keras_
|
nnlib.keras = keras_
|
||||||
|
|
||||||
if device_config.backend == "plaidML":
|
if nnlib.backend == "plaidML":
|
||||||
import plaidml
|
import plaidml
|
||||||
import plaidml.tile
|
import plaidml.tile
|
||||||
nnlib.PML = plaidml
|
nnlib.PML = plaidml
|
||||||
|
@ -183,14 +184,11 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
if device_config.use_fp16:
|
if device_config.use_fp16:
|
||||||
nnlib.keras.backend.set_floatx('float16')
|
nnlib.keras.backend.set_floatx('float16')
|
||||||
|
|
||||||
if "tensorflow" in device_config.backend:
|
if "tensorflow" in nnlib.backend:
|
||||||
nnlib.keras.backend.set_session(nnlib.tf_sess)
|
nnlib.keras.backend.set_session(nnlib.tf_sess)
|
||||||
|
|
||||||
nnlib.keras.backend.set_image_data_format('channels_last')
|
nnlib.keras.backend.set_image_data_format('channels_last')
|
||||||
|
|
||||||
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
|
|
||||||
suppressor.__exit__()
|
|
||||||
|
|
||||||
nnlib.code_import_keras = compile (nnlib.code_import_keras_string,'','exec')
|
nnlib.code_import_keras = compile (nnlib.code_import_keras_string,'','exec')
|
||||||
nnlib.__initialize_keras_functions()
|
nnlib.__initialize_keras_functions()
|
||||||
|
|
||||||
|
@ -201,6 +199,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
keras = nnlib.keras
|
keras = nnlib.keras
|
||||||
K = keras.backend
|
K = keras.backend
|
||||||
KL = keras.layers
|
KL = keras.layers
|
||||||
|
backend = nnlib.backend
|
||||||
|
|
||||||
def modelify(model_functor):
|
def modelify(model_functor):
|
||||||
def func(tensor):
|
def func(tensor):
|
||||||
|
@ -547,8 +546,52 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
nnlib.CAInitializerMP = CAInitializerMP
|
nnlib.CAInitializerMP = CAInitializerMP
|
||||||
|
|
||||||
|
|
||||||
'''
|
if backend == "plaidML":
|
||||||
not implemented in plaidML
|
class TileOP_ReflectionPadding2D(nnlib.PMLTile.Operation):
|
||||||
|
def __init__(self, input, w_pad, h_pad):
|
||||||
|
if K.image_data_format() == 'channels_last':
|
||||||
|
if input.shape.ndims == 4:
|
||||||
|
H, W = input.shape.dims[1:3]
|
||||||
|
if (type(H) == int and h_pad >= H) or \
|
||||||
|
(type(W) == int and w_pad >= W):
|
||||||
|
raise ValueError("Paddings must be less than dimensions.")
|
||||||
|
|
||||||
|
c = """ function (I[B, H, W, C] ) -> (O) {{
|
||||||
|
WE = W + {w_pad}*2;
|
||||||
|
HE = H + {h_pad}*2;
|
||||||
|
""".format(h_pad=h_pad, w_pad=w_pad)
|
||||||
|
if w_pad > 0:
|
||||||
|
c += """
|
||||||
|
LEFT_PAD [b, h, w , c : B, H, WE, C ] = =(I[b, h, {w_pad}-w, c]), w < {w_pad} ;
|
||||||
|
HCENTER [b, h, w , c : B, H, WE, C ] = =(I[b, h, w-{w_pad}, c]), w < W+{w_pad}-1 ;
|
||||||
|
RIGHT_PAD[b, h, w , c : B, H, WE, C ] = =(I[b, h, 2*W - (w-{w_pad}) -2, c]);
|
||||||
|
LCR = LEFT_PAD+HCENTER+RIGHT_PAD;
|
||||||
|
""".format(h_pad=h_pad, w_pad=w_pad)
|
||||||
|
else:
|
||||||
|
c += "LCR = I;"
|
||||||
|
|
||||||
|
if h_pad > 0:
|
||||||
|
c += """
|
||||||
|
TOP_PAD [b, h, w , c : B, HE, WE, C ] = =(LCR[b, {h_pad}-h, w, c]), h < {h_pad};
|
||||||
|
VCENTER [b, h, w , c : B, HE, WE, C ] = =(LCR[b, h-{h_pad}, w, c]), h < H+{h_pad}-1 ;
|
||||||
|
BOTTOM_PAD[b, h, w , c : B, HE, WE, C ] = =(LCR[b, 2*H - (h-{h_pad}) -2, w, c]);
|
||||||
|
TVB = TOP_PAD+VCENTER+BOTTOM_PAD;
|
||||||
|
""".format(h_pad=h_pad, w_pad=w_pad)
|
||||||
|
else:
|
||||||
|
c += "TVB = LCR;"
|
||||||
|
|
||||||
|
c += "O = TVB; }"
|
||||||
|
|
||||||
|
inp_dims = input.shape.dims
|
||||||
|
out_dims = (inp_dims[0], inp_dims[1]+h_pad*2, inp_dims[2]+w_pad*2, inp_dims[3])
|
||||||
|
else:
|
||||||
|
raise NotImplemented
|
||||||
|
else:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
super(TileOP_ReflectionPadding2D, self).__init__(c, [('I', input) ],
|
||||||
|
[('O', nnlib.PMLTile.Shape(input.shape.dtype, out_dims ) )])
|
||||||
|
|
||||||
class ReflectionPadding2D(keras.layers.Layer):
|
class ReflectionPadding2D(keras.layers.Layer):
|
||||||
def __init__(self, padding=(1, 1), **kwargs):
|
def __init__(self, padding=(1, 1), **kwargs):
|
||||||
self.padding = tuple(padding)
|
self.padding = tuple(padding)
|
||||||
|
@ -561,11 +604,32 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
|
|
||||||
def call(self, x, mask=None):
|
def call(self, x, mask=None):
|
||||||
w_pad,h_pad = self.padding
|
w_pad,h_pad = self.padding
|
||||||
return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')
|
if "tensorflow" in backend:
|
||||||
|
return K.tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')
|
||||||
|
elif backend == "plaidML":
|
||||||
|
return TileOP_ReflectionPadding2D.function(x, self.padding[0], self.padding[1])
|
||||||
|
|
||||||
nnlib.ReflectionPadding2D = ReflectionPadding2D
|
nnlib.ReflectionPadding2D = ReflectionPadding2D
|
||||||
'''
|
|
||||||
|
|
||||||
|
class Conv2D():
|
||||||
|
def __init__ (self, *args, **kwargs):
|
||||||
|
self.reflect_pad = False
|
||||||
|
padding = kwargs.get('padding','')
|
||||||
|
if padding == 'zero':
|
||||||
|
kwargs['padding'] = 'same'
|
||||||
|
if padding == 'reflect':
|
||||||
|
kernel_size = kwargs['kernel_size']
|
||||||
|
if (kernel_size % 2) == 1:
|
||||||
|
self.pad = (kernel_size // 2,)*2
|
||||||
|
kwargs['padding'] = 'valid'
|
||||||
|
self.reflect_pad = True
|
||||||
|
self.func = keras.layers.Conv2D (*args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self,x):
|
||||||
|
if self.reflect_pad:
|
||||||
|
x = ReflectionPadding2D( self.pad ) (x)
|
||||||
|
return self.func(x)
|
||||||
|
nnlib.Conv2D = Conv2D
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def import_keras_contrib(device_config):
|
def import_keras_contrib(device_config):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue