From d4f14c0788c217c3f8de687ed48344f366ad1c26 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Mon, 11 Nov 2024 10:38:59 +0200 Subject: [PATCH] fix(duckdb): Fix STRUCT cast generation --- sqlglot/dialects/duckdb.py | 18 ++++++++++++------ tests/dialects/test_duckdb.py | 7 ++++--- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index bf1abe2f1d..a183a883f5 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -156,18 +156,24 @@ def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str: # BigQuery allows inline construction such as "STRUCT('str', 1)" which is # canonicalized to "ROW('str', 1) AS STRUCT(a TEXT, b INT)" in DuckDB - # The transformation to ROW will take place if a cast to STRUCT / ARRAY of STRUCTs is found + # The transformation to ROW will take place if: + # 1. The STRUCT itself does not have proper fields (key := value) as a "proper" STRUCT would + # 2. A cast to STRUCT / ARRAY of STRUCTs is found ancestor_cast = expression.find_ancestor(exp.Cast) - is_struct_cast = ancestor_cast and any( - casted_type.is_type(exp.DataType.Type.STRUCT) - for casted_type in ancestor_cast.find_all(exp.DataType) + is_bq_inline_struct = ( + (expression.find(exp.PropertyEQ) is None) + and ancestor_cast + and any( + casted_type.is_type(exp.DataType.Type.STRUCT) + for casted_type in ancestor_cast.find_all(exp.DataType) + ) ) for i, expr in enumerate(expression.expressions): is_property_eq = isinstance(expr, exp.PropertyEQ) value = expr.expression if is_property_eq else expr - if is_struct_cast: + if is_bq_inline_struct: args.append(self.sql(value)) else: key = expr.name if is_property_eq else f"_{i}" @@ -175,7 +181,7 @@ def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str: csv_args = ", ".join(args) - return f"ROW({csv_args})" if is_struct_cast else f"{{{csv_args}}}" + return f"ROW({csv_args})" if is_bq_inline_struct else f"{{{csv_args}}}" def _datatype_sql(self: DuckDB.Generator, expression: exp.DataType) -> str: diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 5b3b2a4ff4..3d4fe9cc4a 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -1154,6 +1154,7 @@ def test_cast(self): self.validate_identity("CAST(x AS BINARY)", "CAST(x AS BLOB)") self.validate_identity("CAST(x AS VARBINARY)", "CAST(x AS BLOB)") self.validate_identity("CAST(x AS LOGICAL)", "CAST(x AS BOOLEAN)") + self.validate_identity("""CAST({'i': 1, 's': 'foo'} AS STRUCT("s" TEXT, "i" INT))""") self.validate_identity( "CAST(ROW(1, ROW(1)) AS STRUCT(number BIGINT, row STRUCT(number BIGINT)))" ) @@ -1163,11 +1164,11 @@ def test_cast(self): ) self.validate_identity( "CAST([[STRUCT_PACK(a := 1)]] AS STRUCT(a BIGINT)[][])", - "CAST([[ROW(1)]] AS STRUCT(a BIGINT)[][])", + "CAST([[{'a': 1}]] AS STRUCT(a BIGINT)[][])", ) self.validate_identity( "CAST([STRUCT_PACK(a := 1)] AS STRUCT(a BIGINT)[])", - "CAST([ROW(1)] AS STRUCT(a BIGINT)[])", + "CAST([{'a': 1}] AS STRUCT(a BIGINT)[])", ) self.validate_identity( "STRUCT_PACK(a := 'b')::json", @@ -1175,7 +1176,7 @@ def test_cast(self): ) self.validate_identity( "STRUCT_PACK(a := 'b')::STRUCT(a TEXT)", - "CAST(ROW('b') AS STRUCT(a TEXT))", + "CAST({'a': 'b'} AS STRUCT(a TEXT))", ) self.validate_all(