From 506d7a63bb19a76acf895aeec958cd220d23aa12 Mon Sep 17 00:00:00 2001 From: Mayuri Nehate <33225191+mayurinehate@users.noreply.github.com> Date: Tue, 24 Sep 2024 22:00:39 +0530 Subject: [PATCH] fix(ingest): do not cache temporary tables schema resolvers (#11432) --- .../source/bigquery_v2/bigquery_report.py | 5 +++- .../source/bigquery_v2/queries_extractor.py | 6 +++++ .../sql_parsing/sql_parsing_aggregator.py | 21 +++++++++++---- .../datahub/sql_parsing/sqlglot_lineage.py | 27 ++++++++++++++++--- 4 files changed, 50 insertions(+), 9 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py index b333bcf695a46..7f64055f505f4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py @@ -12,7 +12,7 @@ from datahub.ingestion.source_report.ingestion_stage import IngestionStageReport from datahub.ingestion.source_report.time_window import BaseTimeWindowReport from datahub.sql_parsing.sql_parsing_aggregator import SqlAggregatorReport -from datahub.utilities.lossy_collections import LossyDict, LossyList +from datahub.utilities.lossy_collections import LossyDict, LossyList, LossySet from datahub.utilities.perf_timer import PerfTimer from datahub.utilities.stats_collections import TopKDict, int_top_k_dict @@ -69,6 +69,9 @@ class BigQueryQueriesExtractorReport(Report): num_total_queries: int = 0 num_unique_queries: int = 0 + num_discovered_tables: Optional[int] = None + inferred_temp_tables: LossySet[str] = field(default_factory=LossySet) + @dataclass class BigQueryV2Report( diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py index a826f09b9a7c8..d57ec655b1f88 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py @@ -173,6 +173,9 @@ def __init__( format_queries=False, ) self.report.sql_aggregator = self.aggregator.report + self.report.num_discovered_tables = ( + len(self.discovered_tables) if self.discovered_tables else None + ) @functools.cached_property def local_temp_path(self) -> pathlib.Path: @@ -201,6 +204,7 @@ def is_temp_table(self, name: str) -> bool: and self.discovered_tables and str(BigQueryTableRef(table)) not in self.discovered_tables ): + self.report.inferred_temp_tables.add(name) return True except Exception: @@ -264,6 +268,8 @@ def get_workunits_internal( for query in query_instances.values(): if i > 0 and i % 10000 == 0: logger.info(f"Added {i} query log entries to SQL aggregator") + if self.report.sql_aggregator: + logger.info(self.report.sql_aggregator.as_string()) self.aggregator.add(query) i += 1 diff --git a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py index d945e135f0012..f5908753affde 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py @@ -39,10 +39,12 @@ ColumnRef, DownstreamColumnRef, SqlParsingResult, + _sqlglot_lineage_cached, infer_output_schema, sqlglot_lineage, ) from datahub.sql_parsing.sqlglot_utils import ( + _parse_statement, generate_hash, get_query_fingerprint, try_format_query, @@ -222,6 +224,9 @@ class SqlAggregatorReport(Report): sql_parsing_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) sql_fingerprinting_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) sql_formatting_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) + sql_parsing_cache_stats: Optional[dict] = dataclasses.field(default=None) + parse_statement_cache_stats: Optional[dict] = dataclasses.field(default=None) + format_query_cache_stats: Optional[dict] = dataclasses.field(default=None) # Other lineage loading metrics. num_known_query_lineage: int = 0 @@ -239,6 +244,7 @@ class SqlAggregatorReport(Report): queries_with_non_authoritative_session: LossyList[QueryId] = dataclasses.field( default_factory=LossyList ) + make_schema_resolver_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) # Lineage-related. schema_resolver_count: Optional[int] = None @@ -272,6 +278,10 @@ def compute_stats(self) -> None: self.num_temp_sessions = len(self._aggregator._temp_lineage_map) self.num_inferred_temp_schemas = len(self._aggregator._inferred_temp_schemas) + self.sql_parsing_cache_stats = _sqlglot_lineage_cached.cache_info()._asdict() + self.parse_statement_cache_stats = _parse_statement.cache_info()._asdict() + self.format_query_cache_stats = try_format_query.cache_info()._asdict() + return super().compute_stats() @@ -679,11 +689,12 @@ def add_observed_query( # All queries with no session ID are assumed to be part of the same session. session_id = observed.session_id or _MISSING_SESSION_ID - # Load in the temp tables for this session. - schema_resolver: SchemaResolverInterface = ( - self._make_schema_resolver_for_session(session_id) - ) - session_has_temp_tables = schema_resolver.includes_temp_tables() + with self.report.make_schema_resolver_timer: + # Load in the temp tables for this session. + schema_resolver: SchemaResolverInterface = ( + self._make_schema_resolver_for_session(session_id) + ) + session_has_temp_tables = schema_resolver.includes_temp_tables() # Run the SQL parser. parsed = self._run_sql_parser( diff --git a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py index 27d99a14c0520..0806d0ec774fe 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py @@ -1020,9 +1020,8 @@ def _sqlglot_lineage_inner( ) -@functools.lru_cache(maxsize=SQL_PARSE_RESULT_CACHE_SIZE) -def sqlglot_lineage( - sql: str, +def _sqlglot_lineage_nocache( + sql: sqlglot.exp.ExpOrStr, schema_resolver: SchemaResolverInterface, default_db: Optional[str] = None, default_schema: Optional[str] = None, @@ -1091,6 +1090,28 @@ def sqlglot_lineage( return SqlParsingResult.make_from_error(e) +_sqlglot_lineage_cached = functools.lru_cache(maxsize=SQL_PARSE_RESULT_CACHE_SIZE)( + _sqlglot_lineage_nocache +) + + +def sqlglot_lineage( + sql: sqlglot.exp.ExpOrStr, + schema_resolver: SchemaResolverInterface, + default_db: Optional[str] = None, + default_schema: Optional[str] = None, + default_dialect: Optional[str] = None, +) -> SqlParsingResult: + if schema_resolver.includes_temp_tables(): + return _sqlglot_lineage_nocache( + sql, schema_resolver, default_db, default_schema, default_dialect + ) + else: + return _sqlglot_lineage_cached( + sql, schema_resolver, default_db, default_schema, default_dialect + ) + + def create_lineage_sql_parsed_result( query: str, default_db: Optional[str],