diff --git a/web3/_utils/batching.py b/web3/_utils/batching.py index 9f4be4d75e..534bd998b0 100644 --- a/web3/_utils/batching.py +++ b/web3/_utils/batching.py @@ -28,6 +28,9 @@ from web3.contract.contract import ( ContractFunction, ) +from web3.exceptions import ( + Web3ValueError, +) from web3.types import ( TFunc, TReturn, @@ -71,8 +74,31 @@ def __init__(self, web3: Union["AsyncWeb3", "Web3"]) -> None: self._async_requests_info: List[ Coroutine[Any, Any, BatchRequestInformation] ] = [] + self._start() + + def _validate_is_batching(self) -> None: + if not self.web3.provider._is_batching: + raise Web3ValueError( + "Batching is not started or batch has already been executed. Start " + "a new batch using `web3.batch_requests()`." + ) + + def _start(self) -> None: + self.web3.provider._is_batching = True + if self.web3.provider.has_persistent_connection: + provider = cast("PersistentConnectionProvider", self.web3.provider) + provider._batch_request_counter = next(copy(provider.request_counter)) + + def _cleanup(self) -> None: + if self.web3.provider._is_batching: + self.web3.provider._is_batching = False + if self.web3.provider.has_persistent_connection: + provider = cast("PersistentConnectionProvider", self.web3.provider) + provider._batch_request_counter = None def add(self, batch_payload: TReturn) -> None: + self._validate_is_batching() + if isinstance(batch_payload, (ContractFunction, AsyncContractFunction)): batch_payload = batch_payload.call() # type: ignore @@ -98,12 +124,14 @@ def add_mapping( List[Any], ], ) -> None: + self._validate_is_batching() + for method, params in batch_payload.items(): for param in params: self.add(method(param)) def __enter__(self) -> Self: - self.web3.provider._is_batching = True + self._start() return self def __exit__( @@ -112,19 +140,18 @@ def __exit__( exc_val: BaseException, exc_tb: TracebackType, ) -> None: - self.web3.provider._is_batching = False + self._cleanup() def execute(self) -> List["RPCResponse"]: - return self.web3.manager._make_batch_request(self._requests_info) + self._validate_is_batching() + responses = self.web3.manager._make_batch_request(self._requests_info) + self._cleanup() + return responses # -- async -- # async def __aenter__(self) -> Self: - provider = cast("AsyncJSONBaseProvider", self.web3.provider) - provider._is_batching = True - if provider.has_persistent_connection: - provider = cast("PersistentConnectionProvider", provider) - provider._batch_request_counter = next(copy(provider.request_counter)) + self._start() return self async def __aexit__( @@ -133,13 +160,12 @@ async def __aexit__( exc_val: BaseException, exc_tb: TracebackType, ) -> None: - provider = cast("AsyncJSONBaseProvider", self.web3.provider) - provider._is_batching = False - if provider.has_persistent_connection: - provider = cast("PersistentConnectionProvider", provider) - provider._batch_request_counter = None + self._cleanup() async def async_execute(self) -> List["RPCResponse"]: - return await self.web3.manager._async_make_batch_request( + self._validate_is_batching() + responses = await self.web3.manager._async_make_batch_request( self._async_requests_info ) + self._cleanup() + return responses diff --git a/web3/_utils/module_testing/web3_module.py b/web3/_utils/module_testing/web3_module.py index da1e814f1b..7396bf4c1e 100644 --- a/web3/_utils/module_testing/web3_module.py +++ b/web3/_utils/module_testing/web3_module.py @@ -356,6 +356,27 @@ def test_batch_requests(self, w3: "Web3", math_contract: Contract) -> None: assert last_three_responses[1]["number"] == 3 assert last_three_responses[2]["number"] == 5 + def test_batch_requests_initialized_as_object( + self, w3: "Web3", math_contract: Contract + ) -> None: + batch = w3.batch_requests() + batch.add(w3.eth.get_block(1)) + batch.add(w3.eth.get_block(2)) + batch.add(math_contract.functions.multiply7(0)) + batch.add_mapping( + {math_contract.functions.multiply7: [1, 2], w3.eth.get_block: [3, 4]} + ) + + b1, b2, m0, m1, m2, b3, b4 = batch.execute() + + assert cast(BlockData, b1)["number"] == 1 + assert cast(BlockData, b2)["number"] == 2 + assert cast(int, m0) == 0 + assert cast(int, m1) == 7 + assert cast(int, m2) == 14 + assert cast(BlockData, b3)["number"] == 3 + assert cast(BlockData, b4)["number"] == 4 + def test_batch_requests_raises_for_common_unsupported_methods( self, w3: "Web3", math_contract: Contract ) -> None: @@ -438,6 +459,31 @@ async def test_batch_requests( assert last_three_responses[1]["number"] == 3 assert last_three_responses[2]["number"] == 5 + @pytest.mark.asyncio + async def test_batch_requests_initialized_as_object( + self, async_w3: AsyncWeb3, async_math_contract: "AsyncContract" + ) -> None: + batch = async_w3.batch_requests() + batch.add(async_w3.eth.get_block(1)) + batch.add(async_w3.eth.get_block(2)) + batch.add(async_math_contract.functions.multiply7(0)) + batch.add_mapping( + { + async_math_contract.functions.multiply7: [1, 2], + async_w3.eth.get_block: [3, 4], + } + ) + + b1, b2, m0, m1, m2, b3, b4 = await batch.async_execute() + + assert cast(BlockData, b1)["number"] == 1 + assert cast(BlockData, b2)["number"] == 2 + assert cast(int, m0) == 0 + assert cast(int, m1) == 7 + assert cast(int, m2) == 14 + assert cast(BlockData, b3)["number"] == 3 + assert cast(BlockData, b4)["number"] == 4 + @pytest.mark.asyncio async def test_batch_requests_raises_for_common_unsupported_methods( self, async_w3: AsyncWeb3, async_math_contract: "AsyncContract"