-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #264 from lyliyu/view-sort
Sort views topologically before importing
- Loading branch information
Showing
7 changed files
with
209 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from collections import deque | ||
import sqlparse | ||
from typing import Set, List | ||
from collections import defaultdict | ||
import os, re | ||
|
||
|
||
def extract_source_tables(ddl_query: str, all_valid_names: Set[str]): | ||
""" | ||
Extracts table names from a SQL query that includes nested FROM statements. | ||
Returns a list of unique table names in the order they appear in the query. | ||
""" | ||
sql_query = ddl_query.replace("`", "") | ||
table_names = set() | ||
regex = r'\b(?:FROM|JOIN|UNION)\b\s+([\w.]+)' | ||
matches = re.findall(regex, sql_query) | ||
for match in matches: | ||
table_name = match.lower() | ||
if ((all_valid_names and table_name in all_valid_names) or (not all_valid_names)) \ | ||
and table_name not in table_names: | ||
table_names.add(table_name) | ||
return table_names | ||
|
||
def unpack_view_db_name(full_view_name: str): | ||
parts = full_view_name.split(".") | ||
assert len(parts) == 2, f"{full_view_name} is not formatted correctly." | ||
return parts[0], parts[1] | ||
|
||
def get_view_dependencies(metastore_view_dir: str, full_view_name: str, all_views: Set[str]): | ||
print(f"processing dependencies of {full_view_name}") | ||
db_name, vw = unpack_view_db_name(full_view_name) | ||
# ddl_query = spark.sql(f"show create table {view_name}").collect()[0][0] | ||
ddl_full_path = os.path.join(metastore_view_dir, db_name, vw) | ||
dep_set = set() | ||
with open(ddl_full_path, "r") as f: | ||
ddl_query = f.read() | ||
identifiers = extract_source_tables(ddl_query, all_views) | ||
for token in identifiers: | ||
if full_view_name.lower() in token.lower(): | ||
continue | ||
dep_set.add(token) | ||
print(f"dependencies: {dep_set}") | ||
return dep_set | ||
|
||
def create_dependency_graph(metastore_view_dir: str, all_views: Set[str]): | ||
view_parents_dct = dict() | ||
for view_name in all_views: | ||
dep_views = get_view_dependencies(metastore_view_dir, view_name, all_views) | ||
view_parents_dct[view_name] = dep_views | ||
return view_parents_dct | ||
|
||
def sort_views_topology(view_parents_dct): | ||
view_children_dct = defaultdict(set) | ||
q = deque([]) | ||
for view, parents in view_parents_dct.items(): | ||
for pview in parents: | ||
view_children_dct[pview].add(view) | ||
if not parents: | ||
q.append(view) | ||
sorted_views = [] | ||
while q: | ||
cur_view = q.popleft() | ||
sorted_views.append(cur_view) | ||
for child_view in view_children_dct[cur_view]: | ||
view_parents_dct[child_view].remove(cur_view) | ||
if not view_parents_dct[child_view]: | ||
q.append(child_view) | ||
return sorted_views |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import unittest | ||
from unittest.mock import MagicMock | ||
from dbclient import HiveClient | ||
from dbclient.test.TestUtils import TEST_CONFIG | ||
from io import StringIO | ||
from dbclient.common.ViewSort import sort_views_topology, get_view_dependencies | ||
from unittest import mock | ||
|
||
class TestViews(unittest.TestCase): | ||
def test_sort_views_topology(self): | ||
view_parents_graph = { | ||
"view1": {"view2", "view3"}, | ||
"view3": {"view4"}, | ||
"view2": {}, | ||
"view4": {"view5", "view6"}, | ||
"view5": {}, | ||
"view6": {}, | ||
"view7": {} | ||
} | ||
views = sort_views_topology(view_parents_graph) | ||
assert views.index("view1") > views.index("view2") and views.index("view1") > views.index("view3") \ | ||
and views.index("view3") > views.index("view4") \ | ||
and views.index("view4") > views.index("view5") and views.index("view4") > views.index("view6") | ||
|
||
def test_get_view_dependencies(self): | ||
view_ddl = """ | ||
CREATE VIEW `default`.`test_view` ( | ||
first_name, | ||
middle_name, | ||
last_name, | ||
relationship_type_cd, | ||
receipt_number) | ||
TBLPROPERTIES ( | ||
'transient_lastDdlTime' = '1674499157') | ||
AS SELECT | ||
p.first_name AS first_name, | ||
p.middle_name AS middle_name, | ||
p.last_name AS last_name, | ||
pc.role_id AS relationship_type_cd, | ||
pc.receipt_number AS receipt_number | ||
FROM `db1`.`persons` pc | ||
JOIN `db2`.`person` p | ||
ON pc.person_id = p.person_id | ||
AND pc.svr_ctr_cd = p.svr_ctr_cd | ||
WHERE | ||
pc.role_id = 11 | ||
AND (p.first_name is not null or p.middle_name is not null or p.first_name is not null ) | ||
""" | ||
mock_open = mock.mock_open(read_data=view_ddl) | ||
with mock.patch("builtins.open", mock_open): | ||
deps = get_view_dependencies("/tmp/metastore_view", "default.test_view", {}) | ||
assert deps == set(["db1.persons", "db2.person"]) | ||
|
||
|
||
def test_get_view_deps_nested(self): | ||
view_ddl = """ | ||
CREATE VIEW test.view1 ( | ||
step_rank, | ||
same_step_instance, | ||
id, | ||
t_cd) | ||
AS SELECT ROW_NUMBER() OVER (PARTITION BYID ORDER BY st_cd_start_date) AS step_rank, | ||
ROW_NUMBER() OVER (PARTITION BY id, st_cd ORDER BY st_cd_start_date) AS same_step_instance, | ||
id, | ||
st_cd, | ||
st_cd_start_date, | ||
st_cd_end_date, | ||
datediff(st_cd_end_date, st_cd_start_date) AS step_duration | ||
FROM ( | ||
SELECT id, st_cd, st_cd_start_date | ||
FROM ( | ||
SELECT id, NVL(st_cd, 'Null') AS st_cd | ||
FROM test.view2 ch | ||
) aa | ||
WHERE Is_Boundry = 1) bb | ||
WHERE st_cd_start_date IS NOT NULL | ||
""" | ||
mock_open = mock.mock_open(read_data=view_ddl) | ||
with mock.patch("builtins.open", mock_open): | ||
deps = get_view_dependencies("/tmp/metastore_view", "tdss.case_actn_hist_st_cd_instances", {}) | ||
assert len(deps) == 1 and next(iter(deps)) == "test.view2" |