Skip to content

Commit

Permalink
fix(spark): Support DB's TIMESTAMP_DIFF (#4373)
Browse files Browse the repository at this point in the history
  • Loading branch information
VaggelisD authored Nov 11, 2024
1 parent 702fe31 commit 4d3904e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
4 changes: 0 additions & 4 deletions sqlglot/dialects/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
date_delta_sql,
build_date_delta,
timestamptrunc_sql,
timestampdiff_sql,
)
from sqlglot.dialects.spark import Spark
from sqlglot.tokens import TokenType
Expand Down Expand Up @@ -46,7 +45,6 @@ class Parser(Spark.Parser):
"DATE_ADD": build_date_delta(exp.DateAdd),
"DATEDIFF": build_date_delta(exp.DateDiff),
"DATE_DIFF": build_date_delta(exp.DateDiff),
"TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff),
"GET_JSON_OBJECT": _build_json_extract,
}

Expand Down Expand Up @@ -75,8 +73,6 @@ class Generator(Spark.Generator):
exp.Mul(this=e.expression, expression=exp.Literal.number(-1)),
e.this,
),
exp.DatetimeDiff: timestampdiff_sql,
exp.TimestampDiff: timestampdiff_sql,
exp.DatetimeTrunc: timestamptrunc_sql(),
exp.Select: transforms.preprocess(
[
Expand Down
5 changes: 4 additions & 1 deletion sqlglot/dialects/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing as t

from sqlglot import exp
from sqlglot.dialects.dialect import rename_func, unit_to_var
from sqlglot.dialects.dialect import rename_func, unit_to_var, timestampdiff_sql, build_date_delta
from sqlglot.dialects.hive import _build_with_ignore_nulls
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider, _build_as_cast
from sqlglot.helper import ensure_list, seq_get
Expand Down Expand Up @@ -108,6 +108,7 @@ class Parser(Spark2.Parser):
"DATE_ADD": _build_dateadd,
"DATEADD": _build_dateadd,
"TIMESTAMPADD": _build_dateadd,
"TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff),
"DATEDIFF": _build_datediff,
"DATE_DIFF": _build_datediff,
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
Expand Down Expand Up @@ -167,6 +168,8 @@ class Generator(Spark2.Generator):
exp.StartsWith: rename_func("STARTSWITH"),
exp.TsOrDsAdd: _dateadd_sql,
exp.TimestampAdd: _dateadd_sql,
exp.DatetimeDiff: timestampdiff_sql,
exp.TimestampDiff: timestampdiff_sql,
exp.TryCast: lambda self, e: (
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
),
Expand Down
11 changes: 11 additions & 0 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,17 @@ def test_spark(self):
},
)

self.validate_all(
"SELECT TIMESTAMPDIFF(MONTH, foo, bar)",
read={
"databricks": "SELECT TIMESTAMPDIFF(MONTH, foo, bar)",
},
write={
"spark": "SELECT TIMESTAMPDIFF(MONTH, foo, bar)",
"databricks": "SELECT TIMESTAMPDIFF(MONTH, foo, bar)",
},
)

def test_bool_or(self):
self.validate_all(
"SELECT a, LOGICAL_OR(b) FROM table GROUP BY a",
Expand Down

0 comments on commit 4d3904e

Please sign in to comment.