Skip to content

Commit

Permalink
update PR template, docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudiaComito committed Dec 12, 2024
1 parent 103dc7e commit 80a867e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Implementation:
- [ ] unit tests: all split configurations tested
- [ ] unit tests: multiple dtypes tested
- [ ] **NEW** unit tests: MPS tested (1 MPI process, 1 GPU)
- [ ] benchmarks: created for new functionality
- [ ] benchmarks: performance improved or maintained
- [ ] documentation updated where needed
Expand Down
6 changes: 4 additions & 2 deletions heat/core/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Device:
device(cpu:0)
>>> ht.Device("gpu", 0, "cuda:0")
device(gpu:0)
>>> ht.Device("gpu", 0, "mps:0") # on Apple M1/M2
device(gpu:0)
"""

def __init__(self, device_type: str, device_id: int, torch_device: str):
Expand Down Expand Up @@ -134,12 +136,12 @@ def __eq__(self, other: Any) -> bool:
__all__.append("gpu")

elif torch.backends.mps.is_built() and torch.backends.mps.is_available():
# Apple Metal Performance Shaders (MPS) available
# Apple MPS available
gpu_id = 0
# create a new GPU device
gpu = Device("gpu", gpu_id, "mps:{}".format(gpu_id))
"""
The standard GPU Device on Apple M1
The standard GPU Device on Apple M1/M2
Examples
--------
Expand Down

0 comments on commit 80a867e

Please sign in to comment.