From 82d17f927169846aa00aaac8cf11d76179f84fe2 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 14 Jun 2024 13:23:07 -0700 Subject: [PATCH] feat(ingest/snowflake): refactor + parallel schema extraction (#10653) --- metadata-ingestion/setup.py | 2 + .../datahub/ingestion/sink/datahub_rest.py | 52 +- .../source/snowflake/snowflake_report.py | 13 +- .../source/snowflake/snowflake_schema.py | 40 +- .../source/snowflake/snowflake_schema_gen.py | 1077 +++++++++++++++++ .../source/snowflake/snowflake_summary.py | 18 +- .../source/snowflake/snowflake_utils.py | 11 + .../source/snowflake/snowflake_v2.py | 1015 +--------------- .../datahub/utilities/serialized_lru_cache.py | 98 ++ .../integration/snowflake/test_snowflake.py | 12 +- .../tests/unit/test_serialized_lru_cache.py | 92 ++ 11 files changed, 1394 insertions(+), 1036 deletions(-) create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py create mode 100644 metadata-ingestion/src/datahub/utilities/serialized_lru_cache.py create mode 100644 metadata-ingestion/tests/unit/test_serialized_lru_cache.py diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index cd8c9d4541c1d6..f51908999ec158 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -198,6 +198,7 @@ "pandas", "cryptography", "msal", + "cachetools", } | classification_lib trino = { @@ -403,6 +404,7 @@ "sagemaker": aws_common, "salesforce": {"simple-salesforce"}, "snowflake": snowflake_common | usage_common | sqlglot_lib, + "snowflake-summary": snowflake_common | usage_common | sqlglot_lib, "sqlalchemy": sql_common, "sql-queries": usage_common | sqlglot_lib, "slack": slack, diff --git a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py index 8572b2378a3bb0..dab8e99b797fe9 100644 --- a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py +++ b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py @@ -3,6 +3,8 @@ import dataclasses import functools import logging +import os +import threading import uuid from enum import auto from typing import Optional, Union @@ -14,7 +16,7 @@ OperationalError, ) from datahub.emitter.mcp import MetadataChangeProposalWrapper -from datahub.emitter.rest_emitter import DatahubRestEmitter +from datahub.emitter.rest_emitter import DataHubRestEmitter from datahub.ingestion.api.common import RecordEnvelope, WorkUnit from datahub.ingestion.api.sink import ( NoopWriteCallback, @@ -34,6 +36,10 @@ logger = logging.getLogger(__name__) +DEFAULT_REST_SINK_MAX_THREADS = int( + os.getenv("DATAHUB_REST_SINK_DEFAULT_MAX_THREADS", 15) +) + class SyncOrAsync(ConfigEnum): SYNC = auto() @@ -44,7 +50,7 @@ class DatahubRestSinkConfig(DatahubClientConfig): mode: SyncOrAsync = SyncOrAsync.ASYNC # These only apply in async mode. - max_threads: int = 15 + max_threads: int = DEFAULT_REST_SINK_MAX_THREADS max_pending_requests: int = 2000 @@ -82,22 +88,12 @@ def _get_partition_key(record_envelope: RecordEnvelope) -> str: class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]): - emitter: DatahubRestEmitter + _emitter_thread_local: threading.local treat_errors_as_warnings: bool = False def __post_init__(self) -> None: - self.emitter = DatahubRestEmitter( - self.config.server, - self.config.token, - connect_timeout_sec=self.config.timeout_sec, # reuse timeout_sec for connect timeout - read_timeout_sec=self.config.timeout_sec, - retry_status_codes=self.config.retry_status_codes, - retry_max_times=self.config.retry_max_times, - extra_headers=self.config.extra_headers, - ca_certificate_path=self.config.ca_certificate_path, - client_certificate_path=self.config.client_certificate_path, - disable_ssl_verification=self.config.disable_ssl_verification, - ) + self._emitter_thread_local = threading.local() + try: gms_config = self.emitter.get_server_config() except Exception as exc: @@ -120,6 +116,32 @@ def __post_init__(self) -> None: max_pending=self.config.max_pending_requests, ) + @classmethod + def _make_emitter(cls, config: DatahubRestSinkConfig) -> DataHubRestEmitter: + return DataHubRestEmitter( + config.server, + config.token, + connect_timeout_sec=config.timeout_sec, # reuse timeout_sec for connect timeout + read_timeout_sec=config.timeout_sec, + retry_status_codes=config.retry_status_codes, + retry_max_times=config.retry_max_times, + extra_headers=config.extra_headers, + ca_certificate_path=config.ca_certificate_path, + client_certificate_path=config.client_certificate_path, + disable_ssl_verification=config.disable_ssl_verification, + ) + + @property + def emitter(self) -> DataHubRestEmitter: + # While this is a property, it actually uses one emitter per thread. + # Since emitter is one-to-one with request sessions, using a separate + # emitter per thread should improve correctness and performance. + # https://github.com/psf/requests/issues/1871#issuecomment-32751346 + thread_local = self._emitter_thread_local + if not hasattr(thread_local, "emitter"): + thread_local.emitter = DatahubRestSink._make_emitter(self.config) + return thread_local.emitter + def handle_work_unit_start(self, workunit: WorkUnit) -> None: if isinstance(workunit, MetadataWorkUnit): self.treat_errors_as_warnings = workunit.treat_errors_as_warnings diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py index d79ed384d755b0..db2095da01134d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import Dict, List, MutableSet, Optional +from typing import TYPE_CHECKING, Dict, List, MutableSet, Optional from datahub.ingestion.api.report import Report from datahub.ingestion.glossary.classification_mixin import ClassificationReportMixin @@ -14,6 +14,11 @@ from datahub.sql_parsing.sql_parsing_aggregator import SqlAggregatorReport from datahub.utilities.perf_timer import PerfTimer +if TYPE_CHECKING: + from datahub.ingestion.source.snowflake.snowflake_schema import ( + SnowflakeDataDictionary, + ) + @dataclass class SnowflakeUsageAggregationReport(Report): @@ -106,11 +111,7 @@ class SnowflakeV2Report( num_tables_with_known_upstreams: int = 0 num_upstream_lineage_edge_parsing_failed: int = 0 - # Reports how many times we reset in-memory `functools.lru_cache` caches of data, - # which occurs when we occur a different database / schema. - # Should not be more than the number of databases / schemas scanned. - # Maps (function name) -> (stat_name) -> (stat_value) - lru_cache_info: Dict[str, Dict[str, int]] = field(default_factory=dict) + data_dictionary_cache: Optional["SnowflakeDataDictionary"] = None # These will be non-zero if snowflake information_schema queries fail with error - # "Information schema query returned too much data. Please repeat query with more selective predicates."" diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py index 292c57494632c5..3e26d2acd78e1c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py @@ -1,19 +1,23 @@ import logging +import os from collections import defaultdict from dataclasses import dataclass, field from datetime import datetime -from functools import lru_cache -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional from snowflake.connector import SnowflakeConnection +from datahub.ingestion.api.report import SupportsAsObj from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeQueryMixin from datahub.ingestion.source.sql.sql_generic import BaseColumn, BaseTable, BaseView +from datahub.utilities.serialized_lru_cache import serialized_lru_cache logger: logging.Logger = logging.getLogger(__name__) +SCHEMA_PARALLELISM = int(os.getenv("DATAHUB_SNOWFLAKE_SCHEMA_PARALLELISM", 20)) + @dataclass class SnowflakePK: @@ -176,7 +180,7 @@ def get_column_tags_for_table( ) -class SnowflakeDataDictionary(SnowflakeQueryMixin): +class SnowflakeDataDictionary(SnowflakeQueryMixin, SupportsAsObj): def __init__(self) -> None: self.logger = logger self.connection: Optional[SnowflakeConnection] = None @@ -189,6 +193,26 @@ def get_connection(self) -> SnowflakeConnection: assert self.connection is not None return self.connection + def as_obj(self) -> Dict[str, Dict[str, int]]: + # TODO: Move this into a proper report type that gets computed. + + # Reports how many times we reset in-memory `functools.lru_cache` caches of data, + # which occurs when we occur a different database / schema. + # Should not be more than the number of databases / schemas scanned. + # Maps (function name) -> (stat_name) -> (stat_value) + lru_cache_functions: List[Callable] = [ + self.get_tables_for_database, + self.get_views_for_database, + self.get_columns_for_schema, + self.get_pk_constraints_for_schema, + self.get_fk_constraints_for_schema, + ] + + report = {} + for func in lru_cache_functions: + report[func.__name__] = func.cache_info()._asdict() # type: ignore + return report + def show_databases(self) -> List[SnowflakeDatabase]: databases: List[SnowflakeDatabase] = [] @@ -241,7 +265,7 @@ def get_schemas_for_database(self, db_name: str) -> List[SnowflakeSchema]: snowflake_schemas.append(snowflake_schema) return snowflake_schemas - @lru_cache(maxsize=1) + @serialized_lru_cache(maxsize=1) def get_tables_for_database( self, db_name: str ) -> Optional[Dict[str, List[SnowflakeTable]]]: @@ -299,7 +323,7 @@ def get_tables_for_schema( ) return tables - @lru_cache(maxsize=1) + @serialized_lru_cache(maxsize=1) def get_views_for_database( self, db_name: str ) -> Optional[Dict[str, List[SnowflakeView]]]: @@ -349,7 +373,7 @@ def get_views_for_schema( ) return views - @lru_cache(maxsize=1) + @serialized_lru_cache(maxsize=SCHEMA_PARALLELISM) def get_columns_for_schema( self, schema_name: str, db_name: str ) -> Optional[Dict[str, List[SnowflakeColumn]]]: @@ -405,7 +429,7 @@ def get_columns_for_table( ) return columns - @lru_cache(maxsize=1) + @serialized_lru_cache(maxsize=SCHEMA_PARALLELISM) def get_pk_constraints_for_schema( self, schema_name: str, db_name: str ) -> Dict[str, SnowflakePK]: @@ -422,7 +446,7 @@ def get_pk_constraints_for_schema( constraints[row["table_name"]].column_names.append(row["column_name"]) return constraints - @lru_cache(maxsize=1) + @serialized_lru_cache(maxsize=SCHEMA_PARALLELISM) def get_fk_constraints_for_schema( self, schema_name: str, db_name: str ) -> Dict[str, List[SnowflakeFK]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py new file mode 100644 index 00000000000000..5a4e37078dd75f --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py @@ -0,0 +1,1077 @@ +import concurrent.futures +import logging +import queue +from typing import Dict, Iterable, List, Optional, Union + +from snowflake.connector import SnowflakeConnection + +from datahub.configuration.pattern_utils import is_schema_allowed +from datahub.emitter.mce_builder import ( + make_data_platform_urn, + make_dataset_urn_with_platform_instance, + make_schema_field_urn, + make_tag_urn, +) +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.glossary.classification_mixin import ( + ClassificationHandler, + classification_workunit_processor, +) +from datahub.ingestion.source.common.subtypes import ( + DatasetContainerSubTypes, + DatasetSubTypes, +) +from datahub.ingestion.source.snowflake.constants import ( + GENERIC_PERMISSION_ERROR_KEY, + SNOWFLAKE_DATABASE, + SnowflakeObjectDomain, +) +from datahub.ingestion.source.snowflake.snowflake_config import ( + SnowflakeV2Config, + TagOption, +) +from datahub.ingestion.source.snowflake.snowflake_data_reader import SnowflakeDataReader +from datahub.ingestion.source.snowflake.snowflake_profiler import SnowflakeProfiler +from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report +from datahub.ingestion.source.snowflake.snowflake_schema import ( + SCHEMA_PARALLELISM, + SnowflakeColumn, + SnowflakeDatabase, + SnowflakeDataDictionary, + SnowflakeFK, + SnowflakePK, + SnowflakeSchema, + SnowflakeTable, + SnowflakeTag, + SnowflakeView, +) +from datahub.ingestion.source.snowflake.snowflake_tag import SnowflakeTagExtractor +from datahub.ingestion.source.snowflake.snowflake_utils import ( + SnowflakeCommonMixin, + SnowflakeCommonProtocol, + SnowflakeConnectionMixin, + SnowflakePermissionError, + SnowflakeQueryMixin, +) +from datahub.ingestion.source.sql.sql_utils import ( + add_table_to_schema_container, + gen_database_container, + gen_database_key, + gen_schema_container, + gen_schema_key, + get_dataplatform_instance_aspect, + get_domain_wu, +) +from datahub.ingestion.source_report.ingestion_stage import ( + METADATA_EXTRACTION, + PROFILING, +) +from datahub.metadata.com.linkedin.pegasus2avro.common import ( + GlobalTags, + Status, + SubTypes, + TagAssociation, + TimeStamp, +) +from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( + DatasetProperties, + ViewProperties, +) +from datahub.metadata.com.linkedin.pegasus2avro.schema import ( + ArrayType, + BooleanType, + BytesType, + DateType, + ForeignKeyConstraint, + MySqlDDL, + NullType, + NumberType, + RecordType, + SchemaField, + SchemaFieldDataType, + SchemaMetadata, + StringType, + TimeType, +) +from datahub.metadata.com.linkedin.pegasus2avro.tag import TagProperties +from datahub.sql_parsing.sql_parsing_aggregator import SqlParsingAggregator +from datahub.utilities.registries.domain_registry import DomainRegistry + +logger = logging.getLogger(__name__) + +# https://docs.snowflake.com/en/sql-reference/intro-summary-data-types.html +SNOWFLAKE_FIELD_TYPE_MAPPINGS = { + "DATE": DateType, + "BIGINT": NumberType, + "BINARY": BytesType, + # 'BIT': BIT, + "BOOLEAN": BooleanType, + "CHAR": NullType, + "CHARACTER": NullType, + "DATETIME": TimeType, + "DEC": NumberType, + "DECIMAL": NumberType, + "DOUBLE": NumberType, + "FIXED": NumberType, + "FLOAT": NumberType, + "INT": NumberType, + "INTEGER": NumberType, + "NUMBER": NumberType, + # 'OBJECT': ? + "REAL": NumberType, + "BYTEINT": NumberType, + "SMALLINT": NumberType, + "STRING": StringType, + "TEXT": StringType, + "TIME": TimeType, + "TIMESTAMP": TimeType, + "TIMESTAMP_TZ": TimeType, + "TIMESTAMP_LTZ": TimeType, + "TIMESTAMP_NTZ": TimeType, + "TINYINT": NumberType, + "VARBINARY": BytesType, + "VARCHAR": StringType, + "VARIANT": RecordType, + "OBJECT": NullType, + "ARRAY": ArrayType, + "GEOGRAPHY": NullType, +} + + +class SnowflakeSchemaGenerator( + SnowflakeQueryMixin, + SnowflakeConnectionMixin, + SnowflakeCommonMixin, + SnowflakeCommonProtocol, +): + def __init__( + self, + config: SnowflakeV2Config, + report: SnowflakeV2Report, + connection: SnowflakeConnection, + domain_registry: Optional[DomainRegistry], + profiler: Optional[SnowflakeProfiler], + aggregator: Optional[SqlParsingAggregator], + snowsight_base_url: Optional[str], + ) -> None: + self.config: SnowflakeV2Config = config + self.report: SnowflakeV2Report = report + self.connection: SnowflakeConnection = connection + self.logger = logger + + self.data_dictionary: SnowflakeDataDictionary = SnowflakeDataDictionary() + self.data_dictionary.set_connection(self.connection) + self.report.data_dictionary_cache = self.data_dictionary + + self.domain_registry: Optional[DomainRegistry] = domain_registry + self.classification_handler = ClassificationHandler(self.config, self.report) + self.tag_extractor = SnowflakeTagExtractor( + config, self.data_dictionary, self.report + ) + self.profiler: Optional[SnowflakeProfiler] = profiler + self.snowsight_base_url: Optional[str] = snowsight_base_url + + # These are populated as side-effects of get_workunits_internal. + self.databases: List[SnowflakeDatabase] = [] + self.aggregator: Optional[SqlParsingAggregator] = aggregator + + def get_connection(self) -> SnowflakeConnection: + return self.connection + + def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: + self.databases = [] + for database in self.get_databases() or []: + self.report.report_entity_scanned(database.name, "database") + if not self.config.database_pattern.allowed(database.name): + self.report.report_dropped(f"{database.name}.*") + else: + self.databases.append(database) + + if len(self.databases) == 0: + return + + try: + for snowflake_db in self.databases: + self.report.set_ingestion_stage(snowflake_db.name, METADATA_EXTRACTION) + yield from self._process_database(snowflake_db) + + except SnowflakePermissionError as e: + self.report_error(GENERIC_PERMISSION_ERROR_KEY, str(e)) + return + + def get_databases(self) -> Optional[List[SnowflakeDatabase]]: + try: + # `show databases` is required only to get one of the databases + # whose information_schema can be queried to start with. + databases = self.data_dictionary.show_databases() + except Exception as e: + logger.debug(f"Failed to list databases due to error {e}", exc_info=e) + self.report_error( + "list-databases", + f"Failed to list databases due to error {e}", + ) + return None + else: + ischema_databases: List[ + SnowflakeDatabase + ] = self.get_databases_from_ischema(databases) + + if len(ischema_databases) == 0: + self.report_error( + GENERIC_PERMISSION_ERROR_KEY, + "No databases found. Please check permissions.", + ) + return ischema_databases + + def get_databases_from_ischema( + self, databases: List[SnowflakeDatabase] + ) -> List[SnowflakeDatabase]: + ischema_databases: List[SnowflakeDatabase] = [] + for database in databases: + try: + ischema_databases = self.data_dictionary.get_databases(database.name) + break + except Exception: + # query fails if "USAGE" access is not granted for database + # This is okay, because `show databases` query lists all databases irrespective of permission, + # if role has `MANAGE GRANTS` privilege. (not advisable) + logger.debug( + f"Failed to list databases {database.name} information_schema" + ) + # SNOWFLAKE database always shows up even if permissions are missing + if database == SNOWFLAKE_DATABASE: + continue + logger.info( + f"The role {self.report.role} has `MANAGE GRANTS` privilege. This is not advisable and also not required." + ) + + return ischema_databases + + def _process_database( + self, snowflake_db: SnowflakeDatabase + ) -> Iterable[MetadataWorkUnit]: + db_name = snowflake_db.name + + try: + pass + # self.query(SnowflakeQuery.use_database(db_name)) + except Exception as e: + if isinstance(e, SnowflakePermissionError): + # This may happen if REFERENCE_USAGE permissions are set + # We can not run show queries on database in such case. + # This need not be a failure case. + self.report_warning( + "Insufficient privileges to operate on database, skipping. Please grant USAGE permissions on database to extract its metadata.", + db_name, + ) + else: + logger.debug( + f"Failed to use database {db_name} due to error {e}", + exc_info=e, + ) + self.report_warning( + "Failed to get schemas for database", + db_name, + ) + return + + if self.config.extract_tags != TagOption.skip: + snowflake_db.tags = self.tag_extractor.get_tags_on_object( + domain="database", db_name=db_name + ) + + if self.config.include_technical_schema: + yield from self.gen_database_containers(snowflake_db) + + self.fetch_schemas_for_database(snowflake_db, db_name) + + if self.config.include_technical_schema and snowflake_db.tags: + for tag in snowflake_db.tags: + yield from self._process_tag(tag) + + # Caches tables for a single database. Consider moving to disk or S3 when possible. + db_tables: Dict[str, List[SnowflakeTable]] = {} + yield from self._process_db_schemas(snowflake_db, db_tables) + + if self.profiler and db_tables: + self.report.set_ingestion_stage(snowflake_db.name, PROFILING) + yield from self.profiler.get_workunits(snowflake_db, db_tables) + + def _process_db_schemas( + self, + snowflake_db: SnowflakeDatabase, + db_tables: Dict[str, List[SnowflakeTable]], + ) -> Iterable[MetadataWorkUnit]: + q: "queue.Queue[MetadataWorkUnit]" = queue.Queue(maxsize=100) + + def _process_schema_worker(snowflake_schema: SnowflakeSchema) -> None: + for wu in self._process_schema( + snowflake_schema, snowflake_db.name, db_tables + ): + q.put(wu) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=SCHEMA_PARALLELISM + ) as executor: + futures = [] + for snowflake_schema in snowflake_db.schemas: + f = executor.submit(_process_schema_worker, snowflake_schema) + futures.append(f) + + # Read from the queue and yield the work units until all futures are done. + while True: + if q.empty(): + while not q.empty(): + yield q.get_nowait() + else: + try: + yield q.get(timeout=0.2) + except queue.Empty: + pass + + # Filter out the done futures. + futures = [f for f in futures if not f.done()] + if not futures: + break + + # Yield the remaining work units. This theoretically should not happen, but adding it just in case. + while not q.empty(): + yield q.get_nowait() + + def fetch_schemas_for_database( + self, snowflake_db: SnowflakeDatabase, db_name: str + ) -> None: + schemas: List[SnowflakeSchema] = [] + try: + for schema in self.data_dictionary.get_schemas_for_database(db_name): + self.report.report_entity_scanned(schema.name, "schema") + if not is_schema_allowed( + self.config.schema_pattern, + schema.name, + db_name, + self.config.match_fully_qualified_names, + ): + self.report.report_dropped(f"{db_name}.{schema.name}.*") + else: + schemas.append(schema) + except Exception as e: + if isinstance(e, SnowflakePermissionError): + error_msg = f"Failed to get schemas for database {db_name}. Please check permissions." + # Ideal implementation would use PEP 678 – Enriching Exceptions with Notes + raise SnowflakePermissionError(error_msg) from e.__cause__ + else: + logger.debug( + f"Failed to get schemas for database {db_name} due to error {e}", + exc_info=e, + ) + self.report_warning( + "Failed to get schemas for database", + db_name, + ) + + if not schemas: + self.report_warning( + "No schemas found in database. If schemas exist, please grant USAGE permissions on them.", + db_name, + ) + else: + snowflake_db.schemas = schemas + + def _process_schema( + self, + snowflake_schema: SnowflakeSchema, + db_name: str, + db_tables: Dict[str, List[SnowflakeTable]], + ) -> Iterable[MetadataWorkUnit]: + schema_name = snowflake_schema.name + + if self.config.extract_tags != TagOption.skip: + snowflake_schema.tags = self.tag_extractor.get_tags_on_object( + schema_name=schema_name, db_name=db_name, domain="schema" + ) + + if self.config.include_technical_schema: + yield from self.gen_schema_containers(snowflake_schema, db_name) + + if self.config.include_tables: + tables = self.fetch_tables_for_schema( + snowflake_schema, db_name, schema_name + ) + db_tables[schema_name] = tables + + if self.config.include_technical_schema: + data_reader = self.make_data_reader() + for table in tables: + table_wu_generator = self._process_table( + table, schema_name, db_name + ) + yield from classification_workunit_processor( + table_wu_generator, + self.classification_handler, + data_reader, + [db_name, schema_name, table.name], + ) + + if self.config.include_views: + views = self.fetch_views_for_schema(snowflake_schema, db_name, schema_name) + if ( + self.aggregator + and self.config.include_view_lineage + and self.config.parse_view_ddl + ): + for view in views: + view_identifier = self.get_dataset_identifier( + view.name, schema_name, db_name + ) + if view.view_definition: + self.aggregator.add_view_definition( + view_urn=self.gen_dataset_urn(view_identifier), + view_definition=view.view_definition, + default_db=db_name, + default_schema=schema_name, + ) + + if self.config.include_technical_schema: + for view in views: + yield from self._process_view(view, schema_name, db_name) + + if self.config.include_technical_schema and snowflake_schema.tags: + for tag in snowflake_schema.tags: + yield from self._process_tag(tag) + + if not snowflake_schema.views and not snowflake_schema.tables: + self.report_warning( + "No tables/views found in schema. If tables exist, please grant REFERENCES or SELECT permissions on them.", + f"{db_name}.{schema_name}", + ) + + def fetch_views_for_schema( + self, snowflake_schema: SnowflakeSchema, db_name: str, schema_name: str + ) -> List[SnowflakeView]: + try: + views: List[SnowflakeView] = [] + for view in self.get_views_for_schema(schema_name, db_name): + view_name = self.get_dataset_identifier(view.name, schema_name, db_name) + + self.report.report_entity_scanned(view_name, "view") + + if not self.config.view_pattern.allowed(view_name): + self.report.report_dropped(view_name) + else: + views.append(view) + snowflake_schema.views = [view.name for view in views] + return views + except Exception as e: + if isinstance(e, SnowflakePermissionError): + # Ideal implementation would use PEP 678 – Enriching Exceptions with Notes + error_msg = f"Failed to get views for schema {db_name}.{schema_name}. Please check permissions." + + raise SnowflakePermissionError(error_msg) from e.__cause__ + else: + logger.debug( + f"Failed to get views for schema {db_name}.{schema_name} due to error {e}", + exc_info=e, + ) + self.report_warning( + "Failed to get views for schema", + f"{db_name}.{schema_name}", + ) + return [] + + def fetch_tables_for_schema( + self, snowflake_schema: SnowflakeSchema, db_name: str, schema_name: str + ) -> List[SnowflakeTable]: + try: + tables: List[SnowflakeTable] = [] + for table in self.get_tables_for_schema(schema_name, db_name): + table_identifier = self.get_dataset_identifier( + table.name, schema_name, db_name + ) + self.report.report_entity_scanned(table_identifier) + if not self.config.table_pattern.allowed(table_identifier): + self.report.report_dropped(table_identifier) + else: + tables.append(table) + snowflake_schema.tables = [table.name for table in tables] + return tables + except Exception as e: + if isinstance(e, SnowflakePermissionError): + # Ideal implementation would use PEP 678 – Enriching Exceptions with Notes + error_msg = f"Failed to get tables for schema {db_name}.{schema_name}. Please check permissions." + raise SnowflakePermissionError(error_msg) from e.__cause__ + else: + logger.debug( + f"Failed to get tables for schema {db_name}.{schema_name} due to error {e}", + exc_info=e, + ) + self.report_warning( + "Failed to get tables for schema", + f"{db_name}.{schema_name}", + ) + return [] + + def make_data_reader(self) -> Optional[SnowflakeDataReader]: + if self.classification_handler.is_classification_enabled() and self.connection: + return SnowflakeDataReader.create( + self.connection, self.snowflake_identifier + ) + + return None + + def _process_table( + self, + table: SnowflakeTable, + schema_name: str, + db_name: str, + ) -> Iterable[MetadataWorkUnit]: + table_identifier = self.get_dataset_identifier(table.name, schema_name, db_name) + + self.fetch_columns_for_table(table, schema_name, db_name, table_identifier) + + self.fetch_pk_for_table(table, schema_name, db_name, table_identifier) + + self.fetch_foreign_keys_for_table(table, schema_name, db_name, table_identifier) + + if self.config.extract_tags != TagOption.skip: + table.tags = self.tag_extractor.get_tags_on_object( + table_name=table.name, + schema_name=schema_name, + db_name=db_name, + domain="table", + ) + + if self.config.include_technical_schema: + if table.tags: + for tag in table.tags: + yield from self._process_tag(tag) + for column_name in table.column_tags: + for tag in table.column_tags[column_name]: + yield from self._process_tag(tag) + + yield from self.gen_dataset_workunits(table, schema_name, db_name) + + def fetch_foreign_keys_for_table( + self, + table: SnowflakeTable, + schema_name: str, + db_name: str, + table_identifier: str, + ) -> None: + try: + table.foreign_keys = self.get_fk_constraints_for_table( + table.name, schema_name, db_name + ) + except Exception as e: + logger.debug( + f"Failed to get foreign key for table {table_identifier} due to error {e}", + exc_info=e, + ) + self.report_warning("Failed to get foreign key for table", table_identifier) + + def fetch_pk_for_table( + self, + table: SnowflakeTable, + schema_name: str, + db_name: str, + table_identifier: str, + ) -> None: + try: + table.pk = self.get_pk_constraints_for_table( + table.name, schema_name, db_name + ) + except Exception as e: + logger.debug( + f"Failed to get primary key for table {table_identifier} due to error {e}", + exc_info=e, + ) + self.report_warning("Failed to get primary key for table", table_identifier) + + def fetch_columns_for_table( + self, + table: SnowflakeTable, + schema_name: str, + db_name: str, + table_identifier: str, + ) -> None: + try: + table.columns = self.get_columns_for_table(table.name, schema_name, db_name) + table.column_count = len(table.columns) + if self.config.extract_tags != TagOption.skip: + table.column_tags = self.tag_extractor.get_column_tags_for_table( + table.name, schema_name, db_name + ) + except Exception as e: + logger.debug( + f"Failed to get columns for table {table_identifier} due to error {e}", + exc_info=e, + ) + self.report_warning("Failed to get columns for table", table_identifier) + + def _process_view( + self, + view: SnowflakeView, + schema_name: str, + db_name: str, + ) -> Iterable[MetadataWorkUnit]: + view_name = self.get_dataset_identifier(view.name, schema_name, db_name) + + try: + view.columns = self.get_columns_for_table(view.name, schema_name, db_name) + if self.config.extract_tags != TagOption.skip: + view.column_tags = self.tag_extractor.get_column_tags_for_table( + view.name, schema_name, db_name + ) + except Exception as e: + logger.debug( + f"Failed to get columns for view {view_name} due to error {e}", + exc_info=e, + ) + self.report_warning("Failed to get columns for view", view_name) + + if self.config.extract_tags != TagOption.skip: + view.tags = self.tag_extractor.get_tags_on_object( + table_name=view.name, + schema_name=schema_name, + db_name=db_name, + domain="table", + ) + + if self.config.include_technical_schema: + if view.tags: + for tag in view.tags: + yield from self._process_tag(tag) + for column_name in view.column_tags: + for tag in view.column_tags[column_name]: + yield from self._process_tag(tag) + + yield from self.gen_dataset_workunits(view, schema_name, db_name) + + def _process_tag(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]: + tag_identifier = tag.identifier() + + if self.report.is_tag_processed(tag_identifier): + return + + self.report.report_tag_processed(tag_identifier) + + yield from self.gen_tag_workunits(tag) + + def gen_dataset_workunits( + self, + table: Union[SnowflakeTable, SnowflakeView], + schema_name: str, + db_name: str, + ) -> Iterable[MetadataWorkUnit]: + dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name) + dataset_urn = self.gen_dataset_urn(dataset_name) + + status = Status(removed=False) + yield MetadataChangeProposalWrapper( + entityUrn=dataset_urn, aspect=status + ).as_workunit() + + schema_metadata = self.gen_schema_metadata(table, schema_name, db_name) + + yield MetadataChangeProposalWrapper( + entityUrn=dataset_urn, aspect=schema_metadata + ).as_workunit() + + dataset_properties = self.get_dataset_properties(table, schema_name, db_name) + + yield MetadataChangeProposalWrapper( + entityUrn=dataset_urn, aspect=dataset_properties + ).as_workunit() + + schema_container_key = gen_schema_key( + db_name=self.snowflake_identifier(db_name), + schema=self.snowflake_identifier(schema_name), + platform=self.platform, + platform_instance=self.config.platform_instance, + env=self.config.env, + ) + + yield from add_table_to_schema_container( + dataset_urn=dataset_urn, + parent_container_key=schema_container_key, + ) + dpi_aspect = get_dataplatform_instance_aspect( + dataset_urn=dataset_urn, + platform=self.platform, + platform_instance=self.config.platform_instance, + ) + if dpi_aspect: + yield dpi_aspect + + subTypes = SubTypes( + typeNames=( + [DatasetSubTypes.VIEW] + if isinstance(table, SnowflakeView) + else [DatasetSubTypes.TABLE] + ) + ) + + yield MetadataChangeProposalWrapper( + entityUrn=dataset_urn, aspect=subTypes + ).as_workunit() + + if self.domain_registry: + yield from get_domain_wu( + dataset_name=dataset_name, + entity_urn=dataset_urn, + domain_config=self.config.domain, + domain_registry=self.domain_registry, + ) + + if table.tags: + tag_associations = [ + TagAssociation( + tag=make_tag_urn(self.snowflake_identifier(tag.identifier())) + ) + for tag in table.tags + ] + global_tags = GlobalTags(tag_associations) + yield MetadataChangeProposalWrapper( + entityUrn=dataset_urn, aspect=global_tags + ).as_workunit() + + if isinstance(table, SnowflakeView) and table.view_definition is not None: + view_properties_aspect = ViewProperties( + materialized=table.materialized, + viewLanguage="SQL", + viewLogic=table.view_definition, + ) + + yield MetadataChangeProposalWrapper( + entityUrn=dataset_urn, aspect=view_properties_aspect + ).as_workunit() + + def get_dataset_properties( + self, + table: Union[SnowflakeTable, SnowflakeView], + schema_name: str, + db_name: str, + ) -> DatasetProperties: + return DatasetProperties( + name=table.name, + created=( + TimeStamp(time=int(table.created.timestamp() * 1000)) + if table.created is not None + else None + ), + lastModified=( + TimeStamp(time=int(table.last_altered.timestamp() * 1000)) + if table.last_altered is not None + else ( + TimeStamp(time=int(table.created.timestamp() * 1000)) + if table.created is not None + else None + ) + ), + description=table.comment, + qualifiedName=f"{db_name}.{schema_name}.{table.name}", + customProperties={}, + externalUrl=( + self.get_external_url_for_table( + table.name, + schema_name, + db_name, + ( + SnowflakeObjectDomain.TABLE + if isinstance(table, SnowflakeTable) + else SnowflakeObjectDomain.VIEW + ), + ) + if self.config.include_external_url + else None + ), + ) + + def gen_tag_workunits(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]: + tag_urn = make_tag_urn(self.snowflake_identifier(tag.identifier())) + + tag_properties_aspect = TagProperties( + name=tag.display_name(), + description=f"Represents the Snowflake tag `{tag._id_prefix_as_str()}` with value `{tag.value}`.", + ) + + yield MetadataChangeProposalWrapper( + entityUrn=tag_urn, aspect=tag_properties_aspect + ).as_workunit() + + def gen_schema_metadata( + self, + table: Union[SnowflakeTable, SnowflakeView], + schema_name: str, + db_name: str, + ) -> SchemaMetadata: + dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name) + dataset_urn = self.gen_dataset_urn(dataset_name) + + foreign_keys: Optional[List[ForeignKeyConstraint]] = None + if isinstance(table, SnowflakeTable) and len(table.foreign_keys) > 0: + foreign_keys = self.build_foreign_keys(table, dataset_urn) + + schema_metadata = SchemaMetadata( + schemaName=dataset_name, + platform=make_data_platform_urn(self.platform), + version=0, + hash="", + platformSchema=MySqlDDL(tableSchema=""), + fields=[ + SchemaField( + fieldPath=self.snowflake_identifier(col.name), + type=SchemaFieldDataType( + SNOWFLAKE_FIELD_TYPE_MAPPINGS.get(col.data_type, NullType)() + ), + # NOTE: nativeDataType will not be in sync with older connector + nativeDataType=col.get_precise_native_type(), + description=col.comment, + nullable=col.is_nullable, + isPartOfKey=( + col.name in table.pk.column_names + if isinstance(table, SnowflakeTable) and table.pk is not None + else None + ), + globalTags=( + GlobalTags( + [ + TagAssociation( + make_tag_urn( + self.snowflake_identifier(tag.identifier()) + ) + ) + for tag in table.column_tags[col.name] + ] + ) + if col.name in table.column_tags + else None + ), + ) + for col in table.columns + ], + foreignKeys=foreign_keys, + ) + + if self.aggregator: + self.aggregator.register_schema(urn=dataset_urn, schema=schema_metadata) + + return schema_metadata + + def build_foreign_keys( + self, table: SnowflakeTable, dataset_urn: str + ) -> List[ForeignKeyConstraint]: + foreign_keys = [] + for fk in table.foreign_keys: + foreign_dataset = make_dataset_urn_with_platform_instance( + platform=self.platform, + name=self.get_dataset_identifier( + fk.referred_table, fk.referred_schema, fk.referred_database + ), + env=self.config.env, + platform_instance=self.config.platform_instance, + ) + foreign_keys.append( + ForeignKeyConstraint( + name=fk.name, + foreignDataset=foreign_dataset, + foreignFields=[ + make_schema_field_urn( + foreign_dataset, + self.snowflake_identifier(col), + ) + for col in fk.referred_column_names + ], + sourceFields=[ + make_schema_field_urn( + dataset_urn, + self.snowflake_identifier(col), + ) + for col in fk.column_names + ], + ) + ) + return foreign_keys + + def gen_database_containers( + self, database: SnowflakeDatabase + ) -> Iterable[MetadataWorkUnit]: + database_container_key = gen_database_key( + self.snowflake_identifier(database.name), + platform=self.platform, + platform_instance=self.config.platform_instance, + env=self.config.env, + ) + + yield from gen_database_container( + name=database.name, + database=self.snowflake_identifier(database.name), + database_container_key=database_container_key, + sub_types=[DatasetContainerSubTypes.DATABASE], + domain_registry=self.domain_registry, + domain_config=self.config.domain, + external_url=( + self.get_external_url_for_database(database.name) + if self.config.include_external_url + else None + ), + description=database.comment, + created=( + int(database.created.timestamp() * 1000) + if database.created is not None + else None + ), + last_modified=( + int(database.last_altered.timestamp() * 1000) + if database.last_altered is not None + else ( + int(database.created.timestamp() * 1000) + if database.created is not None + else None + ) + ), + tags=( + [self.snowflake_identifier(tag.identifier()) for tag in database.tags] + if database.tags + else None + ), + ) + + def gen_schema_containers( + self, schema: SnowflakeSchema, db_name: str + ) -> Iterable[MetadataWorkUnit]: + schema_name = self.snowflake_identifier(schema.name) + database_container_key = gen_database_key( + database=self.snowflake_identifier(db_name), + platform=self.platform, + platform_instance=self.config.platform_instance, + env=self.config.env, + ) + + schema_container_key = gen_schema_key( + db_name=self.snowflake_identifier(db_name), + schema=schema_name, + platform=self.platform, + platform_instance=self.config.platform_instance, + env=self.config.env, + ) + + yield from gen_schema_container( + name=schema.name, + schema=self.snowflake_identifier(schema.name), + database=self.snowflake_identifier(db_name), + database_container_key=database_container_key, + domain_config=self.config.domain, + schema_container_key=schema_container_key, + sub_types=[DatasetContainerSubTypes.SCHEMA], + domain_registry=self.domain_registry, + description=schema.comment, + external_url=( + self.get_external_url_for_schema(schema.name, db_name) + if self.config.include_external_url + else None + ), + created=( + int(schema.created.timestamp() * 1000) + if schema.created is not None + else None + ), + last_modified=( + int(schema.last_altered.timestamp() * 1000) + if schema.last_altered is not None + else ( + int(schema.created.timestamp() * 1000) + if schema.created is not None + else None + ) + ), + tags=( + [self.snowflake_identifier(tag.identifier()) for tag in schema.tags] + if schema.tags + else None + ), + ) + + def get_tables_for_schema( + self, schema_name: str, db_name: str + ) -> List[SnowflakeTable]: + tables = self.data_dictionary.get_tables_for_database(db_name) + + # get all tables for database failed, + # falling back to get tables for schema + if tables is None: + self.report.num_get_tables_for_schema_queries += 1 + return self.data_dictionary.get_tables_for_schema(schema_name, db_name) + + # Some schema may not have any table + return tables.get(schema_name, []) + + def get_views_for_schema( + self, schema_name: str, db_name: str + ) -> List[SnowflakeView]: + views = self.data_dictionary.get_views_for_database(db_name) + + # get all views for database failed, + # falling back to get views for schema + if views is None: + self.report.num_get_views_for_schema_queries += 1 + return self.data_dictionary.get_views_for_schema(schema_name, db_name) + + # Some schema may not have any table + return views.get(schema_name, []) + + def get_columns_for_table( + self, table_name: str, schema_name: str, db_name: str + ) -> List[SnowflakeColumn]: + columns = self.data_dictionary.get_columns_for_schema(schema_name, db_name) + + # get all columns for schema failed, + # falling back to get columns for table + if columns is None: + self.report.num_get_columns_for_table_queries += 1 + return self.data_dictionary.get_columns_for_table( + table_name, schema_name, db_name + ) + + # Access to table but none of its columns - is this possible ? + return columns.get(table_name, []) + + def get_pk_constraints_for_table( + self, table_name: str, schema_name: str, db_name: str + ) -> Optional[SnowflakePK]: + constraints = self.data_dictionary.get_pk_constraints_for_schema( + schema_name, db_name + ) + + # Access to table but none of its constraints - is this possible ? + return constraints.get(table_name) + + def get_fk_constraints_for_table( + self, table_name: str, schema_name: str, db_name: str + ) -> List[SnowflakeFK]: + constraints = self.data_dictionary.get_fk_constraints_for_schema( + schema_name, db_name + ) + + # Access to table but none of its constraints - is this possible ? + return constraints.get(table_name, []) + + # domain is either "view" or "table" + def get_external_url_for_table( + self, table_name: str, schema_name: str, db_name: str, domain: str + ) -> Optional[str]: + if self.snowsight_base_url is not None: + return f"{self.snowsight_base_url}#/data/databases/{db_name}/schemas/{schema_name}/{domain}/{table_name}/" + return None + + def get_external_url_for_schema( + self, schema_name: str, db_name: str + ) -> Optional[str]: + if self.snowsight_base_url is not None: + return f"{self.snowsight_base_url}#/data/databases/{db_name}/schemas/{schema_name}/" + return None + + def get_external_url_for_database(self, db_name: str) -> Optional[str]: + if self.snowsight_base_url is not None: + return f"{self.snowsight_base_url}#/data/databases/{db_name}/" + return None diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py index ef08866ccd3ede..cd6f17092e810a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py @@ -17,12 +17,14 @@ SnowflakeDatabase, SnowflakeDataDictionary, ) +from datahub.ingestion.source.snowflake.snowflake_schema_gen import ( + SnowflakeSchemaGenerator, +) from datahub.ingestion.source.snowflake.snowflake_utils import ( SnowflakeCommonMixin, SnowflakeConnectionMixin, SnowflakeQueryMixin, ) -from datahub.ingestion.source.snowflake.snowflake_v2 import SnowflakeV2Source from datahub.ingestion.source_config.sql.snowflake import BaseSnowflakeConfig from datahub.ingestion.source_report.time_window import BaseTimeWindowReport from datahub.utilities.lossy_collections import LossyList @@ -167,13 +169,13 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # This is a bit of a hack, but lets us reuse the code from the main ingestion source. # Mypy doesn't really know how to deal with it though, which is why we have all these # type ignore comments. - get_databases = SnowflakeV2Source.get_databases - get_databases_from_ischema = SnowflakeV2Source.get_databases_from_ischema - fetch_schemas_for_database = SnowflakeV2Source.fetch_schemas_for_database - fetch_tables_for_schema = SnowflakeV2Source.fetch_tables_for_schema - fetch_views_for_schema = SnowflakeV2Source.fetch_views_for_schema - get_tables_for_schema = SnowflakeV2Source.get_tables_for_schema - get_views_for_schema = SnowflakeV2Source.get_views_for_schema + get_databases = SnowflakeSchemaGenerator.get_databases + get_databases_from_ischema = SnowflakeSchemaGenerator.get_databases_from_ischema + fetch_schemas_for_database = SnowflakeSchemaGenerator.fetch_schemas_for_database + fetch_tables_for_schema = SnowflakeSchemaGenerator.fetch_tables_for_schema + fetch_views_for_schema = SnowflakeSchemaGenerator.fetch_views_for_schema + get_tables_for_schema = SnowflakeSchemaGenerator.get_tables_for_schema + get_views_for_schema = SnowflakeSchemaGenerator.get_views_for_schema def get_report(self) -> SnowflakeSummaryReport: return self.report diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py index adcc4ba09d8c9e..02942556093f9d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py @@ -7,6 +7,7 @@ from datahub.configuration.common import MetaError from datahub.configuration.pattern_utils import is_schema_allowed +from datahub.emitter.mce_builder import make_dataset_urn_with_platform_instance from datahub.ingestion.source.snowflake.constants import ( GENERIC_PERMISSION_ERROR_KEY, SNOWFLAKE_REGION_CLOUD_REGION_MAPPING, @@ -48,6 +49,8 @@ def query(self: SnowflakeQueryProtocol, query: str) -> Any: class SnowflakeCommonProtocol(SnowflakeLoggingProtocol, Protocol): + platform: str = "snowflake" + config: SnowflakeV2Config report: SnowflakeV2Report @@ -178,6 +181,14 @@ def snowflake_identifier(self: SnowflakeCommonProtocol, identifier: str) -> str: return identifier.lower() return identifier + def gen_dataset_urn(self: SnowflakeCommonProtocol, dataset_identifier: str) -> str: + return make_dataset_urn_with_platform_instance( + platform=self.platform, + name=dataset_identifier, + platform_instance=self.config.platform_instance, + env=self.config.env, + ) + @staticmethod def get_quoted_identifier_for_database(db_name): return f'"{db_name}"' diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index f155ac24fea3fc..06d7042e02456c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -5,18 +5,10 @@ import os.path import platform from dataclasses import dataclass -from typing import Callable, Dict, Iterable, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Union from snowflake.connector import SnowflakeConnection -from datahub.configuration.pattern_utils import is_schema_allowed -from datahub.emitter.mce_builder import ( - make_data_platform_urn, - make_dataset_urn_with_platform_instance, - make_schema_field_urn, - make_tag_urn, -) -from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.decorators import ( SupportStatus, @@ -36,65 +28,35 @@ TestConnectionReport, ) from datahub.ingestion.api.workunit import MetadataWorkUnit -from datahub.ingestion.glossary.classification_mixin import ( - ClassificationHandler, - classification_workunit_processor, -) -from datahub.ingestion.source.common.subtypes import ( - DatasetContainerSubTypes, - DatasetSubTypes, -) from datahub.ingestion.source.snowflake.constants import ( GENERIC_PERMISSION_ERROR_KEY, - SNOWFLAKE_DATABASE, SnowflakeEdition, - SnowflakeObjectDomain, ) from datahub.ingestion.source.snowflake.snowflake_assertion import ( SnowflakeAssertionsHandler, ) -from datahub.ingestion.source.snowflake.snowflake_config import ( - SnowflakeV2Config, - TagOption, -) -from datahub.ingestion.source.snowflake.snowflake_data_reader import SnowflakeDataReader +from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config from datahub.ingestion.source.snowflake.snowflake_lineage_v2 import ( SnowflakeLineageExtractor, ) from datahub.ingestion.source.snowflake.snowflake_profiler import SnowflakeProfiler from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report from datahub.ingestion.source.snowflake.snowflake_schema import ( - SnowflakeColumn, - SnowflakeDatabase, SnowflakeDataDictionary, - SnowflakeFK, - SnowflakePK, SnowflakeQuery, - SnowflakeSchema, - SnowflakeTable, - SnowflakeTag, - SnowflakeView, +) +from datahub.ingestion.source.snowflake.snowflake_schema_gen import ( + SnowflakeSchemaGenerator, ) from datahub.ingestion.source.snowflake.snowflake_shares import SnowflakeSharesHandler -from datahub.ingestion.source.snowflake.snowflake_tag import SnowflakeTagExtractor from datahub.ingestion.source.snowflake.snowflake_usage_v2 import ( SnowflakeUsageExtractor, ) from datahub.ingestion.source.snowflake.snowflake_utils import ( SnowflakeCommonMixin, SnowflakeConnectionMixin, - SnowflakePermissionError, SnowflakeQueryMixin, ) -from datahub.ingestion.source.sql.sql_utils import ( - add_table_to_schema_container, - gen_database_container, - gen_database_key, - gen_schema_container, - gen_schema_key, - get_dataplatform_instance_aspect, - get_domain_wu, -) from datahub.ingestion.source.state.profiling_state_handler import ProfilingHandler from datahub.ingestion.source.state.redundant_run_skip_handler import ( RedundantLineageRunSkipHandler, @@ -110,79 +72,12 @@ from datahub.ingestion.source_report.ingestion_stage import ( LINEAGE_EXTRACTION, METADATA_EXTRACTION, - PROFILING, -) -from datahub.metadata.com.linkedin.pegasus2avro.common import ( - GlobalTags, - Status, - SubTypes, - TagAssociation, - TimeStamp, ) -from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( - DatasetProperties, - ViewProperties, -) -from datahub.metadata.com.linkedin.pegasus2avro.schema import ( - ArrayType, - BooleanType, - BytesType, - DateType, - ForeignKeyConstraint, - MySqlDDL, - NullType, - NumberType, - RecordType, - SchemaField, - SchemaFieldDataType, - SchemaMetadata, - StringType, - TimeType, -) -from datahub.metadata.com.linkedin.pegasus2avro.tag import TagProperties from datahub.sql_parsing.sql_parsing_aggregator import SqlParsingAggregator from datahub.utilities.registries.domain_registry import DomainRegistry logger: logging.Logger = logging.getLogger(__name__) -# https://docs.snowflake.com/en/sql-reference/intro-summary-data-types.html -SNOWFLAKE_FIELD_TYPE_MAPPINGS = { - "DATE": DateType, - "BIGINT": NumberType, - "BINARY": BytesType, - # 'BIT': BIT, - "BOOLEAN": BooleanType, - "CHAR": NullType, - "CHARACTER": NullType, - "DATETIME": TimeType, - "DEC": NumberType, - "DECIMAL": NumberType, - "DOUBLE": NumberType, - "FIXED": NumberType, - "FLOAT": NumberType, - "INT": NumberType, - "INTEGER": NumberType, - "NUMBER": NumberType, - # 'OBJECT': ? - "REAL": NumberType, - "BYTEINT": NumberType, - "SMALLINT": NumberType, - "STRING": StringType, - "TEXT": StringType, - "TIME": TimeType, - "TIMESTAMP": TimeType, - "TIMESTAMP_TZ": TimeType, - "TIMESTAMP_LTZ": TimeType, - "TIMESTAMP_NTZ": TimeType, - "TINYINT": NumberType, - "VARBINARY": BytesType, - "VARCHAR": StringType, - "VARIANT": RecordType, - "OBJECT": NullType, - "ARRAY": ArrayType, - "GEOGRAPHY": NullType, -} - @platform_name("Snowflake", doc_order=1) @config_class(SnowflakeV2Config) @@ -235,7 +130,6 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): self.config: SnowflakeV2Config = config self.report: SnowflakeV2Report = SnowflakeV2Report() self.logger = logger - self.snowsight_base_url: Optional[str] = None self.connection: Optional[SnowflakeConnection] = None self.domain_registry: Optional[DomainRegistry] = None @@ -309,10 +203,6 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): redundant_run_skip_handler=redundant_usage_run_skip_handler, ) - self.tag_extractor = SnowflakeTagExtractor( - config, self.data_dictionary, self.report - ) - self.profiling_state_handler: Optional[ProfilingHandler] = None if self.config.enable_stateful_profiling: self.profiling_state_handler = ProfilingHandler( @@ -322,16 +212,13 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): run_id=self.ctx.run_id, ) + # For profiling + self.profiler: Optional[SnowflakeProfiler] = None if config.is_profiling_enabled(): - # For profiling self.profiler = SnowflakeProfiler( config, self.report, self.profiling_state_handler ) - self.classification_handler = ClassificationHandler(self.config, self.report) - - # Caches tables for a single database. Consider moving to disk or S3 when possible. - self.db_tables: Dict[str, List[SnowflakeTable]] = {} self.add_config_to_report() @classmethod @@ -543,41 +430,31 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: self.inspect_session_metadata() + snowsight_base_url = None if self.config.include_external_url: - self.snowsight_base_url = self.get_snowsight_base_url() + snowsight_base_url = self.get_snowsight_base_url() if self.report.default_warehouse is None: self.report_warehouse_failure() return - self.data_dictionary.set_connection(self.connection) - databases: List[SnowflakeDatabase] = [] - - for database in self.get_databases() or []: - self.report.report_entity_scanned(database.name, "database") - if not self.config.database_pattern.allowed(database.name): - self.report.report_dropped(f"{database.name}.*") - else: - databases.append(database) - - if len(databases) == 0: - return + schema_extractor = SnowflakeSchemaGenerator( + config=self.config, + report=self.report, + connection=self.connection, + domain_registry=self.domain_registry, + profiler=self.profiler, + aggregator=self.aggregator, + snowsight_base_url=snowsight_base_url, + ) - for snowflake_db in databases: - try: - self.report.set_ingestion_stage(snowflake_db.name, METADATA_EXTRACTION) - yield from self._process_database(snowflake_db) + self.report.set_ingestion_stage("*", METADATA_EXTRACTION) + yield from schema_extractor.get_workunits_internal() - except SnowflakePermissionError as e: - # FIXME - This may break stateful ingestion if new tables than previous run are emitted above - # and stateful ingestion is enabled - self.report_error(GENERIC_PERMISSION_ERROR_KEY, str(e)) - return + databases = schema_extractor.databases self.connection.close() - self.report_cache_info() - # TODO: The checkpoint state for stale entity detection can be committed here. if self.config.shares: @@ -624,17 +501,6 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: self.config, self.report, self.gen_dataset_urn ).get_assertion_workunits(discovered_datasets) - def report_cache_info(self) -> None: - lru_cache_functions: List[Callable] = [ - self.data_dictionary.get_tables_for_database, - self.data_dictionary.get_views_for_database, - self.data_dictionary.get_columns_for_schema, - self.data_dictionary.get_pk_constraints_for_schema, - self.data_dictionary.get_fk_constraints_for_schema, - ] - for func in lru_cache_functions: - self.report.lru_cache_info[func.__name__] = func.cache_info()._asdict() # type: ignore - def report_warehouse_failure(self) -> None: if self.config.warehouse is not None: self.report_error( @@ -647,828 +513,9 @@ def report_warehouse_failure(self) -> None: "No default warehouse set for user. Either set default warehouse for user or configure warehouse in recipe.", ) - def get_databases(self) -> Optional[List[SnowflakeDatabase]]: - try: - # `show databases` is required only to get one of the databases - # whose information_schema can be queried to start with. - databases = self.data_dictionary.show_databases() - except Exception as e: - logger.debug(f"Failed to list databases due to error {e}", exc_info=e) - self.report_error( - "list-databases", - f"Failed to list databases due to error {e}", - ) - return None - else: - ischema_databases: List[ - SnowflakeDatabase - ] = self.get_databases_from_ischema(databases) - - if len(ischema_databases) == 0: - self.report_error( - GENERIC_PERMISSION_ERROR_KEY, - "No databases found. Please check permissions.", - ) - return ischema_databases - - def get_databases_from_ischema( - self, databases: List[SnowflakeDatabase] - ) -> List[SnowflakeDatabase]: - ischema_databases: List[SnowflakeDatabase] = [] - for database in databases: - try: - ischema_databases = self.data_dictionary.get_databases(database.name) - break - except Exception: - # query fails if "USAGE" access is not granted for database - # This is okay, because `show databases` query lists all databases irrespective of permission, - # if role has `MANAGE GRANTS` privilege. (not advisable) - logger.debug( - f"Failed to list databases {database.name} information_schema" - ) - # SNOWFLAKE database always shows up even if permissions are missing - if database == SNOWFLAKE_DATABASE: - continue - logger.info( - f"The role {self.report.role} has `MANAGE GRANTS` privilege. This is not advisable and also not required." - ) - - return ischema_databases - - def _process_database( - self, snowflake_db: SnowflakeDatabase - ) -> Iterable[MetadataWorkUnit]: - db_name = snowflake_db.name - - try: - self.query(SnowflakeQuery.use_database(db_name)) - except Exception as e: - if isinstance(e, SnowflakePermissionError): - # This may happen if REFERENCE_USAGE permissions are set - # We can not run show queries on database in such case. - # This need not be a failure case. - self.report_warning( - "Insufficient privileges to operate on database, skipping. Please grant USAGE permissions on database to extract its metadata.", - db_name, - ) - else: - logger.debug( - f"Failed to use database {db_name} due to error {e}", - exc_info=e, - ) - self.report_warning( - "Failed to get schemas for database", - db_name, - ) - return - - if self.config.extract_tags != TagOption.skip: - snowflake_db.tags = self.tag_extractor.get_tags_on_object( - domain="database", db_name=db_name - ) - - if self.config.include_technical_schema: - yield from self.gen_database_containers(snowflake_db) - - self.fetch_schemas_for_database(snowflake_db, db_name) - - if self.config.include_technical_schema and snowflake_db.tags: - for tag in snowflake_db.tags: - yield from self._process_tag(tag) - - self.db_tables = {} - for snowflake_schema in snowflake_db.schemas: - yield from self._process_schema(snowflake_schema, db_name) - - if self.config.is_profiling_enabled() and self.db_tables: - self.report.set_ingestion_stage(snowflake_db.name, PROFILING) - yield from self.profiler.get_workunits(snowflake_db, self.db_tables) - - def fetch_schemas_for_database( - self, snowflake_db: SnowflakeDatabase, db_name: str - ) -> None: - schemas: List[SnowflakeSchema] = [] - try: - for schema in self.data_dictionary.get_schemas_for_database(db_name): - self.report.report_entity_scanned(schema.name, "schema") - if not is_schema_allowed( - self.config.schema_pattern, - schema.name, - db_name, - self.config.match_fully_qualified_names, - ): - self.report.report_dropped(f"{db_name}.{schema.name}.*") - else: - schemas.append(schema) - except Exception as e: - if isinstance(e, SnowflakePermissionError): - error_msg = f"Failed to get schemas for database {db_name}. Please check permissions." - # Ideal implementation would use PEP 678 – Enriching Exceptions with Notes - raise SnowflakePermissionError(error_msg) from e.__cause__ - else: - logger.debug( - f"Failed to get schemas for database {db_name} due to error {e}", - exc_info=e, - ) - self.report_warning( - "Failed to get schemas for database", - db_name, - ) - - if not schemas: - self.report_warning( - "No schemas found in database. If schemas exist, please grant USAGE permissions on them.", - db_name, - ) - else: - snowflake_db.schemas = schemas - - def _process_schema( - self, snowflake_schema: SnowflakeSchema, db_name: str - ) -> Iterable[MetadataWorkUnit]: - schema_name = snowflake_schema.name - - if self.config.extract_tags != TagOption.skip: - snowflake_schema.tags = self.tag_extractor.get_tags_on_object( - schema_name=schema_name, db_name=db_name, domain="schema" - ) - - if self.config.include_technical_schema: - yield from self.gen_schema_containers(snowflake_schema, db_name) - - if self.config.include_tables: - tables = self.fetch_tables_for_schema( - snowflake_schema, db_name, schema_name - ) - self.db_tables[schema_name] = tables - - if self.config.include_technical_schema: - data_reader = self.make_data_reader() - for table in tables: - table_wu_generator = self._process_table( - table, schema_name, db_name - ) - yield from classification_workunit_processor( - table_wu_generator, - self.classification_handler, - data_reader, - [db_name, schema_name, table.name], - ) - - if self.config.include_views: - views = self.fetch_views_for_schema(snowflake_schema, db_name, schema_name) - if ( - self.aggregator - and self.config.include_view_lineage - and self.config.parse_view_ddl - ): - for view in views: - view_identifier = self.get_dataset_identifier( - view.name, schema_name, db_name - ) - if view.view_definition: - self.aggregator.add_view_definition( - view_urn=self.gen_dataset_urn(view_identifier), - view_definition=view.view_definition, - default_db=db_name, - default_schema=schema_name, - ) - - if self.config.include_technical_schema: - for view in views: - yield from self._process_view(view, schema_name, db_name) - - if self.config.include_technical_schema and snowflake_schema.tags: - for tag in snowflake_schema.tags: - yield from self._process_tag(tag) - - if not snowflake_schema.views and not snowflake_schema.tables: - self.report_warning( - "No tables/views found in schema. If tables exist, please grant REFERENCES or SELECT permissions on them.", - f"{db_name}.{schema_name}", - ) - - def fetch_views_for_schema( - self, snowflake_schema: SnowflakeSchema, db_name: str, schema_name: str - ) -> List[SnowflakeView]: - try: - views: List[SnowflakeView] = [] - for view in self.get_views_for_schema(schema_name, db_name): - view_name = self.get_dataset_identifier(view.name, schema_name, db_name) - - self.report.report_entity_scanned(view_name, "view") - - if not self.config.view_pattern.allowed(view_name): - self.report.report_dropped(view_name) - else: - views.append(view) - snowflake_schema.views = [view.name for view in views] - return views - except Exception as e: - if isinstance(e, SnowflakePermissionError): - # Ideal implementation would use PEP 678 – Enriching Exceptions with Notes - error_msg = f"Failed to get views for schema {db_name}.{schema_name}. Please check permissions." - - raise SnowflakePermissionError(error_msg) from e.__cause__ - else: - logger.debug( - f"Failed to get views for schema {db_name}.{schema_name} due to error {e}", - exc_info=e, - ) - self.report_warning( - "Failed to get views for schema", - f"{db_name}.{schema_name}", - ) - return [] - - def fetch_tables_for_schema( - self, snowflake_schema: SnowflakeSchema, db_name: str, schema_name: str - ) -> List[SnowflakeTable]: - try: - tables: List[SnowflakeTable] = [] - for table in self.get_tables_for_schema(schema_name, db_name): - table_identifier = self.get_dataset_identifier( - table.name, schema_name, db_name - ) - self.report.report_entity_scanned(table_identifier) - if not self.config.table_pattern.allowed(table_identifier): - self.report.report_dropped(table_identifier) - else: - tables.append(table) - snowflake_schema.tables = [table.name for table in tables] - return tables - except Exception as e: - if isinstance(e, SnowflakePermissionError): - # Ideal implementation would use PEP 678 – Enriching Exceptions with Notes - error_msg = f"Failed to get tables for schema {db_name}.{schema_name}. Please check permissions." - raise SnowflakePermissionError(error_msg) from e.__cause__ - else: - logger.debug( - f"Failed to get tables for schema {db_name}.{schema_name} due to error {e}", - exc_info=e, - ) - self.report_warning( - "Failed to get tables for schema", - f"{db_name}.{schema_name}", - ) - return [] - - def make_data_reader(self) -> Optional[SnowflakeDataReader]: - if self.classification_handler.is_classification_enabled() and self.connection: - return SnowflakeDataReader.create( - self.connection, self.snowflake_identifier - ) - - return None - - def _process_table( - self, - table: SnowflakeTable, - schema_name: str, - db_name: str, - ) -> Iterable[MetadataWorkUnit]: - table_identifier = self.get_dataset_identifier(table.name, schema_name, db_name) - - self.fetch_columns_for_table(table, schema_name, db_name, table_identifier) - - self.fetch_pk_for_table(table, schema_name, db_name, table_identifier) - - self.fetch_foreign_keys_for_table(table, schema_name, db_name, table_identifier) - - if self.config.extract_tags != TagOption.skip: - table.tags = self.tag_extractor.get_tags_on_object( - table_name=table.name, - schema_name=schema_name, - db_name=db_name, - domain="table", - ) - - if self.config.include_technical_schema: - if table.tags: - for tag in table.tags: - yield from self._process_tag(tag) - for column_name in table.column_tags: - for tag in table.column_tags[column_name]: - yield from self._process_tag(tag) - - yield from self.gen_dataset_workunits(table, schema_name, db_name) - - def fetch_foreign_keys_for_table( - self, - table: SnowflakeTable, - schema_name: str, - db_name: str, - table_identifier: str, - ) -> None: - try: - table.foreign_keys = self.get_fk_constraints_for_table( - table.name, schema_name, db_name - ) - except Exception as e: - logger.debug( - f"Failed to get foreign key for table {table_identifier} due to error {e}", - exc_info=e, - ) - self.report_warning("Failed to get foreign key for table", table_identifier) - - def fetch_pk_for_table( - self, - table: SnowflakeTable, - schema_name: str, - db_name: str, - table_identifier: str, - ) -> None: - try: - table.pk = self.get_pk_constraints_for_table( - table.name, schema_name, db_name - ) - except Exception as e: - logger.debug( - f"Failed to get primary key for table {table_identifier} due to error {e}", - exc_info=e, - ) - self.report_warning("Failed to get primary key for table", table_identifier) - - def fetch_columns_for_table( - self, - table: SnowflakeTable, - schema_name: str, - db_name: str, - table_identifier: str, - ) -> None: - try: - table.columns = self.get_columns_for_table(table.name, schema_name, db_name) - table.column_count = len(table.columns) - if self.config.extract_tags != TagOption.skip: - table.column_tags = self.tag_extractor.get_column_tags_for_table( - table.name, schema_name, db_name - ) - except Exception as e: - logger.debug( - f"Failed to get columns for table {table_identifier} due to error {e}", - exc_info=e, - ) - self.report_warning("Failed to get columns for table", table_identifier) - - def _process_view( - self, - view: SnowflakeView, - schema_name: str, - db_name: str, - ) -> Iterable[MetadataWorkUnit]: - view_name = self.get_dataset_identifier(view.name, schema_name, db_name) - - try: - view.columns = self.get_columns_for_table(view.name, schema_name, db_name) - if self.config.extract_tags != TagOption.skip: - view.column_tags = self.tag_extractor.get_column_tags_for_table( - view.name, schema_name, db_name - ) - except Exception as e: - logger.debug( - f"Failed to get columns for view {view_name} due to error {e}", - exc_info=e, - ) - self.report_warning("Failed to get columns for view", view_name) - - if self.config.extract_tags != TagOption.skip: - view.tags = self.tag_extractor.get_tags_on_object( - table_name=view.name, - schema_name=schema_name, - db_name=db_name, - domain="table", - ) - - if self.config.include_technical_schema: - if view.tags: - for tag in view.tags: - yield from self._process_tag(tag) - for column_name in view.column_tags: - for tag in view.column_tags[column_name]: - yield from self._process_tag(tag) - - yield from self.gen_dataset_workunits(view, schema_name, db_name) - - def _process_tag(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]: - tag_identifier = tag.identifier() - - if self.report.is_tag_processed(tag_identifier): - return - - self.report.report_tag_processed(tag_identifier) - - yield from self.gen_tag_workunits(tag) - - def gen_dataset_urn(self, dataset_identifier: str) -> str: - return make_dataset_urn_with_platform_instance( - platform=self.platform, - name=dataset_identifier, - platform_instance=self.config.platform_instance, - env=self.config.env, - ) - - def gen_dataset_workunits( - self, - table: Union[SnowflakeTable, SnowflakeView], - schema_name: str, - db_name: str, - ) -> Iterable[MetadataWorkUnit]: - dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name) - dataset_urn = self.gen_dataset_urn(dataset_name) - - status = Status(removed=False) - yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=status - ).as_workunit() - - schema_metadata = self.gen_schema_metadata(table, schema_name, db_name) - - yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=schema_metadata - ).as_workunit() - - dataset_properties = self.get_dataset_properties(table, schema_name, db_name) - - yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=dataset_properties - ).as_workunit() - - schema_container_key = gen_schema_key( - db_name=self.snowflake_identifier(db_name), - schema=self.snowflake_identifier(schema_name), - platform=self.platform, - platform_instance=self.config.platform_instance, - env=self.config.env, - ) - - yield from add_table_to_schema_container( - dataset_urn=dataset_urn, - parent_container_key=schema_container_key, - ) - dpi_aspect = get_dataplatform_instance_aspect( - dataset_urn=dataset_urn, - platform=self.platform, - platform_instance=self.config.platform_instance, - ) - if dpi_aspect: - yield dpi_aspect - - subTypes = SubTypes( - typeNames=( - [DatasetSubTypes.VIEW] - if isinstance(table, SnowflakeView) - else [DatasetSubTypes.TABLE] - ) - ) - - yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=subTypes - ).as_workunit() - - if self.domain_registry: - yield from get_domain_wu( - dataset_name=dataset_name, - entity_urn=dataset_urn, - domain_config=self.config.domain, - domain_registry=self.domain_registry, - ) - - if table.tags: - tag_associations = [ - TagAssociation( - tag=make_tag_urn(self.snowflake_identifier(tag.identifier())) - ) - for tag in table.tags - ] - global_tags = GlobalTags(tag_associations) - yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=global_tags - ).as_workunit() - - if isinstance(table, SnowflakeView) and table.view_definition is not None: - view_properties_aspect = ViewProperties( - materialized=table.materialized, - viewLanguage="SQL", - viewLogic=table.view_definition, - ) - - yield MetadataChangeProposalWrapper( - entityUrn=dataset_urn, aspect=view_properties_aspect - ).as_workunit() - - def get_dataset_properties( - self, - table: Union[SnowflakeTable, SnowflakeView], - schema_name: str, - db_name: str, - ) -> DatasetProperties: - return DatasetProperties( - name=table.name, - created=( - TimeStamp(time=int(table.created.timestamp() * 1000)) - if table.created is not None - else None - ), - lastModified=( - TimeStamp(time=int(table.last_altered.timestamp() * 1000)) - if table.last_altered is not None - else ( - TimeStamp(time=int(table.created.timestamp() * 1000)) - if table.created is not None - else None - ) - ), - description=table.comment, - qualifiedName=f"{db_name}.{schema_name}.{table.name}", - customProperties={}, - externalUrl=( - self.get_external_url_for_table( - table.name, - schema_name, - db_name, - ( - SnowflakeObjectDomain.TABLE - if isinstance(table, SnowflakeTable) - else SnowflakeObjectDomain.VIEW - ), - ) - if self.config.include_external_url - else None - ), - ) - - def gen_tag_workunits(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]: - tag_urn = make_tag_urn(self.snowflake_identifier(tag.identifier())) - - tag_properties_aspect = TagProperties( - name=tag.display_name(), - description=f"Represents the Snowflake tag `{tag._id_prefix_as_str()}` with value `{tag.value}`.", - ) - - yield MetadataChangeProposalWrapper( - entityUrn=tag_urn, aspect=tag_properties_aspect - ).as_workunit() - - def gen_schema_metadata( - self, - table: Union[SnowflakeTable, SnowflakeView], - schema_name: str, - db_name: str, - ) -> SchemaMetadata: - dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name) - dataset_urn = self.gen_dataset_urn(dataset_name) - - foreign_keys: Optional[List[ForeignKeyConstraint]] = None - if isinstance(table, SnowflakeTable) and len(table.foreign_keys) > 0: - foreign_keys = self.build_foreign_keys(table, dataset_urn) - - schema_metadata = SchemaMetadata( - schemaName=dataset_name, - platform=make_data_platform_urn(self.platform), - version=0, - hash="", - platformSchema=MySqlDDL(tableSchema=""), - fields=[ - SchemaField( - fieldPath=self.snowflake_identifier(col.name), - type=SchemaFieldDataType( - SNOWFLAKE_FIELD_TYPE_MAPPINGS.get(col.data_type, NullType)() - ), - # NOTE: nativeDataType will not be in sync with older connector - nativeDataType=col.get_precise_native_type(), - description=col.comment, - nullable=col.is_nullable, - isPartOfKey=( - col.name in table.pk.column_names - if isinstance(table, SnowflakeTable) and table.pk is not None - else None - ), - globalTags=( - GlobalTags( - [ - TagAssociation( - make_tag_urn( - self.snowflake_identifier(tag.identifier()) - ) - ) - for tag in table.column_tags[col.name] - ] - ) - if col.name in table.column_tags - else None - ), - ) - for col in table.columns - ], - foreignKeys=foreign_keys, - ) - - if self.aggregator: - self.aggregator.register_schema(urn=dataset_urn, schema=schema_metadata) - - return schema_metadata - - def build_foreign_keys( - self, table: SnowflakeTable, dataset_urn: str - ) -> List[ForeignKeyConstraint]: - foreign_keys = [] - for fk in table.foreign_keys: - foreign_dataset = make_dataset_urn_with_platform_instance( - platform=self.platform, - name=self.get_dataset_identifier( - fk.referred_table, fk.referred_schema, fk.referred_database - ), - env=self.config.env, - platform_instance=self.config.platform_instance, - ) - foreign_keys.append( - ForeignKeyConstraint( - name=fk.name, - foreignDataset=foreign_dataset, - foreignFields=[ - make_schema_field_urn( - foreign_dataset, - self.snowflake_identifier(col), - ) - for col in fk.referred_column_names - ], - sourceFields=[ - make_schema_field_urn( - dataset_urn, - self.snowflake_identifier(col), - ) - for col in fk.column_names - ], - ) - ) - return foreign_keys - def get_report(self) -> SourceReport: return self.report - def gen_database_containers( - self, database: SnowflakeDatabase - ) -> Iterable[MetadataWorkUnit]: - database_container_key = gen_database_key( - self.snowflake_identifier(database.name), - platform=self.platform, - platform_instance=self.config.platform_instance, - env=self.config.env, - ) - - yield from gen_database_container( - name=database.name, - database=self.snowflake_identifier(database.name), - database_container_key=database_container_key, - sub_types=[DatasetContainerSubTypes.DATABASE], - domain_registry=self.domain_registry, - domain_config=self.config.domain, - external_url=( - self.get_external_url_for_database(database.name) - if self.config.include_external_url - else None - ), - description=database.comment, - created=( - int(database.created.timestamp() * 1000) - if database.created is not None - else None - ), - last_modified=( - int(database.last_altered.timestamp() * 1000) - if database.last_altered is not None - else ( - int(database.created.timestamp() * 1000) - if database.created is not None - else None - ) - ), - tags=( - [self.snowflake_identifier(tag.identifier()) for tag in database.tags] - if database.tags - else None - ), - ) - - def gen_schema_containers( - self, schema: SnowflakeSchema, db_name: str - ) -> Iterable[MetadataWorkUnit]: - schema_name = self.snowflake_identifier(schema.name) - database_container_key = gen_database_key( - database=self.snowflake_identifier(db_name), - platform=self.platform, - platform_instance=self.config.platform_instance, - env=self.config.env, - ) - - schema_container_key = gen_schema_key( - db_name=self.snowflake_identifier(db_name), - schema=schema_name, - platform=self.platform, - platform_instance=self.config.platform_instance, - env=self.config.env, - ) - - yield from gen_schema_container( - name=schema.name, - schema=self.snowflake_identifier(schema.name), - database=self.snowflake_identifier(db_name), - database_container_key=database_container_key, - domain_config=self.config.domain, - schema_container_key=schema_container_key, - sub_types=[DatasetContainerSubTypes.SCHEMA], - domain_registry=self.domain_registry, - description=schema.comment, - external_url=( - self.get_external_url_for_schema(schema.name, db_name) - if self.config.include_external_url - else None - ), - created=( - int(schema.created.timestamp() * 1000) - if schema.created is not None - else None - ), - last_modified=( - int(schema.last_altered.timestamp() * 1000) - if schema.last_altered is not None - else ( - int(schema.created.timestamp() * 1000) - if schema.created is not None - else None - ) - ), - tags=( - [self.snowflake_identifier(tag.identifier()) for tag in schema.tags] - if schema.tags - else None - ), - ) - - def get_tables_for_schema( - self, schema_name: str, db_name: str - ) -> List[SnowflakeTable]: - tables = self.data_dictionary.get_tables_for_database(db_name) - - # get all tables for database failed, - # falling back to get tables for schema - if tables is None: - self.report.num_get_tables_for_schema_queries += 1 - return self.data_dictionary.get_tables_for_schema(schema_name, db_name) - - # Some schema may not have any table - return tables.get(schema_name, []) - - def get_views_for_schema( - self, schema_name: str, db_name: str - ) -> List[SnowflakeView]: - views = self.data_dictionary.get_views_for_database(db_name) - - # get all views for database failed, - # falling back to get views for schema - if views is None: - self.report.num_get_views_for_schema_queries += 1 - return self.data_dictionary.get_views_for_schema(schema_name, db_name) - - # Some schema may not have any table - return views.get(schema_name, []) - - def get_columns_for_table( - self, table_name: str, schema_name: str, db_name: str - ) -> List[SnowflakeColumn]: - columns = self.data_dictionary.get_columns_for_schema(schema_name, db_name) - - # get all columns for schema failed, - # falling back to get columns for table - if columns is None: - self.report.num_get_columns_for_table_queries += 1 - return self.data_dictionary.get_columns_for_table( - table_name, schema_name, db_name - ) - - # Access to table but none of its columns - is this possible ? - return columns.get(table_name, []) - - def get_pk_constraints_for_table( - self, table_name: str, schema_name: str, db_name: str - ) -> Optional[SnowflakePK]: - constraints = self.data_dictionary.get_pk_constraints_for_schema( - schema_name, db_name - ) - - # Access to table but none of its constraints - is this possible ? - return constraints.get(table_name) - - def get_fk_constraints_for_table( - self, table_name: str, schema_name: str, db_name: str - ) -> List[SnowflakeFK]: - constraints = self.data_dictionary.get_fk_constraints_for_schema( - schema_name, db_name - ) - - # Access to table but none of its constraints - is this possible ? - return constraints.get(table_name, []) - def add_config_to_report(self) -> None: self.report.cleaned_account_id = self.config.get_account() self.report.ignore_start_time_lineage = self.config.ignore_start_time_lineage @@ -1517,26 +564,6 @@ def inspect_session_metadata(self) -> None: except Exception: self.report.edition = None - # domain is either "view" or "table" - def get_external_url_for_table( - self, table_name: str, schema_name: str, db_name: str, domain: str - ) -> Optional[str]: - if self.snowsight_base_url is not None: - return f"{self.snowsight_base_url}#/data/databases/{db_name}/schemas/{schema_name}/{domain}/{table_name}/" - return None - - def get_external_url_for_schema( - self, schema_name: str, db_name: str - ) -> Optional[str]: - if self.snowsight_base_url is not None: - return f"{self.snowsight_base_url}#/data/databases/{db_name}/schemas/{schema_name}/" - return None - - def get_external_url_for_database(self, db_name: str) -> Optional[str]: - if self.snowsight_base_url is not None: - return f"{self.snowsight_base_url}#/data/databases/{db_name}/" - return None - def get_snowsight_base_url(self) -> Optional[str]: try: # See https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#finding-the-region-and-locator-for-an-account diff --git a/metadata-ingestion/src/datahub/utilities/serialized_lru_cache.py b/metadata-ingestion/src/datahub/utilities/serialized_lru_cache.py new file mode 100644 index 00000000000000..23523501ee0b49 --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/serialized_lru_cache.py @@ -0,0 +1,98 @@ +import functools +import threading +from typing import Callable, Dict, Hashable, Tuple, TypeVar + +import cachetools +import cachetools.keys +from typing_extensions import ParamSpec + +_Key = Tuple[Hashable, ...] +_F = ParamSpec("_F") +_T = TypeVar("_T") + + +def serialized_lru_cache( + maxsize: int, +) -> Callable[[Callable[_F, _T]], Callable[_F, _T]]: + """Similar to `lru_cache`, but ensures multiple calls with the same parameters are serialized. + + Calls with different parameters are allowed to proceed in parallel. + + Args: + maxsize (int): Maximum number of entries to keep in the cache. + + Returns: + Callable[[Callable[F, T]], Callable[F, T]]: Decorator for the function to be wrapped. + """ + + UNSET = object() + + def decorator(func: Callable[_F, _T]) -> Callable[_F, _T]: + hits = 0 + misses = 0 + + cache_lock = threading.Lock() + cache: "cachetools.LRUCache[_Key, _T]" = cachetools.LRUCache(maxsize=maxsize) + + key_locks_lock = threading.Lock() + key_locks: Dict[_Key, threading.Lock] = {} + key_waiters: Dict[_Key, int] = {} + + def wrapper(*args: _F.args, **kwargs: _F.kwargs) -> _T: + # We need a type ignore here because there's no way for us to require that + # the args and kwargs are hashable while using ParamSpec. + key: _Key = cachetools.keys.hashkey(*args, **kwargs) # type: ignore + + with cache_lock: + if key in cache: + nonlocal hits + hits += 1 + return cache[key] + + with key_locks_lock: + if key not in key_locks: + key_locks[key] = threading.Lock() + key_waiters[key] = 0 + lock = key_locks[key] + key_waiters[key] += 1 + + try: + with lock: + # Check the cache again, in case the cache was updated by another thread. + result = UNSET + with cache_lock: + if key in cache: + hits += 1 + return cache[key] + + nonlocal misses + misses += 1 + result = func(*args, **kwargs) + + with cache_lock: + cache[key] = result + return result + + finally: + with key_locks_lock: + key_waiters[key] -= 1 + if key_waiters[key] == 0: + del key_locks[key] + del key_waiters[key] + + def cache_info() -> functools._CacheInfo: + return functools._CacheInfo( + hits=hits, + misses=misses, + maxsize=maxsize, + currsize=len(cache), + ) + + # Add some extra attributes to the wrapper function. This makes it mostly compatible + # with functools.lru_cache. + wrapper.cache = cache # type: ignore + wrapper.cache_info = cache_info # type: ignore + + return functools.update_wrapper(wrapper, func) + + return decorator diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake.py index e9f6190c464f94..ca694b02cff010 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake.py @@ -177,11 +177,13 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph): ], ) report = cast(SnowflakeV2Report, pipeline.source.get_report()) - assert report.lru_cache_info["get_tables_for_database"]["misses"] == 1 - assert report.lru_cache_info["get_views_for_database"]["misses"] == 1 - assert report.lru_cache_info["get_columns_for_schema"]["misses"] == 1 - assert report.lru_cache_info["get_pk_constraints_for_schema"]["misses"] == 1 - assert report.lru_cache_info["get_fk_constraints_for_schema"]["misses"] == 1 + assert report.data_dictionary_cache is not None + cache_info = report.data_dictionary_cache.as_obj() + assert cache_info["get_tables_for_database"]["misses"] == 1 + assert cache_info["get_views_for_database"]["misses"] == 1 + assert cache_info["get_columns_for_schema"]["misses"] == 1 + assert cache_info["get_pk_constraints_for_schema"]["misses"] == 1 + assert cache_info["get_fk_constraints_for_schema"]["misses"] == 1 @freeze_time(FROZEN_TIME) diff --git a/metadata-ingestion/tests/unit/test_serialized_lru_cache.py b/metadata-ingestion/tests/unit/test_serialized_lru_cache.py new file mode 100644 index 00000000000000..2b937e700b4371 --- /dev/null +++ b/metadata-ingestion/tests/unit/test_serialized_lru_cache.py @@ -0,0 +1,92 @@ +import threading +import time + +from datahub.utilities.perf_timer import PerfTimer +from datahub.utilities.serialized_lru_cache import serialized_lru_cache + + +def test_cache_hit() -> None: + @serialized_lru_cache(maxsize=2) + def fetch_data(x): + return x * 2 + + assert fetch_data(1) == 2 # Cache miss + assert fetch_data(1) == 2 # Cache hit + assert fetch_data.cache_info().hits == 1 # type: ignore + assert fetch_data.cache_info().misses == 1 # type: ignore + + +def test_cache_eviction() -> None: + @serialized_lru_cache(maxsize=2) + def compute(x): + return x * 2 + + compute(1) + compute(2) + compute(3) # Should evict the first entry (1) + assert compute.cache_info().currsize == 2 # type: ignore + assert compute.cache_info().misses == 3 # type: ignore + assert compute(1) == 2 # Cache miss, since it was evicted + assert compute.cache_info().misses == 4 # type: ignore + + +def test_thread_safety() -> None: + @serialized_lru_cache(maxsize=5) + def compute(x): + time.sleep(0.2) # Simulate some delay + return x * 2 + + threads = [] + results = [None] * 10 + + def thread_func(index, arg): + results[index] = compute(arg) + + with PerfTimer() as timer: + for i in range(10): + thread = threading.Thread(target=thread_func, args=(i, i % 5)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert len(set(results)) == 5 # Only 5 unique results should be there + assert compute.cache_info().currsize <= 5 # type: ignore + # Only 5 unique calls should miss the cache + assert compute.cache_info().misses == 5 # type: ignore + + # Should take less than 1 second. If not, it means all calls were run serially. + assert timer.elapsed_seconds() < 1 + + +def test_concurrent_access_to_same_key() -> None: + @serialized_lru_cache(maxsize=3) + def compute(x: int) -> int: + time.sleep(0.2) # Simulate some delay + return x * 2 + + threads = [] + results = [] + + def thread_func(): + results.append(compute(1)) + + with PerfTimer() as timer: + for _ in range(10): + thread = threading.Thread(target=thread_func) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert all(result == 2 for result in results) # All should compute the same result + + # 9 hits, as the first one is a miss + assert compute.cache_info().hits == 9 # type: ignore + # Only the first call is a miss + assert compute.cache_info().misses == 1 # type: ignore + + # Should take less than 1 second. If not, it means all calls were run serially. + assert timer.elapsed_seconds() < 1