Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle JWT tokens with oidc providers #1882

Merged
merged 30 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
876 changes: 698 additions & 178 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ indicatif = "0.16"
insta = { version = "1.21.0", features = ["toml"] }
is-terminal = "0.4"
itertools = "0.12"
jsonwebtoken = { version = "8.1.0" }
jsonwebtoken = { git = "https://github.com/jsdt/jsonwebtoken.git", rev = "7f0cef63c74f58dfa912e88da844932ba4c71562"}
lazy_static = "1.4.0"
log = "0.4.17"
mimalloc = "0.1.39"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
source: crates/bindings/tests/deps.rs
expression: "cargo tree -p spacetimedb -f {lib} -e no-dev"
---
total crates: 64
total crates: 67
spacetimedb
├── bytemuck
├── derive_more
Expand Down Expand Up @@ -48,21 +48,20 @@ spacetimedb
│ │ ├── itertools
│ │ │ └── either
│ │ └── nohash_hasher
│ ├── spacetimedb_sql_parser
│ │ ├── derive_more (*)
│ │ ├── sqlparser
│ │ │ └── log
│ │ └── thiserror
│ │ └── thiserror_impl
│ │ ├── proc_macro2 (*)
│ │ ├── quote (*)
│ │ └── syn (*)
│ └── syn (*)
├── spacetimedb_bindings_sys
│ └── spacetimedb_primitives (*)
├── spacetimedb_lib
│ ├── anyhow
│ ├── bitflags
│ ├── blake3
│ │ ├── arrayref
│ │ ├── arrayvec
│ │ ├── cfg_if
│ │ └── constant_time_eq
│ │ [build-dependencies]
│ │ └── cc
│ │ └── shlex
│ ├── derive_more (*)
│ ├── enum_as_inner
│ │ ├── heck
Expand All @@ -73,18 +72,21 @@ spacetimedb
│ ├── itertools (*)
│ ├── spacetimedb_bindings_macro (*)
│ ├── spacetimedb_data_structures
│ │ ├── ahash
│ │ │ ├── cfg_if
│ │ │ ├── getrandom (*)
│ │ │ ├── once_cell
│ │ │ └── zerocopy (*)
│ │ │ [build-dependencies]
│ │ │ └── version_check
│ │ ├── hashbrown
│ │ │ └── equivalent
│ │ │ ├── ahash
│ │ │ │ ├── cfg_if
│ │ │ │ ├── once_cell
│ │ │ │ └── zerocopy (*)
│ │ │ │ [build-dependencies]
│ │ │ │ └── version_check
│ │ │ └── allocator_api2
│ │ ├── nohash_hasher
│ │ ├── smallvec
│ │ └── thiserror (*)
│ │ └── thiserror
│ │ └── thiserror_impl
│ │ ├── proc_macro2 (*)
│ │ ├── quote (*)
│ │ └── syn (*)
│ ├── spacetimedb_primitives (*)
│ ├── spacetimedb_sats
│ │ ├── arrayvec
Expand Down
7 changes: 7 additions & 0 deletions crates/client-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,11 @@ bytestring = "1"
tokio-tungstenite.workspace = true
itoa = "1.0.9"
derive_more = "0.99.17"
uuid.workspace = true
blake3.workspace = true
jsonwebtoken.workspace = true
scopeguard.workspace = true

[dev-dependencies]
jsonwebkey = { version = "0.3.5", features = ["generate","jwt-convert"] }
jsonwebtoken.workspace = true
113 changes: 101 additions & 12 deletions crates/client-api/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
use std::time::Duration;
use std::time::{Duration, SystemTime};

use axum::extract::{Query, Request, State};
use axum::middleware::Next;
use axum::response::IntoResponse;
use axum_extra::typed_header::TypedHeader;
use headers::{authorization, HeaderMapExt};
use http::{request, HeaderValue, StatusCode};
use rand::Rng;
use serde::Deserialize;
use spacetimedb::auth::identity::SpacetimeIdentityClaims2;
use spacetimedb::auth::identity::{
decode_token, encode_token, DecodingKey, EncodingKey, JwtError, JwtErrorKind, SpacetimeIdentityClaims,
};
use spacetimedb::auth::token_validation::{TokenValidationError, validate_token};
use spacetimedb::energy::EnergyQuanta;
use spacetimedb::identity::Identity;
use uuid::Uuid;

use crate::{log_and_500, ControlStateDelegate, NodeDelegate};

Expand All @@ -33,6 +35,7 @@ pub struct SpacetimeCreds {
token: String,
}

pub const LOCALHOST: &str = "localhost";
const TOKEN_USERNAME: &str = "token";
impl authorization::Credentials for SpacetimeCreds {
const SCHEME: &'static str = authorization::Basic::SCHEME;
Expand All @@ -58,6 +61,9 @@ impl SpacetimeCreds {
pub fn decode_token(&self, public_key: &DecodingKey) -> Result<SpacetimeIdentityClaims, JwtError> {
decode_token(public_key, self.token()).map(|x| x.claims)
}
fn from_signed_token(token: String) -> Self {
Self { token }
}
/// Mint a new credentials JWT for an identity.
pub fn encode_token(private_key: &EncodingKey, identity: Identity) -> Result<Self, JwtError> {
let token = encode_token(private_key, identity)?;
Expand Down Expand Up @@ -94,19 +100,51 @@ pub struct SpacetimeAuth {
pub identity: Identity,
}

use jsonwebtoken;

struct TokenClaims {
pub issuer: String,
pub subject: String,
pub audience: Vec<String>,
}

impl TokenClaims {
// Compute the id from the issuer and subject.
fn id(&self) -> Identity {
Identity::from_claims(&self.issuer, &self.subject)
}

fn encode_and_sign(&self, private_key: &EncodingKey) -> Result<String, JwtError> {
let claims = SpacetimeIdentityClaims2 {
identity: self.id(),
subject: self.subject.clone(),
issuer: self.issuer.clone(),
audience: self.audience.clone(),
iat: SystemTime::now(),
exp: None,
};
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::ES256);
jsonwebtoken::encode(&header, &claims, private_key)
}
}

impl SpacetimeAuth {
/// Allocate a new identity, and mint a new token for it.
pub async fn alloc(ctx: &(impl NodeDelegate + ControlStateDelegate + ?Sized)) -> axum::response::Result<Self> {
// TODO: I'm just sticking in a random string until we change how identities are generated.
let identity = {
let mut rng = rand::thread_rng();
let mut random_bytes = [0u8; 16]; // Example: 16 random bytes
rng.fill(&mut random_bytes);

let preimg = [b"clockworklabs:", &random_bytes[..]].concat();
Identity::from_hashing_bytes(preimg)
// Generate claims with a random subject.
let claims = TokenClaims {
issuer: ctx.local_issuer(),
subject: Uuid::new_v4().to_string(),
// Placeholder audience.
audience: vec!["spacetimedb".to_string()],
};
let creds = SpacetimeCreds::encode_token(ctx.private_key(), identity).map_err(log_and_500)?;

let identity = claims.id();
let creds = {
let token = claims.encode_and_sign(ctx.private_key()).map_err(log_and_500)?;
SpacetimeCreds::from_signed_token(token)
};

Ok(Self { creds, identity })
}

Expand All @@ -120,6 +158,51 @@ impl SpacetimeAuth {
}
}

#[cfg(test)]
mod tests {
use crate::auth::TokenClaims;
use anyhow::Ok;
use jsonwebkey as jwk;
use jsonwebtoken::{DecodingKey, EncodingKey};
use spacetimedb::auth::identity;

// TODO: this keypair stuff is duplicated. We should create a test-only crate with helpers.
struct KeyPair {
pub public_key: DecodingKey,
pub private_key: EncodingKey,
}

fn new_keypair() -> anyhow::Result<KeyPair> {
let mut my_jwk = jwk::JsonWebKey::new(jwk::Key::generate_p256());

my_jwk.set_algorithm(jwk::Algorithm::ES256).unwrap();
let public_key = jsonwebtoken::DecodingKey::from_ec_pem(&my_jwk.key.to_public().unwrap().to_pem().as_bytes())?;
let private_key = jsonwebtoken::EncodingKey::from_ec_pem(&my_jwk.key.try_to_pem()?.as_bytes())?;
Ok(KeyPair {
public_key,
private_key,
})
}

// Make sure that when we encode TokenClaims, we can decode to get the expected identity.
#[test]
fn decode_encoded_token() -> Result<(), anyhow::Error> {
let kp = new_keypair()?;

let claims = TokenClaims {
issuer: "localhost".to_string(),
subject: "test-subject".to_string(),
audience: vec!["spacetimedb".to_string()],
};
let id = claims.id();
let token = claims.encode_and_sign(&kp.private_key)?;

let decoded = identity::decode_token(&kp.public_key, &token)?;
assert_eq!(decoded.claims.identity, id);
Ok(())
}
}

pub struct SpacetimeAuthHeader {
auth: Option<SpacetimeAuth>,
}
Expand All @@ -131,7 +214,11 @@ impl<S: NodeDelegate + Send + Sync> axum::extract::FromRequestParts<S> for Space
let Some(creds) = SpacetimeCreds::from_request_parts(parts)? else {
return Ok(Self { auth: None });
};
let claims = creds.decode_token(state.public_key())?;

let claims = validate_token(state.public_key().clone(), &state.local_issuer(), &creds.token)
.await
.map_err(AuthorizationRejection::Custom)?;

let auth = SpacetimeAuth {
creds,
identity: claims.identity,
Expand All @@ -145,6 +232,7 @@ impl<S: NodeDelegate + Send + Sync> axum::extract::FromRequestParts<S> for Space
pub enum AuthorizationRejection {
Jwt(JwtError),
Header(headers::Error),
Custom(TokenValidationError),
Required,
}

Expand All @@ -165,6 +253,7 @@ impl IntoResponse for AuthorizationRejection {
match self {
AuthorizationRejection::Jwt(e) if *e.kind() == JwtErrorKind::InvalidSignature => ROTATED.into_response(),
AuthorizationRejection::Jwt(_) | AuthorizationRejection::Header(_) => INVALID.into_response(),
AuthorizationRejection::Custom(msg) => (StatusCode::UNAUTHORIZED, format!("{:?}", msg)).into_response(),
AuthorizationRejection::Required => REQUIRED.into_response(),
}
}
Expand Down
7 changes: 7 additions & 0 deletions crates/client-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ pub trait NodeDelegate: Send + Sync {
/// Return a JWT decoding key for verifying credentials.
fn public_key(&self) -> &DecodingKey;

// The issuer to use when signing JWTs.
fn local_issuer(&self) -> String;

/// Return the public key used to verify JWTs, as the bytes of a PEM public key file.
///
/// The `/identity/public-key` route calls this method to return the public key to callers.
Expand Down Expand Up @@ -231,6 +234,10 @@ impl<T: NodeDelegate + ?Sized> NodeDelegate for Arc<T> {
(**self).gather_metrics()
}

fn local_issuer(&self) -> String {
(**self).local_issuer()
}

fn host_controller(&self) -> &HostController {
(**self).host_controller()
}
Expand Down
7 changes: 7 additions & 0 deletions crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ url.workspace = true
urlencoding.workspace = true
uuid.workspace = true
wasmtime.workspace = true
jwks = { git = "https://github.com/jsdt/jwks.git", rev = "acb4241f3768ff89515a0c12d927e280a604fff3"}
async_cache = "0.3.1"
faststr = "0.2.23"

[features]
# Print a warning when doing an unindexed `iter_by_col_range` on a large table.
Expand All @@ -121,3 +124,7 @@ proptest-derive.workspace = true
rand.workspace = true
env_logger.workspace = true
pretty_assertions.workspace = true
jsonwebkey = { version = "0.3.5", features = ["generate", "jwt-convert"] }
jsonwebtoken.workspace = true
axum-test = "16.2.0"
axum.workspace = true
Loading