Skip to content

Commit

Permalink
Fix!: explode to unnest with multiple explosions
Browse files Browse the repository at this point in the history
also remove with ordinality with just offset for consistency
closes #2227
  • Loading branch information
tobymao committed Sep 16, 2023
1 parent e07f843 commit 02bb54b
Show file tree
Hide file tree
Showing 13 changed files with 226 additions and 109 deletions.
6 changes: 5 additions & 1 deletion sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ class Generator(generator.Generator):
RENAME_TABLE_WITH_DB = False
ESCAPE_LINE_BREAK = True
NVL2_SUPPORTED = False
UNNEST_WITH_ORDINALITY = False

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand All @@ -434,6 +435,9 @@ class Generator(generator.Generator):
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.Hex: rename_func("TO_HEX"),
exp.If: lambda self, e: self.func(
"IF", e.this, e.args.get("true"), e.args.get("false") or "NULL"
),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
Expand All @@ -455,7 +459,7 @@ class Generator(generator.Generator):
exp.ReturnsProperty: _returnsproperty_sql,
exp.Select: transforms.preprocess(
[
transforms.explode_to_unnest,
transforms.explode_to_unnest(),
_unqualify_unnest,
transforms.eliminate_distinct_on,
_alias_ordered_group,
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> s
this=exp.Unnest(
expressions=[expression.this.this],
alias=expression.args.get("alias"),
ordinality=isinstance(expression.this, exp.Posexplode),
offset=isinstance(expression.this, exp.Posexplode),
),
kind="cross",
)
Expand Down Expand Up @@ -331,7 +331,7 @@ class Generator(generator.Generator):
[
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
transforms.explode_to_unnest,
transforms.explode_to_unnest(1),
]
),
exp.SortArray: _no_sort_array,
Expand Down
1 change: 0 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2569,7 +2569,6 @@ class Intersect(Union):
class Unnest(UDTF):
arg_types = {
"expressions": True,
"ordinality": False,
"alias": False,
"offset": False,
}
Expand Down
29 changes: 24 additions & 5 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ class Generator:
# Whether or not the word COLUMN is included when adding a column with ALTER TABLE
ALTER_TABLE_ADD_COLUMN_KEYWORD = True

# UNNEST WITH ORDINALITY (presto) instead of UNNEST WITH OFFSET (bigquery)
UNNEST_WITH_ORDINALITY = True

TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
Expand Down Expand Up @@ -1858,17 +1861,33 @@ def union_op(self, expression: exp.Union) -> str:

def unnest_sql(self, expression: exp.Unnest) -> str:
args = self.expressions(expression, flat=True)

alias = expression.args.get("alias")
offset = expression.args.get("offset")

if self.UNNEST_WITH_ORDINALITY:
if alias and isinstance(offset, exp.Expression):
alias = alias.copy()
alias.append("columns", offset.copy())

if alias and self.UNNEST_COLUMN_ONLY:
columns = alias.columns
alias = self.sql(columns[0]) if columns else ""
else:
alias = self.sql(expression, "alias")
alias = self.sql(alias)

alias = f" AS {alias}" if alias else alias
ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else ""
offset = expression.args.get("offset")
offset = f" WITH OFFSET AS {self.sql(offset)}" if offset else ""
return f"UNNEST({args}){ordinality}{alias}{offset}"
if self.UNNEST_WITH_ORDINALITY:
suffix = f" WITH ORDINALITY{alias}" if offset else alias
else:
if isinstance(offset, exp.Expression):
suffix = f"{alias} WITH OFFSET AS {self.sql(offset)}"
elif offset:
suffix = f"{alias} WITH OFFSET"
else:
suffix = alias

return f"UNNEST({args}){suffix}"

def where_sql(self, expression: exp.Where) -> str:
this = self.indent(self.sql(expression, "this"))
Expand Down
24 changes: 13 additions & 11 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2634,25 +2634,27 @@ def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
return None

expressions = self._parse_wrapped_csv(self._parse_type)
ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
offset = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)

alias = self._parse_table_alias() if with_alias else None

if alias and self.UNNEST_COLUMN_ONLY:
if alias.args.get("columns"):
self.raise_error("Unexpected extra column alias in unnest.")
if alias:
if self.UNNEST_COLUMN_ONLY:
if alias.args.get("columns"):
self.raise_error("Unexpected extra column alias in unnest.")

alias.set("columns", [alias.this])
alias.set("this", None)

alias.set("columns", [alias.this])
alias.set("this", None)
columns = alias.args.get("columns") or []
if offset and len(expressions) < len(columns):
offset = columns.pop()

offset = None
if self._match_pair(TokenType.WITH, TokenType.OFFSET):
if not offset and self._match_pair(TokenType.WITH, TokenType.OFFSET):
self._match(TokenType.ALIAS)
offset = self._parse_id_var() or exp.to_identifier("offset")

return self.expression(
exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias, offset=offset
)
return self.expression(exp.Unnest, expressions=expressions, alias=alias, offset=offset)

def _parse_derived_table_values(self) -> t.Optional[exp.Values]:
is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
Expand Down
154 changes: 109 additions & 45 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:

if isinstance(unnest, exp.Unnest):
alias = unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode

expression.args["joins"].remove(join)

Expand All @@ -163,65 +163,129 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
return expression


def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
"""Convert explode/posexplode into unnest (used in hive -> presto)."""
if isinstance(expression, exp.Select):
from sqlglot.optimizer.scope import Scope
def explode_to_unnest(index_offset: int = 0) -> t.Callable:
def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
"""Convert explode/posexplode into unnest (used in hive -> presto)."""
if isinstance(expression, exp.Select):
from sqlglot.optimizer.scope import Scope

taken_select_names = set(expression.named_selects)
taken_source_names = {name for name, _ in Scope(expression).references}
taken_select_names = set(expression.named_selects)
taken_source_names = {name for name, _ in Scope(expression).references}

for select in expression.selects:
to_replace = select
def new_name(names: t.Set[str], name: str) -> str:
name = find_new_name(names, name)
names.add(name)
return name

arrays: t.List[exp.Condition] = []
series_alias = new_name(taken_select_names, "pos")
series = exp.alias_(
exp.Unnest(
expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
),
new_name(taken_source_names, "_u"),
table=[series_alias],
)

for select in list(expression.selects):
to_replace = select
pos_alias = ""
explode_alias = ""

pos_alias = ""
explode_alias = ""
if isinstance(select, exp.Alias):
explode_alias = select.alias
select = select.this
elif isinstance(select, exp.Aliases):
pos_alias = select.aliases[0].name
explode_alias = select.aliases[1].name
select = select.this

if isinstance(select, exp.Alias):
explode_alias = select.alias
select = select.this
elif isinstance(select, exp.Aliases):
pos_alias = select.aliases[0].name
explode_alias = select.aliases[1].name
select = select.this
if isinstance(select, (exp.Explode, exp.Posexplode)):
is_posexplode = isinstance(select, exp.Posexplode)
explode_arg = select.this

if isinstance(select, (exp.Explode, exp.Posexplode)):
is_posexplode = isinstance(select, exp.Posexplode)
# This ensures that we won't use [POS]EXPLODE's argument as a new selection
if isinstance(explode_arg, exp.Column):
taken_select_names.add(explode_arg.output_name)

explode_arg = select.this
unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
unnest_source_alias = new_name(taken_source_names, "_u")

# This ensures that we won't use [POS]EXPLODE's argument as a new selection
if isinstance(explode_arg, exp.Column):
taken_select_names.add(explode_arg.output_name)
if not explode_alias:
explode_alias = new_name(taken_select_names, "col")

unnest_source_alias = find_new_name(taken_source_names, "_u")
taken_source_names.add(unnest_source_alias)
if is_posexplode:
pos_alias = new_name(taken_select_names, "pos")

if not explode_alias:
explode_alias = find_new_name(taken_select_names, "col")
taken_select_names.add(explode_alias)
if not pos_alias:
pos_alias = new_name(taken_select_names, "pos")

column = exp.If(
this=exp.column(series_alias).eq(exp.column(pos_alias)),
true=exp.column(explode_alias),
).as_(explode_alias)

if is_posexplode:
pos_alias = find_new_name(taken_select_names, "pos")
taken_select_names.add(pos_alias)
expressions = expression.expressions
index = expressions.index(to_replace)
expressions.pop(index)
expressions.insert(index, column)
expressions.insert(
index + 1,
exp.If(
this=exp.column(series_alias).eq(exp.column(pos_alias)),
true=exp.column(pos_alias),
).as_(pos_alias),
)
expression.set("expressions", expressions)
else:
to_replace.replace(column)

if not arrays:
if expression.args.get("from"):
expression.join(series, copy=False)
else:
expression.from_(series, copy=False)

size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
arrays.append(size)

# trino doesn't support left join unnest with on conditions
# if it did, this would be much simpler
expression.join(
exp.alias_(
exp.Unnest(
expressions=[explode_arg.copy()],
offset=exp.to_identifier(pos_alias),
),
unnest_source_alias,
table=[explode_alias],
),
join_type="CROSS",
copy=False,
)

if is_posexplode:
column_names = [explode_alias, pos_alias]
to_replace.pop()
expression.select(pos_alias, explode_alias, copy=False)
else:
column_names = [explode_alias]
to_replace.replace(exp.column(explode_alias))
if index_offset != 1:
size = size - 1

unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
expression.where(
exp.column(series_alias)
.eq(exp.column(pos_alias))
.or_(
(exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
),
copy=False,
)

if not expression.args.get("from"):
expression.from_(unnest, copy=False)
else:
expression.join(unnest, join_type="CROSS", copy=False)
if arrays:
end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])

return expression
if index_offset != 1:
end = end - (1 - index_offset)
series.expressions[0].set("end", end)

return expression

return _explode_to_unnest


PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_bigquery(self):
},
)
self.validate_all(
"WITH cte AS (SELECT [1, 2, 3] AS arr) SELECT col FROM cte CROSS JOIN UNNEST(arr) AS col",
"WITH cte AS (SELECT [1, 2, 3] AS arr) SELECT IF(pos = pos_2, col, NULL) AS col FROM cte, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(arr)) - 1)) AS pos CROSS JOIN UNNEST(arr) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(arr) - 1) AND pos_2 = (ARRAY_LENGTH(arr) - 1))",
read={
"spark": "WITH cte AS (SELECT ARRAY(1, 2, 3) AS arr) SELECT EXPLODE(arr) FROM cte"
},
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,7 @@ def test_operators(self):
self.validate_all(
"SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) as foo FROM baz",
write={
"bigquery": "SELECT CASE WHEN COALESCE(bar, 0) = 1 THEN TRUE ELSE FALSE END AS foo FROM baz",
"bigquery": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz",
"duckdb": "SELECT CASE WHEN COALESCE(bar, 0) = 1 THEN TRUE ELSE FALSE END AS foo FROM baz",
"presto": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz",
"hive": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz",
Expand Down
16 changes: 16 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,22 @@ class TestDuckDB(Validator):
dialect = "duckdb"

def test_duckdb(self):
self.validate_all(
"SELECT UNNEST(ARRAY[1, 2, 3]), UNNEST(ARRAY[4, 5]), UNNEST(ARRAY[6])",
write={
"bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_3, col_2, NULL) AS col_2, IF(pos = pos_4, col_3, NULL) AS col_3 FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2, 3]), ARRAY_LENGTH([4, 5]), ARRAY_LENGTH([6])) - 1)) AS pos CROSS JOIN UNNEST([1, 2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5]) AS col_2 WITH OFFSET AS pos_3 CROSS JOIN UNNEST([6]) AS col_3 WITH OFFSET AS pos_4 WHERE ((pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5]) - 1)))) AND (pos = pos_4 OR (pos > (ARRAY_LENGTH([6]) - 1) AND pos_4 = (ARRAY_LENGTH([6]) - 1)))",
"presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_3, col_2) AS col_2, IF(pos = pos_4, col_3) AS col_3 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2, 3]) AND pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5]) AND pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (pos = pos_4 OR (pos > CARDINALITY(ARRAY[6]) AND pos_4 = CARDINALITY(ARRAY[6])))",
},
)

self.validate_all(
"SELECT UNNEST(ARRAY[1, 2, 3]), UNNEST(ARRAY[4, 5]), UNNEST(ARRAY[6]) FROM x",
write={
"bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_3, col_2, NULL) AS col_2, IF(pos = pos_4, col_3, NULL) AS col_3 FROM x, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2, 3]), ARRAY_LENGTH([4, 5]), ARRAY_LENGTH([6])) - 1)) AS pos CROSS JOIN UNNEST([1, 2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5]) AS col_2 WITH OFFSET AS pos_3 CROSS JOIN UNNEST([6]) AS col_3 WITH OFFSET AS pos_4 WHERE ((pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5]) - 1)))) AND (pos = pos_4 OR (pos > (ARRAY_LENGTH([6]) - 1) AND pos_4 = (ARRAY_LENGTH([6]) - 1)))",
"presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_3, col_2) AS col_2, IF(pos = pos_4, col_3) AS col_3 FROM x, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2, 3]) AND pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5]) AND pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (pos = pos_4 OR (pos > CARDINALITY(ARRAY[6]) AND pos_4 = CARDINALITY(ARRAY[6])))",
},
)

self.validate_identity("[x.STRING_SPLIT(' ')[1] FOR x IN ['1', '2', 3] IF x.CONTAINS('1')]")
self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y")
self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x")
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,15 @@ def test_unnest(self):
write={
"hive": "SELECT EXPLODE(c) FROM t",
"postgres": "SELECT UNNEST(c) FROM t",
"presto": "SELECT col FROM t CROSS JOIN UNNEST(c) AS _u(col)",
"presto": "SELECT IF(pos = pos_2, col) AS col FROM t, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(c)))) AS _u(pos) CROSS JOIN UNNEST(c) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(c) AND pos_2 = CARDINALITY(c))",
},
)
self.validate_all(
"SELECT UNNEST(ARRAY[1])",
write={
"hive": "SELECT EXPLODE(ARRAY(1))",
"postgres": "SELECT UNNEST(ARRAY[1])",
"presto": "SELECT col FROM UNNEST(ARRAY[1]) AS _u(col)",
"presto": "SELECT IF(pos = pos_2, col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[1]) AND pos_2 = CARDINALITY(ARRAY[1]))",
},
)

Expand Down
Loading

0 comments on commit 02bb54b

Please sign in to comment.