Skip to content

Commit

Permalink
feat: Export the collections extension (#1506)
Browse files Browse the repository at this point in the history
#1450 embedded the standard extension definitions in hugr-py, including
`collections`, but it didn't add a way to load it as it did with all the
others.

This PR adds a `hugr.std.collections` module that just loads the bundled
json.

drive-by: Implement `__str__` for `FloatVal` and `IntVal`
  • Loading branch information
aborgna-q authored Sep 3, 2024
1 parent 6ab1f75 commit 70e0a64
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
40 changes: 40 additions & 0 deletions hugr-py/src/hugr/std/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Collection types and operations."""

from __future__ import annotations

from dataclasses import dataclass

import hugr.tys as tys
from hugr import val
from hugr.std import _load_extension
from hugr.utils import comma_sep_str

EXTENSION = _load_extension("collections")


def list_type(ty: tys.Type) -> tys.ExtType:
"""Returns a list type with a fixed element type."""
arg = tys.TypeTypeArg(ty)
return EXTENSION.types["List"].instantiate([arg])


@dataclass
class ListVal(val.ExtensionValue):
"""Constant value for a list of elements."""

v: list[val.Value]
ty: tys.Type

def __init__(self, v: list[val.Value], elem_ty: tys.Type) -> None:
self.v = v
self.ty = list_type(elem_ty)

def to_value(self) -> val.Extension:
name = "ListValue"
# The value list must be serialized at this point, otherwise the
# `Extension` value would not be serializable.
vs = [v._to_serial_root() for v in self.v]
return val.Extension(name, typ=self.ty, val=vs, extensions=[EXTENSION.name])

def __str__(self) -> str:
return f"[{comma_sep_str(self.v)}]"
3 changes: 3 additions & 0 deletions hugr-py/src/hugr/std/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ def to_value(self) -> val.Extension:
return val.Extension(
name, typ=FLOAT_T, val=payload, extensions=[EXTENSION.name]
)

def __str__(self) -> str:
return f"{self.v}"
5 changes: 5 additions & 0 deletions hugr-py/src/hugr/std/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
if TYPE_CHECKING:
from hugr.ops import Command, ComWire

CONVERSIONS_EXTENSION = _load_extension("arithmetic.conversions")

INT_TYPES_EXTENSION = _load_extension("arithmetic.int.types")
_INT_PARAM = tys.BoundedNatParam(7)

Expand Down Expand Up @@ -66,6 +68,9 @@ def to_value(self) -> val.Extension:
extensions=[INT_TYPES_EXTENSION.name],
)

def __str__(self) -> str:
return f"{self.v}"


INT_OPS_EXTENSION = _load_extension("arithmetic.int")

Expand Down

0 comments on commit 70e0a64

Please sign in to comment.