Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: storage corruption from re-entrancy locks #2379

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions tests/functional/test_storage_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def __init__():
]
self.foo[0] = [987, 654, 321]
self.foo[1] = [123, 456, 789]

@external
@nonreentrant('lock')
def with_lock():
pass


@external
@nonreentrant('otherlock')
def with_other_lock():
pass
"""


Expand All @@ -57,3 +68,27 @@ def test_storage_slots(get_contract):
]
assert [c.foo(0, i) for i in range(3)] == [987, 654, 321]
assert [c.foo(1, i) for i in range(3)] == [123, 456, 789]


def test_reentrancy_lock(get_contract):
c = get_contract(code)

# if re-entrancy locks are incorrectly placed within storage, these
# calls will either revert or correupt the data that we read later
c.with_lock()
c.with_other_lock()

assert c.a() == ["ok", [4, 5, 6]]
assert [c.b(i) for i in range(2)] == [7, 8]
assert c.c() == b"thisisthirtytwobytesokhowdoyoudo"
assert [c.d(i) for i in range(4)] == [-1, -2, -3, -4]
assert c.e() == "A realllllly long string but we wont use it all"
assert c.f(0) == 33
assert c.g(0) == [b"hello", [-66, 420], "another string"]
assert c.g(1) == [
b"gbye",
[1337, 888],
"whatifthisstringtakesuptheentirelengthwouldthatbesobadidothinkso",
]
assert [c.foo(0, i) for i in range(3)] == [987, 654, 321]
assert [c.foo(1, i) for i in range(3)] == [123, 456, 789]
24 changes: 14 additions & 10 deletions vyper/parser/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from vyper import ast as vy_ast
from vyper.exceptions import CompilerPanic, InvalidType, StructureException
from vyper.signatures.function_signature import ContractRecord, VariableRecord
from vyper.types import InterfaceType, parse_type
from vyper.types import InterfaceType, MappingType, parse_type
from vyper.typing import InterfaceImports


Expand Down Expand Up @@ -185,6 +185,9 @@ def add_globals_and_events(self, item):

item_name, item_attributes = self.get_item_name_and_attributes(item, item_attributes)

# references to `len(self._globals)` are remnants of deprecated code, retained
# to preserve existing interfaces while we complete a larger refactor. location
# and size of storage vars is handled in `vyper.context.validation.data_positions`
if item_name in self._contracts or item_name in self._interfaces:
if self.get_call_func_name(item) == "address":
raise StructureException(
Expand Down Expand Up @@ -221,19 +224,20 @@ def get_nonrentrant_counter(self, key):
"""
Nonrentrant locks use a prefix with a counter to minimise deployment cost of a contract.

All storage variables are allocated exactly one slot, incrementing from 0.
For types that require >1 slot the actual location is determined from a keccak
https://github.com/vyperlang/vyper/issues/769

We're able to set the initial re-entrant counter using `len(self._globals)`
because all storage slots are allocated while parsing the module-scope,
and re-entrancy locks aren't allocated until later when parsing individual
function scopes.
We're able to set the initial re-entrant counter using the sum of the sizes
of all the storage slots because all storage slots are allocated while parsing
the module-scope, and re-entrancy locks aren't allocated until later when parsing
individual function scopes. This relies on the deprecated _globals attribute
because the new way of doing things (set_data_positions) doesn't expose the
next unallocated storage location.
"""
if key in self._nonrentrant_keys:
return self._nonrentrant_keys[key]
else:
counter = len(self._globals) + self._nonrentrant_counter
counter = (
sum(v.size for v in self._globals.values() if not isinstance(v.typ, MappingType))
iamdefinitelyahuman marked this conversation as resolved.
Show resolved Hide resolved
+ self._nonrentrant_counter
)
self._nonrentrant_keys[key] = counter
self._nonrentrant_counter += 1
return counter