Skip to content

Commit

Permalink
api changes
Browse files Browse the repository at this point in the history
  • Loading branch information
joyqvq committed Jul 24, 2023
1 parent 549aea0 commit 726477c
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 17 deletions.
43 changes: 43 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions fastcrypto-zkp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.93"
once_cell = "1.16"
poseidon-ark = { git = "https://github.com/arnaucube/poseidon-ark.git", rev = "bf96de3b946e8b343c6b65412bae92f8d32251ad" }
im = "15"

[dev-dependencies]
ark-bls12-377 = "0.4.0"
Expand Down
16 changes: 9 additions & 7 deletions fastcrypto-zkp/src/bn254/unit_tests/zk_login_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::bn254::{
zk_login::{AuxInputs, OAuthProvider, OAuthProviderContent, PublicInputs, ZkLoginProof},
zk_login_api::verify_zk_login,
};
use std::collections::HashMap;
use im::hashmap::HashMap as ImHashMap;

#[test]
fn test_verify_groth16_in_bytes_google() {
Expand Down Expand Up @@ -41,16 +41,17 @@ fn test_verify_groth16_in_bytes_google() {
public_inputs.get_all_inputs_hash().unwrap()
);

let mut map = HashMap::new();
map.insert((TEST_KID, Google.get_config().0), OAuthProviderContent {
let mut map = ImHashMap::new();
map.insert((TEST_KID.to_string(), Google.get_config().0.to_string()), OAuthProviderContent {
kty: "RSA".to_string(),
kid: TEST_KID.to_string(),
e: "AQAB".to_string(),
n: "whYOFK2Ocbbpb_zVypi9SeKiNUqKQH0zTKN1-6fpCTu6ZalGI82s7XK3tan4dJt90ptUPKD2zvxqTzFNfx4HHHsrYCf2-FMLn1VTJfQazA2BvJqAwcpW1bqRUEty8tS_Yv4hRvWfQPcc2Gc3-_fQOOW57zVy-rNoJc744kb30NjQxdGp03J2S3GLQu7oKtSDDPooQHD38PEMNnITf0pj-KgDPjymkMGoJlO3aKppsjfbt_AH6GGdRghYRLOUwQU-h-ofWHR3lbYiKtXPn5dN24kiHy61e3VAQ9_YAZlwXC_99GGtw_NpghFAuM4P1JDn0DppJldy3PGFC0GfBCZASw".to_string(),
alg: "RS256".to_string(),
});
let proof = ZkLoginProof::from_json("{\"pi_a\":[\"21079899190337156604543197959052999786745784780153100922098887555507822163222\",\"4490261504756339299022091724663793329121338007571218596828748539529998991610\",\"1\"],\"pi_b\":[[\"9379167206161123715528853149920855132656754699464636503784643891913740439869\",\"15902897771112804794883785114808675393618430194414793328415185511364403970347\"],[\"16152736996630746506267683507223054358516992879195296708243566008238438281201\",\"15230917601041350929970534508991793588662911174494137634522926575255163535339\"],[\"1\",\"0\"]],\"pi_c\":[\"8242734018052567627683363270753907648903210541694662698981939667442011573249\",\"1775496841914332445297048246214170486364407018954976081505164205395286250461\",\"1\"],\"protocol\":\"groth16\"}");
assert!(proof.is_ok());
let res = verify_zk_login(&proof.unwrap(), &public_inputs, &aux_inputs, 1, map);
let res = verify_zk_login(&proof.unwrap(), &public_inputs, &aux_inputs, map);
assert!(res.is_ok());
}

Expand All @@ -77,17 +78,18 @@ fn test_verify_groth16_in_bytes_twitch() {
public_inputs.get_all_inputs_hash().unwrap()
);

let mut map = HashMap::new();
map.insert(("1", Twitch.get_config().0), OAuthProviderContent {
let mut map = ImHashMap::new();
map.insert(("1".to_string(), Twitch.get_config().0.to_string()), OAuthProviderContent {
kty: "RSA".to_string(),
kid: "1".to_string(),
e: "AQAB".to_string(),
n: "6lq9MQ-q6hcxr7kOUp-tHlHtdcDsVLwVIw13iXUCvuDOeCi0VSuxCCUY6UmMjy53dX00ih2E4Y4UvlrmmurK0eG26b-HMNNAvCGsVXHU3RcRhVoHDaOwHwU72j7bpHn9XbP3Q3jebX6KIfNbei2MiR0Wyb8RZHE-aZhRYO8_-k9G2GycTpvc-2GBsP8VHLUKKfAs2B6sW3q3ymU6M0L-cFXkZ9fHkn9ejs-sqZPhMJxtBPBxoUIUQFTgv4VXTSv914f_YkNw-EjuwbgwXMvpyr06EyfImxHoxsZkFYB-qBYHtaMxTnFsZBr6fn8Ha2JqT1hoP7Z5r5wxDu3GQhKkHw".to_string(),
alg: "RS256".to_string(),
});
let proof = ZkLoginProof::from_json("{ \"pi_a\": [ \"14609816250208775088998769033922823275418989011294962335042447516759468155261\", \"20377558696931353568738668428784363385404286135420274775798451001900237387711\", \"1\" ], \"pi_b\": [ [ \"13205564493500587952133306511249429194738679332267485407336676345714082870630\", \"20796060045071998078434479958974217243296767801927986923760870304883706846959\" ], [ \"18144611315874106283809557225033182618356564976139850467162456490949482704538\", \"4318715074202832054732474611176035084202678394565328538059624195976255391002\" ], [ \"1\", \"0\" ] ], \"pi_c\": [ \"4215643272645108456341625420022677634747189283615115637991603989161283548307\", \"5549730540188640204480179088531560793048476496379683802205245590402338452458\", \"1\" ], \"protocol\": \"groth16\"}");
assert!(proof.is_ok());

let res = verify_zk_login(&proof.unwrap(), &public_inputs, &aux_inputs, 1, map);
let res = verify_zk_login(&proof.unwrap(), &public_inputs, &aux_inputs, map);
assert!(res.is_ok());
}

Expand Down
67 changes: 67 additions & 0 deletions fastcrypto-zkp/src/bn254/zk_login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::str::FromStr;

type ParsedJWKs = Vec<((String, String), OAuthProviderContent)>;
#[cfg(test)]
#[path = "unit_tests/zk_login_tests.rs"]
mod zk_login_tests;
Expand Down Expand Up @@ -77,6 +78,72 @@ pub struct OAuthProviderContent {
pub n: String,
/// Algorithm parameter, https://datatracker.ietf.org/doc/html/rfc7517#section-4.4
pub alg: String,
/// kid
kid: String,
}

/// Reader struct to parse all fields.
#[derive(Debug, Clone, PartialEq, Eq, JsonSchema, Hash, Serialize, Deserialize)]
pub struct OAuthProviderContentReader {
e: String,
n: String,
#[serde(rename = "use")]
my_use: String,
kid: String,
kty: String,
alg: String,
}

impl OAuthProviderContent {
/// Get the kid string.
pub fn kid(&self) -> &str {
&self.kid
}

/// Parse OAuthProviderContent from the reader struct.
pub fn from_reader(reader: OAuthProviderContentReader) -> Self {
Self {
kty: reader.kty,
kid: reader.kid,
e: trim(reader.e),
n: trim(reader.n),
alg: reader.alg,
}
}
}

/// Trim trailing '=' so that it is considered a valid base64 url encoding string by base64ct library.
fn trim(str: String) -> String {
str.trim_end_matches(|c: char| c == '=').to_owned()
}

/// Parse the JWK bytes received from the oauth provider keys endpoint into a map from kid to
/// OAuthProviderContent.
pub fn parse_jwks(json_bytes: &[u8]) -> Result<ParsedJWKs, FastCryptoError> {
let json_str = String::from_utf8_lossy(json_bytes);
let parsed_list: Result<serde_json::Value, serde_json::Error> = serde_json::from_str(&json_str);
if let Ok(parsed_list) = parsed_list {
if let Some(keys) = parsed_list["keys"].as_array() {
let mut ret = Vec::new();
for k in keys {
let parsed: OAuthProviderContentReader = serde_json::from_value(k.clone())
.map_err(|_| FastCryptoError::GeneralError("Parse error".to_string()))?;

if parsed.alg == "RS256" && parsed.my_use == "sig" && parsed.kty == "RSA" {
// Default to Google for iss
ret.push((
(
parsed.kid.clone(),
OAuthProvider::Google.get_config().0.to_owned(),
),
OAuthProviderContent::from_reader(parsed),
));
}
}
return Ok(ret);
}
}
Err(FastCryptoError::GeneralError("JWK not found".to_string()))
}

impl OAuthProvider {
Expand Down
17 changes: 7 additions & 10 deletions fastcrypto-zkp/src/bn254/zk_login_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

use ark_crypto_primitives::snark::SNARK;
use fastcrypto::rsa::{Base64UrlUnpadded, Encoding};
use im::hashmap::HashMap as ImHashMap;
use num_bigint::BigUint;
use std::collections::HashMap;

Expand Down Expand Up @@ -135,24 +136,20 @@ pub fn verify_zk_login(
proof: &ZkLoginProof,
public_inputs: &PublicInputs,
aux_inputs: &AuxInputs,
curr_epoch: u64,
all_jwk: HashMap<(&str, &str), OAuthProviderContent>,
all_jwk: ImHashMap<(String, String), OAuthProviderContent>,
) -> Result<(), FastCryptoError> {
if !is_claim_supported(aux_inputs.get_key_claim_name()) {
return Err(FastCryptoError::GeneralError(
"Unsupported claim found".to_string(),
));
}
// Verify the max epoch in aux inputs is <= the current epoch of authority.
if aux_inputs.get_max_epoch() <= curr_epoch {
return Err(FastCryptoError::GeneralError(
"Invalid max epoch".to_string(),
));
}

let jwk = all_jwk
.get(&(aux_inputs.get_kid(), aux_inputs.get_iss()))
.ok_or_else(|| FastCryptoError::GeneralError("kid not found".to_string()))?;
.get(&(
aux_inputs.get_kid().to_string(),
aux_inputs.get_iss().to_string(),
))
.ok_or_else(|| FastCryptoError::GeneralError("JWK not found".to_string()))?;
let jwk_modulus =
BigUint::from_bytes_be(&Base64UrlUnpadded::decode_vec(&jwk.n).map_err(|_| {
FastCryptoError::GeneralError("Invalid Base64 encoded jwk.n".to_string())
Expand Down

0 comments on commit 726477c

Please sign in to comment.