Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LocalProvider snapshots #61

Merged
merged 2 commits into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ Testing utilities

.. autoclass:: LocalProvider
:show-inheritance:
:members: disable_auto_mine_transactions, enable_auto_mine_transactions
:members: disable_auto_mine_transactions, enable_auto_mine_transactions, take_snapshot, revert_to_snapshot

.. autoclass:: HTTPProviderServer
:members:
Expand Down
2 changes: 2 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Added
- Various methods that had a default ``Amount(0)`` for a parameter can now take ``None``. (PR_57_)
- Support for overloaded methods via ``MultiMethod``. (PR_59_)
- Expose ``HTTPProviderServer``, ``LocalProvider``, ``compile_contract_file`` that can be used for tests of Ethereum-using applications. These are gated behind optional features. (PR_54_)
- ``LocalProvider.take_snapshot()`` and ``revert_to_snapshot()``. (PR_61_)


Fixed
Expand All @@ -35,6 +36,7 @@ Fixed
.. _PR_56: https://github.com/fjarri/pons/pull/56
.. _PR_57: https://github.com/fjarri/pons/pull/57
.. _PR_59: https://github.com/fjarri/pons/pull/59
.. _PR_61: https://github.com/fjarri/pons/pull/61


0.7.0 (09-07-2023)
Expand Down
6 changes: 3 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ A quick usage example:
global root_signer
global http_provider

test_provider = LocalProvider(root_balance=Amount.ether(100))
root_signer = test_provider.root
local_provider = LocalProvider(root_balance=Amount.ether(100))
root_signer = local_provider.root

async with trio.open_nursery() as nursery:
handle = HTTPProviderServer(test_provider)
handle = HTTPProviderServer(local_provider)
http_provider = handle.http_provider
await nursery.start(handle)
await func()
Expand Down
3 changes: 2 additions & 1 deletion pons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
PriorityFallback,
)
from ._http_provider_server import HTTPProviderServer
from ._local_provider import LocalProvider
from ._local_provider import LocalProvider, SnapshotID
from ._provider import (
JSON,
HTTPProvider,
Expand Down Expand Up @@ -107,6 +107,7 @@
"RPCErrorCode",
"HTTPProviderServer",
"Signer",
"SnapshotID",
"TransactionFailed",
"TxHash",
"Unreachable",
Expand Down
15 changes: 15 additions & 0 deletions pons/_local_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def normalize_return_value(value: Normalizable) -> JSON:
return value


class SnapshotID:
"""An ID of a snapshot in a :py:class:`LocalProvider`."""

def __init__(self, id_: int):
self.id_ = id_


class LocalProvider(Provider):
"""A provider maintaining its own chain state, useful for tests."""

Expand Down Expand Up @@ -137,6 +144,14 @@ def enable_auto_mine_transactions(self) -> None:
"""
self._ethereum_tester.enable_auto_mine_transactions()

def take_snapshot(self) -> SnapshotID:
"""Creates a snapshot of the chain state internally and returns its ID."""
return SnapshotID(self._ethereum_tester.take_snapshot())

def revert_to_snapshot(self, snapshot_id: SnapshotID) -> None:
"""Restores the chain state to the snapshot with the given ID."""
self._ethereum_tester.revert_to_snapshot(snapshot_id.id_)

def rpc(self, method: str, *args: Any) -> JSON:
dispatch = dict(
net_version=self.net_version,
Expand Down
File renamed without changes.
10 changes: 5 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@


@pytest.fixture
def test_provider():
def local_provider():
return LocalProvider(root_balance=Amount.ether(100))


@pytest.fixture
async def session(test_provider):
client = Client(provider=test_provider)
async def session(local_provider):
client = Client(provider=local_provider)
async with client.session() as session:
yield session


@pytest.fixture
def root_signer(test_provider):
return test_provider.root
def root_signer(local_provider):
return local_provider.root


@pytest.fixture
Expand Down
62 changes: 31 additions & 31 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def normalize_topics(topics):
return tuple((elem,) for elem in topics)


async def test_net_version(test_provider, session):
async def test_net_version(local_provider, session):
net_version1 = await session.net_version()
assert net_version1 == "0"

Expand All @@ -63,19 +63,19 @@ def wrong_net_version():
raise NotImplementedError # pragma: no cover

# The result should have been cached the first time
with monkeypatched(test_provider, "net_version", wrong_net_version):
with monkeypatched(local_provider, "net_version", wrong_net_version):
net_version2 = await session.net_version()
assert net_version1 == net_version2


async def test_net_version_type_check(test_provider, session):
async def test_net_version_type_check(local_provider, session):
# Provider returning a bad value
with monkeypatched(test_provider, "net_version", lambda: 0):
with monkeypatched(local_provider, "net_version", lambda: 0):
with pytest.raises(BadResponseFormat, match="net_version: expected a string result"):
await session.net_version()


async def test_eth_chain_id(test_provider, session):
async def test_eth_chain_id(local_provider, session):
chain_id1 = await session.eth_chain_id()
assert chain_id1 == 2299111 * 57099167

Expand All @@ -84,7 +84,7 @@ def wrong_chain_id():
raise NotImplementedError # pragma: no cover

# The result should have been cached the first time
with monkeypatched(test_provider, "eth_chain_id", wrong_chain_id):
with monkeypatched(local_provider, "eth_chain_id", wrong_chain_id):
chain_id2 = await session.eth_chain_id()
assert chain_id1 == chain_id2

Expand All @@ -101,15 +101,15 @@ async def test_eth_get_balance(session, root_signer, another_signer):
assert balance == Amount.ether(0)


async def test_eth_get_transaction_receipt(test_provider, session, root_signer, another_signer):
test_provider.disable_auto_mine_transactions()
async def test_eth_get_transaction_receipt(local_provider, session, root_signer, another_signer):
local_provider.disable_auto_mine_transactions()
tx_hash = await session.broadcast_transfer(
root_signer, another_signer.address, Amount.ether(10)
)
receipt = await session.eth_get_transaction_receipt(tx_hash)
assert receipt is None

test_provider.enable_auto_mine_transactions()
local_provider.enable_auto_mine_transactions()
receipt = await session.eth_get_transaction_receipt(tx_hash)
assert receipt.succeeded

Expand All @@ -125,10 +125,10 @@ async def test_eth_get_transaction_count(session, root_signer, another_signer):


async def test_wait_for_transaction_receipt(
test_provider, session, root_signer, another_signer, autojump_clock
local_provider, session, root_signer, another_signer, autojump_clock
):
to_transfer = Amount.ether(10)
test_provider.disable_auto_mine_transactions()
local_provider.disable_auto_mine_transactions()
tx_hash = await session.broadcast_transfer(root_signer, another_signer.address, to_transfer)

# The receipt won't be available until we mine, so the waiting should time out
Expand All @@ -149,7 +149,7 @@ async def get_receipt():

async def delayed_enable_mining():
await trio.sleep(timeout)
test_provider.enable_auto_mine_transactions()
local_provider.enable_auto_mine_transactions()

async with trio.open_nursery() as nursery:
nursery.start_soon(get_receipt)
Expand Down Expand Up @@ -269,23 +269,23 @@ async def test_transfer_custom_gas(session, root_signer, another_signer):
await session.transfer(root_signer, another_signer.address, to_transfer, gas=20000)


async def test_transfer_failed(test_provider, session, root_signer, another_signer):
async def test_transfer_failed(local_provider, session, root_signer, another_signer):
# TODO: it would be nice to reproduce the actual situation where this could happen
# (tranfer was accepted for mining, but failed in the process,
# and the resulting receipt has a 0 status).
orig_get_transaction_receipt = test_provider.eth_get_transaction_receipt
orig_get_transaction_receipt = local_provider.eth_get_transaction_receipt

def mock_get_transaction_receipt(tx_hash_hex):
receipt = orig_get_transaction_receipt(tx_hash_hex)
receipt["status"] = "0x0"
return receipt

with monkeypatched(test_provider, "eth_get_transaction_receipt", mock_get_transaction_receipt):
with monkeypatched(local_provider, "eth_get_transaction_receipt", mock_get_transaction_receipt):
with pytest.raises(TransactionFailed, match="Transfer failed"):
await session.transfer(root_signer, another_signer.address, Amount.ether(10))


async def test_deploy(test_provider, session, compiled_contracts, root_signer):
async def test_deploy(local_provider, session, compiled_contracts, root_signer):
basic_contract = compiled_contracts["BasicContract"]
construction_error = compiled_contracts["TestErrors"]
payable_constructor = compiled_contracts["PayableConstructor"]
Expand Down Expand Up @@ -314,14 +314,14 @@ async def test_deploy(test_provider, session, compiled_contracts, root_signer):
await session.deploy(root_signer, construction_error.constructor(0), gas=300000)

# Test the provider returning an empty `contractAddress`
orig_get_transaction_receipt = test_provider.eth_get_transaction_receipt
orig_get_transaction_receipt = local_provider.eth_get_transaction_receipt

def mock_get_transaction_receipt(tx_hash_hex):
receipt = orig_get_transaction_receipt(tx_hash_hex)
receipt["contractAddress"] = None
return receipt

with monkeypatched(test_provider, "eth_get_transaction_receipt", mock_get_transaction_receipt):
with monkeypatched(local_provider, "eth_get_transaction_receipt", mock_get_transaction_receipt):
with pytest.raises(
BadResponseFormat,
match=(
Expand Down Expand Up @@ -365,7 +365,7 @@ async def test_transact(session, compiled_contracts, root_signer):


async def test_transact_and_return_events(
autojump_clock, test_provider, session, compiled_contracts, root_signer, another_signer
autojump_clock, local_provider, session, compiled_contracts, root_signer, another_signer
):
await session.transfer(root_signer, another_signer.address, Amount.ether(1))

Expand Down Expand Up @@ -395,7 +395,7 @@ def results_for(x):
# Two transactions for the same method in the same block -
# we need to be able to only pick up the results from the relevant transaction receipt

test_provider.disable_auto_mine_transactions()
local_provider.disable_auto_mine_transactions()

results = {}

Expand All @@ -407,7 +407,7 @@ async def transact(signer, x):

async def delayed_enable_mining():
await trio.sleep(5)
test_provider.enable_auto_mine_transactions()
local_provider.enable_auto_mine_transactions()

x1 = 1
x2 = 2
Expand Down Expand Up @@ -455,8 +455,8 @@ async def test_eth_get_transaction_by_hash(session, root_signer, another_signer)
assert tx_info is None


async def test_eth_get_filter_changes_bad_response(test_provider, session, monkeypatch):
monkeypatch.setattr(test_provider, "eth_get_filter_changes", lambda _filter_id: {"foo": 1})
async def test_eth_get_filter_changes_bad_response(local_provider, session, monkeypatch):
monkeypatch.setattr(local_provider, "eth_get_filter_changes", lambda _filter_id: {"foo": 1})

block_filter = await session.eth_new_block_filter()

Expand Down Expand Up @@ -492,12 +492,12 @@ async def test_block_filter(session, root_signer, another_signer):
assert len(block_hashes) == 0


async def test_pending_transaction_filter(test_provider, session, root_signer, another_signer):
async def test_pending_transaction_filter(local_provider, session, root_signer, another_signer):
transaction_filter = await session.eth_new_pending_transaction_filter()

to_transfer = Amount.ether(1)

test_provider.disable_auto_mine_transactions()
local_provider.disable_auto_mine_transactions()
tx_hash = await session.broadcast_transfer(root_signer, another_signer.address, to_transfer)
tx_hashes = await session.eth_get_filter_changes(transaction_filter)
assert tx_hashes == (tx_hash,)
Expand Down Expand Up @@ -751,13 +751,13 @@ async def observer():
assert events[2] == {"from_": another_signer.address, "id": b"1111", "value": 3, "value2": 4}


async def test_unknown_rpc_status_code(test_provider, session, monkeypatch):
async def test_unknown_rpc_status_code(local_provider, session, monkeypatch):
def faulty_net_version():
# This is a known exception type, and it will be transferred through the network
# keeping the status code.
raise RPCError(666, "this method is possessed")

monkeypatch.setattr(test_provider, "net_version", faulty_net_version)
monkeypatch.setattr(local_provider, "net_version", faulty_net_version)

with pytest.raises(ProviderError, match=r"Provider error \(666\): this method is possessed"):
await session.net_version()
Expand Down Expand Up @@ -870,7 +870,7 @@ async def test_contract_exceptions_high_level(session, root_signer, compiled_con
assert exc.value.data == {"x": 4}


async def test_unknown_error_reasons(test_provider, session, compiled_contracts, root_signer):
async def test_unknown_error_reasons(local_provider, session, compiled_contracts, root_signer):
compiled_contract = compiled_contracts["TestErrors"]
contract = await session.deploy(root_signer, compiled_contract.constructor(999))

Expand All @@ -881,7 +881,7 @@ def eth_estimate_gas(*_args, **_kwargs):
data = PANIC_ERROR.selector + encode_args((abi.uint(256), 888))
raise RPCError(RPCErrorCode.EXECUTION_ERROR, "execution reverted", rpc_encode_data(data))

with monkeypatched(test_provider, "eth_estimate_gas", eth_estimate_gas):
with monkeypatched(local_provider, "eth_estimate_gas", eth_estimate_gas):
with pytest.raises(ContractPanic, match=r"ContractPanicReason.UNKNOWN"):
await session.estimate_transact(contract.method.transactPanic(999))

Expand All @@ -892,7 +892,7 @@ def eth_estimate_gas(*_args, **_kwargs):
data = b"1234" + encode_args((abi.uint(256), 1))
raise RPCError(RPCErrorCode.EXECUTION_ERROR, "execution reverted", rpc_encode_data(data))

with monkeypatched(test_provider, "eth_estimate_gas", eth_estimate_gas):
with monkeypatched(local_provider, "eth_estimate_gas", eth_estimate_gas):
with pytest.raises(
ProviderError, match=r"Provider error \(EXECUTION_ERROR\): execution reverted"
):
Expand All @@ -905,6 +905,6 @@ def eth_estimate_gas(*_args, **_kwargs):
data = PANIC_ERROR.selector + encode_args((abi.uint(256), 0))
raise RPCError(12345, "execution reverted", rpc_encode_data(data))

with monkeypatched(test_provider, "eth_estimate_gas", eth_estimate_gas):
with monkeypatched(local_provider, "eth_estimate_gas", eth_estimate_gas):
with pytest.raises(ProviderError, match=r"Provider error \(12345\): execution reverted"):
await session.estimate_transact(contract.method.transactPanic(999))
8 changes: 4 additions & 4 deletions tests/test_http_provider_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@


@pytest.fixture
async def test_server(nursery, test_provider):
handle = HTTPProviderServer(test_provider)
async def server(nursery, local_provider):
handle = HTTPProviderServer(local_provider)
await nursery.start(handle)
yield handle
await handle.shutdown()


@pytest.fixture
async def provider_session(test_server):
async with test_server.http_provider.session() as session:
async def provider_session(server):
async with server.http_provider.session() as session:
yield session


Expand Down
17 changes: 16 additions & 1 deletion tests/test_local_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ async def test_auto_mine(provider, session, root_signer, another_signer):
assert provider.eth_get_balance(dest.rpc_encode(), latest) == Amount.ether(2).rpc_encode()


async def test_snapshots(provider, session, root_signer, another_signer):
amount = Amount.ether(1)
double_amount = Amount.ether(2)
dest = another_signer.address
latest = rpc_encode_block(Block.LATEST)

await session.broadcast_transfer(root_signer, dest, amount)
snapshot_id = provider.take_snapshot()
await session.broadcast_transfer(root_signer, dest, amount)
assert provider.eth_get_balance(dest.rpc_encode(), latest) == double_amount.rpc_encode()

provider.revert_to_snapshot(snapshot_id)
assert provider.eth_get_balance(dest.rpc_encode(), latest) == amount.rpc_encode()


def test_net_version(provider):
assert provider.net_version() == "0"

Expand Down Expand Up @@ -153,7 +168,7 @@ async def test_eth_send_raw_transaction(provider, root_signer, another_signer):


async def test_eth_call(provider, session, root_signer):
path = Path(__file__).resolve().parent / "TestTestProvider.sol"
path = Path(__file__).resolve().parent / "TestLocalProvider.sol"
compiled = compile_contract_file(path)
compiled_contract = compiled["BasicContract"]

Expand Down
Loading