Skip to content

Commit

Permalink
Fix: avoid concealing dialect module exception in _try_load (#4708)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Feb 5, 2025
1 parent 5f90307 commit 23283ca
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
17 changes: 8 additions & 9 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,15 @@ def classes(cls):

@classmethod
def _try_load(cls, key: str | Dialects) -> None:
try:
if isinstance(key, Dialects):
key = key.value

# This import will lead to a new dialect being loaded, and hence, registered.
# We assert that the key is an actual module to avoid blindly importing files.
assert key in DIALECT_MODULE_NAMES
if isinstance(key, Dialects):
key = key.value

# This import will lead to a new dialect being loaded, and hence, registered.
# We check that the key is an actual sqlglot module to avoid blindly importing
# files. Custom user dialects need to be imported at the top-level package, in
# order for them to be registered as soon as possible.
if key in DIALECT_MODULE_NAMES:
importlib.import_module(f"sqlglot.dialects.{key}")
except Exception:
pass

@classmethod
def __getitem__(cls, key: str) -> t.Type[Dialect]:
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def transform(self, fun: t.Callable, *args: t.Any, copy: bool = True, **kwargs)

if not root:
root = new_node
elif new_node is not node:
elif parent and arg_key and new_node is not node:
parent.set(arg_key, new_node, index)

assert root
Expand Down
8 changes: 5 additions & 3 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3176,9 +3176,11 @@ def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.With]:
last_comments = None
expressions = []
while True:
expressions.append(self._parse_cte())
if last_comments:
expressions[-1].add_comments(last_comments)
cte = self._parse_cte()
if isinstance(cte, exp.CTE):
expressions.append(cte)
if last_comments:
cte.add_comments(last_comments)

if not self._match(TokenType.COMMA) and not self._match(TokenType.WITH):
break
Expand Down

0 comments on commit 23283ca

Please sign in to comment.