diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index d6d76abe8d..3fc923238f 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -299,7 +299,7 @@ def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: for expression_type, select_expression in select_expressions: select_expression = select_expression.transform(replace_id_value, replacement_mapping) if optimize: - select_expression = optimize_func(select_expression, identify="always") + select_expression = t.cast(exp.Select, optimize_func(select_expression)) select_expression = df._replace_cte_names_with_hashes(select_expression) expression: t.Union[exp.Select, exp.Cache, exp.Drop] if expression_type == exp.Cache: diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index a74fa87460..da2fce8f3c 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -1,18 +1,11 @@ from __future__ import annotations import itertools -import typing as t from sqlglot import exp -from sqlglot.optimizer.qualify_columns import quote_identifiers -if t.TYPE_CHECKING: - from sqlglot.dialects.dialect import DialectType - -def canonicalize( - expression: exp.Expression, identify: bool = True, dialect: DialectType = None -) -> exp.Expression: +def canonicalize(expression: exp.Expression) -> exp.Expression: """Converts a sql expression into a standard form. This method relies on annotate_types because many of the @@ -20,15 +13,13 @@ def canonicalize( Args: expression: The expression to canonicalize. - identify: Whether or not to force identify identifier. """ - exp.replace_children(expression, canonicalize, identify=identify, dialect=dialect) + exp.replace_children(expression, canonicalize) expression = add_text_to_concat(expression) expression = coerce_type(expression) expression = remove_redundant_casts(expression) expression = ensure_bool_predicates(expression) - expression = quote_identifiers(expression, dialect=dialect, identify=identify) return expression diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index bf4e3329de..1e5c104242 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -1,10 +1,9 @@ from sqlglot import exp +from sqlglot._typing import E from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType -def normalize_identifiers( - expression: exp.Expression, dialect: DialectType = None -) -> exp.Expression: +def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: """ Normalize all unquoted identifiers to either lower or upper case, depending on the dialect. This essentially makes those identifiers case-insensitive. diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index d2b1054a73..dbe33a2088 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -16,6 +16,7 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.qualify import qualify +from sqlglot.optimizer.qualify_columns import quote_identifiers from sqlglot.optimizer.simplify import simplify from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.schema import ensure_schema @@ -31,6 +32,7 @@ merge_subqueries, eliminate_joins, eliminate_ctes, + quote_identifiers, annotate_types, canonicalize, simplify, @@ -45,7 +47,7 @@ def optimize( dialect: DialectType = None, rules: t.Sequence[t.Callable] = RULES, **kwargs, -): +) -> exp.Expression: """ Rewrite a sqlglot AST into an optimized form. @@ -63,11 +65,11 @@ def optimize( dialect: The dialect to parse the sql string. rules: sequence of optimizer rules to use. Many of the rules require tables and columns to be qualified. - Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know - what you're doing! + Do not remove `qualify` from the sequence of rules unless you know what you're doing! **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. + Returns: - sqlglot.Expression: optimized expression + The optimized expression. """ schema = ensure_schema(schema or sqlglot.schema, dialect=dialect) possible_kwargs = { @@ -79,8 +81,8 @@ def optimize( "quote_identifiers": False, # this happens in canonicalize **kwargs, } - expression = exp.maybe_parse(expression, dialect=dialect, copy=True) + expression = exp.maybe_parse(expression, dialect=dialect, copy=True) for rule in rules: # Find any additional rule parameters, beyond `expression` rule_params = rule.__code__.co_varnames @@ -89,4 +91,4 @@ def optimize( } expression = rule(expression, **rule_kwargs) - return expression + return t.cast(exp.Expression, expression) diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py index ea9d4ebccd..5fdbde81cb 100644 --- a/sqlglot/optimizer/qualify.py +++ b/sqlglot/optimizer/qualify.py @@ -4,9 +4,13 @@ from sqlglot import exp from sqlglot.dialects.dialect import DialectType -from sqlglot.optimizer import qualify_columns from sqlglot.optimizer.isolate_table_selects import isolate_table_selects from sqlglot.optimizer.normalize_identifiers import normalize_identifiers +from sqlglot.optimizer.qualify_columns import ( + qualify_columns as qualify_columns_func, + quote_identifiers as quote_identifiers_func, + validate_qualify_columns as validate_qualify_columns_func, +) from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.schema import Schema, ensure_schema @@ -20,6 +24,7 @@ def qualify( expand_alias_refs: bool = True, infer_schema: t.Optional[bool] = None, isolate_tables: bool = False, + qualify_columns: bool = True, validate_qualify_columns: bool = True, quote_identifiers: bool = True, identify: bool = True, @@ -44,11 +49,13 @@ def qualify( expand_alias_refs: Whether or not to expand references to aliases. infer_schema: Whether or not to infer the schema if missing. isolate_tables: Whether or not to isolate table selects. + qualify_columns: Whether or not to qualify columns. validate_qualify_columns: Whether or not to validate columns. quote_identifiers: Whether or not to run the quote_identifiers step. This step is necessary to ensure correctness for case sensitive queries. But this flag is provided in case this step is performed at a later time. identify: If True, quote all identifiers, else only necessary ones. + Returns: The qualified expression. """ @@ -59,19 +66,15 @@ def qualify( if isolate_tables: expression = isolate_table_selects(expression, schema=schema) - expression = qualify_columns.qualify_columns( - expression, - schema, - expand_alias_refs=expand_alias_refs, - infer_schema=infer_schema, - ) + if qualify_columns: + expression = qualify_columns_func( + expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema + ) if quote_identifiers: - expression = expression.transform( - qualify_columns.quote_identifiers, dialect, identify, copy=False - ) + expression = quote_identifiers_func(expression, dialect=dialect, identify=identify) if validate_qualify_columns: - qualify_columns.validate_qualify_columns(expression) + validate_qualify_columns_func(expression) return expression diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 799dd0633a..4a311714ba 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -4,6 +4,7 @@ import typing as t from sqlglot import alias, exp +from sqlglot._typing import E from sqlglot.dialects.dialect import DialectType from sqlglot.errors import OptimizeError from sqlglot.helper import case_sensitive, seq_get @@ -414,19 +415,21 @@ def _qualify_outputs(scope): scope.expression.set("expressions", new_selections) -def quote_identifiers( - expression: exp.Expression, dialect: DialectType, identify: bool -) -> exp.Expression: +def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: """Makes sure all identifiers that need to be quoted are quoted.""" - if isinstance(expression, exp.Identifier): - name = expression.this - expression.set( - "quoted", - identify - or case_sensitive(name, dialect=dialect) - or not exp.SAFE_IDENTIFIER_RE.match(name), - ) - return expression + + def _quote(expression: E) -> E: + if isinstance(expression, exp.Identifier): + name = expression.this + expression.set( + "quoted", + identify + or case_sensitive(name, dialect=dialect) + or not exp.SAFE_IDENTIFIER_RE.match(name), + ) + return expression + + return expression.transform(_quote, copy=False) class Resolver: diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 25abaa202e..e7e7d3d4e6 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -177,6 +177,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 2. {db: {table: set(*cols)}}} 3. {catalog: {db: {table: set(*cols)}}}} dialect: The dialect to be used for custom type mappings & parsing string arguments. + normalize: Whether to normalize identifier names according to the given dialect or not. """ def __init__( @@ -184,9 +185,11 @@ def __init__( schema: t.Optional[t.Dict] = None, visible: t.Optional[t.Dict] = None, dialect: DialectType = None, + normalize: bool = True, ) -> None: self.dialect = dialect self.visible = visible or {} + self.normalize = normalize self._type_mapping_cache: t.Dict[str, exp.DataType] = {} super().__init__(self._normalize(schema or {})) @@ -333,7 +336,7 @@ def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = Non name = identifier.name - if identifier.quoted: + if not self.normalize or identifier.quoted: return name return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower() diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql index ccf2f16b7e..1fc44efbe1 100644 --- a/tests/fixtures/optimizer/canonicalize.sql +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -10,8 +10,8 @@ SELECT CAST(1 AS VARCHAR) AS "a" FROM "w" AS "w"; SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w; SELECT 1 + 3.2 AS "a" FROM "w" AS "w"; -SELECT CAST("2022-01-01" AS DATE) + INTERVAL '1' day; -SELECT CAST("2022-01-01" AS DATE) + INTERVAL '1' day AS "_col_0"; +SELECT CAST('2022-01-01' AS DATE) + INTERVAL '1' day; +SELECT CAST('2022-01-01' AS DATE) + INTERVAL '1' day AS "_col_0"; -------------------------------------- -- Ensure boolean predicates diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 88c99c28ea..2ae6da993a 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -304,8 +304,8 @@ def test_canonicalize(self): optimize = partial( optimizer.optimize, rules=[ - optimizer.qualify_tables.qualify_tables, - optimizer.qualify_columns.qualify_columns, + optimizer.qualify.qualify, + optimizer.qualify_columns.quote_identifiers, annotate_types, optimizer.canonicalize.canonicalize, ], @@ -699,23 +699,18 @@ def test_quotes(self): } } - self.assertEqual( - optimizer.qualify.qualify( - parse_one( - """ - SELECT * FROM example."source" - """ - ), - dialect="snowflake", - schema=schema, - ).sql(pretty=True), - parse_one( - """ - SELECT - "source"."ID" AS "ID", - "source"."name" AS "name", - "source"."payload" AS "payload" - FROM "EXAMPLE"."source" AS "source" + expected = parse_one( """ - ).sql(pretty=True), - ) + SELECT + "source"."ID" AS "ID", + "source"."name" AS "name", + "source"."payload" AS "payload" + FROM "EXAMPLE"."source" AS "source" + """, + read="snowflake", + ).sql(pretty=True, dialect="snowflake") + + for func in (optimizer.qualify.qualify, optimizer.optimize): + source_query = parse_one('SELECT * FROM example."source"', read="snowflake") + transformed = func(source_query, dialect="snowflake", schema=schema) + self.assertEqual(transformed.sql(pretty=True, dialect="snowflake"), expected) diff --git a/tests/test_schema.py b/tests/test_schema.py index 072a41dbc1..83aad213cf 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -221,3 +221,7 @@ def test_schema_normalization(self): # Check that names are normalized to uppercase for Snowflake schema = MappingSchema(schema={"x": {"foo": "int", '"bLa"': "int"}}, dialect="snowflake") self.assertEqual(schema.column_names(exp.Table(this="x")), ["FOO", "bLa"]) + + # Check that switching off the normalization logic works as expected + schema = MappingSchema(schema={"x": {"foo": "int"}}, normalize=False, dialect="snowflake") + self.assertEqual(schema.column_names(exp.Table(this="x")), ["foo"])