From cf510c46ff244ef72ecd9efa023b6b786999ebf9 Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Mon, 9 Dec 2024 10:27:31 -0800 Subject: [PATCH] ec/suite_b: Optimize away slice bounds checks. Help the compiler see that COMMON_OPS.num_limbs, which is used in all the slicing, is always less than the size of the array, so no bounds checks need to be emitted. --- mk/generate_curves.py | 4 +- src/ec/suite_b/ops.rs | 97 +++++++++++++++++++++----------------- src/ec/suite_b/ops/elem.rs | 34 +++++++++---- src/ec/suite_b/ops/p256.rs | 4 +- src/ec/suite_b/ops/p384.rs | 4 +- 5 files changed, 88 insertions(+), 55 deletions(-) diff --git a/mk/generate_curves.py b/mk/generate_curves.py index 08e2596b4d..86a882401a 100644 --- a/mk/generate_curves.py +++ b/mk/generate_curves.py @@ -36,8 +36,10 @@ elem_sqr_mul, elem_sqr_mul_acc, Modulus, *, }; +pub(super) const NUM_LIMBS: usize = (%(bits)d + LIMB_BITS - 1) / LIMB_BITS; + pub static COMMON_OPS: CommonOps = CommonOps { - num_limbs: (%(bits)d + LIMB_BITS - 1) / LIMB_BITS, + num_limbs: elem::NumLimbs::P%(bits)s, order_bits: %(bits)d, q: Modulus { diff --git a/src/ec/suite_b/ops.rs b/src/ec/suite_b/ops.rs index 3b18367505..e11703ac58 100644 --- a/src/ec/suite_b/ops.rs +++ b/src/ec/suite_b/ops.rs @@ -18,7 +18,7 @@ use crate::{ }; use core::marker::PhantomData; -pub use self::elem::*; +use elem::{mul_mont, unary_op, unary_op_assign, unary_op_from_binary_op_assign}; /// A field element, i.e. an element of ℤ/qℤ for the curve's field modulus /// *q*. @@ -44,20 +44,20 @@ pub struct Point { // `ops.num_limbs` elements are the Y coordinate, and the next // `ops.num_limbs` elements are the Z coordinate. This layout is dictated // by the requirements of the nistz256 code. - xyz: [Limb; 3 * MAX_LIMBS], + xyz: [Limb; 3 * elem::NumLimbs::MAX], } impl Point { pub fn new_at_infinity() -> Self { Self { - xyz: [0; 3 * MAX_LIMBS], + xyz: [0; 3 * elem::NumLimbs::MAX], } } } /// Operations and values needed by all curve operations. pub struct CommonOps { - num_limbs: usize, + num_limbs: elem::NumLimbs, q: Modulus, n: PublicElem, @@ -75,17 +75,17 @@ impl CommonOps { // The length of a field element, which is the same as the length of a // scalar, in bytes. pub fn len(&self) -> usize { - self.num_limbs * LIMB_BYTES + self.num_limbs.into() * LIMB_BYTES } #[cfg(test)] pub(super) fn n_limbs(&self) -> &[Limb] { - &self.n.limbs[..self.num_limbs] + &self.n.limbs[..self.num_limbs.into()] } #[inline] pub fn elem_add(&self, a: &mut Elem, b: &Elem) { - let num_limbs = self.num_limbs; + let num_limbs = self.num_limbs.into(); limbs_add_assign_mod( &mut a.limbs[..num_limbs], &b.limbs[..num_limbs], @@ -95,7 +95,8 @@ impl CommonOps { #[inline] pub fn elems_are_equal(&self, a: &Elem, b: &Elem) -> LimbMask { - limbs_equal_limbs_consttime(&a.limbs[..self.num_limbs], &b.limbs[..self.num_limbs]) + let num_limbs = self.num_limbs.into(); + limbs_equal_limbs_consttime(&a.limbs[..num_limbs], &b.limbs[..num_limbs]) } #[inline] @@ -105,7 +106,7 @@ impl CommonOps { #[inline] pub fn elem_mul(&self, a: &mut Elem, b: &Elem) { - binary_op_assign(self.elem_mul_mont, a, b) + elem::binary_op_assign(self.elem_mul_mont, a, b) } #[inline] @@ -132,7 +133,8 @@ impl CommonOps { #[inline] pub fn is_zero(&self, a: &elem::Elem) -> bool { - limbs_are_zero_constant_time(&a.limbs[..self.num_limbs]).leak() + let num_limbs = self.num_limbs.into(); + limbs_are_zero_constant_time(&a.limbs[..num_limbs]).leak() } pub fn elem_verify_is_not_zero(&self, a: &Elem) -> Result<(), error::Unspecified> { @@ -152,28 +154,30 @@ impl CommonOps { } pub fn point_x(&self, p: &Point) -> Elem { + let num_limbs = self.num_limbs.into(); let mut r = Elem::zero(); - r.limbs[..self.num_limbs].copy_from_slice(&p.xyz[0..self.num_limbs]); + r.limbs[..num_limbs].copy_from_slice(&p.xyz[0..num_limbs]); r } pub fn point_y(&self, p: &Point) -> Elem { + let num_limbs = self.num_limbs.into(); let mut r = Elem::zero(); - r.limbs[..self.num_limbs].copy_from_slice(&p.xyz[self.num_limbs..(2 * self.num_limbs)]); + r.limbs[..num_limbs].copy_from_slice(&p.xyz[num_limbs..(2 * num_limbs)]); r } pub fn point_z(&self, p: &Point) -> Elem { + let num_limbs = self.num_limbs.into(); let mut r = Elem::zero(); - r.limbs[..self.num_limbs] - .copy_from_slice(&p.xyz[(2 * self.num_limbs)..(3 * self.num_limbs)]); + r.limbs[..num_limbs].copy_from_slice(&p.xyz[(2 * num_limbs)..(3 * num_limbs)]); r } } struct Modulus { - p: [LeakyLimb; MAX_LIMBS], - rr: [LeakyLimb; MAX_LIMBS], + p: [LeakyLimb; elem::NumLimbs::MAX], + rr: [LeakyLimb; elem::NumLimbs::MAX], } /// Operations on private keys, for ECDH and ECDSA signing. @@ -191,7 +195,7 @@ pub struct PrivateKeyOps { impl PrivateKeyOps { pub fn leak_limbs<'a>(&self, a: &'a Elem) -> &'a [Limb] { - &a.limbs[..self.common.num_limbs] + &a.limbs[..self.common.num_limbs.into()] } #[inline(always)] @@ -273,7 +277,7 @@ impl ScalarOps { } pub fn leak_limbs<'s>(&self, s: &'s Scalar) -> &'s [Limb] { - &s.limbs[..self.common.num_limbs] + &s.limbs[..self.common.num_limbs.into()] } #[inline] @@ -320,12 +324,12 @@ impl PublicScalarOps { } pub fn elem_equals_vartime(&self, a: &Elem, b: &Elem) -> bool { - a.limbs[..self.public_key_ops.common.num_limbs] - == b.limbs[..self.public_key_ops.common.num_limbs] + let num_limbs = self.public_key_ops.common.num_limbs.into(); + a.limbs[..num_limbs] == b.limbs[..num_limbs] } pub fn elem_less_than(&self, a: &Elem, b: &PublicElem) -> bool { - let num_limbs = self.public_key_ops.common.num_limbs; + let num_limbs = self.public_key_ops.common.num_limbs.into(); limbs_less_than_limbs_vartime(&a.limbs[..num_limbs], &b.limbs[..num_limbs]) } @@ -376,7 +380,7 @@ fn twin_mul_inefficient( // This assumes n < q < 2*n. pub fn elem_reduced_to_scalar(ops: &CommonOps, elem: &Elem) -> Scalar { - let num_limbs = ops.num_limbs; + let num_limbs = ops.num_limbs.into(); let mut r_limbs = elem.limbs; limbs_reduce_once_constant_time(&mut r_limbs[..num_limbs], &ops.n.limbs[..num_limbs]); Scalar { @@ -387,10 +391,11 @@ pub fn elem_reduced_to_scalar(ops: &CommonOps, elem: &Elem) -> Scalar } pub fn scalar_sum(ops: &CommonOps, a: &Scalar, mut b: Scalar) -> Scalar { + let num_limbs = ops.num_limbs.into(); limbs_add_assign_mod( - &mut b.limbs[..ops.num_limbs], - &a.limbs[..ops.num_limbs], - &ops.n.limbs[..ops.num_limbs], + &mut b.limbs[..num_limbs], + &a.limbs[..num_limbs], + &ops.n.limbs[..num_limbs], ); b } @@ -436,13 +441,14 @@ pub fn scalar_parse_big_endian_variable( allow_zero: AllowZero, bytes: untrusted::Input, ) -> Result { + let num_limbs = ops.num_limbs.into(); let n = ops.n.limbs.map(Limb::from); let mut r = Scalar::zero(); parse_big_endian_in_range_and_pad_consttime( bytes, allow_zero, - &n[..ops.num_limbs], - &mut r.limbs[..ops.num_limbs], + &n[..num_limbs], + &mut r.limbs[..num_limbs], )?; Ok(r) } @@ -451,12 +457,13 @@ pub fn scalar_parse_big_endian_partially_reduced_variable_consttime( ops: &CommonOps, bytes: untrusted::Input, ) -> Result { + let num_limbs = ops.num_limbs.into(); let mut r = Scalar::zero(); { - let r = &mut r.limbs[..ops.num_limbs]; + let r = &mut r.limbs[..num_limbs]; parse_big_endian_and_pad_consttime(bytes, r)?; - limbs_reduce_once_constant_time(r, &ops.n.limbs[..ops.num_limbs]); + limbs_reduce_once_constant_time(r, &ops.n.limbs[..num_limbs]); } Ok(r) @@ -466,8 +473,9 @@ fn parse_big_endian_fixed_consttime( ops: &CommonOps, bytes: untrusted::Input, allow_zero: AllowZero, - max_exclusive: &[LeakyLimb; MAX_LIMBS], + max_exclusive: &[LeakyLimb; elem::NumLimbs::MAX], ) -> Result, error::Unspecified> { + let num_limbs = ops.num_limbs.into(); let max_exclusive = max_exclusive.map(Limb::from); if bytes.len() != ops.len() { @@ -477,8 +485,8 @@ fn parse_big_endian_fixed_consttime( parse_big_endian_in_range_and_pad_consttime( bytes, allow_zero, - &max_exclusive[..ops.num_limbs], - &mut r.limbs[..ops.num_limbs], + &max_exclusive[..num_limbs], + &mut r.limbs[..num_limbs], )?; Ok(r) } @@ -491,7 +499,7 @@ mod tests { use alloc::{format, vec, vec::Vec}; const ZERO_SCALAR: Scalar = Scalar { - limbs: [0; MAX_LIMBS], + limbs: [0; elem::NumLimbs::MAX], m: PhantomData, encoding: PhantomData, }; @@ -796,7 +804,7 @@ mod tests { { let mut actual_result: Scalar = Scalar { - limbs: [0; MAX_LIMBS], + limbs: [0; elem::NumLimbs::MAX], m: PhantomData, encoding: PhantomData, }; @@ -1127,7 +1135,7 @@ mod tests { } struct AffinePoint { - xy: [Limb; 2 * MAX_LIMBS], + xy: [Limb; 2 * elem::NumLimbs::MAX], } fn consume_affine_point( @@ -1139,7 +1147,7 @@ mod tests { let elems = input.split(", ").collect::>(); assert_eq!(elems.len(), 2); let mut p = AffinePoint { - xy: [0; 2 * MAX_LIMBS], + xy: [0; 2 * elem::NumLimbs::MAX], }; consume_point_elem(ops.common, &mut p.xy, &elems, 0); consume_point_elem(ops.common, &mut p.xy, &elems, 1); @@ -1147,12 +1155,12 @@ mod tests { } fn consume_point_elem(ops: &CommonOps, limbs_out: &mut [Limb], elems: &[&str], i: usize) { + let num_limbs = ops.num_limbs.into(); let bytes = test::from_hex(elems[i]).unwrap(); let bytes = untrusted::Input::from(&bytes); let r: Elem = elem_parse_big_endian_fixed_consttime(ops, bytes).unwrap(); // XXX: “Transmute” this to `Elem` limbs. - limbs_out[(i * ops.num_limbs)..((i + 1) * ops.num_limbs)] - .copy_from_slice(&r.limbs[..ops.num_limbs]); + limbs_out[(i * num_limbs)..((i + 1) * num_limbs)].copy_from_slice(&r.limbs[..num_limbs]); } enum TestPoint { @@ -1195,17 +1203,18 @@ mod tests { fn assert_limbs_are_equal( ops: &CommonOps, - actual: &[Limb; MAX_LIMBS], - expected: &[Limb; MAX_LIMBS], + actual: &[Limb; elem::NumLimbs::MAX], + expected: &[Limb; elem::NumLimbs::MAX], ) { - if actual[..ops.num_limbs] != expected[..ops.num_limbs] { + let num_limbs = ops.num_limbs.into(); + if actual[..num_limbs] != expected[..num_limbs] { let mut actual_s = alloc::string::String::new(); let mut expected_s = alloc::string::String::new(); - for j in 0..ops.num_limbs { + for j in 0..num_limbs { let width = LIMB_BITS / 4; - let formatted = format!("{:0width$x}", actual[ops.num_limbs - j - 1]); + let formatted = format!("{:0width$x}", actual[num_limbs - j - 1]); actual_s.push_str(&formatted); - let formatted = format!("{:0width$x}", expected[ops.num_limbs - j - 1]); + let formatted = format!("{:0width$x}", expected[num_limbs - j - 1]); expected_s.push_str(&formatted); } panic!( diff --git a/src/ec/suite_b/ops/elem.rs b/src/ec/suite_b/ops/elem.rs index f63a56fc69..25eb36b17c 100644 --- a/src/ec/suite_b/ops/elem.rs +++ b/src/ec/suite_b/ops/elem.rs @@ -12,21 +12,41 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +use crate::ec::suite_b::ops::{ + p256::NUM_LIMBS as P256_NUM_LIMBS, p384::NUM_LIMBS as P384_NUM_LIMBS, +}; use crate::{ arithmetic::{ limbs_from_hex, montgomery::{Encoding, ProductEncoding, Unencoded}, }, - limb::{LeakyLimb, Limb, LIMB_BITS}, + limb::{LeakyLimb, Limb}, }; use core::marker::PhantomData; +#[derive(Clone, Copy)] +pub(super) enum NumLimbs { + P256, + P384, +} + +impl NumLimbs { + pub(super) const MAX: usize = Self::P384.into(); + + pub(super) const fn into(self) -> usize { + match self { + NumLimbs::P256 => P256_NUM_LIMBS, + NumLimbs::P384 => P384_NUM_LIMBS, + } + } +} + /// Elements of ℤ/mℤ for some modulus *m*. Elements are always fully reduced /// with respect to *m*; i.e. the 0 <= x < m for every value x. #[derive(Clone, Copy)] pub struct Elem { // XXX: pub - pub(super) limbs: [Limb; MAX_LIMBS], + pub(super) limbs: [Limb; NumLimbs::MAX], /// The modulus *m* for the ring ℤ/mℤ for which this element is a value. pub(super) m: PhantomData, @@ -37,7 +57,7 @@ pub struct Elem { } pub struct PublicElem { - pub(super) limbs: [LeakyLimb; MAX_LIMBS], + pub(super) limbs: [LeakyLimb; NumLimbs::MAX], pub(super) m: PhantomData, pub(super) encoding: PhantomData, } @@ -58,7 +78,7 @@ impl Elem { // as inputs for constructing a zero-valued element. pub fn zero() -> Self { Self { - limbs: [0; MAX_LIMBS], + limbs: [0; NumLimbs::MAX], m: PhantomData, encoding: PhantomData, } @@ -103,7 +123,7 @@ pub fn binary_op( b: &Elem, ) -> Elem { let mut r = Elem { - limbs: [0; MAX_LIMBS], + limbs: [0; NumLimbs::MAX], m: PhantomData, encoding: PhantomData, }; @@ -128,7 +148,7 @@ pub fn unary_op( a: &Elem, ) -> Elem { let mut r = Elem { - limbs: [0; MAX_LIMBS], + limbs: [0; NumLimbs::MAX], m: PhantomData, encoding: PhantomData, }; @@ -153,5 +173,3 @@ pub fn unary_op_from_binary_op_assign( ) { unsafe { f(a.limbs.as_mut_ptr(), a.limbs.as_ptr(), a.limbs.as_ptr()) } } - -pub const MAX_LIMBS: usize = (384 + (LIMB_BITS - 1)) / LIMB_BITS; diff --git a/src/ec/suite_b/ops/p256.rs b/src/ec/suite_b/ops/p256.rs index 853a0d6ff9..d586587a47 100644 --- a/src/ec/suite_b/ops/p256.rs +++ b/src/ec/suite_b/ops/p256.rs @@ -17,8 +17,10 @@ use super::{ elem_sqr_mul, elem_sqr_mul_acc, Modulus, *, }; +pub(super) const NUM_LIMBS: usize = 256 / LIMB_BITS; + pub static COMMON_OPS: CommonOps = CommonOps { - num_limbs: 256 / LIMB_BITS, + num_limbs: elem::NumLimbs::P256, q: Modulus { p: limbs_from_hex("ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"), diff --git a/src/ec/suite_b/ops/p384.rs b/src/ec/suite_b/ops/p384.rs index ec77829b73..faef4e07c5 100644 --- a/src/ec/suite_b/ops/p384.rs +++ b/src/ec/suite_b/ops/p384.rs @@ -17,8 +17,10 @@ use super::{ elem_sqr_mul, elem_sqr_mul_acc, Modulus, *, }; +pub(super) const NUM_LIMBS: usize = 384 / LIMB_BITS; + pub static COMMON_OPS: CommonOps = CommonOps { - num_limbs: 384 / LIMB_BITS, + num_limbs: elem::NumLimbs::P384, q: Modulus { p: limbs_from_hex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff"),