Skip to content

Commit

Permalink
feedback loop
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Oct 7, 2024
1 parent 7268e11 commit bdbc652
Show file tree
Hide file tree
Showing 16 changed files with 153 additions and 163 deletions.
13 changes: 7 additions & 6 deletions backends/client/src/v3/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,13 @@ impl Client {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
Expand All @@ -237,11 +242,7 @@ impl Client {
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest {
batch: None,
batches,
})
.inject_context();
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,
Expand Down
5 changes: 3 additions & 2 deletions backends/client/src/v3/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,12 @@ impl ShardedClient {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
Expand Down Expand Up @@ -256,7 +257,7 @@ impl Health for ShardedClient {
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch).await?;
self.clone().prefill(batch, None).await?;
Ok(())
}
}
20 changes: 8 additions & 12 deletions backends/v2/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
Expand Down Expand Up @@ -36,18 +36,14 @@ impl BackendV2 {
speculate: u32,
) -> Self {
// Infer shared state
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else {
Attention::Paged
};
let block_size = if attention == Attention::FlashDecoding {
256
} else {
16
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
let block_size = match attention.as_str() {
"flashinfer" => 1,
"flashdecoding" => 256,
"paged" => 16,
_ => unreachable!(),
};

let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());

Expand Down
86 changes: 43 additions & 43 deletions backends/v3/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
/// Batching and inference logic
use crate::client::{
Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient,
};
use crate::queue::{Entry, Queue};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
Expand All @@ -31,32 +33,22 @@ impl BackendV3 {
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
support_chunking: bool,
shard_info: InfoResponse,
) -> Self {
if support_chunking {
if shard_info.support_chunking {
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
}

let prefix_caching = std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string());

let attention: Attention = attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
let block_size = attention.block_size();
let block_size = shard_info.block_size;

let queue = Queue::new(
requires_padding,
shard_info.requires_padding,
block_size,
prefix_caching,
window_size,
speculate,
shard_info.use_prefix_caching,
shard_info.window_size,
shard_info.speculate,
max_batch_total_tokens,
support_chunking,
shard_info.support_chunking,
);
let batching_task_notifier = Arc::new(Notify::new());

Expand All @@ -68,7 +60,7 @@ impl BackendV3 {
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
support_chunking,
shard_info.support_chunking,
queue.clone(),
batching_task_notifier.clone(),
));
Expand Down Expand Up @@ -154,7 +146,7 @@ pub(crate) async fn batching_task(
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries)
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1;
Expand All @@ -175,7 +167,8 @@ pub(crate) async fn batching_task(
let (min_size, max_size, prefill_token_budget) = if support_chunking {
// Since the next batch will be concatenated with the current batch,
// the current batch tokens must be subtracted to the prefill budget
let prefill_token_budget = max_batch_prefill_tokens - current_tokens;
let prefill_token_budget =
max_batch_prefill_tokens.saturating_sub(current_tokens);
// We can ignore min_size and max_size
// Models than rely on max_size cannot support chunking
// Regarding min_size, chunking allow us to consistently run at the compute
Expand All @@ -199,10 +192,8 @@ pub(crate) async fn batching_task(
(min_size, max_size, max_batch_prefill_tokens)
};

let mut additional_batch = None;

// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
if let Some((new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
.await
{
Expand All @@ -218,11 +209,11 @@ pub(crate) async fn batching_task(
};
counter.increment(1);
}

if support_chunking {
entries.extend(new_entries);
additional_batch = Some(new_batch);
let cached_batch = if support_chunking {
// Concat current batch to the new one
batches.pop()
} else {
// Request are waiting only if we don't support chunking
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
Expand All @@ -233,18 +224,23 @@ pub(crate) async fn batching_task(
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
None
};
entries.extend(new_entries);

// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
prefill(&mut client, new_batch, cached_batch, &mut entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
}
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
batches.push(new_cached_batch);
} else if support_chunking {
// New cached batch is empty, no work left
break;
}
}

Expand All @@ -262,7 +258,7 @@ pub(crate) async fn batching_task(
entry.temp_span = Some(entry_batch_span);
});

cached_batch = decode(&mut client, additional_batch, batches, &mut entries)
cached_batch = decode(&mut client, batches, &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
Expand All @@ -277,13 +273,14 @@ pub(crate) async fn batching_task(
async fn prefill(
client: &mut ShardedClient,
batch: Batch,
cached_batch: Option<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);

match client.prefill(batch).await {
match client.prefill(batch, cached_batch).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
Expand All @@ -292,6 +289,10 @@ async fn prefill(
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;

if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
Expand All @@ -316,15 +317,14 @@ async fn prefill(
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batch: Option<Batch>,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);

match client.decode(batch, batches).await {
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
Expand Down
21 changes: 16 additions & 5 deletions backends/v3/src/client/grpc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,23 @@ impl Client {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
PrefillTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}

Expand All @@ -235,10 +245,9 @@ impl Client {
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
&mut self,
batch: Option<Batch>,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches, batch }).inject_context();
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,
Expand All @@ -254,14 +263,16 @@ impl Client {
}

pub struct PrefillTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}

impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
Expand Down
8 changes: 4 additions & 4 deletions backends/v3/src/client/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,12 @@ impl ShardedClient {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
Expand Down Expand Up @@ -167,13 +168,12 @@ impl ShardedClient {
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batch: Option<Batch>,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batch.clone(), batches.clone())))
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
Expand Down Expand Up @@ -246,7 +246,7 @@ impl Health for ShardedClient {
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch).await?;
self.clone().prefill(batch, None).await?;
Ok(())
}
}
Loading

0 comments on commit bdbc652

Please sign in to comment.