Skip to content

Commit

Permalink
Fix(spark)!: avoid redundant casts in FROM/TO_UTC_TIMESTAMP (#4725)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Feb 10, 2025
1 parent 7fe40c8 commit a790e41
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
16 changes: 12 additions & 4 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,12 @@ class Parser(Hive.Parser):
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DOUBLE": _build_as_cast("double"),
"FLOAT": _build_as_cast("float"),
"FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone(
this=exp.cast(seq_get(args, 0) or exp.Var(this=""), exp.DataType.Type.TIMESTAMP),
"FROM_UTC_TIMESTAMP": lambda args, dialect: exp.AtTimeZone(
this=exp.cast(
seq_get(args, 0) or exp.Var(this=""),
exp.DataType.Type.TIMESTAMP,
dialect=dialect,
),
zone=seq_get(args, 1),
),
"INT": _build_as_cast("int"),
Expand All @@ -202,8 +206,12 @@ class Parser(Hive.Parser):
else build_formatted_time(exp.StrToTime, "spark")(args)
),
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone(
this=exp.cast(seq_get(args, 0) or exp.Var(this=""), exp.DataType.Type.TIMESTAMP),
"TO_UTC_TIMESTAMP": lambda args, dialect: exp.FromTimeZone(
this=exp.cast(
seq_get(args, 0) or exp.Var(this=""),
exp.DataType.Type.TIMESTAMP,
dialect=dialect,
),
zone=seq_get(args, 1),
),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
Expand Down
12 changes: 8 additions & 4 deletions tests/dialects/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ def test_databricks(self):
self.validate_identity(
"CREATE TABLE IF NOT EXISTS db.table (a TIMESTAMP, b BOOLEAN GENERATED ALWAYS AS (NOT a IS NULL)) USING DELTA"
)
self.validate_identity(
"SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t",
"SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t",
)
self.validate_identity(
"SELECT * FROM sales UNPIVOT INCLUDE NULLS (sales FOR quarter IN (q1 AS `Jan-Mar`))"
)
Expand All @@ -54,6 +50,10 @@ def test_databricks(self):
self.validate_identity(
"COPY INTO target FROM `s3://link` FILEFORMAT = AVRO VALIDATE = ALL FILES = ('file1', 'file2') FORMAT_OPTIONS ('opt1'='true', 'opt2'='test') COPY_OPTIONS ('mergeSchema'='true')"
)
self.validate_identity(
"SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t",
"SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t",
)
self.validate_identity(
"DATE_DIFF(day, created_at, current_date())",
"DATEDIFF(DAY, created_at, CURRENT_DATE)",
Expand All @@ -62,6 +62,10 @@ def test_databricks(self):
r'SELECT r"\\foo.bar\"',
r"SELECT '\\\\foo.bar\\'",
)
self.validate_identity(
"FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz)",
"FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)",
)

self.validate_all(
"CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))",
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def test_time(self):
},
)
self.validate_all(
"SELECT AT_TIMEZONE(CAST(CAST('2012-10-31 00:00' AS TIMESTAMP WITH TIME ZONE) AS TIMESTAMP), 'America/Sao_Paulo')",
"SELECT AT_TIMEZONE(CAST('2012-10-31 00:00' AS TIMESTAMP WITH TIME ZONE), 'America/Sao_Paulo')",
read={
"spark": "SELECT FROM_UTC_TIMESTAMP(TIMESTAMP '2012-10-31 00:00', 'America/Sao_Paulo')",
},
Expand Down

0 comments on commit a790e41

Please sign in to comment.