Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[lang]: add revert_on_failure keyword for all external calls #4517

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,59 @@ def get_lucky(contract_address: address) -> int128:
assert c2.get_lucky(c1.address) == 1


def test_external_contract_call_revert_on_failure(get_contract, tx_failed):
target_source = """
@external
def fail(should_raise: bool):
if should_raise:
raise "fail"

@external
def return_value(should_raise: bool) -> uint256:
if should_raise:
raise "fail"
return 123
"""

caller_source = """
interface Target:
def fail(should_raise: bool): nonpayable
def return_value(should_raise: bool) -> uint256: nonpayable

@external
def call_target_fail(target: address, should_raise: bool) -> bool:
success: bool = extcall Target(target).fail(should_raise, revert_on_failure=False)
return success

@external
def call_target_return(target: address, should_raise: bool) -> (bool, uint256):
success: bool = False
result: uint256 = 0
success, result = extcall Target(target).return_value(should_raise, revert_on_failure=False)

return success, result
"""

target = get_contract(target_source)
caller = get_contract(caller_source)

# Test successful call
assert caller.call_target_fail(target.address, False) is True

# Test failed call
assert caller.call_target_fail(target.address, True) is False

# Test successful call with return value
success, result = caller.call_target_return(target.address, False)
assert success is True
assert result == 123

# Test failed call with return value
success, result = caller.call_target_return(target.address, True)
assert success is False
assert result == 0 # Default value


def test_complex_external_contract_call_declaration(get_contract):
contract_1 = """
@external
Expand Down
82 changes: 74 additions & 8 deletions vyper/codegen/external_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.evm.address_space import MEMORY
from vyper.exceptions import TypeCheckFailure
from vyper.semantics.types import InterfaceT, TupleT
from vyper.semantics.types import BoolT, InterfaceT, TupleT
from vyper.semantics.types.function import StateMutability


Expand All @@ -30,6 +30,7 @@ class _CallKwargs:
gas: IRnode
skip_contract_check: bool
default_return_value: IRnode
revert_on_failure: bool


def _pack_arguments(fn_type, args, context):
Expand Down Expand Up @@ -176,6 +177,7 @@ def _bool(x):
gas=unwrap_location(call_kwargs.pop("gas", IRnode("gas"))),
skip_contract_check=_bool(call_kwargs.pop("skip_contract_check", IRnode(0))),
default_return_value=call_kwargs.pop("default_return_value", None),
revert_on_failure=_bool(call_kwargs.pop("revert_on_failure", IRnode(1))), # Default to True
)

if len(call_kwargs) != 0: # pragma: nocover
Expand All @@ -194,21 +196,28 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con
# sanity check
assert fn_type.n_positional_args <= len(args_ir) <= fn_type.n_total_args

# Check for revert_on_failure and get return type
revert_on_failure = call_kwargs.revert_on_failure
return_t = fn_type.return_type

# Prepare the sequence IR
ret = ["seq"]

# this is a sanity check to prevent double evaluation of the external call
# in the codegen pipeline. if the external call gets doubly evaluated,
# a duplicate label exception will get thrown during assembly.
ret.append(eval_once_check(_freshname(call_expr.node_source_code)))

# Pack the arguments
buf, arg_packer, args_ofst, args_len = _pack_arguments(fn_type, args_ir, context)
ret += arg_packer

# Process the return data unpacker for return types
ret_unpacker, ret_ofst, ret_len = _unpack_returndata(
buf, fn_type, call_kwargs, contract_address, context, call_expr
)

ret += arg_packer

# Check contract existence if no return data expected
if fn_type.return_type is None and not call_kwargs.skip_contract_check:
# if we do not expect return data, check that a contract exists at the
# target address. we must perform this check BEFORE the call because
Expand All @@ -218,25 +227,82 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con
# selfdestructs).
ret.append(_extcodesize_check(contract_address))

# Prepare call parameters
gas = call_kwargs.gas
value = call_kwargs.value

# Determine if we need a static call
use_staticcall = fn_type.mutability in (StateMutability.VIEW, StateMutability.PURE)
if context.is_constant():
assert use_staticcall, "typechecker missed this"

# Create the call operation
if use_staticcall:
call_op = ["staticcall", gas, contract_address, args_ofst, args_len, buf, ret_len]
else:
call_op = ["call", gas, contract_address, value, args_ofst, args_len, buf, ret_len]

ret.append(check_external_call(call_op))
# Handle standard case (revert_on_failure=True)
if revert_on_failure:
ret.append(check_external_call(call_op))

return_t = fn_type.return_type
if return_t is not None:
ret.append(ret_unpacker)
if return_t is not None:
ret.append(ret_unpacker)

return IRnode.from_list(ret, typ=return_t, location=MEMORY)

# Handle non-reverting case (revert_on_failure=False)
if return_t is None:
# Return just the success flag when no return type
ret.append(call_op)
return IRnode.from_list(ret, typ=BoolT())
else:
# Create an output buffer for success flag
bool_ty = BoolT()
success_var = context.new_internal_variable(bool_ty)
success_loc = ["mload", success_var]

# Create a temporary variable and store the call result
ret.append(["mstore", success_var, call_op])

# Create a memory location for any default value
default_val_var = context.new_internal_variable(return_t)
# Initialize with zero by default (for integers)
ret.append(["mstore", default_val_var, 0])

# Store the success flag and return data
# Process return data if call succeeded, otherwise use default value
ret.append(
[
"if",
success_var, # If call succeeded
ret_unpacker, # Process return data
["pass"], # Otherwise use default (already zeroed)
]
)

return IRnode.from_list(ret, typ=return_t, location=MEMORY)
# Create a memory location for the result tuple
tuple_loc = context.new_internal_variable(TupleT([bool_ty, return_t]))

# Create an array of operations to copy the return data
# If success=true, use the return data, otherwise use the default value (0)
copy_return_data = [
"seq",
# Success flag (0 or 1)
["mstore", tuple_loc, success_loc],
# If success, use ret_unpacker[2], else use default_val_var
[
"if",
success_loc, # If success
["mstore", ["add", tuple_loc, 32], ["mload", ret_unpacker[2]]], # Use returned data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this do?

["mstore", ["add", tuple_loc, 32], ["mload", default_val_var]], # Use default value
],
]

ret.append(copy_return_data)

# Return the tuple
return IRnode.from_list(ret + [tuple_loc], typ=TupleT([bool_ty, return_t]), location=MEMORY)


def ir_for_external_call(call_expr, context):
Expand Down
17 changes: 17 additions & 0 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def call_site_kwargs(self):
"value": KwargSettings(UINT256_T, 0),
"skip_contract_check": KwargSettings(BoolT(), False, require_literal=True),
"default_return_value": KwargSettings(self.return_type, None),
"revert_on_failure": KwargSettings(BoolT(), True, require_literal=True),
}

def __repr__(self):
Expand Down Expand Up @@ -639,6 +640,16 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]:
except TypeMismatch as e:
raise self._enhance_call_exception(e, expected.ast_source or self.ast_def)

# Check if revert_on_failure is explicitly set to False
revert_kwarg = next((kw for kw in node.keywords if kw.arg == "revert_on_failure"), None)
revert_on_failure = True
if (
revert_kwarg
and isinstance(revert_kwarg.value, vy_ast.NameConstant)
and revert_kwarg.value.value is False
):
revert_on_failure = False

# TODO this should be moved to validate_call_args
for kwarg in node.keywords:
if kwarg.arg in self.call_site_kwargs:
Expand Down Expand Up @@ -670,6 +681,12 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]:
hint = f"Try removing the kwarg: `{modified_line}`"
raise ArgumentException(msg, kwarg, hint=hint)

# Return a tuple of (bool, return_type) when revert_on_failure=False
if not revert_on_failure:
if self.return_type is None:
return BoolT()
return TupleT((BoolT(), self.return_type))

return self.return_type

def to_toplevel_abi_dict(self):
Expand Down