From 6c42cd3a8f21ab7052576ef8f6c932a07101e269 Mon Sep 17 00:00:00 2001 From: jh Date: Fri, 12 Mar 2021 18:11:05 -0800 Subject: [PATCH] fix cast --- core/leras/layers/MsSsim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/leras/layers/MsSsim.py b/core/leras/layers/MsSsim.py index 6fa7c69..fff5f74 100644 --- a/core/leras/layers/MsSsim.py +++ b/core/leras/layers/MsSsim.py @@ -17,8 +17,8 @@ class MsSsim(nn.LayerBase): def __call__(self, y_true, y_pred, max_val): # Transpose images from NCHW to NHWC - y_true_t = tf.transpose(tf.cast(y_true), [0, 2, 3, 1]) - y_pred_t = tf.transpose(tf.cast(y_pred), [0, 2, 3, 1]) + y_true_t = tf.transpose(tf.cast(y_true, tf.float32), [0, 2, 3, 1]) + y_pred_t = tf.transpose(tf.cast(y_pred, tf.float32), [0, 2, 3, 1]) loss = tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors) return (1.0 - loss) / 2.0