Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] Cleanup tpch benchmark #6609

Merged
merged 2 commits into from
Jun 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 1 addition & 298 deletions benchmarks/src/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, ArrayRef};
use arrow::datatypes::{Fields, SchemaBuilder, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow::datatypes::SchemaBuilder;
use std::fs;
use std::ops::{Div, Mul};
use std::path::Path;
use std::sync::Arc;
use std::time::Instant;

use datafusion::common::cast::{
as_date32_array, as_decimal128_array, as_float64_array, as_int32_array,
as_int64_array, as_string_array,
};
use datafusion::common::ScalarValue;
use datafusion::logical_expr::expr::ScalarFunction;
use datafusion::logical_expr::Cast;
use datafusion::prelude::*;
use datafusion::{
arrow::datatypes::{DataType, Field, Schema},
datasource::MemTable,
error::{DataFusionError, Result},
};
use parquet::basic::Compression;
Expand Down Expand Up @@ -147,155 +135,6 @@ pub fn get_tpch_table_schema(table: &str) -> Schema {
}
}

/// Get the expected schema for the results of a query
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code is not used it seems. However, clippy didn't complain as it was declared pub.

https://github.com/search?q=repo%3Aapache%2Farrow-datafusion%20get_answer_schema&type=code

pub fn get_answer_schema(n: usize) -> Schema {
match n {
1 => Schema::new(vec![
Field::new("l_returnflag", DataType::Utf8, true),
Field::new("l_linestatus", DataType::Utf8, true),
Field::new("sum_qty", DataType::Decimal128(15, 2), true),
Field::new("sum_base_price", DataType::Decimal128(15, 2), true),
Field::new("sum_disc_price", DataType::Decimal128(15, 2), true),
Field::new("sum_charge", DataType::Decimal128(15, 2), true),
Field::new("avg_qty", DataType::Decimal128(15, 2), true),
Field::new("avg_price", DataType::Decimal128(15, 2), true),
Field::new("avg_disc", DataType::Decimal128(15, 2), true),
Field::new("count_order", DataType::Int64, true),
]),

2 => Schema::new(vec![
Field::new("s_acctbal", DataType::Decimal128(15, 2), true),
Field::new("s_name", DataType::Utf8, true),
Field::new("n_name", DataType::Utf8, true),
Field::new("p_partkey", DataType::Int64, true),
Field::new("p_mfgr", DataType::Utf8, true),
Field::new("s_address", DataType::Utf8, true),
Field::new("s_phone", DataType::Utf8, true),
Field::new("s_comment", DataType::Utf8, true),
]),

3 => Schema::new(vec![
Field::new("l_orderkey", DataType::Int64, true),
Field::new("revenue", DataType::Decimal128(15, 2), true),
Field::new("o_orderdate", DataType::Date32, true),
Field::new("o_shippriority", DataType::Int32, true),
]),

4 => Schema::new(vec![
Field::new("o_orderpriority", DataType::Utf8, true),
Field::new("order_count", DataType::Int64, true),
]),

5 => Schema::new(vec![
Field::new("n_name", DataType::Utf8, true),
Field::new("revenue", DataType::Decimal128(15, 2), true),
]),

6 => Schema::new(vec![Field::new(
"revenue",
DataType::Decimal128(15, 2),
true,
)]),

7 => Schema::new(vec![
Field::new("supp_nation", DataType::Utf8, true),
Field::new("cust_nation", DataType::Utf8, true),
Field::new("l_year", DataType::Float64, true),
Field::new("revenue", DataType::Decimal128(15, 2), true),
]),

8 => Schema::new(vec![
Field::new("o_year", DataType::Float64, true),
Field::new("mkt_share", DataType::Decimal128(15, 2), true),
]),

9 => Schema::new(vec![
Field::new("nation", DataType::Utf8, true),
Field::new("o_year", DataType::Float64, true),
Field::new("sum_profit", DataType::Decimal128(15, 2), true),
]),

10 => Schema::new(vec![
Field::new("c_custkey", DataType::Int64, true),
Field::new("c_name", DataType::Utf8, true),
Field::new("revenue", DataType::Decimal128(15, 2), true),
Field::new("c_acctbal", DataType::Decimal128(15, 2), true),
Field::new("n_name", DataType::Utf8, true),
Field::new("c_address", DataType::Utf8, true),
Field::new("c_phone", DataType::Utf8, true),
Field::new("c_comment", DataType::Utf8, true),
]),

11 => Schema::new(vec![
Field::new("ps_partkey", DataType::Int64, true),
Field::new("value", DataType::Decimal128(15, 2), true),
]),

12 => Schema::new(vec![
Field::new("l_shipmode", DataType::Utf8, true),
Field::new("high_line_count", DataType::Int64, true),
Field::new("low_line_count", DataType::Int64, true),
]),

13 => Schema::new(vec![
Field::new("c_count", DataType::Int64, true),
Field::new("custdist", DataType::Int64, true),
]),

14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]),

15 => Schema::new(vec![
Field::new("s_suppkey", DataType::Int64, true),
Field::new("s_name", DataType::Utf8, true),
Field::new("s_address", DataType::Utf8, true),
Field::new("s_phone", DataType::Utf8, true),
Field::new("total_revenue", DataType::Decimal128(15, 2), true),
]),

16 => Schema::new(vec![
Field::new("p_brand", DataType::Utf8, true),
Field::new("p_type", DataType::Utf8, true),
Field::new("p_size", DataType::Int32, true),
Field::new("supplier_cnt", DataType::Int64, true),
]),

17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]),

18 => Schema::new(vec![
Field::new("c_name", DataType::Utf8, true),
Field::new("c_custkey", DataType::Int64, true),
Field::new("o_orderkey", DataType::Int64, true),
Field::new("o_orderdate", DataType::Date32, true),
Field::new("o_totalprice", DataType::Decimal128(15, 2), true),
Field::new("sum_l_quantity", DataType::Decimal128(15, 2), true),
]),

19 => Schema::new(vec![Field::new(
"revenue",
DataType::Decimal128(15, 2),
true,
)]),

20 => Schema::new(vec![
Field::new("s_name", DataType::Utf8, true),
Field::new("s_address", DataType::Utf8, true),
]),

21 => Schema::new(vec![
Field::new("s_name", DataType::Utf8, true),
Field::new("numwait", DataType::Int64, true),
]),

22 => Schema::new(vec![
Field::new("cntrycode", DataType::Utf8, true),
Field::new("numcust", DataType::Int64, true),
Field::new("totacctbal", DataType::Decimal128(15, 2), true),
]),

_ => unimplemented!(),
}
}

/// Get the SQL statements from the specified query file
pub fn get_query_sql(query: usize) -> Result<Vec<String>> {
if query > 0 && query < 23 {
Expand Down Expand Up @@ -399,142 +238,6 @@ pub async fn convert_tbl(
Ok(())
}

/// Converts the results into a 2d array of strings, `result[row][column]`
/// Special cases nulls to NULL for testing
pub fn result_vec(results: &[RecordBatch]) -> Vec<Vec<ScalarValue>> {
let mut result = vec![];
for batch in results {
for row_index in 0..batch.num_rows() {
let row_vec = batch
.columns()
.iter()
.map(|column| col_to_scalar(column, row_index))
.collect();
result.push(row_vec);
}
}
result
}

/// convert expected schema to all utf8 so columns can be read as strings to be parsed separately
/// this is due to the fact that the csv parser cannot handle leading/trailing spaces
pub fn string_schema(schema: Schema) -> Schema {
Schema::new(
schema
.fields()
.iter()
.map(|field| {
Field::new(
Field::name(field),
DataType::Utf8,
Field::is_nullable(field),
)
})
.collect::<Vec<Field>>(),
)
}

fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue {
if column.is_null(row_index) {
return ScalarValue::Null;
}
match column.data_type() {
DataType::Int32 => {
let array = as_int32_array(column).unwrap();
ScalarValue::Int32(Some(array.value(row_index)))
}
DataType::Int64 => {
let array = as_int64_array(column).unwrap();
ScalarValue::Int64(Some(array.value(row_index)))
}
DataType::Float64 => {
let array = as_float64_array(column).unwrap();
ScalarValue::Float64(Some(array.value(row_index)))
}
DataType::Decimal128(p, s) => {
let array = as_decimal128_array(column).unwrap();
ScalarValue::Decimal128(Some(array.value(row_index)), *p, *s)
}
DataType::Date32 => {
let array = as_date32_array(column).unwrap();
ScalarValue::Date32(Some(array.value(row_index)))
}
DataType::Utf8 => {
let array = as_string_array(column).unwrap();
ScalarValue::Utf8(Some(array.value(row_index).to_string()))
}
other => panic!("unexpected data type in benchmark: {other}"),
}
}

pub async fn transform_actual_result(
result: Vec<RecordBatch>,
n: usize,
) -> Result<Vec<RecordBatch>> {
// to compare the recorded answers to the answers we got back from running the query,
// we need to round the decimal columns and trim the Utf8 columns
// we also need to rewrite the batches to use a compatible schema
let ctx = SessionContext::new();
let fields: Fields = result[0]
.schema()
.fields()
.iter()
.map(|f| {
let simple_name = match f.name().find('.') {
Some(i) => f.name()[i + 1..].to_string(),
_ => f.name().to_string(),
};
f.as_ref().clone().with_name(simple_name)
})
.collect();
let result_schema = SchemaRef::new(Schema::new(fields));
let result = result
.iter()
.map(|b| {
RecordBatch::try_new(result_schema.clone(), b.columns().to_vec())
.map_err(|e| e.into())
})
.collect::<Result<Vec<_>>>()?;
let table = Arc::new(MemTable::try_new(result_schema.clone(), vec![result])?);
let mut df = ctx.read_table(table)?.select(
result_schema
.fields
.iter()
.map(|field| {
match field.data_type() {
DataType::Decimal128(_, _) => {
// if decimal, then round it to 2 decimal places like the answers
// round() doesn't support the second argument for decimal places to round to
// this can be simplified to remove the mul and div when
// https://github.com/apache/arrow-datafusion/issues/2420 is completed
// cast it back to an over-sized Decimal with 2 precision when done rounding
let round = Box::new(
Expr::ScalarFunction(ScalarFunction::new(
datafusion::logical_expr::BuiltinScalarFunction::Round,
vec![col(Field::name(field)).mul(lit(100))],
))
.div(lit(100)),
);
Expr::Cast(Cast::new(round, DataType::Decimal128(15, 2)))
.alias(field.name())
}
DataType::Utf8 => {
// if string, then trim it like the answers got trimmed
trim(col(Field::name(field))).alias(field.name())
}
_ => col(field.name()),
}
})
.collect(),
)?;
if let Some(x) = QUERY_LIMIT[n - 1] {
df = df.limit(0, Some(x))?;
}

let df = df.collect().await?;
Ok(df)
}

pub const QUERY_LIMIT: [Option<usize>; 22] = [
None,
Some(100),
Expand Down