diff --git a/core/leras/layers/MsSsim.py b/core/leras/layers/MsSsim.py index fff5f74..84028bb 100644 --- a/core/leras/layers/MsSsim.py +++ b/core/leras/layers/MsSsim.py @@ -19,7 +19,16 @@ class MsSsim(nn.LayerBase): # Transpose images from NCHW to NHWC y_true_t = tf.transpose(tf.cast(y_true, tf.float32), [0, 2, 3, 1]) y_pred_t = tf.transpose(tf.cast(y_pred, tf.float32), [0, 2, 3, 1]) - loss = tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors) - return (1.0 - loss) / 2.0 + + + def assign_device(op): + if op.type != 'Assert': + return '/gpu:0' + else: + return '/cpu:0' + + with tf.device(assign_device): + loss = tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors) + return (1.0 - loss) / 2.0 nn.MsSsim = MsSsim