diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index cc3ec906ac..13e74559ac 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -26,6 +26,7 @@ from databricks.labs.ucx.hive_metastore.table_migrate import TablesMigrate from databricks.labs.ucx.hive_metastore.table_move import TableMove from databricks.labs.ucx.install import WorkspaceInstallation +from databricks.labs.ucx.installer.workflows import WorkflowsInstallation from databricks.labs.ucx.workspace_access.clusters import ClusterAccess from databricks.labs.ucx.workspace_access.groups import GroupManager @@ -41,7 +42,7 @@ @ucx.command def workflows(w: WorkspaceClient): """Show deployed workflows and their state""" - installation = WorkspaceInstallation.current(w) + installation = WorkflowsInstallation.current(w) logger.info("Fetching deployed jobs...") print(json.dumps(installation.latest_job_status())) @@ -162,7 +163,7 @@ def repair_run(w: WorkspaceClient, step): """Repair Run the Failed Job""" if not step: raise KeyError("You did not specify --step") - installation = WorkspaceInstallation.current(w) + installation = WorkflowsInstallation.current(w) logger.info(f"Repair Running {step} Job") installation.repair_run(step) diff --git a/src/databricks/labs/ucx/install.py b/src/databricks/labs/ucx/install.py index 8f23496535..0f8d88f803 100644 --- a/src/databricks/labs/ucx/install.py +++ b/src/databricks/labs/ucx/install.py @@ -1,20 +1,15 @@ import functools import logging import os -import re -import sys import time import webbrowser from collections.abc import Callable -from dataclasses import replace -from datetime import datetime, timedelta -from pathlib import Path +from datetime import timedelta from typing import Any import databricks.sdk.errors from databricks.labs.blueprint.entrypoint import get_logger from databricks.labs.blueprint.installation import Installation, SerdeError -from databricks.labs.blueprint.installer import InstallState from databricks.labs.blueprint.parallel import ManyError, Threads from databricks.labs.blueprint.tui import Prompts from databricks.labs.blueprint.upgrades import Upgrades @@ -22,32 +17,12 @@ from databricks.labs.lsql.backends import SqlBackend, StatementExecutionBackend from databricks.labs.lsql.deployment import SchemaDeployer from databricks.sdk import WorkspaceClient -from databricks.sdk.errors import ( # pylint: disable=redefined-builtin - Aborted, +from databricks.sdk.errors import ( AlreadyExists, BadRequest, - Cancelled, - DataLoss, - DeadlineExceeded, - InternalError, InvalidParameterValue, NotFound, - NotImplemented, - OperationFailed, - PermissionDenied, - RequestLimitExceeded, - ResourceAlreadyExists, - ResourceConflict, - ResourceDoesNotExist, - ResourceExhausted, - TemporarilyUnavailable, - TooManyRequests, - Unauthenticated, - Unknown, ) -from databricks.sdk.retries import retried -from databricks.sdk.service import compute, jobs -from databricks.sdk.service.jobs import RunLifeCycleState, RunResultState from databricks.sdk.service.sql import ( CreateWarehouseRequestWarehouseType, EndpointInfoWarehouseType, @@ -61,17 +36,16 @@ from databricks.labs.ucx.assessment.jobs import JobInfo, SubmitRunInfo from databricks.labs.ucx.assessment.pipelines import PipelineInfo from databricks.labs.ucx.config import WorkspaceConfig -from databricks.labs.ucx.configure import ConfigureClusterOverrides from databricks.labs.ucx.framework.dashboards import DashboardFromFiles -from databricks.labs.ucx.framework.tasks import _TASKS, Task from databricks.labs.ucx.hive_metastore.grants import Grant from databricks.labs.ucx.hive_metastore.locations import ExternalLocation, Mount from databricks.labs.ucx.hive_metastore.table_migrate import MigrationStatus from databricks.labs.ucx.hive_metastore.table_size import TableSize from databricks.labs.ucx.hive_metastore.tables import Table, TableError from databricks.labs.ucx.installer.hms_lineage import HiveMetastoreLineageEnabler +from databricks.labs.ucx.installer.mixins import InstallationMixin from databricks.labs.ucx.installer.policy import ClusterPolicyInstaller -from databricks.labs.ucx.runtime import main +from databricks.labs.ucx.installer.workflows import WorkflowsInstallation from databricks.labs.ucx.workspace_access.base import Permissions from databricks.labs.ucx.workspace_access.generic import WorkspaceObjectInfo from databricks.labs.ucx.workspace_access.groups import ConfigureGroups, MigratedGroup @@ -80,58 +54,6 @@ WAREHOUSE_PREFIX = "Unity Catalog Migration" NUM_USER_ATTEMPTS = 10 # number of attempts user gets at answering a question -EXTRA_TASK_PARAMS = { - "job_id": "{{job_id}}", - "run_id": "{{run_id}}", - "parent_run_id": "{{parent_run_id}}", -} -DEBUG_NOTEBOOK = """# Databricks notebook source -# MAGIC %md -# MAGIC # Debug companion for UCX installation (see [README]({readme_link})) -# MAGIC -# MAGIC Production runs are supposed to be triggered through the following jobs: {job_links} -# MAGIC -# MAGIC **This notebook is overwritten with each UCX update/(re)install.** - -# COMMAND ---------- - -# MAGIC %pip install /Workspace{remote_wheel} -dbutils.library.restartPython() - -# COMMAND ---------- - -import logging -from pathlib import Path -from databricks.labs.blueprint.installation import Installation -from databricks.labs.blueprint.logger import install_logger -from databricks.labs.ucx.__about__ import __version__ -from databricks.labs.ucx.config import WorkspaceConfig -from databricks.sdk import WorkspaceClient - -install_logger() -logging.getLogger("databricks").setLevel("DEBUG") - -cfg = Installation.load_local(WorkspaceConfig, Path("/Workspace{config_file}")) -ws = WorkspaceClient() - -print(__version__) -""" - -TEST_RUNNER_NOTEBOOK = """# Databricks notebook source -# MAGIC %pip install /Workspace{remote_wheel} -dbutils.library.restartPython() - -# COMMAND ---------- - -from databricks.labs.ucx.runtime import main - -main(f'--config=/Workspace{config_file}', - f'--task=' + dbutils.widgets.get('task'), - f'--job_id=' + dbutils.widgets.get('job_id'), - f'--run_id=' + dbutils.widgets.get('run_id'), - f'--parent_run_id=' + dbutils.widgets.get('parent_run_id')) -""" - logger = logging.getLogger(__name__) @@ -202,14 +124,17 @@ def run( sql_backend_factory = self._new_sql_backend if not wheel_builder_factory: wheel_builder_factory = self._new_wheel_builder + wheels = wheel_builder_factory() + workflows_installer = WorkflowsInstallation( + config, self._installation, self._ws, wheels, self._prompts, self._product_info, verify_timeout + ) workspace_installation = WorkspaceInstallation( config, self._installation, sql_backend_factory(config), - wheel_builder_factory(), self._ws, + workflows_installer, self._prompts, - verify_timeout, self._product_info, ) try: @@ -278,26 +203,6 @@ def _configure_new_installation(self) -> WorkspaceConfig: policy_id, instance_profile, spark_conf_dict = self._policy_installer.create(inventory_database) - # Save configurable spark_conf for table migration cluster - # parallelism will not be needed if backlog is fixed in https://databricks.atlassian.net/browse/ES-975874 - parallelism = self._prompts.question( - "Parallelism for migrating dbfs root delta tables with deep clone", default="200", valid_number=True - ) - if not spark_conf_dict: - spark_conf_dict = {} - spark_conf_dict.update({'spark.sql.sources.parallelPartitionDiscovery.parallelism': parallelism}) - # mix max workers for auto-scale migration job cluster - min_workers = int( - self._prompts.question( - "Min workers for auto-scale job cluster for table migration", default="1", valid_number=True - ) - ) - max_workers = int( - self._prompts.question( - "Max workers for auto-scale job cluster for table migration", default="10", valid_number=True - ) - ) - # Check if terraform is being used is_terraform_used = self._prompts.confirm("Do you use Terraform to deploy your infrastructure?") @@ -314,8 +219,6 @@ def _configure_new_installation(self) -> WorkspaceConfig: num_threads=num_threads, instance_profile=instance_profile, spark_conf=spark_conf_dict, - min_workers=min_workers, - max_workers=max_workers, policy_id=policy_id, is_terraform_used=is_terraform_used, include_databases=self._select_databases(), @@ -374,28 +277,25 @@ def _check_inventory_database_exists(self, inventory_database: str): continue -class WorkspaceInstallation: +class WorkspaceInstallation(InstallationMixin): def __init__( self, config: WorkspaceConfig, installation: Installation, sql_backend: SqlBackend, - wheels: WheelsV2, ws: WorkspaceClient, + workflows_installer: WorkflowsInstallation, prompts: Prompts, - verify_timeout: timedelta, product_info: ProductInfo, ): self._config = config self._installation = installation self._ws = ws - self._wheels = wheels self._sql_backend = sql_backend + self._workflows_installer = workflows_installer self._prompts = prompts - self._verify_timeout = verify_timeout - self._state = InstallState.from_installation(installation) - self._this_file = Path(__file__) self._product_info = product_info + super().__init__(config, installation, ws) @classmethod def current(cls, ws: WorkspaceClient): @@ -406,7 +306,9 @@ def current(cls, ws: WorkspaceClient): wheels = product_info.wheels(ws) prompts = Prompts() timeout = timedelta(minutes=2) - return WorkspaceInstallation(config, installation, sql_backend, wheels, ws, prompts, timeout, product_info) + workflows_installer = WorkflowsInstallation(config, installation, ws, wheels, prompts, product_info, timeout) + + return cls(config, installation, sql_backend, ws, workflows_installer, prompts, product_info) @property def config(self): @@ -423,10 +325,9 @@ def run(self): [ self._create_dashboards, self._create_database, - self.create_jobs, ], ) - + self._workflows_installer.create_jobs() readme_url = self._create_readme() logger.info(f"Installation completed successfully! Please refer to the {readme_url} for the next steps.") @@ -457,7 +358,7 @@ def _create_dashboards(self): local_query_files = find_project_root(__file__) / "src/databricks/labs/ucx/queries" dash = DashboardFromFiles( self._ws, - state=self._state, + state=self._workflows_installer.state, local_folder=local_query_files, remote_folder=f"{self._installation.install_folder()}/queries", name_prefix=self._name("UCX "), @@ -466,205 +367,6 @@ def _create_dashboards(self): ) dash.create_dashboards() - def run_workflow(self, step: str): - job_id = int(self._state.jobs[step]) - logger.debug(f"starting {step} job: {self._ws.config.host}#job/{job_id}") - job_run_waiter = self._ws.jobs.run_now(job_id) - try: - job_run_waiter.result() - except OperationFailed as err: - # currently we don't have any good message from API, so we have to work around it. - job_run = self._ws.jobs.get_run(job_run_waiter.run_id) - raise self._infer_error_from_job_run(job_run) from err - - def _infer_error_from_job_run(self, job_run) -> Exception: - errors: list[Exception] = [] - timeouts: list[DeadlineExceeded] = [] - assert job_run.tasks is not None - for run_task in job_run.tasks: - error = self._infer_error_from_task_run(run_task) - if not error: - continue - if isinstance(error, DeadlineExceeded): - timeouts.append(error) - continue - errors.append(error) - assert job_run.state is not None - assert job_run.state.state_message is not None - if len(errors) == 1: - return errors[0] - all_errors = errors + timeouts - if len(all_errors) == 0: - return Unknown(job_run.state.state_message) - return ManyError(all_errors) - - def _infer_error_from_task_run(self, run_task: jobs.RunTask) -> Exception | None: - if not run_task.state: - return None - if run_task.state.result_state == jobs.RunResultState.TIMEDOUT: - msg = f"{run_task.task_key}: The run was stopped after reaching the timeout" - return DeadlineExceeded(msg) - if run_task.state.result_state != jobs.RunResultState.FAILED: - return None - assert run_task.run_id is not None - run_output = self._ws.jobs.get_run_output(run_task.run_id) - if not run_output: - msg = f'No run output. {run_task.state.state_message}' - return InternalError(msg) - if logger.isEnabledFor(logging.DEBUG): - if run_output.error_trace: - sys.stderr.write(run_output.error_trace) - if not run_output.error: - msg = f'No error in run output. {run_task.state.state_message}' - return InternalError(msg) - return self._infer_task_exception(f"{run_task.task_key}: {run_output.error}") - - @staticmethod - def _infer_task_exception(haystack: str) -> Exception: - needles = [ - BadRequest, - Unauthenticated, - PermissionDenied, - NotFound, - ResourceConflict, - TooManyRequests, - Cancelled, - InternalError, - NotImplemented, - TemporarilyUnavailable, - DeadlineExceeded, - InvalidParameterValue, - ResourceDoesNotExist, - Aborted, - AlreadyExists, - ResourceAlreadyExists, - ResourceExhausted, - RequestLimitExceeded, - Unknown, - DataLoss, - ValueError, - KeyError, - ] - constructors: dict[re.Pattern, type[Exception]] = { - re.compile(r".*\[TABLE_OR_VIEW_NOT_FOUND] (.*)"): NotFound, - re.compile(r".*\[SCHEMA_NOT_FOUND] (.*)"): NotFound, - } - for klass in needles: - constructors[re.compile(f".*{klass.__name__}: (.*)")] = klass - for pattern, klass in constructors.items(): - match = pattern.match(haystack) - if match: - return klass(match.group(1)) - return Unknown(haystack) - - @property - def _warehouse_id(self) -> str: - if self._config.warehouse_id is not None: - logger.info("Fetching warehouse_id from a config") - return self._config.warehouse_id - warehouses = [_ for _ in self._ws.warehouses.list() if _.warehouse_type == EndpointInfoWarehouseType.PRO] - warehouse_id = self._config.warehouse_id - if not warehouse_id and not warehouses: - msg = "need either configured warehouse_id or an existing PRO SQL warehouse" - raise ValueError(msg) - if not warehouse_id: - warehouse_id = warehouses[0].id - self._config.warehouse_id = warehouse_id - return warehouse_id - - @property - def _my_username(self): - if not hasattr(self, "_me"): - self._me = self._ws.current_user.me() - is_workspace_admin = any(g.display == "admins" for g in self._me.groups) - if not is_workspace_admin: - msg = "Current user is not a workspace admin" - raise PermissionError(msg) - return self._me.user_name - - @property - def _short_name(self): - if "@" in self._my_username: - username = self._my_username.split("@")[0] - else: - username = self._me.display_name - return username - - @property - def _config_file(self): - return f"{self._installation.install_folder()}/config.yml" - - def _name(self, name: str) -> str: - prefix = os.path.basename(self._installation.install_folder()).removeprefix('.') - return f"[{prefix.upper()}] {name}" - - def _upload_wheel(self): - with self._wheels: - try: - self._wheels.upload_to_dbfs() - except PermissionDenied as err: - if not self._prompts: - raise RuntimeWarning("no Prompts instance found") from err - logger.warning(f"Uploading wheel file to DBFS failed, DBFS is probably write protected. {err}") - configure_cluster_overrides = ConfigureClusterOverrides(self._ws, self._prompts.choice_from_dict) - self._config.override_clusters = configure_cluster_overrides.configure() - self._installation.save(self._config) - return self._wheels.upload_to_wsfs() - - def create_jobs(self): - logger.debug(f"Creating jobs from tasks in {main.__name__}") - remote_wheel = self._upload_wheel() - desired_steps = {t.workflow for t in _TASKS.values() if t.cloud_compatible(self._ws.config)} - wheel_runner = None - - if self._config.override_clusters: - wheel_runner = self._upload_wheel_runner(remote_wheel) - for step_name in desired_steps: - settings = self._job_settings(step_name, remote_wheel) - if self._config.override_clusters: - settings = self._apply_cluster_overrides(settings, self._config.override_clusters, wheel_runner) - self._deploy_workflow(step_name, settings) - - for step_name, job_id in self._state.jobs.items(): - if step_name not in desired_steps: - try: - logger.info(f"Removing job_id={job_id}, as it is no longer needed") - self._ws.jobs.delete(job_id) - except InvalidParameterValue: - logger.warning(f"step={step_name} does not exist anymore for some reason") - continue - - self._state.save() - self._create_debug(remote_wheel) - - def _deploy_workflow(self, step_name: str, settings): - if step_name in self._state.jobs: - try: - job_id = int(self._state.jobs[step_name]) - logger.info(f"Updating configuration for step={step_name} job_id={job_id}") - return self._ws.jobs.reset(job_id, jobs.JobSettings(**settings)) - except InvalidParameterValue: - del self._state.jobs[step_name] - logger.warning(f"step={step_name} does not exist anymore for some reason") - return self._deploy_workflow(step_name, settings) - logger.info(f"Creating new job configuration for step={step_name}") - new_job = self._ws.jobs.create(**settings) - assert new_job.job_id is not None - self._state.jobs[step_name] = str(new_job.job_id) - return None - - @staticmethod - def _sorted_tasks() -> list[Task]: - return sorted(_TASKS.values(), key=lambda x: x.task_id) - - @classmethod - def _step_list(cls) -> list[str]: - step_list = [] - for task in cls._sorted_tasks(): - if task.workflow not in step_list: - step_list.append(task.workflow) - return step_list - def _create_readme(self) -> str: debug_notebook_link = self._installation.workspace_markdown_link('debug notebook', 'DEBUG.py') markdown = [ @@ -673,25 +375,29 @@ def _create_readme(self) -> str: "Here are the URLs and descriptions of workflows that trigger various stages of migration.", "All jobs are defined with necessary cluster configurations and DBR versions.\n", ] - for step_name in self._step_list(): - if step_name not in self._state.jobs: + for step_name in self.step_list(): + if step_name not in self._workflows_installer.state.jobs: logger.warning(f"Skipping step '{step_name}' since it was not deployed.") continue - job_id = self._state.jobs[step_name] + job_id = self._workflows_installer.state.jobs[step_name] dashboard_link = "" - dashboards_per_step = [d for d in self._state.dashboards.keys() if d.startswith(step_name)] + dashboards_per_step = [ + d for d in self._workflows_installer.state.dashboards.keys() if d.startswith(step_name) + ] for dash in dashboards_per_step: if len(dashboard_link) == 0: dashboard_link += "Go to the one of the following dashboards after running the job:\n" first, second = dash.replace("_", " ").title().split() - dashboard_url = f"{self._ws.config.host}/sql/dashboards/{self._state.dashboards[dash]}" + dashboard_url = ( + f"{self._ws.config.host}/sql/dashboards/{self._workflows_installer.state.dashboards[dash]}" + ) dashboard_link += f" - [{first} ({second}) dashboard]({dashboard_url})\n" job_link = f"[{self._name(step_name)}]({self._ws.config.host}#job/{job_id})" markdown.append("---\n\n") markdown.append(f"## {job_link}\n\n") markdown.append(f"{dashboard_link}") markdown.append("\nThe workflow consists of the following separate tasks:\n\n") - for task in self._sorted_tasks(): + for task in self.sorted_tasks(): if task.workflow != step_name: continue doc = self._config.replace_inventory_variable(task.doc) @@ -709,260 +415,6 @@ def _create_readme(self) -> str: def _replace_inventory_variable(self, text: str) -> str: return text.replace("$inventory", f"hive_metastore.{self._config.inventory_database}") - def _create_debug(self, remote_wheel: str): - readme_link = self._installation.workspace_link('README') - job_links = ", ".join( - f"[{self._name(step_name)}]({self._ws.config.host}#job/{job_id})" - for step_name, job_id in self._state.jobs.items() - ) - content = DEBUG_NOTEBOOK.format( - remote_wheel=remote_wheel, readme_link=readme_link, job_links=job_links, config_file=self._config_file - ).encode("utf8") - self._installation.upload('DEBUG.py', content) - - def _job_settings(self, step_name: str, remote_wheel: str): - email_notifications = None - if not self._config.override_clusters and "@" in self._my_username: - # set email notifications only if we're running the real - # installation and not the integration test. - email_notifications = jobs.JobEmailNotifications( - on_success=[self._my_username], on_failure=[self._my_username] - ) - tasks = sorted( - [t for t in _TASKS.values() if t.workflow == step_name], - key=lambda _: _.name, - ) - version = self._product_info.version() - version = version if not self._ws.config.is_gcp else version.replace("+", "-") - return { - "name": self._name(step_name), - "tags": {"version": f"v{version}"}, - "job_clusters": self._job_clusters({t.job_cluster for t in tasks}), - "email_notifications": email_notifications, - "tasks": [self._job_task(task, remote_wheel) for task in tasks], - } - - def _upload_wheel_runner(self, remote_wheel: str): - # TODO: we have to be doing this workaround until ES-897453 is solved in the platform - code = TEST_RUNNER_NOTEBOOK.format(remote_wheel=remote_wheel, config_file=self._config_file).encode("utf8") - return self._installation.upload(f"wheels/wheel-test-runner-{self._product_info.version()}.py", code) - - @staticmethod - def _apply_cluster_overrides(settings: dict[str, Any], overrides: dict[str, str], wheel_runner: str) -> dict: - settings["job_clusters"] = [_ for _ in settings["job_clusters"] if _.job_cluster_key not in overrides] - for job_task in settings["tasks"]: - if job_task.job_cluster_key is None: - continue - if job_task.job_cluster_key in overrides: - job_task.existing_cluster_id = overrides[job_task.job_cluster_key] - job_task.job_cluster_key = None - job_task.libraries = None - if job_task.python_wheel_task is not None: - job_task.python_wheel_task = None - params = {"task": job_task.task_key} | EXTRA_TASK_PARAMS - job_task.notebook_task = jobs.NotebookTask(notebook_path=wheel_runner, base_parameters=params) - return settings - - def _job_task(self, task: Task, remote_wheel: str) -> jobs.Task: - jobs_task = jobs.Task( - task_key=task.name, - job_cluster_key=task.job_cluster, - depends_on=[jobs.TaskDependency(task_key=d) for d in _TASKS[task.name].dependencies()], - ) - if task.dashboard: - # dashboards are created in parallel to wheel uploads, so we'll just retry - retry_on_attribute_error = retried(on=[KeyError], timeout=self._verify_timeout) - retried_job_dashboard_task = retry_on_attribute_error(self._job_dashboard_task) - return retried_job_dashboard_task(jobs_task, task) - if task.notebook: - return self._job_notebook_task(jobs_task, task) - return self._job_wheel_task(jobs_task, task, remote_wheel) - - def _job_dashboard_task(self, jobs_task: jobs.Task, task: Task) -> jobs.Task: - assert task.dashboard is not None - dashboard_id = self._state.dashboards[task.dashboard] - return replace( - jobs_task, - job_cluster_key=None, - sql_task=jobs.SqlTask( - warehouse_id=self._warehouse_id, - dashboard=jobs.SqlTaskDashboard(dashboard_id=dashboard_id), - ), - ) - - def _job_notebook_task(self, jobs_task: jobs.Task, task: Task) -> jobs.Task: - assert task.notebook is not None - local_notebook = self._this_file.parent / task.notebook - with local_notebook.open("rb") as f: - remote_notebook = self._installation.upload(local_notebook.name, f.read()) - return replace( - jobs_task, - notebook_task=jobs.NotebookTask( - notebook_path=remote_notebook, - # ES-872211: currently, we cannot read WSFS files from Scala context - base_parameters={ - "task": task.name, - "config": f"/Workspace{self._config_file}", - } - | EXTRA_TASK_PARAMS, - ), - ) - - def _job_wheel_task(self, jobs_task: jobs.Task, task: Task, remote_wheel: str) -> jobs.Task: - return replace( - jobs_task, - # TODO: check when we can install wheels from WSFS properly - libraries=[compute.Library(whl=f"dbfs:{remote_wheel}")], - python_wheel_task=jobs.PythonWheelTask( - package_name="databricks_labs_ucx", - entry_point="runtime", # [project.entry-points.databricks] in pyproject.toml - named_parameters={"task": task.name, "config": f"/Workspace{self._config_file}"} | EXTRA_TASK_PARAMS, - ), - ) - - def _job_cluster_spark_conf(self, cluster_key: str): - conf_from_installation = self._config.spark_conf if self._config.spark_conf else {} - if cluster_key == "main": - spark_conf = { - "spark.databricks.cluster.profile": "singleNode", - "spark.master": "local[*]", - } - return spark_conf | conf_from_installation - if cluster_key == "tacl": - return {"spark.databricks.acl.sqlOnly": "true"} | conf_from_installation - if cluster_key == "table_migration": - return {"spark.sql.sources.parallelPartitionDiscovery.parallelism": "200"} | conf_from_installation - return conf_from_installation - - def _job_clusters(self, names: set[str]): - clusters = [] - if "main" in names: - clusters.append( - jobs.JobCluster( - job_cluster_key="main", - new_cluster=compute.ClusterSpec( - data_security_mode=compute.DataSecurityMode.LEGACY_SINGLE_USER, - spark_conf=self._job_cluster_spark_conf("main"), - custom_tags={"ResourceClass": "SingleNode"}, - num_workers=0, - policy_id=self.config.policy_id, - ), - ) - ) - if "tacl" in names: - clusters.append( - jobs.JobCluster( - job_cluster_key="tacl", - new_cluster=compute.ClusterSpec( - data_security_mode=compute.DataSecurityMode.LEGACY_TABLE_ACL, - spark_conf=self._job_cluster_spark_conf("tacl"), - num_workers=1, # ShowPermissionsCommand needs a worker - policy_id=self.config.policy_id, - ), - ) - ) - if "table_migration" in names: - clusters.append( - jobs.JobCluster( - job_cluster_key="table_migration", - new_cluster=compute.ClusterSpec( - data_security_mode=compute.DataSecurityMode.SINGLE_USER, - spark_conf=self._job_cluster_spark_conf("table_migration"), - policy_id=self.config.policy_id, - autoscale=compute.AutoScale( - max_workers=self.config.max_workers, - min_workers=self.config.min_workers, - ), - ), - ) - ) - return clusters - - @staticmethod - def _readable_timedelta(epoch): - when = datetime.utcfromtimestamp(epoch) - duration = datetime.now() - when - data = {} - data["days"], remaining = divmod(duration.total_seconds(), 86_400) - data["hours"], remaining = divmod(remaining, 3_600) - data["minutes"], data["seconds"] = divmod(remaining, 60) - - time_parts = ((name, round(value)) for name, value in data.items()) - time_parts = [f"{value} {name[:-1] if value == 1 else name}" for name, value in time_parts if value > 0] - if len(time_parts) > 0: - time_parts.append("ago") - if time_parts: - return " ".join(time_parts) - return "less than 1 second ago" - - def latest_job_status(self) -> list[dict]: - latest_status = [] - for step, job_id in self._state.jobs.items(): - job_state = None - start_time = None - try: - job_runs = list(self._ws.jobs.list_runs(job_id=int(job_id), limit=1)) - except InvalidParameterValue as e: - logger.warning(f"skipping {step}: {e}") - continue - if job_runs: - state = job_runs[0].state - if state and state.result_state: - job_state = state.result_state.name - elif state and state.life_cycle_state: - job_state = state.life_cycle_state.name - if job_runs[0].start_time: - start_time = job_runs[0].start_time / 1000 - latest_status.append( - { - "step": step, - "state": "UNKNOWN" if not (job_runs and job_state) else job_state, - "started": ( - "" if not (job_runs and start_time) else self._readable_timedelta(start_time) - ), - } - ) - return latest_status - - def _get_result_state(self, job_id): - job_runs = list(self._ws.jobs.list_runs(job_id=job_id, limit=1)) - latest_job_run = job_runs[0] - if not latest_job_run.state.result_state: - raise AttributeError("no result state in job run") - job_state = latest_job_run.state.result_state.value - return job_state - - def repair_run(self, workflow): - try: - job_id, run_id = self._repair_workflow(workflow) - run_details = self._ws.jobs.get_run(run_id=run_id, include_history=True) - latest_repair_run_id = run_details.repair_history[-1].id - job_url = f"{self._ws.config.host}#job/{job_id}/run/{run_id}" - logger.debug(f"Repair Running {workflow} job: {job_url}") - self._ws.jobs.repair_run(run_id=run_id, rerun_all_failed_tasks=True, latest_repair_id=latest_repair_run_id) - webbrowser.open(job_url) - except InvalidParameterValue as e: - logger.warning(f"skipping {workflow}: {e}") - except TimeoutError: - logger.warning(f"Skipping the {workflow} due to time out. Please try after sometime") - - def _repair_workflow(self, workflow): - job_id = self._state.jobs.get(workflow) - if not job_id: - raise InvalidParameterValue("job does not exists hence skipping repair") - job_runs = list(self._ws.jobs.list_runs(job_id=job_id, limit=1)) - if not job_runs: - raise InvalidParameterValue("job is not initialized yet. Can't trigger repair run now") - latest_job_run = job_runs[0] - retry_on_attribute_error = retried(on=[AttributeError], timeout=self._verify_timeout) - retried_check = retry_on_attribute_error(self._get_result_state) - state_value = retried_check(job_id) - logger.info(f"The status for the latest run is {state_value}") - if state_value != "FAILED": - raise InvalidParameterValue("job is not in FAILED state hence skipping repair") - run_id = latest_job_run.run_id - return job_id, run_id - def uninstall(self): if self._prompts and not self._prompts.confirm( "Do you want to uninstall ucx from the workspace too, this would " @@ -1010,10 +462,10 @@ def _remove_secret_scope(self): def _remove_jobs(self): logger.info("Deleting jobs") - if not self._state.jobs: + if not self._workflows_installer.state.jobs: logger.error("No jobs present or jobs already deleted") return - for step_name, job_id in self._state.jobs.items(): + for step_name, job_id in self._workflows_installer.state.jobs.items(): try: logger.info(f"Deleting {step_name} job_id={job_id}.") self._ws.jobs.delete(job_id) @@ -1030,31 +482,12 @@ def _remove_warehouse(self): except InvalidParameterValue: logger.error("Error accessing warehouse details") - def validate_step(self, step: str) -> bool: - job_id = int(self._state.jobs[step]) - logger.debug(f"Validating {step} workflow: {self._ws.config.host}#job/{job_id}") - current_runs = list(self._ws.jobs.list_runs(completed_only=False, job_id=job_id)) - for run in current_runs: - if run.state and run.state.result_state == RunResultState.SUCCESS: - return True - for run in current_runs: - if ( - run.run_id - and run.state - and run.state.life_cycle_state in (RunLifeCycleState.RUNNING, RunLifeCycleState.PENDING) - ): - logger.info("Identified a run in progress waiting for run completion") - self._ws.jobs.wait_get_run_job_terminated_or_skipped(run_id=run.run_id) - run_new_state = self._ws.jobs.get_run(run_id=run.run_id).state - return run_new_state is not None and run_new_state.result_state == RunResultState.SUCCESS - return False - def validate_and_run(self, step: str): - if not self.validate_step(step): - self.run_workflow(step) + if not self._workflows_installer.validate_step(step): + self._workflows_installer.run_workflow(step) def _trigger_workflow(self, step: str): - job_id = int(self._state.jobs[step]) + job_id = int(self._workflows_installer.state.jobs[step]) job_url = f"{self._ws.config.host}#job/{job_id}" logger.debug(f"triggering {step} job: {self._ws.config.host}#job/{job_id}") self._ws.jobs.run_now(job_id) @@ -1071,5 +504,6 @@ def _trigger_workflow(self, step: str): current = app.current_installation(workspace_client) except NotFound: current = Installation.assume_global(workspace_client, app.product_name()) - installer = WorkspaceInstaller(Prompts(), current, workspace_client, app) + prmpts = Prompts() + installer = WorkspaceInstaller(prmpts, current, workspace_client, app) installer.run() diff --git a/src/databricks/labs/ucx/installer/mixins.py b/src/databricks/labs/ucx/installer/mixins.py new file mode 100644 index 0000000000..a234747287 --- /dev/null +++ b/src/databricks/labs/ucx/installer/mixins.py @@ -0,0 +1,77 @@ +import logging +import os +from datetime import datetime + +from databricks.labs.blueprint.installation import Installation +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.sql import EndpointInfoWarehouseType + +from databricks.labs.ucx.config import WorkspaceConfig +from databricks.labs.ucx.framework.tasks import _TASKS, Task + +logger = logging.getLogger(__name__) + + +class InstallationMixin: + def __init__(self, config: WorkspaceConfig, installation: Installation, ws: WorkspaceClient): + self._config = config + self._installation = installation + self._ws = ws + + @staticmethod + def sorted_tasks() -> list[Task]: + return sorted(_TASKS.values(), key=lambda x: x.task_id) + + @classmethod + def step_list(cls) -> list[str]: + step_list = [] + for task in cls.sorted_tasks(): + if task.workflow not in step_list: + step_list.append(task.workflow) + return step_list + + def _name(self, name: str) -> str: + prefix = os.path.basename(self._installation.install_folder()).removeprefix('.') + return f"[{prefix.upper()}] {name}" + + @property + def _my_username(self): + if not hasattr(self, "_me"): + self._me = self._ws.current_user.me() + is_workspace_admin = any(g.display == "admins" for g in self._me.groups) + if not is_workspace_admin: + msg = "Current user is not a workspace admin" + raise PermissionError(msg) + return self._me.user_name + + @staticmethod + def _readable_timedelta(epoch): + when = datetime.utcfromtimestamp(epoch) + duration = datetime.now() - when + data = {} + data["days"], remaining = divmod(duration.total_seconds(), 86_400) + data["hours"], remaining = divmod(remaining, 3_600) + data["minutes"], data["seconds"] = divmod(remaining, 60) + + time_parts = ((name, round(value)) for (name, value) in data.items()) + time_parts = [f"{value} {name[:-1] if value == 1 else name}" for name, value in time_parts if value > 0] + if len(time_parts) > 0: + time_parts.append("ago") + if time_parts: + return " ".join(time_parts) + return "less than 1 second ago" + + @property + def _warehouse_id(self) -> str: + if self._config.warehouse_id is not None: + logger.info("Fetching warehouse_id from a config") + return self._config.warehouse_id + warehouses = [_ for _ in self._ws.warehouses.list() if _.warehouse_type == EndpointInfoWarehouseType.PRO] + warehouse_id = self._config.warehouse_id + if not warehouse_id and not warehouses: + msg = "need either configured warehouse_id or an existing PRO SQL warehouse" + raise ValueError(msg) + if not warehouse_id: + warehouse_id = warehouses[0].id + self._config.warehouse_id = warehouse_id + return warehouse_id diff --git a/src/databricks/labs/ucx/installer/workflows.py b/src/databricks/labs/ucx/installer/workflows.py new file mode 100644 index 0000000000..863c0c7a44 --- /dev/null +++ b/src/databricks/labs/ucx/installer/workflows.py @@ -0,0 +1,545 @@ +import logging +import re +import sys +import webbrowser +from dataclasses import replace +from datetime import timedelta +from pathlib import Path +from typing import Any + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.parallel import ManyError +from databricks.labs.blueprint.tui import Prompts +from databricks.labs.blueprint.wheels import ProductInfo, WheelsV2 +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import ( + Aborted, + AlreadyExists, + BadRequest, + Cancelled, + DataLoss, + DeadlineExceeded, + InternalError, + InvalidParameterValue, + NotFound, + OperationFailed, + PermissionDenied, + RequestLimitExceeded, + ResourceAlreadyExists, + ResourceConflict, + ResourceDoesNotExist, + ResourceExhausted, + TemporarilyUnavailable, + TooManyRequests, + Unauthenticated, + Unknown, +) +from databricks.sdk.retries import retried +from databricks.sdk.service import compute, jobs +from databricks.sdk.service.jobs import RunLifeCycleState, RunResultState + +import databricks +from databricks.labs.ucx.config import WorkspaceConfig +from databricks.labs.ucx.configure import ConfigureClusterOverrides +from databricks.labs.ucx.framework.tasks import _TASKS, Task +from databricks.labs.ucx.installer.mixins import InstallationMixin +from databricks.labs.ucx.runtime import main + +logger = logging.getLogger(__name__) + +EXTRA_TASK_PARAMS = { + "job_id": "{{job_id}}", + "run_id": "{{run_id}}", + "parent_run_id": "{{parent_run_id}}", +} +DEBUG_NOTEBOOK = """# Databricks notebook source +# MAGIC %md +# MAGIC # Debug companion for UCX installation (see [README]({readme_link})) +# MAGIC +# MAGIC Production runs are supposed to be triggered through the following jobs: {job_links} +# MAGIC +# MAGIC **This notebook is overwritten with each UCX update/(re)install.** + +# COMMAND ---------- + +# MAGIC %pip install /Workspace{remote_wheel} +dbutils.library.restartPython() + +# COMMAND ---------- + +import logging +from pathlib import Path +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.logger import install_logger +from databricks.labs.ucx.__about__ import __version__ +from databricks.labs.ucx.config import WorkspaceConfig +from databricks.sdk import WorkspaceClient + +install_logger() +logging.getLogger("databricks").setLevel("DEBUG") + +cfg = Installation.load_local(WorkspaceConfig, Path("/Workspace{config_file}")) +ws = WorkspaceClient() + +print(__version__) +""" + +TEST_RUNNER_NOTEBOOK = """# Databricks notebook source +# MAGIC %pip install /Workspace{remote_wheel} +dbutils.library.restartPython() + +# COMMAND ---------- + +from databricks.labs.ucx.runtime import main + +main(f'--config=/Workspace{config_file}', + f'--task=' + dbutils.widgets.get('task'), + f'--job_id=' + dbutils.widgets.get('job_id'), + f'--run_id=' + dbutils.widgets.get('run_id'), + f'--parent_run_id=' + dbutils.widgets.get('parent_run_id')) +""" + + +class WorkflowsInstallation(InstallationMixin): + def __init__( + self, + config: WorkspaceConfig, + installation: Installation, + ws: WorkspaceClient, + wheels: WheelsV2, + prompts: Prompts, + product_info: ProductInfo, + verify_timeout: timedelta, + ): + self._config = config + self._installation = installation + self._ws = ws + self._state = InstallState.from_installation(installation) + self._wheels = wheels + self._prompts = prompts + self._product_info = product_info + self._verify_timeout = verify_timeout + self._this_file = Path(__file__) + super().__init__(config, installation, ws) + + @classmethod + def current(cls, ws: WorkspaceClient): + product_info = ProductInfo.from_class(WorkspaceConfig) + installation = product_info.current_installation(ws) + config = installation.load(WorkspaceConfig) + wheels = product_info.wheels(ws) + prompts = Prompts() + timeout = timedelta(minutes=2) + + return cls(config, installation, ws, wheels, prompts, product_info, timeout) + + @property + def state(self): + return self._state + + def run_workflow(self, step: str): + job_id = int(self._state.jobs[step]) + logger.debug(f"starting {step} job: {self._ws.config.host}#job/{job_id}") + job_run_waiter = self._ws.jobs.run_now(job_id) + try: + job_run_waiter.result() + except OperationFailed as err: + # currently we don't have any good message from API, so we have to work around it. + job_run = self._ws.jobs.get_run(job_run_waiter.run_id) + raise self._infer_error_from_job_run(job_run) from err + + def create_jobs(self): + logger.debug(f"Creating jobs from tasks in {main.__name__}") + remote_wheel = self._upload_wheel() + desired_steps = {t.workflow for t in _TASKS.values() if t.cloud_compatible(self._ws.config)} + wheel_runner = None + + if self._config.override_clusters: + wheel_runner = self._upload_wheel_runner(remote_wheel) + for step_name in desired_steps: + settings = self._job_settings(step_name, remote_wheel) + if self._config.override_clusters: + settings = self._apply_cluster_overrides(settings, self._config.override_clusters, wheel_runner) + self._deploy_workflow(step_name, settings) + + for step_name, job_id in self._state.jobs.items(): + if step_name not in desired_steps: + try: + logger.info(f"Removing job_id={job_id}, as it is no longer needed") + self._ws.jobs.delete(job_id) + except InvalidParameterValue: + logger.warning(f"step={step_name} does not exist anymore for some reason") + continue + + self._state.save() + self._create_debug(remote_wheel) + + def repair_run(self, workflow): + try: + job_id, run_id = self._repair_workflow(workflow) + run_details = self._ws.jobs.get_run(run_id=run_id, include_history=True) + latest_repair_run_id = run_details.repair_history[-1].id + job_url = f"{self._ws.config.host}#job/{job_id}/run/{run_id}" + logger.debug(f"Repair Running {workflow} job: {job_url}") + self._ws.jobs.repair_run(run_id=run_id, rerun_all_failed_tasks=True, latest_repair_id=latest_repair_run_id) + webbrowser.open(job_url) + except InvalidParameterValue as e: + logger.warning(f"skipping {workflow}: {e}") + except TimeoutError: + logger.warning(f"Skipping the {workflow} due to time out. Please try after sometime") + + def latest_job_status(self) -> list[dict]: + latest_status = [] + for step, job_id in self._state.jobs.items(): + job_state = None + start_time = None + try: + job_runs = list(self._ws.jobs.list_runs(job_id=int(job_id), limit=1)) + except InvalidParameterValue as e: + logger.warning(f"skipping {step}: {e}") + continue + if job_runs: + state = job_runs[0].state + if state and state.result_state: + job_state = state.result_state.name + elif state and state.life_cycle_state: + job_state = state.life_cycle_state.name + if job_runs[0].start_time: + start_time = job_runs[0].start_time / 1000 + latest_status.append( + { + "step": step, + "state": "UNKNOWN" if not (job_runs and job_state) else job_state, + "started": ( + "" if not (job_runs and start_time) else self._readable_timedelta(start_time) + ), + } + ) + return latest_status + + def validate_step(self, step: str) -> bool: + job_id = int(self.state.jobs[step]) + logger.debug(f"Validating {step} workflow: {self._ws.config.host}#job/{job_id}") + current_runs = list(self._ws.jobs.list_runs(completed_only=False, job_id=job_id)) + for run in current_runs: + if run.state and run.state.result_state == RunResultState.SUCCESS: + return True + for run in current_runs: + if ( + run.run_id + and run.state + and run.state.life_cycle_state in (RunLifeCycleState.RUNNING, RunLifeCycleState.PENDING) + ): + logger.info("Identified a run in progress waiting for run completion") + self._ws.jobs.wait_get_run_job_terminated_or_skipped(run_id=run.run_id) + run_new_state = self._ws.jobs.get_run(run_id=run.run_id).state + return run_new_state is not None and run_new_state.result_state == RunResultState.SUCCESS + return False + + @property + def _config_file(self): + return f"{self._installation.install_folder()}/config.yml" + + def _job_cluster_spark_conf(self, cluster_key: str): + conf_from_installation = self._config.spark_conf if self._config.spark_conf else {} + if cluster_key == "main": + spark_conf = { + "spark.databricks.cluster.profile": "singleNode", + "spark.master": "local[*]", + } + return spark_conf | conf_from_installation + if cluster_key == "tacl": + return {"spark.databricks.acl.sqlOnly": "true"} | conf_from_installation + if cluster_key == "table_migration": + return {"spark.sql.sources.parallelPartitionDiscovery.parallelism": "200"} | conf_from_installation + return conf_from_installation + + def _deploy_workflow(self, step_name: str, settings): + if step_name in self._state.jobs: + try: + job_id = int(self._state.jobs[step_name]) + logger.info(f"Updating configuration for step={step_name} job_id={job_id}") + return self._ws.jobs.reset(job_id, jobs.JobSettings(**settings)) + except InvalidParameterValue: + del self._state.jobs[step_name] + logger.warning(f"step={step_name} does not exist anymore for some reason") + return self._deploy_workflow(step_name, settings) + logger.info(f"Creating new job configuration for step={step_name}") + new_job = self._ws.jobs.create(**settings) + assert new_job.job_id is not None + self._state.jobs[step_name] = str(new_job.job_id) + return None + + def _infer_error_from_job_run(self, job_run) -> Exception: + errors: list[Exception] = [] + timeouts: list[DeadlineExceeded] = [] + assert job_run.tasks is not None + for run_task in job_run.tasks: + error = self._infer_error_from_task_run(run_task) + if not error: + continue + if isinstance(error, DeadlineExceeded): + timeouts.append(error) + continue + errors.append(error) + assert job_run.state is not None + assert job_run.state.state_message is not None + if len(errors) == 1: + return errors[0] + all_errors = errors + timeouts + if len(all_errors) == 0: + return Unknown(job_run.state.state_message) + return ManyError(all_errors) + + def _infer_error_from_task_run(self, run_task: jobs.RunTask) -> Exception | None: + if not run_task.state: + return None + if run_task.state.result_state == jobs.RunResultState.TIMEDOUT: + msg = f"{run_task.task_key}: The run was stopped after reaching the timeout" + return DeadlineExceeded(msg) + if run_task.state.result_state != jobs.RunResultState.FAILED: + return None + assert run_task.run_id is not None + run_output = self._ws.jobs.get_run_output(run_task.run_id) + if not run_output: + msg = f'No run output. {run_task.state.state_message}' + return InternalError(msg) + if logger.isEnabledFor(logging.DEBUG): + if run_output.error_trace: + sys.stderr.write(run_output.error_trace) + if not run_output.error: + msg = f'No error in run output. {run_task.state.state_message}' + return InternalError(msg) + return self._infer_task_exception(f"{run_task.task_key}: {run_output.error}") + + @staticmethod + def _infer_task_exception(haystack: str) -> Exception: + needles: list[type[Exception]] = [ + BadRequest, + Unauthenticated, + PermissionDenied, + NotFound, + ResourceConflict, + TooManyRequests, + Cancelled, + databricks.sdk.errors.NotImplemented, + InternalError, + TemporarilyUnavailable, + DeadlineExceeded, + InvalidParameterValue, + ResourceDoesNotExist, + Aborted, + AlreadyExists, + ResourceAlreadyExists, + ResourceExhausted, + RequestLimitExceeded, + Unknown, + DataLoss, + ValueError, + KeyError, + ] + constructors: dict[re.Pattern, type[Exception]] = { + re.compile(r".*\[TABLE_OR_VIEW_NOT_FOUND] (.*)"): NotFound, + re.compile(r".*\[SCHEMA_NOT_FOUND] (.*)"): NotFound, + } + for klass in needles: + constructors[re.compile(f".*{klass.__name__}: (.*)")] = klass + for pattern, klass in constructors.items(): + match = pattern.match(haystack) + if match: + return klass(match.group(1)) + return Unknown(haystack) + + def _upload_wheel(self): + with self._wheels: + try: + self._wheels.upload_to_dbfs() + except PermissionDenied as err: + if not self._prompts: + raise RuntimeWarning("no Prompts instance found") from err + logger.warning(f"Uploading wheel file to DBFS failed, DBFS is probably write protected. {err}") + configure_cluster_overrides = ConfigureClusterOverrides(self._ws, self._prompts.choice_from_dict) + self._config.override_clusters = configure_cluster_overrides.configure() + self._installation.save(self._config) + return self._wheels.upload_to_wsfs() + + def _upload_wheel_runner(self, remote_wheel: str): + # TODO: we have to be doing this workaround until ES-897453 is solved in the platform + code = TEST_RUNNER_NOTEBOOK.format(remote_wheel=remote_wheel, config_file=self._config_file).encode("utf8") + return self._installation.upload(f"wheels/wheel-test-runner-{self._product_info.version()}.py", code) + + @staticmethod + def _apply_cluster_overrides(settings: dict[str, Any], overrides: dict[str, str], wheel_runner: str) -> dict: + settings["job_clusters"] = [_ for _ in settings["job_clusters"] if _.job_cluster_key not in overrides] + for job_task in settings["tasks"]: + if job_task.job_cluster_key is None: + continue + if job_task.job_cluster_key in overrides: + job_task.existing_cluster_id = overrides[job_task.job_cluster_key] + job_task.job_cluster_key = None + job_task.libraries = None + if job_task.python_wheel_task is not None: + job_task.python_wheel_task = None + params = {"task": job_task.task_key} | EXTRA_TASK_PARAMS + job_task.notebook_task = jobs.NotebookTask(notebook_path=wheel_runner, base_parameters=params) + return settings + + def _job_settings(self, step_name: str, remote_wheel: str): + email_notifications = None + if not self._config.override_clusters and "@" in self._my_username: + # set email notifications only if we're running the real + # installation and not the integration test. + email_notifications = jobs.JobEmailNotifications( + on_success=[self._my_username], on_failure=[self._my_username] + ) + tasks = sorted( + [t for t in _TASKS.values() if t.workflow == step_name], + key=lambda _: _.name, + ) + version = self._product_info.version() + version = version if not self._ws.config.is_gcp else version.replace("+", "-") + return { + "name": self._name(step_name), + "tags": {"version": f"v{version}"}, + "job_clusters": self._job_clusters({t.job_cluster for t in tasks}), + "email_notifications": email_notifications, + "tasks": [self._job_task(task, remote_wheel) for task in tasks], + } + + def _job_task(self, task: Task, remote_wheel: str) -> jobs.Task: + jobs_task = jobs.Task( + task_key=task.name, + job_cluster_key=task.job_cluster, + depends_on=[jobs.TaskDependency(task_key=d) for d in _TASKS[task.name].dependencies()], + ) + if task.dashboard: + # dashboards are created in parallel to wheel uploads, so we'll just retry + retry_on_attribute_error = retried(on=[KeyError], timeout=self._verify_timeout) + retried_job_dashboard_task = retry_on_attribute_error(self._job_dashboard_task) + return retried_job_dashboard_task(jobs_task, task) + if task.notebook: + return self._job_notebook_task(jobs_task, task) + return self._job_wheel_task(jobs_task, task, remote_wheel) + + def _job_dashboard_task(self, jobs_task: jobs.Task, task: Task) -> jobs.Task: + assert task.dashboard is not None + dashboard_id = self._state.dashboards[task.dashboard] + return replace( + jobs_task, + job_cluster_key=None, + sql_task=jobs.SqlTask( + warehouse_id=self._warehouse_id, + dashboard=jobs.SqlTaskDashboard(dashboard_id=dashboard_id), + ), + ) + + def _job_notebook_task(self, jobs_task: jobs.Task, task: Task) -> jobs.Task: + assert task.notebook is not None + local_notebook = self._this_file.parent.parent / task.notebook + with local_notebook.open("rb") as f: + remote_notebook = self._installation.upload(local_notebook.name, f.read()) + return replace( + jobs_task, + notebook_task=jobs.NotebookTask( + notebook_path=remote_notebook, + # ES-872211: currently, we cannot read WSFS files from Scala context + base_parameters={ + "task": task.name, + "config": f"/Workspace{self._config_file}", + } + | EXTRA_TASK_PARAMS, + ), + ) + + def _job_wheel_task(self, jobs_task: jobs.Task, task: Task, remote_wheel: str) -> jobs.Task: + return replace( + jobs_task, + # TODO: check when we can install wheels from WSFS properly + libraries=[compute.Library(whl=f"dbfs:{remote_wheel}")], + python_wheel_task=jobs.PythonWheelTask( + package_name="databricks_labs_ucx", + entry_point="runtime", # [project.entry-points.databricks] in pyproject.toml + named_parameters={"task": task.name, "config": f"/Workspace{self._config_file}"} | EXTRA_TASK_PARAMS, + ), + ) + + def _job_clusters(self, names: set[str]): + clusters = [] + if "main" in names: + clusters.append( + jobs.JobCluster( + job_cluster_key="main", + new_cluster=compute.ClusterSpec( + data_security_mode=compute.DataSecurityMode.LEGACY_SINGLE_USER, + spark_conf=self._job_cluster_spark_conf("main"), + custom_tags={"ResourceClass": "SingleNode"}, + num_workers=0, + policy_id=self._config.policy_id, + ), + ) + ) + if "tacl" in names: + clusters.append( + jobs.JobCluster( + job_cluster_key="tacl", + new_cluster=compute.ClusterSpec( + data_security_mode=compute.DataSecurityMode.LEGACY_TABLE_ACL, + spark_conf=self._job_cluster_spark_conf("tacl"), + num_workers=1, # ShowPermissionsCommand needs a worker + policy_id=self._config.policy_id, + ), + ) + ) + if "table_migration" in names: + clusters.append( + jobs.JobCluster( + job_cluster_key="table_migration", + new_cluster=compute.ClusterSpec( + data_security_mode=compute.DataSecurityMode.SINGLE_USER, + spark_conf=self._job_cluster_spark_conf("table_migration"), + policy_id=self._config.policy_id, + autoscale=compute.AutoScale( + max_workers=self._config.max_workers, + min_workers=self._config.min_workers, + ), + ), + ) + ) + return clusters + + def _create_debug(self, remote_wheel: str): + readme_link = self._installation.workspace_link('README') + job_links = ", ".join( + f"[{self._name(step_name)}]({self._ws.config.host}#job/{job_id})" + for step_name, job_id in self._state.jobs.items() + ) + content = DEBUG_NOTEBOOK.format( + remote_wheel=remote_wheel, readme_link=readme_link, job_links=job_links, config_file=self._config_file + ).encode("utf8") + self._installation.upload('DEBUG.py', content) + + def _repair_workflow(self, workflow): + job_id = self._state.jobs.get(workflow) + if not job_id: + raise InvalidParameterValue("job does not exists hence skipping repair") + job_runs = list(self._ws.jobs.list_runs(job_id=job_id, limit=1)) + if not job_runs: + raise InvalidParameterValue("job is not initialized yet. Can't trigger repair run now") + latest_job_run = job_runs[0] + retry_on_attribute_error = retried(on=[AttributeError], timeout=self._verify_timeout) + retried_check = retry_on_attribute_error(self._get_result_state) + state_value = retried_check(job_id) + logger.info(f"The status for the latest run is {state_value}") + if state_value != "FAILED": + raise InvalidParameterValue("job is not in FAILED state hence skipping repair") + run_id = latest_job_run.run_id + return job_id, run_id + + def _get_result_state(self, job_id): + job_runs = list(self._ws.jobs.list_runs(job_id=job_id, limit=1)) + latest_job_run = job_runs[0] + if not latest_job_run.state.result_state: + raise AttributeError("no result state in job run") + job_state = latest_job_run.state.result_state.value + return job_state diff --git a/tests/integration/test_installation.py b/tests/integration/test_installation.py index c5c36ddc1b..c9cde041ea 100644 --- a/tests/integration/test_installation.py +++ b/tests/integration/test_installation.py @@ -6,7 +6,6 @@ from dataclasses import replace from datetime import timedelta -import databricks.sdk.errors import pytest # pylint: disable=wrong-import-order from databricks.labs.blueprint.installation import Installation from databricks.labs.blueprint.installer import InstallState, RawState @@ -18,9 +17,11 @@ from databricks.sdk.service import compute, sql from databricks.sdk.service.iam import PermissionLevel +import databricks from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.hive_metastore.mapping import Rule from databricks.labs.ucx.install import WorkspaceInstallation, WorkspaceInstaller +from databricks.labs.ucx.installer.workflows import WorkflowsInstallation from databricks.labs.ucx.workspace_access import redash from databricks.labs.ucx.workspace_access.generic import ( GenericPermissionsSupport, @@ -68,6 +69,7 @@ def factory( default_cluster_id = env_or_skip("TEST_DEFAULT_CLUSTER_ID") tacl_cluster_id = env_or_skip("TEST_LEGACY_TABLE_ACL_CLUSTER_ID") + table_migration_cluster_id = env_or_skip("TEST_USER_ISOLATION_CLUSTER_ID") Threads.strict( "ensure clusters running", [ @@ -81,7 +83,7 @@ def factory( installer = WorkspaceInstaller(prompts, installation, ws, product_info, environ) workspace_config = installer.configure() installation = product_info.current_installation(ws) - overrides = {"main": default_cluster_id, "tacl": tacl_cluster_id} + overrides = {"main": default_cluster_id, "tacl": tacl_cluster_id, "table_migration": table_migration_cluster_id} workspace_config.override_clusters = overrides if workspace_config.workspace_start_path == '/': @@ -94,19 +96,21 @@ def factory( # TODO: see if we want to move building wheel as a context manager for yield factory, # so that we can shave off couple of seconds and build wheel only once per session # instead of every test + workflows_installation = WorkflowsInstallation( + workspace_config, installation, ws, product_info.wheels(ws), prompts, product_info, timedelta(minutes=3) + ) workspace_installation = WorkspaceInstallation( workspace_config, installation, sql_backend, - product_info.wheels(ws), ws, + workflows_installation, prompts, - timedelta(minutes=2), product_info, ) workspace_installation.run() cleanup.append(workspace_installation) - return workspace_installation + return workspace_installation, workflows_installation yield factory @@ -116,22 +120,22 @@ def factory( @retried(on=[NotFound, TimeoutError], timeout=timedelta(minutes=5)) def test_job_failure_propagates_correct_error_message_and_logs(ws, sql_backend, new_installation): - install = new_installation() + workspace_installation, workflow_installation = new_installation() - sql_backend.execute(f"DROP SCHEMA {install.config.inventory_database} CASCADE") + sql_backend.execute(f"DROP SCHEMA {workspace_installation.config.inventory_database} CASCADE") with pytest.raises(NotFound) as failure: - install.run_workflow("099-destroy-schema") + workflow_installation.run_workflow("099-destroy-schema") assert "cannot be found" in str(failure.value) - workflow_run_logs = list(ws.workspace.list(f"{install.folder}/logs")) + workflow_run_logs = list(ws.workspace.list(f"{workspace_installation.folder}/logs")) assert len(workflow_run_logs) == 1 @retried(on=[NotFound, InvalidParameterValue], timeout=timedelta(minutes=3)) def test_job_cluster_policy(ws, new_installation): - install = new_installation(lambda wc: replace(wc, override_clusters=None)) + install, _ = new_installation(lambda wc: replace(wc, override_clusters=None)) user_name = ws.current_user.me().user_name cluster_policy = ws.cluster_policies.get(policy_id=install.config.policy_id) policy_definition = json.loads(cluster_policy.definition) @@ -183,7 +187,7 @@ def test_running_real_assessment_job( group_name=ws_group_a.display_name, ) - install = new_installation(lambda wc: replace(wc, include_group_names=[ws_group_a.display_name])) + _, install = new_installation(lambda wc: replace(wc, include_group_names=[ws_group_a.display_name])) install.run_workflow("assessment") generic_permissions = GenericPermissionsSupport(ws, []) @@ -212,12 +216,12 @@ def test_running_real_migrate_groups_job( ], ) - install = new_installation(lambda wc: replace(wc, include_group_names=[ws_group_a.display_name])) + install, workflows_install = new_installation(lambda wc: replace(wc, include_group_names=[ws_group_a.display_name])) inventory_database = install.config.inventory_database permission_manager = PermissionManager(sql_backend, inventory_database, [generic_permissions]) permission_manager.inventorize_permissions() - install.run_workflow("migrate-groups") + workflows_install.run_workflow("migrate-groups") found = generic_permissions.load_as_dict("cluster-policies", cluster_policy.policy_id) assert found[acc_group_a.display_name] == PermissionLevel.CAN_USE @@ -242,12 +246,12 @@ def test_running_real_validate_groups_permissions_job( [redash.Listing(ws.queries.list, sql.ObjectTypePlural.QUERIES)], ) - install = new_installation(lambda wc: replace(wc, include_group_names=[ws_group_a.display_name])) + install, workflows_install = new_installation(lambda wc: replace(wc, include_group_names=[ws_group_a.display_name])) permission_manager = PermissionManager(sql_backend, install.config.inventory_database, [redash_permissions]) permission_manager.inventorize_permissions() # assert the job does not throw any exception - install.run_workflow("validate-groups-permissions") + workflows_install.run_workflow("validate-groups-permissions") @retried(on=[NotFound], timeout=timedelta(minutes=5)) @@ -270,7 +274,7 @@ def test_running_real_validate_groups_permissions_job_fails( ], ) - install = new_installation(lambda wc: replace(wc, include_group_names=[ws_group_a.display_name])) + install, workflows_install = new_installation(lambda wc: replace(wc, include_group_names=[ws_group_a.display_name])) inventory_database = install.config.inventory_database permission_manager = PermissionManager(sql_backend, inventory_database, [generic_permissions]) permission_manager.inventorize_permissions() @@ -281,14 +285,14 @@ def test_running_real_validate_groups_permissions_job_fails( ) with pytest.raises(ValueError): - install.run_workflow("validate-groups-permissions") + workflows_install.run_workflow("validate-groups-permissions") @retried(on=[NotFound, InvalidParameterValue], timeout=timedelta(minutes=5)) def test_running_real_remove_backup_groups_job(ws, sql_backend, new_installation, make_ucx_group): ws_group_a, _ = make_ucx_group() - install = new_installation(lambda wc: replace(wc, include_group_names=[ws_group_a.display_name])) + install, workflows_install = new_installation(lambda wc: replace(wc, include_group_names=[ws_group_a.display_name])) cfg = install.config group_manager = GroupManager( sql_backend, ws, cfg.inventory_database, cfg.include_group_names, cfg.renamed_group_prefix @@ -297,7 +301,7 @@ def test_running_real_remove_backup_groups_job(ws, sql_backend, new_installation group_manager.rename_groups() group_manager.reflect_account_groups_on_workspace() - install.run_workflow("remove-workspace-local-backup-groups") + workflows_install.run_workflow("remove-workspace-local-backup-groups") with pytest.raises(NotFound): ws.groups.get(ws_group_a.id) @@ -305,15 +309,15 @@ def test_running_real_remove_backup_groups_job(ws, sql_backend, new_installation @retried(on=[NotFound, InvalidParameterValue], timeout=timedelta(minutes=10)) def test_repair_run_workflow_job(ws, mocker, new_installation, sql_backend): - install = new_installation() + install, workflows_install = new_installation() mocker.patch("webbrowser.open") sql_backend.execute(f"DROP SCHEMA {install.config.inventory_database} CASCADE") with pytest.raises(NotFound): - install.run_workflow("099-destroy-schema") + workflows_install.run_workflow("099-destroy-schema") sql_backend.execute(f"CREATE SCHEMA IF NOT EXISTS {install.config.inventory_database}") - install.repair_run("099-destroy-schema") + workflows_install.repair_run("099-destroy-schema") installation = Installation(ws, product=os.path.basename(install.folder), install_folder=install.folder) state = InstallState.from_installation(installation) @@ -327,7 +331,7 @@ def test_repair_run_workflow_job(ws, mocker, new_installation, sql_backend): @retried(on=[NotFound], timeout=timedelta(minutes=5)) def test_uninstallation(ws, sql_backend, new_installation): - install = new_installation() + install, _ = new_installation() installation = Installation(ws, product=os.path.basename(install.folder), install_folder=install.folder) state = InstallState.from_installation(installation) assessment_job_id = state.jobs["assessment"] @@ -342,7 +346,7 @@ def test_uninstallation(ws, sql_backend, new_installation): def test_fresh_global_installation(ws, new_installation): product_info = ProductInfo.for_testing(WorkspaceConfig) - global_installation = new_installation( + global_installation, _ = new_installation( product_info=product_info, installation=Installation.assume_global(ws, product_info.product_name()), ) @@ -352,7 +356,7 @@ def test_fresh_global_installation(ws, new_installation): def test_fresh_user_installation(ws, new_installation): product_info = ProductInfo.for_testing(WorkspaceConfig) - user_installation = new_installation( + user_installation, _ = new_installation( product_info=product_info, installation=Installation.assume_user_home(ws, product_info.product_name()), ) @@ -362,12 +366,12 @@ def test_fresh_user_installation(ws, new_installation): def test_global_installation_on_existing_global_install(ws, new_installation): product_info = ProductInfo.for_testing(WorkspaceConfig) - existing_global_installation = new_installation( + existing_global_installation, _ = new_installation( product_info=product_info, installation=Installation.assume_global(ws, product_info.product_name()), ) assert existing_global_installation.folder == f"/Applications/{product_info.product_name()}" - reinstall_global = new_installation( + reinstall_global, _ = new_installation( product_info=product_info, installation=Installation.assume_global(ws, product_info.product_name()), ) @@ -378,7 +382,7 @@ def test_global_installation_on_existing_global_install(ws, new_installation): def test_user_installation_on_existing_global_install(ws, new_installation): # existing install at global level product_info = ProductInfo.for_testing(WorkspaceConfig) - existing_global_installation = new_installation( + existing_global_installation, _ = new_installation( product_info=product_info, installation=Installation.assume_global(ws, product_info.product_name()), ) @@ -396,7 +400,7 @@ def test_user_installation_on_existing_global_install(ws, new_installation): assert err.value.args[0] == "UCX is already installed, but no confirmation" # successful override with confirmation - reinstall_user_force = new_installation( + reinstall_user_force, _ = new_installation( product_info=product_info, installation=Installation.assume_global(ws, product_info.product_name()), environ={'UCX_FORCE_INSTALL': 'user'}, @@ -413,7 +417,7 @@ def test_user_installation_on_existing_global_install(ws, new_installation): def test_global_installation_on_existing_user_install(ws, new_installation): # existing installation at user level product_info = ProductInfo.for_testing(WorkspaceConfig) - existing_user_installation = new_installation( + existing_user_installation, _ = new_installation( product_info=product_info, installation=Installation.assume_user_home(ws, product_info.product_name()) ) assert ( @@ -448,7 +452,7 @@ def test_global_installation_on_existing_user_install(ws, new_installation): def test_check_inventory_database_exists(ws, new_installation): product_info = ProductInfo.for_testing(WorkspaceConfig) - install = new_installation( + install, _ = new_installation( product_info=product_info, installation=Installation.assume_global(ws, product_info.product_name()), ) @@ -483,7 +487,7 @@ def test_table_migration_job( # pylint: disable=too-many-locals dst_schema = make_schema(catalog_name=dst_catalog.name, name=src_schema.name) product_info = ProductInfo.from_class(WorkspaceConfig) - install = new_installation( + _, workflows_install = new_installation( product_info=product_info, extend_prompts={ r"Parallelism for migrating.*": "1000", @@ -513,9 +517,9 @@ def test_table_migration_job( # pylint: disable=too-many-locals ] installation.save(migrate_rules, filename='mapping.csv') - install.run_workflow("migrate-tables") + workflows_install.run_workflow("migrate-tables") # assert the workflow is successful - assert install.validate_step("migrate-tables") + assert workflows_install.validate_step("migrate-tables") # assert the tables are migrated assert ws.tables.get(f"{dst_catalog.name}.{dst_schema.name}.{src_managed_table.name}").name assert ws.tables.get(f"{dst_catalog.name}.{dst_schema.name}.{src_external_table.name}").name @@ -545,7 +549,7 @@ def test_table_migration_job_cluster_override( # pylint: disable=too-many-local dst_schema = make_schema(catalog_name=dst_catalog.name, name=src_schema.name) product_info = ProductInfo.from_class(WorkspaceConfig) - install = new_installation( + _, workflows_install = new_installation( lambda wc: replace(wc, override_clusters={"table_migration": env_or_skip("TEST_USER_ISOLATION_CLUSTER_ID")}), product_info=product_info, ) @@ -571,9 +575,9 @@ def test_table_migration_job_cluster_override( # pylint: disable=too-many-local ] installation.save(migrate_rules, filename='mapping.csv') - install.run_workflow("migrate-tables") + workflows_install.run_workflow("migrate-tables") # assert the workflow is successful - assert install.validate_step("migrate-tables") + assert workflows_install.validate_step("migrate-tables") # assert the tables are migrated assert ws.tables.get(f"{dst_catalog.name}.{dst_schema.name}.{src_managed_table.name}").name assert ws.tables.get(f"{dst_catalog.name}.{dst_schema.name}.{src_external_table.name}").name diff --git a/tests/unit/test_install.py b/tests/unit/test_install.py index 27bec1f495..9f977e489a 100644 --- a/tests/unit/test_install.py +++ b/tests/unit/test_install.py @@ -6,7 +6,7 @@ import pytest import yaml from databricks.labs.blueprint.installation import Installation, MockInstallation -from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.installer import InstallState, RawState from databricks.labs.blueprint.parallel import ManyError from databricks.labs.blueprint.tui import MockPrompts from databricks.labs.blueprint.wheels import ( @@ -51,10 +51,12 @@ ) from databricks.sdk.service.workspace import ObjectInfo +import databricks.labs.ucx.installer.mixins import databricks.labs.ucx.uninstall # noqa from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.framework.dashboards import DashboardFromFiles from databricks.labs.ucx.install import WorkspaceInstallation, WorkspaceInstaller +from databricks.labs.ucx.installer.workflows import WorkflowsInstallation PRODUCT_INFO = ProductInfo.from_class(WorkspaceConfig) @@ -168,6 +170,20 @@ def mock_installation_with_jobs(): ) +@pytest.fixture +def mock_installation_extra_jobs(): + return MockInstallation( + { + 'state.json': { + 'resources': { + 'jobs': {"assessment": "123", "extra_job": "123"}, + 'dashboards': {'assessment_main': 'abc', 'assessment_estimates': 'def'}, + } + } + } + ) + + @pytest.fixture def any_prompt(): return MockPrompts({".*": ""}) @@ -182,15 +198,23 @@ def test_create_database(ws, caplog, mock_installation, any_prompt): sql_backend = MockBackend( fails_on_first={'CREATE TABLE': '[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column, variable is incorrect'} ) - wheels = create_autospec(WheelsV2) + workflows_installation = WorkflowsInstallation( + WorkspaceConfig(inventory_database="...", policy_id='123'), + mock_installation, + ws, + create_autospec(WheelsV2), + any_prompt, + PRODUCT_INFO, + timedelta(seconds=1), + ) + workspace_installation = WorkspaceInstallation( WorkspaceConfig(inventory_database='ucx'), mock_installation, sql_backend, - wheels, ws, + workflows_installation, any_prompt, - timedelta(seconds=1), PRODUCT_INFO, ) @@ -201,21 +225,15 @@ def test_create_database(ws, caplog, mock_installation, any_prompt): def test_install_cluster_override_jobs(ws, mock_installation, any_prompt): - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) - workspace_installation = WorkspaceInstallation( - WorkspaceConfig( - inventory_database='ucx', - override_clusters={"main": 'one', "tacl": 'two', "table_migration": "three"}, - policy_id='123', - ), + workspace_installation = WorkflowsInstallation( + WorkspaceConfig(inventory_database='ucx', override_clusters={"main": 'one', "tacl": 'two'}, policy_id='123'), mock_installation, - sql_backend, - wheels, ws, + wheels, any_prompt, - timedelta(seconds=1), PRODUCT_INFO, + timedelta(seconds=1), ) workspace_installation.create_jobs() @@ -225,14 +243,9 @@ def test_install_cluster_override_jobs(ws, mock_installation, any_prompt): assert tasks['crawl_grants'].existing_cluster_id == 'two' assert tasks['estimates_report'].sql_task.dashboard.dashboard_id == 'def' - tasks = created_job_tasks(ws, '[MOCK] migrate-tables') - assert tasks['migrate_external_tables_sync'].existing_cluster_id == 'three' - assert tasks['migrate_dbfs_root_delta_tables'].existing_cluster_id == 'three' - def test_write_protected_dbfs(ws, tmp_path, mock_installation): """Simulate write protected DBFS AND override clusters""" - sql_backend = MockBackend() wheels = create_autospec(Wheels) wheels.upload_to_dbfs.side_effect = PermissionDenied(...) wheels.upload_to_wsfs.return_value = "/a/b/c" @@ -245,18 +258,17 @@ def test_write_protected_dbfs(ws, tmp_path, mock_installation): } ) - workspace_installation = WorkspaceInstallation( + workflows_installation = WorkflowsInstallation( WorkspaceConfig(inventory_database='ucx', policy_id='123'), mock_installation, - sql_backend, - wheels, ws, + wheels, prompts, - timedelta(seconds=1), PRODUCT_INFO, + timedelta(seconds=1), ) - workspace_installation.create_jobs() + workflows_installation.create_jobs() tasks = created_job_tasks(ws, '[MOCK] assessment') assert tasks['assess_jobs'].existing_cluster_id == "2222-999999-nosecuri" @@ -271,8 +283,8 @@ def test_write_protected_dbfs(ws, tmp_path, mock_installation): 'log_level': 'INFO', 'num_days_submit_runs_history': 30, 'num_threads': 10, - 'max_workers': 10, 'min_workers': 1, + 'max_workers': 10, 'override_clusters': {'main': '2222-999999-nosecuri', 'tacl': '3333-999999-legacytc'}, 'policy_id': '123', 'renamed_group_prefix': 'ucx-renamed-', @@ -283,20 +295,18 @@ def test_write_protected_dbfs(ws, tmp_path, mock_installation): def test_writeable_dbfs(ws, tmp_path, mock_installation, any_prompt): """Ensure configure does not add cluster override for happy path of writable DBFS""" - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) - workspace_installation = WorkspaceInstallation( + workflows_installation = WorkflowsInstallation( WorkspaceConfig(inventory_database='ucx', policy_id='123'), mock_installation, - sql_backend, - wheels, ws, + wheels, any_prompt, - timedelta(seconds=1), PRODUCT_INFO, + timedelta(seconds=1), ) - workspace_installation.create_jobs() + workflows_installation.create_jobs() job = created_job(ws, '[MOCK] assessment') job_clusters = {_.job_cluster_key: _ for _ in job['job_clusters']} @@ -304,11 +314,6 @@ def test_writeable_dbfs(ws, tmp_path, mock_installation, any_prompt): assert 'tacl' in job_clusters assert job_clusters["main"].new_cluster.policy_id == "123" - job = created_job(ws, '[MOCK] migrate-tables') - job_clusters = {_.job_cluster_key: _ for _ in job['job_clusters']} - assert 'table_migration' in job_clusters - assert job_clusters["table_migration"].new_cluster.policy_id == "123" - def test_run_workflow_creates_proper_failure(ws, mocker, any_prompt, mock_installation_with_jobs): def run_now(job_id): @@ -334,17 +339,15 @@ def result(): ], ) ws.jobs.get_run_output.return_value = jobs.RunOutput(error="does not compute", error_trace="# goes to stderr") - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) - installer = WorkspaceInstallation( + installer = WorkflowsInstallation( WorkspaceConfig(inventory_database='ucx'), mock_installation_with_jobs, - sql_backend, - wheels, ws, + wheels, any_prompt, - timedelta(seconds=1), PRODUCT_INFO, + timedelta(seconds=1), ) with pytest.raises(Unknown) as failure: installer.run_workflow("assessment") @@ -380,17 +383,15 @@ def result(): ws.jobs.get_run_output.return_value = jobs.RunOutput( error="something: PermissionDenied: does not compute", error_trace="# goes to stderr" ) - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) - installer = WorkspaceInstallation( + installer = WorkflowsInstallation( WorkspaceConfig(inventory_database='ucx'), mock_installation_with_jobs, - sql_backend, - wheels, ws, + wheels, any_prompt, - timedelta(seconds=1), PRODUCT_INFO, + timedelta(seconds=1), ) with pytest.raises(PermissionDenied) as failure: installer.run_workflow("assessment") @@ -434,17 +435,15 @@ def result(): ws.jobs.get_run_output.return_value = jobs.RunOutput( error="something: DataLoss: does not compute", error_trace="# goes to stderr" ) - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) - installer = WorkspaceInstallation( + installer = WorkflowsInstallation( WorkspaceConfig(inventory_database='ucx'), mock_installation_with_jobs, - sql_backend, - wheels, ws, + wheels, any_prompt, - timedelta(seconds=1), PRODUCT_INFO, + timedelta(seconds=1), ) with pytest.raises(ManyError) as failure: installer.run_workflow("assessment") @@ -483,9 +482,8 @@ def test_save_config(ws, mock_installation): 'log_level': 'INFO', 'num_days_submit_runs_history': 30, 'num_threads': 8, - 'spark_conf': {'spark.sql.sources.parallelPartitionDiscovery.parallelism': '200'}, - 'max_workers': 10, 'min_workers': 1, + 'max_workers': 10, 'policy_id': 'foo', 'renamed_group_prefix': 'db-temp-', 'warehouse_id': 'abc', @@ -518,9 +516,8 @@ def test_save_config_strip_group_names(ws, mock_installation): 'log_level': 'INFO', 'num_days_submit_runs_history': 30, 'num_threads': 8, - 'spark_conf': {'spark.sql.sources.parallelPartitionDiscovery.parallelism': '200'}, - 'max_workers': 10, 'min_workers': 1, + 'max_workers': 10, 'policy_id': 'foo', 'renamed_group_prefix': 'db-temp-', 'warehouse_id': 'abc', @@ -561,9 +558,8 @@ def test_create_cluster_policy(ws, mock_installation): 'log_level': 'INFO', 'num_days_submit_runs_history': 30, 'num_threads': 8, - 'spark_conf': {'spark.sql.sources.parallelPartitionDiscovery.parallelism': '200'}, - 'max_workers': 10, 'min_workers': 1, + 'max_workers': 10, 'policy_id': 'foo1', 'renamed_group_prefix': 'db-temp-', 'warehouse_id': 'abc', @@ -582,14 +578,22 @@ def test_main_with_existing_conf_does_not_recreate_config(ws, mocker, mock_insta r".*": "", } ) + workflows_installer = WorkflowsInstallation( + WorkspaceConfig(inventory_database="...", policy_id='123'), + mock_installation, + ws, + create_autospec(WheelsV2), + prompts, + PRODUCT_INFO, + timedelta(seconds=1), + ) workspace_installation = WorkspaceInstallation( WorkspaceConfig(inventory_database="...", policy_id='123'), mock_installation, sql_backend, - create_autospec(WheelsV2), ws, + workflows_installer, prompts, - timedelta(seconds=1), PRODUCT_INFO, ) workspace_installation.run() @@ -604,7 +608,6 @@ def test_query_metadata(ws): def test_remove_database(ws): sql_backend = MockBackend() - wheels = create_autospec(WheelsV2) ws = create_autospec(WorkspaceClient) prompts = MockPrompts( { @@ -614,9 +617,9 @@ def test_remove_database(ws): ) installation = create_autospec(Installation) config = WorkspaceConfig(inventory_database='ucx') - timeout = timedelta(seconds=1) + workflow_installer = create_autospec(WorkflowsInstallation) workspace_installation = WorkspaceInstallation( - config, installation, sql_backend, wheels, ws, prompts, timeout, PRODUCT_INFO + config, installation, sql_backend, ws, workflow_installer, prompts, PRODUCT_INFO ) workspace_installation.uninstall() @@ -626,7 +629,6 @@ def test_remove_database(ws): def test_remove_jobs_no_state(ws): sql_backend = MockBackend() - wheels = create_autospec(WheelsV2) ws = create_autospec(WorkspaceClient) prompts = MockPrompts( { @@ -636,9 +638,11 @@ def test_remove_jobs_no_state(ws): ) installation = create_autospec(Installation) config = WorkspaceConfig(inventory_database='ucx') - timeout = timedelta(seconds=1) + workflows_installer = WorkflowsInstallation( + config, installation, ws, create_autospec(WheelsV2), prompts, PRODUCT_INFO, timedelta(seconds=1) + ) workspace_installation = WorkspaceInstallation( - config, installation, sql_backend, wheels, ws, prompts, timeout, PRODUCT_INFO + config, installation, sql_backend, ws, workflows_installer, prompts, PRODUCT_INFO ) workspace_installation.uninstall() @@ -650,7 +654,6 @@ def test_remove_jobs_with_state_missing_job(ws, caplog, mock_installation_with_j ws.jobs.delete.side_effect = InvalidParameterValue("job id 123 not found") sql_backend = MockBackend() - wheels = create_autospec(WheelsV2) prompts = MockPrompts( { r'Do you want to uninstall ucx.*': 'yes', @@ -658,9 +661,12 @@ def test_remove_jobs_with_state_missing_job(ws, caplog, mock_installation_with_j } ) config = WorkspaceConfig(inventory_database='ucx') - timeout = timedelta(seconds=1) + installation = mock_installation_with_jobs + workflows_installer = WorkflowsInstallation( + config, installation, ws, create_autospec(WheelsV2), prompts, PRODUCT_INFO, timedelta(seconds=1) + ) workspace_installation = WorkspaceInstallation( - config, mock_installation_with_jobs, sql_backend, wheels, ws, prompts, timeout, PRODUCT_INFO + config, mock_installation_with_jobs, sql_backend, ws, workflows_installer, prompts, PRODUCT_INFO ) with caplog.at_level('ERROR'): @@ -674,7 +680,6 @@ def test_remove_warehouse(ws): ws.warehouses.get.return_value = sql.GetWarehouseResponse(id="123", name="Unity Catalog Migration 123456") sql_backend = MockBackend() - wheels = create_autospec(WheelsV2) prompts = MockPrompts( { r'Do you want to uninstall ucx.*': 'yes', @@ -683,9 +688,9 @@ def test_remove_warehouse(ws): ) installation = create_autospec(Installation) config = WorkspaceConfig(inventory_database='ucx', warehouse_id="123") - timeout = timedelta(seconds=1) + workflows_installer = create_autospec(WorkflowsInstallation) workspace_installation = WorkspaceInstallation( - config, installation, sql_backend, wheels, ws, prompts, timeout, PRODUCT_INFO + config, installation, sql_backend, ws, workflows_installer, prompts, PRODUCT_INFO ) workspace_installation.uninstall() @@ -697,7 +702,6 @@ def test_not_remove_warehouse_with_a_different_prefix(ws): ws.warehouses.get.return_value = sql.GetWarehouseResponse(id="123", name="Starter Endpoint") sql_backend = MockBackend() - wheels = create_autospec(WheelsV2) prompts = MockPrompts( { r'Do you want to uninstall ucx.*': 'yes', @@ -706,9 +710,9 @@ def test_not_remove_warehouse_with_a_different_prefix(ws): ) installation = create_autospec(Installation) config = WorkspaceConfig(inventory_database='ucx', warehouse_id="123") - timeout = timedelta(seconds=1) + workflows_installer = create_autospec(WorkflowsInstallation) workspace_installation = WorkspaceInstallation( - config, installation, sql_backend, wheels, ws, prompts, timeout, PRODUCT_INFO + config, installation, sql_backend, ws, workflows_installer, prompts, PRODUCT_INFO ) workspace_installation.uninstall() @@ -717,7 +721,6 @@ def test_not_remove_warehouse_with_a_different_prefix(ws): def test_remove_secret_scope(ws, caplog): - wheels = create_autospec(WheelsV2) prompts = MockPrompts( { r'Do you want to uninstall ucx.*': 'yes', @@ -726,17 +729,16 @@ def test_remove_secret_scope(ws, caplog): ) installation = MockInstallation() config = WorkspaceConfig(inventory_database='ucx', uber_spn_id="123") - timeout = timedelta(seconds=1) + workflows_installer = create_autospec(WorkflowsInstallation) # ws.secrets.delete_scope.side_effect = NotFound() workspace_installation = WorkspaceInstallation( - config, installation, MockBackend(), wheels, ws, prompts, timeout, PRODUCT_INFO + config, installation, MockBackend(), ws, workflows_installer, prompts, PRODUCT_INFO ) workspace_installation.uninstall() ws.secrets.delete_scope.assert_called_with('ucx') def test_remove_secret_scope_no_scope(ws, caplog): - wheels = create_autospec(WheelsV2) prompts = MockPrompts( { r'Do you want to uninstall ucx.*': 'yes', @@ -745,10 +747,10 @@ def test_remove_secret_scope_no_scope(ws, caplog): ) installation = MockInstallation() config = WorkspaceConfig(inventory_database='ucx', uber_spn_id="123") - timeout = timedelta(seconds=1) + workflows_installer = create_autospec(WorkflowsInstallation) ws.secrets.delete_scope.side_effect = NotFound() workspace_installation = WorkspaceInstallation( - config, installation, MockBackend(), wheels, ws, prompts, timeout, PRODUCT_INFO + config, installation, MockBackend(), ws, workflows_installer, prompts, PRODUCT_INFO ) with caplog.at_level('ERROR'): workspace_installation.uninstall() @@ -757,7 +759,6 @@ def test_remove_secret_scope_no_scope(ws, caplog): def test_remove_cluster_policy_not_exists(ws, caplog): sql_backend = MockBackend() - wheels = create_autospec(WheelsV2) prompts = MockPrompts( { r'Do you want to uninstall ucx.*': 'yes', @@ -766,10 +767,10 @@ def test_remove_cluster_policy_not_exists(ws, caplog): ) installation = create_autospec(Installation) config = WorkspaceConfig(inventory_database='ucx') - timeout = timedelta(seconds=1) ws.cluster_policies.delete.side_effect = NotFound() + workflows_installer = create_autospec(WorkflowsInstallation) workspace_installation = WorkspaceInstallation( - config, installation, sql_backend, wheels, ws, prompts, timeout, PRODUCT_INFO + config, installation, sql_backend, ws, workflows_installer, prompts, PRODUCT_INFO ) with caplog.at_level('ERROR'): @@ -781,7 +782,6 @@ def test_remove_warehouse_not_exists(ws, caplog): ws.warehouses.delete.side_effect = InvalidParameterValue("warehouse id 123 not found") sql_backend = MockBackend() - wheels = create_autospec(WheelsV2) prompts = MockPrompts( { r'Do you want to uninstall ucx.*': 'yes', @@ -790,9 +790,9 @@ def test_remove_warehouse_not_exists(ws, caplog): ) installation = create_autospec(Installation) config = WorkspaceConfig(inventory_database='ucx') - timeout = timedelta(seconds=1) + workflows_installer = create_autospec(WorkflowsInstallation) workspace_installation = WorkspaceInstallation( - config, installation, sql_backend, wheels, ws, prompts, timeout, PRODUCT_INFO + config, installation, sql_backend, ws, workflows_installer, prompts, PRODUCT_INFO ) with caplog.at_level('ERROR'): @@ -816,15 +816,13 @@ def test_repair_run(ws, mocker, any_prompt, mock_installation_with_jobs): ws.jobs.list_runs.return_value = base ws.jobs.list_runs.repair_run = None - sql_backend = MockBackend() - wheels = create_autospec(WheelsV2) config = WorkspaceConfig(inventory_database='ucx') timeout = timedelta(seconds=1) - workspace_installation = WorkspaceInstallation( - config, mock_installation_with_jobs, sql_backend, wheels, ws, any_prompt, timeout, PRODUCT_INFO + workflows_installer = WorkflowsInstallation( + config, mock_installation_with_jobs, ws, create_autospec(WheelsV2), any_prompt, PRODUCT_INFO, timeout ) - workspace_installation.repair_run("assessment") + workflows_installer.repair_run("assessment") def test_repair_run_success(ws, caplog, mock_installation_with_jobs, any_prompt): @@ -842,15 +840,14 @@ def test_repair_run_success(ws, caplog, mock_installation_with_jobs, any_prompt) ws.jobs.list_runs.return_value = base ws.jobs.list_runs.repair_run = None - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) config = WorkspaceConfig(inventory_database='ucx') timeout = timedelta(seconds=1) - workspace_installation = WorkspaceInstallation( - config, mock_installation_with_jobs, sql_backend, wheels, ws, any_prompt, timeout, PRODUCT_INFO + workflows_installer = WorkflowsInstallation( + config, mock_installation_with_jobs, ws, wheels, any_prompt, PRODUCT_INFO, timeout ) - workspace_installation.repair_run("assessment") + workflows_installer.repair_run("assessment") assert "job is not in FAILED state" in caplog.text @@ -870,16 +867,15 @@ def test_repair_run_no_job_id(ws, mock_installation, any_prompt, caplog): ws.jobs.list_runs.return_value = base ws.jobs.list_runs.repair_run = None - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) config = WorkspaceConfig(inventory_database='ucx') timeout = timedelta(seconds=1) - workspace_installation = WorkspaceInstallation( - config, mock_installation, sql_backend, wheels, ws, any_prompt, timeout, PRODUCT_INFO + workflows_installer = WorkflowsInstallation( + config, mock_installation, ws, wheels, any_prompt, PRODUCT_INFO, timeout ) with caplog.at_level('WARNING'): - workspace_installation.repair_run("assessment") + workflows_installer.repair_run("assessment") assert 'skipping assessment: job does not exists hence skipping repair' in caplog.messages @@ -887,32 +883,30 @@ def test_repair_run_no_job_run(ws, mock_installation_with_jobs, any_prompt, capl ws.jobs.list_runs.return_value = "" ws.jobs.list_runs.repair_run = None - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) config = WorkspaceConfig(inventory_database='ucx') timeout = timedelta(seconds=1) - workspace_installation = WorkspaceInstallation( - config, mock_installation_with_jobs, sql_backend, wheels, ws, any_prompt, timeout, PRODUCT_INFO + workflows_installer = WorkflowsInstallation( + config, mock_installation_with_jobs, ws, wheels, any_prompt, PRODUCT_INFO, timeout ) with caplog.at_level('WARNING'): - workspace_installation.repair_run("assessment") + workflows_installer.repair_run("assessment") assert "skipping assessment: job is not initialized yet. Can't trigger repair run now" in caplog.messages def test_repair_run_exception(ws, mock_installation_with_jobs, any_prompt, caplog): ws.jobs.list_runs.side_effect = InvalidParameterValue("Workflow does not exists") - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) config = WorkspaceConfig(inventory_database='ucx') timeout = timedelta(seconds=1) - workspace_installation = WorkspaceInstallation( - config, mock_installation_with_jobs, sql_backend, wheels, ws, any_prompt, timeout, PRODUCT_INFO + workflows_installer = WorkflowsInstallation( + config, mock_installation_with_jobs, ws, wheels, any_prompt, PRODUCT_INFO, timeout ) with caplog.at_level('WARNING'): - workspace_installation.repair_run("assessment") + workflows_installer.repair_run("assessment") assert "skipping assessment: Workflow does not exists" in caplog.messages @@ -931,15 +925,14 @@ def test_repair_run_result_state(ws, caplog, mock_installation_with_jobs, any_pr ws.jobs.list_runs.return_value = base ws.jobs.list_runs.repair_run = None - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) config = WorkspaceConfig(inventory_database='ucx') timeout = timedelta(seconds=1) - workspace_installation = WorkspaceInstallation( - config, mock_installation_with_jobs, sql_backend, wheels, ws, any_prompt, timeout, PRODUCT_INFO + workflows_installer = WorkflowsInstallation( + config, mock_installation_with_jobs, ws, wheels, any_prompt, PRODUCT_INFO, timeout ) - workspace_installation.repair_run("assessment") + workflows_installer.repair_run("assessment") assert "Please try after sometime" in caplog.text @@ -985,20 +978,19 @@ def test_latest_job_status_states(ws, mock_installation_with_jobs, any_prompt, s start_time=1704114000000, ) ] - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) config = WorkspaceConfig(inventory_database='ucx') timeout = timedelta(seconds=1) - workspace_installation = WorkspaceInstallation( - config, mock_installation_with_jobs, sql_backend, wheels, ws, any_prompt, timeout, PRODUCT_INFO + workflows_installer = WorkflowsInstallation( + config, mock_installation_with_jobs, ws, wheels, any_prompt, PRODUCT_INFO, timeout ) ws.jobs.list_runs.return_value = base - status = workspace_installation.latest_job_status() + status = workflows_installer.latest_job_status() assert len(status) == 1 assert status[0]["state"] == expected -@patch(f"{databricks.labs.ucx.install.__name__}.datetime", wraps=datetime) +@patch(f"{databricks.labs.ucx.installer.mixins.__name__}.datetime", wraps=datetime) @pytest.mark.parametrize( "start_time,expected", [ @@ -1022,17 +1014,16 @@ def test_latest_job_status_success_with_time( start_time=start_time, ) ] - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) config = WorkspaceConfig(inventory_database='ucx') timeout = timedelta(seconds=1) - workspace_installation = WorkspaceInstallation( - config, mock_installation_with_jobs, sql_backend, wheels, ws, any_prompt, timeout, PRODUCT_INFO + workflows_installer = WorkflowsInstallation( + config, mock_installation_with_jobs, ws, wheels, any_prompt, PRODUCT_INFO, timeout ) ws.jobs.list_runs.return_value = base faked_now = datetime(2024, 1, 1, 14, 0, 0) mock_datetime.now.return_value = faked_now - status = workspace_installation.latest_job_status() + status = workflows_installer.latest_job_status() assert status[0]["started"] == expected @@ -1062,16 +1053,13 @@ def test_latest_job_status_list(ws, any_prompt): ], [], # the last job has no runs ] - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) config = WorkspaceConfig(inventory_database='ucx') timeout = timedelta(seconds=1) mock_install = MockInstallation({'state.json': {'resources': {'jobs': {"job1": "1", "job2": "2", "job3": "3"}}}}) - workspace_installation = WorkspaceInstallation( - config, mock_install, sql_backend, wheels, ws, any_prompt, timeout, PRODUCT_INFO - ) + workflows_installer = WorkflowsInstallation(config, mock_install, ws, wheels, any_prompt, PRODUCT_INFO, timeout) ws.jobs.list_runs.side_effect = iter(runs) - status = workspace_installation.latest_job_status() + status = workflows_installer.latest_job_status() assert len(status) == 3 assert status[0]["step"] == "job1" assert status[0]["state"] == "RUNNING" @@ -1082,29 +1070,27 @@ def test_latest_job_status_list(ws, any_prompt): def test_latest_job_status_no_job_run(ws, mock_installation_with_jobs, any_prompt): - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) config = WorkspaceConfig(inventory_database='ucx') timeout = timedelta(seconds=1) - workspace_installation = WorkspaceInstallation( - config, mock_installation_with_jobs, sql_backend, wheels, ws, any_prompt, timeout, PRODUCT_INFO + workflows_installer = WorkflowsInstallation( + config, mock_installation_with_jobs, ws, wheels, any_prompt, PRODUCT_INFO, timeout ) ws.jobs.list_runs.return_value = "" - status = workspace_installation.latest_job_status() + status = workflows_installer.latest_job_status() assert len(status) == 1 assert status[0]["step"] == "assessment" def test_latest_job_status_exception(ws, mock_installation_with_jobs, any_prompt): - sql_backend = MockBackend() wheels = create_autospec(WheelsV2) config = WorkspaceConfig(inventory_database='ucx') timeout = timedelta(seconds=1) - workspace_installation = WorkspaceInstallation( - config, mock_installation_with_jobs, sql_backend, wheels, ws, any_prompt, timeout, PRODUCT_INFO + workflows_installer = WorkflowsInstallation( + config, mock_installation_with_jobs, ws, wheels, any_prompt, PRODUCT_INFO, timeout ) ws.jobs.list_runs.side_effect = InvalidParameterValue("Workflow does not exists") - status = workspace_installation.latest_job_status() + status = workflows_installer.latest_job_status() assert len(status) == 0 @@ -1149,9 +1135,8 @@ def test_save_config_should_include_databases(ws, mock_installation): 'inventory_database': 'ucx', 'log_level': 'INFO', 'num_threads': 8, - 'spark_conf': {'spark.sql.sources.parallelPartitionDiscovery.parallelism': '200'}, - 'max_workers': 10, 'min_workers': 1, + 'max_workers': 10, 'policy_id': 'foo', 'renamed_group_prefix': 'db-temp-', 'warehouse_id': 'abc', @@ -1172,14 +1157,19 @@ def test_triggering_assessment_wf(ws, mocker, mock_installation): r"Open assessment Job url that just triggered ?.*": "yes", } ) + config = WorkspaceConfig(inventory_database="ucx", policy_id='123') + wheels = create_autospec(WheelsV2) + installation = mock_installation + workflows_installer = WorkflowsInstallation( + config, installation, ws, wheels, prompts, PRODUCT_INFO, timedelta(seconds=1) + ) workspace_installation = WorkspaceInstallation( - WorkspaceConfig(inventory_database="ucx", policy_id='123'), - mock_installation, + config, + installation, sql_backend, - create_autospec(WheelsV2), ws, + workflows_installer, prompts, - timedelta(seconds=1), PRODUCT_INFO, ) workspace_installation.run() @@ -1201,7 +1191,7 @@ def test_runs_upgrades_on_too_old_version(ws, any_prompt): sql_backend = MockBackend() wheels = create_autospec(WheelsV2) install.run( - verify_timeout=timedelta(seconds=1), + verify_timeout=timedelta(seconds=60), sql_backend_factory=lambda _: sql_backend, wheel_builder_factory=lambda: wheels, ) @@ -1226,7 +1216,7 @@ def test_runs_upgrades_on_more_recent_version(ws, any_prompt): wheels = create_autospec(WheelsV2) install.run( - verify_timeout=timedelta(seconds=1), + verify_timeout=timedelta(seconds=10), sql_backend_factory=lambda _: sql_backend, wheel_builder_factory=lambda: wheels, ) @@ -1240,9 +1230,6 @@ def test_fresh_install(ws, mock_installation): r".*PRO or SERVERLESS SQL warehouse.*": "1", r"Choose how to map the workspace groups.*": "2", r"Open config file in.*": "no", - r"Parallelism for migrating.*": "1000", - r"Min workers for auto-scale.*": "2", - r"Max workers for auto-scale.*": "20", r".*": "", } ) @@ -1260,10 +1247,9 @@ def test_fresh_install(ws, mock_installation): 'log_level': 'INFO', 'num_days_submit_runs_history': 30, 'num_threads': 8, - 'spark_conf': {'spark.sql.sources.parallelPartitionDiscovery.parallelism': '1000'}, - 'max_workers': 20, - 'min_workers': 2, 'policy_id': 'foo', + 'min_workers': 1, + 'max_workers': 10, 'renamed_group_prefix': 'db-temp-', 'warehouse_id': 'abc', 'workspace_start_path': '/', @@ -1271,69 +1257,31 @@ def test_fresh_install(ws, mock_installation): ) -def test_install_with_external_hms_conf(ws, mock_installation): - prompts = MockPrompts( - { - r".*PRO or SERVERLESS SQL warehouse.*": "1", - r"Choose how to map the workspace groups.*": "2", - r"Open config file in.*": "no", - r"Parallelism for migrating.*": "1000", - r"Min workers for auto-scale.*": "2", - r"Max workers for auto-scale.*": "20", - r".*We have identified one or more cluster.*": "Yes", - r".*Choose a cluster policy.*": "0", - r".*": "", - } +def test_remove_jobs(ws, caplog, mock_installation_extra_jobs, any_prompt): + sql_backend = MockBackend() + workflows_installation = WorkflowsInstallation( + WorkspaceConfig(inventory_database="...", policy_id='123'), + mock_installation_extra_jobs, + ws, + create_autospec(WheelsV2), + any_prompt, + PRODUCT_INFO, + timedelta(seconds=1), ) - ws.workspace.get_status = not_found - - policy_definition = { - "spark_conf.spark.hadoop.javax.jdo.option.ConnectionURL": {"value": "url"}, - "spark_conf.spark.hadoop.javax.jdo.option.ConnectionUserName": {"value": "user"}, - "spark_conf.spark.hadoop.javax.jdo.option.ConnectionPassword": {"value": "pwd"}, - "spark_conf.spark.hadoop.javax.jdo.option.ConnectionDriverName": {"value": "driver"}, - "spark_conf.spark.sql.hive.metastore.version": {"value": "2.3"}, - "spark_conf.spark.sql.hive.metastore.jars": {"value": "jar"}, - } - ws.cluster_policies.list.return_value = [ - Policy( - policy_id="id1", - name="foo", - definition=json.dumps(policy_definition), - description="Custom cluster policy for Unity Catalog Migration (UCX)", - ) - ] - - install = WorkspaceInstaller(prompts, mock_installation, ws, PRODUCT_INFO) - install.configure() - mock_installation.assert_file_written( - 'config.yml', - { - 'version': 2, - 'default_catalog': 'ucx_default', - 'inventory_database': 'ucx', - 'log_level': 'INFO', - 'num_days_submit_runs_history': 30, - 'num_threads': 8, - 'spark_conf': { - 'spark.hadoop.javax.jdo.option.ConnectionDriverName': 'driver', - 'spark.hadoop.javax.jdo.option.ConnectionPassword': 'pwd', - 'spark.hadoop.javax.jdo.option.ConnectionURL': 'url', - 'spark.hadoop.javax.jdo.option.ConnectionUserName': 'user', - 'spark.sql.hive.metastore.jars': 'jar', - 'spark.sql.hive.metastore.version': '2.3', - 'spark.sql.sources.parallelPartitionDiscovery.parallelism': '1000', - }, - 'max_workers': 20, - 'min_workers': 2, - 'policy_id': 'foo', - 'renamed_group_prefix': 'db-temp-', - 'warehouse_id': 'abc', - 'workspace_start_path': '/', - }, + workspace_installation = WorkspaceInstallation( + WorkspaceConfig(inventory_database='ucx'), + mock_installation_extra_jobs, + sql_backend, + ws, + workflows_installation, + any_prompt, + PRODUCT_INFO, ) + workspace_installation.run() + ws.jobs.delete.assert_called_with("123") + def test_get_existing_installation_global(ws, mock_installation, mocker): base_prompts = MockPrompts( @@ -1486,3 +1434,66 @@ def test_check_inventory_database_exists(ws, mock_installation): with pytest.raises(AlreadyExists, match="Inventory database 'ucx_exists' already exists in another installation"): install.configure() + + +def test_user_not_admin(ws, mock_installation, any_prompt): + ws.current_user.me = lambda: iam.User(user_name="me@example.com", groups=[iam.ComplexValue(display="group1")]) + wheels = create_autospec(WheelsV2) + workspace_installation = WorkflowsInstallation( + WorkspaceConfig(inventory_database='ucx', policy_id='123'), + mock_installation, + ws, + wheels, + any_prompt, + PRODUCT_INFO, + timedelta(seconds=1), + ) + + with pytest.raises(PermissionError) as failure: + workspace_installation.create_jobs() + assert "Current user is not a workspace admin" in str(failure.value) + + +@pytest.mark.parametrize( + "result_state,expected", + [ + (RunState(result_state=RunResultState.SUCCESS, life_cycle_state=RunLifeCycleState.TERMINATED), True), + (RunState(result_state=RunResultState.FAILED, life_cycle_state=RunLifeCycleState.TERMINATED), False), + ], +) +def test_validate_step(ws, any_prompt, result_state, expected): + installation = create_autospec(Installation) + installation.load.return_value = RawState({'jobs': {'assessment': '123'}}) + workflows_installer = WorkflowsInstallation( + WorkspaceConfig(inventory_database="...", policy_id='123'), + installation, + ws, + create_autospec(WheelsV2), + any_prompt, + PRODUCT_INFO, + timedelta(seconds=1), + ) + ws.jobs.list_runs.return_value = [ + BaseRun( + job_id=123, + run_id=456, + run_name="assessment", + state=RunState(result_state=None, life_cycle_state=RunLifeCycleState.RUNNING), + ) + ] + + ws.jobs.wait_get_run_job_terminated_or_skipped.return_value = BaseRun( + job_id=123, + run_id=456, + run_name="assessment", + state=RunState(result_state=RunResultState.SUCCESS, life_cycle_state=RunLifeCycleState.TERMINATED), + ) + + ws.jobs.get_run.return_value = BaseRun( + job_id=123, + run_id=456, + run_name="assessment", + state=result_state, + ) + + assert workflows_installer.validate_step("assessment") == expected