Skip to content

Commit

Permalink
feat: add more api to handle client_ids of AsyncClient
Browse files Browse the repository at this point in the history
  • Loading branch information
caojen committed Feb 3, 2024
1 parent 7add4cc commit 8063d4e
Showing 1 changed file with 54 additions and 8 deletions.
62 changes: 54 additions & 8 deletions src/async_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::time::{Duration, Instant};
use lazy_static::lazy_static;
use log::debug;
use async_rwlock::RwLock;
use crate::{DEFAULT_TIMEOUT, GOOGLE_OAUTH_V3_USER_INFO_API, GOOGLE_SA_CERTS_URL, GoogleAccessTokenPayload, GooglePayload, MyResult, utils};
use crate::{DEFAULT_TIMEOUT, Error, GOOGLE_OAUTH_V3_USER_INFO_API, GOOGLE_SA_CERTS_URL, GoogleAccessTokenPayload, GooglePayload, IDTokenClientIDNotFoundError, MyResult, utils};
use crate::certs::{Cert, Certs};
use crate::jwt_parser::JwtParser;
use crate::validate::id_token;
Expand All @@ -18,7 +18,7 @@ lazy_static! {
/// AsyncClient is an async client to do verification.
#[derive(Debug, Clone)]
pub struct AsyncClient {
client_ids: Vec<String>,
client_ids: Arc<RwLock<Vec<String>>>,
timeout: Duration,
cached_certs: Arc<RwLock<Certs>>,
}
Expand All @@ -30,22 +30,50 @@ impl AsyncClient {
Self::new_with_vec(&[client_id])
}

/// Create a new async client, with multiple client ids.
pub fn new_with_vec<T, V>(client_ids: T) -> Self
where
T: AsRef<[V]>,
V: AsRef<str>,
{
Self {
client_ids: client_ids
.as_ref()
.iter()
.map(|c| c.as_ref().to_string())
.collect(),
client_ids: Arc::new(RwLock::new(
client_ids
.as_ref()
.iter()
.map(|c| c.as_ref())
.filter(|c| !c.is_empty())
.map(|c| c.to_string())
.collect()
)),
timeout: Duration::from_secs(DEFAULT_TIMEOUT),
cached_certs: Arc::default(),
}
}

/// Add a new client_id for future validating.
///
/// Note: this function is thread safe.
pub async fn add_client_id<T: ToString>(&mut self, client_id: T) {
let client_id = client_id.to_string();

if !client_id.is_empty() {
self.client_ids.write().await.push(client_id)
}
}

/// Remove a client_id, if it exists.
///
/// Note: this function is thread safe.
pub async fn remove_client_id<T: AsRef<str>>(&mut self, client_id: T) {
let to_delete = client_id.as_ref();

if !to_delete.is_empty() {
let mut client_ids = self.client_ids.write().await;
client_ids.retain(|id| id != to_delete)
}
}

/// Set the timeout (used in fetching google certs).
/// Default timeout is 5 seconds. Zero timeout will be ignored.
pub fn timeout(mut self, d: Duration) -> Self {
Expand All @@ -60,11 +88,23 @@ impl AsyncClient {
pub async fn validate_id_token<S>(&self, token: S) -> MyResult<GooglePayload>
where S: AsRef<str>
{
// fast check:
// if there is no given client id, simple return without communicating with Google server.

let client_ids = self.client_ids.read().await;

if client_ids.is_empty() {
return Err(Error::IDTokenClientIDNotFoundError(IDTokenClientIDNotFoundError {
get: token.as_ref().to_string(),
expected: Default::default(),
}))
}

let token = token.as_ref();

let parser: JwtParser<GooglePayload> = JwtParser::parse(token)?;

id_token::validate_info(&self.client_ids, &parser)?;
id_token::validate_info(&*client_ids, &parser)?;

let cert = self.get_cert(parser.header.alg.as_str(), parser.header.kid.as_str()).await?;

Expand Down Expand Up @@ -123,3 +163,9 @@ impl AsyncClient {
Ok(payload)
}
}

impl Default for AsyncClient {
fn default() -> Self {
Self::new_with_vec::<&[_; 0], &'static str>(&[])
}
}

0 comments on commit 8063d4e

Please sign in to comment.