Skip to content

Commit

Permalink
feat(transport-http): JWT auth layer (alloy-rs#1314)
Browse files Browse the repository at this point in the history
  • Loading branch information
yash-atreya authored and lwedge99 committed Oct 8, 2024
1 parent 59fcff1 commit 4009f53
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 4 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ thiserror-no-std = "2.0.2"
url = "2.5"
derive_more = { version = "1.0.0", default-features = false }
http = "1.1.0"
jsonwebtoken = "9.3.0"

## serde
serde = { version = "1.0", default-features = false, features = [
Expand Down
3 changes: 2 additions & 1 deletion crates/provider/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ alloy-rlp.workspace = true
alloy-sol-types.workspace = true
alloy-signer.workspace = true
alloy-signer-local.workspace = true
alloy-transport-http = { workspace = true, features = ["reqwest"] }
alloy-transport-http = { workspace = true, features = ["reqwest", "jwt-auth"] }

itertools.workspace = true
reqwest.workspace = true
Expand All @@ -82,6 +82,7 @@ tower-http = { workspace = true, features = [
] }
http-body-util.workspace = true
http.workspace = true
alloy-rpc-types-engine = { workspace = true, features = ["jwt"] }

[features]
default = ["reqwest", "reqwest-default-tls"]
Expand Down
37 changes: 37 additions & 0 deletions crates/provider/src/provider/trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,43 @@ mod tests {
assert_eq!(0, num);
}

#[cfg(all(feature = "hyper", not(windows)))]
#[tokio::test]
async fn test_auth_layer_transport() {
use alloy_node_bindings::Reth;
use alloy_rpc_types_engine::JwtSecret;
use alloy_transport_http::{AuthLayer, AuthService, Http, HyperClient};

init_tracing();
let secret = JwtSecret::random();

let reth = Reth::new().arg("--rpc.jwtsecret").arg(hex::encode(secret.as_bytes())).spawn();

let hyper_client = Client::builder(TokioExecutor::new()).build_http::<Full<HyperBytes>>();

let service =
tower::ServiceBuilder::new().layer(AuthLayer::new(secret)).service(hyper_client);

let layer_transport: HyperClient<
Full<HyperBytes>,
AuthService<
Client<
alloy_transport_http::hyper_util::client::legacy::connect::HttpConnector,
Full<HyperBytes>,
>,
>,
> = HyperClient::with_service(service);

let http_hyper = Http::with_client(layer_transport, reth.endpoint_url());

let rpc_client = alloy_rpc_client::RpcClient::new(http_hyper, true);

let provider = RootProvider::<_, Ethereum>::new(rpc_client);

let num = provider.get_block_number().await.unwrap();
assert_eq!(0, num);
}

#[tokio::test]
async fn test_builder_helper_fn_any_network() {
init_tracing();
Expand Down
2 changes: 1 addition & 1 deletion crates/rpc-types-engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jsonrpsee-types = { version = "0.24", optional = true }

# jwt
rand = { workspace = true, optional = true }
jsonwebtoken = { version = "9.3.0", optional = true }
jsonwebtoken = { workspace = true, optional = true }

[features]
default = ["jwt", "std", "serde"]
Expand Down
5 changes: 5 additions & 0 deletions crates/rpc-types-engine/src/jwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ impl JwtSecret {
let algo = jsonwebtoken::Header::new(Algorithm::HS256);
jsonwebtoken::encode(&algo, claims, &key)
}

/// Returns the secret key as a byte slice.
pub const fn as_bytes(&self) -> &[u8] {
&self.0
}
}

impl core::fmt::Debug for JwtSecret {
Expand Down
11 changes: 11 additions & 0 deletions crates/transport-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ http-body-util = { workspace = true, optional = true }
hyper = { workspace = true, default-features = false, optional = true }
hyper-util = { workspace = true, features = ["full"], optional = true }

# auth layer
alloy-rpc-types-engine = { workspace = true, optional = true }
jsonwebtoken = { workspace = true, optional = true }

[features]
default = ["reqwest", "reqwest-default-tls"]
reqwest = [
Expand All @@ -52,6 +56,13 @@ hyper = [
"dep:tower",
"dep:tracing",
]
jwt-auth = [
"hyper",
"dep:alloy-rpc-types-engine",
"alloy-rpc-types-engine/jwt",
"alloy-rpc-types-engine/serde",
"dep:jsonwebtoken",
]
reqwest-default-tls = ["reqwest?/default-tls"]
reqwest-native-tls = ["reqwest?/native-tls"]
reqwest-rustls-tls = ["reqwest?/rustls-tls"]
7 changes: 5 additions & 2 deletions crates/transport-http/src/hyper_transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,15 @@ impl<B, S> HyperClient<B, S> {
}
}

impl<B, S> Http<HyperClient<B, S>>
impl<B, S, ResBody> Http<HyperClient<B, S>>
where
S: Service<Request<B>, Response = HyperResponse> + Clone + Send + Sync + 'static,
S: Service<Request<B>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
S::Future: Send,
S::Error: std::error::Error + Send + Sync + 'static,
B: From<Vec<u8>> + Send + 'static + Clone,
ResBody: BodyExt + Send + 'static,
ResBody::Error: std::error::Error + Send + Sync + 'static,
ResBody::Data: Send,
{
/// Make a request to the server using the given service.
fn request_hyper(&self, req: RequestPacket) -> TransportFut<'static> {
Expand Down
142 changes: 142 additions & 0 deletions crates/transport-http/src/layers/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
use crate::hyper::{header::AUTHORIZATION, Request, Response};
use alloy_rpc_types_engine::{Claims, JwtSecret};
use alloy_transport::{TransportError, TransportErrorKind};
use hyper::header::HeaderValue;
use jsonwebtoken::get_current_timestamp;
use std::{
future::Future,
pin::Pin,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use tower::{Layer, Service};

/// The [`AuthLayer`] uses the provided [`JwtSecret`] to generate and validate the jwt token
/// in the requests.
///
/// The generated token is inserted into the [`AUTHORIZATION`] header of the request.
#[derive(Clone, Debug)]
pub struct AuthLayer {
secret: JwtSecret,
latency_buffer: u64,
}

impl AuthLayer {
/// Create a new [`AuthLayer`].
pub const fn new(secret: JwtSecret) -> Self {
Self { secret, latency_buffer: 5000 }
}

/// We use this buffer to perfom an extra check on the `iat` field to prevent sending any
/// requests with tokens that are valid now but may not be upon reaching the server.
///
/// In milliseconds. Default is 5s.
pub const fn with_latency_buffer(self, latency_buffer: u64) -> Self {
Self { latency_buffer, ..self }
}
}

impl<S> Layer<S> for AuthLayer {
type Service = AuthService<S>;

fn layer(&self, inner: S) -> Self::Service {
AuthService::new(inner, self.secret, self.latency_buffer)
}
}

/// A service that generates and validates the jwt token in the requests using the provided secret.
#[derive(Clone, Debug)]
pub struct AuthService<S> {
inner: S,
secret: JwtSecret,
/// In milliseconds.
latency_buffer: u64,
most_recent_claim: Option<Claims>,
}

impl<S> AuthService<S> {
/// Create a new [`AuthService`] with the given inner service.
pub const fn new(inner: S, secret: JwtSecret, latency_buffer: u64) -> Self {
Self { inner, secret, latency_buffer, most_recent_claim: None }
}

/// Validate the token in the request headers.
///
/// Returns `true` if the token is still valid and `iat` is beyond the grace buffer.
fn validate(&self) -> bool {
if let Some(claim) = self.most_recent_claim.as_ref() {
let curr_secs = get_current_timestamp();
if claim.iat.abs_diff(curr_secs) * 1000 > self.latency_buffer {
return true;
}
}

false
}

/// Create a new token from the secret.
///
/// Updates the most_recent_claim with the new claim.
fn create_token_from_secret(&mut self) -> Result<String, jsonwebtoken::errors::Error> {
let claims = Claims {
iat: (SystemTime::now().duration_since(UNIX_EPOCH).unwrap() + Duration::from_secs(60))
.as_secs(),
exp: None,
};

self.most_recent_claim = Some(claims);

let token = self.secret.encode(&claims)?;

Ok(format!("Bearer {}", token))
}
}

impl<S, B, ResBody> Service<Request<B>> for AuthService<S>
where
S: Service<hyper::Request<B>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
S::Future: Send,
S::Error: std::error::Error + Send + Sync + 'static,
B: From<Vec<u8>> + Send + 'static + Clone + Sync,
ResBody: hyper::body::Body + Send + 'static,
ResBody::Error: std::error::Error + Send + Sync + 'static,
ResBody::Data: Send,
{
type Response = Response<ResBody>;
type Error = TransportError;
type Future =
Pin<Box<dyn Future<Output = Result<Response<ResBody>, Self::Error>> + Send + 'static>>;

fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(TransportErrorKind::custom)
}

fn call(&mut self, req: Request<B>) -> Self::Future {
let mut req = req;
let res = if self.validate() {
// Encodes the most recent claim into a token.
self.secret.encode(self.most_recent_claim.as_ref().unwrap())
} else {
// Creates a new Claim and encodes it into a token.
self.create_token_from_secret()
};

match res {
Ok(token) => {
req.headers_mut().insert(AUTHORIZATION, HeaderValue::from_str(&token).unwrap());

let mut this = self.clone();

Box::pin(
async move { this.inner.call(req).await.map_err(TransportErrorKind::custom) },
)
}
Err(e) => {
let e = TransportErrorKind::custom(e);
Box::pin(async move { Err(e) })
}
}
}
}
7 changes: 7 additions & 0 deletions crates/transport-http/src/layers/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//! tower http like layer implementations that work over the http::Request type.
#![cfg(all(not(target_arch = "wasm32"), feature = "hyper"))]

#[cfg(feature = "jwt-auth")]
mod auth;
#[cfg(feature = "jwt-auth")]
pub use auth::{AuthLayer, AuthService};
5 changes: 5 additions & 0 deletions crates/transport-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ pub use hyper;
#[cfg(all(not(target_arch = "wasm32"), feature = "hyper"))]
pub use hyper_util;

#[cfg(all(not(target_arch = "wasm32"), feature = "hyper", feature = "jwt-auth"))]
mod layers;
#[cfg(all(not(target_arch = "wasm32"), feature = "hyper", feature = "jwt-auth"))]
pub use layers::{AuthLayer, AuthService};

#[cfg(all(not(target_arch = "wasm32"), feature = "hyper"))]
mod hyper_transport;
#[cfg(all(not(target_arch = "wasm32"), feature = "hyper"))]
Expand Down

0 comments on commit 4009f53

Please sign in to comment.