Skip to content

Commit

Permalink
Merge branch 'main' into docs/snek
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Sep 20, 2024
2 parents 5c135eb + eac9d5e commit 528151c
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 9 deletions.
33 changes: 25 additions & 8 deletions src/ape_ethereum/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,11 +624,35 @@ def get_receipt(
)
hex_hash = HexBytes(txn_hash)

txn = {}
if transaction := kwargs.get("transaction"):
# perf: If called `send_transaction()`, we should already have the data!
txn = (
transaction
if isinstance(transaction, dict)
else transaction.model_dump(by_alias=True, mode="json")
)

private = kwargs.get("private")

try:
receipt_data = dict(
self.web3.eth.wait_for_transaction_receipt(hex_hash, timeout=timeout)
)
except TimeExhausted as err:
# Since private transactions can take longer,
# return a partial receipt instead of throwing a TimeExhausted error.
if private:
# Return with a partial receipt
data = {
"block_number": -1,
"required_confirmations": required_confirmations,
"txn_hash": txn_hash,
"status": TransactionStatusEnum.NO_ERROR,
**txn,
}
receipt = self._create_receipt(**data)
return receipt
msg_str = str(err)
if f"HexBytes('{txn_hash}')" in msg_str:
msg_str = msg_str.replace(f"HexBytes('{txn_hash}')", f"'{txn_hash}'")
Expand All @@ -641,18 +665,11 @@ def get_receipt(
network_config: dict = ecosystem_config.get(self.network.name, {})
max_retries = network_config.get("max_get_transaction_retries", DEFAULT_MAX_RETRIES_TX)

if transaction := kwargs.get("transaction"):
# perf: If called `send_transaction()`, we should already have the data!
txn = (
transaction
if isinstance(transaction, dict)
else transaction.model_dump(by_alias=True, mode="json")
)
if transaction:
if "effectiveGasPrice" in receipt_data:
receipt_data["gasPrice"] = receipt_data["effectiveGasPrice"]

else:
txn = {}
for attempt in range(max_retries):
try:
txn = dict(self.web3.eth.get_transaction(HexStr(txn_hash)))
Expand Down
48 changes: 47 additions & 1 deletion tests/functional/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from eth_utils import ValidationError, to_hex
from hexbytes import HexBytes
from requests import HTTPError
from web3.exceptions import ContractPanicError
from web3.exceptions import ContractPanicError, TimeExhausted

from ape import convert
from ape.exceptions import (
Expand Down Expand Up @@ -128,6 +128,52 @@ def test_get_receipt_exists_with_timeout(eth_tester_provider, vyper_contract_ins
assert receipt_from_provider.receiver == vyper_contract_instance.address


def test_get_receipt_ignores_timeout_when_private(
eth_tester_provider, mock_web3, vyper_contract_instance, owner
):
receipt_from_invoke = vyper_contract_instance.setNumber(889, sender=owner)
real_web3 = eth_tester_provider._web3

mock_web3.eth.wait_for_transaction_receipt.side_effect = TimeExhausted
eth_tester_provider._web3 = mock_web3
try:
receipt_from_provider = eth_tester_provider.get_receipt(
receipt_from_invoke.txn_hash, timeout=5, private=True
)

finally:
eth_tester_provider._web3 = real_web3

assert receipt_from_provider.txn_hash == receipt_from_invoke.txn_hash
assert not receipt_from_provider.confirmed


def test_get_receipt_passes_receipt_when_private(
eth_tester_provider, mock_web3, vyper_contract_instance, owner
):
receipt_from_invoke = vyper_contract_instance.setNumber(890, sender=owner)
real_web3 = eth_tester_provider._web3

mock_web3.eth.wait_for_transaction_receipt.side_effect = TimeExhausted
eth_tester_provider._web3 = mock_web3
try:
receipt_from_provider = eth_tester_provider.get_receipt(
receipt_from_invoke.txn_hash,
timeout=5,
private=True,
transaction=receipt_from_invoke.transaction,
)

finally:
eth_tester_provider._web3 = real_web3

assert receipt_from_provider.txn_hash == receipt_from_invoke.txn_hash
assert not receipt_from_provider.confirmed

# Receiver comes from the transaction.
assert receipt_from_provider.receiver == vyper_contract_instance.address


def test_get_contracts_logs_all_logs(chain, contract_instance, owner, eth_tester_provider):
start_block = chain.blocks.height
stop_block = start_block + 100
Expand Down

0 comments on commit 528151c

Please sign in to comment.