Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(query): make string to int respect behavior like PG #16428

Merged
merged 12 commits into from
Sep 12, 2024
40 changes: 37 additions & 3 deletions src/query/functions/src/scalars/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
use std::ops::BitAnd;
use std::ops::BitOr;
use std::ops::BitXor;
use std::str::FromStr;
use std::sync::Arc;

use databend_common_arrow::arrow::bitmap::Bitmap;
use databend_common_expression::serialize::read_decimal_with_size;
use databend_common_expression::types::decimal::DecimalDomain;
use databend_common_expression::types::decimal::DecimalType;
use databend_common_expression::types::nullable::NullableColumn;
Expand Down Expand Up @@ -878,6 +880,30 @@ fn unary_minus_decimal(
})
}

fn parse_number<T>(
s: &str,
number_datatype: &NumberDataType,
rounding_mode: bool,
) -> Result<T, <T as FromStr>::Err>
where
T: FromStr + num_traits::Num,
{
match s.parse::<T>() {
Ok(v) => Ok(v),
Err(e) => {
if !number_datatype.is_float() {
let decimal_pro = number_datatype.get_decimal_properties().unwrap();
let (res, _) =
read_decimal_with_size::<i128>(s.as_bytes(), decimal_pro, true, rounding_mode)
.map_err(|_| e)?;
format!("{}", res).parse::<T>()
} else {
Err(e)
}
}
}
}

fn register_string_to_number(registry: &mut FunctionRegistry) {
for dest_type in ALL_NUMERICS_TYPES {
with_number_mapped_type!(|DEST_TYPE| match dest_type {
Expand All @@ -889,7 +915,11 @@ fn register_string_to_number(registry: &mut FunctionRegistry) {
|_, _| FunctionDomain::MayThrow,
vectorize_with_builder_1_arg::<StringType, NumberType<DEST_TYPE>>(
move |val, output, ctx| {
match val.parse::<DEST_TYPE>() {
match parse_number::<DEST_TYPE>(
val,
&DEST_TYPE::data_type(),
ctx.func_ctx.rounding_mode,
) {
Ok(new_val) => output.push(new_val),
Err(e) => {
ctx.set_error(output.len(), e.to_string());
Expand All @@ -908,8 +938,12 @@ fn register_string_to_number(registry: &mut FunctionRegistry) {
vectorize_with_builder_1_arg::<
StringType,
NullableType<NumberType<DEST_TYPE>>,
>(|val, output, _| {
if let Ok(new_val) = val.parse::<DEST_TYPE>() {
>(|val, output, ctx| {
if let Ok(new_val) = parse_number::<DEST_TYPE>(
val,
&DEST_TYPE::data_type(),
ctx.func_ctx.rounding_mode,
) {
output.push(new_val);
} else {
output.push_null();
Expand Down
8 changes: 4 additions & 4 deletions src/query/sql/src/evaluator/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ pub fn apply_cse(
count_expressions(expr, &mut cse_counter);
}

let mut cse_candidates: Vec<Expr> = cse_counter
let mut cse_candidates: Vec<&Expr> = cse_counter
.iter()
.filter(|(_, count)| **count > 1)
.map(|(expr, _)| expr.clone())
.map(|(expr, _)| expr)
.collect();

// Make sure the smaller expr goes firstly
Expand All @@ -52,7 +52,7 @@ pub fn apply_cse(
let mut cse_replacements = HashMap::new();

let candidates_nums = cse_candidates.len();
for cse_candidate in &cse_candidates {
for cse_candidate in cse_candidates.iter() {
let temp_var = format!("__temp_cse_{}", temp_var_counter);
let temp_expr = Expr::ColumnRef {
span: None,
Expand All @@ -61,7 +61,7 @@ pub fn apply_cse(
display_name: temp_var.clone(),
};

let mut expr_cloned = cse_candidate.clone();
let mut expr_cloned = (*cse_candidate).clone();
perform_cse_replacement(&mut expr_cloned, &cse_replacements);

debug!(
Expand Down
31 changes: 29 additions & 2 deletions src/query/sql/src/planner/binder/column_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,47 @@ pub struct ColumnBinding {
}

const DUMMY_INDEX: usize = usize::MAX;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u64)]
pub enum DummyColumnType {
WindowFunction = 1,
AggregateFunction = 2,
Subquery = 3,
UDF = 4,
AsyncFunction = 5,
Other = 6,
}

impl DummyColumnType {
fn type_identifier(&self) -> usize {
DUMMY_INDEX - (*self) as usize
}
}

impl ColumnBinding {
pub fn new_dummy_column(name: String, data_type: Box<DataType>) -> Self {
pub fn new_dummy_column(
name: String,
data_type: Box<DataType>,
dummy_type: DummyColumnType,
) -> Self {
let index = dummy_type.type_identifier();
ColumnBinding {
database_name: None,
table_name: None,
column_position: None,
table_index: None,
column_name: name,
index: DUMMY_INDEX,
index,
data_type,
visibility: Visibility::Visible,
virtual_computed_expr: None,
}
}

pub fn is_dummy(&self) -> bool {
self.index >= DummyColumnType::Other.type_identifier()
}
}

impl ColumnIndex for ColumnBinding {}
Expand Down
1 change: 1 addition & 0 deletions src/query/sql/src/planner/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ pub use binder::Binder;
pub use builders::*;
pub use column_binding::ColumnBinding;
pub use column_binding::ColumnBindingBuilder;
pub use column_binding::DummyColumnType;
pub use copy_into_table::resolve_file_location;
pub use copy_into_table::resolve_stage_location;
pub use explain::ExplainConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@ use std::sync::Arc;
use databend_common_exception::Result;
use databend_common_expression::DataField;
use databend_common_expression::DataSchemaRefExt;
use databend_common_expression::Scalar;
use itertools::Itertools;

use crate::optimizer::extract::Matcher;
use crate::optimizer::rule::constant::is_falsy;
use crate::optimizer::rule::Rule;
use crate::optimizer::rule::RuleID;
use crate::optimizer::rule::TransformResult;
use crate::optimizer::RelExpr;
use crate::optimizer::SExpr;
use crate::plans::ConstantExpr;
use crate::plans::ConstantTableScan;
use crate::plans::Filter;
use crate::plans::Operator;
Expand Down Expand Up @@ -73,18 +72,7 @@ impl Rule for RuleEliminateFilter {
.collect::<Vec<ScalarExpr>>();

// Rewrite false filter to be empty scan
if predicates.iter().any(|predicate| {
matches!(
predicate,
ScalarExpr::ConstantExpr(ConstantExpr {
value: Scalar::Boolean(false),
..
}) | ScalarExpr::ConstantExpr(ConstantExpr {
value: Scalar::Null,
..
})
)
}) {
if predicates.iter().any(is_falsy) {
let output_columns = eval_scalar
.derive_relational_prop(&RelExpr::with_s_expr(s_expr))?
.output_columns
Expand Down
16 changes: 9 additions & 7 deletions src/query/sql/src/planner/optimizer/rule/rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,20 @@ use crate::optimizer::SExpr;

pub static DEFAULT_REWRITE_RULES: LazyLock<Vec<RuleID>> = LazyLock::new(|| {
vec![
RuleID::NormalizeScalarFilter,
RuleID::EliminateFilter,
RuleID::EliminateSort,
RuleID::MergeFilter,
RuleID::MergeEvalScalar,
// Filter
RuleID::EliminateFilter,
RuleID::MergeFilter,
RuleID::NormalizeScalarFilter,
RuleID::PushDownFilterUnion,
RuleID::PushDownFilterAggregate,
RuleID::PushDownFilterWindow,
RuleID::PushDownFilterSort,
RuleID::PushDownFilterEvalScalar,
// Limit
RuleID::PushDownFilterJoin,
RuleID::PushDownFilterProjectSet,
RuleID::PushDownLimit,
RuleID::PushDownLimitUnion,
RuleID::PushDownLimitEvalScalar,
Expand All @@ -42,10 +48,6 @@ pub static DEFAULT_REWRITE_RULES: LazyLock<Vec<RuleID>> = LazyLock::new(|| {
RuleID::PushDownLimitAggregate,
RuleID::PushDownLimitOuterJoin,
RuleID::PushDownLimitScan,
RuleID::PushDownFilterSort,
RuleID::PushDownFilterEvalScalar,
RuleID::PushDownFilterJoin,
RuleID::PushDownFilterProjectSet,
RuleID::SemiToInnerJoin,
RuleID::FoldCountAggregate,
RuleID::TryApplyAggIndex,
Expand Down
10 changes: 8 additions & 2 deletions src/query/sql/src/planner/semantic/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use databend_common_expression::Expr;
use databend_common_expression::RawExpr;
use databend_common_functions::BUILTIN_FUNCTIONS;

use crate::binder::DummyColumnType;
use crate::plans::ScalarExpr;
use crate::ColumnBinding;
use crate::ColumnEntry;
Expand Down Expand Up @@ -198,6 +199,7 @@ impl ScalarExpr {
id: ColumnBinding::new_dummy_column(
win.display_name.clone(),
Box::new(win.func.return_type()),
DummyColumnType::WindowFunction,
),
data_type: win.func.return_type(),
display_name: win.display_name.clone(),
Expand All @@ -207,6 +209,7 @@ impl ScalarExpr {
id: ColumnBinding::new_dummy_column(
agg.display_name.clone(),
Box::new((*agg.return_type).clone()),
DummyColumnType::AggregateFunction,
),
data_type: (*agg.return_type).clone(),
display_name: agg.display_name.clone(),
Expand Down Expand Up @@ -234,17 +237,19 @@ impl ScalarExpr {
ScalarExpr::SubqueryExpr(subquery) => RawExpr::ColumnRef {
span: subquery.span,
id: ColumnBinding::new_dummy_column(
"DUMMY".to_string(),
"DUMMY_SUBQUERY".to_string(),
Box::new(subquery.data_type()),
DummyColumnType::Subquery,
),
data_type: subquery.data_type(),
display_name: "DUMMY".to_string(),
display_name: "DUMMY_SUBQUERY".to_string(),
},
ScalarExpr::UDFCall(udf) => RawExpr::ColumnRef {
span: None,
id: ColumnBinding::new_dummy_column(
udf.display_name.clone(),
Box::new((*udf.return_type).clone()),
DummyColumnType::UDF,
),
data_type: (*udf.return_type).clone(),
display_name: udf.display_name.clone(),
Expand All @@ -260,6 +265,7 @@ impl ScalarExpr {
id: ColumnBinding::new_dummy_column(
async_func.display_name.clone(),
Box::new(async_func.return_type.as_ref().clone()),
DummyColumnType::AsyncFunction,
),
data_type: async_func.return_type.as_ref().clone(),
display_name: async_func.display_name.clone(),
Expand Down
8 changes: 4 additions & 4 deletions tests/sqllogictests/suites/query/functions/cast.test
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ statement ok
set numeric_cast_option = 'truncating'

query T
select CAST(10249.5500000000000000 * POW(10, 2) AS UNSIGNED)
select CAST(10249.5500000000000000 * POW(10, 2) AS UNSIGNED), '29.55'::Int, '29.155'::Int
----
1024954
1024954 29 29

query T
select to_uint64(1024954.98046875::double)
Expand All @@ -101,9 +101,9 @@ statement ok
set numeric_cast_option = 'rounding'

query T
select CAST(10249.5500000000000000 * POW(10, 2) AS UNSIGNED)
select CAST(10249.5500000000000000 * POW(10, 2) AS UNSIGNED), '29.55'::Int, '29.155'::Int
----
1024955
1024955 30 29

query T
select to_uint64(1024954.98046875::double)
Expand Down
Loading