diff --git a/payjoin/src/input_type.rs b/payjoin/src/input_type.rs deleted file mode 100644 index bc8d11f9..00000000 --- a/payjoin/src/input_type.rs +++ /dev/null @@ -1,303 +0,0 @@ -use std::fmt; - -use bitcoin::blockdata::script::{Instruction, Instructions, Script}; -use bitcoin::blockdata::transaction::TxOut; -use bitcoin::psbt::Input as PsbtInput; - -/// Takes the script out of script_sig assuming script_sig signs p2sh script -fn unpack_p2sh(script_sig: &Script) -> Option<&Script> { - match script_sig.instructions().last()?.ok()? { - Instruction::PushBytes(bytes) => Some(Script::from_bytes(bytes.as_bytes())), - Instruction::Op(_) => None, - } -} - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub(crate) enum InputType { - P2Pk, - P2Pkh, - P2Sh, - SegWitV0 { ty: SegWitV0Type, nested: bool }, - Taproot, -} - -#[cfg(feature = "v2")] -impl serde::Serialize for InputType { - fn serialize(&self, serializer: S) -> Result { - use InputType::*; - - match self { - P2Pk => serializer.serialize_str("P2PK"), - P2Pkh => serializer.serialize_str("P2PKH"), - P2Sh => serializer.serialize_str("P2SH"), - SegWitV0 { ty, nested } => - serializer.serialize_str(&format!("SegWitV0: type={}, nested={}", ty, nested)), - Taproot => serializer.serialize_str("Taproot"), - } - } -} - -#[cfg(feature = "v2")] -impl<'de> serde::Deserialize<'de> for InputType { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - use InputType::*; - - let s = String::deserialize(deserializer)?; - if let Some(rest) = s.strip_prefix("SegWitV0: ") { - let parts: Vec<&str> = rest.split(", ").collect(); - if parts.len() != 2 { - return Err(serde::de::Error::custom("invalid format for SegWitV0")); - } - log::debug!("parts: {:?}", parts); - let ty = match parts[0].strip_prefix("type=") { - Some("pubkey") => SegWitV0Type::Pubkey, - Some("script") => SegWitV0Type::Script, - _ => return Err(serde::de::Error::custom("invalid SegWitV0 type")), - }; - - let nested = match parts[1].strip_prefix("nested=") { - Some("true") => true, - Some("false") => false, - _ => return Err(serde::de::Error::custom("invalid SegWitV0 nested value")), - }; - - Ok(SegWitV0 { ty, nested }) - } else { - match s.as_str() { - "P2PK" => Ok(P2Pk), - "P2PKH" => Ok(P2Pkh), - "P2SH" => Ok(P2Sh), - "Taproot" => Ok(Taproot), - _ => Err(serde::de::Error::custom("invalid type")), - } - } - } -} - -impl InputType { - pub(crate) fn from_spent_input( - txout: &TxOut, - txin: &PsbtInput, - ) -> Result { - if txout.script_pubkey.is_p2pk() { - Ok(InputType::P2Pk) - } else if txout.script_pubkey.is_p2pkh() { - Ok(InputType::P2Pkh) - } else if txout.script_pubkey.is_p2sh() { - match &txin - .final_script_sig - .as_ref() - .and_then(|script_buf| unpack_p2sh(script_buf.as_ref())) - { - Some(script) if script.is_witness_program() => - Self::segwit_from_script(script, true), - Some(_) => Ok(InputType::P2Sh), - None => Err(InputTypeError::NotFinalized), - } - } else if txout.script_pubkey.is_witness_program() { - Self::segwit_from_script(&txout.script_pubkey, false) - } else { - Err(InputTypeError::UnknownInputType) - } - } - - fn segwit_from_script(script: &Script, nested: bool) -> Result { - let mut instructions = script.instructions(); - let witness_version = instructions - .next() - .ok_or(InputTypeError::UnknownInputType)? - .map_err(|_| InputTypeError::UnknownInputType)?; - match witness_version { - Instruction::PushBytes(bytes) if bytes.is_empty() => - Ok(InputType::SegWitV0 { ty: instructions.try_into()?, nested }), - Instruction::Op(bitcoin::blockdata::opcodes::all::OP_PUSHNUM_1) => { - let instruction = instructions - .next() - .ok_or(InputTypeError::UnknownInputType)? - .map_err(|_| InputTypeError::UnknownInputType)?; - match instruction { - Instruction::PushBytes(bytes) if bytes.len() == 32 => Ok(InputType::Taproot), - Instruction::PushBytes(_) | Instruction::Op(_) => - Err(InputTypeError::UnknownInputType), - } - } - _ => Err(InputTypeError::UnknownInputType), - } - } - - pub(crate) fn expected_input_weight(&self) -> bitcoin::Weight { - use InputType::*; - - bitcoin::Weight::from_non_witness_data_size(match self { - P2Pk => unimplemented!(), - P2Pkh => 148, - P2Sh => unimplemented!(), - SegWitV0 { ty: SegWitV0Type::Pubkey, nested: false } => 68, - SegWitV0 { ty: SegWitV0Type::Pubkey, nested: true } => 91, - SegWitV0 { ty: SegWitV0Type::Script, nested: _ } => unimplemented!(), - Taproot => 58, - }) - } -} - -impl fmt::Display for InputType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - InputType::P2Pk => write!(f, "P2PK"), - InputType::P2Pkh => write!(f, "P2PKH"), - InputType::P2Sh => write!(f, "P2SH"), - InputType::SegWitV0 { ty, nested } => - write!(f, "SegWitV0: type={}, nested={}", ty, nested), - InputType::Taproot => write!(f, "Taproot"), - } - } -} - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub(crate) enum SegWitV0Type { - Pubkey, - Script, -} - -impl TryFrom> for SegWitV0Type { - type Error = InputTypeError; - - fn try_from( - mut instructions: bitcoin::blockdata::script::Instructions<'_>, - ) -> Result { - let push = instructions - .next() - .ok_or(InputTypeError::UnknownInputType)? - .map_err(|_| InputTypeError::UnknownInputType)?; - if instructions.next().is_some() { - return Err(InputTypeError::UnknownInputType); - } - match push { - Instruction::PushBytes(bytes) if bytes.len() == 20 => Ok(SegWitV0Type::Pubkey), - Instruction::PushBytes(bytes) if bytes.len() == 32 => Ok(SegWitV0Type::Script), - Instruction::PushBytes(_) | Instruction::Op(_) => Err(InputTypeError::UnknownInputType), - } - } -} - -impl fmt::Display for SegWitV0Type { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - SegWitV0Type::Pubkey => write!(f, "pubkey"), - SegWitV0Type::Script => write!(f, "script"), - } - } -} - -#[derive(Debug)] -pub(crate) enum InputTypeError { - UnknownInputType, - NotFinalized, -} - -impl fmt::Display for InputTypeError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - InputTypeError::UnknownInputType => write!(f, "unknown input type"), - InputTypeError::NotFinalized => write!(f, "input is not finalized"), - } - } -} - -impl std::error::Error for InputTypeError {} - -#[cfg(test)] -mod tests { - use bitcoin::psbt::Input as PsbtInput; - use bitcoin::script::PushBytesBuf; - use bitcoin::{Amount, PublicKey, ScriptBuf}; - - use super::*; - - static FORTY_TWO: Amount = Amount::from_sat(42); - - fn wrap_p2sh_script(script: &Script) -> ScriptBuf { - let bytes: PushBytesBuf = script - .to_bytes() - .try_into() - .expect("script should be valid from ScriptBuf::new_v0_p2wsh"); - bitcoin::blockdata::script::Builder::new().push_slice(bytes.as_ref()).into_script() - } - - #[test] - fn test_p2pk() { - let input_type = InputType::from_spent_input(&TxOut { script_pubkey: ScriptBuf::new_p2pk(&PublicKey::from_slice(b"\x02\x50\x86\x3A\xD6\x4A\x87\xAE\x8A\x2F\xE8\x3C\x1A\xF1\xA8\x40\x3C\xB5\x3F\x53\xE4\x86\xD8\x51\x1D\xAD\x8A\x04\x88\x7E\x5B\x23\x52").unwrap()), value: FORTY_TWO, }, &Default::default()).unwrap(); - assert_eq!(input_type, InputType::P2Pk); - } - - #[test] - fn test_p2pkh() { - let input_type = InputType::from_spent_input(&TxOut { script_pubkey: ScriptBuf::new_p2pkh(&PublicKey::from_slice(b"\x02\x50\x86\x3A\xD6\x4A\x87\xAE\x8A\x2F\xE8\x3C\x1A\xF1\xA8\x40\x3C\xB5\x3F\x53\xE4\x86\xD8\x51\x1D\xAD\x8A\x04\x88\x7E\x5B\x23\x52").unwrap().pubkey_hash()), value: FORTY_TWO, }, &Default::default()).unwrap(); - assert_eq!(input_type, InputType::P2Pkh); - } - - #[test] - fn test_p2sh() { - let script = ScriptBuf::new_op_return(&[42]); - let input_type = InputType::from_spent_input( - &TxOut { script_pubkey: ScriptBuf::new_p2sh(&script.script_hash()), value: FORTY_TWO }, - &PsbtInput { final_script_sig: Some(script), ..Default::default() }, - ) - .unwrap(); - assert_eq!(input_type, InputType::P2Sh); - } - - #[test] - fn test_p2wpkh() { - let input_type = InputType::from_spent_input(&TxOut { script_pubkey: ScriptBuf::new_p2wpkh(&PublicKey::from_slice(b"\x02\x50\x86\x3A\xD6\x4A\x87\xAE\x8A\x2F\xE8\x3C\x1A\xF1\xA8\x40\x3C\xB5\x3F\x53\xE4\x86\xD8\x51\x1D\xAD\x8A\x04\x88\x7E\x5B\x23\x52").unwrap().wpubkey_hash().expect("WTF, the key is uncompressed")), value: FORTY_TWO, }, &Default::default()).unwrap(); - assert_eq!(input_type, InputType::SegWitV0 { ty: SegWitV0Type::Pubkey, nested: false }); - } - - #[test] - fn test_p2wsh() { - let script = ScriptBuf::new_op_return(&[42]); - let input_type = InputType::from_spent_input( - &TxOut { - script_pubkey: ScriptBuf::new_p2wsh(&script.wscript_hash()), - value: FORTY_TWO, - }, - &PsbtInput { final_script_sig: Some(script), ..Default::default() }, - ) - .unwrap(); - assert_eq!(input_type, InputType::SegWitV0 { ty: SegWitV0Type::Script, nested: false }); - } - - #[test] - fn test_p2sh_p2wpkh() { - let segwit_script = ScriptBuf::new_p2wpkh(&PublicKey::from_slice(b"\x02\x50\x86\x3A\xD6\x4A\x87\xAE\x8A\x2F\xE8\x3C\x1A\xF1\xA8\x40\x3C\xB5\x3F\x53\xE4\x86\xD8\x51\x1D\xAD\x8A\x04\x88\x7E\x5B\x23\x52").unwrap().wpubkey_hash().expect("WTF, the key is uncompressed")); - let segwit_script_hash = segwit_script.script_hash(); - let script_sig = wrap_p2sh_script(&segwit_script); - - let input_type = InputType::from_spent_input( - &TxOut { script_pubkey: ScriptBuf::new_p2sh(&segwit_script_hash), value: FORTY_TWO }, - &PsbtInput { final_script_sig: Some(script_sig), ..Default::default() }, - ) - .unwrap(); - assert_eq!(input_type, InputType::SegWitV0 { ty: SegWitV0Type::Pubkey, nested: true }); - } - - #[test] - fn test_p2sh_p2wsh() { - let script = ScriptBuf::new_op_return(&[42]); - let segwit_script = ScriptBuf::new_p2wsh(&script.wscript_hash()); - let segwit_script_hash = segwit_script.script_hash(); - let script_sig = wrap_p2sh_script(&segwit_script); - - let input_type = InputType::from_spent_input( - &TxOut { script_pubkey: ScriptBuf::new_p2sh(&segwit_script_hash), value: FORTY_TWO }, - &PsbtInput { final_script_sig: Some(script_sig), ..Default::default() }, - ) - .unwrap(); - assert_eq!(input_type, InputType::SegWitV0 { ty: SegWitV0Type::Script, nested: true }); - } - - // TODO: test p2tr -} diff --git a/payjoin/src/lib.rs b/payjoin/src/lib.rs index fa2bc000..b1bb0345 100644 --- a/payjoin/src/lib.rs +++ b/payjoin/src/lib.rs @@ -35,8 +35,6 @@ pub use v2::OhttpKeys; #[cfg(feature = "io")] pub mod io; -#[cfg(any(feature = "send", feature = "receive"))] -pub(crate) mod input_type; #[cfg(any(feature = "send", feature = "receive"))] pub(crate) mod psbt; #[cfg(any(feature = "send", all(feature = "receive", feature = "v2")))] @@ -45,8 +43,6 @@ mod request; pub use request::*; mod uri; -#[cfg(any(feature = "send", feature = "receive"))] -pub(crate) mod weight; #[cfg(feature = "base64")] pub use bitcoin::base64; diff --git a/payjoin/src/output_type.rs b/payjoin/src/output_type.rs deleted file mode 100644 index e59d60da..00000000 --- a/payjoin/src/output_type.rs +++ /dev/null @@ -1,25 +0,0 @@ -crate::weight::Weight; - -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -#[non_exhaustive] -pub enum OutputType { - P2Pkh, - P2Sh, - SegWitV0 { ty: SegWitV0Type, nested: bool }, - Taproot, -} - -impl OutputType { - pub(crate) fn output_only_weight(&self) -> Weight { - use OutputType::*; - - match self { - P2Pkh => Weight::from_non_witness_data_size(1 /* OP_DUP */ + 1 /* OP_HASH160 */ + 1 /* OP_PUSH */ + 160 / 8 /* ripemd160 hash size */ + 1 /* OP_EQUALVERIFY */ + 1 /* OP_CHECKSIG */), - P2Sh => Weight::from_non_witness_data_size(1 /* OP_HASH160 */ + 1 /* OP_PUSH */ + 160 / 8 /* ripemd160 hash size */ + 1 /* OP_EQUAL */), - SegWitV0 { ty: _, nested: true } => Weight::from_non_witness_data_size(1 /* OP_HASH160 */ + 1 /* OP_PUSH */ + 160 / 8 /* ripemd160 hash size */ + 1 /* OP_EQUAL */), - SegWitV0 { ty: SegWitV0Type::Pubkey, nested: false } => Weight::from_non_witness_data_size(1 /* OP_PUSH0 */ + 1 /* OP_PUSH */ + 160 / 8 /* ripemd160 hash size */), - SegWitV0 { ty: SegWitV0Type::Script, nested: false } => Weight::from_non_witness_data_size(1 /* OP_PUSH0 */ + 1 /* OP_PUSH */ + 256 / 8 /* ripemd160 hash size */), - Taproot => Weight::from_non_witness_data_size(1 /* OP_PUSH0 */ + 1 /* OP_PUSH */ + 256 / 8 /* ripemd160 hash size */), - } - } -} diff --git a/payjoin/src/psbt.rs b/payjoin/src/psbt.rs index c8fa043e..2bc04cca 100644 --- a/payjoin/src/psbt.rs +++ b/payjoin/src/psbt.rs @@ -3,8 +3,11 @@ use std::collections::BTreeMap; use std::fmt; +use bitcoin::address::FromScriptError; +use bitcoin::blockdata::script::Instruction; use bitcoin::psbt::Psbt; -use bitcoin::{bip32, psbt, TxIn, TxOut}; +use bitcoin::transaction::InputWeightPrediction; +use bitcoin::{bip32, psbt, Address, AddressType, Network, Script, TxIn, TxOut, Weight}; #[derive(Debug)] pub(crate) enum InconsistentPsbt { @@ -37,7 +40,6 @@ pub(crate) trait PsbtExt: Sized { /// thing for outputs. fn validate(self) -> Result; fn validate_input_utxos(&self, treat_missing_as_error: bool) -> Result<(), PsbtInputsError>; - fn calculate_fee(&self) -> bitcoin::Amount; } impl PsbtExt for Psbt { @@ -89,24 +91,21 @@ impl PsbtExt for Psbt { .map_err(|error| PsbtInputsError { index, error }) }) } +} - fn calculate_fee(&self) -> bitcoin::Amount { - let mut total_outputs = bitcoin::Amount::ZERO; - let mut total_inputs = bitcoin::Amount::ZERO; - - for output in &self.unsigned_tx.output { - total_outputs += output.value; - } - - for input in self.input_pairs() { - total_inputs += input.previous_txout().unwrap().value; - } - log::debug!(" total_inputs: {}", total_inputs); - log::debug!("- total_outputs: {}", total_outputs); - total_inputs - total_outputs +/// Gets redeemScript from the script_sig following BIP16 rules regarding P2SH spending. +fn redeem_script(script_sig: &Script) -> Option<&Script> { + match script_sig.instructions().last()?.ok()? { + Instruction::PushBytes(bytes) => Some(Script::from_bytes(bytes.as_bytes())), + Instruction::Op(_) => None, } } +// input script: 0x160014{20-byte-key-hash} = 23 bytes +// witness: = 72, 33 bytes +// https://github.com/bitcoin/bips/blob/master/bip-0141.mediawiki#p2wpkh-nested-in-bip16-p2sh +const NESTED_P2WPKH_MAX: InputWeightPrediction = InputWeightPrediction::from_slice(23, &[72, 33]); + pub(crate) struct InputPair<'a> { pub txin: &'a TxIn, pub psbtin: &'a psbt::Input, @@ -180,6 +179,43 @@ impl<'a> InputPair<'a> { (Some(_), Some(_)) => Err(PsbtInputError::UnequalTxid), } } + + pub fn address_type(&self) -> Result { + let txo = self.previous_txout()?; + // HACK: Network doesn't matter for our use case of only getting the address type + // but is required in the `from_script` interface. Hardcoded to mainnet. + Address::from_script(&txo.script_pubkey, Network::Bitcoin)? + .address_type() + .ok_or(AddressTypeError::UnknownAddressType) + } + + pub fn expected_input_weight(&self) -> Result { + use bitcoin::AddressType::*; + + // Get the input weight prediction corresponding to spending an output of this address type + let iwp = match self.address_type()? { + P2pkh => Ok(InputWeightPrediction::P2PKH_COMPRESSED_MAX), + P2sh => + match self.psbtin.final_script_sig.as_ref().and_then(|s| redeem_script(s.as_ref())) + { + // Nested segwit p2wpkh. + Some(script) if script.is_witness_program() && script.is_p2wpkh() => + Ok(NESTED_P2WPKH_MAX), + // Other script or witness program. + Some(_) => Err(InputWeightError::NotSupported), + // No redeem script provided. Cannot determine the script type. + None => Err(InputWeightError::NotFinalized), + }, + P2wpkh => Ok(InputWeightPrediction::P2WPKH_MAX), + P2wsh => Err(InputWeightError::NotSupported), + P2tr => Ok(InputWeightPrediction::P2TR_KEY_DEFAULT_SIGHASH), + _ => Err(AddressTypeError::UnknownAddressType.into()), + }?; + + // Lengths of txid, index and sequence: (32, 4, 4). + let input_weight = iwp.weight() + Weight::from_non_witness_data_size(32 + 4 + 4); + Ok(input_weight) + } } #[derive(Debug)] @@ -248,3 +284,68 @@ impl fmt::Display for PsbtInputsError { impl std::error::Error for PsbtInputsError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { Some(&self.error) } } + +#[derive(Debug)] +pub(crate) enum AddressTypeError { + PrevTxOut(PrevTxOutError), + InvalidScript(FromScriptError), + UnknownAddressType, +} + +impl fmt::Display for AddressTypeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::PrevTxOut(_) => write!(f, "invalid previous transaction output"), + Self::InvalidScript(_) => write!(f, "invalid script"), + Self::UnknownAddressType => write!(f, "unknown address type"), + } + } +} + +impl std::error::Error for AddressTypeError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::PrevTxOut(error) => Some(error), + Self::InvalidScript(error) => Some(error), + Self::UnknownAddressType => None, + } + } +} + +impl From for AddressTypeError { + fn from(value: PrevTxOutError) -> Self { Self::PrevTxOut(value) } +} + +impl From for AddressTypeError { + fn from(value: FromScriptError) -> Self { Self::InvalidScript(value) } +} + +#[derive(Debug)] +pub(crate) enum InputWeightError { + AddressType(AddressTypeError), + NotFinalized, + NotSupported, +} + +impl fmt::Display for InputWeightError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::AddressType(_) => write!(f, "invalid address type"), + Self::NotFinalized => write!(f, "input not finalized"), + Self::NotSupported => write!(f, "weight prediction not supported"), + } + } +} + +impl std::error::Error for InputWeightError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::AddressType(error) => Some(error), + Self::NotFinalized => None, + Self::NotSupported => None, + } + } +} +impl From for InputWeightError { + fn from(value: AddressTypeError) -> Self { Self::AddressType(value) } +} diff --git a/payjoin/src/receive/error.rs b/payjoin/src/receive/error.rs index 2056bcf8..f1ba34d4 100644 --- a/payjoin/src/receive/error.rs +++ b/payjoin/src/receive/error.rs @@ -73,9 +73,11 @@ pub(crate) enum InternalRequestError { /// The sender is trying to spend the receiver input InputOwned(bitcoin::ScriptBuf), /// The original psbt has mixed input address types that could harm privacy - MixedInputScripts(crate::input_type::InputType, crate::input_type::InputType), - /// Unrecognized input type - InputType(crate::input_type::InputTypeError), + MixedInputScripts(bitcoin::AddressType, bitcoin::AddressType), + /// The address type could not be determined + AddressType(crate::psbt::AddressTypeError), + /// The expected input weight cannot be determined + InputWeight(crate::psbt::InputWeightError), /// Original PSBT input has been seen before. Only automatic receivers, aka "interactive" in the spec /// look out for these to prevent probing attacks. InputSeen(bitcoin::OutPoint), @@ -155,8 +157,10 @@ impl fmt::Display for RequestError { "original-psbt-rejected", &format!("Mixed input scripts: {}; {}.", type_a, type_b), ), - InternalRequestError::InputType(e) => - write_error(f, "original-psbt-rejected", &format!("Input Type Error: {}.", e)), + InternalRequestError::AddressType(e) => + write_error(f, "original-psbt-rejected", &format!("AddressType Error: {}", e)), + InternalRequestError::InputWeight(e) => + write_error(f, "original-psbt-rejected", &format!("InputWeight Error: {}", e)), InternalRequestError::InputSeen(_) => write_error(f, "original-psbt-rejected", "The receiver rejected the original PSBT."), #[cfg(feature = "v2")] @@ -196,6 +200,8 @@ impl std::error::Error for RequestError { InternalRequestError::SenderParams(e) => Some(e), InternalRequestError::InconsistentPsbt(e) => Some(e), InternalRequestError::PrevTxOut(e) => Some(e), + InternalRequestError::AddressType(e) => Some(e), + InternalRequestError::InputWeight(e) => Some(e), #[cfg(feature = "v2")] InternalRequestError::ParsePsbt(e) => Some(e), #[cfg(feature = "v2")] diff --git a/payjoin/src/receive/mod.rs b/payjoin/src/receive/mod.rs index c3bcdcd5..8fe524bc 100644 --- a/payjoin/src/receive/mod.rs +++ b/payjoin/src/receive/mod.rs @@ -45,7 +45,6 @@ use error::{ }; use optional_parameters::Params; -use crate::input_type::InputType; use crate::psbt::PsbtExt; pub trait Headers { @@ -228,16 +227,10 @@ impl MaybeMixedInputScripts { let input_scripts = self .psbt .input_pairs() - .scan(&mut err, |err, input| match input.previous_txout() { - Ok(txout) => match InputType::from_spent_input(txout, input.psbtin) { - Ok(input_script) => Some(input_script), - Err(e) => { - **err = Err(RequestError::from(InternalRequestError::InputType(e))); - None - } - }, + .scan(&mut err, |err, input| match input.address_type() { + Ok(address_type) => Some(address_type), Err(e) => { - **err = Err(RequestError::from(InternalRequestError::PrevTxOut(e))); + **err = Err(RequestError::from(InternalRequestError::AddressType(e))); None } }) @@ -757,13 +750,11 @@ impl ProvisionalProposal { .input_pairs() .next() .ok_or(InternalRequestError::OriginalPsbtNotBroadcastable)?; - // Calculate the additional fee contribution - let txo = input_pair.previous_txout().map_err(InternalRequestError::PrevTxOut)?; - let input_type = InputType::from_spent_input(txo, &self.payjoin_psbt.inputs[0]) - .map_err(InternalRequestError::InputType)?; + // Calculate the additional weight contribution let input_count = self.payjoin_psbt.inputs.len() - self.original_psbt.inputs.len(); log::trace!("input_count : {}", input_count); - let weight_per_input = input_type.expected_input_weight(); + let weight_per_input = + input_pair.expected_input_weight().map_err(InternalRequestError::InputWeight)?; log::trace!("weight_per_input : {}", weight_per_input); let contribution_weight = weight_per_input * input_count as u64; log::trace!("contribution_weight: {}", contribution_weight); @@ -1012,6 +1003,57 @@ mod test { assert!(psbt.is_err(), "Payjoin exceeds receiver fee preference and should error"); } + #[test] + fn additional_input_weight_matches_known_weight() { + // All expected input weights pulled from: + // https://bitcoin.stackexchange.com/questions/84004/how-do-virtual-size-stripped-size-and-raw-size-compare-between-legacy-address-f#84006 + // Input weight for a single P2PKH (legacy) receiver input + let p2pkh_proposal = ProvisionalProposal { + original_psbt: Psbt::from_str("cHNidP8BAHECAAAAAb2qhegy47hqffxh/UH5Qjd/G3sBH6cW2QSXZ86nbY3nAAAAAAD9////AhXKBSoBAAAAFgAU4TiLFD14YbpddFVrZa3+Zmz96yQQJwAAAAAAABYAFB4zA2o+5MsNRT/j+0twLi5VbwO9AAAAAAABAIcCAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA/////wMBSgD/////AgDyBSoBAAAAGXapFGUxpU6cGldVpjUm9rV2B+jTlphDiKwAAAAAAAAAACZqJKohqe3i9hw/cdHe/T+pmd+jaVN1XGkGiXmZYrSL69g2l06M+QAAAAABB2pHMEQCIGsOxO/bBv20bd68sBnEU3cxHR8OxEcUroL3ENhhjtN3AiB+9yWuBGKXu41hcfO4KP7IyLLEYc6j8hGowmAlCPCMPAEhA6WNSN4CqJ9F+42YKPlIFN0wJw7qawWbdelGRMkAbBRnACICAsdIAjsfMLKgfL2J9rfIa8yKdO1BOpSGRIFbFMBdTsc9GE4roNNUAACAAQAAgAAAAIABAAAAAAAAAAAA").unwrap(), + payjoin_psbt: Psbt::from_str("cHNidP8BAJoCAAAAAtTRxwAtk38fRMP3ffdKkIi5r+Ss9AjaO8qEv+eQ/ho3AAAAAAD9////vaqF6DLjuGp9/GH9QflCN38bewEfpxbZBJdnzqdtjecAAAAAAP3///8CgckFKgEAAAAWABThOIsUPXhhul10VWtlrf5mbP3rJBAZBioBAAAAFgAUiDIby0wSbj1kv3MlvwoEKw3vNZUAAAAAAAEAhwIAAAABAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD/////AwFoAP////8CAPIFKgEAAAAZdqkUPXhu3I6D9R0wUpvTvvUm+VGNcNuIrAAAAAAAAAAAJmokqiGp7eL2HD9x0d79P6mZ36NpU3VcaQaJeZlitIvr2DaXToz5AAAAAAEBIgDyBSoBAAAAGXapFD14btyOg/UdMFKb0771JvlRjXDbiKwBB2pHMEQCIGzKy8QfhHoAY0+LZCpQ7ZOjyyXqaSBnr89hH3Eg/xsGAiB3n8hPRuXCX/iWtURfXoJNUFu3sLeQVFf1dDFCZPN0dAEhA8rTfrwcq6dEBSNOrUfNb8+dm7q77vCtfdOmWx0HfajRAAEAhwIAAAABAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD/////AwFKAP////8CAPIFKgEAAAAZdqkUZTGlTpwaV1WmNSb2tXYH6NOWmEOIrAAAAAAAAAAAJmokqiGp7eL2HD9x0d79P6mZ36NpU3VcaQaJeZlitIvr2DaXToz5AAAAAAAAAA==").unwrap(), + params: Params::default(), + change_vout: 0 + }; + assert_eq!( + p2pkh_proposal.additional_input_weight().expect("should calculate input weight"), + Weight::from_wu(592) + ); + + // Input weight for a single nested P2WPKH (nested segwit) receiver input + let nested_p2wpkh_proposal = ProvisionalProposal { + original_psbt: Psbt::from_str("cHNidP8BAHECAAAAAX57euL5j6xOst5JB/e/gp58RihmmpxXpsc2hEKKcVFkAAAAAAD9////AhAnAAAAAAAAFgAUtjrU62JOASAnPQ4e30wBM/Exk7ZM0QKVAAAAABYAFL6xh6gjSHmznJnPMbolG7wbGuwtAAAAAAABAIYCAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA/////wQCqgAA/////wIA+QKVAAAAABepFOyefe4gjXozL4pzi5vcPrjMeCJwhwAAAAAAAAAAJmokqiGp7eL2HD9x0d79P6mZ36NpU3VcaQaJeZlitIvr2DaXToz5AAAAAAEBIAD5ApUAAAAAF6kU7J597iCNejMvinOLm9w+uMx4InCHAQcXFgAUd6fhKfAd+JIJGpIGkMfMpjd/26sBCGsCRzBEAiBaCDgIrTw5bB1VZrB8RPycgKGNPw/YS6P+psUyxOUwgwIgbJkcbHlMoZxG7vBOVWnQQWayDTSvub6L20dDo1R5SS8BIQK2GCTydo2dJXC6C5wcSKzQ2pCsSygXa0+cMlJrRRnKtwAAIgIC0VgJvaoW2/lbq5atJhxfcgVzs6/gnpafsJHbz+ei484YDOqFk1QAAIABAACAAAAAgAEAAAACAAAAAA==").unwrap(), + payjoin_psbt: Psbt::from_str("cHNidP8BAJoCAAAAAn57euL5j6xOst5JB/e/gp58RihmmpxXpsc2hEKKcVFkAAAAAAD9////VinByqmVDo3wPNB9LnNELJoJ0g+hOdWiTSXzWEUVtiAAAAAAAP3///8CEBkGKgEAAAAWABSZUDn7eqenP01ziWRBnTCrpwwD6vHQApUAAAAAFgAUvrGHqCNIebOcmc8xuiUbvBsa7C0AAAAAAAEAhgIAAAABAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD/////BAKqAAD/////AgD5ApUAAAAAF6kU7J597iCNejMvinOLm9w+uMx4InCHAAAAAAAAAAAmaiSqIant4vYcP3HR3v0/qZnfo2lTdVxpBol5mWK0i+vYNpdOjPkAAAAAAQEgAPkClQAAAAAXqRTsnn3uII16My+Kc4ub3D64zHgicIcAAQCEAgAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAP////8CYAD/////AgDyBSoBAAAAF6kUx/+ZHBBBZ+6E/US1N2Oe7IDItXiHAAAAAAAAAAAmaiSqIant4vYcP3HR3v0/qZnfo2lTdVxpBol5mWK0i+vYNpdOjPkAAAAAAQEgAPIFKgEAAAAXqRTH/5kcEEFn7oT9RLU3Y57sgMi1eIcBBxcWABRDVkPBhZHK7tVQqp2uWqQC/GGTCgEIawJHMEQCIEv8/8VpUz0dK4MCcVzS7zoyt+hPRvWwLskZBuaurnFiAiBIuyt1IRaHqFSspDbjDNM607nrDQz4lmDnekNqMNn07AEhAp1Ol7vKvG2Oi8RSrsb7uSPTET83/YXuknx63PhfCG/zAAAA").unwrap(), + params: Params::default(), + change_vout: 0 + }; + // Currently nested segwit is not supported, see https://github.com/payjoin/rust-payjoin/issues/358 + assert!(nested_p2wpkh_proposal.additional_input_weight().is_err()); + + // Input weight for a single P2WPKH (native segwit) receiver input + let p2wpkh_proposal = ProvisionalProposal { + original_psbt: Psbt::from_str("cHNidP8BAHECAAAAASom13OiXZIr3bKk+LtUndZJYqdHQQU8dMs1FZ93IctIAAAAAAD9////AmPKBSoBAAAAFgAU6H98YM9NE1laARQ/t9/90nFraf4QJwAAAAAAABYAFBPJFmYuJBsrIaBBp9ur98pMSKxhAAAAAAABAIQCAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA/////wMBWwD/////AgDyBSoBAAAAFgAUjTJXmC73n+URSNdfgbS6Oa6JyQYAAAAAAAAAACZqJKohqe3i9hw/cdHe/T+pmd+jaVN1XGkGiXmZYrSL69g2l06M+QAAAAABAR8A8gUqAQAAABYAFI0yV5gu95/lEUjXX4G0ujmuickGAQhrAkcwRAIgUqbHS0difIGTRwN56z2/EiqLQFWerfJspyjuwsGSCXcCIA3IRTu8FVgniU5E4gecAMeegVnlTbTVfFyusWhQ2kVVASEDChVRm26KidHNWLdCLBTq5jspGJr+AJyyMqmUkvPkwFsAIgIDeBqmRB3ESjFWIp+wUXn/adGZU3kqWGjdkcnKpk8bAyUY94v8N1QAAIABAACAAAAAgAEAAAAAAAAAAAA=").unwrap(), + payjoin_psbt: Psbt::from_str("cHNidP8BAJoCAAAAAiom13OiXZIr3bKk+LtUndZJYqdHQQU8dMs1FZ93IctIAAAAAAD9////NG21aH8Vat3thaVmPvWDV/lvRmymFHeePcfUjlyngHIAAAAAAP3///8CH8oFKgEAAAAWABTof3xgz00TWVoBFD+33/3ScWtp/hAZBioBAAAAFgAU1mbnqky3bMxfmm0OgFaQCAs5fsoAAAAAAAEAhAIAAAABAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD/////AwFbAP////8CAPIFKgEAAAAWABSNMleYLvef5RFI11+BtLo5ronJBgAAAAAAAAAAJmokqiGp7eL2HD9x0d79P6mZ36NpU3VcaQaJeZlitIvr2DaXToz5AAAAAAEBHwDyBSoBAAAAFgAUjTJXmC73n+URSNdfgbS6Oa6JyQYAAQCEAgAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAP////8DAWcA/////wIA8gUqAQAAABYAFJFtkfHTt3y1EDMaN6CFjjNWtpCRAAAAAAAAAAAmaiSqIant4vYcP3HR3v0/qZnfo2lTdVxpBol5mWK0i+vYNpdOjPkAAAAAAQEfAPIFKgEAAAAWABSRbZHx07d8tRAzGjeghY4zVraQkQEIawJHMEQCIDTC49IB9AnItqd8zy5RDc05f2ApBAfJ5x4zYfj3bsD2AiAQvvSt5ipScHcUwdlYB9vFnEi68hmh55M5a5e+oWvxMAEhAqErVSVulFb97/r5KQryOS1Xgghff8R7AOuEnvnmslQ5AAAA").unwrap(), + params: Params::default(), + change_vout: 0 + }; + assert_eq!( + p2wpkh_proposal.additional_input_weight().expect("should calculate input weight"), + Weight::from_wu(272) + ); + + // Input weight for a single P2TR (taproot) receiver input + let p2tr_proposal = ProvisionalProposal { + original_psbt: Psbt::from_str("cHNidP8BAHECAAAAAU/CHxd1oi9Lq1xOD2GnHe0hsQdGJ2mkpYkmeasTj+w1AAAAAAD9////Am3KBSoBAAAAFgAUqJL/PDPnHeihhNhukTz8QEdZbZAQJwAAAAAAABYAFInyO0NQF7YR22Sm0YTPGm6yf19YAAAAAAABASsA8gUqAQAAACJRIGOPekNKFs9ASLj3FdlCLiou/jdPUegJGzlA111A80MAAQhCAUC3zX8eSeL8+bAo6xO0cpon83UsJdttiuwfMn/pBwub82rzMsoS6HZNXzg7hfcB3p1uj8JmqsBkZwm8k6fnU2peACICA+u+FjwmhEgWdjhEQbO49D0NG8iCYUoqhlfsj0LN7hiRGOcVI65UAACAAQAAgAAAAIABAAAAAAAAAAAA").unwrap(), + payjoin_psbt: Psbt::from_str("cHNidP8BAJoCAAAAAk/CHxd1oi9Lq1xOD2GnHe0hsQdGJ2mkpYkmeasTj+w1AAAAAAD9////Fz+ELsYp/55j6+Jl2unG9sGvpHTiSyzSORBvtu1GEB4AAAAAAP3///8CM8oFKgEAAAAWABSokv88M+cd6KGE2G6RPPxAR1ltkBAZBioBAAAAFgAU68J5imRcKy3g5JCT3bEoP9IXEn0AAAAAAAEBKwDyBSoBAAAAIlEgY496Q0oWz0BIuPcV2UIuKi7+N09R6AkbOUDXXUDzQwAAAQErAPIFKgEAAAAiUSCfbbX+FHJbzC71eEFLsMjDouMJbu8ogeR0eNoNxMM9CwEIQwFBeyOLUebV/YwpaLTpLIaTXaSiPS7Dn6o39X4nlUzQLfb6YyvCAsLA5GTxo+Zb0NUINZ8DaRyUWknOpU/Jzuwn2gEAAAA=").unwrap(), + params: Params::default(), + change_vout: 0 + }; + assert_eq!( + p2tr_proposal.additional_input_weight().expect("should calculate input weight"), + Weight::from_wu(230) + ); + } + #[test] fn test_interleave_shuffle() { let mut original1 = vec![1, 2, 3]; diff --git a/payjoin/src/send/error.rs b/payjoin/src/send/error.rs index 0c955199..114ff962 100644 --- a/payjoin/src/send/error.rs +++ b/payjoin/src/send/error.rs @@ -2,9 +2,7 @@ use std::fmt::{self, Display}; use bitcoin::locktime::absolute::LockTime; use bitcoin::transaction::Version; -use bitcoin::Sequence; - -use crate::input_type::{InputType, InputTypeError}; +use bitcoin::{AddressType, Sequence}; /// Error that may occur when the response from receiver is malformed. /// @@ -19,8 +17,8 @@ pub struct ValidationError { pub(crate) enum InternalValidationError { Parse, Io(std::io::Error), - InvalidInputType(InputTypeError), - InvalidProposedInput(crate::psbt::PrevTxOutError), + InvalidAddressType(crate::psbt::AddressTypeError), + NoInputs, VersionsDontMatch { proposed: Version, original: Version, @@ -43,8 +41,8 @@ pub(crate) enum InternalValidationError { ReceiverTxinMissingUtxoInfo, MixedSequence, MixedInputTypes { - proposed: InputType, - original: InputType, + proposed: AddressType, + original: AddressType, }, MissingOrShuffledInputs, TxOutContainsKeyPaths, @@ -52,26 +50,27 @@ pub(crate) enum InternalValidationError { DisallowedOutputSubstitution, OutputValueDecreased, MissingOrShuffledOutputs, - Inflation, AbsoluteFeeDecreased, PayeeTookContributedFee, FeeContributionPaysOutputSizeIncrease, FeeRateBelowMinimum, + Psbt(bitcoin::psbt::Error), #[cfg(feature = "v2")] Hpke(crate::v2::HpkeError), #[cfg(feature = "v2")] OhttpEncapsulation(crate::v2::OhttpEncapsulationError), #[cfg(feature = "v2")] - Psbt(bitcoin::psbt::Error), - #[cfg(feature = "v2")] UnexpectedStatusCode, } impl From for ValidationError { fn from(value: InternalValidationError) -> Self { ValidationError { internal: value } } } -impl From for InternalValidationError { - fn from(value: InputTypeError) -> Self { InternalValidationError::InvalidInputType(value) } + +impl From for InternalValidationError { + fn from(value: crate::psbt::AddressTypeError) -> Self { + InternalValidationError::InvalidAddressType(value) + } } impl fmt::Display for ValidationError { @@ -81,8 +80,8 @@ impl fmt::Display for ValidationError { match &self.internal { Parse => write!(f, "couldn't decode as PSBT or JSON",), Io(e) => write!(f, "couldn't read PSBT: {}", e), - InvalidInputType(e) => write!(f, "invalid transaction input type: {}", e), - InvalidProposedInput(e) => write!(f, "invalid proposed transaction input: {}", e), + InvalidAddressType(e) => write!(f, "invalid input address type: {}", e), + NoInputs => write!(f, "PSBT doesn't have any inputs"), VersionsDontMatch { proposed, original, } => write!(f, "proposed transaction version {} doesn't match the original {}", proposed, original), LockTimesDontMatch { proposed, original, } => write!(f, "proposed transaction lock time {} doesn't match the original {}", proposed, original), SenderTxinSequenceChanged { proposed, original, } => write!(f, "proposed transaction sequence number {} doesn't match the original {}", proposed, original), @@ -102,18 +101,16 @@ impl fmt::Display for ValidationError { DisallowedOutputSubstitution => write!(f, "the receiver change output despite it being disallowed"), OutputValueDecreased => write!(f, "the amount in our non-fee output was decreased"), MissingOrShuffledOutputs => write!(f, "proposed transaction is missing outputs of the sender or they are shuffled"), - Inflation => write!(f, "proposed transaction is attempting inflation"), AbsoluteFeeDecreased => write!(f, "abslute fee of proposed transaction is lower than original"), PayeeTookContributedFee => write!(f, "payee tried to take fee contribution for himself"), FeeContributionPaysOutputSizeIncrease => write!(f, "fee contribution pays for additional outputs"), FeeRateBelowMinimum => write!(f, "the fee rate of proposed transaction is below minimum"), + Psbt(e) => write!(f, "psbt error: {}", e), #[cfg(feature = "v2")] Hpke(e) => write!(f, "v2 error: {}", e), #[cfg(feature = "v2")] OhttpEncapsulation(e) => write!(f, "Ohttp encapsulation error: {}", e), #[cfg(feature = "v2")] - Psbt(e) => write!(f, "psbt error: {}", e), - #[cfg(feature = "v2")] UnexpectedStatusCode => write!(f, "unexpected status code"), } } @@ -126,8 +123,8 @@ impl std::error::Error for ValidationError { match &self.internal { Parse => None, Io(error) => Some(error), - InvalidInputType(error) => Some(error), - InvalidProposedInput(error) => Some(error), + InvalidAddressType(error) => Some(error), + NoInputs => None, VersionsDontMatch { proposed: _, original: _ } => None, LockTimesDontMatch { proposed: _, original: _ } => None, SenderTxinSequenceChanged { proposed: _, original: _ } => None, @@ -147,18 +144,16 @@ impl std::error::Error for ValidationError { DisallowedOutputSubstitution => None, OutputValueDecreased => None, MissingOrShuffledOutputs => None, - Inflation => None, AbsoluteFeeDecreased => None, PayeeTookContributedFee => None, FeeContributionPaysOutputSizeIncrease => None, FeeRateBelowMinimum => None, + Psbt(error) => Some(error), #[cfg(feature = "v2")] Hpke(error) => Some(error), #[cfg(feature = "v2")] OhttpEncapsulation(error) => Some(error), #[cfg(feature = "v2")] - Psbt(error) => Some(error), - #[cfg(feature = "v2")] UnexpectedStatusCode => None, } } @@ -186,8 +181,8 @@ pub(crate) enum InternalCreateRequestError { ChangeIndexOutOfBounds, ChangeIndexPointsAtPayee, Url(url::ParseError), - PrevTxOut(crate::psbt::PrevTxOutError), - InputType(crate::input_type::InputTypeError), + AddressType(crate::psbt::AddressTypeError), + InputWeight(crate::psbt::InputWeightError), #[cfg(feature = "v2")] Hpke(crate::v2::HpkeError), #[cfg(feature = "v2")] @@ -217,8 +212,8 @@ impl fmt::Display for CreateRequestError { ChangeIndexOutOfBounds => write!(f, "fee output index is points out of bounds"), ChangeIndexPointsAtPayee => write!(f, "fee output index is points at output belonging to the payee"), Url(e) => write!(f, "cannot parse url: {:#?}", e), - PrevTxOut(e) => write!(f, "invalid previous transaction output: {}", e), - InputType(e) => write!(f, "invalid input type: {}", e), + AddressType(e) => write!(f, "can not determine input address type: {}", e), + InputWeight(e) => write!(f, "can not determine expected input weight: {}", e), #[cfg(feature = "v2")] Hpke(e) => write!(f, "v2 error: {}", e), #[cfg(feature = "v2")] @@ -250,8 +245,8 @@ impl std::error::Error for CreateRequestError { ChangeIndexOutOfBounds => None, ChangeIndexPointsAtPayee => None, Url(error) => Some(error), - PrevTxOut(error) => Some(error), - InputType(error) => Some(error), + AddressType(error) => Some(error), + InputWeight(error) => Some(error), #[cfg(feature = "v2")] Hpke(error) => Some(error), #[cfg(feature = "v2")] @@ -270,6 +265,12 @@ impl From for CreateRequestError { fn from(value: InternalCreateRequestError) -> Self { CreateRequestError(value) } } +impl From for CreateRequestError { + fn from(value: crate::psbt::AddressTypeError) -> Self { + CreateRequestError(InternalCreateRequestError::AddressType(value)) + } +} + #[cfg(feature = "v2")] impl From for CreateRequestError { fn from(value: ParseSubdirectoryError) -> Self { diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 2e4b0abf..99f10584 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -27,19 +27,17 @@ use std::str::FromStr; use bitcoin::psbt::Psbt; -use bitcoin::{FeeRate, Script, ScriptBuf, Sequence, TxOut, Weight}; +use bitcoin::{Amount, FeeRate, Script, ScriptBuf, TxOut, Weight}; pub use error::{CreateRequestError, ResponseError, ValidationError}; pub(crate) use error::{InternalCreateRequestError, InternalValidationError}; #[cfg(feature = "v2")] use serde::{Deserialize, Serialize}; use url::Url; -use crate::input_type::InputType; -use crate::psbt::PsbtExt; +use crate::psbt::{InputPair, PsbtExt}; use crate::request::Request; #[cfg(feature = "v2")] use crate::v2::{HpkePublicKey, HpkeSecretKey}; -use crate::weight::{varint_size, ComputeWeight}; use crate::PjUri; // See usize casts @@ -125,26 +123,24 @@ impl<'a> RequestBuilder<'a> { .find(|(_, txo)| payout_scripts.all(|script| script != txo.script_pubkey)) .map(|(i, txo)| (i, txo.value)) { - let input_types = self - .psbt - .input_pairs() - .map(|input| { - let txo = - input.previous_txout().map_err(InternalCreateRequestError::PrevTxOut)?; - InputType::from_spent_input(txo, input.psbtin) - .map_err(InternalCreateRequestError::InputType) - }) - .collect::, InternalCreateRequestError>>()?; - - let first_type = input_types.first().ok_or(InternalCreateRequestError::NoInputs)?; - // use cheapest default if mixed input types - let mut input_vsize = InputType::Taproot.expected_input_weight(); - // Check if all inputs are the same type - if input_types.iter().all(|input_type| input_type == first_type) { - input_vsize = first_type.expected_input_weight(); + let mut input_pairs = self.psbt.input_pairs().collect::>().into_iter(); + let first_input_pair = + input_pairs.next().ok_or(InternalCreateRequestError::NoInputs)?; + let mut input_weight = first_input_pair + .expected_input_weight() + .map_err(InternalCreateRequestError::InputWeight)?; + for input_pair in input_pairs { + // use cheapest default if mixed input types + if input_pair.address_type()? != first_input_pair.address_type()? { + input_weight = + bitcoin::transaction::InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH.weight() + // Lengths of txid, index and sequence: (32, 4, 4). + + Weight::from_non_witness_data_size(32 + 4 + 4); + break; + } } - let recommended_additional_fee = min_fee_rate * input_vsize; + let recommended_additional_fee = min_fee_rate * input_weight; if fee_available < recommended_additional_fee { log::warn!("Insufficient funds to maintain specified minimum feerate."); return self.build_with_additional_fee( @@ -225,21 +221,12 @@ impl<'a> RequestBuilder<'a> { )?; clear_unneeded_fields(&mut psbt); - let zeroth_input = psbt.input_pairs().next().ok_or(InternalCreateRequestError::NoInputs)?; - - let sequence = zeroth_input.txin.sequence; - let txout = zeroth_input.previous_txout().map_err(InternalCreateRequestError::PrevTxOut)?; - let input_type = InputType::from_spent_input(txout, zeroth_input.psbtin) - .map_err(InternalCreateRequestError::InputType)?; - Ok(RequestContext { psbt, endpoint, disable_output_substitution, fee_contribution, payee, - input_type, - sequence, min_fee_rate: self.min_fee_rate, #[cfg(feature = "v2")] e: crate::v2::HpkeKeyPair::gen_keypair().secret_key().clone(), @@ -255,8 +242,6 @@ pub struct RequestContext { disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, - input_type: InputType, - sequence: Sequence, payee: ScriptBuf, #[cfg(feature = "v2")] e: crate::v2::HpkeSecretKey, @@ -281,8 +266,6 @@ impl RequestContext { disable_output_substitution: self.disable_output_substitution, fee_contribution: self.fee_contribution, payee: self.payee.clone(), - input_type: self.input_type, - sequence: self.sequence, min_fee_rate: self.min_fee_rate, }, )) @@ -351,8 +334,6 @@ impl RequestContext { disable_output_substitution: self.disable_output_substitution, fee_contribution: self.fee_contribution, payee: self.payee.clone(), - input_type: self.input_type, - sequence: self.sequence, min_fee_rate: self.min_fee_rate, }, rs: Some(self.extract_rs_pubkey()?), @@ -395,8 +376,6 @@ pub struct ContextV1 { disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, - input_type: InputType, - sequence: Sequence, payee: ScriptBuf, } @@ -480,61 +459,37 @@ impl ContextV1 { fn process_proposal(self, mut proposal: Psbt) -> InternalResult { self.basic_checks(&proposal)?; - let in_stats = self.check_inputs(&proposal)?; - let out_stats = self.check_outputs(&proposal)?; - self.check_fees(&proposal, in_stats, out_stats)?; + self.check_inputs(&proposal)?; + let contributed_fee = self.check_outputs(&proposal)?; self.restore_original_utxos(&mut proposal)?; + self.check_fees(&proposal, contributed_fee)?; Ok(proposal) } - fn check_fees( - &self, - proposal: &Psbt, - in_stats: InputStats, - out_stats: OutputStats, - ) -> InternalResult<()> { - if out_stats.total_value > in_stats.total_value { - return Err(InternalValidationError::Inflation); - } - let proposed_psbt_fee = in_stats.total_value - out_stats.total_value; - let original_fee = self.original_psbt.calculate_fee(); - ensure!(original_fee <= proposed_psbt_fee, AbsoluteFeeDecreased); - ensure!( - out_stats.contributed_fee <= proposed_psbt_fee - original_fee, - PayeeTookContributedFee - ); - let original_weight = Weight::from_wu(u64::from(self.original_psbt.unsigned_tx.weight())); + fn check_fees(&self, proposal: &Psbt, contributed_fee: Amount) -> InternalResult<()> { + let proposed_fee = proposal.fee().map_err(InternalValidationError::Psbt)?; + let original_fee = self.original_psbt.fee().map_err(InternalValidationError::Psbt)?; + ensure!(original_fee <= proposed_fee, AbsoluteFeeDecreased); + ensure!(contributed_fee <= proposed_fee - original_fee, PayeeTookContributedFee); + let original_weight = self.original_psbt.clone().extract_tx_unchecked_fee_rate().weight(); let original_fee_rate = original_fee / original_weight; + // TODO: This should support mixed input types ensure!( - out_stats.contributed_fee + contributed_fee <= original_fee_rate - * self.input_type.expected_input_weight() + * self + .original_psbt + .input_pairs() + .next() + .expect("This shouldn't happen. Failed to get an original input.") + .expected_input_weight() + .expect("This shouldn't happen. Weight should have been calculated successfully before.") * (proposal.inputs.len() - self.original_psbt.inputs.len()) as u64, FeeContributionPaysOutputSizeIncrease ); if self.min_fee_rate > FeeRate::ZERO { - let non_input_output_size = - // version - 4 + - // count variants - varint_size(proposal.unsigned_tx.input.len() as u64) + - varint_size(proposal.unsigned_tx.output.len() as u64) + - // lock time - 4; - let weight_without_witnesses = - Weight::from_non_witness_data_size(non_input_output_size) - + in_stats.total_weight - + out_stats.total_weight; - let total_weight = if in_stats.inputs_with_witnesses == 0 { - weight_without_witnesses - } else { - weight_without_witnesses - + Weight::from_wu( - (proposal.unsigned_tx.input.len() - in_stats.inputs_with_witnesses + 2) - as u64, - ) - }; - ensure!(proposed_psbt_fee / total_weight >= self.min_fee_rate, FeeRateBelowMinimum); + let proposed_weight = proposal.clone().extract_tx_unchecked_fee_rate().weight(); + ensure!(proposed_fee / proposed_weight >= self.min_fee_rate, FeeRateBelowMinimum); } Ok(()) } @@ -554,13 +509,8 @@ impl ContextV1 { Ok(()) } - fn check_inputs(&self, proposal: &Psbt) -> InternalResult { - use crate::weight::ComputeSize; - + fn check_inputs(&self, proposal: &Psbt) -> InternalResult<()> { let mut original_inputs = self.original_psbt.input_pairs().peekable(); - let mut total_value = bitcoin::Amount::ZERO; - let mut total_weight = Weight::ZERO; - let mut inputs_with_witnesses = 0; for proposed in proposal.input_pairs() { ensure!(proposed.psbtin.bip32_derivation.is_empty(), TxInContainsKeyPaths); @@ -588,61 +538,34 @@ impl ContextV1 { proposed.psbtin.final_script_witness.is_none(), SenderTxinContainsFinalScriptWitness ); - let prevout = original.previous_txout().expect("We've validated this before"); - total_value += prevout.value; - // We assume the signture will be the same size - // I know sigs can be slightly different size but there isn't much to do about - // it other than prefer Taproot. - total_weight += original.txin.weight(); - if !original.txin.witness.is_empty() { - inputs_with_witnesses += 1; - } - original_inputs.next(); } // theirs (receiver) None | Some(_) => { + let original = self + .original_psbt + .input_pairs() + .next() + .ok_or(InternalValidationError::NoInputs)?; // Verify the PSBT input is finalized ensure!( proposed.psbtin.final_script_sig.is_some() || proposed.psbtin.final_script_witness.is_some(), ReceiverTxinNotFinalized ); - if let Some(script_sig) = &proposed.psbtin.final_script_sig { - // The weight of the TxIn when it's included in a legacy transaction - // (i.e., a transaction having only legacy inputs). - total_weight += Weight::from_non_witness_data_size( - 32 /* txid */ + 4 /* vout */ + 4 /* sequence */ + script_sig.encoded_size(), - ); - } - if let Some(script_witness) = &proposed.psbtin.final_script_witness { - if !script_witness.is_empty() { - inputs_with_witnesses += 1; - total_weight += crate::weight::witness_weight(script_witness); - }; - } - // Verify that non_witness_utxo or witness_utxo are filled in. ensure!( proposed.psbtin.witness_utxo.is_some() || proposed.psbtin.non_witness_utxo.is_some(), ReceiverTxinMissingUtxoInfo ); - ensure!(proposed.txin.sequence == self.sequence, MixedSequence); - let txout = proposed - .previous_txout() - .map_err(InternalValidationError::InvalidProposedInput)?; - total_value += txout.value; - check_eq!( - InputType::from_spent_input(txout, proposed.psbtin)?, - self.input_type, - MixedInputTypes - ); + ensure!(proposed.txin.sequence == original.txin.sequence, MixedSequence); + check_eq!(proposed.address_type()?, original.address_type()?, MixedInputTypes); } } } ensure!(original_inputs.peek().is_none(), MissingOrShuffledInputs); - Ok(InputStats { total_value, total_weight, inputs_with_witnesses }) + Ok(()) } // Restore Original PSBT utxos that the receiver stripped. @@ -668,19 +591,15 @@ impl ContextV1 { Ok(()) } - fn check_outputs(&self, proposal: &Psbt) -> InternalResult { + fn check_outputs(&self, proposal: &Psbt) -> InternalResult { let mut original_outputs = self.original_psbt.unsigned_tx.output.iter().enumerate().peekable(); - let mut total_value = bitcoin::Amount::ZERO; - let mut contributed_fee = bitcoin::Amount::ZERO; - let mut total_weight = Weight::ZERO; + let mut contributed_fee = Amount::ZERO; for (proposed_txout, proposed_psbtout) in proposal.unsigned_tx.output.iter().zip(&proposal.outputs) { ensure!(proposed_psbtout.bip32_derivation.is_empty(), TxOutContainsKeyPaths); - total_value += proposed_txout.value; - total_weight += proposed_txout.weight(); match (original_outputs.peek(), self.fee_contribution) { // fee output ( @@ -692,7 +611,7 @@ impl ContextV1 { if proposed_txout.value < original_output.value { contributed_fee = original_output.value - proposed_txout.value; ensure!(contributed_fee <= max_fee_contrib, FeeContributionExceedsMaximum); - //The remaining fee checks are done in the caller + // The remaining fee checks are done in later in `check_fees` } original_outputs.next(); } @@ -721,22 +640,10 @@ impl ContextV1 { } ensure!(original_outputs.peek().is_none(), MissingOrShuffledOutputs); - Ok(OutputStats { total_value, contributed_fee, total_weight }) + Ok(contributed_fee) } } -struct OutputStats { - total_value: bitcoin::Amount, - contributed_fee: bitcoin::Amount, - total_weight: Weight, -} - -struct InputStats { - total_value: bitcoin::Amount, - total_weight: Weight, - inputs_with_witnesses: usize, -} - fn check_single_payee( psbt: &Psbt, script_pubkey: &Script, @@ -920,19 +827,15 @@ mod test { const PAYJOIN_PROPOSAL: &str = "cHNidP8BAJwCAAAAAo8nutGgJdyYGXWiBEb45Hoe9lWGbkxh/6bNiOJdCDuDAAAAAAD+////jye60aAl3JgZdaIERvjkeh72VYZuTGH/ps2I4l0IO4MBAAAAAP7///8CJpW4BQAAAAAXqRQd6EnwadJ0FQ46/q6NcutaawlEMIcACT0AAAAAABepFHdAltvPSGdDwi9DR+m0af6+i2d6h9MAAAAAAQEgqBvXBQAAAAAXqRTeTh6QYcpZE1sDWtXm1HmQRUNU0IcBBBYAFMeKRXJTVYKNVlgHTdUmDV/LaYUwIgYDFZrAGqDVh1TEtNi300ntHt/PCzYrT2tVEGcjooWPhRYYSFzWUDEAAIABAACAAAAAgAEAAAAAAAAAAAEBIICEHgAAAAAAF6kUyPLL+cphRyyI5GTUazV0hF2R2NWHAQcXFgAUX4BmVeWSTJIEwtUb5TlPS/ntohABCGsCRzBEAiBnu3tA3yWlT0WBClsXXS9j69Bt+waCs9JcjWtNjtv7VgIge2VYAaBeLPDB6HGFlpqOENXMldsJezF9Gs5amvDQRDQBIQJl1jz1tBt8hNx2owTm+4Du4isx0pmdKNMNIjjaMHFfrQABABYAFEb2Giu6c4KO5YW0pfw3lGp9jMUUIgICygvBWB5prpfx61y1HDAwo37kYP3YRJBvAjtunBAur3wYSFzWUDEAAIABAACAAAAAgAEAAAABAAAAAAA="; fn create_v1_context() -> super::ContextV1 { - use crate::input_type::{InputType, SegWitV0Type}; let original_psbt = Psbt::from_str(ORIGINAL_PSBT).unwrap(); eprintln!("original: {:#?}", original_psbt); let payee = original_psbt.unsigned_tx.output[1].script_pubkey.clone(); - let sequence = original_psbt.unsigned_tx.input[0].sequence; let ctx = super::ContextV1 { original_psbt, disable_output_substitution: false, fee_contribution: Some((bitcoin::Amount::from_sat(182), 0)), min_fee_rate: FeeRate::ZERO, payee, - input_type: InputType::SegWitV0 { ty: SegWitV0Type::Pubkey, nested: true }, - sequence, }; ctx } @@ -987,11 +890,6 @@ mod test { disable_output_substitution: false, fee_contribution: None, min_fee_rate: FeeRate::ZERO, - input_type: InputType::SegWitV0 { - ty: crate::input_type::SegWitV0Type::Pubkey, - nested: true, - }, - sequence: Sequence::MAX, payee: ScriptBuf::from(vec![0x00]), e: HpkeSecretKey( ::PrivateKey::from_bytes(&[0x01; 32]) diff --git a/payjoin/src/weight.rs b/payjoin/src/weight.rs deleted file mode 100644 index e4e46541..00000000 --- a/payjoin/src/weight.rs +++ /dev/null @@ -1,53 +0,0 @@ -//! Implements advanced weight calculations for fee estimation. -use bitcoin::{OutPoint, Script, TxIn, Weight, Witness}; - -pub(crate) trait ComputeWeight { - fn weight(&self) -> Weight; -} - -pub(crate) trait ComputeSize { - fn encoded_size(&self) -> u64; -} - -pub(crate) fn varint_size(number: u64) -> u64 { - match number { - 0..=0xfc => 1, - 0xfd..=0xffff => 3, - 0x10000..=0xffffffff => 5, - 0x100000000..=0xffffffffffffffff => 9, - } -} - -pub(crate) fn witness_weight(witness: &Witness) -> Weight { - if witness.is_empty() { - return Weight::ZERO; - } - let mut size = varint_size(witness.len() as u64); - - for item in witness.iter() { - size += varint_size(item.len() as u64) + item.len() as u64; - } - - Weight::from_witness_data_size(size) -} - -impl ComputeSize for Script { - fn encoded_size(&self) -> u64 { self.len() as u64 + varint_size(self.len() as u64) } -} - -impl ComputeWeight for TxIn { - fn weight(&self) -> Weight { - Weight::from_non_witness_data_size( - self.script_sig.encoded_size() + 4, /* bytes encoding u32 sequence number */ - ) + self.previous_output.weight() - + witness_weight(&self.witness) - } -} - -impl ComputeWeight for OutPoint { - fn weight(&self) -> Weight { - Weight::from_non_witness_data_size( - 32 /* bytes encoding previous hash */ + 4, /* bytes encoding u32 output index */ - ) - } -}