diff --git a/src/lib.rs b/src/lib.rs index e7f9cd79..66550086 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,7 @@ mod rate_limit; use crate::lnd::{ features_support_onion_messages, get_lnd_client, string_to_network, LndCfg, LndNodeSigner, }; -use crate::lndk_offers::OfferError; +use crate::lndk_offers::{OfferError, PayInvoiceParams}; use crate::onion_messenger::MessengerUtilities; use bitcoin::network::constants::Network; use bitcoin::secp256k1::{Error as Secp256k1Error, PublicKey, Secp256k1}; @@ -16,6 +16,7 @@ use home::home_dir; use lightning::blinded_path::BlindedPath; use lightning::ln::inbound_payment::ExpandedKey; use lightning::ln::peer_handler::IgnoringMessageHandler; +use lightning::offers::invoice::Bolt12Invoice; use lightning::offers::invoice_error::InvoiceError; use lightning::offers::offer::Offer; use lightning::onion_message::messenger::{ @@ -33,6 +34,7 @@ use std::collections::HashMap; use std::str::FromStr; use std::sync::{Mutex, Once}; use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::time::{sleep, Duration}; use tonic_lnd::lnrpc::GetInfoRequest; use tonic_lnd::Client; use triggered::{Listener, Trigger}; @@ -189,6 +191,7 @@ enum OfferState { pub struct OfferHandler { active_offers: Mutex>, + active_invoices: Mutex>, pending_messages: Mutex>>, pub messenger_utils: MessengerUtilities, expanded_key: ExpandedKey, @@ -214,6 +217,7 @@ impl OfferHandler { OfferHandler { active_offers: Mutex::new(HashMap::new()), + active_invoices: Mutex::new(Vec::new()), pending_messages: Mutex::new(Vec::new()), messenger_utils, expanded_key, @@ -226,8 +230,42 @@ impl OfferHandler { cfg: PayOfferParams, started: Receiver, ) -> Result<(), OfferError> { - self.send_invoice_request(cfg, started).await?; - Ok(()) + let client_clone = cfg.client.clone(); + let offer_id = cfg.offer.clone().to_string(); + let validated_amount = self.send_invoice_request(cfg, started).await?; + + let invoice = self.wait_for_invoice().await; + { + let mut active_offers = self.active_offers.lock().unwrap(); + active_offers.insert(offer_id.clone(), OfferState::InvoiceReceived); + } + + let payment_hash = invoice.payment_hash(); + let path_info = invoice.payment_paths()[0].clone(); + + let params = PayInvoiceParams { + path: path_info.1, + cltv_expiry_delta: path_info.0.cltv_expiry_delta, + fee_base_msat: path_info.0.fee_base_msat, + payment_hash: payment_hash.0, + msats: validated_amount, + offer_id, + }; + + self.pay_invoice(client_clone, params).await + } + + /// wait_for_invoice waits for the offer creator to respond with an invoice. + pub async fn wait_for_invoice(&self) -> Bolt12Invoice { + loop { + { + let mut active_invoices = self.active_invoices.lock().unwrap(); + if active_invoices.len() == 1 { + return active_invoices.pop().unwrap(); + } + } + sleep(Duration::from_secs(2)).await; + } } } @@ -252,13 +290,20 @@ impl OffersMessageHandler for OfferHandler { // returned payment id below to check if we already processed an invoice for // this payment. Right now it's safe to let this be because we won't try to pay // a second invoice (if it comes through). - Ok(_payment_id) => Some(OffersMessage::Invoice(invoice)), + Ok(_payment_id) => { + let mut active_invoices = self.active_invoices.lock().unwrap(); + active_invoices.push(invoice.clone()); + Some(OffersMessage::Invoice(invoice)) + } Err(()) => Some(OffersMessage::InvoiceError(InvoiceError::from_string( String::from("invoice verification failure"), ))), } } - OffersMessage::InvoiceError(_error) => None, + OffersMessage::InvoiceError(error) => { + log::error!("Invoice error received: {}", error); + None + } } } diff --git a/src/lndk_offers.rs b/src/lndk_offers.rs index bcc156cf..ad749e44 100644 --- a/src/lndk_offers.rs +++ b/src/lndk_offers.rs @@ -98,7 +98,7 @@ impl OfferHandler { &self, mut cfg: PayOfferParams, mut started: Receiver, - ) -> Result<(), OfferError> { + ) -> Result> { // Wait for onion messenger to give us the signal that it's ready. Once the onion messenger drops // the channel sender, recv will return None and we'll stop blocking here. while (started.recv().await).is_some() { @@ -154,7 +154,7 @@ impl OfferHandler { pending_messages.push(pending_message); std::mem::drop(pending_messages); - Ok(()) + Ok(validated_amount) } // create_invoice_request builds and signs an invoice request, the first step in the BOLT 12 process of paying an offer. @@ -259,30 +259,45 @@ impl OfferHandler { pub(crate) async fn pay_invoice( &self, mut payer: impl InvoicePayer + std::marker::Send + 'static, - path: BlindedPath, - cltv_expiry_delta: u16, - fee_base_msat: u32, - payment_hash: [u8; 32], - msats: u64, + params: PayInvoiceParams, ) -> Result<(), OfferError> { let resp = payer - .query_routes(path, cltv_expiry_delta, fee_base_msat, msats) + .query_routes( + params.path, + params.cltv_expiry_delta, + params.fee_base_msat, + params.msats, + ) .await .map_err(OfferError::RouteFailure)?; let _ = payer - .send_to_route(payment_hash, resp.routes[0].clone()) + .send_to_route(params.payment_hash, resp.routes[0].clone()) .await .map_err(OfferError::RouteFailure)?; - // The payment is still in flight. We'll track it until it settles. + { + let mut active_offers = self.active_offers.lock().unwrap(); + active_offers.insert(params.offer_id, OfferState::InvoicePaymentDispatched); + } + + // We'll track the payment until it settles. payer - .track_payment(payment_hash) + .track_payment(params.payment_hash) .await .map_err(|_| OfferError::PaymentFailure) } } +pub struct PayInvoiceParams { + pub path: BlindedPath, + pub cltv_expiry_delta: u16, + pub fee_base_msat: u32, + pub payment_hash: [u8; 32], + pub msats: u64, + pub offer_id: String, +} + // Checks that the user-provided amount matches the offer. pub async fn validate_amount( offer: &Offer, @@ -922,10 +937,15 @@ mod tests { let blinded_path = get_blinded_path(); let payment_hash = MessengerUtilities::new().get_secure_random_bytes(); let handler = OfferHandler::new(); - assert!(handler - .pay_invoice(payer_mock, blinded_path, 200, 1, payment_hash, 2000) - .await - .is_ok()); + let params = PayInvoiceParams { + path: blinded_path, + cltv_expiry_delta: 200, + fee_base_msat: 1, + payment_hash: payment_hash, + msats: 2000, + offer_id: get_offer(), + }; + assert!(handler.pay_invoice(payer_mock, params).await.is_ok()); } #[tokio::test] @@ -939,10 +959,15 @@ mod tests { let blinded_path = get_blinded_path(); let payment_hash = MessengerUtilities::new().get_secure_random_bytes(); let handler = OfferHandler::new(); - assert!(handler - .pay_invoice(payer_mock, blinded_path, 200, 1, payment_hash, 2000) - .await - .is_err()); + let params = PayInvoiceParams { + path: blinded_path, + cltv_expiry_delta: 200, + fee_base_msat: 1, + payment_hash: payment_hash, + msats: 2000, + offer_id: get_offer(), + }; + assert!(handler.pay_invoice(payer_mock, params).await.is_err()); } #[tokio::test] @@ -966,9 +991,14 @@ mod tests { let blinded_path = get_blinded_path(); let payment_hash = MessengerUtilities::new().get_secure_random_bytes(); let handler = OfferHandler::new(); - assert!(handler - .pay_invoice(payer_mock, blinded_path, 200, 1, payment_hash, 2000) - .await - .is_err()); + let params = PayInvoiceParams { + path: blinded_path, + cltv_expiry_delta: 200, + fee_base_msat: 1, + payment_hash: payment_hash, + msats: 2000, + offer_id: get_offer(), + }; + assert!(handler.pay_invoice(payer_mock, params).await.is_err()); } } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index b001ea5c..c9d7c630 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -11,6 +11,7 @@ use lightning::blinded_path::BlindedPath; use lightning::offers::offer::Quantity; use lndk::onion_messenger::MessengerUtilities; use lndk::{LifecycleSignals, PayOfferParams}; +use std::net::SocketAddr; use std::path::PathBuf; use std::str::FromStr; use std::time::SystemTime; @@ -295,3 +296,147 @@ async fn test_lndk_send_invoice_request() { } } } + +#[tokio::test(flavor = "multi_thread")] +// Here we test that we're able to fully pay an offer. +async fn test_lndk_pay_offer() { + let test_name = "lndk_pay_offer"; + let (bitcoind, mut lnd, ldk1, ldk2, lndk_dir) = + common::setup_test_infrastructure(test_name).await; + + // Here we'll produce a little network of channels: + // + // ldk1 <- ldk2 <- lnd + // + // ldk1 will be the offer creator, which will build a blinded route from ldk2 to ldk1. + let (pubkey, addr) = ldk1.get_node_info(); + let (pubkey_2, addr_2) = ldk2.get_node_info(); + let lnd_info = lnd.get_info().await; + let lnd_pubkey = PublicKey::from_str(&lnd_info.identity_pubkey).unwrap(); + + ldk1.connect_to_peer(pubkey_2, addr_2).await.unwrap(); + lnd.connect_to_peer(pubkey_2, addr_2).await; + + let ldk2_fund_addr = ldk2.bitcoind_client.get_new_address().await; + let lnd_fund_addr = lnd.new_address().await.address; + + // We need to convert funding addresses to the form that the bitcoincore_rpc library recognizes. + let ldk2_addr_string = ldk2_fund_addr.to_string(); + let ldk2_addr = bitcoind::bitcoincore_rpc::bitcoin::Address::from_str(&ldk2_addr_string) + .unwrap() + .require_network(RpcNetwork::Regtest) + .unwrap(); + let lnd_addr = bitcoind::bitcoincore_rpc::bitcoin::Address::from_str(&lnd_fund_addr) + .unwrap() + .require_network(RpcNetwork::Regtest) + .unwrap(); + let lnd_network_addr = lnd + .address + .replace("localhost", "127.0.0.1") + .replace("https://", ""); + + // Fund both of these nodes, open the channels, and synchronize the network. + bitcoind + .node + .client + .generate_to_address(6, &lnd_addr) + .unwrap(); + + lnd.wait_for_chain_sync().await; + + ldk2.open_channel(pubkey, addr, 200000, 0, false) + .await + .unwrap(); + + lnd.wait_for_graph_sync().await; + + ldk2.open_channel( + lnd_pubkey, + SocketAddr::from_str(&lnd_network_addr).unwrap(), + 200000, + 10000000, + true, + ) + .await + .unwrap(); + + lnd.wait_for_graph_sync().await; + + bitcoind + .node + .client + .generate_to_address(20, &ldk2_addr) + .unwrap(); + + lnd.wait_for_chain_sync().await; + + let path_pubkeys = vec![pubkey_2, pubkey]; + let expiration = SystemTime::now() + Duration::from_secs(24 * 60 * 60); + let offer = ldk1 + .create_offer( + &path_pubkeys, + Network::Regtest, + 20_000, + Quantity::One, + expiration, + ) + .await + .expect("should create offer"); + + let (shutdown, listener) = triggered::trigger(); + let lnd_cfg = lndk::lnd::LndCfg::new( + lnd.address.clone(), + PathBuf::from_str(&lnd.cert_path).unwrap(), + PathBuf::from_str(&lnd.macaroon_path).unwrap(), + ); + let (tx, rx): (Sender, Receiver) = mpsc::channel(1); + + let signals = LifecycleSignals { + shutdown: shutdown.clone(), + listener, + started: tx, + }; + + let lndk_cfg = lndk::Cfg { + lnd: lnd_cfg, + log_dir: Some( + lndk_dir + .join(format!("lndk-logs.txt")) + .to_str() + .unwrap() + .to_string(), + ), + signals, + }; + + let messenger_utils = MessengerUtilities::new(); + let client = lnd.client.clone().unwrap(); + let blinded_path = offer.paths()[0].clone(); + let secp_ctx = Secp256k1::new(); + let reply_path = + BlindedPath::new_for_message(&[pubkey_2, lnd_pubkey], &messenger_utils, &secp_ctx).unwrap(); + + // Make sure lndk successfully sends the invoice_request. + let handler = lndk::OfferHandler::new(); + let messenger = lndk::LndkOnionMessenger::new(handler); + let pay_cfg = PayOfferParams { + offer, + amount: Some(20_000), + network: Network::Regtest, + client: client.clone(), + blinded_path: blinded_path.clone(), + reply_path: Some(reply_path), + }; + select! { + val = messenger.run(lndk_cfg) => { + panic!("lndk should not have completed first {:?}", val); + }, + // We wait for ldk2 to receive the onion message. + res = messenger.offer_handler.pay_offer(pay_cfg, rx) => { + assert!(res.is_ok()); + shutdown.trigger(); + ldk1.stop().await; + ldk2.stop().await; + } + } +}