Skip to content

Commit

Permalink
fix!: Use latest results extension spec (#370)
Browse files Browse the repository at this point in the history
Fixes #369 and fixes #365.

Make Guppy conform with the latest results spec introduced in
CQCL/tket2#494

BREAKING CHANGE:
* Result tags are now strings instead of ints
* Only numeric values and arrays thereof are allowed as results
  • Loading branch information
mark-koch authored Aug 12, 2024
1 parent 7bf419f commit d9cc8d2
Show file tree
Hide file tree
Showing 17 changed files with 201 additions and 63 deletions.
2 changes: 1 addition & 1 deletion examples/random_walk_qpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main() -> int:
mu,
sigma,
)
result(0, eigenvalue) # Expected outcome is 0.5
result("eigenvalue", eigenvalue) # Expected outcome is 0.5
return 0


Expand Down
61 changes: 51 additions & 10 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from contextlib import contextmanager
from typing import Any, TypeGuard, TypeVar

from hugr.serialization import ops
from hugr.serialization import ops, tys
from typing_extensions import assert_never

from guppylang.ast_util import AstVisitor, get_type, with_loc, with_type
from guppylang.cfg.builder import tmp_vars
Expand All @@ -31,9 +32,13 @@
TensorCall,
TypeApply,
)
from guppylang.tys.arg import ConstArg, TypeArg
from guppylang.tys.builtin import bool_type, get_element_type, is_list_type
from guppylang.tys.const import ConstValue
from guppylang.tys.builtin import (
bool_type,
get_element_type,
is_bool_type,
is_list_type,
)
from guppylang.tys.const import BoundConstVar, ConstValue, ExistentialConstVar
from guppylang.tys.subst import Inst
from guppylang.tys.ty import (
BoundTypeVar,
Expand Down Expand Up @@ -297,14 +302,50 @@ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> OutPortV:
return unpack.out_port(node.struct_ty.fields.index(node.field))

def visit_ResultExpr(self, node: ResultExpr) -> OutPortV:
type_args = [
TypeArg(node.ty),
ConstArg(ConstValue(value=node.tag, ty=NumericType(NumericType.Kind.Nat))),
]
extra_args = []
if isinstance(node.base_ty, NumericType):
match node.base_ty.kind:
case NumericType.Kind.Nat:
base_name = "uint"
extra_args = [
tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))
]
case NumericType.Kind.Int:
base_name = "int"
extra_args = [
tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))
]
case NumericType.Kind.Float:
base_name = "f64"
case kind:
assert_never(kind)
else:
# The only other valid base type is bool
assert is_bool_type(node.base_ty)
base_name = "bool"
if node.array_len is not None:
op_name = f"result_array_{base_name}"
match node.array_len:
case ConstValue(value=value):
assert isinstance(value, int)
extra_args = [tys.TypeArg(tys.BoundedNatArg(n=value)), *extra_args]
case BoundConstVar():
# TODO: We need to handle this once we allow function definitions
# that are generic over array lengths
raise NotImplementedError
case ExistentialConstVar() as var:
raise InternalGuppyError(
f"Unsolved existential variable during Hugr lowering: {var}"
)
case c:
assert_never(c)
else:
op_name = f"result_{base_name}"
args = [tys.TypeArg(tys.StringArg(arg=node.tag)), *extra_args]
op = ops.CustomOp(
extension="tket2.result",
name="result_uint",
args=[arg.to_hugr() for arg in type_args],
name=op_name,
args=args,
parent=UNDEFINED,
)
self.graph.add_node(ops.OpType(op), inputs=[self.visit(node.value)])
Expand Down
9 changes: 6 additions & 3 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Any

from guppylang.ast_util import AstNode
from guppylang.tys.const import Const
from guppylang.tys.subst import Inst
from guppylang.tys.ty import FunctionType, StructType, Type

Expand Down Expand Up @@ -192,10 +193,12 @@ class ResultExpr(ast.expr):
"""A `result(tag, value)` expression."""

value: ast.expr
ty: Type
tag: int
base_ty: Type
#: Array length in case this is an array result, otherwise `None`
array_len: Const | None
tag: str

_fields = ("value", "ty", "tag")
_fields = ("value", "base_ty", "array_len", "tag")


class NestedFunctionDef(ast.FunctionDef):
Expand Down
50 changes: 36 additions & 14 deletions guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@
from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV
from guppylang.nodes import GlobalCall, ResultExpr
from guppylang.tys.arg import ConstArg
from guppylang.tys.builtin import bool_type, int_type, list_type
from guppylang.tys.const import ConstValue
from guppylang.tys.arg import ConstArg, TypeArg
from guppylang.tys.builtin import (
bool_type,
int_type,
is_array_type,
is_bool_type,
list_type,
)
from guppylang.tys.const import Const, ConstValue
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import (
FunctionType,
Expand Down Expand Up @@ -277,28 +283,44 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:


class ResultChecker(CustomCallChecker):
"""Call checker for the `result` function.
This is a temporary hack until we have implemented the proper results mechanism.
"""
"""Call checker for the `result` function."""

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
check_num_args(2, len(args), self.node)
[tag, value] = args
if not isinstance(tag, ast.Constant) or not isinstance(tag.value, int):
raise GuppyTypeError("Expected an int literal", tag)
if not isinstance(tag, ast.Constant) or not isinstance(tag.value, str):
raise GuppyTypeError("Expected a string literal", tag)
value, ty = ExprSynthesizer(self.ctx).synthesize(value)
if ty.linear:
raise GuppyTypeError(
f"Cannot use value with linear type `{ty}` as a result", value
)
return with_loc(self.node, ResultExpr(value, ty, tag.value)), NoneType()
# We only allow numeric values or vectors of numeric values
err = (
f"Expression of type `{ty}` is not a valid result. Only numeric values or "
"arrays thereof are allowed."
)
if self._is_numeric_or_bool_type(ty):
base_ty = ty
array_len: Const | None = None
elif is_array_type(ty):
[ty_arg, len_arg] = ty.args
assert isinstance(ty_arg, TypeArg)
assert isinstance(len_arg, ConstArg)
if not self._is_numeric_or_bool_type(ty_arg.ty):
raise GuppyError(err, value)
base_ty = ty_arg.ty
array_len = len_arg.const
else:
raise GuppyError(err, value)
node = ResultExpr(value, base_ty, array_len, tag.value)
return with_loc(self.node, node), NoneType()

def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
expr, res_ty = self.synthesize(args)
subst, _ = check_type_against(res_ty, ty, self.node)
return expr, subst

@staticmethod
def _is_numeric_or_bool_type(ty: Type) -> bool:
return isinstance(ty, NumericType) or is_bool_type(ty)


class NatTruedivCompiler(CustomCallCompiler):
"""Compiler for the `nat.__truediv__` method."""
Expand Down
6 changes: 5 additions & 1 deletion guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, TypeGuard

from hugr.serialization import tys

Expand Down Expand Up @@ -227,6 +227,10 @@ def is_linst_type(ty: Type) -> bool:
return isinstance(ty, OpaqueType) and ty.defn == linst_type_def


def is_array_type(ty: Type) -> TypeGuard[OpaqueType]:
return isinstance(ty, OpaqueType) and ty.defn == array_type_def


def get_element_type(ty: Type) -> Type:
assert isinstance(ty, OpaqueType)
assert ty.defn in (list_type_def, linst_type_def)
Expand Down
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_array_not_numeric.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:7

5: @compile_guppy
6: def foo(x: array[tuple[int, bool], 42]) -> None:
7: result("foo", x)
^
GuppyError: Expression of type `array[(int, bool), 42]` is not a valid result. Only numeric values or arrays thereof are allowed.
7 changes: 7 additions & 0 deletions tests/error/misc_errors/result_array_not_numeric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from guppylang.prelude.builtins import result, array
from tests.util import compile_guppy


@compile_guppy
def foo(x: array[tuple[int, bool], 42]) -> None:
result("foo", x)
8 changes: 4 additions & 4 deletions tests/error/misc_errors/result_tag_not_static.err
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Guppy compilation failed. Error in file $FILE:7

5: @compile_guppy
6: def foo(x: int, y: bool) -> None:
7: result(x, y)
^
GuppyTypeError: Expected an int literal
6: def foo(y: bool) -> None:
7: result("foo" + "bar", y)
^^^^^^^^^^^^^
GuppyTypeError: Expected a string literal
4 changes: 2 additions & 2 deletions tests/error/misc_errors/result_tag_not_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@


@compile_guppy
def foo(x: int, y: bool) -> None:
result(x, y)
def foo(y: bool) -> None:
result("foo" + "bar", y)
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:7
6: def foo(x: int) -> None:
7: result((), x)
^^
GuppyTypeError: Expected an int literal
GuppyTypeError: Expected a string literal
6 changes: 3 additions & 3 deletions tests/error/misc_errors/result_value_linear.err
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ Guppy compilation failed. Error in file $FILE:14

12: @guppy(module)
13: def foo(q: qubit) -> None:
14: result(0, q)
^
GuppyTypeError: Cannot use value with linear type `qubit` as a result
14: result("foo", q)
^
GuppyError: Expression of type `qubit` is not a valid result. Only numeric values or arrays thereof are allowed.
2 changes: 1 addition & 1 deletion tests/error/misc_errors/result_value_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@guppy(module)
def foo(q: qubit) -> None:
result(0, q)
result("foo", q)


module.compile()
40 changes: 18 additions & 22 deletions tests/integration/test_result.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,42 @@
from guppylang.prelude.builtins import result
from guppylang.prelude.builtins import result, nat, array
from tests.util import compile_guppy


def test_single(validate):
def test_basic(validate):
@compile_guppy
def main(x: int) -> None:
result(0, x)
result("foo", x)

validate(main)


def test_value(validate):
@compile_guppy
def main(x: int) -> None:
return result(0, x)

validate(main)


def test_nested(validate):
def test_multi(validate):
@compile_guppy
def main(x: int, y: float, z: bool) -> None:
result(42, (x, (y, z)))
def main(w: nat, x: int, y: float, z: bool) -> None:
result("a", w)
result("b", x)
result("c", y)
result("d", z)

validate(main)


def test_multi(validate):
def test_array(validate):
@compile_guppy
def main(x: int, y: float, z: bool) -> None:
result(0, x)
result(1, y)
result(2, z)
def main(w: array[nat, 42], x: array[int, 5], y: array[float, 1], z: array[bool, 0]) -> None:
result("a", w)
result("b", x)
result("c", y)
result("d", z)

validate(main)


def test_same_tag(validate):
@compile_guppy
def main(x: int, y: float, z: bool) -> None:
result(0, x)
result(0, y)
result(0, z)
result("foo", x)
result("foo", y)
result("foo", z)

validate(main)
Loading

0 comments on commit d9cc8d2

Please sign in to comment.