Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[lang]: add revert_on_failure keyword for all external calls #4517

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,185 @@ def get_lucky(contract_address: address) -> int128:
assert c1.get_lucky() == 1
assert c2.get_lucky(c1.address) == 1

def test_external_contract_call_revert_on_failure_noreturn(get_contract, tx_failed):
target_source = """
@external
def fail(should_raise: bool):
if should_raise:
raise "fail"
"""

caller_source = """
interface Target:
def fail(should_raise: bool): nonpayable

@external
def call_target_fail(target: address, should_raise: bool) -> bool:
success: bool = extcall Target(target).fail(should_raise, revert_on_failure=False)
return success
"""

target = get_contract(target_source)
caller = get_contract(caller_source)

# Test successful call
assert caller.call_target_fail(target.address, False) is True

# Test failed call
assert caller.call_target_fail(target.address, True) is False


def test_external_contract_call_revert_on_failure(get_contract, tx_failed):
target_source = """
@external
def return_value(should_raise: bool) -> uint256:
if should_raise:
raise "fail"
return 123
"""

caller_source = """
interface Target:
def return_value(should_raise: bool) -> uint256: nonpayable

@external
def call_target_return(target: address, should_raise: bool) -> (bool, uint256):
success: bool = False
result: uint256 = 0
success, result = extcall Target(target).return_value(should_raise, revert_on_failure=False)

return success, result
"""

target = get_contract(target_source)
caller = get_contract(caller_source)

# Test successful call with return value
success, result = caller.call_target_return(target.address, False)
assert success is True
assert result == 123

# Test failed call with return value
success, result = caller.call_target_return(target.address, True)
assert success is False
assert result == 0 # Default value

def test_external_call_with_struct_return_type_revert_on_failure(get_contract):
target_source = """
struct Point:
x: uint256
y: uint256

@external
def return_point(should_raise: bool) -> Point:
if should_raise:
raise "fail"
return Point(x=45, y=67)
"""

caller_source = """
struct Point:
x: uint256
y: uint256

interface Target:
def return_point(should_raise: bool) -> Point: nonpayable

@external
def call_target_point(target: address, should_raise: bool) -> (bool, Point):
success: bool = False
result: Point = Point(x=0, y=0)
success, result = extcall Target(target).return_point(should_raise, revert_on_failure=False)

return success, result
"""

target = get_contract(target_source)
caller = get_contract(caller_source)

# Test successful call with struct return value
success, result = caller.call_target_point(target.address, False)
assert success is True
assert result[0] == 45
assert result[1] == 67

# Test failed call with struct return value
success, result = caller.call_target_point(target.address, True)
assert success is False
assert result[0] == 0
assert result[1] == 0


def test_external_call_with_array_return_type_revert_on_failure(get_contract):
target_source = """
@external
def return_array(should_raise: bool) -> DynArray[uint256, 5]:
if should_raise:
raise "fail"
return [1, 2, 3, 4, 5]
"""

caller_source = """
interface Target:
def return_array(should_raise: bool) -> DynArray[uint256, 5]: nonpayable

@external
def call_target_array(target: address, should_raise: bool) -> (bool, DynArray[uint256, 5]):
success: bool = False
result: DynArray[uint256, 5] = []
success, result = extcall Target(target).return_array(should_raise, revert_on_failure=False)

return success, result
"""

target = get_contract(target_source)
caller = get_contract(caller_source)

# Test successful call with array return value
success, result = caller.call_target_array(target.address, False)
assert success is True
assert result == [1, 2, 3, 4, 5]

# Test failed call with array return value
success, result = caller.call_target_array(target.address, True)
assert success is False
assert result == [] # Default empty array


def test_external_call_with_string_return_type_revert_on_failure(get_contract):
target_source = """
@external
def return_string(should_raise: bool) -> String[11]:
if should_raise:
raise "fail"
return "hello vyper"
"""

caller_source = """
interface Target:
def return_string(should_raise: bool) -> String[11]: nonpayable

@external
def call_target_string(target: address, should_raise: bool) -> (bool, String[11]):
success: bool = False
result: String[11] = ""
success, result = extcall Target(target).return_string(should_raise, revert_on_failure=False)

return success, result
"""

target = get_contract(target_source)
caller = get_contract(caller_source)

# Test successful call with string return value
success, result = caller.call_target_string(target.address, False)
assert success is True
assert result == "hello vyper"

# Test failed call with string return value
success, result = caller.call_target_string(target.address, True)
assert success is False
assert result == "" # Default empty string

def test_complex_external_contract_call_declaration(get_contract):
contract_1 = """
Expand Down
61 changes: 53 additions & 8 deletions vyper/codegen/external_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.evm.address_space import MEMORY
from vyper.exceptions import TypeCheckFailure
from vyper.semantics.types import InterfaceT, TupleT
from vyper.semantics.types import BoolT, InterfaceT, TupleT
from vyper.semantics.types.function import StateMutability


Expand All @@ -30,6 +30,7 @@ class _CallKwargs:
gas: IRnode
skip_contract_check: bool
default_return_value: IRnode
revert_on_failure: bool


def _pack_arguments(fn_type, args, context):
Expand Down Expand Up @@ -176,6 +177,7 @@ def _bool(x):
gas=unwrap_location(call_kwargs.pop("gas", IRnode("gas"))),
skip_contract_check=_bool(call_kwargs.pop("skip_contract_check", IRnode(0))),
default_return_value=call_kwargs.pop("default_return_value", None),
revert_on_failure=_bool(call_kwargs.pop("revert_on_failure", IRnode(1))), # Default to True
)

if len(call_kwargs) != 0: # pragma: nocover
Expand All @@ -194,21 +196,28 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con
# sanity check
assert fn_type.n_positional_args <= len(args_ir) <= fn_type.n_total_args

# Check for revert_on_failure and get return type
revert_on_failure = call_kwargs.revert_on_failure
return_t = fn_type.return_type

# Prepare the sequence IR
ret = ["seq"]

# this is a sanity check to prevent double evaluation of the external call
# in the codegen pipeline. if the external call gets doubly evaluated,
# a duplicate label exception will get thrown during assembly.
ret.append(eval_once_check(_freshname(call_expr.node_source_code)))

# Pack the arguments
buf, arg_packer, args_ofst, args_len = _pack_arguments(fn_type, args_ir, context)
ret += arg_packer

# Process the return data unpacker for return types
ret_unpacker, ret_ofst, ret_len = _unpack_returndata(
buf, fn_type, call_kwargs, contract_address, context, call_expr
)

ret += arg_packer

# Check contract existence if no return data expected
if fn_type.return_type is None and not call_kwargs.skip_contract_check:
# if we do not expect return data, check that a contract exists at the
# target address. we must perform this check BEFORE the call because
Expand All @@ -218,25 +227,61 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con
# selfdestructs).
ret.append(_extcodesize_check(contract_address))

# Prepare call parameters
gas = call_kwargs.gas
value = call_kwargs.value

# Determine if we need a static call
use_staticcall = fn_type.mutability in (StateMutability.VIEW, StateMutability.PURE)
if context.is_constant():
assert use_staticcall, "typechecker missed this"

# Create the call operation
if use_staticcall:
call_op = ["staticcall", gas, contract_address, args_ofst, args_len, buf, ret_len]
else:
call_op = ["call", gas, contract_address, value, args_ofst, args_len, buf, ret_len]

ret.append(check_external_call(call_op))
# Handle standard case (revert_on_failure=True)
if revert_on_failure:
ret.append(check_external_call(call_op))

if return_t is not None:
ret.append(ret_unpacker)

return IRnode.from_list(ret, typ=return_t, location=MEMORY)

else:
bool_ty = BoolT()
if return_t is None:
ret.append(call_op)
return IRnode.from_list(ret, typ=bool_ty)

tuple_t = TupleT([bool_ty, return_t])

success_buf = context.new_internal_variable(bool_ty)
tuple_buf = context.new_internal_variable(tuple_t)

store_success = IRnode.from_list(["mstore", success_buf, "success"])
conditional_unpacker = IRnode.from_list(["if", "success", ret_unpacker, "pass"])

handler = IRnode.from_list(["with", "success", call_op,
["seq",
store_success,
conditional_unpacker]])

ret.append(handler)

ret_ofst.typ = return_t

multi = IRnode.from_list(["multi", success_buf, ret_ofst], typ=tuple_t)
res_setter = make_setter(tuple_buf, multi)

ret.append(res_setter)
ret.append(tuple_buf)

return_t = fn_type.return_type
if return_t is not None:
ret.append(ret_unpacker)

return IRnode.from_list(ret, typ=return_t, location=MEMORY)
return IRnode.from_list(ret, typ=tuple_t, location=MEMORY)


def ir_for_external_call(call_expr, context):
Expand Down
17 changes: 17 additions & 0 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def call_site_kwargs(self):
"value": KwargSettings(UINT256_T, 0),
"skip_contract_check": KwargSettings(BoolT(), False, require_literal=True),
"default_return_value": KwargSettings(self.return_type, None),
"revert_on_failure": KwargSettings(BoolT(), True, require_literal=True),
}

def __repr__(self):
Expand Down Expand Up @@ -639,6 +640,16 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]:
except TypeMismatch as e:
raise self._enhance_call_exception(e, expected.ast_source or self.ast_def)

# Check if revert_on_failure is explicitly set to False
revert_kwarg = next((kw for kw in node.keywords if kw.arg == "revert_on_failure"), None)
revert_on_failure = True
if (
revert_kwarg
and isinstance(revert_kwarg.value, vy_ast.NameConstant)
and revert_kwarg.value.value is False
):
revert_on_failure = False

# TODO this should be moved to validate_call_args
for kwarg in node.keywords:
if kwarg.arg in self.call_site_kwargs:
Expand Down Expand Up @@ -670,6 +681,12 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]:
hint = f"Try removing the kwarg: `{modified_line}`"
raise ArgumentException(msg, kwarg, hint=hint)

# Return a tuple of (bool, return_type) when revert_on_failure=False
if not revert_on_failure:
if self.return_type is None:
return BoolT()
return TupleT((BoolT(), self.return_type))

return self.return_type

def to_toplevel_abi_dict(self):
Expand Down
Loading