diff --git a/src/descriptor/policy.rs b/src/descriptor/policy.rs index 6923ce3fb..2063aeb77 100644 --- a/src/descriptor/policy.rs +++ b/src/descriptor/policy.rs @@ -659,11 +659,11 @@ impl Policy { (0..*threshold).collect() } SatisfiableItem::Multisig { keys, .. } => (0..keys.len()).collect(), - _ => vec![], + _ => HashSet::new(), }; - let selected = match path.get(&self.id) { - Some(arr) => arr, - _ => &default, + let selected: HashSet<_> = match path.get(&self.id) { + Some(arr) => arr.iter().copied().collect(), + _ => default, }; match &self.item { @@ -671,14 +671,24 @@ impl Policy { let mapped_req = items .iter() .map(|i| i.get_condition(path)) - .collect::, _>>()?; + .collect::>(); // if all the requirements are null we don't care about `selected` because there // are no requirements - if mapped_req.iter().all(Condition::is_null) { + if mapped_req + .iter() + .all(|cond| matches!(cond, Ok(c) if c.is_null())) + { return Ok(Condition::default()); } + // make sure all the indexes in the `selected` list are within range + for index in &selected { + if *index >= items.len() { + return Err(PolicyError::IndexOutOfRange(*index)); + } + } + // if we have something, make sure we have enough items. note that the user can set // an empty value for this step in case of n-of-n, because `selected` is set to all // the elements above @@ -687,23 +697,18 @@ impl Policy { } // check the selected items, see if there are conflicting requirements - let mut requirements = Condition::default(); - for item_index in selected { - requirements = requirements.merge( - mapped_req - .get(*item_index) - .ok_or(PolicyError::IndexOutOfRange(*item_index))?, - )?; - } - - Ok(requirements) + mapped_req + .into_iter() + .enumerate() + .filter(|(index, _)| selected.contains(index)) + .try_fold(Condition::default(), |acc, (_, cond)| acc.merge(&cond?)) } SatisfiableItem::Multisig { keys, threshold } => { if selected.len() < *threshold { return Err(PolicyError::NotEnoughItemsSelected(self.id.clone())); } - if let Some(item) = selected.iter().find(|i| **i >= keys.len()) { - return Err(PolicyError::IndexOutOfRange(*item)); + if let Some(item) = selected.into_iter().find(|&i| i >= keys.len()) { + return Err(PolicyError::IndexOutOfRange(item)); } Ok(Condition::default()) diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs index f6fd26b49..55a2ef223 100644 --- a/src/wallet/mod.rs +++ b/src/wallet/mod.rs @@ -2936,6 +2936,25 @@ pub(crate) mod test { assert_eq!(psbt.unsigned_tx.input[0].sequence, Sequence(144)); } + #[test] + fn test_create_tx_policy_path_ignored_subtree_with_csv() { + let (wallet, _, _) = get_funded_wallet("wsh(or_d(pk(cRjo6jqfVNP33HhSS76UhXETZsGTZYx8FMFvR9kpbtCSV1PmdZdu),or_i(and_v(v:pkh(cVpPVruEDdmutPzisEsYvtST1usBR3ntr8pXSyt6D2YYqXRyPcFW),older(30)),and_v(v:pkh(cMnkdebixpXMPfkcNEjjGin7s94hiehAH4mLbYkZoh9KSiNNmqC8),older(90)))))"); + + let external_policy = wallet.policies(KeychainKind::External).unwrap().unwrap(); + let root_id = external_policy.id; + // child #0 is pk(cRjo6jqfVNP33HhSS76UhXETZsGTZYx8FMFvR9kpbtCSV1PmdZdu) + let path = vec![(root_id, vec![0])].into_iter().collect(); + + let addr = Address::from_str("2N1Ffz3WaNzbeLFBb51xyFMHYSEUXcbiSoX").unwrap(); + let mut builder = wallet.build_tx(); + builder + .add_recipient(addr.script_pubkey(), 30_000) + .policy_path(path, KeychainKind::External); + let (psbt, _) = builder.finish().unwrap(); + + assert_eq!(psbt.unsigned_tx.input[0].sequence, Sequence(0xFFFFFFFE)); + } + #[test] fn test_create_tx_global_xpubs_with_origin() { use bitcoin::hashes::hex::FromHex;