Skip to content

Commit

Permalink
Feat-330/allow anonymous (#349)
Browse files Browse the repository at this point in the history
## Changes

### Breaking

- `UserId` now has a new variant: `Anonymous` to represent the explicit seal of approval for Anonymous users
- Update `ProxyAuthLayer` to support anonymous variant
- Update layer in function of the Authorization header to support anonymous variant
  • Loading branch information
calebbourg authored Nov 7, 2024
1 parent 6b2532f commit 5efa6e7
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 15 deletions.
157 changes: 145 additions & 12 deletions rama-http/src/layer/auth/require_authorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,39 @@ use crate::{
};
use rama_core::Context;

use rama_net::user::UserId;

const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;

impl<S, ResBody> ValidateRequestHeader<S, Basic<ResBody>> {
impl<C> ValidateRequestHeaderLayer<AuthorizeContext<C>> {
/// Allow anonymous requests.
pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
self.validate.allow_anonymous = allow_anonymous;
self
}

/// Allow anonymous requests.
pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
self.validate.allow_anonymous = allow_anonymous;
self
}
}

impl<S, C> ValidateRequestHeader<S, AuthorizeContext<C>> {
/// Allow anonymous requests.
pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
self.validate.allow_anonymous = allow_anonymous;
self
}

/// Allow anonymous requests.
pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
self.validate.allow_anonymous = allow_anonymous;
self
}
}

impl<S, ResBody> ValidateRequestHeader<S, AuthorizeContext<Basic<ResBody>>> {
/// Authorize requests using a username and password pair.
///
/// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
Expand All @@ -78,11 +108,11 @@ impl<S, ResBody> ValidateRequestHeader<S, Basic<ResBody>> {
where
ResBody: Default,
{
Self::custom(inner, Basic::new(username, value))
Self::custom(inner, AuthorizeContext::new(Basic::new(username, value)))
}
}

impl<ResBody> ValidateRequestHeaderLayer<Basic<ResBody>> {
impl<ResBody> ValidateRequestHeaderLayer<AuthorizeContext<Basic<ResBody>>> {
/// Authorize requests using a username and password pair.
///
/// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
Expand All @@ -94,11 +124,11 @@ impl<ResBody> ValidateRequestHeaderLayer<Basic<ResBody>> {
where
ResBody: Default,
{
Self::custom(Basic::new(username, password))
Self::custom(AuthorizeContext::new(Basic::new(username, password)))
}
}

impl<S, ResBody> ValidateRequestHeader<S, Bearer<ResBody>> {
impl<S, ResBody> ValidateRequestHeader<S, AuthorizeContext<Bearer<ResBody>>> {
/// Authorize requests using a "bearer token". Commonly used for OAuth 2.
///
/// The `Authorization` header is required to be `Bearer {token}`.
Expand All @@ -110,11 +140,11 @@ impl<S, ResBody> ValidateRequestHeader<S, Bearer<ResBody>> {
where
ResBody: Default,
{
Self::custom(inner, Bearer::new(token))
Self::custom(inner, AuthorizeContext::new(Bearer::new(token)))
}
}

impl<ResBody> ValidateRequestHeaderLayer<Bearer<ResBody>> {
impl<ResBody> ValidateRequestHeaderLayer<AuthorizeContext<Bearer<ResBody>>> {
/// Authorize requests using a "bearer token". Commonly used for OAuth 2.
///
/// The `Authorization` header is required to be `Bearer {token}`.
Expand All @@ -126,7 +156,7 @@ impl<ResBody> ValidateRequestHeaderLayer<Bearer<ResBody>> {
where
ResBody: Default,
{
Self::custom(Bearer::new(token))
Self::custom(AuthorizeContext::new(Bearer::new(token)))
}
}

Expand Down Expand Up @@ -169,7 +199,7 @@ impl<ResBody> fmt::Debug for Bearer<ResBody> {
}
}

impl<S, B, ResBody> ValidateRequest<S, B> for Bearer<ResBody>
impl<S, B, ResBody> ValidateRequest<S, B> for AuthorizeContext<Bearer<ResBody>>
where
ResBody: Default + Send + 'static,
B: Send + 'static,
Expand All @@ -183,7 +213,12 @@ where
request: Request<B>,
) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
match request.headers().get(header::AUTHORIZATION) {
Some(actual) if actual == self.header_value => Ok((ctx, request)),
Some(actual) if actual == self.credential.header_value => Ok((ctx, request)),
None if self.allow_anonymous => {
let mut ctx = ctx;
ctx.insert(UserId::Anonymous);
Ok((ctx, request))
}
_ => {
let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
Expand Down Expand Up @@ -232,7 +267,7 @@ impl<ResBody> fmt::Debug for Basic<ResBody> {
}
}

impl<S, B, ResBody> ValidateRequest<S, B> for Basic<ResBody>
impl<S, B, ResBody> ValidateRequest<S, B> for AuthorizeContext<Basic<ResBody>>
where
ResBody: Default + Send + 'static,
B: Send + 'static,
Expand All @@ -246,7 +281,12 @@ where
request: Request<B>,
) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
match request.headers().get(header::AUTHORIZATION) {
Some(actual) if actual == self.header_value => Ok((ctx, request)),
Some(actual) if actual == self.credential.header_value => Ok((ctx, request)),
None if self.allow_anonymous => {
let mut ctx = ctx;
ctx.insert(UserId::Anonymous);
Ok((ctx, request))
}
_ => {
let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
Expand All @@ -258,6 +298,38 @@ where
}
}

pub struct AuthorizeContext<C> {
credential: C,
allow_anonymous: bool,
}

impl<C> AuthorizeContext<C> {
pub(crate) fn new(credential: C) -> Self {
Self {
credential,
allow_anonymous: false,
}
}
}

impl<C: Clone> Clone for AuthorizeContext<C> {
fn clone(&self) -> Self {
Self {
credential: self.credential.clone(),
allow_anonymous: self.allow_anonymous,
}
}
}

impl<C: fmt::Debug> fmt::Debug for AuthorizeContext<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AuthorizeContext")
.field("credential", &self.credential)
.field("allow_anonymous", &self.allow_anonymous)
.finish()
}
}

#[cfg(test)]
mod tests {
#[allow(unused_imports)]
Expand Down Expand Up @@ -399,4 +471,65 @@ mod tests {
async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}

#[tokio::test]
async fn basic_allows_anonymous_if_header_is_missing() {
let service = ValidateRequestHeaderLayer::basic("foo", "bar")
.with_allow_anonymous(true)
.layer(service_fn(echo));

let request = Request::get("/").body(Body::empty()).unwrap();

let res = service.serve(Context::default(), request).await.unwrap();

assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn basic_fails_if_allow_anonymous_and_credentials_are_invalid() {
let service = ValidateRequestHeaderLayer::basic("foo", "bar")
.with_allow_anonymous(true)
.layer(service_fn(echo));

let request = Request::get("/")
.header(
header::AUTHORIZATION,
format!("Basic {}", BASE64.encode("wrong:credentials")),
)
.body(Body::empty())
.unwrap();

let res = service.serve(Context::default(), request).await.unwrap();

assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn bearer_allows_anonymous_if_header_is_missing() {
let service = ValidateRequestHeaderLayer::bearer("foobar")
.with_allow_anonymous(true)
.layer(service_fn(echo));

let request = Request::get("/").body(Body::empty()).unwrap();

let res = service.serve(Context::default(), request).await.unwrap();

assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn bearer_fails_if_allow_anonymous_and_credentials_are_invalid() {
let service = ValidateRequestHeaderLayer::bearer("foobar")
.with_allow_anonymous(true)
.layer(service_fn(echo));

let request = Request::get("/")
.header(header::AUTHORIZATION, "Bearer wrong")
.body(Body::empty())
.unwrap();

let res = service.serve(Context::default(), request).await.unwrap();

assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
}
39 changes: 38 additions & 1 deletion rama-http/src/layer/proxy_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::header::PROXY_AUTHENTICATE;
use crate::headers::{authorization::Credentials, HeaderMapExt, ProxyAuthorization};
use crate::{Request, Response, StatusCode};
use rama_core::{Context, Layer, Service};
use rama_net::user::auth::Authority;
use rama_net::user::{auth::Authority, UserId};
use rama_utils::macros::define_inner_service_accessors;
use std::fmt;
use std::marker::PhantomData;
Expand All @@ -16,6 +16,7 @@ use std::marker::PhantomData;
/// See the [module docs](super) for an example.
pub struct ProxyAuthLayer<A, C, L = ()> {
proxy_auth: A,
allow_anonymous: bool,
_phantom: PhantomData<fn(C, L) -> ()>,
}

Expand All @@ -35,6 +36,7 @@ impl<A: Clone, C, L> Clone for ProxyAuthLayer<A, C, L> {
fn clone(&self) -> Self {
Self {
proxy_auth: self.proxy_auth.clone(),
allow_anonymous: self.allow_anonymous,
_phantom: PhantomData,
}
}
Expand All @@ -45,9 +47,22 @@ impl<A, C> ProxyAuthLayer<A, C, ()> {
pub const fn new(proxy_auth: A) -> Self {
ProxyAuthLayer {
proxy_auth,
allow_anonymous: false,
_phantom: PhantomData,
}
}

/// Allow anonymous requests.
pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
self.allow_anonymous = allow_anonymous;
self
}

/// Allow anonymous requests.
pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
self.allow_anonymous = allow_anonymous;
self
}
}

impl<A, C, L> ProxyAuthLayer<A, C, L> {
Expand All @@ -63,6 +78,7 @@ impl<A, C, L> ProxyAuthLayer<A, C, L> {
pub fn with_labels<L2>(self) -> ProxyAuthLayer<A, C, L2> {
ProxyAuthLayer {
proxy_auth: self.proxy_auth,
allow_anonymous: self.allow_anonymous,
_phantom: PhantomData,
}
}
Expand All @@ -83,10 +99,13 @@ where
/// Middleware that validates if a request has the appropriate Proxy Authorisation.
///
/// If the request is not authorized a `407 Proxy Authentication Required` response will be sent.
/// If `allow_anonymous` is set to `true` then requests without a Proxy Authorization header will be
/// allowed and the user will be authoized as [`UserId::Anonymous`].
///
/// See the [module docs](self) for an example.
pub struct ProxyAuthService<A, C, S, L = ()> {
proxy_auth: A,
allow_anonymous: bool,
inner: S,
_phantom: PhantomData<fn(C, L) -> ()>,
}
Expand All @@ -96,18 +115,32 @@ impl<A, C, S, L> ProxyAuthService<A, C, S, L> {
pub const fn new(proxy_auth: A, inner: S) -> Self {
Self {
proxy_auth,
allow_anonymous: false,
inner,
_phantom: PhantomData,
}
}

/// Allow anonymous requests.
pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
self.allow_anonymous = allow_anonymous;
self
}

/// Allow anonymous requests.
pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
self.allow_anonymous = allow_anonymous;
self
}

define_inner_service_accessors!();
}

impl<A: fmt::Debug, C, S: fmt::Debug, L> fmt::Debug for ProxyAuthService<A, C, S, L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProxyAuthService")
.field("proxy_auth", &self.proxy_auth)
.field("allow_anonymous", &self.allow_anonymous)
.field("inner", &self.inner)
.field(
"_phantom",
Expand All @@ -121,6 +154,7 @@ impl<A: Clone, C, S: Clone, L> Clone for ProxyAuthService<A, C, S, L> {
fn clone(&self) -> Self {
ProxyAuthService {
proxy_auth: self.proxy_auth.clone(),
allow_anonymous: self.allow_anonymous,
inner: self.inner.clone(),
_phantom: PhantomData,
}
Expand Down Expand Up @@ -162,6 +196,9 @@ where
.body(Default::default())
.unwrap())
}
} else if self.allow_anonymous {
ctx.insert(UserId::Anonymous);
self.inner.serve(ctx, req).await
} else {
Ok(Response::builder()
.status(StatusCode::PROXY_AUTHENTICATION_REQUIRED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::fmt;
///
/// See the [module docs](crate::layer::validate_request) for an example.
pub struct ValidateRequestHeaderLayer<T> {
validate: T,
pub(crate) validate: T,
}

impl<T: fmt::Debug> fmt::Debug for ValidateRequestHeaderLayer<T> {
Expand Down Expand Up @@ -90,7 +90,7 @@ where
/// See the [module docs](crate::layer::validate_request) for an example.
pub struct ValidateRequestHeader<S, T> {
inner: S,
validate: T,
pub(crate) validate: T,
}

impl<S: fmt::Debug, T: fmt::Debug> fmt::Debug for ValidateRequestHeader<S, T> {
Expand Down
Loading

0 comments on commit 5efa6e7

Please sign in to comment.