diff --git a/sbe/__init__.py b/sbe/__init__.py
index cc1cf58..d5d1cef 100644
--- a/sbe/__init__.py
+++ b/sbe/__init__.py
@@ -107,6 +107,8 @@ def __init__(self, name: str, primitiveType: PrimitiveType, nullValue: Optional[
self.nullValue = float(nullValue)
else:
self.nullValue = int(nullValue)
+ else:
+ self.nullValue = None
def __repr__(self):
rv = self.name + " ("
@@ -277,7 +279,8 @@ class Set:
def encode(self, vals: Iterable[str]) -> int:
vals = set(vals)
- return bitstring.BitArray(v.name in vals for i, v in enumerate(self.choices)).uint
+ assert vals.issubset({c.name for c in self.choices}), f"{vals} is not a subset of {self.choices}"
+ return bitstring.BitArray(v.name in vals for v in reversed(self.choices)).uint
def decode(self, val: int) -> List[str]:
if isinstance(self.encodingType, SetEncodingType):
@@ -496,7 +499,8 @@ def encode(self, message: Message, obj: dict, header: Optional[dict] = None) ->
fmts = []
vals = []
cursor = Cursor(0)
- _walk_fields_encode(self, message.fields, obj, fmts, vals, cursor)
+ _walk_fields_encode(self, message.fields, obj, fmts, vals,
+ message.blockLength, cursor)
fmt = "<" + ''.join(fmts)
header = {
@@ -696,6 +700,8 @@ def _prettify_type(_schema: Schema, t: Type, v):
t.characterEncoding == CharacterEncoding.ASCII or t.characterEncoding is None
):
return v.split(b'\x00', 1)[0].decode('ascii', errors='ignore').strip()
+ if t.nullValue is not None and v == t.nullValue:
+ return None
return v
@@ -737,16 +743,22 @@ def _walk_fields_encode_composite(
cursor.val += FORMAT_SIZES[t1]
-def _walk_fields_encode(schema: Schema, fields: List[Union[Group, Field]], obj: dict, fmt: list, vals: list, cursor: Cursor):
+def _walk_fields_encode(schema: Schema, fields: List[Union[Group, Field]],
+ obj: dict, fmt: list, vals: list, blockLength: int,
+ cursor: Cursor):
for f in fields:
if isinstance(f, Group):
+ if cursor.val < blockLength:
+ fmt.append(str(blockLength - cursor.val) + 'x')
+ cursor.val = blockLength
xs = obj[f.name]
fmt1 = []
vals1 = []
block_length = None
for x in xs:
- _walk_fields_encode(schema, f.fields, x, fmt1, vals1, Cursor(0))
+ _walk_fields_encode(schema, f.fields, x, fmt1, vals1,
+ f.blockLength, Cursor(0))
if block_length is None:
block_length = struct.calcsize("<" + ''.join(fmt1))
diff --git a/tests/dat/example-schema.xml b/tests/dat/example-schema.xml
index b4a39e7..746e7d7 100644
--- a/tests/dat/example-schema.xml
+++ b/tests/dat/example-schema.xml
@@ -14,6 +14,7 @@
+
T
@@ -74,4 +75,13 @@
+
+
+
+
+
+
+
+
+
diff --git a/tests/test_sbe.py b/tests/test_sbe.py
index c92f601..61e28ce 100644
--- a/tests/test_sbe.py
+++ b/tests/test_sbe.py
@@ -7,3 +7,33 @@ def test_parse1():
def test_parse2():
with open('tests/dat/b3-entrypoint-messages-8.0.0.xml', 'r') as f:
sbe.Schema.parse(f)
+
+def test_nullValue():
+ with open('tests/dat/example-schema.xml', 'r') as f:
+ s = sbe.Schema.parse(f)
+ nullable = s.messages[2]
+
+ encodedInt = s.encode(nullable, {'nullable': 5})
+ decodedInt = s.decode(encodedInt)
+ assert decodedInt.value['nullable'] == 5
+
+ encodedNull = s.encode(nullable, {'nullable': None})
+ decodedNull = s.decode(encodedNull)
+ assert decodedNull.value['nullable'] is None
+
+def test_blockLength():
+ with open('tests/dat/example-schema.xml', 'r') as f:
+ s = sbe.Schema.parse(f)
+ msg = s.messages[3]
+
+ encoded = s.encode(msg, {'year': 1990, 'AGroup': [{'numbers': 123},
+ {'numbers': 456}]})
+ # BlockHeader = 8b
+ # Body = 2b year + 2b padding
+ # Repeating group = 2 * (4b numbers + 2b padding)
+ expLen = 8 + 4 + 2*6
+ assert len(encoded) == expLen, "Encoded SBE not padded properly"
+
+ decoded = s.decode(encoded)
+ assert decoded.value['year'] == 1990
+ assert decoded.value['AGroup'][1]['numbers'] == 456