diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a23aad3d7237..61f3379735c7 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1975,6 +1975,15 @@ pub async fn from_substrait_agg_func( let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; + // Datafusion does not support aggregate functions with no arguments, so + // we inject a dummy argument that does not affect the query, but allows + // us to bypass this limitation. + let args = if udaf.name() == "count" && args.is_empty() { + vec![Expr::Literal(ScalarValue::Int64(Some(1)))] + } else { + args + }; + Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None), ))) @@ -2248,11 +2257,19 @@ pub async fn from_window_function( window_frame.regularize_order_bys(&mut order_by)?; + // Datafusion does not support aggregate functions with no arguments, so + // we inject a dummy argument that does not affect the query, but allows + // us to bypass this limitation. + let args = if fun.name() == "count" && window.arguments.is_empty() { + vec![Expr::Literal(ScalarValue::Int64(Some(1)))] + } else { + from_substrait_func_args(consumer, &window.arguments, input_schema).await? + }; + Ok(Expr::WindowFunction(expr::WindowFunction { fun, params: WindowFunctionParams { - args: from_substrait_func_args(consumer, &window.arguments, input_schema) - .await?, + args, partition_by: from_substrait_rex_vec( consumer, &window.partitions, @@ -3406,4 +3423,31 @@ mod test { Ok(()) } + + #[tokio::test] + async fn window_function_with_count() -> Result<()> { + let substrait = substrait::proto::Expression { + rex_type: Some(substrait::proto::expression::RexType::WindowFunction( + substrait::proto::expression::WindowFunction { + function_reference: 0, + ..Default::default() + }, + )), + }; + + let mut consumer = test_consumer(); + + let mut extensions = Extensions::default(); + extensions.register_function("count".to_string()); + consumer.extensions = &extensions; + + match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { + Expr::WindowFunction(window_function) => { + assert_eq!(window_function.params.args.len(), 1) + } + _ => panic!("expr was not a WindowFunction"), + }; + + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 1b2b570063a2..1f1a15abb837 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -42,6 +42,7 @@ mod tests { let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; let plan = from_substrait_plan(&ctx.state(), &proto).await?; + ctx.state().create_physical_plan(&plan).await?; Ok(format!("{}", plan)) } @@ -50,9 +51,9 @@ mod tests { let plan_str = tpch_plan_to_string(1).await?; assert_eq!( plan_str, - "Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, sum(LINEITEM.L_QUANTITY) AS SUM_QTY, sum(LINEITEM.L_EXTENDEDPRICE) AS SUM_BASE_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS SUM_DISC_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX) AS SUM_CHARGE, avg(LINEITEM.L_QUANTITY) AS AVG_QTY, avg(LINEITEM.L_EXTENDEDPRICE) AS AVG_PRICE, avg(LINEITEM.L_DISCOUNT) AS AVG_DISC, count() AS COUNT_ORDER\ + "Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, sum(LINEITEM.L_QUANTITY) AS SUM_QTY, sum(LINEITEM.L_EXTENDEDPRICE) AS SUM_BASE_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS SUM_DISC_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX) AS SUM_CHARGE, avg(LINEITEM.L_QUANTITY) AS AVG_QTY, avg(LINEITEM.L_EXTENDEDPRICE) AS AVG_PRICE, avg(LINEITEM.L_DISCOUNT) AS AVG_DISC, count(Int64(1)) AS COUNT_ORDER\ \n Sort: LINEITEM.L_RETURNFLAG ASC NULLS LAST, LINEITEM.L_LINESTATUS ASC NULLS LAST\ - \n Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY), sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY), avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count()]]\ + \n Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY), sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY), avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count(Int64(1))]]\ \n Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, LINEITEM.L_QUANTITY, LINEITEM.L_EXTENDEDPRICE, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT), LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) * (CAST(Int32(1) AS Decimal128(15, 2)) + LINEITEM.L_TAX), LINEITEM.L_DISCOUNT\ \n Filter: LINEITEM.L_SHIPDATE <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 0, milliseconds: 10368000 }\")\ \n TableScan: LINEITEM" @@ -119,9 +120,9 @@ mod tests { let plan_str = tpch_plan_to_string(4).await?; assert_eq!( plan_str, - "Projection: ORDERS.O_ORDERPRIORITY, count() AS ORDER_COUNT\ + "Projection: ORDERS.O_ORDERPRIORITY, count(Int64(1)) AS ORDER_COUNT\ \n Sort: ORDERS.O_ORDERPRIORITY ASC NULLS LAST\ - \n Aggregate: groupBy=[[ORDERS.O_ORDERPRIORITY]], aggr=[[count()]]\ + \n Aggregate: groupBy=[[ORDERS.O_ORDERPRIORITY]], aggr=[[count(Int64(1))]]\ \n Projection: ORDERS.O_ORDERPRIORITY\ \n Filter: ORDERS.O_ORDERDATE >= CAST(Utf8(\"1993-07-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1993-10-01\") AS Date32) AND EXISTS ()\ \n Subquery:\ @@ -269,10 +270,10 @@ mod tests { let plan_str = tpch_plan_to_string(13).await?; assert_eq!( plan_str, - "Projection: count(ORDERS.O_ORDERKEY) AS C_COUNT, count() AS CUSTDIST\ - \n Sort: count() DESC NULLS FIRST, count(ORDERS.O_ORDERKEY) DESC NULLS FIRST\ - \n Projection: count(ORDERS.O_ORDERKEY), count()\ - \n Aggregate: groupBy=[[count(ORDERS.O_ORDERKEY)]], aggr=[[count()]]\ + "Projection: count(ORDERS.O_ORDERKEY) AS C_COUNT, count(Int64(1)) AS CUSTDIST\ + \n Sort: count(Int64(1)) DESC NULLS FIRST, count(ORDERS.O_ORDERKEY) DESC NULLS FIRST\ + \n Projection: count(ORDERS.O_ORDERKEY), count(Int64(1))\ + \n Aggregate: groupBy=[[count(ORDERS.O_ORDERKEY)]], aggr=[[count(Int64(1))]]\ \n Projection: count(ORDERS.O_ORDERKEY)\ \n Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY]], aggr=[[count(ORDERS.O_ORDERKEY)]]\ \n Projection: CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY\ @@ -410,10 +411,10 @@ mod tests { let plan_str = tpch_plan_to_string(21).await?; assert_eq!( plan_str, - "Projection: SUPPLIER.S_NAME, count() AS NUMWAIT\ + "Projection: SUPPLIER.S_NAME, count(Int64(1)) AS NUMWAIT\ \n Limit: skip=0, fetch=100\ - \n Sort: count() DESC NULLS FIRST, SUPPLIER.S_NAME ASC NULLS LAST\ - \n Aggregate: groupBy=[[SUPPLIER.S_NAME]], aggr=[[count()]]\ + \n Sort: count(Int64(1)) DESC NULLS FIRST, SUPPLIER.S_NAME ASC NULLS LAST\ + \n Aggregate: groupBy=[[SUPPLIER.S_NAME]], aggr=[[count(Int64(1))]]\ \n Projection: SUPPLIER.S_NAME\ \n Filter: SUPPLIER.S_SUPPKEY = LINEITEM.L_SUPPKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND ORDERS.O_ORDERSTATUS = Utf8(\"F\") AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE AND EXISTS () AND NOT EXISTS () AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"SAUDI ARABIA\")\ \n Subquery:\ @@ -438,9 +439,9 @@ mod tests { let plan_str = tpch_plan_to_string(22).await?; assert_eq!( plan_str, - "Projection: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) AS CNTRYCODE, count() AS NUMCUST, sum(CUSTOMER.C_ACCTBAL) AS TOTACCTBAL\ + "Projection: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) AS CNTRYCODE, count(Int64(1)) AS NUMCUST, sum(CUSTOMER.C_ACCTBAL) AS TOTACCTBAL\ \n Sort: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) ASC NULLS LAST\ - \n Aggregate: groupBy=[[substr(CUSTOMER.C_PHONE,Int32(1),Int32(2))]], aggr=[[count(), sum(CUSTOMER.C_ACCTBAL)]]\ + \n Aggregate: groupBy=[[substr(CUSTOMER.C_PHONE,Int32(1),Int32(2))]], aggr=[[count(Int64(1)), sum(CUSTOMER.C_ACCTBAL)]]\ \n Projection: substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)), CUSTOMER.C_ACCTBAL\ \n Filter: (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"13\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"31\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"23\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"29\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"30\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"18\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"17\") AS Utf8)) AND CUSTOMER.C_ACCTBAL > () AND NOT EXISTS ()\ \n Subquery:\ @@ -455,4 +456,43 @@ mod tests { ); Ok(()) } + + async fn test_plan_to_string(name: &str) -> Result { + let path = format!("tests/testdata/test_plans/{name}"); + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; + let plan = from_substrait_plan(&ctx.state(), &proto).await?; + ctx.state().create_physical_plan(&plan).await?; + Ok(format!("{}", plan)) + } + + #[tokio::test] + async fn test_select_count_from_select_1() -> Result<()> { + let plan_str = + test_plan_to_string("select_count_from_select_1.substrait.json").await?; + + assert_eq!( + plan_str, + "Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\ + \n Values: (Int64(0))" + ); + Ok(()) + } + + #[tokio::test] + async fn test_select_window_count() -> Result<()> { + let plan_str = test_plan_to_string("select_window_count.substrait.json").await?; + + assert_eq!( + plan_str, + "Projection: count(Int64(1)) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\ + \n WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n TableScan: DATA" + ); + Ok(()) + } } diff --git a/datafusion/substrait/tests/testdata/test_plans/select_count_from_select_1.substrait.json b/datafusion/substrait/tests/testdata/test_plans/select_count_from_select_1.substrait.json new file mode 100644 index 000000000000..e9f679588018 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/select_count_from_select_1.substrait.json @@ -0,0 +1,92 @@ +{ + "extensionUris": [ + { + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 185, + "name": "count:any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "dummy" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [ + { + "fields": [ + { + "i64": "0", + "nullable": false + } + ] + } + ] + } + } + }, + "groupings": [ + { + "groupingExpressions": [], + "expressionReferences": [] + } + ], + "measures": [ + { + "measure": { + "functionReference": 185, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [], + "options": [] + } + } + ], + "groupingExpressions": [] + } + }, + "names": [ + "count(*)" + ] + } + } + ] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/select_window_count.substrait.json b/datafusion/substrait/tests/testdata/test_plans/select_window_count.substrait.json new file mode 100644 index 000000000000..5b50145e13d6 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/select_window_count.substrait.json @@ -0,0 +1,137 @@ +{ + "extensionUris": [ + { + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 185, + "name": "count:any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 3 + ] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "D", + "PART", + "ORD" + ], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "DATA" + ] + } + } + }, + "expressions": [ + { + "windowFunction": { + "functionReference": 185, + "partitions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + ], + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + } + ], + "upperBound": { + "unbounded": { + } + }, + "lowerBound": { + "preceding": { + "offset": "1" + } + }, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "args": [], + "arguments": [], + "invocation": "AGGREGATION_INVOCATION_ALL", + "options": [], + "boundsType": "BOUNDS_TYPE_ROWS" + } + } + ] + } + }, + "names": [ + "LEAD_EXPR" + ] + } + } + ], + "expectedTypeUrls": [] +}