Skip to content

Commit

Permalink
ec/suite_b: Optimize away slice bounds checks.
Browse files Browse the repository at this point in the history
Help the compiler see that COMMON_OPS.num_limbs, which is used in all
the slicing, is always less than the size of the array, so no bounds
checks need to be emitted.
  • Loading branch information
briansmith committed Dec 9, 2024
1 parent 1e539cd commit cf510c4
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 55 deletions.
4 changes: 3 additions & 1 deletion mk/generate_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
elem_sqr_mul, elem_sqr_mul_acc, Modulus, *,
};
pub(super) const NUM_LIMBS: usize = (%(bits)d + LIMB_BITS - 1) / LIMB_BITS;
pub static COMMON_OPS: CommonOps = CommonOps {
num_limbs: (%(bits)d + LIMB_BITS - 1) / LIMB_BITS,
num_limbs: elem::NumLimbs::P%(bits)s,
order_bits: %(bits)d,
q: Modulus {
Expand Down
97 changes: 53 additions & 44 deletions src/ec/suite_b/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
};
use core::marker::PhantomData;

pub use self::elem::*;
use elem::{mul_mont, unary_op, unary_op_assign, unary_op_from_binary_op_assign};

/// A field element, i.e. an element of ℤ/qℤ for the curve's field modulus
/// *q*.
Expand All @@ -44,20 +44,20 @@ pub struct Point {
// `ops.num_limbs` elements are the Y coordinate, and the next
// `ops.num_limbs` elements are the Z coordinate. This layout is dictated
// by the requirements of the nistz256 code.
xyz: [Limb; 3 * MAX_LIMBS],
xyz: [Limb; 3 * elem::NumLimbs::MAX],
}

impl Point {
pub fn new_at_infinity() -> Self {
Self {
xyz: [0; 3 * MAX_LIMBS],
xyz: [0; 3 * elem::NumLimbs::MAX],
}
}
}

/// Operations and values needed by all curve operations.
pub struct CommonOps {
num_limbs: usize,
num_limbs: elem::NumLimbs,
q: Modulus,
n: PublicElem<Unencoded>,

Expand All @@ -75,17 +75,17 @@ impl CommonOps {
// The length of a field element, which is the same as the length of a
// scalar, in bytes.
pub fn len(&self) -> usize {
self.num_limbs * LIMB_BYTES
self.num_limbs.into() * LIMB_BYTES
}

#[cfg(test)]
pub(super) fn n_limbs(&self) -> &[Limb] {
&self.n.limbs[..self.num_limbs]
&self.n.limbs[..self.num_limbs.into()]
}

#[inline]
pub fn elem_add<E: Encoding>(&self, a: &mut Elem<E>, b: &Elem<E>) {
let num_limbs = self.num_limbs;
let num_limbs = self.num_limbs.into();
limbs_add_assign_mod(
&mut a.limbs[..num_limbs],
&b.limbs[..num_limbs],
Expand All @@ -95,7 +95,8 @@ impl CommonOps {

#[inline]
pub fn elems_are_equal(&self, a: &Elem<R>, b: &Elem<R>) -> LimbMask {
limbs_equal_limbs_consttime(&a.limbs[..self.num_limbs], &b.limbs[..self.num_limbs])
let num_limbs = self.num_limbs.into();
limbs_equal_limbs_consttime(&a.limbs[..num_limbs], &b.limbs[..num_limbs])
}

#[inline]
Expand All @@ -105,7 +106,7 @@ impl CommonOps {

#[inline]
pub fn elem_mul(&self, a: &mut Elem<R>, b: &Elem<R>) {
binary_op_assign(self.elem_mul_mont, a, b)
elem::binary_op_assign(self.elem_mul_mont, a, b)
}

#[inline]
Expand All @@ -132,7 +133,8 @@ impl CommonOps {

#[inline]
pub fn is_zero<M, E: Encoding>(&self, a: &elem::Elem<M, E>) -> bool {
limbs_are_zero_constant_time(&a.limbs[..self.num_limbs]).leak()
let num_limbs = self.num_limbs.into();
limbs_are_zero_constant_time(&a.limbs[..num_limbs]).leak()
}

pub fn elem_verify_is_not_zero(&self, a: &Elem<R>) -> Result<(), error::Unspecified> {
Expand All @@ -152,28 +154,30 @@ impl CommonOps {
}

pub fn point_x(&self, p: &Point) -> Elem<R> {
let num_limbs = self.num_limbs.into();
let mut r = Elem::zero();
r.limbs[..self.num_limbs].copy_from_slice(&p.xyz[0..self.num_limbs]);
r.limbs[..num_limbs].copy_from_slice(&p.xyz[0..num_limbs]);
r
}

pub fn point_y(&self, p: &Point) -> Elem<R> {
let num_limbs = self.num_limbs.into();
let mut r = Elem::zero();
r.limbs[..self.num_limbs].copy_from_slice(&p.xyz[self.num_limbs..(2 * self.num_limbs)]);
r.limbs[..num_limbs].copy_from_slice(&p.xyz[num_limbs..(2 * num_limbs)]);
r
}

pub fn point_z(&self, p: &Point) -> Elem<R> {
let num_limbs = self.num_limbs.into();
let mut r = Elem::zero();
r.limbs[..self.num_limbs]
.copy_from_slice(&p.xyz[(2 * self.num_limbs)..(3 * self.num_limbs)]);
r.limbs[..num_limbs].copy_from_slice(&p.xyz[(2 * num_limbs)..(3 * num_limbs)]);
r
}
}

struct Modulus {
p: [LeakyLimb; MAX_LIMBS],
rr: [LeakyLimb; MAX_LIMBS],
p: [LeakyLimb; elem::NumLimbs::MAX],
rr: [LeakyLimb; elem::NumLimbs::MAX],
}

/// Operations on private keys, for ECDH and ECDSA signing.
Expand All @@ -191,7 +195,7 @@ pub struct PrivateKeyOps {

impl PrivateKeyOps {
pub fn leak_limbs<'a>(&self, a: &'a Elem<Unencoded>) -> &'a [Limb] {
&a.limbs[..self.common.num_limbs]
&a.limbs[..self.common.num_limbs.into()]
}

#[inline(always)]
Expand Down Expand Up @@ -273,7 +277,7 @@ impl ScalarOps {
}

pub fn leak_limbs<'s>(&self, s: &'s Scalar) -> &'s [Limb] {
&s.limbs[..self.common.num_limbs]
&s.limbs[..self.common.num_limbs.into()]
}

#[inline]
Expand Down Expand Up @@ -320,12 +324,12 @@ impl PublicScalarOps {
}

pub fn elem_equals_vartime(&self, a: &Elem<Unencoded>, b: &Elem<Unencoded>) -> bool {
a.limbs[..self.public_key_ops.common.num_limbs]
== b.limbs[..self.public_key_ops.common.num_limbs]
let num_limbs = self.public_key_ops.common.num_limbs.into();
a.limbs[..num_limbs] == b.limbs[..num_limbs]
}

pub fn elem_less_than(&self, a: &Elem<Unencoded>, b: &PublicElem<Unencoded>) -> bool {
let num_limbs = self.public_key_ops.common.num_limbs;
let num_limbs = self.public_key_ops.common.num_limbs.into();
limbs_less_than_limbs_vartime(&a.limbs[..num_limbs], &b.limbs[..num_limbs])
}

Expand Down Expand Up @@ -376,7 +380,7 @@ fn twin_mul_inefficient(

// This assumes n < q < 2*n.
pub fn elem_reduced_to_scalar(ops: &CommonOps, elem: &Elem<Unencoded>) -> Scalar<Unencoded> {
let num_limbs = ops.num_limbs;
let num_limbs = ops.num_limbs.into();
let mut r_limbs = elem.limbs;
limbs_reduce_once_constant_time(&mut r_limbs[..num_limbs], &ops.n.limbs[..num_limbs]);
Scalar {
Expand All @@ -387,10 +391,11 @@ pub fn elem_reduced_to_scalar(ops: &CommonOps, elem: &Elem<Unencoded>) -> Scalar
}

pub fn scalar_sum(ops: &CommonOps, a: &Scalar, mut b: Scalar) -> Scalar {
let num_limbs = ops.num_limbs.into();
limbs_add_assign_mod(
&mut b.limbs[..ops.num_limbs],
&a.limbs[..ops.num_limbs],
&ops.n.limbs[..ops.num_limbs],
&mut b.limbs[..num_limbs],
&a.limbs[..num_limbs],
&ops.n.limbs[..num_limbs],
);
b
}
Expand Down Expand Up @@ -436,13 +441,14 @@ pub fn scalar_parse_big_endian_variable(
allow_zero: AllowZero,
bytes: untrusted::Input,
) -> Result<Scalar, error::Unspecified> {
let num_limbs = ops.num_limbs.into();
let n = ops.n.limbs.map(Limb::from);
let mut r = Scalar::zero();
parse_big_endian_in_range_and_pad_consttime(
bytes,
allow_zero,
&n[..ops.num_limbs],
&mut r.limbs[..ops.num_limbs],
&n[..num_limbs],
&mut r.limbs[..num_limbs],
)?;
Ok(r)
}
Expand All @@ -451,12 +457,13 @@ pub fn scalar_parse_big_endian_partially_reduced_variable_consttime(
ops: &CommonOps,
bytes: untrusted::Input,
) -> Result<Scalar, error::Unspecified> {
let num_limbs = ops.num_limbs.into();
let mut r = Scalar::zero();

{
let r = &mut r.limbs[..ops.num_limbs];
let r = &mut r.limbs[..num_limbs];
parse_big_endian_and_pad_consttime(bytes, r)?;
limbs_reduce_once_constant_time(r, &ops.n.limbs[..ops.num_limbs]);
limbs_reduce_once_constant_time(r, &ops.n.limbs[..num_limbs]);
}

Ok(r)
Expand All @@ -466,8 +473,9 @@ fn parse_big_endian_fixed_consttime<M>(
ops: &CommonOps,
bytes: untrusted::Input,
allow_zero: AllowZero,
max_exclusive: &[LeakyLimb; MAX_LIMBS],
max_exclusive: &[LeakyLimb; elem::NumLimbs::MAX],
) -> Result<elem::Elem<M, Unencoded>, error::Unspecified> {
let num_limbs = ops.num_limbs.into();
let max_exclusive = max_exclusive.map(Limb::from);

if bytes.len() != ops.len() {
Expand All @@ -477,8 +485,8 @@ fn parse_big_endian_fixed_consttime<M>(
parse_big_endian_in_range_and_pad_consttime(
bytes,
allow_zero,
&max_exclusive[..ops.num_limbs],
&mut r.limbs[..ops.num_limbs],
&max_exclusive[..num_limbs],
&mut r.limbs[..num_limbs],
)?;
Ok(r)
}
Expand All @@ -491,7 +499,7 @@ mod tests {
use alloc::{format, vec, vec::Vec};

const ZERO_SCALAR: Scalar = Scalar {
limbs: [0; MAX_LIMBS],
limbs: [0; elem::NumLimbs::MAX],
m: PhantomData,
encoding: PhantomData,
};
Expand Down Expand Up @@ -796,7 +804,7 @@ mod tests {

{
let mut actual_result: Scalar<R> = Scalar {
limbs: [0; MAX_LIMBS],
limbs: [0; elem::NumLimbs::MAX],
m: PhantomData,
encoding: PhantomData,
};
Expand Down Expand Up @@ -1127,7 +1135,7 @@ mod tests {
}

struct AffinePoint {
xy: [Limb; 2 * MAX_LIMBS],
xy: [Limb; 2 * elem::NumLimbs::MAX],
}

fn consume_affine_point(
Expand All @@ -1139,20 +1147,20 @@ mod tests {
let elems = input.split(", ").collect::<Vec<&str>>();
assert_eq!(elems.len(), 2);
let mut p = AffinePoint {
xy: [0; 2 * MAX_LIMBS],
xy: [0; 2 * elem::NumLimbs::MAX],
};
consume_point_elem(ops.common, &mut p.xy, &elems, 0);
consume_point_elem(ops.common, &mut p.xy, &elems, 1);
p
}

fn consume_point_elem(ops: &CommonOps, limbs_out: &mut [Limb], elems: &[&str], i: usize) {
let num_limbs = ops.num_limbs.into();
let bytes = test::from_hex(elems[i]).unwrap();
let bytes = untrusted::Input::from(&bytes);
let r: Elem<Unencoded> = elem_parse_big_endian_fixed_consttime(ops, bytes).unwrap();
// XXX: “Transmute” this to `Elem<R>` limbs.
limbs_out[(i * ops.num_limbs)..((i + 1) * ops.num_limbs)]
.copy_from_slice(&r.limbs[..ops.num_limbs]);
limbs_out[(i * num_limbs)..((i + 1) * num_limbs)].copy_from_slice(&r.limbs[..num_limbs]);
}

enum TestPoint<E: Encoding> {
Expand Down Expand Up @@ -1195,17 +1203,18 @@ mod tests {

fn assert_limbs_are_equal(
ops: &CommonOps,
actual: &[Limb; MAX_LIMBS],
expected: &[Limb; MAX_LIMBS],
actual: &[Limb; elem::NumLimbs::MAX],
expected: &[Limb; elem::NumLimbs::MAX],
) {
if actual[..ops.num_limbs] != expected[..ops.num_limbs] {
let num_limbs = ops.num_limbs.into();
if actual[..num_limbs] != expected[..num_limbs] {
let mut actual_s = alloc::string::String::new();
let mut expected_s = alloc::string::String::new();
for j in 0..ops.num_limbs {
for j in 0..num_limbs {
let width = LIMB_BITS / 4;
let formatted = format!("{:0width$x}", actual[ops.num_limbs - j - 1]);
let formatted = format!("{:0width$x}", actual[num_limbs - j - 1]);
actual_s.push_str(&formatted);
let formatted = format!("{:0width$x}", expected[ops.num_limbs - j - 1]);
let formatted = format!("{:0width$x}", expected[num_limbs - j - 1]);
expected_s.push_str(&formatted);
}
panic!(
Expand Down
Loading

0 comments on commit cf510c4

Please sign in to comment.