diff --git a/guide/src/class.md b/guide/src/class.md index 679b61c992e..9bfeb5e8782 100644 --- a/guide/src/class.md +++ b/guide/src/class.md @@ -974,6 +974,10 @@ impl pyo3::class::proto_methods::HasProtoRegistry for MyClass { ®ISTRY } } + +impl pyo3::pyclass::PyClassSend for MyClass { + type ThreadChecker = pyo3::pyclass::ThreadCheckerStub; +} # let gil = Python::acquire_gil(); # let py = gil.python(); # let cls = py.get_type::(); diff --git a/pyo3-derive-backend/src/pyclass.rs b/pyo3-derive-backend/src/pyclass.rs index cf52f334f1c..d2eff105007 100644 --- a/pyo3-derive-backend/src/pyclass.rs +++ b/pyo3-derive-backend/src/pyclass.rs @@ -19,6 +19,7 @@ pub struct PyClassArgs { pub flags: Vec, pub base: syn::TypePath, pub has_extends: bool, + pub has_unsendable: bool, pub module: Option, } @@ -45,6 +46,7 @@ impl Default for PyClassArgs { flags: vec![parse_quote! { 0 }], base: parse_quote! { pyo3::PyAny }, has_extends: false, + has_unsendable: false, } } } @@ -60,7 +62,7 @@ impl PyClassArgs { } } - /// Match a single flag + /// Match a key/value flag fn add_assign(&mut self, assign: &syn::ExprAssign) -> syn::Result<()> { let syn::ExprAssign { left, right, .. } = assign; let key = match &**left { @@ -120,31 +122,28 @@ impl PyClassArgs { Ok(()) } - /// Match a key/value flag + /// Match a single flag fn add_path(&mut self, exp: &syn::ExprPath) -> syn::Result<()> { let flag = exp.path.segments.first().unwrap().ident.to_string(); - let path = match flag.as_str() { - "gc" => { - parse_quote! {pyo3::type_flags::GC} - } - "weakref" => { - parse_quote! {pyo3::type_flags::WEAKREF} - } - "subclass" => { - parse_quote! {pyo3::type_flags::BASETYPE} - } - "dict" => { - parse_quote! {pyo3::type_flags::DICT} + let mut push_flag = |flag| { + self.flags.push(syn::Expr::Path(flag)); + }; + match flag.as_str() { + "gc" => push_flag(parse_quote! {pyo3::type_flags::GC}), + "weakref" => push_flag(parse_quote! {pyo3::type_flags::WEAKREF}), + "subclass" => push_flag(parse_quote! {pyo3::type_flags::BASETYPE}), + "dict" => push_flag(parse_quote! {pyo3::type_flags::DICT}), + "unsendable" => { + // unsendable gives a manuall `PyClassSend` implementation, so sets no flag + self.has_unsendable = true; } _ => { return Err(syn::Error::new_spanned( &exp.path, - "Expected one of gc/weakref/subclass/dict", + "Expected one of gc/weakref/subclass/dict/unsendable", )) } }; - - self.flags.push(syn::Expr::Path(path)); Ok(()) } } @@ -386,6 +385,20 @@ fn impl_class( quote! {} }; + let thread_checker = if attr.has_unsendable { + quote! { pyo3::pyclass::ThreadCheckerImpl<#cls> } + } else if attr.has_extends { + quote! { + pyo3::pyclass::ThreadCheckerInherited< + #cls, + <<#cls as pyo3::type_object::PyTypeInfo>::BaseType as + pyo3::derive_utils::PyBaseTypeUtils>::ThreadChecker + > + } + } else { + quote! { pyo3::pyclass::ThreadCheckerStub<#cls> } + }; + Ok(quote! { unsafe impl pyo3::type_object::PyTypeInfo for #cls { type Type = #cls; @@ -424,6 +437,10 @@ fn impl_class( type Target = pyo3::PyRefMut<'a, #cls>; } + impl pyo3::pyclass::PyClassSend for #cls { + type ThreadChecker = #thread_checker; + } + #into_pyobject #impl_inventory @@ -433,7 +450,6 @@ fn impl_class( #extra #gc_impl - }) } diff --git a/src/derive_utils.rs b/src/derive_utils.rs index 528d6a2ab92..54f97ff78ee 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -162,6 +162,7 @@ pub trait PyBaseTypeUtils { type WeakRef; type LayoutAsBase; type BaseNativeType; + type ThreadChecker; } impl PyBaseTypeUtils for T { @@ -169,6 +170,7 @@ impl PyBaseTypeUtils for T { type WeakRef = T::WeakRef; type LayoutAsBase = crate::pycell::PyCellInner; type BaseNativeType = T::BaseNativeType; + type ThreadChecker = T::ThreadChecker; } /// Utility trait to enable &PyClass as a pymethod/function argument diff --git a/src/pycell.rs b/src/pycell.rs index 2fa672882b0..19f0d1995ab 100644 --- a/src/pycell.rs +++ b/src/pycell.rs @@ -1,10 +1,11 @@ //! Includes `PyCell` implementation. use crate::conversion::{AsPyPointer, FromPyPointer, ToPyObject}; +use crate::pyclass::{PyClass, PyClassThreadChecker}; use crate::pyclass_init::PyClassInitializer; use crate::pyclass_slots::{PyClassDict, PyClassWeakRef}; use crate::type_object::{PyBorrowFlagLayout, PyLayout, PySizedLayout, PyTypeInfo}; use crate::types::PyAny; -use crate::{ffi, FromPy, PyClass, PyErr, PyNativeType, PyObject, PyResult, Python}; +use crate::{ffi, FromPy, PyErr, PyNativeType, PyObject, PyResult, Python}; use std::cell::{Cell, UnsafeCell}; use std::fmt; use std::mem::ManuallyDrop; @@ -161,6 +162,7 @@ pub struct PyCell { inner: PyCellInner, dict: T::Dict, weakref: T::WeakRef, + thread_checker: T::ThreadChecker, } unsafe impl PyNativeType for PyCell {} @@ -227,6 +229,7 @@ impl PyCell { /// } /// ``` pub fn try_borrow(&self) -> Result, PyBorrowError> { + self.thread_checker.ensure(); let flag = self.inner.get_borrow_flag(); if flag == BorrowFlag::HAS_MUTABLE_BORROW { Err(PyBorrowError { _private: () }) @@ -258,6 +261,7 @@ impl PyCell { /// assert!(c.try_borrow_mut().is_ok()); /// ``` pub fn try_borrow_mut(&self) -> Result, PyBorrowMutError> { + self.thread_checker.ensure(); if self.inner.get_borrow_flag() != BorrowFlag::UNUSED { Err(PyBorrowMutError { _private: () }) } else { @@ -296,6 +300,7 @@ impl PyCell { /// } /// ``` pub unsafe fn try_borrow_unguarded(&self) -> Result<&T, PyBorrowError> { + self.thread_checker.ensure(); if self.inner.get_borrow_flag() == BorrowFlag::HAS_MUTABLE_BORROW { Err(PyBorrowError { _private: () }) } else { @@ -352,6 +357,7 @@ impl PyCell { let self_ = base as *mut Self; (*self_).dict = T::Dict::new(); (*self_).weakref = T::WeakRef::new(); + (*self_).thread_checker = T::ThreadChecker::new(); Ok(self_) } } diff --git a/src/pyclass.rs b/src/pyclass.rs index 28cadef6b39..1c709a19cb8 100644 --- a/src/pyclass.rs +++ b/src/pyclass.rs @@ -7,8 +7,9 @@ use crate::type_object::{type_flags, PyLayout}; use crate::types::PyAny; use crate::{class, ffi, PyCell, PyErr, PyNativeType, PyResult, PyTypeInfo, Python}; use std::ffi::CString; +use std::marker::PhantomData; use std::os::raw::c_void; -use std::ptr; +use std::{ptr, thread}; #[inline] pub(crate) unsafe fn default_new( @@ -91,10 +92,10 @@ pub(crate) unsafe fn tp_free_fallback(obj: *mut ffi::PyObject) { pub trait PyClass: PyTypeInfo, AsRefTarget = PyCell> + Sized + + PyClassSend + PyClassAlloc + PyMethods + PyProtoMethods - + Send { /// Specify this class has `#[pyclass(dict)]` or not. type Dict: PyClassDict; @@ -308,3 +309,52 @@ fn py_class_properties() -> Vec { defs.values().cloned().collect() } + +pub trait PyClassSend { + type ThreadChecker: PyClassThreadChecker; +} + +pub trait PyClassThreadChecker: Sized { + fn ensure(&self); + fn new() -> Self; + private_decl! {} +} + +pub struct ThreadCheckerStub(PhantomData); + +impl PyClassThreadChecker for ThreadCheckerStub { + fn ensure(&self) {} + fn new() -> Self { + ThreadCheckerStub(PhantomData) + } + private_impl! {} +} + +pub struct ThreadCheckerImpl(thread::ThreadId, PhantomData); + +impl PyClassThreadChecker for ThreadCheckerImpl { + fn ensure(&self) { + if thread::current().id() != self.0 { + panic!( + "{} is unsendable, but sent to another thread!", + std::any::type_name::() + ); + } + } + fn new() -> Self { + ThreadCheckerImpl(thread::current().id(), PhantomData) + } + private_impl! {} +} + +pub struct ThreadCheckerInherited(PhantomData, U); + +impl PyClassThreadChecker for ThreadCheckerInherited { + fn ensure(&self) { + self.1.ensure(); + } + fn new() -> Self { + ThreadCheckerInherited(PhantomData, U::new()) + } + private_impl! {} +} diff --git a/src/types/mod.rs b/src/types/mod.rs index c238e72aec3..fafc358038d 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -75,6 +75,7 @@ macro_rules! pyobject_native_type { type WeakRef = $crate::pyclass_slots::PyClassDummySlot; type LayoutAsBase = $crate::pycell::PyCellBase<$name>; type BaseNativeType = $name; + type ThreadChecker = $crate::pyclass::ThreadCheckerStub<$crate::PyObject>; } pyobject_native_type_named!($name $(,$type_param)*); pyobject_native_type_convert!($name, $layout, $typeobject, $module, $checkfunction $(,$type_param)*); diff --git a/tests/test_class_basics.rs b/tests/test_class_basics.rs index 82b4810c7de..03a13103f6b 100644 --- a/tests/test_class_basics.rs +++ b/tests/test_class_basics.rs @@ -163,3 +163,55 @@ fn class_with_object_field() { py_assert!(py, ty, "ty(5).value == 5"); py_assert!(py, ty, "ty(None).value == None"); } + +#[pyclass(unsendable)] +struct UnsendableBase { + rc: std::rc::Rc, +} + +#[pymethods] +impl UnsendableBase { + fn value(&self) -> usize { + *self.rc.as_ref() + } +} + +#[pyclass(extends=UnsendableBase)] +struct UnsendableChild {} + +/// If a class is marked as `unsendable`, it panics when accessed by another thread. +#[test] +fn panic_unsendable() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let base = || UnsendableBase { + rc: std::rc::Rc::new(0), + }; + let unsendable_base = PyCell::new(py, base()).unwrap(); + let unsendable_child = PyCell::new(py, (UnsendableChild {}, base())).unwrap(); + + let source = pyo3::indoc::indoc!( + r#" +def value(): + return unsendable.value() + +import concurrent.futures +executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) +future = executor.submit(value) +try: + result = future.result() + assert False, 'future must panic' +except BaseException as e: + assert str(e) == 'test_class_basics::UnsendableBase is unsendable, but sent to another thread!' +"# + ); + let globals = PyModule::import(py, "__main__").unwrap().dict(); + let test = |unsendable| { + globals.set_item("unsendable", unsendable).unwrap(); + py.run(source, Some(globals), None) + .map_err(|e| e.print(py)) + .unwrap(); + }; + test(unsendable_base.as_ref()); + test(unsendable_child.as_ref()); +} diff --git a/tests/ui/invalid_pyclass_args.stderr b/tests/ui/invalid_pyclass_args.stderr index 72373cd6d3e..42b2c460874 100644 --- a/tests/ui/invalid_pyclass_args.stderr +++ b/tests/ui/invalid_pyclass_args.stderr @@ -22,7 +22,7 @@ error: Expected string literal (e.g., "my_mod") 12 | #[pyclass(module = my_module)] | ^^^^^^^^^ -error: Expected one of gc/weakref/subclass/dict +error: Expected one of gc/weakref/subclass/dict/unsendable --> $DIR/invalid_pyclass_args.rs:15:11 | 15 | #[pyclass(weakrev)]