diff --git a/docs/built-in-functions.rst b/docs/built-in-functions.rst index 796e776801..6343d28010 100644 --- a/docs/built-in-functions.rst +++ b/docs/built-in-functions.rst @@ -289,7 +289,19 @@ Vyper has three built-ins for contract creation; all three contract creation bui def foo(_topic: bytes32, _data: Bytes[100]): raw_log([_topic], _data) -.. py:function:: raw_revert(data: Bytes) -> None +.. py:function:: raw_return(data: Bytes[...]) -> None + + Provides low level access to the ``RETURN`` opcode, reverting execution with the specified data returned. + + * ``data``: Data representing the error message causing the revert. + + .. code-block:: vyper + + @external + def foo(_data: Bytes[100]): + raw_return(_data) + +.. py:function:: raw_revert(data: Bytes[...]) -> None Provides low level access to the ``REVERT`` opcode, reverting execution with the specified data returned. diff --git a/tests/functional/builtins/codegen/test_raw_return.py b/tests/functional/builtins/codegen/test_raw_return.py new file mode 100644 index 0000000000..2148695bd4 --- /dev/null +++ b/tests/functional/builtins/codegen/test_raw_return.py @@ -0,0 +1,56 @@ +from eth.codecs import abi + + +def test_raw_return(env, get_contract): + code = """ +@external +def foo(data: Bytes[128]) -> DynArray[uint256, 2]: + raw_return(data) + """ + + c = get_contract(code) + + data = [1, 2] + abi_encoded = abi.encode("(uint256[])", (data,)) + assert c.foo(abi_encoded) == data + + +def test_proxy_raw_return(env, get_contract): + impl1 = """ +@external +def greet() -> String[32]: + return "Hello" + """ + + impl2 = """ +# test delegate calling with a different type, but byte-compatible in abi +@external +def greet() -> Bytes[32]: + return b"Goodbye" + """ + + proxy = """ +target: address + +@external +def set_implementation(target: address): + self.target = target + +@external +def greet() -> String[32]: + # forward msg.data to the implementation contract + data: Bytes[128] = raw_call(self.target, msg.data, is_delegate_call=True, max_outsize=128) + raw_return(data) + """ + + impl_c1 = get_contract(impl1) + impl_c2 = get_contract(impl2) + + proxy_c = get_contract(proxy) + + proxy_c.set_implementation(impl_c1.address) + assert proxy_c.greet() == impl_c1.greet() == "Hello" + + proxy_c.set_implementation(impl_c2.address) + assert impl_c2.greet() == b"Goodbye" + assert proxy_c.greet() == "Goodbye" # note: unsafe casted from bytes diff --git a/tests/functional/codegen/features/test_reverting.py b/tests/functional/builtins/codegen/test_raw_revert.py similarity index 100% rename from tests/functional/codegen/features/test_reverting.py rename to tests/functional/builtins/codegen/test_raw_revert.py diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index d9f8f48e01..72f5c82b57 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -1225,8 +1225,7 @@ def build_IR(self, expr, args, kwargs, contact): return IRnode.from_list(["blobhash", args[0]], typ=BYTES32_T) -class RawRevert(BuiltinFunctionT): - _id = "raw_revert" +class _RawReturnOrRevert(BuiltinFunctionT): _inputs = [("data", BytesT.any())] _return_type = None _is_terminus = True @@ -1239,12 +1238,27 @@ def infer_arg_types(self, node, expected_return_typ=None): data_type = get_possible_types_from_node(node.args[0]).pop() return [data_type] + @property + def OPCODE(self): + # must be implemented by subclass + raise NotImplementedError() + @process_inputs def build_IR(self, expr, args, kwargs, context): - with ensure_in_memory(args[0], context).cache_when_complex("err_buf") as (b, buf): + with ensure_in_memory(args[0], context).cache_when_complex("buf") as (b, buf): data = bytes_data_ptr(buf) len_ = get_bytearray_length(buf) - return b.resolve(IRnode.from_list(["revert", data, len_])) + return b.resolve(IRnode.from_list([self.OPCODE, data, len_])) + + +class RawRevert(_RawReturnOrRevert): + _id = "raw_revert" + OPCODE = "revert" + + +class RawReturn(_RawReturnOrRevert): + _id = "raw_return" + OPCODE = "return" class RawLog(BuiltinFunctionT): @@ -2687,6 +2701,7 @@ def _try_fold(self, node): "raw_call": RawCall(), "raw_log": RawLog(), "raw_revert": RawRevert(), + "raw_return": RawReturn(), "create_minimal_proxy_to": CreateMinimalProxyTo(), "create_forwarder_to": CreateForwarderTo(), "create_copy_of": CreateCopyOf(),