Skip to content

Commit

Permalink
enforce roundtrip equality
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed May 24, 2024
1 parent 8fd93d6 commit b9acf98
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 25 deletions.
58 changes: 40 additions & 18 deletions tests/unit/cli/storage_layout/test_storage_layout_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import pytest

from vyper.compiler import compile_code
from vyper.evm.opcodes import version_check
from vyper.exceptions import StorageLayoutException

from .utils import adjust_storage_layout_for_cancun


def test_storage_layout_overrides():
code = """
Expand Down Expand Up @@ -68,14 +67,19 @@ def public_foo3():
"baz": {"type": "Bytes[65]", "slot": 2, "n_slots": 4},
"bar": {"type": "uint256", "slot": 6, "n_slots": 1},
}
if version_check(begin="cancun"):
del storage_layout_override["$.nonreentrant_key"]

expected_output = {"storage_layout": storage_layout_override}

out = compile_code(
code, output_formats=["layout"], storage_layout_override=storage_layout_override
)

adjust_storage_layout_for_cancun(expected_output, do_adjust_slots=False)
# adjust transient storage layout
expected_output["transient_storage_layout"] = {
"$.nonreentrant_key": {"n_slots": 1, "slot": 0, "type": "nonreentrant lock"}
}

assert out["layout"] == expected_output

Expand Down Expand Up @@ -122,17 +126,26 @@ def test_override_nonreentrant_slot():
def foo():
pass
"""

storage_layout_override = {"$.nonreentrant_key": {"slot": 2**256, "type": "nonreentrant key"}}

exception_regex = re.escape(
f"Invalid storage slot for var $.nonreentrant_key, out of bounds: {2**256}"
)
with pytest.raises(StorageLayoutException, match=exception_regex):
compile_code(
code, output_formats=["layout"], storage_layout_override=storage_layout_override
if version_check(begin="cancun"):
del storage_layout_override["$.nonreentrant_key"]
assert (
compile_code(
code, output_formats=["layout"], storage_layout_override=storage_layout_override
)
is not None
)

else:
exception_regex = re.escape(
f"Invalid storage slot for var $.nonreentrant_key, out of bounds: {2**256}"
)
with pytest.raises(StorageLayoutException, match=exception_regex):
compile_code(
code, output_formats=["layout"], storage_layout_override=storage_layout_override
)


def test_override_missing_nonreentrant_key():
code = """
Expand All @@ -144,14 +157,24 @@ def foo():

storage_layout_override = {}

exception_regex = re.escape(
"Could not find storage slot for $.nonreentrant_key."
" Have you used the correct storage layout file?"
)
with pytest.raises(StorageLayoutException, match=exception_regex):
compile_code(
code, output_formats=["layout"], storage_layout_override=storage_layout_override
if version_check(begin="cancun"):
assert (
compile_code(
code, output_formats=["layout"], storage_layout_override=storage_layout_override
)
is not None
)
# in cancun, nonreentrant key is allocated in transient storage and can't be overridden
return
else:
exception_regex = re.escape(
"Could not find storage slot for $.nonreentrant_key."
" Have you used the correct storage layout file?"
)
with pytest.raises(StorageLayoutException, match=exception_regex):
compile_code(
code, output_formats=["layout"], storage_layout_override=storage_layout_override
)


def test_incomplete_overrides():
Expand Down Expand Up @@ -282,7 +305,6 @@ def foo() -> uint256:
},
},
}
# adjust_storage_layout_for_cancun(expected_output)

assert out["layout"] == expected_output

Expand Down
5 changes: 2 additions & 3 deletions tests/unit/cli/storage_layout/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from vyper.evm.opcodes import version_check


def adjust_storage_layout_for_cancun(layout, do_adjust_slots=True):
def adjust_storage_layout_for_cancun(layout):
def _go(layout):
for _varname, item in layout.items():
if "slot" in item and isinstance(item["slot"], int):
if do_adjust_slots:
item["slot"] -= 1
item["slot"] -= 1
else:
# recurse to submodule
_go(item)
Expand Down
16 changes: 12 additions & 4 deletions vyper/semantics/analysis/data_positions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def set_data_positions(
# allocate code layout with no overrides
_allocate_layout_r(vyper_module, no_storage=True)
_allocate_with_overrides(vyper_module, storage_layout_overrides)

# sanity check that generated layout file is the same as the input.
roundtrip = generate_layout_export(vyper_module).get(_LAYOUT_KEYS[DataLocation.STORAGE], {})
assert roundtrip == storage_layout_overrides, roundtrip
else:
_allocate_layout_r(vyper_module)

Expand Down Expand Up @@ -188,6 +192,11 @@ def _allocate_with_overrides_r(
if not fn_t.nonreentrant:
continue

# if reentrancy keys get allocated in transient storage, we don't
# override them
if get_reentrancy_key_location() == DataLocation.TRANSIENT:
continue

# Expect to find this variable within the storage layout override
if global_nonreentrant_slot is None:
raise StorageLayoutException(
Expand Down Expand Up @@ -253,7 +262,7 @@ def get_reentrancy_key_location() -> DataLocation:
}


def _allocate_nonreentrant_keys(vyper_module, allocators):
def _set_nonreentrant_keys(vyper_module, allocators):
SLOT = allocators.get_global_nonreentrant_key_slot()

for node in vyper_module.get_children(vy_ast.FunctionDef):
Expand Down Expand Up @@ -281,9 +290,8 @@ def _allocate_layout_r(
allocators.allocate_global_nonreentrancy_slot()

# tag functions with the global nonreentrant key
# `no_storage` is slightly confusing, maybe should be `if not overrides`
if not no_storage:
_allocate_nonreentrant_keys(vyper_module, allocators)
if not no_storage or get_reentrancy_key_location() == DataLocation.TRANSIENT:
_set_nonreentrant_keys(vyper_module, allocators)

for node in _get_allocatable(vyper_module):
if isinstance(node, vy_ast.InitializesDecl):
Expand Down

0 comments on commit b9acf98

Please sign in to comment.