Skip to content

Commit

Permalink
Added Legacy Table ACL grants migration (#1054)
Browse files Browse the repository at this point in the history
## Changes
<!-- Summary of your changes that are easy to understand. Add
screenshots when necessary -->

### Linked issues
<!-- DOC: Link issue with a keyword: close, closes, closed, fix, fixes,
fixed, resolve, resolves, resolved. See
https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword
-->

Resolves #340
To be followed up with PRs for  #887 #907

### Functionality 

- [x] modified existing workflow: `migrate-tables`

### Tests
<!-- How is this tested? Please see the checklist below and also
describe any other relevant tests -->

- [x] manually tested
- [x] added unit tests
- [x] added integration tests
- [x] verified on staging environment (screenshot attached)
  • Loading branch information
nkvuong authored Mar 21, 2024
1 parent 6ab4ae2 commit a1014a4
Show file tree
Hide file tree
Showing 7 changed files with 372 additions and 74 deletions.
3 changes: 2 additions & 1 deletion src/databricks/labs/ucx/hive_metastore/mapping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import re
from collections.abc import Collection
from dataclasses import dataclass
from functools import partial

Expand Down Expand Up @@ -118,7 +119,7 @@ def skip_schema(self, schema: str):
except BadRequest as err:
logger.error(err)

def get_tables_to_migrate(self, tables_crawler: TablesCrawler):
def get_tables_to_migrate(self, tables_crawler: TablesCrawler) -> Collection[TableToMigrate]:
rules = self.load()
# Getting all the source tables from the rules
databases_in_scope = self._get_databases_in_scope({rule.src_schema for rule in rules})
Expand Down
94 changes: 75 additions & 19 deletions src/databricks/labs/ucx/hive_metastore/table_migrate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import datetime
import logging
from collections import defaultdict
Expand All @@ -13,9 +14,17 @@
from databricks.labs.ucx.config import WorkspaceConfig
from databricks.labs.ucx.framework.crawlers import CrawlerBase
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.hive_metastore import TablesCrawler
from databricks.labs.ucx.hive_metastore import GrantsCrawler, TablesCrawler
from databricks.labs.ucx.hive_metastore.grants import Grant
from databricks.labs.ucx.hive_metastore.mapping import Rule, TableMapping
from databricks.labs.ucx.hive_metastore.tables import MigrationCount, Table, What
from databricks.labs.ucx.hive_metastore.tables import (
AclMigrationWhat,
MigrationCount,
Table,
What,
)
from databricks.labs.ucx.hive_metastore.udfs import UdfsCrawler
from databricks.labs.ucx.workspace_access.groups import GroupManager, MigratedGroup

logger = logging.getLogger(__name__)

Expand All @@ -33,16 +42,20 @@ class MigrationStatus:
class TablesMigrate:
def __init__(
self,
tables_crawler: TablesCrawler,
table_crawler: TablesCrawler,
grant_crawler: GrantsCrawler,
ws: WorkspaceClient,
backend: SqlBackend,
table_mapping: TableMapping,
group_manager: GroupManager,
migration_status_refresher,
):
self._tc = tables_crawler
self._tc = table_crawler
self._gc = grant_crawler
self._backend = backend
self._ws = ws
self._tm = table_mapping
self._group = group_manager
self._migration_status_refresher = migration_status_refresher
self._seen_tables: dict[str, str] = {}

Expand All @@ -52,33 +65,50 @@ def for_cli(cls, ws: WorkspaceClient, product='ucx'):
config = installation.load(WorkspaceConfig)
sql_backend = StatementExecutionBackend(ws, config.warehouse_id)
table_crawler = TablesCrawler(sql_backend, config.inventory_database)
udfs_crawler = UdfsCrawler(sql_backend, config.inventory_database)
grants_crawler = GrantsCrawler(table_crawler, udfs_crawler)
table_mapping = TableMapping(installation, ws, sql_backend)
group_manager = GroupManager(sql_backend, ws, config.inventory_database)
migration_status_refresher = MigrationStatusRefresher(ws, sql_backend, config.inventory_database, table_crawler)
return cls(table_crawler, ws, sql_backend, table_mapping, migration_status_refresher)
return cls(
table_crawler, grants_crawler, ws, sql_backend, table_mapping, group_manager, migration_status_refresher
)

def migrate_tables(self, *, what: What | None = None):
def migrate_tables(self, *, what: What | None = None, acl_strategy: AclMigrationWhat | None = None):
self._init_seen_tables()
tables_to_migrate = self._tm.get_tables_to_migrate(self._tc)
if acl_strategy is not None:
grants_to_migrate = self._gc.snapshot()
migrated_groups = self._group.snapshot()
tasks = []
for table in tables_to_migrate:
if not what or table.src.what == what:
tasks.append(partial(self._migrate_table, table.src, table.rule))
if what is not None and table.src.what != what:
continue
match acl_strategy:
case None:
tasks.append(partial(self._migrate_table, table.src, table.rule))
case AclMigrationWhat.LEGACY_TACL:
grants = self._match_grants(table.src, grants_to_migrate, migrated_groups)
tasks.append(partial(self._migrate_table, table.src, table.rule, grants))
case AclMigrationWhat.PRINCIPAL:
# TODO: Implement principal-based ACL migration
pass
Threads.strict("migrate tables", tasks)

def _migrate_table(self, src_table: Table, rule: Rule):
def _migrate_table(self, src_table: Table, rule: Rule, grants: list[Grant] | None = None):
if self._table_already_upgraded(rule.as_uc_table_key):
logger.info(f"Table {src_table.key} already upgraded to {rule.as_uc_table_key}")
return True
if src_table.what == What.DBFS_ROOT_DELTA:
return self._migrate_dbfs_root_table(src_table, rule)
return self._migrate_dbfs_root_table(src_table, rule, grants)
if src_table.what == What.EXTERNAL_SYNC:
return self._migrate_external_table(src_table, rule)
return self._migrate_external_table(src_table, rule, grants)
if src_table.what == What.VIEW:
return self._migrate_view(src_table, rule)
return self._migrate_view(src_table, rule, grants)
logger.info(f"Table {src_table.key} is not supported for migration")
return True

def _migrate_external_table(self, src_table: Table, rule: Rule):
def _migrate_external_table(self, src_table: Table, rule: Rule, grants: list[Grant] | None = None):
target_table_key = rule.as_uc_table_key
table_migrate_sql = src_table.sql_migrate_external(target_table_key)
logger.debug(f"Migrating external table {src_table.key} to using SQL query: {table_migrate_sql}")
Expand All @@ -90,24 +120,36 @@ def _migrate_external_table(self, src_table: Table, rule: Rule):
)
return False
self._backend.execute(src_table.sql_alter_from(rule.as_uc_table_key, self._ws.get_workspace_id()))
return True
return self._migrate_acl(src_table, rule, grants)

def _migrate_dbfs_root_table(self, src_table: Table, rule: Rule):
def _migrate_dbfs_root_table(self, src_table: Table, rule: Rule, grants: list[Grant] | None = None):
target_table_key = rule.as_uc_table_key
table_migrate_sql = src_table.sql_migrate_dbfs(target_table_key)
logger.debug(f"Migrating managed table {src_table.key} to using SQL query: {table_migrate_sql}")
self._backend.execute(table_migrate_sql)
self._backend.execute(src_table.sql_alter_to(rule.as_uc_table_key))
self._backend.execute(src_table.sql_alter_from(rule.as_uc_table_key, self._ws.get_workspace_id()))
return True
return self._migrate_acl(src_table, rule, grants)

def _migrate_view(self, src_table: Table, rule: Rule):
def _migrate_view(self, src_table: Table, rule: Rule, grants: list[Grant] | None = None):
target_table_key = rule.as_uc_table_key
table_migrate_sql = src_table.sql_migrate_view(target_table_key)
logger.debug(f"Migrating view {src_table.key} to using SQL query: {table_migrate_sql}")
self._backend.execute(table_migrate_sql)
self._backend.execute(src_table.sql_alter_to(rule.as_uc_table_key))
self._backend.execute(src_table.sql_alter_from(rule.as_uc_table_key, self._ws.get_workspace_id()))
return self._migrate_acl(src_table, rule, grants)

def _migrate_acl(self, src: Table, rule: Rule, grants: list[Grant] | None):
if grants is None:
return True
for grant in grants:
acl_migrate_sql = grant.uc_grant_sql(src.kind, rule.as_uc_table_key)
if acl_migrate_sql is None:
logger.warning(f"Cannot identify UC grant for {src.kind} {rule.as_uc_table_key}. Skipping.")
continue
logger.debug(f"Migrating acls on {rule.as_uc_table_key} using SQL query: {acl_migrate_sql}")
self._backend.execute(acl_migrate_sql)
return True

def _table_already_upgraded(self, target) -> bool:
Expand Down Expand Up @@ -192,7 +234,7 @@ def print_revert_report(self, *, delete_managed: bool) -> bool | None:
table_sub_header = " |"
for what in list(What):
if len(what.name.split("_")) - 1 < header:
table_sub_header += f"{' '*12}|"
table_sub_header += f"{' ' * 12}|"
continue
table_sub_header += f" {what.name.split('_')[header]:<10} |"
print(table_sub_header)
Expand All @@ -216,6 +258,20 @@ def print_revert_report(self, *, delete_managed: bool) -> bool | None:
def _init_seen_tables(self):
self._seen_tables = self._migration_status_refresher.get_seen_tables()

@staticmethod
def _match_grants(table: Table, grants: Iterable[Grant], migrated_groups: list[MigratedGroup]) -> list[Grant]:
matched_grants = []
for grant in grants:
if grant.database != table.database:
continue
if table.name not in (grant.table, grant.view):
continue
matched_group = [g.name_in_account for g in migrated_groups if g.name_in_workspace == grant.principal]
if len(matched_group) > 0:
grant = dataclasses.replace(grant, principal=matched_group[0])
matched_grants.append(grant)
return matched_grants


class MigrationStatusRefresher(CrawlerBase[MigrationStatus]):
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema, table_crawler: TablesCrawler):
Expand All @@ -241,7 +297,7 @@ def get_seen_tables(self) -> dict[str, str]:
return seen_tables

def is_upgraded(self, schema: str, table: str) -> bool:
result = self._backend.fetch(f"SHOW TBLPROPERTIES {escape_sql_identifier(schema+'.'+table)}")
result = self._backend.fetch(f"SHOW TBLPROPERTIES {escape_sql_identifier(schema + '.' + table)}")
for value in result:
if value["key"] == "upgraded_to":
logger.info(f"{schema}.{table} is set as upgraded")
Expand Down
5 changes: 5 additions & 0 deletions src/databricks/labs/ucx/hive_metastore/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ class What(Enum):
UNKNOWN = auto()


class AclMigrationWhat(Enum):
LEGACY_TACL = auto()
PRINCIPAL = auto()


@dataclass
class Table:
catalog: str
Expand Down
18 changes: 12 additions & 6 deletions src/databricks/labs/ucx/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,11 +431,14 @@ def migrate_external_tables_sync(
- For AWS: TBD
"""
table_crawler = TablesCrawler(sql_backend, cfg.inventory_database)
udf_crawler = UdfsCrawler(sql_backend, cfg.inventory_database)
grant_crawler = GrantsCrawler(table_crawler, udf_crawler)
table_mapping = TableMapping(install, ws, sql_backend)
migration_status_refresher = MigrationStatusRefresher(ws, sql_backend, cfg.inventory_database, table_crawler)
TablesMigrate(table_crawler, ws, sql_backend, table_mapping, migration_status_refresher).migrate_tables(
what=What.EXTERNAL_SYNC
)
group_manager = GroupManager(sql_backend, ws, cfg.inventory_database)
TablesMigrate(
table_crawler, grant_crawler, ws, sql_backend, table_mapping, group_manager, migration_status_refresher
).migrate_tables(what=What.EXTERNAL_SYNC)


@task("migrate-tables", job_cluster="table_migration")
Expand All @@ -448,11 +451,14 @@ def migrate_dbfs_root_delta_tables(
- For AWS: TBD
"""
table_crawler = TablesCrawler(sql_backend, cfg.inventory_database)
udf_crawler = UdfsCrawler(sql_backend, cfg.inventory_database)
grant_crawler = GrantsCrawler(table_crawler, udf_crawler)
table_mapping = TableMapping(install, ws, sql_backend)
migration_status_refresher = MigrationStatusRefresher(ws, sql_backend, cfg.inventory_database, table_crawler)
TablesMigrate(table_crawler, ws, sql_backend, table_mapping, migration_status_refresher).migrate_tables(
what=What.DBFS_ROOT_DELTA
)
group_manager = GroupManager(sql_backend, ws, cfg.inventory_database)
TablesMigrate(
table_crawler, grant_crawler, ws, sql_backend, table_mapping, group_manager, migration_status_refresher
).migrate_tables(what=What.DBFS_ROOT_DELTA)


def main(*argv):
Expand Down
25 changes: 24 additions & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
AzureServicePrincipalCrawler,
AzureServicePrincipalInfo,
)
from databricks.labs.ucx.hive_metastore import TablesCrawler
from databricks.labs.ucx.hive_metastore import GrantsCrawler, TablesCrawler
from databricks.labs.ucx.hive_metastore.grants import Grant
from databricks.labs.ucx.hive_metastore.mapping import Rule, TableMapping
from databricks.labs.ucx.hive_metastore.tables import Table
from databricks.labs.ucx.hive_metastore.udfs import Udf, UdfsCrawler
Expand Down Expand Up @@ -147,6 +148,28 @@ def snapshot(self) -> list[Udf]:
return self._udfs


class StaticGrantsCrawler(GrantsCrawler):
def __init__(self, tc: TablesCrawler, udf: UdfsCrawler, grants: list[Grant]):
super().__init__(tc, udf)
self._grants = [
Grant(
principal=_.principal,
action_type=_.action_type,
catalog=_.catalog,
database=_.database,
table=_.table,
view=_.view,
udf=_.udf,
any_file=_.any_file,
anonymous_function=_.anonymous_function,
)
for _ in grants
]

def snapshot(self) -> list[Grant]:
return self._grants


class StaticTableMapping(TableMapping):
def __init__(self, workspace_client: WorkspaceClient, sb: SqlBackend, rules: list[Rule]):
installation = Installation(workspace_client, 'ucx')
Expand Down
Loading

0 comments on commit a1014a4

Please sign in to comment.