Skip to content

Commit

Permalink
feat: add support for form-urlencoding in request
Browse files Browse the repository at this point in the history
* refactor: move handlers and types into separate types

Co-authored-by: tronghn <trong.huu.nguyen@nav.no>
Co-authored-by: kimtore <kim.tore.jensen@nav.no>
  • Loading branch information
3 people committed Nov 1, 2024
1 parent 7c21188 commit 7529f95
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 201 deletions.
167 changes: 167 additions & 0 deletions src/handlers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
use axum::extract::{FromRequest, Request};
use axum::{async_trait, Form, RequestExt};

use crate::config::Config;
use crate::identity_provider::*;
use crate::types;
use crate::types::{IdentityProvider, IntrospectRequest, TokenRequest, TokenResponse};
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use jsonwebtoken as jwt;
use jsonwebtoken::Algorithm::RS512;
use jsonwebtoken::DecodingKey;
use log::error;
use std::sync::Arc;
use axum::http::header::CONTENT_TYPE;
use thiserror::Error;
use tokio::sync::RwLock;

#[axum::debug_handler]
pub async fn token(State(state): State<HandlerState>, JsonOrForm(request): JsonOrForm<TokenRequest>) -> Result<impl IntoResponse, ApiError> {
let endpoint = state.token_endpoint(&request.identity_provider).await;
let params = state.token_request(&request.identity_provider, request.target).await;

let client = reqwest::Client::new();
let request_builder = client.post(endpoint)
.header("accept", "application/json")
.form(&params);

let response = request_builder
.send()
.await
.map_err(ApiError::UpstreamRequest)?
;

if response.status() >= StatusCode::BAD_REQUEST {
let err: types::ErrorResponse = response.json().await.map_err(ApiError::JSON)?;
return Err(ApiError::Upstream(err));
}

let res: TokenResponse = response
.json().await
.inspect_err(|err| {
error!("Identity provider returned invalid JSON: {:?}", err)
})
.map_err(ApiError::JSON)?
;

Ok((StatusCode::OK, Json(res)))
}

pub async fn introspection(State(state): State<HandlerState>, Json(request): Json<IntrospectRequest>) -> Result<impl IntoResponse, ApiError> {
// Need to decode the token to get the issuer before we actually validate it.
let mut validation = jwt::Validation::new(RS512);
validation.validate_exp = false;
validation.insecure_disable_signature_validation();
let key = DecodingKey::from_secret(&[]);
let token_data = jwt::decode::<Claims>(&request.token, &key, &validation).map_err(ApiError::Validate)?;

let claims = match token_data.claims.iss {
s if s == state.cfg.maskinporten_issuer => state.maskinporten.write().await.introspect(request.token).await,
_ => panic!("Unknown issuer: {}", token_data.claims.iss),
};

Ok((StatusCode::OK, Json(claims)))
}

#[derive(Clone)]
pub struct HandlerState {
pub cfg: Config,
pub maskinporten: Arc<RwLock<Maskinporten>>,
// TODO: other providers
}

impl HandlerState {
async fn token_request(&self, identity_provider: &IdentityProvider, target: String) -> Box<dyn erased_serde::Serialize + Send> {
match identity_provider {
IdentityProvider::EntraID => todo!(),
IdentityProvider::TokenX => todo!(),
IdentityProvider::Maskinporten => {
Box::new(self.maskinporten.read().await.token_request(target))
}
}
}

async fn token_endpoint(&self, identity_provider: &IdentityProvider) -> String {
match identity_provider {
IdentityProvider::EntraID => todo!(),
IdentityProvider::TokenX => todo!(),
IdentityProvider::Maskinporten => {
self.maskinporten.read().await.token_endpoint()
}
}
}
}

#[derive(Debug, Error)]
pub enum ApiError {
#[error("identity provider error: {0}")]
UpstreamRequest(reqwest::Error),

#[error("upstream error: {0}")]
Upstream(types::ErrorResponse),

#[error("invalid JSON in token response: {0}")]
JSON(reqwest::Error),

#[error("invalid token: {0}")]
Validate(jwt::errors::Error),
}

impl IntoResponse for ApiError {
fn into_response(self) -> Response {
match &self {
ApiError::UpstreamRequest(err) => {
(err.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), self.to_string())
}
ApiError::JSON(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
}
ApiError::Upstream(_err) => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
}
ApiError::Validate(_) => {
(StatusCode::BAD_REQUEST, self.to_string())
}
}.into_response()
}
}

#[derive(serde::Deserialize)]
struct Claims {
iss: String,
}

pub struct JsonOrForm<T>(T);

#[async_trait]
impl<S, T> FromRequest<S> for JsonOrForm<T>
where
S: Send + Sync,
Json<T>: FromRequest<()>,
Form<T>: FromRequest<()>,
T: 'static,
{
type Rejection = Response;

async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
let content_type_header = req.headers().get(CONTENT_TYPE);
let content_type = content_type_header.and_then(|value| value.to_str().ok());

if let Some(content_type) = content_type {
if content_type.starts_with("application/json") {
let Json(payload) = req.extract().await.map_err(IntoResponse::into_response)?;
return Ok(Self(payload));
}

if content_type.starts_with("application/x-www-form-urlencoded") {
let Form(payload) = req.extract().await.map_err(IntoResponse::into_response)?;
return Ok(Self(payload));
}
}

Err(StatusCode::UNSUPPORTED_MEDIA_TYPE.into_response())
}
}
203 changes: 2 additions & 201 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
pub mod identity_provider;
pub mod jwks;
pub mod handlers;
pub mod types;

use std::sync::Arc;
use crate::config::Config;
Expand Down Expand Up @@ -79,205 +81,4 @@ async fn main() {
axum::serve(listener, app).await.unwrap();
}

pub mod handlers {
use std::sync::Arc;
use crate::config::Config;
use crate::identity_provider::*;
use crate::types;
use crate::types::{IdentityProvider, IntrospectRequest, TokenRequest, TokenResponse};
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use jsonwebtoken as jwt;
use jsonwebtoken::Algorithm::RS512;
use jsonwebtoken::DecodingKey;
use log::error;
use thiserror::Error;
use tokio::sync::{RwLock};

#[derive(Debug, Error)]
pub enum ApiError {
#[error("identity provider error: {0}")]
UpstreamRequest(reqwest::Error),

#[error("upstream error: {0}")]
Upstream(types::ErrorResponse),

#[error("invalid JSON in token response: {0}")]
JSON(reqwest::Error),

#[error("invalid token: {0}")]
Validate(jwt::errors::Error),
}

impl IntoResponse for ApiError {
fn into_response(self) -> Response {
match &self {
ApiError::UpstreamRequest(err) => {
(err.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), self.to_string())
}
ApiError::JSON(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
}
ApiError::Upstream(_err) => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
}
ApiError::Validate(_) => {
(StatusCode::BAD_REQUEST, self.to_string())
}
}.into_response()
}
}

#[derive(Clone)]
pub struct HandlerState {
pub cfg: Config,
pub maskinporten: Arc<RwLock<Maskinporten>>,
// TODO: other providers
}

impl HandlerState {
async fn token_request(&self, identity_provider: &IdentityProvider, target: String ) -> Box<dyn erased_serde::Serialize + Send> {
match identity_provider {
IdentityProvider::EntraID => todo!(),
IdentityProvider::TokenX => todo!(),
IdentityProvider::Maskinporten => {
Box::new(self.maskinporten.read().await.token_request(target))
},
}
}

async fn token_endpoint(&self, identity_provider: &IdentityProvider) -> String {
match identity_provider {
IdentityProvider::EntraID => todo!(),
IdentityProvider::TokenX => todo!(),
IdentityProvider::Maskinporten => {
self.maskinporten.read().await.token_endpoint()
},
}
}
}

#[axum::debug_handler]
pub async fn token(State(state): State<HandlerState>, Json(request): Json<TokenRequest>) -> Result<impl IntoResponse, ApiError> {
let endpoint = state.token_endpoint(&request.identity_provider).await;
let params = state.token_request(&request.identity_provider, request.target).await;

let client = reqwest::Client::new();
let request_builder = client.post(endpoint)
.header("accept", "application/json")
.form(&params);

let response = request_builder
.send()
.await
.map_err(ApiError::UpstreamRequest)?
;

if response.status() >= StatusCode::BAD_REQUEST {
let err: types::ErrorResponse = response.json().await.map_err(ApiError::JSON)?;
return Err(ApiError::Upstream(err));
}

let res: TokenResponse = response
.json().await
.inspect_err(|err| {
error!("Identity provider returned invalid JSON: {:?}", err)
})
.map_err(ApiError::JSON)?
;

Ok((StatusCode::OK, Json(res)))
}

pub async fn introspection(State(state): State<HandlerState>, Json(request): Json<IntrospectRequest>) -> Result<impl IntoResponse, ApiError> {
// Need to decode the token to get the issuer before we actually validate it.
let mut validation = jwt::Validation::new(RS512);
validation.validate_exp = false;
validation.insecure_disable_signature_validation();
let key = DecodingKey::from_secret(&[]);
let token_data = jwt::decode::<Claims>(&request.token, &key, &validation).map_err(ApiError::Validate)?;

let claims = match token_data.claims.iss {
s if s == state.cfg.maskinporten_issuer => state.maskinporten.write().await.introspect(request.token).await,
_ => panic!("Unknown issuer: {}", token_data.claims.iss),
};

Ok((StatusCode::OK, Json(claims)))
}

#[derive(serde::Deserialize)]
struct Claims {
iss: String,
}
}

pub mod types {
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter};

/// This is an upstream RFCXXXX token response.
#[derive(Serialize, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: TokenType,
#[serde(rename = "expires_in")]
pub expires_in_seconds: usize,
}

#[derive(Deserialize, Debug, Clone)]
pub struct ErrorResponse {
pub error: String,
#[serde(rename = "error_description")]
pub description: String,
}

impl Display for ErrorResponse {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.error, self.description)
}
}

/// This is the token request sent to our identity provider.
/// TODO: hard coded parameters that only works with Maskinporten for now.
#[derive(Serialize)]
pub struct ClientTokenRequest {
pub grant_type: String,
pub assertion: String,
}

/// For forwards API compatibility. Token type is always Bearer,
/// but this might change in the future.
#[derive(Deserialize, Serialize)]
pub enum TokenType {
Bearer
}

/// This is a token request that comes from the application we are serving.
#[derive(Deserialize)]
pub struct TokenRequest {
pub target: String, // typically <cluster>:<namespace>:<app>
pub identity_provider: IdentityProvider,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub force: Option<bool>,
}

#[derive(Deserialize)]
pub struct IntrospectRequest {
pub token: String,
}

#[derive(Deserialize, Serialize)]
pub enum IdentityProvider {
#[serde(rename = "entra")]
EntraID,
#[serde(rename = "tokenx")]
TokenX,
#[serde(rename = "maskinporten")]
Maskinporten,
}
}

Loading

0 comments on commit 7529f95

Please sign in to comment.