From 04e4f697ad6700aafe1fc630a21a98b4d7a8f4d9 Mon Sep 17 00:00:00 2001 From: oddgrd <29732646+oddgrd@users.noreply.github.com> Date: Fri, 14 Apr 2023 15:49:56 +0200 Subject: [PATCH 1/7] feat: ensure API key is valid --- CONTRIBUTING.md | 6 +- Cargo.lock | 83 ++++++++++++++++++++++++- cargo-shuttle/Cargo.toml | 2 +- cargo-shuttle/src/client.rs | 4 +- cargo-shuttle/src/config.rs | 52 ++++++++-------- cargo-shuttle/src/lib.rs | 5 +- cargo-shuttle/tests/integration/init.rs | 10 +-- common/Cargo.toml | 1 + common/src/lib.rs | 61 +++++++++++++++++- 9 files changed, 181 insertions(+), 43 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2184b106e..d69c339e6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -93,13 +93,15 @@ Before we can login to our local instance of shuttle, we need to create a user. The following command inserts a user into the `auth` state with admin privileges: ```bash -docker compose --file docker-compose.rendered.yml --project-name shuttle-dev exec auth /usr/local/bin/service --state=/var/lib/shuttle-auth init --name admin --key test-key +# the --key needs to be 16 alphanumeric characters +docker compose --file docker-compose.rendered.yml --project-name shuttle-dev exec auth /usr/local/bin/service --state=/var/lib/shuttle-auth init --name admin --key dh9z58jttoes3qvt ``` Login to shuttle service in a new terminal window from the root of the shuttle directory: ```bash -cargo run --bin cargo-shuttle -- login --api-key "test-key" +# the --api-kei should be the same one you inserted in the auth state +cargo run --bin cargo-shuttle -- login --api-key "dh9z58jttoes3qvt" ``` The [shuttle examples](https://github.com/shuttle-hq/examples) are linked to the main repo as a [git submodule](https://git-scm.com/book/en/v2/Git-Tools-Submodules), to initialize it run the following commands: diff --git a/Cargo.lock b/Cargo.lock index 5e2882962..6faa437d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -705,6 +705,21 @@ dependencies = [ "serde", ] +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -1898,9 +1913,9 @@ dependencies = [ [[package]] name = "dialoguer" -version = "0.10.3" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af3c796f3b0b408d9fd581611b47fa850821fcb84aa640b83a3c1a5be2d691f2" +checksum = "59c6f2989294b9a498d3ad5491a79c6deb604617378e1cdc4bfc1c1361fe2f87" dependencies = [ "console", "fuzzy-matcher", @@ -3200,6 +3215,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "libm" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb" + [[package]] name = "libnghttp2-sys" version = "0.1.7+1.45.0" @@ -3580,6 +3601,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -4155,6 +4177,27 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proptest" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29f1b898011ce9595050a68e60f90bad083ff2987a695a42357134c8381fba70" +dependencies = [ + "bit-set", + "bitflags 1.3.2", + "byteorder", + "lazy_static", + "num-traits", + "quick-error 2.0.1", + "rand", + "rand_chacha", + "rand_xorshift", + "regex-syntax", + "rusty-fork", + "tempfile", + "unarray", +] + [[package]] name = "prost" version = "0.11.8" @@ -4241,6 +4284,12 @@ version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" +[[package]] +name = "quick-error" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" + [[package]] name = "quote" version = "1.0.26" @@ -4286,6 +4335,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_xorshift" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" +dependencies = [ + "rand_core", +] + [[package]] name = "rand_xoshiro" version = "0.6.0" @@ -4476,7 +4534,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00" dependencies = [ "hostname", - "quick-error", + "quick-error 1.2.3", ] [[package]] @@ -4689,6 +4747,18 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06" +[[package]] +name = "rusty-fork" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb3dcc6e454c328bb824492db107ab7c0ae8fcffe4ad210136ef014458c1bc4f" +dependencies = [ + "fnv", + "quick-error 1.2.3", + "tempfile", + "wait-timeout", +] + [[package]] name = "ryu" version = "1.0.13" @@ -5092,6 +5162,7 @@ dependencies = [ "opentelemetry-http", "opentelemetry-otlp", "pin-project", + "proptest", "prost-types", "reqwest", "ring", @@ -6421,6 +6492,12 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + [[package]] name = "unicase" version = "2.6.0" diff --git a/cargo-shuttle/Cargo.toml b/cargo-shuttle/Cargo.toml index 4cc427f6e..b4fb7e092 100644 --- a/cargo-shuttle/Cargo.toml +++ b/cargo-shuttle/Cargo.toml @@ -18,7 +18,7 @@ clap = { workspace = true, features = ["env"] } clap_complete = "4.1.5" crossbeam-channel = { workspace = true } crossterm = { workspace = true } -dialoguer = { version = "0.10.3", features = ["fuzzy-select"] } +dialoguer = { version = "0.10.4", features = ["fuzzy-select"] } dirs = { workspace = true } dunce = "1.0.3" flate2 = { workspace = true } diff --git a/cargo-shuttle/src/client.rs b/cargo-shuttle/src/client.rs index 38084b8ab..dd3209151 100644 --- a/cargo-shuttle/src/client.rs +++ b/cargo-shuttle/src/client.rs @@ -216,7 +216,7 @@ impl Client { let mut request = url.into_client_request()?; if let Some(ref api_key) = self.api_key { - let auth_header = Authorization::bearer(api_key)?; + let auth_header = Authorization::bearer(api_key.as_ref())?; request.headers_mut().typed_insert(auth_header); } @@ -282,7 +282,7 @@ impl Client { fn set_builder_auth(&self, builder: RequestBuilder) -> RequestBuilder { if let Some(ref api_key) = self.api_key { - builder.bearer_auth(api_key) + builder.bearer_auth(api_key.as_ref()) } else { builder } diff --git a/cargo-shuttle/src/config.rs b/cargo-shuttle/src/config.rs index 5225258de..5ecce04ab 100644 --- a/cargo-shuttle/src/config.rs +++ b/cargo-shuttle/src/config.rs @@ -122,24 +122,24 @@ impl ConfigManager for LocalConfigManager { /// Global client config for things like API keys. #[derive(Deserialize, Serialize, Default)] pub struct GlobalConfig { - pub api_key: Option, + api_key: Option, pub api_url: Option, } impl GlobalConfig { - pub fn api_key(&self) -> Option<&ApiKey> { - self.api_key.as_ref() + pub fn api_key(&self) -> Option> { + self.api_key.as_ref().map(|key| ApiKey::parse(key)) } - pub fn set_api_key(&mut self, api_key: ApiKey) -> Option { - self.api_key.replace(api_key) + pub fn set_api_key(&mut self, api_key: ApiKey) -> Option { + self.api_key.replace(api_key.as_ref().to_string()) } pub fn clear_api_key(&mut self) { self.api_key = None; } - pub fn api_url(&self) -> Option { + pub fn api_url(&self) -> Option { self.api_url.clone() } } @@ -324,24 +324,22 @@ impl RequestContext { /// otherwise from the global configuration. Returns an error if /// an API key is not set. pub fn api_key(&self) -> Result { - std::env::var("SHUTTLE_API_KEY") - .context("environment variable SHUTTLE_API_KEY is not set or invalid") - .or_else(|_| { - self.global - .as_ref() - .unwrap() - .api_key() - .map(|key| key.to_owned()) - .ok_or_else(|| { - anyhow!( - "Configuration file: `{}`", - self.global.manager.path().display() - ) - .context(anyhow!( - "No valid API key found, try logging in first with:\n\tcargo shuttle login" - )) - }) - }) + let api_key = std::env::var("SHUTTLE_API_KEY"); + + if let Ok(key) = api_key { + ApiKey::parse(&key).context("environment variable SHUTTLE_API_KEY is invalid") + } else { + match self.global.as_ref().unwrap().api_key() { + Some(key) => key, + None => Err(anyhow!( + "Configuration file: `{}`", + self.global.manager.path().display() + ) + .context(anyhow!( + "No valid API key found, try logging in first with:\n\tcargo shuttle login" + ))), + } + } } /// Get the current context working directory @@ -358,10 +356,10 @@ impl RequestContext { } /// Set the API key to the global configuration. Will persist the file. - pub fn set_api_key(&mut self, api_key: ApiKey) -> Result> { - let res = self.global.as_mut().unwrap().set_api_key(api_key); + pub fn set_api_key(&mut self, api_key: ApiKey) -> Result<()> { + self.global.as_mut().unwrap().set_api_key(api_key); self.global.save()?; - Ok(res) + Ok(()) } pub fn clear_api_key(&mut self) -> Result<()> { diff --git a/cargo-shuttle/src/lib.rs b/cargo-shuttle/src/lib.rs index 43166ae04..0499f515b 100644 --- a/cargo-shuttle/src/lib.rs +++ b/cargo-shuttle/src/lib.rs @@ -9,7 +9,7 @@ use shuttle_common::models::deployment::get_deployments_table; use shuttle_common::models::project::{State, IDLE_MINUTES}; use shuttle_common::models::resource::get_resources_table; use shuttle_common::project::ProjectName; -use shuttle_common::resource; +use shuttle_common::{resource, ApiKey}; use shuttle_proto::runtime::{self, LoadRequest, StartRequest, SubscribeLogsRequest}; use tokio::task::JoinSet; @@ -261,11 +261,12 @@ impl Shuttle { Password::with_theme(&ColorfulTheme::default()) .with_prompt("API key") + .validate_with(|input: &String| ApiKey::parse(input).map(|_| {})) .interact()? } }; - let api_key = api_key_str.trim().parse()?; + let api_key = ApiKey::parse(&api_key_str)?; self.ctx.set_api_key(api_key)?; diff --git a/cargo-shuttle/tests/integration/init.rs b/cargo-shuttle/tests/integration/init.rs index 223987b5b..b67c7ff9a 100644 --- a/cargo-shuttle/tests/integration/init.rs +++ b/cargo-shuttle/tests/integration/init.rs @@ -18,7 +18,7 @@ async fn non_interactive_basic_init() { "http://shuttle.invalid:80", "init", "--api-key", - "fake-api-key", + "dh9z58jttoes3qvt", "--name", "my-project", "--no-framework", @@ -44,7 +44,7 @@ async fn non_interactive_rocket_init() { "http://shuttle.invalid:80", "init", "--api-key", - "fake-api-key", + "dh9z58jttoes3qvt", "--name", "my-project", "--rocket", @@ -67,7 +67,7 @@ fn interactive_rocket_init() -> Result<(), Box> { "http://shuttle.invalid:80", "init", "--api-key", - "fake-api-key", + "dh9z58jttoes3qvt", ]); let mut session = rexpect::session::spawn_command(command, Some(2000))?; @@ -106,7 +106,7 @@ fn interactive_rocket_init_dont_prompt_framework() -> Result<(), Box Result<(), Box anyhow::Result { + let key = key.trim().to_string(); + + let mut errors = vec![]; + if !key.chars().all(char::is_alphanumeric) { + errors.push("The API key should consist of only alphanumeric characters."); + }; + + if key.len() != 16 { + errors.push("The API key should be exactly 16 characters in length."); + }; + + if !errors.is_empty() { + let message = errors.join("\n"); + bail!("Invalid API key:\n{message}") + } + + Ok(Self(key)) + } +} + +impl AsRef for ApiKey { + fn as_ref(&self) -> &str { + &self.0 + } +} + #[cfg(feature = "error")] /// Errors that can occur when changing types. Especially from prost #[derive(thiserror::Error, Debug)] @@ -138,3 +169,31 @@ impl SecretStore { self.secrets.get(key).map(ToOwned::to_owned) } } + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use crate::ApiKey; + + proptest! { + #[test] + // The API key should be a 16 character alphanumeric string. + fn parses_valid_keys(s in "[a-zA-Z0-9]{16}") { + println!("s: {s}, len: {}", s.len()); + ApiKey::parse(&s).unwrap(); + } + } + + #[test] + #[should_panic(expected = "The API key should be exactly 16 characters in length.")] + fn invalid_length() { + ApiKey::parse("tooshort").unwrap(); + } + + #[test] + #[should_panic(expected = "The API key should consist of only alphanumeric characters.")] + fn non_alphanumeric() { + ApiKey::parse("dh9z58jttoes3qv@").unwrap(); + } +} From af4794ee59101f74d0a9c166f609c3d11a88e9c9 Mon Sep 17 00:00:00 2001 From: oddgrd <29732646+oddgrd@users.noreply.github.com> Date: Sun, 23 Apr 2023 13:00:55 +0200 Subject: [PATCH 2/7] feat: use ApiKey in auth --- Cargo.lock | 2 ++ auth/Cargo.toml | 2 +- auth/src/api/handlers.rs | 2 +- auth/src/lib.rs | 16 +++++----- auth/src/user.rs | 61 +++++++++++++++------------------------ auth/tests/api/auth.rs | 2 +- auth/tests/api/helpers.rs | 2 +- common/Cargo.toml | 7 +++++ common/src/lib.rs | 20 +++++++++---- 9 files changed, 60 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6faa437d9..4592c8fe1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5164,12 +5164,14 @@ dependencies = [ "pin-project", "proptest", "prost-types", + "rand", "reqwest", "ring", "rmp-serde", "rustrict", "serde", "serde_json", + "sqlx", "strum", "thiserror", "tokio", diff --git a/auth/Cargo.toml b/auth/Cargo.toml index 9289f479e..8b3d7171a 100644 --- a/auth/Cargo.toml +++ b/auth/Cargo.toml @@ -32,7 +32,7 @@ tracing-subscriber = { workspace = true } [dependencies.shuttle-common] workspace = true -features = ["backend", "models"] +features = ["backend", "models", "persist"] [dev-dependencies] axum-extra = { version = "0.7.1", features = ["cookie"] } diff --git a/auth/src/api/handlers.rs b/auth/src/api/handlers.rs index b45061bf6..e96325ed5 100644 --- a/auth/src/api/handlers.rs +++ b/auth/src/api/handlers.rs @@ -92,7 +92,7 @@ pub(crate) async fn convert_key( let User { name, account_tier, .. } = user_manager - .get_user_by_key(key.clone()) + .get_user_by_key(key.as_ref().clone()) .await .map_err(|_| StatusCode::UNAUTHORIZED)?; diff --git a/auth/src/lib.rs b/auth/src/lib.rs index 67067f52e..2e3aa333f 100644 --- a/auth/src/lib.rs +++ b/auth/src/lib.rs @@ -7,6 +7,7 @@ mod user; use std::{io, str::FromStr, time::Duration}; use args::StartArgs; +use shuttle_common::ApiKey; use sqlx::{ migrate::Migrator, query, @@ -15,10 +16,7 @@ use sqlx::{ }; use tracing::info; -use crate::{ - api::serve, - user::{AccountTier, Key}, -}; +use crate::{api::serve, user::AccountTier}; pub use api::ApiBuilder; pub use args::{Args, Commands, InitArgs}; @@ -41,8 +39,8 @@ pub async fn start(pool: SqlitePool, args: StartArgs) -> io::Result<()> { pub async fn init(pool: SqlitePool, args: InitArgs) -> io::Result<()> { let key = match args.key { - Some(ref key) => Key::from_str(key).unwrap(), - None => Key::new_random(), + Some(ref key) => ApiKey::parse(key).unwrap(), + None => ApiKey::generate(), }; query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)") @@ -53,7 +51,11 @@ pub async fn init(pool: SqlitePool, args: InitArgs) -> io::Result<()> { .await .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - println!("`{}` created as super user with key: {key}", args.name); + println!( + "`{}` created as super user with key: {}", + args.name, + key.as_ref() + ); Ok(()) } diff --git a/auth/src/user.rs b/auth/src/user.rs index 80453c7e0..31a3ba8d5 100644 --- a/auth/src/user.rs +++ b/auth/src/user.rs @@ -7,9 +7,8 @@ use axum::{ http::request::Parts, TypedHeader, }; -use rand::distributions::{Alphanumeric, DistString}; use serde::{Deserialize, Deserializer, Serialize}; -use shuttle_common::claims::Scope; +use shuttle_common::{claims::Scope, ApiKey}; use sqlx::{query, Row, SqlitePool}; use tracing::{trace, Span}; @@ -19,7 +18,7 @@ use crate::{api::UserManagerState, error::Error}; pub trait UserManagement: Send + Sync { async fn create_user(&self, name: AccountName, tier: AccountTier) -> Result; async fn get_user(&self, name: AccountName) -> Result; - async fn get_user_by_key(&self, key: Key) -> Result; + async fn get_user_by_key(&self, key: ApiKey) -> Result; } #[derive(Clone)] @@ -30,7 +29,7 @@ pub struct UserManager { #[async_trait] impl UserManagement for UserManager { async fn create_user(&self, name: AccountName, tier: AccountTier) -> Result { - let key = Key::new_random(); + let key = ApiKey::generate(); query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)") .bind(&name) @@ -55,7 +54,7 @@ impl UserManagement for UserManager { .ok_or(Error::UserNotFound) } - async fn get_user_by_key(&self, key: Key) -> Result { + async fn get_user_by_key(&self, key: ApiKey) -> Result { query("SELECT account_name, key, account_tier FROM users WHERE key = ?1") .bind(&key) .fetch_optional(&self.pool) @@ -72,7 +71,7 @@ impl UserManagement for UserManager { #[derive(Clone, Deserialize, PartialEq, Eq, Serialize, Debug)] pub struct User { pub name: AccountName, - pub key: Key, + pub key: ApiKey, pub account_tier: AccountTier, } @@ -81,7 +80,7 @@ impl User { self.account_tier == AccountTier::Admin } - pub fn new(name: AccountName, key: Key, account_tier: AccountTier) -> Self { + pub fn new(name: AccountName, key: ApiKey, account_tier: AccountTier) -> Self { Self { name, key, @@ -104,9 +103,9 @@ where let user_manager: UserManagerState = UserManagerState::from_ref(state); let user = user_manager - .get_user_by_key(key) + .get_user_by_key(key.as_ref().clone()) .await - // Absord any error into `Unauthorized` + // Absorb any error into `Unauthorized` .map_err(|_| Error::Unauthorized)?; // Record current account name for tracing purposes @@ -120,16 +119,21 @@ impl From for shuttle_common::models::user::Response { fn from(user: User) -> Self { Self { name: user.name.to_string(), - key: user.key.to_string(), + key: user.key.as_ref().to_string(), account_tier: user.account_tier.to_string(), } } } -#[derive(Clone, Debug, sqlx::Type, PartialEq, Hash, Eq, Serialize, Deserialize)] -#[serde(transparent)] -#[sqlx(transparent)] -pub struct Key(String); +/// A wrapper around [ApiKey] so we can implement [FromRequestParts] +/// for it. +pub struct Key(ApiKey); + +impl AsRef for Key { + fn as_ref(&self) -> &ApiKey { + &self.0 + } +} #[async_trait] impl FromRequestParts for Key @@ -142,31 +146,14 @@ where let key = TypedHeader::>::from_request_parts(parts, state) .await .map_err(|_| Error::KeyMissing) - .and_then(|TypedHeader(Authorization(bearer))| bearer.token().trim().parse())?; - - trace!(%key, "got bearer key"); - - Ok(key) - } -} - -impl std::fmt::Display for Key { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl FromStr for Key { - type Err = Error; + .and_then(|TypedHeader(Authorization(bearer))| { + let bearer = bearer.token().trim(); + Ok(ApiKey::parse(bearer).map_err(|_| Self::Rejection::Unauthorized)?) + })?; - fn from_str(s: &str) -> Result { - Ok(Self(s.to_string())) - } -} + trace!("got bearer key"); -impl Key { - pub fn new_random() -> Self { - Self(Alphanumeric.sample_string(&mut rand::thread_rng(), 16)) + Ok(Key(key)) } } diff --git a/auth/tests/api/auth.rs b/auth/tests/api/auth.rs index e8a4091fd..a1f060561 100644 --- a/auth/tests/api/auth.rs +++ b/auth/tests/api/auth.rs @@ -26,7 +26,7 @@ async fn convert_api_key_to_jwt() { // GET /auth/key with invalid bearer token. let request = Request::builder() .uri("/auth/key") - .header(AUTHORIZATION, "Bearer notadmin") + .header(AUTHORIZATION, "Bearer ndh9z58jttoefake") .body(Body::empty()) .unwrap(); diff --git a/auth/tests/api/helpers.rs b/auth/tests/api/helpers.rs index 0d3a7b2f1..4be06084f 100644 --- a/auth/tests/api/helpers.rs +++ b/auth/tests/api/helpers.rs @@ -4,7 +4,7 @@ use shuttle_auth::{sqlite_init, ApiBuilder}; use sqlx::query; use tower::ServiceExt; -pub(crate) const ADMIN_KEY: &str = "my-api-key"; +pub(crate) const ADMIN_KEY: &str = "ndh9z58jttoes3qv"; pub(crate) struct TestApp { pub router: Router, diff --git a/common/Cargo.toml b/common/Cargo.toml index a473310f5..36c8468e4 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -26,12 +26,18 @@ opentelemetry-http = { workspace = true, optional = true } opentelemetry-otlp = { version = "0.11.0", optional = true } pin-project = { workspace = true, optional = true } prost-types = { workspace = true, optional = true } +rand = { workspace = true, optional = true } reqwest = { workspace = true, optional = true } rmp-serde = { workspace = true, optional = true } rustrict = { version = "0.7.4", optional = true } serde = { workspace = true, features = ["derive", "std"] } serde_json = { workspace = true } strum = { workspace = true, features = ["derive"] } +sqlx = { workspace = true, optional = true, features = [ + "sqlite", + "json", + "runtime-tokio-native-tls", +] } thiserror = { workspace = true, optional = true } tonic = { workspace = true, optional = true } tower = { workspace = true, optional = true } @@ -72,6 +78,7 @@ claims = [ display = ["chrono/clock", "comfy-table", "crossterm"] error = ["prost-types", "thiserror", "uuid"] models = ["anyhow", "async-trait", "display", "http", "reqwest", "service"] +persist = ["sqlx", "rand"] service = ["chrono/serde", "once_cell", "rustrict", "serde/derive", "uuid"] tracing = [] wasm = [ diff --git a/common/src/lib.rs b/common/src/lib.rs index ac7f08254..a249725c0 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -22,14 +22,15 @@ pub mod wasm; use std::collections::BTreeMap; use anyhow::bail; -use serde::{Deserialize, Serialize}; -#[cfg(feature = "service")] -use uuid::Uuid; - #[cfg(feature = "service")] pub use log::Item as LogItem; #[cfg(feature = "service")] pub use log::STATE_MESSAGE; +#[cfg(feature = "persist")] +use rand::distributions::{Alphanumeric, DistString}; +use serde::{Deserialize, Serialize}; +#[cfg(feature = "service")] +use uuid::Uuid; #[cfg(debug_assertions)] pub const API_URL_DEFAULT: &str = "http://localhost:8001"; @@ -42,7 +43,10 @@ pub type Host = String; #[cfg(feature = "service")] pub type DeploymentId = Uuid; -#[derive(Serialize, Deserialize, Clone)] +#[derive(Clone, Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "persist", derive(sqlx::Type, PartialEq, Hash, Eq))] +#[cfg_attr(feature = "persist", serde(transparent))] +#[cfg_attr(feature = "persist", sqlx(transparent))] pub struct ApiKey(String); impl ApiKey { @@ -65,6 +69,11 @@ impl ApiKey { Ok(Self(key)) } + + #[cfg(feature = "persist")] + pub fn generate() -> Self { + Self(Alphanumeric.sample_string(&mut rand::thread_rng(), 16)) + } } impl AsRef for ApiKey { @@ -180,7 +189,6 @@ mod tests { #[test] // The API key should be a 16 character alphanumeric string. fn parses_valid_keys(s in "[a-zA-Z0-9]{16}") { - println!("s: {s}, len: {}", s.len()); ApiKey::parse(&s).unwrap(); } } From a475913600a21df316d38936f860e5353820d377 Mon Sep 17 00:00:00 2001 From: oddgrd <29732646+oddgrd@users.noreply.github.com> Date: Mon, 1 May 2023 16:07:59 +0200 Subject: [PATCH 3/7] refactor: clean up tests --- common/Cargo.toml | 8 ++------ common/src/lib.rs | 17 ++++++++++++----- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/common/Cargo.toml b/common/Cargo.toml index ede4fdb45..be204fecb 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -33,11 +33,7 @@ rustrict = { version = "0.7.4", optional = true } serde = { workspace = true, features = ["derive", "std"] } serde_json = { workspace = true } strum = { workspace = true, features = ["derive"] } -sqlx = { workspace = true, optional = true, features = [ - "sqlite", - "json", - "runtime-tokio-native-tls", -] } +sqlx = { workspace = true, optional = true, features = ["runtime-tokio-native-tls"] } thiserror = { workspace = true, optional = true } tonic = { workspace = true, optional = true } tower = { workspace = true, optional = true } @@ -80,7 +76,7 @@ display = ["chrono/clock", "comfy-table", "crossterm"] error = ["prost-types", "thiserror", "uuid"] openapi = ["utoipa/chrono", "utoipa/uuid"] models = ["anyhow", "async-trait", "display", "http", "reqwest", "service"] -persist = ["sqlx", "rand"] +persist = ["sqlx/sqlite", "rand"] service = ["chrono/serde", "once_cell", "rustrict", "serde/derive", "uuid"] tracing = [] wasm = [ diff --git a/common/src/lib.rs b/common/src/lib.rs index a249725c0..340c4f9f2 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -26,8 +26,6 @@ use anyhow::bail; pub use log::Item as LogItem; #[cfg(feature = "service")] pub use log::STATE_MESSAGE; -#[cfg(feature = "persist")] -use rand::distributions::{Alphanumeric, DistString}; use serde::{Deserialize, Serialize}; #[cfg(feature = "service")] use uuid::Uuid; @@ -72,6 +70,8 @@ impl ApiKey { #[cfg(feature = "persist")] pub fn generate() -> Self { + use rand::distributions::{Alphanumeric, DistString}; + Self(Alphanumeric.sample_string(&mut rand::thread_rng(), 16)) } } @@ -188,20 +188,27 @@ mod tests { proptest! { #[test] // The API key should be a 16 character alphanumeric string. - fn parses_valid_keys(s in "[a-zA-Z0-9]{16}") { + fn parses_valid_api_keys(s in "[a-zA-Z0-9]{16}") { ApiKey::parse(&s).unwrap(); } } + #[test] + fn generated_api_key_is_valid() { + let key = ApiKey::generate(); + + assert!(ApiKey::parse(key.as_ref()).is_ok()); + } + #[test] #[should_panic(expected = "The API key should be exactly 16 characters in length.")] - fn invalid_length() { + fn invalid_api_key_length() { ApiKey::parse("tooshort").unwrap(); } #[test] #[should_panic(expected = "The API key should consist of only alphanumeric characters.")] - fn non_alphanumeric() { + fn non_alphanumeric_api_key() { ApiKey::parse("dh9z58jttoes3qv@").unwrap(); } } From f11660847ac5e2646bd735f1e2fa109495ef2c85 Mon Sep 17 00:00:00 2001 From: oddgrd <29732646+oddgrd@users.noreply.github.com> Date: Mon, 1 May 2023 16:26:06 +0200 Subject: [PATCH 4/7] refactor: don't allocate in parse unless it succeeds --- common/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/src/lib.rs b/common/src/lib.rs index 340c4f9f2..579f327e4 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -49,7 +49,7 @@ pub struct ApiKey(String); impl ApiKey { pub fn parse(key: &str) -> anyhow::Result { - let key = key.trim().to_string(); + let key = key.trim(); let mut errors = vec![]; if !key.chars().all(char::is_alphanumeric) { @@ -65,7 +65,7 @@ impl ApiKey { bail!("Invalid API key:\n{message}") } - Ok(Self(key)) + Ok(Self(key.to_string())) } #[cfg(feature = "persist")] From 5ad61bada07da7fa6713d004cca7d0018a5138a3 Mon Sep 17 00:00:00 2001 From: oddgrd <29732646+oddgrd@users.noreply.github.com> Date: Mon, 1 May 2023 22:15:15 +0200 Subject: [PATCH 5/7] fix: clippy --- auth/src/user.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auth/src/user.rs b/auth/src/user.rs index 31a3ba8d5..c452ef7f5 100644 --- a/auth/src/user.rs +++ b/auth/src/user.rs @@ -148,7 +148,7 @@ where .map_err(|_| Error::KeyMissing) .and_then(|TypedHeader(Authorization(bearer))| { let bearer = bearer.token().trim(); - Ok(ApiKey::parse(bearer).map_err(|_| Self::Rejection::Unauthorized)?) + ApiKey::parse(bearer).map_err(|_| Self::Rejection::Unauthorized) })?; trace!("got bearer key"); From 795cc0721c1e415ca6080816dd48898cc90b7081 Mon Sep 17 00:00:00 2001 From: oddgrd <29732646+oddgrd@users.noreply.github.com> Date: Mon, 1 May 2023 22:32:58 +0200 Subject: [PATCH 6/7] fix: missing anyhow --- common/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/Cargo.toml b/common/Cargo.toml index be204fecb..adeb41bd3 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -7,7 +7,7 @@ repository.workspace = true description = "Common library for the shuttle platform (https://www.shuttle.rs/)" [dependencies] -anyhow = { workspace = true, optional = true } +anyhow = { workspace = true } async-trait = { workspace = true, optional = true } axum = { workspace = true, optional = true } bytes = { workspace = true, optional = true } @@ -75,7 +75,7 @@ claims = [ display = ["chrono/clock", "comfy-table", "crossterm"] error = ["prost-types", "thiserror", "uuid"] openapi = ["utoipa/chrono", "utoipa/uuid"] -models = ["anyhow", "async-trait", "display", "http", "reqwest", "service"] +models = ["async-trait", "display", "http", "reqwest", "service"] persist = ["sqlx/sqlite", "rand"] service = ["chrono/serde", "once_cell", "rustrict", "serde/derive", "uuid"] tracing = [] From a878c29d6df17e78f3d1079ef8fb77a9edf4d082 Mon Sep 17 00:00:00 2001 From: oddgrd <29732646+oddgrd@users.noreply.github.com> Date: Thu, 4 May 2023 12:59:01 +0200 Subject: [PATCH 7/7] feat: impl debug/display for apikey --- Cargo.lock | 2 +- common/src/lib.rs | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 08578cff7..c511c8d57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4200,7 +4200,7 @@ dependencies = [ "rand", "rand_chacha", "rand_xorshift", - "regex-syntax", + "regex-syntax 0.6.29", "rusty-fork", "tempfile", "unarray", diff --git a/common/src/lib.rs b/common/src/lib.rs index 579f327e4..6b7acd956 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -20,6 +20,8 @@ pub mod tracing; pub mod wasm; use std::collections::BTreeMap; +use std::fmt::Debug; +use std::fmt::Display; use anyhow::bail; #[cfg(feature = "service")] @@ -41,7 +43,7 @@ pub type Host = String; #[cfg(feature = "service")] pub type DeploymentId = Uuid; -#[derive(Clone, Serialize, Deserialize, Debug)] +#[derive(Clone, Serialize, Deserialize)] #[cfg_attr(feature = "persist", derive(sqlx::Type, PartialEq, Hash, Eq))] #[cfg_attr(feature = "persist", serde(transparent))] #[cfg_attr(feature = "persist", sqlx(transparent))] @@ -82,6 +84,20 @@ impl AsRef for ApiKey { } } +// Ensure we can't accidentaly log an ApiKey +impl Debug for ApiKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ApiKey: REDACTED") + } +} + +// Ensure we can't accidentaly log an ApiKey +impl Display for ApiKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + #[cfg(feature = "error")] /// Errors that can occur when changing types. Especially from prost #[derive(thiserror::Error, Debug)]