Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): AsCustomOp protocol for user-defined custom op types. #1290

Merged
merged 2 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 72 additions & 8 deletions hugr-py/src/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from __future__ import annotations

from dataclasses import dataclass, field
from functools import cached_property
from typing import TYPE_CHECKING, Protocol, TypeVar, runtime_checkable

from typing_extensions import Self

import hugr.serialization.ops as sops
from hugr import tys, val
from hugr.node_port import Direction, InPort, Node, OutPort, Wire
Expand Down Expand Up @@ -197,8 +200,69 @@
self._types = types


@dataclass(frozen=True)
class Custom(DataflowOp):
@runtime_checkable
class AsCustomOp(DataflowOp, Protocol):
"""Abstract interface that types can implement
to behave as a custom dataflow operation.
"""

@cached_property
def custom_op(self) -> Custom:
""":class:`Custom` operation that this type represents.

Computed once using :meth:`to_custom` and cached - should be deterministic.
"""
return self.to_custom()

def to_custom(self) -> Custom:
"""Convert this type to a :class:`Custom` operation.


Used by :attr:`custom_op`, so must be deterministic.
"""
... # pragma: no cover

@classmethod
def from_custom(cls, custom: Custom) -> Self | None:
"""Load from a :class:`Custom` operation.


By default assumes the type of `cls` is a singleton,
and compares the result of :meth:`to_custom` with the given `custom`.

If successful, returns the singleton, else None.

Non-singleton types should override this method.
"""
default = cls()
if default.custom_op == custom:
return default
return None

Check warning on line 240 in hugr-py/src/hugr/ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/ops.py#L240

Added line #L240 was not covered by tests

def __eq__(self, other: object) -> bool:
if not isinstance(other, AsCustomOp):
return NotImplemented

Check warning on line 244 in hugr-py/src/hugr/ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/ops.py#L244

Added line #L244 was not covered by tests
slf, other = self.custom_op, other.custom_op
return (
slf.extension == other.extension
and slf.op_name == other.op_name
and slf.signature == other.signature
and slf.args == other.args
)

def outer_signature(self) -> tys.FunctionType:
return self.custom_op.signature

def to_serial(self, parent: Node) -> sops.CustomOp:
return self.custom_op.to_serial(parent)

@property
def num_out(self) -> int:
return len(self.custom_op.signature.output)


@dataclass(frozen=True, eq=False)
class Custom(AsCustomOp):
"""A non-core dataflow operation defined in an extension."""

op_name: str
Expand All @@ -207,10 +271,6 @@
extension: tys.ExtensionId = ""
args: list[tys.TypeArg] = field(default_factory=list)

@property
def num_out(self) -> int:
return len(self.signature.output)

def to_serial(self, parent: Node) -> sops.CustomOp:
return sops.CustomOp(
parent=parent.idx,
Expand All @@ -221,8 +281,12 @@
args=ser_it(self.args),
)

def outer_signature(self) -> tys.FunctionType:
return self.signature
def to_custom(self) -> Custom:
return self

@classmethod
def from_custom(cls, custom: Custom) -> Custom:
return custom


@dataclass()
Expand Down
54 changes: 36 additions & 18 deletions hugr-py/src/hugr/std/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@

from __future__ import annotations

from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar

from typing_extensions import Self

from hugr import tys, val
from hugr.ops import Custom
from hugr.ops import AsCustomOp, Custom, DataflowOp

if TYPE_CHECKING:
from hugr.ops import Command, ComWire

Check warning on line 14 in hugr-py/src/hugr/std/int.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/std/int.py#L14

Added line #L14 was not covered by tests


def int_t(width: int) -> tys.Opaque:
Expand Down Expand Up @@ -44,27 +50,39 @@
return val.Extension("int", INT_T, self.v)


@dataclass(frozen=True)
class IntOps(Custom):
"""Base class for integer operations."""

extension: tys.ExtensionId = "arithmetic.int"


_ARG_I32 = tys.BoundedNatArg(n=5)
OPS_EXTENSION: tys.ExtensionId = "arithmetic.int"


@dataclass(frozen=True)
class _DivModDef(IntOps):
class _DivModDef(AsCustomOp):
"""DivMod operation, has two inputs and two outputs."""

num_out: int = 2
extension: tys.ExtensionId = "arithmetic.int"
op_name: str = "idivmod_u"
signature: tys.FunctionType = field(
default_factory=lambda: tys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2)
)
args: list[tys.TypeArg] = field(default_factory=lambda: [_ARG_I32, _ARG_I32])
op_name: ClassVar[str] = "idivmod_u"
arg1: int = 5
arg2: int = 5

def to_custom(self) -> Custom:
return Custom(
"idivmod_u",
tys.FunctionType(
input=[int_t(self.arg1)] * 2, output=[int_t(self.arg2)] * 2
),
extension=OPS_EXTENSION,
args=[tys.BoundedNatArg(n=self.arg1), tys.BoundedNatArg(n=self.arg2)],
)

@classmethod
def from_custom(cls, custom: Custom) -> Self | None:
if not (custom.extension == OPS_EXTENSION and custom.op_name == cls.op_name):
return None

Check warning on line 77 in hugr-py/src/hugr/std/int.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/std/int.py#L77

Added line #L77 was not covered by tests
match custom.args:
case [tys.BoundedNatArg(n=a1), tys.BoundedNatArg(n=a2)]:
return cls(arg1=a1, arg2=a2)
case _:
return None

Check warning on line 82 in hugr-py/src/hugr/std/int.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/std/int.py#L81-L82

Added lines #L81 - L82 were not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this raise an error? If the name and extension match then we should always be able to extract the op


def __call__(self, a: ComWire, b: ComWire) -> Command:
return DataflowOp.__call__(self, a, b)


#: DivMod operation.
Expand Down
20 changes: 6 additions & 14 deletions hugr-py/src/hugr/std/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,24 @@
from typing import TYPE_CHECKING

from hugr import tys
from hugr.ops import Command, Custom
from hugr.ops import AsCustomOp, Command, Custom, DataflowOp

if TYPE_CHECKING:
from hugr.ops import ComWire


@dataclass(frozen=True)
class LogicOps(Custom):
"""Base class for logic operations."""

extension: tys.ExtensionId = "logic"


_NotSig = tys.FunctionType.endo([tys.Bool])
EXTENSION_ID: tys.ExtensionId = "logic"


@dataclass(frozen=True)
class _NotDef(LogicOps):
class _NotDef(AsCustomOp):
"""Not operation."""

num_out: int = 1
op_name: str = "Not"
signature: tys.FunctionType = _NotSig
def to_custom(self) -> Custom:
return Custom("Not", tys.FunctionType.endo([tys.Bool]), extension=EXTENSION_ID)

def __call__(self, a: ComWire) -> Command:
return super().__call__(a)
return DataflowOp.__call__(self, a)


#: Not operation
Expand Down
95 changes: 64 additions & 31 deletions hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,77 +5,110 @@
import pathlib
import subprocess
from dataclasses import dataclass
from typing import TYPE_CHECKING
from enum import Enum
from typing import TYPE_CHECKING, TypeVar

from typing_extensions import Self

from hugr import tys
from hugr.hugr import Hugr
from hugr.ops import Command, Custom
from hugr.ops import AsCustomOp, Command, Custom, DataflowOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.std.float import FLOAT_T

if TYPE_CHECKING:
from hugr.ops import ComWire


@dataclass(frozen=True)
class QuantumOps(Custom):
extension: tys.ExtensionId = "tket2.quantum"
QUANTUM_EXTENSION_ID: tys.ExtensionId = "quantum.tket2"

E = TypeVar("E", bound=Enum)


_OneQbSig = tys.FunctionType.endo([tys.Qubit])
def _load_enum(enum_cls: type[E], custom: Custom) -> E | None:
if (
custom.extension == QUANTUM_EXTENSION_ID
and custom.op_name in enum_cls.__members__
):
return enum_cls(custom.op_name)
return None


@dataclass(frozen=True)
class OneQbGate(QuantumOps):
op_name: str
num_out: int = 1
signature: tys.FunctionType = _OneQbSig
class OneQbGate(AsCustomOp):
# Have to nest enum to avoid meta class conflict
class _Enum(Enum):
H = "H"

_enum: _Enum

def __call__(self, q: ComWire) -> Command:
return super().__call__(q)
return DataflowOp.__call__(self, q)

def to_custom(self) -> Custom:
return Custom(
self._enum.value,
tys.FunctionType.endo([tys.Qubit]),
extension=QUANTUM_EXTENSION_ID,
)

H = OneQbGate("H")
@classmethod
def from_custom(cls, custom: Custom) -> Self | None:
return cls(e) if (e := _load_enum(cls._Enum, custom)) else None


_TwoQbSig = tys.FunctionType.endo([tys.Qubit] * 2)
H = OneQbGate(OneQbGate._Enum.H)


@dataclass(frozen=True)
class TwoQbGate(QuantumOps):
op_name: str
num_out: int = 2
signature: tys.FunctionType = _TwoQbSig
class TwoQbGate(AsCustomOp):
class _Enum(Enum):
CX = "CX"

def __call__(self, q0: ComWire, q1: ComWire) -> Command:
return super().__call__(q0, q1)
_enum: _Enum

def to_custom(self) -> Custom:
return Custom(
self._enum.value,
tys.FunctionType.endo([tys.Qubit] * 2),
extension=QUANTUM_EXTENSION_ID,
)

@classmethod
def from_custom(cls, custom: Custom) -> Self | None:
return cls(e) if (e := _load_enum(cls._Enum, custom)) else None

CX = TwoQbGate("CX")
def __call__(self, q0: ComWire, q1: ComWire) -> Command:
return DataflowOp.__call__(self, q0, q1)

_MeasSig = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool])

CX = TwoQbGate(TwoQbGate._Enum.CX)


@dataclass(frozen=True)
class MeasureDef(QuantumOps):
op_name: str = "Measure"
num_out: int = 2
signature: tys.FunctionType = _MeasSig
class MeasureDef(AsCustomOp):
def to_custom(self) -> Custom:
return Custom(
"Measure",
tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]),
extension=QUANTUM_EXTENSION_ID,
)

def __call__(self, q: ComWire) -> Command:
return super().__call__(q)


Measure = MeasureDef()

_RzSig = tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit])


@dataclass(frozen=True)
class RzDef(QuantumOps):
op_name: str = "Rz"
num_out: int = 1
signature: tys.FunctionType = _RzSig
class RzDef(AsCustomOp):
def to_custom(self) -> Custom:
return Custom(
"Rz",
tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit]),
extension=QUANTUM_EXTENSION_ID,
)

def __call__(self, q: ComWire, fl_wire: ComWire) -> Command:
return super().__call__(q, fl_wire)
Expand Down
Loading
Loading