Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Databricks operators' json parameter compatible with XComs, Jinja expression values #40471

Merged
merged 9 commits into from
Jul 2, 2024
184 changes: 93 additions & 91 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
WorkflowRunMetadata,
)
from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event
from airflow.providers.databricks.utils.databricks import _normalise_json_content, validate_trigger_event

if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
Expand Down Expand Up @@ -182,6 +182,17 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
raise AirflowException(error_message)


def _handle_overridden_json_params(operator):
for key, value in operator.overridden_json_params.items():
if value is not None:
operator.json[key] = value


def normalise_json_content(operator):
if operator.json:
operator.json = _normalise_json_content(operator.json)


class DatabricksJobRunLink(BaseOperatorLink):
"""Constructs a link to monitor a Databricks Job Run."""

Expand Down Expand Up @@ -284,34 +295,21 @@ def __init__(
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
if name is not None:
self.json["name"] = name
if description is not None:
self.json["description"] = description
if tags is not None:
self.json["tags"] = tags
if tasks is not None:
self.json["tasks"] = tasks
if job_clusters is not None:
self.json["job_clusters"] = job_clusters
if email_notifications is not None:
self.json["email_notifications"] = email_notifications
if webhook_notifications is not None:
self.json["webhook_notifications"] = webhook_notifications
if notification_settings is not None:
self.json["notification_settings"] = notification_settings
if timeout_seconds is not None:
self.json["timeout_seconds"] = timeout_seconds
if schedule is not None:
self.json["schedule"] = schedule
if max_concurrent_runs is not None:
self.json["max_concurrent_runs"] = max_concurrent_runs
if git_source is not None:
self.json["git_source"] = git_source
if access_control_list is not None:
self.json["access_control_list"] = access_control_list
if self.json:
self.json = normalise_json_content(self.json)
self.overridden_json_params = {
"name": name,
"description": description,
"tags": tags,
"tasks": tasks,
"job_clusters": job_clusters,
"email_notifications": email_notifications,
"webhook_notifications": webhook_notifications,
"notification_settings": notification_settings,
"timeout_seconds": timeout_seconds,
"schedule": schedule,
"max_concurrent_runs": max_concurrent_runs,
"git_source": git_source,
"access_control_list": access_control_list,
}

@cached_property
def _hook(self):
Expand All @@ -323,16 +321,24 @@ def _hook(self):
caller="DatabricksCreateJobsOperator",
)

def execute(self, context: Context) -> int:
def _setup_and_validate_json(self):
_handle_overridden_json_params(self)

if "name" not in self.json:
raise AirflowException("Missing required parameter: name")

normalise_json_content(self)

def execute(self, context: Context) -> int:
self._setup_and_validate_json()

job_id = self._hook.find_job_id_by_name(self.json["name"])
if job_id is None:
return self._hook.create_job(self.json)
self._hook.reset_job(str(job_id), self.json)
if (access_control_list := self.json.get("access_control_list")) is not None:
acl_json = {"access_control_list": access_control_list}
self._hook.update_job_permission(job_id, normalise_json_content(acl_json))
self._hook.update_job_permission(job_id, _normalise_json_content(acl_json))

return job_id

Expand Down Expand Up @@ -505,43 +511,23 @@ def __init__(
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
if tasks is not None:
self.json["tasks"] = tasks
if spark_jar_task is not None:
self.json["spark_jar_task"] = spark_jar_task
if notebook_task is not None:
self.json["notebook_task"] = notebook_task
if spark_python_task is not None:
self.json["spark_python_task"] = spark_python_task
if spark_submit_task is not None:
self.json["spark_submit_task"] = spark_submit_task
if pipeline_task is not None:
self.json["pipeline_task"] = pipeline_task
if dbt_task is not None:
self.json["dbt_task"] = dbt_task
if new_cluster is not None:
self.json["new_cluster"] = new_cluster
if existing_cluster_id is not None:
self.json["existing_cluster_id"] = existing_cluster_id
if libraries is not None:
self.json["libraries"] = libraries
if run_name is not None:
self.json["run_name"] = run_name
if timeout_seconds is not None:
self.json["timeout_seconds"] = timeout_seconds
if "run_name" not in self.json:
self.json["run_name"] = run_name or kwargs["task_id"]
if idempotency_token is not None:
self.json["idempotency_token"] = idempotency_token
if access_control_list is not None:
self.json["access_control_list"] = access_control_list
if git_source is not None:
self.json["git_source"] = git_source

if "dbt_task" in self.json and "git_source" not in self.json:
raise AirflowException("git_source is required for dbt_task")
if pipeline_task is not None and "pipeline_id" in pipeline_task and "pipeline_name" in pipeline_task:
raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'")
self.overridden_json_params = {
"tasks": tasks,
"spark_jar_task": spark_jar_task,
"notebook_task": notebook_task,
"spark_python_task": spark_python_task,
"spark_submit_task": spark_submit_task,
"pipeline_task": pipeline_task,
"dbt_task": dbt_task,
"new_cluster": new_cluster,
"existing_cluster_id": existing_cluster_id,
"libraries": libraries,
"run_name": run_name,
"timeout_seconds": timeout_seconds,
"idempotency_token": idempotency_token,
"access_control_list": access_control_list,
"git_source": git_source,
}

# This variable will be used in case our task gets killed.
self.run_id: int | None = None
Expand All @@ -560,7 +546,25 @@ def _get_hook(self, caller: str) -> DatabricksHook:
caller=caller,
)

def _setup_and_validate_json(self):
_handle_overridden_json_params(self)

if "run_name" not in self.json or self.json["run_name"] is None:
self.json["run_name"] = self.task_id

if "dbt_task" in self.json and "git_source" not in self.json:
raise AirflowException("git_source is required for dbt_task")
if (
"pipeline_task" in self.json
and "pipeline_id" in self.json["pipeline_task"]
and "pipeline_name" in self.json["pipeline_task"]
):
raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'")

normalise_json_content(self)

def execute(self, context: Context):
self._setup_and_validate_json()
if (
"pipeline_task" in self.json
and self.json["pipeline_task"].get("pipeline_id") is None
Expand All @@ -570,7 +574,7 @@ def execute(self, context: Context):
pipeline_name = self.json["pipeline_task"]["pipeline_name"]
self.json["pipeline_task"]["pipeline_id"] = self._hook.find_pipeline_id_by_name(pipeline_name)
del self.json["pipeline_task"]["pipeline_name"]
json_normalised = normalise_json_content(self.json)
json_normalised = _normalise_json_content(self.json)
self.run_id = self._hook.submit_run(json_normalised)
if self.deferrable:
_handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context)
Expand Down Expand Up @@ -606,7 +610,7 @@ def __init__(self, *args, **kwargs):

def execute(self, context):
hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator")
json_normalised = normalise_json_content(self.json)
json_normalised = _normalise_json_content(self.json)
self.run_id = hook.submit_run(json_normalised)
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)

Expand Down Expand Up @@ -806,27 +810,16 @@ def __init__(
self.deferrable = deferrable
self.repair_run = repair_run
self.cancel_previous_runs = cancel_previous_runs

if job_id is not None:
self.json["job_id"] = job_id
if job_name is not None:
self.json["job_name"] = job_name
if "job_id" in self.json and "job_name" in self.json:
raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")
if notebook_params is not None:
self.json["notebook_params"] = notebook_params
if python_params is not None:
self.json["python_params"] = python_params
if python_named_params is not None:
self.json["python_named_params"] = python_named_params
if jar_params is not None:
self.json["jar_params"] = jar_params
if spark_submit_params is not None:
self.json["spark_submit_params"] = spark_submit_params
if idempotency_token is not None:
self.json["idempotency_token"] = idempotency_token
if self.json:
self.json = normalise_json_content(self.json)
self.overridden_json_params = {
"job_id": job_id,
"job_name": job_name,
"notebook_params": notebook_params,
"python_params": python_params,
"python_named_params": python_named_params,
"jar_params": jar_params,
"spark_submit_params": spark_submit_params,
"idempotency_token": idempotency_token,
}
# This variable will be used in case our task gets killed.
self.run_id: int | None = None
self.do_xcom_push = do_xcom_push
Expand All @@ -844,7 +837,16 @@ def _get_hook(self, caller: str) -> DatabricksHook:
caller=caller,
)

def _setup_and_validate_json(self):
potiuk marked this conversation as resolved.
Show resolved Hide resolved
_handle_overridden_json_params(self)

if "job_id" in self.json and "job_name" in self.json:
raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")

normalise_json_content(self)

def execute(self, context: Context):
self._setup_and_validate_json()
hook = self._hook
if "job_name" in self.json:
job_id = hook.find_job_id_by_name(self.json["job_name"])
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/databricks/utils/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from airflow.providers.databricks.hooks.databricks import RunState


def normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict:
def _normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict:
"""
Normalize content or all values of content if it is a dict to a string.

Expand All @@ -33,7 +33,7 @@ def normalise_json_content(content, json_path: str = "json") -> str | bool | lis
The only one exception is when we have boolean values, they can not be converted
to string type because databricks does not understand 'True' or 'False' values.
"""
normalise = normalise_json_content
normalise = _normalise_json_content
if isinstance(content, (str, bool)):
return content
elif isinstance(content, (int, float)):
Expand Down
Loading