From 7c3cf61665ddc3abb2432f84868b3b6da2be80dd Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 4 May 2023 12:55:06 -0700 Subject: [PATCH] fix: rewrite typechecker journal to handle nested commits (#3375) this commit fixes a bug which was introduced in 66930fdfc. when the typechecker enters a nested loop, it can typecheck the inner loop (committing it), and result in invalid state if the outer loop fails to typecheck. this commit implements a checkpointing system in the typechecker state committer so that it can handle nested changes. it also cleans up the data structure used so that there is a single entry point and users of the data structure do not need to think about implementation details. one drawback of this approach is that it *only* handles changes to the metadata dict. non-idempotent changes to the AST during typechecking (such as the use case we are using them for - caching) should then be restricted to changes to the metadata dict so that they can register automatically with the node metadata journal. --- .../semantics/analysis/test_for_loop.py | 43 +++++++++- vyper/ast/metadata.py | 80 +++++++++++++++++++ vyper/ast/nodes.py | 3 +- vyper/semantics/analysis/__init__.py | 3 - vyper/semantics/analysis/local.py | 11 +-- vyper/semantics/analysis/utils.py | 25 ------ 6 files changed, 128 insertions(+), 37 deletions(-) create mode 100644 vyper/ast/metadata.py diff --git a/tests/functional/semantics/analysis/test_for_loop.py b/tests/functional/semantics/analysis/test_for_loop.py index 13f309181f..71e38d253c 100644 --- a/tests/functional/semantics/analysis/test_for_loop.py +++ b/tests/functional/semantics/analysis/test_for_loop.py @@ -1,7 +1,7 @@ import pytest from vyper.ast import parse_to_ast -from vyper.exceptions import ImmutableViolation +from vyper.exceptions import ImmutableViolation, TypeMismatch from vyper.semantics.analysis import validate_semantics @@ -99,3 +99,44 @@ def baz(): vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): validate_semantics(vyper_module, {}) + + +iterator_inference_codes = [ + """ +@external +def main(): + for j in range(3): + x: uint256 = j + y: uint16 = j + """, # issue 3212 + """ +@external +def foo(): + for i in [1]: + a:uint256 = i + b:uint16 = i + """, # issue 3374 + """ +@external +def foo(): + for i in [1]: + for j in [1]: + a:uint256 = i + b:uint16 = i + """, # issue 3374 + """ +@external +def foo(): + for i in [1,2,3]: + for j in [1,2,3]: + b:uint256 = j + i + c:uint16 = i + """, # issue 3374 +] + + +@pytest.mark.parametrize("code", iterator_inference_codes) +def test_iterator_type_inference_checker(namespace, code): + vyper_module = parse_to_ast(code) + with pytest.raises(TypeMismatch): + validate_semantics(vyper_module, {}) diff --git a/vyper/ast/metadata.py b/vyper/ast/metadata.py new file mode 100644 index 0000000000..30e06e0016 --- /dev/null +++ b/vyper/ast/metadata.py @@ -0,0 +1,80 @@ +import contextlib +from typing import Any + +from vyper.exceptions import VyperException + + +# a commit/rollback scheme for metadata caching. in the case that an +# exception is thrown and caught during type checking (currently, only +# during for loop iterator variable type inference), we can roll back +# any state updates due to type checking. +# this is implemented as a stack of changesets, because we need to +# handle nested rollbacks in the case of nested for loops +class _NodeMetadataJournal: + _NOT_FOUND = object() + + def __init__(self): + self._node_updates: list[dict[tuple[int, str, Any], NodeMetadata]] = [] + + def register_update(self, metadata, k): + prev = metadata.get(k, self._NOT_FOUND) + self._node_updates[-1][(id(metadata), k)] = (metadata, prev) + + @contextlib.contextmanager + def enter(self): + self._node_updates.append({}) + try: + yield + except VyperException as e: + # note: would be better to only catch typechecker exceptions here. + self._rollback_inner() + raise e from e + else: + self._commit_inner() + + def _rollback_inner(self): + for (_, k), (metadata, prev) in self._node_updates[-1].items(): + if prev is self._NOT_FOUND: + metadata.pop(k, None) + else: + metadata[k] = prev + self._pop_inner() + + def _commit_inner(self): + inner = self._pop_inner() + + if len(self._node_updates) == 0: + return + + outer = self._node_updates[-1] + + # register with previous frame in case inner gets commited + # but outer needs to be rolled back + for (_, k), (metadata, prev) in inner.items(): + if (id(metadata), k) not in outer: + outer[(id(metadata), k)] = (metadata, prev) + + def _pop_inner(self): + return self._node_updates.pop() + + +class NodeMetadata(dict): + """ + A data structure which allows for journaling. + """ + + _JOURNAL: _NodeMetadataJournal = _NodeMetadataJournal() + + def __setitem__(self, k, v): + # if we are in a context where we need to journal, add + # this to the changeset. + if len(self._JOURNAL._node_updates) != 0: + self._JOURNAL.register_update(self, k) + + super().__setitem__(k, v) + + @classmethod + @contextlib.contextmanager + def enter_typechecker_speculation(cls): + with cls._JOURNAL.enter(): + yield diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 19e50d8895..5e6f8473a0 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -5,6 +5,7 @@ import sys from typing import Any, Optional, Union +from vyper.ast.metadata import NodeMetadata from vyper.compiler.settings import VYPER_ERROR_CONTEXT_LINES, VYPER_ERROR_LINE_NUMBERS from vyper.exceptions import ( ArgumentException, @@ -254,7 +255,7 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): """ self.set_parent(parent) self._children: set = set() - self._metadata: dict = {} + self._metadata: NodeMetadata = NodeMetadata() for field_name in NODE_SRC_ATTRIBUTES: # when a source offset is not available, use the parent's source offset diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index 306f876558..5977a87812 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -12,6 +12,3 @@ def validate_semantics(vyper_ast, interface_codes): with namespace.enter_scope(): add_module_namespace(vyper_ast, interface_codes) validate_functions(vyper_ast) - - # clean up. not sure if this is necessary, but do it for hygiene's sake. - _ExprAnalyser._reset_taint() diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 73fb5c9167..f9bd0db297 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,6 +1,7 @@ from typing import Optional from vyper import ast as vy_ast +from vyper.ast.metadata import NodeMetadata from vyper.ast.validation import validate_call_args from vyper.exceptions import ( ExceptionList, @@ -21,7 +22,6 @@ from vyper.semantics.analysis.base import DataLocation, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( - _ExprAnalyser, get_common_types, get_exact_type_from_node, get_expr_info, @@ -453,20 +453,17 @@ def visit_For(self, node): raise exc.with_annotation(node) from None try: - for n in node.body: - self.visit(n) + with NodeMetadata.enter_typechecker_speculation(): + for n in node.body: + self.visit(n) except (TypeMismatch, InvalidOperation) as exc: for_loop_exceptions.append(exc) - # rollback any changes to the tree - _ExprAnalyser._rollback_taint() else: # type information is applied directly here because the # scope is closed prior to the call to # `StatementAnnotationVisitor` node.target._metadata["type"] = type_ - # perf - persist all calculated types - _ExprAnalyser._commit_taint() # success -- bail out instead of error handling. return diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index ae154e98bc..136012f9ea 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -57,13 +57,6 @@ class _ExprAnalyser: class's method resolution order is examined to decide which method to call. """ - # this allows for a very simple commit/rollback scheme for metadata - # caching. in the case that an exception is thrown and caught during - # type checking (currently, only during for loop iterator variable - # type inference), we can roll back any state updates due to type - # checking. - _tainted_nodes: set[tuple[vy_ast.VyperNode, str]] = set() - def __init__(self): self.namespace = get_namespace() @@ -171,27 +164,9 @@ def get_possible_types_from_node(self, node, include_type_exprs=False): ret.sort(key=lambda k: (k.bits, not k.is_signed), reverse=True) node._metadata[k] = ret - # register with list of tainted nodes, in case the cache - # needs to be invalidated in case of a state rollback - self._tainted_nodes.add((node, k)) return node._metadata[k].copy() - @classmethod - def _rollback_taint(cls): - for node, k in cls._tainted_nodes: - node._metadata.pop(k, None) - # taint has been rolled back, no need to track it anymore - cls._reset_taint() - - @classmethod - def _commit_taint(cls): - cls._reset_taint() - - @classmethod - def _reset_taint(cls): - cls._tainted_nodes.clear() - def _find_fn(self, node): # look for a type-check method for each class in the given class mro for name in [i.__name__ for i in type(node).mro()]: