mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
update leras
This commit is contained in:
parent
a1cc297cef
commit
f07bbd5fb0
1 changed files with 26 additions and 13 deletions
|
@ -117,9 +117,27 @@ def initialize_tensor_ops(nn):
|
|||
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
|
||||
nn.tf_upsample2d = tf_upsample2d
|
||||
|
||||
def tf_upsample2d_bilinear(x, size=2):
|
||||
return tf.image.resize_images(x, (x.shape[1]*size, x.shape[2]*size) )
|
||||
nn.tf_upsample2d_bilinear = tf_upsample2d_bilinear
|
||||
def tf_resize2d_bilinear(x, size=2):
|
||||
h = x.shape[nn.conv2d_spatial_axes[0]].value
|
||||
w = x.shape[nn.conv2d_spatial_axes[1]].value
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
x = tf.transpose(x, (0,2,3,1))
|
||||
|
||||
if size > 0:
|
||||
new_size = (h*size,w*size)
|
||||
else:
|
||||
new_size = (h//-size,w//-size)
|
||||
|
||||
x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.BILINEAR)
|
||||
|
||||
if nn.data_format == "NCHW":
|
||||
x = tf.transpose(x, (0,3,1,2))
|
||||
|
||||
return x
|
||||
nn.tf_resize2d_bilinear = tf_resize2d_bilinear
|
||||
|
||||
|
||||
|
||||
def tf_flatten(x):
|
||||
if nn.data_format == "NHWC":
|
||||
|
@ -181,16 +199,11 @@ def initialize_tensor_ops(nn):
|
|||
padding = None
|
||||
gauss_kernel = gauss_kernel[:,:,None,None]
|
||||
|
||||
outputs = []
|
||||
for i in range(input.shape[nn.conv2d_ch_axis]):
|
||||
x = input[:,:,:,i:i+1] if nn.data_format == "NHWC" \
|
||||
else input[:,i:i+1,:,:]
|
||||
|
||||
if padding is not None:
|
||||
x = tf.pad (x, padding)
|
||||
outputs += [ tf.nn.conv2d(x, tf.constant(gauss_kernel, dtype=input.dtype ), strides=[1,1,1,1], padding="VALID", data_format=nn.data_format) ]
|
||||
|
||||
return tf.concat (outputs, axis=nn.conv2d_ch_axis)
|
||||
x = input
|
||||
k = tf.tile (gauss_kernel, (1,1,x.shape[nn.conv2d_ch_axis],1) )
|
||||
x = tf.pad(x, padding )
|
||||
x = tf.nn.depthwise_conv2d(x, k, strides=[1,1,1,1], padding='VALID', data_format=nn.data_format)
|
||||
return x
|
||||
nn.tf_gaussian_blur = tf_gaussian_blur
|
||||
|
||||
def tf_style_loss(target, style, gaussian_blur_radius=0.0, loss_weight=1.0, step_size=1):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue