From 8b80fcf2b4e9cd4c08a2684e0e7329781a1fbc69 Mon Sep 17 00:00:00 2001 From: DanGould Date: Wed, 9 Oct 2024 23:13:04 -0400 Subject: [PATCH 01/12] Make session initialization implicit A session is now initialized by generating keys and sharing them out of band. The semantics of the protocol are otherwise unchanged. Because sessions are implicit the typestate is now called `Receiver`. --- payjoin-cli/src/app/v2.rs | 24 ++--------- payjoin-cli/src/db/v2.rs | 9 ++--- payjoin-cli/tests/e2e.rs | 9 +---- payjoin-directory/src/lib.rs | 68 +++---------------------------- payjoin-directory/src/main.rs | 4 +- payjoin/src/receive/v2/mod.rs | 76 +++++++---------------------------- payjoin/tests/integration.rs | 58 ++++++-------------------- 7 files changed, 43 insertions(+), 205 deletions(-) diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index 3a8c8a18..6a6e27c9 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -7,7 +7,7 @@ use bitcoincore_rpc::RpcApi; use payjoin::bitcoin::consensus::encode::serialize_hex; use payjoin::bitcoin::psbt::Psbt; use payjoin::bitcoin::{Amount, FeeRate}; -use payjoin::receive::v2::ActiveSession; +use payjoin::receive::v2::Receiver; use payjoin::send::RequestContext; use payjoin::{bitcoin, Error, Uri}; use tokio::signal; @@ -75,32 +75,16 @@ impl AppTrait for App { } async fn receive_payjoin(self, amount_arg: &str) -> Result<()> { - use payjoin::receive::v2::SessionInitializer; - let address = self.bitcoind()?.get_new_address(None, None)?.assume_checked(); let amount = Amount::from_sat(amount_arg.parse()?); let ohttp_keys = unwrap_ohttp_keys_or_else_fetch(&self.config).await?; - let mut initializer = SessionInitializer::new( + let session = Receiver::new( address, self.config.pj_directory.clone(), ohttp_keys.clone(), self.config.ohttp_relay.clone(), None, ); - let (req, ctx) = - initializer.extract_req().map_err(|e| anyhow!("Failed to extract request {}", e))?; - println!("Starting new Payjoin session with {}", self.config.pj_directory); - let http = http_agent()?; - let ohttp_response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; - let session = initializer - .process_res(ohttp_response.bytes().await?.to_vec().as_slice(), ctx) - .map_err(|e| anyhow!("Enrollment failed {}", e))?; self.db.insert_recv_session(session.clone())?; self.spawn_payjoin_receiver(session, Some(amount)).await } @@ -123,7 +107,7 @@ impl App { async fn spawn_payjoin_receiver( &self, - mut session: ActiveSession, + mut session: Receiver, amount: Option, ) -> Result<()> { println!("Receive session established"); @@ -244,7 +228,7 @@ impl App { async fn long_poll_fallback( &self, - session: &mut payjoin::receive::v2::ActiveSession, + session: &mut payjoin::receive::v2::Receiver, ) -> Result { loop { let (req, context) = session.extract_req()?; diff --git a/payjoin-cli/src/db/v2.rs b/payjoin-cli/src/db/v2.rs index 8ec7250b..bd6d8030 100644 --- a/payjoin-cli/src/db/v2.rs +++ b/payjoin-cli/src/db/v2.rs @@ -1,5 +1,5 @@ use bitcoincore_rpc::jsonrpc::serde_json; -use payjoin::receive::v2::ActiveSession; +use payjoin::receive::v2::Receiver; use payjoin::send::RequestContext; use sled::{IVec, Tree}; use url::Url; @@ -7,7 +7,7 @@ use url::Url; use super::*; impl Database { - pub(crate) fn insert_recv_session(&self, session: ActiveSession) -> Result<()> { + pub(crate) fn insert_recv_session(&self, session: Receiver) -> Result<()> { let recv_tree = self.0.open_tree("recv_sessions")?; let key = &session.id(); let value = serde_json::to_string(&session).map_err(Error::Serialize)?; @@ -16,13 +16,12 @@ impl Database { Ok(()) } - pub(crate) fn get_recv_sessions(&self) -> Result> { + pub(crate) fn get_recv_sessions(&self) -> Result> { let recv_tree = self.0.open_tree("recv_sessions")?; let mut sessions = Vec::new(); for item in recv_tree.iter() { let (_, value) = item?; - let session: ActiveSession = - serde_json::from_slice(&value).map_err(Error::Deserialize)?; + let session: Receiver = serde_json::from_slice(&value).map_err(Error::Deserialize)?; sessions.push(session); } Ok(sessions) diff --git a/payjoin-cli/tests/e2e.rs b/payjoin-cli/tests/e2e.rs index 7eae8b43..5db46539 100644 --- a/payjoin-cli/tests/e2e.rs +++ b/payjoin-cli/tests/e2e.rs @@ -482,14 +482,7 @@ mod e2e { let db = docker.run(Redis::default()); let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); println!("Database running on {}", db.get_host_port_ipv4(6379)); - payjoin_directory::listen_tcp_with_tls( - format!("http://localhost:{}", port), - port, - db_host, - timeout, - local_cert_key, - ) - .await + payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await } // generates or gets a DER encoded localhost cert and key. diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index 875ff9fb..eb1a2b65 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -3,12 +3,10 @@ use std::sync::Arc; use std::time::Duration; use anyhow::Result; -use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; -use bitcoin::base64::Engine; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Empty, Full}; use hyper::body::{Body, Bytes, Incoming}; -use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE, LOCATION}; +use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE}; use hyper::server::conn::http1; use hyper::service::service_fn; use hyper::{Method, Request, Response, StatusCode, Uri}; @@ -20,7 +18,6 @@ use tracing::{debug, error, info, trace}; pub const DEFAULT_DIR_PORT: u16 = 8080; pub const DEFAULT_DB_HOST: &str = "localhost:6379"; pub const DEFAULT_TIMEOUT_SECS: u64 = 30; -pub const DEFAULT_BASE_URL: &str = "https://localhost"; const MAX_BUFFER_SIZE: usize = 65536; @@ -32,7 +29,6 @@ mod db; use crate::db::DbPool; pub async fn listen_tcp( - base_url: String, port: u16, db_host: String, timeout: Duration, @@ -44,14 +40,13 @@ pub async fn listen_tcp( while let Ok((stream, _)) = listener.accept().await { let pool = pool.clone(); let ohttp = ohttp.clone(); - let base_url = base_url.clone(); let io = TokioIo::new(stream); tokio::spawn(async move { if let Err(err) = http1::Builder::new() .serve_connection( io, service_fn(move |req| { - serve_payjoin_directory(req, pool.clone(), ohttp.clone(), base_url.clone()) + serve_payjoin_directory(req, pool.clone(), ohttp.clone()) }), ) .with_upgrades() @@ -67,7 +62,6 @@ pub async fn listen_tcp( #[cfg(feature = "danger-local-https")] pub async fn listen_tcp_with_tls( - base_url: String, port: u16, db_host: String, timeout: Duration, @@ -81,7 +75,6 @@ pub async fn listen_tcp_with_tls( while let Ok((stream, _)) = listener.accept().await { let pool = pool.clone(); let ohttp = ohttp.clone(); - let base_url = base_url.clone(); let tls_acceptor = tls_acceptor.clone(); tokio::spawn(async move { let tls_stream = match tls_acceptor.accept(stream).await { @@ -95,7 +88,7 @@ pub async fn listen_tcp_with_tls( .serve_connection( TokioIo::new(tls_stream), service_fn(move |req| { - serve_payjoin_directory(req, pool.clone(), ohttp.clone(), base_url.clone()) + serve_payjoin_directory(req, pool.clone(), ohttp.clone()) }), ) .with_upgrades() @@ -146,7 +139,6 @@ async fn serve_payjoin_directory( req: Request, pool: DbPool, ohttp: Arc>, - base_url: String, ) -> Result>> { let path = req.uri().path().to_string(); let query = req.uri().query().unwrap_or_default().to_string(); @@ -155,7 +147,7 @@ async fn serve_payjoin_directory( let path_segments: Vec<&str> = path.split('/').collect(); debug!("serve_payjoin_directory: {:?}", &path_segments); let mut response = match (parts.method, path_segments.as_slice()) { - (Method::POST, ["", ""]) => handle_ohttp_gateway(body, pool, ohttp, base_url).await, + (Method::POST, ["", ""]) => handle_ohttp_gateway(body, pool, ohttp).await, (Method::GET, ["", "ohttp-keys"]) => get_ohttp_keys(&ohttp).await, (Method::POST, ["", id]) => post_fallback_v1(id, query, body, pool).await, (Method::GET, ["", "health"]) => health_check().await, @@ -173,7 +165,6 @@ async fn handle_ohttp_gateway( body: Incoming, pool: DbPool, ohttp: Arc>, - base_url: String, ) -> Result>, HandlerError> { // decapsulate let ohttp_body = @@ -199,7 +190,7 @@ async fn handle_ohttp_gateway( } let request = http_req.body(full(body))?; - let response = handle_v2(pool, base_url, request).await?; + let response = handle_v2(pool, request).await?; let (parts, body) = response.into_parts(); let mut bhttp_res = bhttp::Message::response(parts.status.as_u16()); @@ -221,7 +212,6 @@ async fn handle_ohttp_gateway( async fn handle_v2( pool: DbPool, - base_url: String, req: Request>, ) -> Result>, HandlerError> { let path = req.uri().path().to_string(); @@ -230,7 +220,6 @@ async fn handle_v2( let path_segments: Vec<&str> = path.split('/').collect(); debug!("handle_v2: {:?}", &path_segments); match (parts.method, path_segments.as_slice()) { - (Method::POST, &["", ""]) => post_session(base_url, body).await, (Method::POST, &["", id]) => post_fallback_v2(id, body, pool).await, (Method::GET, &["", id]) => get_fallback(id, pool).await, (Method::PUT, &["", id]) => post_payjoin(id, body, pool).await, @@ -282,24 +271,6 @@ impl From for HandlerError { fn from(e: hyper::http::Error) -> Self { HandlerError::InternalServerError(e.into()) } } -async fn post_session( - base_url: String, - body: BoxBody, -) -> Result>, HandlerError> { - let bytes = body.collect().await.map_err(|e| HandlerError::BadRequest(e.into()))?.to_bytes(); - let base64_id = - String::from_utf8(bytes.to_vec()).map_err(|e| HandlerError::BadRequest(e.into()))?; - let pubkey_bytes: Vec = - BASE64_URL_SAFE_NO_PAD.decode(base64_id).map_err(|e| HandlerError::BadRequest(e.into()))?; - let pubkey = bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes) - .map_err(|e| HandlerError::BadRequest(e.into()))?; - tracing::info!("Initialized session with pubkey: {:?}", pubkey); - Ok(Response::builder() - .header(LOCATION, format!("{}/{}", base_url, pubkey)) - .status(StatusCode::CREATED) - .body(empty())?) -} - async fn post_fallback_v1( id: &str, query: String, @@ -425,32 +396,3 @@ fn empty() -> BoxBody { fn full>(chunk: T) -> BoxBody { Full::new(chunk.into()).map_err(|never| match never {}).boxed() } - -#[cfg(test)] -mod tests { - use hyper::Request; - - use super::*; - - /// Ensure that the POST / endpoint returns a 201 Created with a Location header - /// as is semantically correct when creating a resource. - /// - /// https://datatracker.ietf.org/doc/html/rfc9110#name-post - #[tokio::test] - async fn test_post_session() -> Result<(), Box> { - let base_url = "https://localhost".to_string(); - let body = full("A6z245ZfDfnlk7_HiAp6sPmNaVYwADih-vCGE3eysWp7"); - - let request = Request::builder().method(Method::POST).uri("/").body(body)?; - - let response = post_session(base_url.clone(), request.into_body()) - .await - .map_err(|e| format!("{:?}", e))?; - - assert_eq!(response.status(), StatusCode::CREATED); - assert!(response.headers().contains_key(LOCATION)); - let location_header = response.headers().get(LOCATION).ok_or("Missing LOCATION header")?; - assert!(location_header.to_str()?.starts_with(&base_url)); - Ok(()) - } -} diff --git a/payjoin-directory/src/main.rs b/payjoin-directory/src/main.rs index 13d04cff..39dcd8c6 100644 --- a/payjoin-directory/src/main.rs +++ b/payjoin-directory/src/main.rs @@ -17,9 +17,7 @@ async fn main() -> Result<(), Box> { let db_host = env::var("PJ_DB_HOST").unwrap_or_else(|_| DEFAULT_DB_HOST.to_string()); - let base_url = env::var("PJ_DIR_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()); - - payjoin_directory::listen_tcp(base_url, dir_port, db_host, timeout).await + payjoin_directory::listen_tcp(dir_port, db_host, timeout).await } fn init_logging() { diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index 2c250bfa..54543c72 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -45,16 +45,19 @@ where Ok(address.assume_checked()) } -/// Initializes a new payjoin session, including necessary context -/// information for communication and cryptographic operations. -#[derive(Debug, Clone)] -pub struct SessionInitializer { +fn subdir_path_from_pubkey(pubkey: &HpkePublicKey) -> String { + BASE64_URL_SAFE_NO_PAD.encode(pubkey.to_compressed_bytes()) +} + +/// A payjoin V2 receiver, allowing for polled requests to the +/// payjoin directory and response processing. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Receiver { context: SessionContext, } -#[cfg(feature = "v2")] -impl SessionInitializer { - /// Creates a new `SessionInitializer` with the provided parameters. +impl Receiver { + /// Creates a new `Receiver` with the provided parameters. /// /// # Parameters /// - `address`: The Bitcoin address for the payjoin session. @@ -64,7 +67,7 @@ impl SessionInitializer { /// - `expire_after`: The duration after which the session expires. /// /// # Returns - /// A new instance of `SessionInitializer`. + /// A new instance of `Receiver`. /// /// # References /// - [BIP 77: Payjoin Version 2: Serverless Payjoin](https://github.com/bitcoin/bips/pull/1483) @@ -90,56 +93,7 @@ impl SessionInitializer { } } - pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), Error> { - let url = self.context.ohttp_relay.clone(); - let subdirectory = subdir_path_from_pubkey(self.context.s.public_key()); - let (body, ctx) = crate::v2::ohttp_encapsulate( - &mut self.context.ohttp_keys, - "POST", - self.context.directory.as_str(), - Some(subdirectory.as_bytes()), - )?; - let req = Request::new_v2(url, body); - Ok((req, ctx)) - } - - pub fn process_res( - mut self, - mut res: impl std::io::Read, - ctx: ohttp::ClientResponse, - ) -> Result { - let mut buf = Vec::new(); - let _ = res.read_to_end(&mut buf); - let response = crate::v2::ohttp_decapsulate(ctx, &buf)?; - if !response.status().is_success() { - return Err(Error::Server("Enrollment failed, expected success status".into())); - } - log::debug!("Received response headers: {:?}", response.headers()); - let location = response - .headers() - .get("location") - .ok_or(Error::Server("Missing location header".into()))? - .to_str() - .map_err(|e| Error::Server(format!("Invalid location header: {}", e).into()))?; - self.context.subdirectory = - Some(url::Url::parse(location).map_err(|e| Error::Server(e.into()))?); - - Ok(ActiveSession { context: self.context.clone() }) - } -} - -fn subdir_path_from_pubkey(pubkey: &HpkePublicKey) -> String { - BASE64_URL_SAFE_NO_PAD.encode(pubkey.to_compressed_bytes()) -} - -/// An active payjoin V2 session, allowing for polled requests to the -/// payjoin directory and response processing. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct ActiveSession { - context: SessionContext, -} - -impl ActiveSession { + // OHTTP Encapsulated HTTP GET request for the Original PSBT pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), SessionError> { if SystemTime::now() > self.context.expiry { return Err(InternalSessionError::Expired(self.context.expiry).into()); @@ -561,7 +515,7 @@ mod test { #[test] #[cfg(feature = "v2")] - fn active_session_ser_de_roundtrip() { + fn receiver_ser_de_roundtrip() { use ohttp::hpke::{Aead, Kdf, Kem}; use ohttp::{KeyId, SymmetricSuite}; const KEY_ID: KeyId = 1; @@ -569,7 +523,7 @@ mod test { const SYMMETRIC: &[SymmetricSuite] = &[ohttp::SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)]; - let session = ActiveSession { + let session = Receiver { context: SessionContext { address: Address::from_str("tb1q6d3a2w975yny0asuvd9a67ner4nks58ff0q8g4") .unwrap() @@ -586,7 +540,7 @@ mod test { }, }; let serialized = serde_json::to_string(&session).unwrap(); - let deserialized: ActiveSession = serde_json::from_str(&serialized).unwrap(); + let deserialized: Receiver = serde_json::from_str(&serialized).unwrap(); assert_eq!(session, deserialized); } } diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index 1b19feff..e39abd92 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -178,9 +178,7 @@ mod integration { use bitcoin::Address; use http::StatusCode; - use payjoin::receive::v2::{ - ActiveSession, PayjoinProposal, SessionInitializer, UncheckedProposal, - }; + use payjoin::receive::v2::{PayjoinProposal, Receiver, UncheckedProposal}; use payjoin::{OhttpKeys, PjUri, UriExt}; use reqwest::{Client, ClientBuilder, Error, Response}; use testcontainers_modules::redis::Redis; @@ -202,7 +200,7 @@ mod integration { let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap(); tokio::select!( _ = init_directory(port, (cert.clone(), key)) => assert!(false, "Directory server is long running"), - res = enroll_with_bad_keys(directory, bad_ohttp_keys, cert) => { + res = try_request_with_bad_keys(directory, bad_ohttp_keys, cert) => { assert_eq!( res.unwrap().headers().get("content-type").unwrap(), "application/problem+json" @@ -210,7 +208,7 @@ mod integration { } ); - async fn enroll_with_bad_keys( + async fn try_request_with_bad_keys( directory: Url, bad_ohttp_keys: OhttpKeys, cert_der: Vec, @@ -221,13 +219,8 @@ mod integration { let mock_address = Address::from_str("tb1q6d3a2w975yny0asuvd9a67ner4nks58ff0q8g4") .unwrap() .assume_checked(); - let mut bad_initializer = SessionInitializer::new( - mock_address, - directory, - bad_ohttp_keys, - mock_ohttp_relay, - None, - ); + let mut bad_initializer = + Receiver::new(mock_address, directory, bad_ohttp_keys, mock_ohttp_relay, None); let (req, _ctx) = bad_initializer.extract_req().expect("Failed to extract request"); agent.post(req.url).body(req.body).send().await } @@ -270,10 +263,8 @@ mod integration { address.clone(), directory.clone(), ohttp_keys.clone(), - cert_der, Some(Duration::from_secs(0)), - ) - .await?; + ); match session.extract_req() { // Internal error types are private, so check against a string Err(err) => assert!(err.to_string().contains("expired")), @@ -340,10 +331,8 @@ mod integration { address.clone(), directory.clone(), ohttp_keys.clone(), - cert_der.clone(), None, - ) - .await?; + ); println!("session: {:#?}", &session); let pj_uri_string = session.pj_uri_builder().build().to_string(); // Poll receive request @@ -661,14 +650,7 @@ mod integration { .await?; let address = receiver.get_new_address(None, None)?.assume_checked(); - let mut session = initialize_session( - address, - directory, - ohttp_keys.clone(), - cert_der.clone(), - None, - ) - .await?; + let mut session = initialize_session(address, directory, ohttp_keys.clone(), None); let pj_uri_string = session.pj_uri_builder().build().to_string(); @@ -780,14 +762,7 @@ mod integration { let db = docker.run(Redis::default()); let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); println!("Database running on {}", db.get_host_port_ipv4(6379)); - payjoin_directory::listen_tcp_with_tls( - format!("http://localhost:{}", port), - port, - db_host, - timeout, - local_cert_key, - ) - .await + payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await } // generates or gets a DER encoded localhost cert and key. @@ -802,27 +777,20 @@ mod integration { (cert_der, key_der) } - async fn initialize_session( + fn initialize_session( address: Address, directory: Url, ohttp_keys: OhttpKeys, - cert_der: Vec, custom_expire_after: Option, - ) -> Result { + ) -> Receiver { let mock_ohttp_relay = directory.clone(); // pass through to directory - let mut initializer = SessionInitializer::new( + Receiver::new( address, directory.clone(), ohttp_keys, mock_ohttp_relay.clone(), custom_expire_after, - ); - let (req, ctx) = initializer.extract_req()?; - println!("enroll req: {:#?}", &req); - let response = - http_agent(cert_der).unwrap().post(req.url).body(req.body).send().await?; - assert!(response.status().is_success()); - Ok(initializer.process_res(response.bytes().await?.to_vec().as_slice(), ctx)?) + ) } fn handle_directory_proposal( From 4e7cd2808c81281a23ad099bd2115cb83ed1e3c0 Mon Sep 17 00:00:00 2001 From: DanGould Date: Fri, 11 Oct 2024 12:13:14 -0400 Subject: [PATCH 02/12] Request with implicit initialization pattern --- payjoin-cli/src/app/mod.rs | 6 +- payjoin-cli/src/app/v2.rs | 77 +++++++---- payjoin-cli/src/db/v2.rs | 18 +-- payjoin-directory/src/db.rs | 20 +-- payjoin-directory/src/lib.rs | 82 ++++++----- payjoin/src/receive/v2/mod.rs | 42 ++++-- payjoin/src/send/mod.rs | 250 +++++++++++++++++++++------------- payjoin/tests/integration.rs | 107 ++++++++++----- 8 files changed, 367 insertions(+), 235 deletions(-) diff --git a/payjoin-cli/src/app/mod.rs b/payjoin-cli/src/app/mod.rs index 5d8c000b..468a4b66 100644 --- a/payjoin-cli/src/app/mod.rs +++ b/payjoin-cli/src/app/mod.rs @@ -7,7 +7,7 @@ use bitcoin::TxIn; use bitcoincore_rpc::bitcoin::Amount; use bitcoincore_rpc::RpcApi; use payjoin::bitcoin::psbt::Psbt; -use payjoin::send::RequestContext; +use payjoin::send::Sender; use payjoin::{bitcoin, PjUri}; pub mod config; @@ -30,7 +30,7 @@ pub trait App { async fn send_payjoin(&self, bip21: &str, fee_rate: &f32) -> Result<()>; async fn receive_payjoin(self, amount_arg: &str) -> Result<()>; - fn create_pj_request(&self, uri: &PjUri, fee_rate: &f32) -> Result { + fn create_pj_request(&self, uri: &PjUri, fee_rate: &f32) -> Result { let amount = uri.amount.ok_or_else(|| anyhow!("please specify the amount in the Uri"))?; // wallet_create_funded_psbt requires a HashMap @@ -66,7 +66,7 @@ pub trait App { .psbt; let psbt = Psbt::from_str(&psbt).with_context(|| "Failed to load PSBT from base64")?; log::debug!("Original psbt: {:#?}", psbt); - let req_ctx = payjoin::send::RequestBuilder::from_psbt_and_uri(psbt, uri.clone()) + let req_ctx = payjoin::send::SenderBuilder::from_psbt_and_uri(psbt, uri.clone()) .with_context(|| "Failed to build payjoin request")? .build_recommended(fee_rate) .with_context(|| "Failed to build payjoin request")?; diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index 6a6e27c9..49a1e26b 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -8,7 +8,7 @@ use payjoin::bitcoin::consensus::encode::serialize_hex; use payjoin::bitcoin::psbt::Psbt; use payjoin::bitcoin::{Amount, FeeRate}; use payjoin::receive::v2::Receiver; -use payjoin::send::RequestContext; +use payjoin::send::Sender; use payjoin::{bitcoin, Error, Uri}; use tokio::signal; use tokio::sync::watch; @@ -91,7 +91,7 @@ impl AppTrait for App { } impl App { - async fn spawn_payjoin_sender(&self, mut req_ctx: RequestContext) -> Result<()> { + async fn spawn_payjoin_sender(&self, mut req_ctx: Sender) -> Result<()> { let mut interrupt = self.interrupt.clone(); tokio::select! { res = self.long_poll_post(&mut req_ctx) => { @@ -197,30 +197,57 @@ impl App { Ok(()) } - async fn long_poll_post(&self, req_ctx: &mut payjoin::send::RequestContext) -> Result { - loop { - let (req, ctx) = req_ctx.extract_v2(self.config.ohttp_relay.clone())?; - println!("Polling send request..."); - let http = http_agent()?; - let response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; - - println!("Sent fallback transaction"); - match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) { - Ok(Some(psbt)) => return Ok(psbt), - Ok(None) => { - println!("No response yet."); - tokio::time::sleep(std::time::Duration::from_secs(5)).await; + async fn long_poll_post(&self, req_ctx: &mut payjoin::send::Sender) -> Result { + let (req, ctx) = req_ctx.extract_highest_version(self.config.ohttp_relay.clone())?; + println!("Posting Original PSBT Payload request..."); + let http = http_agent()?; + let response = http + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await + .map_err(map_reqwest_err)?; + println!("Sent fallback transaction"); + match ctx { + payjoin::send::Context::V2(ctx) => { + let v2_ctx = Arc::new( + ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?, + ); + loop { + let (req, ohttp_ctx) = v2_ctx.extract_req(self.config.ohttp_relay.clone())?; + let response = http + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await + .map_err(map_reqwest_err)?; + match v2_ctx.process_response( + &mut response.bytes().await?.to_vec().as_slice(), + ohttp_ctx, + ) { + Ok(Some(psbt)) => return Ok(psbt), + Ok(None) => { + println!("No response yet."); + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + } + Err(re) => { + println!("{}", re); + log::debug!("{:?}", re); + return Err(anyhow!("Response error").context(re)); + } + } } - Err(re) => { - println!("{}", re); - log::debug!("{:?}", re); - return Err(anyhow!("Response error").context(re)); + } + payjoin::send::Context::V1(ctx) => { + match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) { + Ok(psbt) => Ok(psbt), + Err(re) => { + println!("{}", re); + log::debug!("{:?}", re); + Err(anyhow!("Response error").context(re)) + } } } } diff --git a/payjoin-cli/src/db/v2.rs b/payjoin-cli/src/db/v2.rs index bd6d8030..a2168647 100644 --- a/payjoin-cli/src/db/v2.rs +++ b/payjoin-cli/src/db/v2.rs @@ -1,6 +1,6 @@ use bitcoincore_rpc::jsonrpc::serde_json; use payjoin::receive::v2::Receiver; -use payjoin::send::RequestContext; +use payjoin::send::Sender; use sled::{IVec, Tree}; use url::Url; @@ -34,11 +34,7 @@ impl Database { Ok(()) } - pub(crate) fn insert_send_session( - &self, - session: &mut RequestContext, - pj_url: &Url, - ) -> Result<()> { + pub(crate) fn insert_send_session(&self, session: &mut Sender, pj_url: &Url) -> Result<()> { let send_tree: Tree = self.0.open_tree("send_sessions")?; let value = serde_json::to_string(session).map_err(Error::Serialize)?; send_tree.insert(pj_url.to_string(), IVec::from(value.as_str()))?; @@ -46,23 +42,21 @@ impl Database { Ok(()) } - pub(crate) fn get_send_sessions(&self) -> Result> { + pub(crate) fn get_send_sessions(&self) -> Result> { let send_tree: Tree = self.0.open_tree("send_sessions")?; let mut sessions = Vec::new(); for item in send_tree.iter() { let (_, value) = item?; - let session: RequestContext = - serde_json::from_slice(&value).map_err(Error::Deserialize)?; + let session: Sender = serde_json::from_slice(&value).map_err(Error::Deserialize)?; sessions.push(session); } Ok(sessions) } - pub(crate) fn get_send_session(&self, pj_url: &Url) -> Result> { + pub(crate) fn get_send_session(&self, pj_url: &Url) -> Result> { let send_tree = self.0.open_tree("send_sessions")?; if let Some(val) = send_tree.get(pj_url.to_string())? { - let session: RequestContext = - serde_json::from_slice(&val).map_err(Error::Deserialize)?; + let session: Sender = serde_json::from_slice(&val).map_err(Error::Deserialize)?; Ok(Some(session)) } else { Ok(None) diff --git a/payjoin-directory/src/db.rs b/payjoin-directory/src/db.rs index 26b69864..679a0f40 100644 --- a/payjoin-directory/src/db.rs +++ b/payjoin-directory/src/db.rs @@ -4,8 +4,8 @@ use futures::StreamExt; use redis::{AsyncCommands, Client, ErrorKind, RedisError, RedisResult}; use tracing::debug; -const RES_COLUMN: &str = "res"; -const REQ_COLUMN: &str = "req"; +const DEFAULT_COLUMN: &str = ""; +const PJ_V1_COLUMN: &str = "pjv1"; #[derive(Debug, Clone)] pub(crate) struct DbPool { @@ -19,20 +19,20 @@ impl DbPool { Ok(Self { client, timeout }) } - pub async fn peek_req(&self, pubkey_id: &str) -> Option>> { - self.peek_with_timeout(pubkey_id, REQ_COLUMN).await + pub async fn push_default(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { + self.push(pubkey_id, DEFAULT_COLUMN, data).await } - pub async fn peek_res(&self, pubkey_id: &str) -> Option>> { - self.peek_with_timeout(pubkey_id, RES_COLUMN).await + pub async fn peek_default(&self, pubkey_id: &str) -> Option>> { + self.peek_with_timeout(pubkey_id, DEFAULT_COLUMN).await } - pub async fn push_req(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { - self.push(pubkey_id, REQ_COLUMN, data).await + pub async fn push_v1(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { + self.push(pubkey_id, PJ_V1_COLUMN, data).await } - pub async fn push_res(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { - self.push(pubkey_id, RES_COLUMN, data).await + pub async fn peek_v1(&self, pubkey_id: &str) -> Option>> { + self.peek_with_timeout(pubkey_id, PJ_V1_COLUMN).await } async fn push(&self, pubkey_id: &str, channel_type: &str, data: Vec) -> RedisResult<()> { diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index eb1a2b65..ef267c86 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -220,9 +220,9 @@ async fn handle_v2( let path_segments: Vec<&str> = path.split('/').collect(); debug!("handle_v2: {:?}", &path_segments); match (parts.method, path_segments.as_slice()) { - (Method::POST, &["", id]) => post_fallback_v2(id, body, pool).await, - (Method::GET, &["", id]) => get_fallback(id, pool).await, - (Method::PUT, &["", id]) => post_payjoin(id, body, pool).await, + (Method::POST, &["", id]) => post_subdir(id, body, pool).await, + (Method::GET, &["", id]) => get_subdir(id, pool).await, + (Method::PUT, &["", id]) => put_payjoin_v1(id, body, pool).await, _ => Ok(not_found()), } } @@ -294,27 +294,49 @@ async fn post_fallback_v1( Err(_) => return Ok(bad_request_body_res), }; - let v2_compat_body = full(format!("{}\n{}", body_str, query)); - post_fallback(id, v2_compat_body, pool, none_response).await + let v2_compat_body = format!("{}\n{}", body_str, query); + let id = shorten_string(id); + pool.push_default(&id, v2_compat_body.into()) + .await + .map_err(|e| HandlerError::BadRequest(e.into()))?; + match pool.peek_v1(&id).await { + Some(result) => match result { + Ok(buffered_req) => Ok(Response::new(full(buffered_req))), + Err(e) => Err(HandlerError::BadRequest(e.into())), + }, + None => Ok(none_response), + } } -async fn post_fallback_v2( +async fn put_payjoin_v1( id: &str, body: BoxBody, pool: DbPool, ) -> Result>, HandlerError> { - trace!("Post fallback v2"); - let none_response = Response::builder().status(StatusCode::ACCEPTED).body(empty())?; - post_fallback(id, body, pool, none_response).await + trace!("Put_payjoin_v1"); + let ok_response = Response::builder().status(StatusCode::OK).body(empty())?; + + let id = shorten_string(id); + let req = + body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); + if req.len() > MAX_BUFFER_SIZE { + return Err(HandlerError::PayloadTooLarge); + } + + match pool.push_v1(&id, req.into()).await { + Ok(_) => Ok(ok_response), + Err(e) => Err(HandlerError::BadRequest(e.into())), + } } -async fn post_fallback( +async fn post_subdir( id: &str, body: BoxBody, pool: DbPool, - none_response: Response>, ) -> Result>, HandlerError> { - tracing::trace!("Post fallback"); + let none_response = Response::builder().status(StatusCode::OK).body(empty())?; + trace!("post_subdir"); + let id = shorten_string(id); let req = body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); @@ -322,27 +344,19 @@ async fn post_fallback( return Err(HandlerError::PayloadTooLarge); } - match pool.push_req(&id, req.into()).await { - Ok(_) => (), - Err(e) => return Err(HandlerError::BadRequest(e.into())), - }; - - match pool.peek_res(&id).await { - Some(result) => match result { - Ok(buffered_res) => Ok(Response::new(full(buffered_res))), - Err(e) => Err(HandlerError::BadRequest(e.into())), - }, - None => Ok(none_response), + match pool.push_default(&id, req.into()).await { + Ok(_) => Ok(none_response), + Err(e) => Err(HandlerError::BadRequest(e.into())), } } -async fn get_fallback( +async fn get_subdir( id: &str, pool: DbPool, ) -> Result>, HandlerError> { - trace!("GET fallback"); + trace!("get_subdir"); let id = shorten_string(id); - match pool.peek_req(&id).await { + match pool.peek_default(&id).await { Some(result) => match result { Ok(buffered_req) => Ok(Response::new(full(buffered_req))), Err(e) => Err(HandlerError::BadRequest(e.into())), @@ -351,22 +365,6 @@ async fn get_fallback( } } -async fn post_payjoin( - id: &str, - body: BoxBody, - pool: DbPool, -) -> Result>, HandlerError> { - trace!("POST payjoin"); - let id = shorten_string(id); - let res = - body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); - - match pool.push_res(&id, res.into()).await { - Ok(_) => Ok(Response::builder().status(StatusCode::NO_CONTENT).body(empty())?), - Err(e) => Err(HandlerError::BadRequest(e.into())), - } -} - fn not_found() -> Response> { let mut res = Response::default(); *res.status_mut() = StatusCode::NOT_FOUND; diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index 54543c72..396bf66a 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -93,7 +93,7 @@ impl Receiver { } } - // OHTTP Encapsulated HTTP GET request for the Original PSBT + /// Extratct an OHTTP Encapsulated HTTP GET request for the Original PSBT pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), SessionError> { if SystemTime::now() > self.context.expiry { return Err(InternalSessionError::Expired(self.context.expiry).into()); @@ -461,22 +461,34 @@ impl PayjoinProposal { #[cfg(feature = "v2")] pub fn extract_v2_req(&mut self) -> Result<(Request, ohttp::ClientResponse), Error> { - let body = match &self.context.e { - Some(e) => { - let payjoin_bytes = self.inner.payjoin_psbt.serialize(); - log::debug!("THERE IS AN e: {:?}", e); - crate::v2::encrypt_message_b(payjoin_bytes, &self.context.s, e) - } - None => Ok(self.extract_v1_req().as_bytes().to_vec()), - }?; - let subdir_path = subdir_path_from_pubkey(self.context.s.public_key()); - let post_payjoin_target = - self.context.directory.join(&subdir_path).map_err(|e| Error::Server(e.into()))?; - log::debug!("Payjoin post target: {}", post_payjoin_target.as_str()); + let target_resource: Url; + let body: Vec; + let method: &str; + + if let Some(e) = &self.context.e { + // Prepare v2 payload + let payjoin_bytes = self.inner.payjoin_psbt.serialize(); + let sender_subdir = subdir_path_from_pubkey(e); + target_resource = + self.context.directory.join(&sender_subdir).map_err(|e| Error::Server(e.into()))?; + body = crate::v2::encrypt_message_b(payjoin_bytes, &self.context.s, e)?; + method = "POST"; + } else { + // Prepare v2 wrapped and backwards-compatible v1 payload + body = self.extract_v1_req().as_bytes().to_vec(); + let receiver_subdir = subdir_path_from_pubkey(self.context.s.public_key()); + target_resource = self + .context + .directory + .join(&receiver_subdir) + .map_err(|e| Error::Server(e.into()))?; + method = "PUT"; + } + log::debug!("Payjoin PSBT target: {}", target_resource.as_str()); let (body, ctx) = crate::v2::ohttp_encapsulate( &mut self.context.ohttp_keys, - "PUT", - post_payjoin_target.as_str(), + method, + target_resource.as_str(), Some(&body), )?; let url = self.context.ohttp_relay.clone(); diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 774104df..8a859c19 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -24,6 +24,8 @@ use std::str::FromStr; +#[cfg(feature = "v2")] +use bitcoin::base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; use bitcoin::psbt::Psbt; use bitcoin::{Amount, FeeRate, Script, ScriptBuf, TxOut, Weight}; pub use error::{CreateRequestError, ResponseError, ValidationError}; @@ -35,7 +37,7 @@ use url::Url; use crate::psbt::PsbtExt; use crate::request::Request; #[cfg(feature = "v2")] -use crate::v2::{HpkePublicKey, HpkeSecretKey}; +use crate::v2::{HpkeKeyPair, HpkePublicKey}; use crate::PjUri; // See usize casts @@ -47,7 +49,7 @@ mod error; type InternalResult = Result; #[derive(Clone)] -pub struct RequestBuilder<'a> { +pub struct SenderBuilder<'a> { psbt: Psbt, uri: PjUri<'a>, disable_output_substitution: bool, @@ -61,7 +63,7 @@ pub struct RequestBuilder<'a> { min_fee_rate: FeeRate, } -impl<'a> RequestBuilder<'a> { +impl<'a> SenderBuilder<'a> { /// Prepare an HTTP request and request context to process the response /// /// An HTTP client will own the Request data while Context sticks around so @@ -96,10 +98,7 @@ impl<'a> RequestBuilder<'a> { // The minfeerate parameter is set if the contribution is available in change. // // This method fails if no recommendation can be made or if the PSBT is malformed. - pub fn build_recommended( - self, - min_fee_rate: FeeRate, - ) -> Result { + pub fn build_recommended(self, min_fee_rate: FeeRate) -> Result { // TODO support optional batched payout scripts. This would require a change to // build() which now checks for a single payee. let mut payout_scripts = std::iter::once(self.uri.address.script_pubkey()); @@ -177,7 +176,7 @@ impl<'a> RequestBuilder<'a> { change_index: Option, min_fee_rate: FeeRate, clamp_fee_contribution: bool, - ) -> Result { + ) -> Result { self.fee_contribution = Some((max_fee_contribution, change_index)); self.clamp_fee_contribution = clamp_fee_contribution; self.min_fee_rate = min_fee_rate; @@ -191,7 +190,7 @@ impl<'a> RequestBuilder<'a> { pub fn build_non_incentivizing( mut self, min_fee_rate: FeeRate, - ) -> Result { + ) -> Result { // since this is a builder, these should already be cleared // but we'll reset them to be sure self.fee_contribution = None; @@ -200,7 +199,7 @@ impl<'a> RequestBuilder<'a> { self.build() } - fn build(self) -> Result { + fn build(self) -> Result { let mut psbt = self.psbt.validate().map_err(InternalCreateRequestError::InconsistentOriginalPsbt)?; psbt.validate_input_utxos(true) @@ -219,7 +218,7 @@ impl<'a> RequestBuilder<'a> { )?; clear_unneeded_fields(&mut psbt); - Ok(RequestContext { + Ok(Sender { psbt, endpoint, disable_output_substitution, @@ -227,14 +226,14 @@ impl<'a> RequestBuilder<'a> { payee, min_fee_rate: self.min_fee_rate, #[cfg(feature = "v2")] - e: crate::v2::HpkeKeyPair::gen_keypair().secret_key().clone(), + e: crate::v2::HpkeKeyPair::gen_keypair(), }) } } #[derive(Clone, PartialEq, Eq)] #[cfg_attr(feature = "v2", derive(Serialize, Deserialize))] -pub struct RequestContext { +pub struct Sender { psbt: Psbt, endpoint: Url, disable_output_substitution: bool, @@ -242,12 +241,12 @@ pub struct RequestContext { min_fee_rate: FeeRate, payee: ScriptBuf, #[cfg(feature = "v2")] - e: crate::v2::HpkeSecretKey, + e: crate::v2::HpkeKeyPair, } -impl RequestContext { +impl Sender { /// Extract serialized V1 Request and Context from a Payjoin Proposal - pub fn extract_v1(&self) -> Result<(Request, ContextV1), CreateRequestError> { + pub fn extract_v1(&self) -> Result<(Request, V1Context), CreateRequestError> { let url = serialize_url( self.endpoint.clone(), self.disable_output_substitution, @@ -259,13 +258,15 @@ impl RequestContext { let body = self.psbt.to_string().as_bytes().to_vec(); Ok(( Request::new_v1(url, body), - ContextV1 { - original_psbt: self.psbt.clone(), - disable_output_substitution: self.disable_output_substitution, - fee_contribution: self.fee_contribution, - payee: self.payee.clone(), - min_fee_rate: self.min_fee_rate, - allow_mixed_input_scripts: false, + V1Context { + psbt_context: PsbtContext { + original_psbt: self.psbt.clone(), + disable_output_substitution: self.disable_output_substitution, + fee_contribution: self.fee_contribution, + payee: self.payee.clone(), + min_fee_rate: self.min_fee_rate, + allow_mixed_input_scripts: false, + }, }, )) } @@ -277,10 +278,10 @@ impl RequestContext { /// /// The `ohttp_relay` merely passes the encrypted payload to the ohttp gateway of the receiver #[cfg(feature = "v2")] - pub fn extract_v2( + pub fn extract_highest_version( &mut self, ohttp_relay: Url, - ) -> Result<(Request, ContextV2), CreateRequestError> { + ) -> Result<(Request, Context), CreateRequestError> { use crate::uri::UrlExt; if let Some(expiry) = self.endpoint.exp() { @@ -290,11 +291,11 @@ impl RequestContext { } match self.extract_rs_pubkey() { - Ok(rs) => self.extract_v2_strict(ohttp_relay, rs), + Ok(rs) => self.extract_v2(ohttp_relay, rs), Err(e) => { log::warn!("Failed to extract `rs` pubkey, falling back to v1: {}", e); let (req, context_v1) = self.extract_v1()?; - Ok((req, ContextV2 { context_v1, rs: None, e: None, ohttp_res: None })) + Ok((req, Context::V1(context_v1))) } } } @@ -304,11 +305,11 @@ impl RequestContext { /// This method requires the `rs` pubkey to be extracted from the endpoint /// and has no fallback to v1. #[cfg(feature = "v2")] - fn extract_v2_strict( + fn extract_v2( &mut self, ohttp_relay: Url, rs: HpkePublicKey, - ) -> Result<(Request, ContextV2), CreateRequestError> { + ) -> Result<(Request, Context), CreateRequestError> { use crate::uri::UrlExt; let url = self.endpoint.clone(); let body = serialize_v2_body( @@ -317,18 +318,19 @@ impl RequestContext { self.fee_contribution, self.min_fee_rate, )?; - let body = crate::v2::encrypt_message_a(body, &self.e.clone(), &rs) + let body = crate::v2::encrypt_message_a(body, &self.e.secret_key().clone(), &rs) .map_err(InternalCreateRequestError::Hpke)?; let mut ohttp = self.endpoint.ohttp().ok_or(InternalCreateRequestError::MissingOhttpConfig)?; - let (body, ohttp_res) = + let (body, ohttp_ctx) = crate::v2::ohttp_encapsulate(&mut ohttp, "POST", url.as_str(), Some(&body)) .map_err(InternalCreateRequestError::OhttpEncapsulation)?; log::debug!("ohttp_relay_url: {:?}", ohttp_relay); Ok(( Request::new_v2(ohttp_relay, body), - ContextV2 { - context_v1: ContextV1 { + Context::V2(V2PostContext { + endpoint: self.endpoint.clone(), + psbt_ctx: PsbtContext { original_psbt: self.psbt.clone(), disable_output_substitution: self.disable_output_substitution, fee_contribution: self.fee_contribution, @@ -336,17 +338,14 @@ impl RequestContext { min_fee_rate: self.min_fee_rate, allow_mixed_input_scripts: true, }, - rs: Some(self.extract_rs_pubkey()?), - e: Some(self.e.clone()), - ohttp_res: Some(ohttp_res), - }, + hpke_ctx: HpkeContext { rs, e: self.e.clone() }, + ohttp_ctx, + }), )) } #[cfg(feature = "v2")] fn extract_rs_pubkey(&self) -> Result { - use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; - use bitcoin::base64::Engine; use error::ParseSubdirectoryError; let subdirectory = self @@ -366,12 +365,123 @@ impl RequestContext { pub fn endpoint(&self) -> &Url { &self.endpoint } } +pub enum Context { + V1(V1Context), + #[cfg(feature = "v2")] + V2(V2PostContext), +} + +pub struct V1Context { + psbt_context: PsbtContext, +} + +impl V1Context { + pub fn process_response( + self, + response: &mut impl std::io::Read, + ) -> Result { + self.psbt_context.process_response(response) + } +} + +#[cfg(feature = "v2")] +pub struct V2PostContext { + endpoint: Url, + psbt_ctx: PsbtContext, + hpke_ctx: HpkeContext, + ohttp_ctx: ohttp::ClientResponse, +} + +#[cfg(feature = "v2")] +impl V2PostContext { + pub fn process_response( + self, + response: &mut impl std::io::Read, + ) -> Result { + let mut res_buf = Vec::new(); + response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; + let response = crate::v2::ohttp_decapsulate(self.ohttp_ctx, &res_buf) + .map_err(InternalValidationError::OhttpEncapsulation)?; + match response.status() { + http::StatusCode::OK => { + // return OK with new Typestate + Ok(V2GetContext { + endpoint: self.endpoint, + psbt_ctx: self.psbt_ctx, + hpke_ctx: self.hpke_ctx, + }) + } + _ => Err(InternalValidationError::UnexpectedStatusCode)?, + } + } +} + +#[cfg(feature = "v2")] +pub struct V2GetContext { + endpoint: Url, + psbt_ctx: PsbtContext, + hpke_ctx: HpkeContext, +} + +#[cfg(feature = "v2")] +impl V2GetContext { + pub fn extract_req( + &self, + ohttp_relay: Url, + ) -> Result<(Request, ohttp::ClientResponse), CreateRequestError> { + use crate::uri::UrlExt; + let mut url = self.endpoint.clone(); + let subdir = + BASE64_URL_SAFE_NO_PAD.encode(self.hpke_ctx.e.public_key().to_compressed_bytes()); + url.set_path(&subdir); + let body = crate::v2::encrypt_message_a( + Vec::new(), + &self.hpke_ctx.e.secret_key().clone(), + &self.hpke_ctx.rs.clone(), + ) + .map_err(InternalCreateRequestError::Hpke)?; + let mut ohttp = + self.endpoint.ohttp().ok_or(InternalCreateRequestError::MissingOhttpConfig)?; + let (body, ohttp_ctx) = + crate::v2::ohttp_encapsulate(&mut ohttp, "GET", url.as_str(), Some(&body)) + .map_err(InternalCreateRequestError::OhttpEncapsulation)?; + + Ok((Request::new_v2(ohttp_relay, body), ohttp_ctx)) + } + + pub fn process_response( + &self, + response: &mut impl std::io::Read, + ohttp_ctx: ohttp::ClientResponse, + ) -> Result, ResponseError> { + let mut res_buf = Vec::new(); + response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; + let response = crate::v2::ohttp_decapsulate(ohttp_ctx, &res_buf) + .map_err(InternalValidationError::OhttpEncapsulation)?; + let body = match response.status() { + http::StatusCode::OK => response.body().to_vec(), + http::StatusCode::ACCEPTED => return Ok(None), + _ => return Err(InternalValidationError::UnexpectedStatusCode)?, + }; + let psbt = crate::v2::decrypt_message_b( + &body, + self.hpke_ctx.rs.clone(), + self.hpke_ctx.e.secret_key().clone(), + ) + .map_err(InternalValidationError::Hpke)?; + + let proposal = Psbt::deserialize(&psbt).map_err(InternalValidationError::Psbt)?; + let processed_proposal = self.psbt_ctx.clone().process_proposal(proposal)?; + Ok(Some(processed_proposal)) + } +} + /// Data required for validation of response. /// /// This type is used to process the response. Get it from [`RequestBuilder`]'s build methods. /// Then you only need to call [`Self::process_response`] on it to continue BIP78 flow. #[derive(Debug, Clone)] -pub struct ContextV1 { +pub struct PsbtContext { original_psbt: Psbt, disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, @@ -379,13 +489,10 @@ pub struct ContextV1 { payee: ScriptBuf, allow_mixed_input_scripts: bool, } - #[cfg(feature = "v2")] -pub struct ContextV2 { - context_v1: ContextV1, - rs: Option, - e: Option, - ohttp_res: Option, +struct HpkeContext { + rs: HpkePublicKey, + e: HpkeKeyPair, } macro_rules! check_eq { @@ -406,43 +513,7 @@ macro_rules! ensure { }; } -#[cfg(feature = "v2")] -impl ContextV2 { - /// Decodes and validates the response. - /// - /// Call this method with response from receiver to continue BIP-??? flow. - /// A successful response can either be None if the directory has not response yet or Some(Psbt). - /// - /// If the response is some valid PSBT you should sign and broadcast. - #[inline] - pub fn process_response( - self, - response: &mut impl std::io::Read, - ) -> Result, ResponseError> { - match (self.ohttp_res, self.rs, self.e) { - (Some(ohttp_res), Some(rs), Some(e)) => { - let mut res_buf = Vec::new(); - response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; - let response = crate::v2::ohttp_decapsulate(ohttp_res, &res_buf) - .map_err(InternalValidationError::OhttpEncapsulation)?; - let body = match response.status() { - http::StatusCode::OK => response.body().to_vec(), - http::StatusCode::ACCEPTED => return Ok(None), - _ => return Err(InternalValidationError::UnexpectedStatusCode)?, - }; - let psbt = crate::v2::decrypt_message_b(&body, rs, e) - .map_err(InternalValidationError::Hpke)?; - - let proposal = Psbt::deserialize(&psbt).map_err(InternalValidationError::Psbt)?; - let processed_proposal = self.context_v1.process_proposal(proposal)?; - Ok(Some(processed_proposal)) - } - _ => self.context_v1.process_response(response).map(Some), - } - } -} - -impl ContextV1 { +impl PsbtContext { /// Decodes and validates the response. /// /// Call this method with response from receiver to continue BIP78 flow. If the response is @@ -850,11 +921,11 @@ mod test { const ORIGINAL_PSBT: &str = "cHNidP8BAHMCAAAAAY8nutGgJdyYGXWiBEb45Hoe9lWGbkxh/6bNiOJdCDuDAAAAAAD+////AtyVuAUAAAAAF6kUHehJ8GnSdBUOOv6ujXLrWmsJRDCHgIQeAAAAAAAXqRR3QJbbz0hnQ8IvQ0fptGn+votneofTAAAAAAEBIKgb1wUAAAAAF6kU3k4ekGHKWRNbA1rV5tR5kEVDVNCHAQcXFgAUx4pFclNVgo1WWAdN1SYNX8tphTABCGsCRzBEAiB8Q+A6dep+Rz92vhy26lT0AjZn4PRLi8Bf9qoB/CMk0wIgP/Rj2PWZ3gEjUkTlhDRNAQ0gXwTO7t9n+V14pZ6oljUBIQMVmsAaoNWHVMS02LfTSe0e388LNitPa1UQZyOihY+FFgABABYAFEb2Giu6c4KO5YW0pfw3lGp9jMUUAAA="; const PAYJOIN_PROPOSAL: &str = "cHNidP8BAJwCAAAAAo8nutGgJdyYGXWiBEb45Hoe9lWGbkxh/6bNiOJdCDuDAAAAAAD+////jye60aAl3JgZdaIERvjkeh72VYZuTGH/ps2I4l0IO4MBAAAAAP7///8CJpW4BQAAAAAXqRQd6EnwadJ0FQ46/q6NcutaawlEMIcACT0AAAAAABepFHdAltvPSGdDwi9DR+m0af6+i2d6h9MAAAAAAQEgqBvXBQAAAAAXqRTeTh6QYcpZE1sDWtXm1HmQRUNU0IcBBBYAFMeKRXJTVYKNVlgHTdUmDV/LaYUwIgYDFZrAGqDVh1TEtNi300ntHt/PCzYrT2tVEGcjooWPhRYYSFzWUDEAAIABAACAAAAAgAEAAAAAAAAAAAEBIICEHgAAAAAAF6kUyPLL+cphRyyI5GTUazV0hF2R2NWHAQcXFgAUX4BmVeWSTJIEwtUb5TlPS/ntohABCGsCRzBEAiBnu3tA3yWlT0WBClsXXS9j69Bt+waCs9JcjWtNjtv7VgIge2VYAaBeLPDB6HGFlpqOENXMldsJezF9Gs5amvDQRDQBIQJl1jz1tBt8hNx2owTm+4Du4isx0pmdKNMNIjjaMHFfrQABABYAFEb2Giu6c4KO5YW0pfw3lGp9jMUUIgICygvBWB5prpfx61y1HDAwo37kYP3YRJBvAjtunBAur3wYSFzWUDEAAIABAACAAAAAgAEAAAABAAAAAAA="; - fn create_v1_context() -> super::ContextV1 { + fn create_v1_context() -> super::PsbtContext { let original_psbt = Psbt::from_str(ORIGINAL_PSBT).unwrap(); eprintln!("original: {:#?}", original_psbt); let payee = original_psbt.unsigned_tx.output[1].script_pubkey.clone(); - let ctx = super::ContextV1 { + let ctx = super::PsbtContext { original_psbt, disable_output_substitution: false, fee_contribution: Some((bitcoin::Amount::from_sat(182), 0)), @@ -906,20 +977,15 @@ mod test { #[test] #[cfg(feature = "v2")] fn req_ctx_ser_de_roundtrip() { - use hpke::Deserializable; - use super::*; - let req_ctx = RequestContext { + let req_ctx = Sender { psbt: Psbt::from_str(ORIGINAL_PSBT).unwrap(), endpoint: Url::parse("http://localhost:1234").unwrap(), disable_output_substitution: false, fee_contribution: None, min_fee_rate: FeeRate::ZERO, payee: ScriptBuf::from(vec![0x00]), - e: HpkeSecretKey( - ::PrivateKey::from_bytes(&[0x01; 32]) - .unwrap(), - ), + e: HpkeKeyPair::gen_keypair(), }; let serialized = serde_json::to_string(&req_ctx).unwrap(); let deserialized = serde_json::from_str(&serialized).unwrap(); diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index e39abd92..ce304c00 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -12,7 +12,7 @@ mod integration { use bitcoind::bitcoincore_rpc::{self, RpcApi}; use log::{log_enabled, Level}; use once_cell::sync::{Lazy, OnceCell}; - use payjoin::send::RequestBuilder; + use payjoin::send::SenderBuilder; use payjoin::{PjUri, PjUriBuilder, Request, Uri}; use tracing_subscriber::{EnvFilter, FmtSubscriber}; use url::Url; @@ -92,7 +92,7 @@ mod integration { .unwrap(); let psbt = build_original_psbt(&sender, &uri)?; debug!("Original psbt: {:#?}", psbt); - let (req, ctx) = RequestBuilder::from_psbt_and_uri(psbt, uri)? + let (req, ctx) = SenderBuilder::from_psbt_and_uri(psbt, uri)? .build_with_additional_fee(Amount::from_sat(10000), None, FeeRate::ZERO, false)? .extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type); @@ -157,7 +157,7 @@ mod integration { .unwrap(); let psbt = build_original_psbt(&sender, &uri)?; debug!("Original psbt: {:#?}", psbt); - let (req, _ctx) = RequestBuilder::from_psbt_and_uri(psbt, uri)? + let (req, _ctx) = SenderBuilder::from_psbt_and_uri(psbt, uri)? .build_with_additional_fee(Amount::from_sat(10000), None, FeeRate::ZERO, false)? .extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type); @@ -179,6 +179,7 @@ mod integration { use bitcoin::Address; use http::StatusCode; use payjoin::receive::v2::{PayjoinProposal, Receiver, UncheckedProposal}; + use payjoin::send::Context; use payjoin::{OhttpKeys, PjUri, UriExt}; use reqwest::{Client, ClientBuilder, Error, Response}; use testcontainers_modules::redis::Redis; @@ -283,9 +284,9 @@ mod integration { Some(std::time::SystemTime::now()), ) .build(); - let mut expired_req_ctx = RequestBuilder::from_psbt_and_uri(psbt, expired_pj_uri)? + let mut expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)? .build_non_incentivizing(FeeRate::BROADCAST_MIN)?; - match expired_req_ctx.extract_v2(directory.to_owned()) { + match expired_req_ctx.extract_highest_version(directory.to_owned()) { // Internal error types are private, so check against a string Err(err) => assert!(err.to_string().contains("expired")), _ => assert!(false, "Expired send session should error"), @@ -353,10 +354,14 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_sweep_psbt(&sender, &pj_uri)?; - let mut req_ctx = RequestBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; let (Request { url, body, content_type, .. }, send_ctx) = - req_ctx.extract_v2(directory.to_owned())?; + req_ctx.extract_highest_version(directory.to_owned())?; + let send_ctx = match send_ctx { + Context::V2(ctx) => ctx, + _ => panic!("V2 context expected"), + }; let response = agent .post(url.clone()) .header("Content-Type", content_type) @@ -366,10 +371,9 @@ mod integration { .unwrap(); log::info!("Response: {:#?}", &response); assert!(response.status().is_success()); - let response_body = + let send_ctx = send_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?; - // No response body yet since we are async and pushed fallback_psbt to the buffer - assert!(response_body.is_none()); + // POST Original PSBT // ********************** // Inside the Receiver: @@ -383,7 +387,12 @@ mod integration { let mut payjoin_proposal = handle_directory_proposal(&receiver, proposal, None); assert!(!payjoin_proposal.is_output_substitution_disabled()); let (req, ctx) = payjoin_proposal.extract_v2_req()?; - let response = agent.post(req.url).body(req.body).send().await?; + let response = agent + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await?; let res = response.bytes().await?.to_vec(); payjoin_proposal.process_res(res, ctx)?; @@ -391,11 +400,18 @@ mod integration { // Inside the Sender: // Sender checks, signs, finalizes, extracts, and broadcasts // Replay post fallback to get the response - let (Request { url, body, .. }, send_ctx) = - req_ctx.extract_v2(directory.to_owned())?; - let response = agent.post(url).body(body).send().await?; + let (Request { url, body, content_type, .. }, ohttp_ctx) = + send_ctx.extract_req(directory.to_owned())?; + let response = agent + .post(url.clone()) + .header("Content-Type", content_type) + .body(body.clone()) + .send() + .await + .unwrap(); + log::info!("Response: {:#?}", &response); let checked_payjoin_proposal_psbt = send_ctx - .process_response(&mut response.bytes().await?.to_vec().as_slice())? + .process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)? .unwrap(); let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?; sender.send_raw_transaction(&payjoin_tx)?; @@ -482,10 +498,8 @@ mod integration { address.clone(), directory.clone(), ohttp_keys.clone(), - cert_der.clone(), None, - ) - .await?; + ); println!("session: {:#?}", &session); let pj_uri_string = session.pj_uri_builder().build().to_string(); // Poll receive request @@ -506,10 +520,10 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_sweep_psbt(&sender, &pj_uri)?; - let mut req_ctx = RequestBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; - let (Request { url, body, content_type, .. }, send_ctx) = - req_ctx.extract_v2(directory.to_owned())?; + let (Request { url, body, content_type, .. }, post_ctx) = + req_ctx.extract_highest_version(directory.to_owned())?; let response = agent .post(url.clone()) .header("Content-Type", content_type) @@ -519,10 +533,23 @@ mod integration { .unwrap(); log::info!("Response: {:#?}", &response); assert!(response.status().is_success()); - let response_body = - send_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?; + let get_ctx = match post_ctx { + Context::V2(ctx) => + ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?, + _ => panic!("V2 context expected"), + }; + let (Request { url, body, content_type, .. }, ohttp_ctx) = + get_ctx.extract_req(directory.to_owned())?; + let response = agent + .post(url.clone()) + .header("Content-Type", content_type) + .body(body.clone()) + .send() + .await?; // No response body yet since we are async and pushed fallback_psbt to the buffer - assert!(response_body.is_none()); + assert!(get_ctx + .process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)? + .is_none()); // ********************** // Inside the Receiver: @@ -546,11 +573,16 @@ mod integration { // Inside the Sender: // Sender checks, signs, finalizes, extracts, and broadcasts // Replay post fallback to get the response - let (Request { url, body, .. }, send_ctx) = - req_ctx.extract_v2(directory.to_owned())?; - let response = agent.post(url).body(body).send().await?; - let checked_payjoin_proposal_psbt = send_ctx - .process_response(&mut response.bytes().await?.to_vec().as_slice())? + let (Request { url, body, content_type, .. }, ohttp_ctx) = + get_ctx.extract_req(directory.to_owned())?; + let response = agent + .post(url.clone()) + .header("Content-Type", content_type) + .body(body.clone()) + .send() + .await?; + let checked_payjoin_proposal_psbt = get_ctx + .process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)? .unwrap(); let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?; sender.send_raw_transaction(&payjoin_tx)?; @@ -589,9 +621,9 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_original_psbt(&sender, &pj_uri)?; - let mut req_ctx = RequestBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; - let (req, ctx) = req_ctx.extract_v2(EXAMPLE_URL.to_owned())?; + let (req, ctx) = req_ctx.extract_highest_version(EXAMPLE_URL.to_owned())?; let headers = HeaderMock::new(&req.body, req.content_type); // ********************** @@ -603,8 +635,11 @@ mod integration { // ********************** // Inside the Sender: // Sender checks, signs, finalizes, extracts, and broadcasts - let checked_payjoin_proposal_psbt = - ctx.process_response(&mut response.as_bytes())?.unwrap(); + let ctx = match ctx { + Context::V1(ctx) => ctx, + _ => panic!("V1 context expected"), + }; + let checked_payjoin_proposal_psbt = ctx.process_response(&mut response.as_bytes())?; let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?; sender.send_raw_transaction(&payjoin_tx)?; @@ -664,7 +699,7 @@ mod integration { .unwrap(); let psbt = build_original_psbt(&sender, &pj_uri)?; let (Request { url, body, content_type, .. }, send_ctx) = - RequestBuilder::from_psbt_and_uri(psbt, pj_uri)? + SenderBuilder::from_psbt_and_uri(psbt, pj_uri)? .build_with_additional_fee( Amount::from_sat(10000), None, @@ -992,7 +1027,7 @@ mod integration { let psbt = build_original_psbt(&sender, &uri)?; log::debug!("Original psbt: {:#?}", psbt); let max_additional_fee = Amount::from_sat(1000); - let (req, ctx) = RequestBuilder::from_psbt_and_uri(psbt.clone(), uri)? + let (req, ctx) = SenderBuilder::from_psbt_and_uri(psbt.clone(), uri)? .build_with_additional_fee(max_additional_fee, None, FeeRate::ZERO, false)? .extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type); @@ -1069,7 +1104,7 @@ mod integration { .unwrap(); let psbt = build_original_psbt(&sender, &uri)?; log::debug!("Original psbt: {:#?}", psbt); - let (req, ctx) = RequestBuilder::from_psbt_and_uri(psbt.clone(), uri)? + let (req, ctx) = SenderBuilder::from_psbt_and_uri(psbt.clone(), uri)? .build_with_additional_fee(Amount::from_sat(10000), None, FeeRate::ZERO, false)? .extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type); From 86e37819d9cdc09adade911a2491284e51e65672 Mon Sep 17 00:00:00 2001 From: DanGould Date: Tue, 15 Oct 2024 11:18:37 -0400 Subject: [PATCH 03/12] Separate hpke and ohttp v2 modules --- payjoin/src/{v2.rs => hpke.rs} | 254 +------------------------------ payjoin/src/lib.rs | 6 +- payjoin/src/ohttp.rs | 255 ++++++++++++++++++++++++++++++++ payjoin/src/receive/error.rs | 8 +- payjoin/src/receive/v2/error.rs | 10 +- payjoin/src/receive/v2/mod.rs | 25 ++-- payjoin/src/send/error.rs | 10 +- payjoin/src/send/mod.rs | 30 ++-- 8 files changed, 299 insertions(+), 299 deletions(-) rename payjoin/src/{v2.rs => hpke.rs} (51%) create mode 100644 payjoin/src/ohttp.rs diff --git a/payjoin/src/v2.rs b/payjoin/src/hpke.rs similarity index 51% rename from payjoin/src/v2.rs rename to payjoin/src/hpke.rs index 6c62f3ef..ee7de0bd 100644 --- a/payjoin/src/v2.rs +++ b/payjoin/src/hpke.rs @@ -1,8 +1,6 @@ -use std::ops::{Deref, DerefMut}; +use std::ops::Deref; use std::{error, fmt}; -use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; -use bitcoin::base64::Engine; use bitcoin::key::constants::UNCOMPRESSED_PUBLIC_KEY_SIZE; use hpke::aead::ChaCha20Poly1305; use hpke::kdf::HkdfSha256; @@ -272,253 +270,3 @@ impl error::Error for HpkeError { } } } - -pub fn ohttp_encapsulate( - ohttp_keys: &mut ohttp::KeyConfig, - method: &str, - target_resource: &str, - body: Option<&[u8]>, -) -> Result<(Vec, ohttp::ClientResponse), OhttpEncapsulationError> { - use std::fmt::Write; - - let ctx = ohttp::ClientRequest::from_config(ohttp_keys)?; - let url = url::Url::parse(target_resource)?; - let authority_bytes = url.host().map_or_else(Vec::new, |host| { - let mut authority = host.to_string(); - if let Some(port) = url.port() { - write!(authority, ":{}", port).unwrap(); - } - authority.into_bytes() - }); - let mut bhttp_message = bhttp::Message::request( - method.as_bytes().to_vec(), - url.scheme().as_bytes().to_vec(), - authority_bytes, - url.path().as_bytes().to_vec(), - ); - // None of our messages include headers, so we don't add them - if let Some(body) = body { - bhttp_message.write_content(body); - } - let mut bhttp_req = Vec::new(); - let _ = bhttp_message.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_req); - let encapsulated = ctx.encapsulate(&bhttp_req)?; - Ok(encapsulated) -} - -/// decapsulate ohttp, bhttp response and return http response body and status code -pub fn ohttp_decapsulate( - res_ctx: ohttp::ClientResponse, - ohttp_body: &[u8], -) -> Result>, OhttpEncapsulationError> { - let bhttp_body = res_ctx.decapsulate(ohttp_body)?; - let mut r = std::io::Cursor::new(bhttp_body); - let m: bhttp::Message = bhttp::Message::read_bhttp(&mut r)?; - let mut builder = http::Response::builder(); - for field in m.header().iter() { - builder = builder.header(field.name(), field.value()); - } - builder - .status(m.control().status().unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR.into())) - .body(m.content().to_vec()) - .map_err(OhttpEncapsulationError::Http) -} - -/// Error from de/encapsulating an Oblivious HTTP request or response. -#[derive(Debug)] -pub enum OhttpEncapsulationError { - Http(http::Error), - Ohttp(ohttp::Error), - Bhttp(bhttp::Error), - ParseUrl(url::ParseError), -} - -impl From for OhttpEncapsulationError { - fn from(value: http::Error) -> Self { Self::Http(value) } -} - -impl From for OhttpEncapsulationError { - fn from(value: ohttp::Error) -> Self { Self::Ohttp(value) } -} - -impl From for OhttpEncapsulationError { - fn from(value: bhttp::Error) -> Self { Self::Bhttp(value) } -} - -impl From for OhttpEncapsulationError { - fn from(value: url::ParseError) -> Self { Self::ParseUrl(value) } -} - -impl fmt::Display for OhttpEncapsulationError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - use OhttpEncapsulationError::*; - - match &self { - Http(e) => e.fmt(f), - Ohttp(e) => e.fmt(f), - Bhttp(e) => e.fmt(f), - ParseUrl(e) => e.fmt(f), - } - } -} - -impl error::Error for OhttpEncapsulationError { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - use OhttpEncapsulationError::*; - - match &self { - Http(e) => Some(e), - Ohttp(e) => Some(e), - Bhttp(e) => Some(e), - ParseUrl(e) => Some(e), - } - } -} - -#[derive(Debug, Clone)] -pub struct OhttpKeys(pub ohttp::KeyConfig); - -impl OhttpKeys { - /// Decode an OHTTP KeyConfig - pub fn decode(bytes: &[u8]) -> Result { - ohttp::KeyConfig::decode(bytes).map(Self) - } -} - -const KEM_ID: &[u8] = b"\x00\x16"; // DHKEM(secp256k1, HKDF-SHA256) -const SYMMETRIC_LEN: &[u8] = b"\x00\x04"; // 4 bytes -const SYMMETRIC_KDF_AEAD: &[u8] = b"\x00\x01\x00\x03"; // KDF(HKDF-SHA256), AEAD(ChaCha20Poly1305) - -impl fmt::Display for OhttpKeys { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let bytes = self.encode().map_err(|_| fmt::Error)?; - let key_id = bytes[0]; - let pubkey = &bytes[3..68]; - - let compressed_pubkey = - bitcoin::secp256k1::PublicKey::from_slice(pubkey).map_err(|_| fmt::Error)?.serialize(); - - let mut buf = vec![key_id]; - buf.extend_from_slice(&compressed_pubkey); - - let encoded = BASE64_URL_SAFE_NO_PAD.encode(buf); - write!(f, "{}", encoded) - } -} - -impl std::str::FromStr for OhttpKeys { - type Err = ParseOhttpKeysError; - - /// Parses a base64URL-encoded string into OhttpKeys. - /// The string format is: key_id || compressed_public_key - fn from_str(s: &str) -> Result { - let bytes = BASE64_URL_SAFE_NO_PAD.decode(s).map_err(ParseOhttpKeysError::DecodeBase64)?; - - let key_id = *bytes.first().ok_or(ParseOhttpKeysError::InvalidFormat)?; - let compressed_pk = bytes.get(1..34).ok_or(ParseOhttpKeysError::InvalidFormat)?; - - let pubkey = bitcoin::secp256k1::PublicKey::from_slice(compressed_pk) - .map_err(|_| ParseOhttpKeysError::InvalidPublicKey)?; - - let mut buf = vec![key_id]; - buf.extend_from_slice(KEM_ID); - buf.extend_from_slice(&pubkey.serialize_uncompressed()); - buf.extend_from_slice(SYMMETRIC_LEN); - buf.extend_from_slice(SYMMETRIC_KDF_AEAD); - - ohttp::KeyConfig::decode(&buf).map(Self).map_err(ParseOhttpKeysError::DecodeKeyConfig) - } -} - -impl PartialEq for OhttpKeys { - fn eq(&self, other: &Self) -> bool { - match (self.encode(), other.encode()) { - (Ok(self_encoded), Ok(other_encoded)) => self_encoded == other_encoded, - // If OhttpKeys::encode(&self) is Err, return false - _ => false, - } - } -} - -impl Eq for OhttpKeys {} - -impl Deref for OhttpKeys { - type Target = ohttp::KeyConfig; - - fn deref(&self) -> &Self::Target { &self.0 } -} - -impl DerefMut for OhttpKeys { - fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } -} - -impl<'de> serde::Deserialize<'de> for OhttpKeys { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let bytes = Vec::::deserialize(deserializer)?; - OhttpKeys::decode(&bytes).map_err(serde::de::Error::custom) - } -} - -impl serde::Serialize for OhttpKeys { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let bytes = self.encode().map_err(serde::ser::Error::custom)?; - bytes.serialize(serializer) - } -} - -#[derive(Debug)] -pub enum ParseOhttpKeysError { - InvalidFormat, - InvalidPublicKey, - DecodeBase64(bitcoin::base64::DecodeError), - DecodeKeyConfig(ohttp::Error), -} - -impl std::fmt::Display for ParseOhttpKeysError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ParseOhttpKeysError::InvalidFormat => write!(f, "Invalid format"), - ParseOhttpKeysError::InvalidPublicKey => write!(f, "Invalid public key"), - ParseOhttpKeysError::DecodeBase64(e) => write!(f, "Failed to decode base64: {}", e), - ParseOhttpKeysError::DecodeKeyConfig(e) => - write!(f, "Failed to decode KeyConfig: {}", e), - } - } -} - -impl std::error::Error for ParseOhttpKeysError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - ParseOhttpKeysError::DecodeBase64(e) => Some(e), - ParseOhttpKeysError::DecodeKeyConfig(e) => Some(e), - ParseOhttpKeysError::InvalidFormat | ParseOhttpKeysError::InvalidPublicKey => None, - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_ohttp_keys_roundtrip() { - use std::str::FromStr; - - use ohttp::hpke::{Aead, Kdf, Kem}; - use ohttp::{KeyId, SymmetricSuite}; - const KEY_ID: KeyId = 1; - const KEM: Kem = Kem::K256Sha256; - const SYMMETRIC: &[SymmetricSuite] = - &[ohttp::SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)]; - let keys = OhttpKeys(ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap()); - let serialized = &keys.to_string(); - let deserialized = OhttpKeys::from_str(serialized).unwrap(); - assert_eq!(keys.encode().unwrap(), deserialized.encode().unwrap()); - } -} diff --git a/payjoin/src/lib.rs b/payjoin/src/lib.rs index b1bb0345..8ebf3762 100644 --- a/payjoin/src/lib.rs +++ b/payjoin/src/lib.rs @@ -28,9 +28,11 @@ pub use crate::receive::Error; pub mod send; #[cfg(feature = "v2")] -pub(crate) mod v2; +pub(crate) mod hpke; #[cfg(feature = "v2")] -pub use v2::OhttpKeys; +pub(crate) mod ohttp; +#[cfg(feature = "v2")] +pub use crate::ohttp::OhttpKeys; #[cfg(feature = "io")] pub mod io; diff --git a/payjoin/src/ohttp.rs b/payjoin/src/ohttp.rs new file mode 100644 index 00000000..9bd7d147 --- /dev/null +++ b/payjoin/src/ohttp.rs @@ -0,0 +1,255 @@ +use std::ops::{Deref, DerefMut}; +use std::{error, fmt}; + +use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; +use bitcoin::base64::Engine; + +pub fn ohttp_encapsulate( + ohttp_keys: &mut ohttp::KeyConfig, + method: &str, + target_resource: &str, + body: Option<&[u8]>, +) -> Result<(Vec, ohttp::ClientResponse), OhttpEncapsulationError> { + use std::fmt::Write; + + let ctx = ohttp::ClientRequest::from_config(ohttp_keys)?; + let url = url::Url::parse(target_resource)?; + let authority_bytes = url.host().map_or_else(Vec::new, |host| { + let mut authority = host.to_string(); + if let Some(port) = url.port() { + write!(authority, ":{}", port).unwrap(); + } + authority.into_bytes() + }); + let mut bhttp_message = bhttp::Message::request( + method.as_bytes().to_vec(), + url.scheme().as_bytes().to_vec(), + authority_bytes, + url.path().as_bytes().to_vec(), + ); + // None of our messages include headers, so we don't add them + if let Some(body) = body { + bhttp_message.write_content(body); + } + let mut bhttp_req = Vec::new(); + let _ = bhttp_message.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_req); + let encapsulated = ctx.encapsulate(&bhttp_req)?; + Ok(encapsulated) +} + +/// decapsulate ohttp, bhttp response and return http response body and status code +pub fn ohttp_decapsulate( + res_ctx: ohttp::ClientResponse, + ohttp_body: &[u8], +) -> Result>, OhttpEncapsulationError> { + let bhttp_body = res_ctx.decapsulate(ohttp_body)?; + let mut r = std::io::Cursor::new(bhttp_body); + let m: bhttp::Message = bhttp::Message::read_bhttp(&mut r)?; + let mut builder = http::Response::builder(); + for field in m.header().iter() { + builder = builder.header(field.name(), field.value()); + } + builder + .status(m.control().status().unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR.into())) + .body(m.content().to_vec()) + .map_err(OhttpEncapsulationError::Http) +} + +/// Error from de/encapsulating an Oblivious HTTP request or response. +#[derive(Debug)] +pub enum OhttpEncapsulationError { + Http(http::Error), + Ohttp(ohttp::Error), + Bhttp(bhttp::Error), + ParseUrl(url::ParseError), +} + +impl From for OhttpEncapsulationError { + fn from(value: http::Error) -> Self { Self::Http(value) } +} + +impl From for OhttpEncapsulationError { + fn from(value: ohttp::Error) -> Self { Self::Ohttp(value) } +} + +impl From for OhttpEncapsulationError { + fn from(value: bhttp::Error) -> Self { Self::Bhttp(value) } +} + +impl From for OhttpEncapsulationError { + fn from(value: url::ParseError) -> Self { Self::ParseUrl(value) } +} + +impl fmt::Display for OhttpEncapsulationError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use OhttpEncapsulationError::*; + + match &self { + Http(e) => e.fmt(f), + Ohttp(e) => e.fmt(f), + Bhttp(e) => e.fmt(f), + ParseUrl(e) => e.fmt(f), + } + } +} + +impl error::Error for OhttpEncapsulationError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + use OhttpEncapsulationError::*; + + match &self { + Http(e) => Some(e), + Ohttp(e) => Some(e), + Bhttp(e) => Some(e), + ParseUrl(e) => Some(e), + } + } +} + +#[derive(Debug, Clone)] +pub struct OhttpKeys(pub ohttp::KeyConfig); + +impl OhttpKeys { + /// Decode an OHTTP KeyConfig + pub fn decode(bytes: &[u8]) -> Result { + ohttp::KeyConfig::decode(bytes).map(Self) + } +} + +const KEM_ID: &[u8] = b"\x00\x16"; // DHKEM(secp256k1, HKDF-SHA256) +const SYMMETRIC_LEN: &[u8] = b"\x00\x04"; // 4 bytes +const SYMMETRIC_KDF_AEAD: &[u8] = b"\x00\x01\x00\x03"; // KDF(HKDF-SHA256), AEAD(ChaCha20Poly1305) + +impl fmt::Display for OhttpKeys { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let bytes = self.encode().map_err(|_| fmt::Error)?; + let key_id = bytes[0]; + let pubkey = &bytes[3..68]; + + let compressed_pubkey = + bitcoin::secp256k1::PublicKey::from_slice(pubkey).map_err(|_| fmt::Error)?.serialize(); + + let mut buf = vec![key_id]; + buf.extend_from_slice(&compressed_pubkey); + + let encoded = BASE64_URL_SAFE_NO_PAD.encode(buf); + write!(f, "{}", encoded) + } +} + +impl std::str::FromStr for OhttpKeys { + type Err = ParseOhttpKeysError; + + /// Parses a base64URL-encoded string into OhttpKeys. + /// The string format is: key_id || compressed_public_key + fn from_str(s: &str) -> Result { + let bytes = BASE64_URL_SAFE_NO_PAD.decode(s).map_err(ParseOhttpKeysError::DecodeBase64)?; + + let key_id = *bytes.first().ok_or(ParseOhttpKeysError::InvalidFormat)?; + let compressed_pk = bytes.get(1..34).ok_or(ParseOhttpKeysError::InvalidFormat)?; + + let pubkey = bitcoin::secp256k1::PublicKey::from_slice(compressed_pk) + .map_err(|_| ParseOhttpKeysError::InvalidPublicKey)?; + + let mut buf = vec![key_id]; + buf.extend_from_slice(KEM_ID); + buf.extend_from_slice(&pubkey.serialize_uncompressed()); + buf.extend_from_slice(SYMMETRIC_LEN); + buf.extend_from_slice(SYMMETRIC_KDF_AEAD); + + ohttp::KeyConfig::decode(&buf).map(Self).map_err(ParseOhttpKeysError::DecodeKeyConfig) + } +} + +impl PartialEq for OhttpKeys { + fn eq(&self, other: &Self) -> bool { + match (self.encode(), other.encode()) { + (Ok(self_encoded), Ok(other_encoded)) => self_encoded == other_encoded, + // If OhttpKeys::encode(&self) is Err, return false + _ => false, + } + } +} + +impl Eq for OhttpKeys {} + +impl Deref for OhttpKeys { + type Target = ohttp::KeyConfig; + + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl DerefMut for OhttpKeys { + fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } +} + +impl<'de> serde::Deserialize<'de> for OhttpKeys { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let bytes = Vec::::deserialize(deserializer)?; + OhttpKeys::decode(&bytes).map_err(serde::de::Error::custom) + } +} + +impl serde::Serialize for OhttpKeys { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let bytes = self.encode().map_err(serde::ser::Error::custom)?; + bytes.serialize(serializer) + } +} + +#[derive(Debug)] +pub enum ParseOhttpKeysError { + InvalidFormat, + InvalidPublicKey, + DecodeBase64(bitcoin::base64::DecodeError), + DecodeKeyConfig(ohttp::Error), +} + +impl std::fmt::Display for ParseOhttpKeysError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ParseOhttpKeysError::InvalidFormat => write!(f, "Invalid format"), + ParseOhttpKeysError::InvalidPublicKey => write!(f, "Invalid public key"), + ParseOhttpKeysError::DecodeBase64(e) => write!(f, "Failed to decode base64: {}", e), + ParseOhttpKeysError::DecodeKeyConfig(e) => + write!(f, "Failed to decode KeyConfig: {}", e), + } + } +} + +impl std::error::Error for ParseOhttpKeysError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + ParseOhttpKeysError::DecodeBase64(e) => Some(e), + ParseOhttpKeysError::DecodeKeyConfig(e) => Some(e), + ParseOhttpKeysError::InvalidFormat | ParseOhttpKeysError::InvalidPublicKey => None, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_ohttp_keys_roundtrip() { + use std::str::FromStr; + + use ohttp::hpke::{Aead, Kdf, Kem}; + use ohttp::{KeyId, SymmetricSuite}; + const KEY_ID: KeyId = 1; + const KEM: Kem = Kem::K256Sha256; + const SYMMETRIC: &[SymmetricSuite] = + &[ohttp::SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)]; + let keys = OhttpKeys(ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap()); + let serialized = &keys.to_string(); + let deserialized = OhttpKeys::from_str(serialized).unwrap(); + assert_eq!(keys.encode().unwrap(), deserialized.encode().unwrap()); + } +} diff --git a/payjoin/src/receive/error.rs b/payjoin/src/receive/error.rs index 82bcbb77..a8479db6 100644 --- a/payjoin/src/receive/error.rs +++ b/payjoin/src/receive/error.rs @@ -36,13 +36,13 @@ impl From for Error { } #[cfg(feature = "v2")] -impl From for Error { - fn from(e: crate::v2::HpkeError) -> Self { Error::Server(Box::new(e)) } +impl From for Error { + fn from(e: crate::hpke::HpkeError) -> Self { Error::Server(Box::new(e)) } } #[cfg(feature = "v2")] -impl From for Error { - fn from(e: crate::v2::OhttpEncapsulationError) -> Self { Error::Server(Box::new(e)) } +impl From for Error { + fn from(e: crate::ohttp::OhttpEncapsulationError) -> Self { Error::Server(Box::new(e)) } } /// Error that may occur when the request from sender is malformed. diff --git a/payjoin/src/receive/v2/error.rs b/payjoin/src/receive/v2/error.rs index c6d7daf2..1a934dd3 100644 --- a/payjoin/src/receive/v2/error.rs +++ b/payjoin/src/receive/v2/error.rs @@ -1,7 +1,7 @@ use core::fmt; use std::error; -use crate::v2::OhttpEncapsulationError; +use crate::ohttp::OhttpEncapsulationError; #[derive(Debug)] pub struct SessionError(InternalSessionError); @@ -11,14 +11,14 @@ pub(crate) enum InternalSessionError { /// The session has expired Expired(std::time::SystemTime), /// OHTTP Encapsulation failed - OhttpEncapsulationError(OhttpEncapsulationError), + OhttpEncapsulation(OhttpEncapsulationError), } impl fmt::Display for SessionError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match &self.0 { InternalSessionError::Expired(expiry) => write!(f, "Session expired at {:?}", expiry), - InternalSessionError::OhttpEncapsulationError(e) => + InternalSessionError::OhttpEncapsulation(e) => write!(f, "OHTTP Encapsulation Error: {}", e), } } @@ -28,7 +28,7 @@ impl error::Error for SessionError { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match &self.0 { InternalSessionError::Expired(_) => None, - InternalSessionError::OhttpEncapsulationError(e) => Some(e), + InternalSessionError::OhttpEncapsulation(e) => Some(e), } } } @@ -39,6 +39,6 @@ impl From for SessionError { impl From for SessionError { fn from(e: OhttpEncapsulationError) -> Self { - SessionError(InternalSessionError::OhttpEncapsulationError(e)) + SessionError(InternalSessionError::OhttpEncapsulation(e)) } } diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index 396bf66a..cb19ffbd 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -14,10 +14,11 @@ use super::{ Error, InputContributionError, InternalRequestError, OutputSubstitutionError, RequestError, SelectionError, }; +use crate::hpke::{decrypt_message_a, encrypt_message_b, HpkeKeyPair, HpkePublicKey}; +use crate::ohttp::{ohttp_decapsulate, ohttp_encapsulate, OhttpEncapsulationError, OhttpKeys}; use crate::psbt::PsbtExt; use crate::receive::optional_parameters::Params; -use crate::v2::{HpkeKeyPair, HpkePublicKey, OhttpEncapsulationError}; -use crate::{OhttpKeys, PjUriBuilder, Request}; +use crate::{PjUriBuilder, Request}; pub(crate) mod error; @@ -99,7 +100,7 @@ impl Receiver { return Err(InternalSessionError::Expired(self.context.expiry).into()); } let (body, ohttp_ctx) = - self.fallback_req_body().map_err(InternalSessionError::OhttpEncapsulationError)?; + self.fallback_req_body().map_err(InternalSessionError::OhttpEncapsulation)?; let url = self.context.ohttp_relay.clone(); let req = Request::new_v2(url, body); Ok((req, ohttp_ctx)) @@ -115,7 +116,7 @@ impl Receiver { let mut buf = Vec::new(); let _ = body.read_to_end(&mut buf); log::trace!("decapsulating directory response"); - let response = crate::v2::ohttp_decapsulate(context, &buf)?; + let response = ohttp_decapsulate(context, &buf)?; if response.body().is_empty() { log::debug!("response is empty"); return Ok(None); @@ -132,12 +133,7 @@ impl Receiver { &mut self, ) -> Result<(Vec, ohttp::ClientResponse), OhttpEncapsulationError> { let fallback_target = self.pj_url(); - crate::v2::ohttp_encapsulate( - &mut self.context.ohttp_keys, - "GET", - fallback_target.as_str(), - None, - ) + ohttp_encapsulate(&mut self.context.ohttp_keys, "GET", fallback_target.as_str(), None) } fn extract_proposal_from_v1(&mut self, response: String) -> Result { @@ -145,8 +141,7 @@ impl Receiver { } fn extract_proposal_from_v2(&mut self, response: Vec) -> Result { - let (payload_bytes, e) = - crate::v2::decrypt_message_a(&response, self.context.s.secret_key().clone())?; + let (payload_bytes, e) = decrypt_message_a(&response, self.context.s.secret_key().clone())?; self.context.e = Some(e); let payload = String::from_utf8(payload_bytes).map_err(InternalRequestError::Utf8)?; Ok(self.unchecked_from_payload(payload)?) @@ -471,7 +466,7 @@ impl PayjoinProposal { let sender_subdir = subdir_path_from_pubkey(e); target_resource = self.context.directory.join(&sender_subdir).map_err(|e| Error::Server(e.into()))?; - body = crate::v2::encrypt_message_b(payjoin_bytes, &self.context.s, e)?; + body = encrypt_message_b(payjoin_bytes, &self.context.s, e)?; method = "POST"; } else { // Prepare v2 wrapped and backwards-compatible v1 payload @@ -485,7 +480,7 @@ impl PayjoinProposal { method = "PUT"; } log::debug!("Payjoin PSBT target: {}", target_resource.as_str()); - let (body, ctx) = crate::v2::ohttp_encapsulate( + let (body, ctx) = ohttp_encapsulate( &mut self.context.ohttp_keys, method, target_resource.as_str(), @@ -509,7 +504,7 @@ impl PayjoinProposal { res: Vec, ohttp_context: ohttp::ClientResponse, ) -> Result<(), Error> { - let res = crate::v2::ohttp_decapsulate(ohttp_context, &res)?; + let res = ohttp_decapsulate(ohttp_context, &res)?; if res.status().is_success() { Ok(()) } else { diff --git a/payjoin/src/send/error.rs b/payjoin/src/send/error.rs index f453d866..94c13737 100644 --- a/payjoin/src/send/error.rs +++ b/payjoin/src/send/error.rs @@ -58,9 +58,9 @@ pub(crate) enum InternalValidationError { FeeRateBelowMinimum, Psbt(bitcoin::psbt::Error), #[cfg(feature = "v2")] - Hpke(crate::v2::HpkeError), + Hpke(crate::hpke::HpkeError), #[cfg(feature = "v2")] - OhttpEncapsulation(crate::v2::OhttpEncapsulationError), + OhttpEncapsulation(crate::ohttp::OhttpEncapsulationError), #[cfg(feature = "v2")] UnexpectedStatusCode, } @@ -190,9 +190,9 @@ pub(crate) enum InternalCreateRequestError { AddressType(crate::psbt::AddressTypeError), InputWeight(crate::psbt::InputWeightError), #[cfg(feature = "v2")] - Hpke(crate::v2::HpkeError), + Hpke(crate::hpke::HpkeError), #[cfg(feature = "v2")] - OhttpEncapsulation(crate::v2::OhttpEncapsulationError), + OhttpEncapsulation(crate::ohttp::OhttpEncapsulationError), #[cfg(feature = "v2")] ParseSubdirectory(ParseSubdirectoryError), #[cfg(feature = "v2")] @@ -289,7 +289,7 @@ impl From for CreateRequestError { pub(crate) enum ParseSubdirectoryError { MissingSubdirectory, SubdirectoryNotBase64(bitcoin::base64::DecodeError), - SubdirectoryInvalidPubkey(crate::v2::HpkeError), + SubdirectoryInvalidPubkey(crate::hpke::HpkeError), } #[cfg(feature = "v2")] diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 8a859c19..53f41c9d 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -34,10 +34,12 @@ pub(crate) use error::{InternalCreateRequestError, InternalValidationError}; use serde::{Deserialize, Serialize}; use url::Url; +#[cfg(feature = "v2")] +use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeKeyPair, HpkePublicKey}; +#[cfg(feature = "v2")] +use crate::ohttp::{ohttp_decapsulate, ohttp_encapsulate}; use crate::psbt::PsbtExt; use crate::request::Request; -#[cfg(feature = "v2")] -use crate::v2::{HpkeKeyPair, HpkePublicKey}; use crate::PjUri; // See usize casts @@ -226,7 +228,7 @@ impl<'a> SenderBuilder<'a> { payee, min_fee_rate: self.min_fee_rate, #[cfg(feature = "v2")] - e: crate::v2::HpkeKeyPair::gen_keypair(), + e: HpkeKeyPair::gen_keypair(), }) } } @@ -241,7 +243,7 @@ pub struct Sender { min_fee_rate: FeeRate, payee: ScriptBuf, #[cfg(feature = "v2")] - e: crate::v2::HpkeKeyPair, + e: HpkeKeyPair, } impl Sender { @@ -318,13 +320,12 @@ impl Sender { self.fee_contribution, self.min_fee_rate, )?; - let body = crate::v2::encrypt_message_a(body, &self.e.secret_key().clone(), &rs) + let body = encrypt_message_a(body, &self.e.secret_key().clone(), &rs) .map_err(InternalCreateRequestError::Hpke)?; let mut ohttp = self.endpoint.ohttp().ok_or(InternalCreateRequestError::MissingOhttpConfig)?; - let (body, ohttp_ctx) = - crate::v2::ohttp_encapsulate(&mut ohttp, "POST", url.as_str(), Some(&body)) - .map_err(InternalCreateRequestError::OhttpEncapsulation)?; + let (body, ohttp_ctx) = ohttp_encapsulate(&mut ohttp, "POST", url.as_str(), Some(&body)) + .map_err(InternalCreateRequestError::OhttpEncapsulation)?; log::debug!("ohttp_relay_url: {:?}", ohttp_relay); Ok(( Request::new_v2(ohttp_relay, body), @@ -400,7 +401,7 @@ impl V2PostContext { ) -> Result { let mut res_buf = Vec::new(); response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; - let response = crate::v2::ohttp_decapsulate(self.ohttp_ctx, &res_buf) + let response = ohttp_decapsulate(self.ohttp_ctx, &res_buf) .map_err(InternalValidationError::OhttpEncapsulation)?; match response.status() { http::StatusCode::OK => { @@ -434,7 +435,7 @@ impl V2GetContext { let subdir = BASE64_URL_SAFE_NO_PAD.encode(self.hpke_ctx.e.public_key().to_compressed_bytes()); url.set_path(&subdir); - let body = crate::v2::encrypt_message_a( + let body = encrypt_message_a( Vec::new(), &self.hpke_ctx.e.secret_key().clone(), &self.hpke_ctx.rs.clone(), @@ -442,9 +443,8 @@ impl V2GetContext { .map_err(InternalCreateRequestError::Hpke)?; let mut ohttp = self.endpoint.ohttp().ok_or(InternalCreateRequestError::MissingOhttpConfig)?; - let (body, ohttp_ctx) = - crate::v2::ohttp_encapsulate(&mut ohttp, "GET", url.as_str(), Some(&body)) - .map_err(InternalCreateRequestError::OhttpEncapsulation)?; + let (body, ohttp_ctx) = ohttp_encapsulate(&mut ohttp, "GET", url.as_str(), Some(&body)) + .map_err(InternalCreateRequestError::OhttpEncapsulation)?; Ok((Request::new_v2(ohttp_relay, body), ohttp_ctx)) } @@ -456,14 +456,14 @@ impl V2GetContext { ) -> Result, ResponseError> { let mut res_buf = Vec::new(); response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; - let response = crate::v2::ohttp_decapsulate(ohttp_ctx, &res_buf) + let response = ohttp_decapsulate(ohttp_ctx, &res_buf) .map_err(InternalValidationError::OhttpEncapsulation)?; let body = match response.status() { http::StatusCode::OK => response.body().to_vec(), http::StatusCode::ACCEPTED => return Ok(None), _ => return Err(InternalValidationError::UnexpectedStatusCode)?, }; - let psbt = crate::v2::decrypt_message_b( + let psbt = decrypt_message_b( &body, self.hpke_ctx.rs.clone(), self.hpke_ctx.e.secret_key().clone(), From 84bec83aa0963ac043e2d2955e996ef53a9b615e Mon Sep 17 00:00:00 2001 From: DanGould Date: Tue, 15 Oct 2024 13:24:00 -0400 Subject: [PATCH 04/12] Encrypt Message A with ephemeral key like Noise IK Encrypting Message A with an ephemeral "encapsulation key" allows the sender "reply key" corresponding to its subdirectory to be hidden from the directory. --- payjoin/src/hpke.rs | 43 ++++++++++++++++++++++-------------- payjoin/src/send/mod.rs | 48 +++++++++++++++++++++++++++++------------ 2 files changed, 61 insertions(+), 30 deletions(-) diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index ee7de0bd..6a9c9883 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -20,8 +20,6 @@ pub type SecretKey = ::PrivateKey; pub type PublicKey = ::PublicKey; pub type EncappedKey = ::EncappedKey; -fn sk_to_pk(sk: &SecretKey) -> PublicKey { ::sk_to_pk(sk) } - #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct HpkeKeyPair(pub HpkeSecretKey, pub HpkePublicKey); @@ -130,19 +128,24 @@ impl<'de> serde::Deserialize<'de> for HpkePublicKey { /// Message A is sent from the sender to the receiver containing an Original PSBT payload #[cfg(feature = "send")] pub fn encrypt_message_a( - mut plaintext: Vec, - sender_sk: &HpkeSecretKey, + body: Vec, + encapsulation_pair: &HpkeKeyPair, + reply_pk: &HpkePublicKey, receiver_pk: &HpkePublicKey, ) -> Result, HpkeError> { - let pk = sk_to_pk(&sender_sk.0); let (encapsulated_key, mut encryption_context) = hpke::setup_sender::( - &OpModeS::Auth((sender_sk.0.clone(), pk.clone())), + &OpModeS::Auth(( + encapsulation_pair.secret_key().0.clone(), + encapsulation_pair.public_key().0.clone(), + )), &receiver_pk.0, INFO_A, &mut OsRng, )?; - let aad = pk.to_bytes().to_vec(); + let aad = encapsulation_pair.public_key().to_bytes().to_vec(); + let mut plaintext = reply_pk.to_bytes().to_vec(); + plaintext.extend(body); let plaintext = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_A_LENGTH)?; let ciphertext = encryption_context.seal(plaintext, &aad)?; let mut message_a = encapsulated_key.to_bytes().to_vec(); @@ -156,18 +159,26 @@ pub fn decrypt_message_a( message_a: &[u8], receiver_sk: HpkeSecretKey, ) -> Result<(Vec, HpkePublicKey), HpkeError> { - let enc = message_a.get(..65).ok_or(HpkeError::PayloadTooShort)?; + let enc = message_a.get(..UNCOMPRESSED_PUBLIC_KEY_SIZE).ok_or(HpkeError::PayloadTooShort)?; let enc = EncappedKey::from_bytes(enc)?; - let aad = message_a.get(65..130).ok_or(HpkeError::PayloadTooShort)?; - let pk_s = PublicKey::from_bytes(aad)?; - let mut decryption_ctx = hpke::setup_receiver::< - ChaCha20Poly1305, - HkdfSha256, - SecpK256HkdfSha256, - >(&OpModeR::Auth(pk_s.clone()), &receiver_sk.0, &enc, INFO_A)?; + let aad = message_a + .get(UNCOMPRESSED_PUBLIC_KEY_SIZE..(UNCOMPRESSED_PUBLIC_KEY_SIZE * 2)) + .ok_or(HpkeError::PayloadTooShort)?; + let encapsulation_pk = PublicKey::from_bytes(aad)?; + let mut decryption_ctx = + hpke::setup_receiver::( + &OpModeR::Auth(encapsulation_pk.clone()), + &receiver_sk.0, + &enc, + INFO_A, + )?; let ciphertext = message_a.get(130..).ok_or(HpkeError::PayloadTooShort)?; let plaintext = decryption_ctx.open(ciphertext, aad)?; - Ok((plaintext, HpkePublicKey(pk_s))) + let reply_pk = + plaintext.get(..UNCOMPRESSED_PUBLIC_KEY_SIZE).ok_or(HpkeError::PayloadTooShort)?; + let reply_pk = HpkePublicKey(PublicKey::from_bytes(reply_pk)?); + let body = plaintext.get(UNCOMPRESSED_PUBLIC_KEY_SIZE..).ok_or(HpkeError::PayloadTooShort)?; + Ok((body.to_vec(), reply_pk)) } /// Message B is sent from the receiver to the sender containing a Payjoin PSBT payload or an error diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 53f41c9d..5a02abc6 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -228,7 +228,7 @@ impl<'a> SenderBuilder<'a> { payee, min_fee_rate: self.min_fee_rate, #[cfg(feature = "v2")] - e: HpkeKeyPair::gen_keypair(), + reply_pair: HpkeKeyPair::gen_keypair(), }) } } @@ -243,7 +243,7 @@ pub struct Sender { min_fee_rate: FeeRate, payee: ScriptBuf, #[cfg(feature = "v2")] - e: HpkeKeyPair, + reply_pair: HpkeKeyPair, } impl Sender { @@ -320,8 +320,14 @@ impl Sender { self.fee_contribution, self.min_fee_rate, )?; - let body = encrypt_message_a(body, &self.e.secret_key().clone(), &rs) - .map_err(InternalCreateRequestError::Hpke)?; + let hpke_ctx = HpkeContext::new(rs); + let body = encrypt_message_a( + body, + &hpke_ctx.encapsulation_pair.clone(), + &hpke_ctx.reply_pair.public_key().clone(), + &hpke_ctx.receiver.clone(), + ) + .map_err(InternalCreateRequestError::Hpke)?; let mut ohttp = self.endpoint.ohttp().ok_or(InternalCreateRequestError::MissingOhttpConfig)?; let (body, ohttp_ctx) = ohttp_encapsulate(&mut ohttp, "POST", url.as_str(), Some(&body)) @@ -339,7 +345,7 @@ impl Sender { min_fee_rate: self.min_fee_rate, allow_mixed_input_scripts: true, }, - hpke_ctx: HpkeContext { rs, e: self.e.clone() }, + hpke_ctx, ohttp_ctx, }), )) @@ -432,13 +438,14 @@ impl V2GetContext { ) -> Result<(Request, ohttp::ClientResponse), CreateRequestError> { use crate::uri::UrlExt; let mut url = self.endpoint.clone(); - let subdir = - BASE64_URL_SAFE_NO_PAD.encode(self.hpke_ctx.e.public_key().to_compressed_bytes()); + let subdir = BASE64_URL_SAFE_NO_PAD + .encode(self.hpke_ctx.reply_pair.public_key().to_compressed_bytes()); url.set_path(&subdir); let body = encrypt_message_a( Vec::new(), - &self.hpke_ctx.e.secret_key().clone(), - &self.hpke_ctx.rs.clone(), + &self.hpke_ctx.encapsulation_pair.clone(), + &self.hpke_ctx.reply_pair.public_key().clone(), + &self.hpke_ctx.receiver.clone(), ) .map_err(InternalCreateRequestError::Hpke)?; let mut ohttp = @@ -465,8 +472,8 @@ impl V2GetContext { }; let psbt = decrypt_message_b( &body, - self.hpke_ctx.rs.clone(), - self.hpke_ctx.e.secret_key().clone(), + self.hpke_ctx.receiver.clone(), + self.hpke_ctx.reply_pair.secret_key().clone(), ) .map_err(InternalValidationError::Hpke)?; @@ -489,10 +496,23 @@ pub struct PsbtContext { payee: ScriptBuf, allow_mixed_input_scripts: bool, } + #[cfg(feature = "v2")] struct HpkeContext { - rs: HpkePublicKey, - e: HpkeKeyPair, + receiver: HpkePublicKey, + encapsulation_pair: HpkeKeyPair, + reply_pair: HpkeKeyPair, +} + +#[cfg(feature = "v2")] +impl HpkeContext { + pub fn new(receiver: HpkePublicKey) -> Self { + Self { + receiver, + encapsulation_pair: HpkeKeyPair::gen_keypair(), + reply_pair: HpkeKeyPair::gen_keypair(), + } + } } macro_rules! check_eq { @@ -985,7 +1005,7 @@ mod test { fee_contribution: None, min_fee_rate: FeeRate::ZERO, payee: ScriptBuf::from(vec![0x00]), - e: HpkeKeyPair::gen_keypair(), + reply_pair: HpkeKeyPair::gen_keypair(), }; let serialized = serde_json::to_string(&req_ctx).unwrap(); let deserialized = serde_json::from_str(&serialized).unwrap(); From 9b2635a3211e6a44ba2f2f31dd09f93e12247fd0 Mon Sep 17 00:00:00 2001 From: DanGould Date: Tue, 15 Oct 2024 14:06:27 -0400 Subject: [PATCH 05/12] Remove obsolete `Sender` keypair --- payjoin/src/send/mod.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 5a02abc6..66afa3ab 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -227,8 +227,6 @@ impl<'a> SenderBuilder<'a> { fee_contribution, payee, min_fee_rate: self.min_fee_rate, - #[cfg(feature = "v2")] - reply_pair: HpkeKeyPair::gen_keypair(), }) } } @@ -242,8 +240,6 @@ pub struct Sender { fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, payee: ScriptBuf, - #[cfg(feature = "v2")] - reply_pair: HpkeKeyPair, } impl Sender { @@ -1005,7 +1001,6 @@ mod test { fee_contribution: None, min_fee_rate: FeeRate::ZERO, payee: ScriptBuf::from(vec![0x00]), - reply_pair: HpkeKeyPair::gen_keypair(), }; let serialized = serde_json::to_string(&req_ctx).unwrap(); let deserialized = serde_json::from_str(&serialized).unwrap(); From affb88b9a62ab4cf468556c8453fe5f63d0eaafb Mon Sep 17 00:00:00 2001 From: DanGould Date: Tue, 15 Oct 2024 15:59:47 -0400 Subject: [PATCH 06/12] Seek in hpke en/decryption instead of magic numbers --- payjoin/src/hpke.rs | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index 6a9c9883..fe916e37 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -159,12 +159,18 @@ pub fn decrypt_message_a( message_a: &[u8], receiver_sk: HpkeSecretKey, ) -> Result<(Vec, HpkePublicKey), HpkeError> { - let enc = message_a.get(..UNCOMPRESSED_PUBLIC_KEY_SIZE).ok_or(HpkeError::PayloadTooShort)?; - let enc = EncappedKey::from_bytes(enc)?; - let aad = message_a - .get(UNCOMPRESSED_PUBLIC_KEY_SIZE..(UNCOMPRESSED_PUBLIC_KEY_SIZE * 2)) - .ok_or(HpkeError::PayloadTooShort)?; - let encapsulation_pk = PublicKey::from_bytes(aad)?; + use std::io::{Cursor, Read}; + + let mut cursor = Cursor::new(message_a); + + let mut enc = [0u8; UNCOMPRESSED_PUBLIC_KEY_SIZE]; + cursor.read_exact(&mut enc).map_err(|_| HpkeError::PayloadTooShort)?; + let enc = EncappedKey::from_bytes(&enc)?; + + let mut aad = [0u8; UNCOMPRESSED_PUBLIC_KEY_SIZE]; + cursor.read_exact(&mut aad).map_err(|_| HpkeError::PayloadTooShort)?; + let encapsulation_pk = PublicKey::from_bytes(&aad)?; + let mut decryption_ctx = hpke::setup_receiver::( &OpModeR::Auth(encapsulation_pk.clone()), @@ -172,12 +178,16 @@ pub fn decrypt_message_a( &enc, INFO_A, )?; - let ciphertext = message_a.get(130..).ok_or(HpkeError::PayloadTooShort)?; - let plaintext = decryption_ctx.open(ciphertext, aad)?; - let reply_pk = - plaintext.get(..UNCOMPRESSED_PUBLIC_KEY_SIZE).ok_or(HpkeError::PayloadTooShort)?; - let reply_pk = HpkePublicKey(PublicKey::from_bytes(reply_pk)?); - let body = plaintext.get(UNCOMPRESSED_PUBLIC_KEY_SIZE..).ok_or(HpkeError::PayloadTooShort)?; + + let mut ciphertext = Vec::new(); + cursor.read_to_end(&mut ciphertext).map_err(|_| HpkeError::PayloadTooShort)?; + let plaintext = decryption_ctx.open(&ciphertext, &aad)?; + + let reply_pk_bytes = &plaintext[..UNCOMPRESSED_PUBLIC_KEY_SIZE]; + let reply_pk = HpkePublicKey(PublicKey::from_bytes(reply_pk_bytes)?); + + let body = &plaintext[UNCOMPRESSED_PUBLIC_KEY_SIZE..]; + Ok((body.to_vec(), reply_pk)) } @@ -218,8 +228,10 @@ pub fn decrypt_message_b( HkdfSha256, SecpK256HkdfSha256, >(&OpModeR::Auth(receiver_pk.0), &sender_sk.0, &enc, INFO_B)?; - let plaintext = - decryption_ctx.open(message_b.get(65..).ok_or(HpkeError::PayloadTooShort)?, &[])?; + let plaintext = decryption_ctx.open( + message_b.get(UNCOMPRESSED_PUBLIC_KEY_SIZE..).ok_or(HpkeError::PayloadTooShort)?, + &[], + )?; Ok(plaintext) } From 38d8777a95c9205ad5124670c72a737dad41cbde Mon Sep 17 00:00:00 2001 From: DanGould Date: Tue, 15 Oct 2024 16:04:44 -0400 Subject: [PATCH 07/12] Redact hpke secret key from Debug --- payjoin/src/hpke.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index fe916e37..d002125d 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -47,7 +47,7 @@ impl Deref for HpkeSecretKey { impl core::fmt::Debug for HpkeSecretKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "SecpHpkeSecretKey({:?})", self.0.to_bytes()) + write!(f, "SecpHpkeSecretKey([REDACTED])") } } From 3135d57da770ed2b2fc6cccf836ec3bc0bb82498 Mon Sep 17 00:00:00 2001 From: Yuval Kogman Date: Fri, 18 Oct 2024 14:44:42 +0200 Subject: [PATCH 08/12] make payloads uniform by removing sender auth key Since the encapsulation keypair was ephemeral and not known to the receiver, but used in the Auth pattern it was included as authenticated associated data in the payload. This means that encrypt_message_a and encrypt_message_b had distinguishable bit patterns, the former starting with two uncompressed curve points (one for the DHKEM and one for this auth key), whereas the latter only had one (the DHKEM point). Since the sender's first message establishes a reply key, that key could be used in a second Auth HPKE setup after the Base setup, in order to prove that the sender can decrypt the receiver's reply. However, incentives are for the sender to provide a valid point, and the reply key is included in AEAD ciphertext, so this would add complexity without meaningful improving security or incentive compatibility. --- payjoin/src/hpke.rs | 28 ++++++++-------------------- payjoin/src/send/mod.rs | 9 +-------- 2 files changed, 9 insertions(+), 28 deletions(-) diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index d002125d..cdbbf097 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -129,27 +129,21 @@ impl<'de> serde::Deserialize<'de> for HpkePublicKey { #[cfg(feature = "send")] pub fn encrypt_message_a( body: Vec, - encapsulation_pair: &HpkeKeyPair, reply_pk: &HpkePublicKey, receiver_pk: &HpkePublicKey, ) -> Result, HpkeError> { let (encapsulated_key, mut encryption_context) = hpke::setup_sender::( - &OpModeS::Auth(( - encapsulation_pair.secret_key().0.clone(), - encapsulation_pair.public_key().0.clone(), - )), + &OpModeS::Base, &receiver_pk.0, INFO_A, &mut OsRng, )?; - let aad = encapsulation_pair.public_key().to_bytes().to_vec(); let mut plaintext = reply_pk.to_bytes().to_vec(); plaintext.extend(body); let plaintext = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_A_LENGTH)?; - let ciphertext = encryption_context.seal(plaintext, &aad)?; + let ciphertext = encryption_context.seal(plaintext, &[])?; let mut message_a = encapsulated_key.to_bytes().to_vec(); - message_a.extend(&aad); message_a.extend(&ciphertext); Ok(message_a.to_vec()) } @@ -167,21 +161,15 @@ pub fn decrypt_message_a( cursor.read_exact(&mut enc).map_err(|_| HpkeError::PayloadTooShort)?; let enc = EncappedKey::from_bytes(&enc)?; - let mut aad = [0u8; UNCOMPRESSED_PUBLIC_KEY_SIZE]; - cursor.read_exact(&mut aad).map_err(|_| HpkeError::PayloadTooShort)?; - let encapsulation_pk = PublicKey::from_bytes(&aad)?; - - let mut decryption_ctx = - hpke::setup_receiver::( - &OpModeR::Auth(encapsulation_pk.clone()), - &receiver_sk.0, - &enc, - INFO_A, - )?; + let mut decryption_ctx = hpke::setup_receiver::< + ChaCha20Poly1305, + HkdfSha256, + SecpK256HkdfSha256, + >(&OpModeR::Base, &receiver_sk.0, &enc, INFO_A)?; let mut ciphertext = Vec::new(); cursor.read_to_end(&mut ciphertext).map_err(|_| HpkeError::PayloadTooShort)?; - let plaintext = decryption_ctx.open(&ciphertext, &aad)?; + let plaintext = decryption_ctx.open(&ciphertext, &[])?; let reply_pk_bytes = &plaintext[..UNCOMPRESSED_PUBLIC_KEY_SIZE]; let reply_pk = HpkePublicKey(PublicKey::from_bytes(reply_pk_bytes)?); diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 66afa3ab..d52d0a04 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -319,7 +319,6 @@ impl Sender { let hpke_ctx = HpkeContext::new(rs); let body = encrypt_message_a( body, - &hpke_ctx.encapsulation_pair.clone(), &hpke_ctx.reply_pair.public_key().clone(), &hpke_ctx.receiver.clone(), ) @@ -439,7 +438,6 @@ impl V2GetContext { url.set_path(&subdir); let body = encrypt_message_a( Vec::new(), - &self.hpke_ctx.encapsulation_pair.clone(), &self.hpke_ctx.reply_pair.public_key().clone(), &self.hpke_ctx.receiver.clone(), ) @@ -496,18 +494,13 @@ pub struct PsbtContext { #[cfg(feature = "v2")] struct HpkeContext { receiver: HpkePublicKey, - encapsulation_pair: HpkeKeyPair, reply_pair: HpkeKeyPair, } #[cfg(feature = "v2")] impl HpkeContext { pub fn new(receiver: HpkePublicKey) -> Self { - Self { - receiver, - encapsulation_pair: HpkeKeyPair::gen_keypair(), - reply_pair: HpkeKeyPair::gen_keypair(), - } + Self { receiver, reply_pair: HpkeKeyPair::gen_keypair() } } } From 9c4880c932fb8f084e041d29df84df10e6ebb18d Mon Sep 17 00:00:00 2001 From: DanGould Date: Sat, 19 Oct 2024 14:06:28 -0400 Subject: [PATCH 09/12] Define INFO_x msg as constant size array ref --- payjoin/src/hpke.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index cdbbf097..afa7068a 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -13,8 +13,8 @@ pub const PADDED_MESSAGE_BYTES: usize = 7168; pub const PADDED_PLAINTEXT_A_LENGTH: usize = PADDED_MESSAGE_BYTES - UNCOMPRESSED_PUBLIC_KEY_SIZE * 2; pub const PADDED_PLAINTEXT_B_LENGTH: usize = PADDED_MESSAGE_BYTES - UNCOMPRESSED_PUBLIC_KEY_SIZE; -pub const INFO_A: &[u8] = b"PjV2MsgA"; -pub const INFO_B: &[u8] = b"PjV2MsgB"; +pub const INFO_A: &[u8; 8] = b"PjV2MsgA"; +pub const INFO_B: &[u8; 8] = b"PjV2MsgB"; pub type SecretKey = ::PrivateKey; pub type PublicKey = ::PublicKey; From 69c780051ace3ca7a16c568f5311233a19867ef1 Mon Sep 17 00:00:00 2001 From: DanGould Date: Sun, 20 Oct 2024 12:15:01 -0400 Subject: [PATCH 10/12] Ellswift Encode public keys in hpke wire protocol --- payjoin/src/hpke.rs | 45 ++++++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index afa7068a..618824a3 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -1,7 +1,8 @@ use std::ops::Deref; use std::{error, fmt}; -use bitcoin::key::constants::UNCOMPRESSED_PUBLIC_KEY_SIZE; +use bitcoin::key::constants::{ELLSWIFT_ENCODING_SIZE, UNCOMPRESSED_PUBLIC_KEY_SIZE}; +use bitcoin::secp256k1::ellswift::ElligatorSwift; use hpke::aead::ChaCha20Poly1305; use hpke::kdf::HkdfSha256; use hpke::kem::SecpK256HkdfSha256; @@ -10,8 +11,7 @@ use hpke::{Deserializable, OpModeR, OpModeS, Serializable}; use serde::{Deserialize, Serialize}; pub const PADDED_MESSAGE_BYTES: usize = 7168; -pub const PADDED_PLAINTEXT_A_LENGTH: usize = - PADDED_MESSAGE_BYTES - UNCOMPRESSED_PUBLIC_KEY_SIZE * 2; +pub const PADDED_PLAINTEXT_A_LENGTH: usize = PADDED_MESSAGE_BYTES - ELLSWIFT_ENCODING_SIZE; pub const PADDED_PLAINTEXT_B_LENGTH: usize = PADDED_MESSAGE_BYTES - UNCOMPRESSED_PUBLIC_KEY_SIZE; pub const INFO_A: &[u8; 8] = b"PjV2MsgA"; pub const INFO_B: &[u8; 8] = b"PjV2MsgB"; @@ -36,6 +36,23 @@ impl HpkeKeyPair { pub fn public_key(&self) -> &HpkePublicKey { &self.1 } } +fn encapped_key_from_ellswift_bytes(encoded: &[u8]) -> Result { + let mut buf = [0u8; ELLSWIFT_ENCODING_SIZE]; + buf.copy_from_slice(encoded); + let ellswift = ElligatorSwift::from_array(buf); + let pk = bitcoin::secp256k1::PublicKey::from_ellswift(ellswift); + Ok(EncappedKey::from_bytes(pk.serialize_uncompressed().as_slice())?) +} + +fn ellswift_bytes_from_encapped_key( + enc: &EncappedKey, +) -> Result<[u8; ELLSWIFT_ENCODING_SIZE], HpkeError> { + let uncompressed = enc.to_bytes(); + let pk = bitcoin::secp256k1::PublicKey::from_slice(&uncompressed)?; + let ellswift = ElligatorSwift::from_pubkey(pk); + Ok(ellswift.to_array()) +} + #[derive(Clone, PartialEq, Eq)] pub struct HpkeSecretKey(pub SecretKey); @@ -143,7 +160,7 @@ pub fn encrypt_message_a( plaintext.extend(body); let plaintext = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_A_LENGTH)?; let ciphertext = encryption_context.seal(plaintext, &[])?; - let mut message_a = encapsulated_key.to_bytes().to_vec(); + let mut message_a = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec(); message_a.extend(&ciphertext); Ok(message_a.to_vec()) } @@ -157,9 +174,9 @@ pub fn decrypt_message_a( let mut cursor = Cursor::new(message_a); - let mut enc = [0u8; UNCOMPRESSED_PUBLIC_KEY_SIZE]; - cursor.read_exact(&mut enc).map_err(|_| HpkeError::PayloadTooShort)?; - let enc = EncappedKey::from_bytes(&enc)?; + let mut enc_bytes = [0u8; ELLSWIFT_ENCODING_SIZE]; + cursor.read_exact(&mut enc_bytes).map_err(|_| HpkeError::PayloadTooShort)?; + let enc = encapped_key_from_ellswift_bytes(&enc_bytes)?; let mut decryption_ctx = hpke::setup_receiver::< ChaCha20Poly1305, @@ -196,9 +213,9 @@ pub fn encrypt_message_b( INFO_B, &mut OsRng, )?; - let plaintext = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_B_LENGTH)?; + let plaintext: &[u8] = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_B_LENGTH)?; let ciphertext = encryption_context.seal(plaintext, &[])?; - let mut message_b = encapsulated_key.to_bytes().to_vec(); + let mut message_b = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec(); message_b.extend(&ciphertext); Ok(message_b.to_vec()) } @@ -209,17 +226,15 @@ pub fn decrypt_message_b( receiver_pk: HpkePublicKey, sender_sk: HpkeSecretKey, ) -> Result, HpkeError> { - let enc = message_b.get(..65).ok_or(HpkeError::PayloadTooShort)?; - let enc = EncappedKey::from_bytes(enc)?; + let enc = message_b.get(..ELLSWIFT_ENCODING_SIZE).ok_or(HpkeError::PayloadTooShort)?; + let enc = encapped_key_from_ellswift_bytes(enc)?; let mut decryption_ctx = hpke::setup_receiver::< ChaCha20Poly1305, HkdfSha256, SecpK256HkdfSha256, >(&OpModeR::Auth(receiver_pk.0), &sender_sk.0, &enc, INFO_B)?; - let plaintext = decryption_ctx.open( - message_b.get(UNCOMPRESSED_PUBLIC_KEY_SIZE..).ok_or(HpkeError::PayloadTooShort)?, - &[], - )?; + let plaintext = decryption_ctx + .open(message_b.get(ELLSWIFT_ENCODING_SIZE..).ok_or(HpkeError::PayloadTooShort)?, &[])?; Ok(plaintext) } From 5d8373757b23380dc3c8bf0b8ff28622329192cf Mon Sep 17 00:00:00 2001 From: Yuval Kogman Date: Sun, 20 Oct 2024 18:44:37 +0200 Subject: [PATCH 11/12] test HPKE encryption/decryption The constants PADDED_PLAINTEXT_{A,B}_LENGTH now represent the maximum payload size for the input, whereas before the message A constant included the reply key size. This makes the PayloadTooLarge error represent a maximum and actual size that correspond to the inputs to encrypt_message_a and encrypt_message_b. --- payjoin/src/hpke.rs | 181 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 176 insertions(+), 5 deletions(-) diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index 618824a3..c5cfb429 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -11,8 +11,11 @@ use hpke::{Deserializable, OpModeR, OpModeS, Serializable}; use serde::{Deserialize, Serialize}; pub const PADDED_MESSAGE_BYTES: usize = 7168; -pub const PADDED_PLAINTEXT_A_LENGTH: usize = PADDED_MESSAGE_BYTES - ELLSWIFT_ENCODING_SIZE; -pub const PADDED_PLAINTEXT_B_LENGTH: usize = PADDED_MESSAGE_BYTES - UNCOMPRESSED_PUBLIC_KEY_SIZE; +pub const PADDED_PLAINTEXT_A_LENGTH: usize = PADDED_MESSAGE_BYTES + - (ELLSWIFT_ENCODING_SIZE + UNCOMPRESSED_PUBLIC_KEY_SIZE + POLY1305_TAG_SIZE); +pub const PADDED_PLAINTEXT_B_LENGTH: usize = + PADDED_MESSAGE_BYTES - (ELLSWIFT_ENCODING_SIZE + POLY1305_TAG_SIZE); +pub const POLY1305_TAG_SIZE: usize = 16; // FIXME there is a U16 defined for poly1305, should bitcoin hpke re-export it? pub const INFO_A: &[u8; 8] = b"PjV2MsgA"; pub const INFO_B: &[u8; 8] = b"PjV2MsgB"; @@ -156,10 +159,11 @@ pub fn encrypt_message_a( INFO_A, &mut OsRng, )?; + let mut body = body; + pad_plaintext(&mut body, PADDED_PLAINTEXT_A_LENGTH)?; let mut plaintext = reply_pk.to_bytes().to_vec(); plaintext.extend(body); - let plaintext = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_A_LENGTH)?; - let ciphertext = encryption_context.seal(plaintext, &[])?; + let ciphertext = encryption_context.seal(&plaintext, &[])?; let mut message_a = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec(); message_a.extend(&ciphertext); Ok(message_a.to_vec()) @@ -247,7 +251,7 @@ fn pad_plaintext(msg: &mut Vec, padded_length: usize) -> Result<&[u8], HpkeE } /// Error from de/encrypting a v2 Hybrid Public Key Encryption payload. -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum HpkeError { Secp256k1(bitcoin::secp256k1::Error), Hpke(hpke::HpkeError), @@ -296,3 +300,170 @@ impl error::Error for HpkeError { } } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn message_a_round_trip() { + let mut plaintext = "foo".as_bytes().to_vec(); + + let reply_keypair = HpkeKeyPair::gen_keypair(); + let receiver_keypair = HpkeKeyPair::gen_keypair(); + + let message_a = encrypt_message_a( + plaintext.clone(), + reply_keypair.public_key(), + receiver_keypair.public_key(), + ) + .expect("encryption should work"); + assert_eq!(message_a.len(), PADDED_MESSAGE_BYTES); + + let decrypted = decrypt_message_a(&message_a, receiver_keypair.secret_key().clone()) + .expect("decryption should work"); + + assert_eq!(decrypted.0.len(), PADDED_PLAINTEXT_A_LENGTH); + + // decrypted plaintext is padded, so pad the expected plaintext + plaintext.resize(PADDED_PLAINTEXT_A_LENGTH, 0); + assert_eq!(decrypted, (plaintext.to_vec(), reply_keypair.public_key().clone())); + + // ensure full plaintext round trips + plaintext[PADDED_PLAINTEXT_A_LENGTH - 1] = 42; + let message_a = encrypt_message_a( + plaintext.clone(), + reply_keypair.public_key(), + receiver_keypair.public_key(), + ) + .expect("encryption should work"); + + let decrypted = decrypt_message_a(&message_a, receiver_keypair.secret_key().clone()) + .expect("decryption should work"); + + assert_eq!(decrypted.0.len(), plaintext.len()); + assert_eq!(decrypted, (plaintext.to_vec(), reply_keypair.public_key().clone())); + + let unrelated_keypair = HpkeKeyPair::gen_keypair(); + assert_eq!( + decrypt_message_a(&message_a, unrelated_keypair.secret_key().clone()), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + + let mut corrupted_message_a = message_a.clone(); + corrupted_message_a[3] ^= 1; // corrupt dhkem + assert_eq!( + decrypt_message_a(&corrupted_message_a, receiver_keypair.secret_key().clone()), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + let mut corrupted_message_a = message_a.clone(); + corrupted_message_a[PADDED_MESSAGE_BYTES - 3] ^= 1; // corrupt aead ciphertext + assert_eq!( + decrypt_message_a(&corrupted_message_a, receiver_keypair.secret_key().clone()), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + + plaintext.resize(PADDED_PLAINTEXT_A_LENGTH + 1, 0); + assert_eq!( + encrypt_message_a( + plaintext.clone(), + reply_keypair.public_key(), + receiver_keypair.public_key(), + ), + Err(HpkeError::PayloadTooLarge { + actual: PADDED_PLAINTEXT_A_LENGTH + 1, + max: PADDED_PLAINTEXT_A_LENGTH, + }) + ); + } + + #[test] + fn message_b_round_trip() { + let mut plaintext = "foo".as_bytes().to_vec(); + + let reply_keypair = HpkeKeyPair::gen_keypair(); + let receiver_keypair = HpkeKeyPair::gen_keypair(); + + let message_b = + encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()) + .expect("encryption should work"); + + assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES); + + let decrypted = decrypt_message_b( + &message_b, + receiver_keypair.public_key().clone(), + reply_keypair.secret_key().clone(), + ) + .expect("decryption should work"); + + assert_eq!(decrypted.len(), PADDED_PLAINTEXT_B_LENGTH); + // decrypted plaintext is padded, so pad the expected plaintext + plaintext.resize(PADDED_PLAINTEXT_B_LENGTH, 0); + assert_eq!(decrypted, plaintext.to_vec()); + + plaintext[PADDED_PLAINTEXT_B_LENGTH - 1] = 42; + let message_b = + encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()) + .expect("encryption should work"); + + assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES); + + let decrypted = decrypt_message_b( + &message_b, + receiver_keypair.public_key().clone(), + reply_keypair.secret_key().clone(), + ) + .expect("decryption should work"); + assert_eq!(decrypted.len(), plaintext.len()); + assert_eq!(decrypted, plaintext.to_vec()); + + let unrelated_keypair = HpkeKeyPair::gen_keypair(); + assert_eq!( + decrypt_message_b( + &message_b, + receiver_keypair.public_key().clone(), + unrelated_keypair.secret_key().clone() // wrong decryption key + ), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + assert_eq!( + decrypt_message_b( + &message_b, + unrelated_keypair.public_key().clone(), // wrong auth key + reply_keypair.secret_key().clone() + ), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + + let mut corrupted_message_b = message_b.clone(); + corrupted_message_b[3] ^= 1; // corrupt dhkem + assert_eq!( + decrypt_message_b( + &corrupted_message_b, + receiver_keypair.public_key().clone(), + reply_keypair.secret_key().clone() + ), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + let mut corrupted_message_b = message_b.clone(); + corrupted_message_b[PADDED_MESSAGE_BYTES - 3] ^= 1; // corrupt aead ciphertext + assert_eq!( + decrypt_message_b( + &corrupted_message_b, + receiver_keypair.public_key().clone(), + reply_keypair.secret_key().clone() + ), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + + plaintext.resize(PADDED_PLAINTEXT_B_LENGTH + 1, 0); + assert_eq!( + encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()), + Err(HpkeError::PayloadTooLarge { + actual: PADDED_PLAINTEXT_B_LENGTH + 1, + max: PADDED_PLAINTEXT_B_LENGTH + }) + ); + } +} From 96c0eb609dfd9db59f4e1671bf01be0280f06089 Mon Sep 17 00:00:00 2001 From: DanGould Date: Mon, 21 Oct 2024 13:45:03 -0400 Subject: [PATCH 12/12] Adjust send documentation to use Sender --- payjoin/src/send/mod.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index d52d0a04..f9db3547 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -9,11 +9,10 @@ //! 2. Construct URI request parameters, a finalized “Original PSBT” paying .amount to .address //! 3. (optional) Spawn a thread or async task that will broadcast the original PSBT fallback after //! delay (e.g. 1 minute) unless canceled -//! 4. Construct the request using [`RequestBuilder`] with the PSBT and payjoin uri -//! 5. Send the request and receive response -//! 6. Process the response with [`ContextV1::process_response`] -//! 7. Sign and finalize the Payjoin Proposal PSBT -//! 8. Broadcast the Payjoin Transaction (and cancel the optional fallback broadcast) +//! 4. Construct the [`Sender`] using [`SenderBuilder`] with the PSBT and payjoin uri +//! 5. Send the request(s) and receive response(s) by following on the extracted [`Context`] +//! 6. Sign and finalize the Payjoin Proposal PSBT +//! 7. Broadcast the Payjoin Transaction (and cancel the optional fallback broadcast) //! //! This crate is runtime-agnostic. Data persistence, chain interactions, and networking may be //! provided by custom implementations or copy the reference