Skip to content

Commit

Permalink
feat(cli): cache sql parsing intermediates (datahub-project#10399)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored and sleeperdeep committed Jun 25, 2024
1 parent 90f1ee7 commit 812136d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
14 changes: 8 additions & 6 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,11 @@ def _schema_aware_fuzzy_column_resolve(
return default_col_name

# Optimize the statement + qualify column references.
logger.debug(
"Prior to column qualification sql %s",
statement.sql(pretty=True, dialect=dialect),
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Prior to column qualification sql %s",
statement.sql(pretty=True, dialect=dialect),
)
try:
# Second time running qualify, this time with:
# - the select instead of the full outer statement
Expand All @@ -434,7 +435,8 @@ def _schema_aware_fuzzy_column_resolve(
raise SqlUnderstandingError(
f"sqlglot failed to map columns to their source tables; likely missing/outdated table schema info: {e}"
) from e
logger.debug("Qualified sql %s", statement.sql(pretty=True, dialect=dialect))
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Qualified sql %s", statement.sql(pretty=True, dialect=dialect))

# Handle the create DDL case.
if is_create_ddl:
Expand Down Expand Up @@ -805,7 +807,7 @@ def _sqlglot_lineage_inner(
logger.debug("Parsing lineage from sql statement: %s", sql)
statement = parse_statement(sql, dialect=dialect)

original_statement = statement.copy()
original_statement, statement = statement, statement.copy()
# logger.debug(
# "Formatted sql statement: %s",
# original_statement.sql(pretty=True, dialect=dialect),
Expand Down
16 changes: 15 additions & 1 deletion metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import hashlib
import logging
from typing import Dict, Iterable, Optional, Tuple, Union
Expand All @@ -7,6 +8,7 @@

logger = logging.getLogger(__name__)
DialectOrStr = Union[sqlglot.Dialect, str]
SQL_PARSE_CACHE_SIZE = 1000


def _get_dialect_str(platform: str) -> str:
Expand Down Expand Up @@ -55,7 +57,8 @@ def is_dialect_instance(
return False


def parse_statement(
@functools.lru_cache(maxsize=SQL_PARSE_CACHE_SIZE)
def _parse_statement(
sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect
) -> sqlglot.Expression:
statement: sqlglot.Expression = sqlglot.maybe_parse(
Expand All @@ -64,6 +67,16 @@ def parse_statement(
return statement


def parse_statement(
sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect
) -> sqlglot.Expression:
# Parsing is significantly more expensive than copying the expression.
# Because the expressions are mutable, we don't want to allow the caller
# to modify the parsed expression that sits in the cache. We keep
# the cached versions pristine by returning a copy on each call.
return _parse_statement(sql, dialect).copy()


def parse_statements_and_pick(sql: str, platform: DialectOrStr) -> sqlglot.Expression:
dialect = get_dialect(platform)
statements = [
Expand Down Expand Up @@ -277,4 +290,5 @@ def replace_cte_refs(node: sqlglot.exp.Expression) -> sqlglot.exp.Expression:
else:
return node

statement = statement.copy()
return statement.transform(replace_cte_refs, copy=False)

0 comments on commit 812136d

Please sign in to comment.