Skip to content

Commit 9b0908b

Browse files
committed
Improve speed of first_value by implementing special GroupsAccumulator
1 parent b337fbc commit 9b0908b

File tree

5 files changed

+1048
-80
lines changed

5 files changed

+1048
-80
lines changed

datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs

Lines changed: 145 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::str;
1918
use std::sync::Arc;
2019

2120
use crate::fuzz_cases::aggregation_fuzzer::{
22-
AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig, QueryBuilder,
21+
AggregationFuzzerBuilder, ColumnDescr,
22+
DatasetGeneratorConfig, QueryBuilder,
2323
};
2424

2525
use arrow::array::{types::Int64Type, Array, ArrayRef, AsArray, Int64Array, RecordBatch};
@@ -88,6 +88,141 @@ async fn test_min() {
8888
.await;
8989
}
9090

91+
#[tokio::test(flavor = "multi_thread")]
92+
async fn test_first_val() {
93+
let mut data_gen_config: DatasetGeneratorConfig = baseline_config();
94+
95+
// for ele in data_gen_config.columns.iter_mut() {
96+
// if ele.get_max_num_distinct().is_none() {
97+
// ele.with_max_num_distinct(usize::MAX);
98+
// }
99+
// }
100+
101+
for i in 0..data_gen_config.columns.len() {
102+
// data_gen_config.columns[i]
103+
if data_gen_config.columns[i].get_max_num_distinct().is_none() {
104+
data_gen_config.columns[i] = data_gen_config.columns[i]
105+
.clone()
106+
.with_max_num_distinct(usize::MAX);
107+
}
108+
}
109+
110+
// data_gen_config.columns.iter_mut().for_each(|ele| {
111+
// if ele.get_max_num_distinct().is_none() {
112+
// ele.with_max_num_distinct(usize::MAX);
113+
// }
114+
// });
115+
116+
// Queries like SELECT max(a) FROM fuzz_table GROUP BY b
117+
let query_builder = QueryBuilder::new()
118+
.with_table_name("fuzz_table")
119+
.with_aggregate_function("first_value")
120+
// max works on all column types
121+
.with_aggregate_arguments(data_gen_config.all_columns())
122+
.set_group_by_columns(
123+
data_gen_config.all_columns(), // .into_iter()
124+
// .filter(|x| !x.contains("u8")),
125+
);
126+
127+
AggregationFuzzerBuilder::from(data_gen_config)
128+
.add_query_builder(query_builder)
129+
.build()
130+
.run()
131+
.await;
132+
}
133+
134+
// #[tokio::test(flavor = "multi_thread")]
135+
// async fn test_bad_case() {
136+
// let sql="SELECT u8_low, first_value(i16 order by u8_low DESC, float64 DESC, large_binary DESC,utf8_low DESC,interval_year_month ASC) RESPECT NULLS as col2 FROM parquet_table GROUP BY u8_low
137+
138+
// ";
139+
140+
// let (baseline_ctx, random_ctx) =
141+
// gen_ctx(baseline_config(), "/tmp/SYuFrB4.parquet").await;
142+
143+
// // {let explain = "explain ".to_owned() + sql;
144+
// // let baseline_plan = run_sql(&explain, &baseline_ctx).await.unwrap();
145+
// // let plan = run_sql(&explain, &random_ctx).await.unwrap();
146+
// // println!("{}", pretty_format_batches(&baseline_plan).unwrap());
147+
// // println!("{}", pretty_format_batches(&plan).unwrap());
148+
// // }
149+
// let baseline_res = run_sql(sql, &baseline_ctx).await.unwrap();
150+
151+
// println!("<==========================================================>");
152+
// println!("<==========================================================>");
153+
// println!("<==========================================================>");
154+
// println!("<==========================================================>");
155+
156+
// let got = run_sql(sql, &random_ctx).await.unwrap();
157+
158+
// // println!("{}", pretty_format_batches(&baseline_res).unwrap());
159+
160+
// println!("\n*********************\n");
161+
162+
// let baseline_count = baseline_res.iter().map(|x| x.num_rows()).sum::<usize>();
163+
164+
// let got_count = got.iter().map(|x| x.num_rows()).sum::<usize>();
165+
166+
// if baseline_count != got_count {
167+
// panic!(
168+
// "baseline_res.len()!=got.len() {} {}",
169+
// baseline_count, got_count
170+
// );
171+
// }
172+
173+
// check_equality_of_batches(&baseline_res, &got).unwrap();
174+
// }
175+
176+
// #[tokio::test]
177+
// async fn test_dev_first() -> Result<()> {
178+
// let ctx = SessionContext::new();
179+
180+
// ctx.sql(
181+
// "CREATE TABLE array_agg_order_list_table AS VALUES
182+
// ('w', 2, [1,2,3], 10),
183+
// ('w', 1, [9,5,2], 20),
184+
// ('w', 1, [3,2,5], 30),
185+
// ('b', 2, [4,5,6], 20),
186+
// ('b', 1, [7,8,9], 30)
187+
// ;
188+
// ",
189+
// )
190+
// .await
191+
// .unwrap();
192+
193+
// ctx.sql("select column1, first_value(column3 order by column2, column4 desc) from array_agg_order_list_table group by column1;").and_then(|x|async move{x.show().await}).await.map_err(|x|{
194+
// println!("");
195+
// eprintln!("{x}");
196+
// x
197+
// })?;
198+
199+
// // RUST_BACKTRACE=1 cargo test --features=backtrace --package datafusion --test fuzz -- fuzz_cases::aggregate_fuzz::test_first_val --exact --nocapture
200+
201+
// Ok(())
202+
// }
203+
204+
#[tokio::test]
205+
async fn test_get_backtrace_for_failed_code() -> Result<()> {
206+
let ctx = SessionContext::new();
207+
208+
let sql = "select row_numer() over (partition by a order by a) from (select 1 a);";
209+
210+
match ctx.sql(sql).await {
211+
Ok(result) => result.show().await?,
212+
Err(e) => {
213+
eprintln!("{e}");
214+
}
215+
};
216+
217+
// $ RUST_BACKTRACE=1 cargo test --features=backtrace --package datafusion --lib -- fuzz_cases::aggregate_fuzz::test_get_backtrace_for_failed_code --exact --nocapture
218+
219+
// RUST_BACKTRACE=1 cargo test --features=backtrace --package datafusion --test fuzz -- fuzz_cases::aggregate_fuzz::test_dev_first --exact --nocapture
220+
221+
// RUST_BACKTRACE=1 cargo test --features=backtrace --package datafusion --test fuzz -- fuzz_cases::aggregate_fuzz::test_get_backtrace_for_failed_code --exact --nocapture
222+
223+
Ok(())
224+
}
225+
91226
#[tokio::test(flavor = "multi_thread")]
92227
async fn test_max() {
93228
let data_gen_config = baseline_config();
@@ -216,27 +351,25 @@ fn baseline_config() -> DatasetGeneratorConfig {
216351
"interval_day_time",
217352
DataType::Interval(IntervalUnit::DayTime),
218353
),
219-
ColumnDescr::new(
220-
"interval_month_day_nano",
221-
DataType::Interval(IntervalUnit::MonthDayNano),
222-
),
354+
// ColumnDescr::new(
355+
// "interval_month_day_nano",
356+
// DataType::Interval(IntervalUnit::MonthDayNano),
357+
// ),
223358
// begin decimal columns
224359
ColumnDescr::new("decimal128", {
225360
// Generate valid precision and scale for Decimal128 randomly.
226361
let precision: u8 = rng.gen_range(1..=DECIMAL128_MAX_PRECISION);
227362
// It's safe to cast `precision` to i8 type directly.
228-
let scale: i8 = rng.gen_range(
229-
i8::MIN..=std::cmp::min(precision as i8, DECIMAL128_MAX_SCALE),
230-
);
363+
let scale: i8 =
364+
rng.gen_range(0..=std::cmp::min(precision as i8, DECIMAL128_MAX_SCALE));
231365
DataType::Decimal128(precision, scale)
232366
}),
233367
ColumnDescr::new("decimal256", {
234368
// Generate valid precision and scale for Decimal256 randomly.
235369
let precision: u8 = rng.gen_range(1..=DECIMAL256_MAX_PRECISION);
236370
// It's safe to cast `precision` to i8 type directly.
237-
let scale: i8 = rng.gen_range(
238-
i8::MIN..=std::cmp::min(precision as i8, DECIMAL256_MAX_SCALE),
239-
);
371+
let scale: i8 =
372+
rng.gen_range(0..=std::cmp::min(precision as i8, DECIMAL256_MAX_SCALE));
240373
DataType::Decimal256(precision, scale)
241374
}),
242375
// begin string columns

datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ impl ColumnDescr {
228228
}
229229
}
230230

231+
pub fn get_max_num_distinct(&self) -> Option<usize> {
232+
self.max_num_distinct
233+
}
234+
231235
/// set the maximum number of distinct values in this column
232236
///
233237
/// If `None`, the number of distinct values is randomly selected between 1

0 commit comments

Comments
 (0)