Skip to content

Commit

Permalink
Request with implicit initialization pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Oct 14, 2024
1 parent 8b40f3c commit a6c2a8a
Show file tree
Hide file tree
Showing 8 changed files with 339 additions and 218 deletions.
6 changes: 3 additions & 3 deletions payjoin-cli/src/app/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use anyhow::{anyhow, Context, Result};
use bitcoincore_rpc::bitcoin::Amount;
use bitcoincore_rpc::RpcApi;
use payjoin::bitcoin::psbt::Psbt;
use payjoin::send::RequestContext;
use payjoin::send::Sender;
use payjoin::{bitcoin, PjUri};

pub mod config;
Expand All @@ -28,7 +28,7 @@ pub trait App {
async fn send_payjoin(&self, bip21: &str, fee_rate: &f32) -> Result<()>;
async fn receive_payjoin(self, amount_arg: &str) -> Result<()>;

fn create_pj_request(&self, uri: &PjUri, fee_rate: &f32) -> Result<RequestContext> {
fn create_pj_request(&self, uri: &PjUri, fee_rate: &f32) -> Result<Sender> {
let amount = uri.amount.ok_or_else(|| anyhow!("please specify the amount in the Uri"))?;

// wallet_create_funded_psbt requires a HashMap<address: String, Amount>
Expand Down Expand Up @@ -64,7 +64,7 @@ pub trait App {
.psbt;
let psbt = Psbt::from_str(&psbt).with_context(|| "Failed to load PSBT from base64")?;
log::debug!("Original psbt: {:#?}", psbt);
let req_ctx = payjoin::send::RequestBuilder::from_psbt_and_uri(psbt, uri.clone())
let req_ctx = payjoin::send::SenderBuilder::from_psbt_and_uri(psbt, uri.clone())
.with_context(|| "Failed to build payjoin request")?
.build_recommended(fee_rate)
.with_context(|| "Failed to build payjoin request")?;
Expand Down
77 changes: 52 additions & 25 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use payjoin::bitcoin::consensus::encode::serialize_hex;
use payjoin::bitcoin::psbt::Psbt;
use payjoin::bitcoin::{Amount, FeeRate};
use payjoin::receive::v2::ActiveSession;
use payjoin::send::RequestContext;
use payjoin::send::Sender;
use payjoin::{bitcoin, Error, Uri};
use tokio::signal;
use tokio::sync::watch;
Expand Down Expand Up @@ -91,7 +91,7 @@ impl AppTrait for App {
}

impl App {
async fn spawn_payjoin_sender(&self, mut req_ctx: RequestContext) -> Result<()> {
async fn spawn_payjoin_sender(&self, mut req_ctx: Sender) -> Result<()> {
let mut interrupt = self.interrupt.clone();
tokio::select! {
res = self.long_poll_post(&mut req_ctx) => {
Expand Down Expand Up @@ -197,30 +197,57 @@ impl App {
Ok(())
}

async fn long_poll_post(&self, req_ctx: &mut payjoin::send::RequestContext) -> Result<Psbt> {
loop {
let (req, ctx) = req_ctx.extract_v2(self.config.ohttp_relay.clone())?;
println!("Polling send request...");
let http = http_agent()?;
let response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;

println!("Sent fallback transaction");
match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) {
Ok(Some(psbt)) => return Ok(psbt),
Ok(None) => {
println!("No response yet.");
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
async fn long_poll_post(&self, req_ctx: &mut payjoin::send::Sender) -> Result<Psbt> {
let (req, ctx) = req_ctx.extract_highest_version(self.config.ohttp_relay.clone())?;
println!("Posting Original PSBT Payload request...");
let http = http_agent()?;
let response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
println!("Sent fallback transaction");
match ctx {
payjoin::send::Context::V2(ctx) => {
let v2_ctx = Arc::new(
ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?,
);
loop {
let (req, ohttp_ctx) = v2_ctx.extract_req(self.config.ohttp_relay.clone())?;
let response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
match v2_ctx.process_response(
&mut response.bytes().await?.to_vec().as_slice(),
ohttp_ctx,
) {
Ok(Some(psbt)) => return Ok(psbt),
Ok(None) => {
println!("No response yet.");
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
Err(re) => {
println!("{}", re);
log::debug!("{:?}", re);
return Err(anyhow!("Response error").context(re));
}
}
}
Err(re) => {
println!("{}", re);
log::debug!("{:?}", re);
return Err(anyhow!("Response error").context(re));
}
payjoin::send::Context::V1(ctx) => {
match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) {
Ok(psbt) => Ok(psbt),
Err(re) => {
println!("{}", re);
log::debug!("{:?}", re);
Err(anyhow!("Response error").context(re))
}
}
}
}
Expand Down
18 changes: 6 additions & 12 deletions payjoin-cli/src/db/v2.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use bitcoincore_rpc::jsonrpc::serde_json;
use payjoin::receive::v2::ActiveSession;
use payjoin::send::RequestContext;
use payjoin::send::Sender;
use sled::{IVec, Tree};
use url::Url;

Expand Down Expand Up @@ -35,35 +35,29 @@ impl Database {
Ok(())
}

pub(crate) fn insert_send_session(
&self,
session: &mut RequestContext,
pj_url: &Url,
) -> Result<()> {
pub(crate) fn insert_send_session(&self, session: &mut Sender, pj_url: &Url) -> Result<()> {
let send_tree: Tree = self.0.open_tree("send_sessions")?;
let value = serde_json::to_string(session).map_err(Error::Serialize)?;
send_tree.insert(pj_url.to_string(), IVec::from(value.as_str()))?;
send_tree.flush()?;
Ok(())
}

pub(crate) fn get_send_sessions(&self) -> Result<Vec<RequestContext>> {
pub(crate) fn get_send_sessions(&self) -> Result<Vec<Sender>> {
let send_tree: Tree = self.0.open_tree("send_sessions")?;
let mut sessions = Vec::new();
for item in send_tree.iter() {
let (_, value) = item?;
let session: RequestContext =
serde_json::from_slice(&value).map_err(Error::Deserialize)?;
let session: Sender = serde_json::from_slice(&value).map_err(Error::Deserialize)?;
sessions.push(session);
}
Ok(sessions)
}

pub(crate) fn get_send_session(&self, pj_url: &Url) -> Result<Option<RequestContext>> {
pub(crate) fn get_send_session(&self, pj_url: &Url) -> Result<Option<Sender>> {
let send_tree = self.0.open_tree("send_sessions")?;
if let Some(val) = send_tree.get(pj_url.to_string())? {
let session: RequestContext =
serde_json::from_slice(&val).map_err(Error::Deserialize)?;
let session: Sender = serde_json::from_slice(&val).map_err(Error::Deserialize)?;
Ok(Some(session))
} else {
Ok(None)
Expand Down
20 changes: 10 additions & 10 deletions payjoin-directory/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use futures::StreamExt;
use redis::{AsyncCommands, Client, ErrorKind, RedisError, RedisResult};
use tracing::debug;

const RES_COLUMN: &str = "res";
const REQ_COLUMN: &str = "req";
const DEFAULT_COLUMN: &str = "";
const PJ_V1_COLUMN: &str = "pjv1";

#[derive(Debug, Clone)]
pub(crate) struct DbPool {
Expand All @@ -19,20 +19,20 @@ impl DbPool {
Ok(Self { client, timeout })
}

pub async fn peek_req(&self, pubkey_id: &str) -> Option<RedisResult<Vec<u8>>> {
self.peek_with_timeout(pubkey_id, REQ_COLUMN).await
pub async fn push_default(&self, pubkey_id: &str, data: Vec<u8>) -> RedisResult<()> {
self.push(pubkey_id, DEFAULT_COLUMN, data).await
}

pub async fn peek_res(&self, pubkey_id: &str) -> Option<RedisResult<Vec<u8>>> {
self.peek_with_timeout(pubkey_id, RES_COLUMN).await
pub async fn peek_default(&self, pubkey_id: &str) -> Option<RedisResult<Vec<u8>>> {
self.peek_with_timeout(pubkey_id, DEFAULT_COLUMN).await
}

pub async fn push_req(&self, pubkey_id: &str, data: Vec<u8>) -> RedisResult<()> {
self.push(pubkey_id, REQ_COLUMN, data).await
pub async fn push_v1(&self, pubkey_id: &str, data: Vec<u8>) -> RedisResult<()> {
self.push(pubkey_id, PJ_V1_COLUMN, data).await
}

pub async fn push_res(&self, pubkey_id: &str, data: Vec<u8>) -> RedisResult<()> {
self.push(pubkey_id, RES_COLUMN, data).await
pub async fn peek_v1(&self, pubkey_id: &str) -> Option<RedisResult<Vec<u8>>> {
self.peek_with_timeout(pubkey_id, PJ_V1_COLUMN).await
}

async fn push(&self, pubkey_id: &str, channel_type: &str, data: Vec<u8>) -> RedisResult<()> {
Expand Down
80 changes: 39 additions & 41 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ async fn handle_v2(
let path_segments: Vec<&str> = path.split('/').collect();
debug!("handle_v2: {:?}", &path_segments);
match (parts.method, path_segments.as_slice()) {
(Method::POST, &["", id]) => post_fallback_v2(id, body, pool).await,
(Method::GET, &["", id]) => get_fallback(id, pool).await,
(Method::PUT, &["", id]) => post_payjoin(id, body, pool).await,
(Method::POST, &["", id]) => post_subdir(id, body, pool).await,
(Method::GET, &["", id]) => get_subdir(id, pool).await,
(Method::PUT, &["", id]) => put_payjoin_v1(id, body, pool).await,
_ => Ok(not_found()),
}
}
Expand Down Expand Up @@ -294,55 +294,69 @@ async fn post_fallback_v1(
Err(_) => return Ok(bad_request_body_res),
};

let v2_compat_body = full(format!("{}\n{}", body_str, query));
post_fallback(id, v2_compat_body, pool, none_response).await
let v2_compat_body = format!("{}\n{}", body_str, query);
let id = shorten_string(id);
pool.push_default(&id, v2_compat_body.into())
.await
.map_err(|e| HandlerError::BadRequest(e.into()))?;
match pool.peek_v1(&id).await {
Some(result) => match result {
Ok(buffered_req) => Ok(Response::new(full(buffered_req))),
Err(e) => Err(HandlerError::BadRequest(e.into())),
},
None => Ok(none_response),
}
}

async fn post_fallback_v2(
async fn put_payjoin_v1(
id: &str,
body: BoxBody<Bytes, hyper::Error>,
pool: DbPool,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
trace!("Post fallback v2");
let none_response = Response::builder().status(StatusCode::ACCEPTED).body(empty())?;
post_fallback(id, body, pool, none_response).await
trace!("Put_payjoin_v1");
let ok_response = Response::builder().status(StatusCode::OK).body(empty())?;

let id = shorten_string(id);
let req =
body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes();
if req.len() > MAX_BUFFER_SIZE {
return Err(HandlerError::PayloadTooLarge);
}

match pool.push_v1(&id, req.into()).await {
Ok(_) => Ok(ok_response),
Err(e) => Err(HandlerError::BadRequest(e.into())),
}
}

async fn post_fallback(
async fn post_subdir(
id: &str,
body: BoxBody<Bytes, hyper::Error>,
pool: DbPool,
none_response: Response<BoxBody<Bytes, hyper::Error>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
tracing::trace!("Post fallback");
let none_response = Response::builder().status(StatusCode::OK).body(empty())?;
tracing::trace!("Post subdir");

let id = shorten_string(id);
let req =
body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes();
if req.len() > MAX_BUFFER_SIZE {
return Err(HandlerError::PayloadTooLarge);
}

match pool.push_req(&id, req.into()).await {
Ok(_) => (),
Err(e) => return Err(HandlerError::BadRequest(e.into())),
};

match pool.peek_res(&id).await {
Some(result) => match result {
Ok(buffered_res) => Ok(Response::new(full(buffered_res))),
Err(e) => Err(HandlerError::BadRequest(e.into())),
},
None => Ok(none_response),
match pool.push_default(&id, req.into()).await {
Ok(_) => Ok(none_response),
Err(e) => Err(HandlerError::BadRequest(e.into())),
}
}

async fn get_fallback(
async fn get_subdir(
id: &str,
pool: DbPool,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
trace!("GET fallback");
let id = shorten_string(id);
match pool.peek_req(&id).await {
match pool.peek_default(&id).await {
Some(result) => match result {
Ok(buffered_req) => Ok(Response::new(full(buffered_req))),
Err(e) => Err(HandlerError::BadRequest(e.into())),
Expand All @@ -351,22 +365,6 @@ async fn get_fallback(
}
}

async fn post_payjoin(
id: &str,
body: BoxBody<Bytes, hyper::Error>,
pool: DbPool,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
trace!("POST payjoin");
let id = shorten_string(id);
let res =
body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes();

match pool.push_res(&id, res.into()).await {
Ok(_) => Ok(Response::builder().status(StatusCode::NO_CONTENT).body(empty())?),
Err(e) => Err(HandlerError::BadRequest(e.into())),
}
}

fn not_found() -> Response<BoxBody<Bytes, hyper::Error>> {
let mut res = Response::default();
*res.status_mut() = StatusCode::NOT_FOUND;
Expand Down
Loading

0 comments on commit a6c2a8a

Please sign in to comment.