Skip to content

Commit

Permalink
fix: call internal functions from constructor (#2496)
Browse files Browse the repository at this point in the history
this commit allows the user to call internal functions from the     
`__init__` function. it does this by generating a call graph during the
annotation phase and then generating code for the functions called from
the init function for during deploy code generation

this also has a performance benefit (compiler time) because we can get
rid of the two-pass method for tracing frame size.

now that we have a call graph, this commit also introduces a topsort of
functions based on the call dependency tree. this ensures we can compile
functions that call functions that occur after them in the source code.

lastly, this commit also refactors vyper/codegen/module.py so that the
payable logic is cleaner, it uses properties instead of calculations
more, and cleans up properties on IRnode, FunctionSignature and Context.
  • Loading branch information
charles-cooper authored May 10, 2022
1 parent 03b2f1d commit 4b44ee7
Show file tree
Hide file tree
Showing 22 changed files with 383 additions and 341 deletions.
36 changes: 18 additions & 18 deletions examples/stock/company.vy
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@ def __init__(_company: address, _total_shares: uint256, initial_price: uint256):
# The company holds all the shares at first, but can sell them all.
self.holdings[self.company] = _total_shares

# Find out how much stock the company holds
@view
@internal
def _stockAvailable() -> uint256:
return self.holdings[self.company]

# Public function to allow external access to _stockAvailable
@view
@external
Expand All @@ -69,12 +63,6 @@ def buyStock():
# Log the buy event.
log Buy(msg.sender, buy_order)

# Find out how much stock any address (that's owned by someone) has.
@view
@internal
def _getHolding(_stockholder: address) -> uint256:
return self.holdings[_stockholder]

# Public function to allow external access to _getHolding
@view
@external
Expand Down Expand Up @@ -135,12 +123,6 @@ def payBill(vendor: address, amount: uint256):
# Log the payment event.
log Pay(vendor, amount)

# Return the amount in wei that a company has raised in stock offerings.
@view
@internal
def _debt() -> uint256:
return (self.totalShares - self._stockAvailable()) * self.price

# Public function to allow external access to _debt
@view
@external
Expand All @@ -154,3 +136,21 @@ def debt() -> uint256:
@external
def worth() -> uint256:
return self.balance - self._debt()

# Return the amount in wei that a company has raised in stock offerings.
@view
@internal
def _debt() -> uint256:
return (self.totalShares - self._stockAvailable()) * self.price

# Find out how much stock the company holds
@view
@internal
def _stockAvailable() -> uint256:
return self.holdings[self.company]

# Find out how much stock any address (that's owned by someone) has.
@view
@internal
def _getHolding(_stockholder: address) -> uint256:
return self.holdings[_stockholder]
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ def get_contract_module(source_code, *args, **kwargs):


def get_compiler_gas_estimate(code, func):
ir_runtime = compiler.phases.CompilerData(code).ir_runtime
sigs = compiler.phases.CompilerData(code).function_signatures
if func:
return compiler.utils.build_gas_estimates(ir_runtime)[func] + 22000
return compiler.utils.build_gas_estimates(sigs)[func] + 22000
else:
return sum(compiler.utils.build_gas_estimates(ir_runtime).values()) + 22000
return sum(compiler.utils.build_gas_estimates(sigs).values()) + 22000


def check_gas_on_chain(w3, tester, code, func=None, res=None):
Expand Down
27 changes: 27 additions & 0 deletions tests/parser/features/test_immutable.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,30 @@ def get_idx_two() -> int128:
c = get_contract(code, *values)
assert c.get_my_list() == expected_values
assert c.get_idx_two() == expected_values[2][2][2]


@pytest.mark.parametrize("n", range(5))
def test_internal_function_with_immutables(get_contract, n):
code = """
@internal
def foo() -> uint256:
self.counter += 1
return self.counter
counter: uint256
VALUE: immutable(uint256)
@external
def __init__(x: uint256):
self.counter = x
self.foo()
VALUE = self.foo()
self.foo()
@external
def get_immutable() -> uint256:
return VALUE
"""

c = get_contract(code, n)
assert c.get_immutable() == n + 2
31 changes: 31 additions & 0 deletions tests/parser/features/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,34 @@ def __init__(a: uint256):
assert "CALLDATALOAD" in opcodes
assert "CALLDATACOPY" not in opcodes[:ir_return_idx]
assert "CALLDATALOAD" not in opcodes[:ir_return_idx]


def test_init_calls_internal(get_contract, assert_compile_failed, assert_tx_failed):
code = """
foo: public(uint8)
@internal
def bar(x: uint256) -> uint8:
return convert(x, uint8) * 7
@external
def __init__(a: uint256):
self.foo = self.bar(a)
@external
def baz() -> uint8:
return self.bar(convert(self.foo, uint256))
"""
n = 5
c = get_contract(code, n)
assert c.foo() == n * 7
assert c.baz() == 245 # 5*7*7

n = 6
c = get_contract(code, n)
assert c.foo() == n * 7
assert_tx_failed(lambda: c.baz())

n = 255
assert_compile_failed(lambda: get_contract(code, n))

n = 256
assert_compile_failed(lambda: get_contract(code, n))
2 changes: 1 addition & 1 deletion vyper/ast/signatures/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .function_signature import FunctionSignature, VariableRecord
from .function_signature import FrameInfo, FunctionSignature, VariableRecord
24 changes: 22 additions & 2 deletions vyper/ast/signatures/function_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vyper.codegen.ir_node import Encoding
from vyper.codegen.types import NodeType, parse_type
from vyper.exceptions import StructureException
from vyper.utils import cached_property, mkalphanum
from vyper.utils import MemoryPositions, cached_property, mkalphanum

# dict from function names to signatures
FunctionSignatures = Dict[str, "FunctionSignature"]
Expand Down Expand Up @@ -66,6 +66,16 @@ class FunctionArg:
ast_source: vy_ast.VyperNode


@dataclass
class FrameInfo:
frame_start: int
frame_size: int

@property
def mem_used(self):
return self.frame_size + MemoryPositions.RESERVED_MEMORY


# Function signature object
class FunctionSignature:
def __init__(
Expand All @@ -84,19 +94,25 @@ def __init__(
self.return_type = return_type
self.mutability = mutability
self.internal = internal
self.gas = None
self.gas_estimate = None
self.nonreentrant_key = nonreentrant_key
self.func_ast_code = func_ast_code
self.is_from_json = is_from_json

self.set_default_args()

# frame info is metadata that will be generated during codegen.
self.frame_info: Optional[FrameInfo] = None

def __str__(self):
input_name = "def " + self.name + "(" + ",".join([str(arg.typ) for arg in self.args]) + ")"
if self.return_type:
return input_name + " -> " + str(self.return_type) + ":"
return input_name + ":"

def set_frame_info(self, frame_info):
self.frame_info = frame_info

@cached_property
def _ir_identifier(self) -> str:
# we could do a bit better than this but it just needs to be unique
Expand Down Expand Up @@ -228,3 +244,7 @@ def is_default_func(self):
@property
def is_init_func(self):
return self.name == "__init__"

@property
def is_regular_function(self):
return not self.is_default_func and not self.is_init_func
4 changes: 2 additions & 2 deletions vyper/ast/signatures/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def mk_full_signature_from_json(abi):
def _get_external_signatures(global_ctx, sig_formatter=lambda x: x):
ret = []

for code in global_ctx._defs:
for func_ast in global_ctx._function_defs:
sig = FunctionSignature.from_definition(
code,
func_ast,
sigs=global_ctx._contracts,
custom_structs=global_ctx._structs,
)
Expand Down
51 changes: 17 additions & 34 deletions vyper/codegen/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vyper.ast import VyperNode
from vyper.ast.signatures.function_signature import VariableRecord
from vyper.codegen.types import NodeType
from vyper.exceptions import CompilerPanic, FunctionDeclarationException
from vyper.exceptions import CompilerPanic


class Constancy(enum.Enum):
Expand All @@ -22,11 +22,7 @@ def __init__(
vars_=None,
sigs=None,
forvars=None,
return_type=None,
constancy=Constancy.Mutable,
is_internal=False,
is_payable=False,
# method_id="",
sig=None,
):
# In-memory variables, in the form (name, memory location, type)
Expand All @@ -41,9 +37,6 @@ def __init__(
# Variables defined in for loops, e.g. for i in range(6): ...
self.forvars = forvars or {}

# Return type of the function
self.return_type = return_type

# Is the function constant?
self.constancy = constancy

Expand All @@ -53,14 +46,9 @@ def __init__(
# Whether we are currently parsing a range expression
self.in_range_expr = False

# Is the function payable?
self.is_payable = is_payable

# List of custom structs that have been defined.
self.structs = global_ctx._structs

self.is_internal = is_internal

# store global context
self.global_ctx = global_ctx

Expand All @@ -73,23 +61,25 @@ def __init__(
# Not intended to be accessed directly
self.memory_allocator = memory_allocator

self._callee_frame_sizes = []

# Intermented values, used for internal IDs
# Incremented values, used for internal IDs
self._internal_var_iter = 0
self._scope_id_iter = 0

def is_constant(self):
return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr

def register_callee(self, frame_size):
self._callee_frame_sizes.append(frame_size)
# convenience propreties
@property
def is_payable(self):
return self.sig.mutability == "payable"

@property
def max_callee_frame_size(self):
if len(self._callee_frame_sizes) == 0:
return 0
return max(self._callee_frame_sizes)
def is_internal(self):
return self.sig.internal

@property
def return_type(self):
return self.sig.return_type

#
# Context Managers
Expand Down Expand Up @@ -248,23 +238,16 @@ def lookup_internal_function(self, method_name, args_ir, ast_source):
the kwargs which need to be filled in by the compiler
"""

sig = self.sigs["self"].get(method_name, None)

def _check(cond, s="Unreachable"):
if not cond:
raise CompilerPanic(s)

sig = self.sigs["self"].get(method_name, None)
if sig is None:
raise FunctionDeclarationException(
"Function does not exist or has not been declared yet "
"(reminder: functions cannot call functions later in code "
"than themselves)",
ast_source,
)

_check(sig.internal) # sanity check
# should have been caught during type checking, sanity check anyway
# these should have been caught during type checking; sanity check
_check(sig is not None)
_check(sig.internal)
_check(len(sig.base_args) <= len(args_ir) <= len(sig.args))

# more sanity check, that the types match
# _check(all(l.typ == r.typ for (l, r) in zip(args_ir, sig.args))

Expand Down
2 changes: 1 addition & 1 deletion vyper/codegen/function_definitions/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .common import generate_ir_for_function, is_default_func, is_initializer # noqa
from .common import generate_ir_for_function # noqa
Loading

0 comments on commit 4b44ee7

Please sign in to comment.