Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: set quote_identifiers in qualify, add normalize flag in schema #1701

Merged
merged 12 commits into from
May 30, 2023
2 changes: 1 addition & 1 deletion sqlglot/dataframe/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
if optimize:
select_expression = optimize_func(select_expression, identify="always")
select_expression = t.cast(exp.Select, optimize_func(select_expression))
select_expression = df._replace_cte_names_with_hashes(select_expression)
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
if expression_type == exp.Cache:
Expand Down
13 changes: 2 additions & 11 deletions sqlglot/optimizer/canonicalize.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,25 @@
from __future__ import annotations

import itertools
import typing as t

from sqlglot import exp
from sqlglot.optimizer.qualify_columns import quote_identifiers

if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType


def canonicalize(
expression: exp.Expression, identify: bool = True, dialect: DialectType = None
) -> exp.Expression:
def canonicalize(expression: exp.Expression) -> exp.Expression:
"""Converts a sql expression into a standard form.

This method relies on annotate_types because many of the
conversions rely on type inference.

Args:
expression: The expression to canonicalize.
identify: Whether or not to force identify identifier.
"""
exp.replace_children(expression, canonicalize, identify=identify, dialect=dialect)
exp.replace_children(expression, canonicalize)

expression = add_text_to_concat(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bool_predicates(expression)
expression = quote_identifiers(expression, dialect=dialect, identify=identify)

return expression

Expand Down
5 changes: 2 additions & 3 deletions sqlglot/optimizer/normalize_identifiers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType


def normalize_identifiers(
expression: exp.Expression, dialect: DialectType = None
) -> exp.Expression:
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
"""
Normalize all unquoted identifiers to either lower or upper case, depending on
the dialect. This essentially makes those identifiers case-insensitive.
Expand Down
14 changes: 8 additions & 6 deletions sqlglot/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify import qualify
from sqlglot.optimizer.quote_identifiers import quote_identifiers
from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
Expand All @@ -31,6 +32,7 @@
merge_subqueries,
eliminate_joins,
eliminate_ctes,
quote_identifiers,
annotate_types,
canonicalize,
simplify,
Expand All @@ -45,7 +47,7 @@ def optimize(
dialect: DialectType = None,
rules: t.Sequence[t.Callable] = RULES,
**kwargs,
):
) -> exp.Expression:
"""
Rewrite a sqlglot AST into an optimized form.

Expand All @@ -63,11 +65,11 @@ def optimize(
dialect: The dialect to parse the sql string.
rules: sequence of optimizer rules to use.
Many of the rules require tables and columns to be qualified.
Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know
what you're doing!
Do not remove `qualify` from the sequence of rules unless you know what you're doing!
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.

Returns:
sqlglot.Expression: optimized expression
The optimized expression.
"""
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
possible_kwargs = {
Expand All @@ -79,8 +81,8 @@ def optimize(
"quote_identifiers": False, # this happens in canonicalize
**kwargs,
}
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)

expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
Expand All @@ -89,4 +91,4 @@ def optimize(
}
expression = rule(expression, **rule_kwargs)

return expression
return t.cast(exp.Expression, expression)
27 changes: 17 additions & 10 deletions sqlglot/optimizer/qualify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@

from sqlglot import exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.optimizer import qualify_columns
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import (
qualify_columns as qualify_columns_func,
validate_qualify_columns as validate_qualify_columns_func,
)
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.quote_identifiers import (
quote_identifiers as quote_identifiers_func,
)
from sqlglot.schema import Schema, ensure_schema


Expand All @@ -20,6 +26,7 @@ def qualify(
expand_alias_refs: bool = True,
infer_schema: t.Optional[bool] = None,
isolate_tables: bool = False,
qualify_columns: bool = True,
validate_qualify_columns: bool = True,
quote_identifiers: bool = True,
identify: bool = True,
Expand All @@ -44,11 +51,13 @@ def qualify(
expand_alias_refs: Whether or not to expand references to aliases.
infer_schema: Whether or not to infer the schema if missing.
isolate_tables: Whether or not to isolate table selects.
qualify_columns: Whether or not to qualify columns.
validate_qualify_columns: Whether or not to validate columns.
quote_identifiers: Whether or not to run the quote_identifiers step.
This step is necessary to ensure correctness for case sensitive queries.
But this flag is provided in case this step is performed at a later time.
identify: If True, quote all identifiers, else only necessary ones.

Returns:
The qualified expression.
"""
Expand All @@ -59,19 +68,17 @@ def qualify(
if isolate_tables:
expression = isolate_table_selects(expression, schema=schema)

expression = qualify_columns.qualify_columns(
expression,
schema,
expand_alias_refs=expand_alias_refs,
infer_schema=infer_schema,
)
if qualify_columns:
expression = qualify_columns_func(
expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema
)

if quote_identifiers:
expression = expression.transform(
qualify_columns.quote_identifiers, dialect, identify, copy=False
expression = quote_identifiers_func(
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
expression, dialect=dialect, identify=identify, copy=False
)

if validate_qualify_columns:
qualify_columns.validate_qualify_columns(expression)
validate_qualify_columns_func(expression)

return expression
18 changes: 1 addition & 17 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import typing as t

from sqlglot import alias, exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import case_sensitive, seq_get
from sqlglot.helper import seq_get
from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
from sqlglot.schema import Schema, ensure_schema

Expand Down Expand Up @@ -414,21 +413,6 @@ def _qualify_outputs(scope):
scope.expression.set("expressions", new_selections)


def quote_identifiers(
expression: exp.Expression, dialect: DialectType, identify: bool
) -> exp.Expression:
"""Makes sure all identifiers that need to be quoted are quoted."""
if isinstance(expression, exp.Identifier):
name = expression.this
expression.set(
"quoted",
identify
or case_sensitive(name, dialect=dialect)
or not exp.SAFE_IDENTIFIER_RE.match(name),
)
return expression


class Resolver:
"""
Helper for resolving columns.
Expand Down
23 changes: 23 additions & 0 deletions sqlglot/optimizer/quote_identifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType as DialectType
from sqlglot.helper import case_sensitive


def quote_identifiers(
expression: E, dialect: DialectType = None, identify: bool = True, copy: bool = True
) -> E:
"""Makes sure all identifiers that need to be quoted are quoted."""

def _quote(expression: E) -> E:
if isinstance(expression, exp.Identifier):
name = expression.this
expression.set(
"quoted",
identify
or case_sensitive(name, dialect=dialect)
or not exp.SAFE_IDENTIFIER_RE.match(name),
)
return expression

return expression.transform(_quote, copy=copy)
5 changes: 4 additions & 1 deletion sqlglot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,19 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
2. {db: {table: set(*cols)}}}
3. {catalog: {db: {table: set(*cols)}}}}
dialect: The dialect to be used for custom type mappings & parsing string arguments.
normalize: Whether to normalize identifier names according to the given dialect or not.
"""

def __init__(
self,
schema: t.Optional[t.Dict] = None,
visible: t.Optional[t.Dict] = None,
dialect: DialectType = None,
normalize: bool = True,
) -> None:
self.dialect = dialect
self.visible = visible or {}
self.normalize = normalize
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}

super().__init__(self._normalize(schema or {}))
Expand Down Expand Up @@ -333,7 +336,7 @@ def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = Non

name = identifier.name

if identifier.quoted:
if not self.normalize or identifier.quoted:
return name

return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower()
Expand Down
4 changes: 2 additions & 2 deletions tests/fixtures/optimizer/canonicalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ SELECT CAST(1 AS VARCHAR) AS "a" FROM "w" AS "w";
SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w;
SELECT 1 + 3.2 AS "a" FROM "w" AS "w";

SELECT CAST("2022-01-01" AS DATE) + INTERVAL '1' day;
SELECT CAST("2022-01-01" AS DATE) + INTERVAL '1' day AS "_col_0";
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
SELECT CAST('2022-01-01' AS DATE) + INTERVAL '1' day;
SELECT CAST('2022-01-01' AS DATE) + INTERVAL '1' day AS "_col_0";

--------------------------------------
-- Ensure boolean predicates
Expand Down
37 changes: 16 additions & 21 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def test_canonicalize(self):
optimize = partial(
optimizer.optimize,
rules=[
optimizer.qualify_tables.qualify_tables,
optimizer.qualify_columns.qualify_columns,
optimizer.qualify.qualify,
optimizer.quote_identifiers.quote_identifiers,
annotate_types,
optimizer.canonicalize.canonicalize,
],
Expand Down Expand Up @@ -699,23 +699,18 @@ def test_quotes(self):
}
}

self.assertEqual(
optimizer.qualify.qualify(
parse_one(
"""
SELECT * FROM example."source"
"""
),
dialect="snowflake",
schema=schema,
).sql(pretty=True),
parse_one(
"""
SELECT
"source"."ID" AS "ID",
"source"."name" AS "name",
"source"."payload" AS "payload"
FROM "EXAMPLE"."source" AS "source"
expected = parse_one(
"""
).sql(pretty=True),
)
SELECT
"source"."ID" AS "ID",
"source"."name" AS "name",
"source"."payload" AS "payload"
FROM "EXAMPLE"."source" AS "source"
""",
read="snowflake",
).sql(pretty=True, dialect="snowflake")

for func in (optimizer.qualify.qualify, optimizer.optimize):
source_query = parse_one('SELECT * FROM example."source"', read="snowflake")
transformed = func(source_query, dialect="snowflake", schema=schema)
self.assertEqual(transformed.sql(pretty=True, dialect="snowflake"), expected)
4 changes: 4 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,7 @@ def test_schema_normalization(self):
# Check that names are normalized to uppercase for Snowflake
schema = MappingSchema(schema={"x": {"foo": "int", '"bLa"': "int"}}, dialect="snowflake")
self.assertEqual(schema.column_names(exp.Table(this="x")), ["FOO", "bLa"])

# Check that switching off the normalization logic works as expected
schema = MappingSchema(schema={"x": {"foo": "int"}}, normalize=False, dialect="snowflake")
self.assertEqual(schema.column_names(exp.Table(this="x")), ["foo"])