From 124fd0d391b3335a85449f15b774cc4e5fc2ea3a Mon Sep 17 00:00:00 2001 From: skull8888888 Date: Thu, 14 Nov 2024 21:37:40 -0800 Subject: [PATCH 1/3] eval time progression (#210) * initial work to compare evals * remove unnecessary div * design --------- Co-authored-by: Din --- app-server/src/ch/evaluation_scores.rs | 166 +++++++++++------ app-server/src/ch/events.rs | 38 ++-- app-server/src/ch/spans.rs | 147 ++++++++------- app-server/src/ch/utils.rs | 42 +---- app-server/src/main.rs | 4 - app-server/src/routes/evaluations.rs | 140 ++++---------- .../[groupId]/progression/route.ts | 32 ++++ .../[projectId]/evaluation-groups/route.ts | 23 +++ .../projects/[projectId]/evaluations/route.ts | 9 +- .../queues/[queueId]/remove/route.ts | 3 +- frontend/components/evaluation/chart.tsx | 175 ++++++++++-------- .../components/evaluation/compare-chart.tsx | 125 +++++++++++++ frontend/components/evaluation/evaluation.tsx | 79 ++++---- .../evaluations/evaluations-groups-bar.tsx | 52 ++++++ .../components/evaluations/evaluations.tsx | 131 ++++++++----- .../evaluations/progression-chart.tsx | 143 ++++++++++++++ frontend/components/ui/datatable.tsx | 4 +- frontend/components/ui/select.tsx | 6 +- frontend/lib/clickhouse/evaluation-scores.ts | 39 ++++ frontend/lib/clickhouse/utils.ts | 54 ++++++ frontend/lib/db/drizzle.ts | 25 ++- frontend/lib/evaluation/types.ts | 7 + frontend/lib/types.ts | 6 + 23 files changed, 986 insertions(+), 464 deletions(-) create mode 100644 frontend/app/api/projects/[projectId]/evaluation-groups/[groupId]/progression/route.ts create mode 100644 frontend/app/api/projects/[projectId]/evaluation-groups/route.ts create mode 100644 frontend/components/evaluation/compare-chart.tsx create mode 100644 frontend/components/evaluations/evaluations-groups-bar.tsx create mode 100644 frontend/components/evaluations/progression-chart.tsx create mode 100644 frontend/lib/clickhouse/evaluation-scores.ts create mode 100644 frontend/lib/clickhouse/utils.ts diff --git a/app-server/src/ch/evaluation_scores.rs b/app-server/src/ch/evaluation_scores.rs index 5825989a..384a45d1 100644 --- a/app-server/src/ch/evaluation_scores.rs +++ b/app-server/src/ch/evaluation_scores.rs @@ -6,7 +6,7 @@ use uuid::Uuid; use crate::evaluations::utils::EvaluationDatapointResult; -use super::utils::{chrono_to_nanoseconds, execute_query, validate_string_against_injection}; +use super::utils::chrono_to_nanoseconds; fn serialize_timestamp(timestamp: &DateTime, serializer: S) -> Result where @@ -110,28 +110,37 @@ pub async fn get_average_evaluation_score( evaluation_id: Uuid, name: String, ) -> Result { - validate_string_against_injection(&name)?; - - let query = format!( - "SELECT avg(value) as average_value + let row = clickhouse + .query( + "SELECT avg(value) as average_value FROM evaluation_scores - WHERE project_id = '{project_id}' - AND evaluation_id = '{evaluation_id}' - AND name = '{name}'", - ); + WHERE project_id = ? + AND evaluation_id = ? + AND name = ? + ", + ) + .bind(project_id) + .bind(evaluation_id) + .bind(name) + .fetch_one::() + .await?; - let rows: Vec = execute_query(&clickhouse, &query).await?; - Ok(rows[0].average_value) + Ok(row.average_value) } -#[derive(Row, Deserialize)] +#[derive(Row, Deserialize, Clone, Debug)] pub struct EvaluationScoreBucket { pub lower_bound: f64, pub upper_bound: f64, pub height: u64, } -pub async fn get_evaluation_score_buckets_based_on_bounds( +#[derive(Row, Deserialize)] +struct TotalCount { + total_count: u64, +} + +pub async fn get_evaluation_score_single_bucket( clickhouse: clickhouse::Client, project_id: Uuid, evaluation_id: Uuid, @@ -140,53 +149,98 @@ pub async fn get_evaluation_score_buckets_based_on_bounds( upper_bound: f64, bucket_count: u64, ) -> Result> { - validate_string_against_injection(&name)?; + // If the bounds are the same, we only need one bucket. + // We fill in the rest with 0s. + let total_count = clickhouse + .query( + "SELECT COUNT() as total_count + FROM evaluation_scores + WHERE project_id = ? + AND evaluation_id = ? + AND name = ?", + ) + .bind(project_id) + .bind(evaluation_id) + .bind(name) + .fetch_one::() + .await?; + let mut res = vec![ + EvaluationScoreBucket { + lower_bound, + upper_bound, + height: 0, + }; + bucket_count as usize - 1 + ]; + res.push(EvaluationScoreBucket { + lower_bound, + upper_bound, + height: total_count.total_count, + }); + return Ok(res); +} +pub async fn get_evaluation_score_buckets_based_on_bounds( + clickhouse: clickhouse::Client, + project_id: Uuid, + evaluation_id: Uuid, + name: String, + lower_bound: f64, + upper_bound: f64, + bucket_count: u64, +) -> Result> { let step_size = (upper_bound - lower_bound) / bucket_count as f64; - let interval_nums = (1..=bucket_count) - .map(|num| num.to_string()) - .collect::>() - .join(","); - - // This query uses {:?} with the purpose to render floats like 1.0 as 1.0 instead of 1 - let query = format!( - " + let interval_nums = (1..=bucket_count).collect::>(); + + let rows: Vec = clickhouse + .query( + " WITH intervals AS ( SELECT - arrayJoin([{interval_nums}]) AS interval_num, - {:?} + ((interval_num - 1) * {:?}) AS lower_bound, + arrayJoin(?) AS interval_num, + ? + ((interval_num - 1) * ?) AS lower_bound, CASE - WHEN interval_num = {bucket_count} THEN {:?} - ELSE {:?} + (interval_num * {:?}) + WHEN interval_num = ? THEN ? -- to avoid floating point precision issues + ELSE ? + (interval_num * ?) END AS upper_bound ) SELECT - intervals.lower_bound, - intervals.upper_bound, - COUNT(CASE - WHEN value >= intervals.lower_bound AND value < intervals.upper_bound THEN 1 - WHEN intervals.interval_num = {bucket_count} - AND value >= intervals.lower_bound - AND value <= intervals.upper_bound THEN 1 - ELSE NULL + CAST(intervals.lower_bound AS Float64) AS lower_bound, + CAST(intervals.upper_bound AS Float64) AS upper_bound, + SUM(CASE + -- exclusive on upper bound to avoid counting the same value twice + WHEN (value >= intervals.lower_bound AND value < intervals.upper_bound) + OR value = ? THEN 1 + ELSE 0 END) AS height FROM evaluation_scores JOIN intervals ON 1 = 1 -WHERE project_id = '{project_id}' -AND evaluation_id = '{evaluation_id}' -AND name = '{name}' +WHERE project_id = ? +AND evaluation_id = ? +AND name = ? GROUP BY intervals.lower_bound, intervals.upper_bound, intervals.interval_num ORDER BY intervals.interval_num", - lower_bound, step_size, upper_bound, lower_bound, step_size - ); - - let rows: Vec = execute_query(&clickhouse, &query).await?; + ) + .bind(interval_nums) + .bind(lower_bound) + .bind(step_size) + .bind(bucket_count) + .bind(upper_bound) + .bind(lower_bound) + .bind(step_size) + .bind(upper_bound) + .bind(project_id) + .bind(evaluation_id) + .bind(name) + .fetch_all::() + .await?; Ok(rows) } #[derive(Row, Deserialize, Clone)] pub struct ComparedEvaluationScoresBounds { + pub lower_bound: f64, pub upper_bound: f64, } @@ -196,24 +250,22 @@ pub async fn get_global_evaluation_scores_bounds( evaluation_ids: &Vec, name: String, ) -> Result { - validate_string_against_injection(&name)?; - - let evaluation_ids_str = evaluation_ids - .iter() - .map(|id| format!("'{}'", id)) - .collect::>() - .join(","); - - let query = format!( - " + let row = clickhouse + .query( + " SELECT + MIN(value) AS lower_bound, MAX(value) AS upper_bound FROM evaluation_scores -WHERE project_id = '{project_id}' - AND evaluation_id IN ({evaluation_ids_str}) - AND name = '{name}'", - ); +WHERE project_id = ? + AND evaluation_id IN ? + AND name = ?", + ) + .bind(project_id) + .bind(evaluation_ids) + .bind(name) + .fetch_one() + .await?; - let rows: Vec = execute_query(&clickhouse, &query).await?; - Ok(rows[0].clone()) + Ok(row) } diff --git a/app-server/src/ch/events.rs b/app-server/src/ch/events.rs index 2ebd3551..0481bcfb 100644 --- a/app-server/src/ch/events.rs +++ b/app-server/src/ch/events.rs @@ -10,8 +10,7 @@ use crate::db::{self, event_templates::EventTemplate}; use super::{ modifiers::GroupByInterval, utils::{ - chrono_to_nanoseconds, execute_query, group_by_time_absolute_statement, - group_by_time_relative_statement, + chrono_to_nanoseconds, group_by_time_absolute_statement, group_by_time_relative_statement, }, MetricTimeValue, }; @@ -132,14 +131,22 @@ pub async fn get_total_event_count_metrics_relative( COUNT(DISTINCT id) AS value FROM events WHERE - project_id = '{project_id}' - AND template_id = '{template_id}' - AND timestamp >= now() - INTERVAL {past_hours} HOUR + project_id = ? + AND template_id = ? + AND timestamp >= now() - INTERVAL ? HOUR {}", group_by_time_relative_statement(past_hours, group_by_interval), ); - execute_query(&clickhouse, &query_string).await + let rows: Vec> = clickhouse + .query(&query_string) + .bind(project_id) + .bind(template_id) + .bind(past_hours) + .fetch_all::>() + .await?; + + Ok(rows) } pub async fn get_total_event_count_metrics_absolute( @@ -161,13 +168,22 @@ pub async fn get_total_event_count_metrics_absolute( COUNT(DISTINCT id) AS value FROM events WHERE - project_id = '{project_id}' - AND template_id = '{template_id}' - AND timestamp >= fromUnixTimestamp({ch_start_time}) - AND timestamp <= fromUnixTimestamp({ch_end_time}) + project_id = ? + AND template_id = ? + AND timestamp >= fromUnixTimestamp(?) + AND timestamp <= fromUnixTimestamp(?) {}", group_by_time_absolute_statement(start_time, end_time, group_by_interval) ); - execute_query(&clickhouse, &query_string).await + let rows: Vec> = clickhouse + .query(&query_string) + .bind(project_id) + .bind(template_id) + .bind(ch_start_time) + .bind(ch_end_time) + .fetch_all::>() + .await?; + + Ok(rows) } diff --git a/app-server/src/ch/spans.rs b/app-server/src/ch/spans.rs index 72819b8f..f8bb742b 100644 --- a/app-server/src/ch/spans.rs +++ b/app-server/src/ch/spans.rs @@ -12,8 +12,7 @@ use crate::{ use super::{ modifiers::GroupByInterval, utils::{ - chrono_to_nanoseconds, execute_query, group_by_time_absolute_statement, - group_by_time_relative_statement, + chrono_to_nanoseconds, group_by_time_absolute_statement, group_by_time_relative_statement, }, Aggregation, MetricTimeValue, }; @@ -122,30 +121,18 @@ pub async fn get_total_trace_count_metrics_relative( project_id: Uuid, past_hours: i64, ) -> Result>> { - let ch_round_time = group_by_interval.to_ch_truncate_time(); - - let query_string = format!( - " - WITH traces AS ( - SELECT - trace_id, - project_id, - {ch_round_time}(MIN(start_time)) as time - FROM spans - GROUP BY project_id, trace_id - ) - SELECT - time, - COUNT(DISTINCT(trace_id)) as value - FROM traces - WHERE - project_id = '{project_id}' - AND time >= now() - INTERVAL {past_hours} HOUR - {}", - group_by_time_relative_statement(past_hours, group_by_interval) + let query = span_metric_query_relative( + &clickhouse, + project_id, + group_by_interval, + past_hours, + Aggregation::Total, + "COUNT(DISTINCT(trace_id))", ); - execute_query(&clickhouse, &query_string).await + let rows = query.fetch_all().await?; + + Ok(rows) } pub async fn get_total_trace_count_metrics_absolute( @@ -155,34 +142,19 @@ pub async fn get_total_trace_count_metrics_absolute( start_time: DateTime, end_time: DateTime, ) -> Result>> { - let ch_round_time = group_by_interval.to_ch_truncate_time(); - let ch_start_time = start_time.timestamp(); - let ch_end_time = end_time.timestamp(); - - let query_string = format!( - " - WITH traces AS ( - SELECT - trace_id, + let query = span_metric_query_absolute( + &clickhouse, project_id, - {ch_round_time}(MIN(start_time)) as time, - SUM(total_tokens) as value - FROM spans - GROUP BY project_id, trace_id - ) - SELECT - time, - COUNT(DISTINCT(trace_id)) as value - FROM traces - WHERE - project_id = '{project_id}' - AND time >= fromUnixTimestamp({ch_start_time}) - AND time <= fromUnixTimestamp({ch_end_time}) - {}", - group_by_time_absolute_statement(start_time, end_time, group_by_interval) + group_by_interval, + start_time, + end_time, + Aggregation::Total, + "COUNT(DISTINCT(trace_id))", ); - execute_query(&clickhouse, &query_string).await + let rows = query.fetch_all().await?; + + Ok(rows) } pub async fn get_trace_latency_seconds_metrics_relative( @@ -192,7 +164,8 @@ pub async fn get_trace_latency_seconds_metrics_relative( past_hours: i64, aggregation: Aggregation, ) -> Result>> { - let query_string = span_metric_query_relative( + let query = span_metric_query_relative( + &clickhouse, project_id, group_by_interval, past_hours, @@ -200,7 +173,9 @@ pub async fn get_trace_latency_seconds_metrics_relative( "(toUnixTimestamp64Nano(MAX(end_time)) - toUnixTimestamp64Nano(MIN(start_time))) / 1e9", ); - execute_query(&clickhouse, &query_string).await + let res = query.fetch_all::>().await?; + + Ok(res) } pub async fn get_trace_latency_seconds_metrics_absolute( @@ -211,7 +186,8 @@ pub async fn get_trace_latency_seconds_metrics_absolute( end_time: DateTime, aggregation: Aggregation, ) -> Result>> { - let query_string = span_metric_query_absolute( + let query = span_metric_query_absolute( + &clickhouse, project_id, group_by_interval, start_time, @@ -220,7 +196,9 @@ pub async fn get_trace_latency_seconds_metrics_absolute( "(toUnixTimestamp64Nano(MAX(end_time)) - toUnixTimestamp64Nano(MIN(start_time))) / 1e9", ); - execute_query(&clickhouse, &query_string).await + let res = query.fetch_all::>().await?; + + Ok(res) } pub async fn get_total_token_count_metrics_relative( @@ -230,7 +208,8 @@ pub async fn get_total_token_count_metrics_relative( past_hours: i64, aggregation: Aggregation, ) -> Result>> { - let query_string = span_metric_query_relative( + let query = span_metric_query_relative( + &clickhouse, project_id, group_by_interval, past_hours, @@ -238,7 +217,9 @@ pub async fn get_total_token_count_metrics_relative( "SUM(total_tokens)", ); - execute_query(&clickhouse, &query_string).await + let res = query.fetch_all::>().await?; + + Ok(res) } pub async fn get_total_token_count_metrics_absolute( @@ -249,7 +230,8 @@ pub async fn get_total_token_count_metrics_absolute( end_time: DateTime, aggregation: Aggregation, ) -> Result>> { - let query_string = span_metric_query_absolute( + let query = span_metric_query_absolute( + &clickhouse, project_id, group_by_interval, start_time, @@ -258,7 +240,9 @@ pub async fn get_total_token_count_metrics_absolute( "SUM(total_tokens)", ); - execute_query(&clickhouse, &query_string).await + let res = query.fetch_all().await?; + + Ok(res) } pub async fn get_cost_usd_metrics_relative( @@ -268,7 +252,8 @@ pub async fn get_cost_usd_metrics_relative( past_hours: i64, aggregation: Aggregation, ) -> Result>> { - let query_string = span_metric_query_relative( + let query = span_metric_query_relative( + &clickhouse, project_id, group_by_interval, past_hours, @@ -276,7 +261,9 @@ pub async fn get_cost_usd_metrics_relative( "SUM(total_cost)", ); - execute_query(&clickhouse, &query_string).await + let res = query.fetch_all().await?; + + Ok(res) } pub async fn get_cost_usd_metrics_absolute( @@ -287,7 +274,8 @@ pub async fn get_cost_usd_metrics_absolute( end_time: DateTime, aggregation: Aggregation, ) -> Result>> { - let query_string = span_metric_query_absolute( + let query = span_metric_query_absolute( + &clickhouse, project_id, group_by_interval, start_time, @@ -296,20 +284,23 @@ pub async fn get_cost_usd_metrics_absolute( "SUM(total_cost)", ); - execute_query(&clickhouse, &query_string).await + let res = query.fetch_all().await?; + + Ok(res) } fn span_metric_query_relative( + clickhouse: &clickhouse::Client, project_id: Uuid, group_by_interval: GroupByInterval, past_hours: i64, aggregation: Aggregation, metric: &str, -) -> String { +) -> clickhouse::query::Query { let ch_round_time = group_by_interval.to_ch_truncate_time(); let ch_aggregation = aggregation.to_ch_agg_function(); - format!( + let query_string = format!( " WITH traces AS ( SELECT @@ -325,27 +316,33 @@ fn span_metric_query_relative( {ch_aggregation}(value) as value FROM traces WHERE - project_id = '{project_id}' - AND time >= now() - INTERVAL {past_hours} HOUR + project_id = ? + AND time >= now() - INTERVAL ? HOUR {}", group_by_time_relative_statement(past_hours, group_by_interval) - ) + ); + + clickhouse + .query(&query_string) + .bind(project_id) + .bind(past_hours) } fn span_metric_query_absolute( + clickhouse: &clickhouse::Client, project_id: Uuid, group_by_interval: GroupByInterval, start_time: DateTime, end_time: DateTime, aggregation: Aggregation, metric: &str, -) -> String { +) -> clickhouse::query::Query { let ch_round_time = group_by_interval.to_ch_truncate_time(); let ch_start_time = start_time.timestamp(); let ch_end_time = end_time.timestamp(); let ch_aggregation = aggregation.to_ch_agg_function(); - format!( + let query_string = format!( " WITH traces AS ( SELECT @@ -361,10 +358,16 @@ fn span_metric_query_absolute( {ch_aggregation}(value) as value FROM traces WHERE - project_id = '{project_id}' - AND time >= fromUnixTimestamp({ch_start_time}) - AND time <= fromUnixTimestamp({ch_end_time}) + project_id = ? + AND time >= fromUnixTimestamp(?) + AND time <= fromUnixTimestamp(?) {}", group_by_time_absolute_statement(start_time, end_time, group_by_interval) - ) + ); + + clickhouse + .query(&query_string) + .bind(project_id) + .bind(ch_start_time) + .bind(ch_end_time) } diff --git a/app-server/src/ch/utils.rs b/app-server/src/ch/utils.rs index 78ae23ee..c3b2be52 100644 --- a/app-server/src/ch/utils.rs +++ b/app-server/src/ch/utils.rs @@ -114,12 +114,15 @@ async fn get_time_bounds( MAX({column_name}) AS max_time FROM {table_name} - WHERE project_id = '{project_id}'", + WHERE project_id = ?", ); - let mut cursor = clickhouse.query(&query_string).fetch::()?; + let time_bounds = clickhouse + .query(&query_string) + .bind(project_id) + .fetch_one::() + .await?; - let time_bounds = cursor.next().await?.unwrap(); Ok(time_bounds) } @@ -135,36 +138,3 @@ pub async fn get_bounds( nanoseconds_to_chrono(time_bounds.max_time), )) } - -pub async fn execute_query<'de, T>( - clickhouse: &clickhouse::Client, - query_string: &str, -) -> Result> -where - T: Row + Deserialize<'de>, -{ - if !is_feature_enabled(Feature::FullBuild) { - return Ok(Vec::new()); - } - - let mut cursor = clickhouse.query(query_string).fetch::()?; - - let mut res = Vec::new(); - while let Some(row) = cursor.next().await? { - res.push(row); - } - - Ok(res) -} - -/// Trivial SQL injection protection -pub fn validate_string_against_injection(s: &str) -> Result<()> { - let invalid_chars = ["'", "\"", "\\", ";", "*", "/", "--"]; - if invalid_chars.iter().any(|&c| s.contains(c)) - || s.to_lowercase().contains("union") - || s.to_lowercase().contains("select") - { - return Err(anyhow::anyhow!("Invalid characters or SQL keywords")); - } - return Ok(()); -} diff --git a/app-server/src/main.rs b/app-server/src/main.rs index 2c166ddb..5960c7cc 100644 --- a/app-server/src/main.rs +++ b/app-server/src/main.rs @@ -487,8 +487,6 @@ fn main() -> anyhow::Result<()> { .service(routes::api_keys::create_project_api_key) .service(routes::api_keys::get_api_keys_for_project) .service(routes::api_keys::revoke_project_api_key) - .service(routes::evaluations::get_evaluation) - .service(routes::evaluations::delete_evaluation) .service(routes::evaluations::get_evaluation_score_stats) .service( routes::evaluations::get_evaluation_score_distribution, @@ -501,8 +499,6 @@ fn main() -> anyhow::Result<()> { .service(routes::datasets::delete_datapoints) .service(routes::datasets::delete_all_datapoints) .service(routes::datasets::index_dataset) - .service(routes::evaluations::get_evaluations) - .service(routes::evaluations::get_evaluation) .service(routes::traces::get_traces) .service(routes::traces::get_single_trace) .service(routes::traces::get_single_span) diff --git a/app-server/src/routes/evaluations.rs b/app-server/src/routes/evaluations.rs index 75f2df83..8fb93db2 100644 --- a/app-server/src/routes/evaluations.rs +++ b/app-server/src/routes/evaluations.rs @@ -1,16 +1,10 @@ -use actix_web::{delete, get, web, HttpResponse}; +use actix_web::{get, web, HttpResponse}; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::{ - ch::evaluation_scores::{ - get_average_evaluation_score, get_evaluation_score_buckets_based_on_bounds, - get_global_evaluation_scores_bounds, EvaluationScoreBucket, - }, - db::{ - evaluations::{self, Evaluation, EvaluationDatapoint}, - DB, - }, +use crate::ch::evaluation_scores::{ + get_average_evaluation_score, get_evaluation_score_buckets_based_on_bounds, + get_evaluation_score_single_bucket, get_global_evaluation_scores_bounds, EvaluationScoreBucket, }; use super::ResponseResult; @@ -18,84 +12,6 @@ use super::ResponseResult; const DEFAULT_LOWER_BOUND: f64 = 0.0; const DEFAULT_BUCKET_COUNT: u64 = 10; -#[delete("evaluations/{evaluation_id}")] -async fn delete_evaluation(path: web::Path<(Uuid, Uuid)>, db: web::Data) -> ResponseResult { - let (_project_id, evaluation_id) = path.into_inner(); - evaluations::delete_evaluation(&db.pool, &evaluation_id).await?; - - Ok(HttpResponse::Ok().finish()) -} - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct GetEvaluationsQuery { - #[serde(default)] - current_evaluation_id: Option, -} - -#[get("evaluations")] -async fn get_evaluations( - db: web::Data, - path: web::Path, - query: web::Query, -) -> ResponseResult { - let project_id = path.into_inner(); - let query = query.into_inner(); - let current_evaluation_id = query.current_evaluation_id; - - let evaluations = match current_evaluation_id { - Some(current_evaluation_id) => { - // TODO: Currently, this query takes care of filtering out by group id, need to make it more explicit - evaluations::get_evaluations_grouped_by_current_evaluation( - &db.pool, - project_id, - current_evaluation_id, - ) - .await? - } - None => evaluations::get_evaluations(&db.pool, project_id).await?, - }; - - Ok(HttpResponse::Ok().json(evaluations)) -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct GetEvaluationResponse { - evaluation: Evaluation, - results: Vec, -} - -#[get("evaluations/{evaluation_id}")] -async fn get_evaluation(path: web::Path<(Uuid, Uuid)>, db: web::Data) -> ResponseResult { - let (project_id, evaluation_id) = path.into_inner(); - let db = db.into_inner(); - - let db_clone = db.clone(); - let get_evaluation_task = tokio::task::spawn(async move { - evaluations::get_evaluation(db_clone, project_id, evaluation_id).await - }); - - let get_evaluation_results = tokio::task::spawn(async move { - evaluations::get_evaluation_results(&db.pool, evaluation_id).await - }); - - let join_res = tokio::try_join!(get_evaluation_task, get_evaluation_results); - if let Err(e) = join_res { - return Err(anyhow::anyhow!("Error getting evaluation: {}", e).into()); - } - let (evaluation, results) = join_res.unwrap(); - let evaluation = evaluation?; - let results = results?; - - let response = GetEvaluationResponse { - evaluation, - results, - }; - - Ok(HttpResponse::Ok().json(response)) -} - #[derive(Deserialize)] #[serde(rename_all = "camelCase")] pub struct GetEvaluationScoreStatsQuery { @@ -175,31 +91,41 @@ async fn get_evaluation_score_distribution( score_name.clone(), ) .await?; - // TODO: Figure out better way to handle this in both backend and frontend - if global_bounds.upper_bound < DEFAULT_LOWER_BOUND { - return Err(anyhow::anyhow!( - "Upper bound is less than lower bound: {} < {}", - global_bounds.upper_bound, - DEFAULT_LOWER_BOUND - ) - .into()); - } + + let lower_bound = if global_bounds.lower_bound < DEFAULT_LOWER_BOUND { + global_bounds.lower_bound + } else { + DEFAULT_LOWER_BOUND + }; let evaluation_buckets: Vec> = futures::future::try_join_all(evaluation_ids.into_iter().map(|evaluation_id| { let clickhouse = clickhouse.clone(); let score_name = score_name.clone(); async move { - get_evaluation_score_buckets_based_on_bounds( - clickhouse, - project_id, - evaluation_id, - score_name, - DEFAULT_LOWER_BOUND, - global_bounds.upper_bound, - DEFAULT_BUCKET_COUNT, - ) - .await + if global_bounds.lower_bound == global_bounds.upper_bound { + get_evaluation_score_single_bucket( + clickhouse, + project_id, + evaluation_id, + score_name, + global_bounds.lower_bound, + global_bounds.upper_bound, + DEFAULT_BUCKET_COUNT, + ) + .await + } else { + get_evaluation_score_buckets_based_on_bounds( + clickhouse, + project_id, + evaluation_id, + score_name, + lower_bound, + global_bounds.upper_bound, + DEFAULT_BUCKET_COUNT, + ) + .await + } } })) .await?; diff --git a/frontend/app/api/projects/[projectId]/evaluation-groups/[groupId]/progression/route.ts b/frontend/app/api/projects/[projectId]/evaluation-groups/[groupId]/progression/route.ts new file mode 100644 index 00000000..2a94a159 --- /dev/null +++ b/frontend/app/api/projects/[projectId]/evaluation-groups/[groupId]/progression/route.ts @@ -0,0 +1,32 @@ +import { getEvaluationTimeProgression } from "@/lib/clickhouse/evaluation-scores"; +import { NextRequest, NextResponse } from "next/server"; +import { clickhouseClient } from "@/lib/clickhouse/client"; +import { AggregationFunction, TimeRange } from "@/lib/clickhouse/utils"; + + +export const GET = async (request: NextRequest, { params }: { params: { projectId: string, groupId: string } }) => { + const { projectId, groupId } = params; + let timeRange: TimeRange; + if (request.nextUrl.searchParams.get('pastHours')) { + const pastHours = parseInt(request.nextUrl.searchParams.get('pastHours') ?? '0'); + timeRange = { pastHours }; + } else if (request.nextUrl.searchParams.get('startDate') && request.nextUrl.searchParams.get('endDate')) { + const startDate = new Date(request.nextUrl.searchParams.get('startDate') ?? ''); + const endDate = new Date(request.nextUrl.searchParams.get('endDate') ?? ''); + timeRange = { start: startDate, end: endDate }; + } else { + timeRange = { pastHours: 168 }; + } + + const aggregationFunction = (request.nextUrl.searchParams.get('aggregate') ?? 'AVG') as AggregationFunction; + + const progression = await getEvaluationTimeProgression( + clickhouseClient, + projectId, + groupId, + timeRange, + aggregationFunction + ); + + return NextResponse.json(progression); +}; diff --git a/frontend/app/api/projects/[projectId]/evaluation-groups/route.ts b/frontend/app/api/projects/[projectId]/evaluation-groups/route.ts new file mode 100644 index 00000000..5036e29f --- /dev/null +++ b/frontend/app/api/projects/[projectId]/evaluation-groups/route.ts @@ -0,0 +1,23 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { db } from '@/lib/db/drizzle'; +import { eq, sql, desc } from 'drizzle-orm'; +import { evaluations } from '@/lib/db/migrations/schema'; + +export async function GET(request: NextRequest, { params }: { params: { projectId: string } }) { + const projectId = params.projectId; + const groupedEvaluations = db.$with('grouped_evaluations').as( + db.select({ + groupId: evaluations.groupId, + lastEvaluationCreatedAt: sql`MAX(${evaluations.createdAt})`.as('lastEvaluationCreatedAt'), + }).from(evaluations).where(eq(evaluations.projectId, projectId)).groupBy(evaluations.groupId) + ); + const groups = await db + .with(groupedEvaluations) + .select({ + groupId: groupedEvaluations.groupId, + lastEvaluationCreatedAt: groupedEvaluations.lastEvaluationCreatedAt, + }) + .from(groupedEvaluations) + .orderBy(desc(groupedEvaluations.lastEvaluationCreatedAt)); + return NextResponse.json(groups); +} diff --git a/frontend/app/api/projects/[projectId]/evaluations/route.ts b/frontend/app/api/projects/[projectId]/evaluations/route.ts index c0f12c04..a6969335 100644 --- a/frontend/app/api/projects/[projectId]/evaluations/route.ts +++ b/frontend/app/api/projects/[projectId]/evaluations/route.ts @@ -1,6 +1,6 @@ import { db } from '@/lib/db/drizzle'; import { evaluations } from '@/lib/db/migrations/schema'; -import { and, desc, eq, inArray } from 'drizzle-orm'; +import { and, desc, eq, inArray, SQL } from 'drizzle-orm'; import { paginatedGet } from '@/lib/db/utils'; import { Evaluation } from '@/lib/evaluation/types'; import { NextRequest } from 'next/server'; @@ -10,10 +10,15 @@ export async function GET( { params }: { params: { projectId: string } } ): Promise { const projectId = params.projectId; + const groupId = req.nextUrl.searchParams.get('groupId'); + const filters: SQL[] = [eq(evaluations.projectId, projectId)]; + if (groupId) { + filters.push(eq(evaluations.groupId, groupId)); + } const result = await paginatedGet({ table: evaluations, - filters: [eq(evaluations.projectId, projectId)], + filters, orderBy: desc(evaluations.createdAt) }); diff --git a/frontend/app/api/projects/[projectId]/queues/[queueId]/remove/route.ts b/frontend/app/api/projects/[projectId]/queues/[queueId]/remove/route.ts index f61556ba..8c4c0434 100644 --- a/frontend/app/api/projects/[projectId]/queues/[queueId]/remove/route.ts +++ b/frontend/app/api/projects/[projectId]/queues/[queueId]/remove/route.ts @@ -6,6 +6,7 @@ import { isFeatureEnabled } from '@/lib/features/features'; import { Feature } from '@/lib/features/features'; import { clickhouseClient } from '@/lib/clickhouse/client'; import { z } from 'zod'; +import { dateToNanoseconds } from '@/lib/clickhouse/utils'; const NANOS_PER_MILLISECOND = 1_000_000; @@ -89,7 +90,7 @@ export async function POST(request: Request, { params }: { params: { projectId: result_id: resultId, name: value.name, value: value.score, - timestamp: new Date().getTime() * NANOS_PER_MILLISECOND + timestamp: dateToNanoseconds(new Date()) })) }); } diff --git a/frontend/components/evaluation/chart.tsx b/frontend/components/evaluation/chart.tsx index 74f1adf5..c82e9b4a 100644 --- a/frontend/components/evaluation/chart.tsx +++ b/frontend/components/evaluation/chart.tsx @@ -5,67 +5,61 @@ import { ChartTooltipContent } from '@/components/ui/chart'; import { useProjectContext } from '@/contexts/project-context'; -import { cn, swrFetcher } from '@/lib/utils'; -import { Bar, BarChart, CartesianGrid, XAxis } from 'recharts'; -import useSWR from 'swr'; +import { cn } from '@/lib/utils'; +import { Bar, BarChart, CartesianGrid, XAxis, YAxis } from 'recharts'; import { Skeleton } from '../ui/skeleton'; -import React, { useEffect, useState } from 'react'; -import { usePathname, useSearchParams } from 'next/navigation'; +import { useEffect } from 'react'; +import { useState } from 'react'; +import { BucketRow } from '@/lib/types'; +import { Label } from '../ui/label'; -const URL_QUERY_PARAMS = { - COMPARE_EVAL_ID: 'comparedEvaluationId' -}; - -const getEvaluationIdFromPathname = (pathName: string) => { - if (pathName.endsWith('/')) { - pathName = pathName.slice(0, -1); +const getTransformedData = (data: {[scoreName: string]: BucketRow[]}): {index: number, [scoreName: string]: number}[] => { + const res: {[index: number]: {[scoreName: string]: number}} = {}; + for (const [scoreName, rows] of Object.entries(data)) { + rows.forEach((row, index) => { + res[index] = { + ...res[index], + [scoreName]: row.heights[0], + }; + }); } - const pathParts = pathName.split('/'); - return pathParts[pathParts.length - 1]; -}; - -type BucketRow = { - lowerBound: number; - upperBound: number; - heights: number[]; -}; - -const getTransformedData = (data: []) => - data.map((row: BucketRow, index: number) => ({ + return Object.values(res).map((row, index) => ({ index, - height: row.heights[0], - comparedHeight: row.heights.length > 1 ? row.heights[1] : undefined + ...row, })); +}; function renderTick(tickProps: any) { - const { x, y, payload } = tickProps; - const { value, offset } = payload; - // console.log(`x: ${x}, y: ${y}`) - // console.log(`Value: ${value}, ${typeof value}, offset: ${offset}`) + const { x, y, payload: { value, offset } } = tickProps; + const VERTICAL_TICK_OFFSET = 8; + const VERTICAL_TICK_LENGTH = 4; + const FONT_SIZE = 8; + const BUCKET_COUNT = 10; + const PERCENTAGE_STEP = 100 / BUCKET_COUNT; // Value is equal to index starting from 0 // So we calculate percentage ticks/marks by multiplying value by 10 return ( - + - {value * 10}% + {value * PERCENTAGE_STEP}% - {value === 9 && ( + {value === BUCKET_COUNT - 1 && ( <> - + 100% @@ -76,49 +70,39 @@ function renderTick(tickProps: any) { } interface ChartProps { - scoreName: string; + evaluationId: string; + allScoreNames: string[]; className?: string; } -export default function Chart({ scoreName, className }: ChartProps) { - const pathName = usePathname(); - const searchParams = new URLSearchParams(useSearchParams().toString()); +export default function Chart({ evaluationId, allScoreNames, className }: ChartProps) { const { projectId } = useProjectContext(); - - const [evaluationId, setEvaluationId] = useState( - getEvaluationIdFromPathname(pathName) - ); - const [comparedEvaluationId, setComparedEvaluationId] = useState( - searchParams.get(URL_QUERY_PARAMS.COMPARE_EVAL_ID) - ); - - const { data, isLoading, error } = useSWR( - `/api/projects/${projectId}/evaluation-score-distribution?evaluationIds=${evaluationId + (comparedEvaluationId ? `,${comparedEvaluationId}` : '')}&scoreName=${scoreName}`, - swrFetcher - ); + const [data, setData] = useState<{[scoreName: string]: BucketRow[]}>({}); + const [showScores, setShowScores] = useState(allScoreNames); useEffect(() => { - setEvaluationId(getEvaluationIdFromPathname(pathName)); - }, [pathName]); + allScoreNames.forEach((scoreName) => { + fetch(`/api/projects/${projectId}/evaluation-score-distribution?` + + `evaluationIds=${evaluationId}&scoreName=${scoreName}`) + .then((res) => res.json()) + .then((data) => setData((prev) => ({...prev, [scoreName]: data}))); + }); + }, [evaluationId, allScoreNames]); - useEffect(() => { - setComparedEvaluationId(searchParams.get(URL_QUERY_PARAMS.COMPARE_EVAL_ID)); - }, [searchParams]); - - const chartConfig = { - ['index']: { - color: 'hsl(var(--chart-1))' + const chartConfig = Object.fromEntries(allScoreNames.map((scoreName, index) => ([ + scoreName, { + color: `hsl(var(--chart-${index % 5 + 1}))`, + label: scoreName, } - } satisfies ChartConfig; + ]))) satisfies ChartConfig; + + // console.log(getTransformedData(data)); return (
- {/*
- Score distribution: {scoreName} -
*/}
- {isLoading || !data || error ? ( + {Object.keys(data).length === 0 ? ( ) : ( + + } /> - {comparedEvaluationId && ( + {showScores.map((scoreName) => ( - )} - + ))} )} +
+ {Array.from(allScoreNames).map((scoreName) => ( +
{ + let newShowScores = new Set(showScores); + if (newShowScores.has(scoreName)) { + newShowScores.delete(scoreName); + } else { + newShowScores.add(scoreName); + } + setShowScores(Array.from(newShowScores)); + }} + > + +
+ ))} +
); diff --git a/frontend/components/evaluation/compare-chart.tsx b/frontend/components/evaluation/compare-chart.tsx new file mode 100644 index 00000000..52cc7682 --- /dev/null +++ b/frontend/components/evaluation/compare-chart.tsx @@ -0,0 +1,125 @@ +import { + ChartConfig, + ChartContainer, + ChartTooltip, + ChartTooltipContent +} from '@/components/ui/chart'; +import { useProjectContext } from '@/contexts/project-context'; +import { cn, swrFetcher } from '@/lib/utils'; +import { Bar, BarChart, CartesianGrid, XAxis } from 'recharts'; +import useSWR from 'swr'; +import { Skeleton } from '../ui/skeleton'; +import { BucketRow } from '@/lib/types'; + +const getTransformedData = (data: BucketRow[]) => + data.map((row: BucketRow, index: number) => ({ + index, + height: row.heights[0], + comparedHeight: row.heights[1], + })); + +function renderTick(tickProps: any) { + const { x, y, payload: { value, offset } } = tickProps; + const VERTICAL_TICK_OFFSET = 8; + const VERTICAL_TICK_LENGTH = 4; + const FONT_SIZE = 8; + const BUCKET_COUNT = 10; + const PERCENTAGE_STEP = 100 / BUCKET_COUNT; + + // Value is equal to index starting from 0 + // So we calculate percentage ticks/marks by multiplying value by 10 + return ( + + + + {value * PERCENTAGE_STEP}% + + {value === BUCKET_COUNT - 1 && ( + <> + + + 100% + + + )} + + ); +} + +interface CompareChatProps { + evaluationId: string; + comparedEvaluationId: string; + scoreName: string; + className?: string; +} + +export default function CompareChart({ evaluationId, comparedEvaluationId, scoreName, className }: CompareChatProps) { + const { projectId } = useProjectContext(); + + const { data, isLoading, error } = useSWR( + `/api/projects/${projectId}/evaluation-score-distribution?` + + `evaluationIds=${evaluationId},${comparedEvaluationId}&scoreName=${scoreName}`, + swrFetcher + ); + + const chartConfig = { + ['index']: { + color: 'hsl(var(--chart-1))' + } + } satisfies ChartConfig; + + return ( +
+
+ + {isLoading || !data || error ? ( + + ) : ( + + + + } + /> + + + + )} + +
+
+ ); +} diff --git a/frontend/components/evaluation/evaluation.tsx b/frontend/components/evaluation/evaluation.tsx index d34dfdd2..7f745d01 100644 --- a/frontend/components/evaluation/evaluation.tsx +++ b/frontend/components/evaluation/evaluation.tsx @@ -19,14 +19,15 @@ import { SelectValue } from '../ui/select'; import { mergeOriginalWithComparedDatapoints } from '@/lib/evaluation/utils'; -import { ArrowRight, Loader2 } from 'lucide-react'; +import { ArrowRight } from 'lucide-react'; import { Button } from '../ui/button'; import { Resizable } from 're-resizable'; import TraceView from '../traces/trace-view'; -import Chart from './chart'; +import CompareChart from './compare-chart'; import ScoreCard from './score-card'; import { useToast } from '@/lib/hooks/use-toast'; import DownloadButton from '../ui/download-button'; +import Chart from './chart'; const URL_QUERY_PARAMS = { COMPARE_EVAL_ID: 'comparedEvaluationId' @@ -71,7 +72,7 @@ export default function Evaluation({ } } - // This is ok to search for selected datapoint among defaultResults before we have pagination + // TODO: get datapoints paginated. const [selectedDatapoint, setSelectedDatapoint] = useState( defaultResults.find( @@ -180,17 +181,6 @@ export default function Evaluation({ router.push(`${pathName}?${searchParams.toString()}`); }; - // It will reload the page - const handleEvaluationChange = (evaluationId: string) => { - // change last part of pathname - const currentPathName = pathName.endsWith('/') - ? pathName.slice(0, -1) - : pathName; - const pathParts = currentPathName.split('/'); - pathParts[pathParts.length - 1] = evaluationId; - router.push(`${pathParts.join('/')}?${searchParams.toString()}`); - }; - return (
@@ -229,7 +219,9 @@ export default function Evaluation({ - - - - - {Array.from(scoreColumns).map((scoreName) => ( - - {scoreName} - - ))} - - + {comparedEvaluation !== null && ( + + )}
- + {comparedEvaluation === null && ( + + )}
@@ -295,7 +291,18 @@ export default function Evaluation({
- {} + {comparedEvaluation !== null ? ( + + ) : ( + + )}
)} diff --git a/frontend/components/evaluations/evaluations-groups-bar.tsx b/frontend/components/evaluations/evaluations-groups-bar.tsx new file mode 100644 index 00000000..fbe4c28c --- /dev/null +++ b/frontend/components/evaluations/evaluations-groups-bar.tsx @@ -0,0 +1,52 @@ +import { useProjectContext } from "@/contexts/project-context"; +import { cn, swrFetcher } from "@/lib/utils"; +import { ScrollArea } from "../ui/scroll-area"; +import { useRouter, useSearchParams } from "next/navigation"; +import useSWR from "swr"; +import { DataTable } from "../ui/datatable"; +import { ColumnDef } from "@tanstack/react-table"; +import ClientTimestampFormatter from "../client-timestamp-formatter"; + + +export default function EvaluationsGroupsBar() { + const { projectId } = useProjectContext(); + + const router = useRouter(); + const searchParams = useSearchParams(); + + const { data: groups, isLoading } = useSWR<{ groupId: string, lastEvaluationCreatedAt: string }[]>( + `/api/projects/${projectId}/evaluation-groups`, + swrFetcher, + ); + + if (groups && groups.length > 0 && !searchParams.get('groupId')) { + router.push(`/project/${projectId}/evaluations?groupId=${groups[0].groupId}`); + } + + const columns: ColumnDef<{ groupId: string, lastEvaluationCreatedAt: string }>[] = [ + { + header: 'Group', + accessorFn: (row) => row.groupId, + }, + { + header: 'Last Evaluation', + accessorFn: (row) => row.lastEvaluationCreatedAt, + cell: ({ row }) => , + }, + ]; + + const selectedGroupId = searchParams.get('groupId'); + + return
+
Groups
+ row.groupId} + focusedRowId={selectedGroupId} + onRowClick={(row) => { + router.push(`/project/${projectId}/evaluations?groupId=${row.original.groupId}`); + }} + /> +
; +} diff --git a/frontend/components/evaluations/evaluations.tsx b/frontend/components/evaluations/evaluations.tsx index 84d0af7f..54e6fb66 100644 --- a/frontend/components/evaluations/evaluations.tsx +++ b/frontend/components/evaluations/evaluations.tsx @@ -4,11 +4,10 @@ import { useProjectContext } from '@/contexts/project-context'; import { Evaluation } from '@/lib/evaluation/types'; import { ColumnDef } from '@tanstack/react-table'; import ClientTimestampFormatter from '../client-timestamp-formatter'; -import { useRouter } from 'next/navigation'; +import { useRouter, useSearchParams } from 'next/navigation'; import { DataTable } from '../ui/datatable'; import Mono from '../ui/mono'; import Header from '../ui/header'; -import EvalsPagePlaceholder from './page-placeholder'; import { usePostHog } from 'posthog-js/react'; import { useUserContext } from '@/contexts/user-context'; import { Feature, isFeatureEnabled } from '@/lib/features/features'; @@ -29,16 +28,24 @@ import useSWR from 'swr'; import { swrFetcher } from '@/lib/utils'; import { Loader2 } from 'lucide-react'; import { PaginatedResponse } from '@/lib/types'; +import EvaluationsGroupsBar from './evaluations-groups-bar'; +import { Skeleton } from '../ui/skeleton'; +import ProgressionChart from './progression-chart'; +import { AggregationFunction } from '@/lib/clickhouse/utils'; +import { Select, SelectItem, SelectValue, SelectTrigger, SelectContent } from '../ui/select'; export default function Evaluations() { const { projectId } = useProjectContext(); - const { data, mutate, isLoading } = useSWR>( - `/api/projects/${projectId}/evaluations`, + const router = useRouter(); + const searchParams = useSearchParams(); + const { data, mutate } = useSWR>( + `/api/projects/${projectId}/evaluations?groupId=${searchParams.get('groupId')}`, swrFetcher ); const evaluations = data?.items; - const router = useRouter(); + const [aggregationFunction, setAggregationFunction] = useState('AVG'); + const posthog = usePostHog(); const { email } = useUserContext(); @@ -46,6 +53,7 @@ export default function Evaluations() { posthog.identify(email); } + const columns: ColumnDef[] = [ { accessorKey: 'groupId', @@ -113,52 +121,75 @@ export default function Evaluations() { return (
-
-

- Evaluations -

-
-
- { - router.push(`/project/${projectId}/evaluations/${row.original.id}`); - }} - getRowId={(row: Evaluation) => row.id} - paginated - manualPagination - selectionPanel={(selectedRowIds) => ( -
- - - - - - - Delete Evaluations - - Are you sure you want to delete {selectedRowIds.length} evaluation(s)? - This action cannot be undone. - - - - - - - - +
+ +
+
+
+
- )} - /> +
+ +
+ { + router.push(`/project/${projectId}/evaluations/${row.original.id}`); + }} + getRowId={(row: Evaluation) => row.id} + paginated + manualPagination + selectionPanel={(selectedRowIds) => ( +
+ + + + + + + Delete Evaluations + + Are you sure you want to delete {selectedRowIds.length} evaluation(s)? + This action cannot be undone. + + + + + + + + +
+ )} + /> +
+
); diff --git a/frontend/components/evaluations/progression-chart.tsx b/frontend/components/evaluations/progression-chart.tsx new file mode 100644 index 00000000..a2d37d42 --- /dev/null +++ b/frontend/components/evaluations/progression-chart.tsx @@ -0,0 +1,143 @@ +import { useProjectContext } from "@/contexts/project-context"; +import { swrFetcher, cn, formatTimestamp } from "@/lib/utils"; +import { useSearchParams } from "next/navigation"; +import { CartesianGrid, XAxis, LineChart, Line, YAxis } from "recharts"; +import useSWR from "swr"; +import { ChartConfig, ChartContainer, ChartTooltip, ChartTooltipContent } from "../ui/chart"; +import { Skeleton } from "../ui/skeleton"; +import { EvaluationTimeProgression } from "@/lib/evaluation/types"; +import { useEffect, useState } from "react"; +import { Minus } from "lucide-react"; +import { Label } from "../ui/label"; +import { AggregationFunction } from "@/lib/clickhouse/utils"; + +interface ProgressionChartProps { + className?: string; + aggregationFunction: AggregationFunction; +} + +export default function ProgressionChart({ + className, + aggregationFunction, +}: ProgressionChartProps) { + const [showScores, setShowScores] = useState([]); + const [keys, setKeys] = useState>(new Set()); + const searchParams = new URLSearchParams(useSearchParams().toString()); + const groupId = searchParams.get('groupId'); + const { projectId } = useProjectContext(); + + const convertScores = (progression: EvaluationTimeProgression[]) => + progression.map(({ timestamp, evaluationId, names, values }) => ({ + timestamp, + evaluationId, + ...Object.fromEntries(names.map((name, index) => ([name, values[index]]))), + })); + + const { data, isLoading, error } = useSWR( + `/api/projects/${projectId}/evaluation-groups/${groupId}/progression?aggregate=${aggregationFunction}`, + swrFetcher + ); + useEffect(() => { + let newKeys: Set = new Set(); + data?.forEach(({ names }) => { + names.forEach((name) => { + newKeys.add(name); + }); + }); + setKeys(newKeys); + if (showScores.length === 0) { + setShowScores(Array.from(newKeys)); + } + }, [data]); + + const chartConfig = Object.fromEntries(Array.from(keys).map((key, index) => ([ + key, { + color: `hsl(var(--chart-${index % 5 + 1}))`, + label: key, + } + ]))) satisfies ChartConfig; + + const horizontalPadding = Math.max(10 - (data?.length ?? 0), 0) * 50; + + return ( +
+ + {isLoading || !data || error ? ( +
+ +
+ ) : + + formatTimestamp(`${value}Z`)} + height={8} + padding={{ left: horizontalPadding, right: horizontalPadding }} + /> + + } + /> + {Array.from(keys).filter((key) => showScores.includes(key)).map((key) => ( + + ))} + + } +
+
+ {Array.from(keys).map((key) => ( +
{ + let newShowScores = new Set(showScores); + if (newShowScores.has(key)) { + newShowScores.delete(key); + } else { + newShowScores.add(key); + } + setShowScores(Array.from(newShowScores)); + }} + > + + +
+ ))} +
+
+ ); +} diff --git a/frontend/components/ui/datatable.tsx b/frontend/components/ui/datatable.tsx index b42d9bc4..d12888df 100644 --- a/frontend/components/ui/datatable.tsx +++ b/frontend/components/ui/datatable.tsx @@ -328,12 +328,12 @@ export function DataTable({ {children}
)} - +
{content}
{paginated && ( -
+
span]:line-clamp-1', + 'flex h-7 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1', className )} {...props} @@ -82,7 +82,7 @@ const SelectContent = React.forwardRef< className={cn( 'relative z-50 max-h-96 min-w-[8rem] overflow-hidden rounded-md border bg-popover text-popover-foreground shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2', position === 'popper' && - 'data-[side=bottom]:translate-y-1 data-[side=left]:-translate-x-1 data-[side=right]:translate-x-1 data-[side=top]:-translate-y-1', + 'data-[side=bottom]:translate-y-1 data-[side=left]:-translate-x-1 data-[side=right]:translate-x-1 data-[side=top]:-translate-y-1', className )} position={position} @@ -93,7 +93,7 @@ const SelectContent = React.forwardRef< className={cn( 'p-1', position === 'popper' && - 'h-[var(--radix-select-trigger-height)] w-full min-w-[var(--radix-select-trigger-width)]' + 'h-[var(--radix-select-trigger-height)] w-full min-w-[var(--radix-select-trigger-width)]' )} > {children} diff --git a/frontend/lib/clickhouse/evaluation-scores.ts b/frontend/lib/clickhouse/evaluation-scores.ts new file mode 100644 index 00000000..81ccbb53 --- /dev/null +++ b/frontend/lib/clickhouse/evaluation-scores.ts @@ -0,0 +1,39 @@ +import { ClickHouseClient } from "@clickhouse/client"; +import { AggregationFunction, TimeRange, addTimeRangeToQuery, aggregationFunctionToCh } from "./utils"; +import { EvaluationTimeProgression } from "../evaluation/types"; +import { BucketRow } from "../types"; + +const DEFAULT_BUCKET_COUNT = 10; +const DEFAULT_LOWER_BOUND = 0; + +export const getEvaluationTimeProgression = async ( + clickhouseClient: ClickHouseClient, + projectId: string, + groupId: string, + timeRange: TimeRange, + aggregationFunction: AggregationFunction, +): Promise => { + const query = `WITH base AS ( + SELECT + evaluation_id, + timestamp, + name, + ${aggregationFunctionToCh(aggregationFunction)}(value) AS value + FROM evaluation_scores + WHERE project_id = {projectId: UUID} AND group_id = {groupId: String}`; + const queryWithTimeRange = addTimeRangeToQuery(query, timeRange, 'timestamp'); + const finalQuery = `${queryWithTimeRange} GROUP BY evaluation_id, name, timestamp ORDER BY timestamp, name + ) SELECT groupArray(name) names, groupArray(value) values, MIN(timestamp) timestamp, evaluation_id as evaluationId + FROM base + GROUP BY evaluation_id + ORDER BY timestamp`; + const result = await clickhouseClient.query({ + query: finalQuery, + format: 'JSONEachRow', + query_params: { + projectId, + groupId, + }, + }); + return await result.json(); +}; diff --git a/frontend/lib/clickhouse/utils.ts b/frontend/lib/clickhouse/utils.ts new file mode 100644 index 00000000..30ff5763 --- /dev/null +++ b/frontend/lib/clickhouse/utils.ts @@ -0,0 +1,54 @@ +import { ClickHouseClient } from "@clickhouse/client"; + +interface TimeBounds { + minTime: number; + maxTime: number; +} + +const NANOS_PER_MILLISECOND = 1e6; + +export const dateToNanoseconds = (date: Date) => date.getTime() * NANOS_PER_MILLISECOND; + +export const nanosecondsToDate = (nanoseconds: number) => new Date(nanoseconds / NANOS_PER_MILLISECOND); + +const validateSqlString = (str: string) => /^[a-zA-Z0-9_\.]+$/.test(str); + +type AbsoluteTimeRange = { + start: Date; + end: Date; +}; + +type RelativeTimeRange = { + pastHours: number; +}; + +export type AggregationFunction = 'AVG' | 'SUM' | 'MIN' | 'MAX' | 'MEDIAN' | 'p90' | 'p95' | 'p99'; + +export const aggregationFunctionToCh = (f: AggregationFunction) => { + switch (f) { + case 'AVG': return 'avg'; + case 'SUM': return 'sum'; + case 'MIN': return 'min'; + case 'MAX': return 'max'; + case 'MEDIAN': return 'median'; + case 'p90': return 'quantileExact(0.90)'; + case 'p95': return 'quantileExact(0.95)'; + case 'p99': return 'quantileExact(0.99)'; + default: throw new Error(`Invalid aggregation function: ${f}`); + } +}; + +export type TimeRange = AbsoluteTimeRange | RelativeTimeRange; + +export const addTimeRangeToQuery = (query: string, timeRange: TimeRange, column: string) => { + if (!validateSqlString(column)) { + throw new Error(`Invalid column name: ${column}`); + } + if ('start' in timeRange && 'end' in timeRange) { + return `${query} AND ${column} >= ${dateToNanoseconds(timeRange.start)} AND ${column} <= ${dateToNanoseconds(timeRange.end)}`; + } + if ('pastHours' in timeRange) { + return `${query} AND ${column} >= now() - INTERVAL ${timeRange.pastHours} HOUR`; + } + throw new Error('Invalid time range'); +}; diff --git a/frontend/lib/db/drizzle.ts b/frontend/lib/db/drizzle.ts index 91e465ee..90e4afc3 100644 --- a/frontend/lib/db/drizzle.ts +++ b/frontend/lib/db/drizzle.ts @@ -3,8 +3,29 @@ import { drizzle } from 'drizzle-orm/postgres-js'; import postgres from 'postgres'; import * as schema from './migrations/schema'; import * as relations from './migrations/relations'; +import { DrizzleConfig } from "drizzle-orm"; config({ path: ".env" }); // or .env.local -const client = postgres(process.env.DATABASE_URL!, { max: 10 }); -export const db = drizzle(client, { schema: { ...schema, ...relations } }); +// Singleton function to ensure only one db instance is created +function singleton(name: string, value: () => Value): Value { + const globalAny: any = global; + globalAny.__singletons = globalAny.__singletons || {}; + + if (!globalAny.__singletons[name]) { + globalAny.__singletons[name] = value(); + } + + return globalAny.__singletons[name]; +} + +// Function to create the database connection and apply migrations if needed +function createDatabaseConnection() { + const client = postgres(process.env.DATABASE_URL!, { max: 10 }); + + return drizzle(client, { schema: { ...schema, ...relations } }); +} + +const db = singleton('db', createDatabaseConnection); + +export { db }; diff --git a/frontend/lib/evaluation/types.ts b/frontend/lib/evaluation/types.ts index dd9696e7..37d8b4a4 100644 --- a/frontend/lib/evaluation/types.ts +++ b/frontend/lib/evaluation/types.ts @@ -44,3 +44,10 @@ export type EvaluationResultsInfo = { evaluation: Evaluation; results: EvaluationDatapointPreview[]; }; + +export type EvaluationTimeProgression = { + timestamp: string; + evaluationId: string; + names: string[]; + values: string[]; +}; diff --git a/frontend/lib/types.ts b/frontend/lib/types.ts index 7f98e51a..82e8a1b1 100644 --- a/frontend/lib/types.ts +++ b/frontend/lib/types.ts @@ -48,3 +48,9 @@ export type PaginatedResponse = { export type PaginatedGetResponseWithProjectPresenceFlag = PaginatedResponse & { anyInProject: boolean; }; + +export type BucketRow = { + lowerBound: number; + upperBound: number; + heights: number[]; +}; From 8f65322987077538fb5507643e1c9c82f19ef741 Mon Sep 17 00:00:00 2001 From: Devansh Date: Fri, 6 Dec 2024 08:45:53 +0530 Subject: [PATCH 2/3] Refactor datatable component to include filter clearing functionality --- frontend/components/ui/datatable.tsx | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/frontend/components/ui/datatable.tsx b/frontend/components/ui/datatable.tsx index 803c0779..25254ee1 100644 --- a/frontend/components/ui/datatable.tsx +++ b/frontend/components/ui/datatable.tsx @@ -27,7 +27,7 @@ import { DataTablePagination } from './datatable-pagination'; import { Label } from './label'; import { ScrollArea, ScrollBar } from './scroll-area'; import { Skeleton } from './skeleton'; - +import { usePathname, useRouter, useSearchParams } from 'next/navigation'; const DEFAULT_PAGE_SIZE = 50; interface DataTableProps { @@ -92,6 +92,20 @@ export function DataTable({ const [allRowsAcrossAllPagesSelected, setAllRowsAcrossAllPagesSelected] = useState(false); const [expandedRows, setExpandedRows] = useState({}); + const searchParams = new URLSearchParams(useSearchParams().toString()); + const pathName = usePathname(); + const router = useRouter(); + + const clearFilters = () => { + if (searchParams.size > 0) { + // clear all filters + if (searchParams.get('filter') !== null) { + searchParams.delete('filter'); + } + router.push(`${pathName}?${searchParams.toString()}`); + } + }; + useEffect(() => { onSelectedRowsChange?.(Object.keys(rowSelection)); }, [rowSelection]); @@ -284,7 +298,12 @@ export function DataTable({ colSpan={columns.length} className="text-center p-4 text-secondary-foreground" > - No results + {searchParams.get('filter') !== null ? 'Applied filters returned no results. ' : 'No results'} + {searchParams.get('filter') !== null && ( + + )} )) From 92c2a238f793e204e308bb9e302a32e373a3767c Mon Sep 17 00:00:00 2001 From: Devansh Date: Fri, 6 Dec 2024 09:27:39 +0530 Subject: [PATCH 3/3] fix requested changes --- frontend/components/ui/datatable.tsx | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/frontend/components/ui/datatable.tsx b/frontend/components/ui/datatable.tsx index 25254ee1..7f9784d2 100644 --- a/frontend/components/ui/datatable.tsx +++ b/frontend/components/ui/datatable.tsx @@ -97,11 +97,9 @@ export function DataTable({ const router = useRouter(); const clearFilters = () => { - if (searchParams.size > 0) { - // clear all filters - if (searchParams.get('filter') !== null) { - searchParams.delete('filter'); - } + // clear all filters + if (searchParams !== null && searchParams.get('filter') !== null) { + searchParams.delete('filter'); router.push(`${pathName}?${searchParams.toString()}`); } }; @@ -300,9 +298,9 @@ export function DataTable({ > {searchParams.get('filter') !== null ? 'Applied filters returned no results. ' : 'No results'} {searchParams.get('filter') !== null && ( - + )}