mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-21 14:03:19 -07:00
upgrade to tf 2.4.0rc1
This commit is contained in:
parent
8b9b907682
commit
0eb7e06ac1
6 changed files with 23 additions and 59 deletions
|
@ -70,19 +70,23 @@ class nn():
|
|||
first_run = True
|
||||
os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path)
|
||||
|
||||
os.environ['CUDA_CACHE_MAXSIZE'] = '536870912' #512Mb (32mb default)
|
||||
#nvcuda.dll ignores this param : os.environ['CUDA_CACHE_MAXSIZE'] = '536870912' #512Mb (32mb default)
|
||||
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # tf log errors only
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # tf log errors only
|
||||
|
||||
if first_run:
|
||||
io.log_info("Caching GPU kernels...")
|
||||
|
||||
import tensorflow as tf
|
||||
nn.tf = tf
|
||||
#import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
import logging
|
||||
# Disable tensorflow warnings
|
||||
logging.getLogger('tensorflow').setLevel(logging.ERROR)
|
||||
tf_logger = logging.getLogger('tensorflow')
|
||||
tf_logger.setLevel(logging.ERROR)
|
||||
|
||||
tf.disable_v2_behavior()
|
||||
nn.tf = tf
|
||||
|
||||
# Initialize framework
|
||||
import core.leras.ops
|
||||
|
|
|
@ -333,7 +333,17 @@ def depth_to_space(x, size):
|
|||
x = tf.reshape(x, (-1, oh, ow, oc, ))
|
||||
return x
|
||||
else:
|
||||
|
||||
b,c,h,w = x.shape.as_list()
|
||||
oh, ow = h * size, w * size
|
||||
oc = c // (size * size)
|
||||
|
||||
x = tf.reshape(x, (-1, size, size, oc, h, w, ) )
|
||||
x = tf.transpose(x, (0, 3, 4, 1, 5, 2))
|
||||
x = tf.reshape(x, (-1, oc, oh, ow))
|
||||
return x
|
||||
return tf.depth_to_space(x, size, data_format=nn.data_format)
|
||||
|
||||
nn.depth_to_space = depth_to_space
|
||||
|
||||
def rgb_to_lab(srgb):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue