Skip to content

Commit

Permalink
Authenticated json rpc endpoint (#5250)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthias Seitz <matthias.seitz@outlook.de>
  • Loading branch information
1010adigupta and mattsse committed Nov 8, 2023
1 parent 3c9633b commit d2136f8
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 8 deletions.
2 changes: 1 addition & 1 deletion crates/rpc/rpc-builder/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ where
pub struct AuthServerConfig {
/// Where the server should listen.
pub(crate) socket_addr: SocketAddr,
/// The secrete for the auth layer of the server.
/// The secret for the auth layer of the server.
pub(crate) secret: JwtSecret,
/// Configs for JSON-RPC Http.
pub(crate) server_config: ServerBuilder,
Expand Down
73 changes: 66 additions & 7 deletions crates/rpc/rpc-builder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ use reth_rpc::{
cache::{cache_new_blocks_task, EthStateCache},
gas_oracle::GasPriceOracle,
},
AdminApi, BlockingTaskGuard, BlockingTaskPool, DebugApi, EngineEthApi, EthApi, EthFilter,
EthPubSub, EthSubscriptionIdProvider, NetApi, OtterscanApi, RPCApi, RethApi, TraceApi,
TxPoolApi, Web3Api,
AdminApi, AuthLayer, BlockingTaskGuard, BlockingTaskPool, DebugApi, EngineEthApi, EthApi,
EthFilter, EthPubSub, EthSubscriptionIdProvider, JwtAuthValidator, JwtSecret, NetApi,
OtterscanApi, RPCApi, RethApi, TraceApi, TxPoolApi, Web3Api,
};
use reth_rpc_api::{servers::*, EngineApiServer};
use reth_tasks::{TaskSpawner, TokioTaskExecutor};
Expand Down Expand Up @@ -1141,6 +1141,8 @@ pub struct RpcServerConfig {
ipc_server_config: Option<IpcServerBuilder>,
/// The Endpoint where to launch the ipc server
ipc_endpoint: Option<Endpoint>,
/// JWT secret for authentication
jwt_secret: Option<JwtSecret>,
}

impl fmt::Debug for RpcServerConfig {
Expand All @@ -1153,6 +1155,7 @@ impl fmt::Debug for RpcServerConfig {
.field("ws_addr", &self.ws_addr)
.field("ipc_server_config", &self.ipc_server_config)
.field("ipc_endpoint", &self.ipc_endpoint.as_ref().map(|endpoint| endpoint.path()))
.field("jwt_secret", &self.jwt_secret)
.finish()
}
}
Expand Down Expand Up @@ -1264,6 +1267,12 @@ impl RpcServerConfig {
self
}

/// Configures the JWT secret for authentication
pub fn with_jwt_secret(mut self, secret: JwtSecret) -> Self {
self.jwt_secret = Some(secret);
self
}

/// Returns true if any server is configured.
///
/// If no server is configured, no server will be be launched on [RpcServerConfig::start].
Expand Down Expand Up @@ -1328,6 +1337,8 @@ impl RpcServerConfig {
}
.cloned();

let secret = self.jwt_secret.take();

// we merge this into one server using the http setup
self.ws_server_config.take();

Expand All @@ -1336,6 +1347,7 @@ impl RpcServerConfig {
builder,
http_socket_addr,
cors,
secret,
ServerKind::WsHttp(http_socket_addr),
metrics.clone(),
)
Expand All @@ -1358,6 +1370,7 @@ impl RpcServerConfig {
builder,
ws_socket_addr,
self.ws_cors_domains.take(),
self.jwt_secret.take(),
ServerKind::WS(ws_socket_addr),
metrics.clone(),
)
Expand All @@ -1372,6 +1385,7 @@ impl RpcServerConfig {
builder,
http_socket_addr,
self.http_cors_domains.take(),
self.jwt_secret.take(),
ServerKind::Http(http_socket_addr),
metrics.clone(),
)
Expand Down Expand Up @@ -1667,6 +1681,12 @@ enum WsHttpServerKind {
Plain(Server<Identity, RpcServerMetrics>),
/// Http server with cors
WithCors(Server<Stack<CorsLayer, Identity>, RpcServerMetrics>),
/// Http server with auth
WithAuth(Server<Stack<AuthLayer<JwtAuthValidator>, Identity>, RpcServerMetrics>),
/// Http server with cors and auth
WithCorsAuth(
Server<Stack<AuthLayer<JwtAuthValidator>, Stack<CorsLayer, Identity>>, RpcServerMetrics>,
),
}

// === impl WsHttpServerKind ===
Expand All @@ -1677,30 +1697,69 @@ impl WsHttpServerKind {
match self {
WsHttpServerKind::Plain(server) => server.start(module),
WsHttpServerKind::WithCors(server) => server.start(module),
WsHttpServerKind::WithAuth(server) => server.start(module),
WsHttpServerKind::WithCorsAuth(server) => server.start(module),
}
}

/// Builds
/// Builds the server according to the given config parameters.
///
/// Returns the address of the started server.
async fn build(
builder: ServerBuilder,
socket_addr: SocketAddr,
cors_domains: Option<String>,
auth_secret: Option<JwtSecret>,
server_kind: ServerKind,
metrics: RpcServerMetrics,
) -> Result<(Self, SocketAddr), RpcError> {
if let Some(cors) = cors_domains.as_deref().map(cors::create_cors_layer) {
let cors = cors.map_err(|err| RpcError::Custom(err.to_string()))?;
let middleware = tower::ServiceBuilder::new().layer(cors);

if let Some(secret) = auth_secret {
// stack cors and auth layers
let middleware = tower::ServiceBuilder::new()
.layer(cors)
.layer(AuthLayer::new(JwtAuthValidator::new(secret.clone())));

let server = builder
.set_middleware(middleware)
.set_logger(metrics)
.build(socket_addr)
.await
.map_err(|err| RpcError::from_jsonrpsee_error(err, server_kind))?;
let local_addr = server.local_addr()?;
let server = WsHttpServerKind::WithCorsAuth(server);
Ok((server, local_addr))
} else {
let middleware = tower::ServiceBuilder::new().layer(cors);
let server = builder
.set_middleware(middleware)
.set_logger(metrics)
.build(socket_addr)
.await
.map_err(|err| RpcError::from_jsonrpsee_error(err, server_kind))?;
let local_addr = server.local_addr()?;
let server = WsHttpServerKind::WithCors(server);
Ok((server, local_addr))
}
} else if let Some(secret) = auth_secret {
// jwt auth layered service
let middleware = tower::ServiceBuilder::new()
.layer(AuthLayer::new(JwtAuthValidator::new(secret.clone())));
let server = builder
.set_middleware(middleware)
.set_logger(metrics)
.build(socket_addr)
.await
.map_err(|err| RpcError::from_jsonrpsee_error(err, server_kind))?;
.map_err(|err| {
RpcError::from_jsonrpsee_error(err, ServerKind::Auth(socket_addr))
})?;
let local_addr = server.local_addr()?;
let server = WsHttpServerKind::WithCors(server);
let server = WsHttpServerKind::WithAuth(server);
Ok((server, local_addr))
} else {
// plain server without any middleware
let server = builder
.set_logger(metrics)
.build(socket_addr)
Expand Down

0 comments on commit d2136f8

Please sign in to comment.