Skip to content

Commit

Permalink
Merge pull request #25 from gburca/master
Browse files Browse the repository at this point in the history
Fix Set encoding
  • Loading branch information
kizzx2 authored Nov 3, 2024
2 parents 67f3e66 + 61a0579 commit 4f34455
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 4 deletions.
20 changes: 16 additions & 4 deletions sbe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def __init__(self, name: str, primitiveType: PrimitiveType, nullValue: Optional[
self.nullValue = float(nullValue)
else:
self.nullValue = int(nullValue)
else:
self.nullValue = None

def __repr__(self):
rv = self.name + " ("
Expand Down Expand Up @@ -277,7 +279,8 @@ class Set:

def encode(self, vals: Iterable[str]) -> int:
vals = set(vals)
return bitstring.BitArray(v.name in vals for i, v in enumerate(self.choices)).uint
assert vals.issubset({c.name for c in self.choices}), f"{vals} is not a subset of {self.choices}"
return bitstring.BitArray(v.name in vals for v in reversed(self.choices)).uint

def decode(self, val: int) -> List[str]:
if isinstance(self.encodingType, SetEncodingType):
Expand Down Expand Up @@ -496,7 +499,8 @@ def encode(self, message: Message, obj: dict, header: Optional[dict] = None) ->
fmts = []
vals = []
cursor = Cursor(0)
_walk_fields_encode(self, message.fields, obj, fmts, vals, cursor)
_walk_fields_encode(self, message.fields, obj, fmts, vals,
message.blockLength, cursor)
fmt = "<" + ''.join(fmts)

header = {
Expand Down Expand Up @@ -696,6 +700,8 @@ def _prettify_type(_schema: Schema, t: Type, v):
t.characterEncoding == CharacterEncoding.ASCII or t.characterEncoding is None
):
return v.split(b'\x00', 1)[0].decode('ascii', errors='ignore').strip()
if t.nullValue is not None and v == t.nullValue:
return None

return v

Expand Down Expand Up @@ -737,16 +743,22 @@ def _walk_fields_encode_composite(
cursor.val += FORMAT_SIZES[t1]


def _walk_fields_encode(schema: Schema, fields: List[Union[Group, Field]], obj: dict, fmt: list, vals: list, cursor: Cursor):
def _walk_fields_encode(schema: Schema, fields: List[Union[Group, Field]],
obj: dict, fmt: list, vals: list, blockLength: int,
cursor: Cursor):
for f in fields:
if isinstance(f, Group):
if cursor.val < blockLength:
fmt.append(str(blockLength - cursor.val) + 'x')
cursor.val = blockLength
xs = obj[f.name]

fmt1 = []
vals1 = []
block_length = None
for x in xs:
_walk_fields_encode(schema, f.fields, x, fmt1, vals1, Cursor(0))
_walk_fields_encode(schema, f.fields, x, fmt1, vals1,
f.blockLength, Cursor(0))
if block_length is None:
block_length = struct.calcsize("<" + ''.join(fmt1))

Expand Down
10 changes: 10 additions & 0 deletions tests/dat/example-schema.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
<type name="Ron" primitiveType="uint8" minValue="90" maxValue="110"/>
<type name="someNumbers" primitiveType="uint32" length="4"/>
<type name="Percentage" primitiveType="int8" minValue="0" maxValue="100"/>
<type name="IntNULL" primitiveType="uint16" nullValue="65535"/>
<composite name="Booster">
<enum name="BoostType" encodingType="char">
<validValue name="TURBO">T</validValue>
Expand Down Expand Up @@ -74,4 +75,13 @@
<data name="model" id="19" type="varStringEncoding"/>
<data name="activationCode" id="20" type="varAsciiEncoding"/>
</sbe:message>
<sbe:message name="TestNullValue" id="2">
<field name="nullable" id="1" type="IntNULL"/>
</sbe:message>
<sbe:message name="TestBlockLength" id="3" blockLength="4">
<field name="year" id="1" type="ModelYear"/>
<group name="AGroup" id="2" dimensionType="groupSizeEncoding" blbockLength="6">
<field name="numbers" id="1" type="someNumbers"/>
</group>
</sbe:message>
</sbe:messageSchema>
30 changes: 30 additions & 0 deletions tests/test_sbe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,33 @@ def test_parse1():
def test_parse2():
with open('tests/dat/b3-entrypoint-messages-8.0.0.xml', 'r') as f:
sbe.Schema.parse(f)

def test_nullValue():
with open('tests/dat/example-schema.xml', 'r') as f:
s = sbe.Schema.parse(f)
nullable = s.messages[2]

encodedInt = s.encode(nullable, {'nullable': 5})
decodedInt = s.decode(encodedInt)
assert decodedInt.value['nullable'] == 5

encodedNull = s.encode(nullable, {'nullable': None})
decodedNull = s.decode(encodedNull)
assert decodedNull.value['nullable'] is None

def test_blockLength():
with open('tests/dat/example-schema.xml', 'r') as f:
s = sbe.Schema.parse(f)
msg = s.messages[3]

encoded = s.encode(msg, {'year': 1990, 'AGroup': [{'numbers': 123},
{'numbers': 456}]})
# BlockHeader = 8b
# Body = 2b year + 2b padding
# Repeating group = 2 * (4b numbers + 2b padding)
expLen = 8 + 4 + 2*6
assert len(encoded) == expLen, "Encoded SBE not padded properly"

decoded = s.decode(encoded)
assert decoded.value['year'] == 1990
assert decoded.value['AGroup'][1]['numbers'] == 456

0 comments on commit 4f34455

Please sign in to comment.