Skip to content

Commit

Permalink
Introduce #[pyclass(unsendable)]
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Jun 29, 2020
1 parent 6335a7f commit 292c3ef
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 21 deletions.
4 changes: 4 additions & 0 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,10 @@ impl pyo3::class::proto_methods::HasProtoRegistry for MyClass {
&REGISTRY
}
}

impl pyo3::pyclass::PyClassSend for MyClass {
type ThreadChecker = pyo3::pyclass::ThreadCheckerStub<MyClass>;
}
# let gil = Python::acquire_gil();
# let py = gil.python();
# let cls = py.get_type::<MyClass>();
Expand Down
48 changes: 31 additions & 17 deletions pyo3-derive-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct PyClassArgs {
pub flags: Vec<syn::Expr>,
pub base: syn::TypePath,
pub has_extends: bool,
pub has_unsendable: bool,
pub module: Option<syn::LitStr>,
}

Expand All @@ -45,6 +46,7 @@ impl Default for PyClassArgs {
flags: vec![parse_quote! { 0 }],
base: parse_quote! { pyo3::PyAny },
has_extends: false,
has_unsendable: false,
}
}
}
Expand All @@ -60,7 +62,7 @@ impl PyClassArgs {
}
}

/// Match a single flag
/// Match a key/value flag
fn add_assign(&mut self, assign: &syn::ExprAssign) -> syn::Result<()> {
let syn::ExprAssign { left, right, .. } = assign;
let key = match &**left {
Expand Down Expand Up @@ -120,31 +122,28 @@ impl PyClassArgs {
Ok(())
}

/// Match a key/value flag
/// Match a single flag
fn add_path(&mut self, exp: &syn::ExprPath) -> syn::Result<()> {
let flag = exp.path.segments.first().unwrap().ident.to_string();
let path = match flag.as_str() {
"gc" => {
parse_quote! {pyo3::type_flags::GC}
}
"weakref" => {
parse_quote! {pyo3::type_flags::WEAKREF}
}
"subclass" => {
parse_quote! {pyo3::type_flags::BASETYPE}
}
"dict" => {
parse_quote! {pyo3::type_flags::DICT}
let mut push_flag = |flag| {
self.flags.push(syn::Expr::Path(flag));
};
match flag.as_str() {
"gc" => push_flag(parse_quote! {pyo3::type_flags::GC}),
"weakref" => push_flag(parse_quote! {pyo3::type_flags::WEAKREF}),
"subclass" => push_flag(parse_quote! {pyo3::type_flags::BASETYPE}),
"dict" => push_flag(parse_quote! {pyo3::type_flags::DICT}),
"unsendable" => {
// unsendable gives a manuall `PyClassSend` implementation, so sets no flag
self.has_unsendable = true;
}
_ => {
return Err(syn::Error::new_spanned(
&exp.path,
"Expected one of gc/weakref/subclass/dict",
"Expected one of gc/weakref/subclass/dict/unsendable",
))
}
};

self.flags.push(syn::Expr::Path(path));
Ok(())
}
}
Expand Down Expand Up @@ -386,6 +385,20 @@ fn impl_class(
quote! {}
};

let pysend = if attr.has_unsendable {
quote! {
impl pyo3::pyclass::PyClassSend for #cls {
type ThreadChecker = pyo3::pyclass::ThreadCheckerImpl<#cls>;
}
}
} else {
quote! {
impl pyo3::pyclass::PyClassSend for #cls {
type ThreadChecker = pyo3::pyclass::ThreadCheckerStub<#cls>;
}
}
};

Ok(quote! {
unsafe impl pyo3::type_object::PyTypeInfo for #cls {
type Type = #cls;
Expand Down Expand Up @@ -434,6 +447,7 @@ fn impl_class(

#gc_impl

#pysend
})
}

Expand Down
8 changes: 7 additions & 1 deletion src/pycell.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
//! Includes `PyCell` implementation.
use crate::conversion::{AsPyPointer, FromPyPointer, ToPyObject};
use crate::pyclass::{PyClass, PyClassThreadChecker};
use crate::pyclass_init::PyClassInitializer;
use crate::pyclass_slots::{PyClassDict, PyClassWeakRef};
use crate::type_object::{PyBorrowFlagLayout, PyLayout, PySizedLayout, PyTypeInfo};
use crate::types::PyAny;
use crate::{ffi, FromPy, PyClass, PyErr, PyNativeType, PyObject, PyResult, Python};
use crate::{ffi, FromPy, PyErr, PyNativeType, PyObject, PyResult, Python};
use std::cell::{Cell, UnsafeCell};
use std::fmt;
use std::mem::ManuallyDrop;
Expand Down Expand Up @@ -161,6 +162,7 @@ pub struct PyCell<T: PyClass> {
inner: PyCellInner<T>,
dict: T::Dict,
weakref: T::WeakRef,
thread_checker: T::ThreadChecker,
}

unsafe impl<T: PyClass> PyNativeType for PyCell<T> {}
Expand Down Expand Up @@ -227,6 +229,7 @@ impl<T: PyClass> PyCell<T> {
/// }
/// ```
pub fn try_borrow(&self) -> Result<PyRef<'_, T>, PyBorrowError> {
self.thread_checker.ensure();
let flag = self.inner.get_borrow_flag();
if flag == BorrowFlag::HAS_MUTABLE_BORROW {
Err(PyBorrowError { _private: () })
Expand Down Expand Up @@ -258,6 +261,7 @@ impl<T: PyClass> PyCell<T> {
/// assert!(c.try_borrow_mut().is_ok());
/// ```
pub fn try_borrow_mut(&self) -> Result<PyRefMut<'_, T>, PyBorrowMutError> {
self.thread_checker.ensure();
if self.inner.get_borrow_flag() != BorrowFlag::UNUSED {
Err(PyBorrowMutError { _private: () })
} else {
Expand Down Expand Up @@ -296,6 +300,7 @@ impl<T: PyClass> PyCell<T> {
/// }
/// ```
pub unsafe fn try_borrow_unguarded(&self) -> Result<&T, PyBorrowError> {
self.thread_checker.ensure();
if self.inner.get_borrow_flag() == BorrowFlag::HAS_MUTABLE_BORROW {
Err(PyBorrowError { _private: () })
} else {
Expand Down Expand Up @@ -352,6 +357,7 @@ impl<T: PyClass> PyCell<T> {
let self_ = base as *mut Self;
(*self_).dict = T::Dict::new();
(*self_).weakref = T::WeakRef::new();
(*self_).thread_checker = T::ThreadChecker::new();
Ok(self_)
}
}
Expand Down
42 changes: 40 additions & 2 deletions src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ use crate::type_object::{type_flags, PyLayout};
use crate::types::PyAny;
use crate::{class, ffi, PyCell, PyErr, PyNativeType, PyResult, PyTypeInfo, Python};
use std::ffi::CString;
use std::marker::PhantomData;
use std::os::raw::c_void;
use std::ptr;
use std::{ptr, thread};

#[inline]
pub(crate) unsafe fn default_new<T: PyTypeInfo>(
Expand Down Expand Up @@ -91,10 +92,10 @@ pub(crate) unsafe fn tp_free_fallback(obj: *mut ffi::PyObject) {
pub trait PyClass:
PyTypeInfo<Layout = PyCell<Self>, AsRefTarget = PyCell<Self>>
+ Sized
+ PyClassSend
+ PyClassAlloc
+ PyMethods
+ PyProtoMethods
+ Send
{
/// Specify this class has `#[pyclass(dict)]` or not.
type Dict: PyClassDict;
Expand Down Expand Up @@ -308,3 +309,40 @@ fn py_class_properties<T: PyMethods>() -> Vec<ffi::PyGetSetDef> {

defs.values().cloned().collect()
}

pub trait PyClassSend {
type ThreadChecker: PyClassThreadChecker;
}

pub trait PyClassThreadChecker: Sized {
fn ensure(&self);
fn new() -> Self;
private_decl! {}
}

pub struct ThreadCheckerStub<T: Send>(PhantomData<T>);

impl<T: Send> PyClassThreadChecker for ThreadCheckerStub<T> {
fn ensure(&self) {}
fn new() -> Self {
ThreadCheckerStub(PhantomData)
}
private_impl! {}
}

pub struct ThreadCheckerImpl<T>(thread::ThreadId, PhantomData<T>);

impl<T> PyClassThreadChecker for ThreadCheckerImpl<T> {
fn ensure(&self) {
if thread::current().id() != self.0 {
panic!(
"{} is unsendable, but sent to another thread!",
std::any::type_name::<T>()
);
}
}
fn new() -> Self {
ThreadCheckerImpl(thread::current().id(), PhantomData)
}
private_impl! {}
}
48 changes: 48 additions & 0 deletions tests/test_class_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,51 @@ fn class_with_object_field() {
py_assert!(py, ty, "ty(5).value == 5");
py_assert!(py, ty, "ty(None).value == None");
}

#[pyclass(unsendable)]
struct UnsendableClass {
rc: std::rc::Rc<usize>,
}

#[pymethods]
impl UnsendableClass {
fn value(&self) -> usize {
*self.rc.as_ref()
}
}

/// Checks that `subclass.__new__` works correctly.
/// See https://github.com/PyO3/pyo3/issues/947 for the corresponding bug.
#[test]
fn panic_unsendable() {
let gil = Python::acquire_gil();
let py = gil.python();
let unsendable = PyCell::new(
py,
UnsendableClass {
rc: std::rc::Rc::new(0),
},
)
.unwrap();

let source = pyo3::indoc::indoc!(
r#"
def value():
return unsendable.value()
import concurrent.futures
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
future = executor.submit(value)
try:
result = future.result()
assert False, 'future must panic'
except BaseException as e:
assert str(e) == 'test_class_basics::UnsendableClass is unsendable, but sent to another thread!'
"#
);
let globals = PyModule::import(py, "__main__").unwrap().dict();
globals.set_item("unsendable", unsendable).unwrap();
py.run(source, Some(globals), None)
.map_err(|e| e.print(py))
.unwrap();
}
2 changes: 1 addition & 1 deletion tests/ui/invalid_pyclass_args.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ error: Expected string literal (e.g., "my_mod")
12 | #[pyclass(module = my_module)]
| ^^^^^^^^^

error: Expected one of gc/weakref/subclass/dict
error: Expected one of gc/weakref/subclass/dict/unsendable
--> $DIR/invalid_pyclass_args.rs:15:11
|
15 | #[pyclass(weakrev)]
Expand Down

0 comments on commit 292c3ef

Please sign in to comment.