From 6a68dd5e5954f0ba394494d5dbf913b73b45b6d3 Mon Sep 17 00:00:00 2001 From: jh Date: Fri, 12 Mar 2021 18:36:41 -0800 Subject: [PATCH] tf2 --- core/leras/layers/MsSsim.py | 13 +++---------- core/leras/nn.py | 19 +++++++++++-------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/core/leras/layers/MsSsim.py b/core/leras/layers/MsSsim.py index b631491..1f7ad6d 100644 --- a/core/leras/layers/MsSsim.py +++ b/core/leras/layers/MsSsim.py @@ -1,5 +1,6 @@ from core.leras import nn tf = nn.tf +tf2 = nn.tf2 class MsSsim(nn.LayerBase): default_power_factors = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) @@ -20,15 +21,7 @@ class MsSsim(nn.LayerBase): y_true_t = tf.transpose(tf.cast(y_true, tf.float32), [0, 2, 3, 1]) y_pred_t = tf.transpose(tf.cast(y_pred, tf.float32), [0, 2, 3, 1]) - - def assign_device(op): - if op.type != 'Assert' or op.type != 'ListDiff': - return '/gpu:0' - else: - return '/cpu:0' - - with tf.device(assign_device): - loss = tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors) - return (1.0 - loss) / 2.0 + loss = tf2.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors) + return (1.0 - loss) / 2.0 nn.MsSsim = MsSsim diff --git a/core/leras/nn.py b/core/leras/nn.py index ef5c2c9..8b10fba 100644 --- a/core/leras/nn.py +++ b/core/leras/nn.py @@ -31,6 +31,7 @@ class nn(): current_DeviceConfig = None tf = None + tf2 = None tf_sess = None tf_sess_config = None tf_default_device = None @@ -40,7 +41,7 @@ class nn(): conv2d_spatial_axes = None floatx = None - + @staticmethod def initialize(device_config=None, floatx="float32", data_format="NHWC"): @@ -50,7 +51,7 @@ class nn(): nn.setCurrentDeviceConfig(device_config) # Manipulate environment variables before import tensorflow - + if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): os.environ.pop('CUDA_VISIBLE_DEVICES') @@ -77,15 +78,17 @@ class nn(): io.log_info("Caching GPU kernels...") import tensorflow - + tf_version = getattr(tensorflow,'VERSION', None) if tf_version is None: tf_version = tensorflow.version.GIT_VERSION if tf_version[0] == 'v': tf_version = tf_version[1:] - + if tf_version[0] == '2': tf = tensorflow.compat.v1 + tf2 = tensorflow + nn.tf2 = tf2 else: tf = tensorflow @@ -93,7 +96,7 @@ class nn(): # Disable tensorflow warnings tf_logger = logging.getLogger('tensorflow') tf_logger.setLevel(logging.ERROR) - + if tf_version[0] == '2': tf.disable_v2_behavior() nn.tf = tf @@ -105,7 +108,7 @@ class nn(): import core.leras.optimizers import core.leras.models import core.leras.archis - + # Configure tensorflow session-config if len(device_config.devices) == 0: nn.tf_default_device = "/CPU:0" @@ -118,7 +121,7 @@ class nn(): config.gpu_options.force_gpu_compatible = True config.gpu_options.allow_growth = True nn.tf_sess_config = config - + if nn.tf_sess is None: nn.tf_sess = tf.Session(config=nn.tf_sess_config) @@ -273,7 +276,7 @@ class nn(): @staticmethod def ask_choose_device(*args, **kwargs): return nn.DeviceConfig.GPUIndexes( nn.ask_choose_device_idxs(*args,**kwargs) ) - + def __init__ (self, devices=None): devices = devices or []