diff --git a/rflx/pyrflx/package.py b/rflx/pyrflx/package.py index 5c12a8ef4..34fd8aec8 100644 --- a/rflx/pyrflx/package.py +++ b/rflx/pyrflx/package.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Iterator +from typing import Callable, Dict, Iterator, Mapping, Union from rflx.common import Base from rflx.pyrflx import PyRFLXError @@ -14,8 +14,13 @@ def __init__(self, name: str) -> None: def name(self) -> str: return self.__name - def new_message(self, key: str) -> MessageValue: - return self.__messages[key].clone() + def new_message( + self, key: str, parameters: Mapping[str, Union[bool, int, str]] = None + ) -> MessageValue: + message = self.__messages[key].clone() + if parameters: + message.add_parameters(parameters) + return message def set_message(self, key: str, value: MessageValue) -> None: self.__messages[key] = value diff --git a/rflx/pyrflx/typevalue.py b/rflx/pyrflx/typevalue.py index f45305d62..f9e8eb037 100644 --- a/rflx/pyrflx/typevalue.py +++ b/rflx/pyrflx/typevalue.py @@ -518,12 +518,14 @@ def __init__( model: Message, refinements: ty.Sequence["RefinementValue"] = None, skip_verification: bool = False, + parameters: ty.Mapping[Name, Expr] = None, state: "MessageValue.State" = None, ) -> None: super().__init__(model) self._skip_verification = skip_verification self._refinements = refinements or [] self._path: ty.List[Link] = [] + self.__parameters = parameters or {} self._fields: ty.Mapping[str, MessageValue.Field] = ( state.fields @@ -563,6 +565,7 @@ def __init__( for k, v in t.items() } ) + self.__additional_enum_literals: ty.Dict[Name, Expr] = {} self.__message_first_name = First("Message") initial = self._fields[INITIAL.name] @@ -583,11 +586,29 @@ def __init__( def add_refinement(self, refinement: "RefinementValue") -> None: self._refinements = [*(self._refinements or []), refinement] + def add_parameters(self, parameters: ty.Mapping[str, ty.Union[bool, int, str]]) -> None: + _parameters: ty.Dict[Name, Expr] = {} + expr: Expr + for name, value in parameters.items(): + if isinstance(value, bool): + expr = Variable("True") if value else Variable("False") + elif isinstance(value, int): + expr = Number(value) + elif isinstance(value, str): + expr = Variable(value) + else: + raise PyRFLXError(f"{type(value)} is no supported parameter type") + _parameters[Variable(name)] = expr + self.__parameters = _parameters + if not self._skip_verification: + self._preset_fields(INITIAL.name) + def clone(self) -> "MessageValue": return MessageValue( self._type, self._refinements, self._skip_verification, + self.__parameters, MessageValue.State( { k: MessageValue.Field( @@ -919,6 +940,7 @@ def _preset_fields(self, fld: str) -> None: assert not self._skip_verification nxt = self._next_field(fld) fields: ty.List[str] = [] + while nxt and nxt != FINAL.name: field = self._fields[nxt] first = self._get_first(nxt) @@ -1139,7 +1161,9 @@ def _is_valid_composite_field(self, field: str) -> bool: return False return all( - (v.name in self._fields and self._fields[v.name].set) or v.name == "Message" + (v.name in self._fields and self._fields[v.name].set) + or v in self.__parameters + or v.name == "Message" for v in valid_edge.size.variables() ) @@ -1236,6 +1260,9 @@ def subst(expression: Expr) -> Expr: if expression in self.__additional_enum_literals: assert isinstance(expression, Name) return self.__additional_enum_literals[expression] + if expression in self.__parameters: + assert isinstance(expression, Name) + return self.__parameters[expression] return expression return expr.substituted(func=subst).substituted(func=subst).simplified() diff --git a/tests/data/fixtures/pyrflx.py b/tests/data/fixtures/pyrflx.py index b22a713f7..bbe4791df 100644 --- a/tests/data/fixtures/pyrflx.py +++ b/tests/data/fixtures/pyrflx.py @@ -26,6 +26,7 @@ def fixture_pyrflx() -> pyrflx.PyRFLX: f"{SPEC_DIR}/message_size.rflx", f"{SPEC_DIR}/message_type_size_condition.rflx", f"{SPEC_DIR}/always_valid_aspect.rflx", + f"{SPEC_DIR}/parameterized.rflx", ], skip_model_verification=True, ) @@ -223,3 +224,8 @@ def fixture_always_valid_aspect_value( always_valid_aspect_package: pyrflx.Package, ) -> pyrflx.MessageValue: return always_valid_aspect_package.new_message("Message") + + +@pytest.fixture(name="parameterized_package", scope="session") +def fixture_parameterized_package(pyrflx_: pyrflx.PyRFLX) -> pyrflx.Package: + return pyrflx_.package("Parameterized") diff --git a/tests/data/specs/parameterized.rflx b/tests/data/specs/parameterized.rflx new file mode 100644 index 000000000..9a5b77c1d --- /dev/null +++ b/tests/data/specs/parameterized.rflx @@ -0,0 +1,19 @@ +package Parameterized is + + type Length is range 0 .. 2**16 - 1 with Size => 16; + type Tag is (Tag_A, Tag_B) with Size => 8; + + type Message (Length : Length; Has_Tag : Boolean; Tag_Value : Tag) is + message + Payload : Opaque + with Size => Length * 8 + then Tag + if Has_Tag = True + then null + if Has_Tag = False; + Tag : Tag + then null + if Tag = Tag_Value; + end message; + +end Parameterized; diff --git a/tests/unit/pyrflx_test.py b/tests/unit/pyrflx_test.py index 04550f602..3aa3e17ac 100644 --- a/tests/unit/pyrflx_test.py +++ b/tests/unit/pyrflx_test.py @@ -58,6 +58,11 @@ def test_package_iterator(tlv_package: Package) -> None: assert [m.name for m in tlv_package] == ["Message"] +def test_package_set_item(tlv_package: Package) -> None: + msg = Message("TLV::Msg", [], {}) + tlv_package["Msg"] = MessageValue(msg) + + def test_pyrflx_iterator(pyrflx_: PyRFLX) -> None: assert {p.name for p in pyrflx_} == { "Ethernet", @@ -71,6 +76,7 @@ def test_pyrflx_iterator(pyrflx_: PyRFLX) -> None: "Sequence_Message", "Sequence_Type", "Null_Message", + "Parameterized", "TLV_With_Checksum", "No_Conditionals", "Message_Type_Size_Condition", @@ -1362,3 +1368,40 @@ def test_get_path(icmp_message_value: MessageValue) -> None: def test_get_model(icmp_message_value: MessageValue) -> None: assert isinstance(icmp_message_value.model, Message) + + +def test_parameterized_message(parameterized_package: Package) -> None: + message = parameterized_package.new_message( + "Message", {"Length": 8, "Has_Tag": False, "Tag_Value": "Tag_A"} + ) + assert message.fields == ["Payload", "Tag"] + assert message.required_fields == ["Payload"] + message.set("Payload", bytes(8)) + assert message.required_fields == [] + assert message.valid_message + assert message.bytestring == bytes(8) + + +def test_parameterized_message_no_verification() -> None: + pyrflx_ = PyRFLX.from_specs( + [SPEC_DIR / "parameterized.rflx"], + skip_model_verification=True, + skip_message_verification=True, + ) + message_unv = pyrflx_.package("Parameterized").new_message( + "Message", {"Length": 8, "Has_Tag": False, "Tag_Value": "Tag_A"} + ) + assert message_unv.fields == ["Payload", "Tag"] + message_unv.set("Payload", bytes(8)) + assert message_unv.valid_message + assert message_unv.bytestring == bytes(8) + + +def test_parameterized_message_invalid_type(parameterized_package: Package) -> None: + with pytest.raises( + PyRFLXError, match=f"^pyrflx: error: {type(bytes())} is no supported parameter type" + ): + parameterized_package.new_message( + "Message", + {"Length": bytes(8)}, # type: ignore[dict-item] + )