Skip to content

Commit

Permalink
feat(core): register timestamptz type to DataFusion timestamp with ti…
Browse files Browse the repository at this point in the history
…me zone type (#908)
  • Loading branch information
goldmedal committed Nov 25, 2024
1 parent 0a6557e commit 935677e
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 66 deletions.
15 changes: 8 additions & 7 deletions ibis-server/resources/function_list/bigquery.csv
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ scalar,array_concat,ARRAY,"Concatenates multiple arrays into one."
scalar,array_to_string,STRING,"Converts an array to a single string."
scalar,generate_array,ARRAY,"Generates an array of values in a range."
scalar,generate_date_array,ARRAY,"Generates an array of dates in a range."
scalar,parse_timestamp,TIMESTAMP,"Parses a timestamp from a string."
scalar,parse_timestamp,TIMESTAMPTZ,"Parses a timestamp from a string."
scalar,string_to_array,ARRAY,"Splits a string into an array of substrings."
scalar,safe_divide,FLOAT64,"Divides two numbers, returning NULL if the divisor is zero."
scalar,safe_multiply,FLOAT64,"Multiplies two numbers, returning NULL if an overflow occurs."
Expand Down Expand Up @@ -74,17 +74,18 @@ scalar,substr,STRING,"Returns a substring."
scalar,cast,ANY,"Converts a value to a different data type."
scalar,safe_cast,ANY,"Converts a value to a different data type, returning NULL on error."
scalar,current_date,DATE,"Returns the current date."
scalar,current_datetime,TIMESTAMP,"Returns the current date."
scalar,date_add,DATE,"Adds a specified interval to a date."
scalar,date_sub,DATE,"Subtracts a specified interval from a date."
scalar,date_diff,INT64,"Returns the difference between two dates."
scalar,date_trunc,DATE,"Truncates a date to a specified granularity."
scalar,timestamp_add,TIMESTAMP,"Adds a specified interval to a timestamp."
scalar,timestamp_sub,TIMESTAMP,"Subtracts a specified interval from a timestamp."
scalar,timestamp_add,TIMESTAMPTZ,"Adds a specified interval to a timestamp."
scalar,timestamp_sub,TIMESTAMPTZ,"Subtracts a specified interval from a timestamp."
scalar,timestamp_diff,INT64,"Returns the difference between two timestamps."
scalar,timestamp_trunc,TIMESTAMP,"Truncates a timestamp to a specified granularity."
scalar,timestamp_micros,TIMESTAMP,"Converts the number of microseconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP."
scalar,timestamp_millis,TIMESTAMP,"Converts the number of milliseconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP."
scalar,timestamp_seconds,TIMESTAMP,"Converts the number of seconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP."
scalar,timestamp_trunc,TIMESTAMPTZ,"Truncates a timestamp to a specified granularity."
scalar,timestamp_micros,TIMESTAMPTZ,"Converts the number of microseconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP."
scalar,timestamp_millis,TIMESTAMPTZ,"Converts the number of milliseconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP."
scalar,timestamp_seconds,TIMESTAMPTZ,"Converts the number of seconds since 1970-01-01 00:00:00 UTC to a TIMESTAMP."
scalar,format_date,STRING,"Formats a date according to the specified format string."
scalar,format_timestamp,STRING,"Formats a timestamp according to the specified format string."
scalar,parse_date,DATE,"Parses a date from a string."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from app.config import get_config
from app.main import app
from tests.routers.v3.connector.bigquery.conftest import base_url, function_list_path
from tests.conftest import DATAFUSION_FUNCTION_COUNT
from tests.routers.v3.connector.bigquery.conftest import base_url, function_list_path

manifest = {
"catalog": "my_catalog",
Expand Down Expand Up @@ -47,7 +47,7 @@ def test_function_list():
response = client.get(url=f"{base_url}/functions")
assert response.status_code == 200
result = response.json()
assert len(result) == DATAFUSION_FUNCTION_COUNT + 33
assert len(result) == DATAFUSION_FUNCTION_COUNT + 34
the_func = next(filter(lambda x: x["name"] == "abs", result))
assert the_func == {
"name": "abs",
Expand Down
31 changes: 25 additions & 6 deletions ibis-server/tests/routers/v3/connector/bigquery/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@
{
"name": "timestamptz",
"expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)",
"type": "timestamp",
"type": "timestamptz",
},
{
"name": "dst_utc_minus_5",
"expression": "cast('2024-01-15 23:00:00 America/New_York' as timestamp with time zone)",
"type": "timestamp",
"type": "timestamptz",
},
{
"name": "dst_utc_minus_4",
"expression": "cast('2024-07-15 23:00:00 America/New_York' as timestamp with time zone)",
"type": "timestamp",
"type": "timestamptz",
},
],
"primaryKey": "o_orderkey",
Expand Down Expand Up @@ -83,9 +83,9 @@ def test_query(manifest_str, connection_info):
assert len(result["data"]) == 1
assert result["data"][0] == [
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000",
"2024-01-16 04:00:00.000000", # utc-5
"2024-07-16 03:00:00.000000", # utc-4
"2024-01-01 23:59:59.000000 UTC",
"2024-01-16 04:00:00.000000 UTC", # utc-5
"2024-07-16 03:00:00.000000 UTC", # utc-4
"36485_1202",
1202,
"1992-06-06",
Expand Down Expand Up @@ -237,3 +237,22 @@ def test_timestamp_func(manifest_str, connection_info):
"micros": "object",
"seconds": "object",
}

response = client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT timestamp with time zone '2000-01-01 10:00:00' < current_datetime() as compare",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == 1
assert len(result["data"]) == 1
assert result["data"][0] == [
True,
]
assert result["dtypes"] == {
"compare": "bool",
}
12 changes: 6 additions & 6 deletions ibis-server/tests/routers/v3/connector/postgres/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@
{
"name": "timestamptz",
"expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)",
"type": "timestamp",
"type": "timestamptz",
},
{
"name": "dst_utc_minus_5",
"expression": "cast('2024-01-15 23:00:00 America/New_York' as timestamp with time zone)",
"type": "timestamp",
"type": "timestamptz",
},
{
"name": "dst_utc_minus_4",
"expression": "cast('2024-07-15 23:00:00 America/New_York' as timestamp with time zone)",
"type": "timestamp",
"type": "timestamptz",
},
],
"primaryKey": "o_orderkey",
Expand Down Expand Up @@ -83,9 +83,9 @@ def test_query(manifest_str, connection_info):
assert len(result["data"]) == 1
assert result["data"][0] == [
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000",
"2024-01-16 04:00:00.000000", # utc-5
"2024-07-16 03:00:00.000000", # utc-4
"2024-01-01 23:59:59.000000 UTC",
"2024-01-16 04:00:00.000000 UTC", # utc-5
"2024-07-16 03:00:00.000000 UTC", # utc-4
"1_370",
370,
"1996-01-02",
Expand Down
6 changes: 4 additions & 2 deletions ibis-server/tools/mdl_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import base64
import json

from loguru import logger
from wren_core import SessionContext

# Set up argument parsing
Expand Down Expand Up @@ -39,9 +40,10 @@
sql = f"select \"{column['name']}\" from \"{model['name']}\""
try:
planned_sql = session_context.transform_sql(sql)
except Exception:
except Exception as e:
error_cases.append((model, column))
print(f"Error transforming {model['name']} {column['name']}")
logger.info(f"Error transforming {model['name']} {column['name']}")
logger.debug(e)

if len(error_cases) > 0:
raise Exception(f"Error transforming {len(error_cases)} columns")
6 changes: 3 additions & 3 deletions ibis-server/tools/query_local_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
print("# Function List Path:", function_list_path)
print("# Connection Info Path:", connection_info_path)
print("# Data Source:", data_source)
print("# SQL Query:", sql)
print("# SQL Query:\n", sql)
print("#")

# Read and encode the JSON data
Expand All @@ -60,11 +60,11 @@
print("#")
session_context = SessionContext(encoded_str, function_list_path)
planned_sql = session_context.transform_sql(sql)
print("# Planned SQL:", planned_sql)
print("# Planned SQL:\n", planned_sql)

# Transpile the planned SQL
dialect_sql = sqlglot.transpile(planned_sql, read="trino", write=data_source)[0]
print("# Dialect SQL:", dialect_sql)
print("# Dialect SQL:\n", dialect_sql)
print("#")

if data_source == "bigquery":
Expand Down
43 changes: 13 additions & 30 deletions wren-core/core/src/logical_plan/optimize/simplify_timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@
* under the License.
*/
use datafusion::arrow::datatypes::{DataType, TimeUnit};
use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion::common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRewriter,
};
use datafusion::common::ScalarValue::{
TimestampMicrosecond, TimestampMillisecond, TimestampSecond,
};
use datafusion::common::{DFSchema, DFSchemaRef, Result, ScalarValue};
use datafusion::config::ConfigOptions;
use datafusion::execution::context::ExecutionProps;
use datafusion::logical_expr::expr_rewriter::NamePreserver;
use datafusion::logical_expr::simplify::SimplifyContext;
use datafusion::logical_expr::utils::merge_schema;
use datafusion::logical_expr::{cast, Cast, LogicalPlan, TryCast};
use datafusion::optimizer::optimizer::ApplyOrder;
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
use datafusion::optimizer::{OptimizerConfig, OptimizerRule};
use datafusion::optimizer::AnalyzerRule;
use datafusion::prelude::Expr;
use datafusion::scalar::ScalarValue::TimestampNanosecond;
use std::sync::Arc;
Expand All @@ -46,37 +48,18 @@ impl TimestampSimplify {
}
}

impl OptimizerRule for TimestampSimplify {
fn name(&self) -> &str {
"simplify_cast_expressions"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
impl AnalyzerRule for TimestampSimplify {
fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
Self::analyze_internal(plan).data()
}

fn supports_rewrite(&self) -> bool {
true
}

/// if supports_owned returns true, the Optimizer calls
/// [`Self::rewrite`] instead of [`Self::try_optimize`]
fn rewrite(
&self,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let mut execution_props = ExecutionProps::new();
execution_props.query_execution_start_time = config.query_execution_start_time();
Self::optimize_internal(plan, &execution_props)
fn name(&self) -> &str {
"simplify_timestamp_expressions"
}
}

impl TimestampSimplify {
fn optimize_internal(
plan: LogicalPlan,
execution_props: &ExecutionProps,
) -> Result<Transformed<LogicalPlan>> {
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let schema = if !plan.inputs().is_empty() {
DFSchemaRef::new(merge_schema(&plan.inputs()))
} else if let LogicalPlan::TableScan(scan) = &plan {
Expand All @@ -97,8 +80,8 @@ impl TimestampSimplify {
} else {
Arc::new(DFSchema::empty())
};

let info = SimplifyContext::new(execution_props).with_schema(schema);
let execution_props = ExecutionProps::default();
let info = SimplifyContext::new(&execution_props).with_schema(schema);

// Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer)
// Just need to rewrite our own expressions
Expand Down
20 changes: 15 additions & 5 deletions wren-core/core/src/logical_plan/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub fn map_data_type(data_type: &str) -> DataType {
}
match data_type {
// Wren Definition Types
"bool" => DataType::Boolean,
"bool" | "boolean" => DataType::Boolean,
"tinyint" => DataType::Int8,
"int2" => DataType::Int16,
"smallint" => DataType::Int16,
Expand All @@ -66,8 +66,10 @@ pub fn map_data_type(data_type: &str) -> DataType {
"float" => DataType::Float32,
"float8" => DataType::Float64,
"double" => DataType::Float64,
"timestamp" => DataType::Timestamp(TimeUnit::Nanosecond, None), // chose the smallest time unit
"timestamptz" => DataType::Timestamp(TimeUnit::Nanosecond, None), // don't care about the time zone
"timestamp" | "datetime" => DataType::Timestamp(TimeUnit::Nanosecond, None), // chose the smallest time unit
"timestamptz" | "timestamp_with_timezone" | "timestamp_with_time_zone" => {
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into()))
}
"date" => DataType::Date32,
"interval" => DataType::Interval(IntervalUnit::DayTime),
"json" => DataType::Utf8, // we don't have a JSON type, so we map it to Utf8
Expand All @@ -79,7 +81,6 @@ pub fn map_data_type(data_type: &str) -> DataType {
// BigQuery Compatible Types
"bignumeric" => DataType::Decimal128(38, 10), // set the default precision and scale
"bytes" => DataType::Binary,
"datetime" => DataType::Timestamp(TimeUnit::Nanosecond, None), // chose the smallest time unit
"float64" => DataType::Float64,
"int64" => DataType::Int64,
"time" => DataType::Time32(TimeUnit::Nanosecond), // chose the smallest time unit
Expand Down Expand Up @@ -252,6 +253,7 @@ mod test {
pub fn test_map_data_type() -> Result<()> {
let test_cases = vec![
("bool", DataType::Boolean),
("boolean", DataType::Boolean),
("tinyint", DataType::Int8),
("int2", DataType::Int16),
("smallint", DataType::Int16),
Expand All @@ -274,7 +276,15 @@ mod test {
("timestamp", DataType::Timestamp(TimeUnit::Nanosecond, None)),
(
"timestamptz",
DataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
),
(
"timestamp_with_timezone",
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
),
(
"timestamp_with_time_zone",
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
),
("date", DataType::Date32),
("interval", DataType::Interval(IntervalUnit::DayTime)),
Expand Down
4 changes: 3 additions & 1 deletion wren-core/core/src/mdl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ fn analyze_rule_for_unparsing(
Arc::new(InlineTableScan::new()),
// Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule].
Arc::new(ExpandWildcardRule::new()),
// TimestampSimplify should be placed before TypeCoercion because the simplified timestamp should
// be casted to the target type if needed
Arc::new(TimestampSimplify::new()),
// [Expr::Wildcard] should be expanded before [TypeCoercion]
Arc::new(TypeCoercion::new()),
// Disable it to avoid generate the alias name, `count(*)` because BigQuery doesn't allow
Expand Down Expand Up @@ -180,7 +183,6 @@ fn optimize_rule_for_unparsing() -> Vec<Arc<dyn OptimizerRule + Send + Sync>> {
Arc::new(SingleDistinctToGroupBy::new()),
// Disable SimplifyExpressions to avoid apply some function locally
// Arc::new(SimplifyExpressions::new()),
Arc::new(TimestampSimplify::new()),
Arc::new(UnwrapCastInComparison::new()),
Arc::new(CommonSubexprEliminate::new()),
Arc::new(EliminateGroupByConstant::new()),
Expand Down
Loading

0 comments on commit 935677e

Please sign in to comment.