diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index de782f7e02110..2f315f6491673 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -40,6 +40,7 @@ impl Binder { Expr::TypedString { data_type, value } => { let s: ExprImpl = self.bind_string(value)?.into(); s.cast_explicit(bind_data_type(&data_type)?) + .map_err(Into::into) } Expr::Row(exprs) => self.bind_row(exprs), // input ref @@ -430,7 +431,7 @@ impl Binder { return self.bind_array_cast(expr.clone(), data_type); } let lhs = self.bind_expr(expr)?; - lhs.cast_explicit(data_type) + lhs.cast_explicit(data_type).map_err(Into::into) } } diff --git a/src/frontend/src/binder/expr/value.rs b/src/frontend/src/binder/expr/value.rs index 514a848ee17fe..d2ec9e1daf504 100644 --- a/src/frontend/src/binder/expr/value.rs +++ b/src/frontend/src/binder/expr/value.rs @@ -130,7 +130,7 @@ impl Binder { }, ) .into(); - return lhs.cast_explicit(ty); + return lhs.cast_explicit(ty).map_err(Into::into); } let inner_type = if let DataType::List { datatype } = &ty { *datatype.clone() diff --git a/src/frontend/src/binder/insert.rs b/src/frontend/src/binder/insert.rs index 6b849bbd9b2bb..07637fd9123bf 100644 --- a/src/frontend/src/binder/insert.rs +++ b/src/frontend/src/binder/insert.rs @@ -247,7 +247,7 @@ impl Binder { return exprs .into_iter() .zip_eq_fast(expected_types) - .map(|(e, t)| e.cast_assign(t.clone())) + .map(|(e, t)| e.cast_assign(t.clone()).map_err(Into::into)) .try_collect(); } std::cmp::Ordering::Less => "INSERT has more expressions than target columns", diff --git a/src/frontend/src/expr/function_call.rs b/src/frontend/src/expr/function_call.rs index fa3d6a98d33ac..10897199c8e69 100644 --- a/src/frontend/src/expr/function_call.rs +++ b/src/frontend/src/expr/function_call.rs @@ -14,10 +14,11 @@ use itertools::Itertools; use risingwave_common::catalog::Schema; -use risingwave_common::error::{ErrorCode, Result}; +use risingwave_common::error::{ErrorCode, Result as RwResult, RwError}; use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_expr::vector_op::cast::literal_parsing; +use thiserror::Error; use super::{cast_ok, infer_some_all, infer_type, CastContext, Expr, ExprImpl, Literal}; use crate::expr::{ExprDisplay, ExprType}; @@ -99,7 +100,7 @@ impl FunctionCall { // The functions listed here are all variadic. Type signatures of functions that take a fixed // number of arguments are checked // [elsewhere](crate::expr::type_inference::build_type_derive_map). - pub fn new(func_type: ExprType, mut inputs: Vec) -> Result { + pub fn new(func_type: ExprType, mut inputs: Vec) -> RwResult { let return_type = infer_type(func_type, &mut inputs)?; Ok(Self { func_type, @@ -109,12 +110,16 @@ impl FunctionCall { } /// Create a cast expr over `child` to `target` type in `allows` context. - pub fn new_cast(child: ExprImpl, target: DataType, allows: CastContext) -> Result { + pub fn new_cast( + child: ExprImpl, + target: DataType, + allows: CastContext, + ) -> Result { if is_row_function(&child) { // Row function will have empty fields in Datatype::Struct at this point. Therefore, // we will need to take some special care to generate the cast types. For normal struct // types, they will be handled in `cast_ok`. - return Self::cast_nested(child, target, allows); + return Self::cast_row_expr(child, target, allows); } if child.is_unknown() { // `is_unknown` makes sure `as_literal` and `as_utf8` will never panic. @@ -146,46 +151,51 @@ impl FunctionCall { } .into()) } else { - Err(ErrorCode::BindError(format!( + Err(CastError(format!( "cannot cast type \"{}\" to \"{}\" in {:?} context", source, target, allows - )) - .into()) + ))) } } /// Cast a `ROW` expression to the target type. We intentionally disallow casting arbitrary /// expressions, like `ROW(1)::STRUCT` to `STRUCT`, although an integer /// is castible to VARCHAR. It's to simply the casting rules. - fn cast_nested(expr: ExprImpl, target_type: DataType, allows: CastContext) -> Result { + fn cast_row_expr( + expr: ExprImpl, + target_type: DataType, + allows: CastContext, + ) -> Result { let func = *expr.into_function_call().unwrap(); let (fields, field_names) = if let DataType::Struct(t) = &target_type { (t.fields.clone(), t.field_names.clone()) } else { - return Err(ErrorCode::BindError(format!( - "column is of type '{}' but expression is of type record", - target_type - )) - .into()); + return Err(CastError(format!( + "cannot cast type \"{}\" to \"{}\" in {:?} context", + func.return_type(), + target_type, + allows + ))); }; let (func_type, inputs, _) = func.decompose(); - let msg = match fields.len().cmp(&inputs.len()) { + match fields.len().cmp(&inputs.len()) { std::cmp::Ordering::Equal => { let inputs = inputs .into_iter() .zip_eq_fast(fields.to_vec()) .map(|(e, t)| Self::new_cast(e, t, allows)) - .collect::>>()?; + .collect::, CastError>>()?; let return_type = DataType::new_struct( inputs.iter().map(|i| i.return_type()).collect_vec(), field_names, ); - return Ok(FunctionCall::new_unchecked(func_type, inputs, return_type).into()); + Ok(FunctionCall::new_unchecked(func_type, inputs, return_type).into()) } - std::cmp::Ordering::Less => "Input has too few columns.", - std::cmp::Ordering::Greater => "Input has too many columns.", - }; - Err(ErrorCode::BindError(format!("cannot cast record to {} ({})", target_type, msg)).into()) + std::cmp::Ordering::Less => Err(CastError("Input has too few columns.".to_string())), + std::cmp::Ordering::Greater => { + Err(CastError("Input has too many columns.".to_string())) + } + } } /// Construct a `FunctionCall` expr directly with the provided `return_type`, bypassing type @@ -205,7 +215,7 @@ impl FunctionCall { pub fn new_binary_op_func( mut func_types: Vec, mut inputs: Vec, - ) -> Result { + ) -> RwResult { let expr_type = func_types.remove(0); match expr_type { ExprType::Some | ExprType::All => { @@ -274,7 +284,7 @@ impl FunctionCall { function_call: &risingwave_pb::expr::FunctionCall, expr_type: ExprType, ret_type: DataType, - ) -> Result { + ) -> RwResult { let inputs: Vec<_> = function_call .get_children() .iter() @@ -419,3 +429,19 @@ pub fn is_row_function(expr: &ExprImpl) -> bool { } false } + +#[derive(Debug, Error)] +#[error("{0}")] +pub struct CastError(String); + +impl From for ErrorCode { + fn from(value: CastError) -> Self { + ErrorCode::BindError(value.to_string()) + } +} + +impl From for RwError { + fn from(value: CastError) -> Self { + ErrorCode::from(value).into() + } +} diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index 427c1565fa0eb..cd217b7eab50b 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -16,7 +16,7 @@ use enum_as_inner::EnumAsInner; use fixedbitset::FixedBitSet; use paste::paste; use risingwave_common::array::ListValue; -use risingwave_common::error::Result; +use risingwave_common::error::Result as RwResult; use risingwave_common::types::{DataType, Datum, Scalar}; use risingwave_expr::expr::{build_from_prost, AggKind}; use risingwave_pb::expr::expr_node::RexNode; @@ -177,22 +177,22 @@ impl ExprImpl { } /// Shorthand to create cast expr to `target` type in implicit context. - pub fn cast_implicit(self, target: DataType) -> Result { + pub fn cast_implicit(self, target: DataType) -> Result { FunctionCall::new_cast(self, target, CastContext::Implicit) } /// Shorthand to create cast expr to `target` type in assign context. - pub fn cast_assign(self, target: DataType) -> Result { + pub fn cast_assign(self, target: DataType) -> Result { FunctionCall::new_cast(self, target, CastContext::Assign) } /// Shorthand to create cast expr to `target` type in explicit context. - pub fn cast_explicit(self, target: DataType) -> Result { + pub fn cast_explicit(self, target: DataType) -> Result { FunctionCall::new_cast(self, target, CastContext::Explicit) } /// Shorthand to enforce implicit cast to boolean - pub fn enforce_bool_clause(self, clause: &str) -> Result { + pub fn enforce_bool_clause(self, clause: &str) -> RwResult { if self.is_unknown() { let inner = self.cast_implicit(DataType::Boolean)?; return Ok(inner); @@ -218,26 +218,27 @@ impl ExprImpl { /// References in `PostgreSQL`: /// * [cast](https://github.com/postgres/postgres/blob/a3ff08e0b08dbfeb777ccfa8f13ebaa95d064c04/src/include/catalog/pg_cast.dat#L437-L444) /// * [impl](https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/backend/utils/adt/bool.c#L204-L209) - pub fn cast_output(self) -> Result { + pub fn cast_output(self) -> RwResult { if self.return_type() == DataType::Boolean { return Ok(FunctionCall::new(ExprType::BoolOut, vec![self])?.into()); } // Use normal cast for other types. Both `assign` and `explicit` can pass the castability // check and there is no difference. self.cast_assign(DataType::Varchar) + .map_err(|err| err.into()) } /// Evaluate the expression on the given input. /// /// TODO: This is a naive implementation. We should avoid proto ser/de. /// Tracking issue: - fn eval_row(&self, input: &OwnedRow) -> Result { + fn eval_row(&self, input: &OwnedRow) -> RwResult { let backend_expr = build_from_prost(&self.to_expr_proto())?; backend_expr.eval_row(input).map_err(Into::into) } /// Evaluate a constant expression. - pub fn eval_row_const(&self) -> Result { + pub fn eval_row_const(&self) -> RwResult { assert!(self.is_const()); self.eval_row(&OwnedRow::empty()) } @@ -728,7 +729,7 @@ impl ExprImpl { } } - pub fn from_expr_proto(proto: &ExprNode) -> Result { + pub fn from_expr_proto(proto: &ExprNode) -> RwResult { let rex_node = proto.get_rex_node()?; let ret_type = proto.get_return_type()?.into(); let expr_type = proto.get_expr_type()?; @@ -892,6 +893,7 @@ use risingwave_common::bail; use risingwave_common::catalog::Schema; use risingwave_common::row::OwnedRow; +use self::function_call::CastError; use crate::binder::BoundSetExpr; use crate::utils::Condition; diff --git a/src/frontend/src/expr/table_function.rs b/src/frontend/src/expr/table_function.rs index f85a3989bd8c9..d84cb5e6087bb 100644 --- a/src/frontend/src/expr/table_function.rs +++ b/src/frontend/src/expr/table_function.rs @@ -20,7 +20,7 @@ use risingwave_common::types::{unnested_list_type, DataType, ScalarImpl}; use risingwave_pb::expr::table_function::Type; use risingwave_pb::expr::TableFunction as TableFunctionProst; -use super::{Expr, ExprImpl, ExprRewriter, Result}; +use super::{Expr, ExprImpl, ExprRewriter, RwResult}; /// A table function takes a row as input and returns a table. It is also known as Set-Returning /// Function. @@ -85,7 +85,7 @@ impl FromStr for TableFunctionType { impl TableFunction { /// Create a `TableFunction` expr with the return type inferred from `func_type` and types of /// `inputs`. - pub fn new(func_type: TableFunctionType, args: Vec) -> Result { + pub fn new(func_type: TableFunctionType, args: Vec) -> RwResult { // TODO: refactor into sth like FunctionCall::new. // Current implementation is copied from legacy code. @@ -94,7 +94,7 @@ impl TableFunction { // generate_series ( start timestamp, stop timestamp, step interval ) or // generate_series ( start i32, stop i32, step i32 ) - fn type_check(exprs: &[ExprImpl]) -> Result { + fn type_check(exprs: &[ExprImpl]) -> RwResult { let mut exprs = exprs.iter(); let (start, stop, step) = exprs.next_tuple().unwrap(); match (start.return_type(), stop.return_type(), step.return_type()) { diff --git a/src/frontend/src/expr/type_inference/cast.rs b/src/frontend/src/expr/type_inference/cast.rs index 549f5b275b825..45901c41824ae 100644 --- a/src/frontend/src/expr/type_inference/cast.rs +++ b/src/frontend/src/expr/type_inference/cast.rs @@ -115,9 +115,11 @@ pub fn align_array_and_element( .enumerate() .map(|(idx, input)| { if idx == array_idx { - input.cast_implicit(array_type.clone()) + input.cast_implicit(array_type.clone()).map_err(Into::into) } else { - input.cast_implicit(common_ele_type.clone()) + input + .cast_implicit(common_ele_type.clone()) + .map_err(Into::into) } }) .try_collect(); diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index e7b03a64affcd..685fc65616f43 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -14,7 +14,7 @@ use itertools::Itertools as _; use num_integer::Integer as _; -use risingwave_common::error::{ErrorCode, Result}; +use risingwave_common::error::{ErrorCode, Result, RwError}; use risingwave_common::types::struct_type::StructType; use risingwave_common::types::{DataType, DataTypeName, ScalarImpl}; use risingwave_common::util::iter_util::ZipEqFast; @@ -48,7 +48,7 @@ pub fn infer_type(func_type: ExprType, inputs: &mut Vec) -> Result) -> Result()?; Ok(sig.ret_type.into()) } @@ -318,7 +318,7 @@ fn infer_type_for_special( .enumerate() .map(|(i, input)| match i { // 0-th arg must be string - 0 => input.cast_implicit(DataType::Varchar), + 0 => input.cast_implicit(DataType::Varchar).map_err(Into::into), // subsequent can be any type, using the output format _ => input.cast_output(), }) diff --git a/src/frontend/src/expr/window_function.rs b/src/frontend/src/expr/window_function.rs index 7303a15557ede..351c552722e35 100644 --- a/src/frontend/src/expr/window_function.rs +++ b/src/frontend/src/expr/window_function.rs @@ -19,7 +19,7 @@ use parse_display::Display; use risingwave_common::error::ErrorCode; use risingwave_common::types::DataType; -use super::{Expr, ExprImpl, OrderBy, Result}; +use super::{Expr, ExprImpl, OrderBy, RwResult}; /// A window function performs a calculation across a set of table rows that are somehow related to /// the current row, according to the window spec `OVER (PARTITION BY .. ORDER BY ..)`. @@ -79,7 +79,7 @@ impl WindowFunction { partition_by: Vec, order_by: OrderBy, args: Vec, - ) -> Result { + ) -> RwResult { if !args.is_empty() { return Err(ErrorCode::BindError(format!( "the length of args of {function_type} function should be 0"