Skip to content

Commit

Permalink
feat(ingest/tableau): support platform instance mapping based off dat…
Browse files Browse the repository at this point in the history
…abase server hostname (#10254)

Co-authored-by: Richie Chen <richie.chen@hulu.com>
Co-authored-by: Gabe Lyons <itsgabelyons@gmail.com>
  • Loading branch information
3 people authored May 15, 2024
1 parent 66473db commit 8e5f17b
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 1 deletion.
42 changes: 41 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Union,
cast,
)
from urllib.parse import urlparse

import dateutil.parser as dp
import tableauserverclient as TSC
Expand Down Expand Up @@ -86,6 +87,7 @@
clean_query,
custom_sql_graphql_query,
dashboard_graphql_query,
database_servers_graphql_query,
database_tables_graphql_query,
embedded_datasource_graphql_query,
get_filter_pages,
Expand Down Expand Up @@ -345,6 +347,11 @@ class TableauConfig(
description="Mappings to change generated dataset urns. Use only if you really know what you are doing.",
)

database_hostname_to_platform_instance_map: Optional[Dict[str, str]] = Field(
default=None,
description="Mappings to change platform instance in generated dataset urns based on database. Use only if you really know what you are doing.",
)

extract_usage_stats: bool = Field(
default=False,
description="[experimental] Extract usage statistics for dashboards and charts.",
Expand Down Expand Up @@ -537,6 +544,8 @@ def __init__(
self.workbook_project_map: Dict[str, str] = {}
self.datasource_project_map: Dict[str, str] = {}

# This map keeps track of the database server connection hostnames.
self.database_server_hostname_map: Dict[str, str] = {}
# This list keeps track of sheets in workbooks so that we retrieve those
# when emitting sheets.
self.sheet_ids: List[str] = []
Expand Down Expand Up @@ -609,6 +618,24 @@ def _populate_usage_stat_registry(self) -> None:
self.tableau_stat_registry[view.id] = UsageStat(view_count=view.total_views)
logger.debug("Tableau stats %s", self.tableau_stat_registry)

def _populate_database_server_hostname_map(self) -> None:
def maybe_parse_hostname():
# If the connection string is a URL instead of a hostname, parse it
# and extract the hostname, otherwise just return the connection string.
parsed_host_name = urlparse(server_connection).hostname
if parsed_host_name:
return parsed_host_name
return server_connection

for database_server in self.get_connection_objects(
database_servers_graphql_query, c.DATABASE_SERVERS_CONNECTION
):
database_server_id = database_server.get(c.ID)
server_connection = database_server.get(c.HOST_NAME)
host_name = maybe_parse_hostname()
if host_name:
self.database_server_hostname_map[str(database_server_id)] = host_name

def _get_all_project(self) -> Dict[str, TableauProject]:
all_project_map: Dict[str, TableauProject] = {}

Expand Down Expand Up @@ -864,7 +891,7 @@ def get_connection_objects(
self,
query: str,
connection_type: str,
query_filter: dict,
query_filter: dict = {},
page_size_override: Optional[int] = None,
) -> Iterable[dict]:
# Calls the get_connection_object_page function to get the objects,
Expand Down Expand Up @@ -1142,6 +1169,8 @@ def get_upstream_tables(
self.config.env,
self.config.platform_instance_map,
self.config.lineage_overrides,
self.config.database_hostname_to_platform_instance_map,
self.database_server_hostname_map,
)
table_id_to_urn[table[c.ID]] = table_urn

Expand Down Expand Up @@ -1708,8 +1737,11 @@ def parse_custom_sql(
[
str,
Optional[str],
Optional[str],
Optional[Dict[str, str]],
Optional[TableauLineageOverrides],
Optional[Dict[str, str]],
Optional[Dict[str, str]],
],
Tuple[Optional[str], Optional[str], str, str],
]
Expand Down Expand Up @@ -1753,8 +1785,11 @@ def parse_custom_sql(
upstream_db, platform_instance, platform, _ = func_overridden_info(
database_info[c.CONNECTION_TYPE],
database_info.get(c.NAME),
database_info.get(c.ID),
self.config.platform_instance_map,
self.config.lineage_overrides,
self.config.database_hostname_to_platform_instance_map,
self.database_server_hostname_map,
)

logger.debug(
Expand Down Expand Up @@ -2759,6 +2794,11 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
if self.config.extract_usage_stats:
self._populate_usage_stat_registry()

# Populate the map of database names and database hostnames to be used later to map
# databases to platform instances.
if self.config.database_hostname_to_platform_instance_map:
self._populate_database_server_hostname_map()

self._populate_projects_registry()
yield from self.emit_project_containers()
yield from self.emit_workbooks()
Expand Down
37 changes: 37 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/source/tableau_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class MetadataQueryException(Exception):
name
database {
name
id
}
schema
fullName
Expand Down Expand Up @@ -290,6 +291,7 @@ class MetadataQueryException(Exception):
name
database {
name
id
}
schema
fullName
Expand All @@ -315,6 +317,7 @@ class MetadataQueryException(Exception):
name
database {
name
id
}
schema
fullName
Expand All @@ -327,6 +330,7 @@ class MetadataQueryException(Exception):
connectionType
database{
name
id
connectionType
}
}
Expand All @@ -347,6 +351,7 @@ class MetadataQueryException(Exception):
name
database {
name
id
}
schema
fullName
Expand Down Expand Up @@ -418,6 +423,16 @@ class MetadataQueryException(Exception):
}
"""

database_servers_graphql_query = """
{
name
id
connectionType
extendedConnectionType
hostName
}
"""

# https://referencesource.microsoft.com/#system.data/System/Data/OleDb/OLEDB_Enum.cs,364
FIELD_TYPE_MAPPING = {
"INTEGER": NumberTypeClass,
Expand Down Expand Up @@ -592,6 +607,7 @@ def get_fully_qualified_table_name(
@dataclass
class TableauUpstreamReference:
database: Optional[str]
database_id: Optional[str]
schema: Optional[str]
table: str

Expand All @@ -603,6 +619,7 @@ def create(
) -> "TableauUpstreamReference":
# Values directly from `table` object from Tableau
database = t_database = d.get(c.DATABASE, {}).get(c.NAME)
database_id = d.get(c.DATABASE, {}).get(c.ID)
schema = t_schema = d.get(c.SCHEMA)
table = t_table = d.get(c.NAME) or ""
t_full_name = d.get(c.FULL_NAME)
Expand Down Expand Up @@ -654,6 +671,7 @@ def create(

return cls(
database=database,
database_id=database_id,
schema=schema,
table=table,
connection_type=t_connection_type,
Expand All @@ -679,6 +697,8 @@ def make_dataset_urn(
env: str,
platform_instance_map: Optional[Dict[str, str]],
lineage_overrides: Optional[TableauLineageOverrides] = None,
database_hostname_to_platform_instance_map: Optional[Dict[str, str]] = None,
database_server_hostname_map: Optional[Dict[str, str]] = None,
) -> str:
(
upstream_db,
Expand All @@ -688,8 +708,11 @@ def make_dataset_urn(
) = get_overridden_info(
connection_type=self.connection_type,
upstream_db=self.database,
upstream_db_id=self.database_id,
lineage_overrides=lineage_overrides,
platform_instance_map=platform_instance_map,
database_hostname_to_platform_instance_map=database_hostname_to_platform_instance_map,
database_server_hostname_map=database_server_hostname_map,
)

table_name = get_fully_qualified_table_name(
Expand All @@ -707,8 +730,11 @@ def make_dataset_urn(
def get_overridden_info(
connection_type: Optional[str],
upstream_db: Optional[str],
upstream_db_id: Optional[str],
platform_instance_map: Optional[Dict[str, str]],
lineage_overrides: Optional[TableauLineageOverrides] = None,
database_hostname_to_platform_instance_map: Optional[Dict[str, str]] = None,
database_server_hostname_map: Optional[Dict[str, str]] = None,
) -> Tuple[Optional[str], Optional[str], str, str]:
original_platform = platform = get_platform(connection_type)
if (
Expand All @@ -729,6 +755,17 @@ def get_overridden_info(
platform_instance = (
platform_instance_map.get(original_platform) if platform_instance_map else None
)
if (
database_server_hostname_map is not None
and upstream_db_id is not None
and upstream_db_id in database_server_hostname_map
):
hostname = database_server_hostname_map.get(upstream_db_id)
if (
database_hostname_to_platform_instance_map is not None
and hostname in database_hostname_to_platform_instance_map
):
platform_instance = database_hostname_to_platform_instance_map.get(hostname)

if original_platform in ("athena", "hive", "mysql"): # Two tier databases
upstream_db = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
CUSTOM_SQL_TABLE = "CustomSQLTable"
UPSTREAM_TABLES = "upstreamTables"
DATABASE_TABLES_CONNECTION = "databaseTablesConnection"
DATABASE_SERVERS_CONNECTION = "databaseServersConnection"
HOST_NAME = "hostName"
FIELDS = "fields"
UPSTREAM_DATA_SOURCES = "upstreamDatasources"
COLUMNS = "columns"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def test_lineage_overrides():
assert (
TableauUpstreamReference(
"presto_catalog",
"test-database-id",
"test-schema",
"test-table",
"presto",
Expand All @@ -586,6 +587,7 @@ def test_lineage_overrides():
assert (
TableauUpstreamReference(
"presto_catalog",
"test-database-id",
"test-schema",
"test-table",
"presto",
Expand All @@ -602,6 +604,7 @@ def test_lineage_overrides():
# transform hive urn to presto urn
assert (
TableauUpstreamReference(
None,
None,
"test-schema",
"test-table",
Expand All @@ -617,6 +620,40 @@ def test_lineage_overrides():
)


def test_database_hostname_to_platform_instance_map():
enable_logging()
# Simple - snowflake table
assert (
TableauUpstreamReference(
"test-database-name",
"test-database-id",
"test-schema",
"test-table",
"snowflake",
).make_dataset_urn(env=DEFAULT_ENV, platform_instance_map={})
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,test-database-name.test-schema.test-table,PROD)"
)

# Finding platform instance based off hostname to platform instance mappings
assert (
TableauUpstreamReference(
"test-database-name",
"test-database-id",
"test-schema",
"test-table",
"snowflake",
).make_dataset_urn(
env=DEFAULT_ENV,
platform_instance_map={},
database_hostname_to_platform_instance_map={
"test-hostname": "test-platform-instance"
},
database_server_hostname_map={"test-database-id": "test-hostname"},
)
== "urn:li:dataset:(urn:li:dataPlatform:snowflake,test-platform-instance.test-database-name.test-schema.test-table,PROD)"
)


@freeze_time(FROZEN_TIME)
def test_tableau_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
output_file_name: str = "tableau_mces.json"
Expand Down

0 comments on commit 8e5f17b

Please sign in to comment.