Skip to content

Commit b51be30

Browse files
hawkinspvfdev-5
authored andcommitted
Enabled freethreading support in MLIR python bindings
WIP freethreading Added free-threading CPython mode support in Python bindings - temporarily updated requirements Added lock on PyGlobals::get and PyMlirContext liveContexts WIP on adding multithreaded_tests More tests and added a lock to _cext.register_operation Updated tests + labelled passing and failing tests Updated locks and added docs Docs updates, tests and fixed cmake config issue Fixed nanobind target in mlir/cmake/modules/AddMLIRPython.cmake Reverted mlir/test/python/lib/PythonTestModulePybind11.cpp Removed old data-races stacks and removed fixed tests Revert some python deps versions Recoded mlir/test/python/multithreaded_tests.py without pytest Added a condition to skip running mlir/test/python/multithreaded_tests.py when has GIL Addressed some of the PR review comments Updated Python docs and mlir/test/python/multithreaded_tests.py Updated mlir/test/python/multithreaded_tests.py due to removed lock in ExecutionEngine
1 parent 32bc029 commit b51be30

File tree

10 files changed

+639
-17
lines changed

10 files changed

+639
-17
lines changed

Diff for: mlir/cmake/modules/AddMLIRPython.cmake

+20-1
Original file line numberDiff line numberDiff line change
@@ -668,12 +668,31 @@ function(add_mlir_python_extension libname extname)
668668
elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind")
669669
nanobind_add_module(${libname}
670670
NB_DOMAIN mlir
671+
FREE_THREADED
671672
${ARG_SOURCES}
672673
)
673674

674675
if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
675676
# Avoids warnings from upstream nanobind.
676-
target_compile_options(nanobind-static
677+
set(nanobind_target "nanobind-static")
678+
if (NOT TARGET ${nanobind_target})
679+
# Get correct nanobind target name: nanobind-static-ft or something else
680+
# It is set by nanobind_add_module function according to the passed options
681+
get_property(all_targets DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY BUILDSYSTEM_TARGETS)
682+
683+
# Iterate over the list of targets
684+
foreach(target ${all_targets})
685+
# Check if the target name matches the given string
686+
if("${target}" MATCHES "nanobind-")
687+
set(nanobind_target "${target}")
688+
endif()
689+
endforeach()
690+
691+
if (NOT TARGET ${nanobind_target})
692+
message(FATAL_ERROR "Could not find nanobind target to set compile options to")
693+
endif()
694+
endif()
695+
target_compile_options(${nanobind_target}
677696
PRIVATE
678697
-Wno-cast-qual
679698
-Wno-zero-length-array

Diff for: mlir/docs/Bindings/Python.md

+40
Original file line numberDiff line numberDiff line change
@@ -1187,3 +1187,43 @@ or nanobind and
11871187
utilities to connect to the rest of Python API. The bindings can be located in a
11881188
separate module or in the same module as attributes and types, and
11891189
loaded along with the dialect.
1190+
1191+
## Free-threading (No-GIL) support
1192+
1193+
Free-threading or no-GIL support refers to CPython interpreter (>=3.13) with Global Interpreter Lock made optional. For details on the topic, please check [PEP-703](https://peps.python.org/pep-0703/) and this [Python free-threading guide](https://py-free-threading.github.io/).
1194+
1195+
MLIR Python bindings are free-threading compatible with exceptions (discussed below) in the following sense: it is safe to work in multiple threads with **independent** contexts. Below we show an example code of safe usage:
1196+
1197+
```python
1198+
# python3.13t example.py
1199+
import concurrent.futures
1200+
1201+
import mlir.dialects.arith as arith
1202+
from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
1203+
1204+
1205+
def func(py_value):
1206+
with Context() as ctx:
1207+
module = Module.create(loc=Location.file("foo.txt", 0, 0))
1208+
1209+
dtype = IntegerType.get_signless(64)
1210+
with InsertionPoint(module.body), Location.name("a"):
1211+
arith.constant(dtype, py_value)
1212+
1213+
return module
1214+
1215+
1216+
num_workers = 8
1217+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
1218+
futures = []
1219+
for i in range(num_workers):
1220+
futures.append(executor.submit(func, i))
1221+
assert len(list(f.result() for f in futures)) == num_workers
1222+
```
1223+
1224+
The exceptions to the free-threading compatibility:
1225+
- IR printing is unsafe, e.g. when using `PassManager` with `PassManager.enable_ir_printing()` which calls thread-unsafe `llvm::raw_ostream`.
1226+
- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
1227+
- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
1228+
- Usage of `mlir.dialects.transform.interpreter` is unsafe.
1229+
- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.

Diff for: mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
#include "Standalone-c/Dialects.h"
1313
#include "mlir/Bindings/Python/PybindAdaptors.h"
1414

15+
namespace py = pybind11;
16+
1517
using namespace mlir::python::adaptors;
1618

17-
PYBIND11_MODULE(_standaloneDialectsPybind11, m) {
19+
PYBIND11_MODULE(_standaloneDialectsPybind11, m, py::mod_gil_not_used()) {
1820
//===--------------------------------------------------------------------===//
1921
// standalone dialect
2022
//===--------------------------------------------------------------------===//

Diff for: mlir/lib/Bindings/Python/Globals.h

+11-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace mlir {
2424
namespace python {
2525

2626
/// Globals that are always accessible once the extension has been initialized.
27+
/// Methods of this class are thread-safe.
2728
class PyGlobals {
2829
public:
2930
PyGlobals();
@@ -37,12 +38,18 @@ class PyGlobals {
3738

3839
/// Get and set the list of parent modules to search for dialect
3940
/// implementation classes.
40-
std::vector<std::string> &getDialectSearchPrefixes() {
41+
std::vector<std::string> getDialectSearchPrefixes() {
42+
nanobind::ft_lock_guard lock(mutex);
4143
return dialectSearchPrefixes;
4244
}
4345
void setDialectSearchPrefixes(std::vector<std::string> newValues) {
46+
nanobind::ft_lock_guard lock(mutex);
4447
dialectSearchPrefixes.swap(newValues);
4548
}
49+
void addDialectSearchPrefix(std::string value) {
50+
nanobind::ft_lock_guard lock(mutex);
51+
dialectSearchPrefixes.push_back(std::move(value));
52+
}
4653

4754
/// Loads a python module corresponding to the given dialect namespace.
4855
/// No-ops if the module has already been loaded or is not found. Raises
@@ -109,6 +116,9 @@ class PyGlobals {
109116

110117
private:
111118
static PyGlobals *instance;
119+
120+
nanobind::ft_mutex mutex;
121+
112122
/// Module name prefixes to search under for dialect implementation modules.
113123
std::vector<std::string> dialectSearchPrefixes;
114124
/// Map of dialect namespace to external dialect class object.

Diff for: mlir/lib/Bindings/Python/IRCore.cpp

+27-4
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,15 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes,
243243

244244
/// Wrapper for the global LLVM debugging flag.
245245
struct PyGlobalDebugFlag {
246-
static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
246+
static void set(nb::object &o, bool enable) {
247+
nb::ft_lock_guard lock(mutex);
248+
mlirEnableGlobalDebug(enable);
249+
}
247250

248-
static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); }
251+
static bool get(const nb::object &) {
252+
nb::ft_lock_guard lock(mutex);
253+
return mlirIsGlobalDebugEnabled();
254+
}
249255

250256
static void bind(nb::module_ &m) {
251257
// Debug flags.
@@ -255,6 +261,7 @@ struct PyGlobalDebugFlag {
255261
.def_static(
256262
"set_types",
257263
[](const std::string &type) {
264+
nb::ft_lock_guard lock(mutex);
258265
mlirSetGlobalDebugType(type.c_str());
259266
},
260267
"types"_a, "Sets specific debug types to be produced by LLVM")
@@ -263,11 +270,17 @@ struct PyGlobalDebugFlag {
263270
pointers.reserve(types.size());
264271
for (const std::string &str : types)
265272
pointers.push_back(str.c_str());
273+
nb::ft_lock_guard lock(mutex);
266274
mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
267275
});
268276
}
277+
278+
private:
279+
static nb::ft_mutex mutex;
269280
};
270281

282+
nb::ft_mutex PyGlobalDebugFlag::mutex;
283+
271284
struct PyAttrBuilderMap {
272285
static bool dunderContains(const std::string &attributeKind) {
273286
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
@@ -606,6 +619,7 @@ class PyOpOperandIterator {
606619

607620
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
608621
nb::gil_scoped_acquire acquire;
622+
nb::ft_lock_guard lock(live_contexts_mutex);
609623
auto &liveContexts = getLiveContexts();
610624
liveContexts[context.ptr] = this;
611625
}
@@ -615,7 +629,10 @@ PyMlirContext::~PyMlirContext() {
615629
// forContext method, which always puts the associated handle into
616630
// liveContexts.
617631
nb::gil_scoped_acquire acquire;
618-
getLiveContexts().erase(context.ptr);
632+
{
633+
nb::ft_lock_guard lock(live_contexts_mutex);
634+
getLiveContexts().erase(context.ptr);
635+
}
619636
mlirContextDestroy(context);
620637
}
621638

@@ -632,6 +649,7 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
632649

633650
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
634651
nb::gil_scoped_acquire acquire;
652+
nb::ft_lock_guard lock(live_contexts_mutex);
635653
auto &liveContexts = getLiveContexts();
636654
auto it = liveContexts.find(context.ptr);
637655
if (it == liveContexts.end()) {
@@ -647,12 +665,17 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
647665
return PyMlirContextRef(it->second, std::move(pyRef));
648666
}
649667

668+
nb::ft_mutex PyMlirContext::live_contexts_mutex;
669+
650670
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
651671
static LiveContextMap liveContexts;
652672
return liveContexts;
653673
}
654674

655-
size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
675+
size_t PyMlirContext::getLiveCount() {
676+
nb::ft_lock_guard lock(live_contexts_mutex);
677+
return getLiveContexts().size();
678+
}
656679

657680
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
658681

Diff for: mlir/lib/Bindings/Python/IRModule.cpp

+16-2
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,11 @@ PyGlobals::PyGlobals() {
3838
PyGlobals::~PyGlobals() { instance = nullptr; }
3939

4040
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
41-
if (loadedDialectModules.contains(dialectNamespace))
42-
return true;
41+
{
42+
nb::ft_lock_guard lock(mutex);
43+
if (loadedDialectModules.contains(dialectNamespace))
44+
return true;
45+
}
4346
// Since re-entrancy is possible, make a copy of the search prefixes.
4447
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
4548
nb::object loaded = nb::none();
@@ -62,12 +65,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
6265
return false;
6366
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
6467
// may have occurred, which may do anything.
68+
nb::ft_lock_guard lock(mutex);
6569
loadedDialectModules.insert(dialectNamespace);
6670
return true;
6771
}
6872

6973
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
7074
nb::callable pyFunc, bool replace) {
75+
nb::ft_lock_guard lock(mutex);
7176
nb::object &found = attributeBuilderMap[attributeKind];
7277
if (found && !replace) {
7378
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
@@ -81,6 +86,7 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
8186

8287
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
8388
nb::callable typeCaster, bool replace) {
89+
nb::ft_lock_guard lock(mutex);
8490
nb::object &found = typeCasterMap[mlirTypeID];
8591
if (found && !replace)
8692
throw std::runtime_error("Type caster is already registered with caster: " +
@@ -90,6 +96,7 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
9096

9197
void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
9298
nb::callable valueCaster, bool replace) {
99+
nb::ft_lock_guard lock(mutex);
93100
nb::object &found = valueCasterMap[mlirTypeID];
94101
if (found && !replace)
95102
throw std::runtime_error("Value caster is already registered: " +
@@ -99,6 +106,7 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
99106

100107
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
101108
nb::object pyClass) {
109+
nb::ft_lock_guard lock(mutex);
102110
nb::object &found = dialectClassMap[dialectNamespace];
103111
if (found) {
104112
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
@@ -110,6 +118,7 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
110118

111119
void PyGlobals::registerOperationImpl(const std::string &operationName,
112120
nb::object pyClass, bool replace) {
121+
nb::ft_lock_guard lock(mutex);
113122
nb::object &found = operationClassMap[operationName];
114123
if (found && !replace) {
115124
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
@@ -121,6 +130,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
121130

122131
std::optional<nb::callable>
123132
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
133+
nb::ft_lock_guard lock(mutex);
124134
const auto foundIt = attributeBuilderMap.find(attributeKind);
125135
if (foundIt != attributeBuilderMap.end()) {
126136
assert(foundIt->second && "attribute builder is defined");
@@ -133,6 +143,7 @@ std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
133143
MlirDialect dialect) {
134144
// Try to load dialect module.
135145
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
146+
nb::ft_lock_guard lock(mutex);
136147
const auto foundIt = typeCasterMap.find(mlirTypeID);
137148
if (foundIt != typeCasterMap.end()) {
138149
assert(foundIt->second && "type caster is defined");
@@ -145,6 +156,7 @@ std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
145156
MlirDialect dialect) {
146157
// Try to load dialect module.
147158
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
159+
nb::ft_lock_guard lock(mutex);
148160
const auto foundIt = valueCasterMap.find(mlirTypeID);
149161
if (foundIt != valueCasterMap.end()) {
150162
assert(foundIt->second && "value caster is defined");
@@ -158,6 +170,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
158170
// Make sure dialect module is loaded.
159171
if (!loadDialectModule(dialectNamespace))
160172
return std::nullopt;
173+
nb::ft_lock_guard lock(mutex);
161174
const auto foundIt = dialectClassMap.find(dialectNamespace);
162175
if (foundIt != dialectClassMap.end()) {
163176
assert(foundIt->second && "dialect class is defined");
@@ -175,6 +188,7 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
175188
if (!loadDialectModule(dialectNamespace))
176189
return std::nullopt;
177190

191+
nb::ft_lock_guard lock(mutex);
178192
auto foundIt = operationClassMap.find(operationName);
179193
if (foundIt != operationClassMap.end()) {
180194
assert(foundIt->second && "OpView is defined");

Diff for: mlir/lib/Bindings/Python/IRModule.h

+1
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ class PyMlirContext {
260260
// Note that this holds a handle, which does not imply ownership.
261261
// Mappings will be removed when the context is destructed.
262262
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
263+
static nanobind::ft_mutex live_contexts_mutex;
263264
static LiveContextMap &getLiveContexts();
264265

265266
// Interns all live modules associated with this context. Modules tracked

Diff for: mlir/lib/Bindings/Python/MainModule.cpp

+2-7
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,8 @@ NB_MODULE(_mlir, m) {
3030
.def_prop_rw("dialect_search_modules",
3131
&PyGlobals::getDialectSearchPrefixes,
3232
&PyGlobals::setDialectSearchPrefixes)
33-
.def(
34-
"append_dialect_search_prefix",
35-
[](PyGlobals &self, std::string moduleName) {
36-
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
37-
},
38-
"module_name"_a)
33+
.def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
34+
"module_name"_a)
3935
.def(
4036
"_check_dialect_module_loaded",
4137
[](PyGlobals &self, const std::string &dialectNamespace) {
@@ -76,7 +72,6 @@ NB_MODULE(_mlir, m) {
7672
nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
7773
PyGlobals::get().registerOperationImpl(operationName, opClass,
7874
replace);
79-
8075
// Dict-stuff the new opClass by name onto the dialect class.
8176
nb::object opClassName = opClass.attr("__name__");
8277
dialectClass.attr(opClassName) = opClass;

Diff for: mlir/python/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ nanobind>=2.4, <3.0
22
numpy>=1.19.5, <=2.1.2
33
pybind11>=2.10.0, <=2.13.6
44
PyYAML>=5.4.0, <=6.0.1
5-
ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16
5+
ml_dtypes>=0.5.0, <=0.6.0 # provides several NumPy dtype extensions, including the bf16

0 commit comments

Comments
 (0)