diff --git a/src/opaque.rs b/src/opaque.rs index d016d1f7..1fd0e60d 100644 --- a/src/opaque.rs +++ b/src/opaque.rs @@ -577,37 +577,25 @@ impl ClientRegistration { &Vec::new(), password, blinding_factor_rng, - ) - } - - /// Same as ClientRegistration::start, but also accepts a username and server name as input - pub fn start_with_user_and_server_name( - user_name: &[u8], - server_name: &[u8], - password: &[u8], - blinding_factor_rng: &mut R, - ) -> Result<(RegisterFirstMessage, Self), ProtocolError> { - Self::start_with_user_and_server_name_and_postprocessing( - user_name, - server_name, - password, - blinding_factor_rng, + #[cfg(test)] std::convert::identity, ) } - /// Same as ClientRegistration::start, but also accepts a username and server name as input as well as - /// an optional postprocessing function for the blinding factor - pub fn start_with_user_and_server_name_and_postprocessing( + /// Same as ClientRegistration::start, but also accepts a username and + /// server name as input + /// as well as an optional postprocessing function for the blinding factor(used in tests) + pub fn start_with_user_and_server_name( user_name: &[u8], server_name: &[u8], password: &[u8], blinding_factor_rng: &mut R, - postprocess: fn(::Scalar) -> ::Scalar, + #[cfg(test)] postprocess: fn(::Scalar) -> ::Scalar, ) -> Result<(RegisterFirstMessage, Self), ProtocolError> { - let (token, alpha) = oprf::blind_with_postprocessing::( + let (token, alpha) = oprf::blind::( &password, blinding_factor_rng, + #[cfg(test)] postprocess, )?; @@ -1037,35 +1025,31 @@ impl ClientLogin { password: &[u8], rng: &mut R, ) -> Result<(LoginFirstMessage, Self), ProtocolError> { - Self::start_with_user_and_server_name(&Vec::new(), &Vec::new(), password, rng) - } - - /// Same as start, but allows the user to supply a username and server name - pub fn start_with_user_and_server_name( - user_name: &[u8], - server_name: &[u8], - password: &[u8], - rng: &mut R, - ) -> Result<(LoginFirstMessage, Self), ProtocolError> { - Self::start_with_user_and_server_name_and_postprocessing( - user_name, - server_name, + Self::start_with_user_and_server_name( + &Vec::new(), + &Vec::new(), password, rng, + #[cfg(test)] std::convert::identity, ) } - /// Same as start, but allows the user to supply a username and server name and postprocessing function - pub fn start_with_user_and_server_name_and_postprocessing( + /// Same as start, but allows the user to supply a username and server name + /// and, in tests, a postprocessing function + pub fn start_with_user_and_server_name( user_name: &[u8], server_name: &[u8], password: &[u8], rng: &mut R, - postprocess: fn(::Scalar) -> ::Scalar, + #[cfg(test)] postprocess: fn(::Scalar) -> ::Scalar, ) -> Result<(LoginFirstMessage, Self), ProtocolError> { - let (token, alpha) = - oprf::blind_with_postprocessing::(&password, rng, postprocess)?; + let (token, alpha) = oprf::blind::( + &password, + rng, + #[cfg(test)] + postprocess, + )?; let (ke1_state, ke1_message) = CS::KeyExchange::generate_ke1(alpha.to_arr().to_vec(), rng)?; diff --git a/src/oprf.rs b/src/oprf.rs index d189307d..6dae4a5d 100644 --- a/src/oprf.rs +++ b/src/oprf.rs @@ -23,14 +23,18 @@ static STR_VOPRF: &[u8] = b"VOPRF05"; /// message is sent from the client (who holds the input) to the server (who holds the OPRF key). /// The client can also pass in an optional "pepper" string to be mixed in with the input through /// an HKDF computation. -pub(crate) fn blind_with_postprocessing( +pub(crate) fn blind( input: &[u8], blinding_factor_rng: &mut R, - postprocess: fn(G::Scalar) -> G::Scalar, + #[cfg(test)] postprocess: fn(G::Scalar) -> G::Scalar, ) -> Result<(Token, G), InternalPakeError> { let mapped_point = G::map_to_curve(input, Some(STR_VOPRF)); // TODO: add contextString from RFC let blinding_factor = G::random_scalar(blinding_factor_rng); + #[cfg(test)] let blind = postprocess(blinding_factor); + #[cfg(not(test))] + let blind = blinding_factor; + let blind_token = mapped_point * &blind; Ok(( Token { @@ -60,23 +64,34 @@ pub(crate) fn unblind_and_finalize( Ok(prk) } -// Benchmarking shims +//////////////////////// +// Benchmarking shims // +//////////////////////// + #[cfg(feature = "bench")] +#[doc(hidden)] #[inline] pub fn blind_shim( input: &[u8], blinding_factor_rng: &mut R, ) -> Result<(Token, G), InternalPakeError> { - blind_with_postprocessing(input, blinding_factor_rng, std::convert::identity) + blind( + input, + blinding_factor_rng, + #[cfg(test)] + std::convert::identity, + ) } #[cfg(feature = "bench")] +#[doc(hidden)] #[inline] pub fn evaluate_shim(point: G, oprf_key: &G::Scalar) -> Result { evaluate(point, oprf_key) } #[cfg(feature = "bench")] +#[doc(hidden)] #[inline] pub fn unblind_and_finalize_shim( token: &Token, @@ -85,8 +100,10 @@ pub fn unblind_and_finalize_shim( unblind_and_finalize::(token, point) } -// Tests -// ===== +/////////// +// Tests // +// ===== // +/////////// #[cfg(test)] mod tests { @@ -117,11 +134,8 @@ mod tests { fn oprf_retrieval() -> Result<(), InternalPakeError> { let input = b"hunter2"; let mut rng = OsRng; - let (token, alpha) = blind_with_postprocessing::<_, RistrettoPoint>( - &input[..], - &mut rng, - std::convert::identity, - )?; + let (token, alpha) = + blind::<_, RistrettoPoint>(&input[..], &mut rng, std::convert::identity)?; let oprf_key_bytes = arr![ u8; 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, @@ -139,12 +153,8 @@ mod tests { let mut rng = OsRng; let mut input = vec![0u8; 64]; rng.fill_bytes(&mut input); - let (token, alpha) = blind_with_postprocessing::<_, RistrettoPoint>( - &input, - &mut rng, - std::convert::identity, - ) - .unwrap(); + let (token, alpha) = + blind::<_, RistrettoPoint>(&input, &mut rng, std::convert::identity).unwrap(); let res = unblind_and_finalize::(&token, alpha).unwrap(); let (hashed_input, _) = Hkdf::::extract(Some(STR_VOPRF), &input); diff --git a/src/tests/opaque_ke_test.rs b/src/tests/opaque_ke_test.rs index 848afb0f..979beb37 100644 --- a/src/tests/opaque_ke_test.rs +++ b/src/tests/opaque_ke_test.rs @@ -254,6 +254,7 @@ where id_s, password, &mut blinding_factor_registration_rng, + std::convert::identity, ) .unwrap(); let r1_bytes = r1.serialize().to_vec(); @@ -291,6 +292,7 @@ where id_s, password, &mut client_login_start_rng, + std::convert::identity, ) .unwrap(); let l1_bytes = l1.serialize().to_vec(); @@ -362,14 +364,15 @@ fn postprocess_blinding_factor(_: G::Scalar) -> G::Scalar { fn test_r1() -> Result<(), PakeError> { let parameters = populate_test_vectors(&serde_json::from_str(TEST_VECTOR).unwrap()); let mut rng = OsRng; - let (r1, client_registration) = ClientRegistration::::start_with_user_and_server_name_and_postprocessing( - ¶meters.id_u, - ¶meters.id_s, - ¶meters.password, - &mut rng, - postprocess_blinding_factor::<::Group>, - ) - .unwrap(); + let (r1, client_registration) = + ClientRegistration::::start_with_user_and_server_name( + ¶meters.id_u, + ¶meters.id_s, + ¶meters.password, + &mut rng, + postprocess_blinding_factor::<::Group>, + ) + .unwrap(); assert_eq!(hex::encode(¶meters.r1), hex::encode(r1.serialize())); assert_eq!( hex::encode(¶meters.client_registration_state), @@ -452,15 +455,14 @@ fn test_l1() -> Result<(), PakeError> { ] .concat(); let mut client_login_start_rng = CycleRng::new(client_login_start); - let (l1, client_login) = - ClientLogin::::start_with_user_and_server_name_and_postprocessing( - ¶meters.id_u, - ¶meters.id_s, - ¶meters.password, - &mut client_login_start_rng, - postprocess_blinding_factor::<::Group>, - ) - .unwrap(); + let (l1, client_login) = ClientLogin::::start_with_user_and_server_name( + ¶meters.id_u, + ¶meters.id_s, + ¶meters.password, + &mut client_login_start_rng, + postprocess_blinding_factor::<::Group>, + ) + .unwrap(); assert_eq!(hex::encode(¶meters.l1), hex::encode(l1.serialize())); assert_eq!( hex::encode(¶meters.client_login_state),