Skip to content

Commit

Permalink
fix: Sum value equality. Add unit tests (#1484)
Browse files Browse the repository at this point in the history
This was supposed to be part of #1481, but pushed it to the branch that
depended on it instead 🤦

- Adds the string/repr unit tests suggested by
#1481 (review)

- Tests—and fixes—equality comparation between Sum values.
  • Loading branch information
aborgna-q authored Aug 29, 2024
1 parent 9698420 commit a7b2718
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 33 deletions.
7 changes: 0 additions & 7 deletions hugr-py/src/hugr/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,6 @@ class Either(Sum):
In fallible contexts, the Left variant is used to represent success, and the
Right variant is used to represent failure.
Example:
>>> either = Either([Bool, Bool], [Bool])
>>> either
Either(left=[Bool, Bool], right=[Bool])
>>> str(either)
'Either((Bool, Bool), Bool)'
"""

def __init__(self, left: Iterable[Type], right: Iterable[Type]):
Expand Down
38 changes: 13 additions & 25 deletions hugr-py/src/hugr/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def _to_serial(self) -> sops.SumValue:
vs=ser_it(self.vals),
)

def __eq__(self, other: object) -> bool:
return (
isinstance(other, Sum)
and self.tag == other.tag
and self.typ == other.typ
and self.vals == other.vals
)


class UnitSum(Sum):
"""Simple :class:`Sum` with each variant being an empty row.
Expand Down Expand Up @@ -117,7 +125,7 @@ def bool_value(b: bool) -> UnitSum:
FALSE = bool_value(False)


@dataclass
@dataclass(eq=False)
class Tuple(Sum):
"""Tuple or product value, defined by a list of values.
Internally a :class:`Sum` with a single variant row.
Expand All @@ -131,9 +139,6 @@ class Tuple(Sum):
"""

#: The values of this tuple.
vals: list[Value]

def __init__(self, *vals: Value):
val_list = list(vals)
super().__init__(
Expand All @@ -151,24 +156,19 @@ def __repr__(self) -> str:
return f"Tuple({', '.join(map(repr, self.vals))})"


@dataclass
@dataclass(eq=False)
class Some(Sum):
"""Optional tuple of value, containing a list of values.
Example:
>>> some = Some(TRUE, FALSE)
>>> some
Some(TRUE, FALSE)
>>> str(some)
'Some(TRUE, FALSE)'
>>> some.type_()
Option(Bool, Bool)
"""

#: The values of this tuple.
vals: list[Value]

def __init__(self, *vals: Value):
val_list = list(vals)
super().__init__(
Expand All @@ -179,14 +179,12 @@ def __repr__(self) -> str:
return f"Some({', '.join(map(repr, self.vals))})"


@dataclass
@dataclass(eq=False)
class None_(Sum):
"""Optional tuple of value, containing no values.
Example:
>>> none = None_(tys.Bool)
>>> none
None(Bool)
>>> str(none)
'None'
>>> none.type_()
Expand All @@ -204,25 +202,20 @@ def __str__(self) -> str:
return "None"


@dataclass
@dataclass(eq=False)
class Left(Sum):
"""Left variant of a :class:`tys.Either` type, containing a list of values.
In fallible contexts, this represents the success variant.
Example:
>>> left = Left([TRUE, FALSE], [tys.Bool])
>>> left
Left(vals=[TRUE, FALSE], right_typ=[Bool])
>>> str(left)
'Left(TRUE, FALSE)'
>>> str(left.type_())
'Either((Bool, Bool), Bool)'
"""

#: The values of this tuple.
vals: list[Value]

def __init__(self, vals: Iterable[Value], right_typ: Iterable[tys.Type]):
val_list = list(vals)
super().__init__(
Expand All @@ -240,7 +233,7 @@ def __str__(self) -> str:
return f"Left({vals_str})"


@dataclass
@dataclass(eq=False)
class Right(Sum):
"""Right variant of a :class:`tys.Either` type, containing a list of values.
Expand All @@ -250,17 +243,12 @@ class Right(Sum):
Example:
>>> right = Right([tys.Bool, tys.Bool, tys.Bool], [TRUE, FALSE])
>>> right
Right(left_typ=[Bool, Bool, Bool], vals=[TRUE, FALSE])
>>> str(right)
'Right(TRUE, FALSE)'
>>> str(right.type_())
'Either((Bool, Bool, Bool), (Bool, Bool))'
"""

#: The values of this tuple.
vals: list[Value]

def __init__(self, left_typ: Iterable[tys.Type], vals: Iterable[Value]):
val_list = list(vals)
super().__init__(
Expand Down
34 changes: 33 additions & 1 deletion hugr-py/tests/test_tys.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,46 @@
from __future__ import annotations

from hugr.tys import Bool, Qubit, Sum, Tuple, UnitSum
import pytest

from hugr.tys import Bool, Either, Option, Qubit, Sum, Tuple, Type, UnitSum


def test_sums():
assert Sum([[Bool, Qubit]]) == Tuple(Bool, Qubit)
assert Tuple(Bool, Qubit) == Sum([[Bool, Qubit]])
assert Sum([[Bool, Qubit]]).as_tuple() == Sum([[Bool, Qubit]])

assert Sum([[Bool, Qubit], []]) == Option(Bool, Qubit)
assert Sum([[Bool, Qubit], []]) == Either([Bool, Qubit], [])
assert Option(Bool, Qubit) == Either([Bool, Qubit], [])
assert Sum([[Qubit], [Bool]]) == Either([Qubit], [Bool])

assert Tuple() == Sum([[]])
assert UnitSum(0) == Sum([])
assert UnitSum(1) == Tuple()
assert UnitSum(4) == Sum([[], [], [], []])


@pytest.mark.parametrize(
("ty", "string", "repr_str"),
[
(
Sum([[Bool], [Qubit], [Qubit, Bool]]),
"Sum([[Bool], [Qubit], [Qubit, Bool]])",
"Sum([[Bool], [Qubit], [Qubit, Bool]])",
),
(UnitSum(1), "Unit", "Unit"),
(UnitSum(2), "Bool", "Bool"),
(UnitSum(3), "UnitSum(3)", "UnitSum(3)"),
(Tuple(Bool, Qubit), "Tuple(Bool, Qubit)", "Tuple(Bool, Qubit)"),
(Option(Bool, Qubit), "Option(Bool, Qubit)", "Option(Bool, Qubit)"),
(
Either([Bool, Qubit], [Bool]),
"Either((Bool, Qubit), Bool)",
"Either(left=[Bool, Qubit], right=[Bool])",
),
],
)
def test_tys_sum_str(ty: Type, string: str, repr_str: str):
assert str(ty) == string
assert repr(ty) == repr_str
57 changes: 57 additions & 0 deletions hugr-py/tests/test_val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import pytest

from hugr import tys
from hugr.val import FALSE, TRUE, Left, None_, Right, Some, Sum, Tuple, UnitSum, Value


def test_sums():
assert Sum(0, tys.Tuple(), []) == Tuple()
assert Sum(0, tys.Tuple(tys.Bool, tys.Bool), [TRUE, FALSE]) == Tuple(TRUE, FALSE)

ty = tys.Sum([[tys.Bool, tys.Bool], []])
assert Sum(0, ty, [TRUE, FALSE]) == Some(TRUE, FALSE)
assert Sum(0, ty, [TRUE, FALSE]) == Left([TRUE, FALSE], [])
assert Sum(1, ty, []) == None_(tys.Bool, tys.Bool)
assert Sum(1, ty, []) == Right([tys.Bool, tys.Bool], [])

ty = tys.Sum([[tys.Bool], [tys.Bool]])
assert Sum(0, ty, [TRUE]) == Left([TRUE], [tys.Bool])
assert Sum(1, ty, [FALSE]) == Right([tys.Bool], [FALSE])

assert Tuple() == Sum(0, tys.Tuple(), [])
assert UnitSum(0, size=1) == Tuple()
assert UnitSum(2, size=4) == Sum(2, tys.UnitSum(size=4), [])


@pytest.mark.parametrize(
("value", "string", "repr_str"),
[
(
Sum(0, tys.Sum([[tys.Bool], [tys.Qubit]]), [TRUE, FALSE]),
"Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])",
"Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])",
),
(UnitSum(0, size=1), "Unit", "Unit"),
(UnitSum(0, size=2), "FALSE", "FALSE"),
(UnitSum(1, size=2), "TRUE", "TRUE"),
(UnitSum(2, size=5), "UnitSum(2, 5)", "UnitSum(2, 5)"),
(Tuple(TRUE, FALSE), "Tuple(TRUE, FALSE)", "Tuple(TRUE, FALSE)"),
(Some(TRUE, FALSE), "Some(TRUE, FALSE)", "Some(TRUE, FALSE)"),
(None_(tys.Bool, tys.Bool), "None", "None(Bool, Bool)"),
(
Left([TRUE, FALSE], [tys.Bool]),
"Left(TRUE, FALSE)",
"Left(vals=[TRUE, FALSE], right_typ=[Bool])",
),
(
Right([tys.Bool, tys.Bool], [FALSE]),
"Right(FALSE)",
"Right(left_typ=[Bool, Bool], vals=[FALSE])",
),
],
)
def test_val_sum_str(value: Value, string: str, repr_str: str):
assert str(value) == string
assert repr(value) == repr_str

0 comments on commit a7b2718

Please sign in to comment.