Skip to content

Commit

Permalink
fix(ingest): do not cache temporary tables schema resolvers (#11432)
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate authored Sep 24, 2024
1 parent 6f95dc0 commit 506d7a6
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from datahub.ingestion.source_report.ingestion_stage import IngestionStageReport
from datahub.ingestion.source_report.time_window import BaseTimeWindowReport
from datahub.sql_parsing.sql_parsing_aggregator import SqlAggregatorReport
from datahub.utilities.lossy_collections import LossyDict, LossyList
from datahub.utilities.lossy_collections import LossyDict, LossyList, LossySet
from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.stats_collections import TopKDict, int_top_k_dict

Expand Down Expand Up @@ -69,6 +69,9 @@ class BigQueryQueriesExtractorReport(Report):
num_total_queries: int = 0
num_unique_queries: int = 0

num_discovered_tables: Optional[int] = None
inferred_temp_tables: LossySet[str] = field(default_factory=LossySet)


@dataclass
class BigQueryV2Report(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ def __init__(
format_queries=False,
)
self.report.sql_aggregator = self.aggregator.report
self.report.num_discovered_tables = (
len(self.discovered_tables) if self.discovered_tables else None
)

@functools.cached_property
def local_temp_path(self) -> pathlib.Path:
Expand Down Expand Up @@ -201,6 +204,7 @@ def is_temp_table(self, name: str) -> bool:
and self.discovered_tables
and str(BigQueryTableRef(table)) not in self.discovered_tables
):
self.report.inferred_temp_tables.add(name)
return True

except Exception:
Expand Down Expand Up @@ -264,6 +268,8 @@ def get_workunits_internal(
for query in query_instances.values():
if i > 0 and i % 10000 == 0:
logger.info(f"Added {i} query log entries to SQL aggregator")
if self.report.sql_aggregator:
logger.info(self.report.sql_aggregator.as_string())

self.aggregator.add(query)
i += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@
ColumnRef,
DownstreamColumnRef,
SqlParsingResult,
_sqlglot_lineage_cached,
infer_output_schema,
sqlglot_lineage,
)
from datahub.sql_parsing.sqlglot_utils import (
_parse_statement,
generate_hash,
get_query_fingerprint,
try_format_query,
Expand Down Expand Up @@ -222,6 +224,9 @@ class SqlAggregatorReport(Report):
sql_parsing_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
sql_fingerprinting_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
sql_formatting_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
sql_parsing_cache_stats: Optional[dict] = dataclasses.field(default=None)
parse_statement_cache_stats: Optional[dict] = dataclasses.field(default=None)
format_query_cache_stats: Optional[dict] = dataclasses.field(default=None)

# Other lineage loading metrics.
num_known_query_lineage: int = 0
Expand All @@ -239,6 +244,7 @@ class SqlAggregatorReport(Report):
queries_with_non_authoritative_session: LossyList[QueryId] = dataclasses.field(
default_factory=LossyList
)
make_schema_resolver_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)

# Lineage-related.
schema_resolver_count: Optional[int] = None
Expand Down Expand Up @@ -272,6 +278,10 @@ def compute_stats(self) -> None:
self.num_temp_sessions = len(self._aggregator._temp_lineage_map)
self.num_inferred_temp_schemas = len(self._aggregator._inferred_temp_schemas)

self.sql_parsing_cache_stats = _sqlglot_lineage_cached.cache_info()._asdict()
self.parse_statement_cache_stats = _parse_statement.cache_info()._asdict()
self.format_query_cache_stats = try_format_query.cache_info()._asdict()

return super().compute_stats()


Expand Down Expand Up @@ -679,11 +689,12 @@ def add_observed_query(
# All queries with no session ID are assumed to be part of the same session.
session_id = observed.session_id or _MISSING_SESSION_ID

# Load in the temp tables for this session.
schema_resolver: SchemaResolverInterface = (
self._make_schema_resolver_for_session(session_id)
)
session_has_temp_tables = schema_resolver.includes_temp_tables()
with self.report.make_schema_resolver_timer:
# Load in the temp tables for this session.
schema_resolver: SchemaResolverInterface = (
self._make_schema_resolver_for_session(session_id)
)
session_has_temp_tables = schema_resolver.includes_temp_tables()

# Run the SQL parser.
parsed = self._run_sql_parser(
Expand Down
27 changes: 24 additions & 3 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,9 +1020,8 @@ def _sqlglot_lineage_inner(
)


@functools.lru_cache(maxsize=SQL_PARSE_RESULT_CACHE_SIZE)
def sqlglot_lineage(
sql: str,
def _sqlglot_lineage_nocache(
sql: sqlglot.exp.ExpOrStr,
schema_resolver: SchemaResolverInterface,
default_db: Optional[str] = None,
default_schema: Optional[str] = None,
Expand Down Expand Up @@ -1091,6 +1090,28 @@ def sqlglot_lineage(
return SqlParsingResult.make_from_error(e)


_sqlglot_lineage_cached = functools.lru_cache(maxsize=SQL_PARSE_RESULT_CACHE_SIZE)(
_sqlglot_lineage_nocache
)


def sqlglot_lineage(
sql: sqlglot.exp.ExpOrStr,
schema_resolver: SchemaResolverInterface,
default_db: Optional[str] = None,
default_schema: Optional[str] = None,
default_dialect: Optional[str] = None,
) -> SqlParsingResult:
if schema_resolver.includes_temp_tables():
return _sqlglot_lineage_nocache(
sql, schema_resolver, default_db, default_schema, default_dialect
)
else:
return _sqlglot_lineage_cached(
sql, schema_resolver, default_db, default_schema, default_dialect
)


def create_lineage_sql_parsed_result(
query: str,
default_db: Optional[str],
Expand Down

0 comments on commit 506d7a6

Please sign in to comment.