Skip to content

Commit

Permalink
feat(ingest/snowflake): refactor + parallel schema extraction (datahu…
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored and sleeperdeep committed Jun 25, 2024
1 parent d6395dc commit 82d17f9
Show file tree
Hide file tree
Showing 11 changed files with 1,394 additions and 1,036 deletions.
2 changes: 2 additions & 0 deletions metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@
"pandas",
"cryptography",
"msal",
"cachetools",
} | classification_lib

trino = {
Expand Down Expand Up @@ -403,6 +404,7 @@
"sagemaker": aws_common,
"salesforce": {"simple-salesforce"},
"snowflake": snowflake_common | usage_common | sqlglot_lib,
"snowflake-summary": snowflake_common | usage_common | sqlglot_lib,
"sqlalchemy": sql_common,
"sql-queries": usage_common | sqlglot_lib,
"slack": slack,
Expand Down
52 changes: 37 additions & 15 deletions metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import dataclasses
import functools
import logging
import os
import threading
import uuid
from enum import auto
from typing import Optional, Union
Expand All @@ -14,7 +16,7 @@
OperationalError,
)
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.rest_emitter import DatahubRestEmitter
from datahub.emitter.rest_emitter import DataHubRestEmitter
from datahub.ingestion.api.common import RecordEnvelope, WorkUnit
from datahub.ingestion.api.sink import (
NoopWriteCallback,
Expand All @@ -34,6 +36,10 @@

logger = logging.getLogger(__name__)

DEFAULT_REST_SINK_MAX_THREADS = int(
os.getenv("DATAHUB_REST_SINK_DEFAULT_MAX_THREADS", 15)
)


class SyncOrAsync(ConfigEnum):
SYNC = auto()
Expand All @@ -44,7 +50,7 @@ class DatahubRestSinkConfig(DatahubClientConfig):
mode: SyncOrAsync = SyncOrAsync.ASYNC

# These only apply in async mode.
max_threads: int = 15
max_threads: int = DEFAULT_REST_SINK_MAX_THREADS
max_pending_requests: int = 2000


Expand Down Expand Up @@ -82,22 +88,12 @@ def _get_partition_key(record_envelope: RecordEnvelope) -> str:


class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
emitter: DatahubRestEmitter
_emitter_thread_local: threading.local
treat_errors_as_warnings: bool = False

def __post_init__(self) -> None:
self.emitter = DatahubRestEmitter(
self.config.server,
self.config.token,
connect_timeout_sec=self.config.timeout_sec, # reuse timeout_sec for connect timeout
read_timeout_sec=self.config.timeout_sec,
retry_status_codes=self.config.retry_status_codes,
retry_max_times=self.config.retry_max_times,
extra_headers=self.config.extra_headers,
ca_certificate_path=self.config.ca_certificate_path,
client_certificate_path=self.config.client_certificate_path,
disable_ssl_verification=self.config.disable_ssl_verification,
)
self._emitter_thread_local = threading.local()

try:
gms_config = self.emitter.get_server_config()
except Exception as exc:
Expand All @@ -120,6 +116,32 @@ def __post_init__(self) -> None:
max_pending=self.config.max_pending_requests,
)

@classmethod
def _make_emitter(cls, config: DatahubRestSinkConfig) -> DataHubRestEmitter:
return DataHubRestEmitter(
config.server,
config.token,
connect_timeout_sec=config.timeout_sec, # reuse timeout_sec for connect timeout
read_timeout_sec=config.timeout_sec,
retry_status_codes=config.retry_status_codes,
retry_max_times=config.retry_max_times,
extra_headers=config.extra_headers,
ca_certificate_path=config.ca_certificate_path,
client_certificate_path=config.client_certificate_path,
disable_ssl_verification=config.disable_ssl_verification,
)

@property
def emitter(self) -> DataHubRestEmitter:
# While this is a property, it actually uses one emitter per thread.
# Since emitter is one-to-one with request sessions, using a separate
# emitter per thread should improve correctness and performance.
# https://github.com/psf/requests/issues/1871#issuecomment-32751346
thread_local = self._emitter_thread_local
if not hasattr(thread_local, "emitter"):
thread_local.emitter = DatahubRestSink._make_emitter(self.config)
return thread_local.emitter

def handle_work_unit_start(self, workunit: WorkUnit) -> None:
if isinstance(workunit, MetadataWorkUnit):
self.treat_errors_as_warnings = workunit.treat_errors_as_warnings
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, MutableSet, Optional
from typing import TYPE_CHECKING, Dict, List, MutableSet, Optional

from datahub.ingestion.api.report import Report
from datahub.ingestion.glossary.classification_mixin import ClassificationReportMixin
Expand All @@ -14,6 +14,11 @@
from datahub.sql_parsing.sql_parsing_aggregator import SqlAggregatorReport
from datahub.utilities.perf_timer import PerfTimer

if TYPE_CHECKING:
from datahub.ingestion.source.snowflake.snowflake_schema import (
SnowflakeDataDictionary,
)


@dataclass
class SnowflakeUsageAggregationReport(Report):
Expand Down Expand Up @@ -106,11 +111,7 @@ class SnowflakeV2Report(
num_tables_with_known_upstreams: int = 0
num_upstream_lineage_edge_parsing_failed: int = 0

# Reports how many times we reset in-memory `functools.lru_cache` caches of data,
# which occurs when we occur a different database / schema.
# Should not be more than the number of databases / schemas scanned.
# Maps (function name) -> (stat_name) -> (stat_value)
lru_cache_info: Dict[str, Dict[str, int]] = field(default_factory=dict)
data_dictionary_cache: Optional["SnowflakeDataDictionary"] = None

# These will be non-zero if snowflake information_schema queries fail with error -
# "Information schema query returned too much data. Please repeat query with more selective predicates.""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import logging
import os
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from functools import lru_cache
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional

from snowflake.connector import SnowflakeConnection

from datahub.ingestion.api.report import SupportsAsObj
from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeQueryMixin
from datahub.ingestion.source.sql.sql_generic import BaseColumn, BaseTable, BaseView
from datahub.utilities.serialized_lru_cache import serialized_lru_cache

logger: logging.Logger = logging.getLogger(__name__)

SCHEMA_PARALLELISM = int(os.getenv("DATAHUB_SNOWFLAKE_SCHEMA_PARALLELISM", 20))


@dataclass
class SnowflakePK:
Expand Down Expand Up @@ -176,7 +180,7 @@ def get_column_tags_for_table(
)


class SnowflakeDataDictionary(SnowflakeQueryMixin):
class SnowflakeDataDictionary(SnowflakeQueryMixin, SupportsAsObj):
def __init__(self) -> None:
self.logger = logger
self.connection: Optional[SnowflakeConnection] = None
Expand All @@ -189,6 +193,26 @@ def get_connection(self) -> SnowflakeConnection:
assert self.connection is not None
return self.connection

def as_obj(self) -> Dict[str, Dict[str, int]]:
# TODO: Move this into a proper report type that gets computed.

# Reports how many times we reset in-memory `functools.lru_cache` caches of data,
# which occurs when we occur a different database / schema.
# Should not be more than the number of databases / schemas scanned.
# Maps (function name) -> (stat_name) -> (stat_value)
lru_cache_functions: List[Callable] = [
self.get_tables_for_database,
self.get_views_for_database,
self.get_columns_for_schema,
self.get_pk_constraints_for_schema,
self.get_fk_constraints_for_schema,
]

report = {}
for func in lru_cache_functions:
report[func.__name__] = func.cache_info()._asdict() # type: ignore
return report

def show_databases(self) -> List[SnowflakeDatabase]:
databases: List[SnowflakeDatabase] = []

Expand Down Expand Up @@ -241,7 +265,7 @@ def get_schemas_for_database(self, db_name: str) -> List[SnowflakeSchema]:
snowflake_schemas.append(snowflake_schema)
return snowflake_schemas

@lru_cache(maxsize=1)
@serialized_lru_cache(maxsize=1)
def get_tables_for_database(
self, db_name: str
) -> Optional[Dict[str, List[SnowflakeTable]]]:
Expand Down Expand Up @@ -299,7 +323,7 @@ def get_tables_for_schema(
)
return tables

@lru_cache(maxsize=1)
@serialized_lru_cache(maxsize=1)
def get_views_for_database(
self, db_name: str
) -> Optional[Dict[str, List[SnowflakeView]]]:
Expand Down Expand Up @@ -349,7 +373,7 @@ def get_views_for_schema(
)
return views

@lru_cache(maxsize=1)
@serialized_lru_cache(maxsize=SCHEMA_PARALLELISM)
def get_columns_for_schema(
self, schema_name: str, db_name: str
) -> Optional[Dict[str, List[SnowflakeColumn]]]:
Expand Down Expand Up @@ -405,7 +429,7 @@ def get_columns_for_table(
)
return columns

@lru_cache(maxsize=1)
@serialized_lru_cache(maxsize=SCHEMA_PARALLELISM)
def get_pk_constraints_for_schema(
self, schema_name: str, db_name: str
) -> Dict[str, SnowflakePK]:
Expand All @@ -422,7 +446,7 @@ def get_pk_constraints_for_schema(
constraints[row["table_name"]].column_names.append(row["column_name"])
return constraints

@lru_cache(maxsize=1)
@serialized_lru_cache(maxsize=SCHEMA_PARALLELISM)
def get_fk_constraints_for_schema(
self, schema_name: str, db_name: str
) -> Dict[str, List[SnowflakeFK]]:
Expand Down
Loading

0 comments on commit 82d17f9

Please sign in to comment.