Skip to content

Commit

Permalink
Added assessment for the incompatible RunSubmit API usages (#849)
Browse files Browse the repository at this point in the history
  • Loading branch information
FastLee authored and nkvuong committed Mar 6, 2024
1 parent 2f09978 commit 0b90300
Show file tree
Hide file tree
Showing 25 changed files with 624 additions and 13 deletions.
8 changes: 4 additions & 4 deletions src/databricks/labs/ucx/assessment/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _check_cluster_init_script(self, init_scripts: list[InitScriptInfo], source:
failures.extend(self.check_init_script(init_script_data, source))
return failures

def check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]:
def _check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]:
failures: list[str] = []
for k in INCOMPATIBLE_SPARK_CONFIG_KEYS:
if k in conf:
Expand All @@ -98,7 +98,7 @@ def check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]:
failures.append(f"{AZURE_SP_CONF_FAILURE_MSG} {source}.")
return failures

def check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[str]:
def _check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[str]:
failures: list[str] = []

unsupported_cluster_types = [
Expand All @@ -110,7 +110,7 @@ def check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[s
if support_status != "supported":
failures.append(f"not supported DBR: {cluster.spark_version}")
if cluster.spark_conf is not None:
failures.extend(self.check_spark_conf(cluster.spark_conf, source))
failures.extend(self._check_spark_conf(cluster.spark_conf, source))
# Checking if Azure cluster config is present in cluster policies
if cluster.policy_id is not None:
failures.extend(self._check_cluster_policy(cluster.policy_id, source))
Expand Down Expand Up @@ -149,7 +149,7 @@ def _assess_clusters(self, all_clusters):
success=1,
failures="[]",
)
failures = self.check_cluster_failures(cluster, "cluster")
failures = self._check_cluster_failures(cluster, "cluster")
if len(failures) > 0:
cluster_info.success = 0
cluster_info.failures = json.dumps(failures)
Expand Down
228 changes: 226 additions & 2 deletions src/databricks/labs/ucx/assessment/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,27 @@
import logging
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from hashlib import sha256

from databricks.sdk import WorkspaceClient
from databricks.sdk.service import compute
from databricks.sdk.service.compute import ClusterDetails
from databricks.sdk.service.jobs import BaseJob
from databricks.sdk.service.jobs import (
BaseJob,
BaseRun,
DbtTask,
GitSource,
ListRunsRunType,
PythonWheelTask,
RunConditionTask,
RunTask,
SparkJarTask,
SqlTask,
)

from databricks.labs.ucx.assessment.clusters import CheckClusterMixin
from databricks.labs.ucx.assessment.crawlers import spark_version_compatibility
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -63,7 +78,7 @@ def _assess_jobs(self, all_jobs: list[BaseJob], all_clusters_by_id) -> Iterable[
if not job_id:
continue
cluster_details = ClusterDetails.from_dict(cluster_config.as_dict())
cluster_failures = self.check_cluster_failures(cluster_details, "Job cluster")
cluster_failures = self._check_cluster_failures(cluster_details, "Job cluster")
job_assessment[job_id].update(cluster_failures)

# TODO: next person looking at this - rewrite, as this code makes no sense
Expand Down Expand Up @@ -108,3 +123,212 @@ def snapshot(self) -> Iterable[JobInfo]:
def _try_fetch(self) -> Iterable[JobInfo]:
for row in self._fetch(f"SELECT * FROM {self._schema}.{self._table}"):
yield JobInfo(*row)


@dataclass
class SubmitRunInfo:
run_ids: str # JSON-encoded list of run ids
hashed_id: str # a pseudo id that combines all the hashable attributes of the run
failures: str = "[]" # JSON-encoded list of failures


class SubmitRunsCrawler(CrawlerBase[SubmitRunInfo], JobsMixin, CheckClusterMixin):
_FS_LEVEL_CONF_SETTING_PATTERNS = [
"fs.s3a",
"fs.s3n",
"fs.s3",
"fs.azure",
"fs.wasb",
"fs.abfs",
"fs.adl",
]

def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema: str, num_days_history: int):
super().__init__(sbe, "hive_metastore", schema, "submit_runs", SubmitRunInfo)
self._ws = ws
self._num_days_history = num_days_history

def snapshot(self) -> Iterable[SubmitRunInfo]:
return self._snapshot(self._try_fetch, self._crawl)

@staticmethod
def _dt_to_ms(date_time: datetime):
return int(date_time.timestamp() * 1000)

@staticmethod
def _get_current_dttm() -> datetime:
return datetime.now(timezone.utc)

def _crawl(self) -> Iterable[SubmitRunInfo]:
end = self._dt_to_ms(self._get_current_dttm())
start = self._dt_to_ms(self._get_current_dttm() - timedelta(days=self._num_days_history))
submit_runs = self._ws.jobs.list_runs(
expand_tasks=True,
completed_only=True,
run_type=ListRunsRunType.SUBMIT_RUN,
start_time_from=start,
start_time_to=end,
)
all_clusters = {c.cluster_id: c for c in self._ws.clusters.list()}
return self._assess_job_runs(submit_runs, all_clusters)

def _try_fetch(self) -> Iterable[SubmitRunInfo]:
for row in self._fetch(f"SELECT * FROM {self._schema}.{self._table}"):
yield SubmitRunInfo(*row)

def _check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]:
failures: list[str] = []
for key in conf.keys():
if any(pattern in key for pattern in self._FS_LEVEL_CONF_SETTING_PATTERNS):
failures.append(f"Potentially unsupported config property: {key}")

failures.extend(super()._check_spark_conf(conf, source))
return failures

def _check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[str]:
failures: list[str] = []
if cluster.aws_attributes and cluster.aws_attributes.instance_profile_arn:
failures.append(f"using instance profile: {cluster.aws_attributes.instance_profile_arn}")

failures.extend(super()._check_cluster_failures(cluster, source))
return failures

@staticmethod
def _needs_compatibility_check(spec: compute.ClusterSpec) -> bool:
"""
# we recognize a task as a potentially incompatible one if:
# 1. cluster is not configured with data security mode
# 2. cluster's DBR version is greater than 11.3
"""
if not spec.data_security_mode:
compatibility = spark_version_compatibility(spec.spark_version)
return compatibility == "supported"
return False

def _get_hash_from_run(self, run: BaseRun) -> str:
hashable_items = []
all_tasks: list[RunTask] = run.tasks if run.tasks is not None else []
for task in sorted(all_tasks, key=lambda x: x.task_key if x.task_key is not None else ""):
hashable_items.extend(self._run_task_values(task))

if run.git_source:
hashable_items.extend(self._git_source_values(run.git_source))

return sha256(bytes("|".join(hashable_items).encode("utf-8"))).hexdigest()

@classmethod
def _sql_task_values(cls, task: SqlTask) -> list[str]:
hash_values = [
task.file.path if task.file else None,
task.alert.alert_id if task.alert else None,
task.dashboard.dashboard_id if task.dashboard else None,
task.query.query_id if task.query else None,
]
return [str(value) for value in hash_values if value is not None]

@classmethod
def _git_source_values(cls, source: GitSource) -> list[str]:
hash_values = [source.git_url]
return [str(value) for value in hash_values if value is not None]

@classmethod
def _dbt_task_values(cls, dbt_task: DbtTask) -> list[str]:
hash_values = [
dbt_task.schema,
dbt_task.catalog,
dbt_task.warehouse_id,
dbt_task.project_directory,
",".join(sorted(dbt_task.commands)),
]
return [str(value) for value in hash_values if value is not None]

@classmethod
def _jar_task_values(cls, spark_jar_task: SparkJarTask) -> list[str]:
hash_values = [spark_jar_task.jar_uri, spark_jar_task.main_class_name]
return [str(value) for value in hash_values if value is not None]

@classmethod
def _python_wheel_task_values(cls, pw_task: PythonWheelTask) -> list[str]:
hash_values = [pw_task.package_name, pw_task.entry_point]
return [str(value) for value in hash_values if value is not None]

@classmethod
def _run_condition_task_values(cls, c_task: RunConditionTask) -> list[str]:
hash_values = [c_task.op.value if c_task.op else None, c_task.right, c_task.left, c_task.outcome]
return [str(value) for value in hash_values if value is not None]

@classmethod
def _run_task_values(cls, task: RunTask) -> list[str]:
"""
Retrieve all hashable attributes and append to a list with None removed
- specifically ignore parameters as these change.
"""
hash_values = [
task.notebook_task.notebook_path if task.notebook_task else None,
task.spark_python_task.python_file if task.spark_python_task else None,
(
'|'.join(task.spark_submit_task.parameters)
if (task.spark_submit_task and task.spark_submit_task.parameters)
else None
),
task.pipeline_task.pipeline_id if task.pipeline_task is not None else None,
task.run_job_task.job_id if task.run_job_task else None,
]
hash_lists = [
cls._jar_task_values(task.spark_jar_task) if task.spark_jar_task else None,
(cls._python_wheel_task_values(task.python_wheel_task) if (task.python_wheel_task) else None),
cls._sql_task_values(task.sql_task) if task.sql_task else None,
cls._dbt_task_values(task.dbt_task) if task.dbt_task else None,
cls._run_condition_task_values(task.condition_task) if task.condition_task else None,
cls._git_source_values(task.git_source) if task.git_source else None,
]
# combining all the values from the lists where the list is not "None"
hash_values_from_lists = sum([hash_list for hash_list in hash_lists if hash_list], [])
return [str(value) for value in hash_values + hash_values_from_lists]

def _assess_job_runs(self, submit_runs: Iterable[BaseRun], all_clusters_by_id) -> Iterable[SubmitRunInfo]:
"""
Assessment logic:
1. For eaxch submit run, we analyze all tasks inside this run.
2. Per each task, we calculate a unique hash based on the _retrieve_hash_values_from_task function
3. Then we coalesce all task hashes into a single hash for the submit run
4. Coalesce all runs under the same hash into a single pseudo-job
5. Return a list of pseudo-jobs with their assessment results
"""
result: dict[str, SubmitRunInfo] = {}
runs_per_hash: dict[str, list[int | None]] = {}

for submit_run in submit_runs:
task_failures = []
# v2.1+ API, with tasks
if submit_run.tasks:
all_tasks: list[RunTask] = submit_run.tasks
for task in sorted(all_tasks, key=lambda x: x.task_key if x.task_key is not None else ""):
_task_key = task.task_key if task.task_key is not None else ""
_cluster_details = None
if task.new_cluster:
_cluster_details = ClusterDetails.from_dict(task.new_cluster.as_dict())
if self._needs_compatibility_check(task.new_cluster):
task_failures.append("no data security mode specified")
if task.existing_cluster_id:
_cluster_details = all_clusters_by_id.get(task.existing_cluster_id, None)
if _cluster_details:
task_failures.extend(self._check_cluster_failures(_cluster_details, _task_key))

# v2.0 API, without tasks
elif submit_run.cluster_spec:
_cluster_details = ClusterDetails.from_dict(submit_run.cluster_spec.as_dict())
task_failures.extend(self._check_cluster_failures(_cluster_details, "root_task"))
hashed_id = self._get_hash_from_run(submit_run)
if hashed_id in runs_per_hash:
runs_per_hash[hashed_id].append(submit_run.run_id)
else:
runs_per_hash[hashed_id] = [submit_run.run_id]

result[hashed_id] = SubmitRunInfo(
run_ids=json.dumps(runs_per_hash[hashed_id]),
hashed_id=hashed_id,
failures=json.dumps(list(set(task_failures))),
)

return list(result.values())
4 changes: 2 additions & 2 deletions src/databricks/labs/ucx/assessment/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def _assess_pipelines(self, all_pipelines) -> Iterable[PipelineInfo]:
assert pipeline_response.spec is not None
pipeline_config = pipeline_response.spec.configuration
if pipeline_config:
failures.extend(self.check_spark_conf(pipeline_config, "pipeline"))
failures.extend(self._check_spark_conf(pipeline_config, "pipeline"))
pipeline_cluster = pipeline_response.spec.clusters
if pipeline_cluster:
for cluster in pipeline_cluster:
if cluster.spark_conf:
failures.extend(self.check_spark_conf(cluster.spark_conf, "pipeline cluster"))
failures.extend(self._check_spark_conf(cluster.spark_conf, "pipeline cluster"))
# Checking if cluster config is present in cluster policies
if cluster.policy_id:
failures.extend(self._check_cluster_policy(cluster.policy_id, "pipeline cluster"))
Expand Down
1 change: 1 addition & 0 deletions src/databricks/labs/ucx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class WorkspaceConfig: # pylint: disable=too-many-instance-attributes

override_clusters: dict[str, str] | None = None
policy_id: str | None = None
num_days_submit_runs_history: int = 30

def replace_inventory_variable(self, text: str) -> str:
return text.replace("$inventory", f"hive_metastore.{self.inventory_database}")
Expand Down
3 changes: 2 additions & 1 deletion src/databricks/labs/ucx/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from databricks.labs.ucx.assessment.azure import AzureServicePrincipalInfo
from databricks.labs.ucx.assessment.clusters import ClusterInfo
from databricks.labs.ucx.assessment.init_scripts import GlobalInitScriptInfo
from databricks.labs.ucx.assessment.jobs import JobInfo
from databricks.labs.ucx.assessment.jobs import JobInfo, SubmitRunInfo
from databricks.labs.ucx.assessment.pipelines import PipelineInfo
from databricks.labs.ucx.config import WorkspaceConfig
from databricks.labs.ucx.configure import ConfigureClusterOverrides
Expand Down Expand Up @@ -160,6 +160,7 @@ def deploy_schema(sql_backend: SqlBackend, inventory_schema: str):
functools.partial(table, "table_failures", TableError),
functools.partial(table, "workspace_objects", WorkspaceObjectInfo),
functools.partial(table, "permissions", Permissions),
functools.partial(table, "submit_runs", SubmitRunInfo),
],
)
deployer.deploy_view("objects", "queries/views/objects.sql")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- viz type=table, name=Submit Runs, columns=hashed_id,failure,run_ids
-- widget title=Incompatible Submit Runs, row=6, col=4, size_x=3, size_y=8
SELECT
hashed_id,
EXPLODE(FROM_JSON(failures, 'array<string>')) AS failure,
FROM_JSON(run_ids, 'array<string>') AS run_ids
FROM $inventory.submit_runs
ORDER BY hashed_id DESC
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- viz type=table, name=Submit Runs Failures, columns=failure,submit_runs,run_ids
-- widget title=Incompatible Submit Runs Failures, row=6, col=5, size_x=3, size_y=8
SELECT
EXPLODE(FROM_JSON(failures, 'array<string>')) AS failure,
COUNT(DISTINCT hashed_id) AS submit_runs,
COLLECT_LIST(DISTINCT run_ids) AS run_ids
FROM $inventory.submit_runs
group by 1
2 changes: 2 additions & 0 deletions src/databricks/labs/ucx/queries/views/objects.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ SELECT "clusters" AS object_type, cluster_id AS object_id, failures FROM $invent
UNION ALL
SELECT "global init scripts" AS object_type, script_id AS object_id, failures FROM $inventory.global_init_scripts
UNION ALL
SELECT "submit_runs" AS object_type, hashed_id AS object_id, failures FROM $inventory.submit_runs
UNION ALL
SELECT "pipelines" AS object_type, pipeline_id AS object_id, failures FROM $inventory.pipelines
UNION ALL
SELECT object_type, object_id, failures FROM (
Expand Down
18 changes: 17 additions & 1 deletion src/databricks/labs/ucx/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from databricks.labs.ucx.assessment.azure import AzureServicePrincipalCrawler
from databricks.labs.ucx.assessment.clusters import ClustersCrawler
from databricks.labs.ucx.assessment.init_scripts import GlobalInitScriptCrawler
from databricks.labs.ucx.assessment.jobs import JobsCrawler
from databricks.labs.ucx.assessment.jobs import JobsCrawler, SubmitRunsCrawler
from databricks.labs.ucx.assessment.pipelines import PipelinesCrawler
from databricks.labs.ucx.config import WorkspaceConfig
from databricks.labs.ucx.framework.crawlers import SqlBackend
Expand Down Expand Up @@ -139,6 +139,21 @@ def assess_pipelines(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: Sql
crawler.snapshot()


@task("assessment")
def assess_incompatible_submit_runs(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend):
"""This module scans through all the Submit Runs and identifies those runs which may become incompatible after
the workspace attachment.
It looks for:
- All submit runs with DBR >=11.3 and data_security_mode:None
It also combines several submit runs under a single pseudo_id based on hash of the submit run configuration.
Subsequently, a list of all the incompatible runs with failures are stored in the
`$inventory.submit_runs` table."""
crawler = SubmitRunsCrawler(ws, sql_backend, cfg.inventory_database, cfg.num_days_submit_runs_history)
crawler.snapshot()


@task("assessment", cloud="azure")
def assess_azure_service_principals(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend):
"""This module scans through all the clusters configurations, cluster policies, job cluster configurations,
Expand Down Expand Up @@ -220,6 +235,7 @@ def crawl_groups(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBack
crawl_permissions,
guess_external_locations,
assess_jobs,
assess_incompatible_submit_runs,
assess_clusters,
assess_azure_service_principals,
assess_pipelines,
Expand Down
Loading

0 comments on commit 0b90300

Please sign in to comment.