Skip to content

Commit

Permalink
feat(hugr-py)!: user facing Extension class (#1413)
Browse files Browse the repository at this point in the history
Allows users to define extensions in Python, attaching operation, type
and value definitions to them. Also provides utilities for then using
those objects when building HUGRs.

Diff is big, but a lot of it is poetry and schema changes. Open to
suggestions as to how to break up but I think it is fairly
self-contained.

Closes #1374
Closes #1412 
Also as a drive-by adds the missing `binary` field to serialised opdef
signatures.


BREAKING CHANGE: `AsCustomOp` replaced with `AsExtOp`, so all such
operations now need to be attached to an extension.

---------

Co-authored-by: Alec Edgington <54802828+cqc-alec@users.noreply.github.com>
  • Loading branch information
ss2165 and cqc-alec authored Aug 14, 2024
1 parent 8dfea09 commit c6473c9
Show file tree
Hide file tree
Showing 15 changed files with 1,146 additions and 249 deletions.
472 changes: 472 additions & 0 deletions hugr-py/src/hugr/ext.py

Large diffs are not rendered by default.

13 changes: 12 additions & 1 deletion hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ToNode,
_SubPort,
)
from hugr.ops import Call, Const, DataflowOp, Module, Op
from hugr.ops import Call, Const, Custom, DataflowOp, Module, Op
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.tys import Kind, Type, ValueKind
Expand All @@ -34,6 +34,7 @@
from .exceptions import ParentBeforeChild

if TYPE_CHECKING:
from hugr import ext
from hugr.val import Value


Expand Down Expand Up @@ -598,6 +599,16 @@ def _constrain_offset(self, p: P) -> PortOffset:

return offset

def resolve_extensions(self, registry: ext.ExtensionRegistry) -> Hugr:
"""Resolve extension types and operations in the HUGR by matching them to
extensions in the registry.
"""
for node in self:
op = self[node].op
if isinstance(op, Custom):
self[node].op = op.resolve(registry)
return self

@classmethod
def from_serial(cls, serial: SerialHugr) -> Hugr:
"""Load a HUGR from a serialized form."""
Expand Down
164 changes: 133 additions & 31 deletions hugr-py/src/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from hugr import ext
from hugr.serialization.ops import BaseOp


Expand Down Expand Up @@ -201,79 +202,94 @@ def _set_in_types(self, types: tys.TypeRow) -> None:


@runtime_checkable
class AsCustomOp(DataflowOp, Protocol):
class AsExtOp(DataflowOp, Protocol):
"""Abstract interface that types can implement
to behave as a custom dataflow operation.
to behave as an extension dataflow operation.
"""

@dataclass(frozen=True)
class InvalidCustomOp(Exception):
"""Custom operation does not match the expected type."""
class InvalidExtOp(Exception):
"""Extension operation does not match the expected type."""

msg: str

@cached_property
def custom_op(self) -> Custom:
""":class:`Custom` operation that this type represents.
def ext_op(self) -> ExtOp:
""":class:`ExtOp` operation that this type represents.
Computed once using :meth:`to_custom` and cached - should be deterministic.
Computed once using :meth:`op_def` :meth:`type_args` and :meth:`type_args`.
Each of those methods should be deterministic.
"""
return self.to_custom()
return ExtOp(self.op_def(), self.cached_signature(), self.type_args())

def to_custom(self) -> Custom:
"""Convert this type to a :class:`Custom` operation.
def op_def(self) -> ext.OpDef:
"""The :class:`tys.OpDef` for this operation.
Used by :attr:`custom_op`, so must be deterministic.
Used by :attr:`ext_op`, so must be deterministic.
"""
... # pragma: no cover

def type_args(self) -> list[tys.TypeArg]:
"""Type arguments of the operation.
Used by :attr:`op_def`, so must be deterministic.
"""
return []

def cached_signature(self) -> tys.FunctionType | None:
"""Cached signature of the operation, if there is one.
Used by :attr:`op_def`, so must be deterministic.
"""
return None

@classmethod
def from_custom(cls, custom: Custom) -> Self | None:
"""Load from a :class:`Custom` operation.
def from_ext(cls, ext_op: ExtOp) -> Self | None:
"""Load from a :class:`ExtOp` operation.
By default assumes the type of `cls` is a singleton,
and compares the result of :meth:`to_custom` with the given `custom`.
and compares the result of :meth:`to_ext` with the given `ext_op`.
If successful, returns the singleton, else None.
Non-singleton types should override this method.
Raises:
InvalidCustomOp: If the given `custom` does not match the expected one for a
InvalidCustomOp: If the given `ext_op` does not match the expected one for a
given extension/operation name.
"""
default = cls()
if default.custom_op == custom:
if default.ext_op == ext_op:
return default
return None

def __eq__(self, other: object) -> bool:
if not isinstance(other, AsCustomOp):
if not isinstance(other, AsExtOp):
return NotImplemented
slf, other = self.custom_op, other.custom_op
slf, other = self.ext_op, other.ext_op
return (
slf.extension == other.extension
and slf.name == other.name
and slf.signature == other.signature
slf._op_def == other._op_def
and slf.outer_signature() == other.outer_signature()
and slf.args == other.args
)

def outer_signature(self) -> tys.FunctionType:
return self.custom_op.signature
return self.ext_op.outer_signature()

def to_serial(self, parent: Node) -> sops.CustomOp:
return self.custom_op.to_serial(parent)
return self.ext_op.to_serial(parent)

@property
def num_out(self) -> int:
return len(self.custom_op.signature.output)
return len(self.outer_signature().output)


@dataclass(frozen=True, eq=False)
class Custom(AsCustomOp):
"""A non-core dataflow operation defined in an extension."""
class Custom(DataflowOp):
"""Serializable version of non-core dataflow operation defined in an extension."""

name: str
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)
Expand All @@ -291,17 +307,103 @@ def to_serial(self, parent: Node) -> sops.CustomOp:
args=ser_it(self.args),
)

def to_custom(self) -> Custom:
return self
def outer_signature(self) -> tys.FunctionType:
return self.signature

@classmethod
def from_custom(cls, custom: Custom) -> Custom:
return custom
@property
def num_out(self) -> int:
return len(self.outer_signature().output)

def check_id(self, extension: tys.ExtensionId, name: str) -> bool:
"""Check if the operation matches the given extension and operation name."""
return self.extension == extension and self.name == name

def resolve(self, registry: ext.ExtensionRegistry) -> ExtOp | Custom:
"""Resolve the custom operation to an :class:`ExtOp`.
If extension or operation is not found, returns itself.
"""
from hugr.ext import ExtensionRegistry, Extension # noqa: I001 # no circular import

try:
op_def = registry.get_extension(self.extension).get_op(self.name)
except (
Extension.OperationNotFound,
ExtensionRegistry.ExtensionNotFound,
):
return self

signature = self.signature.resolve(registry)
args = [arg.resolve(registry) for arg in self.args]
# TODO check signature matches op_def reported signature
# if/once op_def can compute signature from type scheme + args
return ExtOp(op_def, signature, args)


@dataclass(frozen=True, eq=False)
class ExtOp(AsExtOp):
"""A non-core dataflow operation defined in an extension."""

_op_def: ext.OpDef
signature: tys.FunctionType | None = None
args: list[tys.TypeArg] = field(default_factory=list)

def to_custom_op(self) -> Custom:
ext = self._op_def._extension
if self.signature is None:
poly_func = self._op_def.signature.poly_func
if poly_func is None or len(poly_func.params) > 0:
msg = "For polymorphic ops signature must be cached."
raise ValueError(msg)
sig = poly_func.body
else:
sig = self.signature

return Custom(
name=self._op_def.name,
signature=sig,
extension=ext.name if ext else "",
args=self.args,
)

def to_serial(self, parent: Node) -> sops.CustomOp:
return self.to_custom_op().to_serial(parent)

def op_def(self) -> ext.OpDef:
return self._op_def

def type_args(self) -> list[tys.TypeArg]:
return self.args

def cached_signature(self) -> tys.FunctionType | None:
return self.signature

@classmethod
def from_ext(cls, custom: ExtOp) -> ExtOp:
return custom

def outer_signature(self) -> tys.FunctionType:
if self.signature is not None:
return self.signature
poly_func = self._op_def.signature.poly_func
if poly_func is None:
msg = "Polymorphic signature must be cached."
raise ValueError(msg)
return poly_func.body


class RegisteredOp(AsExtOp):
"""Base class for operations that are registered with an extension using
:meth:`Extension.register_op <hugr.ext.Extension.register_op>`.
"""

#: Known operation definition.
const_op_def: ext.OpDef # must be initialised by register_op

def op_def(self) -> ext.OpDef:
# override for AsExtOp.op_def
return self.const_op_def


@dataclass()
class MakeTuple(DataflowOp, _PartialOp):
Expand Down
Loading

0 comments on commit c6473c9

Please sign in to comment.