Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Session Initialization Implicit #364

Merged
merged 12 commits into from
Oct 22, 2024
6 changes: 3 additions & 3 deletions payjoin-cli/src/app/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use bitcoin::TxIn;
use bitcoincore_rpc::bitcoin::Amount;
use bitcoincore_rpc::RpcApi;
use payjoin::bitcoin::psbt::Psbt;
use payjoin::send::RequestContext;
use payjoin::send::Sender;
use payjoin::{bitcoin, PjUri};

pub mod config;
Expand All @@ -30,7 +30,7 @@ pub trait App {
async fn send_payjoin(&self, bip21: &str, fee_rate: &f32) -> Result<()>;
async fn receive_payjoin(self, amount_arg: &str) -> Result<()>;

fn create_pj_request(&self, uri: &PjUri, fee_rate: &f32) -> Result<RequestContext> {
fn create_pj_request(&self, uri: &PjUri, fee_rate: &f32) -> Result<Sender> {
let amount = uri.amount.ok_or_else(|| anyhow!("please specify the amount in the Uri"))?;

// wallet_create_funded_psbt requires a HashMap<address: String, Amount>
Expand Down Expand Up @@ -66,7 +66,7 @@ pub trait App {
.psbt;
let psbt = Psbt::from_str(&psbt).with_context(|| "Failed to load PSBT from base64")?;
log::debug!("Original psbt: {:#?}", psbt);
let req_ctx = payjoin::send::RequestBuilder::from_psbt_and_uri(psbt, uri.clone())
let req_ctx = payjoin::send::SenderBuilder::from_psbt_and_uri(psbt, uri.clone())
.with_context(|| "Failed to build payjoin request")?
.build_recommended(fee_rate)
.with_context(|| "Failed to build payjoin request")?;
Expand Down
101 changes: 56 additions & 45 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use bitcoincore_rpc::RpcApi;
use payjoin::bitcoin::consensus::encode::serialize_hex;
use payjoin::bitcoin::psbt::Psbt;
use payjoin::bitcoin::{Amount, FeeRate};
use payjoin::receive::v2::ActiveSession;
use payjoin::send::RequestContext;
use payjoin::receive::v2::Receiver;
use payjoin::send::Sender;
use payjoin::{bitcoin, Error, Uri};
use tokio::signal;
use tokio::sync::watch;
Expand Down Expand Up @@ -75,39 +75,23 @@ impl AppTrait for App {
}

async fn receive_payjoin(self, amount_arg: &str) -> Result<()> {
use payjoin::receive::v2::SessionInitializer;

let address = self.bitcoind()?.get_new_address(None, None)?.assume_checked();
let amount = Amount::from_sat(amount_arg.parse()?);
let ohttp_keys = unwrap_ohttp_keys_or_else_fetch(&self.config).await?;
let mut initializer = SessionInitializer::new(
let session = Receiver::new(
address,
self.config.pj_directory.clone(),
ohttp_keys.clone(),
self.config.ohttp_relay.clone(),
None,
);
let (req, ctx) =
initializer.extract_req().map_err(|e| anyhow!("Failed to extract request {}", e))?;
println!("Starting new Payjoin session with {}", self.config.pj_directory);
let http = http_agent()?;
let ohttp_response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
let session = initializer
.process_res(ohttp_response.bytes().await?.to_vec().as_slice(), ctx)
.map_err(|e| anyhow!("Enrollment failed {}", e))?;
self.db.insert_recv_session(session.clone())?;
self.spawn_payjoin_receiver(session, Some(amount)).await
}
}

impl App {
async fn spawn_payjoin_sender(&self, mut req_ctx: RequestContext) -> Result<()> {
async fn spawn_payjoin_sender(&self, mut req_ctx: Sender) -> Result<()> {
let mut interrupt = self.interrupt.clone();
tokio::select! {
res = self.long_poll_post(&mut req_ctx) => {
Expand All @@ -123,7 +107,7 @@ impl App {

async fn spawn_payjoin_receiver(
&self,
mut session: ActiveSession,
mut session: Receiver,
amount: Option<Amount>,
) -> Result<()> {
println!("Receive session established");
Expand Down Expand Up @@ -213,38 +197,65 @@ impl App {
Ok(())
}

async fn long_poll_post(&self, req_ctx: &mut payjoin::send::RequestContext) -> Result<Psbt> {
loop {
let (req, ctx) = req_ctx.extract_v2(self.config.ohttp_relay.clone())?;
println!("Polling send request...");
let http = http_agent()?;
let response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;

println!("Sent fallback transaction");
match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) {
Ok(Some(psbt)) => return Ok(psbt),
Ok(None) => {
println!("No response yet.");
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
async fn long_poll_post(&self, req_ctx: &mut payjoin::send::Sender) -> Result<Psbt> {
let (req, ctx) = req_ctx.extract_highest_version(self.config.ohttp_relay.clone())?;
println!("Posting Original PSBT Payload request...");
DanGould marked this conversation as resolved.
Show resolved Hide resolved
let http = http_agent()?;
let response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
println!("Sent fallback transaction");
match ctx {
payjoin::send::Context::V2(ctx) => {
let v2_ctx = Arc::new(
ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?,
);
loop {
let (req, ohttp_ctx) = v2_ctx.extract_req(self.config.ohttp_relay.clone())?;
let response = http
.post(req.url)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the sender be polling GET requests here?

Copy link
Contributor Author

@DanGould DanGould Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohttp requests are always POSTs so that they're indistinguishable. Only when the OHTTP encapsulation is removed does the target server see the inner GET request.

see https://github.com/payjoin/rust-payjoin/pull/364/files#diff-e1a2a0997d5241c211b358d2ea2be04fc1a4aa680ac543480188d7b4ab63ff72R448

I wonder if there's a better way to represent this in the state machine than POST/GET to avoid this confusion

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now the GET request is constructed in extract_req, makes sense. It might help if the OHTTP post was extracted into a make_ohttp_request(req) function or something along those lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea I plan to write for both v2.rs and integration.rs in a follow up

.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
match v2_ctx.process_response(
&mut response.bytes().await?.to_vec().as_slice(),
ohttp_ctx,
) {
Ok(Some(psbt)) => return Ok(psbt),
Ok(None) => {
println!("No response yet.");
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
Err(re) => {
println!("{}", re);
log::debug!("{:?}", re);
DanGould marked this conversation as resolved.
Show resolved Hide resolved
return Err(anyhow!("Response error").context(re));
}
}
}
Err(re) => {
println!("{}", re);
log::debug!("{:?}", re);
return Err(anyhow!("Response error").context(re));
}
payjoin::send::Context::V1(ctx) => {
match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) {
Ok(psbt) => Ok(psbt),
Err(re) => {
println!("{}", re);
log::debug!("{:?}", re);
Err(anyhow!("Response error").context(re))
}
}
}
}
}

async fn long_poll_fallback(
&self,
session: &mut payjoin::receive::v2::ActiveSession,
session: &mut payjoin::receive::v2::Receiver,
) -> Result<payjoin::receive::v2::UncheckedProposal> {
loop {
let (req, context) = session.extract_req()?;
Expand Down
27 changes: 10 additions & 17 deletions payjoin-cli/src/db/v2.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use bitcoincore_rpc::jsonrpc::serde_json;
use payjoin::receive::v2::ActiveSession;
use payjoin::send::RequestContext;
use payjoin::receive::v2::Receiver;
use payjoin::send::Sender;
use sled::{IVec, Tree};
use url::Url;

use super::*;

impl Database {
pub(crate) fn insert_recv_session(&self, session: ActiveSession) -> Result<()> {
pub(crate) fn insert_recv_session(&self, session: Receiver) -> Result<()> {
let recv_tree = self.0.open_tree("recv_sessions")?;
let key = &session.id();
let value = serde_json::to_string(&session).map_err(Error::Serialize)?;
Expand All @@ -16,13 +16,12 @@ impl Database {
Ok(())
}

pub(crate) fn get_recv_sessions(&self) -> Result<Vec<ActiveSession>> {
pub(crate) fn get_recv_sessions(&self) -> Result<Vec<Receiver>> {
let recv_tree = self.0.open_tree("recv_sessions")?;
let mut sessions = Vec::new();
for item in recv_tree.iter() {
let (_, value) = item?;
let session: ActiveSession =
serde_json::from_slice(&value).map_err(Error::Deserialize)?;
let session: Receiver = serde_json::from_slice(&value).map_err(Error::Deserialize)?;
sessions.push(session);
}
Ok(sessions)
Expand All @@ -35,35 +34,29 @@ impl Database {
Ok(())
}

pub(crate) fn insert_send_session(
&self,
session: &mut RequestContext,
pj_url: &Url,
) -> Result<()> {
pub(crate) fn insert_send_session(&self, session: &mut Sender, pj_url: &Url) -> Result<()> {
let send_tree: Tree = self.0.open_tree("send_sessions")?;
let value = serde_json::to_string(session).map_err(Error::Serialize)?;
send_tree.insert(pj_url.to_string(), IVec::from(value.as_str()))?;
send_tree.flush()?;
Ok(())
}

pub(crate) fn get_send_sessions(&self) -> Result<Vec<RequestContext>> {
pub(crate) fn get_send_sessions(&self) -> Result<Vec<Sender>> {
let send_tree: Tree = self.0.open_tree("send_sessions")?;
let mut sessions = Vec::new();
for item in send_tree.iter() {
let (_, value) = item?;
let session: RequestContext =
serde_json::from_slice(&value).map_err(Error::Deserialize)?;
let session: Sender = serde_json::from_slice(&value).map_err(Error::Deserialize)?;
sessions.push(session);
}
Ok(sessions)
}

pub(crate) fn get_send_session(&self, pj_url: &Url) -> Result<Option<RequestContext>> {
pub(crate) fn get_send_session(&self, pj_url: &Url) -> Result<Option<Sender>> {
let send_tree = self.0.open_tree("send_sessions")?;
if let Some(val) = send_tree.get(pj_url.to_string())? {
let session: RequestContext =
serde_json::from_slice(&val).map_err(Error::Deserialize)?;
let session: Sender = serde_json::from_slice(&val).map_err(Error::Deserialize)?;
Ok(Some(session))
} else {
Ok(None)
Expand Down
9 changes: 1 addition & 8 deletions payjoin-cli/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,14 +482,7 @@ mod e2e {
let db = docker.run(Redis::default());
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
println!("Database running on {}", db.get_host_port_ipv4(6379));
payjoin_directory::listen_tcp_with_tls(
format!("http://localhost:{}", port),
port,
db_host,
timeout,
local_cert_key,
)
.await
payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await
}

// generates or gets a DER encoded localhost cert and key.
Expand Down
20 changes: 10 additions & 10 deletions payjoin-directory/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use futures::StreamExt;
use redis::{AsyncCommands, Client, ErrorKind, RedisError, RedisResult};
use tracing::debug;

const RES_COLUMN: &str = "res";
const REQ_COLUMN: &str = "req";
const DEFAULT_COLUMN: &str = "";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something descriptive maybe useful here when debugging in prod. even "default-col"

const PJ_V1_COLUMN: &str = "pjv1";

#[derive(Debug, Clone)]
pub(crate) struct DbPool {
Expand All @@ -19,20 +19,20 @@ impl DbPool {
Ok(Self { client, timeout })
}

pub async fn peek_req(&self, pubkey_id: &str) -> Option<RedisResult<Vec<u8>>> {
self.peek_with_timeout(pubkey_id, REQ_COLUMN).await
pub async fn push_default(&self, pubkey_id: &str, data: Vec<u8>) -> RedisResult<()> {
self.push(pubkey_id, DEFAULT_COLUMN, data).await
}

pub async fn peek_res(&self, pubkey_id: &str) -> Option<RedisResult<Vec<u8>>> {
self.peek_with_timeout(pubkey_id, RES_COLUMN).await
pub async fn peek_default(&self, pubkey_id: &str) -> Option<RedisResult<Vec<u8>>> {
self.peek_with_timeout(pubkey_id, DEFAULT_COLUMN).await
}

pub async fn push_req(&self, pubkey_id: &str, data: Vec<u8>) -> RedisResult<()> {
self.push(pubkey_id, REQ_COLUMN, data).await
pub async fn push_v1(&self, pubkey_id: &str, data: Vec<u8>) -> RedisResult<()> {
self.push(pubkey_id, PJ_V1_COLUMN, data).await
}

pub async fn push_res(&self, pubkey_id: &str, data: Vec<u8>) -> RedisResult<()> {
self.push(pubkey_id, RES_COLUMN, data).await
pub async fn peek_v1(&self, pubkey_id: &str) -> Option<RedisResult<Vec<u8>>> {
self.peek_with_timeout(pubkey_id, PJ_V1_COLUMN).await
}

async fn push(&self, pubkey_id: &str, channel_type: &str, data: Vec<u8>) -> RedisResult<()> {
Expand Down
Loading
Loading