Skip to content

Commit 2ddc347

Browse files
committed
refine _DEVICES
1 parent 64e431c commit 2ddc347

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

test/prototype/test_quantized_training.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,8 @@
3939
if common_utils.SEED is None:
4040
common_utils.SEED = 1234
4141

42-
_DEVICES = (
43-
["cpu"]
44-
+ (["cuda"] if torch.cuda.is_available() else [])
45-
+ (["xpu"] if torch.xpu.is_available() else [])
46-
)
4742
_DEVICE = get_current_accelerator_device()
43+
_DEVICES = ["cpu"] + ([_DEVICE] if torch.accelerator.is_available() else [])
4844

4945

5046
def _reset():

0 commit comments

Comments
 (0)