diff --git a/Cargo.toml b/Cargo.toml index 0c1265aa6bc..3217430bd9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -139,6 +139,7 @@ thiserror-no-std = "2.0.2" url = "2.5" derive_more = { version = "1.0.0", default-features = false } http = "1.1.0" +jsonwebtoken = "9.3.0" ## serde serde = { version = "1.0", default-features = false, features = [ diff --git a/crates/provider/Cargo.toml b/crates/provider/Cargo.toml index 12444152f42..0b8587ddeb1 100644 --- a/crates/provider/Cargo.toml +++ b/crates/provider/Cargo.toml @@ -68,7 +68,7 @@ alloy-rlp.workspace = true alloy-sol-types.workspace = true alloy-signer.workspace = true alloy-signer-local.workspace = true -alloy-transport-http = { workspace = true, features = ["reqwest"] } +alloy-transport-http = { workspace = true, features = ["reqwest", "jwt-auth"] } itertools.workspace = true reqwest.workspace = true @@ -82,6 +82,7 @@ tower-http = { workspace = true, features = [ ] } http-body-util.workspace = true http.workspace = true +alloy-rpc-types-engine = { workspace = true, features = ["jwt"] } [features] default = ["reqwest", "reqwest-default-tls"] diff --git a/crates/provider/src/provider/trait.rs b/crates/provider/src/provider/trait.rs index 9956b15ca9c..df6b84dc952 100644 --- a/crates/provider/src/provider/trait.rs +++ b/crates/provider/src/provider/trait.rs @@ -1186,6 +1186,43 @@ mod tests { assert_eq!(0, num); } + #[cfg(all(feature = "hyper", not(windows)))] + #[tokio::test] + async fn test_auth_layer_transport() { + use alloy_node_bindings::Reth; + use alloy_rpc_types_engine::JwtSecret; + use alloy_transport_http::{AuthLayer, AuthService, Http, HyperClient}; + + init_tracing(); + let secret = JwtSecret::random(); + + let reth = Reth::new().arg("--rpc.jwtsecret").arg(hex::encode(secret.as_bytes())).spawn(); + + let hyper_client = Client::builder(TokioExecutor::new()).build_http::>(); + + let service = + tower::ServiceBuilder::new().layer(AuthLayer::new(secret)).service(hyper_client); + + let layer_transport: HyperClient< + Full, + AuthService< + Client< + alloy_transport_http::hyper_util::client::legacy::connect::HttpConnector, + Full, + >, + >, + > = HyperClient::with_service(service); + + let http_hyper = Http::with_client(layer_transport, reth.endpoint_url()); + + let rpc_client = alloy_rpc_client::RpcClient::new(http_hyper, true); + + let provider = RootProvider::<_, Ethereum>::new(rpc_client); + + let num = provider.get_block_number().await.unwrap(); + assert_eq!(0, num); + } + #[tokio::test] async fn test_builder_helper_fn_any_network() { init_tracing(); diff --git a/crates/rpc-types-engine/Cargo.toml b/crates/rpc-types-engine/Cargo.toml index d4b68d91764..9b58bc0973f 100644 --- a/crates/rpc-types-engine/Cargo.toml +++ b/crates/rpc-types-engine/Cargo.toml @@ -41,7 +41,7 @@ jsonrpsee-types = { version = "0.24", optional = true } # jwt rand = { workspace = true, optional = true } -jsonwebtoken = { version = "9.3.0", optional = true } +jsonwebtoken = { workspace = true, optional = true } [features] default = ["jwt", "std", "serde"] diff --git a/crates/rpc-types-engine/src/jwt.rs b/crates/rpc-types-engine/src/jwt.rs index 768ee132da1..b9b8e68862e 100644 --- a/crates/rpc-types-engine/src/jwt.rs +++ b/crates/rpc-types-engine/src/jwt.rs @@ -258,6 +258,11 @@ impl JwtSecret { let algo = jsonwebtoken::Header::new(Algorithm::HS256); jsonwebtoken::encode(&algo, claims, &key) } + + /// Returns the secret key as a byte slice. + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } } impl core::fmt::Debug for JwtSecret { diff --git a/crates/transport-http/Cargo.toml b/crates/transport-http/Cargo.toml index 35f757ad3c2..8ed30d905f6 100644 --- a/crates/transport-http/Cargo.toml +++ b/crates/transport-http/Cargo.toml @@ -34,6 +34,10 @@ http-body-util = { workspace = true, optional = true } hyper = { workspace = true, default-features = false, optional = true } hyper-util = { workspace = true, features = ["full"], optional = true } +# auth layer +alloy-rpc-types-engine = { workspace = true, optional = true } +jsonwebtoken = { workspace = true, optional = true } + [features] default = ["reqwest", "reqwest-default-tls"] reqwest = [ @@ -52,6 +56,13 @@ hyper = [ "dep:tower", "dep:tracing", ] +jwt-auth = [ + "hyper", + "dep:alloy-rpc-types-engine", + "alloy-rpc-types-engine/jwt", + "alloy-rpc-types-engine/serde", + "dep:jsonwebtoken", +] reqwest-default-tls = ["reqwest?/default-tls"] reqwest-native-tls = ["reqwest?/native-tls"] reqwest-rustls-tls = ["reqwest?/rustls-tls"] diff --git a/crates/transport-http/src/hyper_transport.rs b/crates/transport-http/src/hyper_transport.rs index 389d0289c03..6346f66a023 100644 --- a/crates/transport-http/src/hyper_transport.rs +++ b/crates/transport-http/src/hyper_transport.rs @@ -69,12 +69,15 @@ impl HyperClient { } } -impl Http> +impl Http> where - S: Service, Response = HyperResponse> + Clone + Send + Sync + 'static, + S: Service, Response = Response> + Clone + Send + Sync + 'static, S::Future: Send, S::Error: std::error::Error + Send + Sync + 'static, B: From> + Send + 'static + Clone, + ResBody: BodyExt + Send + 'static, + ResBody::Error: std::error::Error + Send + Sync + 'static, + ResBody::Data: Send, { /// Make a request to the server using the given service. fn request_hyper(&self, req: RequestPacket) -> TransportFut<'static> { diff --git a/crates/transport-http/src/layers/auth.rs b/crates/transport-http/src/layers/auth.rs new file mode 100644 index 00000000000..bf68757bbc1 --- /dev/null +++ b/crates/transport-http/src/layers/auth.rs @@ -0,0 +1,142 @@ +use crate::hyper::{header::AUTHORIZATION, Request, Response}; +use alloy_rpc_types_engine::{Claims, JwtSecret}; +use alloy_transport::{TransportError, TransportErrorKind}; +use hyper::header::HeaderValue; +use jsonwebtoken::get_current_timestamp; +use std::{ + future::Future, + pin::Pin, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; +use tower::{Layer, Service}; + +/// The [`AuthLayer`] uses the provided [`JwtSecret`] to generate and validate the jwt token +/// in the requests. +/// +/// The generated token is inserted into the [`AUTHORIZATION`] header of the request. +#[derive(Clone, Debug)] +pub struct AuthLayer { + secret: JwtSecret, + latency_buffer: u64, +} + +impl AuthLayer { + /// Create a new [`AuthLayer`]. + pub const fn new(secret: JwtSecret) -> Self { + Self { secret, latency_buffer: 5000 } + } + + /// We use this buffer to perfom an extra check on the `iat` field to prevent sending any + /// requests with tokens that are valid now but may not be upon reaching the server. + /// + /// In milliseconds. Default is 5s. + pub const fn with_latency_buffer(self, latency_buffer: u64) -> Self { + Self { latency_buffer, ..self } + } +} + +impl Layer for AuthLayer { + type Service = AuthService; + + fn layer(&self, inner: S) -> Self::Service { + AuthService::new(inner, self.secret, self.latency_buffer) + } +} + +/// A service that generates and validates the jwt token in the requests using the provided secret. +#[derive(Clone, Debug)] +pub struct AuthService { + inner: S, + secret: JwtSecret, + /// In milliseconds. + latency_buffer: u64, + most_recent_claim: Option, +} + +impl AuthService { + /// Create a new [`AuthService`] with the given inner service. + pub const fn new(inner: S, secret: JwtSecret, latency_buffer: u64) -> Self { + Self { inner, secret, latency_buffer, most_recent_claim: None } + } + + /// Validate the token in the request headers. + /// + /// Returns `true` if the token is still valid and `iat` is beyond the grace buffer. + fn validate(&self) -> bool { + if let Some(claim) = self.most_recent_claim.as_ref() { + let curr_secs = get_current_timestamp(); + if claim.iat.abs_diff(curr_secs) * 1000 > self.latency_buffer { + return true; + } + } + + false + } + + /// Create a new token from the secret. + /// + /// Updates the most_recent_claim with the new claim. + fn create_token_from_secret(&mut self) -> Result { + let claims = Claims { + iat: (SystemTime::now().duration_since(UNIX_EPOCH).unwrap() + Duration::from_secs(60)) + .as_secs(), + exp: None, + }; + + self.most_recent_claim = Some(claims); + + let token = self.secret.encode(&claims)?; + + Ok(format!("Bearer {}", token)) + } +} + +impl Service> for AuthService +where + S: Service, Response = Response> + Clone + Send + Sync + 'static, + S::Future: Send, + S::Error: std::error::Error + Send + Sync + 'static, + B: From> + Send + 'static + Clone + Sync, + ResBody: hyper::body::Body + Send + 'static, + ResBody::Error: std::error::Error + Send + Sync + 'static, + ResBody::Data: Send, +{ + type Response = Response; + type Error = TransportError; + type Future = + Pin, Self::Error>> + Send + 'static>>; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx).map_err(TransportErrorKind::custom) + } + + fn call(&mut self, req: Request) -> Self::Future { + let mut req = req; + let res = if self.validate() { + // Encodes the most recent claim into a token. + self.secret.encode(self.most_recent_claim.as_ref().unwrap()) + } else { + // Creates a new Claim and encodes it into a token. + self.create_token_from_secret() + }; + + match res { + Ok(token) => { + req.headers_mut().insert(AUTHORIZATION, HeaderValue::from_str(&token).unwrap()); + + let mut this = self.clone(); + + Box::pin( + async move { this.inner.call(req).await.map_err(TransportErrorKind::custom) }, + ) + } + Err(e) => { + let e = TransportErrorKind::custom(e); + Box::pin(async move { Err(e) }) + } + } + } +} diff --git a/crates/transport-http/src/layers/mod.rs b/crates/transport-http/src/layers/mod.rs new file mode 100644 index 00000000000..26f80220765 --- /dev/null +++ b/crates/transport-http/src/layers/mod.rs @@ -0,0 +1,7 @@ +//! tower http like layer implementations that work over the http::Request type. +#![cfg(all(not(target_arch = "wasm32"), feature = "hyper"))] + +#[cfg(feature = "jwt-auth")] +mod auth; +#[cfg(feature = "jwt-auth")] +pub use auth::{AuthLayer, AuthService}; diff --git a/crates/transport-http/src/lib.rs b/crates/transport-http/src/lib.rs index d1881fefaae..d81685399c3 100644 --- a/crates/transport-http/src/lib.rs +++ b/crates/transport-http/src/lib.rs @@ -20,6 +20,11 @@ pub use hyper; #[cfg(all(not(target_arch = "wasm32"), feature = "hyper"))] pub use hyper_util; +#[cfg(all(not(target_arch = "wasm32"), feature = "hyper", feature = "jwt-auth"))] +mod layers; +#[cfg(all(not(target_arch = "wasm32"), feature = "hyper", feature = "jwt-auth"))] +pub use layers::{AuthLayer, AuthService}; + #[cfg(all(not(target_arch = "wasm32"), feature = "hyper"))] mod hyper_transport; #[cfg(all(not(target_arch = "wasm32"), feature = "hyper"))]