Skip to content

Commit

Permalink
Move batching to web3 module; Re-structure Web3Module tests:
Browse files Browse the repository at this point in the history
- Move request batching to a higher level public API on the ``web3`` module;
  make the implementation methods on the manager private.
- Create a proper ``AsyncWeb3ModuleTest`` class that can test differences
  between synchronous and asynchronous methods on the module that now exist.
  This gets rid of the parameterized tests that previously tested the
  static methods on both the ``Web3`` and the ``AsyncWeb3`` classes.
- Tighten up typing related to batch requests; add unsupported methods for batching with tests
  • Loading branch information
fselmo committed Apr 30, 2024
1 parent 64e791a commit 24e15cb
Show file tree
Hide file tree
Showing 13 changed files with 256 additions and 111 deletions.
19 changes: 15 additions & 4 deletions tests/integration/go_ethereum/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
NetModuleTest,
Web3ModuleTest,
)
from web3._utils.module_testing.web3_module import (
AsyncWeb3ModuleTest,
)
from web3.types import (
BlockData,
)
Expand All @@ -27,7 +30,7 @@
)


class GoEthereumTest(Web3ModuleTest):
class GoEthereumWeb3ModuleTest(Web3ModuleTest):
def _check_web3_client_version(self, client_version):
assert client_version.startswith("Geth/")

Expand Down Expand Up @@ -85,12 +88,16 @@ class GoEthereumNetModuleTest(NetModuleTest):
pass


class GoEthereumAsyncNetModuleTest(AsyncNetModuleTest):
class GoEthereumAdminModuleTest(GoEthereumAdminModuleTest):
pass


class GoEthereumAdminModuleTest(GoEthereumAdminModuleTest):
pass
# --- async --- #


class GoEthereumAsyncWeb3ModuleTest(AsyncWeb3ModuleTest):
def _check_web3_client_version(self, client_version):
assert client_version.startswith("Geth/")


class GoEthereumAsyncEthModuleTest(AsyncEthModuleTest):
Expand All @@ -111,3 +118,7 @@ async def test_invalid_eth_sign_typed_data(
await super().test_invalid_eth_sign_typed_data(
async_w3, keyfile_account_address_dual_type, async_skip_if_testrpc
)


class GoEthereumAsyncNetModuleTest(AsyncNetModuleTest):
pass
9 changes: 7 additions & 2 deletions tests/integration/go_ethereum/test_goethereum_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
GoEthereumAsyncEthModuleTest,
GoEthereumAsyncNetModuleTest,
GoEthereumAsyncTxPoolModuleTest,
GoEthereumAsyncWeb3ModuleTest,
GoEthereumEthModuleTest,
GoEthereumNetModuleTest,
GoEthereumTest,
GoEthereumTxPoolModuleTest,
GoEthereumWeb3ModuleTest,
)
from .utils import (
wait_for_aiohttp,
Expand Down Expand Up @@ -70,7 +71,7 @@ def w3(geth_process, endpoint_uri):
return Web3(Web3.HTTPProvider(endpoint_uri))


class TestGoEthereumTest(GoEthereumTest):
class TestGoEthereumWeb3ModuleTest(GoEthereumWeb3ModuleTest):
pass


Expand Down Expand Up @@ -116,6 +117,10 @@ async def async_w3(geth_process, endpoint_uri):
return _w3


class TestGoEthereumAsyncWeb3ModuleTest(GoEthereumAsyncWeb3ModuleTest):
pass


class TestGoEthereumAsyncAdminModuleTest(GoEthereumAsyncAdminModuleTest):
@pytest.mark.asyncio
@pytest.mark.xfail(
Expand Down
30 changes: 19 additions & 11 deletions tests/integration/go_ethereum/test_goethereum_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
GoEthereumAdminModuleTest,
GoEthereumAsyncEthModuleTest,
GoEthereumAsyncNetModuleTest,
GoEthereumAsyncWeb3ModuleTest,
GoEthereumEthModuleTest,
GoEthereumNetModuleTest,
GoEthereumTest,
GoEthereumWeb3ModuleTest,
)
from .utils import (
wait_for_async_socket,
Expand Down Expand Up @@ -56,14 +57,15 @@ def w3(geth_process, geth_ipc_path):
return Web3(Web3.IPCProvider(geth_ipc_path, timeout=30))


@pytest_asyncio.fixture(scope="module")
async def async_w3(geth_process, geth_ipc_path):
await wait_for_async_socket(geth_ipc_path)
async with AsyncWeb3(AsyncIPCProvider(geth_ipc_path)) as _aw3:
yield _aw3
class TestGoEthereumWeb3ModuleTest(GoEthereumWeb3ModuleTest):
pass


class TestGoEthereumTest(GoEthereumTest):
class TestGoEthereumEthModuleTest(GoEthereumEthModuleTest):
pass


class TestGoEthereumNetModuleTest(GoEthereumNetModuleTest):
pass


Expand All @@ -87,15 +89,21 @@ def test_admin_start_stop_ws(self, w3: "Web3") -> None:
super().test_admin_start_stop_ws(w3)


class TestGoEthereumEthModuleTest(GoEthereumEthModuleTest):
pass
# -- async -- #


class TestGoEthereumAsyncEthModuleTest(GoEthereumAsyncEthModuleTest):
@pytest_asyncio.fixture(scope="module")
async def async_w3(geth_process, geth_ipc_path):
await wait_for_async_socket(geth_ipc_path)
async with AsyncWeb3(AsyncIPCProvider(geth_ipc_path)) as _aw3:
yield _aw3


class TestGoEthereumAsyncWeb3ModuleTest(GoEthereumAsyncWeb3ModuleTest):
pass


class TestGoEthereumNetModuleTest(GoEthereumNetModuleTest):
class TestGoEthereumAsyncEthModuleTest(GoEthereumAsyncEthModuleTest):
pass


Expand Down
4 changes: 2 additions & 2 deletions tests/integration/go_ethereum/test_goethereum_legacy_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
GoEthereumAdminModuleTest,
GoEthereumEthModuleTest,
GoEthereumNetModuleTest,
GoEthereumTest,
GoEthereumWeb3ModuleTest,
)


Expand Down Expand Up @@ -70,7 +70,7 @@ def w3(geth_process, endpoint_uri):
return _w3


class TestGoEthereumTest(GoEthereumTest):
class TestGoEthereumWeb3ModuleTest(GoEthereumWeb3ModuleTest):
pass


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..common import (
GoEthereumAsyncEthModuleTest,
GoEthereumAsyncNetModuleTest,
GoEthereumAsyncWeb3ModuleTest,
)
from ..utils import (
wait_for_aiohttp,
Expand All @@ -30,6 +31,10 @@ async def async_w3(geth_process, endpoint_uri):
return await AsyncWeb3(WebSocketProvider(endpoint_uri))


class TestGoEthereumAsyncWeb3ModuleTest(GoEthereumAsyncWeb3ModuleTest):
pass


class TestGoEthereumAsyncAdminModuleTest(GoEthereumAsyncAdminModuleTest):
@pytest.mark.asyncio
@pytest.mark.xfail(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..common import (
GoEthereumAsyncEthModuleTest,
GoEthereumAsyncNetModuleTest,
GoEthereumAsyncWeb3ModuleTest,
)
from ..utils import (
wait_for_aiohttp,
Expand All @@ -31,6 +32,10 @@ async def async_w3(geth_process, endpoint_uri):
yield w3


class TestGoEthereumAsyncWeb3ModuleTest(GoEthereumAsyncWeb3ModuleTest):
pass


class TestGoEthereumAsyncAdminModuleTest(GoEthereumAsyncAdminModuleTest):
@pytest.mark.asyncio
@pytest.mark.xfail(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..common import (
GoEthereumAsyncEthModuleTest,
GoEthereumAsyncNetModuleTest,
GoEthereumAsyncWeb3ModuleTest,
)
from ..utils import (
wait_for_aiohttp,
Expand All @@ -31,6 +32,10 @@ async def async_w3(geth_process, endpoint_uri):
return w3


class TestGoEthereumAsyncWeb3ModuleTest(GoEthereumAsyncWeb3ModuleTest):
pass


class TestGoEthereumAsyncAdminModuleTest(GoEthereumAsyncAdminModuleTest):
@pytest.mark.asyncio
@pytest.mark.xfail(
Expand Down
36 changes: 27 additions & 9 deletions web3/_utils/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
from web3.contract.contract import (
ContractFunction,
)
from web3.method import (
Method,
)
from web3.types import (
TFunc,
TReturn,
Expand All @@ -41,6 +38,9 @@
AsyncWeb3,
Web3,
)
from web3.method import ( # noqa: F401
Method,
)
from web3.providers import ( # noqa: F401
PersistentConnectionProvider,
)
Expand All @@ -53,6 +53,15 @@
)

BatchRequestInformation = Tuple[Tuple["RPCEndpoint", Any], Sequence[Any]]
RPC_METHODS_UNSUPPORTED_DURING_BATCH = {
"eth_subscribe",
"eth_unsubscribe",
"eth_sendRawTransaction",
"eth_sendTransaction",
"eth_signTransaction",
"eth_sign",
"eth_signTypedData",
}


class BatchRequestContextManager(Generic[TFunc]):
Expand All @@ -64,12 +73,12 @@ def __init__(self, web3: Union["AsyncWeb3", "Web3"]) -> None:
] = []

def add(self, batch_payload: TReturn) -> None:
# When batching, we don't make a request. Instead, we will get the request
# information and store it in the `_requests_info` list. So we have to cast the
# apparent "request" into the BatchRequestInformation type.
if isinstance(batch_payload, (ContractFunction, AsyncContractFunction)):
batch_payload = batch_payload.call() # type: ignore

# When batching, we don't make a request. Instead, we will get the request
# information and store it in the `_requests_info` list. So we have to cast the
# apparent "request" into the BatchRequestInformation type.
if self.web3.provider.is_async:
self._async_requests_info.append(
cast(Coroutine[Any, Any, BatchRequestInformation], batch_payload)
Expand All @@ -78,7 +87,16 @@ def add(self, batch_payload: TReturn) -> None:
self._requests_info.append(cast(BatchRequestInformation, batch_payload))

def add_mapping(
self, batch_payload: Dict[Method[Callable[..., Any]], List[Any]]
self,
batch_payload: Dict[
Union[
"Method[Callable[..., Any]]",
Callable[..., Any],
ContractFunction,
AsyncContractFunction,
],
List[Any],
],
) -> None:
for method, params in batch_payload.items():
for param in params:
Expand All @@ -97,7 +115,7 @@ def __exit__(
self.web3.provider._is_batching = False

def execute(self) -> List["RPCResponse"]:
return self.web3.manager.make_batch_request(self._requests_info)
return self.web3.manager._make_batch_request(self._requests_info)

# -- async -- #

Expand All @@ -122,6 +140,6 @@ async def __aexit__(
provider._batch_request_counter = None

async def async_execute(self) -> List["RPCResponse"]:
return await self.web3.manager.async_make_batch_request(
return await self.web3.manager._async_make_batch_request(
self._async_requests_info
)
41 changes: 0 additions & 41 deletions web3/_utils/module_testing/persistent_connection_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
)

if TYPE_CHECKING:
from web3.contract import (
AsyncContract,
)
from web3.main import (
AsyncWeb3,
)
Expand Down Expand Up @@ -397,41 +394,3 @@ async def test_asyncio_gather_for_multiple_requests_matches_the_responses(
assert isinstance(chain_id, int)
assert isinstance(chain_id2, int)
assert isinstance(chain_id3, int)

@pytest.mark.asyncio
async def test_async_batch_request(
self, async_w3: "AsyncWeb3", async_math_contract: "AsyncContract"
) -> None:
async with async_w3.manager.batch_requests() as batch:
batch.add(async_w3.eth.get_block(6))
batch.add(async_w3.eth.get_block(4))
batch.add(async_w3.eth.get_block(2))
batch.add(async_w3.eth.get_block(0))

batch.add(async_math_contract.functions.multiply7(0))

batch.add_mapping(
{
async_math_contract.functions.multiply7: [1, 2, 3],
async_w3.eth.get_block: [1, 3, 5],
}
)

responses = await batch.async_execute()

assert len(responses) == 11
assert all(isinstance(response, AttributeDict) for response in responses[:3])
assert responses[0]["number"] == 6
assert responses[1]["number"] == 4
assert responses[2]["number"] == 2
assert responses[3]["number"] == 0

assert responses[4] == 0
assert responses[5] == 7
assert responses[6] == 14
assert responses[7] == 21

assert all(isinstance(response, AttributeDict) for response in responses[8:])
assert responses[8]["number"] == 1
assert responses[9]["number"] == 3
assert responses[10]["number"] == 5
Loading

0 comments on commit 24e15cb

Please sign in to comment.