From 24e15cbb2f51e9071eda4aa72fcc009f4ed643bb Mon Sep 17 00:00:00 2001 From: fselmo Date: Tue, 30 Apr 2024 12:21:31 -0600 Subject: [PATCH] Move batching to web3 module; Re-structure Web3Module tests: - Move request batching to a higher level public API on the ``web3`` module; make the implementation methods on the manager private. - Create a proper ``AsyncWeb3ModuleTest`` class that can test differences between synchronous and asynchronous methods on the module that now exist. This gets rid of the parameterized tests that previously tested the static methods on both the ``Web3`` and the ``AsyncWeb3`` classes. - Tighten up typing related to batch requests; add unsupported methods for batching with tests --- tests/integration/go_ethereum/common.py | 19 +- .../go_ethereum/test_goethereum_http.py | 9 +- .../go_ethereum/test_goethereum_ipc.py | 30 +-- .../go_ethereum/test_goethereum_legacy_ws.py | 4 +- .../test_goethereum_ws/test_async_await_w3.py | 5 + .../test_async_ctx_manager_w3.py | 5 + .../test_async_iterator_w3.py | 5 + web3/_utils/batching.py | 36 +++- .../persistent_connection_provider.py | 41 ---- web3/_utils/module_testing/web3_module.py | 182 ++++++++++++++---- web3/main.py | 14 +- web3/manager.py | 8 +- web3/method.py | 9 + 13 files changed, 256 insertions(+), 111 deletions(-) diff --git a/tests/integration/go_ethereum/common.py b/tests/integration/go_ethereum/common.py index 59e4614ab5..e0df00bdd8 100644 --- a/tests/integration/go_ethereum/common.py +++ b/tests/integration/go_ethereum/common.py @@ -17,6 +17,9 @@ NetModuleTest, Web3ModuleTest, ) +from web3._utils.module_testing.web3_module import ( + AsyncWeb3ModuleTest, +) from web3.types import ( BlockData, ) @@ -27,7 +30,7 @@ ) -class GoEthereumTest(Web3ModuleTest): +class GoEthereumWeb3ModuleTest(Web3ModuleTest): def _check_web3_client_version(self, client_version): assert client_version.startswith("Geth/") @@ -85,12 +88,16 @@ class GoEthereumNetModuleTest(NetModuleTest): pass -class GoEthereumAsyncNetModuleTest(AsyncNetModuleTest): +class GoEthereumAdminModuleTest(GoEthereumAdminModuleTest): pass -class GoEthereumAdminModuleTest(GoEthereumAdminModuleTest): - pass +# --- async --- # + + +class GoEthereumAsyncWeb3ModuleTest(AsyncWeb3ModuleTest): + def _check_web3_client_version(self, client_version): + assert client_version.startswith("Geth/") class GoEthereumAsyncEthModuleTest(AsyncEthModuleTest): @@ -111,3 +118,7 @@ async def test_invalid_eth_sign_typed_data( await super().test_invalid_eth_sign_typed_data( async_w3, keyfile_account_address_dual_type, async_skip_if_testrpc ) + + +class GoEthereumAsyncNetModuleTest(AsyncNetModuleTest): + pass diff --git a/tests/integration/go_ethereum/test_goethereum_http.py b/tests/integration/go_ethereum/test_goethereum_http.py index 54633b3cb0..5ae82ad624 100644 --- a/tests/integration/go_ethereum/test_goethereum_http.py +++ b/tests/integration/go_ethereum/test_goethereum_http.py @@ -21,10 +21,11 @@ GoEthereumAsyncEthModuleTest, GoEthereumAsyncNetModuleTest, GoEthereumAsyncTxPoolModuleTest, + GoEthereumAsyncWeb3ModuleTest, GoEthereumEthModuleTest, GoEthereumNetModuleTest, - GoEthereumTest, GoEthereumTxPoolModuleTest, + GoEthereumWeb3ModuleTest, ) from .utils import ( wait_for_aiohttp, @@ -70,7 +71,7 @@ def w3(geth_process, endpoint_uri): return Web3(Web3.HTTPProvider(endpoint_uri)) -class TestGoEthereumTest(GoEthereumTest): +class TestGoEthereumWeb3ModuleTest(GoEthereumWeb3ModuleTest): pass @@ -116,6 +117,10 @@ async def async_w3(geth_process, endpoint_uri): return _w3 +class TestGoEthereumAsyncWeb3ModuleTest(GoEthereumAsyncWeb3ModuleTest): + pass + + class TestGoEthereumAsyncAdminModuleTest(GoEthereumAsyncAdminModuleTest): @pytest.mark.asyncio @pytest.mark.xfail( diff --git a/tests/integration/go_ethereum/test_goethereum_ipc.py b/tests/integration/go_ethereum/test_goethereum_ipc.py index 77eb58d25f..d651e2a5da 100644 --- a/tests/integration/go_ethereum/test_goethereum_ipc.py +++ b/tests/integration/go_ethereum/test_goethereum_ipc.py @@ -17,9 +17,10 @@ GoEthereumAdminModuleTest, GoEthereumAsyncEthModuleTest, GoEthereumAsyncNetModuleTest, + GoEthereumAsyncWeb3ModuleTest, GoEthereumEthModuleTest, GoEthereumNetModuleTest, - GoEthereumTest, + GoEthereumWeb3ModuleTest, ) from .utils import ( wait_for_async_socket, @@ -56,14 +57,15 @@ def w3(geth_process, geth_ipc_path): return Web3(Web3.IPCProvider(geth_ipc_path, timeout=30)) -@pytest_asyncio.fixture(scope="module") -async def async_w3(geth_process, geth_ipc_path): - await wait_for_async_socket(geth_ipc_path) - async with AsyncWeb3(AsyncIPCProvider(geth_ipc_path)) as _aw3: - yield _aw3 +class TestGoEthereumWeb3ModuleTest(GoEthereumWeb3ModuleTest): + pass -class TestGoEthereumTest(GoEthereumTest): +class TestGoEthereumEthModuleTest(GoEthereumEthModuleTest): + pass + + +class TestGoEthereumNetModuleTest(GoEthereumNetModuleTest): pass @@ -87,15 +89,21 @@ def test_admin_start_stop_ws(self, w3: "Web3") -> None: super().test_admin_start_stop_ws(w3) -class TestGoEthereumEthModuleTest(GoEthereumEthModuleTest): - pass +# -- async -- # -class TestGoEthereumAsyncEthModuleTest(GoEthereumAsyncEthModuleTest): +@pytest_asyncio.fixture(scope="module") +async def async_w3(geth_process, geth_ipc_path): + await wait_for_async_socket(geth_ipc_path) + async with AsyncWeb3(AsyncIPCProvider(geth_ipc_path)) as _aw3: + yield _aw3 + + +class TestGoEthereumAsyncWeb3ModuleTest(GoEthereumAsyncWeb3ModuleTest): pass -class TestGoEthereumNetModuleTest(GoEthereumNetModuleTest): +class TestGoEthereumAsyncEthModuleTest(GoEthereumAsyncEthModuleTest): pass diff --git a/tests/integration/go_ethereum/test_goethereum_legacy_ws.py b/tests/integration/go_ethereum/test_goethereum_legacy_ws.py index 709620e0ad..9b9c77473a 100644 --- a/tests/integration/go_ethereum/test_goethereum_legacy_ws.py +++ b/tests/integration/go_ethereum/test_goethereum_legacy_ws.py @@ -17,7 +17,7 @@ GoEthereumAdminModuleTest, GoEthereumEthModuleTest, GoEthereumNetModuleTest, - GoEthereumTest, + GoEthereumWeb3ModuleTest, ) @@ -70,7 +70,7 @@ def w3(geth_process, endpoint_uri): return _w3 -class TestGoEthereumTest(GoEthereumTest): +class TestGoEthereumWeb3ModuleTest(GoEthereumWeb3ModuleTest): pass diff --git a/tests/integration/go_ethereum/test_goethereum_ws/test_async_await_w3.py b/tests/integration/go_ethereum/test_goethereum_ws/test_async_await_w3.py index 0037211f03..40b8bfbad3 100644 --- a/tests/integration/go_ethereum/test_goethereum_ws/test_async_await_w3.py +++ b/tests/integration/go_ethereum/test_goethereum_ws/test_async_await_w3.py @@ -16,6 +16,7 @@ from ..common import ( GoEthereumAsyncEthModuleTest, GoEthereumAsyncNetModuleTest, + GoEthereumAsyncWeb3ModuleTest, ) from ..utils import ( wait_for_aiohttp, @@ -30,6 +31,10 @@ async def async_w3(geth_process, endpoint_uri): return await AsyncWeb3(WebSocketProvider(endpoint_uri)) +class TestGoEthereumAsyncWeb3ModuleTest(GoEthereumAsyncWeb3ModuleTest): + pass + + class TestGoEthereumAsyncAdminModuleTest(GoEthereumAsyncAdminModuleTest): @pytest.mark.asyncio @pytest.mark.xfail( diff --git a/tests/integration/go_ethereum/test_goethereum_ws/test_async_ctx_manager_w3.py b/tests/integration/go_ethereum/test_goethereum_ws/test_async_ctx_manager_w3.py index 6e0262de55..594be64421 100644 --- a/tests/integration/go_ethereum/test_goethereum_ws/test_async_ctx_manager_w3.py +++ b/tests/integration/go_ethereum/test_goethereum_ws/test_async_ctx_manager_w3.py @@ -16,6 +16,7 @@ from ..common import ( GoEthereumAsyncEthModuleTest, GoEthereumAsyncNetModuleTest, + GoEthereumAsyncWeb3ModuleTest, ) from ..utils import ( wait_for_aiohttp, @@ -31,6 +32,10 @@ async def async_w3(geth_process, endpoint_uri): yield w3 +class TestGoEthereumAsyncWeb3ModuleTest(GoEthereumAsyncWeb3ModuleTest): + pass + + class TestGoEthereumAsyncAdminModuleTest(GoEthereumAsyncAdminModuleTest): @pytest.mark.asyncio @pytest.mark.xfail( diff --git a/tests/integration/go_ethereum/test_goethereum_ws/test_async_iterator_w3.py b/tests/integration/go_ethereum/test_goethereum_ws/test_async_iterator_w3.py index b8074bb06b..c834d12502 100644 --- a/tests/integration/go_ethereum/test_goethereum_ws/test_async_iterator_w3.py +++ b/tests/integration/go_ethereum/test_goethereum_ws/test_async_iterator_w3.py @@ -16,6 +16,7 @@ from ..common import ( GoEthereumAsyncEthModuleTest, GoEthereumAsyncNetModuleTest, + GoEthereumAsyncWeb3ModuleTest, ) from ..utils import ( wait_for_aiohttp, @@ -31,6 +32,10 @@ async def async_w3(geth_process, endpoint_uri): return w3 +class TestGoEthereumAsyncWeb3ModuleTest(GoEthereumAsyncWeb3ModuleTest): + pass + + class TestGoEthereumAsyncAdminModuleTest(GoEthereumAsyncAdminModuleTest): @pytest.mark.asyncio @pytest.mark.xfail( diff --git a/web3/_utils/batching.py b/web3/_utils/batching.py index 295a4ac6f7..9f4be4d75e 100644 --- a/web3/_utils/batching.py +++ b/web3/_utils/batching.py @@ -28,9 +28,6 @@ from web3.contract.contract import ( ContractFunction, ) -from web3.method import ( - Method, -) from web3.types import ( TFunc, TReturn, @@ -41,6 +38,9 @@ AsyncWeb3, Web3, ) + from web3.method import ( # noqa: F401 + Method, + ) from web3.providers import ( # noqa: F401 PersistentConnectionProvider, ) @@ -53,6 +53,15 @@ ) BatchRequestInformation = Tuple[Tuple["RPCEndpoint", Any], Sequence[Any]] +RPC_METHODS_UNSUPPORTED_DURING_BATCH = { + "eth_subscribe", + "eth_unsubscribe", + "eth_sendRawTransaction", + "eth_sendTransaction", + "eth_signTransaction", + "eth_sign", + "eth_signTypedData", +} class BatchRequestContextManager(Generic[TFunc]): @@ -64,12 +73,12 @@ def __init__(self, web3: Union["AsyncWeb3", "Web3"]) -> None: ] = [] def add(self, batch_payload: TReturn) -> None: - # When batching, we don't make a request. Instead, we will get the request - # information and store it in the `_requests_info` list. So we have to cast the - # apparent "request" into the BatchRequestInformation type. if isinstance(batch_payload, (ContractFunction, AsyncContractFunction)): batch_payload = batch_payload.call() # type: ignore + # When batching, we don't make a request. Instead, we will get the request + # information and store it in the `_requests_info` list. So we have to cast the + # apparent "request" into the BatchRequestInformation type. if self.web3.provider.is_async: self._async_requests_info.append( cast(Coroutine[Any, Any, BatchRequestInformation], batch_payload) @@ -78,7 +87,16 @@ def add(self, batch_payload: TReturn) -> None: self._requests_info.append(cast(BatchRequestInformation, batch_payload)) def add_mapping( - self, batch_payload: Dict[Method[Callable[..., Any]], List[Any]] + self, + batch_payload: Dict[ + Union[ + "Method[Callable[..., Any]]", + Callable[..., Any], + ContractFunction, + AsyncContractFunction, + ], + List[Any], + ], ) -> None: for method, params in batch_payload.items(): for param in params: @@ -97,7 +115,7 @@ def __exit__( self.web3.provider._is_batching = False def execute(self) -> List["RPCResponse"]: - return self.web3.manager.make_batch_request(self._requests_info) + return self.web3.manager._make_batch_request(self._requests_info) # -- async -- # @@ -122,6 +140,6 @@ async def __aexit__( provider._batch_request_counter = None async def async_execute(self) -> List["RPCResponse"]: - return await self.web3.manager.async_make_batch_request( + return await self.web3.manager._async_make_batch_request( self._async_requests_info ) diff --git a/web3/_utils/module_testing/persistent_connection_provider.py b/web3/_utils/module_testing/persistent_connection_provider.py index c05a35c275..810318ea6b 100644 --- a/web3/_utils/module_testing/persistent_connection_provider.py +++ b/web3/_utils/module_testing/persistent_connection_provider.py @@ -26,9 +26,6 @@ ) if TYPE_CHECKING: - from web3.contract import ( - AsyncContract, - ) from web3.main import ( AsyncWeb3, ) @@ -397,41 +394,3 @@ async def test_asyncio_gather_for_multiple_requests_matches_the_responses( assert isinstance(chain_id, int) assert isinstance(chain_id2, int) assert isinstance(chain_id3, int) - - @pytest.mark.asyncio - async def test_async_batch_request( - self, async_w3: "AsyncWeb3", async_math_contract: "AsyncContract" - ) -> None: - async with async_w3.manager.batch_requests() as batch: - batch.add(async_w3.eth.get_block(6)) - batch.add(async_w3.eth.get_block(4)) - batch.add(async_w3.eth.get_block(2)) - batch.add(async_w3.eth.get_block(0)) - - batch.add(async_math_contract.functions.multiply7(0)) - - batch.add_mapping( - { - async_math_contract.functions.multiply7: [1, 2, 3], - async_w3.eth.get_block: [1, 3, 5], - } - ) - - responses = await batch.async_execute() - - assert len(responses) == 11 - assert all(isinstance(response, AttributeDict) for response in responses[:3]) - assert responses[0]["number"] == 6 - assert responses[1]["number"] == 4 - assert responses[2]["number"] == 2 - assert responses[3]["number"] == 0 - - assert responses[4] == 0 - assert responses[5] == 7 - assert responses[6] == 14 - assert responses[7] == 21 - - assert all(isinstance(response, AttributeDict) for response in responses[8:]) - assert responses[8]["number"] == 1 - assert responses[9]["number"] == 3 - assert responses[10]["number"] == 5 diff --git a/web3/_utils/module_testing/web3_module.py b/web3/_utils/module_testing/web3_module.py index 13cc1eccdf..da1e814f1b 100644 --- a/web3/_utils/module_testing/web3_module.py +++ b/web3/_utils/module_testing/web3_module.py @@ -1,12 +1,14 @@ import pytest from typing import ( + TYPE_CHECKING, Any, NoReturn, Sequence, - Union, + cast, ) from eth_typing import ( + Address, ChecksumAddress, HexAddress, HexStr, @@ -26,13 +28,19 @@ from web3.contract import ( Contract, ) -from web3.datastructures import ( - AttributeDict, -) from web3.exceptions import ( InvalidAddress, + MethodNotSupported, +) +from web3.types import ( + BlockData, ) +if TYPE_CHECKING: + from web3.contract import ( # noqa: F401 + AsyncContract, + ) + class Web3ModuleTest: def test_web3_client_version(self, w3: Web3) -> None: @@ -229,16 +237,9 @@ def _check_web3_client_version(self, client_version: str) -> NoReturn: ), ), ) - @pytest.mark.parametrize( - "w3", - ( - Web3, - AsyncWeb3, - ), - ) def test_solidity_keccak( self, - w3: Union["Web3", "AsyncWeb3"], + w3: "Web3", types: Sequence[TypeStr], values: Sequence[Any], expected: HexBytes, @@ -270,16 +271,9 @@ def test_solidity_keccak( ), ), ) - @pytest.mark.parametrize( - "w3", - ( - Web3(), - AsyncWeb3(), - ), - ) def test_solidity_keccak_ens( self, - w3: Union["Web3", "AsyncWeb3"], + w3: "Web3", types: Sequence[TypeStr], values: Sequence[str], expected: HexBytes, @@ -320,8 +314,8 @@ def test_solidity_keccak_same_number_of_types_and_values( def test_is_connected(self, w3: "Web3") -> None: assert w3.is_connected() - def test_batch_request(self, w3: "Web3", math_contract: Contract) -> None: - with w3.manager.batch_requests() as batch: + def test_batch_requests(self, w3: "Web3", math_contract: Contract) -> None: + with w3.batch_requests() as batch: batch.add(w3.eth.get_block(6)) batch.add(w3.eth.get_block(4)) batch.add(w3.eth.get_block(2)) @@ -338,18 +332,132 @@ def test_batch_request(self, w3: "Web3", math_contract: Contract) -> None: responses = batch.execute() assert len(responses) == 11 - assert all(isinstance(response, AttributeDict) for response in responses[:4]) - assert responses[0]["number"] == 6 - assert responses[1]["number"] == 4 - assert responses[2]["number"] == 2 - assert responses[3]["number"] == 0 - - assert responses[4] == 0 - assert responses[5] == 7 - assert responses[6] == 14 - assert responses[7] == 21 - - assert all(isinstance(response, AttributeDict) for response in responses[8:]) - assert responses[8]["number"] == 1 - assert responses[9]["number"] == 3 - assert responses[10]["number"] == 5 + + first_four_responses: Sequence[BlockData] = cast( + Sequence[BlockData], responses[:4] + ) + assert first_four_responses[0]["number"] == 6 + assert first_four_responses[1]["number"] == 4 + assert first_four_responses[2]["number"] == 2 + assert first_four_responses[3]["number"] == 0 + + responses_five_through_eight: Sequence[int] = cast( + Sequence[int], responses[4:8] + ) + assert responses_five_through_eight[0] == 0 + assert responses_five_through_eight[1] == 7 + assert responses_five_through_eight[2] == 14 + assert responses_five_through_eight[3] == 21 + + last_three_responses: Sequence[BlockData] = cast( + Sequence[BlockData], responses[8:] + ) + assert last_three_responses[0]["number"] == 1 + assert last_three_responses[1]["number"] == 3 + assert last_three_responses[2]["number"] == 5 + + def test_batch_requests_raises_for_common_unsupported_methods( + self, w3: "Web3", math_contract: Contract + ) -> None: + with w3.batch_requests() as batch: + with pytest.raises(MethodNotSupported, match="eth_sendTransaction"): + batch.add(w3.eth.send_transaction({})) + batch.execute() + + with w3.batch_requests() as batch: + with pytest.raises(MethodNotSupported, match="eth_sendTransaction"): + batch.add(math_contract.functions.multiply7(1).transact({})) + batch.execute() + + with w3.batch_requests() as batch: + with pytest.raises(MethodNotSupported, match="eth_sendRawTransaction"): + batch.add(w3.eth.send_raw_transaction(b"")) + batch.execute() + + with w3.batch_requests() as batch: + with pytest.raises(MethodNotSupported, match="eth_sign"): + batch.add(w3.eth.sign(Address(b"\x00" * 20))) + batch.execute() + + +# -- async -- # + + +class AsyncWeb3ModuleTest(Web3ModuleTest): + # Note: Any test that overrides the synchronous test from `Web3ModuleTest` with + # an asynchronous test should have the exact same name. + + @pytest.mark.asyncio + async def test_web3_client_version(self, async_w3: AsyncWeb3) -> None: + client_version = await async_w3.client_version + self._check_web3_client_version(client_version) + + @pytest.mark.asyncio + async def test_batch_requests( + self, async_w3: AsyncWeb3, async_math_contract: "AsyncContract" + ) -> None: + async with async_w3.batch_requests() as batch: + batch.add(async_w3.eth.get_block(6)) + batch.add(async_w3.eth.get_block(4)) + batch.add(async_w3.eth.get_block(2)) + batch.add(async_w3.eth.get_block(0)) + + batch.add(async_math_contract.functions.multiply7(0)) + + batch.add_mapping( + { + async_math_contract.functions.multiply7: [1, 2, 3], + async_w3.eth.get_block: [1, 3, 5], + } + ) + + responses = await batch.async_execute() + + assert len(responses) == 11 + + first_four_responses: Sequence[BlockData] = cast( + Sequence[BlockData], responses[:4] + ) + assert first_four_responses[0]["number"] == 6 + assert first_four_responses[1]["number"] == 4 + assert first_four_responses[2]["number"] == 2 + assert first_four_responses[3]["number"] == 0 + + responses_five_through_eight: Sequence[int] = cast( + Sequence[int], responses[4:8] + ) + assert responses_five_through_eight[0] == 0 + assert responses_five_through_eight[1] == 7 + assert responses_five_through_eight[2] == 14 + assert responses_five_through_eight[3] == 21 + + last_three_responses: Sequence[BlockData] = cast( + Sequence[BlockData], responses[8:] + ) + assert last_three_responses[0]["number"] == 1 + assert last_three_responses[1]["number"] == 3 + assert last_three_responses[2]["number"] == 5 + + @pytest.mark.asyncio + async def test_batch_requests_raises_for_common_unsupported_methods( + self, async_w3: AsyncWeb3, async_math_contract: "AsyncContract" + ) -> None: + async with async_w3.batch_requests() as batch: + with pytest.raises(MethodNotSupported, match="eth_sendTransaction"): + batch.add(async_w3.eth.send_transaction({})) + await batch.async_execute() + + async with async_w3.batch_requests() as batch: + with pytest.raises(MethodNotSupported, match="eth_sendTransaction"): + batch.add(async_math_contract.functions.multiply7(1).transact({})) + await batch.async_execute() + + async with async_w3.batch_requests() as batch: + with pytest.raises(MethodNotSupported, match="eth_sendRawTransaction"): + batch.add(async_w3.eth.send_raw_transaction(b"")) + await batch.async_execute() + + async with async_w3.batch_requests() as batch: + with pytest.raises(MethodNotSupported, match="eth_sign"): + batch.add(async_w3.eth.sign(Address(b"\x00" * 20))) + await batch.async_execute() diff --git a/web3/main.py b/web3/main.py index 7bfe7dc103..22fd6cd625 100644 --- a/web3/main.py +++ b/web3/main.py @@ -34,6 +34,7 @@ TYPE_CHECKING, Any, AsyncIterator, + Callable, Dict, Generator, List, @@ -100,6 +101,9 @@ RequestManager as DefaultRequestManager, ) from web3.middleware.base import MiddlewareOnion +from web3.method import ( + Method, +) from web3.module import ( Module, ) @@ -143,6 +147,7 @@ ) if TYPE_CHECKING: + from web3._utils.batching import BatchRequestContextManager # noqa: F401 from web3._utils.empty import Empty # noqa: F401 @@ -336,6 +341,13 @@ def attach_modules( def is_encodable(self, _type: TypeStr, value: Any) -> bool: return self.codec.is_encodable(_type, value) + # -- APIs for high-level requests -- # + + def batch_requests( + self, + ) -> "BatchRequestContextManager[Method[Callable[..., Any]]]": + return self.manager._batch_requests() + class Web3(BaseWeb3): # mypy types @@ -468,7 +480,7 @@ def ens(self, new_ens: Union[AsyncENS, "Empty"]) -> None: new_ens.w3 = self # set self object reference for ``AsyncENS.w3`` self._ens = new_ens - # -- persistent connection methods -- # + # -- persistent connection methods -- # @property @persistent_connection_provider_method() diff --git a/web3/manager.py b/web3/manager.py index a46135e4b5..16f48256c0 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -387,15 +387,15 @@ async def coro_request( response, params, error_formatters, null_result_formatters ) - # -- batch requests -- # + # -- batch requests management -- # - def batch_requests(self) -> BatchRequestContextManager[Method[Callable[..., Any]]]: + def _batch_requests(self) -> BatchRequestContextManager[Method[Callable[..., Any]]]: """ Context manager for making batch requests """ return BatchRequestContextManager(self.w3) - def make_batch_request( + def _make_batch_request( self, requests_info: List[Tuple[Tuple["RPCEndpoint", Any], Sequence[Any]]] ) -> List[RPCResponse]: """ @@ -421,7 +421,7 @@ def make_batch_request( ] return list(formatted_responses) - async def async_make_batch_request( + async def _async_make_batch_request( self, requests_info: List[ Coroutine[Any, Any, Tuple[Tuple["RPCEndpoint", Any], Sequence[Any]]] diff --git a/web3/method.py b/web3/method.py index d83dc7c7ca..96b0939d09 100644 --- a/web3/method.py +++ b/web3/method.py @@ -21,6 +21,9 @@ pipe, ) +from web3._utils.batching import ( + RPC_METHODS_UNSUPPORTED_DURING_BATCH, +) from web3._utils.method_formatters import ( get_error_formatters, get_null_result_formatters, @@ -31,6 +34,7 @@ RPC, ) from web3.exceptions import ( + MethodNotSupported, Web3TypeError, Web3ValidationError, Web3ValueError, @@ -161,6 +165,11 @@ def __get__( "usually attached to a web3 instance." ) if module.w3.provider._is_batching: + if self.json_rpc_method in RPC_METHODS_UNSUPPORTED_DURING_BATCH: + raise MethodNotSupported( + f"Method `{self.json_rpc_method}` is not supported within a batch " + "request." + ) return module.retrieve_request_information(self) else: return module.retrieve_caller_fn(self)