Skip to content

Commit

Permalink
[Minor] Cleanup tpch benchmark (#6609)
Browse files Browse the repository at this point in the history
* Cleanup tpch benchmark

* Cleanup tpch benchmark

---------

Co-authored-by: Daniël Heres <daniel.heres@coralogix.com>
  • Loading branch information
Dandandan and Daniël Heres authored Jun 9, 2023
1 parent 3d36dfe commit 43423be
Showing 1 changed file with 1 addition and 298 deletions.
299 changes: 1 addition & 298 deletions benchmarks/src/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, ArrayRef};
use arrow::datatypes::{Fields, SchemaBuilder, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow::datatypes::SchemaBuilder;
use std::fs;
use std::ops::{Div, Mul};
use std::path::Path;
use std::sync::Arc;
use std::time::Instant;

use datafusion::common::cast::{
as_date32_array, as_decimal128_array, as_float64_array, as_int32_array,
as_int64_array, as_string_array,
};
use datafusion::common::ScalarValue;
use datafusion::logical_expr::expr::ScalarFunction;
use datafusion::logical_expr::Cast;
use datafusion::prelude::*;
use datafusion::{
arrow::datatypes::{DataType, Field, Schema},
datasource::MemTable,
error::{DataFusionError, Result},
};
use parquet::basic::Compression;
Expand Down Expand Up @@ -147,155 +135,6 @@ pub fn get_tpch_table_schema(table: &str) -> Schema {
}
}

/// Get the expected schema for the results of a query
pub fn get_answer_schema(n: usize) -> Schema {
match n {
1 => Schema::new(vec![
Field::new("l_returnflag", DataType::Utf8, true),
Field::new("l_linestatus", DataType::Utf8, true),
Field::new("sum_qty", DataType::Decimal128(15, 2), true),
Field::new("sum_base_price", DataType::Decimal128(15, 2), true),
Field::new("sum_disc_price", DataType::Decimal128(15, 2), true),
Field::new("sum_charge", DataType::Decimal128(15, 2), true),
Field::new("avg_qty", DataType::Decimal128(15, 2), true),
Field::new("avg_price", DataType::Decimal128(15, 2), true),
Field::new("avg_disc", DataType::Decimal128(15, 2), true),
Field::new("count_order", DataType::Int64, true),
]),

2 => Schema::new(vec![
Field::new("s_acctbal", DataType::Decimal128(15, 2), true),
Field::new("s_name", DataType::Utf8, true),
Field::new("n_name", DataType::Utf8, true),
Field::new("p_partkey", DataType::Int64, true),
Field::new("p_mfgr", DataType::Utf8, true),
Field::new("s_address", DataType::Utf8, true),
Field::new("s_phone", DataType::Utf8, true),
Field::new("s_comment", DataType::Utf8, true),
]),

3 => Schema::new(vec![
Field::new("l_orderkey", DataType::Int64, true),
Field::new("revenue", DataType::Decimal128(15, 2), true),
Field::new("o_orderdate", DataType::Date32, true),
Field::new("o_shippriority", DataType::Int32, true),
]),

4 => Schema::new(vec![
Field::new("o_orderpriority", DataType::Utf8, true),
Field::new("order_count", DataType::Int64, true),
]),

5 => Schema::new(vec![
Field::new("n_name", DataType::Utf8, true),
Field::new("revenue", DataType::Decimal128(15, 2), true),
]),

6 => Schema::new(vec![Field::new(
"revenue",
DataType::Decimal128(15, 2),
true,
)]),

7 => Schema::new(vec![
Field::new("supp_nation", DataType::Utf8, true),
Field::new("cust_nation", DataType::Utf8, true),
Field::new("l_year", DataType::Float64, true),
Field::new("revenue", DataType::Decimal128(15, 2), true),
]),

8 => Schema::new(vec![
Field::new("o_year", DataType::Float64, true),
Field::new("mkt_share", DataType::Decimal128(15, 2), true),
]),

9 => Schema::new(vec![
Field::new("nation", DataType::Utf8, true),
Field::new("o_year", DataType::Float64, true),
Field::new("sum_profit", DataType::Decimal128(15, 2), true),
]),

10 => Schema::new(vec![
Field::new("c_custkey", DataType::Int64, true),
Field::new("c_name", DataType::Utf8, true),
Field::new("revenue", DataType::Decimal128(15, 2), true),
Field::new("c_acctbal", DataType::Decimal128(15, 2), true),
Field::new("n_name", DataType::Utf8, true),
Field::new("c_address", DataType::Utf8, true),
Field::new("c_phone", DataType::Utf8, true),
Field::new("c_comment", DataType::Utf8, true),
]),

11 => Schema::new(vec![
Field::new("ps_partkey", DataType::Int64, true),
Field::new("value", DataType::Decimal128(15, 2), true),
]),

12 => Schema::new(vec![
Field::new("l_shipmode", DataType::Utf8, true),
Field::new("high_line_count", DataType::Int64, true),
Field::new("low_line_count", DataType::Int64, true),
]),

13 => Schema::new(vec![
Field::new("c_count", DataType::Int64, true),
Field::new("custdist", DataType::Int64, true),
]),

14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]),

15 => Schema::new(vec![
Field::new("s_suppkey", DataType::Int64, true),
Field::new("s_name", DataType::Utf8, true),
Field::new("s_address", DataType::Utf8, true),
Field::new("s_phone", DataType::Utf8, true),
Field::new("total_revenue", DataType::Decimal128(15, 2), true),
]),

16 => Schema::new(vec![
Field::new("p_brand", DataType::Utf8, true),
Field::new("p_type", DataType::Utf8, true),
Field::new("p_size", DataType::Int32, true),
Field::new("supplier_cnt", DataType::Int64, true),
]),

17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]),

18 => Schema::new(vec![
Field::new("c_name", DataType::Utf8, true),
Field::new("c_custkey", DataType::Int64, true),
Field::new("o_orderkey", DataType::Int64, true),
Field::new("o_orderdate", DataType::Date32, true),
Field::new("o_totalprice", DataType::Decimal128(15, 2), true),
Field::new("sum_l_quantity", DataType::Decimal128(15, 2), true),
]),

19 => Schema::new(vec![Field::new(
"revenue",
DataType::Decimal128(15, 2),
true,
)]),

20 => Schema::new(vec![
Field::new("s_name", DataType::Utf8, true),
Field::new("s_address", DataType::Utf8, true),
]),

21 => Schema::new(vec![
Field::new("s_name", DataType::Utf8, true),
Field::new("numwait", DataType::Int64, true),
]),

22 => Schema::new(vec![
Field::new("cntrycode", DataType::Utf8, true),
Field::new("numcust", DataType::Int64, true),
Field::new("totacctbal", DataType::Decimal128(15, 2), true),
]),

_ => unimplemented!(),
}
}

/// Get the SQL statements from the specified query file
pub fn get_query_sql(query: usize) -> Result<Vec<String>> {
if query > 0 && query < 23 {
Expand Down Expand Up @@ -399,142 +238,6 @@ pub async fn convert_tbl(
Ok(())
}

/// Converts the results into a 2d array of strings, `result[row][column]`
/// Special cases nulls to NULL for testing
pub fn result_vec(results: &[RecordBatch]) -> Vec<Vec<ScalarValue>> {
let mut result = vec![];
for batch in results {
for row_index in 0..batch.num_rows() {
let row_vec = batch
.columns()
.iter()
.map(|column| col_to_scalar(column, row_index))
.collect();
result.push(row_vec);
}
}
result
}

/// convert expected schema to all utf8 so columns can be read as strings to be parsed separately
/// this is due to the fact that the csv parser cannot handle leading/trailing spaces
pub fn string_schema(schema: Schema) -> Schema {
Schema::new(
schema
.fields()
.iter()
.map(|field| {
Field::new(
Field::name(field),
DataType::Utf8,
Field::is_nullable(field),
)
})
.collect::<Vec<Field>>(),
)
}

fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue {
if column.is_null(row_index) {
return ScalarValue::Null;
}
match column.data_type() {
DataType::Int32 => {
let array = as_int32_array(column).unwrap();
ScalarValue::Int32(Some(array.value(row_index)))
}
DataType::Int64 => {
let array = as_int64_array(column).unwrap();
ScalarValue::Int64(Some(array.value(row_index)))
}
DataType::Float64 => {
let array = as_float64_array(column).unwrap();
ScalarValue::Float64(Some(array.value(row_index)))
}
DataType::Decimal128(p, s) => {
let array = as_decimal128_array(column).unwrap();
ScalarValue::Decimal128(Some(array.value(row_index)), *p, *s)
}
DataType::Date32 => {
let array = as_date32_array(column).unwrap();
ScalarValue::Date32(Some(array.value(row_index)))
}
DataType::Utf8 => {
let array = as_string_array(column).unwrap();
ScalarValue::Utf8(Some(array.value(row_index).to_string()))
}
other => panic!("unexpected data type in benchmark: {other}"),
}
}

pub async fn transform_actual_result(
result: Vec<RecordBatch>,
n: usize,
) -> Result<Vec<RecordBatch>> {
// to compare the recorded answers to the answers we got back from running the query,
// we need to round the decimal columns and trim the Utf8 columns
// we also need to rewrite the batches to use a compatible schema
let ctx = SessionContext::new();
let fields: Fields = result[0]
.schema()
.fields()
.iter()
.map(|f| {
let simple_name = match f.name().find('.') {
Some(i) => f.name()[i + 1..].to_string(),
_ => f.name().to_string(),
};
f.as_ref().clone().with_name(simple_name)
})
.collect();
let result_schema = SchemaRef::new(Schema::new(fields));
let result = result
.iter()
.map(|b| {
RecordBatch::try_new(result_schema.clone(), b.columns().to_vec())
.map_err(|e| e.into())
})
.collect::<Result<Vec<_>>>()?;
let table = Arc::new(MemTable::try_new(result_schema.clone(), vec![result])?);
let mut df = ctx.read_table(table)?.select(
result_schema
.fields
.iter()
.map(|field| {
match field.data_type() {
DataType::Decimal128(_, _) => {
// if decimal, then round it to 2 decimal places like the answers
// round() doesn't support the second argument for decimal places to round to
// this can be simplified to remove the mul and div when
// https://github.com/apache/arrow-datafusion/issues/2420 is completed
// cast it back to an over-sized Decimal with 2 precision when done rounding
let round = Box::new(
Expr::ScalarFunction(ScalarFunction::new(
datafusion::logical_expr::BuiltinScalarFunction::Round,
vec![col(Field::name(field)).mul(lit(100))],
))
.div(lit(100)),
);
Expr::Cast(Cast::new(round, DataType::Decimal128(15, 2)))
.alias(field.name())
}
DataType::Utf8 => {
// if string, then trim it like the answers got trimmed
trim(col(Field::name(field))).alias(field.name())
}
_ => col(field.name()),
}
})
.collect(),
)?;
if let Some(x) = QUERY_LIMIT[n - 1] {
df = df.limit(0, Some(x))?;
}

let df = df.collect().await?;
Ok(df)
}

pub const QUERY_LIMIT: [Option<usize>; 22] = [
None,
Some(100),
Expand Down

0 comments on commit 43423be

Please sign in to comment.