mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 14:24:40 -07:00
Try assign device
This commit is contained in:
parent
3588b0efa9
commit
d35549fc5d
1 changed files with 10 additions and 2 deletions
|
@ -329,6 +329,14 @@ def ms_ssim(img1, img2, resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=
|
||||||
# Transpose images from NCHW to NHWC
|
# Transpose images from NCHW to NHWC
|
||||||
img1_t = tf.transpose(img1, [0, 2, 3, 1])
|
img1_t = tf.transpose(img1, [0, 2, 3, 1])
|
||||||
img2_t = tf.transpose(img2, [0, 2, 3, 1])
|
img2_t = tf.transpose(img2, [0, 2, 3, 1])
|
||||||
|
|
||||||
|
def assign_device(op):
|
||||||
|
if op.type != 'ListDiff':
|
||||||
|
return '/gpu:0'
|
||||||
|
else:
|
||||||
|
return '/cpu:0'
|
||||||
|
|
||||||
|
with tf.device(assign_device):
|
||||||
ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors,
|
ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors,
|
||||||
filter_size=kernel_size, k1=k1, k2=k2)
|
filter_size=kernel_size, k1=k1, k2=k2)
|
||||||
ms_ssim_loss = (1.0 - ms_ssim_val) / 2.0
|
ms_ssim_loss = (1.0 - ms_ssim_val) / 2.0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue