Skip to content

Commit

Permalink
chore(query): make string to int respect behavior like PG (#16428)
Browse files Browse the repository at this point in the history
* chore(query): reorder optimize rules

* chore(query): reorder optimize rules

* chore(query): reorder optimize rules

* chore(query): make string to int respect behavior like PG

* update

* revert reorder

* update

* update

* update
  • Loading branch information
sundy-li authored Sep 12, 2024
1 parent 4d14d75 commit 2ca0036
Show file tree
Hide file tree
Showing 13 changed files with 119 additions and 49 deletions.
1 change: 1 addition & 0 deletions benchmark/clickbench/internal/queries/04.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select (number::string)::Int from numbers(100000000) ignore_result;
8 changes: 8 additions & 0 deletions src/query/expression/src/converts/arrow/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ impl DataBlock {
) -> Result<(Self, DataSchema)> {
assert_eq!(schema.num_fields(), batch.num_columns());

if schema.fields().len() != batch.num_columns() {
return Err(ErrorCode::Internal(format!(
"conversion from RecordBatch to DataBlock failed, schema fields len: {}, RecordBatch columns len: {}",
schema.fields().len(),
batch.num_columns()
)));
}

if batch.num_columns() == 0 {
return Ok((DataBlock::new(vec![], batch.num_rows()), schema.clone()));
}
Expand Down
20 changes: 10 additions & 10 deletions src/query/expression/src/utils/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
// limitations under the License.

use std::cmp::Ordering;
use std::result::Result;

use chrono::Datelike;
use chrono::NaiveDate;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;

use crate::types::decimal::Decimal;
use crate::types::decimal::DecimalSize;
Expand All @@ -29,12 +28,13 @@ pub fn uniform_date(date: NaiveDate) -> i32 {
date.num_days_from_ce() - EPOCH_DAYS_FROM_CE
}

// Used in function, so we don't want to return ErrorCode with backtrace
pub fn read_decimal_with_size<T: Decimal>(
buf: &[u8],
size: DecimalSize,
exact: bool,
rounding_mode: bool,
) -> Result<(T, usize)> {
) -> Result<(T, usize), String> {
// Read one more digit for round
let (n, d, e, n_read) =
read_decimal::<T>(buf, (size.precision + 1) as u32, size.scale as _, exact)?;
Expand Down Expand Up @@ -91,7 +91,7 @@ pub fn read_decimal<T: Decimal>(
max_digits: u32,
mut max_scales: u32,
exact: bool,
) -> Result<(T, u8, i32, usize)> {
) -> Result<(T, u8, i32, usize), String> {
if buf.is_empty() {
return Err(decimal_parse_error("empty"));
}
Expand Down Expand Up @@ -302,7 +302,7 @@ pub fn read_decimal<T: Decimal>(
pub fn read_decimal_from_json<T: Decimal>(
value: &serde_json::Value,
size: DecimalSize,
) -> Result<T> {
) -> Result<T, String> {
match value {
serde_json::Value::Number(n) => {
if n.is_i64() {
Expand All @@ -323,14 +323,14 @@ pub fn read_decimal_from_json<T: Decimal>(
let (n, _) = read_decimal_with_size::<T>(s.as_bytes(), size, true, true)?;
Ok(n)
}
_ => Err(ErrorCode::from("Incorrect json value for decimal")),
_ => Err("Incorrect json value for decimal".to_string()),
}
}

fn decimal_parse_error(msg: &str) -> ErrorCode {
ErrorCode::BadArguments(format!("bad decimal literal: {msg}"))
fn decimal_parse_error(msg: &str) -> String {
format!("bad decimal literal: {msg}")
}

fn decimal_overflow_error() -> ErrorCode {
ErrorCode::Overflow("Decimal overflow")
fn decimal_overflow_error() -> String {
"Decimal overflow".to_string()
}
43 changes: 40 additions & 3 deletions src/query/functions/src/scalars/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
use std::ops::BitAnd;
use std::ops::BitOr;
use std::ops::BitXor;
use std::str::FromStr;
use std::sync::Arc;

use databend_common_arrow::arrow::bitmap::Bitmap;
use databend_common_expression::serialize::read_decimal_with_size;
use databend_common_expression::types::decimal::DecimalDomain;
use databend_common_expression::types::decimal::DecimalType;
use databend_common_expression::types::nullable::NullableColumn;
Expand Down Expand Up @@ -878,18 +880,48 @@ fn unary_minus_decimal(
})
}

#[inline]
fn parse_number<T>(
s: &str,
number_datatype: &NumberDataType,
rounding_mode: bool,
) -> Result<T, <T as FromStr>::Err>
where
T: FromStr + num_traits::Num,
{
match s.parse::<T>() {
Ok(v) => Ok(v),
Err(e) => {
if !number_datatype.is_float() {
let decimal_pro = number_datatype.get_decimal_properties().unwrap();
let (res, _) =
read_decimal_with_size::<i128>(s.as_bytes(), decimal_pro, true, rounding_mode)
.map_err(|_| e)?;
format!("{}", res).parse::<T>()
} else {
Err(e)
}
}
}
}

fn register_string_to_number(registry: &mut FunctionRegistry) {
for dest_type in ALL_NUMERICS_TYPES {
with_number_mapped_type!(|DEST_TYPE| match dest_type {
NumberDataType::DEST_TYPE => {
let name = format!("to_{dest_type}").to_lowercase();
let data_type = DEST_TYPE::data_type();
registry
.register_passthrough_nullable_1_arg::<StringType, NumberType<DEST_TYPE>, _, _>(
&name,
|_, _| FunctionDomain::MayThrow,
vectorize_with_builder_1_arg::<StringType, NumberType<DEST_TYPE>>(
move |val, output, ctx| {
match val.parse::<DEST_TYPE>() {
match parse_number::<DEST_TYPE>(
val,
&data_type,
ctx.func_ctx.rounding_mode,
) {
Ok(new_val) => output.push(new_val),
Err(e) => {
ctx.set_error(output.len(), e.to_string());
Expand All @@ -901,15 +933,20 @@ fn register_string_to_number(registry: &mut FunctionRegistry) {
);

let name = format!("try_to_{dest_type}").to_lowercase();
let data_type = DEST_TYPE::data_type();
registry
.register_combine_nullable_1_arg::<StringType, NumberType<DEST_TYPE>, _, _>(
&name,
|_, _| FunctionDomain::Full,
vectorize_with_builder_1_arg::<
StringType,
NullableType<NumberType<DEST_TYPE>>,
>(|val, output, _| {
if let Ok(new_val) = val.parse::<DEST_TYPE>() {
>(move |val, output, ctx| {
if let Ok(new_val) = parse_number::<DEST_TYPE>(
val,
&data_type,
ctx.func_ctx.rounding_mode,
) {
output.push(new_val);
} else {
output.push_null();
Expand Down
2 changes: 1 addition & 1 deletion src/query/functions/src/scalars/decimal/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ where
{
Ok((d, _)) => d,
Err(e) => {
ctx.set_error(builder.len(), e.message());
ctx.set_error(builder.len(), e);
T::zero()
}
};
Expand Down
8 changes: 4 additions & 4 deletions src/query/sql/src/evaluator/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ pub fn apply_cse(
count_expressions(expr, &mut cse_counter);
}

let mut cse_candidates: Vec<Expr> = cse_counter
let mut cse_candidates: Vec<&Expr> = cse_counter
.iter()
.filter(|(_, count)| **count > 1)
.map(|(expr, _)| expr.clone())
.map(|(expr, _)| expr)
.collect();

// Make sure the smaller expr goes firstly
Expand All @@ -52,7 +52,7 @@ pub fn apply_cse(
let mut cse_replacements = HashMap::new();

let candidates_nums = cse_candidates.len();
for cse_candidate in &cse_candidates {
for cse_candidate in cse_candidates.iter() {
let temp_var = format!("__temp_cse_{}", temp_var_counter);
let temp_expr = Expr::ColumnRef {
span: None,
Expand All @@ -61,7 +61,7 @@ pub fn apply_cse(
display_name: temp_var.clone(),
};

let mut expr_cloned = cse_candidate.clone();
let mut expr_cloned = (*cse_candidate).clone();
perform_cse_replacement(&mut expr_cloned, &cse_replacements);

debug!(
Expand Down
31 changes: 29 additions & 2 deletions src/query/sql/src/planner/binder/column_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,47 @@ pub struct ColumnBinding {
}

const DUMMY_INDEX: usize = usize::MAX;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u64)]
pub enum DummyColumnType {
WindowFunction = 1,
AggregateFunction = 2,
Subquery = 3,
UDF = 4,
AsyncFunction = 5,
Other = 6,
}

impl DummyColumnType {
fn type_identifier(&self) -> usize {
DUMMY_INDEX - (*self) as usize
}
}

impl ColumnBinding {
pub fn new_dummy_column(name: String, data_type: Box<DataType>) -> Self {
pub fn new_dummy_column(
name: String,
data_type: Box<DataType>,
dummy_type: DummyColumnType,
) -> Self {
let index = dummy_type.type_identifier();
ColumnBinding {
database_name: None,
table_name: None,
column_position: None,
table_index: None,
column_name: name,
index: DUMMY_INDEX,
index,
data_type,
visibility: Visibility::Visible,
virtual_computed_expr: None,
}
}

pub fn is_dummy(&self) -> bool {
self.index >= DummyColumnType::Other.type_identifier()
}
}

impl ColumnIndex for ColumnBinding {}
Expand Down
1 change: 1 addition & 0 deletions src/query/sql/src/planner/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ pub use binder::Binder;
pub use builders::*;
pub use column_binding::ColumnBinding;
pub use column_binding::ColumnBindingBuilder;
pub use column_binding::DummyColumnType;
pub use copy_into_table::resolve_file_location;
pub use copy_into_table::resolve_stage_location;
pub use explain::ExplainConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@ use std::sync::Arc;
use databend_common_exception::Result;
use databend_common_expression::DataField;
use databend_common_expression::DataSchemaRefExt;
use databend_common_expression::Scalar;
use itertools::Itertools;

use crate::optimizer::extract::Matcher;
use crate::optimizer::rule::constant::is_falsy;
use crate::optimizer::rule::Rule;
use crate::optimizer::rule::RuleID;
use crate::optimizer::rule::TransformResult;
use crate::optimizer::RelExpr;
use crate::optimizer::SExpr;
use crate::plans::ConstantExpr;
use crate::plans::ConstantTableScan;
use crate::plans::Filter;
use crate::plans::Operator;
Expand Down Expand Up @@ -73,18 +72,7 @@ impl Rule for RuleEliminateFilter {
.collect::<Vec<ScalarExpr>>();

// Rewrite false filter to be empty scan
if predicates.iter().any(|predicate| {
matches!(
predicate,
ScalarExpr::ConstantExpr(ConstantExpr {
value: Scalar::Boolean(false),
..
}) | ScalarExpr::ConstantExpr(ConstantExpr {
value: Scalar::Null,
..
})
)
}) {
if predicates.iter().any(is_falsy) {
let output_columns = eval_scalar
.derive_relational_prop(&RelExpr::with_s_expr(s_expr))?
.output_columns
Expand Down
16 changes: 9 additions & 7 deletions src/query/sql/src/planner/optimizer/rule/rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,20 @@ use crate::optimizer::SExpr;

pub static DEFAULT_REWRITE_RULES: LazyLock<Vec<RuleID>> = LazyLock::new(|| {
vec![
RuleID::NormalizeScalarFilter,
RuleID::EliminateFilter,
RuleID::EliminateSort,
RuleID::MergeFilter,
RuleID::MergeEvalScalar,
// Filter
RuleID::EliminateFilter,
RuleID::MergeFilter,
RuleID::NormalizeScalarFilter,
RuleID::PushDownFilterUnion,
RuleID::PushDownFilterAggregate,
RuleID::PushDownFilterWindow,
RuleID::PushDownFilterSort,
RuleID::PushDownFilterEvalScalar,
// Limit
RuleID::PushDownFilterJoin,
RuleID::PushDownFilterProjectSet,
RuleID::PushDownLimit,
RuleID::PushDownLimitUnion,
RuleID::PushDownLimitEvalScalar,
Expand All @@ -42,10 +48,6 @@ pub static DEFAULT_REWRITE_RULES: LazyLock<Vec<RuleID>> = LazyLock::new(|| {
RuleID::PushDownLimitAggregate,
RuleID::PushDownLimitOuterJoin,
RuleID::PushDownLimitScan,
RuleID::PushDownFilterSort,
RuleID::PushDownFilterEvalScalar,
RuleID::PushDownFilterJoin,
RuleID::PushDownFilterProjectSet,
RuleID::SemiToInnerJoin,
RuleID::FoldCountAggregate,
RuleID::TryApplyAggIndex,
Expand Down
10 changes: 8 additions & 2 deletions src/query/sql/src/planner/semantic/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use databend_common_expression::Expr;
use databend_common_expression::RawExpr;
use databend_common_functions::BUILTIN_FUNCTIONS;

use crate::binder::DummyColumnType;
use crate::plans::ScalarExpr;
use crate::ColumnBinding;
use crate::ColumnEntry;
Expand Down Expand Up @@ -198,6 +199,7 @@ impl ScalarExpr {
id: ColumnBinding::new_dummy_column(
win.display_name.clone(),
Box::new(win.func.return_type()),
DummyColumnType::WindowFunction,
),
data_type: win.func.return_type(),
display_name: win.display_name.clone(),
Expand All @@ -207,6 +209,7 @@ impl ScalarExpr {
id: ColumnBinding::new_dummy_column(
agg.display_name.clone(),
Box::new((*agg.return_type).clone()),
DummyColumnType::AggregateFunction,
),
data_type: (*agg.return_type).clone(),
display_name: agg.display_name.clone(),
Expand Down Expand Up @@ -234,17 +237,19 @@ impl ScalarExpr {
ScalarExpr::SubqueryExpr(subquery) => RawExpr::ColumnRef {
span: subquery.span,
id: ColumnBinding::new_dummy_column(
"DUMMY".to_string(),
"DUMMY_SUBQUERY".to_string(),
Box::new(subquery.data_type()),
DummyColumnType::Subquery,
),
data_type: subquery.data_type(),
display_name: "DUMMY".to_string(),
display_name: "DUMMY_SUBQUERY".to_string(),
},
ScalarExpr::UDFCall(udf) => RawExpr::ColumnRef {
span: None,
id: ColumnBinding::new_dummy_column(
udf.display_name.clone(),
Box::new((*udf.return_type).clone()),
DummyColumnType::UDF,
),
data_type: (*udf.return_type).clone(),
display_name: udf.display_name.clone(),
Expand All @@ -260,6 +265,7 @@ impl ScalarExpr {
id: ColumnBinding::new_dummy_column(
async_func.display_name.clone(),
Box::new(async_func.return_type.as_ref().clone()),
DummyColumnType::AsyncFunction,
),
data_type: async_func.return_type.as_ref().clone(),
display_name: async_func.display_name.clone(),
Expand Down
Loading

0 comments on commit 2ca0036

Please sign in to comment.