Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update top-level definitions to use new diagnostics #604

Merged
merged 4 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
45 changes: 45 additions & 0 deletions guppylang/checker/errors/generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from dataclasses import dataclass
from typing import ClassVar

from guppylang.diagnostic import Error


@dataclass(frozen=True)
class UnsupportedError(Error):
title: ClassVar[str] = "Unsupported"
span_label: ClassVar[str] = "{things} {is_are} not supported{extra}"
things: str
singular: bool = False
unsupported_in: str = ""

@property
def is_are(self) -> str:
return "is" if self.singular else "are"

@property
def extra(self) -> str:
return f" in {self.unsupported_in}" if self.unsupported_in else ""


@dataclass(frozen=True)
class UnexpectedError(Error):
title: ClassVar[str] = "Unexpected {things}"
span_label: ClassVar[str] = "Unexpected {things}{extra}"
things: str
unexpected_in: str = ""

@property
def extra(self) -> str:
return f" in {self.unexpected_in}" if self.unexpected_in else ""


@dataclass(frozen=True)
class ExpectedError(Error):
title: ClassVar[str] = "Expected {things}"
span_label: ClassVar[str] = "Expected {things}{extra}"
things: str
got: str = ""

@property
def extra(self) -> str:
return f", got {self.got}" if self.got else ""
14 changes: 4 additions & 10 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from hugr.package import FuncDefnPointer, ModulePointer

import guppylang
from guppylang.ast_util import annotate_location, has_empty_body
from guppylang.ast_util import annotate_location
from guppylang.definition.common import DefId, Definition
from guppylang.definition.const import RawConstDef
from guppylang.definition.custom import (
Expand All @@ -28,7 +28,6 @@
from guppylang.definition.function import (
CompiledFunctionDef,
RawFunctionDef,
parse_py_func,
)
from guppylang.definition.parameter import ConstVarDef, TypeVarDef
from guppylang.definition.struct import RawStructDef
Expand Down Expand Up @@ -309,17 +308,12 @@ def custom(
mod = module or self.get_module()

def dec(f: PyFunc) -> RawCustomFunctionDef:
func_ast, docstring = parse_py_func(f, self._sources)
if not has_empty_body(func_ast):
raise GuppyError(
"Body of custom function declaration must be empty",
func_ast.body[0],
)
call_checker = checker or DefaultCallChecker()
func = RawCustomFunctionDef(
DefId.fresh(mod),
name or func_ast.name,
func_ast,
name or f.__name__,
None,
f,
call_checker,
compiler or NotImplementedCallCompiler(),
higher_order_value,
Expand Down
10 changes: 10 additions & 0 deletions guppylang/definition/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from hugr.build.dfg import DefinitionBuilder, OpVar

from guppylang.diagnostic import Fatal
from guppylang.span import SourceMap

if TYPE_CHECKING:
Expand Down Expand Up @@ -157,3 +158,12 @@ def compile_inner(self, globals: "CompiledGlobals") -> None:
Opposed to `CompilableDef.compile()`, we have access to all other compiled
definitions here, which allows things like mutual recursion.
"""


@dataclass(frozen=True)
class UnknownSourceError(Fatal):
title: ClassVar[str] = "Cannot find source"
message: ClassVar[str] = (
"Unable to look up the source code for Python object `{obj}`"
)
obj: object
72 changes: 55 additions & 17 deletions guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, ClassVar

from hugr import Wire, ops
from hugr import tys as ht
from hugr.build.dfg import DfBase

from guppylang.ast_util import AstNode, get_type, with_loc, with_type
from guppylang.ast_util import AstNode, get_type, has_empty_body, with_loc, with_type
from guppylang.checker.core import Context, Globals
from guppylang.checker.expr_checker import check_call, synthesize_call
from guppylang.checker.func_checker import check_signature
from guppylang.compiler.core import CompiledGlobals, DFContainer
from guppylang.definition.common import ParsableDef
from guppylang.definition.value import CallReturnWires, CompiledCallableDef
from guppylang.diagnostic import Error, Help
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.nodes import GlobalCall
from guppylang.span import SourceMap
Expand All @@ -27,6 +29,42 @@
type_to_row,
)

if TYPE_CHECKING:
from guppylang.definition.function import PyFunc


@dataclass(frozen=True)
class BodyNotEmptyError(Error):
title: ClassVar[str] = "Unexpected function body"
span_label: ClassVar[str] = "Body of custom function `{name}` must be empty"
name: str


@dataclass(frozen=True)
class NoSignatureError(Error):
title: ClassVar[str] = "Type signature missing"
span_label: ClassVar[str] = "Custom function `{name}` requires a type signature"
name: str

@dataclass(frozen=True)
class Suggestion(Help):
message: ClassVar[str] = (
"Annotate the type signature of `{name}` or disallow the use of `{name}` "
"as a higher-order value: `@guppy.custom(..., higher_order_value=False)`"
)

def __post_init__(self) -> None:
self.add_sub_diagnostic(NoSignatureError.Suggestion(None))


@dataclass(frozen=True)
class NotHigherOrderError(Error):
title: ClassVar[str] = "Not higher-order"
span_label: ClassVar[str] = (
"Function `{name}` may not be used as a higher-order value"
)
name: str


@dataclass(frozen=True)
class RawCustomFunctionDef(ParsableDef):
Expand All @@ -47,7 +85,7 @@ class RawCustomFunctionDef(ParsableDef):
higher_order_value: Whether the function may be used as a higher-order value.
"""

defined_at: ast.FunctionDef
python_func: "PyFunc"
call_checker: "CustomCallChecker"
call_compiler: "CustomInoutCallCompiler"

Expand All @@ -69,12 +107,17 @@ def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
code. The only information we need to access is that it's a function type and
that there are no unsolved existential vars.
"""
sig = self._get_signature(globals)
from guppylang.definition.function import parse_py_func

func_ast, docstring = parse_py_func(self.python_func, sources)
if not has_empty_body(func_ast):
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
sig = self._get_signature(func_ast, globals)
ty = sig or FunctionType([], NoneType())
return CustomFunctionDef(
self.id,
self.name,
self.defined_at,
func_ast,
ty,
self.call_checker,
self.call_compiler,
Expand Down Expand Up @@ -104,7 +147,9 @@ def compile_call(
)
return self.call_compiler.compile_with_inouts(args).regular_returns

def _get_signature(self, globals: Globals) -> FunctionType | None:
def _get_signature(
self, node: ast.FunctionDef, globals: Globals
) -> FunctionType | None:
"""Returns the type of the function, if known.

Type annotations are needed if we rely on the default call checker or
Expand All @@ -117,19 +162,15 @@ def _get_signature(self, globals: Globals) -> FunctionType | None:
requires_type_annotation = (
isinstance(self.call_checker, DefaultCallChecker) or self.higher_order_value
)
has_type_annotation = self.defined_at.returns or any(
arg.annotation for arg in self.defined_at.args.args
has_type_annotation = node.returns or any(
arg.annotation for arg in node.args.args
)

if requires_type_annotation and not has_type_annotation:
raise GuppyError(
f"Type signature for function `{self.name}` is required. "
"Alternatively, try passing `higher_order_value=False` on definition.",
self.defined_at,
)
raise GuppyError(NoSignatureError(node, self.name))

if requires_type_annotation:
return check_signature(self.defined_at, globals)
return check_signature(node, globals)
else:
return None

Expand Down Expand Up @@ -196,10 +237,7 @@ def load_with_args(
"""
# TODO: This should be raised during checking, not compilation!
if not self.higher_order_value:
raise GuppyError(
"This function does not support usage in a higher-order context",
node,
)
raise GuppyError(NotHigherOrderError(node, self.name))
assert len(self.ty.params) == len(type_args)

# We create a `FunctionDef` that takes some inputs, compiles a call to the
Expand Down
13 changes: 10 additions & 3 deletions guppylang/definition/declaration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
from dataclasses import dataclass, field
from typing import ClassVar

from hugr import Node, Wire
from hugr import tys as ht
Expand All @@ -14,13 +15,21 @@
from guppylang.definition.common import CompilableDef, ParsableDef
from guppylang.definition.function import PyFunc, parse_py_func
from guppylang.definition.value import CallableDef, CallReturnWires, CompiledCallableDef
from guppylang.diagnostic import Error
from guppylang.error import GuppyError
from guppylang.nodes import GlobalCall
from guppylang.span import SourceMap
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import Type, type_to_row


@dataclass(frozen=True)
class BodyNotEmptyError(Error):
title: ClassVar[str] = "Unexpected function body"
span_label: ClassVar[str] = "Body of declared function `{name}` must be empty"
name: str


@dataclass(frozen=True)
class RawFunctionDecl(ParsableDef):
"""A raw function declaration provided by the user.
Expand All @@ -38,9 +47,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
func_ast, docstring = parse_py_func(self.python_func, sources)
ty = check_signature(func_ast, globals.with_python_scope(self.python_scope))
if not has_empty_body(func_ast):
raise GuppyError(
"Body of function declaration must be empty", func_ast.body[0]
)
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
return CheckedFunctionDecl(
self.id,
self.name,
Expand Down
16 changes: 10 additions & 6 deletions guppylang/definition/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from guppylang.ast_util import AstNode, annotate_location, with_loc
from guppylang.checker.cfg_checker import CheckedCFG
from guppylang.checker.core import Context, Globals, Place, PyScope
from guppylang.checker.errors.generic import ExpectedError, UnsupportedError
from guppylang.checker.expr_checker import check_call, synthesize_call
from guppylang.checker.func_checker import (
check_global_func_def,
Expand All @@ -22,7 +23,12 @@
)
from guppylang.compiler.core import CompiledGlobals, DFContainer
from guppylang.compiler.func_compiler import compile_global_func_def
from guppylang.definition.common import CheckableDef, CompilableDef, ParsableDef
from guppylang.definition.common import (
CheckableDef,
CompilableDef,
ParsableDef,
UnknownSourceError,
)
from guppylang.definition.value import CallableDef, CallReturnWires, CompiledCallableDef
from guppylang.error import GuppyError
from guppylang.ipython_inspect import find_ipython_def, is_running_ipython
Expand Down Expand Up @@ -60,9 +66,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
func_ast, docstring = parse_py_func(self.python_func, sources)
ty = check_signature(func_ast, globals.with_python_scope(self.python_scope))
if ty.parametrized:
raise GuppyError(
"Generic function definitions are not supported yet", func_ast
)
raise GuppyError(UnsupportedError(func_ast, "Generic function definitions"))
return ParsedFunctionDef(
self.id, self.name, func_ast, ty, self.python_scope, docstring
)
Expand Down Expand Up @@ -251,9 +255,9 @@ def parse_py_func(f: PyFunc, sources: SourceMap) -> tuple[ast.FunctionDef, str |
else:
file = inspect.getsourcefile(f)
if file is None:
raise GuppyError("Couldn't determine source file for function")
raise GuppyError(UnknownSourceError(None, f))
doug-q marked this conversation as resolved.
Show resolved Hide resolved
sources.add_file(file)
annotate_location(func_ast, source, file, line_offset)
if not isinstance(func_ast, ast.FunctionDef):
raise GuppyError("Expected a function definition", func_ast)
raise GuppyError(ExpectedError(func_ast, "a function definition"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no coverage, ok?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test 👍

return parse_function_with_docstring(func_ast)
Loading
Loading