From 3bd831af90fe288faeeae106d4d332a0fe9cabe3 Mon Sep 17 00:00:00 2001
From: Orbital <orbitalturtle@protonmail.com>
Date: Wed, 10 Jul 2024 20:59:17 -0500
Subject: [PATCH] multi: move to derive_new_key for key gen

We should use a new key to sign each invoice request to improve sender privacy
and to ensure we can pay a particular CLN offer more than once.
---
 src/lnd.rs                 |  5 +--
 src/lndk_offers.rs         | 71 +++++++++++++++++++++++---------------
 tests/integration_tests.rs | 60 ++++++++++++++++++++++++++++++++
 3 files changed, 107 insertions(+), 29 deletions(-)

diff --git a/src/lnd.rs b/src/lnd.rs
index 124561a6..db537ca4 100644
--- a/src/lnd.rs
+++ b/src/lnd.rs
@@ -23,9 +23,10 @@ use tonic_lnd::lnrpc::{
     GetInfoResponse, HtlcAttempt, LightningNode, ListPeersResponse, Payment, QueryRoutesResponse,
     Route,
 };
-use tonic_lnd::signrpc::KeyLocator;
+use tonic_lnd::signrpc::{KeyDescriptor, KeyLocator};
 use tonic_lnd::tonic::Status;
 use tonic_lnd::verrpc::Version;
+use tonic_lnd::walletrpc::KeyReq;
 use tonic_lnd::{Client, ConnectError};
 
 const ONION_MESSAGES_REQUIRED: u32 = 38;
@@ -384,7 +385,7 @@ pub fn string_to_network(network_str: &str) -> Result<Network, NetworkParseError
 /// MessageSigner provides a layer of abstraction over the LND API for message signing.
 #[async_trait]
 pub trait MessageSigner {
-    async fn derive_key(&mut self, key_loc: KeyLocator) -> Result<Vec<u8>, Status>;
+    async fn derive_next_key(&mut self, key_loc: KeyReq) -> Result<KeyDescriptor, Status>;
     async fn sign_message(
         &mut self,
         key_loc: KeyLocator,
diff --git a/src/lndk_offers.rs b/src/lndk_offers.rs
index c6d63661..2a24cfdf 100644
--- a/src/lndk_offers.rs
+++ b/src/lndk_offers.rs
@@ -29,8 +29,9 @@ use tonic_lnd::lnrpc::{
     ListPeersResponse, Payment, QueryRoutesResponse, Route,
 };
 use tonic_lnd::routerrpc::TrackPaymentRequest;
-use tonic_lnd::signrpc::{KeyLocator, SignMessageReq};
+use tonic_lnd::signrpc::{KeyDescriptor, KeyLocator, SignMessageReq};
 use tonic_lnd::tonic::Status;
+use tonic_lnd::walletrpc::KeyReq;
 use tonic_lnd::Client;
 
 #[derive(Debug)]
@@ -183,19 +184,20 @@ impl OfferHandler {
     ) -> Result<(InvoiceRequest, PaymentId, u64), OfferError> {
         let validated_amount = validate_amount(offer.amount(), msats).await?;
 
-        // We use KeyFamily KeyFamilyNodeKey (6) to derive a key to represent our node id. See:
-        // https://github.com/lightningnetwork/lnd/blob/a3f8011ed695f6204ec6a13ad5c2a67ac542b109/keychain/derivation.go#L103
-        let key_loc = KeyLocator {
-            key_family: 6,
-            key_index: 1,
+        // We use KeyFamily KeyFamilyNodeKey (3) to derive a key. For better privacy, the key
+        // shouldn't correspond to our node id.
+        // https://github.com/lightningnetwork/lnd/blob/a3f8011ed695f6204ec6a13ad5c2a67ac542b109/keychain/derivation.go#L86
+        let key_loc = KeyReq {
+            key_family: 3,
+            ..Default::default()
         };
 
-        let pubkey_bytes = signer
-            .derive_key(key_loc.clone())
+        let key_descriptor = signer
+            .derive_next_key(key_loc.clone())
             .await
             .map_err(OfferError::DeriveKeyFailure)?;
-        let pubkey =
-            PublicKey::from_slice(&pubkey_bytes).expect("failed to deserialize public key");
+        let pubkey = PublicKey::from_slice(&key_descriptor.raw_key_bytes)
+            .expect("failed to deserialize public key");
 
         // Generate a new payment id for this payment.
         let payment_id = PaymentId(self.messenger_utils.get_secure_random_bytes());
@@ -227,10 +229,11 @@ impl OfferHandler {
         // To create a valid invoice request, we also need to sign it. This is spawned in a blocking
         // task because we need to call block_on on sign_message so that sign_closure can be a
         // synchronous closure.
-        let invoice_request =
-            task::spawn_blocking(move || signer.sign_uir(key_loc, unsigned_invoice_req))
-                .await
-                .unwrap()?;
+        let invoice_request = task::spawn_blocking(move || {
+            signer.sign_uir(key_descriptor.key_loc.unwrap(), unsigned_invoice_req)
+        })
+        .await
+        .unwrap()?;
 
         {
             let mut active_payments = self.active_payments.lock().unwrap();
@@ -489,9 +492,9 @@ impl PeerConnector for Client {
 
 #[async_trait]
 impl MessageSigner for Client {
-    async fn derive_key(&mut self, key_loc: KeyLocator) -> Result<Vec<u8>, Status> {
-        match self.wallet().derive_key(key_loc).await {
-            Ok(resp) => Ok(resp.into_inner().raw_key_bytes),
+    async fn derive_next_key(&mut self, key_req: KeyReq) -> Result<KeyDescriptor, Status> {
+        match self.wallet().derive_next_key(key_req).await {
+            Ok(resp) => Ok(resp.into_inner()),
             Err(e) => Err(e),
         }
     }
@@ -754,7 +757,7 @@ mod tests {
 
          #[async_trait]
          impl MessageSigner for TestBolt12Signer {
-             async fn derive_key(&mut self, key_loc: KeyLocator) -> Result<Vec<u8>, Status>;
+             async fn derive_next_key(&mut self, key_req: KeyReq) -> Result<KeyDescriptor, Status>;
              async fn sign_message(&mut self, key_loc: KeyLocator, merkle_hash: Hash, tag: String) -> Result<Vec<u8>, Status>;
              fn sign_uir(&mut self, key_loc: KeyLocator, unsigned_invoice_req: UnsignedInvoiceRequest) -> Result<InvoiceRequest, OfferError>;
          }
@@ -786,9 +789,17 @@ mod tests {
     async fn test_request_invoice() {
         let mut signer_mock = MockTestBolt12Signer::new();
 
-        signer_mock.expect_derive_key().returning(|_| {
-            let pubkey = PublicKey::from_str(&get_pubkey()).unwrap();
-            Ok(pubkey.serialize().to_vec())
+        signer_mock.expect_derive_next_key().returning(|_| {
+            Ok(KeyDescriptor {
+                raw_key_bytes: PublicKey::from_str(&get_pubkey())
+                    .unwrap()
+                    .serialize()
+                    .to_vec(),
+                key_loc: Some(KeyLocator {
+                    key_family: 3,
+                    ..Default::default()
+                }),
+            })
         });
 
         let offer = decode(get_offer()).unwrap();
@@ -821,7 +832,7 @@ mod tests {
         let mut signer_mock = MockTestBolt12Signer::new();
 
         signer_mock
-            .expect_derive_key()
+            .expect_derive_next_key()
             .returning(|_| Err(Status::unknown("error testing")));
 
         signer_mock
@@ -846,11 +857,17 @@ mod tests {
     async fn test_request_invoice_signer_error() {
         let mut signer_mock = MockTestBolt12Signer::new();
 
-        signer_mock.expect_derive_key().returning(|_| {
-            Ok(PublicKey::from_str(&get_pubkey())
-                .unwrap()
-                .serialize()
-                .to_vec())
+        signer_mock.expect_derive_next_key().returning(|_| {
+            Ok(KeyDescriptor {
+                raw_key_bytes: PublicKey::from_str(&get_pubkey())
+                    .unwrap()
+                    .serialize()
+                    .to_vec(),
+                key_loc: Some(KeyLocator {
+                    key_family: 3,
+                    ..Default::default()
+                }),
+            })
         });
 
         signer_mock
diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs
index b1368f3f..a60f9f6c 100644
--- a/tests/integration_tests.rs
+++ b/tests/integration_tests.rs
@@ -453,3 +453,63 @@ async fn test_lndk_pay_multiple_offers_concurrently() {
         }
     }
 }
+
+#[tokio::test(flavor = "multi_thread")]
+// Here we test that a new key is created with each call to create_invoice_request. Transient keys
+// improve privacy and we also need them to successfully make multiple payments to the same CLN
+// offer.
+async fn test_transient_keys() {
+    let test_name = "transient_keys";
+    let (bitcoind, mut lnd, ldk1, ldk2, lndk_dir) =
+        common::setup_test_infrastructure(test_name).await;
+
+    let (ldk1_pubkey, ldk2_pubkey, _) =
+        common::connect_network(&ldk1, &ldk2, &mut lnd, &bitcoind).await;
+
+    let path_pubkeys = vec![ldk2_pubkey, ldk1_pubkey];
+    let expiration = SystemTime::now() + Duration::from_secs(24 * 60 * 60);
+    let offer = ldk1
+        .create_offer(
+            &path_pubkeys,
+            Network::Regtest,
+            20_000,
+            Quantity::One,
+            expiration,
+        )
+        .await
+        .expect("should create offer");
+
+    let (lndk_cfg, handler, messenger, shutdown) =
+        common::setup_lndk(&lnd.cert_path, &lnd.macaroon_path, lnd.address, lndk_dir).await;
+
+    select! {
+        val = messenger.run(lndk_cfg, Arc::clone(&handler)) => {
+            panic!("lndk should not have completed first {:?}", val);
+        },
+        res1 = handler.create_invoice_request(
+            lnd.client.clone().unwrap(),
+            offer.clone(),
+            Network::Regtest,
+            None,
+            None,
+        ) => {
+            let res2 = handler.create_invoice_request(
+                lnd.client.clone().unwrap(),
+                offer.clone(),
+                Network::Regtest,
+                None,
+                None,
+            ).await;
+
+            let pubkey1 = res1.unwrap().0.payer_id();
+            let pubkey2 = res2.unwrap().0.payer_id();
+
+            // Verify that the signing pubkeys for each invoice request are different.
+            assert_ne!(pubkey1, pubkey2);
+
+            shutdown.trigger();
+            ldk1.stop().await;
+            ldk2.stop().await;
+        }
+    }
+}