diff --git a/Cargo-minimal.lock b/Cargo-minimal.lock index adc611bd..736cbddb 100644 --- a/Cargo-minimal.lock +++ b/Cargo-minimal.lock @@ -217,9 +217,9 @@ dependencies = [ [[package]] name = "bitcoin" -version = "0.32.4" +version = "0.32.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "788902099d47c8682efe6a7afb01c8d58b9794ba66c06affd81c3d6b560743eb" +checksum = "ce6bc65742dea50536e35ad42492b234c27904a27f0abdcbce605015cb4ea026" dependencies = [ "base58ck", "base64 0.21.7", diff --git a/Cargo-recent.lock b/Cargo-recent.lock index adc611bd..736cbddb 100644 --- a/Cargo-recent.lock +++ b/Cargo-recent.lock @@ -217,9 +217,9 @@ dependencies = [ [[package]] name = "bitcoin" -version = "0.32.4" +version = "0.32.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "788902099d47c8682efe6a7afb01c8d58b9794ba66c06affd81c3d6b560743eb" +checksum = "ce6bc65742dea50536e35ad42492b234c27904a27f0abdcbce605015cb4ea026" dependencies = [ "base58ck", "base64 0.21.7", diff --git a/payjoin-directory/src/db.rs b/payjoin-directory/src/db.rs index 41b9b6dc..6165abf9 100644 --- a/payjoin-directory/src/db.rs +++ b/payjoin-directory/src/db.rs @@ -7,19 +7,6 @@ use tracing::debug; const DEFAULT_COLUMN: &str = ""; const PJ_V1_COLUMN: &str = "pjv1"; -// TODO move to payjoin crate as pub? -// TODO impl From for ShortId -// TODO impl Display for ShortId (Base64) -// TODO impl TryFrom<&str> for ShortId (Base64) -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(crate) struct ShortId(pub [u8; 8]); - -impl ShortId { - pub fn column_key(&self, column: &str) -> Vec { - self.0.iter().chain(column.as_bytes()).copied().collect() - } -} - #[derive(Debug, Clone)] pub(crate) struct DbPool { client: Client, @@ -32,30 +19,30 @@ impl DbPool { Ok(Self { client, timeout }) } - pub async fn push_default(&self, pubkey_id: &ShortId, data: Vec) -> RedisResult<()> { - self.push(pubkey_id, DEFAULT_COLUMN, data).await + pub async fn push_default(&self, subdirectory_id: &str, data: Vec) -> RedisResult<()> { + self.push(subdirectory_id, DEFAULT_COLUMN, data).await } - pub async fn peek_default(&self, pubkey_id: &ShortId) -> Option>> { - self.peek_with_timeout(pubkey_id, DEFAULT_COLUMN).await + pub async fn peek_default(&self, subdirectory_id: &str) -> Option>> { + self.peek_with_timeout(subdirectory_id, DEFAULT_COLUMN).await } - pub async fn push_v1(&self, pubkey_id: &ShortId, data: Vec) -> RedisResult<()> { - self.push(pubkey_id, PJ_V1_COLUMN, data).await + pub async fn push_v1(&self, subdirectory_id: &str, data: Vec) -> RedisResult<()> { + self.push(subdirectory_id, PJ_V1_COLUMN, data).await } - pub async fn peek_v1(&self, pubkey_id: &ShortId) -> Option>> { - self.peek_with_timeout(pubkey_id, PJ_V1_COLUMN).await + pub async fn peek_v1(&self, subdirectory_id: &str) -> Option>> { + self.peek_with_timeout(subdirectory_id, PJ_V1_COLUMN).await } async fn push( &self, - pubkey_id: &ShortId, + subdirectory_id: &str, channel_type: &str, data: Vec, ) -> RedisResult<()> { let mut conn = self.client.get_async_connection().await?; - let key = pubkey_id.column_key(channel_type); + let key = channel_name(subdirectory_id, channel_type); () = conn.set(&key, data.clone()).await?; () = conn.publish(&key, "updated").await?; Ok(()) @@ -63,17 +50,17 @@ impl DbPool { async fn peek_with_timeout( &self, - pubkey_id: &ShortId, + subdirectory_id: &str, channel_type: &str, ) -> Option>> { - tokio::time::timeout(self.timeout, self.peek(pubkey_id, channel_type)).await.ok() + tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await.ok() } - async fn peek(&self, pubkey_id: &ShortId, channel_type: &str) -> RedisResult> { + async fn peek(&self, subdirectory_id: &str, channel_type: &str) -> RedisResult> { let mut conn = self.client.get_async_connection().await?; - let key = pubkey_id.column_key(channel_type); + let key = channel_name(subdirectory_id, channel_type); - // Attempt to fetch existing content for the given pubkey_id and channel_type + // Attempt to fetch existing content for the given subdirectory_id and channel_type if let Ok(data) = conn.get::<_, Vec>(&key).await { if !data.is_empty() { return Ok(data); @@ -83,7 +70,7 @@ impl DbPool { // Set up a temporary listener for changes let mut pubsub_conn = self.client.get_async_connection().await?.into_pubsub(); - let channel_name = pubkey_id.column_key(channel_type); + let channel_name = channel_name(subdirectory_id, channel_type); pubsub_conn.subscribe(&channel_name).await?; // Use a block to limit the scope of the mutable borrow @@ -116,3 +103,7 @@ impl DbPool { Ok(data) } } + +fn channel_name(subdirectory_id: &str, channel_type: &str) -> Vec { + (subdirectory_id.to_owned() + channel_type).into_bytes() +} diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index 04a05de2..0fde99a4 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -3,8 +3,6 @@ use std::sync::Arc; use std::time::Duration; use anyhow::Result; -use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; -use bitcoin::base64::Engine; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Empty, Full}; use hyper::body::{Body, Bytes, Incoming}; @@ -17,8 +15,6 @@ use tokio::net::TcpListener; use tokio::sync::Mutex; use tracing::{debug, error, info, trace}; -use crate::db::ShortId; - pub const DEFAULT_DIR_PORT: u16 = 8080; pub const DEFAULT_DB_HOST: &str = "localhost:6379"; pub const DEFAULT_TIMEOUT_SECS: u64 = 30; @@ -34,6 +30,9 @@ const V1_REJECT_RES_JSON: &str = r#"{{"errorCode": "original-psbt-rejected ", "message": "Body is not a string"}}"#; const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message": "V2 receiver offline. V1 sends require synchronous communications."}}"#; +// 8 bytes as bech32 is 12.8 characters +const ID_LENGTH: usize = 13; + mod db; use crate::db::DbPool; @@ -306,11 +305,11 @@ async fn post_fallback_v1( }; let v2_compat_body = format!("{}\n{}", body_str, query); - let id = decode_short_id(id)?; - pool.push_default(&id, v2_compat_body.into()) + let id = check_id_length(id)?; + pool.push_default(id, v2_compat_body.into()) .await .map_err(|e| HandlerError::BadRequest(e.into()))?; - match pool.peek_v1(&id).await { + match pool.peek_v1(id).await { Some(result) => match result { Ok(buffered_req) => Ok(Response::new(full(buffered_req))), Err(e) => Err(HandlerError::BadRequest(e.into())), @@ -327,19 +326,29 @@ async fn put_payjoin_v1( trace!("Put_payjoin_v1"); let ok_response = Response::builder().status(StatusCode::OK).body(empty())?; - let id = decode_short_id(id)?; + let id = check_id_length(id)?; let req = body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); if req.len() > V1_MAX_BUFFER_SIZE { return Err(HandlerError::PayloadTooLarge); } - match pool.push_v1(&id, req.into()).await { + match pool.push_v1(id, req.into()).await { Ok(_) => Ok(ok_response), Err(e) => Err(HandlerError::BadRequest(e.into())), } } +fn check_id_length(id: &str) -> Result<&str, HandlerError> { + if id.len() != ID_LENGTH { + return Err(HandlerError::BadRequest(anyhow::anyhow!( + "subdirectory ID must be 13 bech32 characters", + ))); + } + + Ok(id) +} + async fn post_subdir( id: &str, body: BoxBody, @@ -348,14 +357,15 @@ async fn post_subdir( let none_response = Response::builder().status(StatusCode::OK).body(empty())?; trace!("post_subdir"); - let id = decode_short_id(id)?; + let id = check_id_length(id)?; + let req = body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); if req.len() > V1_MAX_BUFFER_SIZE { return Err(HandlerError::PayloadTooLarge); } - match pool.push_default(&id, req.into()).await { + match pool.push_default(id, req.into()).await { Ok(_) => Ok(none_response), Err(e) => Err(HandlerError::BadRequest(e.into())), } @@ -366,8 +376,8 @@ async fn get_subdir( pool: DbPool, ) -> Result>, HandlerError> { trace!("get_subdir"); - let id = decode_short_id(id)?; - match pool.peek_default(&id).await { + let id = check_id_length(id)?; + match pool.peek_default(id).await { Some(result) => match result { Ok(buffered_req) => Ok(Response::new(full(buffered_req))), Err(e) => Err(HandlerError::BadRequest(e.into())), @@ -396,16 +406,6 @@ async fn get_ohttp_keys( Ok(res) } -fn decode_short_id(input: &str) -> Result { - let decoded = - BASE64_URL_SAFE_NO_PAD.decode(input).map_err(|e| HandlerError::BadRequest(e.into()))?; - - decoded[..8] - .try_into() - .map_err(|_| HandlerError::BadRequest(anyhow::anyhow!("Invalid subdirectory ID"))) - .map(ShortId) -} - fn empty() -> BoxBody { Empty::::new().map_err(|never| match never {}).boxed() } diff --git a/payjoin/Cargo.toml b/payjoin/Cargo.toml index ff503aa5..d75b6199 100644 --- a/payjoin/Cargo.toml +++ b/payjoin/Cargo.toml @@ -19,12 +19,12 @@ exclude = ["tests"] send = [] receive = ["bitcoin/rand"] base64 = ["bitcoin/base64"] -v2 = ["bitcoin/rand", "bitcoin/serde", "hpke", "dep:http", "bhttp", "ohttp", "serde", "url/serde"] +v2 = ["bitcoin/rand", "bitcoin/serde", "hpke", "dep:http", "bhttp", "ohttp", "serde", "url/serde" ] io = ["reqwest/rustls-tls"] danger-local-https = ["io", "reqwest/rustls-tls", "rustls"] [dependencies] -bitcoin = { version = "0.32.4", features = ["base64"] } +bitcoin = { version = "0.32.5", features = ["base64"] } bip21 = "0.5.0" hpke = { package = "bitcoin-hpke", version = "0.13.0", optional = true } log = { version = "0.4.14"} diff --git a/payjoin/src/bech32.rs b/payjoin/src/bech32.rs new file mode 100644 index 00000000..59eb3105 --- /dev/null +++ b/payjoin/src/bech32.rs @@ -0,0 +1,49 @@ +use std::fmt; + +use bitcoin::bech32::primitives::decode::{CheckedHrpstring, CheckedHrpstringError}; +use bitcoin::bech32::{self, EncodeError, Hrp, NoChecksum}; + +pub mod nochecksum { + use super::*; + + pub fn decode(encoded: &str) -> Result<(Hrp, Vec), CheckedHrpstringError> { + let hrp_string = CheckedHrpstring::new::(encoded)?; + Ok((hrp_string.hrp(), hrp_string.byte_iter().collect::>())) + } + + pub fn encode(hrp: Hrp, data: &[u8]) -> Result { + bech32::encode_upper::(hrp, data) + } + + pub fn encode_to_fmt(f: &mut fmt::Formatter, hrp: Hrp, data: &[u8]) -> Result<(), EncodeError> { + bech32::encode_upper_to_fmt::(f, hrp, data) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn bech32_for_qr() { + let bytes = vec![0u8, 1, 2, 3, 31, 32, 33, 95, 0, 96, 127, 128, 129, 254, 255, 0]; + let hrp = Hrp::parse("STUFF").unwrap(); + let encoded = nochecksum::encode(hrp, &bytes).unwrap(); + let decoded = nochecksum::decode(&encoded).unwrap(); + assert_eq!(decoded, (hrp, bytes.to_vec())); + + // no checksum + assert_eq!( + encoded.len() as f32, + (hrp.as_str().len() + 1) as f32 + (bytes.len() as f32 * 8.0 / 5.0).ceil() + ); + + // TODO assert uppercase + + // should not error + let corrupted = encoded + "QQPP"; + let decoded = nochecksum::decode(&corrupted).unwrap(); + assert_eq!(decoded.0, hrp); + assert_ne!(decoded, (hrp, bytes.to_vec())); + } +} diff --git a/payjoin/src/lib.rs b/payjoin/src/lib.rs index 2bc3a2f3..5d77ce58 100644 --- a/payjoin/src/lib.rs +++ b/payjoin/src/lib.rs @@ -35,6 +35,8 @@ pub use crate::hpke::{HpkeKeyPair, HpkePublicKey}; pub(crate) mod ohttp; #[cfg(feature = "v2")] pub use crate::ohttp::OhttpKeys; +#[cfg(feature = "v2")] +pub(crate) mod bech32; #[cfg(feature = "io")] pub mod io; diff --git a/payjoin/src/ohttp.rs b/payjoin/src/ohttp.rs index b28ad45b..f62ee6fe 100644 --- a/payjoin/src/ohttp.rs +++ b/payjoin/src/ohttp.rs @@ -1,8 +1,7 @@ use std::ops::{Deref, DerefMut}; use std::{error, fmt}; -use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; -use bitcoin::base64::Engine; +use bitcoin::bech32::{self, EncodeError}; use bitcoin::key::constants::UNCOMPRESSED_PUBLIC_KEY_SIZE; pub const ENCAPSULATED_MESSAGE_BYTES: usize = 8192; @@ -145,19 +144,19 @@ impl fmt::Display for OhttpKeys { let mut buf = vec![key_id]; buf.extend_from_slice(&compressed_pubkey); - let encoded = BASE64_URL_SAFE_NO_PAD.encode(buf); - write!(f, "{}", encoded) + let oh_hrp: bech32::Hrp = bech32::Hrp::parse("OH").unwrap(); + + crate::bech32::nochecksum::encode_to_fmt(f, oh_hrp, &buf).map_err(|e| match e { + EncodeError::Fmt(e) => e, + _ => fmt::Error, + }) } } -impl std::str::FromStr for OhttpKeys { - type Err = ParseOhttpKeysError; - - /// Parses a base64URL-encoded string into OhttpKeys. - /// The string format is: key_id || compressed_public_key - fn from_str(s: &str) -> Result { - let bytes = BASE64_URL_SAFE_NO_PAD.decode(s).map_err(ParseOhttpKeysError::DecodeBase64)?; +impl TryFrom<&[u8]> for OhttpKeys { + type Error = ParseOhttpKeysError; + fn try_from(bytes: &[u8]) -> Result { let key_id = *bytes.first().ok_or(ParseOhttpKeysError::InvalidFormat)?; let compressed_pk = bytes.get(1..34).ok_or(ParseOhttpKeysError::InvalidFormat)?; @@ -174,6 +173,26 @@ impl std::str::FromStr for OhttpKeys { } } +impl std::str::FromStr for OhttpKeys { + type Err = ParseOhttpKeysError; + + /// Parses a base64URL-encoded string into OhttpKeys. + /// The string format is: key_id || compressed_public_key + fn from_str(s: &str) -> Result { + // TODO extract to utility function + let oh_hrp: bech32::Hrp = bech32::Hrp::parse("OH").unwrap(); + + let (hrp, bytes) = + crate::bech32::nochecksum::decode(s).map_err(ParseOhttpKeysError::DecodeBech32)?; + + if hrp != oh_hrp { + return Err(ParseOhttpKeysError::InvalidFormat); + } + + Self::try_from(&bytes[..]) + } +} + impl PartialEq for OhttpKeys { fn eq(&self, other: &Self) -> bool { match (self.encode(), other.encode()) { @@ -220,7 +239,7 @@ impl serde::Serialize for OhttpKeys { pub enum ParseOhttpKeysError { InvalidFormat, InvalidPublicKey, - DecodeBase64(bitcoin::base64::DecodeError), + DecodeBech32(bech32::primitives::decode::CheckedHrpstringError), DecodeKeyConfig(ohttp::Error), } @@ -229,7 +248,7 @@ impl std::fmt::Display for ParseOhttpKeysError { match self { ParseOhttpKeysError::InvalidFormat => write!(f, "Invalid format"), ParseOhttpKeysError::InvalidPublicKey => write!(f, "Invalid public key"), - ParseOhttpKeysError::DecodeBase64(e) => write!(f, "Failed to decode base64: {}", e), + ParseOhttpKeysError::DecodeBech32(e) => write!(f, "Failed to decode base64: {}", e), ParseOhttpKeysError::DecodeKeyConfig(e) => write!(f, "Failed to decode KeyConfig: {}", e), } @@ -239,7 +258,7 @@ impl std::fmt::Display for ParseOhttpKeysError { impl std::error::Error for ParseOhttpKeysError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { - ParseOhttpKeysError::DecodeBase64(e) => Some(e), + ParseOhttpKeysError::DecodeBech32(e) => Some(e), ParseOhttpKeysError::DecodeKeyConfig(e) => Some(e), ParseOhttpKeysError::InvalidFormat | ParseOhttpKeysError::InvalidPublicKey => None, } diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index 39cd2a1a..b2d619f9 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -1,8 +1,6 @@ use std::str::FromStr; use std::time::{Duration, SystemTime}; -use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; -use bitcoin::base64::Engine; use bitcoin::hashes::{sha256, Hash}; use bitcoin::psbt::Psbt; use bitcoin::{Address, FeeRate, OutPoint, Script, TxOut}; @@ -20,6 +18,7 @@ use crate::ohttp::{ohttp_decapsulate, ohttp_encapsulate, OhttpEncapsulationError use crate::psbt::PsbtExt; use crate::receive::optional_parameters::Params; use crate::receive::InputPair; +use crate::uri::ShortId; use crate::{PjUriBuilder, Request}; pub(crate) mod error; @@ -48,9 +47,8 @@ where Ok(address.assume_checked()) } -fn subdir_path_from_pubkey(pubkey: &HpkePublicKey) -> String { - let hash = sha256::Hash::hash(&pubkey.to_compressed_bytes()); - BASE64_URL_SAFE_NO_PAD.encode(&hash.as_byte_array()[..8]) +fn subdir_path_from_pubkey(pubkey: &HpkePublicKey) -> ShortId { + sha256::Hash::hash(&pubkey.to_compressed_bytes()).into() } /// A payjoin V2 receiver, allowing for polled requests to the @@ -200,22 +198,18 @@ impl Receiver { // The contents of the `&pj=` query parameter. // This identifies a session at the payjoin directory server. pub fn pj_url(&self) -> Url { - let id_base64 = BASE64_URL_SAFE_NO_PAD.encode(self.id()); let mut url = self.context.directory.clone(); { let mut path_segments = url.path_segments_mut().expect("Payjoin Directory URL cannot be a base"); - path_segments.push(&id_base64); + path_segments.push(&self.id().to_string()); } url } /// The per-session identifier - pub fn id(&self) -> [u8; 8] { - let hash = sha256::Hash::hash(&self.context.s.public_key().to_compressed_bytes()); - hash.as_byte_array()[..8] - .try_into() - .expect("truncating SHA256 to 8 bytes should always succeed") + pub fn id(&self) -> ShortId { + sha256::Hash::hash(&self.context.s.public_key().to_compressed_bytes()).into() } } @@ -479,8 +473,11 @@ impl PayjoinProposal { // Prepare v2 payload let payjoin_bytes = self.inner.payjoin_psbt.serialize(); let sender_subdir = subdir_path_from_pubkey(e); - target_resource = - self.context.directory.join(&sender_subdir).map_err(|e| Error::Server(e.into()))?; + target_resource = self + .context + .directory + .join(&sender_subdir.to_string()) + .map_err(|e| Error::Server(e.into()))?; body = encrypt_message_b(payjoin_bytes, &self.context.s, e)?; method = "POST"; } else { @@ -490,7 +487,7 @@ impl PayjoinProposal { target_resource = self .context .directory - .join(&receiver_subdir) + .join(&receiver_subdir.to_string()) .map_err(|e| Error::Server(e.into()))?; method = "PUT"; } diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 07ef0946..2292a00a 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -23,8 +23,6 @@ use std::str::FromStr; -#[cfg(feature = "v2")] -use bitcoin::base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; #[cfg(feature = "v2")] use bitcoin::hashes::{sha256, Hash}; use bitcoin::psbt::Psbt; @@ -41,6 +39,8 @@ use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeKeyPair, HpkePublicK use crate::ohttp::{ohttp_decapsulate, ohttp_encapsulate}; use crate::psbt::PsbtExt; use crate::request::Request; +#[cfg(feature = "v2")] +use crate::uri::ShortId; use crate::PjUri; // See usize casts @@ -405,8 +405,8 @@ impl V2GetContext { // TODO unify with receiver's fn subdir_path_from_pubkey let hash = sha256::Hash::hash(&self.hpke_ctx.reply_pair.public_key().to_compressed_bytes()); - let subdir = BASE64_URL_SAFE_NO_PAD.encode(&hash.as_byte_array()[..8]); - url.set_path(&subdir); + let subdir: ShortId = hash.into(); + url.set_path(&subdir.to_string()); let body = encrypt_message_a( Vec::new(), &self.hpke_ctx.reply_pair.public_key().clone(), diff --git a/payjoin/src/uri/error.rs b/payjoin/src/uri/error.rs index 7bd94281..f4ef7fab 100644 --- a/payjoin/src/uri/error.rs +++ b/payjoin/src/uri/error.rs @@ -15,7 +15,8 @@ pub(crate) enum InternalPjParseError { #[derive(Debug)] pub(crate) enum ParseReceiverPubkeyError { MissingPubkey, - PubkeyNotBase64(bitcoin::base64::DecodeError), + InvalidHrp(bitcoin::bech32::Hrp), + DecodeBech32(bitcoin::bech32::primitives::decode::CheckedHrpstringError), InvalidPubkey(crate::hpke::HpkeError), } @@ -26,7 +27,8 @@ impl std::fmt::Display for ParseReceiverPubkeyError { match &self { MissingPubkey => write!(f, "receiver public key is missing"), - PubkeyNotBase64(e) => write!(f, "receiver public is not valid base64: {}", e), + InvalidHrp(h) => write!(f, "incorrect hrp for receiver key: {}", h), + DecodeBech32(e) => write!(f, "receiver public is not valid base64: {}", e), InvalidPubkey(e) => write!(f, "receiver public key does not represent a valid pubkey: {}", e), } @@ -40,7 +42,8 @@ impl std::error::Error for ParseReceiverPubkeyError { match &self { MissingPubkey => None, - PubkeyNotBase64(error) => Some(error), + InvalidHrp(_) => None, + DecodeBech32(error) => Some(error), InvalidPubkey(error) => Some(error), } } diff --git a/payjoin/src/uri/mod.rs b/payjoin/src/uri/mod.rs index 6281446c..197e3402 100644 --- a/payjoin/src/uri/mod.rs +++ b/payjoin/src/uri/mod.rs @@ -17,6 +17,64 @@ pub mod error; #[cfg(feature = "v2")] pub(crate) mod url_ext; +#[cfg(feature = "v2")] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ShortId(pub [u8; 8]); + +#[cfg(feature = "v2")] +impl ShortId { + pub fn as_bytes(&self) -> &[u8] { &self.0 } + pub fn as_slice(&self) -> &[u8] { &self.0 } +} + +#[cfg(feature = "v2")] +impl std::fmt::Display for ShortId { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let id_hrp = bitcoin::bech32::Hrp::parse("ID").unwrap(); + f.write_str( + crate::bech32::nochecksum::encode(id_hrp, &self.0) + .expect("bech32 encoding of short ID must succeed") + .strip_prefix("ID1") + .expect("human readable part must be ID1"), + ) + } +} + +#[cfg(feature = "v2")] +#[derive(Debug)] +pub enum ShortIdError { + DecodeBech32(bitcoin::bech32::primitives::decode::CheckedHrpstringError), + IncorrectLength(std::array::TryFromSliceError), +} + +#[cfg(feature = "v2")] +impl std::convert::From for ShortId { + fn from(h: bitcoin::hashes::sha256::Hash) -> Self { + bitcoin::hashes::Hash::as_byte_array(&h)[..8] + .try_into() + .expect("truncating SHA256 to 8 bytes should always succeed") + } +} + +#[cfg(feature = "v2")] +impl std::convert::TryFrom<&[u8]> for ShortId { + type Error = ShortIdError; + fn try_from(bytes: &[u8]) -> Result { + let bytes: [u8; 8] = bytes.try_into().map_err(ShortIdError::IncorrectLength)?; + Ok(Self(bytes)) + } +} + +#[cfg(feature = "v2")] +impl std::str::FromStr for ShortId { + type Err = ShortIdError; + fn from_str(s: &str) -> Result { + let (_, bytes) = crate::bech32::nochecksum::decode(&("ID1".to_string() + s)) + .map_err(ShortIdError::DecodeBech32)?; + (&bytes[..]).try_into() + } +} + #[derive(Debug, Clone)] pub enum MaybePayjoinExtras { Supported(PayjoinExtras), @@ -123,11 +181,17 @@ impl PjUriBuilder { #[allow(unused_mut)] let mut pj = origin; #[cfg(feature = "v2")] - pj.set_receiver_pubkey(receiver_pubkey); + if let Some(receiver_pubkey) = receiver_pubkey { + pj.set_receiver_pubkey(receiver_pubkey); + } #[cfg(feature = "v2")] - pj.set_ohttp(ohttp_keys); + if let Some(ohttp_keys) = ohttp_keys { + pj.set_ohttp(ohttp_keys); + } #[cfg(feature = "v2")] - pj.set_exp(expiry); + if let Some(expiry) = expiry { + pj.set_exp(expiry); + } Self { address, amount: None, message: None, label: None, pj, pjos: false } } /// Set the amount you want to receive. @@ -205,9 +269,19 @@ impl bip21::SerializeParams for &PayjoinExtras { type Iterator = std::vec::IntoIter<(Self::Key, Self::Value)>; fn serialize_params(self) -> Self::Iterator { + // normalizing to uppercase enables QR alphanumeric mode encoding + // unfortunately Url normalizes these to be lowercase + let scheme = self.endpoint.scheme(); + let host = self.endpoint.host_str().expect("host must be set"); + let endpoint_str = self + .endpoint + .as_str() + .replacen(scheme, &scheme.to_uppercase(), 1) + .replacen(host, &host.to_uppercase(), 1); + vec![ - ("pj", self.endpoint.as_str().to_string()), ("pjos", if self.disable_output_substitution { "1" } else { "0" }.to_string()), + ("pj", endpoint_str), ] .into_iter() } diff --git a/payjoin/src/uri/url_ext.rs b/payjoin/src/uri/url_ext.rs index 824be1a2..26d4267d 100644 --- a/payjoin/src/uri/url_ext.rs +++ b/payjoin/src/uri/url_ext.rs @@ -1,7 +1,8 @@ use std::str::FromStr; -use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; -use bitcoin::base64::Engine; +use bitcoin::bech32::Hrp; +use bitcoin::consensus::encode::Decodable; +use bitcoin::consensus::Encodable; use url::Url; use super::error::ParseReceiverPubkeyError; @@ -11,65 +12,86 @@ use crate::OhttpKeys; /// Parse and set fragment parameters from `&pj=` URI parameter URLs pub(crate) trait UrlExt { fn receiver_pubkey(&self) -> Result; - fn set_receiver_pubkey(&mut self, exp: Option); + fn set_receiver_pubkey(&mut self, exp: HpkePublicKey); fn ohttp(&self) -> Option; - fn set_ohttp(&mut self, ohttp: Option); + fn set_ohttp(&mut self, ohttp: OhttpKeys); fn exp(&self) -> Option; - fn set_exp(&mut self, exp: Option); + fn set_exp(&mut self, exp: std::time::SystemTime); } impl UrlExt for Url { /// Retrieve the receiver's public key from the URL fragment fn receiver_pubkey(&self) -> Result { - let value = get_param(self, "rk=", |v| Some(v.to_owned())) + let value = get_param(self, "RK1", |v| Some(v.to_owned())) .ok_or(ParseReceiverPubkeyError::MissingPubkey)?; - let decoded = BASE64_URL_SAFE_NO_PAD - .decode(&value) - .map_err(ParseReceiverPubkeyError::PubkeyNotBase64)?; + let (hrp, bytes) = crate::bech32::nochecksum::decode(&value) + .map_err(ParseReceiverPubkeyError::DecodeBech32)?; - HpkePublicKey::from_compressed_bytes(&decoded) + let rk_hrp: Hrp = Hrp::parse("RK").unwrap(); + if hrp != rk_hrp { + return Err(ParseReceiverPubkeyError::InvalidHrp(hrp)); + } + + HpkePublicKey::from_compressed_bytes(&bytes[..]) .map_err(ParseReceiverPubkeyError::InvalidPubkey) } /// Set the receiver's public key in the URL fragment - fn set_receiver_pubkey(&mut self, pubkey: Option) { + fn set_receiver_pubkey(&mut self, pubkey: HpkePublicKey) { + let rk_hrp: Hrp = Hrp::parse("RK").unwrap(); + set_param( self, - "rk=", - pubkey.map(|k| BASE64_URL_SAFE_NO_PAD.encode(k.to_compressed_bytes())), + "RK1", + &crate::bech32::nochecksum::encode(rk_hrp, &pubkey.to_compressed_bytes()) + .expect("encoding compressed pubkey bytes should never fail"), ) } /// Retrieve the ohttp parameter from the URL fragment fn ohttp(&self) -> Option { - get_param(self, "ohttp=", |value| OhttpKeys::from_str(value).ok()) + get_param(self, "OH1", |value| OhttpKeys::from_str(value).ok()) } /// Set the ohttp parameter in the URL fragment - fn set_ohttp(&mut self, ohttp: Option) { - set_param(self, "ohttp=", ohttp.map(|o| o.to_string())) - } + fn set_ohttp(&mut self, ohttp: OhttpKeys) { set_param(self, "OH1", &ohttp.to_string()) } /// Retrieve the exp parameter from the URL fragment fn exp(&self) -> Option { - get_param(self, "exp=", |value| { - value - .parse::() + get_param(self, "EX1", |value| { + let (hrp, bytes) = crate::bech32::nochecksum::decode(value).ok()?; + + let ex_hrp: Hrp = Hrp::parse("EX").unwrap(); + if hrp != ex_hrp { + return None; + } + + let mut cursor = &bytes[..]; + u32::consensus_decode(&mut cursor) + .map(|timestamp| { + std::time::UNIX_EPOCH + std::time::Duration::from_secs(timestamp as u64) + }) .ok() - .map(|timestamp| std::time::UNIX_EPOCH + std::time::Duration::from_secs(timestamp)) }) } /// Set the exp parameter in the URL fragment - fn set_exp(&mut self, exp: Option) { - let exp_str = exp.map(|e| { - match e.duration_since(std::time::UNIX_EPOCH) { - Ok(duration) => duration.as_secs().to_string(), - Err(_) => "0".to_string(), // Handle times before Unix epoch by setting to "0" - } - }); - set_param(self, "exp=", exp_str) + fn set_exp(&mut self, exp: std::time::SystemTime) { + let t = match exp.duration_since(std::time::UNIX_EPOCH) { + Ok(duration) => duration.as_secs().try_into().unwrap(), // TODO Result type instead of Option & unwrap + Err(_) => 0u32, + }; + + let mut buf = [0u8; 4]; + t.consensus_encode(&mut &mut buf[..]).unwrap(); // TODO no unwrap + + let ex_hrp: Hrp = Hrp::parse("EX").unwrap(); + + let exp_str = crate::bech32::nochecksum::encode(ex_hrp, &buf) + .expect("encoding u32 timestamp should never fail"); + + set_param(self, "EX1", &exp_str) } } @@ -78,33 +100,30 @@ where F: Fn(&str) -> Option, { if let Some(fragment) = url.fragment() { - for param in fragment.split('&') { - if let Some(value) = param.strip_prefix(prefix) { - return parse(value); + for param in fragment.split('+') { + if param.starts_with(prefix) { + return parse(param); } } } None } -fn set_param(url: &mut Url, prefix: &str, value: Option) { +fn set_param(url: &mut Url, prefix: &str, param: &str) { let fragment = url.fragment().unwrap_or(""); let mut fragment = fragment.to_string(); if let Some(start) = fragment.find(prefix) { - let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i); + let end = fragment[start..].find('+').map_or(fragment.len(), |i| start + i); fragment.replace_range(start..end, ""); - if fragment.ends_with('&') { + if fragment.ends_with('+') { fragment.pop(); } } - if let Some(value) = value { - let new_param = format!("{}{}", prefix, value); - if !fragment.is_empty() { - fragment.push('&'); - } - fragment.push_str(&new_param); + if !fragment.is_empty() { + fragment.push('+'); } + fragment.push_str(param); url.set_fragment(if fragment.is_empty() { None } else { Some(&fragment) }); } @@ -118,15 +137,12 @@ mod tests { fn test_ohttp_get_set() { let mut url = Url::parse("https://example.com").unwrap(); - let ohttp_keys = - OhttpKeys::from_str("AQO6SMScPUqSo60A7MY6Ak2hDO0CGAxz7BLYp60syRu0gw").unwrap(); - url.set_ohttp(Some(ohttp_keys.clone())); - assert_eq!(url.fragment(), Some("ohttp=AQO6SMScPUqSo60A7MY6Ak2hDO0CGAxz7BLYp60syRu0gw")); + let serialized = "OH1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC"; + let ohttp_keys = OhttpKeys::from_str(serialized).unwrap(); + url.set_ohttp(ohttp_keys.clone()); + assert_eq!(url.fragment(), Some(serialized)); assert_eq!(url.ohttp(), Some(ohttp_keys)); - - url.set_ohttp(None); - assert_eq!(url.fragment(), None); } #[test] @@ -135,30 +151,17 @@ mod tests { let exp_time = std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(1720547781); - url.set_exp(Some(exp_time)); - assert_eq!(url.fragment(), Some("exp=1720547781")); + url.set_exp(exp_time); + assert_eq!(url.fragment(), Some("EX1C4UC6ES")); assert_eq!(url.exp(), Some(exp_time)); - - url.set_exp(None); - assert_eq!(url.fragment(), None); - } - - #[test] - fn test_invalid_v2_url_fragment_on_bip21() { - // fragment is not percent encoded so `&ohttp=` is parsed as a query parameter, not a fragment parameter - let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?amount=0.01\ - &pj=https://example.com\ - #exp=1720547781&ohttp=AQO6SMScPUqSo60A7MY6Ak2hDO0CGAxz7BLYp60syRu0gw"; - let uri = Uri::try_from(uri).unwrap().assume_checked().check_pj_supported().unwrap(); - assert!(uri.extras.endpoint().ohttp().is_none()); } #[test] fn test_valid_v2_url_fragment_on_bip21() { let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?amount=0.01\ &pj=https://example.com\ - #ohttp%3DAQO6SMScPUqSo60A7MY6Ak2hDO0CGAxz7BLYp60syRu0gw"; + #OH1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC"; let uri = Uri::try_from(uri).unwrap().assume_checked().check_pj_supported().unwrap(); assert!(uri.extras.endpoint().ohttp().is_some()); } diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index becf82a3..3223742a 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -193,7 +193,7 @@ mod integration { #[tokio::test] async fn test_bad_ohttp_keys() { let bad_ohttp_keys = - OhttpKeys::from_str("AQO6SMScPUqSo60A7MY6Ak2hDO0CGAxz7BLYp60syRu0gw") + OhttpKeys::from_str("OH1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC") .expect("Invalid OhttpKeys"); let (cert, key) = local_cert_key();