Skip to content

Commit

Permalink
Move provider type validation to _batch_requests method.
Browse files Browse the repository at this point in the history
  • Loading branch information
fselmo committed Apr 30, 2024
1 parent 24e15cb commit f7114f4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
9 changes: 7 additions & 2 deletions tests/integration/test_ethereum_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,13 @@ class TestEthereumTesterWeb3Module(Web3ModuleTest):
def _check_web3_client_version(self, client_version):
assert client_version.startswith("EthereumTester/")

test_batch_request = not_implemented(
Web3ModuleTest.test_batch_request, Web3TypeError
test_batch_requests = not_implemented(
Web3ModuleTest.test_batch_requests, Web3TypeError
)

test_batch_requests_raises_for_common_unsupported_methods = not_implemented(
Web3ModuleTest.test_batch_requests_raises_for_common_unsupported_methods,
Web3TypeError,
)


Expand Down
18 changes: 6 additions & 12 deletions web3/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def _batch_requests(self) -> BatchRequestContextManager[Method[Callable[..., Any
"""
Context manager for making batch requests
"""
if not isinstance(self.provider, (AsyncJSONBaseProvider, JSONBaseProvider)):
raise Web3TypeError("Batch requests are not supported by this provider.")
return BatchRequestContextManager(self.w3)

def _make_batch_request(
Expand All @@ -401,12 +403,8 @@ def _make_batch_request(
"""
Make a batch request using the provider
"""
if not isinstance(self.provider, JSONBaseProvider):
raise Web3TypeError(
"Only JSONBaseProvider classes support batched requests."
)

request_func = self.provider.batch_request_func(
provider = cast(JSONBaseProvider, self.provider)
request_func = provider.batch_request_func(
cast("Web3", self.w3), cast("MiddlewareOnion", self.middleware_onion)
)
responses = request_func(
Expand All @@ -430,12 +428,8 @@ async def _async_make_batch_request(
"""
Make an asynchronous batch request using the provider
"""
if not isinstance(self.provider, AsyncJSONBaseProvider):
raise Web3TypeError(
"Only AsyncJSONBaseProvider classes support batched requests."
)

request_func = await self.provider.batch_request_func(
provider = cast(AsyncJSONBaseProvider, self.provider)
request_func = await provider.batch_request_func(
cast("AsyncWeb3", self.w3),
cast("MiddlewareOnion", self.middleware_onion),
)
Expand Down

0 comments on commit f7114f4

Please sign in to comment.