diff --git a/core/leras/nn.py b/core/leras/nn.py index d9c5850..8ac437b 100644 --- a/core/leras/nn.py +++ b/core/leras/nn.py @@ -112,7 +112,7 @@ class nn(): config = tf.ConfigProto(device_count={'GPU': 0}) else: nn.tf_default_device = "/GPU:0" - config = tf.ConfigProto() + config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices]) config.gpu_options.force_gpu_compatible = True