Skip to content

Commit

Permalink
Do not encode UNSET values in struct arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
chuckwondo committed Aug 22, 2024
1 parent 2c37da0 commit a8abf89
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
30 changes: 27 additions & 3 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -12702,17 +12702,36 @@ mpack_encode_struct_array(
int tagged = tag_value != NULL;
PyObject *fields = struct_type->struct_encode_fields;
Py_ssize_t nfields = PyTuple_GET_SIZE(fields);
Py_ssize_t len = nfields + tagged;
Py_ssize_t len = nfields + tagged, actual_len = len;

if (Py_EnterRecursiveCall(" while serializing an object")) return -1;

Py_ssize_t header_offset = self->output_len;
if (mpack_encode_array_header(self, len, "structs") < 0) goto cleanup;
if (tagged) {
if (mpack_encode(self, tag_value) < 0) goto cleanup;
}
for (Py_ssize_t i = 0; i < nfields; i++) {
PyObject *val = Struct_get_index(obj, i);
if (val == NULL || mpack_encode(self, val) < 0) goto cleanup;
if (val == UNSET) {
actual_len--;
} else if (val == NULL || mpack_encode(self, val) < 0) {
goto cleanup;
}
}
if (MS_UNLIKELY(actual_len != len)) {
/* Fixup the header length after we know how many fields were
* actually written */
char *header_loc = self->output_buffer_raw + header_offset;
if (len < 16) {
*header_loc = MP_FIXARRAY | actual_len;
} else if (len < (1 << 16)) {
*header_loc++ = MP_ARRAY16;
_msgspec_store16(header_loc, (uint16_t)actual_len);
} else {
*header_loc++ = MP_ARRAY32;
_msgspec_store32(header_loc, (uint32_t)actual_len);
}
}
status = 0;
cleanup:
Expand Down Expand Up @@ -14100,11 +14119,16 @@ json_encode_struct_array(
for (Py_ssize_t i = 0; i < nfields; i++) {
PyObject *val = Struct_get_index(obj, i);
if (val == NULL) goto cleanup;
if (val == UNSET) continue;
if (json_encode(self, val) < 0) goto cleanup;
if (ms_write(self, ",", 1) < 0) goto cleanup;
}
/* Overwrite trailing comma with ] */
*(self->output_buffer_raw + self->output_len - 1) = ']';
if (*(self->output_buffer_raw + self->output_len - 1) == ',') {
*(self->output_buffer_raw + self->output_len - 1) = ']';
} else {
if (ms_write(self, "]", 1) < 0) goto cleanup;
}
status = 0;
cleanup:
Py_LeaveRecursiveCall();
Expand Down
11 changes: 11 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4061,6 +4061,17 @@ class Ex(Struct, omit_defaults=True):
sol = proto.encode(y)
assert res == sol

def test_unset_encode_struct_array_like(self, proto):
class Ex(Struct, array_like=True):
x: Union[int, UnsetType] = UNSET
y: Union[int, UnsetType] = UNSET
z: int = 0

for x, y in [(Ex(), [0]), (Ex(x=1), [1, 0]), (Ex(y=2), [2, 0])]:
res = proto.encode(x)
sol = proto.encode(y)
assert res == sol


class TestOrder:
def test_encoder_order_attribute(self, proto):
Expand Down

0 comments on commit a8abf89

Please sign in to comment.