Skip to content

Commit cbe9530

Browse files
committed
Add Python Module interface for MPS backend (#251)
- Enable global manual seeding via torch.manual_seed() + test case - Add torch.mps.synchronize() to wait for MPS stream to finish + test case - Enable the following python interfaces for MPS: torch.mps.get_rng_state() torch.mps.set_rng_state() torch.mps.is_available() torch.mps.synchronize() torch.mps.manual_seed() torch.mps.seed() torch.mps.is_initialized() torch.mps.init()
1 parent 4e984cb commit cbe9530

File tree

14 files changed

+301
-16
lines changed

14 files changed

+301
-16
lines changed

aten/src/ATen/detail/MPSHooksInterface.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,21 @@ struct TORCH_API MPSHooksInterface {
2828
return false;
2929
}
3030

31+
virtual bool isOnMacOS13orNewer() const {
32+
return false;
33+
}
34+
3135
virtual const Generator& getDefaultMPSGenerator() const {
3236
AT_ERROR("Cannot get default MPS generator without MPS backend.");
3337
}
3438

3539
virtual Allocator* getMPSDeviceAllocator() const {
3640
AT_ERROR("MPSDeviceAllocator requires MPS.");
3741
}
42+
43+
virtual void deviceSynchronize() const {
44+
TORCH_CHECK(false, "Cannot synchronize MPS device without MPS backend. ");
45+
}
3846
};
3947

4048
struct TORCH_API MPSHooksArgs {};

aten/src/ATen/mps/MPSDevice.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class TORCH_API MPSDevice {
7272

7373
TORCH_API bool is_available();
7474
TORCH_API bool is_macos_13_or_newer();
75-
75+
TORCH_API void device_synchronize();
7676
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
7777

7878
} // namespace mps

aten/src/ATen/mps/MPSDevice.mm

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <c10/util/CallOnce.h>
44

55
#include <ATen/mps/MPSDevice.h>
6+
#include <ATen/mps/MPSStream.h>
67
#include <ATen/mps/MPSAllocatorInterface.h>
78
#include <ATen/mps/IndexKernels.h>
89

@@ -107,5 +108,9 @@ bool is_macos_13_or_newer() {
107108
return MPSDevice::getInstance()->isMacOS13Plus();
108109
}
109110

111+
void device_synchronize() {
112+
getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
113+
}
114+
110115
} // namespace mps
111116
} // namespace at

aten/src/ATen/mps/MPSHooks.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ bool MPSHooks::hasMPS() const {
1616
return at::mps::is_available();
1717
}
1818

19+
bool MPSHooks::isOnMacOS13orNewer() const {
20+
return at::mps::is_macos_13_or_newer();
21+
}
22+
1923
Allocator* MPSHooks::getMPSDeviceAllocator() const {
2024
return at::mps::GetMPSAllocator();
2125
}
@@ -24,6 +28,10 @@ const Generator& MPSHooks::getDefaultMPSGenerator() const {
2428
return at::mps::detail::getDefaultMPSGenerator();
2529
}
2630

31+
void MPSHooks::deviceSynchronize() const {
32+
at::mps::device_synchronize();
33+
}
34+
2735
using at::MPSHooksRegistry;
2836
using at::RegistererMPSHooksRegistry;
2937

aten/src/ATen/mps/MPSHooks.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ struct MPSHooks : public at::MPSHooksInterface {
1313
MPSHooks(at::MPSHooksArgs) {}
1414
void initMPS() const override;
1515
bool hasMPS() const override;
16+
bool isOnMacOS13orNewer() const override;
1617
Allocator* getMPSDeviceAllocator() const override;
1718
const Generator& getDefaultMPSGenerator() const override;
19+
void deviceSynchronize() const override;
1820
};
1921

2022
}} // at::mps

build_variables.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,7 @@ torch_cpp_srcs = [
715715
"torch/csrc/api/src/imethod.cpp",
716716
"torch/csrc/api/src/jit.cpp",
717717
"torch/csrc/api/src/serialize.cpp",
718+
"torch/csrc/api/src/mps.cpp",
718719
"torch/csrc/api/src/nn/init.cpp",
719720
"torch/csrc/api/src/nn/module.cpp",
720721
"torch/csrc/api/src/nn/modules/_functions.cpp",
@@ -821,6 +822,7 @@ libtorch_python_core_sources = [
821822
"torch/csrc/dynamo/guards.cpp",
822823
"torch/csrc/dynamo/init.cpp",
823824
"torch/csrc/functorch/init.cpp",
825+
"torch/csrc/mps/Module.cpp",
824826
"torch/csrc/jit/backends/backend_init.cpp",
825827
"torch/csrc/jit/python/init.cpp",
826828
"torch/csrc/jit/passes/onnx.cpp",

test/test_mps.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch.testing import make_tensor
2626
from torch.testing._comparison import TensorLikePair
2727
from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
28+
import torch.mps
2829
import torch.backends.mps
2930
from torch.distributions import Uniform, Exponential
3031
from functools import partial
@@ -5741,6 +5742,45 @@ def test_mps_generator(self):
57415742
mps_x = torch.randn(5, device='mps', generator=g_mps)
57425743
self.assertEqual(mps_x, mps_y)
57435744

5745+
def test_default_mps_generator(self):
5746+
# manual seeding on the "default" MPS generator using
5747+
# the global torch.manual_seed()
5748+
torch.manual_seed(230)
5749+
mps_x = torch.randn(5, device='mps')
5750+
# manual seeding using torch.mps.manual_seed()
5751+
# which should set the "default" MPS generator
5752+
# like the global torch.manual_seed()
5753+
torch.mps.manual_seed(230)
5754+
mps_y = torch.randn(5, device='mps')
5755+
# seed values were the same, so the random tensor contents should match
5756+
self.assertEqual(mps_x, mps_y)
5757+
5758+
# save the default generator's state to restore it later
5759+
g_state = torch.mps.get_rng_state()
5760+
5761+
# generate random numbers without seeding
5762+
mps_x = torch.randn(5, device='mps')
5763+
# in this case, the random results must differ from the last generated random results
5764+
self.assertNotEqual(mps_x, mps_y)
5765+
5766+
# restore the previously saved state, and the results should match again
5767+
torch.mps.set_rng_state(g_state)
5768+
mps_x = torch.randn(5, device='mps')
5769+
self.assertEqual(mps_x, mps_y)
5770+
5771+
def test_device_synchronize(self):
5772+
# just running some ops each followed by a synchronize to wait for
5773+
# MPS stream to finish running each of them
5774+
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
5775+
.to(device='mps', dtype=torch.float)
5776+
5777+
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
5778+
torch.mps.synchronize()
5779+
x = net1(x)
5780+
torch.mps.synchronize()
5781+
x.backward(torch.randn_like(x))
5782+
torch.mps.synchronize()
5783+
57445784
# Test random_.to and random_.from_int
57455785
def test_random(self):
57465786
def helper(shape, low, high, dtype=torch.int32):

torch/_C/__init__.pyi.in

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -904,8 +904,6 @@ def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: T
904904
def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function
905905
def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
906906
def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
907-
def _is_mps_available() -> _bool: ...
908-
def _is_mps_on_macos_13_or_newer() -> _bool: ...
909907
class _LinalgBackend:
910908
Default: _LinalgBackend
911909
Cusolver: _LinalgBackend
@@ -1201,6 +1199,12 @@ class _TensorBase(metaclass=_TensorMeta):
12011199
# Defined in torch/csrc/multiprocessing/init.cpp
12021200
def _multiprocessing_init() -> None: ...
12031201

1202+
# Defined in torch/csrc/mps/Module.cpp
1203+
def _mps_synchronize() -> None: ...
1204+
def _mps_init() -> None: ...
1205+
def _is_mps_available() -> _bool: ...
1206+
def _is_mps_on_macos_13_or_newer() -> _bool: ...
1207+
12041208
# Defined in torch/csrc/cuda/Module.cpp
12051209
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
12061210
def _cuda_getCurrentRawStream(device: _int) -> _int: ...

torch/csrc/Module.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@
8787
#endif
8888
#endif
8989

90-
#if defined(USE_MPS)
91-
#include <ATen/mps/MPSDevice.h>
92-
#endif
93-
9490
#if defined(USE_VALGRIND)
9591
#include <callgrind.h>
9692
#endif
@@ -1219,6 +1215,10 @@ void initIttBindings(PyObject* module);
12191215
} // namespace torch
12201216
#endif
12211217

1218+
#ifdef USE_MPS
1219+
PyMethodDef* MPSModule_methods();
1220+
#endif
1221+
12221222
namespace torch {
12231223
void initVerboseBindings(PyObject* module);
12241224
} // namespace torch
@@ -1274,6 +1274,9 @@ PyObject* initModule() {
12741274
#ifdef USE_CUDA
12751275
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
12761276
#endif
1277+
#ifdef USE_MPS
1278+
THPUtils_addPyMethodDefs(methods, MPSModule_methods());
1279+
#endif
12771280
#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
12781281
THPUtils_addPyMethodDefs(
12791282
methods, torch::distributed::c10d::python_functions());
@@ -1593,15 +1596,6 @@ Call this whenever a new thread is created in order to propagate values from
15931596

15941597
ASSERT_TRUE(set_module_attr("has_cuda", has_cuda));
15951598
ASSERT_TRUE(set_module_attr("has_mps", has_mps));
1596-
py_module.def("_is_mps_available", []() { return at::hasMPS(); });
1597-
py_module.def("_is_mps_on_macos_13_or_newer", []() {
1598-
#ifdef USE_MPS
1599-
return at::mps::is_macos_13_or_newer();
1600-
#else
1601-
return false;
1602-
#endif
1603-
});
1604-
16051599
ASSERT_TRUE(
16061600
set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));
16071601

torch/csrc/api/include/torch/mps.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#include <torch/csrc/Export.h>
4+
5+
#include <cstddef>
6+
#include <cstdint>
7+
8+
namespace torch {
9+
namespace mps {
10+
11+
/// Returns true if MPS device is available.
12+
bool TORCH_API is_available();
13+
14+
/// Sets the seed for the current GPU.
15+
void TORCH_API manual_seed(uint64_t seed);
16+
17+
/// Waits for all streams on a MPS device to complete.
18+
void TORCH_API synchronize();
19+
20+
} // namespace mps
21+
} // namespace torch

0 commit comments

Comments
 (0)