Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add CastError for cast function #8090

Merged
merged 2 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ impl Binder {
Expr::TypedString { data_type, value } => {
let s: ExprImpl = self.bind_string(value)?.into();
s.cast_explicit(bind_data_type(&data_type)?)
.map_err(Into::into)
}
Expr::Row(exprs) => self.bind_row(exprs),
// input ref
Expand Down Expand Up @@ -430,7 +431,7 @@ impl Binder {
return self.bind_array_cast(expr.clone(), data_type);
}
let lhs = self.bind_expr(expr)?;
lhs.cast_explicit(data_type)
lhs.cast_explicit(data_type).map_err(Into::into)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl Binder {
},
)
.into();
return lhs.cast_explicit(ty);
return lhs.cast_explicit(ty).map_err(Into::into);
}
let inner_type = if let DataType::List { datatype } = &ty {
*datatype.clone()
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ impl Binder {
return exprs
.into_iter()
.zip_eq_fast(expected_types)
.map(|(e, t)| e.cast_assign(t.clone()))
.map(|(e, t)| e.cast_assign(t.clone()).map_err(Into::into))
.try_collect();
}
std::cmp::Ordering::Less => "INSERT has more expressions than target columns",
Expand Down
70 changes: 48 additions & 22 deletions src/frontend/src/expr/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

use itertools::Itertools;
use risingwave_common::catalog::Schema;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::error::{ErrorCode, Result as RwResult, RwError};
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_expr::vector_op::cast::literal_parsing;
use thiserror::Error;

use super::{cast_ok, infer_some_all, infer_type, CastContext, Expr, ExprImpl, Literal};
use crate::expr::{ExprDisplay, ExprType};
Expand Down Expand Up @@ -99,7 +100,7 @@ impl FunctionCall {
// The functions listed here are all variadic. Type signatures of functions that take a fixed
// number of arguments are checked
// [elsewhere](crate::expr::type_inference::build_type_derive_map).
pub fn new(func_type: ExprType, mut inputs: Vec<ExprImpl>) -> Result<Self> {
pub fn new(func_type: ExprType, mut inputs: Vec<ExprImpl>) -> RwResult<Self> {
let return_type = infer_type(func_type, &mut inputs)?;
Ok(Self {
func_type,
Expand All @@ -109,12 +110,16 @@ impl FunctionCall {
}

/// Create a cast expr over `child` to `target` type in `allows` context.
pub fn new_cast(child: ExprImpl, target: DataType, allows: CastContext) -> Result<ExprImpl> {
pub fn new_cast(
child: ExprImpl,
target: DataType,
allows: CastContext,
) -> Result<ExprImpl, CastError> {
if is_row_function(&child) {
// Row function will have empty fields in Datatype::Struct at this point. Therefore,
// we will need to take some special care to generate the cast types. For normal struct
// types, they will be handled in `cast_ok`.
return Self::cast_nested(child, target, allows);
return Self::cast_row_expr(child, target, allows);
}
if child.is_unknown() {
// `is_unknown` makes sure `as_literal` and `as_utf8` will never panic.
Expand Down Expand Up @@ -146,46 +151,51 @@ impl FunctionCall {
}
.into())
} else {
Err(ErrorCode::BindError(format!(
Err(CastError(format!(
"cannot cast type \"{}\" to \"{}\" in {:?} context",
source, target, allows
))
.into())
)))
}
}

/// Cast a `ROW` expression to the target type. We intentionally disallow casting arbitrary
/// expressions, like `ROW(1)::STRUCT<i INTEGER>` to `STRUCT<VARCHAR>`, although an integer
/// is castible to VARCHAR. It's to simply the casting rules.
fn cast_nested(expr: ExprImpl, target_type: DataType, allows: CastContext) -> Result<ExprImpl> {
fn cast_row_expr(
expr: ExprImpl,
target_type: DataType,
allows: CastContext,
) -> Result<ExprImpl, CastError> {
let func = *expr.into_function_call().unwrap();
let (fields, field_names) = if let DataType::Struct(t) = &target_type {
(t.fields.clone(), t.field_names.clone())
} else {
return Err(ErrorCode::BindError(format!(
"column is of type '{}' but expression is of type record",
target_type
))
.into());
return Err(CastError(format!(
"cannot cast type \"{}\" to \"{}\" in {:?} context",
func.return_type(),
target_type,
allows
)));
};
let (func_type, inputs, _) = func.decompose();
let msg = match fields.len().cmp(&inputs.len()) {
match fields.len().cmp(&inputs.len()) {
std::cmp::Ordering::Equal => {
let inputs = inputs
.into_iter()
.zip_eq_fast(fields.to_vec())
.map(|(e, t)| Self::new_cast(e, t, allows))
.collect::<Result<Vec<_>>>()?;
.collect::<Result<Vec<_>, CastError>>()?;
let return_type = DataType::new_struct(
inputs.iter().map(|i| i.return_type()).collect_vec(),
field_names,
);
return Ok(FunctionCall::new_unchecked(func_type, inputs, return_type).into());
Ok(FunctionCall::new_unchecked(func_type, inputs, return_type).into())
}
std::cmp::Ordering::Less => "Input has too few columns.",
std::cmp::Ordering::Greater => "Input has too many columns.",
};
Err(ErrorCode::BindError(format!("cannot cast record to {} ({})", target_type, msg)).into())
std::cmp::Ordering::Less => Err(CastError("Input has too few columns.".to_string())),
std::cmp::Ordering::Greater => {
Err(CastError("Input has too many columns.".to_string()))
}
}
}

/// Construct a `FunctionCall` expr directly with the provided `return_type`, bypassing type
Expand All @@ -205,7 +215,7 @@ impl FunctionCall {
pub fn new_binary_op_func(
mut func_types: Vec<ExprType>,
mut inputs: Vec<ExprImpl>,
) -> Result<ExprImpl> {
) -> RwResult<ExprImpl> {
let expr_type = func_types.remove(0);
match expr_type {
ExprType::Some | ExprType::All => {
Expand Down Expand Up @@ -274,7 +284,7 @@ impl FunctionCall {
function_call: &risingwave_pb::expr::FunctionCall,
expr_type: ExprType,
ret_type: DataType,
) -> Result<Self> {
) -> RwResult<Self> {
let inputs: Vec<_> = function_call
.get_children()
.iter()
Expand Down Expand Up @@ -419,3 +429,19 @@ pub fn is_row_function(expr: &ExprImpl) -> bool {
}
false
}

#[derive(Debug, Error)]
#[error("{0}")]
pub struct CastError(String);

impl From<CastError> for ErrorCode {
fn from(value: CastError) -> Self {
ErrorCode::BindError(value.to_string())
}
}

impl From<CastError> for RwError {
fn from(value: CastError) -> Self {
ErrorCode::from(value).into()
}
}
20 changes: 11 additions & 9 deletions src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use enum_as_inner::EnumAsInner;
use fixedbitset::FixedBitSet;
use paste::paste;
use risingwave_common::array::ListValue;
use risingwave_common::error::Result;
use risingwave_common::error::Result as RwResult;
use risingwave_common::types::{DataType, Datum, Scalar};
use risingwave_expr::expr::{build_from_prost, AggKind};
use risingwave_pb::expr::expr_node::RexNode;
Expand Down Expand Up @@ -177,22 +177,22 @@ impl ExprImpl {
}

/// Shorthand to create cast expr to `target` type in implicit context.
pub fn cast_implicit(self, target: DataType) -> Result<ExprImpl> {
pub fn cast_implicit(self, target: DataType) -> Result<ExprImpl, CastError> {
FunctionCall::new_cast(self, target, CastContext::Implicit)
}

/// Shorthand to create cast expr to `target` type in assign context.
pub fn cast_assign(self, target: DataType) -> Result<ExprImpl> {
pub fn cast_assign(self, target: DataType) -> Result<ExprImpl, CastError> {
FunctionCall::new_cast(self, target, CastContext::Assign)
}

/// Shorthand to create cast expr to `target` type in explicit context.
pub fn cast_explicit(self, target: DataType) -> Result<ExprImpl> {
pub fn cast_explicit(self, target: DataType) -> Result<ExprImpl, CastError> {
FunctionCall::new_cast(self, target, CastContext::Explicit)
}

/// Shorthand to enforce implicit cast to boolean
pub fn enforce_bool_clause(self, clause: &str) -> Result<ExprImpl> {
pub fn enforce_bool_clause(self, clause: &str) -> RwResult<ExprImpl> {
if self.is_unknown() {
let inner = self.cast_implicit(DataType::Boolean)?;
return Ok(inner);
Expand All @@ -218,26 +218,27 @@ impl ExprImpl {
/// References in `PostgreSQL`:
/// * [cast](https://github.com/postgres/postgres/blob/a3ff08e0b08dbfeb777ccfa8f13ebaa95d064c04/src/include/catalog/pg_cast.dat#L437-L444)
/// * [impl](https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/backend/utils/adt/bool.c#L204-L209)
pub fn cast_output(self) -> Result<ExprImpl> {
pub fn cast_output(self) -> RwResult<ExprImpl> {
if self.return_type() == DataType::Boolean {
return Ok(FunctionCall::new(ExprType::BoolOut, vec![self])?.into());
}
// Use normal cast for other types. Both `assign` and `explicit` can pass the castability
// check and there is no difference.
self.cast_assign(DataType::Varchar)
.map_err(|err| err.into())
}

/// Evaluate the expression on the given input.
///
/// TODO: This is a naive implementation. We should avoid proto ser/de.
/// Tracking issue: <https://github.com/risingwavelabs/risingwave/issues/3479>
fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
fn eval_row(&self, input: &OwnedRow) -> RwResult<Datum> {
let backend_expr = build_from_prost(&self.to_expr_proto())?;
backend_expr.eval_row(input).map_err(Into::into)
}

/// Evaluate a constant expression.
pub fn eval_row_const(&self) -> Result<Datum> {
pub fn eval_row_const(&self) -> RwResult<Datum> {
assert!(self.is_const());
self.eval_row(&OwnedRow::empty())
}
Expand Down Expand Up @@ -728,7 +729,7 @@ impl ExprImpl {
}
}

pub fn from_expr_proto(proto: &ExprNode) -> Result<Self> {
pub fn from_expr_proto(proto: &ExprNode) -> RwResult<Self> {
let rex_node = proto.get_rex_node()?;
let ret_type = proto.get_return_type()?.into();
let expr_type = proto.get_expr_type()?;
Expand Down Expand Up @@ -892,6 +893,7 @@ use risingwave_common::bail;
use risingwave_common::catalog::Schema;
use risingwave_common::row::OwnedRow;

use self::function_call::CastError;
use crate::binder::BoundSetExpr;
use crate::utils::Condition;

Expand Down
6 changes: 3 additions & 3 deletions src/frontend/src/expr/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use risingwave_common::types::{unnested_list_type, DataType, ScalarImpl};
use risingwave_pb::expr::table_function::Type;
use risingwave_pb::expr::TableFunction as TableFunctionProst;

use super::{Expr, ExprImpl, ExprRewriter, Result};
use super::{Expr, ExprImpl, ExprRewriter, RwResult};

/// A table function takes a row as input and returns a table. It is also known as Set-Returning
/// Function.
Expand Down Expand Up @@ -85,7 +85,7 @@ impl FromStr for TableFunctionType {
impl TableFunction {
/// Create a `TableFunction` expr with the return type inferred from `func_type` and types of
/// `inputs`.
pub fn new(func_type: TableFunctionType, args: Vec<ExprImpl>) -> Result<Self> {
pub fn new(func_type: TableFunctionType, args: Vec<ExprImpl>) -> RwResult<Self> {
// TODO: refactor into sth like FunctionCall::new.
// Current implementation is copied from legacy code.

Expand All @@ -94,7 +94,7 @@ impl TableFunction {
// generate_series ( start timestamp, stop timestamp, step interval ) or
// generate_series ( start i32, stop i32, step i32 )

fn type_check(exprs: &[ExprImpl]) -> Result<DataType> {
fn type_check(exprs: &[ExprImpl]) -> RwResult<DataType> {
let mut exprs = exprs.iter();
let (start, stop, step) = exprs.next_tuple().unwrap();
match (start.return_type(), stop.return_type(), step.return_type()) {
Expand Down
6 changes: 4 additions & 2 deletions src/frontend/src/expr/type_inference/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,11 @@ pub fn align_array_and_element(
.enumerate()
.map(|(idx, input)| {
if idx == array_idx {
input.cast_implicit(array_type.clone())
input.cast_implicit(array_type.clone()).map_err(Into::into)
} else {
input.cast_implicit(common_ele_type.clone())
input
.cast_implicit(common_ele_type.clone())
.map_err(Into::into)
}
})
.try_collect();
Expand Down
8 changes: 4 additions & 4 deletions src/frontend/src/expr/type_inference/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use itertools::Itertools as _;
use num_integer::Integer as _;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_common::types::struct_type::StructType;
use risingwave_common::types::{DataType, DataTypeName, ScalarImpl};
use risingwave_common::util::iter_util::ZipEqFast;
Expand Down Expand Up @@ -48,7 +48,7 @@ pub fn infer_type(func_type: ExprType, inputs: &mut Vec<ExprImpl>) -> Result<Dat
.map(|(expr, t)| {
if DataTypeName::from(expr.return_type()) != *t {
if t.is_scalar() {
return expr.cast_implicit((*t).into());
return expr.cast_implicit((*t).into()).map_err(Into::into);
} else {
return Err(ErrorCode::BindError(format!(
"Cannot implicitly cast '{:?}' to polymorphic type {:?}",
Expand All @@ -59,7 +59,7 @@ pub fn infer_type(func_type: ExprType, inputs: &mut Vec<ExprImpl>) -> Result<Dat
}
Ok(expr)
})
.try_collect()?;
.try_collect::<_, _, RwError>()?;
Ok(sig.ret_type.into())
}

Expand Down Expand Up @@ -318,7 +318,7 @@ fn infer_type_for_special(
.enumerate()
.map(|(i, input)| match i {
// 0-th arg must be string
0 => input.cast_implicit(DataType::Varchar),
0 => input.cast_implicit(DataType::Varchar).map_err(Into::into),
// subsequent can be any type, using the output format
_ => input.cast_output(),
})
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/expr/window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use parse_display::Display;
use risingwave_common::error::ErrorCode;
use risingwave_common::types::DataType;

use super::{Expr, ExprImpl, OrderBy, Result};
use super::{Expr, ExprImpl, OrderBy, RwResult};

/// A window function performs a calculation across a set of table rows that are somehow related to
/// the current row, according to the window spec `OVER (PARTITION BY .. ORDER BY ..)`.
Expand Down Expand Up @@ -79,7 +79,7 @@ impl WindowFunction {
partition_by: Vec<ExprImpl>,
order_by: OrderBy,
args: Vec<ExprImpl>,
) -> Result<Self> {
) -> RwResult<Self> {
if !args.is_empty() {
return Err(ErrorCode::BindError(format!(
"the length of args of {function_type} function should be 0"
Expand Down