diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py index 52b0eaa1768e0..aaa9fece60c30 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py @@ -298,6 +298,16 @@ def reset_job(self, job_id: str, json: dict) -> None: :param json: The data used in the new_settings of the request to the ``reset`` endpoint. """ + access_control_list = json.get("access_control_list", None) + if access_control_list: + self.log.info( + "Updating job permission for Databricks workflow job id %s with access_control_list %s", + job_id, + access_control_list, + ) + acl_json = {"access_control_list": access_control_list} + self.update_job_permission(job_id=int(job_id), json=acl_json) + self._do_api_call(RESET_ENDPOINT, {"job_id": job_id, "new_settings": json}) def update_job(self, job_id: str, json: dict) -> None: diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index e2596b0e16a65..f1c8a411962a9 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -398,10 +398,6 @@ def execute(self, context: Context) -> int: if job_id is None: return self._hook.create_job(self.json) self._hook.reset_job(str(job_id), self.json) - if (access_control_list := self.json.get("access_control_list")) is not None: - acl_json = {"access_control_list": access_control_list} - self._hook.update_job_permission(job_id, normalise_json_content(acl_json)) - return job_id diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py index f7920d16d2d1b..1564c6dbfbc0f 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py @@ -275,6 +275,13 @@ def list_spark_versions_endpoint(host): return f"https://{host}/api/2.0/clusters/spark-versions" +def permissions_endpoint(host, job_id): + """ + Utility function to generate the permissions endpoint given the host + """ + return f"https://{host}/api/2.0/permissions/jobs/{job_id}" + + def create_valid_response_mock(content): response = mock.MagicMock() response.json.return_value = content @@ -474,7 +481,7 @@ def test_create(self, mock_requests): ) @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") - def test_reset(self, mock_requests): + def test_reset_with_no_acl(self, mock_requests): mock_requests.codes.ok = 200 status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock @@ -490,6 +497,40 @@ def test_reset(self, mock_requests): timeout=self.hook.timeout_seconds, ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_reset_with_acl(self, mock_requests): + mock_requests.codes.ok = 200 + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + ACCESS_CONTROL_LIST = [{"permission_level": "CAN_MANAGE", "user_name": "test_user"}] + json = { + "access_control_list": ACCESS_CONTROL_LIST, + "name": "test", + } + + self.hook.reset_job(JOB_ID, json) + + mock_requests.post.assert_called_once_with( + reset_endpoint(HOST), + json={ + "job_id": JOB_ID, + "new_settings": json, + }, + params=None, + auth=HTTPBasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + + mock_requests.patch.assert_called_once_with( + permissions_endpoint(HOST, JOB_ID), + json={"access_control_list": ACCESS_CONTROL_LIST}, + params=None, + auth=HTTPBasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_update(self, mock_requests): mock_requests.codes.ok = 200 diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks.py b/providers/databricks/tests/unit/databricks/operators/test_databricks.py index 592b999e07960..715a0d35f55e6 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks.py @@ -569,42 +569,6 @@ def test_exec_reset(self, db_mock_class): db_mock.reset_job.assert_called_once_with(JOB_ID, expected) assert return_result == JOB_ID - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_exec_update_job_permission(self, db_mock_class): - """ - Test job permission update. - """ - json = { - "name": JOB_NAME, - "tags": TAGS, - "tasks": TASKS, - "job_clusters": JOB_CLUSTERS, - "email_notifications": EMAIL_NOTIFICATIONS, - "webhook_notifications": WEBHOOK_NOTIFICATIONS, - "timeout_seconds": TIMEOUT_SECONDS, - "schedule": SCHEDULE, - "max_concurrent_runs": MAX_CONCURRENT_RUNS, - "git_source": GIT_SOURCE, - "access_control_list": ACCESS_CONTROL_LIST, - } - op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) - db_mock = db_mock_class.return_value - db_mock.find_job_id_by_name.return_value = JOB_ID - - op.execute({}) - - expected = utils.normalise_json_content({"access_control_list": ACCESS_CONTROL_LIST}) - - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay, - retry_args=None, - caller="DatabricksCreateJobsOperator", - ) - - db_mock.update_job_permission.assert_called_once_with(JOB_ID, expected) - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_update_job_permission_with_empty_acl(self, db_mock_class): """ diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py index c8b650a69b48c..36105c47c80d7 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py @@ -104,7 +104,6 @@ def test_create_or_reset_job_existing(mock_databricks_hook, context, mock_task_g operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default") operator.task_group = mock_task_group operator._hook.list_jobs.return_value = [{"job_id": 123}] - operator._hook.create_job.return_value = 123 job_id = operator._create_or_reset_job(context) assert job_id == 123