-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
196 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import ast | ||
from dataclasses import dataclass, field | ||
|
||
from hugr.serialization import ops | ||
|
||
from guppylang.ast_util import AstNode | ||
from guppylang.checker.core import Globals | ||
from guppylang.compiler.core import CompiledGlobals, DFContainer | ||
from guppylang.definition.common import CompilableDef, ParsableDef | ||
from guppylang.definition.value import CompiledValueDef, ValueDef | ||
from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV, VNode | ||
from guppylang.tys.parsing import type_from_ast | ||
|
||
|
||
@dataclass(frozen=True) | ||
class RawExternDef(ParsableDef): | ||
"""A raw extern symbol definition provided by the user.""" | ||
|
||
symbol: str | ||
constant: bool | ||
type_ast: ast.expr | ||
|
||
description: str = field(default="extern", init=False) | ||
|
||
def parse(self, globals: Globals) -> "ExternDef": | ||
"""Parses and checks the user-provided signature of the function.""" | ||
return ExternDef( | ||
self.id, | ||
self.name, | ||
self.defined_at, | ||
type_from_ast(self.type_ast, globals, None), | ||
self.symbol, | ||
self.constant, | ||
self.type_ast, | ||
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ExternDef(RawExternDef, ValueDef, CompilableDef): | ||
"""An extern symbol definition.""" | ||
|
||
def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledExternDef": | ||
"""Adds a Hugr constant node for the extern definition to the provided graph.""" | ||
custom_const = { | ||
"symbol": self.symbol, | ||
"typ": self.ty.to_hugr(), | ||
"constant": self.constant, | ||
} | ||
value = ops.ExtensionValue( | ||
extensions=["prelude"], | ||
typ=self.ty.to_hugr(), | ||
value=ops.CustomConst(c="ConstExternalSymbol", v=custom_const), | ||
) | ||
const_node = graph.add_constant(ops.Value(value), self.ty, parent) | ||
return CompiledExternDef( | ||
self.id, | ||
self.name, | ||
self.defined_at, | ||
self.ty, | ||
self.symbol, | ||
self.constant, | ||
self.type_ast, | ||
const_node, | ||
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class CompiledExternDef(ExternDef, CompiledValueDef): | ||
"""An extern symbol definition that has been compiled to a Hugr constant.""" | ||
|
||
const_node: VNode | ||
|
||
def load( | ||
self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode | ||
) -> OutPortV: | ||
"""Loads the extern value into a local Hugr dataflow graph.""" | ||
return graph.add_load_constant(self.const_node.out_port(0)).out_port(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Guppy compilation failed. Error in file $FILE:9 | ||
|
||
7: module = GuppyModule("test") | ||
8: | ||
9: guppy.extern(module, "x", ty="float[int]") | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
GuppyError: Type `float` is not parameterized |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
from guppylang.prelude.quantum import qubit | ||
|
||
import guppylang.prelude.quantum as quantum | ||
|
||
|
||
module = GuppyModule("test") | ||
|
||
guppy.extern(module, "x", ty="float[int]") | ||
|
||
module.compile() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from hugr.serialization import ops | ||
|
||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
|
||
|
||
def test_extern_float(validate): | ||
module = GuppyModule("module") | ||
|
||
guppy.extern(module, "ext", ty="float") | ||
|
||
@guppy(module) | ||
def main() -> float: | ||
return ext + ext # noqa: F821 | ||
|
||
hg = module.compile() | ||
validate(hg) | ||
|
||
[c] = [n.op.root for n in hg.nodes() if isinstance(n.op.root, ops.Const)] | ||
assert isinstance(c.v.root, ops.ExtensionValue) | ||
assert c.v.root.value.v["symbol"] == "ext" | ||
|
||
|
||
def test_extern_alt_symbol(validate): | ||
module = GuppyModule("module") | ||
|
||
guppy.extern(module, "ext", ty="int", symbol="foo") | ||
|
||
@guppy(module) | ||
def main() -> int: | ||
return ext # noqa: F821 | ||
|
||
hg = module.compile() | ||
validate(hg) | ||
|
||
[c] = [n.op.root for n in hg.nodes() if isinstance(n.op.root, ops.Const)] | ||
assert isinstance(c.v.root, ops.ExtensionValue) | ||
assert c.v.root.value.v["symbol"] == "foo" | ||
|
||
def test_extern_tuple(validate): | ||
module = GuppyModule("module") | ||
|
||
guppy.extern(module, "ext", ty="tuple[int, float]") | ||
|
||
@guppy(module) | ||
def main() -> float: | ||
x, y = ext # noqa: F821 | ||
return x + y | ||
|
||
validate(module.compile()) |