Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make API globals thread safe using atomics #222

Merged
merged 1 commit into from
Nov 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions src/npyffi/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
use libc::FILE;
use pyo3::ffi::{self, PyObject, PyTypeObject};
use std::os::raw::*;
use std::{cell::Cell, ptr};
use std::ptr::null_mut;
use std::sync::atomic::{AtomicPtr, Ordering};

use crate::npyffi::*;

Expand All @@ -12,7 +13,7 @@ const CAPSULE_NAME: &str = "_ARRAY_API";
/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
/// pointer to [Numpy Array API](https://numpy.org/doc/stable/reference/c-api/array.html).
///
/// You can acceess raw c APIs via this variable and its Deref implementation.
/// You can acceess raw C APIs via this variable.
///
/// See [PyArrayAPI](struct.PyArrayAPI.html) for what methods you can use via this variable.
///
Expand All @@ -31,28 +32,35 @@ pub static PY_ARRAY_API: PyArrayAPI = PyArrayAPI::new();

/// See [PY_ARRAY_API] for more.
pub struct PyArrayAPI {
api: Cell<*const *const c_void>,
api: AtomicPtr<*const c_void>,
}

impl PyArrayAPI {
const fn new() -> Self {
Self {
api: Cell::new(ptr::null_mut()),
api: AtomicPtr::new(null_mut()),
}
}
fn get(&self, offset: isize) -> *const *const c_void {
if self.api.get().is_null() {
Python::with_gil(|py| {
let api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
self.api.set(api);
});
#[cold]
fn init(&self) -> *const *const c_void {
Python::with_gil(|py| {
let mut api = self.api.load(Ordering::Relaxed) as *const *const c_void;
if api.is_null() {
adamreichold marked this conversation as resolved.
Show resolved Hide resolved
api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a potential gotcha, can get_numpy_api lead to temporary release of the GIL lock? That would potentially enable multiple threads to run this initialization.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As CPython's import implementation can be hooked, I think one cannot prevent this from happening in general. But I also think that multiple threads performing the initialization is only an issue of efficiency.

If a hook is releasing the GIL for whatever reason, it needs to be reacquired and all threads will only progress back here with the GIL held and at most store the same capsule pointer redundantly. (Doing the double-checking here on my part was only motivated by efficiency, i.e. we already have to take the lock so why not use this to avoid redundant initialization as we are already on the slow path.)

(If multiple threads importing the same module yields a different capsule and hence API pointer, I think all bets are off and we would need external synchronization like using std::sync::Once.)

Copy link
Member Author

@adamreichold adamreichold Nov 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(If multiple threads importing the same module yields a different capsule and hence API pointer, I think all bets are off and we would need external synchronization like using std::sync::Once.)

Well we could compare_exchange the pointer instead of storeing it and only update it if it still NULL and otherwise discard our just initialized value in favour of the "old" one returned by compare_exchange.

But having the global at all seems weird if we are expecting that the get_numpy_api returns different capsules when called from different threads or at different times.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I think most likely this code is fine as-is thanks to the global nature. In GILOnceCell I chose to drop any surplus values produced by other threads if a race occurred. This was kind of necessary because of the API contract of it being write-once.

self.api.store(api as *mut _, Ordering::Release);
}
api
})
}
unsafe fn get(&self, offset: isize) -> *const *const c_void {
let mut api = self.api.load(Ordering::Acquire) as *const *const c_void;
if api.is_null() {
api = self.init();
}
unsafe { self.api.get().offset(offset) }
api.offset(offset)
}
}

unsafe impl Sync for PyArrayAPI {}

impl PyArrayAPI {
impl_api![0; PyArray_GetNDArrayCVersion() -> c_uint];
impl_api![40; PyArray_SetNumericOps(dict: *mut PyObject) -> c_int];
Expand Down
32 changes: 20 additions & 12 deletions src/npyffi/ufunc.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
//! Low-Level binding for [UFunc API](https://numpy.org/doc/stable/reference/c-api/ufunc.html)

use std::os::raw::*;
use std::{cell::Cell, ptr};
use std::ptr::null_mut;
use std::sync::atomic::{AtomicPtr, Ordering};

use pyo3::ffi::PyObject;
use pyo3::Python;
Expand All @@ -18,28 +19,35 @@ const CAPSULE_NAME: &str = "_UFUNC_API";
pub static PY_UFUNC_API: PyUFuncAPI = PyUFuncAPI::new();

pub struct PyUFuncAPI {
api: Cell<*const *const c_void>,
api: AtomicPtr<*const c_void>,
}

impl PyUFuncAPI {
const fn new() -> Self {
Self {
api: Cell::new(ptr::null_mut()),
api: AtomicPtr::new(null_mut()),
}
}
fn get(&self, offset: isize) -> *const *const c_void {
if self.api.get().is_null() {
Python::with_gil(|py| {
let api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
self.api.set(api);
});
#[cold]
fn init(&self) -> *const *const c_void {
Python::with_gil(|py| {
let mut api = self.api.load(Ordering::Relaxed) as *const *const c_void;
if api.is_null() {
api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
self.api.store(api as *mut _, Ordering::Release);
}
api
})
}
unsafe fn get(&self, offset: isize) -> *const *const c_void {
let mut api = self.api.load(Ordering::Acquire) as *const *const c_void;
if api.is_null() {
api = self.init();
}
unsafe { self.api.get().offset(offset) }
api.offset(offset)
}
}

unsafe impl Sync for PyUFuncAPI {}

impl PyUFuncAPI {
impl_api![1; PyUFunc_FromFuncAndData(func: *mut PyUFuncGenericFunction, data: *mut *mut c_void, types: *mut c_char, ntypes: c_int, nin: c_int, nout: c_int, identity: c_int, name: *const c_char, doc: *const c_char, unused: c_int) -> *mut PyObject];
impl_api![2; PyUFunc_RegisterLoopForType(ufunc: *mut PyUFuncObject, usertype: c_int, function: PyUFuncGenericFunction, arg_types: *mut c_int, data: *mut c_void) -> c_int];
Expand Down