Skip to content

Coin selection fix #445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions integration-test/test/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,21 @@ class TestBase:
payment_key_pair = PaymentKeyPair.generate()
stake_key_pair = StakeKeyPair.generate()

@retry(tries=TEST_RETRIES, delay=3)
def assert_output(self, target_address, target_output):
@retry(tries=10, delay=3)
def assert_output(self, target_address, target):
utxos = self.chain_context.utxos(target_address)
found = False

for utxo in utxos:
output = utxo.output
if output == target_output:
found = True
if isinstance(target, UTxO):
if utxo == target:
found = True
if isinstance(target, TransactionOutput):
if utxo.output == target:
found = True
if isinstance(target, TransactionId):
if utxo.input.transaction_id == target:
found = True

assert found, f"Cannot find target UTxO in address: {target_address}"

Expand All @@ -84,4 +90,5 @@ def fund(self, source_address, source_key, target_address, amount=5000000):
print(signed_tx.to_cbor_hex())
print("############### Submitting transaction ###############")
self.chain_context.submit_tx(signed_tx)
self.assert_output(target_address, target_output=output)
target_utxo = UTxO(TransactionInput(signed_tx.id, 0), output)
self.assert_output(target_address, target_utxo)
10 changes: 6 additions & 4 deletions integration-test/test/test_mint.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def load_or_create_key_pair(base_dir, base_name):
@retry(tries=TEST_RETRIES, backoff=1.3, delay=2, jitter=(0, 10))
def test_mint_nft_with_script(self):
address = Address(self.payment_vkey.hash(), network=self.NETWORK)
# Create a collateral
self.fund(address, self.payment_skey, address)

with open("./plutus_scripts/fortytwoV2.plutus", "r") as f:
script_hex = f.read()
Expand Down Expand Up @@ -229,13 +231,13 @@ def test_mint_nft_with_script(self):
nft_output = TransactionOutput(address, Value(min_val, my_nft))
builder.add_output(nft_output)

# Create a collateral
self.fund(address, self.payment_skey, address)

non_nft_utxo = None
for utxo in self.chain_context.utxos(address):
# multi_asset should be empty for collateral utxo
if not utxo.output.amount.multi_asset:
if (
not utxo.output.amount.multi_asset
and utxo.output.amount.coin >= 5000000
):
non_nft_utxo = utxo
break

Expand Down
5 changes: 3 additions & 2 deletions integration-test/test/test_plutus.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,8 @@ def test_plutus_v3_unroll(self):

builder = TransactionBuilder(self.chain_context)
builder.add_input_address(giver_address)
builder.add_output(TransactionOutput(script_address, 50000000, datum=Unit()))
output = TransactionOutput(script_address, 50000000, datum=Unit())
builder.add_output(output)

signed_tx = builder.build_and_sign([self.payment_skey], giver_address)

Expand All @@ -472,7 +473,7 @@ def test_plutus_v3_unroll(self):
print(signed_tx.to_cbor_hex())
print("############### Submitting transaction ###############")
self.chain_context.submit_tx(signed_tx)
time.sleep(3)
time.sleep(6)

# ----------- Taker take ---------------

Expand Down
25 changes: 14 additions & 11 deletions pycardano/coinselection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import random
from copy import deepcopy
from typing import Iterable, List, Optional, Tuple

from pycardano.address import Address
Expand Down Expand Up @@ -36,6 +37,7 @@ def select(
max_input_count: Optional[int] = None,
include_max_fee: Optional[bool] = True,
respect_min_utxo: Optional[bool] = True,
existing_amount: Optional[Value] = None,
) -> Tuple[List[UTxO], Value]:
"""From an input list of UTxOs, select a subset of UTxOs whose sum (including ADA and multi-assets)
is equal to or larger than the sum of a set of outputs.
Expand All @@ -50,6 +52,7 @@ def select(
respect_min_utxo (bool): Respect minimum amount of ADA required to hold a multi-asset bundle in the change.
Defaults to True. If disabled, the selection will not add addition amount of ADA to change even
when the amount is too small to hold a multi-asset bundle.
existing_amount (Value): A starting amount already existed before selection. Defaults to 0.

Returns:
Tuple[List[UTxO], Value]: A tuple containing:
Expand Down Expand Up @@ -83,6 +86,7 @@ def select(
max_input_count: Optional[int] = None,
include_max_fee: Optional[bool] = True,
respect_min_utxo: Optional[bool] = True,
existing_amount: Optional[Value] = None,
) -> Tuple[List[UTxO], Value]:
available: List[UTxO] = sorted(utxos, key=lambda utxo: utxo.output.lovelace)
max_fee = max_tx_fee(context) if include_max_fee else 0
Expand All @@ -91,15 +95,14 @@ def select(
total_requested += o.amount

selected = []
selected_amount = Value()
selected_amount = existing_amount if existing_amount is not None else Value()

while not total_requested <= selected_amount:
if not available:
raise InsufficientUTxOBalanceException("UTxO Balance insufficient!")
to_add = available.pop()
selected.append(to_add)
selected_amount += to_add.output.amount

if max_input_count and len(selected) > max_input_count:
raise MaxInputCountExceededException(
f"Max input count: {max_input_count} exceeded!"
Expand All @@ -108,9 +111,8 @@ def select(
if respect_min_utxo:
change = selected_amount - total_requested
min_change_amount = min_lovelace_post_alonzo(
TransactionOutput(_FAKE_ADDR, change), context
TransactionOutput(_FAKE_ADDR, deepcopy(change)), context
)

if change.coin < min_change_amount:
additional, _ = self.select(
available,
Expand All @@ -127,7 +129,6 @@ def select(
for u in additional:
selected.append(u)
selected_amount += u.output.amount

return selected, selected_amount - total_requested


Expand Down Expand Up @@ -218,10 +219,9 @@ def _find_diff_by_former(a: Value, b: Value) -> int:
else:
policy_id = list(a.multi_asset.keys())[0]
asset_name = list(a.multi_asset[policy_id].keys())[0]
return (
a.multi_asset[policy_id][asset_name]
- b.multi_asset[policy_id][asset_name]
)
return a.multi_asset[policy_id].get(asset_name, 0) - b.multi_asset[
policy_id
].get(asset_name, 0)

def _improve(
self,
Expand Down Expand Up @@ -272,6 +272,7 @@ def select(
max_input_count: Optional[int] = None,
include_max_fee: Optional[bool] = True,
respect_min_utxo: Optional[bool] = True,
existing_amount: Optional[Value] = None,
) -> Tuple[List[UTxO], Value]:
# Shallow copy the list
remaining = list(utxos)
Expand All @@ -281,11 +282,13 @@ def select(
request_sum += o.amount

assets = self._split_by_asset(request_sum)

request_sorted = sorted(assets, key=self._get_single_asset_val, reverse=True)

# Phase 1 - random select
selected: List[UTxO] = []
selected_amount = Value()
selected_amount = existing_amount if existing_amount is not None else Value()

for r in request_sorted:
self._random_select_subset(r, remaining, selected, selected_amount)
if max_input_count and len(selected) > max_input_count:
Expand Down Expand Up @@ -315,7 +318,7 @@ def select(
if respect_min_utxo:
change = selected_amount - request_sum
min_change_amount = min_lovelace_post_alonzo(
TransactionOutput(_FAKE_ADDR, change), context
TransactionOutput(_FAKE_ADDR, deepcopy(change)), context
)

if change.coin < min_change_amount:
Expand Down
38 changes: 37 additions & 1 deletion pycardano/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ def __le__(self, other: Asset) -> bool:
return False
return True

def __lt__(self, other: Asset):
return self <= other and self != other

def __ge__(self, other: Asset) -> bool:
for n in other:
if n not in self or self[n] < other[n]:
return False
return True

def __gt__(self, other: Asset) -> bool:
return self >= other and self != other

@classmethod
@limit_primitive_type(dict)
def from_primitive(cls: Type[DictBase], value: dict) -> DictBase:
Expand Down Expand Up @@ -191,12 +203,28 @@ def __eq__(self, other):
return False
return True

def __ge__(self, other: MultiAsset) -> bool:
for n in other:
if n not in self:
return False
if not self[n] >= other[n]:
return False
return True

def __gt__(self, other: MultiAsset) -> bool:
return self >= other and self != other

def __le__(self, other: MultiAsset):
for p in self:
if p not in other or not self[p] <= other[p]:
if p not in other:
return False
if not self[p] <= other[p]:
return False
return True

def __lt__(self, other: MultiAsset):
return self <= other and self != other

def filter(
self, criteria=Callable[[ScriptHash, AssetName, int], bool]
) -> MultiAsset:
Expand Down Expand Up @@ -297,6 +325,14 @@ def __le__(self, other: Union[Value, int]):
def __lt__(self, other: Union[Value, int]):
return self <= other and self != other

def __ge__(self, other: Union[Value, int]):
if isinstance(other, int):
other = Value(other)
return self.coin >= other.coin and self.multi_asset >= other.multi_asset

def __gt__(self, other: Union[Value, int]):
return self >= other and self != other

def to_shallow_primitive(self):
if self.multi_asset:
return super().to_shallow_primitive()
Expand Down
20 changes: 16 additions & 4 deletions pycardano/txbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class TransactionBuilder:
context: ChainContext

utxo_selectors: List[UTxOSelector] = field(
default_factory=lambda: [RandomImproveMultiAsset(), LargestFirstSelector()]
default_factory=lambda: [LargestFirstSelector(), RandomImproveMultiAsset()]
)

execution_memory_buffer: float = 0.2
Expand Down Expand Up @@ -641,8 +641,13 @@ def _calc_change(

provided.coin -= self._get_total_key_deposit()
provided.coin -= self._get_total_proposal_deposit()

if not requested < provided:
provided.multi_asset.filter(
lambda p, n, v: p in requested.multi_asset and n in requested.multi_asset[p]
)
if (
provided.coin < requested.coin
or requested.multi_asset > provided.multi_asset
):
raise InvalidTransactionException(
f"The input UTxOs cannot cover the transaction outputs and tx fee. \n"
f"Inputs: {inputs} \n"
Expand Down Expand Up @@ -733,6 +738,7 @@ def _merge_changes(changes):

# Set fee to max
self.fee = self._estimate_fee()

changes = self._calc_change(
self.fee,
self.inputs,
Expand Down Expand Up @@ -1344,10 +1350,15 @@ def build(

unfulfilled_amount = requested_amount - trimmed_selected_amount

remaining = trimmed_selected_amount - requested_amount
remaining.multi_asset = remaining.multi_asset.filter(lambda p, n, v: v > 0)
remaining.coin = max(0, remaining.coin)

if change_address is not None and not can_merge_change:
# If change address is provided and remainder is smaller than minimum ADA required in change,
# we need to select additional UTxOs available from the address
if unfulfilled_amount.coin < 0:

unfulfilled_amount.coin = max(
0,
unfulfilled_amount.coin
Expand Down Expand Up @@ -1401,11 +1412,12 @@ def build(
self.context,
include_max_fee=False,
respect_min_utxo=not can_merge_change,
existing_amount=remaining,
)

for s in selected:
selected_amount += s.output.amount
selected_utxos.append(s)

break

except UTxOSelectionException as e:
Expand Down
Loading