From 5acb0516437ced396375b6c574811ffae0dc5c16 Mon Sep 17 00:00:00 2001
From: Orbital <orbitalturtle@protonmail.com>
Date: Fri, 16 Aug 2024 16:56:19 -0500
Subject: [PATCH] offers: don't choose unadvertised node as introduction node

---
 src/lnd.rs                 |   9 +-
 src/lndk_offers.rs         | 181 ++++++++++++++++++++++++++++++++-----
 tests/common/mod.rs        |   3 +-
 tests/integration_tests.rs |  74 ++++++++++++++-
 4 files changed, 235 insertions(+), 32 deletions(-)

diff --git a/src/lnd.rs b/src/lnd.rs
index db537ca4..4770ff11 100644
--- a/src/lnd.rs
+++ b/src/lnd.rs
@@ -20,8 +20,7 @@ use std::fmt::Display;
 use std::path::PathBuf;
 use std::{fmt, fs};
 use tonic_lnd::lnrpc::{
-    GetInfoResponse, HtlcAttempt, LightningNode, ListPeersResponse, Payment, QueryRoutesResponse,
-    Route,
+    GetInfoResponse, HtlcAttempt, ListPeersResponse, NodeInfo, Payment, QueryRoutesResponse, Route,
 };
 use tonic_lnd::signrpc::{KeyDescriptor, KeyLocator};
 use tonic_lnd::tonic::Status;
@@ -404,7 +403,11 @@ pub trait MessageSigner {
 pub trait PeerConnector {
     async fn list_peers(&mut self) -> Result<ListPeersResponse, Status>;
     async fn connect_peer(&mut self, node_id: String, addr: String) -> Result<(), Status>;
-    async fn get_node_info(&mut self, pub_key: String) -> Result<Option<LightningNode>, Status>;
+    async fn get_node_info(
+        &mut self,
+        pub_key: String,
+        include_channels: bool,
+    ) -> Result<NodeInfo, Status>;
 }
 
 /// InvoicePayer provides a layer of abstraction over the LND API for paying for a BOLT 12 invoice.
diff --git a/src/lndk_offers.rs b/src/lndk_offers.rs
index b8283382..8934180f 100644
--- a/src/lndk_offers.rs
+++ b/src/lndk_offers.rs
@@ -25,8 +25,8 @@ use std::fmt::Display;
 use std::str::FromStr;
 use tokio::task;
 use tonic_lnd::lnrpc::{
-    ChanInfoRequest, GetInfoRequest, HtlcAttempt, LightningNode, ListPeersRequest,
-    ListPeersResponse, Payment, QueryRoutesResponse, Route,
+    ChanInfoRequest, GetInfoRequest, HtlcAttempt, ListPeersRequest, ListPeersResponse, NodeInfo,
+    Payment, QueryRoutesResponse, Route,
 };
 use tonic_lnd::routerrpc::TrackPaymentRequest;
 use tonic_lnd::signrpc::{KeyDescriptor, KeyLocator, SignMessageReq};
@@ -251,10 +251,14 @@ impl OfferHandler {
         Ok((invoice_request, payment_id, validated_amount))
     }
 
-    /// create_reply_path creates a blinded path to provide to the offer maker when requesting an
+    /// create_reply_path creates a blinded path to provide to the offer node when requesting an
     /// invoice so they know where to send the invoice back to. We try to find a peer that we're
-    /// connected to with onion messaging support that we can use to form a blinded path,
-    /// otherwise we creae a blinded path directly to ourselves.
+    /// connected to with the necessary requirements to form a blinded path. The peer needs two
+    /// things:
+    /// 1) Onion messaging support.
+    /// 2) To be an advertised node with at least one public channel.
+    ///
+    /// Otherwise we create a blinded path directly to ourselves.
     pub async fn create_reply_path(
         &self,
         mut connector: impl PeerConnector + std::marker::Send + 'static,
@@ -271,7 +275,18 @@ impl OfferHandler {
             let pubkey = PublicKey::from_str(&peer.pub_key).unwrap();
             let onion_support = features_support_onion_messages(&peer.features);
             if onion_support {
+                // We also need to check that the candidate introduction node is actually an
+                // advertised node with at least one public channel.
+                match connector.get_node_info(peer.pub_key, true).await {
+                    Ok(node) => {
+                        if node.channels.is_empty() {
+                            continue;
+                        }
+                    }
+                    Err(_) => continue,
+                };
                 intro_node = Some(pubkey);
+                break;
             }
         }
 
@@ -426,11 +441,11 @@ pub async fn connect_to_peer(
     }
 
     let node = connector
-        .get_node_info(node_id_str.clone())
+        .get_node_info(node_id_str.clone(), false)
         .await
         .map_err(OfferError::PeerConnectError)?;
 
-    let node = match node {
+    let node = match node.node {
         Some(node) => node,
         None => return Err(OfferError::NodeAddressNotFound),
     };
@@ -477,16 +492,20 @@ impl PeerConnector for Client {
             .map(|_| ())
     }
 
-    async fn get_node_info(&mut self, pub_key: String) -> Result<Option<LightningNode>, Status> {
+    async fn get_node_info(
+        &mut self,
+        pub_key: String,
+        include_channels: bool,
+    ) -> Result<NodeInfo, Status> {
         let req = tonic_lnd::lnrpc::NodeInfoRequest {
             pub_key,
-            include_channels: false,
+            include_channels,
         };
 
         self.lightning()
             .get_node_info(req)
             .await
-            .map(|resp| resp.into_inner().node)
+            .map(|resp| resp.into_inner())
     }
 }
 
@@ -693,10 +712,11 @@ mod tests {
     use lightning::offers::merkle::SignError;
     use lightning::offers::offer::{OfferBuilder, Quantity};
     use mockall::mock;
+    use mockall::predicate::eq;
     use std::collections::HashMap;
     use std::str::FromStr;
     use std::time::{Duration, SystemTime};
-    use tonic_lnd::lnrpc::{NodeAddress, Payment};
+    use tonic_lnd::lnrpc::{ChannelEdge, LightningNode, NodeAddress, Payment};
 
     fn get_offer() -> String {
         "lno1qgsqvgnwgcg35z6ee2h3yczraddm72xrfua9uve2rlrm9deu7xyfzrcgqgn3qzsyvfkx26qkyypvr5hfx60h9w9k934lt8s2n6zc0wwtgqlulw7dythr83dqx8tzumg".to_string()
@@ -773,7 +793,7 @@ mod tests {
          #[async_trait]
          impl PeerConnector for TestPeerConnector {
              async fn list_peers(&mut self) -> Result<ListPeersResponse, Status>;
-             async fn get_node_info(&mut self, pub_key: String) -> Result<Option<LightningNode>, Status>;
+             async fn get_node_info(&mut self, pub_key: String, include_channels: bool) -> Result<NodeInfo, Status>;
              async fn connect_peer(&mut self, node_id: String, addr: String) -> Result<(), Status>;
          }
     }
@@ -923,17 +943,22 @@ mod tests {
             })
         });
 
-        connector_mock.expect_get_node_info().returning(|_| {
+        connector_mock.expect_get_node_info().returning(|_, _| {
             let node_addr = NodeAddress {
                 network: String::from("regtest"),
                 addr: String::from("127.0.0.1"),
             };
-            let node = LightningNode {
+            let node = Some(LightningNode {
                 addresses: vec![node_addr],
                 ..Default::default()
+            });
+
+            let node_info = NodeInfo {
+                node,
+                ..Default::default()
             };
 
-            Ok(Some(node))
+            Ok(node_info)
         });
 
         connector_mock
@@ -959,17 +984,20 @@ mod tests {
             })
         });
 
-        connector_mock.expect_get_node_info().returning(|_| {
+        connector_mock.expect_get_node_info().returning(|_, _| {
             let node_addr = NodeAddress {
                 network: String::from("regtest"),
                 addr: String::from("127.0.0.1"),
             };
-            let node = LightningNode {
+            let node = Some(LightningNode {
                 addresses: vec![node_addr],
                 ..Default::default()
-            };
+            });
 
-            Ok(Some(node))
+            Ok(NodeInfo {
+                node,
+                ..Default::default()
+            })
         });
 
         let pubkey = PublicKey::from_str(&get_pubkeys()[0]).unwrap();
@@ -985,17 +1013,20 @@ mod tests {
             })
         });
 
-        connector_mock.expect_get_node_info().returning(|_| {
+        connector_mock.expect_get_node_info().returning(|_, _| {
             let node_addr = NodeAddress {
                 network: String::from("regtest"),
                 addr: String::from("127.0.0.1"),
             };
-            let node = LightningNode {
+            let node = Some(LightningNode {
                 addresses: vec![node_addr],
                 ..Default::default()
-            };
+            });
 
-            Ok(Some(node))
+            Ok(NodeInfo {
+                node,
+                ..Default::default()
+            })
         });
 
         connector_mock
@@ -1025,6 +1056,17 @@ mod tests {
             Ok(ListPeersResponse { peers: vec![peer] })
         });
 
+        connector_mock.expect_get_node_info().returning(|_, _| {
+            let node = Some(LightningNode {
+                ..Default::default()
+            });
+
+            Ok(NodeInfo {
+                node,
+                ..Default::default()
+            })
+        });
+
         let receiver_node_id = PublicKey::from_str(&get_pubkeys()[0]).unwrap();
         let handler = OfferHandler::default();
         assert!(handler
@@ -1066,6 +1108,99 @@ mod tests {
             .is_err())
     }
 
+    #[tokio::test]
+    async fn test_create_reply_path_not_advertised() {
+        // First lets test that if we're only connected to one peer. It has onion support, but the
+        // node isn't advertised, meaning it has no public channels. This should return
+        // a blinded path with only one hop.
+        let mut connector_mock = MockTestPeerConnector::new();
+        connector_mock.expect_list_peers().returning(|| {
+            let feature = tonic_lnd::lnrpc::Feature {
+                ..Default::default()
+            };
+            let mut feature_entry = HashMap::new();
+            feature_entry.insert(38, feature);
+
+            let peer = tonic_lnd::lnrpc::Peer {
+                pub_key: get_pubkeys()[0].clone(),
+                features: feature_entry,
+                ..Default::default()
+            };
+            Ok(ListPeersResponse { peers: vec![peer] })
+        });
+
+        connector_mock
+            .expect_get_node_info()
+            .returning(|_, _| Err(Status::not_found("node was not found")));
+
+        let receiver_node_id = PublicKey::from_str(&get_pubkeys()[0]).unwrap();
+        let handler = OfferHandler::default();
+        let resp = handler
+            .create_reply_path(connector_mock, receiver_node_id)
+            .await;
+        assert!(resp.is_ok());
+        assert!(resp.unwrap().blinded_hops.len() == 1);
+
+        // Now let's test that we have two peers that both have onion support feature flags set.
+        // One isn't advertised (i.e. it has no public channels). But the second is. This
+        // should succeed.
+        let mut connector_mock = MockTestPeerConnector::new();
+        connector_mock.expect_list_peers().returning(|| {
+            let feature = tonic_lnd::lnrpc::Feature {
+                ..Default::default()
+            };
+            let mut feature_entry = HashMap::new();
+            feature_entry.insert(38, feature);
+
+            let keys = get_pubkeys();
+
+            let peer1 = tonic_lnd::lnrpc::Peer {
+                pub_key: keys[0].clone(),
+                features: feature_entry.clone(),
+                ..Default::default()
+            };
+            let peer2 = tonic_lnd::lnrpc::Peer {
+                pub_key: keys[1].clone(),
+                features: feature_entry,
+                ..Default::default()
+            };
+            Ok(ListPeersResponse {
+                peers: vec![peer1, peer2],
+            })
+        });
+
+        let keys = get_pubkeys();
+        connector_mock
+            .expect_get_node_info()
+            .with(eq(keys[0].clone()), eq(true))
+            .returning(|_, _| Err(Status::not_found("node was not found")));
+
+        connector_mock
+            .expect_get_node_info()
+            .with(eq(keys[1].clone()), eq(true))
+            .returning(|_, _| {
+                let node = Some(LightningNode {
+                    ..Default::default()
+                });
+
+                Ok(NodeInfo {
+                    node,
+                    channels: vec![ChannelEdge {
+                        ..Default::default()
+                    }],
+                    ..Default::default()
+                })
+            });
+
+        let receiver_node_id = PublicKey::from_str(&get_pubkeys()[0]).unwrap();
+        let handler = OfferHandler::default();
+        let resp = handler
+            .create_reply_path(connector_mock, receiver_node_id)
+            .await;
+        assert!(resp.is_ok());
+        assert!(resp.unwrap().blinded_hops.len() == 2);
+    }
+
     #[tokio::test]
     async fn test_send_payment() {
         let mut payer_mock = MockTestInvoicePayer::new();
diff --git a/tests/common/mod.rs b/tests/common/mod.rs
index 37ea513c..18479901 100644
--- a/tests/common/mod.rs
+++ b/tests/common/mod.rs
@@ -82,6 +82,7 @@ pub async fn setup_test_infrastructure(
 pub async fn connect_network(
     ldk1: &LdkNode,
     ldk2: &LdkNode,
+    announce_channel: bool,
     lnd: &mut LndNode,
     bitcoind: &BitcoindNode,
 ) -> (PublicKey, PublicKey, PublicKey) {
@@ -136,7 +137,7 @@ pub async fn connect_network(
         SocketAddr::from_str(&lnd_network_addr).unwrap(),
         200000,
         10000000,
-        true,
+        announce_channel,
     )
     .await
     .unwrap();
diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs
index a60f9f6c..d9bdd8c1 100644
--- a/tests/integration_tests.rs
+++ b/tests/integration_tests.rs
@@ -7,7 +7,7 @@ use bitcoin::Network;
 use bitcoincore_rpc::bitcoin::Network as RpcNetwork;
 use bitcoincore_rpc::RpcApi;
 use ldk_sample::node_api::Node as LdkNode;
-use lightning::blinded_path::BlindedPath;
+use lightning::blinded_path::{BlindedPath, IntroductionNode};
 use lightning::offers::offer::Quantity;
 use lightning::onion_message::messenger::Destination;
 use lndk::lnd::validate_lnd_creds;
@@ -309,7 +309,7 @@ async fn test_lndk_pay_offer() {
         common::setup_test_infrastructure(test_name).await;
 
     let (ldk1_pubkey, ldk2_pubkey, lnd_pubkey) =
-        common::connect_network(&ldk1, &ldk2, &mut lnd, &bitcoind).await;
+        common::connect_network(&ldk1, &ldk2, true, &mut lnd, &bitcoind).await;
 
     let path_pubkeys = vec![ldk2_pubkey, ldk1_pubkey];
     let expiration = SystemTime::now() + Duration::from_secs(24 * 60 * 60);
@@ -366,7 +366,7 @@ async fn test_lndk_pay_offer_concurrently() {
         common::setup_test_infrastructure(test_name).await;
 
     let (ldk1_pubkey, ldk2_pubkey, lnd_pubkey) =
-        common::connect_network(&ldk1, &ldk2, &mut lnd, &bitcoind).await;
+        common::connect_network(&ldk1, &ldk2, true, &mut lnd, &bitcoind).await;
 
     let path_pubkeys = vec![ldk2_pubkey, ldk1_pubkey];
     let expiration = SystemTime::now() + Duration::from_secs(24 * 60 * 60);
@@ -424,7 +424,7 @@ async fn test_lndk_pay_multiple_offers_concurrently() {
         common::setup_test_infrastructure(test_name).await;
 
     let (ldk1_pubkey, ldk2_pubkey, lnd_pubkey) =
-        common::connect_network(&ldk1, &ldk2, &mut lnd, &bitcoind).await;
+        common::connect_network(&ldk1, &ldk2, true, &mut lnd, &bitcoind).await;
 
     let path_pubkeys = &vec![ldk2_pubkey, ldk1_pubkey];
     let reply_path = &vec![ldk2_pubkey, lnd_pubkey];
@@ -464,7 +464,7 @@ async fn test_transient_keys() {
         common::setup_test_infrastructure(test_name).await;
 
     let (ldk1_pubkey, ldk2_pubkey, _) =
-        common::connect_network(&ldk1, &ldk2, &mut lnd, &bitcoind).await;
+        common::connect_network(&ldk1, &ldk2, true, &mut lnd, &bitcoind).await;
 
     let path_pubkeys = vec![ldk2_pubkey, ldk1_pubkey];
     let expiration = SystemTime::now() + Duration::from_secs(24 * 60 * 60);
@@ -513,3 +513,67 @@ async fn test_transient_keys() {
         }
     }
 }
+
+#[tokio::test(flavor = "multi_thread")]
+// We test that when creating a reply path for an offer node to send an invoice to, we don't
+// use a node that we're connected to as the introduction node if it's an unadvertised node that
+// is only connected by private channels.
+async fn test_reply_path_unannounced_peers() {
+    let test_name = "unannounced_peers";
+    let (bitcoind, mut lnd, ldk1, ldk2, lndk_dir) =
+        common::setup_test_infrastructure(test_name).await;
+
+    let (_, _, lnd_pubkey) =
+        common::connect_network(&ldk1, &ldk2, false, &mut lnd, &bitcoind).await;
+
+    let (_, handler, _, shutdown) =
+        common::setup_lndk(&lnd.cert_path, &lnd.macaroon_path, lnd.address, lndk_dir).await;
+
+    // In the small network we produced above, the lnd node is only connected to ldk2, which has a
+    // private channel and as such, is an unadvertised node. Because of that, create_reply_path
+    // should not use ldk2 as an introduction node and should return a reply path directly to
+    // itself.
+    let reply_path = handler
+        .create_reply_path(lnd.client.clone().unwrap(), lnd_pubkey)
+        .await;
+    assert!(reply_path.is_ok());
+    assert_eq!(reply_path.unwrap().blinded_hops.len(), 1);
+
+    shutdown.trigger();
+    ldk1.stop().await;
+    ldk2.stop().await;
+}
+
+#[tokio::test(flavor = "multi_thread")]
+// We test that when creating a reply path for an offer node to send an invoice to, we successfully
+// use a node that we're connected to as the introduction node *if* it's an advertised node with
+// public channels.
+async fn test_reply_path_announced_peers() {
+    let test_name = "announced_peers";
+    let (bitcoind, mut lnd, ldk1, ldk2, lndk_dir) =
+        common::setup_test_infrastructure(test_name).await;
+
+    let (_, ldk2_pubkey, lnd_pubkey) =
+        common::connect_network(&ldk1, &ldk2, true, &mut lnd, &bitcoind).await;
+
+    let (_, handler, _, shutdown) =
+        common::setup_lndk(&lnd.cert_path, &lnd.macaroon_path, lnd.address, lndk_dir).await;
+
+    // In the small network we produced above, the lnd node is only connected to ldk2, which has a
+    // public channel and as such, is indeed an advertised node. Because of this, we make sure
+    // create_reply_path produces a path of length two with ldk2 as the introduction node, as we
+    // expected.
+    let reply_path = handler
+        .create_reply_path(lnd.client.clone().unwrap(), lnd_pubkey)
+        .await;
+    assert!(reply_path.is_ok());
+    assert_eq!(reply_path.as_ref().unwrap().blinded_hops.len(), 2);
+    assert_eq!(
+        reply_path.as_ref().unwrap().introduction_node,
+        IntroductionNode::NodeId(ldk2_pubkey)
+    );
+
+    shutdown.trigger();
+    ldk1.stop().await;
+    ldk2.stop().await;
+}