Skip to content

Commit

Permalink
Refactor batch logic to allow for object assignment as per #832
Browse files Browse the repository at this point in the history
  • Loading branch information
fselmo committed Apr 30, 2024
1 parent 6eea4db commit 809c2bd
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 14 deletions.
54 changes: 40 additions & 14 deletions web3/_utils/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from web3.contract.contract import (
ContractFunction,
)
from web3.exceptions import (
Web3ValueError,
)
from web3.types import (
TFunc,
TReturn,
Expand Down Expand Up @@ -71,8 +74,31 @@ def __init__(self, web3: Union["AsyncWeb3", "Web3"]) -> None:
self._async_requests_info: List[
Coroutine[Any, Any, BatchRequestInformation]
] = []
self._start()

def _validate_is_batching(self) -> None:
if not self.web3.provider._is_batching:
raise Web3ValueError(
"Batching is not started or batch has already been executed. Start "
"a new batch using `web3.batch_requests()`."
)

def _start(self) -> None:
self.web3.provider._is_batching = True
if self.web3.provider.has_persistent_connection:
provider = cast("PersistentConnectionProvider", self.web3.provider)
provider._batch_request_counter = next(copy(provider.request_counter))

def _cleanup(self) -> None:
if self.web3.provider._is_batching:
self.web3.provider._is_batching = False
if self.web3.provider.has_persistent_connection:
provider = cast("PersistentConnectionProvider", self.web3.provider)
provider._batch_request_counter = None

def add(self, batch_payload: TReturn) -> None:
self._validate_is_batching()

if isinstance(batch_payload, (ContractFunction, AsyncContractFunction)):
batch_payload = batch_payload.call() # type: ignore

Expand All @@ -98,12 +124,14 @@ def add_mapping(
List[Any],
],
) -> None:
self._validate_is_batching()

for method, params in batch_payload.items():
for param in params:
self.add(method(param))

def __enter__(self) -> Self:
self.web3.provider._is_batching = True
self._start()
return self

def __exit__(
Expand All @@ -112,19 +140,18 @@ def __exit__(
exc_val: BaseException,
exc_tb: TracebackType,
) -> None:
self.web3.provider._is_batching = False
self._cleanup()

def execute(self) -> List["RPCResponse"]:
return self.web3.manager._make_batch_request(self._requests_info)
self._validate_is_batching()
responses = self.web3.manager._make_batch_request(self._requests_info)
self._cleanup()
return responses

# -- async -- #

async def __aenter__(self) -> Self:
provider = cast("AsyncJSONBaseProvider", self.web3.provider)
provider._is_batching = True
if provider.has_persistent_connection:
provider = cast("PersistentConnectionProvider", provider)
provider._batch_request_counter = next(copy(provider.request_counter))
self._start()
return self

async def __aexit__(
Expand All @@ -133,13 +160,12 @@ async def __aexit__(
exc_val: BaseException,
exc_tb: TracebackType,
) -> None:
provider = cast("AsyncJSONBaseProvider", self.web3.provider)
provider._is_batching = False
if provider.has_persistent_connection:
provider = cast("PersistentConnectionProvider", provider)
provider._batch_request_counter = None
self._cleanup()

async def async_execute(self) -> List["RPCResponse"]:
return await self.web3.manager._async_make_batch_request(
self._validate_is_batching()
responses = await self.web3.manager._async_make_batch_request(
self._async_requests_info
)
self._cleanup()
return responses
46 changes: 46 additions & 0 deletions web3/_utils/module_testing/web3_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,27 @@ def test_batch_requests(self, w3: "Web3", math_contract: Contract) -> None:
assert last_three_responses[1]["number"] == 3
assert last_three_responses[2]["number"] == 5

def test_batch_requests_initialized_as_object(
self, w3: "Web3", math_contract: Contract
) -> None:
batch = w3.batch_requests()
batch.add(w3.eth.get_block(1))
batch.add(w3.eth.get_block(2))
batch.add(math_contract.functions.multiply7(0))
batch.add_mapping(
{math_contract.functions.multiply7: [1, 2], w3.eth.get_block: [3, 4]}
)

b1, b2, m0, m1, m2, b3, b4 = batch.execute()

assert cast(BlockData, b1)["number"] == 1
assert cast(BlockData, b2)["number"] == 2
assert cast(int, m0) == 0
assert cast(int, m1) == 7
assert cast(int, m2) == 14
assert cast(BlockData, b3)["number"] == 3
assert cast(BlockData, b4)["number"] == 4

def test_batch_requests_raises_for_common_unsupported_methods(
self, w3: "Web3", math_contract: Contract
) -> None:
Expand Down Expand Up @@ -438,6 +459,31 @@ async def test_batch_requests(
assert last_three_responses[1]["number"] == 3
assert last_three_responses[2]["number"] == 5

@pytest.mark.asyncio
async def test_batch_requests_initialized_as_object(
self, async_w3: AsyncWeb3, async_math_contract: "AsyncContract"
) -> None:
batch = async_w3.batch_requests()
batch.add(async_w3.eth.get_block(1))
batch.add(async_w3.eth.get_block(2))
batch.add(async_math_contract.functions.multiply7(0))
batch.add_mapping(
{
async_math_contract.functions.multiply7: [1, 2],
async_w3.eth.get_block: [3, 4],
}
)

b1, b2, m0, m1, m2, b3, b4 = await batch.async_execute()

assert cast(BlockData, b1)["number"] == 1
assert cast(BlockData, b2)["number"] == 2
assert cast(int, m0) == 0
assert cast(int, m1) == 7
assert cast(int, m2) == 14
assert cast(BlockData, b3)["number"] == 3
assert cast(BlockData, b4)["number"] == 4

@pytest.mark.asyncio
async def test_batch_requests_raises_for_common_unsupported_methods(
self, async_w3: AsyncWeb3, async_math_contract: "AsyncContract"
Expand Down

0 comments on commit 809c2bd

Please sign in to comment.