Skip to content

Commit

Permalink
Fix named parameters templating in Databricks operators (apache#40864)
Browse files Browse the repository at this point in the history
This PR fixes the many named parameters that was templated and was broken with apache#40471.

The following operators are affected:

DatabricksCreateJobsOperator
DatabricksSubmitRunOperator
DatabricksRunNowOperator

closes: apache#40788
  • Loading branch information
boraberke authored Jul 18, 2024
1 parent a4e3fbe commit cfe1d53
Show file tree
Hide file tree
Showing 2 changed files with 284 additions and 45 deletions.
173 changes: 128 additions & 45 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,23 @@ class DatabricksCreateJobsOperator(BaseOperator):
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = ("json", "databricks_conn_id")
template_fields: Sequence[str] = (
"json",
"databricks_conn_id",
"name",
"description",
"tags",
"tasks",
"job_clusters",
"email_notifications",
"webhook_notifications",
"notification_settings",
"timeout_seconds",
"schedule",
"max_concurrent_runs",
"git_source",
"access_control_list",
)
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
ui_fgcolor = "#fff"
Expand Down Expand Up @@ -300,21 +316,19 @@ def __init__(
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
self.overridden_json_params = {
"name": name,
"description": description,
"tags": tags,
"tasks": tasks,
"job_clusters": job_clusters,
"email_notifications": email_notifications,
"webhook_notifications": webhook_notifications,
"notification_settings": notification_settings,
"timeout_seconds": timeout_seconds,
"schedule": schedule,
"max_concurrent_runs": max_concurrent_runs,
"git_source": git_source,
"access_control_list": access_control_list,
}
self.name = name
self.description = description
self.tags = tags
self.tasks = tasks
self.job_clusters = job_clusters
self.email_notifications = email_notifications
self.webhook_notifications = webhook_notifications
self.notification_settings = notification_settings
self.timeout_seconds = timeout_seconds
self.schedule = schedule
self.max_concurrent_runs = max_concurrent_runs
self.git_source = git_source
self.access_control_list = access_control_list

@cached_property
def _hook(self):
Expand All @@ -327,6 +341,22 @@ def _hook(self):
)

def _setup_and_validate_json(self):
self.overridden_json_params = {
"name": self.name,
"description": self.description,
"tags": self.tags,
"tasks": self.tasks,
"job_clusters": self.job_clusters,
"email_notifications": self.email_notifications,
"webhook_notifications": self.webhook_notifications,
"notification_settings": self.notification_settings,
"timeout_seconds": self.timeout_seconds,
"schedule": self.schedule,
"max_concurrent_runs": self.max_concurrent_runs,
"git_source": self.git_source,
"access_control_list": self.access_control_list,
}

_handle_overridden_json_params(self)

if "name" not in self.json:
Expand Down Expand Up @@ -470,7 +500,25 @@ class DatabricksSubmitRunOperator(BaseOperator):
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = ("json", "databricks_conn_id")
template_fields: Sequence[str] = (
"json",
"databricks_conn_id",
"tasks",
"spark_jar_task",
"notebook_task",
"spark_python_task",
"spark_submit_task",
"pipeline_task",
"dbt_task",
"new_cluster",
"existing_cluster_id",
"libraries",
"run_name",
"timeout_seconds",
"idempotency_token",
"access_control_list",
"git_source",
)
template_ext: Sequence[str] = (".json-tpl",)
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
Expand Down Expand Up @@ -516,23 +564,21 @@ def __init__(
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
self.overridden_json_params = {
"tasks": tasks,
"spark_jar_task": spark_jar_task,
"notebook_task": notebook_task,
"spark_python_task": spark_python_task,
"spark_submit_task": spark_submit_task,
"pipeline_task": pipeline_task,
"dbt_task": dbt_task,
"new_cluster": new_cluster,
"existing_cluster_id": existing_cluster_id,
"libraries": libraries,
"run_name": run_name,
"timeout_seconds": timeout_seconds,
"idempotency_token": idempotency_token,
"access_control_list": access_control_list,
"git_source": git_source,
}
self.tasks = tasks
self.spark_jar_task = spark_jar_task
self.notebook_task = notebook_task
self.spark_python_task = spark_python_task
self.spark_submit_task = spark_submit_task
self.pipeline_task = pipeline_task
self.dbt_task = dbt_task
self.new_cluster = new_cluster
self.existing_cluster_id = existing_cluster_id
self.libraries = libraries
self.run_name = run_name
self.timeout_seconds = timeout_seconds
self.idempotency_token = idempotency_token
self.access_control_list = access_control_list
self.git_source = git_source

# This variable will be used in case our task gets killed.
self.run_id: int | None = None
Expand All @@ -552,6 +598,24 @@ def _get_hook(self, caller: str) -> DatabricksHook:
)

def _setup_and_validate_json(self):
self.overridden_json_params = {
"tasks": self.tasks,
"spark_jar_task": self.spark_jar_task,
"notebook_task": self.notebook_task,
"spark_python_task": self.spark_python_task,
"spark_submit_task": self.spark_submit_task,
"pipeline_task": self.pipeline_task,
"dbt_task": self.dbt_task,
"new_cluster": self.new_cluster,
"existing_cluster_id": self.existing_cluster_id,
"libraries": self.libraries,
"run_name": self.run_name,
"timeout_seconds": self.timeout_seconds,
"idempotency_token": self.idempotency_token,
"access_control_list": self.access_control_list,
"git_source": self.git_source,
}

_handle_overridden_json_params(self)

if "run_name" not in self.json or self.json["run_name"] is None:
Expand Down Expand Up @@ -772,7 +836,18 @@ class DatabricksRunNowOperator(BaseOperator):
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = ("json", "databricks_conn_id")
template_fields: Sequence[str] = (
"json",
"databricks_conn_id",
"job_id",
"job_name",
"notebook_params",
"python_params",
"python_named_params",
"jar_params",
"spark_submit_params",
"idempotency_token",
)
template_ext: Sequence[str] = (".json-tpl",)
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
Expand Down Expand Up @@ -815,16 +890,14 @@ def __init__(
self.deferrable = deferrable
self.repair_run = repair_run
self.cancel_previous_runs = cancel_previous_runs
self.overridden_json_params = {
"job_id": job_id,
"job_name": job_name,
"notebook_params": notebook_params,
"python_params": python_params,
"python_named_params": python_named_params,
"jar_params": jar_params,
"spark_submit_params": spark_submit_params,
"idempotency_token": idempotency_token,
}
self.job_id = job_id
self.job_name = job_name
self.notebook_params = notebook_params
self.python_params = python_params
self.python_named_params = python_named_params
self.jar_params = jar_params
self.spark_submit_params = spark_submit_params
self.idempotency_token = idempotency_token
# This variable will be used in case our task gets killed.
self.run_id: int | None = None
self.do_xcom_push = do_xcom_push
Expand All @@ -843,6 +916,16 @@ def _get_hook(self, caller: str) -> DatabricksHook:
)

def _setup_and_validate_json(self):
self.overridden_json_params = {
"job_id": self.job_id,
"job_name": self.job_name,
"notebook_params": self.notebook_params,
"python_params": self.python_params,
"python_named_params": self.python_named_params,
"jar_params": self.jar_params,
"spark_submit_params": self.spark_submit_params,
"idempotency_token": self.idempotency_token,
}
_handle_overridden_json_params(self)

if "job_id" in self.json and "job_name" in self.json:
Expand Down
Loading

0 comments on commit cfe1d53

Please sign in to comment.