Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ml-kem: fix potential kyberslash attack #18

Merged
merged 9 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ml-kem/src/algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ pub struct FieldElement(pub Integer);
impl FieldElement {
pub const Q: Integer = 3329;
pub const Q32: u32 = Self::Q as u32;
const Q64: u64 = Self::Q as u64;
pub const Q64: u64 = Self::Q as u64;
const BARRETT_SHIFT: usize = 24;
#[allow(clippy::integer_division_remainder_used)]
const BARRETT_MULTIPLIER: u64 = (1 << Self::BARRETT_SHIFT) / Self::Q64;

// A fast modular reduction for small numbers `x < 2*q`
Expand Down Expand Up @@ -263,6 +264,7 @@ impl NttPolynomial {
#[allow(clippy::cast_possible_truncation)]
const ZETA_POW_BITREV: [FieldElement; 128] = {
const ZETA: u64 = 17;
#[allow(clippy::integer_division_remainder_used)]
const fn bitrev7(x: usize) -> usize {
((x >> 6) % 2)
| (((x >> 5) % 2) << 1)
Expand All @@ -277,6 +279,7 @@ const ZETA_POW_BITREV: [FieldElement; 128] = {
let mut pow = [FieldElement(0); 128];
let mut i = 0;
let mut curr = 1u64;
#[allow(clippy::integer_division_remainder_used)]
while i < 128 {
pow[i] = FieldElement(curr as u16);
i += 1;
Expand All @@ -300,6 +303,7 @@ const GAMMA: [FieldElement; 128] = {
let mut i = 0;
while i < 128 {
let zpr = ZETA_POW_BITREV[i].0 as u64;
#[allow(clippy::integer_division_remainder_used)]
let g = (zpr * zpr * ZETA) % FieldElement::Q64;
gamma[i] = FieldElement(g as u16);
i += 1;
Expand Down
64 changes: 37 additions & 27 deletions ml-kem/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use crate::util::Truncate;
pub trait CompressionFactor: EncodingSize {
const POW2_HALF: u32;
const MASK: Integer;
const DIV_SHIFT: u32;
const DIV_MUL: u64;
}

impl<T> CompressionFactor for T
Expand All @@ -14,6 +16,9 @@ where
{
const POW2_HALF: u32 = 1 << (T::USIZE - 1);
const MASK: Integer = ((1 as Integer) << T::USIZE) - 1;
const DIV_SHIFT: u32 = 28 + (T::U32 >> 3) * 4;
#[allow(clippy::integer_division_remainder_used)]
const DIV_MUL: u64 = (1 << T::DIV_SHIFT) / FieldElement::Q64;
}

// Traits for objects that allow compression / decompression
Expand All @@ -25,16 +30,18 @@ pub trait Compress {
impl Compress for FieldElement {
// Equation 4.5: Compress_d(x) = round((2^d / q) x)
//
// Here and in decompression, we leverage the following fact:
// Here and in decompression, we leverage the following facts:
//
// round(a / b) = floor((a + b/2) / b)
// a / q ~= (a * x) >> s where x >> s ~= 1/q
fn compress<D: CompressionFactor>(&mut self) -> &Self {
const Q_HALF: u32 = (FieldElement::Q32 - 1) / 2;
let x = u32::from(self.0);
let y = ((x << D::USIZE) + Q_HALF) / FieldElement::Q32;
const Q_HALF: u64 = (FieldElement::Q64 - 1) >> 1;
let x = u64::from(self.0);
let y = ((((x << D::USIZE) + Q_HALF) * D::DIV_MUL) >> D::DIV_SHIFT).truncate();
self.0 = y.truncate() & D::MASK;
self
}

// Equation 4.6: Decomporess_d(x) = round((q / 2^d) x)
fn decompress<D: CompressionFactor>(&mut self) -> &Self {
let x = u32::from(self.0);
Expand Down Expand Up @@ -85,36 +92,39 @@ pub(crate) mod test {
use super::*;
use hybrid_array::typenum::{U1, U10, U11, U12, U4, U5, U6};

// Verify that the integer compression routine produces the same results as rounding with
// floats.
fn compression_known_answer_test<D: CompressionFactor>() {
let fq: f64 = FieldElement::Q as f64;
let f2d: f64 = 2.0_f64.powi(D::I32);

// Verify against inequality 4.7
#[allow(clippy::integer_division_remainder_used)]
fn compression_decompression_inequality<D: CompressionFactor>() {
let half_q: i32 = i32::from(FieldElement::Q) / 2;
let error_threshold =
((f64::from(FieldElement::Q)) / f64::from(1 << (D::U32 + 1))).round() as i32;
for x in 0..FieldElement::Q {
let fx = x as f64;
let mut x = FieldElement(x);
let mut y = FieldElement(x);

supinie marked this conversation as resolved.
Show resolved Hide resolved
// Verify equivalence of compression
x.compress::<D>();
let fcx = ((f2d / fq * fx).round() as Integer) % (1 << D::USIZE);
assert_eq!(x.0, fcx);
y.compress::<D>();
y.decompress::<D>();

// Verify equivalence of decompression
x.decompress::<D>();
let fdx = (fq / f2d * (fcx as f64)).round() as Integer;
assert_eq!(x.0, fdx);
let mut error = (i32::from(y.0) - i32::from(x)) % half_q;
if error < -half_q {
error += half_q;
}

assert!(
error <= error_threshold,
"Inequality failed for x = {x}: error = {error}, error_threshold = {error_threshold}, D = {:?}",
D::USIZE
);
}
}

#[test]
fn compress_decompress() {
compression_known_answer_test::<U1>();
compression_known_answer_test::<U4>();
compression_known_answer_test::<U5>();
compression_known_answer_test::<U6>();
compression_known_answer_test::<U10>();
compression_known_answer_test::<U11>();
compression_known_answer_test::<U12>();
compression_decompression_inequality::<U1>();
supinie marked this conversation as resolved.
Show resolved Hide resolved
compression_decompression_inequality::<U4>();
compression_decompression_inequality::<U5>();
compression_decompression_inequality::<U6>();
compression_decompression_inequality::<U10>();
compression_decompression_inequality::<U11>();
compression_decompression_inequality::<U12>();
}
}
3 changes: 3 additions & 0 deletions ml-kem/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,13 @@ pub(crate) mod test {
D: ArraySize + Rem<N>,
Mod<D, N>: Zero,
{
#[allow(clippy::integer_division_remainder_used)]
fn repeat(&self) -> Array<T, D> {
Array::from_fn(|i| self[i % N::USIZE].clone())
}
}

#[allow(clippy::integer_division_remainder_used)]
fn byte_codec_test<D>(decoded: DecodedValue, encoded: EncodedPolynomial<D>)
where
D: EncodingSize,
Expand Down Expand Up @@ -247,6 +249,7 @@ pub(crate) mod test {
byte_codec_test::<U12>(decoded, encoded);
}

#[allow(clippy::integer_division_remainder_used)]
#[test]
fn byte_codec_12_mod() {
// DecodeBytes_12 is required to reduce mod q
Expand Down
1 change: 1 addition & 0 deletions ml-kem/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg"
)]
#![warn(clippy::pedantic)] // Be pedantic by default
#![warn(clippy::integer_division_remainder_used)]
supinie marked this conversation as resolved.
Show resolved Hide resolved
#![allow(non_snake_case)] // Allow notation matching the spec
#![allow(clippy::clone_on_copy)] // Be explicit about moving data
#![deny(missing_docs)] // Require all public interfaces to be documented
Expand Down
1 change: 1 addition & 0 deletions ml-kem/src/param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ where
let mut x = 0usize;
while x < max {
let mut y = 0usize;
#[allow(clippy::integer_division_remainder_used)]
while y < max {
let x_ones = x.count_ones() as u16;
let y_ones = y.count_ones() as u16;
Expand Down
Loading