Skip to content

Commit

Permalink
fix(core): make aggregate functions and window functions support zero…
Browse files Browse the repository at this point in the history
… argument (#943)
  • Loading branch information
grieve54706 authored Nov 26, 2024
1 parent c5a9d41 commit 90082c7
Showing 1 changed file with 52 additions and 2 deletions.
54 changes: 52 additions & 2 deletions wren-core/core/src/mdl/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,13 @@ impl ByPassAggregateUDF {
Self {
name: name.to_string(),
return_type,
signature: Signature::variadic_any(Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::VariadicAny,
TypeSignature::Uniform(0, vec![]),
],
Volatility::Volatile,
),
}
}
}
Expand Down Expand Up @@ -158,7 +164,13 @@ impl ByPassWindowFunction {
Self {
name: name.to_string(),
return_type,
signature: Signature::variadic_any(Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::VariadicAny,
TypeSignature::Uniform(0, vec![]),
],
Volatility::Volatile,
),
}
}
}
Expand Down Expand Up @@ -214,6 +226,14 @@ mod test {
.into_unoptimized_plan();
let expected = "Projection: date_diff(Int64(1), Int64(2))\n EmptyRelation";
assert_eq!(format!("{plan}"), expected);

ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new(
"today",
DataType::Utf8,
)));
let plan_2 = ctx.sql("SELECT today()").await?.into_unoptimized_plan();
assert_eq!(format!("{plan_2}"), "Projection: today()\n EmptyRelation");

Ok(())
}

Expand All @@ -230,6 +250,23 @@ mod test {
\n Projection: column1 AS c1, column2 AS c2\
\n Values: (Int64(1), Int64(2)), (Int64(2), Int64(3)), (Int64(3), Int64(4))";
assert_eq!(format!("{plan}"), expected);

ctx.register_udaf(AggregateUDF::new_from_impl(ByPassAggregateUDF::new(
"total_count",
DataType::Int64,
)));
let plan_2 = ctx
.sql("SELECT total_count() AS total_count FROM (VALUES (1), (2), (3)) AS val(x)")
.await?
.into_unoptimized_plan();
assert_eq!(
format!("{plan_2}"),
"Projection: total_count() AS total_count\
\n Aggregate: groupBy=[[]], aggr=[[total_count()]]\
\n SubqueryAlias: val\n Projection: column1 AS x\
\n Values: (Int64(1)), (Int64(2)), (Int64(3))"
);

Ok(())
}

Expand All @@ -247,6 +284,19 @@ mod test {
\n WindowAggr: windowExpr=[[custom_window(Int64(1), Int64(2)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
\n EmptyRelation";
assert_eq!(format!("{plan}"), expected);

ctx.register_udwf(WindowUDF::new_from_impl(ByPassWindowFunction::new(
"cume_dist",
DataType::Int64,
)));
let plan_2 = ctx
.sql("SELECT cume_dist() OVER ()")
.await?
.into_unoptimized_plan();
assert_eq!(format!("{plan_2}"), "Projection: cume_dist() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\
\n WindowAggr: windowExpr=[[cume_dist() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
\n EmptyRelation");

Ok(())
}
}

0 comments on commit 90082c7

Please sign in to comment.