From 13dab937a0622b22c09fee0717dc475913e697a2 Mon Sep 17 00:00:00 2001 From: Andrew Date: Mon, 30 May 2022 18:57:02 -0400 Subject: [PATCH 1/2] add reentrant call mock --- tests/mocks/account_reentrant_call_mock.cairo | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tests/mocks/account_reentrant_call_mock.cairo diff --git a/tests/mocks/account_reentrant_call_mock.cairo b/tests/mocks/account_reentrant_call_mock.cairo new file mode 100644 index 000000000..26b99cd1e --- /dev/null +++ b/tests/mocks/account_reentrant_call_mock.cairo @@ -0,0 +1,54 @@ +%lang starknet + +from starkware.starknet.common.syscalls import call_contract, get_caller_address, get_tx_info, get_contract_address +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin + +from starkware.cairo.common.alloc import alloc + +const GET_NONCE = 756703644403488948674317127005533987569832834207225504298785384568821383277 +const EXECUTE = 617075754465154585683856897856256838130216341506379215893724690153393808813 +const SET_PUBLIC_KEY = 1307260637166823203998179679098545329314629630090003875272134084395659334905 + +@external +func account_takeover{ + syscall_ptr : felt*, + pedersen_ptr : HashBuiltin*, + range_check_ptr + }(): + alloc_locals + let (caller) = get_caller_address() + + let (empty_calldata: felt*) = alloc() + let res = call_contract( + contract_address=caller, + function_selector=GET_NONCE, # get_nonce + calldata_size=0, + calldata=empty_calldata, + ) + let nonce = res.retdata[0] + + let (call_calldata: felt*) = alloc() + + # call_array + assert call_calldata[0] = 1 + assert call_calldata[1] = caller + assert call_calldata[2] = SET_PUBLIC_KEY + assert call_calldata[3] = 0 + assert call_calldata[4] = 1 + + # calldata + assert call_calldata[5] = 1 + assert call_calldata[6] = 123 # new public key + + # nonce + assert call_calldata[7] = nonce + + call_contract( + contract_address=caller, + function_selector=EXECUTE, + calldata_size=8, + calldata=call_calldata, + ) + + return () +end From 6ee70917aed6d67cef6be5fb1530499b824648db Mon Sep 17 00:00:00 2001 From: Andrew Date: Mon, 30 May 2022 19:52:20 -0400 Subject: [PATCH 2/2] add account reentrancy test --- src/openzeppelin/account/library.cairo | 5 + tests/account/test_Account.py | 97 +++++++++++++------ ...ll_mock.cairo => account_reentrancy.cairo} | 0 3 files changed, 71 insertions(+), 31 deletions(-) rename tests/mocks/{account_reentrant_call_mock.cairo => account_reentrancy.cairo} (100%) diff --git a/src/openzeppelin/account/library.cairo b/src/openzeppelin/account/library.cairo index 0fd2fa030..6f5c27f3d 100644 --- a/src/openzeppelin/account/library.cairo +++ b/src/openzeppelin/account/library.cairo @@ -155,6 +155,11 @@ namespace Account: ) -> (response_len: felt, response: felt*): alloc_locals + let (caller) = get_caller_address() + with_attr error_message("Account: no reentrant call"): + assert caller = 0 + end + let (__fp__, _) = get_fp_and_pc() let (tx_info) = get_tx_info() let (_current_nonce) = Account_current_nonce.read() diff --git a/tests/account/test_Account.py b/tests/account/test_Account.py index 0a066a3a1..2c65a8499 100644 --- a/tests/account/test_Account.py +++ b/tests/account/test_Account.py @@ -2,34 +2,70 @@ from starkware.starknet.testing.starknet import Starknet from starkware.starkware_utils.error_handling import StarkException from starkware.starknet.definitions.error_codes import StarknetErrorCode -from utils import TestSigner, assert_revert, contract_path +from utils import TestSigner, assert_revert, get_contract_def, cached_contract, TRUE signer = TestSigner(123456789987654321) other = TestSigner(987654321123456789) IACCOUNT_ID = 0xf10dbd44 -TRUE = 1 @pytest.fixture(scope='module') -async def account_factory(): +def contract_defs(): + account_def = get_contract_def('openzeppelin/account/Account.cairo') + init_def = get_contract_def("tests/mocks/Initializable.cairo") + attacker_def = get_contract_def("tests/mocks/account_reentrancy.cairo") + + return account_def, init_def, attacker_def + + +@pytest.fixture(scope='module') +async def account_init(contract_defs): + account_def, init_def, attacker_def = contract_defs starknet = await Starknet.empty() - account = await starknet.deploy( - contract_path("openzeppelin/account/Account.cairo"), + + account1 = await starknet.deploy( + contract_def=account_def, + constructor_calldata=[signer.public_key] + ) + account2 = await starknet.deploy( + contract_def=account_def, constructor_calldata=[signer.public_key] ) - bad_account = await starknet.deploy( - contract_path("openzeppelin/account/Account.cairo"), - constructor_calldata=[signer.public_key], + initializable1 = await starknet.deploy( + contract_def=init_def, + constructor_calldata=[], ) + initializable2 = await starknet.deploy( + contract_def=init_def, + constructor_calldata=[], + ) + attacker = await starknet.deploy( + contract_def=attacker_def, + constructor_calldata=[], + ) + + return starknet.state, account1, account2, initializable1, initializable2, attacker - return starknet, account, bad_account + +@pytest.fixture +def account_factory(contract_defs, account_init): + account_def, init_def, attacker_def = contract_defs + state, account1, account2, initializable1, initializable2, attacker = account_init + _state = state.copy() + account1 = cached_contract(_state, account_def, account1) + account2 = cached_contract(_state, account_def, account2) + initializable1 = cached_contract(_state, init_def, initializable1) + initializable2 = cached_contract(_state, init_def, initializable2) + attacker = cached_contract(_state, attacker_def, attacker) + + return account1, account2, initializable1, initializable2, attacker @pytest.mark.asyncio async def test_constructor(account_factory): - _, account, _ = account_factory + account, *_ = account_factory execution_info = await account.get_public_key().call() assert execution_info.result == (signer.public_key,) @@ -40,10 +76,7 @@ async def test_constructor(account_factory): @pytest.mark.asyncio async def test_execute(account_factory): - starknet, account, _ = account_factory - initializable = await starknet.deploy( - contract_path("tests/mocks/Initializable.cairo") - ) + account, _, initializable, *_ = account_factory execution_info = await initializable.initialized().call() assert execution_info.result == (0,) @@ -56,13 +89,7 @@ async def test_execute(account_factory): @pytest.mark.asyncio async def test_multicall(account_factory): - starknet, account, _ = account_factory - initializable_1 = await starknet.deploy( - contract_path("tests/mocks/Initializable.cairo") - ) - initializable_2 = await starknet.deploy( - contract_path("tests/mocks/Initializable.cairo") - ) + account, _, initializable_1, initializable_2, _ = account_factory execution_info = await initializable_1.initialized().call() assert execution_info.result == (0,) @@ -85,10 +112,7 @@ async def test_multicall(account_factory): @pytest.mark.asyncio async def test_return_value(account_factory): - starknet, account, _ = account_factory - initializable = await starknet.deploy( - contract_path("tests/mocks/Initializable.cairo") - ) + account, _, initializable, *_ = account_factory # initialize, set `initialized = 1` await signer.send_transactions(account, [(initializable.contract_address, 'initialize', [])]) @@ -101,10 +125,8 @@ async def test_return_value(account_factory): @ pytest.mark.asyncio async def test_nonce(account_factory): - starknet, account, _ = account_factory - initializable = await starknet.deploy( - contract_path("tests/mocks/Initializable.cairo") - ) + account, _, initializable, *_ = account_factory + execution_info = await account.get_nonce().call() current_nonce = execution_info.result.res @@ -132,7 +154,7 @@ async def test_nonce(account_factory): @pytest.mark.asyncio async def test_public_key_setter(account_factory): - _, account, _ = account_factory + account, *_ = account_factory execution_info = await account.get_public_key().call() assert execution_info.result == (signer.public_key,) @@ -146,7 +168,7 @@ async def test_public_key_setter(account_factory): @pytest.mark.asyncio async def test_public_key_setter_different_account(account_factory): - _, account, bad_account = account_factory + account, bad_account, *_ = account_factory # set new pubkey await assert_revert( @@ -156,3 +178,16 @@ async def test_public_key_setter_different_account(account_factory): ), reverted_with="Account: caller is not this account" ) + + +@pytest.mark.asyncio +async def test_account_takeover_with_reentrant_call(account_factory): + account, _, _, _, attacker = account_factory + + await assert_revert( + signer.send_transaction(account, attacker.contract_address, 'account_takeover', []), + reverted_with="Account: no reentrant call" + ) + + execution_info = await account.get_public_key().call() + assert execution_info.result == (signer.public_key,) diff --git a/tests/mocks/account_reentrant_call_mock.cairo b/tests/mocks/account_reentrancy.cairo similarity index 100% rename from tests/mocks/account_reentrant_call_mock.cairo rename to tests/mocks/account_reentrancy.cairo