Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[tool]: delay global constraint check #3810

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/functional/syntax/modules/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,3 +1208,23 @@ def test_ownership_decl_errors_not_swallowed(make_input_bundle):
with pytest.raises(UndeclaredDefinition) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "'lib2' has not been declared."


def test_partial_compilation(make_input_bundle):
lib1 = """
counter: uint256
"""
main = """
import lib1

uses: lib1

@internal
def use_lib1():
lib1.counter += 1
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
assert (
compile_code(main, input_bundle=input_bundle, output_formats=["annotated_ast_dict"])
is not None
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
)
2 changes: 1 addition & 1 deletion tests/unit/ast/nodes/test_hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ def foo():
def test_invalid_checksum(code, dummy_input_bundle):
with pytest.raises(InvalidLiteral):
vyper_module = vy_ast.parse_to_ast(code)
semantics.validate_semantics(vyper_module, dummy_input_bundle)
semantics.analyze_module(vyper_module, dummy_input_bundle)
12 changes: 6 additions & 6 deletions tests/unit/semantics/analysis/test_array_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
TypeMismatch,
UndeclaredDefinition,
)
from vyper.semantics.analysis import validate_semantics
from vyper.semantics.analysis import analyze_module


@pytest.mark.parametrize("value", ["address", "Bytes[10]", "decimal", "bool"])
Expand All @@ -22,7 +22,7 @@ def foo(b: {value}):
"""
vyper_module = parse_to_ast(code)
with pytest.raises(TypeMismatch):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


@pytest.mark.parametrize("value", ["1.0", "0.0", "'foo'", "0x00", "b'\x01'", "False"])
Expand All @@ -37,7 +37,7 @@ def foo():
"""
vyper_module = parse_to_ast(code)
with pytest.raises(TypeMismatch):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


@pytest.mark.parametrize("value", [-1, 3, -(2**127), 2**127 - 1, 2**256 - 1])
Expand All @@ -52,7 +52,7 @@ def foo():
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ArrayIndexException):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


@pytest.mark.parametrize("value", ["b", "self.b"])
Expand All @@ -67,7 +67,7 @@ def foo():
"""
vyper_module = parse_to_ast(code)
with pytest.raises(UndeclaredDefinition):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


@pytest.mark.parametrize("value", ["a", "foo", "int128"])
Expand All @@ -82,4 +82,4 @@ def foo():
"""
vyper_module = parse_to_ast(code)
with pytest.raises(InvalidReference):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)
10 changes: 5 additions & 5 deletions tests/unit/semantics/analysis/test_cyclic_function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from vyper.ast import parse_to_ast
from vyper.exceptions import CallViolation, StructureException
from vyper.semantics.analysis import validate_semantics
from vyper.semantics.analysis import analyze_module


def test_self_function_call(dummy_input_bundle):
Expand All @@ -13,7 +13,7 @@ def foo():
"""
vyper_module = parse_to_ast(code)
with pytest.raises(CallViolation):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


def test_cyclic_function_call(dummy_input_bundle):
Expand All @@ -28,7 +28,7 @@ def bar():
"""
vyper_module = parse_to_ast(code)
with pytest.raises(CallViolation):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


def test_multi_cyclic_function_call(dummy_input_bundle):
Expand All @@ -51,7 +51,7 @@ def potato():
"""
vyper_module = parse_to_ast(code)
with pytest.raises(CallViolation):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


def test_global_ann_assign_callable_no_crash(dummy_input_bundle):
Expand All @@ -64,5 +64,5 @@ def foo(to : address):
"""
vyper_module = parse_to_ast(code)
with pytest.raises(StructureException) as excinfo:
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)
assert excinfo.value.message == "HashMap[address, uint256] is not callable"
28 changes: 14 additions & 14 deletions tests/unit/semantics/analysis/test_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from vyper.ast import parse_to_ast
from vyper.exceptions import ArgumentException, ImmutableViolation, StructureException, TypeMismatch
from vyper.semantics.analysis import validate_semantics
from vyper.semantics.analysis import analyze_module


def test_modify_iterator_function_outside_loop(dummy_input_bundle):
Expand All @@ -21,7 +21,7 @@ def bar():
pass
"""
vyper_module = parse_to_ast(code)
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


def test_pass_memory_var_to_other_function(dummy_input_bundle):
Expand All @@ -41,7 +41,7 @@ def bar():
self.foo(a)
"""
vyper_module = parse_to_ast(code)
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


def test_modify_iterator(dummy_input_bundle):
Expand All @@ -56,7 +56,7 @@ def bar():
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


def test_bad_keywords(dummy_input_bundle):
Expand All @@ -70,7 +70,7 @@ def bar(n: uint256):
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ArgumentException):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


def test_bad_bound(dummy_input_bundle):
Expand All @@ -84,7 +84,7 @@ def bar(n: uint256):
"""
vyper_module = parse_to_ast(code)
with pytest.raises(StructureException):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


def test_modify_iterator_function_call(dummy_input_bundle):
Expand All @@ -103,7 +103,7 @@ def bar():
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


def test_modify_iterator_recursive_function_call(dummy_input_bundle):
Expand All @@ -126,7 +126,7 @@ def baz():
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


def test_modify_iterator_recursive_function_call_topsort(dummy_input_bundle):
Expand All @@ -149,7 +149,7 @@ def foo():
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation) as e:
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)

assert e.value._message == "Cannot modify loop variable `a`"

Expand All @@ -170,7 +170,7 @@ def foo():
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation) as e:
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)

assert e.value._message == "Cannot modify loop variable `a`"

Expand All @@ -189,7 +189,7 @@ def foo():
self.b[self.a[1]] = i
"""
vyper_module = parse_to_ast(code)
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


def test_modify_iterator_siblings(dummy_input_bundle):
Expand All @@ -207,7 +207,7 @@ def foo():
self.f.b += i
"""
vyper_module = parse_to_ast(code)
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)


def test_modify_subscript_barrier(dummy_input_bundle):
Expand All @@ -229,7 +229,7 @@ def foo():
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation) as e:
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)

assert e.value._message == "Cannot modify loop variable `b`"

Expand Down Expand Up @@ -272,4 +272,4 @@ def foo():
def test_iterator_type_inference_checker(code, dummy_input_bundle):
vyper_module = parse_to_ast(code)
with pytest.raises(TypeMismatch):
validate_semantics(vyper_module, dummy_input_bundle)
analyze_module(vyper_module, dummy_input_bundle)
20 changes: 14 additions & 6 deletions vyper/compiler/phases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vyper.compiler.settings import OptimizationLevel, Settings
from vyper.exceptions import StructureException
from vyper.ir import compile_ir, optimizer
from vyper.semantics import set_data_positions, validate_semantics
from vyper.semantics import analyze_module, set_data_positions, validate_compilation_target
from vyper.semantics.types.function import ContractFunctionT
from vyper.semantics.types.module import ModuleT
from vyper.typing import StorageLayout
Expand Down Expand Up @@ -156,9 +156,19 @@ def vyper_module(self):
def annotated_vyper_module(self) -> vy_ast.Module:
return generate_annotated_ast(self.vyper_module, self.input_bundle)

@cached_property
def compilation_target(self):
"""
Get the annotated AST, and additionally run the global checks
required for a compilation target.
"""
module_t = self.annotated_vyper_module._metadata["type"]
validate_compilation_target(module_t)
return self.annotated_vyper_module

@cached_property
def storage_layout(self) -> StorageLayout:
module_ast = self.annotated_vyper_module
module_ast = self.compilation_target
return set_data_positions(module_ast, self.storage_layout_override)

@property
Expand Down Expand Up @@ -251,13 +261,11 @@ def generate_annotated_ast(vyper_module: vy_ast.Module, input_bundle: InputBundl
-------
vy_ast.Module
Annotated Vyper AST
StorageLayout
Layout of variables in storage
"""
vyper_module = copy.deepcopy(vyper_module)
with input_bundle.search_path(Path(vyper_module.resolved_path).parent):
# note: validate_semantics does type inference on the AST
validate_semantics(vyper_module, input_bundle)
# note: analyze_module does type inference on the AST
analyze_module(vyper_module, input_bundle)

return vyper_module

Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .analysis import validate_semantics
from .analysis import analyze_module, validate_compilation_target
from .analysis.data_positions import set_data_positions
5 changes: 3 additions & 2 deletions vyper/semantics/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .. import types # break a dependency cycle.
from .global_ import validate_semantics
from .global_ import validate_compilation_target
from .module import analyze_module

__all__ = ["validate_semantics"]
__all__ = [validate_compilation_target, analyze_module] # type: ignore[misc]
10 changes: 2 additions & 8 deletions vyper/semantics/analysis/global_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,11 @@

from vyper.exceptions import ExceptionList, InitializerException
from vyper.semantics.analysis.base import InitializesInfo, UsesInfo
from vyper.semantics.analysis.import_graph import ImportGraph
from vyper.semantics.analysis.module import validate_module_semantics_r
from vyper.semantics.types.module import ModuleT


def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT:
ret = validate_module_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface)

_validate_global_initializes_constraint(ret)

return ret
def validate_compilation_target(module_t: ModuleT):
_validate_global_initializes_constraint(module_t)


def _collect_used_modules_r(module_t):
Expand Down
31 changes: 21 additions & 10 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,29 @@
from vyper.utils import OrderedSet


def validate_module_semantics_r(
def analyze_module(
module_ast: vy_ast.Module,
input_bundle: InputBundle,
import_graph: ImportGraph,
is_interface: bool,
import_graph: ImportGraph = None,
is_interface: bool = False,
) -> ModuleT:
"""
Analyze a Vyper module AST node, add all module-level objects to the
namespace, type-check/validate semantics and annotate with type and analysis info
Analyze a Vyper module AST node, recursively analyze all its imports,
add all module-level objects to the namespace, type-check/validate
semantics and annotate with type and analysis info
"""
if import_graph is None:
import_graph = ImportGraph()

return _analyze_module_r(module_ast, input_bundle, import_graph, is_interface)


def _analyze_module_r(
module_ast: vy_ast.Module,
input_bundle: InputBundle,
import_graph: ImportGraph,
is_interface: bool = False,
):
if "type" in module_ast._metadata:
# we don't need to analyse again, skip out
assert isinstance(module_ast._metadata["type"], ModuleT)
Expand Down Expand Up @@ -742,7 +755,7 @@ def _load_import_helper(
module_ast = self._ast_from_file(file)

with override_global_namespace(Namespace()):
module_t = validate_module_semantics_r(
module_t = _analyze_module_r(
module_ast,
self.input_bundle,
import_graph=self._import_graph,
Expand All @@ -762,7 +775,7 @@ def _load_import_helper(
module_ast = self._ast_from_file(file)

with override_global_namespace(Namespace()):
validate_module_semantics_r(
_analyze_module_r(
module_ast,
self.input_bundle,
import_graph=self._import_graph,
Expand Down Expand Up @@ -871,7 +884,5 @@ def _load_builtin_import(level: int, module_str: str) -> InterfaceT:
interface_ast = _parse_and_fold_ast(file)

with override_global_namespace(Namespace()):
module_t = validate_module_semantics_r(
interface_ast, input_bundle, ImportGraph(), is_interface=True
)
module_t = _analyze_module_r(interface_ast, input_bundle, ImportGraph(), is_interface=True)
return module_t.interface
Loading