fix for plaidML

This commit is contained in:
iperov 2019-05-24 09:30:39 +04:00
parent 3114ae9d7b
commit c9da4cd89b
2 changed files with 6 additions and 4 deletions

View file

@ -161,7 +161,6 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
return nnlib.code_import_keras return nnlib.code_import_keras
nnlib.backend = device_config.backend nnlib.backend = device_config.backend
if "tensorflow" in nnlib.backend: if "tensorflow" in nnlib.backend:
nnlib._import_tf(device_config) nnlib._import_tf(device_config)
elif nnlib.backend == "plaidML": elif nnlib.backend == "plaidML":
@ -174,6 +173,9 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
import keras as keras_ import keras as keras_
nnlib.keras = keras_ nnlib.keras = keras_
if 'KERAS_BACKEND' in os.environ:
os.environ.pop('KERAS_BACKEND')
if nnlib.backend == "plaidML": if nnlib.backend == "plaidML":
import plaidml import plaidml
import plaidml.tile import plaidml.tile