Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-91052: Add PyDict_Unwatch for unwatching a dictionary #98055

Merged
merged 2 commits into from
Oct 8, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion Doc/c-api/dict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,24 +246,41 @@ Dictionary Objects
of error (e.g. no more watcher IDs available), return ``-1`` and set an
exception.

.. versionadded:: 3.12

.. c:function:: int PyDict_ClearWatcher(int watcher_id)

Clear watcher identified by *watcher_id* previously returned from
:c:func:`PyDict_AddWatcher`. Return ``0`` on success, ``-1`` on error (e.g.
if the given *watcher_id* was never registered.)

.. versionadded:: 3.12

.. c:function:: int PyDict_Watch(int watcher_id, PyObject *dict)

Mark dictionary *dict* as watched. The callback granted *watcher_id* by
:c:func:`PyDict_AddWatcher` will be called when *dict* is modified or
deallocated.
deallocated. Return ``0`` on success or ``-1`` on error.

.. versionadded:: 3.12

.. c:function:: int PyDict_Unwatch(int watcher_id, PyObject *dict)

Mark dictionary *dict* as no longer watched. The callback granted
*watcher_id* by :c:func:`PyDict_AddWatcher` will no longer be called when
*dict* is modified or deallocated. The dict must previously have been
watched by this watcher. Return ``0`` on success or ``-1`` on error.

.. versionadded:: 3.12

.. c:type:: PyDict_WatchEvent

Enumeration of possible dictionary watcher events: ``PyDict_EVENT_ADDED``,
``PyDict_EVENT_MODIFIED``, ``PyDict_EVENT_DELETED``, ``PyDict_EVENT_CLONED``,
``PyDict_EVENT_CLEARED``, or ``PyDict_EVENT_DEALLOCATED``.

.. versionadded:: 3.12

.. c:type:: int (*PyDict_WatchCallback)(PyDict_WatchEvent event, PyObject *dict, PyObject *key, PyObject *new_value)

Type of a dict watcher callback function.
Expand All @@ -289,3 +306,5 @@ Dictionary Objects
If the callback returns with an exception set, it must return ``-1``; this
exception will be printed as an unraisable exception using
:c:func:`PyErr_WriteUnraisable`. Otherwise it should return ``0``.

.. versionadded:: 3.12
5 changes: 5 additions & 0 deletions Doc/whatsnew/3.12.rst
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,11 @@ New Features
which sets the vectorcall field of a given :c:type:`PyFunctionObject`.
(Contributed by Andrew Frost in :gh:`92257`.)

* The C API now permits registering callbacks via :c:func:`PyDict_AddWatcher`,
:c:func:`PyDict_AddWatch` and related APIs to be called whenever a dictionary
is modified. This is intended for use by optimizing interpreters, JIT
compilers, or debuggers.

Porting to Python 3.12
----------------------

Expand Down
1 change: 1 addition & 0 deletions Include/cpython/dictobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,4 @@ PyAPI_FUNC(int) PyDict_ClearWatcher(int watcher_id);

// Mark given dictionary as "watched" (callback will be called if it is modified)
PyAPI_FUNC(int) PyDict_Watch(int watcher_id, PyObject* dict);
PyAPI_FUNC(int) PyDict_Unwatch(int watcher_id, PyObject* dict);
60 changes: 44 additions & 16 deletions Lib/test/test_capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import weakref
from test import support
from test.support import MISSING_C_DOCSTRINGS
from test.support import catch_unraisable_exception
from test.support import import_helper
from test.support import threading_helper
from test.support import warnings_helper
Expand Down Expand Up @@ -1421,6 +1422,9 @@ def assert_events(self, expected):
def watch(self, wid, d):
_testcapi.watch_dict(wid, d)

def unwatch(self, wid, d):
_testcapi.unwatch_dict(wid, d)

def test_set_new_item(self):
d = {}
with self.watcher() as wid:
Expand Down Expand Up @@ -1477,27 +1481,24 @@ def test_dealloc(self):
del d
self.assert_events(["dealloc"])

def test_unwatch(self):
d = {}
with self.watcher() as wid:
self.watch(wid, d)
d["foo"] = "bar"
self.unwatch(wid, d)
d["hmm"] = "baz"
self.assert_events(["new:foo:bar"])

def test_error(self):
d = {}
unraisables = []
def unraisable_hook(unraisable):
unraisables.append(unraisable)
with self.watcher(kind=self.ERROR) as wid:
self.watch(wid, d)
orig_unraisable_hook = sys.unraisablehook
sys.unraisablehook = unraisable_hook
try:
with catch_unraisable_exception() as cm:
d["foo"] = "bar"
finally:
sys.unraisablehook = orig_unraisable_hook
self.assertIs(cm.unraisable.object, d)
self.assertEqual(str(cm.unraisable.exc_value), "boom!")
self.assert_events([])
self.assertEqual(len(unraisables), 1)
unraisable = unraisables[0]
self.assertIs(unraisable.object, d)
self.assertEqual(str(unraisable.exc_value), "boom!")
# avoid leaking reference cycles
del unraisable
del unraisables

def test_two_watchers(self):
d1 = {}
Expand All @@ -1522,11 +1523,38 @@ def test_watch_out_of_range_watcher_id(self):
with self.assertRaisesRegex(ValueError, r"Invalid dict watcher ID 8"):
self.watch(8, d) # DICT_MAX_WATCHERS = 8

def test_unassigned_watcher_id(self):
def test_watch_unassigned_watcher_id(self):
d = {}
with self.assertRaisesRegex(ValueError, r"No dict watcher set for ID 1"):
self.watch(1, d)

def test_unwatch_non_dict(self):
with self.watcher() as wid:
with self.assertRaisesRegex(ValueError, r"Cannot watch non-dictionary"):
self.unwatch(wid, 1)

def test_unwatch_out_of_range_watcher_id(self):
d = {}
with self.assertRaisesRegex(ValueError, r"Invalid dict watcher ID -1"):
self.unwatch(-1, d)
with self.assertRaisesRegex(ValueError, r"Invalid dict watcher ID 8"):
self.unwatch(8, d) # DICT_MAX_WATCHERS = 8

def test_unwatch_unassigned_watcher_id(self):
d = {}
with self.assertRaisesRegex(ValueError, r"No dict watcher set for ID 1"):
self.unwatch(1, d)

def test_clear_out_of_range_watcher_id(self):
with self.assertRaisesRegex(ValueError, r"Invalid dict watcher ID -1"):
self.clear_watcher(-1)
with self.assertRaisesRegex(ValueError, r"Invalid dict watcher ID 8"):
self.clear_watcher(8) # DICT_MAX_WATCHERS = 8

def test_clear_unassigned_watcher_id(self):
with self.assertRaisesRegex(ValueError, r"No dict watcher set for ID 1"):
self.clear_watcher(1)


if __name__ == "__main__":
unittest.main()
15 changes: 15 additions & 0 deletions Modules/_testcapimodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -5296,6 +5296,20 @@ watch_dict(PyObject *self, PyObject *args)
Py_RETURN_NONE;
}

static PyObject *
unwatch_dict(PyObject *self, PyObject *args)
{
PyObject *dict;
int watcher_id;
if (!PyArg_ParseTuple(args, "iO", &watcher_id, &dict)) {
return NULL;
}
if (PyDict_Unwatch(watcher_id, dict)) {
return NULL;
}
Py_RETURN_NONE;
}

static PyObject *
get_dict_watcher_events(PyObject *self, PyObject *Py_UNUSED(args))
{
Expand Down Expand Up @@ -5904,6 +5918,7 @@ static PyMethodDef TestMethods[] = {
{"add_dict_watcher", add_dict_watcher, METH_O, NULL},
{"clear_dict_watcher", clear_dict_watcher, METH_O, NULL},
{"watch_dict", watch_dict, METH_VARARGS, NULL},
{"unwatch_dict", unwatch_dict, METH_VARARGS, NULL},
{"get_dict_watcher_events", get_dict_watcher_events, METH_NOARGS, NULL},
{NULL, NULL} /* sentinel */
};
Expand Down
41 changes: 30 additions & 11 deletions Objects/dictobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -5720,23 +5720,47 @@ uint32_t _PyDictKeys_GetVersionForCurrentState(PyDictKeysObject *dictkeys)
return v;
}

static inline int
validate_watcher_id(PyInterpreterState *interp, int watcher_id)
{
if (watcher_id < 0 || watcher_id >= DICT_MAX_WATCHERS) {
PyErr_Format(PyExc_ValueError, "Invalid dict watcher ID %d", watcher_id);
return -1;
}
if (!interp->dict_watchers[watcher_id]) {
PyErr_Format(PyExc_ValueError, "No dict watcher set for ID %d", watcher_id);
return -1;
}
return 0;
}

int
PyDict_Watch(int watcher_id, PyObject* dict)
{
if (!PyDict_Check(dict)) {
PyErr_SetString(PyExc_ValueError, "Cannot watch non-dictionary");
return -1;
}
if (watcher_id < 0 || watcher_id >= DICT_MAX_WATCHERS) {
PyErr_Format(PyExc_ValueError, "Invalid dict watcher ID %d", watcher_id);
PyInterpreterState *interp = _PyInterpreterState_GET();
if (validate_watcher_id(interp, watcher_id)) {
return -1;
}
((PyDictObject*)dict)->ma_version_tag |= (1LL << watcher_id);
return 0;
}

int
PyDict_Unwatch(int watcher_id, PyObject* dict)
{
if (!PyDict_Check(dict)) {
PyErr_SetString(PyExc_ValueError, "Cannot watch non-dictionary");
return -1;
}
PyInterpreterState *interp = _PyInterpreterState_GET();
if (!interp->dict_watchers[watcher_id]) {
PyErr_Format(PyExc_ValueError, "No dict watcher set for ID %d", watcher_id);
if (validate_watcher_id(interp, watcher_id)) {
return -1;
}
((PyDictObject*)dict)->ma_version_tag |= (1LL << watcher_id);
((PyDictObject*)dict)->ma_version_tag &= ~(1LL << watcher_id);
return 0;
}

Expand All @@ -5759,13 +5783,8 @@ PyDict_AddWatcher(PyDict_WatchCallback callback)
int
PyDict_ClearWatcher(int watcher_id)
{
if (watcher_id < 0 || watcher_id >= DICT_MAX_WATCHERS) {
PyErr_Format(PyExc_ValueError, "Invalid dict watcher ID %d", watcher_id);
return -1;
}
PyInterpreterState *interp = _PyInterpreterState_GET();
if (!interp->dict_watchers[watcher_id]) {
PyErr_Format(PyExc_ValueError, "No dict watcher set for ID %d", watcher_id);
if (validate_watcher_id(interp, watcher_id)) {
return -1;
}
interp->dict_watchers[watcher_id] = NULL;
Expand Down