diff --git a/src/cereggii/_cereggii.pyi b/src/cereggii/_cereggii.pyi index bc7221ac..0a5752f4 100644 --- a/src/cereggii/_cereggii.pyi +++ b/src/cereggii/_cereggii.pyi @@ -1,8 +1,8 @@ -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from typing import Callable, NewType, SupportsComplex, SupportsFloat, SupportsInt Key = NewType("Key", object) -Value = NewType("Value", object | None) +Value = NewType("Value", object) Cancel = NewType("Cancel", object) Number = SupportsInt | SupportsFloat | SupportsComplex @@ -137,29 +137,36 @@ class AtomicDict: :returns: the input :param:`batch` dictionary, with substituted values. """ - # def aggregate(self, iterator: Iterator[tuple[Key, Value]], aggregation: Callable[[Key, Value, Value], Value]) - # -> None: - # """ - # Aggregate the values in this dictionary with those found in ``iterator``, as computed by ``aggregation``. - # - # The ``aggregation`` parameter expects a function that takes as input a key, the value currently stored - # in the dictionary, and the new value from ``iterator``, and then returns the aggregated value. - # - # For instance, to aggregate counts:: - # - # d = AtomicDict() - # - # it = [ - # ("red", 1), - # ("green", 3), - # ("blue", 40), - # ("red", 1), - # ] - # - # d.aggregate(it, lambda key, current, new: - # new if current is cereggii.NOT_FOUND else current + new - # ) - # """ + def reduce( + self, + iterable: Iterable[tuple[Key, Value]], + aggregate: Callable[[Key, Value, Value], Value], + chunk_size: int = 0, + ) -> None: + """ + Aggregate the values in this dictionary with those found in ``iterable``, as computed by ``aggregate``. + + The ``aggregate`` function takes as input a key, the value currently stored in the dictionary, + and the new value from ``iterator``. It returns the aggregated value. + + For instance, to aggregate some simple counts:: + + d = AtomicDict() + + data = [ + ("red", 1), + ("green", 42), + ("blue", 3), + ("red", 5), + ] + + def count(key, current, new): + if current is cereggii.NOT_FOUND: + return new + return current + new + + d.reduce(data, count) + """ def compact(self) -> None: ... def debug(self) -> dict: ... def rehash(self, o: object) -> int: ... diff --git a/src/cereggii/atomic_dict/insert.c b/src/cereggii/atomic_dict/insert.c index 55f1a8d9..54fe0b80 100644 --- a/src/cereggii/atomic_dict/insert.c +++ b/src/cereggii/atomic_dict/insert.c @@ -488,3 +488,227 @@ AtomicDict_SetItem(AtomicDict *self, PyObject *key, PyObject *value) fail: return -1; } + + +#define REDUCE_FLUSH_CHUNK_SIZE 64 + + +static inline int +reduce_flush(AtomicDict *self, PyObject *local_buffer, PyObject *aggregate) +{ + AtomicDict_Meta *meta = NULL; + PyObject **keys = NULL; + PyObject **expected = NULL; + PyObject **desired = NULL; + + keys = PyMem_RawMalloc(REDUCE_FLUSH_CHUNK_SIZE * sizeof(PyObject *)); + if (keys == NULL) + goto fail; + + expected = PyMem_RawMalloc(REDUCE_FLUSH_CHUNK_SIZE * sizeof(PyObject *)); + if (expected == NULL) + goto fail; + + desired = PyMem_RawMalloc(REDUCE_FLUSH_CHUNK_SIZE * sizeof(PyObject *)); + if (desired == NULL) + goto fail; + + Py_ssize_t pos = 0; + + get_meta: + meta = (AtomicDict_Meta *) AtomicRef_Get(self->metadata); + if (meta == NULL) + goto fail; + + for (;;) { + int num_items = 0; + int done = 0; + + for (int i = 0; i < REDUCE_FLUSH_CHUNK_SIZE; ++i) { + PyObject *values; + if (PyDict_Next(local_buffer, &pos, &keys[num_items], &values)) { + expected[num_items] = PyTuple_GetItem(values, 0); + desired[num_items] = PyTuple_GetItem(values, 1); + + if (expected[num_items] == NULL || desired[num_items] == NULL) + goto fail; + + num_items++; + } else { + done = 1; + break; + } + } + + for (int i = 0; i < num_items; ++i) { + // prefetch index + Py_hash_t hash = PyObject_Hash(keys[i]); + if (hash == -1) + goto fail; + + __builtin_prefetch(AtomicDict_IndexAddressOf( + AtomicDict_Distance0Of(hash, meta), + meta + )); + } + + for (int i = 0; i < num_items; ++i) { + PyObject *result = NULL; + PyObject *new = desired[i]; + + retry: + result = AtomicDict_CompareAndSet(self, keys[i], expected[i], new); + + if (result == NULL) + goto fail; + + if (result == EXPECTATION_FAILED) { + PyObject *current = AtomicDict_GetItemOrDefault(self, keys[i], NOT_FOUND); + if (current == NULL) + goto fail; + + expected[i] = current; + new = PyObject_CallFunctionObjArgs(aggregate, keys[i], expected[i], desired[i], NULL); + if (new == NULL) + goto fail; + + goto retry; + } + } + + if (done) + break; + + + if (meta->new_gen_metadata != NULL) + goto get_meta; + } + + PyMem_RawFree(keys); + PyMem_RawFree(expected); + PyMem_RawFree(desired); + Py_DECREF(meta); + return 0; + + fail: + if (keys != NULL) { + PyMem_RawFree(keys); + } + if (expected != NULL) { + PyMem_RawFree(expected); + } + if (desired != NULL) { + PyMem_RawFree(desired); + } + Py_XDECREF(meta); + return -1; +} + + +int +AtomicDict_Reduce(AtomicDict *self, PyObject *iterable, PyObject *aggregate, int chunk_size) +{ + PyObject *local_buffer = NULL; + PyObject *item = NULL; + + if (!PyCallable_Check(aggregate)) { + PyErr_Format(PyExc_TypeError, "%R is not callable.", aggregate); + goto fail; + } + + if (chunk_size != 0) { + PyErr_SetString(PyExc_ValueError, "only supporting chunk_size=0 at the moment."); + goto fail; + } + + local_buffer = PyDict_New(); + if (local_buffer == NULL) + goto fail; + + PyObject *iterator = PyObject_GetIter(iterable); + + if (iterator == NULL) { + PyErr_Format(PyExc_TypeError, "%R is not iterable.", iterable); + goto fail; + } + + while ((item = PyIter_Next(iterator))) { + PyObject *key = NULL; + PyObject *value = NULL; + PyObject *current = NULL; + PyObject *expected = NULL; + PyObject *desired = NULL; + + if (!PyTuple_CheckExact(item)) { + PyErr_Format(PyExc_TypeError, "type(%R) != tuple", item); + goto fail; + } + + if (PyTuple_Size(item) != 2) { + PyErr_Format(PyExc_TypeError, "len(%R) != 2", item); + goto fail; + } + + key = PyTuple_GetItem(item, 0); + value = PyTuple_GetItem(item, 1); + if (PyDict_Contains(local_buffer, key)) { + current = PyDict_GetItem(local_buffer, key); + expected = PyTuple_GetItem(current, 0); + current = PyTuple_GetItem(current, 1); + } else { + expected = NOT_FOUND; + current = expected; + } + + desired = PyObject_CallFunctionObjArgs(aggregate, key, current, value, NULL); + if (desired == NULL) + goto fail; + + PyObject *tuple = PyTuple_Pack(2, expected, desired); + if (tuple == NULL) + goto fail; + + if (PyDict_SetItem(local_buffer, key, tuple) < 0) + goto fail; + + Py_DECREF(item); + } + + Py_DECREF(iterator); + + if (PyErr_Occurred()) { + goto fail; + } + + reduce_flush(self, local_buffer, aggregate); + + Py_DECREF(local_buffer); + return 0; + + fail: + Py_XDECREF(local_buffer); + Py_XDECREF(item); + return -1; +} + +PyObject * +AtomicDict_Reduce_callable(AtomicDict *self, PyObject *args, PyObject *kwargs) +{ + PyObject *iterable = NULL; + PyObject *aggregate = NULL; + int chunk_size = 0; + + char *kw_list[] = {"iterable", "aggregate", "chunk_size", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO|i", kw_list, &iterable, &aggregate, &chunk_size)) + goto fail; + + int res = AtomicDict_Reduce(self, iterable, aggregate, chunk_size); + if (res < 0) + goto fail; + + Py_RETURN_NONE; + + fail: + return NULL; +} diff --git a/src/cereggii/cereggii.c b/src/cereggii/cereggii.c index d4842fb7..179ce595 100644 --- a/src/cereggii/cereggii.c +++ b/src/cereggii/cereggii.c @@ -222,6 +222,7 @@ static PyMethodDef AtomicDict_methods[] = { {"fast_iter", (PyCFunction) AtomicDict_FastIter, METH_VARARGS | METH_KEYWORDS, NULL}, {"compare_and_set", (PyCFunction) AtomicDict_CompareAndSet_callable, METH_VARARGS | METH_KEYWORDS, NULL}, {"batch_getitem", (PyCFunction) AtomicDict_BatchGetItem, METH_VARARGS | METH_KEYWORDS, NULL}, + {"reduce", (PyCFunction) AtomicDict_Reduce_callable, METH_VARARGS | METH_KEYWORDS, NULL}, {NULL} }; diff --git a/src/include/atomic_dict.h b/src/include/atomic_dict.h index 910b41b6..883be576 100644 --- a/src/include/atomic_dict.h +++ b/src/include/atomic_dict.h @@ -48,6 +48,10 @@ PyObject *AtomicDict_CompareAndSet(AtomicDict *self, PyObject *key, PyObject *ex PyObject *AtomicDict_CompareAndSet_callable(AtomicDict *self, PyObject *args, PyObject *kwargs); +int AtomicDict_Reduce(AtomicDict *self, PyObject *iterable, PyObject *aggregate, int chunk_size); + +PyObject *AtomicDict_Reduce_callable(AtomicDict *self, PyObject *args, PyObject *kwargs); + int AtomicDict_Compact(AtomicDict *self); PyObject *AtomicDict_Compact_callable(AtomicDict *self); diff --git a/tests/test_atomic_dict.py b/tests/test_atomic_dict.py index 47ee5255..7d0d8675 100644 --- a/tests/test_atomic_dict.py +++ b/tests/test_atomic_dict.py @@ -719,3 +719,25 @@ def test_len(): assert len(d) == 10 assert len(d) == 10 # test twice for len_dirty + + +def test_reduce(): + d = AtomicDict() + + data = [ + ("red", 1), + ("green", 42), + ("blue", 3), + ("red", 5), + ] + + def count(key, current, new): # noqa: ARG001 + if current is cereggii.NOT_FOUND: + return new + return current + new + + d.reduce(data, count) + + assert d["red"] == 6 + assert d["green"] == 42 + assert d["blue"] == 3