From 8ff34be5e4a78a4470fad941c5cacbe156fc4ce3 Mon Sep 17 00:00:00 2001 From: iperov Date: Fri, 1 Jan 2021 17:37:12 +0400 Subject: [PATCH 1/2] leras.nn : support for tf ver 1 --- core/leras/nn.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/core/leras/nn.py b/core/leras/nn.py index c4420e1..76750fb 100644 --- a/core/leras/nn.py +++ b/core/leras/nn.py @@ -76,15 +76,19 @@ class nn(): if first_run: io.log_info("Caching GPU kernels...") - #import tensorflow as tf - import tensorflow.compat.v1 as tf - + import tensorflow + if tensorflow.VERSION[0] == '2': + tf = tensorflow.compat.v1 + else: + tf = tensorflow + import logging # Disable tensorflow warnings tf_logger = logging.getLogger('tensorflow') tf_logger.setLevel(logging.ERROR) - tf.disable_v2_behavior() + if tensorflow.VERSION[0] == '2': + tf.disable_v2_behavior() nn.tf = tf # Initialize framework From 4f2efd7985b1bd25ffd96a53f05c5889038f36df Mon Sep 17 00:00:00 2001 From: iperov Date: Fri, 1 Jan 2021 17:59:57 +0400 Subject: [PATCH 2/2] fix support for v1/v2 --- core/leras/nn.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/core/leras/nn.py b/core/leras/nn.py index 76750fb..ef5c2c9 100644 --- a/core/leras/nn.py +++ b/core/leras/nn.py @@ -77,7 +77,14 @@ class nn(): io.log_info("Caching GPU kernels...") import tensorflow - if tensorflow.VERSION[0] == '2': + + tf_version = getattr(tensorflow,'VERSION', None) + if tf_version is None: + tf_version = tensorflow.version.GIT_VERSION + if tf_version[0] == 'v': + tf_version = tf_version[1:] + + if tf_version[0] == '2': tf = tensorflow.compat.v1 else: tf = tensorflow @@ -87,7 +94,7 @@ class nn(): tf_logger = logging.getLogger('tensorflow') tf_logger.setLevel(logging.ERROR) - if tensorflow.VERSION[0] == '2': + if tf_version[0] == '2': tf.disable_v2_behavior() nn.tf = tf