diff --git a/guppylang/decorator.py b/guppylang/decorator.py index 994b224a..2b5f94f5 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -1,3 +1,4 @@ +import ast import inspect from collections.abc import Callable from dataclasses import dataclass, field @@ -7,7 +8,7 @@ from hugr.serialization import ops, tys -from guppylang.ast_util import has_empty_body +from guppylang.ast_util import annotate_location, has_empty_body from guppylang.definition.common import DefId from guppylang.definition.custom import ( CustomCallChecker, @@ -18,6 +19,7 @@ RawCustomFunctionDef, ) from guppylang.definition.declaration import RawFunctionDecl +from guppylang.definition.extern import RawExternDef from guppylang.definition.function import RawFunctionDef, parse_py_func from guppylang.definition.parameter import TypeVarDef from guppylang.definition.struct import RawStructDef @@ -226,6 +228,43 @@ def dec(f: Callable[..., Any]) -> RawFunctionDecl: return dec + def extern( + self, + module: GuppyModule, + name: str, + ty: str, + symbol: str | None = None, + constant: bool = True, + ) -> RawExternDef: + """Adds an extern symbol to a module.""" + try: + type_ast = ast.parse(ty, mode="eval").body + except SyntaxError: + err = f"Not a valid Guppy type: `{ty}`" + raise GuppyError(err) from None + + # Try to annotate the type AST with source information. This requires us to + # inspect the stack frame of the caller + if frame := inspect.currentframe(): # noqa: SIM102 + if caller_frame := frame.f_back: # noqa: SIM102 + if caller_module := inspect.getmodule(caller_frame): + info = inspect.getframeinfo(caller_frame) + source_lines, _ = inspect.getsourcelines(caller_module) + source = "".join(source_lines) + annotate_location(type_ast, source, info.filename, 0) + # Modify the AST so that all sub-nodes span the entire line. We + # can't give a better location since we don't know the column + # offset of the `ty` argument + for node in [type_ast, *ast.walk(type_ast)]: + node.lineno, node.col_offset = info.lineno, 0 + node.end_col_offset = len(source_lines[info.lineno - 1]) + + defn = RawExternDef( + DefId.fresh(module), name, None, symbol or name, constant, type_ast + ) + module.register_def(defn) + return defn + def load(self, m: ModuleType | GuppyModule) -> None: caller = self._get_python_caller() if caller not in self._modules: diff --git a/guppylang/definition/extern.py b/guppylang/definition/extern.py new file mode 100644 index 00000000..ba5d9e6b --- /dev/null +++ b/guppylang/definition/extern.py @@ -0,0 +1,77 @@ +import ast +from dataclasses import dataclass, field + +from hugr.serialization import ops + +from guppylang.ast_util import AstNode +from guppylang.checker.core import Globals +from guppylang.compiler.core import CompiledGlobals, DFContainer +from guppylang.definition.common import CompilableDef, ParsableDef +from guppylang.definition.value import CompiledValueDef, ValueDef +from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV, VNode +from guppylang.tys.parsing import type_from_ast + + +@dataclass(frozen=True) +class RawExternDef(ParsableDef): + """A raw extern symbol definition provided by the user.""" + + symbol: str + constant: bool + type_ast: ast.expr + + description: str = field(default="extern", init=False) + + def parse(self, globals: Globals) -> "ExternDef": + """Parses and checks the user-provided signature of the function.""" + return ExternDef( + self.id, + self.name, + self.defined_at, + type_from_ast(self.type_ast, globals, None), + self.symbol, + self.constant, + self.type_ast, + ) + + +@dataclass(frozen=True) +class ExternDef(RawExternDef, ValueDef, CompilableDef): + """An extern symbol definition.""" + + def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledExternDef": + """Adds a Hugr constant node for the extern definition to the provided graph.""" + custom_const = { + "symbol": self.symbol, + "typ": self.ty.to_hugr(), + "constant": self.constant, + } + value = ops.ExtensionValue( + extensions=["prelude"], + typ=self.ty.to_hugr(), + value=ops.CustomConst(c="ConstExternalSymbol", v=custom_const), + ) + const_node = graph.add_constant(ops.Value(value), self.ty, parent) + return CompiledExternDef( + self.id, + self.name, + self.defined_at, + self.ty, + self.symbol, + self.constant, + self.type_ast, + const_node, + ) + + +@dataclass(frozen=True) +class CompiledExternDef(ExternDef, CompiledValueDef): + """An extern symbol definition that has been compiled to a Hugr constant.""" + + const_node: VNode + + def load( + self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode + ) -> OutPortV: + """Loads the extern value into a local Hugr dataflow graph.""" + return graph.add_load_constant(self.const_node.out_port(0)).out_port(0) diff --git a/tests/error/misc_errors/extern_bad_type.err b/tests/error/misc_errors/extern_bad_type.err new file mode 100644 index 00000000..382e00ea --- /dev/null +++ b/tests/error/misc_errors/extern_bad_type.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:9 + +7: module = GuppyModule("test") +8: +9: guppy.extern(module, "x", ty="float[int]") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +GuppyError: Type `float` is not parameterized diff --git a/tests/error/misc_errors/extern_bad_type.py b/tests/error/misc_errors/extern_bad_type.py new file mode 100644 index 00000000..4f8e044d --- /dev/null +++ b/tests/error/misc_errors/extern_bad_type.py @@ -0,0 +1,12 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.quantum import qubit + +import guppylang.prelude.quantum as quantum + + +module = GuppyModule("test") + +guppy.extern(module, "x", ty="float[int]") + +module.compile() diff --git a/tests/error/test_misc_errors.py b/tests/error/test_misc_errors.py index 03803e65..a82593c3 100644 --- a/tests/error/test_misc_errors.py +++ b/tests/error/test_misc_errors.py @@ -1,6 +1,8 @@ import pathlib import pytest +from guppylang import GuppyModule, guppy +from guppylang.error import GuppyError from tests.error.util import run_error_test path = pathlib.Path(__file__).parent.resolve() / "misc_errors" @@ -19,5 +21,12 @@ @pytest.mark.parametrize("file", files) -def test_type_errors(file, capsys): +def test_misc_errors(file, capsys): run_error_test(file, capsys) + + +def test_extern_bad_type_syntax(): + module = GuppyModule("test") + + with pytest.raises(GuppyError, match="Not a valid Guppy type: `foo bar`"): + guppy.extern(module, name="x", ty="foo bar") diff --git a/tests/integration/test_extern.py b/tests/integration/test_extern.py new file mode 100644 index 00000000..76573328 --- /dev/null +++ b/tests/integration/test_extern.py @@ -0,0 +1,50 @@ +from hugr.serialization import ops + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +def test_extern_float(validate): + module = GuppyModule("module") + + guppy.extern(module, "ext", ty="float") + + @guppy(module) + def main() -> float: + return ext + ext # noqa: F821 + + hg = module.compile() + validate(hg) + + [c] = [n.op.root for n in hg.nodes() if isinstance(n.op.root, ops.Const)] + assert isinstance(c.v.root, ops.ExtensionValue) + assert c.v.root.value.v["symbol"] == "ext" + + +def test_extern_alt_symbol(validate): + module = GuppyModule("module") + + guppy.extern(module, "ext", ty="int", symbol="foo") + + @guppy(module) + def main() -> int: + return ext # noqa: F821 + + hg = module.compile() + validate(hg) + + [c] = [n.op.root for n in hg.nodes() if isinstance(n.op.root, ops.Const)] + assert isinstance(c.v.root, ops.ExtensionValue) + assert c.v.root.value.v["symbol"] == "foo" + +def test_extern_tuple(validate): + module = GuppyModule("module") + + guppy.extern(module, "ext", ty="tuple[int, float]") + + @guppy(module) + def main() -> float: + x, y = ext # noqa: F821 + return x + y + + validate(module.compile())