Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: track performance #32

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions atoma-proxy/src/server/chat_completions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::sync::Arc;
use std::time::Duration;
use std::time::{Duration, Instant};

use atoma_state::types::AtomaAtomaStateManagerEvent;
use axum::body::Body;
Expand Down Expand Up @@ -99,12 +99,19 @@ pub async fn chat_completions_handler(
"include_usage": true
});
}
let (node_address, signature, selected_stack_small_id, headers, estimated_total_tokens) =
authenticate_and_process(&state, headers, &payload).await?;
let (
node_address,
node_id,
signature,
selected_stack_small_id,
headers,
estimated_total_tokens,
) = authenticate_and_process(&state, headers, &payload).await?;
if is_streaming {
handle_streaming_response(
state,
node_address,
node_id,
signature,
selected_stack_small_id,
headers,
Expand All @@ -116,6 +123,7 @@ pub async fn chat_completions_handler(
handle_non_streaming_response(
state,
node_address,
node_id,
signature,
selected_stack_small_id,
headers,
Expand Down Expand Up @@ -177,16 +185,19 @@ pub async fn chat_completions_handler(
payload_hash
)
)]
#[allow(clippy::too_many_arguments)]
async fn handle_non_streaming_response(
state: ProxyState,
node_address: String,
node_id: i64,
signature: String,
selected_stack_small_id: i64,
headers: HeaderMap,
payload: Value,
estimated_total_tokens: i64,
) -> Result<Response<Body>, StatusCode> {
let client = reqwest::Client::new();
let time = Instant::now();
let response = client
.post(format!("{}{}", node_address, CHAT_COMPLETIONS_PATH))
.headers(headers)
Expand Down Expand Up @@ -216,6 +227,20 @@ async fn handle_non_streaming_response(
.map(|n| n as i64)
.unwrap_or(0);

let input_tokens = response
.get("usage")
.and_then(|usage| usage.get("completion_tokens"))
.and_then(|total_tokens| total_tokens.as_u64())
.map(|n| n as i64)
.unwrap_or(0);

let output_tokens = response
.get("usage")
.and_then(|usage| usage.get("prompt_tokens"))
.and_then(|total_tokens| total_tokens.as_u64())
.map(|n| n as i64)
.unwrap_or(0);

// NOTE: We need to update the stack num tokens, because the inference response might have produced
// less tokens than estimated what we initially estimated, from the middleware.
if let Err(e) = utils::update_state_manager(
Expand All @@ -230,6 +255,21 @@ async fn handle_non_streaming_response(
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}

state
.state_manager_sender
.send(
AtomaAtomaStateManagerEvent::UpdateNodeThroughputPerformance {
node_small_id: node_id,
input_tokens,
output_tokens,
time: time.elapsed().as_secs_f64(),
Cifko marked this conversation as resolved.
Show resolved Hide resolved
},
)
.map_err(|err| {
error!("Error updating node throughput performance: {}", err);
StatusCode::INTERNAL_SERVER_ERROR
})?;

Ok(response.into_response())
}

Expand Down Expand Up @@ -278,9 +318,11 @@ async fn handle_non_streaming_response(
payload_hash
)
)]
#[allow(clippy::too_many_arguments)]
async fn handle_streaming_response(
state: ProxyState,
node_address: String,
node_id: i64,
signature: String,
selected_stack_small_id: i64,
headers: HeaderMap,
Expand All @@ -292,6 +334,7 @@ async fn handle_streaming_response(
// that were processed for this request.

let client = reqwest::Client::new();
let start = Instant::now();
let response = client
.post(format!("{}{}", node_address, CHAT_COMPLETIONS_PATH))
.headers(headers)
Expand Down Expand Up @@ -319,6 +362,8 @@ async fn handle_streaming_response(
state.state_manager_sender,
selected_stack_small_id,
estimated_total_tokens,
start,
node_id,
))
.keep_alive(
axum::response::sse::KeepAlive::new()
Expand All @@ -345,7 +390,7 @@ async fn authenticate_and_process(
state: &ProxyState,
headers: HeaderMap,
payload: &Value,
) -> Result<(String, String, i64, HeaderMap, u64), StatusCode> {
) -> Result<(String, i64, String, i64, HeaderMap, u64), StatusCode> {
// Authentication and payload extraction
let (model, max_tokens, messages, tokenizer_index) =
authenticate_and_extract(state, &headers, payload).await?;
Expand All @@ -354,6 +399,8 @@ async fn authenticate_and_process(
let total_tokens =
get_token_estimate(&messages, max_tokens, &state.tokenizers[tokenizer_index]).await?;

dbg!(&state.state_manager_sender);
Cifko marked this conversation as resolved.
Show resolved Hide resolved
dbg!(state.state_manager_sender.is_disconnected());
Cifko marked this conversation as resolved.
Show resolved Hide resolved
// Get node selection
let (selected_stack_small_id, selected_node_id) = get_selected_node(
&model,
Expand Down Expand Up @@ -408,6 +455,7 @@ async fn authenticate_and_process(

Ok((
node_address,
selected_node_id,
signature,
selected_stack_small_id,
headers,
Expand Down Expand Up @@ -594,6 +642,8 @@ async fn get_selected_node(
) -> Result<(i64, i64), StatusCode> {
let (result_sender, result_receiver) = oneshot::channel();

dbg!(&model);
Cifko marked this conversation as resolved.
Show resolved Hide resolved
dbg!(total_tokens);
Cifko marked this conversation as resolved.
Show resolved Hide resolved
state_manager_sender
.send(AtomaAtomaStateManagerEvent::GetStacksForModel {
model: model.to_string(),
Expand All @@ -616,6 +666,7 @@ async fn get_selected_node(
StatusCode::INTERNAL_SERVER_ERROR
})?;

dbg!(&stacks);
Cifko marked this conversation as resolved.
Show resolved Hide resolved
if stacks.is_empty() {
let (result_sender, result_receiver) = oneshot::channel();
state_manager_sender
Expand Down
57 changes: 57 additions & 0 deletions atoma-proxy/src/server/streamer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
pin::Pin,
task::{Context, Poll},
time::Instant,
};

use atoma_state::types::AtomaAtomaStateManagerEvent;
Expand Down Expand Up @@ -35,6 +36,12 @@ pub struct Streamer {
stack_small_id: i64,
/// State manager sender
state_manager_sender: Sender<AtomaAtomaStateManagerEvent>,
/// Start time of the request
start: Instant,
/// Node id that's running this request
node_id: i64,
/// Updated latency
updated_latency: bool,
}

/// Represents the various states of a streaming process
Expand All @@ -57,13 +64,18 @@ impl Streamer {
state_manager_sender: Sender<AtomaAtomaStateManagerEvent>,
stack_small_id: i64,
estimated_total_tokens: i64,
start: Instant,
node_id: i64,
) -> Self {
Self {
stream: Box::pin(stream),
status: StreamStatus::NotStarted,
estimated_total_tokens,
stack_small_id,
state_manager_sender,
start,
node_id,
updated_latency: false,
}
}

Expand Down Expand Up @@ -106,6 +118,22 @@ impl Streamer {
)
)]
fn handle_final_chunk(&mut self, usage: &Value) -> Result<(), Error> {
// Get input tokens
let input_tokens = usage
.get("prompt_tokens")
.and_then(|t| t.as_i64())
.ok_or_else(|| {
error!("Error getting prompt tokens from usage");
Error::new("Error getting prompt tokens from usage")
})?;
// Get output tokens
let output_tokens = usage
.get("completion_tokens")
.and_then(|t| t.as_i64())
.ok_or_else(|| {
error!("Error getting completion tokens from usage");
Error::new("Error getting completion tokens from usage")
})?;
// Get total tokens
let total_tokens = usage
.get("total_tokens")
Expand All @@ -130,6 +158,21 @@ impl Streamer {
e
)));
}
// Update the nodes throughput performance
if let Err(e) = self.state_manager_sender.send(
AtomaAtomaStateManagerEvent::UpdateNodeThroughputPerformance {
node_small_id: self.node_id,
input_tokens,
output_tokens,
time: self.start.elapsed().as_secs_f64(),
},
) {
error!("Error updating node throughput performance: {}", e);
return Err(Error::new(format!(
"Error updating node throughput performance: {}",
e
)));
}

Ok(())
}
Expand Down Expand Up @@ -185,6 +228,20 @@ impl Stream for Streamer {
}
};

if !self.updated_latency {
self.updated_latency = true;
let latency = self.start.elapsed().as_secs_f64();
self.state_manager_sender
.send(AtomaAtomaStateManagerEvent::UpdateNodeLatencyPerformance {
node_small_id: self.node_id,
latency,
})
.map_err(|e| {
error!("Error updating node latency performance: {}", e);
Error::new(format!("Error updating node latency performance: {}", e))
})?;
}

if choices.is_empty() {
if let Some(usage) = chunk.get(USAGE) {
self.status = StreamStatus::Completed;
Expand Down
25 changes: 25 additions & 0 deletions atoma-state/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,31 @@ pub(crate) async fn handle_state_manager_event(
} => {
handle_stack_created_event(state_manager, event, already_computed_units).await?;
}
AtomaAtomaStateManagerEvent::UpdateNodeThroughputPerformance {
node_small_id,
input_tokens,
output_tokens,
time,
} => {
state_manager
.state
.update_node_throughput_performance(
node_small_id,
input_tokens,
output_tokens,
time,
)
.await?;
}
AtomaAtomaStateManagerEvent::UpdateNodeLatencyPerformance {
node_small_id,
latency,
} => {
state_manager
.state
.update_node_latency_performance(node_small_id, latency)
.await?;
}
}
Ok(())
}
Loading