From ff9ee0cc1b864df6415def23171f5ca86a60750d Mon Sep 17 00:00:00 2001 From: benr-ml Date: Mon, 4 Mar 2024 19:52:42 +0200 Subject: [PATCH] remove default from, add tests --- fastcrypto/src/groups/bls12381.rs | 81 ++++++++++---------- fastcrypto/src/tests/bls12381_group_tests.rs | 79 +++++++++++++++++-- 2 files changed, 113 insertions(+), 47 deletions(-) diff --git a/fastcrypto/src/groups/bls12381.rs b/fastcrypto/src/groups/bls12381.rs index 705ca9dc50..e7c3b79dba 100644 --- a/fastcrypto/src/groups/bls12381.rs +++ b/fastcrypto/src/groups/bls12381.rs @@ -26,7 +26,6 @@ use blst::{ blst_scalar, blst_scalar_fr_check, blst_scalar_from_be_bytes, blst_scalar_from_bendian, blst_scalar_from_fr, p1_affines, p2_affines, BLS12_381_G1, BLS12_381_G2, BLST_ERROR, }; -use derive_more::From; use fastcrypto_derive::GroupOpsExtend; use hex_literal::hex; use once_cell::sync::OnceCell; @@ -36,22 +35,22 @@ use std::ops::{Add, Div, Mul, Neg, Sub}; use std::ptr; /// Elements of the group G_1 in BLS 12-381. -#[derive(From, Clone, Copy, Eq, PartialEq, GroupOpsExtend)] +#[derive(Clone, Copy, Eq, PartialEq, GroupOpsExtend)] #[repr(transparent)] pub struct G1Element(blst_p1); /// Elements of the group G_2 in BLS 12-381. -#[derive(From, Clone, Copy, Eq, PartialEq, GroupOpsExtend)] +#[derive(Clone, Copy, Eq, PartialEq, GroupOpsExtend)] #[repr(transparent)] pub struct G2Element(blst_p2); /// Elements of the subgroup G_T of F_q^{12} in BLS 12-381. Note that it is written in additive notation here. -#[derive(From, Clone, Copy, Eq, PartialEq, GroupOpsExtend)] +#[derive(Clone, Copy, Eq, PartialEq, GroupOpsExtend)] pub struct GTElement(blst_fp12); /// This represents a scalar modulo r = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 /// which is the order of the groups G1, G2 and GT. Note that r is a 255 bit prime. -#[derive(From, Clone, Copy, Eq, PartialEq, GroupOpsExtend)] +#[derive(Clone, Copy, Eq, PartialEq, GroupOpsExtend)] pub struct Scalar(blst_fr); pub const SCALAR_LENGTH: usize = 32; @@ -68,7 +67,7 @@ impl Add for G1Element { unsafe { blst_p1_add_or_double(&mut ret, &self.0, &rhs.0); } - Self::from(ret) + Self(ret) } } @@ -88,14 +87,14 @@ impl Neg for G1Element { unsafe { blst_p1_cneg(&mut ret, true); } - Self::from(ret) + Self(ret) } } /// The size of this scalar in bytes. fn size_in_bytes(scalar: &blst_scalar) -> usize { let mut i = scalar.b.len(); - assert_eq!(i, 32); + debug_assert_eq!(i, 32); while i != 0 && scalar.b[i - 1] == 0 { i -= 1; } @@ -152,7 +151,7 @@ impl Mul for G1Element { ); } - Self::from(result) + Self(result) } } @@ -174,7 +173,7 @@ impl MultiScalarMul for G1Element { } // The scalar field size is smaller than 2^255, so we need at most 255 bits. let res = points.mult(scalar_bytes.as_slice(), 255); - Ok(Self::from(res)) + Ok(Self(res)) } } @@ -190,7 +189,7 @@ impl GroupElement for G1Element { type ScalarType = Scalar; fn zero() -> Self { - Self::from(blst_p1::default()) + Self(blst_p1::default()) } fn generator() -> Self { @@ -198,7 +197,7 @@ impl GroupElement for G1Element { unsafe { blst_p1_from_affine(&mut ret, &BLS12_381_G1); } - Self::from(ret) + Self(ret) } } @@ -216,7 +215,7 @@ impl Pairing for G1Element { blst_miller_loop(&mut res, &other_affine, &self_affine); blst_final_exp(&mut res, &res); } - ::Output::from(res) + GTElement(res) } } @@ -234,7 +233,7 @@ impl HashToGroupElement for G1Element { 0, ); } - Self::from(res) + Self(res) } } @@ -252,7 +251,7 @@ impl ToFromByteArray for G1Element { return Err(FastCryptoError::InvalidInput); } } - Ok(G1Element::from(ret)) + Ok(G1Element(ret)) } fn to_byte_array(&self) -> [u8; G1_ELEMENT_BYTE_LENGTH] { @@ -282,7 +281,7 @@ impl Add for G2Element { unsafe { blst_p2_add_or_double(&mut ret, &self.0, &rhs.0); } - Self::from(ret) + Self(ret) } } @@ -302,7 +301,7 @@ impl Neg for G2Element { unsafe { blst_p2_cneg(&mut ret, true); } - Self::from(ret) + Self(ret) } } @@ -347,7 +346,7 @@ impl Mul for G2Element { ); } - Self::from(result) + Self(result) } } @@ -369,7 +368,7 @@ impl MultiScalarMul for G2Element { } // The scalar field size is smaller than 2^255, so we need at most 255 bits. let res = points.mult(scalar_bytes.as_slice(), 255); - Ok(Self::from(res)) + Ok(Self(res)) } } @@ -377,7 +376,7 @@ impl GroupElement for G2Element { type ScalarType = Scalar; fn zero() -> Self { - Self::from(blst_p2::default()) + Self(blst_p2::default()) } fn generator() -> Self { @@ -385,7 +384,7 @@ impl GroupElement for G2Element { unsafe { blst_p2_from_affine(&mut ret, &BLS12_381_G2); } - Self::from(ret) + Self(ret) } } @@ -403,7 +402,7 @@ impl HashToGroupElement for G2Element { 0, ); } - Self::from(res) + Self(res) } } @@ -421,7 +420,7 @@ impl ToFromByteArray for G2Element { return Err(FastCryptoError::InvalidInput); } } - Ok(G2Element::from(ret)) + Ok(G2Element(ret)) } fn to_byte_array(&self) -> [u8; G2_ELEMENT_BYTE_LENGTH] { @@ -451,7 +450,7 @@ impl Add for GTElement { unsafe { blst_fp12_mul(&mut ret, &self.0, &rhs.0); } - Self::from(ret) + Self(ret) } } @@ -471,7 +470,7 @@ impl Neg for GTElement { unsafe { blst_fp12_inverse(&mut ret, &self.0); } - Self::from(ret) + Self(ret) } } @@ -513,7 +512,7 @@ impl Mul for GTElement { blst_fr_rshift(&mut n, &n, 1); } y *= x; - Self::from(y) + Self(y) } } } @@ -522,12 +521,12 @@ impl GroupElement for GTElement { type ScalarType = Scalar; fn zero() -> Self { - unsafe { Self::from(*blst_fp12_one()) } + unsafe { Self(*blst_fp12_one()) } } fn generator() -> Self { static G: OnceCell = OnceCell::new(); - Self::from(*G.get_or_init(Self::compute_generator)) + Self(*G.get_or_init(Self::compute_generator)) } } @@ -547,7 +546,7 @@ const P_AS_BYTES: [u8; FP_BYTE_LENGTH] = hex!("1a0111ea397fe69a4b1ba7b6434bacd76 // Note that the serialization below is uncompressed, i.e. it uses 576 bytes. impl ToFromByteArray for GTElement { - fn from_byte_array(bytes: &[u8; GT_ELEMENT_BYTE_LENGTH]) -> Result { + fn from_byte_array(bytes: &[u8; GT_ELEMENT_BYTE_LENGTH]) -> FastCryptoResult { // The following is based on the order from // https://github.com/supranational/blst/blob/b4ebf88014251f1cfefeb6cf1cd4df7c40dc568f/src/fp12_tower.c#L773-L786C2 let mut gt: blst_fp12 = Default::default(); @@ -558,7 +557,7 @@ impl ToFromByteArray for GTElement { let mut fp = blst_fp::default(); let slice = &bytes[current..current + FP_BYTE_LENGTH]; // We compare with P_AS_BYTES to ensure that we process a canonical representation - // which is uses mod p elements. + // which uses mod p elements. if *slice >= P_AS_BYTES[..] { return Err(FastCryptoError::InvalidInput); } @@ -572,7 +571,7 @@ impl ToFromByteArray for GTElement { } match gt.in_group() { - true => Ok(Self::from(gt)), + true => Ok(Self(gt)), false => Err(FastCryptoError::InvalidInput), } } @@ -595,11 +594,11 @@ impl GroupElement for Scalar { type ScalarType = Self; fn zero() -> Self { - Self::from(blst_fr::default()) + Self(blst_fr::default()) } fn generator() -> Self { - Self::from(BLST_FR_ONE) + Self(BLST_FR_ONE) } } @@ -611,7 +610,7 @@ impl Add for Scalar { unsafe { blst_fr_add(&mut ret, &self.0, &rhs.0); } - Self::from(ret) + Self(ret) } } @@ -623,7 +622,7 @@ impl Sub for Scalar { unsafe { blst_fr_sub(&mut ret, &self.0, &rhs.0); } - Self::from(ret) + Self(ret) } } @@ -635,7 +634,7 @@ impl Neg for Scalar { unsafe { blst_fr_cneg(&mut ret, &self.0, true); } - Self::from(ret) + Self(ret) } } @@ -647,7 +646,7 @@ impl Mul for Scalar { unsafe { blst_fr_mul(&mut ret, &self.0, &rhs.0); } - Self::from(ret) + Self(ret) } } @@ -658,7 +657,7 @@ impl From for Scalar { unsafe { blst_fr_from_uint64(&mut ret, buff.as_ptr()); } - Self::from(ret) + Self(ret) } } @@ -687,7 +686,7 @@ impl ScalarType for Scalar { unsafe { blst_fr_inverse(&mut ret, &self.0); } - Ok(Self::from(ret)) + Ok(Self(ret)) } } @@ -705,7 +704,7 @@ pub(crate) fn reduce_mod_uniform_buffer(buffer: &[u8]) -> Scalar { blst_scalar_from_be_bytes(&mut tmp, buffer.as_ptr(), buffer.len()); blst_fr_from_scalar(&mut ret, &tmp); } - Scalar::from(ret) + Scalar(ret) } impl FiatShamirChallenge for Scalar { @@ -725,7 +724,7 @@ impl ToFromByteArray for Scalar { } blst_fr_from_scalar(&mut ret, &scalar); } - Ok(Scalar::from(ret)) + Ok(Scalar(ret)) } fn to_byte_array(&self) -> [u8; SCALAR_LENGTH] { diff --git a/fastcrypto/src/tests/bls12381_group_tests.rs b/fastcrypto/src/tests/bls12381_group_tests.rs index 80b1da115e..361796e92a 100644 --- a/fastcrypto/src/tests/bls12381_group_tests.rs +++ b/fastcrypto/src/tests/bls12381_group_tests.rs @@ -46,11 +46,6 @@ fn test_scalar_arithmetic() { let inv_two = two.inverse().unwrap(); assert_eq!(inv_two * two, one); - // Scalar::from_byte_array should not accept the order. - let order = - hex::decode("73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001").unwrap(); - assert!(Scalar::from_byte_array(<&[u8; 32]>::try_from(order.as_slice()).unwrap()).is_err()); - // Check that u128 is decoded correctly. let x: u128 = 2 << 66; let x_scalar = Scalar::from(x); @@ -82,10 +77,15 @@ fn test_g1_arithmetic() { let p6 = g * Scalar::zero(); assert_eq!(G1Element::zero(), p6); + let sc = Scalar::rand(&mut thread_rng()); + let p7 = g * sc; + assert_eq!(p7 * Scalar::from(1), p7); + assert_ne!(G1Element::zero(), g); assert_eq!(G1Element::zero(), g - g); assert!((G1Element::generator() / Scalar::zero()).is_err()); + assert_eq!((p5 / Scalar::from(5)).unwrap(), g); let identity = G1Element::zero(); assert_eq!(identity, identity - identity); @@ -136,7 +136,12 @@ fn test_g2_arithmetic() { let p6 = g * Scalar::zero(); assert_eq!(G2Element::zero(), p6); + let sc = Scalar::rand(&mut thread_rng()); + let p7 = g * sc; + assert_eq!(p7 * Scalar::from(1), p7); + assert!((G2Element::generator() / Scalar::zero()).is_err()); + assert_eq!((p5 / Scalar::from(5)).unwrap(), g); assert_ne!(G2Element::zero(), g); assert_eq!(G2Element::zero(), g - g); @@ -190,11 +195,16 @@ fn test_gt_arithmetic() { let p6 = g * Scalar::zero(); assert_eq!(GTElement::zero(), p6); + let sc = Scalar::rand(&mut thread_rng()); + let p7 = g * sc; + assert_eq!(p7 * Scalar::from(1), p7); + assert_ne!(GTElement::zero(), g); assert_eq!(GTElement::zero(), g - g); assert_eq!(GTElement::zero(), GTElement::zero() - GTElement::zero()); assert!((GTElement::generator() / Scalar::zero()).is_err()); + assert_eq!((p5 / Scalar::from(5)).unwrap(), g); } #[test] @@ -210,17 +220,37 @@ fn test_pairing_and_hash_to_curve() { let pk2 = G1Element::generator() * sk2; let sig2 = e2 * sk2; assert_eq!(pk2.pairing(&e2), G1Element::generator().pairing(&sig2)); + + assert_eq!( + G1Element::zero().pairing(&G2Element::zero()), + GTElement::zero() + ); + assert_eq!( + G1Element::zero().pairing(&G2Element::generator()), + GTElement::zero() + ); + assert_eq!( + G1Element::generator().pairing(&G2Element::zero()), + GTElement::zero() + ); + + // next should not fail + let _ = G1Element::hash_to_group_element(&[]); + let _ = G2Element::hash_to_group_element(&[]); + let _ = G1Element::hash_to_group_element(&[1]); + let _ = G2Element::hash_to_group_element(&[1]); } #[test] fn test_serde_and_regression() { - let s1 = Scalar::from(1); + let s1 = Scalar::generator(); let g1 = G1Element::generator(); let g2 = G2Element::generator(); let gt = GTElement::generator(); let id1 = G1Element::zero(); let id2 = G2Element::zero(); let id3 = GTElement::zero(); + let id4 = Scalar::zero(); verify_serialization( &s1, @@ -236,6 +266,14 @@ fn test_serde_and_regression() { verify_serialization(&id1, Some(hex::decode("c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000").unwrap().as_slice())); verify_serialization(&id2, Some(hex::decode("c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000").unwrap().as_slice())); verify_serialization(&id3, Some(hex::decode("000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000").unwrap().as_slice())); + verify_serialization( + &id4, + Some( + hex::decode("0000000000000000000000000000000000000000000000000000000000000000") + .unwrap() + .as_slice(), + ), + ); } #[test] @@ -265,6 +303,35 @@ fn test_consistent_bls12381_serialization() { assert_eq!(sig1, sig3); } +#[test] +fn test_serialization_scalar() { + let bytes = [0u8; 32]; + assert_eq!(Scalar::from_byte_array(&bytes).unwrap(), Scalar::zero()); + + // Scalar::from_byte_array should not accept the order or above it. + let order = + hex::decode("73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001").unwrap(); + assert!(Scalar::from_byte_array(<&[u8; 32]>::try_from(order.as_slice()).unwrap()).is_err()); + let order = + hex::decode("73eda753299d9d483339d80809a1d80553bda402fffe5bfeffffffff11000001").unwrap(); + assert!(Scalar::from_byte_array(<&[u8; 32]>::try_from(order.as_slice()).unwrap()).is_err()); + + // Scalar::from_byte_array should accept the order - 1. + let order_minus_one = + hex::decode("73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000000").unwrap(); + assert_eq!( + Scalar::from_byte_array(<&[u8; 32]>::try_from(order_minus_one.as_slice()).unwrap()) + .unwrap(), + Scalar::zero() - Scalar::generator() + ); + + for _ in 0..100 { + let s = Scalar::rand(&mut thread_rng()); + let bytes = s.to_byte_array(); + assert_eq!(s, Scalar::from_byte_array(&bytes).unwrap()); + } +} + #[test] fn test_serialization_g1() { let infinity_bit = 0x40;