Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-116738: Make _abc module thread-safe #117488

Merged
merged 4 commits into from
Apr 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Include/internal/pycore_typeobject.h
Original file line number Diff line number Diff line change
@@ -152,6 +152,18 @@ PyAPI_FUNC(PyObject*) _PySuper_Lookup(PyTypeObject *su_type, PyObject *su_obj,

extern PyObject* _PyType_GetFullyQualifiedName(PyTypeObject *type, char sep);

// Perform the following operation, in a thread-safe way when required by the
// build mode.
//
// self->tp_flags = (self->tp_flags & ~mask) | flags;
extern void _PyType_SetFlags(PyTypeObject *self, unsigned long mask,
unsigned long flags);

// Like _PyType_SetFlags(), but apply the operation to self and any of its
// subclasses without Py_TPFLAGS_IMMUTABLETYPE set.
extern void _PyType_SetFlagsRecursive(PyTypeObject *self, unsigned long mask,
unsigned long flags);


#ifdef __cplusplus
}
262 changes: 147 additions & 115 deletions Modules/_abc.c
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ PyDoc_STRVAR(_abc__doc__,

typedef struct {
PyTypeObject *_abc_data_type;
unsigned long long abc_invalidation_counter;
uint64_t abc_invalidation_counter;
} _abcmodule_state;

static inline _abcmodule_state*
@@ -32,17 +32,61 @@ get_abc_state(PyObject *module)
return (_abcmodule_state *)state;
}

static inline uint64_t
get_invalidation_counter(_abcmodule_state *state)
{
#ifdef Py_GIL_DISABLED
return _Py_atomic_load_uint64(&state->abc_invalidation_counter);
#else
return state->abc_invalidation_counter;
#endif
}

static inline void
increment_invalidation_counter(_abcmodule_state *state)
{
#ifdef Py_GIL_DISABLED
_Py_atomic_add_uint64(&state->abc_invalidation_counter, 1);
#else
state->abc_invalidation_counter++;
#endif
}

/* This object stores internal state for ABCs.
Note that we can use normal sets for caches,
since they are never iterated over. */
typedef struct {
PyObject_HEAD
/* These sets of weak references are lazily created. Once created, they
will point to the same sets until the ABCMeta object is destroyed or
cleared, both of which will only happen while the object is visible to a
single thread. */
PyObject *_abc_registry;
PyObject *_abc_cache; /* Normal set of weak references. */
PyObject *_abc_negative_cache; /* Normal set of weak references. */
unsigned long long _abc_negative_cache_version;
PyObject *_abc_cache;
PyObject *_abc_negative_cache;
uint64_t _abc_negative_cache_version;
} _abc_data;

static inline uint64_t
get_cache_version(_abc_data *impl)
{
#ifdef Py_GIL_DISABLED
return _Py_atomic_load_uint64(&impl->_abc_negative_cache_version);
#else
return impl->_abc_negative_cache_version;
#endif
}

static inline void
set_cache_version(_abc_data *impl, uint64_t version)
{
#ifdef Py_GIL_DISABLED
_Py_atomic_store_uint64(&impl->_abc_negative_cache_version, version);
#else
impl->_abc_negative_cache_version = version;
#endif
}

static int
abc_data_traverse(_abc_data *self, visitproc visit, void *arg)
{
@@ -90,7 +134,7 @@ abc_data_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
self->_abc_registry = NULL;
self->_abc_cache = NULL;
self->_abc_negative_cache = NULL;
self->_abc_negative_cache_version = state->abc_invalidation_counter;
self->_abc_negative_cache_version = get_invalidation_counter(state);
return (PyObject *) self;
}

@@ -130,8 +174,12 @@ _get_impl(PyObject *module, PyObject *self)
}

static int
_in_weak_set(PyObject *set, PyObject *obj)
_in_weak_set(_abc_data *impl, PyObject **pset, PyObject *obj)
{
PyObject *set;
Py_BEGIN_CRITICAL_SECTION(impl);
set = *pset;
Py_END_CRITICAL_SECTION();
if (set == NULL || PySet_GET_SIZE(set) == 0) {
return 0;
}
@@ -168,16 +216,19 @@ static PyMethodDef _destroy_def = {
};

static int
_add_to_weak_set(PyObject **pset, PyObject *obj)
_add_to_weak_set(_abc_data *impl, PyObject **pset, PyObject *obj)
{
if (*pset == NULL) {
*pset = PySet_New(NULL);
if (*pset == NULL) {
return -1;
}
PyObject *set;
Py_BEGIN_CRITICAL_SECTION(impl);
set = *pset;
if (set == NULL) {
set = *pset = PySet_New(NULL);
}
Py_END_CRITICAL_SECTION();
if (set == NULL) {
return -1;
}

PyObject *set = *pset;
PyObject *ref, *wr;
PyObject *destroy_cb;
wr = PyWeakref_NewRef(set, NULL);
@@ -220,7 +271,11 @@ _abc__reset_registry(PyObject *module, PyObject *self)
if (impl == NULL) {
return NULL;
}
if (impl->_abc_registry != NULL && PySet_Clear(impl->_abc_registry) < 0) {
PyObject *registry;
Py_BEGIN_CRITICAL_SECTION(impl);
registry = impl->_abc_registry;
Py_END_CRITICAL_SECTION();
if (registry != NULL && PySet_Clear(registry) < 0) {
Py_DECREF(impl);
return NULL;
}
@@ -247,13 +302,17 @@ _abc__reset_caches(PyObject *module, PyObject *self)
if (impl == NULL) {
return NULL;
}
if (impl->_abc_cache != NULL && PySet_Clear(impl->_abc_cache) < 0) {
PyObject *cache, *negative_cache;
Py_BEGIN_CRITICAL_SECTION(impl);
cache = impl->_abc_cache;
negative_cache = impl->_abc_negative_cache;
Py_END_CRITICAL_SECTION();
if (cache != NULL && PySet_Clear(cache) < 0) {
Py_DECREF(impl);
return NULL;
}
/* also the second cache */
if (impl->_abc_negative_cache != NULL &&
PySet_Clear(impl->_abc_negative_cache) < 0) {
if (negative_cache != NULL && PySet_Clear(negative_cache) < 0) {
Py_DECREF(impl);
return NULL;
}
@@ -282,11 +341,14 @@ _abc__get_dump(PyObject *module, PyObject *self)
if (impl == NULL) {
return NULL;
}
PyObject *res = Py_BuildValue("NNNK",
PySet_New(impl->_abc_registry),
PySet_New(impl->_abc_cache),
PySet_New(impl->_abc_negative_cache),
impl->_abc_negative_cache_version);
PyObject *res;
Py_BEGIN_CRITICAL_SECTION(impl);
res = Py_BuildValue("NNNK",
PySet_New(impl->_abc_registry),
PySet_New(impl->_abc_cache),
PySet_New(impl->_abc_negative_cache),
get_cache_version(impl));
Py_END_CRITICAL_SECTION();
Py_DECREF(impl);
return res;
}
@@ -453,56 +515,27 @@ _abc__abc_init(PyObject *module, PyObject *self)
if (PyType_Check(self)) {
PyTypeObject *cls = (PyTypeObject *)self;
PyObject *dict = _PyType_GetDict(cls);
PyObject *flags = PyDict_GetItemWithError(dict,
&_Py_ID(__abc_tpflags__));
if (flags == NULL) {
if (PyErr_Occurred()) {
return NULL;
}
PyObject *flags = NULL;
if (PyDict_Pop(dict, &_Py_ID(__abc_tpflags__), &flags) < 0) {
return NULL;
}
else {
if (PyLong_CheckExact(flags)) {
long val = PyLong_AsLong(flags);
if (val == -1 && PyErr_Occurred()) {
return NULL;
}
if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) {
PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING");
return NULL;
}
((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS);
}
if (PyDict_DelItem(dict, &_Py_ID(__abc_tpflags__)) < 0) {
return NULL;
}
if (flags == NULL || !PyLong_CheckExact(flags)) {
Py_XDECREF(flags);
Py_RETURN_NONE;
}
}
Py_RETURN_NONE;
}

static void
set_collection_flag_recursive(PyTypeObject *child, unsigned long flag)
{
assert(flag == Py_TPFLAGS_MAPPING || flag == Py_TPFLAGS_SEQUENCE);
if (PyType_HasFeature(child, Py_TPFLAGS_IMMUTABLETYPE) ||
(child->tp_flags & COLLECTION_FLAGS) == flag)
{
return;
}

child->tp_flags &= ~COLLECTION_FLAGS;
child->tp_flags |= flag;

PyObject *grandchildren = _PyType_GetSubclasses(child);
if (grandchildren == NULL) {
return;
}

for (Py_ssize_t i = 0; i < PyList_GET_SIZE(grandchildren); i++) {
PyObject *grandchild = PyList_GET_ITEM(grandchildren, i);
set_collection_flag_recursive((PyTypeObject *)grandchild, flag);
long val = PyLong_AsLong(flags);
Py_DECREF(flags);
if (val == -1 && PyErr_Occurred()) {
return NULL;
}
if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) {
PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING");
return NULL;
}
_PyType_SetFlags((PyTypeObject *)self, 0, val & COLLECTION_FLAGS);
}
Py_DECREF(grandchildren);
Py_RETURN_NONE;
}

/*[clinic input]
@@ -545,20 +578,23 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
if (impl == NULL) {
return NULL;
}
if (_add_to_weak_set(&impl->_abc_registry, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_registry, subclass) < 0) {
Py_DECREF(impl);
return NULL;
}
Py_DECREF(impl);

/* Invalidate negative cache */
get_abc_state(module)->abc_invalidation_counter++;
increment_invalidation_counter(get_abc_state(module));

/* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
/* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
if (PyType_Check(self)) {
unsigned long collection_flag = ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS;
unsigned long collection_flag =
PyType_GetFlags((PyTypeObject *)self) & COLLECTION_FLAGS;
if (collection_flag) {
set_collection_flag_recursive((PyTypeObject *)subclass, collection_flag);
_PyType_SetFlagsRecursive((PyTypeObject *)subclass,
COLLECTION_FLAGS,
collection_flag);
}
}
return Py_NewRef(subclass);
@@ -592,7 +628,7 @@ _abc__abc_instancecheck_impl(PyObject *module, PyObject *self,
return NULL;
}
/* Inline the cache checking. */
int incache = _in_weak_set(impl->_abc_cache, subclass);
int incache = _in_weak_set(impl, &impl->_abc_cache, subclass);
if (incache < 0) {
goto end;
}
@@ -602,8 +638,8 @@ _abc__abc_instancecheck_impl(PyObject *module, PyObject *self,
}
subtype = (PyObject *)Py_TYPE(instance);
if (subtype == subclass) {
if (impl->_abc_negative_cache_version == get_abc_state(module)->abc_invalidation_counter) {
incache = _in_weak_set(impl->_abc_negative_cache, subclass);
if (get_cache_version(impl) == get_invalidation_counter(get_abc_state(module))) {
incache = _in_weak_set(impl, &impl->_abc_negative_cache, subclass);
if (incache < 0) {
goto end;
}
@@ -681,7 +717,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
}

/* 1. Check cache. */
incache = _in_weak_set(impl->_abc_cache, subclass);
incache = _in_weak_set(impl, &impl->_abc_cache, subclass);
if (incache < 0) {
goto end;
}
@@ -692,17 +728,20 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,

state = get_abc_state(module);
/* 2. Check negative cache; may have to invalidate. */
if (impl->_abc_negative_cache_version < state->abc_invalidation_counter) {
uint64_t invalidation_counter = get_invalidation_counter(state);
if (get_cache_version(impl) < invalidation_counter) {
/* Invalidate the negative cache. */
if (impl->_abc_negative_cache != NULL &&
PySet_Clear(impl->_abc_negative_cache) < 0)
{
PyObject *negative_cache;
Py_BEGIN_CRITICAL_SECTION(impl);
negative_cache = impl->_abc_negative_cache;
Py_END_CRITICAL_SECTION();
if (negative_cache != NULL && PySet_Clear(negative_cache) < 0) {
goto end;
}
impl->_abc_negative_cache_version = state->abc_invalidation_counter;
set_cache_version(impl, invalidation_counter);
}
else {
incache = _in_weak_set(impl->_abc_negative_cache, subclass);
incache = _in_weak_set(impl, &impl->_abc_negative_cache, subclass);
if (incache < 0) {
goto end;
}
@@ -720,15 +759,15 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
}
if (ok == Py_True) {
Py_DECREF(ok);
if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
goto end;
}
result = Py_True;
goto end;
}
if (ok == Py_False) {
Py_DECREF(ok);
if (_add_to_weak_set(&impl->_abc_negative_cache, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_negative_cache, subclass) < 0) {
goto end;
}
result = Py_False;
@@ -744,7 +783,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,

/* 4. Check if it's a direct subclass. */
if (PyType_IsSubtype((PyTypeObject *)subclass, (PyTypeObject *)self)) {
if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
goto end;
}
result = Py_True;
@@ -767,12 +806,14 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
goto end;
}
for (pos = 0; pos < PyList_GET_SIZE(subclasses); pos++) {
PyObject *scls = PyList_GET_ITEM(subclasses, pos);
Py_INCREF(scls);
PyObject *scls = PyList_GetItemRef(subclasses, pos);
if (scls == NULL) {
goto end;
}
int r = PyObject_IsSubclass(subclass, scls);
Py_DECREF(scls);
if (r > 0) {
if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
goto end;
}
result = Py_True;
@@ -784,7 +825,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
}

/* No dice; update negative cache. */
if (_add_to_weak_set(&impl->_abc_negative_cache, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_negative_cache, subclass) < 0) {
goto end;
}
result = Py_False;
@@ -801,7 +842,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
PyObject **result)
{
// Fast path: check subclass is in weakref directly.
int ret = _in_weak_set(impl->_abc_registry, subclass);
int ret = _in_weak_set(impl, &impl->_abc_registry, subclass);
if (ret < 0) {
*result = NULL;
return -1;
@@ -811,33 +852,27 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
return 1;
}

if (impl->_abc_registry == NULL) {
PyObject *registry_shared;
Py_BEGIN_CRITICAL_SECTION(impl);
registry_shared = impl->_abc_registry;
Py_END_CRITICAL_SECTION();
if (registry_shared == NULL) {
return 0;
}
Py_ssize_t registry_size = PySet_Size(impl->_abc_registry);
if (registry_size == 0) {
return 0;
}
// Weakref callback may remove entry from set.
// So we take snapshot of registry first.
PyObject **copy = PyMem_Malloc(sizeof(PyObject*) * registry_size);
if (copy == NULL) {
PyErr_NoMemory();

// Make a local copy of the registry to protect against concurrent
// modifications of _abc_registry.
PyObject *registry = PySet_New(registry_shared);
if (registry == NULL) {
return -1;
}
PyObject *key;
Py_ssize_t pos = 0;
Py_hash_t hash;
Py_ssize_t i = 0;

while (_PySet_NextEntry(impl->_abc_registry, &pos, &key, &hash)) {
copy[i++] = Py_NewRef(key);
}
assert(i == registry_size);

for (i = 0; i < registry_size; i++) {
while (_PySet_NextEntry(registry, &pos, &key, &hash)) {
PyObject *rkey;
if (PyWeakref_GetRef(copy[i], &rkey) < 0) {
if (PyWeakref_GetRef(key, &rkey) < 0) {
// Someone inject non-weakref type in the registry.
ret = -1;
break;
@@ -853,7 +888,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
break;
}
if (r > 0) {
if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
ret = -1;
break;
}
@@ -863,10 +898,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
}
}

for (i = 0; i < registry_size; i++) {
Py_DECREF(copy[i]);
}
PyMem_Free(copy);
Py_DECREF(registry);
return ret;
}

@@ -885,7 +917,7 @@ _abc_get_cache_token_impl(PyObject *module)
/*[clinic end generated code: output=c7d87841e033dacc input=70413d1c423ad9f9]*/
{
_abcmodule_state *state = get_abc_state(module);
return PyLong_FromUnsignedLongLong(state->abc_invalidation_counter);
return PyLong_FromUnsignedLongLong(get_invalidation_counter(state));
}

static struct PyMethodDef _abcmodule_methods[] = {
46 changes: 46 additions & 0 deletions Objects/typeobject.c
Original file line number Diff line number Diff line change
@@ -5117,6 +5117,52 @@ _PyType_LookupId(PyTypeObject *type, _Py_Identifier *name)
return _PyType_Lookup(type, oname);
}

static void
set_flags(PyTypeObject *self, unsigned long mask, unsigned long flags)
{
ASSERT_TYPE_LOCK_HELD();
self->tp_flags = (self->tp_flags & ~mask) | flags;
}

void
_PyType_SetFlags(PyTypeObject *self, unsigned long mask, unsigned long flags)
{
BEGIN_TYPE_LOCK();
set_flags(self, mask, flags);
END_TYPE_LOCK();
}

static void
set_flags_recursive(PyTypeObject *self, unsigned long mask, unsigned long flags)
{
if (PyType_HasFeature(self, Py_TPFLAGS_IMMUTABLETYPE) ||
(self->tp_flags & mask) == flags)
{
return;
}

set_flags(self, mask, flags);

PyObject *children = _PyType_GetSubclasses(self);
if (children == NULL) {
return;
}

for (Py_ssize_t i = 0; i < PyList_GET_SIZE(children); i++) {
PyObject *child = PyList_GET_ITEM(children, i);
set_flags_recursive((PyTypeObject *)child, mask, flags);
}
Py_DECREF(children);
}

void
_PyType_SetFlagsRecursive(PyTypeObject *self, unsigned long mask, unsigned long flags)
{
BEGIN_TYPE_LOCK();
set_flags_recursive(self, mask, flags);
END_TYPE_LOCK();
}

/* This is similar to PyObject_GenericGetAttr(),
but uses _PyType_Lookup() instead of just looking in type->tp_dict.