diff --git a/examples/random_walk_qpe.py b/examples/random_walk_qpe.py index 0ba36f40..a3dc03dd 100644 --- a/examples/random_walk_qpe.py +++ b/examples/random_walk_qpe.py @@ -100,7 +100,7 @@ def main() -> int: mu, sigma, ) - result(0, eigenvalue) # Expected outcome is 0.5 + result("eigenvalue", eigenvalue) # Expected outcome is 0.5 return 0 diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 362ba21b..f5935e21 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -4,7 +4,8 @@ from contextlib import contextmanager from typing import Any, TypeGuard, TypeVar -from hugr.serialization import ops +from hugr.serialization import ops, tys +from typing_extensions import assert_never from guppylang.ast_util import AstVisitor, get_type, with_loc, with_type from guppylang.cfg.builder import tmp_vars @@ -31,9 +32,13 @@ TensorCall, TypeApply, ) -from guppylang.tys.arg import ConstArg, TypeArg -from guppylang.tys.builtin import bool_type, get_element_type, is_list_type -from guppylang.tys.const import ConstValue +from guppylang.tys.builtin import ( + bool_type, + get_element_type, + is_bool_type, + is_list_type, +) +from guppylang.tys.const import BoundConstVar, ConstValue, ExistentialConstVar from guppylang.tys.subst import Inst from guppylang.tys.ty import ( BoundTypeVar, @@ -297,14 +302,50 @@ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> OutPortV: return unpack.out_port(node.struct_ty.fields.index(node.field)) def visit_ResultExpr(self, node: ResultExpr) -> OutPortV: - type_args = [ - TypeArg(node.ty), - ConstArg(ConstValue(value=node.tag, ty=NumericType(NumericType.Kind.Nat))), - ] + extra_args = [] + if isinstance(node.base_ty, NumericType): + match node.base_ty.kind: + case NumericType.Kind.Nat: + base_name = "uint" + extra_args = [ + tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH)) + ] + case NumericType.Kind.Int: + base_name = "int" + extra_args = [ + tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH)) + ] + case NumericType.Kind.Float: + base_name = "f64" + case kind: + assert_never(kind) + else: + # The only other valid base type is bool + assert is_bool_type(node.base_ty) + base_name = "bool" + if node.array_len is not None: + op_name = f"result_array_{base_name}" + match node.array_len: + case ConstValue(value=value): + assert isinstance(value, int) + extra_args = [tys.TypeArg(tys.BoundedNatArg(n=value)), *extra_args] + case BoundConstVar(): + # TODO: We need to handle this once we allow function definitions + # that are generic over array lengths + raise NotImplementedError + case ExistentialConstVar() as var: + raise InternalGuppyError( + f"Unsolved existential variable during Hugr lowering: {var}" + ) + case c: + assert_never(c) + else: + op_name = f"result_{base_name}" + args = [tys.TypeArg(tys.StringArg(arg=node.tag)), *extra_args] op = ops.CustomOp( extension="tket2.result", - name="result_uint", - args=[arg.to_hugr() for arg in type_args], + name=op_name, + args=args, parent=UNDEFINED, ) self.graph.add_node(ops.OpType(op), inputs=[self.visit(node.value)]) diff --git a/guppylang/nodes.py b/guppylang/nodes.py index 2b7cb5c9..6b57dad0 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any from guppylang.ast_util import AstNode +from guppylang.tys.const import Const from guppylang.tys.subst import Inst from guppylang.tys.ty import FunctionType, StructType, Type @@ -192,10 +193,12 @@ class ResultExpr(ast.expr): """A `result(tag, value)` expression.""" value: ast.expr - ty: Type - tag: int + base_ty: Type + #: Array length in case this is an array result, otherwise `None` + array_len: Const | None + tag: str - _fields = ("value", "ty", "tag") + _fields = ("value", "base_ty", "array_len", "tag") class NestedFunctionDef(ast.FunctionDef): diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index ed44b1aa..70558bf0 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -22,9 +22,15 @@ from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV from guppylang.nodes import GlobalCall, ResultExpr -from guppylang.tys.arg import ConstArg -from guppylang.tys.builtin import bool_type, int_type, list_type -from guppylang.tys.const import ConstValue +from guppylang.tys.arg import ConstArg, TypeArg +from guppylang.tys.builtin import ( + bool_type, + int_type, + is_array_type, + is_bool_type, + list_type, +) +from guppylang.tys.const import Const, ConstValue from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import ( FunctionType, @@ -277,28 +283,44 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: class ResultChecker(CustomCallChecker): - """Call checker for the `result` function. - - This is a temporary hack until we have implemented the proper results mechanism. - """ + """Call checker for the `result` function.""" def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: check_num_args(2, len(args), self.node) [tag, value] = args - if not isinstance(tag, ast.Constant) or not isinstance(tag.value, int): - raise GuppyTypeError("Expected an int literal", tag) + if not isinstance(tag, ast.Constant) or not isinstance(tag.value, str): + raise GuppyTypeError("Expected a string literal", tag) value, ty = ExprSynthesizer(self.ctx).synthesize(value) - if ty.linear: - raise GuppyTypeError( - f"Cannot use value with linear type `{ty}` as a result", value - ) - return with_loc(self.node, ResultExpr(value, ty, tag.value)), NoneType() + # We only allow numeric values or vectors of numeric values + err = ( + f"Expression of type `{ty}` is not a valid result. Only numeric values or " + "arrays thereof are allowed." + ) + if self._is_numeric_or_bool_type(ty): + base_ty = ty + array_len: Const | None = None + elif is_array_type(ty): + [ty_arg, len_arg] = ty.args + assert isinstance(ty_arg, TypeArg) + assert isinstance(len_arg, ConstArg) + if not self._is_numeric_or_bool_type(ty_arg.ty): + raise GuppyError(err, value) + base_ty = ty_arg.ty + array_len = len_arg.const + else: + raise GuppyError(err, value) + node = ResultExpr(value, base_ty, array_len, tag.value) + return with_loc(self.node, node), NoneType() def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: expr, res_ty = self.synthesize(args) subst, _ = check_type_against(res_ty, ty, self.node) return expr, subst + @staticmethod + def _is_numeric_or_bool_type(ty: Type) -> bool: + return isinstance(ty, NumericType) or is_bool_type(ty) + class NatTruedivCompiler(CustomCallCompiler): """Compiler for the `nat.__truediv__` method.""" diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 409e36fc..82a838e4 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, TypeGuard from hugr.serialization import tys @@ -227,6 +227,10 @@ def is_linst_type(ty: Type) -> bool: return isinstance(ty, OpaqueType) and ty.defn == linst_type_def +def is_array_type(ty: Type) -> TypeGuard[OpaqueType]: + return isinstance(ty, OpaqueType) and ty.defn == array_type_def + + def get_element_type(ty: Type) -> Type: assert isinstance(ty, OpaqueType) assert ty.defn in (list_type_def, linst_type_def) diff --git a/tests/error/misc_errors/result_array_not_numeric.err b/tests/error/misc_errors/result_array_not_numeric.err new file mode 100644 index 00000000..f7575eba --- /dev/null +++ b/tests/error/misc_errors/result_array_not_numeric.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:7 + +5: @compile_guppy +6: def foo(x: array[tuple[int, bool], 42]) -> None: +7: result("foo", x) + ^ +GuppyError: Expression of type `array[(int, bool), 42]` is not a valid result. Only numeric values or arrays thereof are allowed. diff --git a/tests/error/misc_errors/result_array_not_numeric.py b/tests/error/misc_errors/result_array_not_numeric.py new file mode 100644 index 00000000..f8129ec9 --- /dev/null +++ b/tests/error/misc_errors/result_array_not_numeric.py @@ -0,0 +1,7 @@ +from guppylang.prelude.builtins import result, array +from tests.util import compile_guppy + + +@compile_guppy +def foo(x: array[tuple[int, bool], 42]) -> None: + result("foo", x) diff --git a/tests/error/misc_errors/result_tag_not_static.err b/tests/error/misc_errors/result_tag_not_static.err index 9c20d4d1..d4dd0338 100644 --- a/tests/error/misc_errors/result_tag_not_static.err +++ b/tests/error/misc_errors/result_tag_not_static.err @@ -1,7 +1,7 @@ Guppy compilation failed. Error in file $FILE:7 5: @compile_guppy -6: def foo(x: int, y: bool) -> None: -7: result(x, y) - ^ -GuppyTypeError: Expected an int literal +6: def foo(y: bool) -> None: +7: result("foo" + "bar", y) + ^^^^^^^^^^^^^ +GuppyTypeError: Expected a string literal diff --git a/tests/error/misc_errors/result_tag_not_static.py b/tests/error/misc_errors/result_tag_not_static.py index 947105c2..100b1752 100644 --- a/tests/error/misc_errors/result_tag_not_static.py +++ b/tests/error/misc_errors/result_tag_not_static.py @@ -3,5 +3,5 @@ @compile_guppy -def foo(x: int, y: bool) -> None: - result(x, y) +def foo(y: bool) -> None: + result("foo" + "bar", y) diff --git a/tests/error/misc_errors/result_tag_not_int.err b/tests/error/misc_errors/result_tag_not_str.err similarity index 77% rename from tests/error/misc_errors/result_tag_not_int.err rename to tests/error/misc_errors/result_tag_not_str.err index e9151799..575e6cf1 100644 --- a/tests/error/misc_errors/result_tag_not_int.err +++ b/tests/error/misc_errors/result_tag_not_str.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:7 6: def foo(x: int) -> None: 7: result((), x) ^^ -GuppyTypeError: Expected an int literal +GuppyTypeError: Expected a string literal diff --git a/tests/error/misc_errors/result_tag_not_int.py b/tests/error/misc_errors/result_tag_not_str.py similarity index 100% rename from tests/error/misc_errors/result_tag_not_int.py rename to tests/error/misc_errors/result_tag_not_str.py diff --git a/tests/error/misc_errors/result_value_linear.err b/tests/error/misc_errors/result_value_linear.err index 9aa7b18d..514cc748 100644 --- a/tests/error/misc_errors/result_value_linear.err +++ b/tests/error/misc_errors/result_value_linear.err @@ -2,6 +2,6 @@ Guppy compilation failed. Error in file $FILE:14 12: @guppy(module) 13: def foo(q: qubit) -> None: -14: result(0, q) - ^ -GuppyTypeError: Cannot use value with linear type `qubit` as a result +14: result("foo", q) + ^ +GuppyError: Expression of type `qubit` is not a valid result. Only numeric values or arrays thereof are allowed. diff --git a/tests/error/misc_errors/result_value_linear.py b/tests/error/misc_errors/result_value_linear.py index f8e2bf22..69638f1f 100644 --- a/tests/error/misc_errors/result_value_linear.py +++ b/tests/error/misc_errors/result_value_linear.py @@ -11,7 +11,7 @@ @guppy(module) def foo(q: qubit) -> None: - result(0, q) + result("foo", q) module.compile() diff --git a/tests/integration/test_result.py b/tests/integration/test_result.py index bea82903..e2066480 100644 --- a/tests/integration/test_result.py +++ b/tests/integration/test_result.py @@ -1,37 +1,33 @@ -from guppylang.prelude.builtins import result +from guppylang.prelude.builtins import result, nat, array from tests.util import compile_guppy -def test_single(validate): +def test_basic(validate): @compile_guppy def main(x: int) -> None: - result(0, x) + result("foo", x) validate(main) -def test_value(validate): - @compile_guppy - def main(x: int) -> None: - return result(0, x) - - validate(main) - - -def test_nested(validate): +def test_multi(validate): @compile_guppy - def main(x: int, y: float, z: bool) -> None: - result(42, (x, (y, z))) + def main(w: nat, x: int, y: float, z: bool) -> None: + result("a", w) + result("b", x) + result("c", y) + result("d", z) validate(main) -def test_multi(validate): +def test_array(validate): @compile_guppy - def main(x: int, y: float, z: bool) -> None: - result(0, x) - result(1, y) - result(2, z) + def main(w: array[nat, 42], x: array[int, 5], y: array[float, 1], z: array[bool, 0]) -> None: + result("a", w) + result("b", x) + result("c", y) + result("d", z) validate(main) @@ -39,8 +35,8 @@ def main(x: int, y: float, z: bool) -> None: def test_same_tag(validate): @compile_guppy def main(x: int, y: float, z: bool) -> None: - result(0, x) - result(0, y) - result(0, z) + result("foo", x) + result("foo", y) + result("foo", z) validate(main) diff --git a/validator/Cargo.lock b/validator/Cargo.lock index 1a2e932a..6845457c 100644 --- a/validator/Cargo.lock +++ b/validator/Cargo.lock @@ -285,7 +285,8 @@ dependencies = [ "lazy_static", "pyo3", "serde_json", - "tket2", + "tket2 0.1.0-alpha.2", + "tket2-hseries", ] [[package]] @@ -954,6 +955,60 @@ dependencies = [ "zstd", ] +[[package]] +name = "tket2" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ded27cd4a131a8e071090d9043c048739196cf1fb6f763a217c1d97af3b67477" +dependencies = [ + "bytemuck", + "cgmath", + "chrono", + "crossbeam-channel", + "csv", + "delegate", + "derive_more", + "downcast-rs", + "fxhash", + "hugr", + "hugr-core", + "itertools", + "lazy_static", + "num-rational", + "petgraph", + "portgraph", + "priority-queue", + "rayon", + "serde", + "serde_json", + "smol_str", + "strum", + "strum_macros", + "thiserror", + "tket-json-rs", + "tracing", + "typetag", + "zstd", +] + +[[package]] +name = "tket2-hseries" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f25d9beb4668a4955504bb0c6f6af8a61c2510d4bf4fd5c3220ea0c3ee2a2e81" +dependencies = [ + "hugr", + "itertools", + "lazy_static", + "serde", + "serde_json", + "smol_str", + "strum", + "strum_macros", + "thiserror", + "tket2 0.1.0", +] + [[package]] name = "tracing" version = "0.1.40" diff --git a/validator/Cargo.toml b/validator/Cargo.toml index facadf06..0e664ec9 100644 --- a/validator/Cargo.toml +++ b/validator/Cargo.toml @@ -14,3 +14,4 @@ hugr = "0.10.0" lazy_static = "1.5.0" serde_json = "1.0.111" tket2 = { git = "https://github.com/CQCL/tket2.git", rev = "eca258bfbef5fcd82a0d3b3d70cb736e275b3487" } +tket2-hseries = "0.1.0" diff --git a/validator/src/lib.rs b/validator/src/lib.rs index cfe30be8..cf2ee99a 100644 --- a/validator/src/lib.rs +++ b/validator/src/lib.rs @@ -5,6 +5,7 @@ use hugr::std_extensions::logic; use lazy_static::lazy_static; use pyo3::prelude::*; use tket2::extension::{TKET1_EXTENSION, TKET2_EXTENSION}; +use tket2_hseries::extension::result; lazy_static! { pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ @@ -17,6 +18,7 @@ lazy_static! { collections::EXTENSION.to_owned(), TKET1_EXTENSION.to_owned(), TKET2_EXTENSION.to_owned(), + result::EXTENSION.to_owned(), ]) .unwrap(); }