Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: prefill chunking #2600

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
7169cba
wip
OlivierDehaene Sep 20, 2024
de043b5
rollback
OlivierDehaene Sep 25, 2024
838756e
refactor to use prefix/postfix namming + fix all_input_ids_tensor
OlivierDehaene Sep 25, 2024
e4f9110
maybe patching vlms?
OlivierDehaene Sep 25, 2024
a85f5eb
fix filter and concat
OlivierDehaene Sep 25, 2024
962ccfd
wip, no filter, no concat
OlivierDehaene Sep 26, 2024
0e31619
current
OlivierDehaene Sep 30, 2024
173bc99
add prepare_for_prefill
OlivierDehaene Sep 30, 2024
34f5dc5
working
OlivierDehaene Oct 1, 2024
7f9abde
load tested
OlivierDehaene Oct 2, 2024
4db5e7d
re-create slots
OlivierDehaene Oct 2, 2024
b49978f
re-create slots
OlivierDehaene Oct 2, 2024
ff4155d
fix slot_filtering_indices
OlivierDehaene Oct 2, 2024
c8a033b
feedback loop
OlivierDehaene Oct 7, 2024
4ddea01
remove log
OlivierDehaene Oct 7, 2024
460e830
fix benchmarker
OlivierDehaene Oct 7, 2024
8188dea
fix vlm and seq2seq
OlivierDehaene Oct 7, 2024
3924b87
rename to cache and input lengths
OlivierDehaene Oct 7, 2024
ea4b739
fix prefill logprobs
OlivierDehaene Oct 7, 2024
08953c5
fix launcher
OlivierDehaene Oct 8, 2024
3ace1b2
fix logprobs?
OlivierDehaene Oct 9, 2024
57f55fe
idk at this point
OlivierDehaene Oct 9, 2024
d73c5c6
max input length
OlivierDehaene Oct 9, 2024
d361197
omfg
OlivierDehaene Oct 9, 2024
f85a308
remove debugging lines
OlivierDehaene Oct 9, 2024
b7a1280
fix tests
OlivierDehaene Oct 10, 2024
f923a3f
fix mllama
OlivierDehaene Oct 10, 2024
df98299
fix cargo tests
OlivierDehaene Oct 10, 2024
5e70158
remove support chunking for paged
OlivierDehaene Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions backends/client/src/v3/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
prefix_len: 0,
cache_len: 0,
chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
Expand Down Expand Up @@ -217,8 +218,13 @@ impl Client {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
Expand Down
8 changes: 5 additions & 3 deletions backends/client/src/v3/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,12 @@ impl ShardedClient {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
Expand Down Expand Up @@ -245,7 +246,8 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
prefix_len: 0,
cache_len: 0,
chunk_len: None,
adapter_id: None,
};
let batch = Batch {
Expand All @@ -255,7 +257,7 @@ impl Health for ShardedClient {
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch).await?;
self.clone().prefill(batch, None).await?;
Ok(())
}
}
20 changes: 8 additions & 12 deletions backends/v2/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
Expand Down Expand Up @@ -36,18 +36,14 @@ impl BackendV2 {
speculate: u32,
) -> Self {
// Infer shared state
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else {
Attention::Paged
};
let block_size = if attention == Attention::FlashDecoding {
256
} else {
16
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
let block_size = match attention.as_str() {
"flashinfer" => 1,
"flashdecoding" => 256,
"paged" => 16,
_ => unreachable!(),
};

let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());

Expand Down
134 changes: 84 additions & 50 deletions backends/v3/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
/// Batching and inference logic
use crate::client::{
Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient,
};
use crate::queue::{Entry, Queue};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
Expand All @@ -31,27 +33,22 @@ impl BackendV3 {
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
shard_info: InfoResponse,
) -> Self {
let prefix_caching =
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
let attention: String = std::env::var("ATTENTION").expect("attention env var");
if shard_info.support_chunking {
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
}

let attention: Attention = attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
let block_size = attention.block_size();
let block_size = shard_info.block_size;

let queue = Queue::new(
requires_padding,
shard_info.requires_padding,
block_size,
prefix_caching,
window_size,
speculate,
shard_info.use_prefix_caching,
shard_info.window_size,
shard_info.speculate,
max_batch_total_tokens,
shard_info.support_chunking,
);
let batching_task_notifier = Arc::new(Notify::new());

Expand All @@ -63,6 +60,7 @@ impl BackendV3 {
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.support_chunking,
queue.clone(),
batching_task_notifier.clone(),
));
Expand Down Expand Up @@ -127,6 +125,7 @@ pub(crate) async fn batching_task(
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
support_chunking: bool,
queue: Queue,
notifier: Arc<Notify>,
) {
Expand All @@ -147,7 +146,7 @@ pub(crate) async fn batching_task(
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries)
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1;
Expand All @@ -158,60 +157,90 @@ pub(crate) async fn batching_task(
// Get current batch info
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let current_tokens = batch.current_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);

let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);

let (min_size, max_size, prefill_token_budget) = if support_chunking {
// Since the next batch will be concatenated with the current batch,
// the current batch tokens must be subtracted to the prefill budget
let prefill_token_budget =
max_batch_prefill_tokens.saturating_sub(current_tokens);
// We can ignore min_size and max_size
// Models than rely on max_size cannot support chunking
// Regarding min_size, chunking allow us to consistently run at the compute
// bound, making min_size useless.
(None, None, prefill_token_budget)
} else {
// Minimum batch size
// TODO: temporarily disable to avoid incorrect deallocation +
// reallocation when using prefix caching.
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
} else {
// Minimum batch size
// TODO: temporarily disable to avoid incorrect deallocation +
// reallocation when using prefix caching.
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};

let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size =
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
let max_size =
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));

(min_size, max_size, max_batch_prefill_tokens)
};

// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
if let Some((new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
let counter = if support_chunking {
metrics::counter!("tgi_batch_concat", "reason" => "chunking")
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
};
counter.increment(1);
}

entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
let cached_batch = if support_chunking {
// Concat current batch to the new one
batches.pop()
} else {
// Request are waiting only if we don't support chunking
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
None
};
entries.extend(new_entries);

// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
.instrument(span)
.await;
let new_cached_batch =
prefill(&mut client, new_batch, cached_batch, &mut entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
} else if support_chunking {
// New cached batch is empty, no work left
break;
}
}

Expand Down Expand Up @@ -244,13 +273,14 @@ pub(crate) async fn batching_task(
async fn prefill(
client: &mut ShardedClient,
batch: Batch,
cached_batch: Option<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);

match client.prefill(batch).await {
match client.prefill(batch, cached_batch).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
Expand All @@ -259,6 +289,10 @@ async fn prefill(
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;

if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
Expand Down
21 changes: 17 additions & 4 deletions backends/v3/src/client/grpc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
prefix_len: 0,
cache_len: 0,
chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
Expand Down Expand Up @@ -217,13 +218,23 @@ impl Client {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
PrefillTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}

Expand Down Expand Up @@ -252,14 +263,16 @@ impl Client {
}

pub struct PrefillTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}

impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
Expand Down
9 changes: 0 additions & 9 deletions backends/v3/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,6 @@ pub trait Health {
async fn model_health(&self) -> Result<()>;
}

#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}

#[derive(Error, Debug, Clone)]
pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")]
Expand Down
Loading
Loading