diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index a4f59d6..711edb8 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -1,15 +1,18 @@ +import contextlib +import multiprocessing import os import sys -import contextlib +from pathlib import Path + import numpy as np -from .CAInitializer import CAGenerateWeights -import multiprocessing -from joblib import Subprocessor - -from utils import std_utils -from .device import device from interact import interact as io +from joblib import Subprocessor +from utils import std_utils + +from .CAInitializer import CAGenerateWeights +from .device import device + class nnlib(object): device = device #forwards nnlib.devicelib to device in order to use nnlib as standalone lib @@ -173,6 +176,11 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator os.environ.pop('CUDA_VISIBLE_DEVICES') os.environ['CUDA_​CACHE_​MAXSIZE'] = '536870912' #512Mb (32mb default) + + if sys.platform[0:3] == 'win': + if len(device_config.gpu_idxs) == 1: + os.environ['CUDA_CACHE_PATH'] = \ + str(Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache_' + device_config.gpu_names[0].replace(' ','_'))) os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #tf log errors only