diff --git a/core/leras/layers/MsSsim.py b/core/leras/layers/MsSsim.py index 84028bb..b631491 100644 --- a/core/leras/layers/MsSsim.py +++ b/core/leras/layers/MsSsim.py @@ -22,7 +22,7 @@ class MsSsim(nn.LayerBase): def assign_device(op): - if op.type != 'Assert': + if op.type != 'Assert' or op.type != 'ListDiff': return '/gpu:0' else: return '/cpu:0'