Skip to content

Commit

Permalink
infer return type for udaf
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Chien <stdrc@outlook.com>
  • Loading branch information
stdrc committed Aug 22, 2024
1 parent 2fa8034 commit bde6578
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/frontend/src/expr/window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ use itertools::Itertools;
use risingwave_common::bail_not_implemented;
use risingwave_common::types::DataType;
use risingwave_expr::aggregate::AggKind;
use risingwave_expr::sig::FUNCTION_REGISTRY;
use risingwave_expr::window_function::{Frame, WindowFuncKind};

use super::{Expr, ExprImpl, OrderBy, RwResult};
use crate::error::{ErrorCode, RwError};
use crate::expr::infer_type;

/// A window function performs a calculation across a set of table rows that are somehow related to
/// the current row, according to the window spec `OVER (PARTITION BY .. ORDER BY ..)`.
Expand All @@ -45,10 +45,10 @@ impl WindowFunction {
kind: WindowFuncKind,
partition_by: Vec<ExprImpl>,
order_by: OrderBy,
args: Vec<ExprImpl>,
mut args: Vec<ExprImpl>,
frame: Option<Frame>,
) -> RwResult<Self> {
let return_type = Self::infer_return_type(&kind, &args)?;
let return_type = Self::infer_return_type(&kind, &mut args)?;
Ok(Self {
kind,
args,
Expand All @@ -59,7 +59,7 @@ impl WindowFunction {
})
}

fn infer_return_type(kind: &WindowFuncKind, args: &[ExprImpl]) -> RwResult<DataType> {
fn infer_return_type(kind: &WindowFuncKind, args: &mut [ExprImpl]) -> RwResult<DataType> {
use WindowFuncKind::*;
match (kind, args) {
(RowNumber, []) => Ok(DataType::Int64),
Expand Down Expand Up @@ -87,13 +87,13 @@ impl WindowFunction {
);
}

(Aggregate(AggKind::Builtin(agg_kind)), args) => {
let arg_types = args.iter().map(ExprImpl::return_type).collect::<Vec<_>>();
let return_type = FUNCTION_REGISTRY.get_return_type(*agg_kind, &arg_types)?;
Ok(return_type)
}
(Aggregate(agg_kind), args) => Ok(match agg_kind {
AggKind::Builtin(kind) => infer_type((*kind).into(), args)?,
AggKind::UserDefined(udf) => udf.return_type.as_ref().unwrap().into(),
AggKind::WrapScalar(expr) => expr.return_type.as_ref().unwrap().into(),
}),

_ => {
(_, args) => {
let args = args
.iter()
.map(|e| format!("{}", e.return_type()))
Expand Down

0 comments on commit bde6578

Please sign in to comment.