diff --git a/src/cipherstate.rs b/src/cipherstate.rs index 6ef741a..5597233 100644 --- a/src/cipherstate.rs +++ b/src/cipherstate.rs @@ -167,6 +167,10 @@ impl StatelessCipherState { pub fn rekey_manually(&mut self, key: &[u8]) { self.cipher.set(key); } + + pub fn key(&self) -> &[u8] { + self.cipher.get() + } } impl From for StatelessCipherState { diff --git a/src/resolvers/default.rs b/src/resolvers/default.rs index 8c52a26..d5f8154 100644 --- a/src/resolvers/default.rs +++ b/src/resolvers/default.rs @@ -195,6 +195,10 @@ impl Cipher for CipherAesGcm { copy_slices!(key, &mut self.key) } + fn get(&self) -> &[u8] { + &self.key + } + fn encrypt(&self, nonce: u64, authtext: &[u8], plaintext: &[u8], out: &mut [u8]) -> usize { let aead = aes_gcm::Aes256Gcm::new(&self.key.into()); @@ -248,6 +252,10 @@ impl Cipher for CipherChaChaPoly { copy_slices!(key, &mut self.key); } + fn get(&self) -> &[u8] { + &self.key + } + fn encrypt(&self, nonce: u64, authtext: &[u8], plaintext: &[u8], out: &mut [u8]) -> usize { let mut nonce_bytes = [0u8; 12]; copy_slices!(nonce.to_le_bytes(), &mut nonce_bytes[4..]); diff --git a/src/stateless_transportstate.rs b/src/stateless_transportstate.rs index 6ad7f7c..230fb4e 100644 --- a/src/stateless_transportstate.rs +++ b/src/stateless_transportstate.rs @@ -140,6 +140,24 @@ impl StatelessTransportState { pub fn is_initiator(&self) -> bool { self.initiator } + + /// Gets the initiator's current symmetric key. + /// + /// This can be passed to `rekey_initiator_manually` to restore the state to + /// before a reykeying, in case you need to work with messages that have come in + /// before then. + pub fn initiator_key(&self) -> &[u8] { + self.cipherstates.0.key() + } + + /// Gets the responder's current symmetric key. + /// + /// This can be passed to `rekey_responder_manually` to restore the state to + /// before a reykeying, in case you need to work with messages that have come in + /// before then. + pub fn responder_key(&self) -> &[u8] { + self.cipherstates.1.key() + } } impl fmt::Debug for StatelessTransportState { diff --git a/src/types.rs b/src/types.rs index 3aa3bfb..e2f5720 100644 --- a/src/types.rs +++ b/src/types.rs @@ -44,6 +44,9 @@ pub trait Cipher: Send + Sync { /// Set the key fn set(&mut self, key: &[u8]); + /// Get the key + fn get(&self) -> &[u8]; + /// Encrypt (with associated data) a given plaintext. fn encrypt(&self, nonce: u64, authtext: &[u8], plaintext: &[u8], out: &mut [u8]) -> usize; diff --git a/tests/general.rs b/tests/general.rs index a919135..0c723ec 100644 --- a/tests/general.rs +++ b/tests/general.rs @@ -803,3 +803,44 @@ fn test_handshake_read_oob_error() { // This shouldn't panic, but it *should* return an error. let _ = h_i.read_message(&buffer_msg[..len], &mut buffer_out); } + +#[test] +fn test_stateless_get_set_key() { + let params: NoiseParams = "Noise_NN_25519_ChaChaPoly_SHA256".parse().unwrap(); + let mut h_i = Builder::new(params.clone()).build_initiator().unwrap(); + let mut h_r = Builder::new(params).build_responder().unwrap(); + + let mut buffer_msg = [0u8; 200]; + let mut buffer_out = [0u8; 200]; + let mut buffer_pre_rekey = [0u8; 200]; + let mut buffer_post_rekey = [0u8; 200]; + let len = h_i.write_message(b"abc", &mut buffer_msg).unwrap(); + h_r.read_message(&buffer_msg[..len], &mut buffer_out).unwrap(); + + let len = h_r.write_message(b"defg", &mut buffer_msg).unwrap(); + h_i.read_message(&buffer_msg[..len], &mut buffer_out).unwrap(); + + let mut h_i = h_i.into_stateless_transport_mode().unwrap(); + let mut h_r = h_r.into_stateless_transport_mode().unwrap(); + + let pre_len = h_i.write_message(0, b"hello world", &mut buffer_pre_rekey).unwrap(); + let pre_key = Vec::from(h_r.initiator_key()); + + h_i.rekey_outgoing(); + h_r.rekey_incoming(); + + let post_len = h_i.write_message(1, b"goodbye world", &mut buffer_post_rekey).unwrap(); + let post_key = Vec::from(h_r.initiator_key()); + + assert_ne!(pre_key, post_key); + + assert!(h_r.read_message(0, &buffer_pre_rekey[0..pre_len], &mut buffer_out).is_err()); + let len = h_r.read_message(1, &buffer_post_rekey[0..post_len], &mut buffer_out).unwrap(); + assert_eq!(&buffer_out[0..len], b"goodbye world"); + + h_r.rekey_initiator_manually(&pre_key); + + assert!(h_r.read_message(1, &buffer_post_rekey[0..post_len], &mut buffer_out).is_err()); + let len = h_r.read_message(0, &buffer_pre_rekey[0..pre_len], &mut buffer_out).unwrap(); + assert_eq!(&buffer_out[0..len], b"hello world"); +}