added AMD/Intel cards support via DirectX12 ( DirectML backend )

This commit is contained in:
iperov 2021-04-22 18:19:15 +04:00
commit fdb143ff47
7 changed files with 166 additions and 116 deletions

View file

@ -33,8 +33,8 @@ class nn():
tf = None
tf_sess = None
tf_sess_config = None
tf_default_device = None
tf_default_device_name = None
data_format = None
conv2d_ch_axis = None
conv2d_spatial_axes = None
@ -50,9 +50,6 @@ class nn():
nn.setCurrentDeviceConfig(device_config)
# Manipulate environment variables before import tensorflow
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
os.environ.pop('CUDA_VISIBLE_DEVICES')
first_run = False
if len(device_config.devices) != 0:
@ -68,22 +65,19 @@ class nn():
compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache' + devices_str)
if not compute_cache_path.exists():
first_run = True
compute_cache_path.mkdir(parents=True, exist_ok=True)
os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path)
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # tf log errors only
if first_run:
io.log_info("Caching GPU kernels...")
import tensorflow
tf_version = getattr(tensorflow,'VERSION', None)
if tf_version is None:
tf_version = tensorflow.version.GIT_VERSION
if tf_version[0] == 'v':
tf_version = tf_version[1:]
tf_version = tensorflow.version.VERSION
#if tf_version is None:
# tf_version = tensorflow.version.GIT_VERSION
if tf_version[0] == 'v':
tf_version = tf_version[1:]
if tf_version[0] == '2':
tf = tensorflow.compat.v1
else:
@ -108,13 +102,14 @@ class nn():
# Configure tensorflow session-config
if len(device_config.devices) == 0:
nn.tf_default_device = "/CPU:0"
config = tf.ConfigProto(device_count={'GPU': 0})
nn.tf_default_device_name = '/CPU:0'
else:
nn.tf_default_device = "/GPU:0"
nn.tf_default_device_name = f'/{device_config.devices[0].tf_dev_type}:0'
config = tf.ConfigProto()
config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices])
config.gpu_options.force_gpu_compatible = True
config.gpu_options.allow_growth = True
nn.tf_sess_config = config
@ -202,14 +197,6 @@ class nn():
nn.tf_sess.close()
nn.tf_sess = None
@staticmethod
def get_current_device():
# Undocumented access to last tf.device(...)
objs = nn.tf.get_default_graph()._device_function_stack.peek_objs()
if len(objs) != 0:
return objs[0].display_name
return nn.tf_default_device
@staticmethod
def ask_choose_device_idxs(choose_only_one=False, allow_cpu=True, suggest_best_multi_gpu=False, suggest_all_gpu=False):
devices = Devices.getDevices()