Skip to content

Commit

Permalink
feat: add payload limit (#2726)
Browse files Browse the repository at this point in the history
* feat: add payload limit

* update launcher
  • Loading branch information
OlivierDehaene authored Nov 21, 2024
1 parent d5bc6a2 commit ab7ccf5
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 9 deletions.
5 changes: 5 additions & 0 deletions backends/trtllm/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ struct Args {
executor_worker: PathBuf,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}

async fn get_tokenizer(
Expand Down Expand Up @@ -217,6 +219,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
auth_token,
executor_worker,
usage_stats,
payload_limit,
} = args;

// Launch Tokio runtime
Expand Down Expand Up @@ -287,6 +290,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
tokenizer_name,
tokenizer_config_path,
revision,
false,
hostname,
port,
cors_allow_origin,
Expand All @@ -296,6 +300,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
true,
max_client_batch_size,
usage_stats,
payload_limit,
)
.await?;
Ok(())
Expand Down
4 changes: 4 additions & 0 deletions backends/v2/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ struct Args {
max_client_batch_size: usize,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}

#[derive(Debug, Subcommand)]
Expand Down Expand Up @@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support,
max_client_batch_size,
usage_stats,
payload_limit,
} = args;

if let Some(Commands::PrintSchema) = command {
Expand Down Expand Up @@ -194,6 +197,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support,
max_client_batch_size,
usage_stats,
payload_limit,
)
.await?;
Ok(())
Expand Down
4 changes: 4 additions & 0 deletions backends/v3/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ struct Args {
max_client_batch_size: usize,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}

#[derive(Debug, Subcommand)]
Expand Down Expand Up @@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support,
max_client_batch_size,
usage_stats,
payload_limit,
} = args;

if let Some(Commands::PrintSchema) = command {
Expand Down Expand Up @@ -210,6 +213,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support,
max_client_batch_size,
usage_stats,
payload_limit,
)
.await?;
Ok(())
Expand Down
11 changes: 11 additions & 0 deletions docs/source/reference/launcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,17 @@ Options:
- off: Disables all collection of usage statistics
- no-stack: Doesn't send the error stack trace or error type, but allows sending a crash event

```
## PAYLOAD_LIMIT
```shell
--payload-limit <PAYLOAD_LIMIT>
Payload size limit in bytes

Default is 2MB

[env: PAYLOAD_LIMIT=]
[default: 2000000]

```
## HELP
```shell
Expand Down
8 changes: 8 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,12 @@ struct Args {
/// Defaul is on.
#[clap(default_value = "on", long, env)]
usage_stats: UsageStatsLevel,

/// Payload size limit in bytes
///
/// Default is 2MB
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}

#[derive(Debug)]
Expand Down Expand Up @@ -1479,6 +1485,8 @@ fn spawn_webserver(
format!("{}-0", args.shard_uds_path),
"--tokenizer-name".to_string(),
args.model_id,
"--payload-limit".to_string(),
args.payload_limit.to_string(),
];
if let Some(max_input_tokens) = max_input_tokens {
router_args.extend_from_slice(&[
Expand Down
6 changes: 5 additions & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::{
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream;
use axum::extract::Extension;
use axum::extract::{DefaultBodyLimit, Extension};
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
Expand Down Expand Up @@ -1674,6 +1674,7 @@ pub async fn run(
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel,
payload_limit: usize,
) -> Result<(), WebServerError> {
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
Expand Down Expand Up @@ -1928,6 +1929,7 @@ pub async fn run(
model_info,
compat_return_full_text,
allow_origin,
payload_limit,
)
.await;

Expand Down Expand Up @@ -1987,6 +1989,7 @@ async fn start(
model_info: HubModelInfo,
compat_return_full_text: bool,
allow_origin: Option<AllowOrigin>,
payload_limit: usize,
) -> Result<(), WebServerError> {
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
Expand Down Expand Up @@ -2384,6 +2387,7 @@ async fn start(
.layer(Extension(compute_type))
.layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default())
.layer(DefaultBodyLimit::max(payload_limit))
.layer(cors_layer);

tracing::info!("Connected");
Expand Down
11 changes: 5 additions & 6 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,9 +962,9 @@ def prepare_for_prefill(self):
self.input_lengths_tensor = torch.tensor(
self.input_lengths, dtype=torch.int32, device=device
)
self.cu_seqlen_prefill = torch.nn.functional.pad(
torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)
).to(torch.int32)
cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1)
torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
self.cache_lengths_tensor = torch.tensor(
self.cache_lengths, dtype=torch.int32, device=device
)
Expand Down Expand Up @@ -2020,9 +2020,8 @@ def generate_token(

# For each member of the batch
# Cumulative length
cu_accepted_ids = torch.nn.functional.pad(
torch.cumsum(accepted_ids, dim=0), (1, 0)
)
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
cumulative_length = 0
for i, (
request,
Expand Down
5 changes: 3 additions & 2 deletions server/text_generation_server/models/metadata_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def block_tables_to_ragged(
)

if has_triton():
cu_seqlen = torch.nn.functional.pad(
torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0)
cu_seqlen = input_lengths_tensor.new_zeros(input_lengths_tensor.shape[0] + 1)
torch.cumsum(
input_lengths_tensor + cache_lengths_tensor, out=cu_seqlen[1:], dim=0
)

def grid(meta):
Expand Down

0 comments on commit ab7ccf5

Please sign in to comment.