Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest): support ingesting from multiple snowflake dbs #2793

Merged
merged 3 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion metadata-ingestion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,13 @@ source:
username: user
password: pass
host_port: account_name
database: db_name
database_pattern:
allow:
- ^regex$
- ^another_regex$
deny:
- ^SNOWFLAKE$
- ^SNOWFLAKE_SAMPLE_DATA$
warehouse: "COMPUTE_WH" # optional
role: "sysadmin" # optional
include_views: True # whether to include views, defaults to True
Expand Down
1 change: 0 additions & 1 deletion metadata-ingestion/src/datahub/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def datahub(debug: bool) -> None:
logging.getLogger("datahub").setLevel(logging.INFO)
# loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
# print(loggers)
# breakpoint()


@datahub.command()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ def from_entry(cls, entry: AuditLogEntry) -> "QueryEvent":
referencedTables = [
BigQueryTableRef.from_spec_obj(spec) for spec in rawRefTables
]
# if job['jobConfiguration']['query']['statementType'] != "SCRIPT" and not referencedTables:
# breakpoint()

queryEvent = QueryEvent(
timestamp=entry.timestamp,
Expand Down
50 changes: 47 additions & 3 deletions metadata-ingestion/src/datahub/ingestion/source/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import logging
from typing import Optional
from typing import Iterable, Optional

import pydantic

# This import verifies that the dependencies are available.
import snowflake.sqlalchemy # noqa: F401
from snowflake.sqlalchemy import custom_types
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import text
from sqlalchemy.sql.elements import quoted_name

from datahub.configuration.common import ConfigModel
from datahub.configuration.common import AllowDenyPattern, ConfigModel

from .sql_common import (
SQLAlchemyConfig,
Expand All @@ -18,6 +24,7 @@
register_custom_type(custom_types.TIMESTAMP_TZ, TimeTypeClass)
register_custom_type(custom_types.TIMESTAMP_LTZ, TimeTypeClass)
register_custom_type(custom_types.TIMESTAMP_NTZ, TimeTypeClass)
register_custom_type(custom_types.VARIANT)

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,7 +60,23 @@ def get_sql_alchemy_url(self, database=None):


class SnowflakeConfig(BaseSnowflakeConfig, SQLAlchemyConfig):
database: str
database_pattern: AllowDenyPattern = AllowDenyPattern(
deny=[
r"^UTIL_DB$",
r"^SNOWFLAKE$",
r"^SNOWFLAKE_SAMPLE_DATA$",
]
)

database: str = ".*" # deprecated

@pydantic.validator("database")
def note_database_opt_deprecation(cls, v, values, **kwargs):
logger.warn(
"snowflake's `database` option has been deprecated; use database_pattern instead"
)
values["database_pattern"].allow = f"^{v}$"
return None

def get_sql_alchemy_url(self):
return super().get_sql_alchemy_url(self.database)
Expand All @@ -64,10 +87,31 @@ def get_identifier(self, schema: str, table: str) -> str:


class SnowflakeSource(SQLAlchemySource):
config: SnowflakeConfig

def __init__(self, config, ctx):
super().__init__(config, ctx, "snowflake")

@classmethod
def create(cls, config_dict, ctx):
config = SnowflakeConfig.parse_obj(config_dict)
return cls(config, ctx)

def get_inspectors(self) -> Iterable[Inspector]:
url = self.config.get_sql_alchemy_url()
logger.debug(f"sql_alchemy_url={url}")
engine = create_engine(url, **self.config.options)

for db_row in engine.execute(text("SHOW DATABASES")):
with engine.connect() as conn:
db = db_row.name
if self.config.database_pattern.allowed(db):
# TRICKY: As we iterate through this loop, we modify the value of
# self.config.database so that the get_identifier method can function
# as intended.
self.config.database = db
conn.execute((f'USE DATABASE "{quoted_name(db, True)}"'))
inspector = inspect(conn)
yield inspector
else:
self.report.report_dropped(db)
33 changes: 21 additions & 12 deletions metadata-ingestion/src/datahub/ingestion/source/sql_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,26 +238,35 @@ def __init__(self, config: SQLAlchemyConfig, ctx: PipelineContext, platform: str
self.platform = platform
self.report = SQLSourceReport()

def get_inspectors(self) -> Iterable[Inspector]:
# This method can be overridden in the case that you want to dynamically
# run on multiple databases.

url = self.config.get_sql_alchemy_url()
logger.debug(f"sql_alchemy_url={url}")
engine = create_engine(url, **self.config.options)
inspector = inspect(engine)
yield inspector

def get_workunits(self) -> Iterable[SqlWorkUnit]:
sql_config = self.config
if logger.isEnabledFor(logging.DEBUG):
# If debug logging is enabled, we also want to echo each SQL query issued.
sql_config.options["echo"] = True

url = sql_config.get_sql_alchemy_url()
logger.debug(f"sql_alchemy_url={url}")
engine = create_engine(url, **sql_config.options)
inspector = inspect(engine)
for schema in inspector.get_schema_names():
if not sql_config.schema_pattern.allowed(schema):
self.report.report_dropped(schema)
continue
for inspector in self.get_inspectors():
for schema in inspector.get_schema_names():
if not sql_config.schema_pattern.allowed(schema):
self.report.report_dropped(
".".join(sql_config.standardize_schema_table_names(schema, "*"))
)
continue

if sql_config.include_tables:
yield from self.loop_tables(inspector, schema, sql_config)
if sql_config.include_tables:
yield from self.loop_tables(inspector, schema, sql_config)

if sql_config.include_views:
yield from self.loop_views(inspector, schema, sql_config)
if sql_config.include_views:
yield from self.loop_views(inspector, schema, sql_config)

def loop_tables(
self,
Expand Down
2 changes: 1 addition & 1 deletion metadata-ingestion/tests/unit/test_snowflake_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ def test_snowflake_uri():

assert (
config.get_sql_alchemy_url()
== "snowflake://user:password@acctname/demo?warehouse=COMPUTE_WH&role=sysadmin"
== "snowflake://user:password@acctname/?warehouse=COMPUTE_WH&role=sysadmin"
)