Skip to content

Commit

Permalink
fix!: allow to_column to properly parse quoted column paths, make typ…
Browse files Browse the repository at this point in the history
…es simpler
  • Loading branch information
tobymao committed Apr 8, 2024
1 parent a37d231 commit a90a82b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 36 deletions.
70 changes: 38 additions & 32 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6637,17 +6637,9 @@ def to_interval(interval: str | Literal) -> Interval:
)


@t.overload
def to_table(sql_path: str | Table, **kwargs) -> Table: ...


@t.overload
def to_table(sql_path: None, **kwargs) -> None: ...


def to_table(
sql_path: t.Optional[str | Table], dialect: DialectType = None, copy: bool = True, **kwargs
) -> t.Optional[Table]:
sql_path: str | Table, dialect: DialectType = None, copy: bool = True, **kwargs
) -> Table:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
If a table is passed in then that table is returned.
Expand All @@ -6661,35 +6653,42 @@ def to_table(
Returns:
A table expression.
"""
if sql_path is None or isinstance(sql_path, Table):
if isinstance(sql_path, Table):
return maybe_copy(sql_path, copy=copy)
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")

table = maybe_parse(sql_path, into=Table, dialect=dialect)
if table:
for k, v in kwargs.items():
table.set(k, v)

for k, v in kwargs.items():
table.set(k, v)

return table


def to_column(sql_path: str | Column, **kwargs) -> Column:
def to_column(
sql_path: str | Column, dialect: DialectType = None, copy: bool = True, **kwargs
) -> Column:
"""
Create a column from a `[table].[column]` sql path. Schema is optional.
Create a column from a `[table].[column]` sql path. Table is optional.
If a column is passed in then that column is returned.
Args:
sql_path: `[table].[column]` string
sql_path: a `[table].[column]` string.
dialect: the source dialect according to which the column name will be parsed.
copy: Whether to copy a column if it is passed in.
kwargs: the kwargs to instantiate the resulting `Column` expression with.
Returns:
Table: A column expression
A column expression.
"""
if sql_path is None or isinstance(sql_path, Column):
return sql_path
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore
if isinstance(sql_path, Column):
return maybe_copy(sql_path, copy=copy)

column = maybe_parse(sql_path, into=Column, dialect=dialect)

for k, v in kwargs.items():
column.set(k, v)

return column


def alias_(
Expand Down Expand Up @@ -6948,18 +6947,23 @@ def var(name: t.Optional[ExpOrStr]) -> Var:
return Var(this=name)


def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
def rename_table(
old_name: str | Table,
new_name: str | Table,
dialect: DialectType = None,
) -> AlterTable:
"""Build ALTER TABLE... RENAME... expression
Args:
old_name: The old name of the table
new_name: The new name of the table
dialect: The dialect to parse the table.
Returns:
Alter table expression
"""
old_table = to_table(old_name)
new_table = to_table(new_name)
old_table = to_table(old_name, dialect=dialect)
new_table = to_table(new_name, dialect=dialect)
return AlterTable(
this=old_table,
actions=[
Expand All @@ -6973,6 +6977,7 @@ def rename_column(
old_column_name: str | Column,
new_column_name: str | Column,
exists: t.Optional[bool] = None,
dialect: DialectType = None,
) -> AlterTable:
"""Build ALTER TABLE... RENAME COLUMN... expression
Expand All @@ -6981,13 +6986,14 @@ def rename_column(
old_column: The old name of the column
new_column: The new name of the column
exists: Whether to add the `IF EXISTS` clause
dialect: The dialect to parse the table/column.
Returns:
Alter table expression
"""
table = to_table(table_name)
old_column = to_column(old_column_name)
new_column = to_column(new_column_name)
table = to_table(table_name, dialect=dialect)
old_column = to_column(old_column_name, dialect=dialect)
new_column = to_column(new_column_name, dialect=dialect)
return AlterTable(
this=table,
actions=[
Expand Down
6 changes: 2 additions & 4 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,8 +893,6 @@ def test_to_table(self):
self.assertEqual(catalog_db_and_table.name, "table_name")
self.assertEqual(catalog_db_and_table.args.get("db"), exp.to_identifier("db"))
self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog"))
with self.assertRaises(ValueError):
exp.to_table(1)

def test_to_column(self):
column_only = exp.to_column("column_name")
Expand All @@ -903,8 +901,8 @@ def test_to_column(self):
table_and_column = exp.to_column("table_name.column_name")
self.assertEqual(table_and_column.name, "column_name")
self.assertEqual(table_and_column.args.get("table"), exp.to_identifier("table_name"))
with self.assertRaises(ValueError):
exp.to_column(1)

self.assertEqual(exp.to_column("`column_name`", dialect="spark").sql(), '"column_name"')

def test_union(self):
expression = parse_one("SELECT cola, colb UNION SELECT colx, coly")
Expand Down

0 comments on commit a90a82b

Please sign in to comment.