Skip to content

Commit

Permalink
fix(py): Invalid serialization of float and int constants (#1427)
Browse files Browse the repository at this point in the history
Fixes #1424 

I updated the roundtrip validation to compare decoded jsons rather than
strings.

The values were manually compared against a rust serialization output.
We cannot do better without python-side extension loading.
  • Loading branch information
aborgna-q committed Aug 14, 2024
1 parent c6473c9 commit b89c08f
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 6 deletions.
7 changes: 6 additions & 1 deletion hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ class ExtensionValue(BaseValue):
value: CustomConst

def deserialize(self) -> val.Value:
return val.Extension(self.value.c, self.typ.deserialize(), self.value.v)
return val.Extension(
name=self.value.c,
typ=self.typ.deserialize(),
val=self.value.v,
extensions=self.extensions,
)


class FunctionValue(BaseValue):
Expand Down
6 changes: 5 additions & 1 deletion hugr-py/src/hugr/std/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,8 @@ class FloatVal(val.ExtensionValue):
v: float

def to_value(self) -> val.Extension:
return val.Extension("float", FLOAT_T, self.v)
name = "ConstF64"
payload = {"value": self.v}
return val.Extension(
name, typ=FLOAT_T, val=payload, extensions=[EXTENSION.name]
)
9 changes: 8 additions & 1 deletion hugr-py/src/hugr/std/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@ class IntVal(val.ExtensionValue):
width: int = field(default=5)

def to_value(self) -> val.Extension:
return val.Extension("int", int_t(self.width), self.v)
name = "ConstInt"
payload = {"log_width": self.width, "value": self.v}
return val.Extension(
name,
typ=int_t(self.width),
val=payload,
extensions=[INT_TYPES_EXTENSION.name],
)


INT_OPS_EXTENSION = ext.Extension("arithmetic.int", ext.Version(0, 1, 0))
Expand Down
6 changes: 4 additions & 2 deletions hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ def validate(h: Hugr, roundtrip: bool = True):
_run_hugr_cmd(serial, cmd)

if roundtrip:
h2 = Hugr.from_serial(SerialHugr.load_json(json.loads(serial)))
assert serial == h2.to_serial().to_json()
starting_json = json.loads(serial)
h2 = Hugr.from_serial(SerialHugr.load_json(starting_json))
roundtrip_json = json.loads(h2.to_serial().to_json())
assert roundtrip_json == starting_json


def _run_hugr_cmd(serial: str, cmd: list[str]):
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/tests/test_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_dom_edge() -> None:


def test_asymm_types() -> None:
# test different types going to entry block's susccessors
# test different types going to entry block's successors
with Cfg() as cfg:
with cfg.add_entry() as entry:
int_load = entry.load(IntVal(34))
Expand Down

0 comments on commit b89c08f

Please sign in to comment.