From 9e69fffff1965708e37f7266f83df721795ebc15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 9 Jun 2023 15:33:53 +0200 Subject: [PATCH] [Minor] Cleanup tpch benchmark (#6609) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Cleanup tpch benchmark * Cleanup tpch benchmark --------- Co-authored-by: Daniƫl Heres --- benchmarks/src/tpch.rs | 299 +---------------------------------------- 1 file changed, 1 insertion(+), 298 deletions(-) diff --git a/benchmarks/src/tpch.rs b/benchmarks/src/tpch.rs index 72f5907b075f8..58b9c3637c4e9 100644 --- a/benchmarks/src/tpch.rs +++ b/benchmarks/src/tpch.rs @@ -15,26 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{Fields, SchemaBuilder, SchemaRef}; -use arrow::record_batch::RecordBatch; +use arrow::datatypes::SchemaBuilder; use std::fs; -use std::ops::{Div, Mul}; use std::path::Path; -use std::sync::Arc; use std::time::Instant; -use datafusion::common::cast::{ - as_date32_array, as_decimal128_array, as_float64_array, as_int32_array, - as_int64_array, as_string_array, -}; -use datafusion::common::ScalarValue; -use datafusion::logical_expr::expr::ScalarFunction; -use datafusion::logical_expr::Cast; use datafusion::prelude::*; use datafusion::{ arrow::datatypes::{DataType, Field, Schema}, - datasource::MemTable, error::{DataFusionError, Result}, }; use parquet::basic::Compression; @@ -147,155 +135,6 @@ pub fn get_tpch_table_schema(table: &str) -> Schema { } } -/// Get the expected schema for the results of a query -pub fn get_answer_schema(n: usize) -> Schema { - match n { - 1 => Schema::new(vec![ - Field::new("l_returnflag", DataType::Utf8, true), - Field::new("l_linestatus", DataType::Utf8, true), - Field::new("sum_qty", DataType::Decimal128(15, 2), true), - Field::new("sum_base_price", DataType::Decimal128(15, 2), true), - Field::new("sum_disc_price", DataType::Decimal128(15, 2), true), - Field::new("sum_charge", DataType::Decimal128(15, 2), true), - Field::new("avg_qty", DataType::Decimal128(15, 2), true), - Field::new("avg_price", DataType::Decimal128(15, 2), true), - Field::new("avg_disc", DataType::Decimal128(15, 2), true), - Field::new("count_order", DataType::Int64, true), - ]), - - 2 => Schema::new(vec![ - Field::new("s_acctbal", DataType::Decimal128(15, 2), true), - Field::new("s_name", DataType::Utf8, true), - Field::new("n_name", DataType::Utf8, true), - Field::new("p_partkey", DataType::Int64, true), - Field::new("p_mfgr", DataType::Utf8, true), - Field::new("s_address", DataType::Utf8, true), - Field::new("s_phone", DataType::Utf8, true), - Field::new("s_comment", DataType::Utf8, true), - ]), - - 3 => Schema::new(vec![ - Field::new("l_orderkey", DataType::Int64, true), - Field::new("revenue", DataType::Decimal128(15, 2), true), - Field::new("o_orderdate", DataType::Date32, true), - Field::new("o_shippriority", DataType::Int32, true), - ]), - - 4 => Schema::new(vec![ - Field::new("o_orderpriority", DataType::Utf8, true), - Field::new("order_count", DataType::Int64, true), - ]), - - 5 => Schema::new(vec![ - Field::new("n_name", DataType::Utf8, true), - Field::new("revenue", DataType::Decimal128(15, 2), true), - ]), - - 6 => Schema::new(vec![Field::new( - "revenue", - DataType::Decimal128(15, 2), - true, - )]), - - 7 => Schema::new(vec![ - Field::new("supp_nation", DataType::Utf8, true), - Field::new("cust_nation", DataType::Utf8, true), - Field::new("l_year", DataType::Float64, true), - Field::new("revenue", DataType::Decimal128(15, 2), true), - ]), - - 8 => Schema::new(vec![ - Field::new("o_year", DataType::Float64, true), - Field::new("mkt_share", DataType::Decimal128(15, 2), true), - ]), - - 9 => Schema::new(vec![ - Field::new("nation", DataType::Utf8, true), - Field::new("o_year", DataType::Float64, true), - Field::new("sum_profit", DataType::Decimal128(15, 2), true), - ]), - - 10 => Schema::new(vec![ - Field::new("c_custkey", DataType::Int64, true), - Field::new("c_name", DataType::Utf8, true), - Field::new("revenue", DataType::Decimal128(15, 2), true), - Field::new("c_acctbal", DataType::Decimal128(15, 2), true), - Field::new("n_name", DataType::Utf8, true), - Field::new("c_address", DataType::Utf8, true), - Field::new("c_phone", DataType::Utf8, true), - Field::new("c_comment", DataType::Utf8, true), - ]), - - 11 => Schema::new(vec![ - Field::new("ps_partkey", DataType::Int64, true), - Field::new("value", DataType::Decimal128(15, 2), true), - ]), - - 12 => Schema::new(vec![ - Field::new("l_shipmode", DataType::Utf8, true), - Field::new("high_line_count", DataType::Int64, true), - Field::new("low_line_count", DataType::Int64, true), - ]), - - 13 => Schema::new(vec![ - Field::new("c_count", DataType::Int64, true), - Field::new("custdist", DataType::Int64, true), - ]), - - 14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), - - 15 => Schema::new(vec![ - Field::new("s_suppkey", DataType::Int64, true), - Field::new("s_name", DataType::Utf8, true), - Field::new("s_address", DataType::Utf8, true), - Field::new("s_phone", DataType::Utf8, true), - Field::new("total_revenue", DataType::Decimal128(15, 2), true), - ]), - - 16 => Schema::new(vec![ - Field::new("p_brand", DataType::Utf8, true), - Field::new("p_type", DataType::Utf8, true), - Field::new("p_size", DataType::Int32, true), - Field::new("supplier_cnt", DataType::Int64, true), - ]), - - 17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]), - - 18 => Schema::new(vec![ - Field::new("c_name", DataType::Utf8, true), - Field::new("c_custkey", DataType::Int64, true), - Field::new("o_orderkey", DataType::Int64, true), - Field::new("o_orderdate", DataType::Date32, true), - Field::new("o_totalprice", DataType::Decimal128(15, 2), true), - Field::new("sum_l_quantity", DataType::Decimal128(15, 2), true), - ]), - - 19 => Schema::new(vec![Field::new( - "revenue", - DataType::Decimal128(15, 2), - true, - )]), - - 20 => Schema::new(vec![ - Field::new("s_name", DataType::Utf8, true), - Field::new("s_address", DataType::Utf8, true), - ]), - - 21 => Schema::new(vec![ - Field::new("s_name", DataType::Utf8, true), - Field::new("numwait", DataType::Int64, true), - ]), - - 22 => Schema::new(vec![ - Field::new("cntrycode", DataType::Utf8, true), - Field::new("numcust", DataType::Int64, true), - Field::new("totacctbal", DataType::Decimal128(15, 2), true), - ]), - - _ => unimplemented!(), - } -} - /// Get the SQL statements from the specified query file pub fn get_query_sql(query: usize) -> Result> { if query > 0 && query < 23 { @@ -399,142 +238,6 @@ pub async fn convert_tbl( Ok(()) } -/// Converts the results into a 2d array of strings, `result[row][column]` -/// Special cases nulls to NULL for testing -pub fn result_vec(results: &[RecordBatch]) -> Vec> { - let mut result = vec![]; - for batch in results { - for row_index in 0..batch.num_rows() { - let row_vec = batch - .columns() - .iter() - .map(|column| col_to_scalar(column, row_index)) - .collect(); - result.push(row_vec); - } - } - result -} - -/// convert expected schema to all utf8 so columns can be read as strings to be parsed separately -/// this is due to the fact that the csv parser cannot handle leading/trailing spaces -pub fn string_schema(schema: Schema) -> Schema { - Schema::new( - schema - .fields() - .iter() - .map(|field| { - Field::new( - Field::name(field), - DataType::Utf8, - Field::is_nullable(field), - ) - }) - .collect::>(), - ) -} - -fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue { - if column.is_null(row_index) { - return ScalarValue::Null; - } - match column.data_type() { - DataType::Int32 => { - let array = as_int32_array(column).unwrap(); - ScalarValue::Int32(Some(array.value(row_index))) - } - DataType::Int64 => { - let array = as_int64_array(column).unwrap(); - ScalarValue::Int64(Some(array.value(row_index))) - } - DataType::Float64 => { - let array = as_float64_array(column).unwrap(); - ScalarValue::Float64(Some(array.value(row_index))) - } - DataType::Decimal128(p, s) => { - let array = as_decimal128_array(column).unwrap(); - ScalarValue::Decimal128(Some(array.value(row_index)), *p, *s) - } - DataType::Date32 => { - let array = as_date32_array(column).unwrap(); - ScalarValue::Date32(Some(array.value(row_index))) - } - DataType::Utf8 => { - let array = as_string_array(column).unwrap(); - ScalarValue::Utf8(Some(array.value(row_index).to_string())) - } - other => panic!("unexpected data type in benchmark: {other}"), - } -} - -pub async fn transform_actual_result( - result: Vec, - n: usize, -) -> Result> { - // to compare the recorded answers to the answers we got back from running the query, - // we need to round the decimal columns and trim the Utf8 columns - // we also need to rewrite the batches to use a compatible schema - let ctx = SessionContext::new(); - let fields: Fields = result[0] - .schema() - .fields() - .iter() - .map(|f| { - let simple_name = match f.name().find('.') { - Some(i) => f.name()[i + 1..].to_string(), - _ => f.name().to_string(), - }; - f.as_ref().clone().with_name(simple_name) - }) - .collect(); - let result_schema = SchemaRef::new(Schema::new(fields)); - let result = result - .iter() - .map(|b| { - RecordBatch::try_new(result_schema.clone(), b.columns().to_vec()) - .map_err(|e| e.into()) - }) - .collect::>>()?; - let table = Arc::new(MemTable::try_new(result_schema.clone(), vec![result])?); - let mut df = ctx.read_table(table)?.select( - result_schema - .fields - .iter() - .map(|field| { - match field.data_type() { - DataType::Decimal128(_, _) => { - // if decimal, then round it to 2 decimal places like the answers - // round() doesn't support the second argument for decimal places to round to - // this can be simplified to remove the mul and div when - // https://github.com/apache/arrow-datafusion/issues/2420 is completed - // cast it back to an over-sized Decimal with 2 precision when done rounding - let round = Box::new( - Expr::ScalarFunction(ScalarFunction::new( - datafusion::logical_expr::BuiltinScalarFunction::Round, - vec![col(Field::name(field)).mul(lit(100))], - )) - .div(lit(100)), - ); - Expr::Cast(Cast::new(round, DataType::Decimal128(15, 2))) - .alias(field.name()) - } - DataType::Utf8 => { - // if string, then trim it like the answers got trimmed - trim(col(Field::name(field))).alias(field.name()) - } - _ => col(field.name()), - } - }) - .collect(), - )?; - if let Some(x) = QUERY_LIMIT[n - 1] { - df = df.limit(0, Some(x))?; - } - - let df = df.collect().await?; - Ok(df) -} - pub const QUERY_LIMIT: [Option; 22] = [ None, Some(100),