Try assign device

This commit is contained in:
jh 2021-03-12 16:00:20 -08:00
commit d35549fc5d

View file

@ -329,8 +329,16 @@ def ms_ssim(img1, img2, resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=
# Transpose images from NCHW to NHWC
img1_t = tf.transpose(img1, [0, 2, 3, 1])
img2_t = tf.transpose(img2, [0, 2, 3, 1])
ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors,
filter_size=kernel_size, k1=k1, k2=k2)
def assign_device(op):
if op.type != 'ListDiff':
return '/gpu:0'
else:
return '/cpu:0'
with tf.device(assign_device):
ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors,
filter_size=kernel_size, k1=k1, k2=k2)
ms_ssim_loss = (1.0 - ms_ssim_val) / 2.0
if img_dtype != tf.float32: