Skip to content

Commit

Permalink
Add msgspec.json.Encoder.encode_lines
Browse files Browse the repository at this point in the history
This adds a new method on the JSON `Encoder` class for encoding an
iterable of items as newline delimited JSON, one item per line. This is
roughly equivalent to (but up to 3x faster than):

```
def encode_lines(items):
    return b''.join(
        msgspec.json.encode(item) + b'\n' for item in items
    )
```

We limit it to an implementation on the `Encoder` class alone for now
(rather than a new top-level function) since this is a relatively niche
feature.
  • Loading branch information
jcrist committed Jul 11, 2023
1 parent 8d70fc0 commit 473d6d3
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ JSON
.. currentmodule:: msgspec.json

.. autoclass:: Encoder
:members: encode, encode_into
:members: encode, encode_lines, encode_into

.. autoclass:: Decoder
:members: decode
Expand Down
73 changes: 73 additions & 0 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -8345,6 +8345,7 @@ static PyTypeObject Ext_Type = {
*************************************************************************/

#define ENC_INIT_BUFSIZE 32
#define ENC_LINES_INIT_BUFSIZE 1024

typedef struct EncoderState {
MsgspecState *mod; /* module reference */
Expand Down Expand Up @@ -12729,6 +12730,74 @@ JSONEncoder_encode(Encoder *self, PyObject *const *args, Py_ssize_t nargs)
return encoder_encode_common(self, args, nargs, &json_encode);
}

PyDoc_STRVAR(JSONEncoder_encode_lines__doc__,
"encode_lines(self, items)\n"
"--\n"
"\n"
"Encode an iterable of items as newline-delimited JSON, one item per line.\n"
"\n"
"Parameters\n"
"----------\n"
"items : iterable\n"
" An iterable of items to encode.\n"
"\n"
"Returns\n"
"-------\n"
"data : bytes\n"
" The items encoded as newline-delimited JSON, one item per line.\n"
"\n"
"Examples\n"
"--------\n"
">>> import msgspec\n"
">>> items = [{\"name\": \"alice\"}, {\"name\": \"ben\"}]\n"
">>> encoder = msgspec.json.Encoder()\n"
">>> encoder.encode_lines(items)\n"
"b'{\"name\":\"alice\"}\\n{\"name\":\"ben\"}\\n'"
);
static PyObject *
JSONEncoder_encode_lines(Encoder *self, PyObject *const *args, Py_ssize_t nargs)
{
if (!check_positional_nargs(nargs, 1, 1)) return NULL;

EncoderState state = {
.mod = self->mod,
.enc_hook = self->enc_hook,
.decimal_as_string = self->decimal_as_string,
.output_len = 0,
.max_output_len = ENC_LINES_INIT_BUFSIZE,
.resize_buffer = &ms_resize_bytes
};
state.output_buffer = PyBytes_FromStringAndSize(NULL, state.max_output_len);
if (state.output_buffer == NULL) return NULL;
state.output_buffer_raw = PyBytes_AS_STRING(state.output_buffer);

PyObject *input = args[0];
if (MS_LIKELY(PyList_Check(input))) {
for (Py_ssize_t i = 0; i < PyList_GET_SIZE(input); i++) {
if (json_encode(&state, PyList_GET_ITEM(input, i)) < 0) goto error;
if (ms_write(&state, "\n", 1) < 0) goto error;
}
}
else {
PyObject *iter = PyObject_GetIter(input);
if (iter == NULL) goto error;

PyObject *item;
while ((item = PyIter_Next(iter))) {
if (json_encode(&state, item) < 0) goto error;
if (ms_write(&state, "\n", 1) < 0) goto error;
}
if (PyErr_Occurred()) goto error;
}

FAST_BYTES_SHRINK(state.output_buffer, state.output_len);
return state.output_buffer;

error:
Py_DECREF(state.output_buffer);
return NULL;
}

static struct PyMethodDef JSONEncoder_methods[] = {
{
"encode", (PyCFunction) JSONEncoder_encode, METH_FASTCALL,
Expand All @@ -12738,6 +12807,10 @@ static struct PyMethodDef JSONEncoder_methods[] = {
"encode_into", (PyCFunction) JSONEncoder_encode_into, METH_FASTCALL,
Encoder_encode_into__doc__,
},
{
"encode_lines", (PyCFunction) JSONEncoder_encode_lines, METH_FASTCALL,
JSONEncoder_encode_lines__doc__,
},
{NULL, NULL} /* sentinel */
};

Expand Down
2 changes: 2 additions & 0 deletions msgspec/json.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ from typing import (
Callable,
Dict,
Generic,
Iterable,
Literal,
Optional,
Tuple,
Expand All @@ -29,6 +30,7 @@ class Encoder:
decimal_format: Literal["string", "number"] = "string",
): ...
def encode(self, obj: Any) -> bytes: ...
def encode_lines(self, items: Iterable) -> bytes: ...
def encode_into(
self, obj: Any, buffer: bytearray, offset: Optional[int] = 0
) -> None: ...
Expand Down
10 changes: 10 additions & 0 deletions tests/basic_typing_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,16 @@ def check_json_Encoder_encode() -> None:
reveal_type(b) # assert "bytes" in typ


def check_json_Encoder_encode_lines() -> None:
enc = msgspec.json.Encoder()
items = [{"x": 1}, 2]
b = enc.encode_lines(items)
b2 = enc.encode_lines((i for i in items))

reveal_type(b) # assert "bytes" in typ
reveal_type(b2) # assert "bytes" in typ


def check_json_Encoder_encode_into() -> None:
enc = msgspec.json.Encoder()
buf = bytearray(48)
Expand Down
53 changes: 53 additions & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,59 @@ def test_encode_into_handles_errors_properly(self):
out2 = enc.encode([1, 2, 3])
assert out1 == out2

@pytest.mark.parametrize("n", range(3))
@pytest.mark.parametrize("iterable", [False, True])
def test_encode_lines(self, n, iterable):
class custom:
def __init__(self, x):
self.x = x

def __str__(self):
return f"<{self.x}>"

enc = msgspec.json.Encoder(enc_hook=str)

items = [{"x": i, "y": custom(i)} for i in range(n)]
sol = b"".join(enc.encode(i) + b"\n" for i in items)
if iterable:
items = (i for i in items)

res = enc.encode_lines(items)
assert res == sol

@pytest.mark.parametrize("iterable", [False, True])
def test_encode_lines_iterable_unsupported_item_errors(self, iterable):
enc = msgspec.json.Encoder()

def gen():
yield 1
yield object()

items = gen() if iterable else list(gen())

with pytest.raises(TypeError):
enc.encode_lines(items)

def test_encode_lines_iterable_iter_error(self):
enc = msgspec.json.Encoder()

class noiter:
def __iter__(self):
raise ValueError("Oh no!")

with pytest.raises(ValueError, match="Oh no!"):
enc.encode_lines(noiter())

def test_encode_lines_iterable_next_error(self):
enc = msgspec.json.Encoder()

def gen():
yield 1
raise ValueError("Oh no!")

with pytest.raises(ValueError, match="Oh no!"):
enc.encode_lines(gen())


class TestDecodeFunction:
def test_decode(self):
Expand Down

0 comments on commit 473d6d3

Please sign in to comment.