From 8121333ee5878e11bc557e76019a8fcf797b27a5 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 10 Jul 2024 12:28:40 +0100 Subject: [PATCH 1/2] feat(hugr-py): `AsCustomOp` protocol for user-defined custom op types. Replace existing custom operation types with this. Follow ups: - Similar thing for custom types. - Optional: allow these types to register themselves with `serialization.ops.CustomOp` so they can be deserialized directly. --- hugr-py/src/hugr/ops.py | 80 ++++++++++++++++++++++++++--- hugr-py/src/hugr/std/int.py | 54 +++++++++++++------- hugr-py/src/hugr/std/logic.py | 20 +++----- hugr-py/tests/conftest.py | 95 +++++++++++++++++++++++------------ hugr-py/tests/test_custom.py | 42 ++++++++++++++++ 5 files changed, 220 insertions(+), 71 deletions(-) create mode 100644 hugr-py/tests/test_custom.py diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index f7451742f..e98eb91cb 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -3,8 +3,11 @@ from __future__ import annotations from dataclasses import dataclass, field +from functools import cached_property from typing import TYPE_CHECKING, Protocol, TypeVar, runtime_checkable +from typing_extensions import Self + import hugr.serialization.ops as sops from hugr import tys, val from hugr.node_port import Direction, InPort, Node, OutPort, Wire @@ -197,8 +200,69 @@ def _set_in_types(self, types: tys.TypeRow) -> None: self._types = types -@dataclass(frozen=True) -class Custom(DataflowOp): +@runtime_checkable +class AsCustomOp(DataflowOp, Protocol): + """Abstract interface that types can implement + to behave as a custom dataflow operation. + """ + + @cached_property + def custom_op(self) -> Custom: + """:class:`Custom` operation that this type represents. + + Computed once using :meth:`to_custom` and cached - should be deterministic. + """ + return self.to_custom() + + def to_custom(self) -> Custom: + """Convert this type to a :class:`Custom` operation. + + + Used by :attr:`custom_op`, so must be deterministic. + """ + ... # pragma: no cover + + @classmethod + def from_custom(cls, custom: Custom) -> Self | None: + """Load from a :class:`Custom` operation. + + + By default assumes the type of `cls` is a singleton, + and compares the result of :meth:`to_custom` with the given `custom`. + + If successful, returns the singleton, else None. + + Non-singleton types should override this method. + """ + default = cls() + if default.custom_op == custom: + return default + return None + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AsCustomOp): + return NotImplemented + slf, other = self.custom_op, other.custom_op + return ( + slf.extension == other.extension + and slf.op_name == other.op_name + and slf.signature == other.signature + and slf.args == other.args + ) + + def outer_signature(self) -> tys.FunctionType: + return self.custom_op.signature + + def to_serial(self, parent: Node) -> sops.CustomOp: + return self.custom_op.to_serial(parent) + + @property + def num_out(self) -> int: + return len(self.custom_op.signature.output) + + +@dataclass(frozen=True, eq=False) +class Custom(AsCustomOp): """A non-core dataflow operation defined in an extension.""" op_name: str @@ -207,10 +271,6 @@ class Custom(DataflowOp): extension: tys.ExtensionId = "" args: list[tys.TypeArg] = field(default_factory=list) - @property - def num_out(self) -> int: - return len(self.signature.output) - def to_serial(self, parent: Node) -> sops.CustomOp: return sops.CustomOp( parent=parent.idx, @@ -221,8 +281,12 @@ def to_serial(self, parent: Node) -> sops.CustomOp: args=ser_it(self.args), ) - def outer_signature(self) -> tys.FunctionType: - return self.signature + def to_custom(self) -> Custom: + return self + + @classmethod + def from_custom(cls, custom: Custom) -> Custom: + return custom @dataclass() diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 20f7b91cd..34c826472 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -2,10 +2,16 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar + +from typing_extensions import Self from hugr import tys, val -from hugr.ops import Custom +from hugr.ops import AsCustomOp, Custom, DataflowOp + +if TYPE_CHECKING: + from hugr.ops import Command, ComWire def int_t(width: int) -> tys.Opaque: @@ -44,27 +50,39 @@ def to_value(self) -> val.Extension: return val.Extension("int", INT_T, self.v) -@dataclass(frozen=True) -class IntOps(Custom): - """Base class for integer operations.""" - - extension: tys.ExtensionId = "arithmetic.int" - - -_ARG_I32 = tys.BoundedNatArg(n=5) +OPS_EXTENSION: tys.ExtensionId = "arithmetic.int" @dataclass(frozen=True) -class _DivModDef(IntOps): +class _DivModDef(AsCustomOp): """DivMod operation, has two inputs and two outputs.""" - num_out: int = 2 - extension: tys.ExtensionId = "arithmetic.int" - op_name: str = "idivmod_u" - signature: tys.FunctionType = field( - default_factory=lambda: tys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2) - ) - args: list[tys.TypeArg] = field(default_factory=lambda: [_ARG_I32, _ARG_I32]) + op_name: ClassVar[str] = "idivmod_u" + arg1: int = 5 + arg2: int = 5 + + def to_custom(self) -> Custom: + return Custom( + "idivmod_u", + tys.FunctionType( + input=[int_t(self.arg1)] * 2, output=[int_t(self.arg2)] * 2 + ), + extension=OPS_EXTENSION, + args=[tys.BoundedNatArg(n=self.arg1), tys.BoundedNatArg(n=self.arg2)], + ) + + @classmethod + def from_custom(cls, custom: Custom) -> Self | None: + if not (custom.extension == OPS_EXTENSION and custom.op_name == cls.op_name): + return None + match custom.args: + case [tys.BoundedNatArg(n=a1), tys.BoundedNatArg(n=a2)]: + return cls(arg1=a1, arg2=a2) + case _: + return None + + def __call__(self, a: ComWire, b: ComWire) -> Command: + return DataflowOp.__call__(self, a, b) #: DivMod operation. diff --git a/hugr-py/src/hugr/std/logic.py b/hugr-py/src/hugr/std/logic.py index b2890e6e4..1291a61c5 100644 --- a/hugr-py/src/hugr/std/logic.py +++ b/hugr-py/src/hugr/std/logic.py @@ -6,32 +6,24 @@ from typing import TYPE_CHECKING from hugr import tys -from hugr.ops import Command, Custom +from hugr.ops import AsCustomOp, Command, Custom, DataflowOp if TYPE_CHECKING: from hugr.ops import ComWire -@dataclass(frozen=True) -class LogicOps(Custom): - """Base class for logic operations.""" - - extension: tys.ExtensionId = "logic" - - -_NotSig = tys.FunctionType.endo([tys.Bool]) +EXTENSION_ID: tys.ExtensionId = "logic" @dataclass(frozen=True) -class _NotDef(LogicOps): +class _NotDef(AsCustomOp): """Not operation.""" - num_out: int = 1 - op_name: str = "Not" - signature: tys.FunctionType = _NotSig + def to_custom(self) -> Custom: + return Custom("Not", tys.FunctionType.endo([tys.Bool]), extension=EXTENSION_ID) def __call__(self, a: ComWire) -> Command: - return super().__call__(a) + return DataflowOp.__call__(self, a) #: Not operation diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index e2d67b0c2..61ccd09ee 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -5,11 +5,14 @@ import pathlib import subprocess from dataclasses import dataclass -from typing import TYPE_CHECKING +from enum import Enum +from typing import TYPE_CHECKING, TypeVar + +from typing_extensions import Self from hugr import tys from hugr.hugr import Hugr -from hugr.ops import Command, Custom +from hugr.ops import AsCustomOp, Command, Custom, DataflowOp from hugr.serialization.serial_hugr import SerialHugr from hugr.std.float import FLOAT_T @@ -17,50 +20,79 @@ from hugr.ops import ComWire -@dataclass(frozen=True) -class QuantumOps(Custom): - extension: tys.ExtensionId = "tket2.quantum" +QUANTUM_EXTENSION_ID: tys.ExtensionId = "quantum.tket2" + +E = TypeVar("E", bound=Enum) -_OneQbSig = tys.FunctionType.endo([tys.Qubit]) +def _load_enum(enum_cls: type[E], custom: Custom) -> E | None: + if ( + custom.extension == QUANTUM_EXTENSION_ID + and custom.op_name in enum_cls.__members__ + ): + return enum_cls(custom.op_name) + return None @dataclass(frozen=True) -class OneQbGate(QuantumOps): - op_name: str - num_out: int = 1 - signature: tys.FunctionType = _OneQbSig +class OneQbGate(AsCustomOp): + # Have to nest enum to avoid meta class conflict + class _Enum(Enum): + H = "H" + + _enum: _Enum def __call__(self, q: ComWire) -> Command: - return super().__call__(q) + return DataflowOp.__call__(self, q) + def to_custom(self) -> Custom: + return Custom( + self._enum.value, + tys.FunctionType.endo([tys.Qubit]), + extension=QUANTUM_EXTENSION_ID, + ) -H = OneQbGate("H") + @classmethod + def from_custom(cls, custom: Custom) -> Self | None: + return cls(e) if (e := _load_enum(cls._Enum, custom)) else None -_TwoQbSig = tys.FunctionType.endo([tys.Qubit] * 2) +H = OneQbGate(OneQbGate._Enum.H) @dataclass(frozen=True) -class TwoQbGate(QuantumOps): - op_name: str - num_out: int = 2 - signature: tys.FunctionType = _TwoQbSig +class TwoQbGate(AsCustomOp): + class _Enum(Enum): + CX = "CX" - def __call__(self, q0: ComWire, q1: ComWire) -> Command: - return super().__call__(q0, q1) + _enum: _Enum + + def to_custom(self) -> Custom: + return Custom( + self._enum.value, + tys.FunctionType.endo([tys.Qubit] * 2), + extension=QUANTUM_EXTENSION_ID, + ) + @classmethod + def from_custom(cls, custom: Custom) -> Self | None: + return cls(e) if (e := _load_enum(cls._Enum, custom)) else None -CX = TwoQbGate("CX") + def __call__(self, q0: ComWire, q1: ComWire) -> Command: + return DataflowOp.__call__(self, q0, q1) -_MeasSig = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]) + +CX = TwoQbGate(TwoQbGate._Enum.CX) @dataclass(frozen=True) -class MeasureDef(QuantumOps): - op_name: str = "Measure" - num_out: int = 2 - signature: tys.FunctionType = _MeasSig +class MeasureDef(AsCustomOp): + def to_custom(self) -> Custom: + return Custom( + "Measure", + tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]), + extension=QUANTUM_EXTENSION_ID, + ) def __call__(self, q: ComWire) -> Command: return super().__call__(q) @@ -68,14 +100,15 @@ def __call__(self, q: ComWire) -> Command: Measure = MeasureDef() -_RzSig = tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit]) - @dataclass(frozen=True) -class RzDef(QuantumOps): - op_name: str = "Rz" - num_out: int = 1 - signature: tys.FunctionType = _RzSig +class RzDef(AsCustomOp): + def to_custom(self) -> Custom: + return Custom( + "Rz", + tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit]), + extension=QUANTUM_EXTENSION_ID, + ) def __call__(self, q: ComWire, fl_wire: ComWire) -> Command: return super().__call__(q, fl_wire) diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py new file mode 100644 index 000000000..8c9fcabb7 --- /dev/null +++ b/hugr-py/tests/test_custom.py @@ -0,0 +1,42 @@ +import pytest + +from hugr import tys +from hugr.node_port import Node +from hugr.ops import AsCustomOp, Custom +from hugr.std.int import DivMod +from hugr.std.logic import EXTENSION_ID, Not + +from .conftest import CX, H, Measure, Rz + + +@pytest.mark.parametrize( + "as_custom", + [Not, DivMod, H, CX, Measure, Rz], +) +def test_custom(as_custom: AsCustomOp): + custom = as_custom.to_custom() + + assert custom.to_custom() == custom + assert Custom.from_custom(custom) == custom + + assert type(as_custom).from_custom(custom) == as_custom + assert as_custom.to_serial(Node(0)).deserialize() == custom + assert custom == as_custom + assert as_custom == custom + + +def test_custom_bad_eq(): + assert Not != DivMod + + bad_custom_sig = Custom("Not", extension=EXTENSION_ID) # empty signature + + assert Not != bad_custom_sig + + bad_custom_args = Custom( + "Not", + extension=EXTENSION_ID, + signature=tys.FunctionType.endo([tys.Bool]), + args=[tys.Bool.type_arg()], + ) + + assert Not != bad_custom_args From b3824bb6191b145022c22d574385f1946081283c Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 10 Jul 2024 16:38:27 +0100 Subject: [PATCH 2/2] unexpected ops error --- hugr-py/src/hugr/ops.py | 14 ++++++++++++++ hugr-py/src/hugr/std/int.py | 5 +++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index e98eb91cb..4a76c6d59 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -206,6 +206,12 @@ class AsCustomOp(DataflowOp, Protocol): to behave as a custom dataflow operation. """ + @dataclass(frozen=True) + class InvalidCustomOp(Exception): + """Custom operation does not match the expected type.""" + + msg: str + @cached_property def custom_op(self) -> Custom: """:class:`Custom` operation that this type represents. @@ -233,6 +239,10 @@ def from_custom(cls, custom: Custom) -> Self | None: If successful, returns the singleton, else None. Non-singleton types should override this method. + + Raises: + InvalidCustomOp: If the given `custom` does not match the expected one for a + given extension/operation name. """ default = cls() if default.custom_op == custom: @@ -288,6 +298,10 @@ def to_custom(self) -> Custom: def from_custom(cls, custom: Custom) -> Custom: return custom + def check_id(self, extension: tys.ExtensionId, op_name: str) -> bool: + """Check if the operation matches the given extension and operation name.""" + return self.extension == extension and self.op_name == op_name + @dataclass() class MakeTuple(DataflowOp, _PartialOp): diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 34c826472..391ebc7eb 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -73,13 +73,14 @@ def to_custom(self) -> Custom: @classmethod def from_custom(cls, custom: Custom) -> Self | None: - if not (custom.extension == OPS_EXTENSION and custom.op_name == cls.op_name): + if not custom.check_id(OPS_EXTENSION, "idivmod_u"): return None match custom.args: case [tys.BoundedNatArg(n=a1), tys.BoundedNatArg(n=a2)]: return cls(arg1=a1, arg2=a2) case _: - return None + msg = f"Invalid args: {custom.args}" + raise AsCustomOp.InvalidCustomOp(msg) def __call__(self, a: ComWire, b: ComWire) -> Command: return DataflowOp.__call__(self, a, b)