Skip to content

Commit

Permalink
fix(ingest/snowflake): add more reporting for usage aggregation, hand…
Browse files Browse the repository at this point in the history
…le lineage errors (datahub-project#10279)

Co-authored-by: Harshal Sheth <hsheth2@gmail.com>
  • Loading branch information
2 people authored and sleeperdeep committed Jun 25, 2024
1 parent 5189d6e commit 2c4833d
Show file tree
Hide file tree
Showing 14 changed files with 443 additions and 104 deletions.
2 changes: 1 addition & 1 deletion metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
sqlglot_lib = {
# Using an Acryl fork of sqlglot.
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:hsheth?expand=1
"acryl-sqlglot==23.11.2.dev2",
"acryl-sqlglot[rs]==23.11.2.dev2",
}

classification_lib = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,9 @@ def _fetch_upstream_lineages_for_tables(self) -> Iterable[UpstreamLineageEdge]:
)
try:
for db_row in self.query(query):
yield UpstreamLineageEdge.parse_obj(db_row)
edge = self._process_upstream_lineage_row(db_row)
if edge:
yield edge
except Exception as e:
if isinstance(e, SnowflakePermissionError):
error_msg = "Failed to get table/view to table lineage. Please grant imported privileges on SNOWFLAKE database. "
Expand All @@ -382,6 +384,19 @@ def _fetch_upstream_lineages_for_tables(self) -> Iterable[UpstreamLineageEdge]:
)
self.report_status(TABLE_LINEAGE, False)

def _process_upstream_lineage_row(
self, db_row: dict
) -> Optional[UpstreamLineageEdge]:
try:
return UpstreamLineageEdge.parse_obj(db_row)
except Exception as e:
self.report.num_upstream_lineage_edge_parsing_failed += 1
self.report_warning(
f"Parsing lineage edge failed due to error {e}",
db_row.get("DOWNSTREAM_TABLE_NAME") or "",
)
return None

def map_query_result_upstreams(
self, upstream_tables: Optional[List[UpstreamTableNode]], query_id: str
) -> List[UrnStr]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,13 +558,20 @@ def usage_per_object_per_time_bucket_for_time_window(
include_top_n_queries: bool,
email_domain: Optional[str],
email_filter: AllowDenyPattern,
table_deny_pattern: List[str] = DEFAULT_TABLES_DENY_LIST,
) -> str:
if not include_top_n_queries:
top_n_queries = 0
assert (
time_bucket_size == BucketDuration.DAY
or time_bucket_size == BucketDuration.HOUR
)

temp_table_filter = create_deny_regex_sql_filter(
table_deny_pattern,
["object_name"],
)

objects_column = (
"BASE_OBJECTS_ACCESSED" if use_base_objects else "DIRECT_OBJECTS_ACCESSED"
)
Expand Down Expand Up @@ -604,6 +611,7 @@ def usage_per_object_per_time_bucket_for_time_window(
)
t,
lateral flatten(input => t.{objects_column}) object
{("where " + temp_table_filter) if temp_table_filter else ""}
)
,
field_access_history AS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import datetime
from typing import Dict, List, MutableSet, Optional

from datahub.ingestion.api.report import Report
from datahub.ingestion.glossary.classification_mixin import ClassificationReportMixin
from datahub.ingestion.source.snowflake.constants import SnowflakeEdition
from datahub.ingestion.source.sql.sql_generic_profiler import ProfilingSqlReport
Expand All @@ -11,6 +12,19 @@
from datahub.ingestion.source_report.ingestion_stage import IngestionStageReport
from datahub.ingestion.source_report.time_window import BaseTimeWindowReport
from datahub.sql_parsing.sql_parsing_aggregator import SqlAggregatorReport
from datahub.utilities.perf_timer import PerfTimer


@dataclass
class SnowflakeUsageAggregationReport(Report):
query_secs: float = -1
query_row_count: int = -1
result_fetch_timer: PerfTimer = field(default_factory=PerfTimer)
result_skip_timer: PerfTimer = field(default_factory=PerfTimer)
result_map_timer: PerfTimer = field(default_factory=PerfTimer)
users_map_timer: PerfTimer = field(default_factory=PerfTimer)
queries_map_timer: PerfTimer = field(default_factory=PerfTimer)
fields_map_timer: PerfTimer = field(default_factory=PerfTimer)


@dataclass
Expand All @@ -31,6 +45,10 @@ class SnowflakeUsageReport:
usage_end_time: Optional[datetime] = None
stateful_usage_ingestion_enabled: bool = False

usage_aggregation: SnowflakeUsageAggregationReport = (
SnowflakeUsageAggregationReport()
)


@dataclass
class SnowflakeReport(ProfilingSqlReport, BaseTimeWindowReport):
Expand Down Expand Up @@ -83,12 +101,10 @@ class SnowflakeV2Report(
include_technical_schema: bool = False
include_column_lineage: bool = False

usage_aggregation_query_secs: float = -1
table_lineage_query_secs: float = -1
# view_lineage_parse_secs: float = -1
# view_upstream_lineage_query_secs: float = -1
# view_downstream_lineage_query_secs: float = -1
external_lineage_queries_secs: float = -1
num_tables_with_known_upstreams: int = 0
num_upstream_lineage_edge_parsing_failed: int = 0

# Reports how many times we reset in-memory `functools.lru_cache` caches of data,
# which occurs when we occur a different database / schema.
Expand All @@ -115,14 +131,6 @@ class SnowflakeV2Report(

edition: Optional[SnowflakeEdition] = None

# num_tables_with_external_upstreams_only: int = 0
num_tables_with_known_upstreams: int = 0
# num_views_with_upstreams: int = 0

# num_view_definitions_parsed: int = 0
# num_view_definitions_failed_parsing: int = 0
# num_view_definitions_failed_column_parsing: int = 0

def report_entity_scanned(self, name: str, ent_type: str = "table") -> None:
"""
Entity could be a view or a table or a schema or a database
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@
)
from datahub.metadata.com.linkedin.pegasus2avro.timeseries import TimeWindowSize
from datahub.metadata.schema_classes import OperationClass, OperationTypeClass
from datahub.sql_parsing.sqlglot_utils import try_format_query
from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.sql_formatter import format_sql_query, trim_query
from datahub.utilities.sql_formatter import trim_query

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -216,6 +217,7 @@ def _get_workunits_internal(
include_top_n_queries=self.config.include_top_n_queries,
email_domain=self.config.email_domain,
email_filter=self.config.user_email_pattern,
table_deny_pattern=self.config.temporary_tables_pattern,
),
)
except Exception as e:
Expand All @@ -227,29 +229,46 @@ def _get_workunits_internal(
self.report_status(USAGE_EXTRACTION_USAGE_AGGREGATION, False)
return

self.report.usage_aggregation_query_secs = timer.elapsed_seconds()
self.report.usage_aggregation.query_secs = timer.elapsed_seconds()
self.report.usage_aggregation.query_row_count = results.rowcount

for row in results:
if not self._is_dataset_pattern_allowed(
row["OBJECT_NAME"],
row["OBJECT_DOMAIN"],
):
continue

dataset_identifier = self.get_dataset_identifier_from_qualified_name(
row["OBJECT_NAME"]
)
if dataset_identifier not in discovered_datasets:
logger.debug(
f"Skipping usage for table {dataset_identifier}, as table schema is not accessible or not allowed by recipe."
)
continue
with self.report.usage_aggregation.result_fetch_timer as fetch_timer:
for row in results:
with fetch_timer.pause(), self.report.usage_aggregation.result_skip_timer as skip_timer:
if results.rownumber is not None and results.rownumber % 1000 == 0:
logger.debug(f"Processing usage row number {results.rownumber}")
logger.debug(self.report.usage_aggregation.as_string())

yield from self.build_usage_statistics_for_dataset(dataset_identifier, row)
if not self._is_dataset_pattern_allowed(
row["OBJECT_NAME"],
row["OBJECT_DOMAIN"],
):
logger.debug(
f"Skipping usage for {row['OBJECT_DOMAIN']} {row['OBJECT_NAME']}, as table is not allowed by recipe."
)
continue

dataset_identifier = (
self.get_dataset_identifier_from_qualified_name(
row["OBJECT_NAME"]
)
)
if dataset_identifier not in discovered_datasets:
logger.debug(
f"Skipping usage for {row['OBJECT_DOMAIN']} {dataset_identifier}, as table is not accessible."
)
continue
with skip_timer.pause(), self.report.usage_aggregation.result_map_timer as map_timer:
wu = self.build_usage_statistics_for_dataset(
dataset_identifier, row
)
if wu:
with map_timer.pause():
yield wu

def build_usage_statistics_for_dataset(
self, dataset_identifier: str, row: dict
) -> Iterable[MetadataWorkUnit]:
) -> Optional[MetadataWorkUnit]:
try:
stats = DatasetUsageStatistics(
timestampMillis=int(row["BUCKET_START_TIME"].timestamp() * 1000),
Expand All @@ -258,18 +277,15 @@ def build_usage_statistics_for_dataset(
),
totalSqlQueries=row["TOTAL_QUERIES"],
uniqueUserCount=row["TOTAL_USERS"],
topSqlQueries=self._map_top_sql_queries(
json.loads(row["TOP_SQL_QUERIES"])
)
if self.config.include_top_n_queries
else None,
userCounts=self._map_user_counts(
json.loads(row["USER_COUNTS"]),
topSqlQueries=(
self._map_top_sql_queries(row["TOP_SQL_QUERIES"])
if self.config.include_top_n_queries
else None
),
fieldCounts=self._map_field_counts(json.loads(row["FIELD_COUNTS"])),
userCounts=self._map_user_counts(row["USER_COUNTS"]),
fieldCounts=self._map_field_counts(row["FIELD_COUNTS"]),
)

yield MetadataChangeProposalWrapper(
return MetadataChangeProposalWrapper(
entityUrn=self.dataset_urn_builder(dataset_identifier), aspect=stats
).as_workunit()
except Exception as e:
Expand All @@ -281,61 +297,79 @@ def build_usage_statistics_for_dataset(
"Failed to parse usage statistics for dataset", dataset_identifier
)

def _map_top_sql_queries(self, top_sql_queries: Dict) -> List[str]:
budget_per_query: int = int(
self.config.queries_character_limit / self.config.top_n_queries
)
return sorted(
[
trim_query(format_sql_query(query), budget_per_query)
if self.config.format_sql_queries
else trim_query(query, budget_per_query)
for query in top_sql_queries
]
)
return None

def _map_top_sql_queries(self, top_sql_queries_str: str) -> List[str]:
with self.report.usage_aggregation.queries_map_timer:
top_sql_queries = json.loads(top_sql_queries_str)
budget_per_query: int = int(
self.config.queries_character_limit / self.config.top_n_queries
)
return sorted(
[
(
trim_query(
try_format_query(query, self.platform), budget_per_query
)
if self.config.format_sql_queries
else trim_query(query, budget_per_query)
)
for query in top_sql_queries
]
)

def _map_user_counts(
self,
user_counts: Dict,
user_counts_str: str,
) -> List[DatasetUserUsageCounts]:
filtered_user_counts = []
for user_count in user_counts:
user_email = user_count.get("email")
if not user_email and self.config.email_domain and user_count["user_name"]:
user_email = "{0}@{1}".format(
user_count["user_name"], self.config.email_domain
).lower()
if not user_email or not self.config.user_email_pattern.allowed(user_email):
continue

filtered_user_counts.append(
DatasetUserUsageCounts(
user=make_user_urn(
self.get_user_identifier(
user_count["user_name"],
user_email,
self.config.email_as_user_identifier,
)
),
count=user_count["total"],
# NOTE: Generated emails may be incorrect, as email may be different than
# username@email_domain
userEmail=user_email,
with self.report.usage_aggregation.users_map_timer:
user_counts = json.loads(user_counts_str)
filtered_user_counts = []
for user_count in user_counts:
user_email = user_count.get("email")
if (
not user_email
and self.config.email_domain
and user_count["user_name"]
):
user_email = "{0}@{1}".format(
user_count["user_name"], self.config.email_domain
).lower()
if not user_email or not self.config.user_email_pattern.allowed(
user_email
):
continue

filtered_user_counts.append(
DatasetUserUsageCounts(
user=make_user_urn(
self.get_user_identifier(
user_count["user_name"],
user_email,
self.config.email_as_user_identifier,
)
),
count=user_count["total"],
# NOTE: Generated emails may be incorrect, as email may be different than
# username@email_domain
userEmail=user_email,
)
)
return sorted(filtered_user_counts, key=lambda v: v.user)

def _map_field_counts(self, field_counts_str: str) -> List[DatasetFieldUsageCounts]:
with self.report.usage_aggregation.fields_map_timer:
field_counts = json.loads(field_counts_str)
return sorted(
[
DatasetFieldUsageCounts(
fieldPath=self.snowflake_identifier(field_count["col"]),
count=field_count["total"],
)
for field_count in field_counts
],
key=lambda v: v.fieldPath,
)
return sorted(filtered_user_counts, key=lambda v: v.user)

def _map_field_counts(self, field_counts: Dict) -> List[DatasetFieldUsageCounts]:
return sorted(
[
DatasetFieldUsageCounts(
fieldPath=self.snowflake_identifier(field_count["col"]),
count=field_count["total"],
)
for field_count in field_counts
],
key=lambda v: v.fieldPath,
)

def _get_snowflake_history(self) -> Iterable[SnowflakeJoinedAccessEvent]:
logger.info("Getting access history")
Expand Down Expand Up @@ -438,9 +472,11 @@ def _get_operation_aspect_work_unit(
lastUpdatedTimestamp=last_updated_timestamp,
actor=user_urn,
operationType=operation_type,
customOperationType=query_type
if operation_type is OperationTypeClass.CUSTOM
else None,
customOperationType=(
query_type
if operation_type is OperationTypeClass.CUSTOM
else None
),
)
mcp = MetadataChangeProposalWrapper(
entityUrn=self.dataset_urn_builder(dataset_identifier),
Expand Down
Loading

0 comments on commit 2c4833d

Please sign in to comment.