diff --git a/core/leras/ops/__init__.py b/core/leras/ops/__init__.py index 583a67d..b460f17 100644 --- a/core/leras/ops/__init__.py +++ b/core/leras/ops/__init__.py @@ -329,8 +329,16 @@ def ms_ssim(img1, img2, resolution, kernel_size=11, k1=0.01, k2=0.03, max_value= # Transpose images from NCHW to NHWC img1_t = tf.transpose(img1, [0, 2, 3, 1]) img2_t = tf.transpose(img2, [0, 2, 3, 1]) - ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors, - filter_size=kernel_size, k1=k1, k2=k2) + + def assign_device(op): + if op.type != 'ListDiff': + return '/gpu:0' + else: + return '/cpu:0' + + with tf.device(assign_device): + ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors, + filter_size=kernel_size, k1=k1, k2=k2) ms_ssim_loss = (1.0 - ms_ssim_val) / 2.0 if img_dtype != tf.float32: