Skip to content

Commit

Permalink
[signalapp#474] add docstrings and doctests to crypto.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmicexplorer committed Sep 9, 2022
1 parent 87bca22 commit 0a1bc67
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 178 deletions.
279 changes: 191 additions & 88 deletions rust/protocol/src/crypto.rs
Original file line number Diff line number Diff line change
@@ -1,153 +1,256 @@
//
// Copyright 2020 Signal Messenger, LLC.
// Copyright 2020-2022 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//

use std::convert::TryInto;
use std::result::Result;
//! Symmetric cryptographic primitives, including [HMAC] and [AES].
//!
//! AES may be used in either [CBC] or [CTR] modes via the
//! `aes_256_{cbc,ctr}_{en,de}crypt()` methods.
//!
//! [HMAC]: https://en.wikipedia.org/wiki/HMAC
//! [AES]: https://en.wikipedia.org/wiki/Advanced_Encryption_Standard
//! [CBC]: https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#CBC
//! [CTR]: https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#CTR
#![warn(missing_docs)]

use aes::cipher::{NewCipher, StreamCipher};
use aes::{Aes256, Aes256Ctr};
use arrayref::array_ref;
use block_modes::block_padding::Pkcs7;
use block_modes::{BlockMode, Cbc};
use displaydoc::Display;
use hmac::{Hmac, Mac, NewMac};
use sha2::Sha256;
use subtle::ConstantTimeEq;
use thiserror::Error;

#[derive(Debug)]
pub(crate) enum EncryptionError {
/// The key or IV is the wrong length.
BadKeyOrIv,
}

#[derive(Debug)]
pub(crate) enum DecryptionError {
/// The key or IV is the wrong length.
BadKeyOrIv,
/// Failures that may occur during symmetric decryption.
#[derive(Debug, Display, Error)]
#[ignore_extra_doc_attributes]
pub enum DecryptionError {
/// bad ciphertext: {0}
///
/// Either the input is malformed, or the MAC doesn't match on decryption.
///
/// These cases should not be distinguished; message corruption can cause either problem.
BadCiphertext(&'static str),
}

fn aes_256_ctr_encrypt(ptext: &[u8], key: &[u8]) -> Result<Vec<u8>, EncryptionError> {
let key: [u8; 32] = key.try_into().map_err(|_| EncryptionError::BadKeyOrIv)?;
/// TODO: Could be nice to have a type-safe library for manipulating units of bytes safely.
const BITS_PER_BYTE: usize = std::mem::size_of::<u8>() * 8;

let zero_nonce = [0u8; 16];
let mut cipher = Aes256Ctr::new(key[..].into(), zero_nonce[..].into());
/// The length of the key we use for AES encryption in this crate.
pub const AES_256_KEY_SIZE: usize = 256 / BITS_PER_BYTE;

let mut ctext = ptext.to_vec();
cipher.apply_keystream(&mut ctext);
Ok(ctext)
}
/// The size of the generated nonce we use for AES encryption in this crate.
pub const AES_NONCE_SIZE: usize = 128 / BITS_PER_BYTE;

fn aes_256_ctr_decrypt(ctext: &[u8], key: &[u8]) -> Result<Vec<u8>, DecryptionError> {
aes_256_ctr_encrypt(ctext, key).map_err(|e| match e {
EncryptionError::BadKeyOrIv => DecryptionError::BadKeyOrIv,
})
}
/// Use AES-256 in CTR mode.
///
///```
/// # fn main() -> Result<(), libsignal_protocol::crypto::DecryptionError> {
/// use rand::{Rng, rngs::OsRng};
/// use libsignal_protocol::crypto::{self, ctr};
///
/// let key: [u8; crypto::AES_256_KEY_SIZE] = (&mut OsRng).gen();
/// let ptext: [u8; 30] = (&mut OsRng).gen();
///
/// // Without an HMAC.
/// let ctext = ctr::aes_256_ctr_encrypt(&ptext, &key);
/// let decrypted = ctr::aes_256_ctr_decrypt(&ctext, &key);
/// assert_eq!(ptext.as_ref(), &decrypted);
///
/// // With an HMAC.
/// let mac_key: [u8; ctr::MAC_KEY_LENGTH] = (&mut OsRng).gen();
/// let ctext = ctr::aes_256_ctr_hmac_encrypt(&ptext, &key, &mac_key);
/// let decrypted = ctr::aes_256_ctr_hmac_decrypt(&ctext, &key, &mac_key)?;
/// assert_eq!(ptext.as_ref(), &decrypted);
/// # Ok(())
/// # }
///```
pub mod ctr {
use super::*;

/// Encrypt plaintext `ptext` using key `key` with AES-256 in CTR mode.
pub fn aes_256_ctr_encrypt(ptext: &[u8], key: &[u8; AES_256_KEY_SIZE]) -> Vec<u8> {
let zero_nonce = [0u8; AES_NONCE_SIZE];
let mut cipher = Aes256Ctr::new(key[..].into(), zero_nonce[..].into());

let mut ctext = ptext.to_vec();
cipher.apply_keystream(&mut ctext);
ctext
}

/// Decrypt ciphertext `ctext` using key `key` with AES-256 in CTR mode.
pub fn aes_256_ctr_decrypt(ctext: &[u8], key: &[u8; AES_256_KEY_SIZE]) -> Vec<u8> {
aes_256_ctr_encrypt(ctext, key)
}

/// Length in bytes of the [`Hmac`] key used for [`aes_256_ctr_hmac_encrypt`] and
/// [`aes_256_ctr_hmac_decrypt`].
pub const MAC_KEY_LENGTH: usize = 80 / BITS_PER_BYTE;

/// Encrypt plaintext `msg` with AES-256 and embed a computed HMAC into the returned bytes.
///
/// *Implementation note: within the body of this method, only the first [`MAC_KEY_LENGTH`]
/// bytes of the computed MAC are used.*
pub fn aes_256_ctr_hmac_encrypt(
msg: &[u8],
cipher_key: &[u8; AES_256_KEY_SIZE],
mac_key: &[u8; MAC_KEY_LENGTH],
) -> Vec<u8> {
let mut ctext = aes_256_ctr_encrypt(msg, cipher_key);
let mac = hmac_sha256(mac_key, &ctext);
ctext.extend_from_slice(&mac[..MAC_KEY_LENGTH]);
ctext
}

pub(crate) fn aes_256_cbc_encrypt(
ptext: &[u8],
key: &[u8],
iv: &[u8],
) -> Result<Vec<u8>, EncryptionError> {
match Cbc::<Aes256, Pkcs7>::new_from_slices(key, iv) {
Ok(mode) => Ok(mode.encrypt_vec(ptext)),
Err(block_modes::InvalidKeyIvLength) => Err(EncryptionError::BadKeyOrIv),
/// Validate the HMAC `mac_key` against the ciphertext `ctext`, then decrypt `ctext` using
/// AES-256 with `cipher_key` and [`aes_256_ctr_decrypt`].
///
/// *Implementation note: the last [`MAC_KEY_LENGTH`] bytes of the `ctext` slice represent the
/// truncated HMAC of the rest of the message, as generated by [`aes_256_ctr_hmac_encrypt`].*
pub fn aes_256_ctr_hmac_decrypt(
ctext: &[u8],
cipher_key: &[u8; AES_256_KEY_SIZE],
mac_key: &[u8; MAC_KEY_LENGTH],
) -> Result<Vec<u8>, DecryptionError> {
if ctext.len() < MAC_KEY_LENGTH {
return Err(DecryptionError::BadCiphertext("truncated ciphertext"));
}
let (ctext, ctext_mac) = ctext.split_at(ctext.len() - MAC_KEY_LENGTH);
let ctext_mac = array_ref![ctext_mac, 0, MAC_KEY_LENGTH];
let our_mac = hmac_sha256(mac_key, ctext);
let our_mac = array_ref![&our_mac, 0, MAC_KEY_LENGTH];
let same: bool = our_mac.ct_eq(ctext_mac).into();
dbg!(our_mac);
dbg!(ctext_mac);
if !same {
return Err(DecryptionError::BadCiphertext("MAC verification failed"));
}
Ok(aes_256_ctr_decrypt(ctext, cipher_key))
}
}

pub(crate) fn aes_256_cbc_decrypt(
ctext: &[u8],
key: &[u8],
iv: &[u8],
) -> Result<Vec<u8>, DecryptionError> {
if ctext.is_empty() || ctext.len() % 16 != 0 {
return Err(DecryptionError::BadCiphertext(
"ciphertext length must be a non-zero multiple of 16",
));
/// Use AES-256 in CBC mode.
///
///```
/// # fn main() -> Result<(), libsignal_protocol::crypto::DecryptionError> {
/// use rand::{Rng, rngs::OsRng};
/// use libsignal_protocol::crypto::{self, cbc};
///
/// let key: [u8; crypto::AES_256_KEY_SIZE] = (&mut OsRng).gen();
/// let iv: [u8; crypto::AES_NONCE_SIZE] = (&mut OsRng).gen();
/// let ptext: [u8; 30] = (&mut OsRng).gen();
///
/// let ctext = cbc::aes_256_cbc_encrypt(&ptext, &key, &iv);
/// let decrypted = cbc::aes_256_cbc_decrypt(&ctext, &key, &iv)?;
/// assert_eq!(ptext.as_ref(), &decrypted);
/// # Ok(())
/// # }
///```
pub mod cbc {
use super::*;

/// Encrypt plaintext `ptext` using key `key` and initialization vector `iv` with AES-256 in
/// CBC mode.
pub fn aes_256_cbc_encrypt(
ptext: &[u8],
key: &[u8; AES_256_KEY_SIZE],
iv: &[u8; AES_NONCE_SIZE],
) -> Vec<u8> {
let mode =
Cbc::<Aes256, Pkcs7>::new_from_slices(key, iv).expect("key and iv were fixed length");
mode.encrypt_vec(ptext)
}

let mode =
Cbc::<Aes256, Pkcs7>::new_from_slices(key, iv).map_err(|_| DecryptionError::BadKeyOrIv)?;
mode.decrypt_vec(ctext)
.map_err(|_| DecryptionError::BadCiphertext("failed to decrypt"))
/// Decrypt ciphertext `ctext` using key `key` and initialization vector `iv` with AES-256 in
/// CBC mode.
pub fn aes_256_cbc_decrypt(
ctext: &[u8],
key: &[u8; AES_256_KEY_SIZE],
iv: &[u8; AES_NONCE_SIZE],
) -> Result<Vec<u8>, DecryptionError> {
if ctext.is_empty() || ctext.len() % 16 != 0 {
return Err(DecryptionError::BadCiphertext(
"ciphertext length must be a non-zero multiple of 16",
));
}

let mode = Cbc::<Aes256, Pkcs7>::new_from_slices(key.as_ref(), iv.as_ref())
.expect("key and iv were fixed length");
mode.decrypt_vec(ctext)
.map_err(|_| DecryptionError::BadCiphertext("failed to decrypt"))
}
}

pub(crate) fn hmac_sha256(key: &[u8], input: &[u8]) -> [u8; 32] {
/// The statically-known size of the output of [`hmac_sha256`].
pub const HMAC_OUTPUT_SIZE: usize = 256 / BITS_PER_BYTE;

/// Calculate the [`Hmac`]-[`Sha256`] code over `input` using `key`.
pub fn hmac_sha256(key: &[u8], input: &[u8]) -> [u8; HMAC_OUTPUT_SIZE] {
let mut hmac =
Hmac::<Sha256>::new_from_slice(key).expect("HMAC-SHA256 should accept any size key");
hmac.update(input);
hmac.finalize().into_bytes().into()
}

pub(crate) fn aes256_ctr_hmacsha256_encrypt(
msg: &[u8],
cipher_key: &[u8],
mac_key: &[u8],
) -> Result<Vec<u8>, EncryptionError> {
let mut ctext = aes_256_ctr_encrypt(msg, cipher_key)?;
let mac = hmac_sha256(mac_key, &ctext);
ctext.extend_from_slice(&mac[..10]);
Ok(ctext)
}

pub(crate) fn aes256_ctr_hmacsha256_decrypt(
ctext: &[u8],
cipher_key: &[u8],
mac_key: &[u8],
) -> Result<Vec<u8>, DecryptionError> {
if ctext.len() < 10 {
return Err(DecryptionError::BadCiphertext("truncated ciphertext"));
}
let ptext_len = ctext.len() - 10;
let our_mac = hmac_sha256(mac_key, &ctext[..ptext_len]);
let same: bool = our_mac[..10].ct_eq(&ctext[ptext_len..]).into();
if !same {
return Err(DecryptionError::BadCiphertext("MAC verification failed"));
}
aes_256_ctr_decrypt(&ctext[..ptext_len], cipher_key)
}

#[cfg(test)]
mod test {
use super::*;

use std::convert::TryInto;

#[test]
fn aes_cbc_test() {
let key = hex::decode("4e22eb16d964779994222e82192ce9f747da72dc4abe49dfdeeb71d0ffe3796e")
.expect("valid hex");
let iv = hex::decode("6f8a557ddc0a140c878063a6d5f31d3d").expect("valid hex");
let key: [u8; AES_256_KEY_SIZE] =
hex::decode("4e22eb16d964779994222e82192ce9f747da72dc4abe49dfdeeb71d0ffe3796e")
.expect("valid hex")
.try_into()
.expect("correct array size");
let iv: [u8; AES_NONCE_SIZE] = hex::decode("6f8a557ddc0a140c878063a6d5f31d3d")
.expect("valid hex")
.try_into()
.expect("correct array size");

let ptext = hex::decode("30736294a124482a4159").expect("valid hex");

let ctext = aes_256_cbc_encrypt(&ptext, &key, &iv).expect("valid key and IV");
let ctext = cbc::aes_256_cbc_encrypt(&ptext, &key, &iv);
assert_eq!(
hex::encode(ctext.clone()),
"dd3f573ab4508b9ed0e45e0baf5608f3"
);

let recovered = aes_256_cbc_decrypt(&ctext, &key, &iv).expect("valid");
let recovered = cbc::aes_256_cbc_decrypt(&ctext, &key, &iv).expect("valid");
assert_eq!(hex::encode(ptext), hex::encode(recovered.clone()));

// padding is invalid:
assert!(aes_256_cbc_decrypt(&recovered, &key, &iv).is_err());
assert!(aes_256_cbc_decrypt(&ctext, &key, &ctext).is_err());
assert!(cbc::aes_256_cbc_decrypt(&recovered, &key, &iv).is_err());
assert!(
cbc::aes_256_cbc_decrypt(&ctext, &key, array_ref![&ctext, 0, AES_NONCE_SIZE]).is_err()
);

// bitflip the IV to cause a change in the recovered text
let bad_iv = hex::decode("ef8a557ddc0a140c878063a6d5f31d3d").expect("valid hex");
let recovered = aes_256_cbc_decrypt(&ctext, &key, &bad_iv).expect("still valid");
let bad_iv: [u8; AES_NONCE_SIZE] = hex::decode("ef8a557ddc0a140c878063a6d5f31d3d")
.expect("valid hex")
.try_into()
.expect("correct array size");
let recovered = cbc::aes_256_cbc_decrypt(&ctext, &key, &bad_iv).expect("still valid");
assert_eq!(hex::encode(recovered), "b0736294a124482a4159");
}

#[test]
fn aes_ctr_test() {
let key = hex::decode("603DEB1015CA71BE2B73AEF0857D77811F352C073B6108D72D9810A30914DFF4")
.expect("valid hex");
let key: [u8; AES_256_KEY_SIZE] =
hex::decode("603DEB1015CA71BE2B73AEF0857D77811F352C073B6108D72D9810A30914DFF4")
.expect("valid hex")
.try_into()
.expect("correct array size");
let ptext = [0u8; 35];

let ctext = aes_256_ctr_encrypt(&ptext, &key).expect("valid key");
let ctext = ctr::aes_256_ctr_encrypt(&ptext, &key);
assert_eq!(
hex::encode(ctext),
"e568f68194cf76d6174d4cc04310a85491151e5d0b7a1f1bc0d7acd0ae3e51e4170e23"
Expand Down
Loading

0 comments on commit 0a1bc67

Please sign in to comment.