Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,16 @@ def reset_job(self, job_id: str, json: dict) -> None:

:param json: The data used in the new_settings of the request to the ``reset`` endpoint.
"""
access_control_list = json.get("access_control_list", None)
if access_control_list:
self.log.info(
"Updating job permission for Databricks workflow job id %s with access_control_list %s",
job_id,
access_control_list,
)
acl_json = {"access_control_list": access_control_list}
self.update_job_permission(job_id=int(job_id), json=acl_json)

self._do_api_call(RESET_ENDPOINT, {"job_id": job_id, "new_settings": json})

def update_job(self, job_id: str, json: dict) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,6 @@ def execute(self, context: Context) -> int:
if job_id is None:
return self._hook.create_job(self.json)
self._hook.reset_job(str(job_id), self.json)
if (access_control_list := self.json.get("access_control_list")) is not None:
acl_json = {"access_control_list": access_control_list}
self._hook.update_job_permission(job_id, normalise_json_content(acl_json))

return job_id


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,13 @@ def list_spark_versions_endpoint(host):
return f"https://{host}/api/2.0/clusters/spark-versions"


def permissions_endpoint(host, job_id):
"""
Utility function to generate the permissions endpoint given the host
"""
return f"https://{host}/api/2.0/permissions/jobs/{job_id}"


def create_valid_response_mock(content):
response = mock.MagicMock()
response.json.return_value = content
Expand Down Expand Up @@ -474,7 +481,7 @@ def test_create(self, mock_requests):
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_reset(self, mock_requests):
def test_reset_with_no_acl(self, mock_requests):
mock_requests.codes.ok = 200
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
Expand All @@ -490,6 +497,40 @@ def test_reset(self, mock_requests):
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_reset_with_acl(self, mock_requests):
mock_requests.codes.ok = 200
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
ACCESS_CONTROL_LIST = [{"permission_level": "CAN_MANAGE", "user_name": "test_user"}]
json = {
"access_control_list": ACCESS_CONTROL_LIST,
"name": "test",
}

self.hook.reset_job(JOB_ID, json)

mock_requests.post.assert_called_once_with(
reset_endpoint(HOST),
json={
"job_id": JOB_ID,
"new_settings": json,
},
params=None,
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)

mock_requests.patch.assert_called_once_with(
permissions_endpoint(HOST, JOB_ID),
json={"access_control_list": ACCESS_CONTROL_LIST},
params=None,
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_update(self, mock_requests):
mock_requests.codes.ok = 200
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,42 +569,6 @@ def test_exec_reset(self, db_mock_class):
db_mock.reset_job.assert_called_once_with(JOB_ID, expected)
assert return_result == JOB_ID

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_update_job_permission(self, db_mock_class):
"""
Test job permission update.
"""
json = {
"name": JOB_NAME,
"tags": TAGS,
"tasks": TASKS,
"job_clusters": JOB_CLUSTERS,
"email_notifications": EMAIL_NOTIFICATIONS,
"webhook_notifications": WEBHOOK_NOTIFICATIONS,
"timeout_seconds": TIMEOUT_SECONDS,
"schedule": SCHEDULE,
"max_concurrent_runs": MAX_CONCURRENT_RUNS,
"git_source": GIT_SOURCE,
"access_control_list": ACCESS_CONTROL_LIST,
}
op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
db_mock = db_mock_class.return_value
db_mock.find_job_id_by_name.return_value = JOB_ID

op.execute({})

expected = utils.normalise_json_content({"access_control_list": ACCESS_CONTROL_LIST})

db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay,
retry_args=None,
caller="DatabricksCreateJobsOperator",
)

db_mock.update_job_permission.assert_called_once_with(JOB_ID, expected)

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_update_job_permission_with_empty_acl(self, db_mock_class):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def test_create_or_reset_job_existing(mock_databricks_hook, context, mock_task_g
operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default")
operator.task_group = mock_task_group
operator._hook.list_jobs.return_value = [{"job_id": 123}]
operator._hook.create_job.return_value = 123

job_id = operator._create_or_reset_job(context)
assert job_id == 123
Expand Down
Loading