Skip to content

Commit

Permalink
Support backwards compatiblity
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Oct 14, 2024
1 parent f8e459d commit 8f5aca5
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 58 deletions.
19 changes: 12 additions & 7 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,12 @@ impl App {
.send()
.await
.map_err(map_reqwest_err)?;

match ctx {
payjoin::send::Context::V2(ctx) => {
let v2_ctx = Arc::new(ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?);
payjoin::send::Context::V2(ctx) => {
let v2_ctx = Arc::new(
ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?,
);
loop {
let (req, ohttp_ctx) = v2_ctx.extract_req(self.config.ohttp_relay.clone())?;
let response = http
Expand All @@ -221,7 +223,10 @@ impl App {
.send()
.await
.map_err(map_reqwest_err)?;
match v2_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx) {
match v2_ctx.process_response(
&mut response.bytes().await?.to_vec().as_slice(),
ohttp_ctx,
) {
Ok(Some(psbt)) => return Ok(psbt),
Ok(None) => {
println!("No response yet.");
Expand All @@ -234,7 +239,7 @@ impl App {
}
}
}
},
}
payjoin::send::Context::V1(ctx) => {
match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) {
Ok(psbt) => return Ok(psbt),
Expand All @@ -244,8 +249,8 @@ impl App {
return Err(anyhow!("Response error").context(re));
}
}
},
_ => panic!("V1 context expected")
}
_ => panic!("V1 context expected"),
};
println!("Sent fallback transaction");
}
Expand Down
12 changes: 3 additions & 9 deletions payjoin-cli/src/db/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,7 @@ impl Database {
Ok(())
}

pub(crate) fn insert_send_session(
&self,
session: &mut Sender,
pj_url: &Url,
) -> Result<()> {
pub(crate) fn insert_send_session(&self, session: &mut Sender, pj_url: &Url) -> Result<()> {
let send_tree: Tree = self.0.open_tree("send_sessions")?;
let value = serde_json::to_string(session).map_err(Error::Serialize)?;
send_tree.insert(pj_url.to_string(), IVec::from(value.as_str()))?;
Expand All @@ -52,8 +48,7 @@ impl Database {
let mut sessions = Vec::new();
for item in send_tree.iter() {
let (_, value) = item?;
let session: Sender =
serde_json::from_slice(&value).map_err(Error::Deserialize)?;
let session: Sender = serde_json::from_slice(&value).map_err(Error::Deserialize)?;
sessions.push(session);
}
Ok(sessions)
Expand All @@ -62,8 +57,7 @@ impl Database {
pub(crate) fn get_send_session(&self, pj_url: &Url) -> Result<Option<Sender>> {
let send_tree = self.0.open_tree("send_sessions")?;
if let Some(val) = send_tree.get(pj_url.to_string())? {
let session: Sender =
serde_json::from_slice(&val).map_err(Error::Deserialize)?;
let session: Sender = serde_json::from_slice(&val).map_err(Error::Deserialize)?;
Ok(Some(session))
} else {
Ok(None)
Expand Down
14 changes: 12 additions & 2 deletions payjoin-directory/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use redis::{AsyncCommands, Client, ErrorKind, RedisError, RedisResult};
use tracing::debug;

const DEFAULT_COLUMN: &str = "";
const PJ_V1_COLUMN: &str = "pjv1";

const RES_COLUMN: &str = "res";
const REQ_COLUMN: &str = "req";

Expand All @@ -20,14 +22,22 @@ impl DbPool {
Ok(Self { client, timeout })
}

pub async fn push_payload(&self, pubkey_id: &str, data: Vec<u8>) -> RedisResult<()> {
pub async fn push_default(&self, pubkey_id: &str, data: Vec<u8>) -> RedisResult<()> {
self.push(pubkey_id, DEFAULT_COLUMN, data).await
}

pub async fn peek_payload(&self, pubkey_id: &str) -> Option<RedisResult<Vec<u8>>> {
pub async fn peek_default(&self, pubkey_id: &str) -> Option<RedisResult<Vec<u8>>> {
self.peek_with_timeout(pubkey_id, DEFAULT_COLUMN).await
}

pub async fn push_v1(&self, pubkey_id: &str, data: Vec<u8>) -> RedisResult<()> {
self.push(pubkey_id, PJ_V1_COLUMN, data).await
}

pub async fn peek_v1(&self, pubkey_id: &str) -> Option<RedisResult<Vec<u8>>> {
self.peek_with_timeout(pubkey_id, PJ_V1_COLUMN).await
}

async fn push(&self, pubkey_id: &str, channel_type: &str, data: Vec<u8>) -> RedisResult<()> {
let mut conn = self.client.get_async_connection().await?;
let key = channel_name(pubkey_id, channel_type);
Expand Down
27 changes: 21 additions & 6 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,11 @@ async fn post_fallback_v1(
};

let v2_compat_body = format!("{}\n{}", body_str, query);
pool.push_payload(&id, v2_compat_body.into()).await.map_err(|e| HandlerError::BadRequest(e.into()))?;
match pool.peek_payload(&id).await {
let id = shorten_string(id);
pool.push_default(&id, v2_compat_body.into())
.await
.map_err(|e| HandlerError::BadRequest(e.into()))?;
match pool.peek_v1(&id).await {
Some(result) => match result {
Ok(buffered_req) => Ok(Response::new(full(buffered_req))),
Err(e) => Err(HandlerError::BadRequest(e.into())),
Expand All @@ -310,8 +313,20 @@ async fn put_payjoin_v1(
body: BoxBody<Bytes, hyper::Error>,
pool: DbPool,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
trace!("Put payjoin v1");
post_subdir(id, body, pool).await
trace!("Put_payjoin_v1");
let ok_response = Response::builder().status(StatusCode::OK).body(empty())?;

let id = shorten_string(id);
let req =
body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes();
if req.len() > MAX_BUFFER_SIZE {
return Err(HandlerError::PayloadTooLarge);
}

match pool.push_v1(&id, req.into()).await {
Ok(_) => Ok(ok_response),
Err(e) => return Err(HandlerError::BadRequest(e.into())),
}
}

async fn post_subdir(
Expand All @@ -329,7 +344,7 @@ async fn post_subdir(
return Err(HandlerError::PayloadTooLarge);
}

match pool.push_payload(&id, req.into()).await {
match pool.push_default(&id, req.into()).await {
Ok(_) => Ok(none_response),
Err(e) => return Err(HandlerError::BadRequest(e.into())),
}
Expand All @@ -341,7 +356,7 @@ async fn get_subdir(
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
trace!("GET fallback");
let id = shorten_string(id);
match pool.peek_payload(&id).await {
match pool.peek_default(&id).await {
Some(result) => match result {
Ok(buffered_req) => Ok(Response::new(full(buffered_req))),
Err(e) => Err(HandlerError::BadRequest(e.into())),
Expand Down
11 changes: 8 additions & 3 deletions payjoin/src/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,17 +490,22 @@ impl PayjoinProposal {
// Prepare v2 payload
let payjoin_bytes = self.inner.payjoin_psbt.serialize();
let sender_subdir = subdir_path_from_pubkey(e);
target_resource = self.context.directory.join(&sender_subdir).map_err(|e| Error::Server(e.into()))?;
target_resource =
self.context.directory.join(&sender_subdir).map_err(|e| Error::Server(e.into()))?;
body = crate::v2::encrypt_message_b(payjoin_bytes, &self.context.s, e).unwrap();
method = "POST";
} else {
// Prepare v2 wrapped and backwards-compatible v1 payload
body = self.extract_v1_req().as_bytes().to_vec();
let receiver_subdir = subdir_path_from_pubkey(self.context.s.public_key());
target_resource = self.context.directory.join(&receiver_subdir).map_err(|e| Error::Server(e.into()))?;
target_resource = self
.context
.directory
.join(&receiver_subdir)
.map_err(|e| Error::Server(e.into()))?;
method = "PUT";
}
log::debug!("Payjoin post target: {}", target_resource.as_str());
log::debug!("Payjoin PSBT target: {}", target_resource.as_str());
let (body, ctx) = crate::v2::ohttp_encapsulate(
&mut self.context.ohttp_keys,
method,
Expand Down
50 changes: 26 additions & 24 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use url::Url;
use crate::psbt::{InputPair, PsbtExt};
use crate::request::Request;
#[cfg(feature = "v2")]
use crate::v2::{HpkePublicKey, HpkeKeyPair};
use crate::v2::{HpkeKeyPair, HpkePublicKey};
use crate::PjUri;

// See usize casts
Expand Down Expand Up @@ -100,10 +100,7 @@ impl<'a> SenderBuilder<'a> {
// The minfeerate parameter is set if the contribution is available in change.
//
// This method fails if no recommendation can be made or if the PSBT is malformed.
pub fn build_recommended(
self,
min_fee_rate: FeeRate,
) -> Result<Sender, CreateRequestError> {
pub fn build_recommended(self, min_fee_rate: FeeRate) -> Result<Sender, CreateRequestError> {
// TODO support optional batched payout scripts. This would require a change to
// build() which now checks for a single payee.
let mut payout_scripts = std::iter::once(self.uri.address.script_pubkey());
Expand Down Expand Up @@ -270,8 +267,8 @@ impl Sender {
fee_contribution: self.fee_contribution,
payee: self.payee.clone(),
min_fee_rate: self.min_fee_rate,
}
}
},
},
))
}

Expand Down Expand Up @@ -341,10 +338,7 @@ impl Sender {
payee: self.payee.clone(),
min_fee_rate: self.min_fee_rate,
},
hpke_ctx: HpkeContext {
rs: rs,
e: self.e.clone(),
},
hpke_ctx: HpkeContext { rs: rs, e: self.e.clone() },
ohttp_ctx,
}),
))
Expand Down Expand Up @@ -417,7 +411,7 @@ impl V2PostContext {
psbt_ctx: self.psbt_ctx,
hpke_ctx: self.hpke_ctx,
})
},
}
_ => return Err(InternalValidationError::UnexpectedStatusCode)?,
}
}
Expand All @@ -432,24 +426,29 @@ pub struct V2GetContext {

#[cfg(feature = "v2")]
impl V2GetContext {
pub fn extract_req(&self, ohttp_relay: Url) -> Result<(Request, ohttp::ClientResponse), CreateRequestError> {
pub fn extract_req(
&self,
ohttp_relay: Url,
) -> Result<(Request, ohttp::ClientResponse), CreateRequestError> {
use crate::uri::UrlExt;
let mut url = self.endpoint.clone();
let subdir = BASE64_URL_SAFE_NO_PAD.encode(self.hpke_ctx.e.public_key().to_compressed_bytes());
let subdir =
BASE64_URL_SAFE_NO_PAD.encode(self.hpke_ctx.e.public_key().to_compressed_bytes());
url.set_path(&subdir);
println!("sender subdir from sender: {:?}", &url);
let body = crate::v2::encrypt_message_a(Vec::new(), &self.hpke_ctx.e.secret_key().clone(), &self.hpke_ctx.rs.clone())
.map_err(InternalCreateRequestError::Hpke)?;
let body = crate::v2::encrypt_message_a(
Vec::new(),
&self.hpke_ctx.e.secret_key().clone(),
&self.hpke_ctx.rs.clone(),
)
.map_err(InternalCreateRequestError::Hpke)?;
let mut ohttp =
self.endpoint.ohttp().ok_or(InternalCreateRequestError::MissingOhttpConfig)?;
let (body, ohttp_ctx) =
crate::v2::ohttp_encapsulate(&mut ohttp, "GET", url.as_str(), Some(&body))
.map_err(InternalCreateRequestError::OhttpEncapsulation)?;

Ok((
Request::new_v2(ohttp_relay, body),
ohttp_ctx,
))

Ok((Request::new_v2(ohttp_relay, body), ohttp_ctx))
}

pub fn process_response(
Expand All @@ -468,16 +467,19 @@ impl V2GetContext {
http::StatusCode::ACCEPTED => return Ok(None),
_ => return Err(InternalValidationError::UnexpectedStatusCode)?,
};
let psbt = crate::v2::decrypt_message_b(&body, self.hpke_ctx.rs.clone(), self.hpke_ctx.e.secret_key().clone())
.map_err(InternalValidationError::Hpke)?;
let psbt = crate::v2::decrypt_message_b(
&body,
self.hpke_ctx.rs.clone(),
self.hpke_ctx.e.secret_key().clone(),
)
.map_err(InternalValidationError::Hpke)?;

let proposal = Psbt::deserialize(&psbt).map_err(InternalValidationError::Psbt)?;
let processed_proposal = self.psbt_ctx.clone().process_proposal(proposal)?;
Ok(Some(processed_proposal))
}
}


/// Data required for validation of response.
///
/// This type is used to process the response. Get it from [`RequestBuilder`](crate::send::RequestBuilder)'s build methods.
Expand Down
15 changes: 8 additions & 7 deletions payjoin/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ mod integration {
req_ctx.extract_highest_version(directory.to_owned())?;
let send_ctx = match send_ctx {
Context::V2(ctx) => ctx,
_ => panic!("V2 context expected")
_ => panic!("V2 context expected"),
};
let response = agent
.post(url.clone())
Expand Down Expand Up @@ -319,7 +319,8 @@ mod integration {
// Inside the Sender:
// Sender checks, signs, finalizes, extracts, and broadcasts
// Replay post fallback to get the response
let (Request { url, body, content_type, .. }, ohttp_ctx) = send_ctx.extract_req(directory.to_owned())?;
let (Request { url, body, content_type, .. }, ohttp_ctx) =
send_ctx.extract_req(directory.to_owned())?;
let response = agent
.post(url.clone())
.header("Content-Type", content_type)
Expand All @@ -328,7 +329,9 @@ mod integration {
.await
.unwrap();
log::info!("Response: {:#?}", &response);
let checked_payjoin_proposal_psbt = send_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)?.unwrap();
let checked_payjoin_proposal_psbt = send_ctx
.process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)?
.unwrap();
let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?;
sender.send_raw_transaction(&payjoin_tx)?;
log::info!("sent");
Expand Down Expand Up @@ -382,10 +385,9 @@ mod integration {
// Sender checks, signs, finalizes, extracts, and broadcasts
let ctx = match ctx {
Context::V1(ctx) => ctx,
_ => panic!("V1 context expected")
_ => panic!("V1 context expected"),
};
let checked_payjoin_proposal_psbt =
ctx.process_response(&mut response.as_bytes())?;
let checked_payjoin_proposal_psbt = ctx.process_response(&mut response.as_bytes())?;
let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?;
sender.send_raw_transaction(&payjoin_tx)?;

Expand Down Expand Up @@ -461,7 +463,6 @@ mod integration {
.body(body.clone())
.send()
.await;
dbg!(&res);
assert!(res.as_ref().unwrap().status() == StatusCode::SERVICE_UNAVAILABLE);

// **********************
Expand Down

0 comments on commit 8f5aca5

Please sign in to comment.