diff --git a/core/leras/nn.py b/core/leras/nn.py index 07cbdf6..7c28874 100644 --- a/core/leras/nn.py +++ b/core/leras/nn.py @@ -107,7 +107,7 @@ class nn(): else: nn.tf_default_device_name = f'/{device_config.devices[0].tf_dev_type}:0' - config = tf.ConfigProto() + config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices]) config.gpu_options.force_gpu_compatible = True