diff --git a/sbe/__init__.py b/sbe/__init__.py index cc1cf58..d5d1cef 100644 --- a/sbe/__init__.py +++ b/sbe/__init__.py @@ -107,6 +107,8 @@ def __init__(self, name: str, primitiveType: PrimitiveType, nullValue: Optional[ self.nullValue = float(nullValue) else: self.nullValue = int(nullValue) + else: + self.nullValue = None def __repr__(self): rv = self.name + " (" @@ -277,7 +279,8 @@ class Set: def encode(self, vals: Iterable[str]) -> int: vals = set(vals) - return bitstring.BitArray(v.name in vals for i, v in enumerate(self.choices)).uint + assert vals.issubset({c.name for c in self.choices}), f"{vals} is not a subset of {self.choices}" + return bitstring.BitArray(v.name in vals for v in reversed(self.choices)).uint def decode(self, val: int) -> List[str]: if isinstance(self.encodingType, SetEncodingType): @@ -496,7 +499,8 @@ def encode(self, message: Message, obj: dict, header: Optional[dict] = None) -> fmts = [] vals = [] cursor = Cursor(0) - _walk_fields_encode(self, message.fields, obj, fmts, vals, cursor) + _walk_fields_encode(self, message.fields, obj, fmts, vals, + message.blockLength, cursor) fmt = "<" + ''.join(fmts) header = { @@ -696,6 +700,8 @@ def _prettify_type(_schema: Schema, t: Type, v): t.characterEncoding == CharacterEncoding.ASCII or t.characterEncoding is None ): return v.split(b'\x00', 1)[0].decode('ascii', errors='ignore').strip() + if t.nullValue is not None and v == t.nullValue: + return None return v @@ -737,16 +743,22 @@ def _walk_fields_encode_composite( cursor.val += FORMAT_SIZES[t1] -def _walk_fields_encode(schema: Schema, fields: List[Union[Group, Field]], obj: dict, fmt: list, vals: list, cursor: Cursor): +def _walk_fields_encode(schema: Schema, fields: List[Union[Group, Field]], + obj: dict, fmt: list, vals: list, blockLength: int, + cursor: Cursor): for f in fields: if isinstance(f, Group): + if cursor.val < blockLength: + fmt.append(str(blockLength - cursor.val) + 'x') + cursor.val = blockLength xs = obj[f.name] fmt1 = [] vals1 = [] block_length = None for x in xs: - _walk_fields_encode(schema, f.fields, x, fmt1, vals1, Cursor(0)) + _walk_fields_encode(schema, f.fields, x, fmt1, vals1, + f.blockLength, Cursor(0)) if block_length is None: block_length = struct.calcsize("<" + ''.join(fmt1)) diff --git a/tests/dat/example-schema.xml b/tests/dat/example-schema.xml index b4a39e7..746e7d7 100644 --- a/tests/dat/example-schema.xml +++ b/tests/dat/example-schema.xml @@ -14,6 +14,7 @@ + T @@ -74,4 +75,13 @@ + + + + + + + + + diff --git a/tests/test_sbe.py b/tests/test_sbe.py index c92f601..61e28ce 100644 --- a/tests/test_sbe.py +++ b/tests/test_sbe.py @@ -7,3 +7,33 @@ def test_parse1(): def test_parse2(): with open('tests/dat/b3-entrypoint-messages-8.0.0.xml', 'r') as f: sbe.Schema.parse(f) + +def test_nullValue(): + with open('tests/dat/example-schema.xml', 'r') as f: + s = sbe.Schema.parse(f) + nullable = s.messages[2] + + encodedInt = s.encode(nullable, {'nullable': 5}) + decodedInt = s.decode(encodedInt) + assert decodedInt.value['nullable'] == 5 + + encodedNull = s.encode(nullable, {'nullable': None}) + decodedNull = s.decode(encodedNull) + assert decodedNull.value['nullable'] is None + +def test_blockLength(): + with open('tests/dat/example-schema.xml', 'r') as f: + s = sbe.Schema.parse(f) + msg = s.messages[3] + + encoded = s.encode(msg, {'year': 1990, 'AGroup': [{'numbers': 123}, + {'numbers': 456}]}) + # BlockHeader = 8b + # Body = 2b year + 2b padding + # Repeating group = 2 * (4b numbers + 2b padding) + expLen = 8 + 4 + 2*6 + assert len(encoded) == expLen, "Encoded SBE not padded properly" + + decoded = s.decode(encoded) + assert decoded.value['year'] == 1990 + assert decoded.value['AGroup'][1]['numbers'] == 456