Skip to content

Commit 83d6d6d

Browse files
committed
Add Python Module interface for MPS backend (#251)
- Enable global manual seeding via torch.manual_seed() + test case - Add torch.mps.synchronize() to wait for MPS stream to finish + test case - Enable the following python interfaces for MPS: torch.mps.get_rng_state() torch.mps.set_rng_state() torch.mps.is_available() torch.mps.synchronize() torch.mps.manual_seed() torch.mps.seed() torch.mps.is_initialized() torch.mps.init()
1 parent 889a464 commit 83d6d6d

File tree

14 files changed

+300
-16
lines changed

14 files changed

+300
-16
lines changed

aten/src/ATen/detail/MPSHooksInterface.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,21 @@ struct TORCH_API MPSHooksInterface {
2828
return false;
2929
}
3030

31+
virtual bool isOnMacOS13orNewer() const {
32+
return false;
33+
}
34+
3135
virtual const Generator& getDefaultMPSGenerator() const {
3236
AT_ERROR("Cannot get default MPS generator without MPS backend.");
3337
}
3438

3539
virtual Allocator* getMPSDeviceAllocator() const {
3640
AT_ERROR("MPSDeviceAllocator requires MPS.");
3741
}
42+
43+
virtual void deviceSynchronize() const {
44+
TORCH_CHECK(false, "Cannot synchronize MPS device without MPS backend. ");
45+
}
3846
};
3947

4048
struct TORCH_API MPSHooksArgs {};

aten/src/ATen/mps/MPSDevice.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class TORCH_API MPSDevice {
7878

7979
TORCH_API bool is_available();
8080
TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
81-
81+
TORCH_API void device_synchronize();
8282
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
8383

8484
} // namespace mps

aten/src/ATen/mps/MPSDevice.mm

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <c10/util/CallOnce.h>
44

55
#include <ATen/mps/MPSDevice.h>
6+
#include <ATen/mps/MPSStream.h>
67
#include <ATen/mps/MPSAllocatorInterface.h>
78
#include <ATen/mps/IndexKernels.h>
89

@@ -118,5 +119,9 @@ bool is_macos_13_or_newer(MacOSVersion version) {
118119
return MPSDevice::getInstance()->isMacOS13Plus(version);
119120
}
120121

122+
void device_synchronize() {
123+
getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
124+
}
125+
121126
} // namespace mps
122127
} // namespace at

aten/src/ATen/mps/MPSHooks.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ bool MPSHooks::hasMPS() const {
1616
return at::mps::is_available();
1717
}
1818

19+
bool MPSHooks::isOnMacOS13orNewer() const {
20+
return at::mps::is_macos_13_or_newer();
21+
}
22+
1923
Allocator* MPSHooks::getMPSDeviceAllocator() const {
2024
return at::mps::GetMPSAllocator();
2125
}
@@ -24,6 +28,10 @@ const Generator& MPSHooks::getDefaultMPSGenerator() const {
2428
return at::mps::detail::getDefaultMPSGenerator();
2529
}
2630

31+
void MPSHooks::deviceSynchronize() const {
32+
at::mps::device_synchronize();
33+
}
34+
2735
using at::MPSHooksRegistry;
2836
using at::RegistererMPSHooksRegistry;
2937

aten/src/ATen/mps/MPSHooks.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ struct MPSHooks : public at::MPSHooksInterface {
1313
MPSHooks(at::MPSHooksArgs) {}
1414
void initMPS() const override;
1515
bool hasMPS() const override;
16+
bool isOnMacOS13orNewer() const override;
1617
Allocator* getMPSDeviceAllocator() const override;
1718
const Generator& getDefaultMPSGenerator() const override;
19+
void deviceSynchronize() const override;
1820
};
1921

2022
}} // at::mps

build_variables.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,7 @@ torch_cpp_srcs = [
715715
"torch/csrc/api/src/imethod.cpp",
716716
"torch/csrc/api/src/jit.cpp",
717717
"torch/csrc/api/src/serialize.cpp",
718+
"torch/csrc/api/src/mps.cpp",
718719
"torch/csrc/api/src/nn/init.cpp",
719720
"torch/csrc/api/src/nn/module.cpp",
720721
"torch/csrc/api/src/nn/modules/_functions.cpp",
@@ -821,6 +822,7 @@ libtorch_python_core_sources = [
821822
"torch/csrc/dynamo/guards.cpp",
822823
"torch/csrc/dynamo/init.cpp",
823824
"torch/csrc/functorch/init.cpp",
825+
"torch/csrc/mps/Module.cpp",
824826
"torch/csrc/jit/backends/backend_init.cpp",
825827
"torch/csrc/jit/python/init.cpp",
826828
"torch/csrc/jit/passes/onnx.cpp",

test/test_mps.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5803,6 +5803,45 @@ def test_mps_generator(self):
58035803
mps_x = torch.randn(5, device='mps', generator=g_mps)
58045804
self.assertEqual(mps_x, mps_y)
58055805

5806+
def test_default_mps_generator(self):
5807+
# manual seeding on the "default" MPS generator using
5808+
# the global torch.manual_seed()
5809+
torch.manual_seed(230)
5810+
mps_x = torch.randn(5, device='mps')
5811+
# manual seeding using torch.mps.manual_seed()
5812+
# which should set the "default" MPS generator
5813+
# like the global torch.manual_seed()
5814+
torch.mps.manual_seed(230)
5815+
mps_y = torch.randn(5, device='mps')
5816+
# seed values were the same, so the random tensor contents should match
5817+
self.assertEqual(mps_x, mps_y)
5818+
5819+
# save the default generator's state to restore it later
5820+
g_state = torch.mps.get_rng_state()
5821+
5822+
# generate random numbers without seeding
5823+
mps_x = torch.randn(5, device='mps')
5824+
# in this case, the random results must differ from the last generated random results
5825+
self.assertNotEqual(mps_x, mps_y)
5826+
5827+
# restore the previously saved state, and the results should match again
5828+
torch.mps.set_rng_state(g_state)
5829+
mps_x = torch.randn(5, device='mps')
5830+
self.assertEqual(mps_x, mps_y)
5831+
5832+
def test_device_synchronize(self):
5833+
# just running some ops each followed by a synchronize to wait for
5834+
# MPS stream to finish running each of them
5835+
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
5836+
.to(device='mps', dtype=torch.float)
5837+
5838+
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
5839+
torch.mps.synchronize()
5840+
x = net1(x)
5841+
torch.mps.synchronize()
5842+
x.backward(torch.randn_like(x))
5843+
torch.mps.synchronize()
5844+
58065845
# Test random_.to and random_.from
58075846
def test_random(self):
58085847
def helper(shape, low, high, dtype=torch.int32):

torch/_C/__init__.pyi.in

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -903,8 +903,6 @@ def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: T
903903
def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function
904904
def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
905905
def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
906-
def _is_mps_available() -> _bool: ...
907-
def _is_mps_on_macos_13_or_newer() -> _bool: ...
908906
class _LinalgBackend:
909907
Default: _LinalgBackend
910908
Cusolver: _LinalgBackend
@@ -1200,6 +1198,12 @@ class _TensorBase(metaclass=_TensorMeta):
12001198
# Defined in torch/csrc/multiprocessing/init.cpp
12011199
def _multiprocessing_init() -> None: ...
12021200

1201+
# Defined in torch/csrc/mps/Module.cpp
1202+
def _mps_synchronize() -> None: ...
1203+
def _mps_init() -> None: ...
1204+
def _is_mps_available() -> _bool: ...
1205+
def _is_mps_on_macos_13_or_newer() -> _bool: ...
1206+
12031207
# Defined in torch/csrc/cuda/Module.cpp
12041208
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
12051209
def _cuda_getCurrentRawStream(device: _int) -> _int: ...

torch/csrc/Module.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@
8787
#endif
8888
#endif
8989

90-
#if defined(USE_MPS)
91-
#include <ATen/mps/MPSDevice.h>
92-
#endif
93-
9490
#if defined(USE_VALGRIND)
9591
#include <callgrind.h>
9692
#endif
@@ -1219,6 +1215,10 @@ void initIttBindings(PyObject* module);
12191215
} // namespace torch
12201216
#endif
12211217

1218+
#ifdef USE_MPS
1219+
PyMethodDef* MPSModule_methods();
1220+
#endif
1221+
12221222
namespace torch {
12231223
void initVerboseBindings(PyObject* module);
12241224
} // namespace torch
@@ -1274,6 +1274,9 @@ PyObject* initModule() {
12741274
#ifdef USE_CUDA
12751275
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
12761276
#endif
1277+
#ifdef USE_MPS
1278+
THPUtils_addPyMethodDefs(methods, MPSModule_methods());
1279+
#endif
12771280
#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
12781281
THPUtils_addPyMethodDefs(
12791282
methods, torch::distributed::c10d::python_functions());
@@ -1593,15 +1596,6 @@ Call this whenever a new thread is created in order to propagate values from
15931596

15941597
ASSERT_TRUE(set_module_attr("has_cuda", has_cuda));
15951598
ASSERT_TRUE(set_module_attr("has_mps", has_mps));
1596-
py_module.def("_is_mps_available", []() { return at::hasMPS(); });
1597-
py_module.def("_is_mps_on_macos_13_or_newer", []() {
1598-
#ifdef USE_MPS
1599-
return at::mps::is_macos_13_or_newer();
1600-
#else
1601-
return false;
1602-
#endif
1603-
});
1604-
16051599
ASSERT_TRUE(
16061600
set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));
16071601

torch/csrc/api/include/torch/mps.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#include <torch/csrc/Export.h>
4+
5+
#include <cstddef>
6+
#include <cstdint>
7+
8+
namespace torch {
9+
namespace mps {
10+
11+
/// Returns true if MPS device is available.
12+
bool TORCH_API is_available();
13+
14+
/// Sets the seed for the current GPU.
15+
void TORCH_API manual_seed(uint64_t seed);
16+
17+
/// Waits for all streams on a MPS device to complete.
18+
void TORCH_API synchronize();
19+
20+
} // namespace mps
21+
} // namespace torch

0 commit comments

Comments
 (0)