DeepFaceLab/core/leras/tensor_ops.py
Colombo 76ca79216e Upgraded to TF version 1.13.2
Removed the wait at first launch for most graphics cards.

Increased speed of training by 10-20%, but you have to retrain all models from scratch.

SAEHD:

added option 'use float16'
	Experimental option. Reduces the model size by half.
	Increases the speed of training.
	Decreases the accuracy of the model.
	The model may collapse or not train.
	Model may not learn the mask in large resolutions.

true_face_training option is replaced by
"True face power". 0.0000 .. 1.0
Experimental option. Discriminates the result face to be more like the src face. Higher value - stronger discrimination.
Comparison - https://i.imgur.com/czScS9q.png
2020-01-25 21:58:19 +04:00

332 lines
No EOL
13 KiB
Python

import numpy as np
def initialize_tensor_ops(nn):
tf = nn.tf
from tensorflow.python.ops import array_ops, random_ops, math_ops, sparse_ops, gradients
from tensorflow.python.framework import sparse_tensor
def tf_get_value(tensor):
return nn.tf_sess.run (tensor)
nn.tf_get_value = tf_get_value
def tf_batch_set_value(tuples):
if len(tuples) != 0:
with nn.tf.device('/CPU:0'):
assign_ops = []
feed_dict = {}
for x, value in tuples:
if isinstance(value, nn.tf.Operation):
assign_ops.append(value)
else:
value = np.asarray(value, dtype=x.dtype.as_numpy_dtype)
assign_placeholder = nn.tf.placeholder( x.dtype.base_dtype, shape=[None]*value.ndim )
assign_op = nn.tf.assign (x, assign_placeholder )
assign_ops.append(assign_op)
feed_dict[assign_placeholder] = value
nn.tf_sess.run(assign_ops, feed_dict=feed_dict)
nn.tf_batch_set_value = tf_batch_set_value
def tf_gradients ( loss, vars ):
grads = gradients.gradients(loss, vars, colocate_gradients_with_ops=True )
gv = [*zip(grads,vars)]
for g,v in gv:
if g is None:
raise Exception("No gradient for variable {v.name}")
return gv
nn.tf_gradients = tf_gradients
def tf_average_gv_list(grad_var_list, tf_device_string=None):
if len(grad_var_list) == 1:
return grad_var_list[0]
e = tf.device(tf_device_string) if tf_device_string is not None else None
if e is not None: e.__enter__()
result = []
for i, (gv) in enumerate(grad_var_list):
for j,(g,v) in enumerate(gv):
g = tf.expand_dims(g, 0)
if i == 0:
result += [ [[g], v] ]
else:
result[j][0] += [g]
for i,(gs,v) in enumerate(result):
result[i] = ( tf.reduce_mean( tf.concat (gs, 0), 0 ), v )
if e is not None: e.__exit__(None,None,None)
return result
nn.tf_average_gv_list = tf_average_gv_list
def tf_average_tensor_list(tensors_list, tf_device_string=None):
if len(tensors_list) == 1:
return tensors_list[0]
e = tf.device(tf_device_string) if tf_device_string is not None else None
if e is not None: e.__enter__()
result = tf.reduce_mean(tf.concat ([tf.expand_dims(t, 0) for t in tensors_list], 0), 0)
if e is not None: e.__exit__(None,None,None)
return result
nn.tf_average_tensor_list = tf_average_tensor_list
def tf_concat (tensors_list, axis):
"""
Better version.
"""
if len(tensors_list) == 1:
return tensors_list[0]
return tf.concat(tensors_list, axis)
nn.tf_concat = tf_concat
def tf_gelu(x):
cdf = 0.5 * (1.0 + tf.nn.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
nn.tf_gelu = tf_gelu
def tf_upsample2d(x, size=2):
if nn.data_format == "NCHW":
b,c,h,w = x.shape.as_list()
x = tf.reshape (x, (-1,c,h,1,w,1) )
x = tf.tile(x, (1,1,1,size,1,size) )
x = tf.reshape (x, (-1,c,h*size,w*size) )
return x
else:
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
nn.tf_upsample2d = tf_upsample2d
def tf_upsample2d_bilinear(x, size=2):
return tf.image.resize_images(x, (x.shape[1]*size, x.shape[2]*size) )
nn.tf_upsample2d_bilinear = tf_upsample2d_bilinear
def tf_flatten(x):
if nn.data_format == "NHWC":
# match NCHW version in order to switch data_format without problems
x = tf.transpose(x, (0,3,1,2) )
return tf.reshape (x, (-1, np.prod(x.shape[1:])) )
nn.tf_flatten = tf_flatten
def tf_reshape_4D(x, w,h,c):
if nn.data_format == "NHWC":
# match NCHW version in order to switch data_format without problems
x = tf.reshape (x, (-1,c,h,w))
x = tf.transpose(x, (0,2,3,1) )
return x
else:
return tf.reshape (x, (-1,c,h,w))
nn.tf_reshape_4D = tf_reshape_4D
def tf_random_binomial(shape, p=0.0, dtype=None, seed=None):
if dtype is None:
dtype=tf.float32
if seed is None:
seed = np.random.randint(10e6)
return array_ops.where(
random_ops.random_uniform(shape, dtype=tf.float16, seed=seed) < p,
array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
nn.tf_random_binomial = tf_random_binomial
def tf_gaussian_blur(input, radius=2.0):
def gaussian(x, mu, sigma):
return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2))
def make_kernel(sigma):
kernel_size = max(3, int(2 * 2 * sigma + 1))
mean = np.floor(0.5 * kernel_size)
kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)])
np_kernel = np.outer(kernel_1d, kernel_1d).astype(np.float32)
kernel = np_kernel / np.sum(np_kernel)
return kernel, kernel_size
gauss_kernel, kernel_size = make_kernel(radius)
padding = kernel_size//2
if padding != 0:
if nn.data_format == "NHWC":
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
else:
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
else:
padding = None
gauss_kernel = gauss_kernel[:,:,None,None]
outputs = []
for i in range(input.shape[nn.conv2d_ch_axis]):
x = input[:,:,:,i:i+1] if nn.data_format == "NHWC" \
else input[:,i:i+1,:,:]
if padding is not None:
x = tf.pad (x, padding)
outputs += [ tf.nn.conv2d(x, tf.constant(gauss_kernel, dtype=input.dtype ), strides=[1,1,1,1], padding="VALID", data_format=nn.data_format) ]
return tf.concat (outputs, axis=nn.conv2d_ch_axis)
nn.tf_gaussian_blur = tf_gaussian_blur
def tf_style_loss(target, style, gaussian_blur_radius=0.0, loss_weight=1.0, step_size=1):
def sd(content, style, loss_weight):
content_nc = content.shape[ nn.conv2d_ch_axis ]
style_nc = style.shape[nn.conv2d_ch_axis]
if content_nc != style_nc:
raise Exception("style_loss() content_nc != style_nc")
c_mean, c_var = tf.nn.moments(content, axes=nn.conv2d_spatial_axes, keep_dims=True)
s_mean, s_var = tf.nn.moments(style, axes=nn.conv2d_spatial_axes, keep_dims=True)
c_std, s_std = tf.sqrt(c_var + 1e-5), tf.sqrt(s_var + 1e-5)
mean_loss = tf.reduce_sum(tf.square(c_mean-s_mean), axis=[1,2,3])
std_loss = tf.reduce_sum(tf.square(c_std-s_std), axis=[1,2,3])
return (mean_loss + std_loss) * ( loss_weight / content_nc.value )
if gaussian_blur_radius > 0.0:
target = tf_gaussian_blur(target, gaussian_blur_radius)
style = tf_gaussian_blur(style, gaussian_blur_radius)
return sd( target, style, loss_weight=loss_weight )
nn.tf_style_loss = tf_style_loss
def tf_dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
if img1.dtype != img2.dtype:
raise ValueError("img1.dtype != img2.dtype")
not_float32 = img1.dtype != tf.float32
if not_float32:
img_dtype = img1.dtype
img1 = tf.cast(img1, tf.float32)
img2 = tf.cast(img2, tf.float32)
kernel = np.arange(0, filter_size, dtype=np.float32)
kernel -= (filter_size - 1 ) / 2.0
kernel = kernel**2
kernel *= ( -0.5 / (filter_sigma**2) )
kernel = np.reshape (kernel, (1,-1)) + np.reshape(kernel, (-1,1) )
kernel = tf.constant ( np.reshape (kernel, (1,-1)), dtype=tf.float32 )
kernel = tf.nn.softmax(kernel)
kernel = tf.reshape (kernel, (filter_size, filter_size, 1, 1))
kernel = tf.tile (kernel, (1,1, img1.shape[ nn.conv2d_ch_axis ] ,1))
def reducer(x):
return tf.nn.depthwise_conv2d(x, kernel, strides=[1,1,1,1], padding='VALID', data_format=nn.data_format)
c1 = (k1 * max_val) ** 2
c2 = (k2 * max_val) ** 2
mean0 = reducer(img1)
mean1 = reducer(img2)
num0 = mean0 * mean1 * 2.0
den0 = tf.square(mean0) + tf.square(mean1)
luminance = (num0 + c1) / (den0 + c1)
num1 = reducer(img1 * img2) * 2.0
den1 = reducer(tf.square(img1) + tf.square(img2))
c2 *= 1.0 #compensation factor
cs = (num1 - num0 + c2) / (den1 - den0 + c2)
ssim_val = tf.reduce_mean(luminance * cs, axis=nn.conv2d_spatial_axes )
dssim = (1.0 - ssim_val ) / 2.0
if not_float32:
dssim = tf.cast(dssim, img_dtype)
return dssim
nn.tf_dssim = tf_dssim
def tf_space_to_depth(x, size):
if nn.data_format == "NHWC":
# match NCHW version in order to switch data_format without problems
b,h,w,c = x.shape.as_list()
oh, ow = h // size, w // size
x = tf.reshape(x, (-1, size, oh, size, ow, c))
x = tf.transpose(x, (0, 2, 4, 1, 3, 5))
x = tf.reshape(x, (-1, oh, ow, size* size* c ))
return x
else:
return tf.space_to_depth(x, size, data_format=nn.data_format)
nn.tf_space_to_depth = tf_space_to_depth
def tf_depth_to_space(x, size):
if nn.data_format == "NHWC":
# match NCHW version in order to switch data_format without problems
b,h,w,c = x.shape.as_list()
oh, ow = h * size, w * size
oc = c // (size * size)
x = tf.reshape(x, (-1, h, w, size, size, oc, ) )
x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
x = tf.reshape(x, (-1, oh, ow, oc, ))
return x
else:
return tf.depth_to_space(x, size, data_format=nn.data_format)
nn.tf_depth_to_space = tf_depth_to_space
def tf_rgb_to_lab(srgb):
srgb_pixels = tf.reshape(srgb, [-1, 3])
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
rgb_to_xyz = tf.constant([
# X Y Z
[0.412453, 0.212671, 0.019334], # R
[0.357580, 0.715160, 0.119193], # G
[0.180423, 0.072169, 0.950227], # B
])
xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
epsilon = 6/29
linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
fxfyfz_to_lab = tf.constant([
# l a b
[ 0.0, 500.0, 0.0], # fx
[116.0, -500.0, 200.0], # fy
[ 0.0, 0.0, -200.0], # fz
])
lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
return tf.reshape(lab_pixels, tf.shape(srgb))
nn.tf_rgb_to_lab = tf_rgb_to_lab
def tf_suppress_lower_mean(t, eps=0.00001):
if t.shape.ndims != 1:
raise ValueError("tf_suppress_lower_mean: t rank must be 1")
t_mean_eps = tf.reduce_mean(t) - eps
q = tf.clip_by_value(t, t_mean_eps, tf.reduce_max(t) )
q = tf.clip_by_value(q-t_mean_eps, 0, eps)
q = q * (t/eps)
return q
"""
class GeLU(KL.Layer):
Gaussian Error Linear Unit.
A smoother version of ReLU generally used
in the BERT or BERT architecture based models.
Original paper: https://arxiv.org/abs/1606.08415
Input shape:
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
Output shape:
Same shape as the input.
def __init__(self, approximate=True, **kwargs):
super(GeLU, self).__init__(**kwargs)
self.approximate = approximate
self.supports_masking = True
def call(self, inputs):
cdf = 0.5 * (1.0 + K.tanh((np.sqrt(2 / np.pi) * (inputs + 0.044715 * K.pow(inputs, 3)))))
return inputs * cdf
def get_config(self):
config = {'approximate': self.approximate}
base_config = super(GeLU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape
nn.GeLU = GeLU
"""