Skip to content

Commit

Permalink
Added baseline for getting Azure Resource Role Assignments (#764)
Browse files Browse the repository at this point in the history
This pull request adds a new command `save-azure-storage-accounts` to
the UCX CLI tool. This command identifies all storage accounts used by
tables in a HMS metastore, identifies the corresponding service
principals and their permissions on each storage account, and saves the
data in the CSV file on workspace. The new method
`AzureResourcePermissions.save_spn_permissions` performs this
functionality.

Additionally, the pull request adds the `AzureResources` class, which is
used to fetch information about Azure resources such as subscriptions,
storage accounts, and containers. The `AzureRoleAssignment` and
`Principal` classes are also added to represent role assignments and
their associated principals.

Changes related to tests include:

* Test cases are added for the new `save-azure-storage-accounts`
command, including tests for valid and invalid subscription IDs, as well
as tests for cases where there are no external tables, no Azure storage
accounts, or no valid Azure storage accounts.
* The test case for `test_azure_spn_info_without_secret` is updated to
use the `create_autospec` function to create a mock `WorkspaceClient`
object instead of creating a mock object manually.
* The `test_move` function is updated to use the `patch` decorator to
patch the `TableMove.move_tables` method instead of using a mock object.
* The `test_save_azure_storage_accounts_no_ucx` test case is added to
test the behavior when UCX is not installed.
* The `test_save_azure_storage_accounts_not_azure` test case is added to
test the behavior when the workspace is not on Azure.
* The `test_save_azure_storage_accounts_no_azure_cli` test case is added
to test the behavior when the Azure CLI authentication method is not
used.
* The `test_save_azure_storage_accounts_no_subscription_id` test case is
added to test the behavior when the subscription ID is not provided.

---------

Co-authored-by: Hari Selvarajan <hari.selvarajan@databricks.com>
  • Loading branch information
nfx and HariGS-DB authored Jan 17, 2024
1 parent 369893a commit 306e8aa
Show file tree
Hide file tree
Showing 6 changed files with 796 additions and 4 deletions.
8 changes: 7 additions & 1 deletion labs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ commands:
- name: delete_managed
description: Revert and delete managed tables


- name: move
description: move tables across schema/catalog withing a UC metastore
flags:
Expand All @@ -80,3 +79,10 @@ commands:
description: target catalog to migrate schema to
- name: to-schema
description: target schema to migrate tables to

- name: save-azure-storage-accounts
description: Identifies all storage account used by tables, identify spn and its permission on each storage accounts
flags:
- name: subscription-id
description: Subscription to scan storage account in

309 changes: 309 additions & 0 deletions src/databricks/labs/ucx/assessment/azure.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import base64
import csv
import dataclasses
import io
import json
import re
from collections.abc import Iterable
from dataclasses import dataclass

from databricks.sdk import WorkspaceClient
from databricks.sdk.core import (
ApiClient,
AzureCliTokenSource,
Config,
credentials_provider,
)
from databricks.sdk.errors import NotFound
from databricks.sdk.service.catalog import Privilege
from databricks.sdk.service.compute import ClusterSource, Policy
from databricks.sdk.service.workspace import ImportFormat

from databricks.labs.ucx.assessment.crawlers import (
_CLIENT_ENDPOINT_LENGTH,
Expand All @@ -17,6 +28,7 @@
logger,
)
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
from databricks.labs.ucx.hive_metastore.locations import ExternalLocations


@dataclass
Expand Down Expand Up @@ -249,3 +261,300 @@ def snapshot(self) -> Iterable[AzureServicePrincipalInfo]:
def _try_fetch(self) -> Iterable[AzureServicePrincipalInfo]:
for row in self._fetch(f"SELECT * FROM {self._schema}.{self._table}"):
yield AzureServicePrincipalInfo(*row)


@dataclass
class AzureSubscription:
name: str
subscription_id: str
tenant_id: str


class AzureResource:
def __init__(self, resource_id: str):
self._pairs = {}
self._resource_id = resource_id
split = resource_id.lstrip("/").split("/")
if len(split) % 2 != 0:
msg = f"not a list of pairs: {resource_id}"
raise ValueError(msg)
i = 0
while i < len(split):
k = split[i]
v = split[i + 1]
i += 2
self._pairs[k] = v

@property
def subscription_id(self):
return self._pairs.get("subscriptions")

@property
def resource_group(self):
return self._pairs.get("resourceGroups")

@property
def storage_account(self):
return self._pairs.get("storageAccounts")

@property
def container(self):
return self._pairs.get("containers")

def __eq__(self, other):
if not isinstance(other, AzureResource):
return NotImplemented
return self._resource_id == other._resource_id

def __repr__(self):
properties = ["subscription_id", "resource_group", "storage_account", "container"]
pairs = [f"{_}={getattr(self, _)}" for _ in properties]
return f'AzureResource<{", ".join(pairs)}>'

def __str__(self):
return self._resource_id


@dataclass
class Principal:
client_id: str
display_name: str
object_id: str


@dataclass
class AzureRoleAssignment:
resource: AzureResource
scope: AzureResource
principal: Principal
role_name: str


class AzureResources:
def __init__(self, ws: WorkspaceClient, *, include_subscriptions=None):
if not include_subscriptions:
include_subscriptions = []
rm_host = ws.config.arm_environment.resource_manager_endpoint
self._resource_manager = ApiClient(
Config(
host=rm_host,
credentials_provider=self._provider_for(ws.config.arm_environment.service_management_endpoint),
)
)
self._graph = ApiClient(
Config(
host="https://graph.microsoft.com",
credentials_provider=self._provider_for("https://graph.microsoft.com"),
)
)
self._token_source = AzureCliTokenSource(rm_host)
self._include_subscriptions = include_subscriptions
self._role_definitions = {} # type: dict[str, str]
self._principals: dict[str, Principal | None] = {}

def _provider_for(self, endpoint: str):
@credentials_provider("azure-cli", ["host"])
def _credentials(_: Config):
token_source = AzureCliTokenSource(endpoint)

def inner() -> dict[str, str]:
token = token_source.token()
return {"Authorization": f"{token.token_type} {token.access_token}"}

return inner

return _credentials

def _get_subscriptions(self) -> Iterable[AzureSubscription]:
for subscription in self._get_resource("/subscriptions", api_version="2022-12-01").get("value", []):
yield AzureSubscription(
name=subscription["displayName"],
subscription_id=subscription["subscriptionId"],
tenant_id=subscription["tenantId"],
)

def _tenant_id(self):
token = self._token_source.token()
return token.jwt_claims().get("tid")

def subscriptions(self):
tenant_id = self._tenant_id()
for subscription in self._get_subscriptions():
if subscription.tenant_id != tenant_id:
continue
if subscription.subscription_id not in self._include_subscriptions:
continue
yield subscription

def _get_resource(self, path: str, api_version: str):
headers = {"Accept": "application/json"}
query = {"api-version": api_version}
return self._resource_manager.do("GET", path, query=query, headers=headers)

def storage_accounts(self) -> Iterable[AzureResource]:
for subscription in self.subscriptions():
logger.info(f"Checking in subscription {subscription.name} for storage accounts")
path = f"/subscriptions/{subscription.subscription_id}/providers/Microsoft.Storage/storageAccounts"
for storage in self._get_resource(path, "2023-01-01").get("value", []):
resource_id = storage.get("id")
if not resource_id:
continue
yield AzureResource(resource_id)

def containers(self, storage: AzureResource):
for raw in self._get_resource(f"{storage}/blobServices/default/containers", "2023-01-01").get("value", []):
resource_id = raw.get("id")
if not resource_id:
continue
yield AzureResource(resource_id)

def _get_principal(self, principal_id: str) -> Principal | None:
if principal_id in self._principals:
return self._principals[principal_id]
try:
path = f"/v1.0/directoryObjects/{principal_id}"
raw: dict[str, str] = self._graph.do("GET", path) # type: ignore[assignment]
client_id = raw.get("appId")
display_name = raw.get("displayName")
object_id = raw.get("id")
assert client_id is not None
assert display_name is not None
assert object_id is not None
self._principals[principal_id] = Principal(client_id, display_name, object_id)
return self._principals[principal_id]
except NotFound:
# don't load principals from external directories twice
self._principals[principal_id] = None
return self._principals[principal_id]

def role_assignments(
self, resource_id: str, *, principal_types: list[str] | None = None
) -> Iterable[AzureRoleAssignment]:
"""See https://learn.microsoft.com/en-us/rest/api/authorization/role-assignments/list-for-resource"""
if not principal_types:
principal_types = ["ServicePrincipal"]
result = self._get_resource(f"{resource_id}/providers/Microsoft.Authorization/roleAssignments", "2022-04-01")
for role_assignment in result.get("value", []):
assignment_properties = role_assignment.get("properties", {})
principal_type = assignment_properties.get("principalType")
if not principal_type:
continue
if principal_type not in principal_types:
continue
principal_id = assignment_properties.get("principalId")
if not principal_id:
continue
role_definition_id = assignment_properties.get("roleDefinitionId")
if not role_definition_id:
continue
scope = assignment_properties.get("scope")
if not scope:
continue
if role_definition_id not in self._role_definitions:
role_definition = self._get_resource(role_definition_id, "2022-04-01")
definition_properties = role_definition.get("properties", {})
role_name: str = definition_properties.get("roleName")
if not role_name:
continue
self._role_definitions[role_definition_id] = role_name
principal = self._get_principal(principal_id)
if not principal:
continue
role_name = self._role_definitions[role_definition_id]
if scope == "/":
scope = resource_id
yield AzureRoleAssignment(
resource=AzureResource(resource_id),
scope=AzureResource(scope),
principal=principal,
role_name=role_name,
)


@dataclass
class StoragePermissionMapping:
prefix: str
client_id: str
principal: str
privilege: str


class AzureResourcePermissions:
def __init__(self, ws: WorkspaceClient, azurerm: AzureResources, lc: ExternalLocations, folder: str | None = None):
self._locations = lc
self._azurerm = azurerm
self._ws = ws
self._field_names = [_.name for _ in dataclasses.fields(StoragePermissionMapping)]
if not folder:
folder = f"/Users/{ws.current_user.me().user_name}/.ucx"
self._folder = folder
self._levels = {
"Storage Blob Data Contributor": Privilege.WRITE_FILES,
"Storage Blob Data Owner": Privilege.WRITE_FILES,
"Storage Blob Data Reader": Privilege.READ_FILES,
}

def _map_storage(self, storage: AzureResource) -> list[StoragePermissionMapping]:
logger.info(f"Fetching role assignment for {storage.storage_account}")
out = []
for container in self._azurerm.containers(storage):
for role_assignment in self._azurerm.role_assignments(str(container)):
# one principal may be assigned multiple roles with overlapping dataActions, hence appearing
# here in duplicates. hence, role name -> permission level is not enough for the perfect scenario.
if role_assignment.role_name not in self._levels:
continue
privilege = self._levels[role_assignment.role_name].value
out.append(
StoragePermissionMapping(
prefix=f"abfss://{container.container}@{container.storage_account}.dfs.core.windows.net/",
client_id=role_assignment.principal.client_id,
principal=role_assignment.principal.display_name,
privilege=privilege,
)
)
return out

def save_spn_permissions(self) -> str | None:
used_storage_accounts = self._get_storage_accounts()
if len(used_storage_accounts) == 0:
logger.warning(
"There are no external table present with azure storage account. "
"Please check if assessment job is run"
)
return None
storage_account_infos = []
for storage in self._azurerm.storage_accounts():
if storage.storage_account not in used_storage_accounts:
continue
for mapping in self._map_storage(storage):
storage_account_infos.append(mapping)
if len(storage_account_infos) == 0:
logger.error("No storage account found in current tenant with spn permission")
return None
return self._save(storage_account_infos)

def _save(self, storage_infos: list[StoragePermissionMapping]) -> str:
buffer = io.StringIO()
writer = csv.DictWriter(buffer, self._field_names)
writer.writeheader()
for storage_info in storage_infos:
writer.writerow(dataclasses.asdict(storage_info))
buffer.seek(0)
return self._overwrite_mapping(buffer)

def _overwrite_mapping(self, buffer) -> str:
path = f"{self._folder}/azure_storage_account_info.csv"
self._ws.workspace.upload(path, buffer, overwrite=True, format=ImportFormat.AUTO)
return path

def _get_storage_accounts(self) -> list[str]:
external_locations = self._locations.snapshot()
storage_accounts = []
for location in external_locations:
if location.location.startswith("abfss://"):
start = location.location.index("@")
end = location.location.index(".dfs.core.windows.net")
storage_acct = location.location[start + 1 : end]
if storage_acct not in storage_accounts:
storage_accounts.append(storage_acct)
return storage_accounts
30 changes: 30 additions & 0 deletions src/databricks/labs/ucx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from databricks.sdk import AccountClient, WorkspaceClient

from databricks.labs.ucx.account import AccountWorkspaces, WorkspaceInfo
from databricks.labs.ucx.assessment.azure import (
AzureResourcePermissions,
AzureResources,
)
from databricks.labs.ucx.config import AccountConfig, ConnectConfig
from databricks.labs.ucx.framework.crawlers import StatementExecutionBackend
from databricks.labs.ucx.hive_metastore import ExternalLocations, TablesCrawler
Expand Down Expand Up @@ -198,5 +202,31 @@ def move(
tables.move_tables(from_catalog, from_schema, from_table, to_catalog, to_schema, del_table)


@ucx.command
def save_azure_storage_accounts(w: WorkspaceClient, subscription_id: str):
"""identifies all azure storage account used by external tables
identifies all spn which has storage blob reader, blob contributor, blob owner access
saves the data in ucx database."""
installation_manager = InstallationManager(w)
installation = installation_manager.for_user(w.current_user.me())
if not installation:
logger.error(CANT_FIND_UCX_MSG)
return
if not w.config.is_azure:
logger.error("Workspace is not on azure, please run this command on azure databricks workspaces.")
return
if w.config.auth_type != "azure_cli":
logger.error("In order to obtain AAD token, Please run azure cli to authenticate.")
return
if subscription_id == "":
logger.error("Please enter subscription id to scan storage account in.")
return
sql_backend = StatementExecutionBackend(w, installation.config.warehouse_id)
location = ExternalLocations(w, sql_backend, installation.config.inventory_database)
azure_resource_permissions = AzureResourcePermissions(w, AzureResources(w), location)
logger.info("Generating azure storage accounts and service principal permission info")
azure_resource_permissions.save_spn_permissions()


if "__main__" == __name__:
ucx()
Loading

0 comments on commit 306e8aa

Please sign in to comment.