Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: call graph stability #3370

Merged
68 changes: 68 additions & 0 deletions tests/parser/test_call_graph_stability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import random
import string

import hypothesis.strategies as st
import pytest
from hypothesis import given, settings

import vyper.ast as vy_ast
from vyper.compiler.phases import CompilerData


# random names for functions
@settings(max_examples=20, deadline=None)
@given(
st.lists(
st.tuples(
st.sampled_from(["@pure", "@view", "@nonpayable", "@payable"]),
st.text(alphabet=string.ascii_lowercase, min_size=1),
),
unique_by=lambda x: x[1], # unique on function name
min_size=1,
max_size=10,
)
)
@pytest.mark.fuzzing
def test_call_graph_stability_fuzz(funcs):
def generate_func_def(mutability, func_name, i):
return f"""
@internal
{mutability}
def {func_name}() -> uint256:
return {i}
"""

func_defs = "\n".join(generate_func_def(m, s, i) for i, (m, s) in enumerate(funcs))

for _ in range(10):
func_names = [f for (_, f) in funcs]
random.shuffle(func_names)

self_calls = "\n".join(f" self.{f}()" for f in func_names)
code = f"""
{func_defs}

@external
def foo():
{self_calls}
"""
t = CompilerData(code)

# check the .called_functions data structure on foo() directly
foo = t.vyper_module_folded.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0]
foo_t = foo._metadata["type"]
assert [f.name for f in foo_t.called_functions] == func_names

# now for sanity, ensure the order that the function definitions appear
# in the IR is the same as the order of the calls
sigs = t.function_signatures
del sigs["foo"]
ir = t.ir_runtime
ir_funcs = []
# search for function labels
for d in ir.args: # currently: (seq ... (seq (label foo ...)) ...)
if d.value == "seq" and d.args[0].value == "label":
r = d.args[0].args[0].value
if isinstance(r, str) and r.startswith("internal"):
ir_funcs.append(r)
assert ir_funcs == [f.internal_function_label for f in sigs.values()]
6 changes: 3 additions & 3 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import warnings
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Tuple

from vyper import ast as vy_ast
from vyper.ast.validation import validate_call_args
Expand All @@ -28,7 +28,7 @@
from vyper.semantics.types.shortcuts import UINT256_T
from vyper.semantics.types.subscriptable import TupleT
from vyper.semantics.types.utils import type_from_abi, type_from_annotation
from vyper.utils import keccak256
from vyper.utils import OrderedSet, keccak256


class ContractFunctionT(VyperType):
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(
self.nonreentrant = nonreentrant

# a list of internal functions this function calls
self.called_functions: Set["ContractFunctionT"] = set()
self.called_functions = OrderedSet()

# special kwargs that are allowed in call site
self.call_site_kwargs = {
Expand Down
13 changes: 13 additions & 0 deletions vyper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@
from vyper.exceptions import DecimalOverrideException, InvalidLiteral


class OrderedSet(dict):
"""
a minimal "ordered set" class. this is needed in some places
because, while dict guarantees you can recover insertion order
vanilla sets do not.
no attempt is made to fully implement the set API, will add
functionality as needed.
"""

def add(self, item):
self[item] = None


class DecimalContextOverride(decimal.Context):
def __setattr__(self, name, value):
if name == "prec":
Expand Down