Skip to content

Commit 446a0e8

Browse files
authored
Merge pull request #445 from tacmota/coin_selection_fix
Coin selection fix
2 parents d250316 + 793e861 commit 446a0e8

File tree

9 files changed

+221
-42
lines changed

9 files changed

+221
-42
lines changed

integration-test/test/base.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,21 @@ class TestBase:
5757
payment_key_pair = PaymentKeyPair.generate()
5858
stake_key_pair = StakeKeyPair.generate()
5959

60-
@retry(tries=TEST_RETRIES, delay=3)
61-
def assert_output(self, target_address, target_output):
60+
@retry(tries=10, delay=3)
61+
def assert_output(self, target_address, target):
6262
utxos = self.chain_context.utxos(target_address)
6363
found = False
6464

6565
for utxo in utxos:
66-
output = utxo.output
67-
if output == target_output:
68-
found = True
66+
if isinstance(target, UTxO):
67+
if utxo == target:
68+
found = True
69+
if isinstance(target, TransactionOutput):
70+
if utxo.output == target:
71+
found = True
72+
if isinstance(target, TransactionId):
73+
if utxo.input.transaction_id == target:
74+
found = True
6975

7076
assert found, f"Cannot find target UTxO in address: {target_address}"
7177

@@ -84,4 +90,5 @@ def fund(self, source_address, source_key, target_address, amount=5000000):
8490
print(signed_tx.to_cbor_hex())
8591
print("############### Submitting transaction ###############")
8692
self.chain_context.submit_tx(signed_tx)
87-
self.assert_output(target_address, target_output=output)
93+
target_utxo = UTxO(TransactionInput(signed_tx.id, 0), output)
94+
self.assert_output(target_address, target_utxo)

integration-test/test/test_mint.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def load_or_create_key_pair(base_dir, base_name):
169169
@retry(tries=TEST_RETRIES, backoff=1.3, delay=2, jitter=(0, 10))
170170
def test_mint_nft_with_script(self):
171171
address = Address(self.payment_vkey.hash(), network=self.NETWORK)
172+
# Create a collateral
173+
self.fund(address, self.payment_skey, address)
172174

173175
with open("./plutus_scripts/fortytwoV2.plutus", "r") as f:
174176
script_hex = f.read()
@@ -229,13 +231,13 @@ def test_mint_nft_with_script(self):
229231
nft_output = TransactionOutput(address, Value(min_val, my_nft))
230232
builder.add_output(nft_output)
231233

232-
# Create a collateral
233-
self.fund(address, self.payment_skey, address)
234-
235234
non_nft_utxo = None
236235
for utxo in self.chain_context.utxos(address):
237236
# multi_asset should be empty for collateral utxo
238-
if not utxo.output.amount.multi_asset:
237+
if (
238+
not utxo.output.amount.multi_asset
239+
and utxo.output.amount.coin >= 5000000
240+
):
239241
non_nft_utxo = utxo
240242
break
241243

integration-test/test/test_plutus.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,8 @@ def test_plutus_v3_unroll(self):
463463

464464
builder = TransactionBuilder(self.chain_context)
465465
builder.add_input_address(giver_address)
466-
builder.add_output(TransactionOutput(script_address, 50000000, datum=Unit()))
466+
output = TransactionOutput(script_address, 50000000, datum=Unit())
467+
builder.add_output(output)
467468

468469
signed_tx = builder.build_and_sign([self.payment_skey], giver_address)
469470

@@ -472,7 +473,7 @@ def test_plutus_v3_unroll(self):
472473
print(signed_tx.to_cbor_hex())
473474
print("############### Submitting transaction ###############")
474475
self.chain_context.submit_tx(signed_tx)
475-
time.sleep(3)
476+
time.sleep(6)
476477

477478
# ----------- Taker take ---------------
478479

pycardano/coinselection.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import random
6+
from copy import deepcopy
67
from typing import Iterable, List, Optional, Tuple
78

89
from pycardano.address import Address
@@ -36,6 +37,7 @@ def select(
3637
max_input_count: Optional[int] = None,
3738
include_max_fee: Optional[bool] = True,
3839
respect_min_utxo: Optional[bool] = True,
40+
existing_amount: Optional[Value] = None,
3941
) -> Tuple[List[UTxO], Value]:
4042
"""From an input list of UTxOs, select a subset of UTxOs whose sum (including ADA and multi-assets)
4143
is equal to or larger than the sum of a set of outputs.
@@ -50,6 +52,7 @@ def select(
5052
respect_min_utxo (bool): Respect minimum amount of ADA required to hold a multi-asset bundle in the change.
5153
Defaults to True. If disabled, the selection will not add addition amount of ADA to change even
5254
when the amount is too small to hold a multi-asset bundle.
55+
existing_amount (Value): A starting amount already existed before selection. Defaults to 0.
5356
5457
Returns:
5558
Tuple[List[UTxO], Value]: A tuple containing:
@@ -83,6 +86,7 @@ def select(
8386
max_input_count: Optional[int] = None,
8487
include_max_fee: Optional[bool] = True,
8588
respect_min_utxo: Optional[bool] = True,
89+
existing_amount: Optional[Value] = None,
8690
) -> Tuple[List[UTxO], Value]:
8791
available: List[UTxO] = sorted(utxos, key=lambda utxo: utxo.output.lovelace)
8892
max_fee = max_tx_fee(context) if include_max_fee else 0
@@ -91,15 +95,14 @@ def select(
9195
total_requested += o.amount
9296

9397
selected = []
94-
selected_amount = Value()
98+
selected_amount = existing_amount if existing_amount is not None else Value()
9599

96100
while not total_requested <= selected_amount:
97101
if not available:
98102
raise InsufficientUTxOBalanceException("UTxO Balance insufficient!")
99103
to_add = available.pop()
100104
selected.append(to_add)
101105
selected_amount += to_add.output.amount
102-
103106
if max_input_count and len(selected) > max_input_count:
104107
raise MaxInputCountExceededException(
105108
f"Max input count: {max_input_count} exceeded!"
@@ -108,9 +111,8 @@ def select(
108111
if respect_min_utxo:
109112
change = selected_amount - total_requested
110113
min_change_amount = min_lovelace_post_alonzo(
111-
TransactionOutput(_FAKE_ADDR, change), context
114+
TransactionOutput(_FAKE_ADDR, deepcopy(change)), context
112115
)
113-
114116
if change.coin < min_change_amount:
115117
additional, _ = self.select(
116118
available,
@@ -127,7 +129,6 @@ def select(
127129
for u in additional:
128130
selected.append(u)
129131
selected_amount += u.output.amount
130-
131132
return selected, selected_amount - total_requested
132133

133134

@@ -218,10 +219,9 @@ def _find_diff_by_former(a: Value, b: Value) -> int:
218219
else:
219220
policy_id = list(a.multi_asset.keys())[0]
220221
asset_name = list(a.multi_asset[policy_id].keys())[0]
221-
return (
222-
a.multi_asset[policy_id][asset_name]
223-
- b.multi_asset[policy_id][asset_name]
224-
)
222+
return a.multi_asset[policy_id].get(asset_name, 0) - b.multi_asset[
223+
policy_id
224+
].get(asset_name, 0)
225225

226226
def _improve(
227227
self,
@@ -272,6 +272,7 @@ def select(
272272
max_input_count: Optional[int] = None,
273273
include_max_fee: Optional[bool] = True,
274274
respect_min_utxo: Optional[bool] = True,
275+
existing_amount: Optional[Value] = None,
275276
) -> Tuple[List[UTxO], Value]:
276277
# Shallow copy the list
277278
remaining = list(utxos)
@@ -281,11 +282,13 @@ def select(
281282
request_sum += o.amount
282283

283284
assets = self._split_by_asset(request_sum)
285+
284286
request_sorted = sorted(assets, key=self._get_single_asset_val, reverse=True)
285287

286288
# Phase 1 - random select
287289
selected: List[UTxO] = []
288-
selected_amount = Value()
290+
selected_amount = existing_amount if existing_amount is not None else Value()
291+
289292
for r in request_sorted:
290293
self._random_select_subset(r, remaining, selected, selected_amount)
291294
if max_input_count and len(selected) > max_input_count:
@@ -315,7 +318,7 @@ def select(
315318
if respect_min_utxo:
316319
change = selected_amount - request_sum
317320
min_change_amount = min_lovelace_post_alonzo(
318-
TransactionOutput(_FAKE_ADDR, change), context
321+
TransactionOutput(_FAKE_ADDR, deepcopy(change)), context
319322
)
320323

321324
if change.coin < min_change_amount:

pycardano/transaction.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,18 @@ def __le__(self, other: Asset) -> bool:
131131
return False
132132
return True
133133

134+
def __lt__(self, other: Asset):
135+
return self <= other and self != other
136+
137+
def __ge__(self, other: Asset) -> bool:
138+
for n in other:
139+
if n not in self or self[n] < other[n]:
140+
return False
141+
return True
142+
143+
def __gt__(self, other: Asset) -> bool:
144+
return self >= other and self != other
145+
134146
@classmethod
135147
@limit_primitive_type(dict)
136148
def from_primitive(cls: Type[DictBase], value: dict) -> DictBase:
@@ -191,12 +203,28 @@ def __eq__(self, other):
191203
return False
192204
return True
193205

206+
def __ge__(self, other: MultiAsset) -> bool:
207+
for n in other:
208+
if n not in self:
209+
return False
210+
if not self[n] >= other[n]:
211+
return False
212+
return True
213+
214+
def __gt__(self, other: MultiAsset) -> bool:
215+
return self >= other and self != other
216+
194217
def __le__(self, other: MultiAsset):
195218
for p in self:
196-
if p not in other or not self[p] <= other[p]:
219+
if p not in other:
220+
return False
221+
if not self[p] <= other[p]:
197222
return False
198223
return True
199224

225+
def __lt__(self, other: MultiAsset):
226+
return self <= other and self != other
227+
200228
def filter(
201229
self, criteria=Callable[[ScriptHash, AssetName, int], bool]
202230
) -> MultiAsset:
@@ -297,6 +325,14 @@ def __le__(self, other: Union[Value, int]):
297325
def __lt__(self, other: Union[Value, int]):
298326
return self <= other and self != other
299327

328+
def __ge__(self, other: Union[Value, int]):
329+
if isinstance(other, int):
330+
other = Value(other)
331+
return self.coin >= other.coin and self.multi_asset >= other.multi_asset
332+
333+
def __gt__(self, other: Union[Value, int]):
334+
return self >= other and self != other
335+
300336
def to_shallow_primitive(self):
301337
if self.multi_asset:
302338
return super().to_shallow_primitive()

pycardano/txbuilder.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class TransactionBuilder:
105105
context: ChainContext
106106

107107
utxo_selectors: List[UTxOSelector] = field(
108-
default_factory=lambda: [RandomImproveMultiAsset(), LargestFirstSelector()]
108+
default_factory=lambda: [LargestFirstSelector(), RandomImproveMultiAsset()]
109109
)
110110

111111
execution_memory_buffer: float = 0.2
@@ -641,8 +641,13 @@ def _calc_change(
641641

642642
provided.coin -= self._get_total_key_deposit()
643643
provided.coin -= self._get_total_proposal_deposit()
644-
645-
if not requested < provided:
644+
provided.multi_asset.filter(
645+
lambda p, n, v: p in requested.multi_asset and n in requested.multi_asset[p]
646+
)
647+
if (
648+
provided.coin < requested.coin
649+
or requested.multi_asset > provided.multi_asset
650+
):
646651
raise InvalidTransactionException(
647652
f"The input UTxOs cannot cover the transaction outputs and tx fee. \n"
648653
f"Inputs: {inputs} \n"
@@ -733,6 +738,7 @@ def _merge_changes(changes):
733738

734739
# Set fee to max
735740
self.fee = self._estimate_fee()
741+
736742
changes = self._calc_change(
737743
self.fee,
738744
self.inputs,
@@ -1344,10 +1350,15 @@ def build(
13441350

13451351
unfulfilled_amount = requested_amount - trimmed_selected_amount
13461352

1353+
remaining = trimmed_selected_amount - requested_amount
1354+
remaining.multi_asset = remaining.multi_asset.filter(lambda p, n, v: v > 0)
1355+
remaining.coin = max(0, remaining.coin)
1356+
13471357
if change_address is not None and not can_merge_change:
13481358
# If change address is provided and remainder is smaller than minimum ADA required in change,
13491359
# we need to select additional UTxOs available from the address
13501360
if unfulfilled_amount.coin < 0:
1361+
13511362
unfulfilled_amount.coin = max(
13521363
0,
13531364
unfulfilled_amount.coin
@@ -1401,11 +1412,12 @@ def build(
14011412
self.context,
14021413
include_max_fee=False,
14031414
respect_min_utxo=not can_merge_change,
1415+
existing_amount=remaining,
14041416
)
1417+
14051418
for s in selected:
14061419
selected_amount += s.output.amount
14071420
selected_utxos.append(s)
1408-
14091421
break
14101422

14111423
except UTxOSelectionException as e:

0 commit comments

Comments
 (0)