From a90a82b262c1ac3ff9384e655c778cca8863aa73 Mon Sep 17 00:00:00 2001 From: tobymao Date: Mon, 8 Apr 2024 13:55:26 -0700 Subject: [PATCH] fix!: allow to_column to properly parse quoted column paths, make types simpler --- sqlglot/expressions.py | 70 +++++++++++++++++++++------------------ tests/test_expressions.py | 6 ++-- 2 files changed, 40 insertions(+), 36 deletions(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 6e1c2ce05e..37704ffa33 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -6637,17 +6637,9 @@ def to_interval(interval: str | Literal) -> Interval: ) -@t.overload -def to_table(sql_path: str | Table, **kwargs) -> Table: ... - - -@t.overload -def to_table(sql_path: None, **kwargs) -> None: ... - - def to_table( - sql_path: t.Optional[str | Table], dialect: DialectType = None, copy: bool = True, **kwargs -) -> t.Optional[Table]: + sql_path: str | Table, dialect: DialectType = None, copy: bool = True, **kwargs +) -> Table: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. If a table is passed in then that table is returned. @@ -6661,35 +6653,42 @@ def to_table( Returns: A table expression. """ - if sql_path is None or isinstance(sql_path, Table): + if isinstance(sql_path, Table): return maybe_copy(sql_path, copy=copy) - if not isinstance(sql_path, str): - raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") table = maybe_parse(sql_path, into=Table, dialect=dialect) - if table: - for k, v in kwargs.items(): - table.set(k, v) + + for k, v in kwargs.items(): + table.set(k, v) return table -def to_column(sql_path: str | Column, **kwargs) -> Column: +def to_column( + sql_path: str | Column, dialect: DialectType = None, copy: bool = True, **kwargs +) -> Column: """ - Create a column from a `[table].[column]` sql path. Schema is optional. - + Create a column from a `[table].[column]` sql path. Table is optional. If a column is passed in then that column is returned. Args: - sql_path: `[table].[column]` string + sql_path: a `[table].[column]` string. + dialect: the source dialect according to which the column name will be parsed. + copy: Whether to copy a column if it is passed in. + kwargs: the kwargs to instantiate the resulting `Column` expression with. + Returns: - Table: A column expression + A column expression. """ - if sql_path is None or isinstance(sql_path, Column): - return sql_path - if not isinstance(sql_path, str): - raise ValueError(f"Invalid type provided for column: {type(sql_path)}") - return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore + if isinstance(sql_path, Column): + return maybe_copy(sql_path, copy=copy) + + column = maybe_parse(sql_path, into=Column, dialect=dialect) + + for k, v in kwargs.items(): + column.set(k, v) + + return column def alias_( @@ -6948,18 +6947,23 @@ def var(name: t.Optional[ExpOrStr]) -> Var: return Var(this=name) -def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable: +def rename_table( + old_name: str | Table, + new_name: str | Table, + dialect: DialectType = None, +) -> AlterTable: """Build ALTER TABLE... RENAME... expression Args: old_name: The old name of the table new_name: The new name of the table + dialect: The dialect to parse the table. Returns: Alter table expression """ - old_table = to_table(old_name) - new_table = to_table(new_name) + old_table = to_table(old_name, dialect=dialect) + new_table = to_table(new_name, dialect=dialect) return AlterTable( this=old_table, actions=[ @@ -6973,6 +6977,7 @@ def rename_column( old_column_name: str | Column, new_column_name: str | Column, exists: t.Optional[bool] = None, + dialect: DialectType = None, ) -> AlterTable: """Build ALTER TABLE... RENAME COLUMN... expression @@ -6981,13 +6986,14 @@ def rename_column( old_column: The old name of the column new_column: The new name of the column exists: Whether to add the `IF EXISTS` clause + dialect: The dialect to parse the table/column. Returns: Alter table expression """ - table = to_table(table_name) - old_column = to_column(old_column_name) - new_column = to_column(new_column_name) + table = to_table(table_name, dialect=dialect) + old_column = to_column(old_column_name, dialect=dialect) + new_column = to_column(new_column_name, dialect=dialect) return AlterTable( this=table, actions=[ diff --git a/tests/test_expressions.py b/tests/test_expressions.py index ed19ac1a0d..35c27f553e 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -893,8 +893,6 @@ def test_to_table(self): self.assertEqual(catalog_db_and_table.name, "table_name") self.assertEqual(catalog_db_and_table.args.get("db"), exp.to_identifier("db")) self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog")) - with self.assertRaises(ValueError): - exp.to_table(1) def test_to_column(self): column_only = exp.to_column("column_name") @@ -903,8 +901,8 @@ def test_to_column(self): table_and_column = exp.to_column("table_name.column_name") self.assertEqual(table_and_column.name, "column_name") self.assertEqual(table_and_column.args.get("table"), exp.to_identifier("table_name")) - with self.assertRaises(ValueError): - exp.to_column(1) + + self.assertEqual(exp.to_column("`column_name`", dialect="spark").sql(), '"column_name"') def test_union(self): expression = parse_one("SELECT cola, colb UNION SELECT colx, coly")