Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1975,6 +1975,15 @@ pub async fn from_substrait_agg_func(

let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;

// Datafusion does not support aggregate functions with no arguments, so
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// we inject a dummy argument that does not affect the query, but allows
// us to bypass this limitation.
let args = if udaf.name() == "count" && args.is_empty() {
vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
} else {
args
};

Ok(Arc::new(Expr::AggregateFunction(
expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None),
)))
Expand Down Expand Up @@ -2248,11 +2257,19 @@ pub async fn from_window_function(

window_frame.regularize_order_bys(&mut order_by)?;

// Datafusion does not support aggregate functions with no arguments, so
// we inject a dummy argument that does not affect the query, but allows
// us to bypass this limitation.
let args = if fun.name() == "count" && window.arguments.is_empty() {
vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
} else {
from_substrait_func_args(consumer, &window.arguments, input_schema).await?
};

Ok(Expr::WindowFunction(expr::WindowFunction {
fun,
params: WindowFunctionParams {
args: from_substrait_func_args(consumer, &window.arguments, input_schema)
.await?,
args,
partition_by: from_substrait_rex_vec(
consumer,
&window.partitions,
Expand Down Expand Up @@ -3406,4 +3423,31 @@ mod test {

Ok(())
}

#[tokio::test]
async fn window_function_with_count() -> Result<()> {
let substrait = substrait::proto::Expression {
rex_type: Some(substrait::proto::expression::RexType::WindowFunction(
substrait::proto::expression::WindowFunction {
function_reference: 0,
..Default::default()
},
)),
};

let mut consumer = test_consumer();

let mut extensions = Extensions::default();
extensions.register_function("count".to_string());
consumer.extensions = &extensions;

match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
Expr::WindowFunction(window_function) => {
assert_eq!(window_function.params.args.len(), 1)
}
_ => panic!("expr was not a WindowFunction"),
};

Ok(())
}
}
66 changes: 53 additions & 13 deletions datafusion/substrait/tests/cases/consumer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ mod tests {

let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?;
let plan = from_substrait_plan(&ctx.state(), &proto).await?;
ctx.state().create_physical_plan(&plan).await?;
Ok(format!("{}", plan))
}

Expand All @@ -50,9 +51,9 @@ mod tests {
let plan_str = tpch_plan_to_string(1).await?;
assert_eq!(
plan_str,
"Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, sum(LINEITEM.L_QUANTITY) AS SUM_QTY, sum(LINEITEM.L_EXTENDEDPRICE) AS SUM_BASE_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS SUM_DISC_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX) AS SUM_CHARGE, avg(LINEITEM.L_QUANTITY) AS AVG_QTY, avg(LINEITEM.L_EXTENDEDPRICE) AS AVG_PRICE, avg(LINEITEM.L_DISCOUNT) AS AVG_DISC, count() AS COUNT_ORDER\
"Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, sum(LINEITEM.L_QUANTITY) AS SUM_QTY, sum(LINEITEM.L_EXTENDEDPRICE) AS SUM_BASE_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS SUM_DISC_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX) AS SUM_CHARGE, avg(LINEITEM.L_QUANTITY) AS AVG_QTY, avg(LINEITEM.L_EXTENDEDPRICE) AS AVG_PRICE, avg(LINEITEM.L_DISCOUNT) AS AVG_DISC, count(Int64(1)) AS COUNT_ORDER\
\n Sort: LINEITEM.L_RETURNFLAG ASC NULLS LAST, LINEITEM.L_LINESTATUS ASC NULLS LAST\
\n Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY), sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY), avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count()]]\
\n Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY), sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY), avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count(Int64(1))]]\
\n Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, LINEITEM.L_QUANTITY, LINEITEM.L_EXTENDEDPRICE, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT), LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) * (CAST(Int32(1) AS Decimal128(15, 2)) + LINEITEM.L_TAX), LINEITEM.L_DISCOUNT\
\n Filter: LINEITEM.L_SHIPDATE <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 0, milliseconds: 10368000 }\")\
\n TableScan: LINEITEM"
Expand Down Expand Up @@ -119,9 +120,9 @@ mod tests {
let plan_str = tpch_plan_to_string(4).await?;
assert_eq!(
plan_str,
"Projection: ORDERS.O_ORDERPRIORITY, count() AS ORDER_COUNT\
"Projection: ORDERS.O_ORDERPRIORITY, count(Int64(1)) AS ORDER_COUNT\
\n Sort: ORDERS.O_ORDERPRIORITY ASC NULLS LAST\
\n Aggregate: groupBy=[[ORDERS.O_ORDERPRIORITY]], aggr=[[count()]]\
\n Aggregate: groupBy=[[ORDERS.O_ORDERPRIORITY]], aggr=[[count(Int64(1))]]\
\n Projection: ORDERS.O_ORDERPRIORITY\
\n Filter: ORDERS.O_ORDERDATE >= CAST(Utf8(\"1993-07-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1993-10-01\") AS Date32) AND EXISTS (<subquery>)\
\n Subquery:\
Expand Down Expand Up @@ -269,10 +270,10 @@ mod tests {
let plan_str = tpch_plan_to_string(13).await?;
assert_eq!(
plan_str,
"Projection: count(ORDERS.O_ORDERKEY) AS C_COUNT, count() AS CUSTDIST\
\n Sort: count() DESC NULLS FIRST, count(ORDERS.O_ORDERKEY) DESC NULLS FIRST\
\n Projection: count(ORDERS.O_ORDERKEY), count()\
\n Aggregate: groupBy=[[count(ORDERS.O_ORDERKEY)]], aggr=[[count()]]\
"Projection: count(ORDERS.O_ORDERKEY) AS C_COUNT, count(Int64(1)) AS CUSTDIST\
\n Sort: count(Int64(1)) DESC NULLS FIRST, count(ORDERS.O_ORDERKEY) DESC NULLS FIRST\
\n Projection: count(ORDERS.O_ORDERKEY), count(Int64(1))\
\n Aggregate: groupBy=[[count(ORDERS.O_ORDERKEY)]], aggr=[[count(Int64(1))]]\
\n Projection: count(ORDERS.O_ORDERKEY)\
\n Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY]], aggr=[[count(ORDERS.O_ORDERKEY)]]\
\n Projection: CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY\
Expand Down Expand Up @@ -410,10 +411,10 @@ mod tests {
let plan_str = tpch_plan_to_string(21).await?;
assert_eq!(
plan_str,
"Projection: SUPPLIER.S_NAME, count() AS NUMWAIT\
"Projection: SUPPLIER.S_NAME, count(Int64(1)) AS NUMWAIT\
\n Limit: skip=0, fetch=100\
\n Sort: count() DESC NULLS FIRST, SUPPLIER.S_NAME ASC NULLS LAST\
\n Aggregate: groupBy=[[SUPPLIER.S_NAME]], aggr=[[count()]]\
\n Sort: count(Int64(1)) DESC NULLS FIRST, SUPPLIER.S_NAME ASC NULLS LAST\
\n Aggregate: groupBy=[[SUPPLIER.S_NAME]], aggr=[[count(Int64(1))]]\
\n Projection: SUPPLIER.S_NAME\
\n Filter: SUPPLIER.S_SUPPKEY = LINEITEM.L_SUPPKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND ORDERS.O_ORDERSTATUS = Utf8(\"F\") AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE AND EXISTS (<subquery>) AND NOT EXISTS (<subquery>) AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"SAUDI ARABIA\")\
\n Subquery:\
Expand All @@ -438,9 +439,9 @@ mod tests {
let plan_str = tpch_plan_to_string(22).await?;
assert_eq!(
plan_str,
"Projection: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) AS CNTRYCODE, count() AS NUMCUST, sum(CUSTOMER.C_ACCTBAL) AS TOTACCTBAL\
"Projection: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) AS CNTRYCODE, count(Int64(1)) AS NUMCUST, sum(CUSTOMER.C_ACCTBAL) AS TOTACCTBAL\
\n Sort: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) ASC NULLS LAST\
\n Aggregate: groupBy=[[substr(CUSTOMER.C_PHONE,Int32(1),Int32(2))]], aggr=[[count(), sum(CUSTOMER.C_ACCTBAL)]]\
\n Aggregate: groupBy=[[substr(CUSTOMER.C_PHONE,Int32(1),Int32(2))]], aggr=[[count(Int64(1)), sum(CUSTOMER.C_ACCTBAL)]]\
\n Projection: substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)), CUSTOMER.C_ACCTBAL\
\n Filter: (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"13\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"31\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"23\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"29\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"30\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"18\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"17\") AS Utf8)) AND CUSTOMER.C_ACCTBAL > (<subquery>) AND NOT EXISTS (<subquery>)\
\n Subquery:\
Expand All @@ -455,4 +456,43 @@ mod tests {
);
Ok(())
}

async fn test_plan_to_string(name: &str) -> Result<String> {
let path = format!("tests/testdata/test_plans/{name}");
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?;
let plan = from_substrait_plan(&ctx.state(), &proto).await?;
ctx.state().create_physical_plan(&plan).await?;
Ok(format!("{}", plan))
}

#[tokio::test]
async fn test_select_count_from_select_1() -> Result<()> {
let plan_str =
test_plan_to_string("select_count_from_select_1.substrait.json").await?;

assert_eq!(
plan_str,
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\
\n Values: (Int64(0))"
);
Ok(())
}

#[tokio::test]
async fn test_select_window_count() -> Result<()> {
let plan_str = test_plan_to_string("select_window_count.substrait.json").await?;

assert_eq!(
plan_str,
"Projection: count(Int64(1)) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\
\n WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\
\n TableScan: DATA"
);
Ok(())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
{
"extensionUris": [
{
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml"
}
],
"extensions": [
{
"extensionFunction": {
"functionAnchor": 185,
"name": "count:any"
}
}
],
"relations": [
{
"root": {
"input": {
"aggregate": {
"common": {
"direct": {
}
},
"input": {
"read": {
"common": {
"direct": {
}
},
"baseSchema": {
"names": [
"dummy"
],
"struct": {
"types": [
{
"i64": {
"nullability": "NULLABILITY_REQUIRED"
}
}
],
"nullability": "NULLABILITY_REQUIRED"
}
},
"virtualTable": {
"values": [
{
"fields": [
{
"i64": "0",
"nullable": false
}
]
}
]
}
}
},
"groupings": [
{
"groupingExpressions": [],
"expressionReferences": []
}
],
"measures": [
{
"measure": {
"functionReference": 185,
"args": [],
"sorts": [],
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
"outputType": {
"i64": {
"nullability": "NULLABILITY_REQUIRED"
}
},
"invocation": "AGGREGATION_INVOCATION_ALL",
"arguments": [],
"options": []
}
}
],
"groupingExpressions": []
}
},
"names": [
"count(*)"
]
}
}
]
}
Loading