Skip to content

Commit

Permalink
Have directory give US a port
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Nov 27, 2024
1 parent 0148630 commit fe9d146
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 50 deletions.
106 changes: 70 additions & 36 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,70 @@ const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message"
mod db;
use crate::db::DbPool;

type BoxError = Box<dyn std::error::Error + Send + Sync>;

#[cfg(feature = "danger-local-https")]
pub async fn listen_tcp_with_tls_on_free_port(
db_host: String,
timeout: Duration,
cert_key: (Vec<u8>, Vec<u8>),
) -> Result<(u16, tokio::task::JoinHandle<Result<(), BoxError>>), BoxError> {
let listener = std::net::TcpListener::bind("0.0.0.0:0")?;
let port = listener.local_addr()?.port();
println!("Directory server binding to port {}", port);

let listener = tokio::net::TcpListener::from_std(listener)?;
println!("tokio listener created");
let handle = listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await?;
println!("Directory server started");
Ok((port, handle))
}

// Helper function to avoid code duplication
async fn listen_tcp_with_tls_on_listener(
listener: tokio::net::TcpListener,
db_host: String,
timeout: Duration,
tls_config: (Vec<u8>, Vec<u8>),
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
let pool = DbPool::new(timeout, db_host).await?;
let ohttp = Arc::new(Mutex::new(init_ohttp()?));
let tls_acceptor = init_tls_acceptor(tls_config)?;
// Spawn the connection handling loop in a separate task
let handle = tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
let pool = pool.clone();
let ohttp = ohttp.clone();
let tls_acceptor = tls_acceptor.clone();
tokio::spawn(async move {
let tls_stream = match tls_acceptor.accept(stream).await {
Ok(tls_stream) => tls_stream,
Err(e) => {
error!("TLS accept error: {}", e);
return;
}
};
if let Err(err) = http1::Builder::new()
.serve_connection(
TokioIo::new(tls_stream),
service_fn(move |req| {
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
}),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}
Ok(())
});
Ok(handle)
}

// Modify existing listen_tcp_with_tls to use the new helper

pub async fn listen_tcp(
port: u16,
db_host: String,
Expand Down Expand Up @@ -68,42 +132,12 @@ pub async fn listen_tcp(
pub async fn listen_tcp_with_tls(
port: u16,
db_host: String,
timeout: Duration,
tls_config: (Vec<u8>, Vec<u8>),
) -> Result<(), Box<dyn std::error::Error>> {
let pool = DbPool::new(timeout, db_host).await?;
let ohttp = Arc::new(Mutex::new(init_ohttp()?));
let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port);
let tls_acceptor = init_tls_acceptor(tls_config)?;
let listener = TcpListener::bind(bind_addr).await?;
while let Ok((stream, _)) = listener.accept().await {
let pool = pool.clone();
let ohttp = ohttp.clone();
let tls_acceptor = tls_acceptor.clone();
tokio::spawn(async move {
let tls_stream = match tls_acceptor.accept(stream).await {
Ok(tls_stream) => tls_stream,
Err(e) => {
error!("TLS accept error: {}", e);
return;
}
};
if let Err(err) = http1::Builder::new()
.serve_connection(
TokioIo::new(tls_stream),
service_fn(move |req| {
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
}),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}

Ok(())
timeout: Duration,
cert_key: (Vec<u8>, Vec<u8>),
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
let addr = format!("0.0.0.0:{}", port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await
}

#[cfg(feature = "danger-local-https")]
Expand Down
51 changes: 37 additions & 14 deletions payjoin/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod integration {
use url::Url;

type BoxError = Box<dyn std::error::Error + 'static>;
type BoxSendSyncError = Box<dyn std::error::Error + Send + Sync>;

static INIT_TRACING: OnceCell<()> = OnceCell::new();
static EXAMPLE_URL: Lazy<Url> =
Expand Down Expand Up @@ -192,15 +193,28 @@ mod integration {

#[tokio::test]
async fn test_bad_ohttp_keys() {
dbg!("Starting test_bad_ohttp_keys");
let bad_ohttp_keys =
OhttpKeys::from_str("AQO6SMScPUqSo60A7MY6Ak2hDO0CGAxz7BLYp60syRu0gw")
.expect("Invalid OhttpKeys");

let (cert, key) = local_cert_key();
let port = find_free_port();
dbg!("G");
let (port, directory_future) = init_directory((cert.clone(), key))
.await
.expect("Failed to init directory");
println!("Directory server started on port IN TEST FN {}", port);
let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap();

// Spawn the directory server task
let directory_task = tokio::spawn(async move {
if let Err(e) = directory_future.await {
eprintln!("Directory server error: {:?}", e);
}
});

tokio::select!(
_ = init_directory(port, (cert.clone(), key)) => panic!("Directory server is long running"),
_ = directory_task => panic!("Directory server is long running"),
res = try_request_with_bad_keys(directory, bad_ohttp_keys, cert) => {
assert_eq!(
res.unwrap().headers().get("content-type").unwrap(),
Expand All @@ -214,8 +228,10 @@ mod integration {
bad_ohttp_keys: OhttpKeys,
cert_der: Vec<u8>,
) -> Result<Response, Error> {
println!("Trying request with bad keys");
let agent = Arc::new(http_agent(cert_der.clone()).unwrap());
wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap();
println!("Service ready");
let mock_ohttp_relay = directory.clone(); // pass through to directory
let mock_address = Address::from_str("tb1q6d3a2w975yny0asuvd9a67ner4nks58ff0q8g4")
.unwrap()
Expand All @@ -234,12 +250,14 @@ mod integration {
let ohttp_relay_port = find_free_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let (directory_port, directory_handle) = init_directory((cert.clone(), key))
.await
.expect("Failed to init directory");
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
_ = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay is long running"),
_ = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server is long running"),
_ = directory_handle => panic!("Directory server is long running"),
res = do_expiration_tests(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res)
);

Expand Down Expand Up @@ -303,12 +321,14 @@ mod integration {
let ohttp_relay_port = find_free_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let (directory_port, directory_future) = init_directory((cert.clone(), key))
.await
.expect("Failed to init directory");
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
_ = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay is long running"),
_ = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server is long running"),
_ = directory_future => panic!("Directory server is long running"),
res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res)
);

Expand Down Expand Up @@ -435,12 +455,14 @@ mod integration {
let ohttp_relay_port = find_free_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let (directory_port, directory_future) = init_directory((cert.clone(), key))
.await
.expect("Failed to init directory");
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
_ = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay is long running"),
_ = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server is long running"),
_ = directory_future => panic!("Directory server is long running"),
res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res)
);

Expand Down Expand Up @@ -653,12 +675,14 @@ mod integration {
let ohttp_relay_port = find_free_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let (directory_port, directory_future) = init_directory((cert.clone(), key))
.await
.expect("Failed to init directory");
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
_ = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay is long running"),
_ = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server is long running"),
_ = directory_future => panic!("Directory server is long running"),
res = do_v1_to_v2(ohttp_relay, directory, cert) => assert!(res.is_ok()),
);

Expand Down Expand Up @@ -780,15 +804,14 @@ mod integration {
}

async fn init_directory(
port: u16,
local_cert_key: (Vec<u8>, Vec<u8>),
) -> Result<(), BoxError> {
) -> Result<(u16, tokio::task::JoinHandle<Result<(), BoxSendSyncError>>), BoxSendSyncError> {
let docker: Cli = Cli::default();
let timeout = Duration::from_secs(2);
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
println!("Database running on {}", db.get_host_port_ipv4(6379));
payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await
payjoin_directory::listen_tcp_with_tls_on_free_port(db_host, timeout, local_cert_key).await
}

// generates or gets a DER encoded localhost cert and key.
Expand Down Expand Up @@ -929,7 +952,7 @@ mod integration {
while start.elapsed() < *TESTS_TIMEOUT {
let request_result =
agent.get(health_url.as_str()).send().await.map_err(|_| "Bad request")?;

println!("awaiting Service ready: {:?}", request_result.status());
match request_result.status() {
StatusCode::OK => return Ok(()),
StatusCode::NOT_FOUND => return Err("Endpoint not found"),
Expand Down

0 comments on commit fe9d146

Please sign in to comment.