From 935677ee239b728196599040275b2679056b7609 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 14 Nov 2024 17:30:39 +0800 Subject: [PATCH] feat(core): register timestamptz type to DataFusion timestamp with time zone type (#908) --- .../resources/function_list/bigquery.csv | 15 +- .../v3/connector/bigquery/test_functions.py | 4 +- .../v3/connector/bigquery/test_query.py | 31 +++- .../v3/connector/postgres/test_query.py | 12 +- ibis-server/tools/mdl_validation.py | 6 +- ibis-server/tools/query_local_run.py | 6 +- .../optimize/simplify_timestamp.rs | 43 ++--- wren-core/core/src/logical_plan/utils.rs | 20 ++- wren-core/core/src/mdl/context.rs | 4 +- wren-core/core/src/mdl/mod.rs | 149 +++++++++++++++++- 10 files changed, 224 insertions(+), 66 deletions(-) diff --git a/ibis-server/resources/function_list/bigquery.csv b/ibis-server/resources/function_list/bigquery.csv index 87343a5bc..11e33a675 100644 --- a/ibis-server/resources/function_list/bigquery.csv +++ b/ibis-server/resources/function_list/bigquery.csv @@ -45,7 +45,7 @@ scalar,array_concat,ARRAY,"Concatenates multiple arrays into one." scalar,array_to_string,STRING,"Converts an array to a single string." scalar,generate_array,ARRAY,"Generates an array of values in a range." scalar,generate_date_array,ARRAY,"Generates an array of dates in a range." -scalar,parse_timestamp,TIMESTAMP,"Parses a timestamp from a string." +scalar,parse_timestamp,TIMESTAMPTZ,"Parses a timestamp from a string." scalar,string_to_array,ARRAY,"Splits a string into an array of substrings." scalar,safe_divide,FLOAT64,"Divides two numbers, returning NULL if the divisor is zero." scalar,safe_multiply,FLOAT64,"Multiplies two numbers, returning NULL if an overflow occurs." @@ -74,17 +74,18 @@ scalar,substr,STRING,"Returns a substring." scalar,cast,ANY,"Converts a value to a different data type." scalar,safe_cast,ANY,"Converts a value to a different data type, returning NULL on error." scalar,current_date,DATE,"Returns the current date." +scalar,current_datetime,TIMESTAMP,"Returns the current date." scalar,date_add,DATE,"Adds a specified interval to a date." scalar,date_sub,DATE,"Subtracts a specified interval from a date." scalar,date_diff,INT64,"Returns the difference between two dates." scalar,date_trunc,DATE,"Truncates a date to a specified granularity." -scalar,timestamp_add,TIMESTAMP,"Adds a specified interval to a timestamp." -scalar,timestamp_sub,TIMESTAMP,"Subtracts a specified interval from a timestamp." +scalar,timestamp_add,TIMESTAMPTZ,"Adds a specified interval to a timestamp." +scalar,timestamp_sub,TIMESTAMPTZ,"Subtracts a specified interval from a timestamp." scalar,timestamp_diff,INT64,"Returns the difference between two timestamps." -scalar,timestamp_trunc,TIMESTAMP,"Truncates a timestamp to a specified granularity." -scalar,timestamp_micros,TIMESTAMP,"Converts the number of microseconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP." -scalar,timestamp_millis,TIMESTAMP,"Converts the number of milliseconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP." -scalar,timestamp_seconds,TIMESTAMP,"Converts the number of seconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP." +scalar,timestamp_trunc,TIMESTAMPTZ,"Truncates a timestamp to a specified granularity." +scalar,timestamp_micros,TIMESTAMPTZ,"Converts the number of microseconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP." +scalar,timestamp_millis,TIMESTAMPTZ,"Converts the number of milliseconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP." +scalar,timestamp_seconds,TIMESTAMPTZ,"Converts the number of seconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP." scalar,format_date,STRING,"Formats a date according to the specified format string." scalar,format_timestamp,STRING,"Formats a timestamp according to the specified format string." scalar,parse_date,DATE,"Parses a date from a string." diff --git a/ibis-server/tests/routers/v3/connector/bigquery/test_functions.py b/ibis-server/tests/routers/v3/connector/bigquery/test_functions.py index 6893e7af9..477a8b8c5 100644 --- a/ibis-server/tests/routers/v3/connector/bigquery/test_functions.py +++ b/ibis-server/tests/routers/v3/connector/bigquery/test_functions.py @@ -6,8 +6,8 @@ from app.config import get_config from app.main import app -from tests.routers.v3.connector.bigquery.conftest import base_url, function_list_path from tests.conftest import DATAFUSION_FUNCTION_COUNT +from tests.routers.v3.connector.bigquery.conftest import base_url, function_list_path manifest = { "catalog": "my_catalog", @@ -47,7 +47,7 @@ def test_function_list(): response = client.get(url=f"{base_url}/functions") assert response.status_code == 200 result = response.json() - assert len(result) == DATAFUSION_FUNCTION_COUNT + 33 + assert len(result) == DATAFUSION_FUNCTION_COUNT + 34 the_func = next(filter(lambda x: x["name"] == "abs", result)) assert the_func == { "name": "abs", diff --git a/ibis-server/tests/routers/v3/connector/bigquery/test_query.py b/ibis-server/tests/routers/v3/connector/bigquery/test_query.py index a28d6605d..141a9bd8a 100644 --- a/ibis-server/tests/routers/v3/connector/bigquery/test_query.py +++ b/ibis-server/tests/routers/v3/connector/bigquery/test_query.py @@ -42,17 +42,17 @@ { "name": "timestamptz", "expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)", - "type": "timestamp", + "type": "timestamptz", }, { "name": "dst_utc_minus_5", "expression": "cast('2024-01-15 23:00:00 America/New_York' as timestamp with time zone)", - "type": "timestamp", + "type": "timestamptz", }, { "name": "dst_utc_minus_4", "expression": "cast('2024-07-15 23:00:00 America/New_York' as timestamp with time zone)", - "type": "timestamp", + "type": "timestamptz", }, ], "primaryKey": "o_orderkey", @@ -83,9 +83,9 @@ def test_query(manifest_str, connection_info): assert len(result["data"]) == 1 assert result["data"][0] == [ "2024-01-01 23:59:59.000000", - "2024-01-01 23:59:59.000000", - "2024-01-16 04:00:00.000000", # utc-5 - "2024-07-16 03:00:00.000000", # utc-4 + "2024-01-01 23:59:59.000000 UTC", + "2024-01-16 04:00:00.000000 UTC", # utc-5 + "2024-07-16 03:00:00.000000 UTC", # utc-4 "36485_1202", 1202, "1992-06-06", @@ -237,3 +237,22 @@ def test_timestamp_func(manifest_str, connection_info): "micros": "object", "seconds": "object", } + + response = client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT timestamp with time zone '2000-01-01 10:00:00' < current_datetime() as compare", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == 1 + assert len(result["data"]) == 1 + assert result["data"][0] == [ + True, + ] + assert result["dtypes"] == { + "compare": "bool", + } diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_query.py b/ibis-server/tests/routers/v3/connector/postgres/test_query.py index b1131261f..685472fec 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_query.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_query.py @@ -42,17 +42,17 @@ { "name": "timestamptz", "expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)", - "type": "timestamp", + "type": "timestamptz", }, { "name": "dst_utc_minus_5", "expression": "cast('2024-01-15 23:00:00 America/New_York' as timestamp with time zone)", - "type": "timestamp", + "type": "timestamptz", }, { "name": "dst_utc_minus_4", "expression": "cast('2024-07-15 23:00:00 America/New_York' as timestamp with time zone)", - "type": "timestamp", + "type": "timestamptz", }, ], "primaryKey": "o_orderkey", @@ -83,9 +83,9 @@ def test_query(manifest_str, connection_info): assert len(result["data"]) == 1 assert result["data"][0] == [ "2024-01-01 23:59:59.000000", - "2024-01-01 23:59:59.000000", - "2024-01-16 04:00:00.000000", # utc-5 - "2024-07-16 03:00:00.000000", # utc-4 + "2024-01-01 23:59:59.000000 UTC", + "2024-01-16 04:00:00.000000 UTC", # utc-5 + "2024-07-16 03:00:00.000000 UTC", # utc-4 "1_370", 370, "1996-01-02", diff --git a/ibis-server/tools/mdl_validation.py b/ibis-server/tools/mdl_validation.py index 5b27e0409..aa708515e 100644 --- a/ibis-server/tools/mdl_validation.py +++ b/ibis-server/tools/mdl_validation.py @@ -12,6 +12,7 @@ import base64 import json +from loguru import logger from wren_core import SessionContext # Set up argument parsing @@ -39,9 +40,10 @@ sql = f"select \"{column['name']}\" from \"{model['name']}\"" try: planned_sql = session_context.transform_sql(sql) - except Exception: + except Exception as e: error_cases.append((model, column)) - print(f"Error transforming {model['name']} {column['name']}") + logger.info(f"Error transforming {model['name']} {column['name']}") + logger.debug(e) if len(error_cases) > 0: raise Exception(f"Error transforming {len(error_cases)} columns") diff --git a/ibis-server/tools/query_local_run.py b/ibis-server/tools/query_local_run.py index 0c934af37..fd07ceed6 100644 --- a/ibis-server/tools/query_local_run.py +++ b/ibis-server/tools/query_local_run.py @@ -42,7 +42,7 @@ print("# Function List Path:", function_list_path) print("# Connection Info Path:", connection_info_path) print("# Data Source:", data_source) -print("# SQL Query:", sql) +print("# SQL Query:\n", sql) print("#") # Read and encode the JSON data @@ -60,11 +60,11 @@ print("#") session_context = SessionContext(encoded_str, function_list_path) planned_sql = session_context.transform_sql(sql) -print("# Planned SQL:", planned_sql) +print("# Planned SQL:\n", planned_sql) # Transpile the planned SQL dialect_sql = sqlglot.transpile(planned_sql, read="trino", write=data_source)[0] -print("# Dialect SQL:", dialect_sql) +print("# Dialect SQL:\n", dialect_sql) print("#") if data_source == "bigquery": diff --git a/wren-core/core/src/logical_plan/optimize/simplify_timestamp.rs b/wren-core/core/src/logical_plan/optimize/simplify_timestamp.rs index d837f64cf..b29c2f517 100644 --- a/wren-core/core/src/logical_plan/optimize/simplify_timestamp.rs +++ b/wren-core/core/src/logical_plan/optimize/simplify_timestamp.rs @@ -17,19 +17,21 @@ * under the License. */ use datafusion::arrow::datatypes::{DataType, TimeUnit}; -use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion::common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRewriter, +}; use datafusion::common::ScalarValue::{ TimestampMicrosecond, TimestampMillisecond, TimestampSecond, }; use datafusion::common::{DFSchema, DFSchemaRef, Result, ScalarValue}; +use datafusion::config::ConfigOptions; use datafusion::execution::context::ExecutionProps; use datafusion::logical_expr::expr_rewriter::NamePreserver; use datafusion::logical_expr::simplify::SimplifyContext; use datafusion::logical_expr::utils::merge_schema; use datafusion::logical_expr::{cast, Cast, LogicalPlan, TryCast}; -use datafusion::optimizer::optimizer::ApplyOrder; use datafusion::optimizer::simplify_expressions::ExprSimplifier; -use datafusion::optimizer::{OptimizerConfig, OptimizerRule}; +use datafusion::optimizer::AnalyzerRule; use datafusion::prelude::Expr; use datafusion::scalar::ScalarValue::TimestampNanosecond; use std::sync::Arc; @@ -46,37 +48,18 @@ impl TimestampSimplify { } } -impl OptimizerRule for TimestampSimplify { - fn name(&self) -> &str { - "simplify_cast_expressions" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) +impl AnalyzerRule for TimestampSimplify { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + Self::analyze_internal(plan).data() } - fn supports_rewrite(&self) -> bool { - true - } - - /// if supports_owned returns true, the Optimizer calls - /// [`Self::rewrite`] instead of [`Self::try_optimize`] - fn rewrite( - &self, - plan: LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - let mut execution_props = ExecutionProps::new(); - execution_props.query_execution_start_time = config.query_execution_start_time(); - Self::optimize_internal(plan, &execution_props) + fn name(&self) -> &str { + "simplify_timestamp_expressions" } } impl TimestampSimplify { - fn optimize_internal( - plan: LogicalPlan, - execution_props: &ExecutionProps, - ) -> Result> { + fn analyze_internal(plan: LogicalPlan) -> Result> { let schema = if !plan.inputs().is_empty() { DFSchemaRef::new(merge_schema(&plan.inputs())) } else if let LogicalPlan::TableScan(scan) = &plan { @@ -97,8 +80,8 @@ impl TimestampSimplify { } else { Arc::new(DFSchema::empty()) }; - - let info = SimplifyContext::new(execution_props).with_schema(schema); + let execution_props = ExecutionProps::default(); + let info = SimplifyContext::new(&execution_props).with_schema(schema); // Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer) // Just need to rewrite our own expressions diff --git a/wren-core/core/src/logical_plan/utils.rs b/wren-core/core/src/logical_plan/utils.rs index 36c241d7e..8d018a5cb 100644 --- a/wren-core/core/src/logical_plan/utils.rs +++ b/wren-core/core/src/logical_plan/utils.rs @@ -44,7 +44,7 @@ pub fn map_data_type(data_type: &str) -> DataType { } match data_type { // Wren Definition Types - "bool" => DataType::Boolean, + "bool" | "boolean" => DataType::Boolean, "tinyint" => DataType::Int8, "int2" => DataType::Int16, "smallint" => DataType::Int16, @@ -66,8 +66,10 @@ pub fn map_data_type(data_type: &str) -> DataType { "float" => DataType::Float32, "float8" => DataType::Float64, "double" => DataType::Float64, - "timestamp" => DataType::Timestamp(TimeUnit::Nanosecond, None), // chose the smallest time unit - "timestamptz" => DataType::Timestamp(TimeUnit::Nanosecond, None), // don't care about the time zone + "timestamp" | "datetime" => DataType::Timestamp(TimeUnit::Nanosecond, None), // chose the smallest time unit + "timestamptz" | "timestamp_with_timezone" | "timestamp_with_time_zone" => { + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) + } "date" => DataType::Date32, "interval" => DataType::Interval(IntervalUnit::DayTime), "json" => DataType::Utf8, // we don't have a JSON type, so we map it to Utf8 @@ -79,7 +81,6 @@ pub fn map_data_type(data_type: &str) -> DataType { // BigQuery Compatible Types "bignumeric" => DataType::Decimal128(38, 10), // set the default precision and scale "bytes" => DataType::Binary, - "datetime" => DataType::Timestamp(TimeUnit::Nanosecond, None), // chose the smallest time unit "float64" => DataType::Float64, "int64" => DataType::Int64, "time" => DataType::Time32(TimeUnit::Nanosecond), // chose the smallest time unit @@ -252,6 +253,7 @@ mod test { pub fn test_map_data_type() -> Result<()> { let test_cases = vec![ ("bool", DataType::Boolean), + ("boolean", DataType::Boolean), ("tinyint", DataType::Int8), ("int2", DataType::Int16), ("smallint", DataType::Int16), @@ -274,7 +276,15 @@ mod test { ("timestamp", DataType::Timestamp(TimeUnit::Nanosecond, None)), ( "timestamptz", - DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + ), + ( + "timestamp_with_timezone", + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + ), + ( + "timestamp_with_time_zone", + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), ), ("date", DataType::Date32), ("interval", DataType::Interval(IntervalUnit::DayTime)), diff --git a/wren-core/core/src/mdl/context.rs b/wren-core/core/src/mdl/context.rs index 5b02241e2..7b4f30305 100644 --- a/wren-core/core/src/mdl/context.rs +++ b/wren-core/core/src/mdl/context.rs @@ -140,6 +140,9 @@ fn analyze_rule_for_unparsing( Arc::new(InlineTableScan::new()), // Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule]. Arc::new(ExpandWildcardRule::new()), + // TimestampSimplify should be placed before TypeCoercion because the simplified timestamp should + // be casted to the target type if needed + Arc::new(TimestampSimplify::new()), // [Expr::Wildcard] should be expanded before [TypeCoercion] Arc::new(TypeCoercion::new()), // Disable it to avoid generate the alias name, `count(*)` because BigQuery doesn't allow @@ -180,7 +183,6 @@ fn optimize_rule_for_unparsing() -> Vec> { Arc::new(SingleDistinctToGroupBy::new()), // Disable SimplifyExpressions to avoid apply some function locally // Arc::new(SimplifyExpressions::new()), - Arc::new(TimestampSimplify::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateGroupByConstant::new()), diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index bfc1999c4..a787c4257 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -426,17 +426,20 @@ impl ColumnReference { #[cfg(test)] mod test { + use std::collections::HashMap; use std::fs; use std::path::PathBuf; use std::sync::Arc; use crate::mdl::builder::{ColumnBuilder, ManifestBuilder, ModelBuilder}; + use crate::mdl::context::create_ctx_with_mdl; use crate::mdl::function::RemoteFunction; use crate::mdl::manifest::Manifest; use crate::mdl::{self, transform_sql_with_ctx, AnalyzedWrenMDL}; use datafusion::arrow::array::{ ArrayRef, Int64Array, RecordBatch, StringArray, TimestampNanosecondArray, }; + use datafusion::assert_batches_eq; use datafusion::common::not_impl_err; use datafusion::common::Result; use datafusion::config::ConfigOptions; @@ -964,7 +967,7 @@ mod test { .build(), ) .column( - ColumnBuilder::new("cast_timestamp", "timestamp") + ColumnBuilder::new("cast_timestamptz", "timestamptz") .expression(r#"cast("出道時間" as timestamp with time zone)"#) .build(), ) @@ -973,7 +976,7 @@ mod test { .build(); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); - let sql = r#"select count(*) from wren.test.artist where cast(cast_timestamp as timestamp) > timestamp '2011-01-01 21:00:00'"#; + let sql = r#"select count(*) from wren.test.artist where cast(cast_timestamptz as timestamp) > timestamp '2011-01-01 21:00:00'"#; let actual = transform_sql_with_ctx( &SessionContext::new(), Arc::clone(&analyzed_mdl), @@ -982,8 +985,135 @@ mod test { ) .await?; assert_eq!(actual, - "SELECT count(*) FROM (SELECT artist.cast_timestamp FROM (SELECT CAST(artist.\"出道時間\" AS TIMESTAMP WITH TIME ZONE) AS cast_timestamp \ - FROM artist) AS artist) AS artist WHERE artist.cast_timestamp > CAST('2011-01-01 21:00:00' AS TIMESTAMP)"); + "SELECT count(*) FROM (SELECT artist.cast_timestamptz FROM (SELECT CAST(artist.\"出道時間\" AS TIMESTAMP WITH TIME ZONE) AS cast_timestamptz \ + FROM artist) AS artist) AS artist WHERE CAST(artist.cast_timestamptz AS TIMESTAMP) > CAST('2011-01-01 21:00:00' AS TIMESTAMP)"); + Ok(()) + } + + #[tokio::test] + async fn test_register_timestamptz() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_batch("timestamp_table", timestamp_table())?; + let provider = ctx + .catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("timestamp_table") + .await? + .unwrap(); + let mut registers = HashMap::new(); + registers.insert( + "datafusion.public.timestamp_table".to_string(), + Arc::clone(&provider), + ); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("timestamp_table") + .table_reference("datafusion.public.timestamp_table") + .column(ColumnBuilder::new("timestamp_col", "timestamp").build()) + .column(ColumnBuilder::new("timestamptz_col", "timestamptz").build()) + .build(), + ) + .build(); + + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, registers)?); + let ctx = create_ctx_with_mdl(&ctx, Arc::clone(&analyzed_mdl), true).await?; + let sql = r#"select arrow_typeof(timestamp_col), arrow_typeof(timestamptz_col) from wren.test.timestamp_table limit 1"#; + let result = ctx.sql(sql).await?.collect().await?; + let expected = vec![ + "+---------------------------------------------+-----------------------------------------------+", + "| arrow_typeof(timestamp_table.timestamp_col) | arrow_typeof(timestamp_table.timestamptz_col) |", + "+---------------------------------------------+-----------------------------------------------+", + "| Timestamp(Nanosecond, None) | Timestamp(Nanosecond, Some(\"UTC\")) |", + "+---------------------------------------------+-----------------------------------------------+", + ]; + assert_batches_eq!(&expected, &result); + Ok(()) + } + + #[tokio::test] + async fn test_coercion_timestamptz() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_batch("timestamp_table", timestamp_table())?; + for timezone_type in [ + "timestamptz", + "timestamp_with_timezone", + "timestamp_with_time_zone", + ] { + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("timestamp_table") + .table_reference("datafusion.public.timestamp_table") + .column(ColumnBuilder::new("timestamp_col", "timestamp").build()) + .column( + ColumnBuilder::new("timestamptz_col", timezone_type).build(), + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let sql = r#"select timestamp_col = timestamptz_col from wren.test.timestamp_table"#; + let actual = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + assert_eq!(actual, + "SELECT CAST(timestamp_table.timestamp_col AS TIMESTAMP WITH TIME ZONE) = timestamp_table.timestamptz_col FROM \ + (SELECT timestamp_table.timestamp_col, timestamp_table.timestamptz_col FROM \ + (SELECT timestamp_table.timestamp_col AS timestamp_col, timestamp_table.timestamptz_col AS timestamptz_col \ + FROM datafusion.public.timestamp_table) AS timestamp_table) AS timestamp_table"); + + let sql = r#"select timestamptz_col > cast('2011-01-01 18:00:00' as TIMESTAMP WITH TIME ZONE) from wren.test.timestamp_table"#; + let actual = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + // assert the simplified literal will be casted to the timestamp tz + assert_eq!(actual, + "SELECT timestamp_table.timestamptz_col > CAST(CAST('2011-01-01 18:00:00' AS TIMESTAMP) AS TIMESTAMP WITH TIME ZONE) \ + FROM (SELECT timestamp_table.timestamptz_col FROM (SELECT timestamp_table.timestamptz_col AS timestamptz_col \ + FROM datafusion.public.timestamp_table) AS timestamp_table) AS timestamp_table"); + + let sql = r#"select timestamptz_col > '2011-01-01 18:00:00' from wren.test.timestamp_table"#; + let actual = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + // assert the string literal will be casted to the timestamp tz + assert_eq!(actual, + "SELECT timestamp_table.timestamptz_col > CAST('2011-01-01 18:00:00' AS TIMESTAMP WITH TIME ZONE) \ + FROM (SELECT timestamp_table.timestamptz_col FROM (SELECT timestamp_table.timestamptz_col AS timestamptz_col \ + FROM datafusion.public.timestamp_table) AS timestamp_table) AS timestamp_table"); + + let sql = r#"select timestamp_col > cast('2011-01-01 18:00:00' as TIMESTAMP WITH TIME ZONE) from wren.test.timestamp_table"#; + let actual = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + // assert the simplified literal won't be casted to the timestamp tz + assert_eq!(actual, + "SELECT timestamp_table.timestamp_col > CAST('2011-01-01 18:00:00' AS TIMESTAMP) FROM \ + (SELECT timestamp_table.timestamp_col FROM (SELECT timestamp_table.timestamp_col AS timestamp_col \ + FROM datafusion.public.timestamp_table) AS timestamp_table) AS timestamp_table"); + } Ok(()) } @@ -1039,4 +1169,15 @@ mod test { ]) .unwrap() } + + fn timestamp_table() -> RecordBatch { + let timestamp: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![1, 2, 3])); + let timestamptz: ArrayRef = + Arc::new(TimestampNanosecondArray::from(vec![1, 2, 3]).with_timezone("UTC")); + RecordBatch::try_from_iter(vec![ + ("timestamp_col", timestamp), + ("timestamptz_col", timestamptz), + ]) + .unwrap() + } }