From beb8c30970b310ce1fac421e1448726451a80845 Mon Sep 17 00:00:00 2001 From: jh Date: Fri, 12 Mar 2021 18:27:00 -0800 Subject: [PATCH] assign device --- core/leras/layers/MsSsim.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/core/leras/layers/MsSsim.py b/core/leras/layers/MsSsim.py index fff5f74..84028bb 100644 --- a/core/leras/layers/MsSsim.py +++ b/core/leras/layers/MsSsim.py @@ -19,7 +19,16 @@ class MsSsim(nn.LayerBase): # Transpose images from NCHW to NHWC y_true_t = tf.transpose(tf.cast(y_true, tf.float32), [0, 2, 3, 1]) y_pred_t = tf.transpose(tf.cast(y_pred, tf.float32), [0, 2, 3, 1]) - loss = tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors) - return (1.0 - loss) / 2.0 + + + def assign_device(op): + if op.type != 'Assert': + return '/gpu:0' + else: + return '/cpu:0' + + with tf.device(assign_device): + loss = tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors) + return (1.0 - loss) / 2.0 nn.MsSsim = MsSsim