Skip to content

Commit

Permalink
feat: Add SizedIter wrapper type (#611)
Browse files Browse the repository at this point in the history
Adds a new wrapper type `SizedIter` to the standard library that
annotates an iterator with a static size hint. This will be used for
example in array comprehensions to infer the size of the resulting
array. The type gets erased when lowering to Hugr.

Updates the `range` function to emit such a size hint when the range
stop is given by a static number. For example the expression `range(10)`
is typed as `SizedIter[Range, 10]` whereas `range(n + 1)` is typed as
`Range`. Currently, this is implemented via a `CustomCallChecker`, but
in the future we could use function overloading and `Literal` types to
implement this in Guppy source:

```python
@guppy.overloaded
def range[n: nat](stop: Literal[n]) -> SizedIter[Range, n]:
   return SizedIter(Range(0, stop))

@guppy.overloaded
def range(stop: int) -> Range:
   return Range(0, stop)
```

Closes #610.
  • Loading branch information
mark-koch authored Nov 4, 2024
1 parent f5670f6 commit 2e9da6b
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 19 deletions.
2 changes: 2 additions & 0 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
list_type_def,
nat_type_def,
none_type_def,
sized_iter_type_def,
tuple_type_def,
)
from guppylang.tys.ty import (
Expand Down Expand Up @@ -222,6 +223,7 @@ def default() -> "Globals":
float_type_def,
list_type_def,
array_type_def,
sized_iter_type_def,
]
defs = {defn.id: defn for defn in builtins}
names = {defn.name: defn.id for defn in builtins}
Expand Down
6 changes: 5 additions & 1 deletion guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,16 @@ def _setup(self, ctx: Context, node: AstNode, func: CustomFunctionDef) -> None:
self.node = node
self.func = func

@abstractmethod
def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
"""Checks the return value against a given type.
Returns a (possibly) transformed and annotated AST node for the call.
"""
from guppylang.checker.expr_checker import check_type_against

expr, res_ty = self.synthesize(args)
subst, _ = check_type_against(res_ty, ty, self.node)
return expr, subst

@abstractmethod
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
Expand Down
44 changes: 43 additions & 1 deletion guppylang/prelude/_internal/checker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
from typing import cast

from guppylang.ast_util import AstNode, with_loc
from guppylang.ast_util import AstNode, with_loc, with_type
from guppylang.checker.core import Context
from guppylang.checker.expr_checker import (
ExprChecker,
Expand All @@ -15,6 +16,7 @@
CustomFunctionDef,
DefaultCallChecker,
)
from guppylang.definition.struct import CheckedStructDef, RawStructDef
from guppylang.definition.value import CallableDef
from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppylang.nodes import GlobalCall, ResultExpr
Expand All @@ -25,13 +27,15 @@
int_type,
is_array_type,
is_bool_type,
sized_iter_type,
)
from guppylang.tys.const import Const, ConstValue
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import (
FunctionType,
NoneType,
NumericType,
StructType,
Type,
unify,
)
Expand Down Expand Up @@ -279,3 +283,41 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
@staticmethod
def _is_numeric_or_bool_type(ty: Type) -> bool:
return isinstance(ty, NumericType) or is_bool_type(ty)


class RangeChecker(CustomCallChecker):
"""Call checker for the `range` function."""

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
check_num_args(1, len(args), self.node)
[stop] = args
stop, _ = ExprChecker(self.ctx).check(stop, int_type(), "argument")
range_iter, range_ty = self.make_range(stop)
if isinstance(stop, ast.Constant):
return to_sized_iter(range_iter, range_ty, stop.value, self.ctx)
return range_iter, range_ty

def range_ty(self) -> StructType:
from guppylang.prelude.builtins import Range

def_id = cast(RawStructDef, Range).id
range_type_def = self.ctx.globals.defs[def_id]
assert isinstance(range_type_def, CheckedStructDef)
return StructType([], range_type_def)

def make_range(self, stop: ast.expr) -> tuple[ast.expr, Type]:
make_range = self.ctx.globals.get_instance_func(self.range_ty(), "__new__")
assert make_range is not None
start = with_type(int_type(), with_loc(self.node, ast.Constant(value=0)))
return make_range.synthesize_call([start, stop], self.node, self.ctx)


def to_sized_iter(
iterator: ast.expr, range_ty: Type, size: int, ctx: Context
) -> tuple[ast.expr, Type]:
"""Adds a static size annotation to an iterator."""
sized_iter_ty = sized_iter_type(range_ty, size)
make_sized_iter = ctx.globals.get_instance_func(sized_iter_ty, "__new__")
assert make_sized_iter is not None
sized_iter, _ = make_sized_iter.check_call([iterator], sized_iter_ty, iterator, ctx)
return sized_iter, sized_iter_ty
50 changes: 34 additions & 16 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
CoercingChecker,
DunderChecker,
NewArrayChecker,
RangeChecker,
ResultChecker,
ReversingChecker,
UnsupportedChecker,
Expand Down Expand Up @@ -53,6 +54,7 @@
int_type_def,
list_type_def,
nat_type_def,
sized_iter_type_def,
)

guppy.init_module(import_builtins=False)
Expand Down Expand Up @@ -559,6 +561,32 @@ def __len__(self: array[L, n]) -> int: ...
def __new__(): ...


@guppy.extend_type(sized_iter_type_def)
class SizedIter:
"""A wrapper around an iterator type `T` promising that the iterator will yield
exactly `n` values.
Annotating an iterator with an incorrect size is undefined behaviour.
"""

def __class_getitem__(cls, item: Any) -> type:
# Dummy implementation to allow subscripting of the `SizedIter` type in
# positions that are evaluated by the Python interpreter
return cls

@guppy.custom(NoopCompiler())
def __new__(iterator: L @ owned) -> "SizedIter[L, n]": # type: ignore[type-arg]
"""Casts an iterator into a `SizedIter`."""

@guppy.custom(NoopCompiler())
def unwrap_iter(self: "SizedIter[L, n]" @ owned) -> L:
"""Extracts the actual iterator."""

@guppy.custom(NoopCompiler())
def __iter__(self: "SizedIter[L, n]" @ owned) -> L:
"""Extracts the actual iterator."""


# TODO: This is a temporary hack until we have implemented the proper results mechanism.
@guppy.custom(checker=ResultChecker(), higher_order_value=False)
def result(tag, value): ...
Expand Down Expand Up @@ -769,43 +797,33 @@ def property(x): ...

@guppy.struct
class Range:
stop: int

@guppy
def __iter__(self: "Range") -> "RangeIter":
return RangeIter(0, self.stop) # type: ignore[call-arg]


@guppy.struct
class RangeIter:
next: int
stop: int

@guppy
def __iter__(self: "RangeIter") -> "RangeIter":
def __iter__(self: "Range") -> "Range":
return self

@guppy
def __hasnext__(self: "RangeIter") -> tuple[bool, "RangeIter"]:
def __hasnext__(self: "Range") -> tuple[bool, "Range"]:
return (self.next < self.stop, self)

@guppy
def __next__(self: "RangeIter") -> tuple[int, "RangeIter"]:
def __next__(self: "Range") -> tuple[int, "Range"]:
# Fine not to check bounds while we can only be called from inside a `for` loop.
# if self.start >= self.stop:
# raise StopIteration
return (self.next, RangeIter(self.next + 1, self.stop)) # type: ignore[call-arg]
return (self.next, Range(self.next + 1, self.stop)) # type: ignore[call-arg]

@guppy
def __end__(self: "RangeIter") -> None:
def __end__(self: "Range") -> None:
pass


@guppy
@guppy.custom(checker=RangeChecker(), higher_order_value=False)
def range(stop: int) -> Range:
"""Limited version of python range().
Only a single argument (stop/limit) is supported."""
return Range(stop) # type: ignore[call-arg]


@guppy.custom(checker=UnsupportedChecker(), higher_order_value=False)
Expand Down
39 changes: 39 additions & 0 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type:
return array.instantiate([len_arg.to_hugr(), ht.TypeTypeArg(elem_ty)])


def _sized_iter_to_hugr(args: Sequence[Argument]) -> ht.Type:
[ty_arg, len_arg] = args
assert isinstance(ty_arg, TypeArg)
assert isinstance(len_arg, ConstArg)
return ty_arg.ty.to_hugr()


callable_type_def = CallableTypeDef(DefId.fresh(), None)
tuple_type_def = _TupleTypeDef(DefId.fresh(), None)
none_type_def = _NoneTypeDef(DefId.fresh(), None)
Expand Down Expand Up @@ -179,6 +186,17 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type:
always_linear=False,
to_hugr=_array_to_hugr,
)
sized_iter_type_def = OpaqueTypeDef(
id=DefId.fresh(),
name="SizedIter",
defined_at=None,
params=[
TypeParam(0, "T", can_be_linear=True),
ConstParam(1, "n", NumericType(NumericType.Kind.Nat)),
],
always_linear=False,
to_hugr=_sized_iter_to_hugr,
)


def bool_type() -> OpaqueType:
Expand All @@ -200,6 +218,13 @@ def array_type(element_ty: Type, length: int) -> OpaqueType:
)


def sized_iter_type(iter_type: Type, size: int) -> OpaqueType:
nat_type = NumericType(NumericType.Kind.Nat)
return OpaqueType(
[TypeArg(iter_type), ConstArg(ConstValue(nat_type, size))], sized_iter_type_def
)


def is_bool_type(ty: Type) -> bool:
return isinstance(ty, OpaqueType) and ty.defn == bool_type_def

Expand All @@ -212,9 +237,23 @@ def is_array_type(ty: Type) -> TypeGuard[OpaqueType]:
return isinstance(ty, OpaqueType) and ty.defn == array_type_def


def is_sized_iter_type(ty: Type) -> TypeGuard[OpaqueType]:
return isinstance(ty, OpaqueType) and ty.defn == sized_iter_type_def


def get_element_type(ty: Type) -> Type:
assert isinstance(ty, OpaqueType)
assert ty.defn == list_type_def
(arg,) = ty.args
assert isinstance(arg, TypeArg)
return arg.ty


def get_iter_size(ty: Type) -> int:
assert isinstance(ty, OpaqueType)
assert ty.defn == sized_iter_type_def
match ty.args:
case [_, ConstArg(ConstValue(value=int(size)))]:
return size
case _:
raise InternalGuppyError("Unexpected type args")
22 changes: 21 additions & 1 deletion tests/integration/test_range.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from guppylang.decorator import guppy
from guppylang.prelude.builtins import nat, range
from guppylang.prelude.builtins import nat, range, SizedIter, Range
from guppylang.module import GuppyModule
from tests.util import compile_guppy

Expand All @@ -20,7 +20,27 @@ def negative() -> int:
total += 100 + x
return total

@guppy(module)
def non_static() -> int:
total = 0
n = 4
for x in range(n + 1):
total += x + 100 # Make the initial 0 obvious
return total

compiled = module.compile()
validate(compiled)
run_int_fn(compiled, expected=510)
run_int_fn(compiled, expected=0, fn_name="negative")
run_int_fn(compiled, expected=510, fn_name="non_static")


def test_static_size(validate):
module = GuppyModule("test")

@guppy(module)
def negative() -> SizedIter[Range, 10]:
return range(10)

validate(module.compile())

0 comments on commit 2e9da6b

Please sign in to comment.