Skip to content

Commit

Permalink
refactor: generic assertion claims for providers
Browse files Browse the repository at this point in the history
Co-authored-by: Trong Huu Nguyen <trong.huu.nguyen@nav.no>
Co-authored-by: Tommy Trøen <tommy.troen@nav.no>
  • Loading branch information
3 people committed Nov 4, 2024
1 parent 1573c8c commit 2a32ed2
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 107 deletions.
21 changes: 12 additions & 9 deletions src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ pub async fn token(
JsonOrForm(request): JsonOrForm<TokenRequest>,
) -> impl IntoResponse {
match &request.identity_provider {
IdentityProvider::AzureAD => state.azure_ad_cc.read().await.get_token(state.clone(), request).await.into_response(),
IdentityProvider::Maskinporten => state.maskinporten.read().await.get_token(state.clone(), request).await.into_response(),
IdentityProvider::AzureAD => state.azure_ad_cc.read().await.get_token(request).await.into_response(),
IdentityProvider::Maskinporten => state.maskinporten.read().await.get_token(request).await.into_response(),
IdentityProvider::TokenX => (StatusCode::BAD_REQUEST, "TokenX does not support machine-to-machine tokens".to_string()).into_response(),
}
}
Expand All @@ -36,9 +36,9 @@ pub async fn token_exchange(
JsonOrForm(request): JsonOrForm<TokenExchangeRequest>,
) -> impl IntoResponse {
match &request.identity_provider {
IdentityProvider::AzureAD => state.azure_ad_obo.read().await.exchange_token(state.clone(), request.into()).await.into_response(),
IdentityProvider::AzureAD => state.azure_ad_obo.read().await.exchange_token(request.into()).await.into_response(),
IdentityProvider::Maskinporten => (StatusCode::BAD_REQUEST, "Maskinporten does not support token exchange".to_string()).into_response(),
IdentityProvider::TokenX => state.token_x.read().await.exchange_token(state.clone(), request.into()).await.into_response(),
IdentityProvider::TokenX => state.token_x.read().await.exchange_token(request.into()).await.into_response(),
}
}

Expand Down Expand Up @@ -100,11 +100,10 @@ impl Claims {
#[derive(Clone)]
pub struct HandlerState {
pub cfg: Config,
pub maskinporten: Arc<RwLock<Provider<MaskinportenTokenRequest>>>,
pub azure_ad_obo: Arc<RwLock<Provider<AzureADOnBehalfOfTokenRequest>>>,
pub azure_ad_cc: Arc<RwLock<Provider<AzureADClientCredentialsTokenRequest>>>,
pub token_x: Arc<RwLock<Provider<TokenXTokenRequest>>>,
// TODO: other providers
pub maskinporten: Arc<RwLock<Provider<MaskinportenTokenRequest, JWTBearerAssertionClaims>>>,
pub azure_ad_obo: Arc<RwLock<Provider<AzureADOnBehalfOfTokenRequest, ClientAssertionClaims>>>,
pub azure_ad_cc: Arc<RwLock<Provider<AzureADClientCredentialsTokenRequest, ClientAssertionClaims>>>,
pub token_x: Arc<RwLock<Provider<TokenXTokenRequest, ClientAssertionClaims>>>,
}

#[derive(Debug, Error)]
Expand All @@ -118,6 +117,9 @@ pub enum ApiError {
#[error("invalid JSON in token response: {0}")]
JSON(reqwest::Error),

#[error("cannot sign JWT claims")]
Sign,

#[error("invalid token: {0}")]
Validate(jwt::errors::Error),
}
Expand All @@ -132,6 +134,7 @@ impl IntoResponse for ApiError {
ApiError::JSON(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
ApiError::Upstream(_err) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
ApiError::Validate(_) => (StatusCode::BAD_REQUEST, self.to_string()),
ApiError::Sign => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
}
.into_response()
}
Expand Down
179 changes: 86 additions & 93 deletions src/identity_provider.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use crate::{jwks, types};
use jsonwebkey as jwk;
use jsonwebtoken as jwt;
use serde::Serialize;
use serde::{Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::marker::PhantomData;
use axum::Json;
use axum::response::IntoResponse;
use log::error;
use reqwest::StatusCode;
use crate::handlers::{ApiError, HandlerState};
use crate::types::{IdentityProvider, TokenExchangeRequest, TokenRequest, TokenResponse};
use crate::handlers::{ApiError};
use crate::types::{TokenExchangeRequest, TokenRequest, TokenResponse};

pub trait TokenRequestFactory {
fn token_request(config: TokenRequestConfig) -> Option<Self>
Expand All @@ -26,14 +26,16 @@ pub struct TokenRequestConfig {
}

#[derive(Clone)]
pub struct Provider<T: Serialize> {
issuer: String, // unused for now; maskinporten might require this as `aud` in client_assertion
pub struct Provider<T: Serialize, U: Serialize> {
#[allow(dead_code)]
issuer: String, // FIXME: unused for now; maskinporten might require this as `aud` in client_assertion
client_id: String,
pub token_endpoint: String,
private_jwk: jwt::EncodingKey,
client_assertion_header: jwt::Header,
upstream_jwks: jwks::Jwks,
_fake: PhantomData<T>,
_fake2: PhantomData<U>,
}

#[derive(Serialize)]
Expand Down Expand Up @@ -83,6 +85,7 @@ impl TokenRequestFactory for AzureADClientCredentialsTokenRequest {
})
}
}

impl TokenRequestFactory for AzureADOnBehalfOfTokenRequest {
fn token_request(config: TokenRequestConfig) -> Option<Self> {
Some(Self {
Expand Down Expand Up @@ -122,9 +125,10 @@ impl TokenRequestFactory for TokenXTokenRequest {
}

//impl<T> Provider<T> where T: TokenRequestFactory<T> + Serialize
impl<T> Provider<T>
impl<T, U> Provider<T, U>
where
T: Serialize + TokenRequestFactory,
U: Serialize + ClientAssertion,
{
pub fn new(
issuer: String,
Expand All @@ -146,6 +150,7 @@ where
upstream_jwks,
private_jwk: client_private_jwk.key.to_encoding_key(),
_fake: Default::default(),
_fake2: Default::default(),
})
}

Expand All @@ -165,31 +170,9 @@ where
})
}

fn create_assertion(&self, ass: AssertionClaimType) -> Result<String, jwt::errors::Error> {
AssertionClaims::new(
self.token_endpoint.clone(),
self.client_id.clone(),
ass,
).serialize(&self.client_assertion_header, &self.private_jwk)
}

pub async fn get_token(
&self,
_state: HandlerState,
request: TokenRequest,
async fn get_token_with_config(&self, config: TokenRequestConfig,
) -> Result<impl IntoResponse, ApiError> {
let assertion = match request.identity_provider {
IdentityProvider::AzureAD => self.create_assertion(AssertionClaimType::WithSub(self.client_id.clone())).unwrap(),
IdentityProvider::TokenX => self.create_assertion(AssertionClaimType::WithSub(self.client_id.clone())).unwrap(),
IdentityProvider::Maskinporten => self.create_assertion(AssertionClaimType::WithScope(request.target.clone())).unwrap()
};

let params = T::token_request(TokenRequestConfig {
target: request.target,
assertion,
client_id: Some(self.client_id.clone()),
user_token: None,
}).unwrap();
let params = T::token_request(config).ok_or(ApiError::Sign)?;

let client = reqwest::Client::new();
let response = client
Expand All @@ -214,99 +197,109 @@ where
Ok((StatusCode::OK, Json(res)))
}

pub async fn exchange_token(
fn create_assertion(&self, target: String) -> String {
let assertion = U::new(self.token_endpoint.clone(), self.client_id.clone(), target);
serialize_claims(assertion, &self.client_assertion_header, &self.private_jwk).unwrap()
}

pub async fn get_token(
&self,
_state: HandlerState,
request: TokenExchangeRequest,
request: TokenRequest,
) -> Result<impl IntoResponse, ApiError> {
let assertion = match request.identity_provider {
IdentityProvider::AzureAD => self.create_assertion(AssertionClaimType::WithSub(self.client_id.clone())).unwrap(),
IdentityProvider::TokenX => self.create_assertion(AssertionClaimType::WithSub(self.client_id.clone())).unwrap(),
IdentityProvider::Maskinporten => self.create_assertion(AssertionClaimType::WithScope(request.target.clone())).unwrap()
let token_request = TokenRequestConfig {
target: request.target.clone(),
assertion: self.create_assertion(request.target.clone()),
client_id: Some(self.client_id.clone()),
user_token: None,
};
self.get_token_with_config(token_request).await
}

let params = T::token_request(TokenRequestConfig {
target: request.target,
assertion,
pub async fn exchange_token(
&self,
request: TokenExchangeRequest,
) -> Result<impl IntoResponse, ApiError> {
let token_request = TokenRequestConfig {
target: request.target.clone(),
assertion: self.create_assertion(request.target.clone()),
client_id: Some(self.client_id.clone()),
user_token: Some(request.user_token),
}).unwrap();

let client = reqwest::Client::new();
let response = client
.post(self.token_endpoint.clone())
.header("accept", "application/json")
.form(&params)
.send()
.await
.map_err(ApiError::UpstreamRequest)?;

if response.status() >= StatusCode::BAD_REQUEST {
let err: types::ErrorResponse = response.json().await.map_err(ApiError::JSON)?;
return Err(ApiError::Upstream(err));
}
};
self.get_token_with_config(token_request).await
}
}

let res: TokenResponse = response
.json()
.await
.inspect_err(|err| error!("Identity provider returned invalid JSON: {:?}", err))
.map_err(ApiError::JSON)?;
pub trait ClientAssertion {
fn new(token_endpoint: String, client_id: String, target: String) -> Self;
}

Ok((StatusCode::OK, Json(res)))
}
#[derive(Serialize)]
pub struct ClientAssertionClaims {
exp: usize,
iat: usize,
nbf: usize,
jti: String,
sub: String,
iss: String,
aud: String,
}

// FIXME: split into client_assertion and jwt_bearer types
#[derive(serde::Serialize, serde::Deserialize)]
struct AssertionClaims {
#[derive(Serialize)]
pub struct JWTBearerAssertionClaims {
exp: usize,
iat: usize,
nbf: usize,
jti: String,
#[serde(skip_serializing_if = "Option::is_none")]
scope: Option<String>,
scope: String,
iss: String,
aud: String,
#[serde(skip_serializing_if = "Option::is_none")]
sub: Option<String>,
}

enum AssertionClaimType {
WithScope(String),
#[allow(dead_code)]
WithSub(String),
fn epoch_now_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
}

impl AssertionClaims {
fn new(token_endpoint: String, client_id: String, ass: AssertionClaimType) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let jti = uuid::Uuid::new_v4();
fn serialize_claims<T: Serialize>(
claims: T,
client_assertion_header: &jwt::Header,
key: &jwt::EncodingKey,
) -> Result<String, jsonwebtoken::errors::Error> {
jwt::encode(client_assertion_header, &claims, key)
}

let (scope, sub) = match ass {
AssertionClaimType::WithScope(scope) => (Some(scope), None),
AssertionClaimType::WithSub(sub) => (None, Some(sub)),
};
impl ClientAssertion for JWTBearerAssertionClaims {
fn new(token_endpoint: String, client_id: String, target: String) -> Self {
let now = epoch_now_secs();
let jti = uuid::Uuid::new_v4();

AssertionClaims {
Self {
exp: (now + 30) as usize,
iat: now as usize,
nbf: now as usize,
jti: jti.to_string(),
scope,
sub,
iss: client_id, // issuer of the token is the client itself
aud: token_endpoint, // audience of the token is the issuer
scope: target,
}
}
}

fn serialize(
&self,
client_assertion_header: &jwt::Header,
key: &jwt::EncodingKey,
) -> Result<String, jsonwebtoken::errors::Error> {
jwt::encode(client_assertion_header, &self, key)
impl ClientAssertion for ClientAssertionClaims {
fn new(token_endpoint: String, client_id: String, _target: String) -> Self {
let now = epoch_now_secs();
let jti = uuid::Uuid::new_v4();

Self {
exp: (now + 30) as usize,
iat: now as usize,
nbf: now as usize,
jti: jti.to_string(),
iss: client_id.clone(), // issuer of the token is the client itself
aud: token_endpoint, // audience of the token is the issuer
sub: client_id,
}
}
}
16 changes: 11 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use log::{info, LevelFilter};
use std::sync::Arc;
use tokio::sync::RwLock;
use identity_provider::Provider;
use crate::identity_provider::{AzureADClientCredentialsTokenRequest, AzureADOnBehalfOfTokenRequest, MaskinportenTokenRequest, TokenXTokenRequest};
use crate::identity_provider::{AzureADClientCredentialsTokenRequest, AzureADOnBehalfOfTokenRequest, ClientAssertionClaims, JWTBearerAssertionClaims, MaskinportenTokenRequest, TokenXTokenRequest};

pub mod config {
use clap::Parser;
Expand Down Expand Up @@ -83,11 +83,14 @@ async fn main() {

print_texas_logo();

info!("Starting up");

let _ = dotenv(); // load .env if present

let cfg = Config::parse();

let maskinporten: Provider<MaskinportenTokenRequest> = Provider::new(
info!("Fetch JWKS for Maskinporten...");
let maskinporten: Provider<MaskinportenTokenRequest, JWTBearerAssertionClaims> = Provider::new(
cfg.maskinporten_issuer.clone(),
cfg.maskinporten_client_id.clone(),
cfg.maskinporten_token_endpoint.clone(),
Expand All @@ -97,7 +100,8 @@ async fn main() {
.unwrap(),
).unwrap();

let azure_ad_obo: Provider<AzureADOnBehalfOfTokenRequest> = Provider::new(
info!("Fetch JWKS for Azure AD (on behalf of)...");
let azure_ad_obo: Provider<AzureADOnBehalfOfTokenRequest, ClientAssertionClaims> = Provider::new(
cfg.azure_ad_issuer.clone(),
cfg.azure_ad_client_id.clone(),
cfg.azure_ad_token_endpoint.clone(),
Expand All @@ -107,7 +111,8 @@ async fn main() {
.unwrap(),
).unwrap();

let azure_ad_cc: Provider<AzureADClientCredentialsTokenRequest> = Provider::new(
info!("Fetch JWKS for Azure AD (client credentials)...");
let azure_ad_cc: Provider<AzureADClientCredentialsTokenRequest, ClientAssertionClaims> = Provider::new(
cfg.azure_ad_issuer.clone(),
cfg.azure_ad_client_id.clone(),
cfg.azure_ad_token_endpoint.clone(),
Expand All @@ -117,7 +122,8 @@ async fn main() {
.unwrap(),
).unwrap();

let token_x: Provider<TokenXTokenRequest> = Provider::new(
info!("Fetch JWKS for TokenX...");
let token_x: Provider<TokenXTokenRequest, ClientAssertionClaims> = Provider::new(
cfg.token_x_issuer.clone(),
cfg.token_x_client_id.clone(),
cfg.token_x_token_endpoint.clone(),
Expand Down

0 comments on commit 2a32ed2

Please sign in to comment.