Skip to content

Commit

Permalink
feat(tableau): ability to force extraction of table/column level lina…
Browse files Browse the repository at this point in the history
…ge from SQL queries (datahub-project#9838)
  • Loading branch information
alexs-101 authored and sleeperdeep committed Jun 25, 2024
1 parent 7469598 commit 32c0a24
Show file tree
Hide file tree
Showing 9 changed files with 508 additions and 56 deletions.
278 changes: 226 additions & 52 deletions metadata-ingestion/src/datahub/ingestion/source/tableau.py

Large diffs are not rendered by default.

25 changes: 22 additions & 3 deletions metadata-ingestion/src/datahub/ingestion/source/tableau_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from dataclasses import dataclass
from functools import lru_cache
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from pydantic.fields import Field
from tableauserverclient import Server
Expand Down Expand Up @@ -762,8 +762,19 @@ def make_upstream_class(


def make_fine_grained_lineage_class(
parsed_result: Optional[SqlParsingResult], dataset_urn: str
parsed_result: Optional[SqlParsingResult],
dataset_urn: str,
out_columns: List[Dict[Any, Any]],
) -> List[FineGrainedLineage]:
# 1) fine grained lineage links are case sensitive
# 2) parsed out columns are always lower cased
# 3) corresponding Custom SQL output columns can be in any case lower/upper/mix
#
# we need a map between 2 and 3 that will be used during building column level linage links (see below)
out_columns_map = {
col.get(c.NAME, "").lower(): col.get(c.NAME, "") for col in out_columns
}

fine_grained_lineages: List[FineGrainedLineage] = []

if parsed_result is None:
Expand All @@ -775,7 +786,15 @@ def make_fine_grained_lineage_class(

for cll_info in cll:
downstream = (
[builder.make_schema_field_urn(dataset_urn, cll_info.downstream.column)]
[
builder.make_schema_field_urn(
dataset_urn,
out_columns_map.get(
cll_info.downstream.column.lower(),
cll_info.downstream.column,
),
)
]
if cll_info.downstream is not None
and cll_info.downstream.column is not None
else []
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Dict, Set

from datahub.sql_parsing.sqlglot_lineage import SqlParsingResult, Urn


def transform_parsing_result_to_in_tables_schemas(
parsing_result: SqlParsingResult,
) -> Dict[Urn, Set[str]]:
table_urn_to_schema_map: Dict[str, Set[str]] = (
{it: set() for it in parsing_result.in_tables}
if parsing_result.in_tables
else {}
)

if parsing_result.column_lineage:
for cli in parsing_result.column_lineage:
for upstream in cli.upstreams:
if upstream.table in table_urn_to_schema_map:
table_urn_to_schema_map[upstream.table].add(upstream.column)
else:
table_urn_to_schema_map[upstream.table] = {upstream.column}

return table_urn_to_schema_map
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ def _schema_aware_fuzzy_column_resolve(

# Parse the column name out of the node name.
# Sqlglot calls .sql(), so we have to do the inverse.
if node.name == "*":
continue

normalized_col = sqlglot.parse_one(node.name).this.name
if node.subfield:
normalized_col = f"{normalized_col}.{node.subfield}"
Expand Down Expand Up @@ -834,6 +837,7 @@ def _sqlglot_lineage_inner(
# Fetch schema info for the relevant tables.
table_name_urn_mapping: Dict[_TableName, str] = {}
table_name_schema_mapping: Dict[_TableName, SchemaInfo] = {}

for table in tables | modified:
# For select statements, qualification will be a no-op. For other statements, this
# is where the qualification actually happens.
Expand Down Expand Up @@ -1016,8 +1020,9 @@ def create_lineage_sql_parsed_result(
env: str,
default_schema: Optional[str] = None,
graph: Optional[DataHubGraph] = None,
schema_aware: bool = True,
) -> SqlParsingResult:
if graph:
if graph and schema_aware:
needs_close = False
schema_resolver = graph._make_schema_resolver(
platform=platform,
Expand Down
6 changes: 6 additions & 0 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ def _get_dialect_str(platform: str) -> str:
return "tsql"
elif platform == "athena":
return "trino"
# TODO: define SalesForce SOQL dialect
# Temporary workaround is to treat SOQL as databricks dialect
# At least it allows to parse simple SQL queries and built linage for them
elif platform == "salesforce":
return "databricks"
elif platform in {"mysql", "mariadb"}:
# In sqlglot v20+, MySQL is now case-sensitive by default, which is the
# default behavior on Linux. However, MySQL's default case sensitivity
Expand All @@ -31,6 +36,7 @@ def _get_dialect_str(platform: str) -> str:
def get_dialect(platform: DialectOrStr) -> sqlglot.Dialect:
if isinstance(platform, sqlglot.Dialect):
return platform

return sqlglot.Dialect.get_or_raise(_get_dialect_str(platform))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42870,6 +42870,38 @@
"lastRunId": "no-run-id-provided"
}
},
{
"entityType": "dataset",
"entityUrn": "urn:li:dataset:(urn:li:dataPlatform:bigquery,demo-custom-323403.bigquery_demo.order_items,PROD)",
"changeType": "UPSERT",
"aspectName": "status",
"aspect": {
"json": {
"removed": false
}
},
"systemMetadata": {
"lastObserved": 1638860400000,
"runId": "tableau-test",
"lastRunId": "no-run-id-provided"
}
},
{
"entityType": "dataset",
"entityUrn": "urn:li:dataset:(urn:li:dataPlatform:bigquery,demo-custom-323403.bigquery_demo.sellers,PROD)",
"changeType": "UPSERT",
"aspectName": "status",
"aspect": {
"json": {
"removed": false
}
},
"systemMetadata": {
"lastObserved": 1638860400000,
"runId": "tableau-test",
"lastRunId": "no-run-id-provided"
}
},
{
"entityType": "dataset",
"entityUrn": "urn:li:dataset:(urn:li:dataPlatform:external,sample - superstore%2C %28new%29.xls.orders,PROD)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ def test_tableau_cll_ingest(pytestconfig, tmp_path, mock_datahub_graph):
new_pipeline_config: Dict[Any, Any] = {
**config_source_default,
"extract_lineage_from_unsupported_custom_sql_queries": True,
"force_extraction_of_lineage_from_custom_sql_queries": False,
"sql_parsing_disable_schema_awareness": False,
"extract_column_level_lineage": True,
}

Expand Down Expand Up @@ -834,6 +836,7 @@ def test_tableau_unsupported_csql(mock_datahub_graph):
"connectionType": "bigquery",
},
},
out_columns=[],
)

mcp = cast(MetadataChangeProposalClass, next(iter(lineage)).metadata)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from datahub.sql_parsing.sql_parsing_result_utils import (
transform_parsing_result_to_in_tables_schemas,
)
from datahub.sql_parsing.sqlglot_lineage import (
ColumnLineageInfo,
ColumnRef,
DownstreamColumnRef,
SqlParsingResult,
)


def test_transform_parsing_result_to_in_tables_schemas__empty_parsing_result():
parsing_result = SqlParsingResult(in_tables=[], out_tables=[], column_lineage=None)

in_tables_schema = transform_parsing_result_to_in_tables_schemas(parsing_result)
assert not in_tables_schema


def test_transform_parsing_result_to_in_tables_schemas__in_tables_only():
parsing_result = SqlParsingResult(
in_tables=["table_urn1", "table_urn2", "table_urn3"],
out_tables=[],
column_lineage=None,
)

in_tables_schema = transform_parsing_result_to_in_tables_schemas(parsing_result)
assert in_tables_schema == {
"table_urn1": set(),
"table_urn2": set(),
"table_urn3": set(),
}


def test_transform_parsing_result_to_in_tables_schemas__in_tables_and_column_linage():
parsing_result = SqlParsingResult(
in_tables=["table_urn1", "table_urn2", "table_urn3"],
out_tables=[],
column_lineage=[
ColumnLineageInfo(
downstream=DownstreamColumnRef(column="out_col1"),
upstreams=[
ColumnRef(table="table_urn1", column="col11"),
],
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(column="out_col2"),
upstreams=[
ColumnRef(table="table_urn2", column="col21"),
ColumnRef(table="table_urn2", column="col22"),
],
),
ColumnLineageInfo(
downstream=DownstreamColumnRef(column="out_col3"),
upstreams=[
ColumnRef(table="table_urn1", column="col12"),
ColumnRef(table="table_urn2", column="col23"),
],
),
],
)

in_tables_schema = transform_parsing_result_to_in_tables_schemas(parsing_result)
assert in_tables_schema == {
"table_urn1": {"col11", "col12"},
"table_urn2": {"col21", "col22", "col23"},
"table_urn3": set(),
}
123 changes: 123 additions & 0 deletions metadata-ingestion/tests/unit/test_tableau_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import pytest

from datahub.ingestion.source.tableau import TableauSource


def test_tableau_source_unescapes_lt():
res = TableauSource._clean_tableau_query_parameters(
"select * from t where c1 << 135"
)

assert res == "select * from t where c1 < 135"


def test_tableau_source_unescapes_gt():
res = TableauSource._clean_tableau_query_parameters(
"select * from t where c1 >> 135"
)

assert res == "select * from t where c1 > 135"


def test_tableau_source_unescapes_gte():
res = TableauSource._clean_tableau_query_parameters(
"select * from t where c1 >>= 135"
)

assert res == "select * from t where c1 >= 135"


def test_tableau_source_unescapeslgte():
res = TableauSource._clean_tableau_query_parameters(
"select * from t where c1 <<= 135"
)

assert res == "select * from t where c1 <= 135"


def test_tableau_source_doesnt_touch_not_escaped():
res = TableauSource._clean_tableau_query_parameters(
"select * from t where c1 < 135 and c2 > 15"
)

assert res == "select * from t where c1 < 135 and c2 > 15"


TABLEAU_PARAMS = [
"<Parameters.MyParam>",
"<Parameters.MyParam_1>",
"<Parameters.My Param _ 1>",
"<Parameters.My Param 1 !@\"',.#$%^:;&*()-_+={}|\\ /<>",
"<[Parameters].MyParam>",
"<[Parameters].MyParam_1>",
"<[Parameters].My Param _ 1>",
"<[Parameters].My Param 1 !@\"',.#$%^:;&*()-_+={}|\\ /<>",
"<Parameters.[MyParam]>",
"<Parameters.[MyParam_1]>",
"<Parameters.[My Param _ 1]>",
"<Parameters.[My Param 1 !@\"',.#$%^:;&*()-_+={}|\\ /<]>",
"<[Parameters].[MyParam]>",
"<[Parameters].[MyParam_1]>",
"<[Parameters].[My Param _ 1]>",
"<[Parameters].[My Param 1 !@\"',.#$%^:;&*()-_+={}|\\ /<]>",
"<Parameters.[My Param 1 !@\"',.#$%^:;&*()-_+={}|\\ /<>]>",
"<[Parameters].[My Param 1 !@\"',.#$%^:;&*()-_+={}|\\ /<>]>",
]


@pytest.mark.parametrize("p", TABLEAU_PARAMS)
def test_tableau_source_cleanups_tableau_parameters_in_equi_predicates(p):
assert (
TableauSource._clean_tableau_query_parameters(
f"select * from t where c1 = {p} and c2 = {p} and c3 = 7"
)
== "select * from t where c1 = 1 and c2 = 1 and c3 = 7"
)


@pytest.mark.parametrize("p", TABLEAU_PARAMS)
def test_tableau_source_cleanups_tableau_parameters_in_lt_gt_predicates(p):
assert (
TableauSource._clean_tableau_query_parameters(
f"select * from t where c1 << {p} and c2<<{p} and c3 >> {p} and c4>>{p} or {p} >> c1 and {p}>>c2 and {p} << c3 and {p}<<c4"
)
== "select * from t where c1 < 1 and c2<1 and c3 > 1 and c4>1 or 1 > c1 and 1>c2 and 1 < c3 and 1<c4"
)


@pytest.mark.parametrize("p", TABLEAU_PARAMS)
def test_tableau_source_cleanups_tableau_parameters_in_lte_gte_predicates(p):
assert (
TableauSource._clean_tableau_query_parameters(
f"select * from t where c1 <<= {p} and c2<<={p} and c3 >>= {p} and c4>>={p} or {p} >>= c1 and {p}>>=c2 and {p} <<= c3 and {p}<<=c4"
)
== "select * from t where c1 <= 1 and c2<=1 and c3 >= 1 and c4>=1 or 1 >= c1 and 1>=c2 and 1 <= c3 and 1<=c4"
)


@pytest.mark.parametrize("p", TABLEAU_PARAMS)
def test_tableau_source_cleanups_tableau_parameters_in_join_predicate(p):
assert (
TableauSource._clean_tableau_query_parameters(
f"select * from t1 inner join t2 on t1.id = t2.id and t2.c21 = {p} and t1.c11 = 123 + {p}"
)
== "select * from t1 inner join t2 on t1.id = t2.id and t2.c21 = 1 and t1.c11 = 123 + 1"
)


@pytest.mark.parametrize("p", TABLEAU_PARAMS)
def test_tableau_source_cleanups_tableau_parameters_in_complex_expressions(p):
assert (
TableauSource._clean_tableau_query_parameters(
f"select myudf1(c1, {p}, c2) / myudf2({p}) > ({p} + 3 * {p} * c5) * {p} - c4"
)
== "select myudf1(c1, 1, c2) / myudf2(1) > (1 + 3 * 1 * c5) * 1 - c4"
)


@pytest.mark.parametrize("p", TABLEAU_PARAMS)
def test_tableau_source_cleanups_tableau_parameters_in_udfs(p):
assert (
TableauSource._clean_tableau_query_parameters(f"select myudf({p}) from t")
== "select myudf(1) from t"
)

0 comments on commit 32c0a24

Please sign in to comment.