diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index a9f2807a81..89aee224f8 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -624,11 +624,35 @@ def get_receipt( ) hex_hash = HexBytes(txn_hash) + txn = {} + if transaction := kwargs.get("transaction"): + # perf: If called `send_transaction()`, we should already have the data! + txn = ( + transaction + if isinstance(transaction, dict) + else transaction.model_dump(by_alias=True, mode="json") + ) + + private = kwargs.get("private") + try: receipt_data = dict( self.web3.eth.wait_for_transaction_receipt(hex_hash, timeout=timeout) ) except TimeExhausted as err: + # Since private transactions can take longer, + # return a partial receipt instead of throwing a TimeExhausted error. + if private: + # Return with a partial receipt + data = { + "block_number": -1, + "required_confirmations": required_confirmations, + "txn_hash": txn_hash, + "status": TransactionStatusEnum.NO_ERROR, + **txn, + } + receipt = self._create_receipt(**data) + return receipt msg_str = str(err) if f"HexBytes('{txn_hash}')" in msg_str: msg_str = msg_str.replace(f"HexBytes('{txn_hash}')", f"'{txn_hash}'") @@ -641,18 +665,11 @@ def get_receipt( network_config: dict = ecosystem_config.get(self.network.name, {}) max_retries = network_config.get("max_get_transaction_retries", DEFAULT_MAX_RETRIES_TX) - if transaction := kwargs.get("transaction"): - # perf: If called `send_transaction()`, we should already have the data! - txn = ( - transaction - if isinstance(transaction, dict) - else transaction.model_dump(by_alias=True, mode="json") - ) + if transaction: if "effectiveGasPrice" in receipt_data: receipt_data["gasPrice"] = receipt_data["effectiveGasPrice"] else: - txn = {} for attempt in range(max_retries): try: txn = dict(self.web3.eth.get_transaction(HexStr(txn_hash))) diff --git a/tests/functional/test_provider.py b/tests/functional/test_provider.py index e30ae84bbd..ed8f8ebe03 100644 --- a/tests/functional/test_provider.py +++ b/tests/functional/test_provider.py @@ -9,7 +9,7 @@ from eth_utils import ValidationError, to_hex from hexbytes import HexBytes from requests import HTTPError -from web3.exceptions import ContractPanicError +from web3.exceptions import ContractPanicError, TimeExhausted from ape import convert from ape.exceptions import ( @@ -128,6 +128,52 @@ def test_get_receipt_exists_with_timeout(eth_tester_provider, vyper_contract_ins assert receipt_from_provider.receiver == vyper_contract_instance.address +def test_get_receipt_ignores_timeout_when_private( + eth_tester_provider, mock_web3, vyper_contract_instance, owner +): + receipt_from_invoke = vyper_contract_instance.setNumber(889, sender=owner) + real_web3 = eth_tester_provider._web3 + + mock_web3.eth.wait_for_transaction_receipt.side_effect = TimeExhausted + eth_tester_provider._web3 = mock_web3 + try: + receipt_from_provider = eth_tester_provider.get_receipt( + receipt_from_invoke.txn_hash, timeout=5, private=True + ) + + finally: + eth_tester_provider._web3 = real_web3 + + assert receipt_from_provider.txn_hash == receipt_from_invoke.txn_hash + assert not receipt_from_provider.confirmed + + +def test_get_receipt_passes_receipt_when_private( + eth_tester_provider, mock_web3, vyper_contract_instance, owner +): + receipt_from_invoke = vyper_contract_instance.setNumber(890, sender=owner) + real_web3 = eth_tester_provider._web3 + + mock_web3.eth.wait_for_transaction_receipt.side_effect = TimeExhausted + eth_tester_provider._web3 = mock_web3 + try: + receipt_from_provider = eth_tester_provider.get_receipt( + receipt_from_invoke.txn_hash, + timeout=5, + private=True, + transaction=receipt_from_invoke.transaction, + ) + + finally: + eth_tester_provider._web3 = real_web3 + + assert receipt_from_provider.txn_hash == receipt_from_invoke.txn_hash + assert not receipt_from_provider.confirmed + + # Receiver comes from the transaction. + assert receipt_from_provider.receiver == vyper_contract_instance.address + + def test_get_contracts_logs_all_logs(chain, contract_instance, owner, eth_tester_provider): start_block = chain.blocks.height stop_block = start_block + 100