diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 00f860ce..1ec1e072 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,8 @@ Source code is also available at: # Release Notes - (Unreleased) - - Fix return value of snowflake get_table_names + - Fix return value of snowflake get_table_names. + - Fix incorrect quoting of identifiers with `_` as initial character. - Added `force_div_is_floordiv` flag to override `div_is_floordiv` new default value `False` in `SnowflakeDialect`. - With the flag in `False`, the `/` division operator will be treated as a float division and `//` as a floor division. - This flag is added to maintain backward compatibility with the previous behavior of Snowflake Dialect division. diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 0226d37d..dc624949 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -117,7 +117,11 @@ r"\s*(?:UPDATE|INSERT|DELETE|MERGE|COPY)", re.I | re.UNICODE ) # used for quoting identifiers ie. table names, column names, etc. -ILLEGAL_INITIAL_CHARACTERS = frozenset({d for d in string.digits}.union({"_", "$"})) +ILLEGAL_INITIAL_CHARACTERS = frozenset({d for d in string.digits}.union({"$"})) + + +# used for quoting identifiers ie. table names, column names, etc. +ILLEGAL_IDENTIFIERS = frozenset({d for d in string.digits}.union({"_"})) """ Overwrite methods to handle Snowflake BCR change: @@ -443,6 +447,7 @@ def _join_left_to_right( class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = {x.lower() for x in RESERVED_WORDS} illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + illegal_identifiers = ILLEGAL_IDENTIFIERS def __init__(self, dialect, **kw): quote = '"' @@ -471,6 +476,17 @@ def format_label(self, label, name=None): return self.quote_identifier(s) if n.quote else s + def _requires_quotes(self, value: str) -> bool: + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return ( + lc_value in self.reserved_words + or lc_value in self.illegal_identifiers + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(str(value)) + or (lc_value != value) + ) + def _split_schema_by_dot(self, schema): ret = [] idx = 0 diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 0eea4607..cb9632a4 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -50,6 +50,21 @@ def test_underscore_as_valid_identifier(self): dialect="snowflake", ) + def test_underscore_as_initial_character_as_non_quoted_identifier(self): + _table = table( + "table_1745924", + column("ca", Integer), + column("cb", String), + column("_identifier", String), + ) + + stmt = insert(_table).values(ca=1, cb="test", _identifier="test_") + self.assert_compile( + stmt, + "INSERT INTO table_1745924 (ca, cb, _identifier) VALUES (%(ca)s, %(cb)s, %(_identifier)s)", + dialect="snowflake", + ) + def test_multi_table_delete(self): statement = table1.delete().where(table1.c.id == table2.c.id) self.assert_compile(