Skip to content

Commit

Permalink
remove default from, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
benr-ml committed Mar 4, 2024
1 parent 352ed61 commit ff9ee0c
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 47 deletions.
81 changes: 40 additions & 41 deletions fastcrypto/src/groups/bls12381.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ use blst::{
blst_scalar, blst_scalar_fr_check, blst_scalar_from_be_bytes, blst_scalar_from_bendian,
blst_scalar_from_fr, p1_affines, p2_affines, BLS12_381_G1, BLS12_381_G2, BLST_ERROR,
};
use derive_more::From;
use fastcrypto_derive::GroupOpsExtend;
use hex_literal::hex;
use once_cell::sync::OnceCell;
Expand All @@ -36,22 +35,22 @@ use std::ops::{Add, Div, Mul, Neg, Sub};
use std::ptr;

/// Elements of the group G_1 in BLS 12-381.
#[derive(From, Clone, Copy, Eq, PartialEq, GroupOpsExtend)]
#[derive(Clone, Copy, Eq, PartialEq, GroupOpsExtend)]
#[repr(transparent)]
pub struct G1Element(blst_p1);

/// Elements of the group G_2 in BLS 12-381.
#[derive(From, Clone, Copy, Eq, PartialEq, GroupOpsExtend)]
#[derive(Clone, Copy, Eq, PartialEq, GroupOpsExtend)]
#[repr(transparent)]
pub struct G2Element(blst_p2);

/// Elements of the subgroup G_T of F_q^{12} in BLS 12-381. Note that it is written in additive notation here.
#[derive(From, Clone, Copy, Eq, PartialEq, GroupOpsExtend)]
#[derive(Clone, Copy, Eq, PartialEq, GroupOpsExtend)]
pub struct GTElement(blst_fp12);

/// This represents a scalar modulo r = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
/// which is the order of the groups G1, G2 and GT. Note that r is a 255 bit prime.
#[derive(From, Clone, Copy, Eq, PartialEq, GroupOpsExtend)]
#[derive(Clone, Copy, Eq, PartialEq, GroupOpsExtend)]
pub struct Scalar(blst_fr);

pub const SCALAR_LENGTH: usize = 32;
Expand All @@ -68,7 +67,7 @@ impl Add for G1Element {
unsafe {
blst_p1_add_or_double(&mut ret, &self.0, &rhs.0);
}
Self::from(ret)
Self(ret)
}
}

Expand All @@ -88,14 +87,14 @@ impl Neg for G1Element {
unsafe {
blst_p1_cneg(&mut ret, true);
}
Self::from(ret)
Self(ret)
}
}

/// The size of this scalar in bytes.
fn size_in_bytes(scalar: &blst_scalar) -> usize {
let mut i = scalar.b.len();
assert_eq!(i, 32);
debug_assert_eq!(i, 32);
while i != 0 && scalar.b[i - 1] == 0 {
i -= 1;
}
Expand Down Expand Up @@ -152,7 +151,7 @@ impl Mul<Scalar> for G1Element {
);
}

Self::from(result)
Self(result)
}
}

Expand All @@ -174,7 +173,7 @@ impl MultiScalarMul for G1Element {
}
// The scalar field size is smaller than 2^255, so we need at most 255 bits.
let res = points.mult(scalar_bytes.as_slice(), 255);
Ok(Self::from(res))
Ok(Self(res))
}
}

Expand All @@ -190,15 +189,15 @@ impl GroupElement for G1Element {
type ScalarType = Scalar;

fn zero() -> Self {
Self::from(blst_p1::default())
Self(blst_p1::default())
}

fn generator() -> Self {
let mut ret = blst_p1::default();
unsafe {
blst_p1_from_affine(&mut ret, &BLS12_381_G1);
}
Self::from(ret)
Self(ret)
}
}

Expand All @@ -216,7 +215,7 @@ impl Pairing for G1Element {
blst_miller_loop(&mut res, &other_affine, &self_affine);
blst_final_exp(&mut res, &res);
}
<Self as Pairing>::Output::from(res)
GTElement(res)
}
}

Expand All @@ -234,7 +233,7 @@ impl HashToGroupElement for G1Element {
0,
);
}
Self::from(res)
Self(res)
}
}

Expand All @@ -252,7 +251,7 @@ impl ToFromByteArray<G1_ELEMENT_BYTE_LENGTH> for G1Element {
return Err(FastCryptoError::InvalidInput);
}
}
Ok(G1Element::from(ret))
Ok(G1Element(ret))
}

fn to_byte_array(&self) -> [u8; G1_ELEMENT_BYTE_LENGTH] {
Expand Down Expand Up @@ -282,7 +281,7 @@ impl Add for G2Element {
unsafe {
blst_p2_add_or_double(&mut ret, &self.0, &rhs.0);
}
Self::from(ret)
Self(ret)
}
}

Expand All @@ -302,7 +301,7 @@ impl Neg for G2Element {
unsafe {
blst_p2_cneg(&mut ret, true);
}
Self::from(ret)
Self(ret)
}
}

Expand Down Expand Up @@ -347,7 +346,7 @@ impl Mul<Scalar> for G2Element {
);
}

Self::from(result)
Self(result)
}
}

Expand All @@ -369,23 +368,23 @@ impl MultiScalarMul for G2Element {
}
// The scalar field size is smaller than 2^255, so we need at most 255 bits.
let res = points.mult(scalar_bytes.as_slice(), 255);
Ok(Self::from(res))
Ok(Self(res))
}
}

impl GroupElement for G2Element {
type ScalarType = Scalar;

fn zero() -> Self {
Self::from(blst_p2::default())
Self(blst_p2::default())
}

fn generator() -> Self {
let mut ret = blst_p2::default();
unsafe {
blst_p2_from_affine(&mut ret, &BLS12_381_G2);
}
Self::from(ret)
Self(ret)
}
}

Expand All @@ -403,7 +402,7 @@ impl HashToGroupElement for G2Element {
0,
);
}
Self::from(res)
Self(res)
}
}

Expand All @@ -421,7 +420,7 @@ impl ToFromByteArray<G2_ELEMENT_BYTE_LENGTH> for G2Element {
return Err(FastCryptoError::InvalidInput);
}
}
Ok(G2Element::from(ret))
Ok(G2Element(ret))
}

fn to_byte_array(&self) -> [u8; G2_ELEMENT_BYTE_LENGTH] {
Expand Down Expand Up @@ -451,7 +450,7 @@ impl Add for GTElement {
unsafe {
blst_fp12_mul(&mut ret, &self.0, &rhs.0);
}
Self::from(ret)
Self(ret)
}
}

Expand All @@ -471,7 +470,7 @@ impl Neg for GTElement {
unsafe {
blst_fp12_inverse(&mut ret, &self.0);
}
Self::from(ret)
Self(ret)
}
}

Expand Down Expand Up @@ -513,7 +512,7 @@ impl Mul<Scalar> for GTElement {
blst_fr_rshift(&mut n, &n, 1);
}
y *= x;
Self::from(y)
Self(y)
}
}
}
Expand All @@ -522,12 +521,12 @@ impl GroupElement for GTElement {
type ScalarType = Scalar;

fn zero() -> Self {
unsafe { Self::from(*blst_fp12_one()) }
unsafe { Self(*blst_fp12_one()) }
}

fn generator() -> Self {
static G: OnceCell<blst_fp12> = OnceCell::new();
Self::from(*G.get_or_init(Self::compute_generator))
Self(*G.get_or_init(Self::compute_generator))
}
}

Expand All @@ -547,7 +546,7 @@ const P_AS_BYTES: [u8; FP_BYTE_LENGTH] = hex!("1a0111ea397fe69a4b1ba7b6434bacd76

// Note that the serialization below is uncompressed, i.e. it uses 576 bytes.
impl ToFromByteArray<GT_ELEMENT_BYTE_LENGTH> for GTElement {
fn from_byte_array(bytes: &[u8; GT_ELEMENT_BYTE_LENGTH]) -> Result<Self, FastCryptoError> {
fn from_byte_array(bytes: &[u8; GT_ELEMENT_BYTE_LENGTH]) -> FastCryptoResult<Self> {
// The following is based on the order from
// https://github.com/supranational/blst/blob/b4ebf88014251f1cfefeb6cf1cd4df7c40dc568f/src/fp12_tower.c#L773-L786C2
let mut gt: blst_fp12 = Default::default();
Expand All @@ -558,7 +557,7 @@ impl ToFromByteArray<GT_ELEMENT_BYTE_LENGTH> for GTElement {
let mut fp = blst_fp::default();
let slice = &bytes[current..current + FP_BYTE_LENGTH];
// We compare with P_AS_BYTES to ensure that we process a canonical representation
// which is uses mod p elements.
// which uses mod p elements.
if *slice >= P_AS_BYTES[..] {
return Err(FastCryptoError::InvalidInput);
}
Expand All @@ -572,7 +571,7 @@ impl ToFromByteArray<GT_ELEMENT_BYTE_LENGTH> for GTElement {
}

match gt.in_group() {
true => Ok(Self::from(gt)),
true => Ok(Self(gt)),
false => Err(FastCryptoError::InvalidInput),
}
}
Expand All @@ -595,11 +594,11 @@ impl GroupElement for Scalar {
type ScalarType = Self;

fn zero() -> Self {
Self::from(blst_fr::default())
Self(blst_fr::default())
}

fn generator() -> Self {
Self::from(BLST_FR_ONE)
Self(BLST_FR_ONE)
}
}

Expand All @@ -611,7 +610,7 @@ impl Add for Scalar {
unsafe {
blst_fr_add(&mut ret, &self.0, &rhs.0);
}
Self::from(ret)
Self(ret)
}
}

Expand All @@ -623,7 +622,7 @@ impl Sub for Scalar {
unsafe {
blst_fr_sub(&mut ret, &self.0, &rhs.0);
}
Self::from(ret)
Self(ret)
}
}

Expand All @@ -635,7 +634,7 @@ impl Neg for Scalar {
unsafe {
blst_fr_cneg(&mut ret, &self.0, true);
}
Self::from(ret)
Self(ret)
}
}

Expand All @@ -647,7 +646,7 @@ impl Mul<Scalar> for Scalar {
unsafe {
blst_fr_mul(&mut ret, &self.0, &rhs.0);
}
Self::from(ret)
Self(ret)
}
}

Expand All @@ -658,7 +657,7 @@ impl From<u128> for Scalar {
unsafe {
blst_fr_from_uint64(&mut ret, buff.as_ptr());
}
Self::from(ret)
Self(ret)
}
}

Expand Down Expand Up @@ -687,7 +686,7 @@ impl ScalarType for Scalar {
unsafe {
blst_fr_inverse(&mut ret, &self.0);
}
Ok(Self::from(ret))
Ok(Self(ret))
}
}

Expand All @@ -705,7 +704,7 @@ pub(crate) fn reduce_mod_uniform_buffer(buffer: &[u8]) -> Scalar {
blst_scalar_from_be_bytes(&mut tmp, buffer.as_ptr(), buffer.len());
blst_fr_from_scalar(&mut ret, &tmp);
}
Scalar::from(ret)
Scalar(ret)
}

impl FiatShamirChallenge for Scalar {
Expand All @@ -725,7 +724,7 @@ impl ToFromByteArray<SCALAR_LENGTH> for Scalar {
}
blst_fr_from_scalar(&mut ret, &scalar);
}
Ok(Scalar::from(ret))
Ok(Scalar(ret))
}

fn to_byte_array(&self) -> [u8; SCALAR_LENGTH] {
Expand Down
Loading

0 comments on commit ff9ee0c

Please sign in to comment.