Skip to content

Commit

Permalink
fix: PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sagar-salvi-apptware committed Dec 9, 2024
1 parent 4427621 commit c1d7596
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 88 deletions.
11 changes: 1 addition & 10 deletions metadata-ingestion-modules/gx-plugin/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@ def get_long_description():

rest_common = {"requests", "requests_file"}

sqlglot_lib = {
# We heavily monkeypatch sqlglot.
# Prior to the patching, we originally maintained an acryl-sqlglot fork:
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main?expand=1
"sqlglot[rs]==25.26.0",
"patchy==2.8.0",
}

_version: str = package_metadata["__version__"]
_self_pin = (
f"=={_version}"
Expand All @@ -42,8 +34,7 @@ def get_long_description():
# https://github.com/ipython/traitlets/issues/741
"traitlets<5.2.2",
*rest_common,
*sqlglot_lib,
f"acryl-datahub[datahub-rest]{_self_pin}",
f"acryl-datahub[datahub-rest,sql-parser]{_self_pin}",
}

mypy_stubs = {
Expand Down
2 changes: 2 additions & 0 deletions metadata-ingestion/docs/sources/redash/redash.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Note! The integration can use an SQL parser to try to parse the tables the chart depends on. This parsing is disabled by default,
but can be enabled by setting `parse_table_names_from_sql: true`. The parser is based on the [`sqlglot`](https://pypi.org/project/sqlglot/) package.
12 changes: 4 additions & 8 deletions metadata-ingestion/src/datahub/ingestion/source/unity/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _get_workunits_internal(
) -> Iterable[MetadataWorkUnit]:
table_map = defaultdict(list)
query_hashes = set()
print("table_refs", table_refs)
for ref in table_refs:
table_map[ref.table].append(ref)
table_map[f"{ref.schema}.{ref.table}"].append(ref)
Expand Down Expand Up @@ -175,6 +176,7 @@ def _parse_query(
self, query: Query, table_map: TableMap
) -> Optional[QueryTableInfo]:
with self.report.usage_perf_report.sql_parsing_timer:
breakpoint()
table_info = self._parse_query_via_sqlglot(query.query_text)
if table_info is None and query.statement_type == QueryStatementType.SELECT:
with self.report.usage_perf_report.spark_sql_parsing_timer:
Expand Down Expand Up @@ -218,20 +220,14 @@ def _parse_query_via_sqlglot(self, query: str) -> Optional[StringTableInfo]:
return None

@staticmethod
def _parse_sqllineage_table(sqllineage_table: str) -> str:
full_table_name = str(sqllineage_table)
def _parse_sqlglot_table(table_urn: str) -> str:
full_table_name = DatasetUrn.from_string(table_urn).name
default_schema = "<default>."
if full_table_name.startswith(default_schema):
return full_table_name[len(default_schema) :]
else:
return full_table_name

@staticmethod
def _parse_sqlglot_table(table_urn: str) -> str:
return UnityCatalogUsageExtractor._parse_sqllineage_table(
DatasetUrn.from_string(table_urn).name
)

def _parse_query_via_spark_sql_plan(self, query: str) -> Optional[StringTableInfo]:
"""Parse query source tables via Spark SQL plan. This is a fallback option."""
# Would be more effective if we upgrade pyspark
Expand Down
130 changes: 60 additions & 70 deletions metadata-ingestion/tests/unit/utilities/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,55 @@
import doctest
import re
from typing import List

from datahub.sql_parsing.schema_resolver import SchemaResolver
from datahub.sql_parsing.sqlglot_lineage import SqlParsingDebugInfo, sqlglot_lineage
from datahub.sql_parsing.sqlglot_lineage import sqlglot_lineage
from datahub.utilities.delayed_iter import delayed_iter
from datahub.utilities.is_pytest import is_pytest_running
from datahub.utilities.urns.dataset_urn import DatasetUrn


class SqlglotSQLParser:
class SqlLineageSQLParser:
"""
It uses `sqlglot_lineage` to extract tables and columns, serving as a replacement for the `sqllineage` implementation, similar to BigQuery.
Reference: [BigQuery SQL Lineage Test](https://github.com/datahub-project/datahub/blob/master/metadata-ingestion/tests/unit/bigquery/test_bigquery_sql_lineage.py#L8).
"""

_MYVIEW_SQL_TABLE_NAME_TOKEN = "__my_view__.__sql_table_name__"
_MYVIEW_LOOKER_TOKEN = "my_view.SQL_TABLE_NAME"

def __init__(self, sql_query: str, platform: str = "bigquery") -> None:
self.result = sqlglot_lineage(sql_query, SchemaResolver(platform=platform))
# SqlLineageParser lowercarese tablenames and we need to replace Looker specific token which should be uppercased
sql_query = re.sub(
rf"(\${{{self._MYVIEW_LOOKER_TOKEN}}})",
rf"{self._MYVIEW_SQL_TABLE_NAME_TOKEN}",
sql_query,
)
self.sql_query = sql_query
self.schema_resolver = SchemaResolver(platform=platform)
self.result = sqlglot_lineage(sql_query, self.schema_resolver)

def get_tables(self) -> List[str]:
ans = []
for urn in self.result.in_tables:
table_ref = DatasetUrn.from_string(urn)
ans.append(str(table_ref.name))
return ans

result = [
self._MYVIEW_LOOKER_TOKEN if c == self._MYVIEW_SQL_TABLE_NAME_TOKEN else c
for c in ans
]
# Sort tables to make the list deterministic
result.sort()

return result

def get_columns(self) -> List[str]:
ans = set()
ans = []
for col_info in self.result.column_lineage or []:
for col_ref in col_info.upstreams:
ans.add(col_ref.column)
return list(ans)

def get_downstream_columns(self) -> List[str]:
ans = set()
for col_info in self.result.column_lineage or []:
ans.add(col_info.downstream.column)
return list(ans)

def debug_info(self) -> SqlParsingDebugInfo:
return self.result.debug_info
ans.append(col_ref.column)
return ans


def test_delayed_iter():
Expand Down Expand Up @@ -73,7 +89,7 @@ def maker(n):
def test_sqllineage_sql_parser_get_tables_from_simple_query():
sql_query = "SELECT foo.a, foo.b, bar.c FROM foo JOIN bar ON (foo.a == bar.b);"

tables_list = SqlglotSQLParser(sql_query).get_tables()
tables_list = SqlLineageSQLParser(sql_query).get_tables()
tables_list.sort()
assert tables_list == ["bar", "foo"]

Expand Down Expand Up @@ -126,31 +142,33 @@ def test_sqllineage_sql_parser_get_tables_from_complex_query():
5)
"""

tables_list = SqlglotSQLParser(sql_query).get_tables()
tables_list = SqlLineageSQLParser(sql_query).get_tables()
tables_list.sort()
assert tables_list == ["schema1.foo", "schema2.bar"]


def test_sqllineage_sql_parser_get_columns_with_join():
sql_query = "SELECT foo.a, foo.b, bar.c FROM foo JOIN bar ON (foo.a == bar.b);"

columns_list = SqlglotSQLParser(sql_query).get_columns()
columns_list = SqlLineageSQLParser(sql_query).get_columns()
columns_list.sort()
assert columns_list == ["a", "b", "c"]


def test_sqllineage_sql_parser_get_columns_from_simple_query():
sql_query = "SELECT foo.a, foo.b FROM foo;"

parser = SqlglotSQLParser(sql_query)
assert sorted(parser.get_columns()) == ["a", "b"]
columns_list = SqlLineageSQLParser(sql_query).get_columns()
columns_list.sort()
assert columns_list == ["a", "b"]


def test_sqllineage_sql_parser_get_columns_with_alias_and_count_star():
sql_query = "SELECT foo.a, foo.b, bar.c as test, count(*) as count FROM foo JOIN bar ON (foo.a == bar.b);"
parser = SqlglotSQLParser(sql_query)
assert sorted(parser.get_columns()) == ["a", "b", "c"]
assert sorted(parser.get_downstream_columns()) == ["a", "b", "count", "test"]

columns_list = SqlLineageSQLParser(sql_query).get_columns()
columns_list.sort()
assert columns_list == ["a", "b", "c"]


def test_sqllineage_sql_parser_get_columns_with_more_complex_join():
Expand All @@ -171,9 +189,10 @@ def test_sqllineage_sql_parser_get_columns_with_more_complex_join():
WHERE
fp.dt = '2018-01-01'
"""
parser = SqlglotSQLParser(sql_query)
assert sorted(parser.get_columns()) == ["bs", "pi", "tt", "v"]
assert sorted(parser.get_downstream_columns()) == ["bs", "pi", "pt", "pu", "v"]

columns_list = SqlLineageSQLParser(sql_query).get_columns()
columns_list.sort()
assert columns_list == ["bs", "pi", "tt", "tt", "v"]


def test_sqllineage_sql_parser_get_columns_complex_query_with_union():
Expand Down Expand Up @@ -223,17 +242,10 @@ def test_sqllineage_sql_parser_get_columns_complex_query_with_union():
4,
5)
"""
parser = SqlglotSQLParser(sql_query)
columns_list = parser.get_columns()
assert sorted(columns_list) == ["c", "e", "u", "x"]
assert sorted(parser.get_downstream_columns()) == [
"c",
"count(*)",
"date",
"e",
"u",
"x",
]

columns_list = SqlLineageSQLParser(sql_query).get_columns()
columns_list.sort()
assert columns_list == ["c", "c", "e", "e", "e", "e", "u", "u", "x", "x"]


def test_sqllineage_sql_parser_get_tables_from_templated_query():
Expand All @@ -246,11 +258,9 @@ def test_sqllineage_sql_parser_get_tables_from_templated_query():
FROM
${my_view.SQL_TABLE_NAME} AS my_view
"""
parser = SqlglotSQLParser(sql_query)
tables_list = parser.get_tables()
tables_list = SqlLineageSQLParser(sql_query).get_tables()
tables_list.sort()
assert tables_list == []
assert parser.debug_info().table_error is None
assert tables_list == ["my_view.SQL_TABLE_NAME"]


def test_sqllineage_sql_parser_get_columns_from_templated_query():
Expand All @@ -263,15 +273,9 @@ def test_sqllineage_sql_parser_get_columns_from_templated_query():
FROM
${my_view.SQL_TABLE_NAME} AS my_view
"""
parser = SqlglotSQLParser(sql_query)
assert sorted(parser.get_columns()) == []
assert sorted(parser.get_downstream_columns()) == [
"city",
"country",
"measurement",
"timestamp",
]
assert parser.debug_info().column_error is None
columns_list = SqlLineageSQLParser(sql_query).get_columns()
columns_list.sort()
assert columns_list == ["city", "country", "measurement", "timestamp"]


def test_sqllineage_sql_parser_with_weird_lookml_query():
Expand All @@ -280,14 +284,9 @@ def test_sqllineage_sql_parser_with_weird_lookml_query():
platform VARCHAR(20) AS aliased_platform,
country VARCHAR(20) FROM fragment_derived_view'
"""
parser = SqlglotSQLParser(sql_query)
columns_list = parser.get_columns()
columns_list = SqlLineageSQLParser(sql_query).get_columns()
columns_list.sort()
assert columns_list == []
assert (
str(parser.debug_info().table_error)
== "Error tokenizing 'untry VARCHAR(20) FROM fragment_derived_view'\n ': Missing ' from 5:143"
)


def test_sqllineage_sql_parser_tables_from_redash_query():
Expand All @@ -302,7 +301,7 @@ def test_sqllineage_sql_parser_tables_from_redash_query():
GROUP BY
name,
year(order_date)"""
table_list = SqlglotSQLParser(sql_query).get_tables()
table_list = SqlLineageSQLParser(sql_query).get_tables()
table_list.sort()
assert table_list == ["order_items", "orders", "staffs"]

Expand All @@ -324,18 +323,9 @@ def test_sqllineage_sql_parser_tables_with_special_names():
"hour-table",
"timestamp-table",
]
expected_columns = [
"column-admin",
"column-data",
"column-date",
"column-hour",
"column-timestamp",
]
assert sorted(SqlglotSQLParser(sql_query).get_tables()) == expected_tables
assert sorted(SqlglotSQLParser(sql_query).get_columns()) == []
assert (
sorted(SqlglotSQLParser(sql_query).get_downstream_columns()) == expected_columns
)
expected_columns: List[str] = []
assert sorted(SqlLineageSQLParser(sql_query).get_tables()) == expected_tables
assert sorted(SqlLineageSQLParser(sql_query).get_columns()) == expected_columns


def test_logging_name_extraction():
Expand Down

0 comments on commit c1d7596

Please sign in to comment.