Skip to content

Commit

Permalink
[MPSInductor] Implement bitcasts (pytorch#144638)
Browse files Browse the repository at this point in the history
That will be used to compile something like `torch.rand(32, device='mps').view(dtype=torch.int32)`

Pull Request resolved: pytorch#144638
Approved by: https://github.com/dcci
  • Loading branch information
malfet authored and pytorchmergebot committed Jan 12, 2025
1 parent 32a91de commit cec2458
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/inductor/test_mps_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class MPSBasicTests(TestCase):
test_tanh = CommonTemplate.test_tanh
test_view_as_complex = CommonTemplate.test_view_as_complex
test_views6 = CommonTemplate.test_views6
test_views7 = CommonTemplate.test_views7
test_zero_dim_reductions = CommonTemplate.test_zero_dim_reductions

@parametrize("dtype", MPS_DTYPES)
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/codegen/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def to_dtype(
) -> str:
return f"static_cast<{DTYPE_TO_METAL[dtype]}>({x})"

@staticmethod
def to_dtype_bitcast(
x: CSEVariable, dtype: torch.dtype, src_dtype: torch.dtype
) -> str:
return f"*reinterpret_cast<thread {DTYPE_TO_METAL[dtype]}*>(&{x})"

@staticmethod
def constant(val: CSEVariable, dtype: torch.dtype) -> str:
return value_to_metal(val)
Expand Down

0 comments on commit cec2458

Please sign in to comment.