diff --git a/guppylang/std/_internal/checker.py b/guppylang/std/_internal/checker.py index bd1d95df..9afc259b 100644 --- a/guppylang/std/_internal/checker.py +++ b/guppylang/std/_internal/checker.py @@ -267,8 +267,10 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: assert isinstance(len_arg, ConstArg) if not self._is_numeric_or_bool_type(ty_arg.ty): raise GuppyError(err) - base_ty = ty_arg.ty - array_len = len_arg.const + _base_ty = ty_arg.ty + _array_len = len_arg.const + # See https://github.com/CQCL/guppylang/issues/631 + raise GuppyError(UnsupportedError(value, "Array results")) else: raise GuppyError(err) node = ResultExpr(value, base_ty, array_len, tag.value) diff --git a/guppylang/std/_internal/compiler/array.py b/guppylang/std/_internal/compiler/array.py index b17b8243..7c8f035f 100644 --- a/guppylang/std/_internal/compiler/array.py +++ b/guppylang/std/_internal/compiler/array.py @@ -99,7 +99,8 @@ class NewArrayCompiler(ArrayCompiler): def build_classical_array(self, elems: list[Wire]) -> Wire: """Lowers a call to `array.__new__` for classical arrays.""" - return self.builder.add_op(array_new(self.elem_ty, len(elems)), *elems) + # See https://github.com/CQCL/guppylang/issues/629 + return self.build_linear_array(elems) def build_linear_array(self, elems: list[Wire]) -> Wire: """Lowers a call to `array.__new__` for linear arrays.""" @@ -121,9 +122,12 @@ class ArrayGetitemCompiler(ArrayCompiler): def build_classical_getitem(self, array: Wire, idx: Wire) -> CallReturnWires: """Lowers a call to `array.__getitem__` for classical arrays.""" + # See https://github.com/CQCL/guppylang/issues/629 + elem_opt_ty = ht.Option(self.elem_ty) idx = self.builder.add_op(convert_itousize(), idx) - result = self.builder.add_op(array_get(self.elem_ty, self.length), array, idx) - elem = build_unwrap(self.builder, result, "Array index out of bounds") + result = self.builder.add_op(array_get(elem_opt_ty, self.length), array, idx) + elem_opt = build_unwrap(self.builder, result, "Array index out of bounds") + elem = build_unwrap(self.builder, elem_opt, "array.__getitem__: Internal error") return CallReturnWires(regular_returns=[elem], inout_returns=[array]) def build_linear_getitem(self, array: Wire, idx: Wire) -> CallReturnWires: @@ -163,9 +167,12 @@ def build_classical_setitem( self, array: Wire, idx: Wire, elem: Wire ) -> CallReturnWires: """Lowers a call to `array.__setitem__` for classical arrays.""" + # See https://github.com/CQCL/guppylang/issues/629 + elem_opt_ty = ht.Option(self.elem_ty) idx = self.builder.add_op(convert_itousize(), idx) + elem_opt = self.builder.add_op(ops.Tag(1, elem_opt_ty), elem) result = self.builder.add_op( - array_set(self.elem_ty, self.length), array, idx, elem + array_set(elem_opt_ty, self.length), array, idx, elem_opt ) # Unwrap the result, but we don't have to hold onto the returned old value _, array = build_unwrap_right(self.builder, result, "Array index out of bounds") diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index e091d245..a42aa877 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -133,9 +133,8 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type: # Linear elements are turned into an optional to enable unsafe indexing. # See `ArrayGetitemCompiler` for details. - elem_ty = ( - ht.Option(ty_arg.ty.to_hugr()) if ty_arg.ty.linear else ty_arg.ty.to_hugr() - ) + # Same also for classical arrays, see https://github.com/CQCL/guppylang/issues/629 + elem_ty = ht.Option(ty_arg.ty.to_hugr()) array = hugr.std.PRELUDE.get_type("array") return array.instantiate([len_arg.to_hugr(), ht.TypeTypeArg(elem_ty)]) diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index f98bc33e..24cfa5d3 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -227,6 +227,26 @@ def main(a: A @owned, i: int, j: int, k: int) -> A: validate(module.compile()) + +def test_generic_function(validate): + module = GuppyModule("test") + module.load(qubit) + T = guppy.type_var("T", linear=True, module=module) + n = guppy.nat_var("n", module=module) + + @guppy(module) + def foo(xs: array[T, n] @owned) -> array[T, n]: + return xs + + @guppy(module) + def main() -> tuple[array[int, 3], array[qubit, 2]]: + xs = array(1, 2, 3) + ys = array(qubit(), qubit()) + return foo(xs), foo(ys) + + validate(module.compile()) + + def test_exec_array(validate, run_int_fn): module = GuppyModule("test") diff --git a/tests/integration/test_result.py b/tests/integration/test_result.py index 33b12210..3ebc4ee4 100644 --- a/tests/integration/test_result.py +++ b/tests/integration/test_result.py @@ -1,3 +1,5 @@ +import pytest + from guppylang.std.builtins import result, nat, array from tests.util import compile_guppy @@ -21,6 +23,7 @@ def main(w: nat, x: int, y: float, z: bool) -> None: validate(main) +@pytest.mark.skip("See https://github.com/CQCL/guppylang/issues/631") def test_array(validate): @compile_guppy def main(