Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AtomicDict.reduce() #26

Merged
merged 3 commits into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 32 additions & 25 deletions src/cereggii/_cereggii.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections.abc import Iterator
from collections.abc import Iterable, Iterator
from typing import Callable, NewType, SupportsComplex, SupportsFloat, SupportsInt

Key = NewType("Key", object)
Value = NewType("Value", object | None)
Value = NewType("Value", object)
Cancel = NewType("Cancel", object)

Number = SupportsInt | SupportsFloat | SupportsComplex
Expand Down Expand Up @@ -137,29 +137,36 @@ class AtomicDict:

:returns: the input :param:`batch` dictionary, with substituted values.
"""
# def aggregate(self, iterator: Iterator[tuple[Key, Value]], aggregation: Callable[[Key, Value, Value], Value])
# -> None:
# """
# Aggregate the values in this dictionary with those found in ``iterator``, as computed by ``aggregation``.
#
# The ``aggregation`` parameter expects a function that takes as input a key, the value currently stored
# in the dictionary, and the new value from ``iterator``, and then returns the aggregated value.
#
# For instance, to aggregate counts::
#
# d = AtomicDict()
#
# it = [
# ("red", 1),
# ("green", 3),
# ("blue", 40),
# ("red", 1),
# ]
#
# d.aggregate(it, lambda key, current, new:
# new if current is cereggii.NOT_FOUND else current + new
# )
# """
def reduce(
self,
iterable: Iterable[tuple[Key, Value]],
aggregate: Callable[[Key, Value, Value], Value],
chunk_size: int = 0,
) -> None:
"""
Aggregate the values in this dictionary with those found in ``iterable``, as computed by ``aggregate``.

The ``aggregate`` function takes as input a key, the value currently stored in the dictionary,
and the new value from ``iterator``. It returns the aggregated value.

For instance, to aggregate some simple counts::

d = AtomicDict()

data = [
("red", 1),
("green", 42),
("blue", 3),
("red", 5),
]

def count(key, current, new):
if current is cereggii.NOT_FOUND:
return new
return current + new

d.reduce(data, count)
"""
def compact(self) -> None: ...
def debug(self) -> dict: ...
def rehash(self, o: object) -> int: ...
Expand Down
224 changes: 224 additions & 0 deletions src/cereggii/atomic_dict/insert.c
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,227 @@ AtomicDict_SetItem(AtomicDict *self, PyObject *key, PyObject *value)
fail:
return -1;
}


#define REDUCE_FLUSH_CHUNK_SIZE 64


static inline int
reduce_flush(AtomicDict *self, PyObject *local_buffer, PyObject *aggregate)
{
AtomicDict_Meta *meta = NULL;
PyObject **keys = NULL;
PyObject **expected = NULL;
PyObject **desired = NULL;

keys = PyMem_RawMalloc(REDUCE_FLUSH_CHUNK_SIZE * sizeof(PyObject *));
if (keys == NULL)
goto fail;

expected = PyMem_RawMalloc(REDUCE_FLUSH_CHUNK_SIZE * sizeof(PyObject *));
if (expected == NULL)
goto fail;

desired = PyMem_RawMalloc(REDUCE_FLUSH_CHUNK_SIZE * sizeof(PyObject *));
if (desired == NULL)
goto fail;

Py_ssize_t pos = 0;

get_meta:
meta = (AtomicDict_Meta *) AtomicRef_Get(self->metadata);
if (meta == NULL)
goto fail;

for (;;) {
int num_items = 0;
int done = 0;

for (int i = 0; i < REDUCE_FLUSH_CHUNK_SIZE; ++i) {
PyObject *values;
if (PyDict_Next(local_buffer, &pos, &keys[num_items], &values)) {
expected[num_items] = PyTuple_GetItem(values, 0);
desired[num_items] = PyTuple_GetItem(values, 1);

if (expected[num_items] == NULL || desired[num_items] == NULL)
goto fail;

num_items++;
} else {
done = 1;
break;
}
}

for (int i = 0; i < num_items; ++i) {
// prefetch index
Py_hash_t hash = PyObject_Hash(keys[i]);
if (hash == -1)
goto fail;

__builtin_prefetch(AtomicDict_IndexAddressOf(
AtomicDict_Distance0Of(hash, meta),
meta
));
}

for (int i = 0; i < num_items; ++i) {
PyObject *result = NULL;
PyObject *new = desired[i];

retry:
result = AtomicDict_CompareAndSet(self, keys[i], expected[i], new);

if (result == NULL)
goto fail;

if (result == EXPECTATION_FAILED) {
PyObject *current = AtomicDict_GetItemOrDefault(self, keys[i], NOT_FOUND);
if (current == NULL)
goto fail;

expected[i] = current;
new = PyObject_CallFunctionObjArgs(aggregate, keys[i], expected[i], desired[i], NULL);
if (new == NULL)
goto fail;

goto retry;
}
}

if (done)
break;


if (meta->new_gen_metadata != NULL)
goto get_meta;
}

PyMem_RawFree(keys);
PyMem_RawFree(expected);
PyMem_RawFree(desired);
Py_DECREF(meta);
return 0;

fail:
if (keys != NULL) {
PyMem_RawFree(keys);
}
if (expected != NULL) {
PyMem_RawFree(expected);
}
if (desired != NULL) {
PyMem_RawFree(desired);
}
Py_XDECREF(meta);
return -1;
}


int
AtomicDict_Reduce(AtomicDict *self, PyObject *iterable, PyObject *aggregate, int chunk_size)
{
PyObject *local_buffer = NULL;
PyObject *item = NULL;

if (!PyCallable_Check(aggregate)) {
PyErr_Format(PyExc_TypeError, "%R is not callable.", aggregate);
goto fail;
}

if (chunk_size != 0) {
PyErr_SetString(PyExc_ValueError, "only supporting chunk_size=0 at the moment.");
goto fail;
}

local_buffer = PyDict_New();
if (local_buffer == NULL)
goto fail;

PyObject *iterator = PyObject_GetIter(iterable);

if (iterator == NULL) {
PyErr_Format(PyExc_TypeError, "%R is not iterable.", iterable);
goto fail;
}

while ((item = PyIter_Next(iterator))) {
PyObject *key = NULL;
PyObject *value = NULL;
PyObject *current = NULL;
PyObject *expected = NULL;
PyObject *desired = NULL;

if (!PyTuple_CheckExact(item)) {
PyErr_Format(PyExc_TypeError, "type(%R) != tuple", item);
goto fail;
}

if (PyTuple_Size(item) != 2) {
PyErr_Format(PyExc_TypeError, "len(%R) != 2", item);
goto fail;
}

key = PyTuple_GetItem(item, 0);
value = PyTuple_GetItem(item, 1);
if (PyDict_Contains(local_buffer, key)) {
current = PyDict_GetItem(local_buffer, key);
expected = PyTuple_GetItem(current, 0);
current = PyTuple_GetItem(current, 1);
} else {
expected = NOT_FOUND;
current = expected;
}

desired = PyObject_CallFunctionObjArgs(aggregate, key, current, value, NULL);
if (desired == NULL)
goto fail;

PyObject *tuple = PyTuple_Pack(2, expected, desired);
if (tuple == NULL)
goto fail;

if (PyDict_SetItem(local_buffer, key, tuple) < 0)
goto fail;

Py_DECREF(item);
}

Py_DECREF(iterator);

if (PyErr_Occurred()) {
goto fail;
}

reduce_flush(self, local_buffer, aggregate);

Py_DECREF(local_buffer);
return 0;

fail:
Py_XDECREF(local_buffer);
Py_XDECREF(item);
return -1;
}

PyObject *
AtomicDict_Reduce_callable(AtomicDict *self, PyObject *args, PyObject *kwargs)
{
PyObject *iterable = NULL;
PyObject *aggregate = NULL;
int chunk_size = 0;

char *kw_list[] = {"iterable", "aggregate", "chunk_size", NULL};

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO|i", kw_list, &iterable, &aggregate, &chunk_size))
goto fail;

int res = AtomicDict_Reduce(self, iterable, aggregate, chunk_size);
if (res < 0)
goto fail;

Py_RETURN_NONE;

fail:
return NULL;
}
1 change: 1 addition & 0 deletions src/cereggii/cereggii.c
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ static PyMethodDef AtomicDict_methods[] = {
{"fast_iter", (PyCFunction) AtomicDict_FastIter, METH_VARARGS | METH_KEYWORDS, NULL},
{"compare_and_set", (PyCFunction) AtomicDict_CompareAndSet_callable, METH_VARARGS | METH_KEYWORDS, NULL},
{"batch_getitem", (PyCFunction) AtomicDict_BatchGetItem, METH_VARARGS | METH_KEYWORDS, NULL},
{"reduce", (PyCFunction) AtomicDict_Reduce_callable, METH_VARARGS | METH_KEYWORDS, NULL},
{NULL}
};

Expand Down
4 changes: 4 additions & 0 deletions src/include/atomic_dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ PyObject *AtomicDict_CompareAndSet(AtomicDict *self, PyObject *key, PyObject *ex

PyObject *AtomicDict_CompareAndSet_callable(AtomicDict *self, PyObject *args, PyObject *kwargs);

int AtomicDict_Reduce(AtomicDict *self, PyObject *iterable, PyObject *aggregate, int chunk_size);

PyObject *AtomicDict_Reduce_callable(AtomicDict *self, PyObject *args, PyObject *kwargs);

int AtomicDict_Compact(AtomicDict *self);

PyObject *AtomicDict_Compact_callable(AtomicDict *self);
Expand Down
22 changes: 22 additions & 0 deletions tests/test_atomic_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,3 +719,25 @@ def test_len():

assert len(d) == 10
assert len(d) == 10 # test twice for len_dirty


def test_reduce():
d = AtomicDict()

data = [
("red", 1),
("green", 42),
("blue", 3),
("red", 5),
]

def count(key, current, new): # noqa: ARG001
if current is cereggii.NOT_FOUND:
return new
return current + new

d.reduce(data, count)

assert d["red"] == 6
assert d["green"] == 42
assert d["blue"] == 3
Loading