diff --git a/tests/unit/cli/storage_layout/__init__.py b/tests/unit/cli/storage_layout/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/cli/storage_layout/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py index ece2743b81..d490d2008f 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout.py +++ b/tests/unit/cli/storage_layout/test_storage_layout.py @@ -1,21 +1,6 @@ from vyper.compiler import compile_code -from vyper.evm.opcodes import version_check - -def _adjust_storage_layout_for_cancun(layout): - def _go(layout): - for _varname, item in layout.items(): - if "slot" in item and isinstance(item["slot"], int): - item["slot"] -= 1 - else: - # recurse to submodule - _go(item) - - if version_check(begin="cancun"): - layout["transient_storage_layout"] = { - "$.nonreentrant_key": layout["storage_layout"].pop("$.nonreentrant_key") - } - _go(layout["storage_layout"]) +from .utils import adjust_storage_layout_for_cancun def test_storage_layout(): @@ -55,19 +40,18 @@ def public_foo3(): pass """ - out = compile_code(code, output_formats=["layout"]) - expected = { "storage_layout": { - "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, - "foo": {"slot": 1, "type": "HashMap[address, uint256]"}, - "arr": {"slot": 2, "type": "DynArray[uint256, 3]"}, - "baz": {"slot": 6, "type": "Bytes[65]"}, - "bar": {"slot": 10, "type": "uint256"}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock", "n_slots": 1}, + "foo": {"slot": 1, "type": "HashMap[address, uint256]", "n_slots": 1}, + "arr": {"slot": 2, "type": "DynArray[uint256, 3]", "n_slots": 4}, + "baz": {"slot": 6, "type": "Bytes[65]", "n_slots": 4}, + "bar": {"slot": 10, "type": "uint256", "n_slots": 1}, } } - _adjust_storage_layout_for_cancun(expected) + adjust_storage_layout_for_cancun(expected) + out = compile_code(code, output_formats=["layout"]) assert out["layout"] == expected @@ -88,12 +72,9 @@ def __init__(): "SYMBOL": {"length": 64, "offset": 0, "type": "String[32]"}, "DECIMALS": {"length": 32, "offset": 64, "type": "uint8"}, }, - "storage_layout": { - "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, - "name": {"slot": 1, "type": "String[32]"}, - }, + "storage_layout": {"name": {"slot": 1, "type": "String[32]", "n_slots": 2}}, } - _adjust_storage_layout_for_cancun(expected_layout) + adjust_storage_layout_for_cancun(expected_layout) out = compile_code(code, output_formats=["layout"]) assert out["layout"] == expected_layout @@ -137,13 +118,12 @@ def __init__(): }, }, "storage_layout": { - "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, - "counter": {"slot": 1, "type": "uint256"}, - "counter2": {"slot": 2, "type": "uint256"}, - "a_library": {"supply": {"slot": 3, "type": "uint256"}}, + "counter": {"slot": 1, "type": "uint256", "n_slots": 1}, + "counter2": {"slot": 2, "type": "uint256", "n_slots": 1}, + "a_library": {"supply": {"slot": 3, "type": "uint256", "n_slots": 1}}, }, } - _adjust_storage_layout_for_cancun(expected_layout) + adjust_storage_layout_for_cancun(expected_layout) out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) assert out["layout"] == expected_layout @@ -187,13 +167,12 @@ def __init__(): }, }, "storage_layout": { - "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, - "counter": {"slot": 1, "type": "uint256"}, - "a_library": {"supply": {"slot": 2, "type": "uint256"}}, - "counter2": {"slot": 3, "type": "uint256"}, + "counter": {"slot": 1, "type": "uint256", "n_slots": 1}, + "a_library": {"supply": {"slot": 2, "type": "uint256", "n_slots": 1}}, + "counter2": {"slot": 3, "type": "uint256", "n_slots": 1}, }, } - _adjust_storage_layout_for_cancun(expected_layout) + adjust_storage_layout_for_cancun(expected_layout) out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) assert out["layout"] == expected_layout @@ -271,14 +250,14 @@ def bar(): }, }, "storage_layout": { - "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, - "counter": {"slot": 1, "type": "uint256"}, - "lib2": {"storage_variable": {"slot": 2, "type": "uint256"}}, - "counter2": {"slot": 3, "type": "uint256"}, - "a_library": {"supply": {"slot": 4, "type": "uint256"}}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock", "n_slots": 1}, + "counter": {"slot": 1, "type": "uint256", "n_slots": 1}, + "lib2": {"storage_variable": {"slot": 2, "type": "uint256", "n_slots": 1}}, + "counter2": {"slot": 3, "type": "uint256", "n_slots": 1}, + "a_library": {"supply": {"slot": 4, "type": "uint256", "n_slots": 1}}, }, } - _adjust_storage_layout_for_cancun(expected_layout) + adjust_storage_layout_for_cancun(expected_layout) out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) assert out["layout"] == expected_layout @@ -351,16 +330,15 @@ def foo() -> uint256: }, }, "storage_layout": { - "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, - "counter": {"slot": 1, "type": "uint256"}, + "counter": {"slot": 1, "type": "uint256", "n_slots": 1}, "lib2": { - "lib1": {"supply": {"slot": 2, "type": "uint256"}}, - "storage_variable": {"slot": 3, "type": "uint256"}, + "lib1": {"supply": {"slot": 2, "type": "uint256", "n_slots": 1}}, + "storage_variable": {"slot": 3, "type": "uint256", "n_slots": 1}, }, - "counter2": {"slot": 4, "type": "uint256"}, + "counter2": {"slot": 4, "type": "uint256", "n_slots": 1}, }, } - _adjust_storage_layout_for_cancun(expected_layout) + adjust_storage_layout_for_cancun(expected_layout) out = compile_code(code, input_bundle=input_bundle, output_formats=["layout"]) assert out["layout"] == expected_layout diff --git a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py index 707c94c3fc..f02a8471e2 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py +++ b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py @@ -3,6 +3,7 @@ import pytest from vyper.compiler import compile_code +from vyper.evm.opcodes import version_check from vyper.exceptions import StorageLayoutException @@ -12,11 +13,11 @@ def test_storage_layout_overrides(): b: uint256""" storage_layout_overrides = { - "a": {"type": "uint256", "slot": 1}, - "b": {"type": "uint256", "slot": 0}, + "a": {"type": "uint256", "slot": 1, "n_slots": 1}, + "b": {"type": "uint256", "slot": 0, "n_slots": 1}, } - expected_output = {"storage_layout": storage_layout_overrides, "code_layout": {}} + expected_output = {"storage_layout": storage_layout_overrides} out = compile_code( code, output_formats=["layout"], storage_layout_override=storage_layout_overrides @@ -61,18 +62,26 @@ def public_foo3(): """ storage_layout_override = { - "$.nonreentrant_key": {"type": "nonreentrant lock", "slot": 8}, - "foo": {"type": "HashMap[address, uint256]", "slot": 1}, - "baz": {"type": "Bytes[65]", "slot": 2}, - "bar": {"type": "uint256", "slot": 6}, + "$.nonreentrant_key": {"type": "nonreentrant lock", "slot": 8, "n_slots": 1}, + "foo": {"type": "HashMap[address, uint256]", "slot": 1, "n_slots": 1}, + "baz": {"type": "Bytes[65]", "slot": 2, "n_slots": 4}, + "bar": {"type": "uint256", "slot": 6, "n_slots": 1}, } + if version_check(begin="cancun"): + del storage_layout_override["$.nonreentrant_key"] - expected_output = {"storage_layout": storage_layout_override, "code_layout": {}} + expected_output = {"storage_layout": storage_layout_override} out = compile_code( code, output_formats=["layout"], storage_layout_override=storage_layout_override ) + # adjust transient storage layout + if version_check(begin="cancun"): + expected_output["transient_storage_layout"] = { + "$.nonreentrant_key": {"n_slots": 1, "slot": 0, "type": "nonreentrant lock"} + } + assert out["layout"] == expected_output @@ -118,16 +127,55 @@ def test_override_nonreentrant_slot(): def foo(): pass """ - storage_layout_override = {"$.nonreentrant_key": {"slot": 2**256, "type": "nonreentrant key"}} - exception_regex = re.escape( - f"Invalid storage slot for var $.nonreentrant_key, out of bounds: {2**256}" - ) - with pytest.raises(StorageLayoutException, match=exception_regex): - compile_code( - code, output_formats=["layout"], storage_layout_override=storage_layout_override + if version_check(begin="cancun"): + del storage_layout_override["$.nonreentrant_key"] + assert ( + compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) + is not None + ) + + else: + exception_regex = re.escape( + f"Invalid storage slot for var $.nonreentrant_key, out of bounds: {2**256}" ) + with pytest.raises(StorageLayoutException, match=exception_regex): + compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) + + +def test_override_missing_nonreentrant_key(): + code = """ +@nonreentrant +@external +def foo(): + pass + """ + + storage_layout_override = {} + + if version_check(begin="cancun"): + assert ( + compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) + is not None + ) + # in cancun, nonreentrant key is allocated in transient storage and can't be overridden + return + else: + exception_regex = re.escape( + "Could not find storage slot for $.nonreentrant_key." + " Have you used the correct storage layout file?" + ) + with pytest.raises(StorageLayoutException, match=exception_regex): + compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) def test_incomplete_overrides(): @@ -139,9 +187,225 @@ def test_incomplete_overrides(): with pytest.raises( StorageLayoutException, - match="Could not find storage_slot for symbol. " + match="Could not find storage slot for symbol. " "Have you used the correct storage layout file?", ): compile_code( code, output_formats=["layout"], storage_layout_override=storage_layout_override ) + + +@pytest.mark.requires_evm_version("cancun") +def test_override_with_immutables_and_transient(): + code = """ +some_local: transient(uint256) +some_immutable: immutable(uint256) +name: public(String[64]) + +@deploy +def __init__(): + some_immutable = 5 + """ + + storage_layout_override = {"name": {"slot": 10, "type": "String[64]", "n_slots": 3}} + + out = compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) + + expected_output = { + "storage_layout": storage_layout_override, + "transient_storage_layout": {"some_local": {"slot": 1, "type": "uint256", "n_slots": 1}}, + "code_layout": {"some_immutable": {"offset": 0, "type": "uint256", "length": 32}}, + } + + assert out["layout"] == expected_output + + +def test_override_modules(make_input_bundle): + # test module storage layout, with initializes in an imported module + # note code repetition with test_storage_layout.py; maybe refactor to + # some fixtures + lib1 = """ +supply: uint256 +SYMBOL: immutable(String[32]) +DECIMALS: immutable(uint8) + +@deploy +def __init__(): + SYMBOL = "VYPR" + DECIMALS = 18 + """ + lib2 = """ +import lib1 + +initializes: lib1 + +counter: uint256 +storage_variable: uint256 +immutable_variable: immutable(uint256) + +@deploy +def __init__(s: uint256): + immutable_variable = s + lib1.__init__() + +@internal +def decimals() -> uint8: + return lib1.DECIMALS + """ + code = """ +import lib1 as a_library +import lib2 + +counter: uint256 # test shadowing +some_immutable: immutable(DynArray[uint256, 10]) + +# for fun: initialize lib2 in front of lib1 +initializes: lib2 + +counter2: uint256 + +uses: a_library + +@deploy +def __init__(): + some_immutable = [1, 2, 3] + + lib2.__init__(17) + +@external +def foo() -> uint256: + return a_library.supply + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + override = { + "counter": {"slot": 5, "type": "uint256", "n_slots": 1}, + "lib2": { + "lib1": {"supply": {"slot": 12, "type": "uint256", "n_slots": 1}}, + "storage_variable": {"slot": 34, "type": "uint256", "n_slots": 1}, + "counter": {"slot": 15, "type": "uint256", "n_slots": 1}, + }, + "counter2": {"slot": 171, "type": "uint256", "n_slots": 1}, + } + out = compile_code( + code, output_formats=["layout"], input_bundle=input_bundle, storage_layout_override=override + ) + + expected_output = { + "storage_layout": override, + "code_layout": { + "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, + "lib2": { + "lib1": { + "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, + }, + "immutable_variable": {"length": 32, "offset": 448, "type": "uint256"}, + }, + }, + } + + assert out["layout"] == expected_output + + +def test_module_collision(make_input_bundle): + # test collisions between modules which are "siblings" in the import tree + # some fixtures + lib1 = """ +supply: uint256 + """ + lib2 = """ +counter: uint256 + """ + code = """ +import lib1 as a_library +import lib2 + +# for fun: initialize lib2 in front of lib1 +initializes: lib2 +initializes: a_library + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + override = { + "lib2": {"counter": {"slot": 15, "type": "uint256", "n_slots": 1}}, + "a_library": {"supply": {"slot": 15, "type": "uint256", "n_slots": 1}}, + } + + with pytest.raises( + StorageLayoutException, + match="Storage collision! Tried to assign 'a_library.supply' to" + " slot 15 but it has already been reserved by 'lib2.counter'", + ): + compile_code( + code, + output_formats=["layout"], + input_bundle=input_bundle, + storage_layout_override=override, + ) + + +def test_module_collision2(make_input_bundle): + # test "parent-child" collisions + lib1 = """ +supply: uint256 + """ + code = """ +import lib1 + +counter: uint256 + +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + override = { + "counter": {"slot": 15, "type": "uint256", "n_slots": 1}, + "lib1": {"supply": {"slot": 15, "type": "uint256", "n_slots": 1}}, + } + + with pytest.raises( + StorageLayoutException, + match="Storage collision! Tried to assign 'lib1.supply' to" + " slot 15 but it has already been reserved by 'counter'", + ): + compile_code( + code, + output_formats=["layout"], + input_bundle=input_bundle, + storage_layout_override=override, + ) + + +def test_module_overlap(make_input_bundle): + # test a collision which only overlaps on one word + lib1 = """ +supply: uint256[2] + """ + code = """ +import lib1 + +counter: uint256[2] + +initializes: lib1 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + override = { + "counter": {"slot": 15, "type": "uint256[2]", "n_slots": 2}, + "lib1": {"supply": {"slot": 16, "type": "uint256[2]", "n_slots": 2}}, + } + + with pytest.raises( + StorageLayoutException, + match="Storage collision! Tried to assign 'lib1.supply' to" + " slot 16 but it has already been reserved by 'counter'", + ): + compile_code( + code, + output_formats=["layout"], + input_bundle=input_bundle, + storage_layout_override=override, + ) diff --git a/tests/unit/cli/storage_layout/utils.py b/tests/unit/cli/storage_layout/utils.py new file mode 100644 index 0000000000..6e67886b0d --- /dev/null +++ b/tests/unit/cli/storage_layout/utils.py @@ -0,0 +1,17 @@ +from vyper.evm.opcodes import version_check + + +def adjust_storage_layout_for_cancun(layout): + def _go(layout): + for _varname, item in layout.items(): + if "slot" in item and isinstance(item["slot"], int): + item["slot"] -= 1 + else: + # recurse to submodule + _go(item) + + if version_check(begin="cancun"): + nonreentrant = layout["storage_layout"].pop("$.nonreentrant_key", None) + if nonreentrant is not None: + layout["transient_storage_layout"] = {"$.nonreentrant_key": nonreentrant} + _go(layout["storage_layout"]) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 10b4833e67..6f437395c6 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -12,6 +12,7 @@ from vyper.compiler.settings import OptimizationLevel, Settings, anchor_settings, merge_settings from vyper.ir import compile_ir, optimizer from vyper.semantics import analyze_module, set_data_positions, validate_compilation_target +from vyper.semantics.analysis.data_positions import generate_layout_export from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout @@ -180,7 +181,9 @@ def compilation_target(self): @cached_property def storage_layout(self) -> StorageLayout: module_ast = self.compilation_target - return set_data_positions(module_ast, self.storage_layout_override) + set_data_positions(module_ast, self.storage_layout_override) + + return generate_layout_export(module_ast) @property def global_ctx(self) -> ModuleT: diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 718581c20c..026e0626e7 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -194,7 +194,7 @@ def getter_ast(self) -> Optional[vy_ast.VyperNode]: def set_position(self, position: VarOffset) -> None: if self.position is not None: - raise CompilerPanic("Position was already assigned") + raise CompilerPanic(f"Position was already assigned: {self}") assert isinstance(position, VarOffset) # sanity check self.position = position @@ -207,6 +207,10 @@ def is_state_variable(self): def get_size(self) -> int: return self.typ.get_size_in(self.location) + @property + def is_storage(self): + return self.location == DataLocation.STORAGE + @property def is_transient(self): return self.location == DataLocation.TRANSIENT diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index e5e8b998ca..5f6702668f 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -1,5 +1,6 @@ +import json from collections import defaultdict -from typing import Generic, TypeVar +from typing import Generic, Optional, TypeVar from vyper import ast as vy_ast from vyper.evm.opcodes import version_check @@ -11,7 +12,7 @@ def set_data_positions( vyper_module: vy_ast.Module, storage_layout_overrides: StorageLayout = None -) -> StorageLayout: +) -> None: """ Parse the annotated Vyper AST, determine data positions for all variables, and annotate the AST nodes with the position data. @@ -22,14 +23,19 @@ def set_data_positions( Top-level Vyper AST node that has already been annotated with type data. """ if storage_layout_overrides is not None: - # extract code layout with no overrides - code_offsets = _allocate_layout_r(vyper_module, immutables_only=True)["code_layout"] - storage_slots = set_storage_slots_with_overrides(vyper_module, storage_layout_overrides) - return {"storage_layout": storage_slots, "code_layout": code_offsets} + # allocate code layout with no overrides + _allocate_layout_r(vyper_module, no_storage=True) + _allocate_with_overrides(vyper_module, storage_layout_overrides) - ret = _allocate_layout_r(vyper_module) - assert isinstance(ret, defaultdict) - return dict(ret) # convert back to dict + # sanity check that generated layout file is the same as the input. + roundtrip = generate_layout_export(vyper_module).get(_LAYOUT_KEYS[DataLocation.STORAGE], {}) + if roundtrip != storage_layout_overrides: + msg = "Computed storage layout does not match override file!\n" + msg += f"expected: {json.dumps(storage_layout_overrides)}\n\n" + msg += f"got:\n{json.dumps(roundtrip)}" + raise CompilerPanic(msg) + else: + _allocate_layout_r(vyper_module) _T = TypeVar("_T") @@ -45,6 +51,7 @@ def __setitem__(self, k, v): # some name that the user cannot assign to a variable GLOBAL_NONREENTRANT_KEY = "$.nonreentrant_key" +NONREENTRANT_KEY_SIZE = 1 class SimpleAllocator: @@ -55,7 +62,7 @@ def __init__(self, max_slot: int = 2**256, starting_slot: int = 0): self._slot = starting_slot self._max_slot = max_slot - def allocate_slot(self, n, var_name, node=None): + def allocate_slot(self, n, node=None): ret = self._slot if self._slot + n >= self._max_slot: raise StorageLayoutException( @@ -67,7 +74,7 @@ def allocate_slot(self, n, var_name, node=None): return ret def allocate_global_nonreentrancy_slot(self): - slot = self.allocate_slot(1, GLOBAL_NONREENTRANT_KEY) + slot = self.allocate_slot(NONREENTRANT_KEY_SIZE) assert slot == self._starting_slot return slot @@ -141,74 +148,105 @@ def _reserve_slot(self, slot: int, var_name: str) -> None: self.occupied_slots[slot] = var_name -def set_storage_slots_with_overrides( - vyper_module: vy_ast.Module, storage_layout_overrides: StorageLayout -) -> StorageLayout: +def _fetch_path(path: list[str], layout: StorageLayout, node: vy_ast.VyperNode): + tmp = layout + qualified_path = ".".join(path) + + for segment in path: + if segment not in tmp: + raise StorageLayoutException( + f"Could not find storage slot for {qualified_path}. " + "Have you used the correct storage layout file?", + node, + ) + tmp = tmp[segment] + + try: + ret = tmp["slot"] + except KeyError as e: + raise StorageLayoutException(f"no storage slot for {qualified_path}", node) from e + + return ret + + +def _allocate_with_overrides(vyper_module: vy_ast.Module, layout: StorageLayout): """ Set storage layout given a layout override file. - Returns the layout as a dict of variable name -> variable info - (Doesn't handle modules, or transient storage) """ - ret: InsertableOnceDict[str, dict] = InsertableOnceDict() - reserved_slots = OverridingStorageAllocator() + allocator = OverridingStorageAllocator() + + nonreentrant_slot = None + if GLOBAL_NONREENTRANT_KEY in layout: + nonreentrant_slot = layout[GLOBAL_NONREENTRANT_KEY]["slot"] + + _allocate_with_overrides_r(vyper_module, layout, allocator, nonreentrant_slot, []) + +def _allocate_with_overrides_r( + vyper_module: vy_ast.Module, + layout: StorageLayout, + allocator: OverridingStorageAllocator, + global_nonreentrant_slot: Optional[int], + path: list[str], +): # Search through function definitions to find non-reentrant functions for node in vyper_module.get_children(vy_ast.FunctionDef): - type_ = node._metadata["func_type"] + fn_t = node._metadata["func_type"] # Ignore functions without non-reentrant - if not type_.nonreentrant: + if not fn_t.nonreentrant: continue - variable_name = GLOBAL_NONREENTRANT_KEY - - # re-entrant key was already identified - if variable_name in ret: + # if reentrancy keys get allocated in transient storage, we don't + # override them + if get_reentrancy_key_location() == DataLocation.TRANSIENT: continue # Expect to find this variable within the storage layout override - if variable_name in storage_layout_overrides: - reentrant_slot = storage_layout_overrides[variable_name]["slot"] - # Ensure that this slot has not been used, and prevents other storage variables - # from using the same slot - reserved_slots.reserve_slot_range(reentrant_slot, 1, variable_name) - - type_.set_reentrancy_key_position(VarOffset(reentrant_slot)) - - ret[variable_name] = {"type": "nonreentrant lock", "slot": reentrant_slot} - else: + if global_nonreentrant_slot is None: raise StorageLayoutException( - f"Could not find storage_slot for {variable_name}. " + f"Could not find storage slot for {GLOBAL_NONREENTRANT_KEY}. " "Have you used the correct storage layout file?", node, ) - # Iterate through variables - for node in vyper_module.get_children(vy_ast.VariableDecl): - # Ignore immutable parameters - if node.get("annotation.func.id") == "immutable": + # prevent other storage variables from using the same slot + if allocator.occupied_slots.get(global_nonreentrant_slot) != GLOBAL_NONREENTRANT_KEY: + allocator.reserve_slot_range( + global_nonreentrant_slot, NONREENTRANT_KEY_SIZE, GLOBAL_NONREENTRANT_KEY + ) + + fn_t.set_reentrancy_key_position(VarOffset(global_nonreentrant_slot)) + + for node in _get_allocatable(vyper_module): + if isinstance(node, vy_ast.InitializesDecl): + module_info = node._metadata["initializes_info"].module_info + + sub_path = [*path, module_info.alias] + _allocate_with_overrides_r( + module_info.module_node, layout, allocator, global_nonreentrant_slot, sub_path + ) continue + # Iterate through variables + # Ignore immutables and transient variables varinfo = node.target._metadata["varinfo"] + if not varinfo.is_storage: + continue + # Expect to find this variable within the storage layout overrides - if node.target.id in storage_layout_overrides: - var_slot = storage_layout_overrides[node.target.id]["slot"] - storage_length = varinfo.typ.storage_size_in_words - # Ensure that all required storage slots are reserved, and prevents other variables - # from using these slots - reserved_slots.reserve_slot_range(var_slot, storage_length, node.target.id) - varinfo.set_position(VarOffset(var_slot)) - - ret[node.target.id] = {"type": str(varinfo.typ), "slot": var_slot} - else: - raise StorageLayoutException( - f"Could not find storage_slot for {node.target.id}. " - "Have you used the correct storage layout file?", - node, - ) + varname = node.target.id + varpath = [*path, varname] + qualified_varname = ".".join(varpath) - return ret + var_slot = _fetch_path(varpath, layout, node) + + storage_length = varinfo.typ.storage_size_in_words + # Ensure that all required storage slots are reserved, and + # prevent other variables from using these slots + allocator.reserve_slot_range(var_slot, storage_length, qualified_varname) + varinfo.set_position(VarOffset(var_slot)) def _get_allocatable(vyper_module: vy_ast.Module) -> list[vy_ast.VyperNode]: @@ -229,7 +267,7 @@ def get_reentrancy_key_location() -> DataLocation: } -def _allocate_nonreentrant_keys(vyper_module, allocators): +def _set_nonreentrant_keys(vyper_module, allocators): SLOT = allocators.get_global_nonreentrant_key_slot() for node in vyper_module.get_children(vy_ast.FunctionDef): @@ -244,73 +282,116 @@ def _allocate_nonreentrant_keys(vyper_module, allocators): def _allocate_layout_r( - vyper_module: vy_ast.Module, allocators: Allocators = None, immutables_only=False -) -> StorageLayout: + vyper_module: vy_ast.Module, allocators: Allocators = None, no_storage=False +): """ Parse module-level Vyper AST to calculate the layout of storage variables. Returns the layout as a dict of variable name -> variable info """ - global_ = False if allocators is None: - global_ = True allocators = Allocators() # always allocate nonreentrancy slot, so that adding or removing # reentrancy protection from a contract does not change its layout allocators.allocate_global_nonreentrancy_slot() - ret: defaultdict[str, InsertableOnceDict[str, dict]] = defaultdict(InsertableOnceDict) - # tag functions with the global nonreentrant key - if not immutables_only: - _allocate_nonreentrant_keys(vyper_module, allocators) - - layout_key = _LAYOUT_KEYS[get_reentrancy_key_location()] - # TODO this could have better typing but leave it untyped until - # we nail down the format better - if global_ and GLOBAL_NONREENTRANT_KEY not in ret[layout_key]: - slot = allocators.get_global_nonreentrant_key_slot() - ret[layout_key][GLOBAL_NONREENTRANT_KEY] = {"type": "nonreentrant lock", "slot": slot} + if not no_storage or get_reentrancy_key_location() == DataLocation.TRANSIENT: + _set_nonreentrant_keys(vyper_module, allocators) for node in _get_allocatable(vyper_module): if isinstance(node, vy_ast.InitializesDecl): module_info = node._metadata["initializes_info"].module_info - module_layout = _allocate_layout_r(module_info.module_node, allocators) - module_alias = module_info.alias - for layout_key in module_layout.keys(): - assert layout_key in _LAYOUT_KEYS.values() - ret[layout_key][module_alias] = module_layout[layout_key] + _allocate_layout_r(module_info.module_node, allocators, no_storage) continue assert isinstance(node, vy_ast.VariableDecl) - # skip non-state variables varinfo = node.target._metadata["varinfo"] + + # skip things we don't need to allocate, like constants if not varinfo.is_state_variable(): continue - location = varinfo.location - if immutables_only and location != DataLocation.CODE: + if no_storage and varinfo.is_storage: continue - allocator = allocators.get_allocator(location) + allocator = allocators.get_allocator(varinfo.location) size = varinfo.get_size() # CMC 2021-07-23 note that HashMaps get assigned a slot here # using the same allocator (even though there is not really # any risk of physical overlap) - offset = allocator.allocate_slot(size, node.target.id, node) - + offset = allocator.allocate_slot(size, node) varinfo.set_position(VarOffset(offset)) + +# get the layout for export +def generate_layout_export(vyper_module: vy_ast.Module): + return _generate_layout_export_r(vyper_module) + + +def _generate_layout_export_r(vyper_module): + ret: defaultdict[str, InsertableOnceDict[str, dict]] = defaultdict(InsertableOnceDict) + + for node in _get_allocatable(vyper_module): + if isinstance(node, vy_ast.InitializesDecl): + module_info = node._metadata["initializes_info"].module_info + module_layout = _generate_layout_export_r(module_info.module_node) + module_alias = module_info.alias + for layout_key in module_layout.keys(): + assert layout_key in _LAYOUT_KEYS.values() + + # lift the nonreentrancy key (if any) into the outer dict + # note that lifting can leave the inner dict empty, which + # should be filtered (below) for cleanliness + nonreentrant = module_layout[layout_key].pop(GLOBAL_NONREENTRANT_KEY, None) + if nonreentrant is not None and GLOBAL_NONREENTRANT_KEY not in ret[layout_key]: + ret[layout_key][GLOBAL_NONREENTRANT_KEY] = nonreentrant + + # add the module as a nested dict, but only if it is non-empty + if len(module_layout[layout_key]) != 0: + ret[layout_key][module_alias] = module_layout[layout_key] + + continue + + assert isinstance(node, vy_ast.VariableDecl) + varinfo = node.target._metadata["varinfo"] + # skip non-state variables + if not varinfo.is_state_variable(): + continue + + location = varinfo.location layout_key = _LAYOUT_KEYS[location] type_ = varinfo.typ + size = varinfo.get_size() + offset = varinfo.position.position + # this could have better typing but leave it untyped until # we understand the use case better if location == DataLocation.CODE: item = {"type": str(type_), "length": size, "offset": offset} elif location in (DataLocation.STORAGE, DataLocation.TRANSIENT): - item = {"type": str(type_), "slot": offset} + item = {"type": str(type_), "n_slots": size, "slot": offset} else: # pragma: nocover raise CompilerPanic("unreachable") ret[layout_key][node.target.id] = item + for fn in vyper_module.get_children(vy_ast.FunctionDef): + fn_t = fn._metadata["func_type"] + if not fn_t.nonreentrant: + continue + + location = get_reentrancy_key_location() + layout_key = _LAYOUT_KEYS[location] + + if GLOBAL_NONREENTRANT_KEY in ret[layout_key]: + break + + slot = fn_t.reentrancy_key_position.position + ret[layout_key][GLOBAL_NONREENTRANT_KEY] = { + "type": "nonreentrant lock", + "slot": slot, + "n_slots": NONREENTRANT_KEY_SIZE, + } + break + return ret