diff --git a/Cargo.lock b/Cargo.lock index 9fed9832af..6b867412b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2315,6 +2315,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util 0.7.10", + "tower", "tower-http", "unindent", ] @@ -3555,6 +3556,7 @@ dependencies = [ "pin-project", "pin-project-lite", "tokio", + "tokio-util 0.7.10", "tower-layer", "tower-service", "tracing", diff --git a/Cargo.toml b/Cargo.toml index d8f9cad265..cfeeb941e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,6 +71,7 @@ tower-http = { version = "0.4.0", features = ["compression-br", "compression-gzi rdkafka = { version = "0.33.2" } axum-jrpc = { version = "0.5.1", features = ["serde_json", "anyhow_error"] } ord-kafka-macros = { path = "ord-kafka-macros" } +tower = { version = "0.4.13", features = ["limit"] } [dev-dependencies] diff --git a/src/subcommand/server.rs b/src/subcommand/server.rs index de4cdba68f..691ba94cf7 100644 --- a/src/subcommand/server.rs +++ b/src/subcommand/server.rs @@ -41,6 +41,7 @@ use { }, std::{cmp::Ordering, str, sync::Arc}, tokio_stream::StreamExt, + tower::limit::concurrency::ConcurrencyLimitLayer, tower_http::{ compression::CompressionLayer, cors::{Any, CorsLayer}, @@ -172,11 +173,17 @@ pub(crate) struct Server { help = "Timeout requests after seconds. Default: 30 seconds." )] timeout: Option, + #[clap(long, help = "Set max concurrent connections. Default: 1024")] + max_connections: Option, } impl Server { pub(crate) fn run(self, options: Options, index: Arc, handle: Handle) -> SubcommandResult { Runtime::new()?.block_on(async { + log::debug!( + "Starting server with {} max connections", + self.max_connections.unwrap_or(1024) + ); let index_clone = index.clone(); let index_thread = thread::spawn(move || loop { @@ -276,7 +283,14 @@ impl Server { .route("/tx/:txid", get(Self::transaction)) // API routes - .route("/rpc/v1", post(rpc::handler)) + .route("/rpc/v1", post(rpc::handler) + .route_layer(TimeoutLayer::new(Duration::from_secs(self.timeout.unwrap_or(30)))) + .route_layer( + ConcurrencyLimitLayer::new( + self.max_connections.unwrap_or(1024), + ) + ) + ) .layer(axum::middleware::from_fn(middleware::tracing_layer)) .layer(Extension(index)) .layer(Extension(page_config)) @@ -295,7 +309,6 @@ impl Server { .allow_origin(Any), ) .layer(CompressionLayer::new()) - .layer(TimeoutLayer::new(Duration::from_secs(self.timeout.unwrap_or(30)))) .with_state(server_config); match (self.http_port(), self.https_port()) {