Skip to content

Commit

Permalink
torch/accelerator: fix device type comparison (#143541)
Browse files Browse the repository at this point in the history
This was failing without the fix:
```
python -c 'import torch; d=torch.device("xpu:0"); torch.accelerator.current_stream(d)'
```
with:
```
ValueError: xpu doesn't match the current accelerator xpu.
```

CC: @guangyey, @EikanWang

Pull Request resolved: #143541
Approved by: https://github.com/guangyey, https://github.com/albanD
  • Loading branch information
dvrogozh authored and pytorchmergebot committed Dec 23, 2024
1 parent 434e0c2 commit 7314cf4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
11 changes: 11 additions & 0 deletions test/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ def test_generic_stream_behavior(self):
self.assertTrue(event.query())
self.assertEqual(c_acc.cpu(), c)

def test_current_stream_query(self):
s = torch.accelerator.current_stream()
self.assertEqual(torch.accelerator.current_stream(s.device), s)
self.assertEqual(torch.accelerator.current_stream(s.device.index), s)
self.assertEqual(torch.accelerator.current_stream(str(s.device)), s)
other_device = torch.device("cpu")
with self.assertRaisesRegex(
ValueError, "doesn't match the current accelerator"
):
torch.accelerator.current_stream(other_device)


if __name__ == "__main__":
run_tests()
2 changes: 1 addition & 1 deletion torch/accelerator/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def _get_device_index(device: _device_t, optional: bool = False) -> int:
device = torch.device(device)
device_index: Optional[int] = None
if isinstance(device, torch.device):
if torch.accelerator.current_accelerator() != device.type:
if torch.accelerator.current_accelerator().type != device.type:
raise ValueError(
f"{device.type} doesn't match the current accelerator {torch.accelerator.current_accelerator()}."
)
Expand Down

0 comments on commit 7314cf4

Please sign in to comment.