diff --git a/Cargo.toml b/Cargo.toml index be7fb3a7..e659cb48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,10 +19,11 @@ digest = "0.8.1" ff-zeroize = "0.6.3" funty = "=1.1.0" generic-array = "0.14" -hex = "0.4" +hex = { version = "0.4", features = ["serde"] } hmac = "0.7.1" +thiserror = "1" merkle-sha3 = "^0.1" -lazy_static = "1.4.0" +lazy_static = "1.4" num-traits = "0.2" num-integer = "0.1" pairing-plus = "0.19" @@ -33,7 +34,6 @@ serde = { version = "1.0", features = ["derive"] } serde_derive = "1.0" sha2 = "0.8.0" sha3 = "0.8.2" -thiserror = "1" zeroize = "1" rust-gmp-kzen = { version = "0.5", features = ["serde_support"], optional = true } diff --git a/src/elliptic/curves/mod.rs b/src/elliptic/curves/mod.rs index cc22a67f..2b3d2375 100644 --- a/src/elliptic/curves/mod.rs +++ b/src/elliptic/curves/mod.rs @@ -1,6 +1,11 @@ -pub mod bls12_381; -pub mod curve_ristretto; -pub mod ed25519; -pub mod p256; +// pub mod bls12_381; +// pub mod curve_ristretto; +// pub mod ed25519; +// pub mod p256; pub mod secp256_k1; -pub mod traits; + +mod traits; +mod wrappers; + +pub use self::secp256_k1::Secp256k1; +pub use self::{traits::*, wrappers::*}; diff --git a/src/elliptic/curves/secp256_k1.rs b/src/elliptic/curves/secp256_k1.rs index 41f79c25..58ecbd0c 100644 --- a/src/elliptic/curves/secp256_k1.rs +++ b/src/elliptic/curves/secp256_k1.rs @@ -16,31 +16,51 @@ // The Public Key codec: Point <> SecretKey // -use super::traits::{ECPoint, ECScalar}; -use crate::arithmetic::traits::*; -use crate::BigInt; -use crate::ErrorKey; - -#[cfg(feature = "merkle")] -use crypto::digest::Digest; -#[cfg(feature = "merkle")] -use crypto::sha3::Sha3; -#[cfg(feature = "merkle")] -use merkle::Hashable; +use std::ops::{self, Deref}; +use std::ptr; +use std::sync::atomic; + use rand::thread_rng; use secp256k1::constants::{ - CURVE_ORDER, GENERATOR_X, GENERATOR_Y, SECRET_KEY_SIZE, UNCOMPRESSED_PUBLIC_KEY_SIZE, + self, GENERATOR_X, GENERATOR_Y, SECRET_KEY_SIZE, UNCOMPRESSED_PUBLIC_KEY_SIZE, }; -use secp256k1::{PublicKey, Secp256k1, SecretKey, VerifyOnly}; -use serde::de::{self, Error, MapAccess, SeqAccess, Visitor}; -use serde::ser::SerializeStruct; -use serde::ser::{Serialize, Serializer}; -use serde::{Deserialize, Deserializer}; -use std::fmt; -use std::ops::{Add, Mul}; -use std::ptr; -use std::sync::{atomic, Once}; -use zeroize::Zeroize; +use secp256k1::{PublicKey, SecretKey}; +use zeroize::{Zeroize, Zeroizing}; + +use crate::arithmetic::*; + +use super::traits::*; + +lazy_static::lazy_static! { + static ref CONTEXT: secp256k1::Secp256k1 = secp256k1::Secp256k1::verification_only(); + + static ref CURVE_ORDER: BigInt = BigInt::from_bytes(&constants::CURVE_ORDER); + + static ref GENERATOR_UNCOMRESSED: Vec = { + let mut g = vec![4_u8]; + g.extend_from_slice(&GENERATOR_X); + g.extend_from_slice(&GENERATOR_Y); + g + }; + + static ref BASE_POINT2_UNCOMPRESSED: Vec = { + let mut g = vec![4_u8]; + g.extend_from_slice(BASE_POINT2_X.as_ref()); + g.extend_from_slice(BASE_POINT2_Y.as_ref()); + g + }; + + static ref GENERATOR: Secp256k1Point = Secp256k1Point { + purpose: "generator", + ge: Some(PK(PublicKey::from_slice(&GENERATOR_UNCOMRESSED).unwrap())), + }; + + static ref BASE_POINT2: Secp256k1Point = Secp256k1Point { + purpose: "base_point2", + ge: Some(PK(PublicKey::from_slice(&BASE_POINT2_UNCOMPRESSED).unwrap())), + }; +} + /* X coordinate of a point of unknown discrete logarithm. Computed using a deterministic algorithm with the generator as input. See test_base_point2 */ @@ -54,415 +74,251 @@ const BASE_POINT2_Y: [u8; 32] = [ 0x80, 0x7b, 0xcb, 0xa1, 0xdf, 0x0d, 0xf0, 0x7a, 0x82, 0x17, 0xe9, 0xf7, 0xf7, 0xc2, 0xbe, 0x88, ]; -pub type SK = SecretKey; -pub type PK = PublicKey; +/// SK wraps secp256k1::SecretKey and implements Zeroize to it +#[derive(Clone, PartialEq, Debug)] +pub struct SK(pub SecretKey); +/// PK wraps secp256k1::PublicKey and implements Zeroize to it +#[derive(Copy, Clone, PartialEq, Debug)] +pub struct PK(pub PublicKey); -#[derive(Clone, Debug, Copy)] -pub struct Secp256k1Scalar { - purpose: &'static str, - fe: SK, +impl ops::Deref for SK { + type Target = SecretKey; + + fn deref(&self) -> &Self::Target { + &self.0 + } } -#[derive(Clone, Debug, Copy)] -pub struct Secp256k1Point { - purpose: &'static str, - ge: PK, +impl ops::DerefMut for SK { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } } -pub type GE = Secp256k1Point; -pub type FE = Secp256k1Scalar; - -impl Secp256k1Point { - pub fn random_point() -> Secp256k1Point { - let random_scalar: Secp256k1Scalar = Secp256k1Scalar::new_random(); - let base_point = Secp256k1Point::generator(); - let pk = base_point.scalar_mul(&random_scalar.get_element()); - Secp256k1Point { - purpose: "random_point", - ge: pk.get_element(), - } + +impl ops::Deref for PK { + type Target = PublicKey; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl ops::DerefMut for PK { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } -impl Zeroize for Secp256k1Scalar { +impl Zeroize for SK { fn zeroize(&mut self) { - unsafe { ptr::write_volatile(self, FE::zero()) }; + let sk = self.0.as_mut_ptr(); + let sk_bytes = unsafe { std::slice::from_raw_parts_mut(sk, 32) }; + sk_bytes.zeroize() + } +} + +impl Zeroize for PK { + fn zeroize(&mut self) { + let zeroed = unsafe { secp256k1::ffi::PublicKey::new() }; + unsafe { ptr::write_volatile(self.0.as_mut_ptr(), zeroed) }; atomic::fence(atomic::Ordering::SeqCst); atomic::compiler_fence(atomic::Ordering::SeqCst); } } +#[derive(Clone, Debug, PartialEq)] +pub enum Secp256k1 {} + +impl Curve for Secp256k1 { + type Point = GE; + type Scalar = FE; + + fn curve_name() -> &'static str { + "secp256k1" + } +} + +#[derive(Clone, Debug)] +pub struct Secp256k1Scalar { + purpose: &'static str, + /// Zeroizing wraps SK and zeroize it on drop + /// + /// `fe` might be None — special case for scalar being zero + fe: zeroize::Zeroizing>, +} +#[derive(Clone, Debug, Copy)] +pub struct Secp256k1Point { + purpose: &'static str, + ge: Option, +} + +type GE = Secp256k1Point; +type FE = Secp256k1Scalar; + impl ECScalar for Secp256k1Scalar { - type SecretKey = SK; + type Underlying = Option; - fn new_random() -> Secp256k1Scalar { + fn random() -> Secp256k1Scalar { + let sk = SK(SecretKey::new(&mut thread_rng())); Secp256k1Scalar { purpose: "random", - fe: SecretKey::new(&mut thread_rng()), + fe: Zeroizing::new(Some(sk)), } } fn zero() -> Secp256k1Scalar { - let zero_arr = [0u8; 32]; - let zero = unsafe { std::mem::transmute::<[u8; 32], SecretKey>(zero_arr) }; Secp256k1Scalar { purpose: "zero", - fe: zero, + fe: Zeroizing::new(None), } } - fn get_element(&self) -> SK { - self.fe + fn is_zero(&self) -> bool { + self.fe.is_none() } - fn set_element(&mut self, element: SK) { - self.fe = element - } + fn from_bigint(n: &BigInt) -> Secp256k1Scalar { + if n.is_zero() { + return Self::zero(); + } - fn from(n: &BigInt) -> Secp256k1Scalar { - let curve_order = FE::q(); - let n_reduced = BigInt::mod_add(n, &BigInt::from(0), &curve_order); - let mut v = BigInt::to_bytes(&n_reduced); + let curve_order = Self::curve_order(); + let n_reduced = n.modulus(curve_order); + let bytes = BigInt::to_bytes(&n_reduced); - if v.len() < SECRET_KEY_SIZE { - let mut template = vec![0; SECRET_KEY_SIZE - v.len()]; - template.extend_from_slice(&v); - v = template; - } + let bytes = if bytes.len() < SECRET_KEY_SIZE { + let mut zero_prepended = vec![0; SECRET_KEY_SIZE - bytes.len()]; + zero_prepended.extend_from_slice(&bytes); + zero_prepended + } else { + bytes + }; Secp256k1Scalar { - purpose: "from_big_int", - fe: SK::from_slice(&v).unwrap(), + purpose: "from_bigint", + fe: Zeroizing::new(SecretKey::from_slice(&bytes).map(SK).ok()), } } - fn to_big_int(&self) -> BigInt { - BigInt::from_bytes(&(self.fe[0..self.fe.len()])) - } - - fn q() -> BigInt { - BigInt::from_bytes(CURVE_ORDER.as_ref()) + fn to_bigint(&self) -> BigInt { + match self.fe.deref() { + Some(sk) => BigInt::from_bytes(&sk[..]), + None => BigInt::zero(), + } } - fn add(&self, other: &SK) -> Secp256k1Scalar { - let mut other_scalar: FE = ECScalar::new_random(); - other_scalar.set_element(*other); - let res: FE = ECScalar::from(&BigInt::mod_add( - &self.to_big_int(), - &other_scalar.to_big_int(), - &FE::q(), - )); + fn add(&self, other: &Self) -> Secp256k1Scalar { + // TODO: use add_assign? + // https://docs.rs/secp256k1/0.20.3/secp256k1/key/struct.SecretKey.html#method.add_assign + let n = BigInt::mod_add(&self.to_bigint(), &other.to_bigint(), Self::curve_order()); Secp256k1Scalar { purpose: "add", - fe: res.get_element(), + fe: Self::from_bigint(&n).fe, } } - fn mul(&self, other: &SK) -> Secp256k1Scalar { - let mut other_scalar: FE = ECScalar::new_random(); - other_scalar.set_element(*other); - let res: FE = ECScalar::from(&BigInt::mod_mul( - &self.to_big_int(), - &other_scalar.to_big_int(), - &FE::q(), - )); + fn mul(&self, other: &Self) -> Secp256k1Scalar { + // TODO: use mul_assign? + // https://docs.rs/secp256k1/0.20.3/secp256k1/key/struct.SecretKey.html#method.mul_assign + let n = BigInt::mod_mul(&self.to_bigint(), &other.to_bigint(), Self::curve_order()); Secp256k1Scalar { purpose: "mul", - fe: res.get_element(), + fe: Self::from_bigint(&n).fe, } } - fn sub(&self, other: &SK) -> Secp256k1Scalar { - let mut other_scalar: FE = ECScalar::new_random(); - other_scalar.set_element(*other); - let res: FE = ECScalar::from(&BigInt::mod_sub( - &self.to_big_int(), - &other_scalar.to_big_int(), - &FE::q(), - )); + fn sub(&self, other: &Self) -> Secp256k1Scalar { + // TODO: use negate+add_assign? + // https://docs.rs/secp256k1/0.20.3/secp256k1/key/struct.SecretKey.html#method.negate_assign + // https://docs.rs/secp256k1/0.20.3/secp256k1/key/struct.SecretKey.html#method.add_assign + let n = BigInt::mod_sub(&self.to_bigint(), &other.to_bigint(), Self::curve_order()); Secp256k1Scalar { purpose: "sub", - fe: res.get_element(), + fe: Self::from_bigint(&n).fe, } } - fn invert(&self) -> Secp256k1Scalar { - let bignum = self.to_big_int(); - let bn_inv = BigInt::mod_inv(&bignum, &FE::q()).unwrap(); - ECScalar::from(&bn_inv) - } -} -impl Mul for Secp256k1Scalar { - type Output = Secp256k1Scalar; - fn mul(self, other: Secp256k1Scalar) -> Secp256k1Scalar { - (&self).mul(&other.get_element()) - } -} - -impl<'o> Mul<&'o Secp256k1Scalar> for Secp256k1Scalar { - type Output = Secp256k1Scalar; - fn mul(self, other: &'o Secp256k1Scalar) -> Secp256k1Scalar { - (&self).mul(&other.get_element()) - } -} - -impl Add for Secp256k1Scalar { - type Output = Secp256k1Scalar; - fn add(self, other: Secp256k1Scalar) -> Secp256k1Scalar { - (&self).add(&other.get_element()) + fn neg(&self) -> Self { + let n = BigInt::mod_sub(&BigInt::zero(), &self.to_bigint(), Self::curve_order()); + Secp256k1Scalar { + purpose: "neg", + fe: Self::from_bigint(&n).fe, + } } -} -impl<'o> Add<&'o Secp256k1Scalar> for Secp256k1Scalar { - type Output = Secp256k1Scalar; - fn add(self, other: &'o Secp256k1Scalar) -> Secp256k1Scalar { - (&self).add(&other.get_element()) + fn invert(&self) -> Option { + let n = self.to_bigint(); + let n_inv = BigInt::mod_inv(&n, Self::curve_order()); + n_inv.map(|i| Secp256k1Scalar { + purpose: "invert", + fe: Self::from_bigint(&i).fe, + }) } -} -impl Serialize for Secp256k1Scalar { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str(&self.to_big_int().to_hex()) + fn curve_order() -> &'static BigInt { + &CURVE_ORDER } -} -impl<'de> Deserialize<'de> for Secp256k1Scalar { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - deserializer.deserialize_str(Secp256k1ScalarVisitor) + fn underlying_ref(&self) -> &Self::Underlying { + &self.fe } -} - -struct Secp256k1ScalarVisitor; -impl<'de> Visitor<'de> for Secp256k1ScalarVisitor { - type Value = Secp256k1Scalar; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("Secp256k1Scalar") + fn underlying_mut(&mut self) -> &mut Self::Underlying { + &mut self.fe } - fn visit_str(self, s: &str) -> Result { - let v = BigInt::from_hex(s).map_err(E::custom)?; - Ok(ECScalar::from(&v)) + fn from_underlying(u: Self::Underlying) -> Secp256k1Scalar { + Secp256k1Scalar { + purpose: "from_underlying", + fe: Zeroizing::new(u), + } } } impl PartialEq for Secp256k1Scalar { fn eq(&self, other: &Secp256k1Scalar) -> bool { - self.get_element() == other.get_element() - } -} - -impl PartialEq for Secp256k1Point { - fn eq(&self, other: &Secp256k1Point) -> bool { - self.get_element() == other.get_element() - } -} - -impl Zeroize for Secp256k1Point { - fn zeroize(&mut self) { - unsafe { ptr::write_volatile(self, GE::generator()) }; - atomic::fence(atomic::Ordering::SeqCst); - atomic::compiler_fence(atomic::Ordering::SeqCst); + self.underlying_ref() == other.underlying_ref() } } impl ECPoint for Secp256k1Point { - type SecretKey = SK; - type PublicKey = PK; type Scalar = Secp256k1Scalar; + type Underlying = Option; - fn base_point2() -> Secp256k1Point { - let mut v = vec![4_u8]; - v.extend(BASE_POINT2_X.as_ref()); - v.extend(BASE_POINT2_Y.as_ref()); - Secp256k1Point { - purpose: "random", - ge: PK::from_slice(&v).unwrap(), - } - } - - fn generator() -> Secp256k1Point { - let mut v = vec![4_u8]; - v.extend(GENERATOR_X.as_ref()); - v.extend(GENERATOR_Y.as_ref()); + fn zero() -> Secp256k1Point { Secp256k1Point { - purpose: "base_fe", - ge: PK::from_slice(&v).unwrap(), - } - } - - fn get_element(&self) -> PK { - self.ge - } - - /// to return from BigInt to PK use from_bytes: - /// 1) convert BigInt::to_vec - /// 2) remove first byte [1..33] - /// 3) call from_bytes - fn bytes_compressed_to_big_int(&self) -> BigInt { - let serial = self.ge.serialize(); - BigInt::from_bytes(&serial[0..33]) - } - - fn x_coor(&self) -> Option { - let serialized_pk = PK::serialize_uncompressed(&self.ge); - let x = &serialized_pk[1..serialized_pk.len() / 2 + 1]; - let x_vec = x.to_vec(); - Some(BigInt::from_bytes(&x_vec[..])) - } - - fn y_coor(&self) -> Option { - let serialized_pk = PK::serialize_uncompressed(&self.ge); - let y = &serialized_pk[(serialized_pk.len() - 1) / 2 + 1..serialized_pk.len()]; - let y_vec = y.to_vec(); - Some(BigInt::from_bytes(&y_vec[..])) - } - - fn from_bytes(bytes: &[u8]) -> Result { - let bytes_vec = bytes.to_vec(); - let mut bytes_array_65 = [0u8; 65]; - let mut bytes_array_33 = [0u8; 33]; - - let byte_len = bytes_vec.len(); - match byte_len { - 33..=63 => { - let mut template = vec![0; 64 - bytes_vec.len()]; - template.extend_from_slice(&bytes); - let mut bytes_vec = template; - let mut template: Vec = vec![4]; - template.append(&mut bytes_vec); - let bytes_slice = &template[..]; - - bytes_array_65.copy_from_slice(&bytes_slice[0..65]); - let result = PK::from_slice(&bytes_array_65); - let test = result.map(|pk| Secp256k1Point { - purpose: "random", - ge: pk, - }); - test.map_err(|_err| ErrorKey::InvalidPublicKey) - } - - 0..=32 => { - let mut template = vec![0; 32 - bytes_vec.len()]; - template.extend_from_slice(&bytes); - let mut bytes_vec = template; - let mut template: Vec = vec![2]; - template.append(&mut bytes_vec); - let bytes_slice = &template[..]; - - bytes_array_33.copy_from_slice(&bytes_slice[0..33]); - let result = PK::from_slice(&bytes_array_33); - let test = result.map(|pk| Secp256k1Point { - purpose: "random", - ge: pk, - }); - test.map_err(|_err| ErrorKey::InvalidPublicKey) - } - _ => { - let bytes_slice = &bytes_vec[0..64]; - let mut bytes_vec = bytes_slice.to_vec(); - let mut template: Vec = vec![4]; - template.append(&mut bytes_vec); - let bytes_slice = &template[..]; - - bytes_array_65.copy_from_slice(&bytes_slice[0..65]); - let result = PK::from_slice(&bytes_array_65); - let test = result.map(|pk| Secp256k1Point { - purpose: "random", - ge: pk, - }); - test.map_err(|_err| ErrorKey::InvalidPublicKey) - } + purpose: "zero", + ge: None, } } - fn pk_to_key_slice(&self) -> Vec { - let mut v = vec![4_u8]; - let x_vec = BigInt::to_bytes(&self.x_coor().unwrap()); - let y_vec = BigInt::to_bytes(&self.y_coor().unwrap()); - - let mut raw_x: Vec = Vec::new(); - let mut raw_y: Vec = Vec::new(); - raw_x.extend(vec![0u8; 32 - x_vec.len()]); - raw_x.extend(x_vec); - - raw_y.extend(vec![0u8; 32 - y_vec.len()]); - raw_y.extend(y_vec); - - v.extend(raw_x); - v.extend(raw_y); - v - } - fn scalar_mul(&self, fe: &SK) -> Secp256k1Point { - let mut new_point = *self; - new_point - .ge - .mul_assign(get_context(), &fe[..]) - .expect("Assignment expected"); - new_point + fn is_zero(&self) -> bool { + self.ge.is_none() } - fn add_point(&self, other: &PK) -> Secp256k1Point { - Secp256k1Point { - purpose: "combine", - ge: self.ge.combine(other).unwrap(), - } + fn generator() -> &'static Secp256k1Point { + &GENERATOR } - fn sub_point(&self, other: &PK) -> Secp256k1Point { - let point = Secp256k1Point { - purpose: "sub_point", - ge: *other, - }; - let p: Vec = vec![ - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 254, 255, 255, 252, 47, - ]; - let order = BigInt::from_bytes(&p[..]); - let x = point.x_coor().unwrap(); - let y = point.y_coor().unwrap(); - let minus_y = BigInt::mod_sub(&order, &y, &order); - - let x_vec = BigInt::to_bytes(&x); - let y_vec = BigInt::to_bytes(&minus_y); - - let mut template_x = vec![0; 32 - x_vec.len()]; - template_x.extend_from_slice(&x_vec); - let mut x_vec = template_x; - - let mut template_y = vec![0; 32 - y_vec.len()]; - template_y.extend_from_slice(&y_vec); - let y_vec = template_y; - - x_vec.extend_from_slice(&y_vec); - - let minus_point: GE = ECPoint::from_bytes(&x_vec).unwrap(); - //let minus_point: GE = ECPoint::from_coor(&x, &y_inv); - ECPoint::add_point(self, &minus_point.get_element()) + fn base_point2() -> &'static Secp256k1Point { + &BASE_POINT2 } - fn from_coor(x: &BigInt, y: &BigInt) -> Secp256k1Point { + fn from_coords(x: &BigInt, y: &BigInt) -> Result { let mut vec_x = BigInt::to_bytes(x); let mut vec_y = BigInt::to_bytes(y); let coor_size = (UNCOMPRESSED_PUBLIC_KEY_SIZE - 1) / 2; if vec_x.len() < coor_size { // pad - let mut x_buffer = vec![0; coor_size - vec_x.len()]; - x_buffer.extend_from_slice(&vec_x); - vec_x = x_buffer + let mut x_padded = vec![0; coor_size - vec_x.len()]; + x_padded.extend_from_slice(&vec_x); + vec_x = x_padded } if vec_y.len() < coor_size { // pad - let mut y_buffer = vec![0; coor_size - vec_y.len()]; - y_buffer.extend_from_slice(&vec_y); - vec_y = y_buffer + let mut y_padded = vec![0; coor_size - vec_y.len()]; + y_padded.extend_from_slice(&vec_y); + vec_y = y_padded } assert_eq!(x, &BigInt::from_bytes(vec_x.as_ref())); @@ -472,366 +328,359 @@ impl ECPoint for Secp256k1Point { v.extend(vec_x); v.extend(vec_y); - Secp256k1Point { - purpose: "base_fe", - ge: PK::from_slice(&v).unwrap(), - } + PublicKey::from_slice(&v) + .map(|ge| Secp256k1Point { + purpose: "from_coords", + ge: Some(PK(ge)), + }) + .map_err(|_| NotOnCurve) } -} -static mut CONTEXT: Option> = None; -pub fn get_context() -> &'static Secp256k1 { - static INIT_CONTEXT: Once = Once::new(); - INIT_CONTEXT.call_once(|| unsafe { - CONTEXT = Some(Secp256k1::verification_only()); - }); - unsafe { CONTEXT.as_ref().unwrap() } -} - -#[cfg(feature = "merkle")] -impl Hashable for Secp256k1Point { - fn update_context(&self, context: &mut Sha3) { - let bytes: Vec = self.pk_to_key_slice(); - context.input(&bytes[..]); - } -} - -impl Mul for Secp256k1Point { - type Output = Secp256k1Point; - fn mul(self, other: Secp256k1Scalar) -> Self::Output { - self.scalar_mul(&other.get_element()) + fn x_coord(&self) -> Option { + match &self.ge { + Some(ge) => { + let serialized_pk = ge.serialize_uncompressed(); + let x = &serialized_pk[1..serialized_pk.len() / 2 + 1]; + Some(BigInt::from_bytes(x)) + } + None => None, + } } -} -impl<'o> Mul<&'o Secp256k1Scalar> for Secp256k1Point { - type Output = Secp256k1Point; - fn mul(self, other: &'o Secp256k1Scalar) -> Self::Output { - self.scalar_mul(&other.get_element()) + fn y_coord(&self) -> Option { + match &self.ge { + Some(ge) => { + let serialized_pk = ge.serialize_uncompressed(); + let y = &serialized_pk[(serialized_pk.len() - 1) / 2 + 1..serialized_pk.len()]; + Some(BigInt::from_bytes(y)) + } + None => None, + } } -} -impl<'o> Mul<&'o Secp256k1Scalar> for &'o Secp256k1Point { - type Output = Secp256k1Point; - fn mul(self, other: &'o Secp256k1Scalar) -> Self::Output { - self.scalar_mul(&other.get_element()) + fn coords(&self) -> Option { + match &self.ge { + Some(ge) => { + let serialized_pk = ge.serialize_uncompressed(); + let x = &serialized_pk[1..serialized_pk.len() / 2 + 1]; + let y = &serialized_pk[(serialized_pk.len() - 1) / 2 + 1..serialized_pk.len()]; + Some(PointCoords { + x: BigInt::from_bytes(x), + y: BigInt::from_bytes(y), + }) + } + None => None, + } } -} -impl Add for Secp256k1Point { - type Output = Secp256k1Point; - fn add(self, other: Secp256k1Point) -> Self::Output { - self.add_point(&other.get_element()) + fn serialize(&self, compressed: bool) -> Option> { + let ge = self.ge.as_ref()?; + if compressed { + Some(ge.serialize().to_vec()) + } else { + // TODO: why not using ge.serialize_uncompressed()? + // https://docs.rs/secp256k1/0.20.3/secp256k1/key/struct.PublicKey.html#method.serialize_uncompressed + let mut v = vec![4_u8]; + let x_vec = BigInt::to_bytes( + &self + .x_coord() + .expect("guaranteed by the first line of this function"), + ); + let y_vec = BigInt::to_bytes( + &self + .y_coord() + .expect("guaranteed by the first line of this function"), + ); + + let mut raw_x: Vec = Vec::new(); + let mut raw_y: Vec = Vec::new(); + raw_x.extend(vec![0u8; 32 - x_vec.len()]); + raw_x.extend(x_vec); + + raw_y.extend(vec![0u8; 32 - y_vec.len()]); + raw_y.extend(y_vec); + + v.extend(raw_x); + v.extend(raw_y); + Some(v) + } } -} -impl<'o> Add<&'o Secp256k1Point> for Secp256k1Point { - type Output = Secp256k1Point; - fn add(self, other: &'o Secp256k1Point) -> Self::Output { - self.add_point(&other.get_element()) + fn deserialize(bytes: &[u8]) -> Result { + let pk = PublicKey::from_slice(bytes).map_err(|_| DeserializationError)?; + Ok(Secp256k1Point { + purpose: "from_bytes", + ge: Some(PK(pk)), + }) } -} -impl<'o> Add<&'o Secp256k1Point> for &'o Secp256k1Point { - type Output = Secp256k1Point; - fn add(self, other: &'o Secp256k1Point) -> Self::Output { - self.add_point(&other.get_element()) - } -} + fn scalar_mul(&self, scalar: &Self::Scalar) -> Secp256k1Point { + let mut new_point = match &self.ge { + Some(ge) => *ge, + None => { + // Point is zero => O * a = O + return Secp256k1Point { + purpose: "mul", + ge: None, + }; + } + }; + let scalar = match scalar.fe.deref() { + Some(s) => s, + None => { + // Scalar is zero => p * 0 = O + return Secp256k1Point { + purpose: "mul", + ge: None, + }; + } + }; + let result = new_point.mul_assign(&CONTEXT, &scalar[..]); + if result.is_err() { + // Multiplication resulted into zero point + return Secp256k1Point { + purpose: "mul", + ge: None, + }; + } -impl Serialize for Secp256k1Point { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut state = serializer.serialize_struct("Secp256k1Point", 2)?; - state.serialize_field("x", &self.x_coor().unwrap().to_hex())?; - state.serialize_field("y", &self.y_coor().unwrap().to_hex())?; - state.end() + Secp256k1Point { + purpose: "mul", + ge: Some(new_point), + } } -} -impl<'de> Deserialize<'de> for Secp256k1Point { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let fields = &["x", "y"]; - deserializer.deserialize_struct("Secp256k1Point", fields, Secp256k1PointVisitor) + fn add_point(&self, other: &Self) -> Secp256k1Point { + let ge1 = match &self.ge { + Some(ge) => ge, + None => { + // Point1 is zero => O + p2 = p2 + return Secp256k1Point { + purpose: "add", + ge: other.ge, + }; + } + }; + let ge2 = match &other.ge { + Some(ge) => ge, + None => { + // Point2 is zero => p1 + O = p1 + return Secp256k1Point { + purpose: "add", + ge: Some(*ge1), + }; + } + }; + Secp256k1Point { + purpose: "add", + ge: ge1.combine(ge2).map(PK).ok(), + } } -} -struct Secp256k1PointVisitor; - -impl<'de> Visitor<'de> for Secp256k1PointVisitor { - type Value = Secp256k1Point; + fn sub_point(&self, other: &Self) -> Secp256k1Point { + let mut ge2_negated = match &other.ge { + Some(ge) => *ge, + None => { + // Point2 is zero => p1 - O = p1 + return Secp256k1Point { + purpose: "sub", + ge: self.ge, + }; + } + }; + ge2_negated.negate_assign(&CONTEXT); + + let ge1 = match &self.ge { + Some(ge) => ge, + None => { + // Point1 is zero => O - p2 = -p2 + return Secp256k1Point { + purpose: "sub", + ge: Some(ge2_negated), + }; + } + }; - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("Secp256k1Point") + Secp256k1Point { + purpose: "sub", + ge: ge1.combine(&ge2_negated).map(PK).ok(), + } } - fn visit_seq(self, mut seq: V) -> Result - where - V: SeqAccess<'de>, - { - let x = seq - .next_element()? - .ok_or_else(|| V::Error::invalid_length(0, &"a single element"))?; - let y = seq - .next_element()? - .ok_or_else(|| V::Error::invalid_length(0, &"a single element"))?; - - let bx = BigInt::from_hex(x).map_err(V::Error::custom)?; - let by = BigInt::from_hex(y).map_err(V::Error::custom)?; - - Ok(Secp256k1Point::from_coor(&bx, &by)) + fn neg_point(&self) -> Secp256k1Point { + Secp256k1Point { + purpose: "neg", + ge: match self.ge { + Some(mut ge) => { + ge.negate_assign(&CONTEXT); + Some(ge) + } + None => { + // Point is zero => -O = O + None + } + }, + } } - fn visit_map>(self, mut map: E) -> Result { - let mut x = String::new(); - let mut y = String::new(); - - while let Some(ref key) = map.next_key::()? { - let v = map.next_value::()?; - if key == "x" { - x = v - } else if key == "y" { - y = v - } else { - return Err(E::Error::unknown_field(key, &["x", "y"])); + fn scalar_mul_assign(&mut self, scalar: &Self::Scalar) { + let ge = match self.ge.as_mut() { + Some(ge) => ge, + None => { + // Point is zero => O * s = O + self.ge = None; + return; } - } + }; - let bx = BigInt::from_hex(&x).map_err(E::Error::custom)?; - let by = BigInt::from_hex(&y).map_err(E::Error::custom)?; + let fe = match scalar.fe.as_ref() { + Some(fe) => fe, + None => { + // Scalar is zero => p * 0 = O + self.ge = None; + return; + } + }; - Ok(Secp256k1Point::from_coor(&bx, &by)) + if let Err(_) = ge.mul_assign(&CONTEXT, &fe[..]) { + // Multiplication resulted into zero + self.ge = None + } } -} - -#[cfg(test)] -mod tests { - use super::BigInt; - use super::Secp256k1Point; - use super::Secp256k1Scalar; - use crate::arithmetic::traits::*; - use crate::cryptographic_primitives::hashing::hash_sha256::HSha256; - use crate::cryptographic_primitives::hashing::traits::Hash; - use crate::elliptic::curves::traits::ECPoint; - use crate::elliptic::curves::traits::ECScalar; - #[test] - fn serialize_sk() { - let scalar: Secp256k1Scalar = ECScalar::from(&BigInt::from(123456)); - let s = serde_json::to_string(&scalar).expect("Failed in serialization"); - assert_eq!(s, "\"1e240\""); + fn underlying_ref(&self) -> &Self::Underlying { + &self.ge } - - #[test] - fn serialize_rand_pk_verify_pad() { - let vx = BigInt::from_hex( - &"ccaf75ab7960a01eb421c0e2705f6e84585bd0a094eb6af928c892a4a2912508".to_string(), - ) - .unwrap(); - - let vy = BigInt::from_hex( - &"e788e294bd64eee6a73d2fc966897a31eb370b7e8e9393b0d8f4f820b48048df".to_string(), - ) - .unwrap(); - - Secp256k1Point::from_coor(&vx, &vy); // x and y of size 32 - - let x = BigInt::from_hex( - &"5f6853305467a385b56a5d87f382abb52d10835a365ec265ce510e04b3c3366f".to_string(), - ) - .unwrap(); - - let y = BigInt::from_hex( - &"b868891567ca1ee8c44706c0dc190dd7779fe6f9b92ced909ad870800451e3".to_string(), - ) - .unwrap(); - - Secp256k1Point::from_coor(&x, &y); // x and y not of size 32 each - - let r = Secp256k1Point::random_point(); - let r_expected = Secp256k1Point::from_coor(&r.x_coor().unwrap(), &r.y_coor().unwrap()); - - assert_eq!(r.x_coor().unwrap(), r_expected.x_coor().unwrap()); - assert_eq!(r.y_coor().unwrap(), r_expected.y_coor().unwrap()); + fn underlying_mut(&mut self) -> &mut Self::Underlying { + &mut self.ge } - - #[test] - fn deserialize_sk() { - let s = "\"1e240\""; - let dummy: Secp256k1Scalar = serde_json::from_str(s).expect("Failed in serialization"); - - let sk: Secp256k1Scalar = ECScalar::from(&BigInt::from(123456)); - - assert_eq!(dummy, sk); + fn from_underlying(ge: Self::Underlying) -> Secp256k1Point { + Secp256k1Point { + purpose: "from_underlying", + ge, + } } +} - #[test] - fn serialize_pk() { - let pk = Secp256k1Point::generator(); - let x = pk.x_coor().unwrap(); - let y = pk.y_coor().unwrap(); - let s = serde_json::to_string(&pk).expect("Failed in serialization"); - - let expected = format!("{{\"x\":\"{}\",\"y\":\"{}\"}}", x.to_hex(), y.to_hex()); - assert_eq!(s, expected); - - let des_pk: Secp256k1Point = serde_json::from_str(&s).expect("Failed in serialization"); - assert_eq!(des_pk.ge, pk.ge); +impl PartialEq for Secp256k1Point { + fn eq(&self, other: &Secp256k1Point) -> bool { + self.underlying_ref() == other.underlying_ref() } +} - #[test] - fn bincode_pk() { - let pk = Secp256k1Point::generator(); - let bin = bincode::serialize(&pk).unwrap(); - let decoded: Secp256k1Point = bincode::deserialize(bin.as_slice()).unwrap(); - assert_eq!(decoded, pk); +impl Zeroize for Secp256k1Point { + fn zeroize(&mut self) { + self.ge.zeroize() } +} - use crate::elliptic::curves::secp256_k1::{FE, GE}; - use crate::ErrorKey; +#[cfg(test)] +mod test { + use std::iter; - #[test] - fn test_serdes_pk() { - let pk = GE::generator(); - let s = serde_json::to_string(&pk).expect("Failed in serialization"); - let des_pk: GE = serde_json::from_str(&s).expect("Failed in deserialization"); - assert_eq!(des_pk, pk); + use crate::elliptic::curves::traits::*; + use crate::BigInt; - let pk = GE::base_point2(); - let s = serde_json::to_string(&pk).expect("Failed in serialization"); - let des_pk: GE = serde_json::from_str(&s).expect("Failed in deserialization"); - assert_eq!(des_pk, pk); - } + use super::{FE, GE}; #[test] - #[should_panic] - fn test_serdes_bad_pk() { - let pk = GE::generator(); - let s = serde_json::to_string(&pk).expect("Failed in serialization"); - // we make sure that the string encodes invalid point: - let s: String = s.replace("79be", "79bf"); - let des_pk: GE = serde_json::from_str(&s).expect("Failed in deserialization"); - assert_eq!(des_pk, pk); + fn valid_zero_point() { + let zero = GE::zero(); + assert!(zero.is_zero()); + assert_eq!(zero, GE::zero()); } #[test] - fn test_from_bytes() { - let g = Secp256k1Point::generator(); - let hash = HSha256::create_hash(&[&g.bytes_compressed_to_big_int()]); - let hash_vec = BigInt::to_bytes(&hash); - let result = Secp256k1Point::from_bytes(&hash_vec); - assert_eq!(result.unwrap_err(), ErrorKey::InvalidPublicKey) - } + fn zero_point_arithmetic() { + let zero_point = GE::zero(); + let point = GE::generator().scalar_mul(&FE::random()); - #[test] - fn test_from_bytes_3() { - let test_vec = [ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 1, 2, 3, 4, 5, 6, - ]; - let result = Secp256k1Point::from_bytes(&test_vec); - assert!(result.is_ok() | result.is_err()) - } + assert_eq!(zero_point.add_point(&point), point, "O + P = P"); + assert_eq!(point.add_point(&zero_point), point, "P + O = P"); - #[test] - fn test_from_bytes_4() { - let test_vec = [ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, - ]; - let result = Secp256k1Point::from_bytes(&test_vec); - assert!(result.is_ok() | result.is_err()) - } + let point_neg = point.neg_point(); + assert!(point.add_point(&point_neg).is_zero(), "P + (-P) = O"); + assert!(point.sub_point(&point).is_zero(), "P - P = O"); - #[test] - fn test_from_bytes_5() { - let test_vec = [ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, - 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, - 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, - 4, 5, 6, - ]; - let result = Secp256k1Point::from_bytes(&test_vec); - assert!(result.is_ok() | result.is_err()) + let zero_scalar = FE::zero(); + assert!(point.scalar_mul(&zero_scalar).is_zero(), "P * 0 = O"); + let scalar = FE::random(); + assert!(zero_point.scalar_mul(&scalar).is_zero(), "O * s = O") } #[test] - fn test_minus_point() { - let a: FE = ECScalar::new_random(); - let b: FE = ECScalar::new_random(); - let b_bn = b.to_big_int(); - let order = FE::q(); - let minus_b = BigInt::mod_sub(&order, &b_bn, &order); - let a_minus_b = BigInt::mod_add(&a.to_big_int(), &minus_b, &order); - let a_minus_b_fe: FE = ECScalar::from(&a_minus_b); - let base: GE = ECPoint::generator(); - let point_ab1 = base * a_minus_b_fe; - - let point_a = base * a; - let point_b = base * b; - let point_ab2 = point_a.sub_point(&point_b.get_element()); - assert_eq!(point_ab1.get_element(), point_ab2.get_element()); - } + fn scalar_modulo_curve_order() { + let n = FE::curve_order(); + let s = FE::from_bigint(n); + assert!(s.is_zero()); - #[test] - fn test_invert() { - let a: FE = ECScalar::new_random(); - let a_bn = a.to_big_int(); - let a_inv = a.invert(); - let a_inv_bn_1 = BigInt::mod_inv(&a_bn, &FE::q()).unwrap(); - let a_inv_bn_2 = a_inv.to_big_int(); - assert_eq!(a_inv_bn_1, a_inv_bn_2); + let s = FE::from_bigint(&(n + 1)); + assert_eq!(s, FE::from_bigint(&BigInt::from(1))); } #[test] - fn test_scalar_mul_scalar() { - let a: FE = ECScalar::new_random(); - let b: FE = ECScalar::new_random(); - let c1 = a.mul(&b.get_element()); - let c2 = a * b; - assert_eq!(c1.get_element(), c2.get_element()); + fn zero_scalar_arithmetic() { + let s = FE::random(); + let z = FE::zero(); + assert!(s.mul(&z).is_zero()); + assert!(z.mul(&s).is_zero()); + assert_eq!(s.add(&z), s); + assert_eq!(z.add(&s), s); } #[test] - fn test_pk_to_key_slice() { - for _ in 1..200 { - let r = FE::new_random(); - let rg = GE::generator() * r; - let key_slice = rg.pk_to_key_slice(); + fn point_addition_multiplication() { + let point = GE::generator().scalar_mul(&FE::random()); + assert!(!point.is_zero(), "G * s != O"); - assert!(key_slice.len() == 65); - assert!(key_slice[0] == 4); - - let rg_prime: GE = ECPoint::from_bytes(&key_slice[1..65]).unwrap(); - assert_eq!(rg_prime.get_element(), rg.get_element()); - } + let addition = iter::successors(Some(point), |p| Some(p.add_point(&point))) + .take(10) + .collect::>(); + let multiplication = (1..=10) + .map(|i| FE::from_bigint(&BigInt::from(i))) + .map(|s| point.scalar_mul(&s)) + .collect::>(); + assert_eq!(addition, multiplication); } #[test] - fn test_base_point2() { - /* Show that base_point2() is returning a point of unknown discrete logarithm. - It is done by using SHA256 repeatedly as a pseudo-random function, with the generator - as the initial input, until receiving a valid Secp256k1 point. */ - - let base_point2 = Secp256k1Point::base_point2(); - - let g = Secp256k1Point::generator(); - let mut hash = HSha256::create_hash(&[&g.bytes_compressed_to_big_int()]); - hash = HSha256::create_hash(&[&hash]); - hash = HSha256::create_hash(&[&hash]); - - assert_eq!(hash, base_point2.x_coor().unwrap(),); - - // check that base_point2 is indeed on the curve (from_coor() will fail otherwise) - assert_eq!( - Secp256k1Point::from_coor( - &base_point2.x_coor().unwrap(), - &base_point2.y_coor().unwrap() - ), - base_point2 - ); - } + fn serialize_deserialize() { + let point = GE::generator().scalar_mul(&FE::random()); + let bytes = point + .serialize(true) + .expect("point has coordinates => must be serializable"); + let deserialized = GE::deserialize(&bytes).unwrap(); + assert_eq!(point, deserialized); + + let bytes = point + .serialize(false) + .expect("point has coordinates => must be serializable"); + let deserialized = GE::deserialize(&bytes).unwrap(); + assert_eq!(point, deserialized); + } + + // #[test] + // fn test_base_point2() { + // /* Show that base_point2() is returning a point of unknown discrete logarithm. + // It is done by using SHA256 repeatedly as a pseudo-random function, with the generator + // as the initial input, until receiving a valid Secp256k1 point. */ + // + // let base_point2 = GE::base_point2(); + // + // let g = GE::generator(); + // let mut hash = HSha256::create_hash(&[&g.bytes_compressed_to_big_int()]); + // hash = HSha256::create_hash(&[&hash]); + // hash = HSha256::create_hash(&[&hash]); + // + // assert_eq!(hash, base_point2.x_coor().unwrap(),); + // + // // check that base_point2 is indeed on the curve (from_coor() will fail otherwise) + // assert_eq!( + // Secp256k1Point::from_coor( + // &base_point2.x_coor().unwrap(), + // &base_point2.y_coor().unwrap() + // ), + // base_point2 + // ); + // } } diff --git a/src/elliptic/curves/traits.rs b/src/elliptic/curves/traits.rs index ff029c6c..bd57200a 100644 --- a/src/elliptic/curves/traits.rs +++ b/src/elliptic/curves/traits.rs @@ -5,48 +5,212 @@ License MIT: */ -use std::ops::{Add, Mul}; +use std::fmt; use crate::BigInt; -use crate::ErrorKey; +use zeroize::Zeroize; -pub trait ECScalar: Mul + Add + Sized { - type SecretKey; +/// Elliptic curve implementation +/// +/// Refers to according implementation of [ECPoint] and [ECScalar]. +pub trait Curve { + type Point: ECPoint; + type Scalar: ECScalar; - fn new_random() -> Self; + /// Returns canonical name for this curve + fn curve_name() -> &'static str; +} + +/// Scalar value modulus [curve order](Self::curve_order) +/// +/// ## Note +/// This is a low-level trait, you should not use it directly. See wrappers [Point], [PointZ], +/// [Scalar], [ScalarZ]. +/// +/// [Point]: super::wrappers::Point +/// [PointZ]: super::wrappers::PointZ +/// [Scalar]: super::wrappers::Scalar +/// [ScalarZ]: super::wrappers::ScalarZ +/// +/// Trait exposes various methods to manipulate scalars. Scalar can be zero. Scalar must zeroize its +/// value on drop. +pub trait ECScalar: Clone + PartialEq + fmt::Debug { + /// Underlying scalar type that can be retrieved in case of missing methods in this trait + type Underlying; + + /// Samples a random scalar + fn random() -> Self; + + /// Constructs a zero scalar fn zero() -> Self; - fn get_element(&self) -> Self::SecretKey; - fn set_element(&mut self, element: Self::SecretKey); - fn from(n: &BigInt) -> Self; - fn to_big_int(&self) -> BigInt; - fn q() -> BigInt; - fn add(&self, other: &Self::SecretKey) -> Self; - fn mul(&self, other: &Self::SecretKey) -> Self; - fn sub(&self, other: &Self::SecretKey) -> Self; - fn invert(&self) -> Self; + /// Checks if the scalar equals to zero + fn is_zero(&self) -> bool; + + /// Constructs a scalar `n % curve_order` + fn from_bigint(n: &BigInt) -> Self; + /// Converts a scalar to BigInt + fn to_bigint(&self) -> BigInt; + + /// Calculates `(self + other) mod curve_order` + fn add(&self, other: &Self) -> Self; + /// Calculates `(self * other) mod curve_order` + fn mul(&self, other: &Self) -> Self; + /// Calculates `(self - other) mod curve_order` + fn sub(&self, other: &Self) -> Self; + /// Calculates `-self mod curve_order` + fn neg(&self) -> Self; + /// Calculates `self^-1 (mod curve_order)`, returns None if self equals to zero + fn invert(&self) -> Option; + /// Calculates `(self + other) mod curve_order`, and assigns result to `self` + fn add_assign(&mut self, other: &Self) { + *self = self.add(other) + } + /// Calculates `(self * other) mod curve_order`, and assigns result to `self` + fn mul_assign(&mut self, other: &Self) { + *self = self.mul(other) + } + /// Calculates `(self - other) mod curve_order`, and assigns result to `self` + fn sub_assign(&mut self, other: &Self) { + *self = self.sub(other) + } + /// Calculates `-self mod curve_order`, and assigns result to `self` + fn neg_assign(&mut self) { + *self = self.neg() + } + + fn curve_order() -> &'static BigInt; + + /// Returns a reference to underlying scalar value + fn underlying_ref(&self) -> &Self::Underlying; + /// Returns a mutable reference to underlying scalar value + fn underlying_mut(&mut self) -> &mut Self::Underlying; + /// Constructs a scalar from underlying value + fn from_underlying(u: Self::Underlying) -> Self; } -// TODO: add a fn is_point -pub trait ECPoint: - Mul<::Scalar, Output = Self> + Add + PartialEq -where - Self: Sized, -{ - type SecretKey; - type PublicKey; - - type Scalar: ECScalar; - - fn base_point2() -> Self; - fn generator() -> Self; - fn get_element(&self) -> Self::PublicKey; - fn x_coor(&self) -> Option; - fn y_coor(&self) -> Option; - fn bytes_compressed_to_big_int(&self) -> BigInt; - fn from_bytes(bytes: &[u8]) -> Result; - fn pk_to_key_slice(&self) -> Vec; - fn scalar_mul(&self, fe: &Self::SecretKey) -> Self; - fn add_point(&self, other: &Self::PublicKey) -> Self; - fn sub_point(&self, other: &Self::PublicKey) -> Self; - fn from_coor(x: &BigInt, y: &BigInt) -> Self; +/// Point on elliptic curve +/// +/// ## Note +/// This is a low-level trait, you should not use it directly. See [Point], [PointZ], [Scalar], +/// [ScalarZ]. +/// +/// [Point]: super::wrappers::Point +/// [PointZ]: super::wrappers::PointZ +/// [Scalar]: super::wrappers::Scalar +/// [ScalarZ]: super::wrappers::ScalarZ +/// +/// Trait exposes various methods that make elliptic curve arithmetic. The point can +/// be [zero](ECPoint::zero). Unlike [ECScalar], ECPoint isn't required to zeroize its value on drop, +/// but it implementы [Zeroize] trait so you can force zeroizing policy on your own. +pub trait ECPoint: Zeroize + Clone + PartialEq + fmt::Debug { + /// Scalar value the point can be multiplied at + type Scalar: ECScalar; + /// Underlying curve implementation that can be retrieved in case of missing methods in this trait + type Underlying; + + /// Zero point + /// + /// Zero point is usually denoted as O. It's curve neutral element, i.e. `forall A. A + O = A`. + /// Weierstrass and Montgomery curves employ special "point at infinity" to add neutral elements, + /// such points don't have coordinates (i.e. [from_coords], [x_coord], [y_coord] return `None`). + /// Edwards curves' neutral element has coordinates. + /// + /// [from_coords]: Self::from_coords + /// [x_coord]: Self::x_coord + /// [y_coord]: Self::y_coord + fn zero() -> Self; + + /// Returns `true` if point is a neutral element + fn is_zero(&self) -> bool; + + /// Curve generator + /// + /// Returns a static reference at actual value because in most cases reference value is fine. + /// Use `.clone()` if you need to take it by value, i.e. `ECPoint::generator().clone()` + fn generator() -> &'static Self; + /// Curve second generator + /// + /// We provide an alternative generator value and prove that it was picked randomly + fn base_point2() -> &'static Self; + + /// Constructs a curve point from its coordinates + /// + /// Returns error if x, y are not on curve + fn from_coords(x: &BigInt, y: &BigInt) -> Result; + /// Returns `x` coordinate of the point, or `None` if point is at infinity + fn x_coord(&self) -> Option; + /// Returns `y` coordinate of the point, or `None` if point is at infinity + fn y_coord(&self) -> Option; + /// Returns point coordinates (`x` and `y`), or `None` if point is at infinity + fn coords(&self) -> Option; + + /// Serializes point into bytes either in compressed or uncompressed form + /// + /// Returns None if point doesn't have coordinates, ie. it is "at infinity". If point isn't + /// at infinity, serialize always succeeds. + fn serialize(&self, compressed: bool) -> Option>; + /// Deserializes point from bytes + /// + /// Whether point in compressed or uncompressed form will be deducted from its size + fn deserialize(bytes: &[u8]) -> Result; + + /// Multiplies the point at scalar value + fn scalar_mul(&self, scalar: &Self::Scalar) -> Self; + /// Adds two points + fn add_point(&self, other: &Self) -> Self; + /// Substrates `other` from `self` + fn sub_point(&self, other: &Self) -> Self; + /// Negates point + fn neg_point(&self) -> Self; + + /// Multiplies the point at scalar value, assigns result to `self` + fn scalar_mul_assign(&mut self, scalar: &Self::Scalar) { + *self = self.scalar_mul(scalar) + } + /// Adds two points, assigns result to `self` + fn add_point_assign(&mut self, other: &Self) { + *self = self.add_point(other) + } + /// Substrates `other` from `self`, assigns result to `self` + fn sub_point_assign(&mut self, other: &Self) { + *self = self.sub_point(other) + } + /// Negates point, assigns result to `self` + fn neg_point_assign(&mut self) { + *self = self.neg_point() + } + + /// Reference to underlying curve implementation + fn underlying_ref(&self) -> &Self::Underlying; + /// Mutual reference to underlying curve implementation + fn underlying_mut(&mut self) -> &mut Self::Underlying; + /// Construct a point from its underlying representation + fn from_underlying(u: Self::Underlying) -> Self; } + +pub struct PointCoords { + pub x: BigInt, + pub y: BigInt, +} + +#[derive(Debug)] +pub struct DeserializationError; + +impl fmt::Display for DeserializationError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "failed to deserialize the point") + } +} + +impl std::error::Error for DeserializationError {} + +#[derive(Debug)] +pub struct NotOnCurve; + +impl fmt::Display for NotOnCurve { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "point not on the curve") + } +} + +impl std::error::Error for NotOnCurve {} diff --git a/src/elliptic/curves/wrappers.rs b/src/elliptic/curves/wrappers.rs new file mode 100644 index 00000000..5411a941 --- /dev/null +++ b/src/elliptic/curves/wrappers.rs @@ -0,0 +1,1282 @@ +use std::borrow::Cow; +use std::convert::TryFrom; +use std::{fmt, iter, ops}; + +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use super::traits::*; +use crate::arithmetic::{BigInt, Converter}; + +/// Elliptic point that **might be zero** +/// +/// ## Security +/// +/// Mistakenly used zero point might break security of cryptographic algorithm. It's preferred to +/// use [`Point`](Point) that's guaranteed to be non-zero. Use [ensure_nonzero](PointZ::ensure_nonzero) +/// to convert `PointZ` into `Point`. +/// +/// ## Guarantees +/// +/// * Belongs to curve +/// +/// Any instance of `PointZ` is guaranteed to belong to curve `E`, i.e. its coordinates must +/// satisfy curve equations +/// +/// ## Arithmetics +/// +/// You can add, subtract two points, or multiply point at scalar: +/// +/// ```rust +/// # use curv::elliptic::curves::{PointZ, Scalar, Secp256k1}; +/// fn expression( +/// a: PointZ, +/// b: PointZ, +/// c: Scalar, +/// ) -> PointZ { +/// a + b * c +/// } +/// ``` +pub struct PointZ(E::Point); + +impl PointZ { + /// Checks if `self` is not zero and converts it into [`Point`](Point). Returns `None` if + /// it's zero. + pub fn ensure_nonzero(self) -> Option> { + Point::try_from(self).ok() + } + + pub fn zero() -> Self { + Self::from_raw(E::Point::zero()) + } + + pub fn iz_zero(&self) -> bool { + self.0.is_zero() + } + + pub fn coords(&self) -> Option { + self.0.coords() + } + + pub fn x_coord(&self) -> Option { + self.0.x_coord() + } + + pub fn y_coord(&self) -> Option { + self.0.y_coord() + } + + fn from_raw(point: E::Point) -> Self { + Self(point) + } +} + +impl PartialEq for PointZ { + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} + +impl Clone for PointZ { + fn clone(&self) -> Self { + PointZ(self.0.clone()) + } +} + +impl fmt::Debug for PointZ { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +/// Elliptic point that _guaranteed_ to be non zero +/// +/// ## Security +/// Non-zero points are preferred to be used in cryptographic algorithms. Lack of checking whether +/// computation on elliptic points results into zero point might lead to vulnerabilities. Using +/// `Point` ensures you and reviewers that check on point not being zero was made. +/// +/// ## Guarantees +/// +/// * Belongs to curve +/// +/// Any instance of `Point` is guaranteed to belong to curve `E`, i.e. its coordinates must +/// satisfy curve equations +/// * Not a neutral element +/// +/// Any instance of `Point` is restricted not to be zero (neutral element), i.e. for any +/// `a: PointZ ∧ b: Point → a + b ≢ a`. +/// +/// Weierstrass and Montgomery curves represent zero point +/// using special "point at infinity", whereas Edwards curves zero point is a regular point that +/// has coordinates. `Point` cannot be instantiated with neither of these points. +/// +/// Note also that `Point` is guaranteed to have coordinates (only point at infinity doesn't). +/// +/// ## Arithmetics +/// +/// You can add, subtract two points, or multiply point at scalar. +/// +/// Any arithmetic operation on non-zero point might result into zero point, so addition, subtraction, +/// and multiplication operations output [PointZ]. Use [ensure_nonzero](PointZ::ensure_nonzero) method +/// to ensure that computation doesn't produce zero-point: +/// +/// ```rust +/// # use curv::elliptic::curves::{PointZ, Point, Scalar, Secp256k1}; +/// let s = Scalar::::random(); // Non-zero scalar +/// let g = Point::::generator(); // Non-zero point (curve generator) +/// let result: PointZ = s * g; // Multiplication of two non-zero points +/// // might produce zero-point +/// let nonzero_result: Option> = result.ensure_nonzero(); +/// ``` +/// +/// When evaluating complex expressions, you typically need to ensure that none of intermediate +/// results are zero-points: +/// +/// ```rust +/// # use curv::elliptic::curves::{Curve, Point, Scalar}; +/// fn expression(a: Point, b: Point, c: Scalar) -> Option> { +/// (a + (b * c).ensure_nonzero()?).ensure_nonzero() +/// } +/// ``` +pub struct Point(E::Point); + +impl Point { + fn from_raw(point: E::Point) -> Result { + if point.is_zero() { + Err(ZeroPointError(())) + } else { + Ok(Self(point)) + } + } + + /// Curve generator + /// + /// Returns a static reference on actual value because in most cases referenced value is fine. + /// Use [`.to_point_owned()`](PointRef::to_point_owned) if you need to take it by value. + pub fn generator() -> PointRef<'static, E> { + let p = E::Point::generator(); + PointRef::from_raw(p).expect("generator must be non-zero") + } + + /// Curve second generator + /// + /// We provide an alternative generator value and prove that it was picked randomly. + /// + /// Returns a static reference on actual value because in most cases referenced value is fine. + /// Use [`.to_point_owned()`](PointRef::to_point_owned) if you need to take it by value. + pub fn base_point2() -> PointRef<'static, E> { + let p = E::Point::base_point2(); + PointRef::from_raw(p).expect("base_point2 must be non-zero") + } + + /// Constructs a point from coordinates, returns error if x,y don't satisfy curve equation or + /// correspond to zero point + pub fn from_coords(x: &BigInt, y: &BigInt) -> Result { + let p = E::Point::from_coords(x, y) + .map_err(|NotOnCurve { .. }| PointFromCoordsError::PointNotOnCurve)?; + Self::from_raw(p).map_err(|ZeroPointError(())| PointFromCoordsError::ZeroPoint) + } + + /// Tries to parse a point from its (un)compressed form + /// + /// Whether it's a compressed or uncompressed form will be deduced from its length + pub fn from_bytes(bytes: &[u8]) -> Result { + let p = E::Point::deserialize(bytes).map_err(PointFromBytesError::Deserialize)?; + Self::from_raw(p).map_err(|ZeroPointError(())| PointFromBytesError::ZeroPoint) + } + + /// Returns point coordinates (`x` and `y`) + /// + /// Method never fails as Point is guaranteed to have coordinates + pub fn coords(&self) -> PointCoords { + self.as_point_ref().coords() + } + + /// Returns `x` coordinate of point + /// + /// Method never fails as Point is guaranteed to have coordinates + pub fn x_coord(&self) -> BigInt { + self.as_point_ref().x_coord() + } + + /// Returns `y` coordinate of point + /// + /// Method never fails as Point is guaranteed to have coordinates + pub fn y_coord(&self) -> BigInt { + self.as_point_ref().y_coord() + } + + /// Adds two points, returns the result, or `None` if resulting point is zero + pub fn add_checked(&self, point: PointRef) -> Option { + self.as_point_ref().add_checked(point) + } + + /// Substrates two points, returns the result, or `None` if resulting point is zero + pub fn sub_checked(&self, point: PointRef) -> Option { + self.as_point_ref().sub_checked(point) + } + + /// Multiplies a point at scalar, returns the result, or `None` if resulting point is zero + pub fn mul_checked_z(&self, scalar: &ScalarZ) -> Option { + self.as_point_ref().mul_checked_z(scalar) + } + + /// Multiplies a point at nonzero scalar, returns the result, or `None` if resulting point is zero + pub fn mul_checked(&self, scalar: &Scalar) -> Option { + self.as_point_ref().mul_checked(scalar) + } + + /// Serializes point into (un)compressed form + pub fn to_bytes(&self, compressed: bool) -> Vec { + self.as_point_ref().to_bytes(compressed) + } + + /// Creates [PointRef] that holds a reference on `self` + pub fn as_point_ref(&self) -> PointRef { + PointRef(&self.0) + } +} + +impl Clone for Point { + fn clone(&self) -> Self { + Point(self.0.clone()) + } +} + +impl fmt::Debug for Point { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl TryFrom> for Point { + type Error = ZeroPointError; + fn try_from(point: PointZ) -> Result { + Self::from_raw(point.0) + } +} + +/// Reference on elliptic point, _guaranteed_ to be non-zero +/// +/// Holds internally a reference on [`Point`](Point), refer to its documentation to learn +/// more about Point/PointRef guarantees, security notes, and arithmetics. +pub struct PointRef<'p, E: Curve>(&'p E::Point); + +impl<'p, E: Curve> Clone for PointRef<'p, E> { + fn clone(&self) -> Self { + Self(self.0) + } +} + +impl<'p, E: Curve> Copy for PointRef<'p, E> {} + +impl<'p, E: Curve> fmt::Debug for PointRef<'p, E> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl PointRef<'static, E> { + pub fn generator() -> Self { + Self::from_raw(E::Point::generator()).expect("generator must be non-zero") + } + + pub fn base_point2() -> Self { + Self::from_raw(E::Point::base_point2()).expect("base_point2 must be non-zero") + } +} + +impl<'p, E> PointRef<'p, E> +where + E: Curve, +{ + fn from_raw(point: &'p E::Point) -> Option { + if point.is_zero() { + None + } else { + Some(Self(point)) + } + } + + /// Returns point coordinates (`x` and `y`) + /// + /// Method never fails as Point is guaranteed to have coordinates + pub fn coords(&self) -> PointCoords { + self.0 + .coords() + .expect("Point guaranteed to have coordinates") + } + + /// Returns `x` coordinate of point + /// + /// Method never fails as Point is guaranteed to have coordinates + pub fn x_coord(&self) -> BigInt { + self.0 + .x_coord() + .expect("Point guaranteed to have coordinates") + } + + /// Returns `y` coordinate of point + /// + /// Method never fails as Point is guaranteed to have coordinates + pub fn y_coord(&self) -> BigInt { + self.0 + .y_coord() + .expect("Point guaranteed to have coordinates") + } + + /// Adds two points, returns the result, or `None` if resulting point is at infinity + pub fn add_checked(&self, point: Self) -> Option> { + let new_point = self.0.add_point(&point.0); + if new_point.is_zero() { + None + } else { + Some(Point(new_point)) + } + } + + /// Substrates two points, returns the result, or `None` if resulting point is at infinity + pub fn sub_checked(&self, point: Self) -> Option> { + let new_point = self.0.sub_point(&point.0); + if new_point.is_zero() { + None + } else { + Some(Point(new_point)) + } + } + + /// Multiplies a point at scalar, returns the result, or `None` if resulting point is at infinity + pub fn mul_checked_z(&self, scalar: &ScalarZ) -> Option> { + let new_point = self.0.scalar_mul(&scalar.0); + if new_point.is_zero() { + None + } else { + Some(Point(new_point)) + } + } + + /// Multiplies a point at nonzero scalar, returns the result, or `None` if resulting point is at infinity + pub fn mul_checked(&self, scalar: &Scalar) -> Option> { + let new_point = self.0.scalar_mul(&scalar.0); + if new_point.is_zero() { + None + } else { + Some(Point(new_point)) + } + } + + /// Serializes point into (un)compressed form + pub fn to_bytes(&self, compressed: bool) -> Vec { + self.0 + .serialize(compressed) + .expect("non-zero point must always be serializable") + } + + /// Clones the referenced point + pub fn to_point_owned(&self) -> Point { + Point(self.0.clone()) + } +} + +/// Converting PointZ to Point error +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct ZeroPointError(()); + +impl fmt::Display for ZeroPointError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "nonzero check failed: point is zero") + } +} + +impl std::error::Error for ZeroPointError {} + +/// Constructing Point from its coordinates error +#[derive(Debug, Error)] +pub enum PointFromCoordsError { + #[error("x,y correspond to zero point")] + ZeroPoint, + #[error("point is not on the curve")] + PointNotOnCurve, +} + +/// Constructing Point from its (un)compressed representation error +#[derive(Debug, Error)] +pub enum PointFromBytesError { + #[error("deserialized point corresponds to zero point")] + ZeroPoint, + #[error("{0}")] + Deserialize(#[source] DeserializationError), +} + +/// Scalar value in a prime field that **might be zero** +/// +/// ## Security +/// +/// Mistakenly used zero scalar might break security of cryptographic algorithm. It's preferred to +/// use `Scalar`[Scalar] that's guaranteed to be non-zero. Use [ensure_nonzero](ScalarZ::ensure_nonzero) +/// to convert `ScalarZ` into `Scalar`. +/// +/// ## Guarantees +/// +/// * Belongs to the curve prime field +/// +/// Denoting curve modulus as `q`, any instance `s` of `ScalarZ` is guaranteed to be non-negative +/// integer modulo `q`: `0 <= s < q` +/// +/// ## Arithmetics +/// +/// Supported operations: +/// * Unary: you can [invert](Self::invert) and negate a scalar by modulo of prime field +/// * Binary: you can add, subtract, and multiply two points +/// +/// ### Example +/// +/// ```rust +/// # use curv::elliptic::curves::{ScalarZ, Secp256k1}; +/// fn expression( +/// a: ScalarZ, +/// b: ScalarZ, +/// c: ScalarZ +/// ) -> ScalarZ { +/// a + b * c +/// } +/// ``` +#[derive(Serialize, Deserialize)] +#[serde(try_from = "ScalarFormat", into = "ScalarFormat")] +pub struct ScalarZ(E::Scalar); + +impl ScalarZ { + pub fn ensure_nonzero(self) -> Option> { + Scalar::from_raw(self.0) + } + + fn from_raw(scalar: E::Scalar) -> Self { + Self(scalar) + } + + pub fn random() -> Self { + Self::from_raw(E::Scalar::random()) + } + + pub fn zero() -> Self { + Self::from_raw(E::Scalar::zero()) + } + + pub fn is_zero(&self) -> bool { + self.0.is_zero() + } + + pub fn to_bigint(&self) -> BigInt { + self.0.to_bigint() + } + + pub fn from_bigint(n: &BigInt) -> Self { + Self::from_raw(E::Scalar::from_bigint(n)) + } + + pub fn invert(&self) -> Option { + self.0.invert().map(Self::from_raw) + } +} + +impl Clone for ScalarZ { + fn clone(&self) -> Self { + Self::from_raw(self.0.clone()) + } +} + +impl fmt::Debug for ScalarZ { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl PartialEq for ScalarZ { + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} + +impl PartialEq> for ScalarZ { + fn eq(&self, other: &Scalar) -> bool { + self.0.eq(&other.0) + } +} + +impl From> for ScalarZ { + fn from(scalar: Scalar) -> Self { + ScalarZ::from_raw(scalar.0) + } +} + +impl From for ScalarZ { + fn from(n: u16) -> Self { + Self::from(&BigInt::from(n)) + } +} + +impl From for ScalarZ { + fn from(n: u32) -> Self { + Self::from(&BigInt::from(n)) + } +} + +impl From for ScalarZ { + fn from(n: u64) -> Self { + Self::from(&BigInt::from(n)) + } +} + +impl From for ScalarZ { + fn from(n: i32) -> Self { + Self::from(&BigInt::from(n)) + } +} + +impl From<&BigInt> for ScalarZ { + fn from(n: &BigInt) -> Self { + ScalarZ::from_raw(E::Scalar::from_bigint(n)) + } +} + +impl From for ScalarZ { + fn from(n: BigInt) -> Self { + Self::from(&n) + } +} + +/// Scalar value in a prime field that _guaranteed_ to be non zero +/// +/// ## Security +/// +/// Non-zero scalars are preferred to be used in cryptographic algorithms. Lack of checking whether +/// computation on field scalars results into zero scalar might lead to vulnerability. Using `Scalar` +/// ensures you and reviewers that check on scalar not being zero was made. +/// +/// ## Guarantees +/// +/// * Belongs to the curve prime field +/// +/// Denoting curve modulus as `q`, any instance `s` of `Scalar` is guaranteed to be less than `q`: +/// `s < q` +/// * Not a zero +/// +/// Any instance `s` of `Scalar` is guaranteed to be more than zero: `s > 0` +/// +/// Combining two rules above, any instance `s` of `Scalar` is guaranteed to be: `0 < s < q`. +/// +/// ## Arithmetic +/// +/// Supported operations: +/// * Unary: you can [invert](Self::invert) and negate a scalar by modulo of prime field +/// * Binary: you can add, subtract, and multiply two points +/// +/// Addition, subtraction, or multiplication of two (even non-zero) scalars might result into zero +/// scalar, so these operations output [ScalarZ]. Use [ensure_nonzero](ScalarZ::ensure_nonzero) method +/// to ensure that computation doesn't produce zero scalar; +/// +/// ```rust +/// # use curv::elliptic::curves::{ScalarZ, Scalar, Secp256k1}; +/// let a = Scalar::::random(); +/// let b = Scalar::::random(); +/// let result: ScalarZ = a * b; +/// let non_zero_result: Option> = result.ensure_nonzero(); +/// ``` +/// +/// When evaluating complex expressions, you typically need to ensure that none of intermediate +/// results are zero scalars: +/// ```rust +/// # use curv::elliptic::curves::{Scalar, Secp256k1}; +/// fn expression(a: Scalar, b: Scalar, c: Scalar) -> Option> { +/// (a + (b * c).ensure_nonzero()?).ensure_nonzero() +/// } +/// ``` +#[derive(Serialize, Deserialize)] +#[serde(try_from = "ScalarFormat", into = "ScalarFormat")] +pub struct Scalar(E::Scalar); + +impl Scalar { + /// Samples a random non-zero scalar + pub fn random() -> Self { + loop { + if let Some(scalar) = ScalarZ::from_raw(E::Scalar::random()).ensure_nonzero() { + break scalar; + } + } + } + + /// Returns modular multiplicative inverse of the scalar + /// + /// Inverse of non-zero scalar is always defined in a prime field, and inverted scalar is also + /// guaranteed to be non-zero. + pub fn invert(&self) -> Self { + self.0 + .invert() + .map(Self) + .expect("non-zero scalar must have corresponding inversion") + } + + /// Adds two scalars, returns the result by modulo `q`, or `None` if resulting scalar is zero + pub fn add_checked(&self, scalar: &Scalar) -> Option { + let scalar = self.0.add(&scalar.0); + Self::from_raw(scalar) + } + + /// Subtracts two scalars, returns the result by modulo `q`, or `None` if resulting scalar is zero + pub fn sub_checked(&self, scalar: &Scalar) -> Option { + let scalar = self.0.sub(&scalar.0); + Self::from_raw(scalar) + } + + /// Multiplies two scalars, returns the result by modulo `q`, or `None` if resulting scalar is zero + pub fn mul_checked(&self, scalar: &Scalar) -> Option { + let scalar = self.0.mul(&scalar.0); + Self::from_raw(scalar) + } + + fn from_raw(scalar: E::Scalar) -> Option { + if scalar.is_zero() { + None + } else { + Some(Self(scalar)) + } + } +} + +impl Clone for Scalar { + fn clone(&self) -> Self { + Scalar(self.0.clone()) + } +} + +impl fmt::Debug for Scalar { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl PartialEq for Scalar { + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} + +impl PartialEq> for Scalar { + fn eq(&self, other: &ScalarZ) -> bool { + self.0.eq(&other.0) + } +} + +macro_rules! matrix { + ( + trait = $trait:ident, + trait_fn = $trait_fn:ident, + output = $output:ty, + output_new = $output_new:ident, + point_fn = $point_fn:ident, + point_assign_fn = $point_assign_fn:ident, + pairs = {(r_<$($l:lifetime),*> $lhs_ref:ty, $rhs:ty), $($rest:tt)*} + ) => { + impl<$($l,)* E: Curve> ops::$trait<$rhs> for $lhs_ref { + type Output = $output; + fn $trait_fn(self, rhs: $rhs) -> Self::Output { + let p = self.0.$point_fn(&rhs.0); + $output_new(p) + } + } + matrix!{ + trait = $trait, + trait_fn = $trait_fn, + output = $output, + output_new = $output_new, + point_fn = $point_fn, + point_assign_fn = $point_assign_fn, + pairs = {$($rest)*} + } + }; + + ( + trait = $trait:ident, + trait_fn = $trait_fn:ident, + output = $output:ty, + output_new = $output_new:ident, + point_fn = $point_fn:ident, + point_assign_fn = $point_assign_fn:ident, + pairs = {(_r<$($l:lifetime),*> $lhs:ty, $rhs_ref:ty), $($rest:tt)*} + ) => { + impl<$($l,)* E: Curve> ops::$trait<$rhs_ref> for $lhs { + type Output = $output; + fn $trait_fn(self, rhs: $rhs_ref) -> Self::Output { + let p = rhs.0.$point_fn(&self.0); + $output_new(p) + } + } + matrix!{ + trait = $trait, + trait_fn = $trait_fn, + output = $output, + output_new = $output_new, + point_fn = $point_fn, + point_assign_fn = $point_assign_fn, + pairs = {$($rest)*} + } + }; + + ( + trait = $trait:ident, + trait_fn = $trait_fn:ident, + output = $output:ty, + output_new = $output_new:ident, + point_fn = $point_fn:ident, + point_assign_fn = $point_assign_fn:ident, + pairs = {(o_<$($l:lifetime),*> $lhs_owned:ty, $rhs:ty), $($rest:tt)*} + ) => { + impl<$($l,)* E: Curve> ops::$trait<$rhs> for $lhs_owned { + type Output = $output; + fn $trait_fn(mut self, rhs: $rhs) -> Self::Output { + self.0.$point_assign_fn(&rhs.0); + $output_new(self.0) + } + } + matrix!{ + trait = $trait, + trait_fn = $trait_fn, + output = $output, + output_new = $output_new, + point_fn = $point_fn, + point_assign_fn = $point_assign_fn, + pairs = {$($rest)*} + } + }; + + ( + trait = $trait:ident, + trait_fn = $trait_fn:ident, + output = $output:ty, + output_new = $output_new:ident, + point_fn = $point_fn:ident, + point_assign_fn = $point_assign_fn:ident, + pairs = {(_o<$($l:lifetime),*> $lhs:ty, $rhs_owned:ty), $($rest:tt)*} + ) => { + impl<$($l,)* E: Curve> ops::$trait<$rhs_owned> for $lhs { + type Output = $output; + fn $trait_fn(self, mut rhs: $rhs_owned) -> Self::Output { + rhs.0.$point_assign_fn(&self.0); + $output_new(rhs.0) + } + } + matrix!{ + trait = $trait, + trait_fn = $trait_fn, + output = $output, + output_new = $output_new, + point_fn = $point_fn, + point_assign_fn = $point_assign_fn, + pairs = {$($rest)*} + } + }; + + ( + trait = $trait:ident, + trait_fn = $trait_fn:ident, + output = $output:ty, + output_new = $output_new:ident, + point_fn = $point_fn:ident, + point_assign_fn = $point_assign_fn:ident, + pairs = {} + ) => { + // happy termination + }; +} + +matrix! { + trait = Add, + trait_fn = add, + output = PointZ, + output_new = PointZ, + point_fn = add_point, + point_assign_fn = add_point_assign, + pairs = { + (o_<> Point, &Point), (o_<> Point, &PointZ), + (r_<> &Point, &Point), (r_<> &Point, &PointZ), + (o_<> PointZ, &Point), (o_<> PointZ, &PointZ), + (r_<> &PointZ, &Point), (r_<> &PointZ, &PointZ), + (o_<> Point, Point), (o_<> PointZ, PointZ), + (o_<> Point, PointZ), (o_<> PointZ, Point), + (_o<> &Point, Point), (_o<> &Point, PointZ), + (_o<> &PointZ, Point), (_o<> &PointZ, PointZ), + + // The same as above, but replacing &Point with PointRef + (o_<'r> Point, PointRef<'r, E>), + (r_<'a, 'b> PointRef<'a, E>, PointRef<'b, E>), (r_<'r> PointRef<'r, E>, &PointZ), + (o_<'r> PointZ, PointRef<'r, E>), + (r_<'r> &PointZ, PointRef<'r, E>), + (_o<'r> PointRef<'r, E>, Point), (_o<'r> PointRef<'r, E>, PointZ), + + // And define trait between &Point and PointRef + (r_<'r> &Point, PointRef<'r, E>), (r_<'r> PointRef<'r, E>, &Point), + } +} + +matrix! { + trait = Sub, + trait_fn = sub, + output = PointZ, + output_new = PointZ, + point_fn = sub_point, + point_assign_fn = sub_point_assign, + pairs = { + (o_<> Point, &Point), (o_<> Point, &PointZ), + (r_<> &Point, &Point), (r_<> &Point, &PointZ), + (o_<> PointZ, &Point), (o_<> PointZ, &PointZ), + (r_<> &PointZ, &Point), (r_<> &PointZ, &PointZ), + (o_<> Point, Point), (o_<> PointZ, PointZ), + (o_<> Point, PointZ), (o_<> PointZ, Point), + (_o<> &Point, Point), (_o<> &Point, PointZ), + (_o<> &PointZ, Point), (_o<> &PointZ, PointZ), + + // The same as above, but replacing &Point with PointRef + (o_<'r> Point, PointRef<'r, E>), + (r_<'a, 'b> PointRef<'a, E>, PointRef<'b, E>), (r_<'r> PointRef<'r, E>, &PointZ), + (o_<'r> PointZ, PointRef<'r, E>), + (r_<'r> &PointZ, PointRef<'r, E>), + (_o<'r> PointRef<'r, E>, Point), (_o<'r> PointRef<'r, E>, PointZ), + + // And define trait between &Point and PointRef + (r_<'r> &Point, PointRef<'r, E>), (r_<'r> PointRef<'r, E>, &Point), + } +} + +matrix! { + trait = Mul, + trait_fn = mul, + output = PointZ, + output_new = PointZ, + point_fn = scalar_mul, + point_assign_fn = scalar_mul_assign, + pairs = { + (o_<> Point, &Scalar), (o_<> Point, &ScalarZ), + (r_<> &Point, &Scalar), (r_<> &Point, &ScalarZ), + (o_<> PointZ, &Scalar), (o_<> PointZ, &ScalarZ), + (r_<> &PointZ, &Scalar), (r_<> &PointZ, &ScalarZ), + (o_<> Point, Scalar), (o_<> Point, ScalarZ), + (r_<> &Point, Scalar), (r_<> &Point, ScalarZ), + (o_<> PointZ, Scalar), (o_<> PointZ, ScalarZ), + (r_<> &PointZ, Scalar), (r_<> &PointZ, ScalarZ), + + // The same as above but replacing &Point with PointRef + (r_<'p> PointRef<'p, E>, &Scalar), (r_<'p> PointRef<'p, E>, &ScalarZ), + (r_<'p> PointRef<'p, E>, Scalar), (r_<'p> PointRef<'p, E>, ScalarZ), + + // --- And vice-versa --- + + (_o<> &Scalar, Point), (_o<> &ScalarZ, Point), + (_r<> &Scalar, &Point), (_r<> &ScalarZ, &Point), + (_o<> &Scalar, PointZ), (_o<> &ScalarZ, PointZ), + (_r<> &Scalar, &PointZ), (_r<> &ScalarZ, &PointZ), + (_o<> Scalar, Point), (_o<> ScalarZ, Point), + (_r<> Scalar, &Point), (_r<> ScalarZ, &Point), + (_o<> Scalar, PointZ), (_o<> ScalarZ, PointZ), + (_r<> Scalar, &PointZ), (_r<> ScalarZ, &PointZ), + + // The same as above but replacing &Point with PointRef + (_r<'p> &Scalar, PointRef<'p, E>), (_r<'p> &ScalarZ, PointRef<'p, E>), + (_r<'p> Scalar, PointRef<'p, E>), (_r<'p> ScalarZ, PointRef<'p, E>), + } +} + +matrix! { + trait = Add, + trait_fn = add, + output = ScalarZ, + output_new = ScalarZ, + point_fn = add, + point_assign_fn = add_assign, + pairs = { + (o_<> Scalar, Scalar), (o_<> Scalar, ScalarZ), + (o_<> Scalar, &Scalar), (o_<> Scalar, &ScalarZ), + (o_<> ScalarZ, Scalar), (o_<> ScalarZ, ScalarZ), + (o_<> ScalarZ, &Scalar), (o_<> ScalarZ, &ScalarZ), + (_o<> &Scalar, Scalar), (_o<> &Scalar, ScalarZ), + (r_<> &Scalar, &Scalar), (r_<> &Scalar, &ScalarZ), + (_o<> &ScalarZ, Scalar), (_o<> &ScalarZ, ScalarZ), + (r_<> &ScalarZ, &Scalar), (r_<> &ScalarZ, &ScalarZ), + } +} + +matrix! { + trait = Sub, + trait_fn = sub, + output = ScalarZ, + output_new = ScalarZ, + point_fn = sub, + point_assign_fn = sub_assign, + pairs = { + (o_<> Scalar, Scalar), (o_<> Scalar, ScalarZ), + (o_<> Scalar, &Scalar), (o_<> Scalar, &ScalarZ), + (o_<> ScalarZ, Scalar), (o_<> ScalarZ, ScalarZ), + (o_<> ScalarZ, &Scalar), (o_<> ScalarZ, &ScalarZ), + (_o<> &Scalar, Scalar), (_o<> &Scalar, ScalarZ), + (r_<> &Scalar, &Scalar), (r_<> &Scalar, &ScalarZ), + (_o<> &ScalarZ, Scalar), (_o<> &ScalarZ, ScalarZ), + (r_<> &ScalarZ, &Scalar), (r_<> &ScalarZ, &ScalarZ), + } +} + +matrix! { + trait = Mul, + trait_fn = mul, + output = ScalarZ, + output_new = ScalarZ, + point_fn = mul, + point_assign_fn = mul_assign, + pairs = { + (o_<> Scalar, Scalar), (o_<> Scalar, ScalarZ), + (o_<> Scalar, &Scalar), (o_<> Scalar, &ScalarZ), + (o_<> ScalarZ, Scalar), (o_<> ScalarZ, ScalarZ), + (o_<> ScalarZ, &Scalar), (o_<> ScalarZ, &ScalarZ), + (_o<> &Scalar, Scalar), (_o<> &Scalar, ScalarZ), + (r_<> &Scalar, &Scalar), (r_<> &Scalar, &ScalarZ), + (_o<> &ScalarZ, Scalar), (_o<> &ScalarZ, ScalarZ), + (r_<> &ScalarZ, &Scalar), (r_<> &ScalarZ, &ScalarZ), + } +} + +impl ops::Neg for Scalar { + type Output = Scalar; + + fn neg(self) -> Self::Output { + Scalar::from_raw(self.0.neg()).expect("neg must not produce zero point") + } +} + +impl ops::Neg for &Scalar { + type Output = Scalar; + + fn neg(self) -> Self::Output { + Scalar::from_raw(self.0.neg()).expect("neg must not produce zero point") + } +} + +impl ops::Neg for ScalarZ { + type Output = ScalarZ; + + fn neg(self) -> Self::Output { + ScalarZ::from_raw(self.0.neg()) + } +} + +impl ops::Neg for &ScalarZ { + type Output = ScalarZ; + + fn neg(self) -> Self::Output { + ScalarZ::from_raw(self.0.neg()) + } +} + +impl ops::Neg for Point { + type Output = Point; + + fn neg(self) -> Self::Output { + Point::from_raw(self.0.neg_point()).expect("neg must not produce zero point") + } +} + +impl ops::Neg for &Point { + type Output = Point; + + fn neg(self) -> Self::Output { + Point::from_raw(self.0.neg_point()).expect("neg must not produce zero point") + } +} + +impl<'p, E: Curve> ops::Neg for PointRef<'p, E> { + type Output = Point; + + fn neg(self) -> Self::Output { + Point::from_raw(self.0.neg_point()).expect("neg must not produce zero point") + } +} + +impl ops::Neg for PointZ { + type Output = PointZ; + + fn neg(self) -> Self::Output { + PointZ::from_raw(self.0.neg_point()) + } +} + +impl ops::Neg for &PointZ { + type Output = PointZ; + + fn neg(self) -> Self::Output { + PointZ::from_raw(self.0.neg_point()) + } +} + +#[derive(Serialize, Deserialize)] +#[serde(bound = "")] +struct ScalarFormat { + curve_name: Cow<'static, str>, + #[serde(with = "hex")] + scalar: ScalarHex, +} + +impl TryFrom> for ScalarZ { + type Error = ConvertParsedScalarError; + + fn try_from(parsed: ScalarFormat) -> Result { + if parsed.curve_name != E::curve_name() { + return Err(ConvertParsedScalarError::MismatchedCurve { + got: parsed.curve_name, + expected: E::curve_name(), + }); + } + + Ok(ScalarZ::from_raw(parsed.scalar.0)) + } +} + +impl From> for ScalarFormat { + fn from(s: ScalarZ) -> Self { + ScalarFormat { + curve_name: E::curve_name().into(), + scalar: ScalarHex(s.0), + } + } +} + +impl TryFrom> for Scalar { + type Error = ConvertParsedScalarError; + + fn try_from(parsed: ScalarFormat) -> Result { + if parsed.curve_name != E::curve_name() { + return Err(ConvertParsedScalarError::MismatchedCurve { + got: parsed.curve_name, + expected: E::curve_name(), + }); + } + + ScalarZ::from_raw(parsed.scalar.0) + .ensure_nonzero() + .ok_or(ConvertParsedScalarError::ZeroScalar) + } +} + +impl From> for ScalarFormat { + fn from(s: Scalar) -> Self { + ScalarFormat { + curve_name: E::curve_name().into(), + scalar: ScalarHex(s.0), + } + } +} + +#[derive(Debug, Error)] +enum ConvertParsedScalarError { + #[error("scalar must not be zero")] + ZeroScalar, + #[error("expected scalar of curve {expected}, but got scalar of curve {got}")] + MismatchedCurve { + got: Cow<'static, str>, + expected: &'static str, + }, +} + +struct ScalarHex(E::Scalar); + +impl hex::ToHex for &ScalarHex { + fn encode_hex>(&self) -> T { + self.0.to_bigint().to_bytes().encode_hex() + } + + fn encode_hex_upper>(&self) -> T { + self.0.to_bigint().to_bytes().encode_hex_upper() + } +} + +impl hex::FromHex for ScalarHex { + type Error = hex::FromHexError; + + fn from_hex>(hex: T) -> Result { + let bytes = Vec::::from_hex(hex)?; + let big_int = BigInt::from_bytes(&bytes); + Ok(ScalarHex(E::Scalar::from_bigint(&big_int))) + } +} + +#[cfg(test)] +mod test { + use super::*; + + macro_rules! assert_operator_defined_for { + ( + assert_fn = $assert_fn:ident, + lhs = {}, + rhs = {$($rhs:ty),*}, + ) => { + // Corner case + }; + ( + assert_fn = $assert_fn:ident, + lhs = {$lhs:ty $(, $lhs_tail:ty)*}, + rhs = {$($rhs:ty),*}, + ) => { + assert_operator_defined_for! { + assert_fn = $assert_fn, + lhs = $lhs, + rhs = {$($rhs),*}, + } + assert_operator_defined_for! { + assert_fn = $assert_fn, + lhs = {$($lhs_tail),*}, + rhs = {$($rhs),*}, + } + }; + ( + assert_fn = $assert_fn:ident, + lhs = $lhs:ty, + rhs = {$($rhs:ty),*}, + ) => { + $($assert_fn::());* + }; + } + + /// Function asserts that P2 can be added to P1 (ie. P1 + P2) and result is PointZ. + /// If any condition doesn't meet, function won't compile. + #[allow(dead_code)] + fn assert_point_addition_defined() + where + P1: ops::Add>, + E: Curve, + { + // no-op + } + + #[test] + fn test_point_addition_defined() { + fn _curve() { + assert_operator_defined_for! { + assert_fn = assert_point_addition_defined, + lhs = {Point, PointZ, &Point, &PointZ, PointRef}, + rhs = {Point, PointZ, &Point, &PointZ, PointRef}, + } + } + } + + /// Function asserts that P2 can be subtracted from P1 (ie. P1 - P2) and result is PointZ. + /// If any condition doesn't meet, function won't compile. + #[allow(dead_code)] + fn assert_point_subtraction_defined() + where + P1: ops::Sub>, + E: Curve, + { + // no-op + } + + #[test] + fn test_point_subtraction_defined() { + fn _curve() { + assert_operator_defined_for! { + assert_fn = assert_point_subtraction_defined, + lhs = {Point, PointZ, &Point, &PointZ, PointRef}, + rhs = {Point, PointZ, &Point, &PointZ, PointRef}, + } + } + } + + /// Function asserts that M can be multiplied by N (ie. M * N) and result is PointZ. + /// If any condition doesn't meet, function won't compile. + #[allow(dead_code)] + fn assert_point_multiplication_defined() + where + M: ops::Mul>, + E: Curve, + { + // no-op + } + + #[test] + fn test_point_multiplication_defined() { + fn _curve() { + assert_operator_defined_for! { + assert_fn = assert_point_multiplication_defined, + lhs = {Point, PointZ, &Point, &PointZ, PointRef}, + rhs = {Scalar, ScalarZ, &Scalar, &ScalarZ}, + } + assert_operator_defined_for! { + assert_fn = assert_point_multiplication_defined, + lhs = {Scalar, ScalarZ, &Scalar, &ScalarZ}, + rhs = {Point, PointZ, &Point, &PointZ, PointRef}, + } + } + } + + /// Function asserts that S2 can be added to S1 (ie. S1 + S2) and result is ScalarZ. + /// If any condition doesn't meet, function won't compile. + #[allow(dead_code)] + fn assert_scalars_addition_defined() + where + S1: ops::Add>, + E: Curve, + { + // no-op + } + + #[test] + fn test_scalars_addition_defined() { + fn _curve() { + assert_operator_defined_for! { + assert_fn = assert_scalars_addition_defined, + lhs = {Scalar, ScalarZ, &Scalar, &ScalarZ}, + rhs = {Scalar, ScalarZ, &Scalar, &ScalarZ}, + } + } + } + + /// Function asserts that S2 can be added to S1 (ie. S1 + S2) and result is ScalarZ. + /// If any condition doesn't meet, function won't compile. + #[allow(dead_code)] + fn assert_scalars_subtraction_defined() + where + S1: ops::Sub>, + E: Curve, + { + // no-op + } + + #[test] + fn test_scalars_subtraction_defined() { + fn _curve() { + assert_operator_defined_for! { + assert_fn = assert_scalars_subtraction_defined, + lhs = {Scalar, ScalarZ, &Scalar, &ScalarZ}, + rhs = {Scalar, ScalarZ, &Scalar, &ScalarZ}, + } + } + } + + /// Function asserts that S1 can be multiplied by S2 (ie. S1 * S2) and result is ScalarZ. + /// If any condition doesn't meet, function won't compile. + #[allow(dead_code)] + fn assert_scalars_multiplication_defined() + where + S1: ops::Mul>, + E: Curve, + { + // no-op + } + + #[test] + fn test_scalars_multiplication_defined() { + fn _curve() { + assert_operator_defined_for! { + assert_fn = assert_scalars_multiplication_defined, + lhs = {Scalar, ScalarZ, &Scalar, &ScalarZ}, + rhs = {Scalar, ScalarZ, &Scalar, &ScalarZ}, + } + } + } +}