diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 410c12e01..d0ed525c1 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -113,7 +113,10 @@ class Parser(Spark2.Parser): "TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"), "TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"), "TRY_ELEMENT_AT": lambda args: exp.Bracket( - this=seq_get(args, 0), expressions=ensure_list(seq_get(args, 1)), safe=True + this=seq_get(args, 0), + expressions=ensure_list(seq_get(args, 1)), + offset=1, + safe=True, ), } @@ -172,7 +175,7 @@ class Generator(Spark2.Generator): def bracket_sql(self, expression: exp.Bracket) -> str: if expression.args.get("safe"): - key = seq_get(self.bracket_offset_expressions(expression), 0) + key = seq_get(self.bracket_offset_expressions(expression, index_offset=1), 0) return self.func("TRY_ELEMENT_AT", expression.this, key) return super().bracket_sql(expression) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index bce7c0b51..bf5f5f809 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -2657,11 +2657,13 @@ def between_sql(self, expression: exp.Between) -> str: high = self.sql(expression, "high") return f"{this} BETWEEN {low} AND {high}" - def bracket_offset_expressions(self, expression: exp.Bracket) -> t.List[exp.Expression]: + def bracket_offset_expressions( + self, expression: exp.Bracket, index_offset: t.Optional[int] = None + ) -> t.List[exp.Expression]: return apply_index_offset( expression.this, expression.expressions, - self.dialect.INDEX_OFFSET - expression.args.get("offset", 0), + (index_offset or self.dialect.INDEX_OFFSET) - expression.args.get("offset", 0), ) def bracket_sql(self, expression: exp.Bracket) -> str: diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 4fed68cae..01859c620 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -2,7 +2,6 @@ from sqlglot import exp, parse_one from sqlglot.dialects.dialect import Dialects -from sqlglot.helper import logger as helper_logger from tests.dialects.test_dialect import Validator @@ -294,19 +293,19 @@ def test_spark(self): "SELECT STR_TO_MAP('a:1,b:2,c:3')", "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", ) - - with self.assertLogs(helper_logger): - self.validate_all( - "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", - read={ - "databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", - }, - write={ - "databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", - "duckdb": "SELECT ([1, 2, 3])[3]", - "spark": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", - }, - ) + self.validate_all( + "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + read={ + "databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + "presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 2)", + }, + write={ + "databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + "spark": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + "duckdb": "SELECT ([1, 2, 3])[2]", + "presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 2)", + }, + ) self.validate_all( "SELECT ARRAY_AGG(x) FILTER (WHERE x = 5) FROM (SELECT 1 UNION ALL SELECT NULL) AS t(x)",