diff --git a/libsql-server/src/auth/errors.rs b/libsql-server/src/auth/errors.rs index 5275153315..f1267c0c44 100644 --- a/libsql-server/src/auth/errors.rs +++ b/libsql-server/src/auth/errors.rs @@ -26,6 +26,10 @@ pub enum AuthError { AuthStringMalformed, #[error("Expected authorization header but none given")] AuthHeaderNotFound, + #[error("Expected authorization proxy header but none given")] + AuthProxyHeaderNotFound, + #[error("Failed to parse auth proxy header")] + AuthProxyHeaderInvalid, #[error("Non-ASCII auth header")] AuthHeaderNonAscii, #[error("Authentication failed")] @@ -47,6 +51,8 @@ impl AuthError { Self::JwtImmature => "AUTH_JWT_IMMATURE", Self::AuthStringMalformed => "AUTH_HEADER_MALFORMED", Self::AuthHeaderNotFound => "AUTH_HEADER_NOT_FOUND", + Self::AuthProxyHeaderNotFound => "AUTH_PROXY_HEADER_NOT_FOUND", + Self::AuthProxyHeaderInvalid => "AUTH_PROXY_HEADER_INVALID", Self::AuthHeaderNonAscii => "AUTH_HEADER_MALFORMED", Self::Other => "AUTH_FAILED", } diff --git a/libsql-server/src/auth/mod.rs b/libsql-server/src/auth/mod.rs index 09468f4b3a..365044e3ce 100644 --- a/libsql-server/src/auth/mod.rs +++ b/libsql-server/src/auth/mod.rs @@ -13,24 +13,23 @@ pub use authorized::Authorized; pub use errors::AuthError; pub use parsers::{parse_http_auth_header, parse_http_basic_auth_arg, parse_jwt_key}; pub use permission::Permission; -pub use user_auth_strategies::{Disabled, HttpBasic, Jwt, UserAuthContext, UserAuthStrategy}; +pub use user_auth_strategies::{ + Disabled, HttpBasic, Jwt, ProxyGrpc, UserAuthContext, UserAuthStrategy, +}; #[derive(Clone)] pub struct Auth { - pub user_strategy: Arc, + pub strategy: Arc, } impl Auth { - pub fn new(user_strategy: impl UserAuthStrategy + Send + Sync + 'static) -> Self { + pub fn new(strategy: impl UserAuthStrategy + Send + Sync + 'static) -> Self { Self { - user_strategy: Arc::new(user_strategy), + strategy: Arc::new(strategy), } } - pub fn authenticate( - &self, - context: Result, - ) -> Result { - self.user_strategy.authenticate(context) + pub fn authenticate(&self, context: UserAuthContext) -> Result { + self.strategy.authenticate(context) } } diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index 643a3fee1b..dbae78cf17 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -1,4 +1,4 @@ -use crate::auth::{constants::GRPC_AUTH_HEADER, AuthError}; +use crate::auth::AuthError; use anyhow::{bail, Context as _, Result}; use axum::http::HeaderValue; @@ -36,12 +36,20 @@ pub fn parse_jwt_key(data: &str) -> Result { } } -pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Result { - metadata - .get(GRPC_AUTH_HEADER) - .ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)) +pub(crate) fn parse_grpc_auth_header( + metadata: &MetadataMap, + required_fields: &Vec, +) -> UserAuthContext { + let mut context = UserAuthContext::empty(); + for field in required_fields.iter() { + metadata + .get(field) + .map(|header| header.to_str().ok()) + .and_then(|r| r) + .map(|v| context.add_field(field.into(), v.into())); + } + + context } pub fn parse_http_auth_header<'a>( @@ -79,40 +87,26 @@ mod tests { #[test] fn parse_grpc_auth_header_returns_valid_context() { let mut map = tonic::metadata::MetadataMap::new(); - map.insert("x-authorization", "bearer 123".parse().unwrap()); - let context = parse_grpc_auth_header(&map).unwrap(); - assert_eq!(context.scheme().as_ref().unwrap(), "bearer"); - assert_eq!(context.token().as_ref().unwrap(), "123"); - } - - #[test] - fn parse_grpc_auth_header_error_no_header() { - let map = tonic::metadata::MetadataMap::new(); - let result = parse_grpc_auth_header(&map); + map.insert( + crate::auth::constants::GRPC_AUTH_HEADER, + "bearer 123".parse().unwrap(), + ); + let required_fields = vec!["x-authorization".into()]; + let context = parse_grpc_auth_header(&map, &required_fields); assert_eq!( - result.unwrap_err().to_string(), - "Expected authorization header but none given" + context.custom_fields.get("x-authorization"), + Some(&"bearer 123".to_string()) ); } - #[test] - fn parse_grpc_auth_header_error_non_ascii() { - let mut map = tonic::metadata::MetadataMap::new(); - map.insert("x-authorization", "bearer I❤NY".parse().unwrap()); - let result = parse_grpc_auth_header(&map); - assert_eq!(result.unwrap_err().to_string(), "Non-ASCII auth header") - } - - #[test] - fn parse_grpc_auth_header_error_malformed_auth_str() { - let mut map = tonic::metadata::MetadataMap::new(); - map.insert("x-authorization", "bearer123".parse().unwrap()); - let result = parse_grpc_auth_header(&map); - assert_eq!( - result.unwrap_err().to_string(), - "Auth string does not conform to ' ' form" - ) - } + // #[test] TODO rewrite + // fn parse_grpc_auth_header_error_non_ascii() { + // let mut map = tonic::metadata::MetadataMap::new(); + // map.insert("x-authorization", "bearer I❤NY".parse().unwrap()); + // let required_fields = Vec::new(); + // let result = parse_grpc_auth_header(&map, &required_fields); + // assert_eq!(result.unwrap_err().to_string(), "Non-ASCII auth header") + // } #[test] fn parse_http_auth_header_returns_auth_header_param_when_valid() { diff --git a/libsql-server/src/auth/user_auth_strategies/disabled.rs b/libsql-server/src/auth/user_auth_strategies/disabled.rs index b95d52c061..ef9aae9062 100644 --- a/libsql-server/src/auth/user_auth_strategies/disabled.rs +++ b/libsql-server/src/auth/user_auth_strategies/disabled.rs @@ -4,10 +4,7 @@ use crate::auth::{AuthError, Authenticated}; pub struct Disabled {} impl UserAuthStrategy for Disabled { - fn authenticate( - &self, - _context: Result, - ) -> Result { + fn authenticate(&self, _context: UserAuthContext) -> Result { tracing::trace!("executing disabled auth"); Ok(Authenticated::FullAccess) } @@ -26,7 +23,7 @@ mod tests { #[test] fn authenticates() { let strategy = Disabled::new(); - let context = Ok(UserAuthContext::empty()); + let context = UserAuthContext::empty(); assert!(matches!( strategy.authenticate(context).unwrap(), diff --git a/libsql-server/src/auth/user_auth_strategies/http_basic.rs b/libsql-server/src/auth/user_auth_strategies/http_basic.rs index fbb45d0912..2310c7821b 100644 --- a/libsql-server/src/auth/user_auth_strategies/http_basic.rs +++ b/libsql-server/src/auth/user_auth_strategies/http_basic.rs @@ -7,27 +7,31 @@ pub struct HttpBasic { } impl UserAuthStrategy for HttpBasic { - fn authenticate( - &self, - context: Result, - ) -> Result { + fn authenticate(&self, ctx: UserAuthContext) -> Result { tracing::trace!("executing http basic auth"); + let auth_str = None + .or_else(|| ctx.custom_fields.get("authorization")) + .or_else(|| ctx.custom_fields.get("x-authorization")); + + let (_, token) = auth_str + .ok_or(AuthError::AuthHeaderNotFound) + .map(|s| s.split_once(' ').ok_or(AuthError::AuthStringMalformed)) + .and_then(|o| o)?; // NOTE: this naive comparison may leak information about the `expected_value` // using a timing attack let expected_value = self.credential.trim_end_matches('='); - - let creds_match = match context?.token { - Some(s) => s.contains(expected_value), - None => expected_value.is_empty(), - }; - + let creds_match = token.contains(expected_value); if creds_match { return Ok(Authenticated::FullAccess); } Err(AuthError::BasicRejected) } + + fn required_fields(&self) -> Vec { + vec!["authorization".to_string(), "x-authorization".to_string()] + } } impl HttpBasic { @@ -48,7 +52,7 @@ mod tests { #[test] fn authenticates_with_valid_credential() { - let context = Ok(UserAuthContext::basic(CREDENTIAL)); + let context = UserAuthContext::basic(CREDENTIAL); assert!(matches!( strategy().authenticate(context).unwrap(), @@ -59,7 +63,7 @@ mod tests { #[test] fn authenticates_with_valid_trimmed_credential() { let credential = CREDENTIAL.trim_end_matches('='); - let context = Ok(UserAuthContext::basic(credential)); + let context = UserAuthContext::basic(credential); assert!(matches!( strategy().authenticate(context).unwrap(), @@ -69,7 +73,7 @@ mod tests { #[test] fn errors_when_credentials_do_not_match() { - let context = Ok(UserAuthContext::basic("abc")); + let context = UserAuthContext::basic("abc"); assert_eq!( strategy().authenticate(context).unwrap_err(), diff --git a/libsql-server/src/auth/user_auth_strategies/jwt.rs b/libsql-server/src/auth/user_auth_strategies/jwt.rs index da68e91df0..6fd504ba88 100644 --- a/libsql-server/src/auth/user_auth_strategies/jwt.rs +++ b/libsql-server/src/auth/user_auth_strategies/jwt.rs @@ -12,21 +12,16 @@ pub struct Jwt { } impl UserAuthStrategy for Jwt { - fn authenticate( - &self, - context: Result, - ) -> Result { + fn authenticate(&self, ctx: UserAuthContext) -> Result { tracing::trace!("executing jwt auth"); + let auth_str = None + .or_else(|| ctx.custom_fields.get("authorization")) + .or_else(|| ctx.custom_fields.get("x-authorization")) + .ok_or_else(|| AuthError::AuthHeaderNotFound)?; - let ctx = context?; - - let UserAuthContext { - scheme: Some(scheme), - token: Some(token), - } = ctx - else { - return Err(AuthError::HttpAuthHeaderInvalid); - }; + let (scheme, token) = auth_str + .split_once(' ') + .ok_or(AuthError::AuthStringMalformed)?; if !scheme.eq_ignore_ascii_case("bearer") { return Err(AuthError::HttpAuthHeaderUnsupportedScheme); @@ -34,6 +29,10 @@ impl UserAuthStrategy for Jwt { return validate_jwt(&self.key, &token); } + + fn required_fields(&self) -> Vec { + vec!["authentication".to_string()] + } } impl Jwt { @@ -155,7 +154,7 @@ mod tests { }; let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); assert!(matches!( strategy(dec).authenticate(context).unwrap(), @@ -177,8 +176,7 @@ mod tests { }; let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); - + let context = UserAuthContext::bearer(token.as_str()); let Authenticated::Legacy(a) = strategy(dec).authenticate(context).unwrap() else { panic!() }; @@ -190,7 +188,7 @@ mod tests { #[test] fn errors_when_jwt_token_invalid() { let (_enc, dec) = key_pair(); - let context = Ok(UserAuthContext::bearer("abc")); + let context = UserAuthContext::bearer("abc"); assert_eq!( strategy(dec).authenticate(context).unwrap_err(), @@ -210,7 +208,7 @@ mod tests { let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); assert_eq!( strategy(dec).authenticate(context).unwrap_err(), @@ -232,7 +230,7 @@ mod tests { let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); let Authenticated::Authorized(a) = strategy(dec).authenticate(context).unwrap() else { panic!() diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 4f0f2ef786..119e57ef7e 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -1,60 +1,43 @@ pub mod disabled; pub mod http_basic; pub mod jwt; +pub mod proxy_grpc; pub use disabled::Disabled; +use hashbrown::HashMap; pub use http_basic::HttpBasic; pub use jwt::Jwt; +pub use proxy_grpc::ProxyGrpc; use super::{AuthError, Authenticated}; #[derive(Debug)] pub struct UserAuthContext { - scheme: Option, - token: Option, + pub custom_fields: HashMap, String>, } impl UserAuthContext { - pub fn scheme(&self) -> &Option { - &self.scheme - } - - pub fn token(&self) -> &Option { - &self.token - } - pub fn empty() -> UserAuthContext { UserAuthContext { - scheme: None, - token: None, + custom_fields: HashMap::new(), } } pub fn basic(creds: &str) -> UserAuthContext { UserAuthContext { - scheme: Some("Basic".into()), - token: Some(creds.into()), + custom_fields: HashMap::from([("authorization".into(), format!("Basic {creds}"))]), } } pub fn bearer(token: &str) -> UserAuthContext { UserAuthContext { - scheme: Some("Bearer".into()), - token: Some(token.into()), - } - } - - pub fn bearer_opt(token: Option) -> UserAuthContext { - UserAuthContext { - scheme: Some("Bearer".into()), - token: token, + custom_fields: HashMap::from([("authorization".into(), format!("Bearer {token}"))]), } } pub fn new(scheme: &str, token: &str) -> UserAuthContext { UserAuthContext { - scheme: Some(scheme.into()), - token: Some(token.into()), + custom_fields: HashMap::from([("authorization".into(), format!("{scheme} {token}"))]), } } @@ -64,11 +47,16 @@ impl UserAuthContext { .ok_or(AuthError::AuthStringMalformed)?; Ok(UserAuthContext::new(scheme, token)) } + + pub fn add_field(&mut self, key: String, value: String) { + self.custom_fields.insert(key.into(), value.into()); + } } pub trait UserAuthStrategy: Sync + Send { - fn authenticate( - &self, - context: Result, - ) -> Result; + fn required_fields(&self) -> Vec { + vec![] + } + + fn authenticate(&self, context: UserAuthContext) -> Result; } diff --git a/libsql-server/src/auth/user_auth_strategies/proxy_grpc.rs b/libsql-server/src/auth/user_auth_strategies/proxy_grpc.rs new file mode 100644 index 0000000000..c6c8e39151 --- /dev/null +++ b/libsql-server/src/auth/user_auth_strategies/proxy_grpc.rs @@ -0,0 +1,33 @@ +use crate::auth::{AuthError, Authenticated}; + +use super::{UserAuthContext, UserAuthStrategy}; + +pub struct ProxyGrpc {} + +impl UserAuthStrategy for ProxyGrpc { + fn authenticate(&self, ctx: UserAuthContext) -> Result { + tracing::trace!("executing proxy grpc auth"); + let auth_str = None + .or_else(|| ctx.custom_fields.get("proxy-authorization")) + .or_else(|| ctx.custom_fields.get("x-proxy-authorization")) + .ok_or_else(|| AuthError::AuthProxyHeaderNotFound)?; + + serde_json::from_str::(&auth_str) + .map_err(|_| AuthError::AuthProxyHeaderInvalid) + } + + fn required_fields(&self) -> Vec { + vec![ + "authorization".to_string(), + "x-proxy-authorization".to_string(), + ] + } +} + +impl ProxyGrpc { + pub fn new() -> Self { + Self {} + } +} + +// todo tests diff --git a/libsql-server/src/hrana/ws/conn.rs b/libsql-server/src/hrana/ws/conn.rs index e34e00e6bf..a5581cbc71 100644 --- a/libsql-server/src/hrana/ws/conn.rs +++ b/libsql-server/src/hrana/ws/conn.rs @@ -206,16 +206,18 @@ async fn handle_client_msg(conn: &mut Conn, client_msg: proto::ClientMsg) -> Res } async fn handle_hello_msg(conn: &mut Conn, jwt: Option) -> Result { - let hello_res = match conn.session.as_mut() { - None => { - session::handle_initial_hello(&conn.server, conn.version, jwt, conn.namespace.clone()) - .await - .map(|session| conn.session = Some(session)) - } - Some(session) => { - session::handle_repeated_hello(&conn.server, session, jwt, conn.namespace.clone()).await - } - }; + let auth = session::handle_hello(&conn.server, jwt, conn.namespace.clone()).await; + + let hello_res = auth + .map(|a| { + if let Some(sess) = conn.session.as_mut() { + sess.update_auth(a) + } else { + conn.session = Some(session::Session::new(a, conn.version)); + Ok(()) + } + }) + .and_then(|o| o); match hello_res { Ok(_) => { diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index aef1f63574..4cf4d59407 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; -use anyhow::{anyhow, bail, Result}; +use anyhow::{anyhow, bail, Error, Result}; use futures::future::BoxFuture; use tokio::sync::{mpsc, oneshot}; @@ -22,6 +22,29 @@ pub struct Session { cursors: HashMap, } +impl Session { + pub fn new(auth: Authenticated, version: Version) -> Self { + Self { + auth, + version, + streams: HashMap::new(), + sqls: HashMap::new(), + cursors: HashMap::new(), + } + } + + pub fn update_auth(&mut self, auth: Authenticated) -> Result<(), Error> { + if self.version < Version::Hrana2 { + bail!(ProtocolError::NotSupported { + what: "Repeated hello message", + min_version: Version::Hrana2, + }) + } + self.auth = auth; + Ok(()) + } +} + struct StreamHandle { job_tx: mpsc::Sender, cursor_id: Option, @@ -65,60 +88,34 @@ pub enum ResponseError { Batch(batch::BatchError), } -pub(super) async fn handle_initial_hello( +pub(super) async fn handle_hello( server: &Server, - version: Version, jwt: Option, namespace: NamespaceName, -) -> Result { - // todo dupe #auth +) -> Result { let namespace_jwt_key = server .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) .await??; - let auth = namespace_jwt_key + let auth_strategy = namespace_jwt_key .map(Jwt::new) .map(Auth::new) - .unwrap_or(server.user_auth_strategy.clone()) - .authenticate(Ok(UserAuthContext::bearer_opt(jwt))) - .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; - - Ok(Session { - auth, - version, - streams: HashMap::new(), - sqls: HashMap::new(), - cursors: HashMap::new(), - }) -} + .unwrap_or_else(|| server.user_auth_strategy.clone()); -pub(super) async fn handle_repeated_hello( - server: &Server, - session: &mut Session, - jwt: Option, - namespace: NamespaceName, -) -> Result<()> { - if session.version < Version::Hrana2 { - bail!(ProtocolError::NotSupported { - what: "Repeated hello message", - min_version: Version::Hrana2, - }) - } - // todo dupe #auth - let namespace_jwt_key = server - .namespaces - .with(namespace.clone(), |ns| ns.jwt_key()) - .await??; + let context: UserAuthContext = build_context(jwt, &auth_strategy.strategy.required_fields()); - session.auth = namespace_jwt_key - .map(Jwt::new) - .map(Auth::new) - .unwrap_or_else(|| server.user_auth_strategy.clone()) - .authenticate(Ok(UserAuthContext::bearer_opt(jwt))) - .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; + auth_strategy + .authenticate(context) + .map_err(|err| anyhow!(ResponseError::Auth { source: err })) +} - Ok(()) +fn build_context(jwt: Option, required_fields: &Vec) -> UserAuthContext { + let mut ctx = UserAuthContext::empty(); + if required_fields.contains(&"authorization".into()) && jwt.is_some() { + ctx.add_field("authorization".into(), jwt.unwrap()); + } + ctx } pub(super) async fn handle_request( diff --git a/libsql-server/src/http/user/db_factory.rs b/libsql-server/src/http/user/db_factory.rs index 257d8811c1..794a1f25ac 100644 --- a/libsql-server/src/http/user/db_factory.rs +++ b/libsql-server/src/http/user/db_factory.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use axum::extract::{FromRequestParts, Path}; use hyper::http::request::Parts; use hyper::HeaderMap; +use libsql_replication::rpc::replication::NAMESPACE_METADATA_KEY; use crate::auth::Authenticated; use crate::connection::MakeConnection; @@ -46,18 +47,25 @@ pub fn namespace_from_headers( return Ok(NamespaceName::default()); } - let host = headers - .get("host") - .ok_or_else(|| Error::InvalidHost("missing host header".into()))? - .as_bytes(); - let host_str = std::str::from_utf8(host) - .map_err(|_| Error::InvalidHost("host header is not valid UTF-8".into()))?; + headers + .get(NAMESPACE_METADATA_KEY) + .ok_or(Error::InvalidNamespace) + .and_then(|h| h.to_str().map_err(|_| Error::InvalidNamespace)) + .and_then(|n| NamespaceName::from_string(n.into())) + .or_else(|_| { + let host = headers + .get("host") + .ok_or_else(|| Error::InvalidHost("missing host header".into()))? + .as_bytes(); + let host_str = std::str::from_utf8(host) + .map_err(|_| Error::InvalidHost("host header is not valid UTF-8".into()))?; - match split_namespace(host_str) { - Ok(ns) => Ok(ns), - Err(_) if !disable_default_namespace => Ok(NamespaceName::default()), - Err(e) => Err(e), - } + match split_namespace(host_str) { + Ok(ns) => Ok(ns), + Err(_) if !disable_default_namespace => Ok(NamespaceName::default()), + Err(e) => Err(e), + } + }) } pub struct MakeConnectionExtractorPath(pub Arc>); diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index b850e76ff3..80badb18fe 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -1,7 +1,7 @@ use axum::extract::FromRequestParts; use crate::{ - auth::{Auth, AuthError, Jwt, UserAuthContext}, + auth::{Auth, Jwt}, connection::RequestContext, }; @@ -24,25 +24,18 @@ impl FromRequestParts for RequestContext { let namespace_jwt_key = state .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) - .await??; - - let context = parts - .headers - .get(hyper::header::AUTHORIZATION) - .ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)); - - let authenticated = namespace_jwt_key - .map(Jwt::new) - .map(Auth::new) - .unwrap_or_else(|| state.user_auth_strategy.clone()) - .authenticate(context)?; - - Ok(Self::new( - authenticated, - namespace, - state.namespaces.meta_store().clone(), - )) + .await + .and_then(|o|o)?; + + let auth = namespace_jwt_key + .map(|key|Auth::new(Jwt::new(key))) + .unwrap_or_else(|| state.user_auth_strategy.clone()); + + let context = super::build_context(&parts.headers, &auth.strategy.required_fields()); + + auth.authenticate(context) + .map(|a| Self::new(a, namespace, state.namespaces.meta_store().clone())) + .map_err(|e|e.into()) + } } diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 3e394fc579..9137c4e6a9 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -468,20 +468,36 @@ impl FromRequestParts for Authenticated { .with(ns.clone(), |ns| ns.jwt_key()) .await??; - let context = parts - .headers - .get(hyper::header::AUTHORIZATION) - .ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)); - - let authenticated = namespace_jwt_key - .map(Jwt::new) - .map(Auth::new) - .unwrap_or_else(|| state.user_auth_strategy.clone()) - .authenticate(context)?; - Ok(authenticated) + let auth = namespace_jwt_key + .map(|key|Auth::new(Jwt::new(key))) + .unwrap_or_else(|| state.user_auth_strategy.clone()); + + let context = build_context(&parts.headers, &auth.strategy.required_fields()); + + auth.authenticate(context) + .map_err(|e|e.into()) + } +} + +fn build_context( + headers: &hyper::HeaderMap, + required_fields: &Vec, +) -> UserAuthContext { + let mut ctx = headers + .get(hyper::header::AUTHORIZATION) + .ok_or(AuthError::AuthHeaderNotFound) + .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) + .and_then(|t| UserAuthContext::from_auth_str(t)) + .unwrap_or(UserAuthContext::empty()); + + for field in required_fields.iter() { + headers + .get(field) + .map(|h| h.to_str().ok()) + .and_then(|t| t.map(|s| ctx.add_field(field.into(), s.into()))); } + + ctx } impl FromRef for Auth { diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index d7f96d68a7..e39717cbaf 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -19,7 +19,7 @@ use tokio::time::Duration; use uuid::Uuid; use crate::auth::parsers::parse_grpc_auth_header; -use crate::auth::{Auth, Authenticated, Jwt}; +use crate::auth::{Auth, Jwt, ProxyGrpc}; use crate::connection::{Connection as _, RequestContext}; use crate::database::Connection; use crate::namespace::NamespaceStore; @@ -311,41 +311,25 @@ impl ProxyService { &self, req: &mut tonic::Request, ) -> Result { - let namespace = super::extract_namespace(self.disable_namespaces, req)?; + let ns = super::extract_namespace(self.disable_namespaces, req)?; // todo dupe #auth let namespace_jwt_key = self .namespaces - .with(namespace.clone(), |ns| ns.jwt_key()) - .await; - - let auth = match namespace_jwt_key { - Ok(Ok(Some(key))) => Some(Auth::new(Jwt::new(key))), - Ok(Ok(None)) => self.user_auth_strategy.clone(), - Err(e) => match e.as_ref() { - crate::error::Error::NamespaceDoesntExist(_) => None, - _ => Err(tonic::Status::internal(format!( - "Error fetching jwt key for a namespace: {}", - e - )))?, - }, - Ok(Err(e)) => Err(tonic::Status::internal(format!( - "Error fetching jwt key for a namespace: {}", - e - )))?, - }; + .with(ns.clone(), |ns| ns.jwt_key()) + .await + .and_then(|o|o) + .map_err(|e|tonic::Status::internal(format!("Error fetching jwt key for a namespace: {}",e)))?; - let auth = if let Some(auth) = auth { - let context = parse_grpc_auth_header(req.metadata()); - auth.authenticate(context)? - } else { - Authenticated::from_proxy_grpc_request(req)? - }; + let auth = namespace_jwt_key + .map(|key|Auth::new(Jwt::new(key))) + .or_else(||self.user_auth_strategy.clone()) + .unwrap_or_else(|| Auth::new(ProxyGrpc::new())); + + let context = parse_grpc_auth_header(req.metadata(), &auth.strategy.required_fields()); - Ok(RequestContext::new( - auth, - namespace, - self.namespaces.meta_store().clone(), - )) + auth.authenticate(context) + .map(|a| RequestContext::new(a, ns, self.namespaces.meta_store().clone())) + .map_err(|e|e.into()) } } diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index 6945cfffed..faf4be950e 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -36,30 +36,22 @@ impl ReplicaProxyService { async fn do_auth(&self, req: &mut Request) -> Result<(), Status> { let namespace = super::extract_namespace(self.disable_namespaces, req)?; - + // todo dupe #auth let jwt_result = self .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) - .await; - - let namespace_jwt_key = jwt_result.and_then(|s| s); + .await + .and_then(|s| s) + .map_err(|e|Status::internal(format!("Can't fetch jwt key for a namespace: {}",e)))?; - let auth_strategy = match namespace_jwt_key { - Ok(Some(key)) => Ok(Auth::new(Jwt::new(key))), - Ok(None) | Err(crate::error::Error::NamespaceDoesntExist(_)) => { - Ok(self.user_auth_strategy.clone()) - } - Err(e) => Err(Status::internal(format!( - "Can't fetch jwt key for a namespace: {}", - e - ))), - }?; + let auth = jwt_result + .map(|key|Auth::new(Jwt::new(key))) + .unwrap_or_else(|| self.user_auth_strategy.clone()); - let auth_context = parse_grpc_auth_header(req.metadata()); - auth_strategy - .authenticate(auth_context)? - .upgrade_grpc_request(req); + let auth_context = parse_grpc_auth_header(req.metadata(), &auth.strategy.required_fields()); + auth.authenticate(auth_context)?.upgrade_grpc_request(req); + return Ok(()); } } diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index 5a93b3331e..b9ff3c8035 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -75,27 +75,17 @@ impl ReplicationLogService { let namespace_jwt_key = self .namespaces .with(namespace.clone(), |ns| ns.jwt_key()) - .await; - - let auth = match namespace_jwt_key { - Ok(Ok(Some(key))) => Some(Auth::new(Jwt::new(key))), - Ok(Ok(None)) => self.user_auth_strategy.clone(), - Err(e) => match e.as_ref() { - crate::error::Error::NamespaceDoesntExist(_) => self.user_auth_strategy.clone(), - _ => Err(Status::internal(format!( - "Error fetching jwt key for a namespace: {}", - e - )))?, - }, - Ok(Err(e)) => Err(Status::internal(format!( - "Error fetching jwt key for a namespace: {}", - e - )))?, - }; + .await + .and_then(|o|o) + .map_err(|e|Status::internal(format!("Error fetching jwt key for a namespace: {}",e)))?; + + let auth = namespace_jwt_key + .map(|key|Auth::new(Jwt::new(key))) + .or_else(|| self.user_auth_strategy.clone()); if let Some(auth) = auth { - let user_credential = parse_grpc_auth_header(req.metadata()); - auth.authenticate(user_credential)?; + let context = parse_grpc_auth_header(req.metadata(), &auth.strategy.required_fields()); + auth.authenticate(context)?; } Ok(()) diff --git a/libsql/examples/remote_sync.rs b/libsql/examples/remote_sync.rs index ccc31eb033..d42ccd1f75 100644 --- a/libsql/examples/remote_sync.rs +++ b/libsql/examples/remote_sync.rs @@ -7,40 +7,42 @@ async fn main() { tracing_subscriber::fmt::init(); // The local database path where the data will be stored. - let db_path = match std::env::var("LIBSQL_DB_PATH") { - Ok(path) => path, - Err(_) => { + let db_path = std::env::var("LIBSQL_DB_PATH") + .map_err(|_| { eprintln!( "Please set the LIBSQL_DB_PATH environment variable to set to local database path." - ); - return; - } - }; + ) + }) + .unwrap(); // The remote sync URL to use. - let sync_url = match std::env::var("LIBSQL_SYNC_URL") { - Ok(url) => url, - Err(_) => { + let sync_url = std::env::var("LIBSQL_SYNC_URL") + .map_err(|_| { eprintln!( "Please set the LIBSQL_SYNC_URL environment variable to set to remote sync URL." - ); - return; - } - }; + ) + }) + .unwrap(); + + let namespace = std::env::var("LIBSQL_NAMESPACE").ok(); // The authentication token to use. let auth_token = std::env::var("LIBSQL_AUTH_TOKEN").unwrap_or("".to_string()); - let db = match Builder::new_remote_replica(db_path, sync_url, auth_token) - .build() - .await - { + let db_builder = if let Some(ns) = namespace { + Builder::new_remote_replica(db_path, sync_url, auth_token).namespace(&ns) + } else { + Builder::new_remote_replica(db_path, sync_url, auth_token) + }; + + let db = match db_builder.build().await { Ok(db) => db, Err(error) => { eprintln!("Error connecting to remote sync server: {}", error); return; } }; + let conn = db.connect().unwrap(); print!("Syncing with remote database..."); diff --git a/libsql/src/database.rs b/libsql/src/database.rs index dd5ebc671f..6659b9ba19 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -184,7 +184,7 @@ cfg_replication! { None, OpenFlags::default(), encryption_config.clone(), - None + None, ).await?; Ok(Database { @@ -309,6 +309,7 @@ cfg_replication! { read_your_writes, encryption_config.clone(), sync_interval, + None, None ).await?; diff --git a/libsql/src/database/builder.rs b/libsql/src/database/builder.rs index 065b50588c..46fe7125f9 100644 --- a/libsql/src/database/builder.rs +++ b/libsql/src/database/builder.rs @@ -60,7 +60,8 @@ impl Builder<()> { encryption_config: None, read_your_writes: true, sync_interval: None, - http_request_callback: None + http_request_callback: None, + namespace: None }, } } @@ -165,6 +166,7 @@ cfg_replication! { read_your_writes: bool, sync_interval: Option, http_request_callback: Option, + namespace: Option, } /// Local replica configuration type in [`Builder`]. @@ -226,6 +228,11 @@ cfg_replication! { } + pub fn namespace(mut self, namespace: &str) -> Builder { + self.inner.namespace = Some(namespace.into()); + self + } + #[doc(hidden)] pub fn version(mut self, version: String) -> Builder { self.inner.remote = self.inner.remote.version(version); @@ -246,7 +253,8 @@ cfg_replication! { encryption_config, read_your_writes, sync_interval, - http_request_callback + http_request_callback, + namespace } = self.inner; let connector = if let Some(connector) = connector { @@ -273,7 +281,8 @@ cfg_replication! { read_your_writes, encryption_config.clone(), sync_interval, - http_request_callback + http_request_callback, + namespace, ) .await?; @@ -339,7 +348,7 @@ cfg_replication! { version, flags, encryption_config.clone(), - http_request_callback + http_request_callback, ) .await? } else { diff --git a/libsql/src/local/database.rs b/libsql/src/local/database.rs index 1ffdc49b12..712c9e095a 100644 --- a/libsql/src/local/database.rs +++ b/libsql/src/local/database.rs @@ -65,6 +65,7 @@ impl Database { encryption_config, sync_interval, None, + None, ) .await } @@ -81,6 +82,7 @@ impl Database { encryption_config: Option, sync_interval: Option, http_request_callback: Option, + namespace: Option ) -> Result { use std::path::PathBuf; @@ -95,6 +97,7 @@ impl Database { auth_token, version.as_deref(), http_request_callback, + namespace, ) .unwrap(); let path = PathBuf::from(db_path); @@ -166,6 +169,7 @@ impl Database { auth_token, version.as_deref(), http_request_callback, + None, ) .unwrap(); diff --git a/libsql/src/replication/client.rs b/libsql/src/replication/client.rs index 8ef5edaf04..16227d4f5a 100644 --- a/libsql/src/replication/client.rs +++ b/libsql/src/replication/client.rs @@ -47,6 +47,7 @@ impl Client { auth_token: impl AsRef, version: Option<&str>, http_request_callback: Option, + maybe_namespace: Option, ) -> anyhow::Result { let ver = version.unwrap_or(env!("CARGO_PKG_VERSION")); @@ -58,7 +59,12 @@ impl Client { .try_into() .context("Invalid auth token must be ascii")?; - let ns = split_namespace(origin.host().unwrap()).unwrap_or_else(|_| "default".to_string()); + + let ns = maybe_namespace.unwrap_or_else(|| + split_namespace(origin.host().unwrap()) + .unwrap_or_else(|_| "default".to_string()) + ); + let namespace = BinaryMetadataValue::from_bytes(ns.as_bytes()); let channel = GrpcChannel::new(connector, http_request_callback);