Skip to content

Commit

Permalink
Make session initialization implicit
Browse files Browse the repository at this point in the history
A session is now initialized by generating keys and sharing them out
of band. The semantics of the protocol are otherwise unchanged.
  • Loading branch information
DanGould committed Oct 10, 2024
1 parent a9b9a34 commit 8b40f3c
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 179 deletions.
18 changes: 1 addition & 17 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,32 +75,16 @@ impl AppTrait for App {
}

async fn receive_payjoin(self, amount_arg: &str) -> Result<()> {
use payjoin::receive::v2::SessionInitializer;

let address = self.bitcoind()?.get_new_address(None, None)?.assume_checked();
let amount = Amount::from_sat(amount_arg.parse()?);
let ohttp_keys = unwrap_ohttp_keys_or_else_fetch(&self.config).await?;
let mut initializer = SessionInitializer::new(
let session = ActiveSession::new(
address,
self.config.pj_directory.clone(),
ohttp_keys.clone(),
self.config.ohttp_relay.clone(),
None,
);
let (req, ctx) =
initializer.extract_req().map_err(|e| anyhow!("Failed to extract request {}", e))?;
println!("Starting new Payjoin session with {}", self.config.pj_directory);
let http = http_agent()?;
let ohttp_response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
let session = initializer
.process_res(ohttp_response.bytes().await?.to_vec().as_slice(), ctx)
.map_err(|e| anyhow!("Enrollment failed {}", e))?;
self.db.insert_recv_session(session.clone())?;
self.spawn_payjoin_receiver(session, Some(amount)).await
}
Expand Down
9 changes: 1 addition & 8 deletions payjoin-cli/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,14 +482,7 @@ mod e2e {
let db = docker.run(Redis::default());
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
println!("Database running on {}", db.get_host_port_ipv4(6379));
payjoin_directory::listen_tcp_with_tls(
format!("http://localhost:{}", port),
port,
db_host,
timeout,
local_cert_key,
)
.await
payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await
}

// generates or gets a DER encoded localhost cert and key.
Expand Down
68 changes: 5 additions & 63 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ use std::sync::Arc;
use std::time::Duration;

use anyhow::Result;
use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD;
use bitcoin::base64::Engine;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty, Full};
use hyper::body::{Body, Bytes, Incoming};
use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE, LOCATION};
use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode, Uri};
Expand All @@ -20,7 +18,6 @@ use tracing::{debug, error, info, trace};
pub const DEFAULT_DIR_PORT: u16 = 8080;
pub const DEFAULT_DB_HOST: &str = "localhost:6379";
pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_BASE_URL: &str = "https://localhost";

const MAX_BUFFER_SIZE: usize = 65536;

Expand All @@ -32,7 +29,6 @@ mod db;
use crate::db::DbPool;

pub async fn listen_tcp(
base_url: String,
port: u16,
db_host: String,
timeout: Duration,
Expand All @@ -44,14 +40,13 @@ pub async fn listen_tcp(
while let Ok((stream, _)) = listener.accept().await {
let pool = pool.clone();
let ohttp = ohttp.clone();
let base_url = base_url.clone();
let io = TokioIo::new(stream);
tokio::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| {
serve_payjoin_directory(req, pool.clone(), ohttp.clone(), base_url.clone())
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
}),
)
.with_upgrades()
Expand All @@ -67,7 +62,6 @@ pub async fn listen_tcp(

#[cfg(feature = "danger-local-https")]
pub async fn listen_tcp_with_tls(
base_url: String,
port: u16,
db_host: String,
timeout: Duration,
Expand All @@ -81,7 +75,6 @@ pub async fn listen_tcp_with_tls(
while let Ok((stream, _)) = listener.accept().await {
let pool = pool.clone();
let ohttp = ohttp.clone();
let base_url = base_url.clone();
let tls_acceptor = tls_acceptor.clone();
tokio::spawn(async move {
let tls_stream = match tls_acceptor.accept(stream).await {
Expand All @@ -95,7 +88,7 @@ pub async fn listen_tcp_with_tls(
.serve_connection(
TokioIo::new(tls_stream),
service_fn(move |req| {
serve_payjoin_directory(req, pool.clone(), ohttp.clone(), base_url.clone())
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
}),
)
.with_upgrades()
Expand Down Expand Up @@ -146,7 +139,6 @@ async fn serve_payjoin_directory(
req: Request<Incoming>,
pool: DbPool,
ohttp: Arc<Mutex<ohttp::Server>>,
base_url: String,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>> {
let path = req.uri().path().to_string();
let query = req.uri().query().unwrap_or_default().to_string();
Expand All @@ -155,7 +147,7 @@ async fn serve_payjoin_directory(
let path_segments: Vec<&str> = path.split('/').collect();
debug!("serve_payjoin_directory: {:?}", &path_segments);
let mut response = match (parts.method, path_segments.as_slice()) {
(Method::POST, ["", ""]) => handle_ohttp_gateway(body, pool, ohttp, base_url).await,
(Method::POST, ["", ""]) => handle_ohttp_gateway(body, pool, ohttp).await,
(Method::GET, ["", "ohttp-keys"]) => get_ohttp_keys(&ohttp).await,
(Method::POST, ["", id]) => post_fallback_v1(id, query, body, pool).await,
(Method::GET, ["", "health"]) => health_check().await,
Expand All @@ -173,7 +165,6 @@ async fn handle_ohttp_gateway(
body: Incoming,
pool: DbPool,
ohttp: Arc<Mutex<ohttp::Server>>,
base_url: String,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
// decapsulate
let ohttp_body =
Expand All @@ -199,7 +190,7 @@ async fn handle_ohttp_gateway(
}
let request = http_req.body(full(body))?;

let response = handle_v2(pool, base_url, request).await?;
let response = handle_v2(pool, request).await?;

let (parts, body) = response.into_parts();
let mut bhttp_res = bhttp::Message::response(parts.status.as_u16());
Expand All @@ -221,7 +212,6 @@ async fn handle_ohttp_gateway(

async fn handle_v2(
pool: DbPool,
base_url: String,
req: Request<BoxBody<Bytes, hyper::Error>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
let path = req.uri().path().to_string();
Expand All @@ -230,7 +220,6 @@ async fn handle_v2(
let path_segments: Vec<&str> = path.split('/').collect();
debug!("handle_v2: {:?}", &path_segments);
match (parts.method, path_segments.as_slice()) {
(Method::POST, &["", ""]) => post_session(base_url, body).await,
(Method::POST, &["", id]) => post_fallback_v2(id, body, pool).await,
(Method::GET, &["", id]) => get_fallback(id, pool).await,
(Method::PUT, &["", id]) => post_payjoin(id, body, pool).await,
Expand Down Expand Up @@ -282,24 +271,6 @@ impl From<hyper::http::Error> for HandlerError {
fn from(e: hyper::http::Error) -> Self { HandlerError::InternalServerError(e.into()) }
}

async fn post_session(
base_url: String,
body: BoxBody<Bytes, hyper::Error>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
let bytes = body.collect().await.map_err(|e| HandlerError::BadRequest(e.into()))?.to_bytes();
let base64_id =
String::from_utf8(bytes.to_vec()).map_err(|e| HandlerError::BadRequest(e.into()))?;
let pubkey_bytes: Vec<u8> =
BASE64_URL_SAFE_NO_PAD.decode(base64_id).map_err(|e| HandlerError::BadRequest(e.into()))?;
let pubkey = bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes)
.map_err(|e| HandlerError::BadRequest(e.into()))?;
tracing::info!("Initialized session with pubkey: {:?}", pubkey);
Ok(Response::builder()
.header(LOCATION, format!("{}/{}", base_url, pubkey))
.status(StatusCode::CREATED)
.body(empty())?)
}

async fn post_fallback_v1(
id: &str,
query: String,
Expand Down Expand Up @@ -425,32 +396,3 @@ fn empty() -> BoxBody<Bytes, hyper::Error> {
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into()).map_err(|never| match never {}).boxed()
}

#[cfg(test)]
mod tests {
use hyper::Request;

use super::*;

/// Ensure that the POST / endpoint returns a 201 Created with a Location header
/// as is semantically correct when creating a resource.
///
/// https://datatracker.ietf.org/doc/html/rfc9110#name-post
#[tokio::test]
async fn test_post_session() -> Result<(), Box<dyn std::error::Error>> {
let base_url = "https://localhost".to_string();
let body = full("A6z245ZfDfnlk7_HiAp6sPmNaVYwADih-vCGE3eysWp7");

let request = Request::builder().method(Method::POST).uri("/").body(body)?;

let response = post_session(base_url.clone(), request.into_body())
.await
.map_err(|e| format!("{:?}", e))?;

assert_eq!(response.status(), StatusCode::CREATED);
assert!(response.headers().contains_key(LOCATION));
let location_header = response.headers().get(LOCATION).ok_or("Missing LOCATION header")?;
assert!(location_header.to_str()?.starts_with(&base_url));
Ok(())
}
}
4 changes: 1 addition & 3 deletions payjoin-directory/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let db_host = env::var("PJ_DB_HOST").unwrap_or_else(|_| DEFAULT_DB_HOST.to_string());

let base_url = env::var("PJ_DIR_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());

payjoin_directory::listen_tcp(base_url, dir_port, db_host, timeout).await
payjoin_directory::listen_tcp(dir_port, db_host, timeout).await
}

fn init_logging() {
Expand Down
66 changes: 10 additions & 56 deletions payjoin/src/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,18 @@ where
Ok(address.assume_checked())
}

/// Initializes a new payjoin session, including necessary context
/// information for communication and cryptographic operations.
#[derive(Debug, Clone)]
pub struct SessionInitializer {
fn subdir_path_from_pubkey(pubkey: &HpkePublicKey) -> String {
BASE64_URL_SAFE_NO_PAD.encode(pubkey.to_compressed_bytes())
}

/// An active payjoin V2 session, allowing for polled requests to the
/// payjoin directory and response processing.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ActiveSession {
context: SessionContext,
}

#[cfg(feature = "v2")]
impl SessionInitializer {
impl ActiveSession {
/// Creates a new `SessionInitializer` with the provided parameters.
///
/// # Parameters
Expand Down Expand Up @@ -90,56 +93,7 @@ impl SessionInitializer {
}
}

pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), Error> {
let url = self.context.ohttp_relay.clone();
let subdirectory = subdir_path_from_pubkey(self.context.s.public_key());
let (body, ctx) = crate::v2::ohttp_encapsulate(
&mut self.context.ohttp_keys,
"POST",
self.context.directory.as_str(),
Some(subdirectory.as_bytes()),
)?;
let req = Request::new_v2(url, body);
Ok((req, ctx))
}

pub fn process_res(
mut self,
mut res: impl std::io::Read,
ctx: ohttp::ClientResponse,
) -> Result<ActiveSession, Error> {
let mut buf = Vec::new();
let _ = res.read_to_end(&mut buf);
let response = crate::v2::ohttp_decapsulate(ctx, &buf)?;
if !response.status().is_success() {
return Err(Error::Server("Enrollment failed, expected success status".into()));
}
log::debug!("Received response headers: {:?}", response.headers());
let location = response
.headers()
.get("location")
.ok_or(Error::Server("Missing location header".into()))?
.to_str()
.map_err(|e| Error::Server(format!("Invalid location header: {}", e).into()))?;
self.context.subdirectory =
Some(url::Url::parse(location).map_err(|e| Error::Server(e.into()))?);

Ok(ActiveSession { context: self.context.clone() })
}
}

fn subdir_path_from_pubkey(pubkey: &HpkePublicKey) -> String {
BASE64_URL_SAFE_NO_PAD.encode(pubkey.to_compressed_bytes())
}

/// An active payjoin V2 session, allowing for polled requests to the
/// payjoin directory and response processing.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ActiveSession {
context: SessionContext,
}

impl ActiveSession {
// OHTTP Encapsulated HTTP GET request for the Original PSBT
pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), SessionError> {
if SystemTime::now() > self.context.expiry {
return Err(InternalSessionError::Expired(self.context.expiry).into());
Expand Down
Loading

0 comments on commit 8b40f3c

Please sign in to comment.