From 19087a40ecbaaf4f52ad0f0748245df232950cad Mon Sep 17 00:00:00 2001 From: George Sittas Date: Thu, 4 Apr 2024 23:40:04 +0300 Subject: [PATCH] Feat!: transpile map retrieval to duckdb, transpile TRY_ELEMENT_AT --- sqlglot/dialects/bigquery.py | 10 ++-------- sqlglot/dialects/dialect.py | 7 +++++++ sqlglot/dialects/duckdb.py | 35 +++++++++++++++++++++++++--------- sqlglot/dialects/spark.py | 12 +++++++++++- sqlglot/expressions.py | 8 +++++++- sqlglot/generator.py | 7 +++++-- tests/dialects/test_duckdb.py | 1 + tests/dialects/test_spark.py | 36 +++++++++++++++++++++++++++++------ 8 files changed, 89 insertions(+), 27 deletions(-) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 2167ba29a1..dbe90b045f 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -15,7 +15,7 @@ build_formatted_time, filter_array_using_unnest, if_sql, - inline_array_sql, + inline_array_unless_query, max_or_greatest, min_or_least, no_ilike_sql, @@ -576,6 +576,7 @@ class Generator(generator.Generator): exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), + exp.Array: inline_array_unless_query, exp.ArrayContains: _array_contains_sql, exp.ArrayFilter: filter_array_using_unnest, exp.ArraySize: rename_func("ARRAY_LENGTH"), @@ -843,13 +844,6 @@ def attimezone_sql(self, expression: exp.AtTimeZone) -> str: def trycast_sql(self, expression: exp.TryCast) -> str: return self.cast_sql(expression, safe_prefix="SAFE_") - def array_sql(self, expression: exp.Array) -> str: - first_arg = seq_get(expression.expressions, 0) - if isinstance(first_arg, exp.Query): - return f"ARRAY{self.wrap(self.sql(first_arg))}" - - return inline_array_sql(self, expression) - def bracket_sql(self, expression: exp.Bracket) -> str: this = expression.this expressions = expression.expressions diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 81057c2695..1e4cfeba09 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -571,6 +571,13 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str: return f"[{self.expressions(expression, flat=True)}]" +def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: + elem = seq_get(expression.expressions, 0) + if isinstance(elem, exp.Expression) and elem.find(exp.Query): + return self.func("ARRAY", elem) + return inline_array_sql(self, expression) + + def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: return self.like_sql( exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 6a1d07a18d..3b0651f859 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -15,7 +15,7 @@ datestrtodate_sql, encode_decode_sql, build_formatted_time, - inline_array_sql, + inline_array_unless_query, no_comment_column_constraint_sql, no_safe_divide_sql, no_timestamp_sql, @@ -312,6 +312,15 @@ class Parser(parser.Parser): ), } + def _parse_bracket( + self, this: t.Optional[exp.Expression] = None + ) -> t.Optional[exp.Expression]: + bracket = super()._parse_bracket(this) + if isinstance(bracket, exp.Bracket): + bracket.set("returns_list_for_maps", True) + + return bracket + def _parse_map(self) -> exp.ToMap | exp.Map: if self._match(TokenType.L_BRACE, advance=False): return self.expression(exp.ToMap, this=self._parse_bracket()) @@ -370,11 +379,7 @@ class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, - exp.Array: lambda self, e: ( - self.func("ARRAY", e.expressions[0]) - if e.expressions and e.expressions[0].find(exp.Select) - else inline_array_sql(self, e) - ), + exp.Array: inline_array_unless_query, exp.ArrayFilter: rename_func("LIST_FILTER"), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"), @@ -593,7 +598,19 @@ def generateseries_sql(self, expression: exp.GenerateSeries) -> str: return super().generateseries_sql(expression) def bracket_sql(self, expression: exp.Bracket) -> str: - if isinstance(expression.this, exp.Array): - expression.this.replace(exp.paren(expression.this)) + this = expression.this + if isinstance(this, exp.Array): + this.replace(exp.paren(this)) + + bracket = super().bracket_sql(expression) + + if not expression.args.get("returns_list_for_maps"): + if not this.type: + from sqlglot.optimizer.annotate_types import annotate_types + + this = annotate_types(this) + + if this.is_type(exp.DataType.Type.MAP): + bracket = f"({bracket})[1]" - return super().bracket_sql(expression) + return bracket diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 88b5ddc4a9..9bb9a5cb97 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -6,7 +6,7 @@ from sqlglot.dialects.dialect import rename_func, unit_to_var from sqlglot.dialects.hive import _build_with_ignore_nulls from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider -from sqlglot.helper import seq_get +from sqlglot.helper import ensure_list, seq_get from sqlglot.transforms import ( ctas_with_tmp_tables_to_create_tmp_view, remove_unique_constraints, @@ -63,6 +63,9 @@ class Parser(Spark2.Parser): **Spark2.Parser.FUNCTIONS, "ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue), "DATEDIFF": _build_datediff, + "TRY_ELEMENT_AT": lambda args: exp.Bracket( + this=seq_get(args, 0), expressions=ensure_list(seq_get(args, 1)), safe=True + ), } def _parse_generated_as_identity( @@ -112,6 +115,13 @@ class Generator(Spark2.Generator): TRANSFORMS.pop(exp.DateDiff) TRANSFORMS.pop(exp.Group) + def bracket_sql(self, expression: exp.Bracket) -> str: + if expression.args.get("safe"): + key = seq_get(self.bracket_offset_expressions(expression), 0) + return self.func("TRY_ELEMENT_AT", expression.this, key) + + return super().bracket_sql(expression) + def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str: return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index e79c04bd8d..38bfc91a03 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -4395,7 +4395,13 @@ class Between(Predicate): class Bracket(Condition): # https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#array_subscript_operator - arg_types = {"this": True, "expressions": True, "offset": False, "safe": False} + arg_types = { + "this": True, + "expressions": True, + "offset": False, + "safe": False, + "returns_list_for_maps": False, + } @property def output_name(self) -> str: diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 76d9b5d65f..df0929655b 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -2412,12 +2412,15 @@ def between_sql(self, expression: exp.Between) -> str: high = self.sql(expression, "high") return f"{this} BETWEEN {low} AND {high}" - def bracket_sql(self, expression: exp.Bracket) -> str: - expressions = apply_index_offset( + def bracket_offset_expressions(self, expression: exp.Bracket) -> t.List[exp.Expression]: + return apply_index_offset( expression.this, expression.expressions, self.dialect.INDEX_OFFSET - expression.args.get("offset", 0), ) + + def bracket_sql(self, expression: exp.Bracket) -> str: + expressions = self.bracket_offset_expressions(expression) expressions_sql = ", ".join(self.sql(e) for e in expressions) return f"{self.sql(expression, 'this')}[{expressions_sql}]" diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 5a7e93e1fc..0b13a7042a 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -240,6 +240,7 @@ def test_duckdb(self): self.validate_identity("SELECT MAP(['key1', 'key2', 'key3'], [10, 20, 30])") self.validate_identity("SELECT MAP {'x': 1}") + self.validate_identity("SELECT (MAP {'x': 1})['x']") self.validate_identity("SELECT df1.*, df2.* FROM df1 POSITIONAL JOIN df2") self.validate_identity("MAKE_TIMESTAMP(1992, 9, 20, 13, 34, 27.123456)") self.validate_identity("MAKE_TIMESTAMP(1667810584123456)") diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 18f1fb732a..d2285e0565 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -2,6 +2,7 @@ from sqlglot import exp, parse_one from sqlglot.dialects.dialect import Dialects +from sqlglot.helper import logger as helper_logger from tests.dialects.test_dialect import Validator @@ -223,17 +224,16 @@ def test_hint(self, logger): ) def test_spark(self): - self.validate_identity("any_value(col, true)", "ANY_VALUE(col) IGNORE NULLS") - self.validate_identity("first(col, true)", "FIRST(col) IGNORE NULLS") - self.validate_identity("first_value(col, true)", "FIRST_VALUE(col) IGNORE NULLS") - self.validate_identity("last(col, true)", "LAST(col) IGNORE NULLS") - self.validate_identity("last_value(col, true)", "LAST_VALUE(col) IGNORE NULLS") - self.assertEqual( parse_one("REFRESH TABLE t", read="spark").assert_is(exp.Refresh).sql(dialect="spark"), "REFRESH TABLE t", ) + self.validate_identity("any_value(col, true)", "ANY_VALUE(col) IGNORE NULLS") + self.validate_identity("first(col, true)", "FIRST(col) IGNORE NULLS") + self.validate_identity("first_value(col, true)", "FIRST_VALUE(col) IGNORE NULLS") + self.validate_identity("last(col, true)", "LAST(col) IGNORE NULLS") + self.validate_identity("last_value(col, true)", "LAST_VALUE(col) IGNORE NULLS") self.validate_identity("DESCRIBE EXTENDED db.table") self.validate_identity("SELECT * FROM test TABLESAMPLE (50 PERCENT)") self.validate_identity("SELECT * FROM test TABLESAMPLE (5 ROWS)") @@ -284,6 +284,30 @@ def test_spark(self): "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", ) + with self.assertLogs(helper_logger): + self.validate_all( + "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + read={ + "databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + }, + write={ + "databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + "duckdb": "SELECT ([1, 2, 3])[3]", + "spark": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + }, + ) + + self.validate_all( + "SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)", + read={ + "databricks": "SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)", + }, + write={ + "databricks": "SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)", + "duckdb": "SELECT (MAP([1, 2], ['a', 'b'])[2])[1]", + "spark": "SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)", + }, + ) self.validate_all( "SELECT SPLIT('123|789', '\\\\|')", read={