diff --git a/auth/Cargo.toml b/auth/Cargo.toml index 8b3d7171a..9289f479e 100644 --- a/auth/Cargo.toml +++ b/auth/Cargo.toml @@ -32,7 +32,7 @@ tracing-subscriber = { workspace = true } [dependencies.shuttle-common] workspace = true -features = ["backend", "models", "persist"] +features = ["backend", "models"] [dev-dependencies] axum-extra = { version = "0.7.1", features = ["cookie"] } diff --git a/auth/src/api/handlers.rs b/auth/src/api/handlers.rs index e96325ed5..c441015c0 100644 --- a/auth/src/api/handlers.rs +++ b/auth/src/api/handlers.rs @@ -92,7 +92,7 @@ pub(crate) async fn convert_key( let User { name, account_tier, .. } = user_manager - .get_user_by_key(key.as_ref().clone()) + .get_user_by_key(key) .await .map_err(|_| StatusCode::UNAUTHORIZED)?; diff --git a/auth/src/lib.rs b/auth/src/lib.rs index 2e3aa333f..3e73a2471 100644 --- a/auth/src/lib.rs +++ b/auth/src/lib.rs @@ -7,7 +7,6 @@ mod user; use std::{io, str::FromStr, time::Duration}; use args::StartArgs; -use shuttle_common::ApiKey; use sqlx::{ migrate::Migrator, query, @@ -16,7 +15,10 @@ use sqlx::{ }; use tracing::info; -use crate::{api::serve, user::AccountTier}; +use crate::{ + api::serve, + user::{AccountTier, Key}, +}; pub use api::ApiBuilder; pub use args::{Args, Commands, InitArgs}; @@ -39,8 +41,8 @@ pub async fn start(pool: SqlitePool, args: StartArgs) -> io::Result<()> { pub async fn init(pool: SqlitePool, args: InitArgs) -> io::Result<()> { let key = match args.key { - Some(ref key) => ApiKey::parse(key).unwrap(), - None => ApiKey::generate(), + Some(ref key) => Key::from_str(key).unwrap(), + None => Key::new_random(), }; query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)") @@ -51,11 +53,8 @@ pub async fn init(pool: SqlitePool, args: InitArgs) -> io::Result<()> { .await .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - println!( - "`{}` created as super user with key: {}", - args.name, - key.as_ref() - ); + println!("`{}` created as super user with key: {key}", args.name,); + Ok(()) } diff --git a/auth/src/user.rs b/auth/src/user.rs index fba965bbb..5ba3ec1b6 100644 --- a/auth/src/user.rs +++ b/auth/src/user.rs @@ -7,11 +7,9 @@ use axum::{ http::request::Parts, TypedHeader, }; +use rand::distributions::{Alphanumeric, DistString}; use serde::{Deserialize, Deserializer, Serialize}; -use shuttle_common::{ - claims::{Scope, ScopeBuilder}, - ApiKey, -}; +use shuttle_common::claims::{Scope, ScopeBuilder}; use sqlx::{query, Row, SqlitePool}; use tracing::{trace, Span}; @@ -21,7 +19,7 @@ use crate::{api::UserManagerState, error::Error}; pub trait UserManagement: Send + Sync { async fn create_user(&self, name: AccountName, tier: AccountTier) -> Result; async fn get_user(&self, name: AccountName) -> Result; - async fn get_user_by_key(&self, key: ApiKey) -> Result; + async fn get_user_by_key(&self, key: Key) -> Result; } #[derive(Clone)] @@ -32,7 +30,7 @@ pub struct UserManager { #[async_trait] impl UserManagement for UserManager { async fn create_user(&self, name: AccountName, tier: AccountTier) -> Result { - let key = ApiKey::generate(); + let key = Key::new_random(); query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)") .bind(&name) @@ -57,7 +55,7 @@ impl UserManagement for UserManager { .ok_or(Error::UserNotFound) } - async fn get_user_by_key(&self, key: ApiKey) -> Result { + async fn get_user_by_key(&self, key: Key) -> Result { query("SELECT account_name, key, account_tier FROM users WHERE key = ?1") .bind(&key) .fetch_optional(&self.pool) @@ -71,10 +69,10 @@ impl UserManagement for UserManager { } } -#[derive(Clone, Deserialize, PartialEq, Eq, Serialize, Debug)] +#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] pub struct User { pub name: AccountName, - pub key: ApiKey, + pub key: Key, pub account_tier: AccountTier, } @@ -83,7 +81,7 @@ impl User { self.account_tier == AccountTier::Admin } - pub fn new(name: AccountName, key: ApiKey, account_tier: AccountTier) -> Self { + pub fn new(name: AccountName, key: Key, account_tier: AccountTier) -> Self { Self { name, key, @@ -106,7 +104,7 @@ where let user_manager: UserManagerState = UserManagerState::from_ref(state); let user = user_manager - .get_user_by_key(key.as_ref().clone()) + .get_user_by_key(key) .await // Absorb any error into `Unauthorized` .map_err(|_| Error::Unauthorized)?; @@ -122,21 +120,16 @@ impl From for shuttle_common::models::user::Response { fn from(user: User) -> Self { Self { name: user.name.to_string(), - key: user.key.as_ref().to_string(), + key: user.key.to_string(), account_tier: user.account_tier.to_string(), } } } -/// A wrapper around [ApiKey] so we can implement [FromRequestParts] -/// for it. -pub struct Key(ApiKey); - -impl AsRef for Key { - fn as_ref(&self) -> &ApiKey { - &self.0 - } -} +#[derive(Clone, sqlx::Type, PartialEq, Hash, Eq, Serialize, Deserialize, Debug)] +#[serde(transparent)] +#[sqlx(transparent)] +pub struct Key(String); #[async_trait] impl FromRequestParts for Key @@ -149,14 +142,31 @@ where let key = TypedHeader::>::from_request_parts(parts, state) .await .map_err(|_| Error::KeyMissing) - .and_then(|TypedHeader(Authorization(bearer))| { - let bearer = bearer.token().trim(); - ApiKey::parse(bearer).map_err(|_| Self::Rejection::Unauthorized) - })?; + .and_then(|TypedHeader(Authorization(bearer))| bearer.token().trim().parse())?; trace!("got bearer key"); - Ok(Key(key)) + Ok(key) + } +} + +impl std::fmt::Display for Key { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl FromStr for Key { + type Err = Error; + + fn from_str(s: &str) -> Result { + Ok(Self(s.to_string())) + } +} + +impl Key { + pub fn new_random() -> Self { + Self(Alphanumeric.sample_string(&mut rand::thread_rng(), 16)) } }