Skip to content

Commit

Permalink
Fix incorrect quoting of identifiers with _ as initial character.
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jvasquezrojas committed Jan 10, 2025
1 parent b9b26e5 commit 99740da
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Source code is also available at:

# Release Notes
- (Unreleased)
- Fix return value of snowflake get_table_names
- Fix return value of snowflake get_table_names.
- Fix incorrect quoting of identifiers with `_` as initial character.
- Added `force_div_is_floordiv` flag to override `div_is_floordiv` new default value `False` in `SnowflakeDialect`.
- With the flag in `False`, the `/` division operator will be treated as a float division and `//` as a floor division.
- This flag is added to maintain backward compatibility with the previous behavior of Snowflake Dialect division.
Expand Down
18 changes: 17 additions & 1 deletion src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@
r"\s*(?:UPDATE|INSERT|DELETE|MERGE|COPY)", re.I | re.UNICODE
)
# used for quoting identifiers ie. table names, column names, etc.
ILLEGAL_INITIAL_CHARACTERS = frozenset({d for d in string.digits}.union({"_", "$"}))
ILLEGAL_INITIAL_CHARACTERS = frozenset({d for d in string.digits}.union({"$"}))


# used for quoting identifiers ie. table names, column names, etc.
ILLEGAL_IDENTIFIERS = frozenset({d for d in string.digits}.union({"_"}))

"""
Overwrite methods to handle Snowflake BCR change:
Expand Down Expand Up @@ -443,6 +447,7 @@ def _join_left_to_right(
class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = {x.lower() for x in RESERVED_WORDS}
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
illegal_identifiers = ILLEGAL_IDENTIFIERS

def __init__(self, dialect, **kw):
quote = '"'
Expand Down Expand Up @@ -471,6 +476,17 @@ def format_label(self, label, name=None):

return self.quote_identifier(s) if n.quote else s

def _requires_quotes(self, value: str) -> bool:
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
return (
lc_value in self.reserved_words
or lc_value in self.illegal_identifiers
or value[0] in self.illegal_initial_characters
or not self.legal_characters.match(str(value))
or (lc_value != value)
)

def _split_schema_by_dot(self, schema):
ret = []
idx = 0
Expand Down
15 changes: 15 additions & 0 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ def test_underscore_as_valid_identifier(self):
dialect="snowflake",
)

def test_underscore_as_initial_character_as_non_quoted_identifier(self):
_table = table(
"table_1745924",
column("ca", Integer),
column("cb", String),
column("_identifier", String),
)

stmt = insert(_table).values(ca=1, cb="test", _identifier="test_")
self.assert_compile(
stmt,
"INSERT INTO table_1745924 (ca, cb, _identifier) VALUES (%(ca)s, %(cb)s, %(_identifier)s)",
dialect="snowflake",
)

def test_multi_table_delete(self):
statement = table1.delete().where(table1.c.id == table2.c.id)
self.assert_compile(
Expand Down

0 comments on commit 99740da

Please sign in to comment.