diff --git a/README.md b/README.md index c53c0207..95c33b1d 100644 --- a/README.md +++ b/README.md @@ -339,7 +339,7 @@ usage: import_db.py [-h] [--users] [--workspace] [--workspace-top-level] [--no-ssl-verification] [--silent] [--debug] [--set-export-dir SET_EXPORT_DIR] [--pause-all-jobs] [--unpause-all-jobs] [--import-pause-status] - [--delete-all-jobs] [--last-session] + [--delete-all-jobs] [--last-session] [--sort-views] Import full workspace artifacts into Databricks @@ -391,6 +391,9 @@ optional arguments: --delete-all-jobs Delete all jobs --last-session The session to compare against. If set, the script compares current sesssion with the last session and only import updated and new notebooks. + --sort-views + Sort all views topologically based upon dependencies before importing. + e.g. if view A is created from view B, B will be imported before A. This will solve the cases when views are created from other views. ``` --- diff --git a/data/notebooks/Import_Table_ACLs_delta.py b/data/notebooks/Import_Table_ACLs_delta.py index bb906e03..66d0422a 100644 --- a/data/notebooks/Import_Table_ACLs_delta.py +++ b/data/notebooks/Import_Table_ACLs_delta.py @@ -135,7 +135,7 @@ def execute_sql_statements(sqls): if sql: print(f"{sql};") try: -# spark.sql(sql) + spark.sql(sql) num_sucessfully_executed = num_sucessfully_executed+1 except: error_causing_sqls.append({'sql': sql, 'error': sys.exc_info()}) diff --git a/dbclient/HiveClient.py b/dbclient/HiveClient.py index fa0adec2..a65ccd1f 100644 --- a/dbclient/HiveClient.py +++ b/dbclient/HiveClient.py @@ -10,6 +10,8 @@ import logging_utils import re from dbclient import * +from collections import defaultdict +from dbclient.common.ViewSort import create_dependency_graph, sort_views_topology, unpack_view_db_name class HiveClient(ClustersClient): @@ -361,7 +363,7 @@ def move_table_view(self, db_name, tbl_name, local_table_ddl, views_dir='metasto return False def import_hive_metastore(self, cluster_name=None, metastore_dir='metastore/', views_dir='metastore_views/', - has_unicode=False, should_repair_table=False): + has_unicode=False, should_repair_table=False, sort_views=False): metastore_local_dir = self.get_export_dir() + metastore_dir metastore_view_dir = self.get_export_dir() + views_dir error_logger = logging_utils.get_error_logger( @@ -413,21 +415,50 @@ def import_hive_metastore(self, cluster_name=None, metastore_dir='metastore/', v logging.error("Error: Only databases should exist at this level: {0}".format(db_name)) self.delete_dir_if_empty(metastore_view_dir + db_name) views_db_list = self.listdir(metastore_view_dir) - for db_name in views_db_list: - local_view_db_path = metastore_view_dir + db_name - database_attributes = all_db_details_json.get(db_name, '') - db_path = database_attributes.get('Location') - if os.path.isdir(local_view_db_path): - views = self.listdir(local_view_db_path) - for view_name in views: - full_view_name = f'{db_name}.{view_name}' - if not checkpoint_metastore_set.contains(full_view_name): - logging.info(f"Importing view {full_view_name}") - local_view_ddl = metastore_view_dir + db_name + '/' + view_name - resp = self.apply_table_ddl(local_view_ddl, ec_id, cid, db_path, has_unicode) - if logging_utils.log_response_error(error_logger, resp): - checkpoint_metastore_set.write(full_view_name) - logging.info(resp) + + if sort_views: + # To sort views, we will scan and get all the views first + all_view_set = set() + for db_name in views_db_list: + local_view_db_path = metastore_view_dir + db_name + if os.path.isdir(local_view_db_path): + views = self.listdir(local_view_db_path) + for v in views: + all_view_set.add(f"{db_name}.{v}") + logging.info(f"all views: {all_view_set}") + # Build dependency graph of the views + view_parents_dct = create_dependency_graph(metastore_view_dir, all_view_set) + # Sort the views using the dependency graph + logging.info(f"view graph: {view_parents_dct}") + sorted_views = sort_views_topology(view_parents_dct) + logging.info(f"Importing order of views: {sorted_views}") + # Import views in the sorted order + for full_view_name in sorted_views: + if not checkpoint_metastore_set.contains(full_view_name): + logging.info(f"Importing view {full_view_name}") + db_name, view_name = unpack_view_db_name(full_view_name) + local_view_ddl = metastore_view_dir + db_name + '/' + view_name + resp = self.apply_table_ddl(local_view_ddl, ec_id, cid, db_path, has_unicode) + if logging_utils.log_response_error(error_logger, resp): + checkpoint_metastore_set.write(full_view_name) + logging.info(resp) + + else: + for db_name in views_db_list: + local_view_db_path = metastore_view_dir + db_name + database_attributes = all_db_details_json.get(db_name, '') + db_path = database_attributes.get('Location') + if os.path.isdir(local_view_db_path): + views = self.listdir(local_view_db_path) + for view_name in views: + full_view_name = f'{db_name}.{view_name}' + if not checkpoint_metastore_set.contains(full_view_name): + logging.info(f"Importing view {full_view_name}") + local_view_ddl = metastore_view_dir + db_name + '/' + view_name + resp = self.apply_table_ddl(local_view_ddl, ec_id, cid, db_path, has_unicode) + if logging_utils.log_response_error(error_logger, resp): + checkpoint_metastore_set.write(full_view_name) + logging.info(resp) # repair legacy tables if should_repair_table: diff --git a/dbclient/common/ViewSort.py b/dbclient/common/ViewSort.py new file mode 100644 index 00000000..b31af381 --- /dev/null +++ b/dbclient/common/ViewSort.py @@ -0,0 +1,68 @@ +from collections import deque +import sqlparse +from typing import Set, List +from collections import defaultdict +import os, re + + +def extract_source_tables(ddl_query: str, all_valid_names: Set[str]): + """ + Extracts table names from a SQL query that includes nested FROM statements. + Returns a list of unique table names in the order they appear in the query. + """ + sql_query = ddl_query.replace("`", "") + table_names = set() + regex = r'\b(?:FROM|JOIN|UNION)\b\s+([\w.]+)' + matches = re.findall(regex, sql_query) + for match in matches: + table_name = match.lower() + if ((all_valid_names and table_name in all_valid_names) or (not all_valid_names)) \ + and table_name not in table_names: + table_names.add(table_name) + return table_names + +def unpack_view_db_name(full_view_name: str): + parts = full_view_name.split(".") + assert len(parts) == 2, f"{full_view_name} is not formatted correctly." + return parts[0], parts[1] + +def get_view_dependencies(metastore_view_dir: str, full_view_name: str, all_views: Set[str]): + print(f"processing dependencies of {full_view_name}") + db_name, vw = unpack_view_db_name(full_view_name) + # ddl_query = spark.sql(f"show create table {view_name}").collect()[0][0] + ddl_full_path = os.path.join(metastore_view_dir, db_name, vw) + dep_set = set() + with open(ddl_full_path, "r") as f: + ddl_query = f.read() + identifiers = extract_source_tables(ddl_query, all_views) + for token in identifiers: + if full_view_name.lower() in token.lower(): + continue + dep_set.add(token) + print(f"dependencies: {dep_set}") + return dep_set + +def create_dependency_graph(metastore_view_dir: str, all_views: Set[str]): + view_parents_dct = dict() + for view_name in all_views: + dep_views = get_view_dependencies(metastore_view_dir, view_name, all_views) + view_parents_dct[view_name] = dep_views + return view_parents_dct + +def sort_views_topology(view_parents_dct): + view_children_dct = defaultdict(set) + q = deque([]) + for view, parents in view_parents_dct.items(): + for pview in parents: + view_children_dct[pview].add(view) + if not parents: + q.append(view) + sorted_views = [] + while q: + cur_view = q.popleft() + sorted_views.append(cur_view) + for child_view in view_children_dct[cur_view]: + view_parents_dct[child_view].remove(cur_view) + if not view_parents_dct[child_view]: + q.append(child_view) + return sorted_views \ No newline at end of file diff --git a/dbclient/parser.py b/dbclient/parser.py index c478e717..c4071822 100644 --- a/dbclient/parser.py +++ b/dbclient/parser.py @@ -390,6 +390,9 @@ def get_import_parser(): parser.add_argument('--retry-backoff', type=float, default=1.0, help='Backoff factor to apply between retry attempts when making calls to Databricks API') + parser.add_argument('--sort-views', action='store_true', default=False, + help='If True, the views will be sorted based upon dependencies before importing.') + return parser @@ -568,4 +571,7 @@ def get_pipeline_parser() -> argparse.ArgumentParser: parser.add_argument('--last-session', action='store', default='', help='If set, the script compares current sesssion with the last session and only import updated and new notebooks.') + parser.add_argument('--sort-views', action='store_true', default=False, + help='If True, the views will be sorted based upon dependencies before importing.') + return parser diff --git a/tasks/tasks.py b/tasks/tasks.py index f91caa9e..593b10ef 100644 --- a/tasks/tasks.py +++ b/tasks/tasks.py @@ -315,7 +315,8 @@ def run(self): # log job configs hive_c.import_hive_metastore(cluster_name=self.args.cluster_name, has_unicode=self.args.metastore_unicode, - should_repair_table=self.args.repair_metastore_tables) + should_repair_table=self.args.repair_metastore_tables, + sort_views = self.args.sort_views) class MetastoreTableACLExportTask(AbstractTask): diff --git a/test/view_sort_test.py b/test/view_sort_test.py new file mode 100644 index 00000000..03216e00 --- /dev/null +++ b/test/view_sort_test.py @@ -0,0 +1,81 @@ +import unittest +from unittest.mock import MagicMock +from dbclient import HiveClient +from dbclient.test.TestUtils import TEST_CONFIG +from io import StringIO +from dbclient.common.ViewSort import sort_views_topology, get_view_dependencies +from unittest import mock + +class TestViews(unittest.TestCase): + def test_sort_views_topology(self): + view_parents_graph = { + "view1": {"view2", "view3"}, + "view3": {"view4"}, + "view2": {}, + "view4": {"view5", "view6"}, + "view5": {}, + "view6": {}, + "view7": {} + } + views = sort_views_topology(view_parents_graph) + assert views.index("view1") > views.index("view2") and views.index("view1") > views.index("view3") \ + and views.index("view3") > views.index("view4") \ + and views.index("view4") > views.index("view5") and views.index("view4") > views.index("view6") + + def test_get_view_dependencies(self): + view_ddl = """ + CREATE VIEW `default`.`test_view` ( + first_name, + middle_name, + last_name, + relationship_type_cd, + receipt_number) +TBLPROPERTIES ( + 'transient_lastDdlTime' = '1674499157') +AS SELECT + p.first_name AS first_name, + p.middle_name AS middle_name, + p.last_name AS last_name, + pc.role_id AS relationship_type_cd, + pc.receipt_number AS receipt_number + FROM `db1`.`persons` pc + JOIN `db2`.`person` p + ON pc.person_id = p.person_id + AND pc.svr_ctr_cd = p.svr_ctr_cd + WHERE + pc.role_id = 11 + AND (p.first_name is not null or p.middle_name is not null or p.first_name is not null ) + """ + mock_open = mock.mock_open(read_data=view_ddl) + with mock.patch("builtins.open", mock_open): + deps = get_view_dependencies("/tmp/metastore_view", "default.test_view", {}) + assert deps == set(["db1.persons", "db2.person"]) + + + def test_get_view_deps_nested(self): + view_ddl = """ + CREATE VIEW test.view1 ( + step_rank, + same_step_instance, + id, + t_cd) +AS SELECT ROW_NUMBER() OVER (PARTITION BYID ORDER BY st_cd_start_date) AS step_rank, + ROW_NUMBER() OVER (PARTITION BY id, st_cd ORDER BY st_cd_start_date) AS same_step_instance, + id, + st_cd, + st_cd_start_date, + st_cd_end_date, + datediff(st_cd_end_date, st_cd_start_date) AS step_duration + FROM ( + SELECT id, st_cd, st_cd_start_date + FROM ( + SELECT id, NVL(st_cd, 'Null') AS st_cd + FROM test.view2 ch + ) aa + WHERE Is_Boundry = 1) bb +WHERE st_cd_start_date IS NOT NULL + """ + mock_open = mock.mock_open(read_data=view_ddl) + with mock.patch("builtins.open", mock_open): + deps = get_view_dependencies("/tmp/metastore_view", "tdss.case_actn_hist_st_cd_instances", {}) + assert len(deps) == 1 and next(iter(deps)) == "test.view2" \ No newline at end of file