Skip to content

Commit

Permalink
Merge pull request #264 from lyliyu/view-sort
Browse files Browse the repository at this point in the history
Sort views topologically before importing
  • Loading branch information
gregwood-db committed Jul 6, 2023
2 parents d08e7a0 + 80b42ab commit 8d242b7
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 19 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ usage: import_db.py [-h] [--users] [--workspace] [--workspace-top-level]
[--no-ssl-verification] [--silent] [--debug]
[--set-export-dir SET_EXPORT_DIR] [--pause-all-jobs]
[--unpause-all-jobs] [--import-pause-status]
[--delete-all-jobs] [--last-session]
[--delete-all-jobs] [--last-session] [--sort-views]
Import full workspace artifacts into Databricks
Expand Down Expand Up @@ -391,6 +391,9 @@ optional arguments:
--delete-all-jobs Delete all jobs
--last-session
The session to compare against. If set, the script compares current sesssion with the last session and only import updated and new notebooks.
--sort-views
Sort all views topologically based upon dependencies before importing.
e.g. if view A is created from view B, B will be imported before A. This will solve the cases when views are created from other views.
```

---
Expand Down
2 changes: 1 addition & 1 deletion data/notebooks/Import_Table_ACLs_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def execute_sql_statements(sqls):
if sql:
print(f"{sql};")
try:
# spark.sql(sql)
spark.sql(sql)
num_sucessfully_executed = num_sucessfully_executed+1
except:
error_causing_sqls.append({'sql': sql, 'error': sys.exc_info()})
Expand Down
63 changes: 47 additions & 16 deletions dbclient/HiveClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import logging_utils
import re
from dbclient import *
from collections import defaultdict
from dbclient.common.ViewSort import create_dependency_graph, sort_views_topology, unpack_view_db_name


class HiveClient(ClustersClient):
Expand Down Expand Up @@ -361,7 +363,7 @@ def move_table_view(self, db_name, tbl_name, local_table_ddl, views_dir='metasto
return False

def import_hive_metastore(self, cluster_name=None, metastore_dir='metastore/', views_dir='metastore_views/',
has_unicode=False, should_repair_table=False):
has_unicode=False, should_repair_table=False, sort_views=False):
metastore_local_dir = self.get_export_dir() + metastore_dir
metastore_view_dir = self.get_export_dir() + views_dir
error_logger = logging_utils.get_error_logger(
Expand Down Expand Up @@ -413,21 +415,50 @@ def import_hive_metastore(self, cluster_name=None, metastore_dir='metastore/', v
logging.error("Error: Only databases should exist at this level: {0}".format(db_name))
self.delete_dir_if_empty(metastore_view_dir + db_name)
views_db_list = self.listdir(metastore_view_dir)
for db_name in views_db_list:
local_view_db_path = metastore_view_dir + db_name
database_attributes = all_db_details_json.get(db_name, '')
db_path = database_attributes.get('Location')
if os.path.isdir(local_view_db_path):
views = self.listdir(local_view_db_path)
for view_name in views:
full_view_name = f'{db_name}.{view_name}'
if not checkpoint_metastore_set.contains(full_view_name):
logging.info(f"Importing view {full_view_name}")
local_view_ddl = metastore_view_dir + db_name + '/' + view_name
resp = self.apply_table_ddl(local_view_ddl, ec_id, cid, db_path, has_unicode)
if logging_utils.log_response_error(error_logger, resp):
checkpoint_metastore_set.write(full_view_name)
logging.info(resp)

if sort_views:
# To sort views, we will scan and get all the views first
all_view_set = set()
for db_name in views_db_list:
local_view_db_path = metastore_view_dir + db_name
if os.path.isdir(local_view_db_path):
views = self.listdir(local_view_db_path)
for v in views:
all_view_set.add(f"{db_name}.{v}")
logging.info(f"all views: {all_view_set}")
# Build dependency graph of the views
view_parents_dct = create_dependency_graph(metastore_view_dir, all_view_set)
# Sort the views using the dependency graph
logging.info(f"view graph: {view_parents_dct}")
sorted_views = sort_views_topology(view_parents_dct)
logging.info(f"Importing order of views: {sorted_views}")
# Import views in the sorted order
for full_view_name in sorted_views:
if not checkpoint_metastore_set.contains(full_view_name):
logging.info(f"Importing view {full_view_name}")
db_name, view_name = unpack_view_db_name(full_view_name)
local_view_ddl = metastore_view_dir + db_name + '/' + view_name
resp = self.apply_table_ddl(local_view_ddl, ec_id, cid, db_path, has_unicode)
if logging_utils.log_response_error(error_logger, resp):
checkpoint_metastore_set.write(full_view_name)
logging.info(resp)

else:
for db_name in views_db_list:
local_view_db_path = metastore_view_dir + db_name
database_attributes = all_db_details_json.get(db_name, '')
db_path = database_attributes.get('Location')
if os.path.isdir(local_view_db_path):
views = self.listdir(local_view_db_path)
for view_name in views:
full_view_name = f'{db_name}.{view_name}'
if not checkpoint_metastore_set.contains(full_view_name):
logging.info(f"Importing view {full_view_name}")
local_view_ddl = metastore_view_dir + db_name + '/' + view_name
resp = self.apply_table_ddl(local_view_ddl, ec_id, cid, db_path, has_unicode)
if logging_utils.log_response_error(error_logger, resp):
checkpoint_metastore_set.write(full_view_name)
logging.info(resp)

# repair legacy tables
if should_repair_table:
Expand Down
68 changes: 68 additions & 0 deletions dbclient/common/ViewSort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from collections import deque
import sqlparse
from typing import Set, List
from collections import defaultdict
import os, re


def extract_source_tables(ddl_query: str, all_valid_names: Set[str]):
"""
Extracts table names from a SQL query that includes nested FROM statements.
Returns a list of unique table names in the order they appear in the query.
"""
sql_query = ddl_query.replace("`", "")
table_names = set()
regex = r'\b(?:FROM|JOIN|UNION)\b\s+([\w.]+)'
matches = re.findall(regex, sql_query)
for match in matches:
table_name = match.lower()
if ((all_valid_names and table_name in all_valid_names) or (not all_valid_names)) \
and table_name not in table_names:
table_names.add(table_name)
return table_names

def unpack_view_db_name(full_view_name: str):
parts = full_view_name.split(".")
assert len(parts) == 2, f"{full_view_name} is not formatted correctly."
return parts[0], parts[1]

def get_view_dependencies(metastore_view_dir: str, full_view_name: str, all_views: Set[str]):
print(f"processing dependencies of {full_view_name}")
db_name, vw = unpack_view_db_name(full_view_name)
# ddl_query = spark.sql(f"show create table {view_name}").collect()[0][0]
ddl_full_path = os.path.join(metastore_view_dir, db_name, vw)
dep_set = set()
with open(ddl_full_path, "r") as f:
ddl_query = f.read()
identifiers = extract_source_tables(ddl_query, all_views)
for token in identifiers:
if full_view_name.lower() in token.lower():
continue
dep_set.add(token)
print(f"dependencies: {dep_set}")
return dep_set

def create_dependency_graph(metastore_view_dir: str, all_views: Set[str]):
view_parents_dct = dict()
for view_name in all_views:
dep_views = get_view_dependencies(metastore_view_dir, view_name, all_views)
view_parents_dct[view_name] = dep_views
return view_parents_dct

def sort_views_topology(view_parents_dct):
view_children_dct = defaultdict(set)
q = deque([])
for view, parents in view_parents_dct.items():
for pview in parents:
view_children_dct[pview].add(view)
if not parents:
q.append(view)
sorted_views = []
while q:
cur_view = q.popleft()
sorted_views.append(cur_view)
for child_view in view_children_dct[cur_view]:
view_parents_dct[child_view].remove(cur_view)
if not view_parents_dct[child_view]:
q.append(child_view)
return sorted_views
6 changes: 6 additions & 0 deletions dbclient/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ def get_import_parser():

parser.add_argument('--retry-backoff', type=float, default=1.0, help='Backoff factor to apply between retry attempts when making calls to Databricks API')

parser.add_argument('--sort-views', action='store_true', default=False,
help='If True, the views will be sorted based upon dependencies before importing.')

return parser


Expand Down Expand Up @@ -568,4 +571,7 @@ def get_pipeline_parser() -> argparse.ArgumentParser:
parser.add_argument('--last-session', action='store', default='',
help='If set, the script compares current sesssion with the last session and only import updated and new notebooks.')

parser.add_argument('--sort-views', action='store_true', default=False,
help='If True, the views will be sorted based upon dependencies before importing.')

return parser
3 changes: 2 additions & 1 deletion tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ def run(self):
# log job configs
hive_c.import_hive_metastore(cluster_name=self.args.cluster_name,
has_unicode=self.args.metastore_unicode,
should_repair_table=self.args.repair_metastore_tables)
should_repair_table=self.args.repair_metastore_tables,
sort_views = self.args.sort_views)


class MetastoreTableACLExportTask(AbstractTask):
Expand Down
81 changes: 81 additions & 0 deletions test/view_sort_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import unittest
from unittest.mock import MagicMock
from dbclient import HiveClient
from dbclient.test.TestUtils import TEST_CONFIG
from io import StringIO
from dbclient.common.ViewSort import sort_views_topology, get_view_dependencies
from unittest import mock

class TestViews(unittest.TestCase):
def test_sort_views_topology(self):
view_parents_graph = {
"view1": {"view2", "view3"},
"view3": {"view4"},
"view2": {},
"view4": {"view5", "view6"},
"view5": {},
"view6": {},
"view7": {}
}
views = sort_views_topology(view_parents_graph)
assert views.index("view1") > views.index("view2") and views.index("view1") > views.index("view3") \
and views.index("view3") > views.index("view4") \
and views.index("view4") > views.index("view5") and views.index("view4") > views.index("view6")

def test_get_view_dependencies(self):
view_ddl = """
CREATE VIEW `default`.`test_view` (
first_name,
middle_name,
last_name,
relationship_type_cd,
receipt_number)
TBLPROPERTIES (
'transient_lastDdlTime' = '1674499157')
AS SELECT
p.first_name AS first_name,
p.middle_name AS middle_name,
p.last_name AS last_name,
pc.role_id AS relationship_type_cd,
pc.receipt_number AS receipt_number
FROM `db1`.`persons` pc
JOIN `db2`.`person` p
ON pc.person_id = p.person_id
AND pc.svr_ctr_cd = p.svr_ctr_cd
WHERE
pc.role_id = 11
AND (p.first_name is not null or p.middle_name is not null or p.first_name is not null )
"""
mock_open = mock.mock_open(read_data=view_ddl)
with mock.patch("builtins.open", mock_open):
deps = get_view_dependencies("/tmp/metastore_view", "default.test_view", {})
assert deps == set(["db1.persons", "db2.person"])


def test_get_view_deps_nested(self):
view_ddl = """
CREATE VIEW test.view1 (
step_rank,
same_step_instance,
id,
t_cd)
AS SELECT ROW_NUMBER() OVER (PARTITION BYID ORDER BY st_cd_start_date) AS step_rank,
ROW_NUMBER() OVER (PARTITION BY id, st_cd ORDER BY st_cd_start_date) AS same_step_instance,
id,
st_cd,
st_cd_start_date,
st_cd_end_date,
datediff(st_cd_end_date, st_cd_start_date) AS step_duration
FROM (
SELECT id, st_cd, st_cd_start_date
FROM (
SELECT id, NVL(st_cd, 'Null') AS st_cd
FROM test.view2 ch
) aa
WHERE Is_Boundry = 1) bb
WHERE st_cd_start_date IS NOT NULL
"""
mock_open = mock.mock_open(read_data=view_ddl)
with mock.patch("builtins.open", mock_open):
deps = get_view_dependencies("/tmp/metastore_view", "tdss.case_actn_hist_st_cd_instances", {})
assert len(deps) == 1 and next(iter(deps)) == "test.view2"

0 comments on commit 8d242b7

Please sign in to comment.