Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(spark): Offset TRY_ELEMENT_AT index by one #4183

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sqlglot/dialects/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ class Parser(Spark2.Parser):
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
"TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
"TRY_ELEMENT_AT": lambda args: exp.Bracket(
this=seq_get(args, 0), expressions=ensure_list(seq_get(args, 1)), safe=True
this=seq_get(args, 0),
expressions=ensure_list(seq_get(args, 1)),
offset=1,
safe=True,
),
}

Expand Down Expand Up @@ -172,7 +175,7 @@ class Generator(Spark2.Generator):

def bracket_sql(self, expression: exp.Bracket) -> str:
if expression.args.get("safe"):
key = seq_get(self.bracket_offset_expressions(expression), 0)
key = seq_get(self.bracket_offset_expressions(expression, index_offset=1), 0)
return self.func("TRY_ELEMENT_AT", expression.this, key)

return super().bracket_sql(expression)
Expand Down
6 changes: 4 additions & 2 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2657,11 +2657,13 @@ def between_sql(self, expression: exp.Between) -> str:
high = self.sql(expression, "high")
return f"{this} BETWEEN {low} AND {high}"

def bracket_offset_expressions(self, expression: exp.Bracket) -> t.List[exp.Expression]:
def bracket_offset_expressions(
self, expression: exp.Bracket, index_offset: t.Optional[int] = None
) -> t.List[exp.Expression]:
return apply_index_offset(
expression.this,
expression.expressions,
self.dialect.INDEX_OFFSET - expression.args.get("offset", 0),
(index_offset or self.dialect.INDEX_OFFSET) - expression.args.get("offset", 0),
)

def bracket_sql(self, expression: exp.Bracket) -> str:
Expand Down
27 changes: 13 additions & 14 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from sqlglot import exp, parse_one
from sqlglot.dialects.dialect import Dialects
from sqlglot.helper import logger as helper_logger
from tests.dialects.test_dialect import Validator


Expand Down Expand Up @@ -294,19 +293,19 @@ def test_spark(self):
"SELECT STR_TO_MAP('a:1,b:2,c:3')",
"SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')",
)

with self.assertLogs(helper_logger):
self.validate_all(
"SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
read={
"databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
},
write={
"databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
"duckdb": "SELECT ([1, 2, 3])[3]",
"spark": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
},
)
self.validate_all(
"SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
read={
"databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
"presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 2)",
},
write={
"databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
"spark": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
"duckdb": "SELECT ([1, 2, 3])[2]",
"presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 2)",
},
)

self.validate_all(
"SELECT ARRAY_AGG(x) FILTER (WHERE x = 5) FROM (SELECT 1 UNION ALL SELECT NULL) AS t(x)",
Expand Down
Loading