diff --git a/core/leras/nn.py b/core/leras/nn.py index c4420e1..76750fb 100644 --- a/core/leras/nn.py +++ b/core/leras/nn.py @@ -76,15 +76,19 @@ class nn(): if first_run: io.log_info("Caching GPU kernels...") - #import tensorflow as tf - import tensorflow.compat.v1 as tf - + import tensorflow + if tensorflow.VERSION[0] == '2': + tf = tensorflow.compat.v1 + else: + tf = tensorflow + import logging # Disable tensorflow warnings tf_logger = logging.getLogger('tensorflow') tf_logger.setLevel(logging.ERROR) - tf.disable_v2_behavior() + if tensorflow.VERSION[0] == '2': + tf.disable_v2_behavior() nn.tf = tf # Initialize framework