Skip to content

Commit

Permalink
Make Databricks operators' json parameter compatible with XComs, Jinj…
Browse files Browse the repository at this point in the history
…a expression values (apache#40471)
  • Loading branch information
boraberke committed Jul 2, 2024
1 parent db16eeb commit 4fb2140
Show file tree
Hide file tree
Showing 4 changed files with 571 additions and 221 deletions.
184 changes: 93 additions & 91 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
WorkflowRunMetadata,
)
from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event
from airflow.providers.databricks.utils.databricks import _normalise_json_content, validate_trigger_event

if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
Expand Down Expand Up @@ -182,6 +182,17 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
raise AirflowException(error_message)


def _handle_overridden_json_params(operator):
for key, value in operator.overridden_json_params.items():
if value is not None:
operator.json[key] = value


def normalise_json_content(operator):
if operator.json:
operator.json = _normalise_json_content(operator.json)


class DatabricksJobRunLink(BaseOperatorLink):
"""Constructs a link to monitor a Databricks Job Run."""

Expand Down Expand Up @@ -285,34 +296,21 @@ def __init__(
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
if name is not None:
self.json["name"] = name
if description is not None:
self.json["description"] = description
if tags is not None:
self.json["tags"] = tags
if tasks is not None:
self.json["tasks"] = tasks
if job_clusters is not None:
self.json["job_clusters"] = job_clusters
if email_notifications is not None:
self.json["email_notifications"] = email_notifications
if webhook_notifications is not None:
self.json["webhook_notifications"] = webhook_notifications
if notification_settings is not None:
self.json["notification_settings"] = notification_settings
if timeout_seconds is not None:
self.json["timeout_seconds"] = timeout_seconds
if schedule is not None:
self.json["schedule"] = schedule
if max_concurrent_runs is not None:
self.json["max_concurrent_runs"] = max_concurrent_runs
if git_source is not None:
self.json["git_source"] = git_source
if access_control_list is not None:
self.json["access_control_list"] = access_control_list
if self.json:
self.json = normalise_json_content(self.json)
self.overridden_json_params = {
"name": name,
"description": description,
"tags": tags,
"tasks": tasks,
"job_clusters": job_clusters,
"email_notifications": email_notifications,
"webhook_notifications": webhook_notifications,
"notification_settings": notification_settings,
"timeout_seconds": timeout_seconds,
"schedule": schedule,
"max_concurrent_runs": max_concurrent_runs,
"git_source": git_source,
"access_control_list": access_control_list,
}

@cached_property
def _hook(self):
Expand All @@ -324,16 +322,24 @@ def _hook(self):
caller="DatabricksCreateJobsOperator",
)

def execute(self, context: Context) -> int:
def _setup_and_validate_json(self):
_handle_overridden_json_params(self)

if "name" not in self.json:
raise AirflowException("Missing required parameter: name")

normalise_json_content(self)

def execute(self, context: Context) -> int:
self._setup_and_validate_json()

job_id = self._hook.find_job_id_by_name(self.json["name"])
if job_id is None:
return self._hook.create_job(self.json)
self._hook.reset_job(str(job_id), self.json)
if (access_control_list := self.json.get("access_control_list")) is not None:
acl_json = {"access_control_list": access_control_list}
self._hook.update_job_permission(job_id, normalise_json_content(acl_json))
self._hook.update_job_permission(job_id, _normalise_json_content(acl_json))

return job_id

Expand Down Expand Up @@ -506,43 +512,23 @@ def __init__(
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
if tasks is not None:
self.json["tasks"] = tasks
if spark_jar_task is not None:
self.json["spark_jar_task"] = spark_jar_task
if notebook_task is not None:
self.json["notebook_task"] = notebook_task
if spark_python_task is not None:
self.json["spark_python_task"] = spark_python_task
if spark_submit_task is not None:
self.json["spark_submit_task"] = spark_submit_task
if pipeline_task is not None:
self.json["pipeline_task"] = pipeline_task
if dbt_task is not None:
self.json["dbt_task"] = dbt_task
if new_cluster is not None:
self.json["new_cluster"] = new_cluster
if existing_cluster_id is not None:
self.json["existing_cluster_id"] = existing_cluster_id
if libraries is not None:
self.json["libraries"] = libraries
if run_name is not None:
self.json["run_name"] = run_name
if timeout_seconds is not None:
self.json["timeout_seconds"] = timeout_seconds
if "run_name" not in self.json:
self.json["run_name"] = run_name or kwargs["task_id"]
if idempotency_token is not None:
self.json["idempotency_token"] = idempotency_token
if access_control_list is not None:
self.json["access_control_list"] = access_control_list
if git_source is not None:
self.json["git_source"] = git_source

if "dbt_task" in self.json and "git_source" not in self.json:
raise AirflowException("git_source is required for dbt_task")
if pipeline_task is not None and "pipeline_id" in pipeline_task and "pipeline_name" in pipeline_task:
raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'")
self.overridden_json_params = {
"tasks": tasks,
"spark_jar_task": spark_jar_task,
"notebook_task": notebook_task,
"spark_python_task": spark_python_task,
"spark_submit_task": spark_submit_task,
"pipeline_task": pipeline_task,
"dbt_task": dbt_task,
"new_cluster": new_cluster,
"existing_cluster_id": existing_cluster_id,
"libraries": libraries,
"run_name": run_name,
"timeout_seconds": timeout_seconds,
"idempotency_token": idempotency_token,
"access_control_list": access_control_list,
"git_source": git_source,
}

# This variable will be used in case our task gets killed.
self.run_id: int | None = None
Expand All @@ -561,7 +547,25 @@ def _get_hook(self, caller: str) -> DatabricksHook:
caller=caller,
)

def _setup_and_validate_json(self):
_handle_overridden_json_params(self)

if "run_name" not in self.json or self.json["run_name"] is None:
self.json["run_name"] = self.task_id

if "dbt_task" in self.json and "git_source" not in self.json:
raise AirflowException("git_source is required for dbt_task")
if (
"pipeline_task" in self.json
and "pipeline_id" in self.json["pipeline_task"]
and "pipeline_name" in self.json["pipeline_task"]
):
raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'")

normalise_json_content(self)

def execute(self, context: Context):
self._setup_and_validate_json()
if (
"pipeline_task" in self.json
and self.json["pipeline_task"].get("pipeline_id") is None
Expand All @@ -571,7 +575,7 @@ def execute(self, context: Context):
pipeline_name = self.json["pipeline_task"]["pipeline_name"]
self.json["pipeline_task"]["pipeline_id"] = self._hook.find_pipeline_id_by_name(pipeline_name)
del self.json["pipeline_task"]["pipeline_name"]
json_normalised = normalise_json_content(self.json)
json_normalised = _normalise_json_content(self.json)
self.run_id = self._hook.submit_run(json_normalised)
if self.deferrable:
_handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context)
Expand Down Expand Up @@ -607,7 +611,7 @@ def __init__(self, *args, **kwargs):

def execute(self, context):
hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator")
json_normalised = normalise_json_content(self.json)
json_normalised = _normalise_json_content(self.json)
self.run_id = hook.submit_run(json_normalised)
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)

Expand Down Expand Up @@ -807,27 +811,16 @@ def __init__(
self.deferrable = deferrable
self.repair_run = repair_run
self.cancel_previous_runs = cancel_previous_runs

if job_id is not None:
self.json["job_id"] = job_id
if job_name is not None:
self.json["job_name"] = job_name
if "job_id" in self.json and "job_name" in self.json:
raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")
if notebook_params is not None:
self.json["notebook_params"] = notebook_params
if python_params is not None:
self.json["python_params"] = python_params
if python_named_params is not None:
self.json["python_named_params"] = python_named_params
if jar_params is not None:
self.json["jar_params"] = jar_params
if spark_submit_params is not None:
self.json["spark_submit_params"] = spark_submit_params
if idempotency_token is not None:
self.json["idempotency_token"] = idempotency_token
if self.json:
self.json = normalise_json_content(self.json)
self.overridden_json_params = {
"job_id": job_id,
"job_name": job_name,
"notebook_params": notebook_params,
"python_params": python_params,
"python_named_params": python_named_params,
"jar_params": jar_params,
"spark_submit_params": spark_submit_params,
"idempotency_token": idempotency_token,
}
# This variable will be used in case our task gets killed.
self.run_id: int | None = None
self.do_xcom_push = do_xcom_push
Expand All @@ -845,7 +838,16 @@ def _get_hook(self, caller: str) -> DatabricksHook:
caller=caller,
)

def _setup_and_validate_json(self):
_handle_overridden_json_params(self)

if "job_id" in self.json and "job_name" in self.json:
raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")

normalise_json_content(self)

def execute(self, context: Context):
self._setup_and_validate_json()
hook = self._hook
if "job_name" in self.json:
job_id = hook.find_job_id_by_name(self.json["job_name"])
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/databricks/utils/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from airflow.providers.databricks.hooks.databricks import RunState


def normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict:
def _normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict:
"""
Normalize content or all values of content if it is a dict to a string.
Expand All @@ -33,7 +33,7 @@ def normalise_json_content(content, json_path: str = "json") -> str | bool | lis
The only one exception is when we have boolean values, they can not be converted
to string type because databricks does not understand 'True' or 'False' values.
"""
normalise = normalise_json_content
normalise = _normalise_json_content
if isinstance(content, (str, bool)):
return content
elif isinstance(content, (int, float)):
Expand Down
Loading

0 comments on commit 4fb2140

Please sign in to comment.