diff --git a/src/lnd.rs b/src/lnd.rs index 124561a6..db537ca4 100644 --- a/src/lnd.rs +++ b/src/lnd.rs @@ -23,9 +23,10 @@ use tonic_lnd::lnrpc::{ GetInfoResponse, HtlcAttempt, LightningNode, ListPeersResponse, Payment, QueryRoutesResponse, Route, }; -use tonic_lnd::signrpc::KeyLocator; +use tonic_lnd::signrpc::{KeyDescriptor, KeyLocator}; use tonic_lnd::tonic::Status; use tonic_lnd::verrpc::Version; +use tonic_lnd::walletrpc::KeyReq; use tonic_lnd::{Client, ConnectError}; const ONION_MESSAGES_REQUIRED: u32 = 38; @@ -384,7 +385,7 @@ pub fn string_to_network(network_str: &str) -> Result Result, Status>; + async fn derive_next_key(&mut self, key_loc: KeyReq) -> Result; async fn sign_message( &mut self, key_loc: KeyLocator, diff --git a/src/lndk_offers.rs b/src/lndk_offers.rs index c6d63661..eff7044e 100644 --- a/src/lndk_offers.rs +++ b/src/lndk_offers.rs @@ -29,8 +29,9 @@ use tonic_lnd::lnrpc::{ ListPeersResponse, Payment, QueryRoutesResponse, Route, }; use tonic_lnd::routerrpc::TrackPaymentRequest; -use tonic_lnd::signrpc::{KeyLocator, SignMessageReq}; +use tonic_lnd::signrpc::{KeyDescriptor, KeyLocator, SignMessageReq}; use tonic_lnd::tonic::Status; +use tonic_lnd::walletrpc::KeyReq; use tonic_lnd::Client; #[derive(Debug)] @@ -183,19 +184,20 @@ impl OfferHandler { ) -> Result<(InvoiceRequest, PaymentId, u64), OfferError> { let validated_amount = validate_amount(offer.amount(), msats).await?; - // We use KeyFamily KeyFamilyNodeKey (6) to derive a key to represent our node id. See: - // https://github.com/lightningnetwork/lnd/blob/a3f8011ed695f6204ec6a13ad5c2a67ac542b109/keychain/derivation.go#L103 - let key_loc = KeyLocator { - key_family: 6, - key_index: 1, + // We use KeyFamily KeyFamilyNodeKey (3) to derive a key. For better privacy, the key + // shouldn't correspond to our node id. + // https://github.com/lightningnetwork/lnd/blob/a3f8011ed695f6204ec6a13ad5c2a67ac542b109/keychain/derivation.go#L86 + let key_loc = KeyReq { + key_family: 3, + ..Default::default() }; - let pubkey_bytes = signer - .derive_key(key_loc.clone()) + let key_descriptor = signer + .derive_next_key(key_loc.clone()) .await .map_err(OfferError::DeriveKeyFailure)?; - let pubkey = - PublicKey::from_slice(&pubkey_bytes).expect("failed to deserialize public key"); + let pubkey = PublicKey::from_slice(&key_descriptor.raw_key_bytes) + .expect("failed to deserialize public key"); // Generate a new payment id for this payment. let payment_id = PaymentId(self.messenger_utils.get_secure_random_bytes()); @@ -227,10 +229,11 @@ impl OfferHandler { // To create a valid invoice request, we also need to sign it. This is spawned in a blocking // task because we need to call block_on on sign_message so that sign_closure can be a // synchronous closure. - let invoice_request = - task::spawn_blocking(move || signer.sign_uir(key_loc, unsigned_invoice_req)) - .await - .unwrap()?; + let invoice_request = task::spawn_blocking(move || { + signer.sign_uir(key_descriptor.key_loc.unwrap(), unsigned_invoice_req) + }) + .await + .unwrap()?; { let mut active_payments = self.active_payments.lock().unwrap(); @@ -489,9 +492,9 @@ impl PeerConnector for Client { #[async_trait] impl MessageSigner for Client { - async fn derive_key(&mut self, key_loc: KeyLocator) -> Result, Status> { - match self.wallet().derive_key(key_loc).await { - Ok(resp) => Ok(resp.into_inner().raw_key_bytes), + async fn derive_next_key(&mut self, key_req: KeyReq) -> Result { + match self.wallet().derive_next_key(key_req).await { + Ok(resp) => Ok(resp.into_inner()), Err(e) => Err(e), } } @@ -600,7 +603,7 @@ impl InvoicePayer for Client { let blinded_payment_paths = tonic_lnd::lnrpc::BlindedPaymentPath { blinded_path, - total_cltv_delta: u32::from(cltv_expiry_delta) + 120, + total_cltv_delta: u32::from(cltv_expiry_delta), base_fee_msat: u64::from(fee_base_msat), proportional_fee_msat: u64::from(fee_ppm), ..Default::default() @@ -754,7 +757,7 @@ mod tests { #[async_trait] impl MessageSigner for TestBolt12Signer { - async fn derive_key(&mut self, key_loc: KeyLocator) -> Result, Status>; + async fn derive_next_key(&mut self, key_req: KeyReq) -> Result; async fn sign_message(&mut self, key_loc: KeyLocator, merkle_hash: Hash, tag: String) -> Result, Status>; fn sign_uir(&mut self, key_loc: KeyLocator, unsigned_invoice_req: UnsignedInvoiceRequest) -> Result; } @@ -786,9 +789,17 @@ mod tests { async fn test_request_invoice() { let mut signer_mock = MockTestBolt12Signer::new(); - signer_mock.expect_derive_key().returning(|_| { - let pubkey = PublicKey::from_str(&get_pubkey()).unwrap(); - Ok(pubkey.serialize().to_vec()) + signer_mock.expect_derive_next_key().returning(|_| { + Ok(KeyDescriptor { + raw_key_bytes: PublicKey::from_str(&get_pubkey()) + .unwrap() + .serialize() + .to_vec(), + key_loc: Some(KeyLocator { + key_family: 3, + ..Default::default() + }), + }) }); let offer = decode(get_offer()).unwrap(); @@ -821,7 +832,7 @@ mod tests { let mut signer_mock = MockTestBolt12Signer::new(); signer_mock - .expect_derive_key() + .expect_derive_next_key() .returning(|_| Err(Status::unknown("error testing"))); signer_mock @@ -846,11 +857,17 @@ mod tests { async fn test_request_invoice_signer_error() { let mut signer_mock = MockTestBolt12Signer::new(); - signer_mock.expect_derive_key().returning(|_| { - Ok(PublicKey::from_str(&get_pubkey()) - .unwrap() - .serialize() - .to_vec()) + signer_mock.expect_derive_next_key().returning(|_| { + Ok(KeyDescriptor { + raw_key_bytes: PublicKey::from_str(&get_pubkey()) + .unwrap() + .serialize() + .to_vec(), + key_loc: Some(KeyLocator { + key_family: 3, + ..Default::default() + }), + }) }); signer_mock diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index b1368f3f..a60f9f6c 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -453,3 +453,63 @@ async fn test_lndk_pay_multiple_offers_concurrently() { } } } + +#[tokio::test(flavor = "multi_thread")] +// Here we test that a new key is created with each call to create_invoice_request. Transient keys +// improve privacy and we also need them to successfully make multiple payments to the same CLN +// offer. +async fn test_transient_keys() { + let test_name = "transient_keys"; + let (bitcoind, mut lnd, ldk1, ldk2, lndk_dir) = + common::setup_test_infrastructure(test_name).await; + + let (ldk1_pubkey, ldk2_pubkey, _) = + common::connect_network(&ldk1, &ldk2, &mut lnd, &bitcoind).await; + + let path_pubkeys = vec![ldk2_pubkey, ldk1_pubkey]; + let expiration = SystemTime::now() + Duration::from_secs(24 * 60 * 60); + let offer = ldk1 + .create_offer( + &path_pubkeys, + Network::Regtest, + 20_000, + Quantity::One, + expiration, + ) + .await + .expect("should create offer"); + + let (lndk_cfg, handler, messenger, shutdown) = + common::setup_lndk(&lnd.cert_path, &lnd.macaroon_path, lnd.address, lndk_dir).await; + + select! { + val = messenger.run(lndk_cfg, Arc::clone(&handler)) => { + panic!("lndk should not have completed first {:?}", val); + }, + res1 = handler.create_invoice_request( + lnd.client.clone().unwrap(), + offer.clone(), + Network::Regtest, + None, + None, + ) => { + let res2 = handler.create_invoice_request( + lnd.client.clone().unwrap(), + offer.clone(), + Network::Regtest, + None, + None, + ).await; + + let pubkey1 = res1.unwrap().0.payer_id(); + let pubkey2 = res2.unwrap().0.payer_id(); + + // Verify that the signing pubkeys for each invoice request are different. + assert_ne!(pubkey1, pubkey2); + + shutdown.trigger(); + ldk1.stop().await; + ldk2.stop().await; + } + } +}