Skip to content

Commit

Permalink
feat: Cache accessed source files (#551)
Browse files Browse the repository at this point in the history
We used to store the entire source code that was used to build an AST as
a field within *every node* of the AST. This was a hack so that we can
render the source line from an AST node when an error occurs. With our
new diagnostics system, we should do better than that!

With this PR, we cache the source of every file accessed by the compiler
in a new `SourceMap` class. Then, we can look up the source associated
with a span when rendering a diagnostic.

Note that we can't rely on Pythons `linecache` since we also need to
keep track of cells in jupyter notebooks as individual "files".
  • Loading branch information
mark-koch authored Oct 10, 2024
1 parent d671bb5 commit 7f07cc7
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 21 deletions.
18 changes: 14 additions & 4 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
find_guppy_module_in_py_module,
get_calling_frame,
)
from guppylang.span import SourceMap
from guppylang.tys.subst import Inst
from guppylang.tys.ty import NumericType

Expand Down Expand Up @@ -77,8 +78,12 @@ class _Guppy:
# The currently-alive GuppyModules, associated with a Python file/module
_modules: dict[ModuleIdentifier, GuppyModule]

# Storage for source code that has been read by the compiler
_sources: SourceMap

def __init__(self) -> None:
self._modules = {}
self._sources = SourceMap()

@overload
def __call__(self, arg: PyFunc) -> RawFunctionDef: ...
Expand Down Expand Up @@ -300,7 +305,7 @@ def custom(
mod = module or self.get_module()

def dec(f: PyFunc) -> RawCustomFunctionDef:
func_ast, docstring = parse_py_func(f)
func_ast, docstring = parse_py_func(f, self._sources)
if not has_empty_body(func_ast):
raise GuppyError(
"Body of custom function declaration must be empty",
Expand Down Expand Up @@ -360,7 +365,9 @@ def constant(
) -> RawConstDef:
"""Adds a constant to a module, backed by a `hugr.val.Value`."""
module = module or self.get_module()
type_ast = _parse_expr_string(ty, f"Not a valid Guppy type: `{ty}`")
type_ast = _parse_expr_string(
ty, f"Not a valid Guppy type: `{ty}`", self._sources
)
defn = RawConstDef(DefId.fresh(module), name, None, type_ast, value)
module.register_def(defn)
return defn
Expand All @@ -375,7 +382,9 @@ def extern(
) -> RawExternDef:
"""Adds an extern symbol to a module."""
module = module or self.get_module()
type_ast = _parse_expr_string(ty, f"Not a valid Guppy type: `{ty}`")
type_ast = _parse_expr_string(
ty, f"Not a valid Guppy type: `{ty}`", self._sources
)
defn = RawExternDef(
DefId.fresh(module), name, None, symbol or name, constant, type_ast
)
Expand Down Expand Up @@ -444,7 +453,7 @@ def registered_modules(self) -> KeysView[ModuleIdentifier]:
guppy = _Guppy()


def _parse_expr_string(ty_str: str, parse_err: str) -> ast.expr:
def _parse_expr_string(ty_str: str, parse_err: str, sources: SourceMap) -> ast.expr:
"""Helper function to parse expressions that are provided as strings.
Tries to infer the source location were the given string was defined by inspecting
Expand All @@ -460,6 +469,7 @@ def _parse_expr_string(ty_str: str, parse_err: str) -> ast.expr:
if caller_frame := get_calling_frame():
info = inspect.getframeinfo(caller_frame)
if caller_module := inspect.getmodule(caller_frame):
sources.add_file(info.filename)
source_lines, _ = inspect.getsourcelines(caller_module)
source = "".join(source_lines)
annotate_location(expr_ast, source, info.filename, 0)
Expand Down
4 changes: 3 additions & 1 deletion guppylang/definition/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from hugr.build.dfg import DefinitionBuilder, OpVar
from hugr.ext import Package

from guppylang.span import SourceMap

if TYPE_CHECKING:
from guppylang.checker.core import Globals
from guppylang.compiler.core import CompiledGlobals
Expand Down Expand Up @@ -92,7 +94,7 @@ class ParsableDef(Definition):
"""

@abstractmethod
def parse(self, globals: "Globals") -> ParsedDef:
def parse(self, globals: "Globals", sources: SourceMap) -> ParsedDef:
"""Performs parsing and validation, returning a definition that can be checked.
The provided globals contain all other raw definitions that have been defined.
Expand Down
3 changes: 2 additions & 1 deletion guppylang/definition/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from guppylang.compiler.core import CompiledGlobals, DFContainer
from guppylang.definition.common import CompilableDef, ParsableDef
from guppylang.definition.value import CompiledValueDef, ValueDef
from guppylang.span import SourceMap
from guppylang.tys.parsing import type_from_ast


Expand All @@ -22,7 +23,7 @@ class RawConstDef(ParsableDef):

description: str = field(default="constant", init=False)

def parse(self, globals: Globals) -> "ConstDef":
def parse(self, globals: Globals, sources: SourceMap) -> "ConstDef":
"""Parses and checks the user-provided signature of the function."""
return ConstDef(
self.id,
Expand Down
3 changes: 2 additions & 1 deletion guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from guppylang.definition.value import CallReturnWires, CompiledCallableDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.nodes import GlobalCall
from guppylang.span import SourceMap
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import (
FuncInput,
Expand Down Expand Up @@ -56,7 +57,7 @@ class RawCustomFunctionDef(ParsableDef):

description: str = field(default="function", init=False)

def parse(self, globals: "Globals") -> "CustomFunctionDef":
def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
"""Parses and checks the user-provided signature of the custom function.
The signature is optional if custom type checking logic is provided by the user.
Expand Down
5 changes: 3 additions & 2 deletions guppylang/definition/declaration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from guppylang.definition.value import CallableDef, CallReturnWires, CompiledCallableDef
from guppylang.error import GuppyError
from guppylang.nodes import GlobalCall
from guppylang.span import SourceMap
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import Type, type_to_row

Expand All @@ -32,9 +33,9 @@ class RawFunctionDecl(ParsableDef):
python_scope: PyScope
description: str = field(default="function", init=False)

def parse(self, globals: Globals) -> "CheckedFunctionDecl":
def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
"""Parses and checks the user-provided signature of the function."""
func_ast, docstring = parse_py_func(self.python_func)
func_ast, docstring = parse_py_func(self.python_func, sources)
ty = check_signature(func_ast, globals.with_python_scope(self.python_scope))
if not has_empty_body(func_ast):
raise GuppyError(
Expand Down
3 changes: 2 additions & 1 deletion guppylang/definition/extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from guppylang.compiler.core import CompiledGlobals, DFContainer
from guppylang.definition.common import CompilableDef, ParsableDef
from guppylang.definition.value import CompiledValueDef, ValueDef
from guppylang.span import SourceMap
from guppylang.tys.parsing import type_from_ast


Expand All @@ -22,7 +23,7 @@ class RawExternDef(ParsableDef):

description: str = field(default="extern", init=False)

def parse(self, globals: Globals) -> "ExternDef":
def parse(self, globals: Globals, sources: SourceMap) -> "ExternDef":
"""Parses and checks the user-provided signature of the function."""
return ExternDef(
self.id,
Expand Down
13 changes: 8 additions & 5 deletions guppylang/definition/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from guppylang.error import GuppyError
from guppylang.ipython_inspect import find_ipython_def, is_running_ipython
from guppylang.nodes import GlobalCall
from guppylang.span import SourceMap
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import FunctionType, Type, type_to_row

Expand Down Expand Up @@ -53,9 +54,9 @@ class RawFunctionDef(ParsableDef):

description: str = field(default="function", init=False)

def parse(self, globals: Globals) -> "ParsedFunctionDef":
def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
"""Parses and checks the user-provided signature of the function."""
func_ast, docstring = parse_py_func(self.python_func)
func_ast, docstring = parse_py_func(self.python_func, sources)
ty = check_signature(func_ast, globals.with_python_scope(self.python_scope))
if ty.parametrized:
raise GuppyError(
Expand Down Expand Up @@ -220,7 +221,7 @@ def compile_inner(self, globals: CompiledGlobals) -> None:
compile_global_func_def(self, self.func_def, globals)


def parse_py_func(f: PyFunc) -> tuple[ast.FunctionDef, str | None]:
def parse_py_func(f: PyFunc, sources: SourceMap) -> tuple[ast.FunctionDef, str | None]:
source_lines, line_offset = inspect.getsourcelines(f)
source = "".join(source_lines) # Lines already have trailing \n's
source = textwrap.dedent(source)
Expand All @@ -234,10 +235,12 @@ def parse_py_func(f: PyFunc) -> tuple[ast.FunctionDef, str | None]:
defn = find_ipython_def(func_ast.name)
if defn is not None:
file = f"<{defn.cell_name}>"
sources.add_file(file, source)
else:
file = inspect.getsourcefile(f)
if file is None:
raise GuppyError("Couldn't determine source file for function")
if file is None:
raise GuppyError("Couldn't determine source file for function")
sources.add_file(file)
annotate_location(func_ast, source, file, line_offset)
if not isinstance(func_ast, ast.FunctionDef):
raise GuppyError("Expected a function definition", func_ast)
Expand Down
9 changes: 6 additions & 3 deletions guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.ipython_inspect import find_ipython_def, is_running_ipython
from guppylang.span import SourceMap
from guppylang.tys.arg import Argument
from guppylang.tys.param import Parameter, check_all_args
from guppylang.tys.parsing import type_from_ast
Expand Down Expand Up @@ -63,9 +64,9 @@ def __getitem__(self, item: Any) -> "RawStructDef":
"""
return self

def parse(self, globals: Globals) -> "ParsedStructDef":
def parse(self, globals: Globals, sources: SourceMap) -> "ParsedStructDef":
"""Parses the raw class object into an AST and checks that it is well-formed."""
cls_def = parse_py_class(self.python_class)
cls_def = parse_py_class(self.python_class, sources)
if cls_def.keywords:
raise GuppyError("Unexpected keyword", cls_def.keywords[0])

Expand Down Expand Up @@ -232,7 +233,7 @@ def compile(self, args: list[Wire]) -> list[Wire]:
return [constructor_def]


def parse_py_class(cls: type) -> ast.ClassDef:
def parse_py_class(cls: type, sources: SourceMap) -> ast.ClassDef:
"""Parses a Python class object into an AST."""
# If we are running IPython, `inspect.getsourcelines` works only for builtins
# (guppy stdlib), but not for most/user-defined classes - see:
Expand All @@ -254,6 +255,8 @@ def parse_py_class(cls: type) -> ast.ClassDef:
file = inspect.getsourcefile(cls)
if file is None:
raise GuppyError("Couldn't determine source file for class")
# Store the source file in our cache
sources.add_file(file)
annotate_location(cls_ast, source, file, line_offset)
if not isinstance(cls_ast, ast.ClassDef):
raise GuppyError("Expected a class definition", cls_ast)
Expand Down
15 changes: 12 additions & 3 deletions guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError, pretty_errors
from guppylang.experimental import enable_experimental_features
from guppylang.span import SourceMap

PyClass = type
PyFunc = Callable[..., Any]
Expand Down Expand Up @@ -73,6 +74,9 @@ class GuppyModule:
# `_register_buffered_instance_funcs` is called. This way, we can associate
_instance_func_buffer: dict[str, RawDef] | None

# Storage for source code that has been read by the compiler
_sources: SourceMap

def __init__(self, name: str, import_builtins: bool = True):
self.name = name
self._globals = Globals({}, {}, {}, {})
Expand All @@ -86,6 +90,10 @@ def __init__(self, name: str, import_builtins: bool = True):
self._raw_type_defs = {}
self._checked_defs = {}

from guppylang.decorator import guppy

self._sources = guppy._sources

# Import builtin module
if import_builtins:
import guppylang.prelude.builtins as builtins
Expand Down Expand Up @@ -249,14 +257,15 @@ def checked(self) -> bool:
def compiled(self) -> bool:
return self._compiled

@staticmethod
def _check_defs(
raw_defs: Mapping[DefId, RawDef], globals: Globals
self, raw_defs: Mapping[DefId, RawDef], globals: Globals
) -> dict[DefId, CheckedDef]:
"""Helper method to parse and check raw definitions."""
raw_globals = globals | Globals(dict(raw_defs), {}, {}, {})
parsed = {
def_id: defn.parse(raw_globals) if isinstance(defn, ParsableDef) else defn
def_id: defn.parse(raw_globals, self._sources)
if isinstance(defn, ParsableDef)
else defn
for def_id, defn in raw_defs.items()
}
parsed_globals = globals | Globals(dict(parsed), {}, {}, {})
Expand Down
29 changes: 29 additions & 0 deletions guppylang/span.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Source spans representing locations in the code being compiled."""

import ast
import linecache
from dataclasses import dataclass
from typing import TypeAlias

Expand Down Expand Up @@ -80,3 +81,31 @@ def to_span(x: ToSpan) -> Span:
x.end_col_offset or x.col_offset,
)
return Span(start, end)


#: List of source lines in a file
SourceLines: TypeAlias = list[str]


class SourceMap:
"""Map holding the source code for all files accessed by the compiler.
Can be used to look up the source code associated with a span.
"""

sources: dict[str, SourceLines]

def __init__(self) -> None:
self.sources = {}

def add_file(self, file: str, content: str | None = None) -> None:
"""Registers a new source file."""
if content is None:
self.sources[file] = [line.rstrip() for line in linecache.getlines(file)]
else:
self.sources[file] = content.splitlines(keepends=False)

def span_lines(self, span: Span, prefix_lines: int = 0) -> list[str]:
return self.sources[span.file][
span.start.line - prefix_lines - 1 : span.end.line
]

0 comments on commit 7f07cc7

Please sign in to comment.