Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat!: transpile map retrieval to duckdb, transpile TRY_ELEMENT_AT #3277

Merged
merged 1 commit into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
build_formatted_time,
filter_array_using_unnest,
if_sql,
inline_array_sql,
inline_array_unless_query,
max_or_greatest,
min_or_least,
no_ilike_sql,
Expand Down Expand Up @@ -576,6 +576,7 @@ class Generator(generator.Generator):
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
exp.Array: inline_array_unless_query,
exp.ArrayContains: _array_contains_sql,
exp.ArrayFilter: filter_array_using_unnest,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
Expand Down Expand Up @@ -843,13 +844,6 @@ def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
def trycast_sql(self, expression: exp.TryCast) -> str:
return self.cast_sql(expression, safe_prefix="SAFE_")

def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
if isinstance(first_arg, exp.Query):
return f"ARRAY{self.wrap(self.sql(first_arg))}"

return inline_array_sql(self, expression)

def bracket_sql(self, expression: exp.Bracket) -> str:
this = expression.this
expressions = expression.expressions
Expand Down
7 changes: 7 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,13 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str:
return f"[{self.expressions(expression, flat=True)}]"


def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
elem = seq_get(expression.expressions, 0)
if isinstance(elem, exp.Expression) and elem.find(exp.Query):
return self.func("ARRAY", elem)
return inline_array_sql(self, expression)


def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql(
exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
Expand Down
35 changes: 26 additions & 9 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
datestrtodate_sql,
encode_decode_sql,
build_formatted_time,
inline_array_sql,
inline_array_unless_query,
no_comment_column_constraint_sql,
no_safe_divide_sql,
no_timestamp_sql,
Expand Down Expand Up @@ -312,6 +312,15 @@ class Parser(parser.Parser):
),
}

def _parse_bracket(
self, this: t.Optional[exp.Expression] = None
) -> t.Optional[exp.Expression]:
bracket = super()._parse_bracket(this)
if isinstance(bracket, exp.Bracket):
bracket.set("returns_list_for_maps", True)

return bracket

def _parse_map(self) -> exp.ToMap | exp.Map:
if self._match(TokenType.L_BRACE, advance=False):
return self.expression(exp.ToMap, this=self._parse_bracket())
Expand Down Expand Up @@ -370,11 +379,7 @@ class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: (
self.func("ARRAY", e.expressions[0])
if e.expressions and e.expressions[0].find(exp.Select)
else inline_array_sql(self, e)
),
exp.Array: inline_array_unless_query,
exp.ArrayFilter: rename_func("LIST_FILTER"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"),
Expand Down Expand Up @@ -593,7 +598,19 @@ def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
return super().generateseries_sql(expression)

def bracket_sql(self, expression: exp.Bracket) -> str:
if isinstance(expression.this, exp.Array):
expression.this.replace(exp.paren(expression.this))
this = expression.this
if isinstance(this, exp.Array):
this.replace(exp.paren(this))

bracket = super().bracket_sql(expression)

if not expression.args.get("returns_list_for_maps"):
if not this.type:
from sqlglot.optimizer.annotate_types import annotate_types

this = annotate_types(this)

if this.is_type(exp.DataType.Type.MAP):
bracket = f"({bracket})[1]"

return super().bracket_sql(expression)
return bracket
12 changes: 11 additions & 1 deletion sqlglot/dialects/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlglot.dialects.dialect import rename_func, unit_to_var
from sqlglot.dialects.hive import _build_with_ignore_nulls
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
from sqlglot.helper import seq_get
from sqlglot.helper import ensure_list, seq_get
from sqlglot.transforms import (
ctas_with_tmp_tables_to_create_tmp_view,
remove_unique_constraints,
Expand Down Expand Up @@ -63,6 +63,9 @@ class Parser(Spark2.Parser):
**Spark2.Parser.FUNCTIONS,
"ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
"DATEDIFF": _build_datediff,
"TRY_ELEMENT_AT": lambda args: exp.Bracket(
this=seq_get(args, 0), expressions=ensure_list(seq_get(args, 1)), safe=True
),
}

def _parse_generated_as_identity(
Expand Down Expand Up @@ -112,6 +115,13 @@ class Generator(Spark2.Generator):
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)

def bracket_sql(self, expression: exp.Bracket) -> str:
if expression.args.get("safe"):
key = seq_get(self.bracket_offset_expressions(expression), 0)
return self.func("TRY_ELEMENT_AT", expression.this, key)

return super().bracket_sql(expression)

def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})"

Expand Down
8 changes: 7 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4395,7 +4395,13 @@ class Between(Predicate):

class Bracket(Condition):
# https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#array_subscript_operator
arg_types = {"this": True, "expressions": True, "offset": False, "safe": False}
arg_types = {
"this": True,
"expressions": True,
"offset": False,
"safe": False,
"returns_list_for_maps": False,
}

@property
def output_name(self) -> str:
Expand Down
7 changes: 5 additions & 2 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2412,12 +2412,15 @@ def between_sql(self, expression: exp.Between) -> str:
high = self.sql(expression, "high")
return f"{this} BETWEEN {low} AND {high}"

def bracket_sql(self, expression: exp.Bracket) -> str:
expressions = apply_index_offset(
def bracket_offset_expressions(self, expression: exp.Bracket) -> t.List[exp.Expression]:
return apply_index_offset(
expression.this,
expression.expressions,
self.dialect.INDEX_OFFSET - expression.args.get("offset", 0),
)

def bracket_sql(self, expression: exp.Bracket) -> str:
expressions = self.bracket_offset_expressions(expression)
expressions_sql = ", ".join(self.sql(e) for e in expressions)
return f"{self.sql(expression, 'this')}[{expressions_sql}]"

Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def test_duckdb(self):

self.validate_identity("SELECT MAP(['key1', 'key2', 'key3'], [10, 20, 30])")
self.validate_identity("SELECT MAP {'x': 1}")
self.validate_identity("SELECT (MAP {'x': 1})['x']")
self.validate_identity("SELECT df1.*, df2.* FROM df1 POSITIONAL JOIN df2")
self.validate_identity("MAKE_TIMESTAMP(1992, 9, 20, 13, 34, 27.123456)")
self.validate_identity("MAKE_TIMESTAMP(1667810584123456)")
Expand Down
36 changes: 30 additions & 6 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from sqlglot import exp, parse_one
from sqlglot.dialects.dialect import Dialects
from sqlglot.helper import logger as helper_logger
from tests.dialects.test_dialect import Validator


Expand Down Expand Up @@ -223,17 +224,16 @@ def test_hint(self, logger):
)

def test_spark(self):
self.validate_identity("any_value(col, true)", "ANY_VALUE(col) IGNORE NULLS")
self.validate_identity("first(col, true)", "FIRST(col) IGNORE NULLS")
self.validate_identity("first_value(col, true)", "FIRST_VALUE(col) IGNORE NULLS")
self.validate_identity("last(col, true)", "LAST(col) IGNORE NULLS")
self.validate_identity("last_value(col, true)", "LAST_VALUE(col) IGNORE NULLS")

self.assertEqual(
parse_one("REFRESH TABLE t", read="spark").assert_is(exp.Refresh).sql(dialect="spark"),
"REFRESH TABLE t",
)

self.validate_identity("any_value(col, true)", "ANY_VALUE(col) IGNORE NULLS")
self.validate_identity("first(col, true)", "FIRST(col) IGNORE NULLS")
self.validate_identity("first_value(col, true)", "FIRST_VALUE(col) IGNORE NULLS")
self.validate_identity("last(col, true)", "LAST(col) IGNORE NULLS")
self.validate_identity("last_value(col, true)", "LAST_VALUE(col) IGNORE NULLS")
self.validate_identity("DESCRIBE EXTENDED db.table")
self.validate_identity("SELECT * FROM test TABLESAMPLE (50 PERCENT)")
self.validate_identity("SELECT * FROM test TABLESAMPLE (5 ROWS)")
Expand Down Expand Up @@ -284,6 +284,30 @@ def test_spark(self):
"SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')",
)

with self.assertLogs(helper_logger):
self.validate_all(
"SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
read={
"databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
},
write={
"databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
"duckdb": "SELECT ([1, 2, 3])[3]",
"spark": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)",
},
)

self.validate_all(
"SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)",
read={
"databricks": "SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)",
},
write={
"databricks": "SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)",
"duckdb": "SELECT (MAP([1, 2], ['a', 'b'])[2])[1]",
"spark": "SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)",
},
)
self.validate_all(
"SELECT SPLIT('123|789', '\\\\|')",
read={
Expand Down
Loading