Skip to content

Commit

Permalink
gh-112066: Add PyDict_SetDefaultRef function. (#112123)
Browse files Browse the repository at this point in the history
The `PyDict_SetDefaultRef` function is similar to `PyDict_SetDefault`,
but returns a strong reference through the optional `**result` pointer
instead of a borrowed reference.

Co-authored-by: Petr Viktorin <encukou@gmail.com>
  • Loading branch information
colesbury and encukou authored Feb 6, 2024
1 parent 0e2ab73 commit de61d4b
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 20 deletions.
20 changes: 20 additions & 0 deletions Doc/c-api/dict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,26 @@ Dictionary Objects
.. versionadded:: 3.4
.. c:function:: int PyDict_SetDefaultRef(PyObject *p, PyObject *key, PyObject *default_value, PyObject **result)
Inserts *default_value* into the dictionary *p* with a key of *key* if the
key is not already present in the dictionary. If *result* is not ``NULL``,
then *\*result* is set to a :term:`strong reference` to either
*default_value*, if the key was not present, or the existing value, if *key*
was already present in the dictionary.
Returns ``1`` if the key was present and *default_value* was not inserted,
or ``0`` if the key was not present and *default_value* was inserted.
On failure, returns ``-1``, sets an exception, and sets ``*result``
to ``NULL``.
For clarity: if you have a strong reference to *default_value* before
calling this function, then after it returns, you hold a strong reference
to both *default_value* and *\*result* (if it's not ``NULL``).
These may refer to the same object: in that case you hold two separate
references to it.
.. versionadded:: 3.13
.. c:function:: int PyDict_Pop(PyObject *p, PyObject *key, PyObject **result)
Remove *key* from dictionary *p* and optionally return the removed value.
Expand Down
6 changes: 6 additions & 0 deletions Doc/whatsnew/3.13.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1440,6 +1440,12 @@ New Features
not needed.
(Contributed by Victor Stinner in :gh:`106004`.)

* Added :c:func:`PyDict_SetDefaultRef`, which is similar to
:c:func:`PyDict_SetDefault` but returns a :term:`strong reference` instead of
a :term:`borrowed reference`. This function returns ``-1`` on error, ``0`` on
insertion, and ``1`` if the key was already present in the dictionary.
(Contributed by Sam Gross in :gh:`112066`.)

* Add :c:func:`PyDict_ContainsString` function: same as
:c:func:`PyDict_Contains`, but *key* is specified as a :c:expr:`const char*`
UTF-8 encoded bytes string, rather than a :c:expr:`PyObject*`.
Expand Down
10 changes: 10 additions & 0 deletions Include/cpython/dictobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ PyAPI_FUNC(PyObject *) _PyDict_GetItemStringWithError(PyObject *, const char *);
PyAPI_FUNC(PyObject *) PyDict_SetDefault(
PyObject *mp, PyObject *key, PyObject *defaultobj);

// Inserts `key` with a value `default_value`, if `key` is not already present
// in the dictionary. If `result` is not NULL, then the value associated
// with `key` is returned in `*result` (either the existing value, or the now
// inserted `default_value`).
// Returns:
// -1 on error
// 0 if `key` was not present and `default_value` was inserted
// 1 if `key` was present and `default_value` was not inserted
PyAPI_FUNC(int) PyDict_SetDefaultRef(PyObject *mp, PyObject *key, PyObject *default_value, PyObject **result);

/* Get the number of items of a dictionary. */
static inline Py_ssize_t PyDict_GET_SIZE(PyObject *op) {
PyDictObject *mp;
Expand Down
22 changes: 22 additions & 0 deletions Lib/test/test_capi/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,28 @@ def test_dict_setdefault(self):
# CRASHES setdefault({}, 'a', NULL)
# CRASHES setdefault(NULL, 'a', 5)

def test_dict_setdefaultref(self):
setdefault = _testcapi.dict_setdefaultref
dct = {}
self.assertEqual(setdefault(dct, 'a', 5), 5)
self.assertEqual(dct, {'a': 5})
self.assertEqual(setdefault(dct, 'a', 8), 5)
self.assertEqual(dct, {'a': 5})

dct2 = DictSubclass()
self.assertEqual(setdefault(dct2, 'a', 5), 5)
self.assertEqual(dct2, {'a': 5})
self.assertEqual(setdefault(dct2, 'a', 8), 5)
self.assertEqual(dct2, {'a': 5})

self.assertRaises(TypeError, setdefault, {}, [], 5) # unhashable
self.assertRaises(SystemError, setdefault, UserDict(), 'a', 5)
self.assertRaises(SystemError, setdefault, [1], 0, 5)
self.assertRaises(SystemError, setdefault, 42, 'a', 5)
# CRASHES setdefault({}, NULL, 5)
# CRASHES setdefault({}, 'a', NULL)
# CRASHES setdefault(NULL, 'a', 5)

def test_mapping_keys_valuesitems(self):
class BadMapping(dict):
def keys(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Add :c:func:`PyDict_SetDefaultRef`: insert a key and value into a dictionary
if the key is not already present. This is similar to
:meth:`dict.setdefault`, but returns an integer value indicating if the key
was already present. It is also similar to :c:func:`PyDict_SetDefault`, but
returns a strong reference instead of a borrowed reference.
26 changes: 26 additions & 0 deletions Modules/_testcapi/dict.c
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,31 @@ dict_setdefault(PyObject *self, PyObject *args)
return PyDict_SetDefault(mapping, key, defaultobj);
}

static PyObject *
dict_setdefaultref(PyObject *self, PyObject *args)
{
PyObject *obj, *key, *default_value, *result = UNINITIALIZED_PTR;
if (!PyArg_ParseTuple(args, "OOO", &obj, &key, &default_value)) {
return NULL;
}
NULLABLE(obj);
NULLABLE(key);
NULLABLE(default_value);
switch (PyDict_SetDefaultRef(obj, key, default_value, &result)) {
case -1:
assert(result == NULL);
return NULL;
case 0:
assert(result == default_value);
return result;
case 1:
return result;
default:
Py_FatalError("PyDict_SetDefaultRef() returned invalid code");
Py_UNREACHABLE();
}
}

static PyObject *
dict_delitem(PyObject *self, PyObject *args)
{
Expand Down Expand Up @@ -433,6 +458,7 @@ static PyMethodDef test_methods[] = {
{"dict_delitem", dict_delitem, METH_VARARGS},
{"dict_delitemstring", dict_delitemstring, METH_VARARGS},
{"dict_setdefault", dict_setdefault, METH_VARARGS},
{"dict_setdefaultref", dict_setdefaultref, METH_VARARGS},
{"dict_keys", dict_keys, METH_O},
{"dict_values", dict_values, METH_O},
{"dict_items", dict_items, METH_O},
Expand Down
91 changes: 71 additions & 20 deletions Objects/dictobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -3355,8 +3355,9 @@ dict_get_impl(PyDictObject *self, PyObject *key, PyObject *default_value)
return Py_NewRef(val);
}

PyObject *
PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
static int
dict_setdefault_ref(PyObject *d, PyObject *key, PyObject *default_value,
PyObject **result, int incref_result)
{
PyDictObject *mp = (PyDictObject *)d;
PyObject *value;
Expand All @@ -3365,41 +3366,64 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)

if (!PyDict_Check(d)) {
PyErr_BadInternalCall();
return NULL;
if (result) {
*result = NULL;
}
return -1;
}

if (!PyUnicode_CheckExact(key) || (hash = unicode_get_hash(key)) == -1) {
hash = PyObject_Hash(key);
if (hash == -1)
return NULL;
if (hash == -1) {
if (result) {
*result = NULL;
}
return -1;
}
}

if (mp->ma_keys == Py_EMPTY_KEYS) {
if (insert_to_emptydict(interp, mp, Py_NewRef(key), hash,
Py_NewRef(defaultobj)) < 0) {
return NULL;
Py_NewRef(default_value)) < 0) {
if (result) {
*result = NULL;
}
return -1;
}
if (result) {
*result = incref_result ? Py_NewRef(default_value) : default_value;
}
return defaultobj;
return 0;
}

if (!PyUnicode_CheckExact(key) && DK_IS_UNICODE(mp->ma_keys)) {
if (insertion_resize(interp, mp, 0) < 0) {
return NULL;
if (result) {
*result = NULL;
}
return -1;
}
}

Py_ssize_t ix = _Py_dict_lookup(mp, key, hash, &value);
if (ix == DKIX_ERROR)
return NULL;
if (ix == DKIX_ERROR) {
if (result) {
*result = NULL;
}
return -1;
}

if (ix == DKIX_EMPTY) {
uint64_t new_version = _PyDict_NotifyEvent(
interp, PyDict_EVENT_ADDED, mp, key, defaultobj);
interp, PyDict_EVENT_ADDED, mp, key, default_value);
mp->ma_keys->dk_version = 0;
value = defaultobj;
value = default_value;
if (mp->ma_keys->dk_usable <= 0) {
if (insertion_resize(interp, mp, 1) < 0) {
return NULL;
if (result) {
*result = NULL;
}
return -1;
}
}
Py_ssize_t hashpos = find_empty_slot(mp->ma_keys, hash);
Expand Down Expand Up @@ -3431,22 +3455,50 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
mp->ma_keys->dk_usable--;
mp->ma_keys->dk_nentries++;
assert(mp->ma_keys->dk_usable >= 0);
ASSERT_CONSISTENT(mp);
if (result) {
*result = incref_result ? Py_NewRef(value) : value;
}
return 0;
}
else if (value == NULL) {
uint64_t new_version = _PyDict_NotifyEvent(
interp, PyDict_EVENT_ADDED, mp, key, defaultobj);
value = defaultobj;
interp, PyDict_EVENT_ADDED, mp, key, default_value);
value = default_value;
assert(_PyDict_HasSplitTable(mp));
assert(mp->ma_values->values[ix] == NULL);
MAINTAIN_TRACKING(mp, key, value);
mp->ma_values->values[ix] = Py_NewRef(value);
_PyDictValues_AddToInsertionOrder(mp->ma_values, ix);
mp->ma_used++;
mp->ma_version_tag = new_version;
ASSERT_CONSISTENT(mp);
if (result) {
*result = incref_result ? Py_NewRef(value) : value;
}
return 0;
}

ASSERT_CONSISTENT(mp);
return value;
if (result) {
*result = incref_result ? Py_NewRef(value) : value;
}
return 1;
}

int
PyDict_SetDefaultRef(PyObject *d, PyObject *key, PyObject *default_value,
PyObject **result)
{
return dict_setdefault_ref(d, key, default_value, result, 1);
}

PyObject *
PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
{
PyObject *result;
dict_setdefault_ref(d, key, defaultobj, &result, 0);
return result;
}

/*[clinic input]
Expand All @@ -3467,9 +3519,8 @@ dict_setdefault_impl(PyDictObject *self, PyObject *key,
/*[clinic end generated code: output=f8c1101ebf69e220 input=0f063756e815fd9d]*/
{
PyObject *val;

val = PyDict_SetDefault((PyObject *)self, key, default_value);
return Py_XNewRef(val);
PyDict_SetDefaultRef((PyObject *)self, key, default_value, &val);
return val;
}


Expand Down

0 comments on commit de61d4b

Please sign in to comment.