Skip to content

Commit

Permalink
Deprecate databricks async operator (#30761)
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro authored Apr 22, 2023
1 parent 5efcb99 commit 7d02277
Showing 1 changed file with 38 additions and 2 deletions.
40 changes: 38 additions & 2 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

import time
import warnings
from logging import Logger
from typing import TYPE_CHECKING, Any, Sequence

Expand Down Expand Up @@ -267,6 +268,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
:param do_xcom_push: Whether we should push run_id and run_page_url to xcom.
:param git_source: Optional specification of a remote git repository from which
supported task types are retrieved.
:param deferrable: Run operator in the deferrable mode.
.. seealso::
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit
Expand Down Expand Up @@ -306,6 +308,7 @@ def __init__(
access_control_list: list[dict[str, str]] | None = None,
wait_for_termination: bool = True,
git_source: dict[str, str] | None = None,
deferrable: bool = False,
**kwargs,
) -> None:
"""Creates a new ``DatabricksSubmitRunOperator``."""
Expand All @@ -317,6 +320,7 @@ def __init__(
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
if tasks is not None:
self.json["tasks"] = tasks
if spark_jar_task is not None:
Expand Down Expand Up @@ -373,7 +377,10 @@ def _get_hook(self, caller: str) -> DatabricksHook:
def execute(self, context: Context):
json_normalised = normalise_json_content(self.json)
self.run_id = self._hook.submit_run(json_normalised)
_handle_databricks_operator_execution(self, self._hook, self.log, context)
if self.deferrable:
_handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context)
else:
_handle_databricks_operator_execution(self, self._hook, self.log, context)

def on_kill(self):
if self.run_id:
Expand All @@ -384,10 +391,23 @@ def on_kill(self):
else:
self.log.error("Error: Task: %s with invalid run_id was requested to be cancelled.", self.task_id)

def execute_complete(self, context: dict | None, event: dict):
_handle_deferrable_databricks_operator_completion(event, self.log)


class DatabricksSubmitRunDeferrableOperator(DatabricksSubmitRunOperator):
"""Deferrable version of ``DatabricksSubmitRunOperator``"""

def __init__(self, *args, **kwargs):
warnings.warn(
"`DatabricksSubmitRunDeferrableOperator` has been deprecated. "
"Please use `airflow.providers.databricks.operators.DatabricksSubmitRunOperator` with "
"`deferrable=True` instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(deferrable=True, *args, **kwargs)

def execute(self, context):
hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator")
json_normalised = normalise_json_content(self.json)
Expand Down Expand Up @@ -549,6 +569,7 @@ class DatabricksRunNowOperator(BaseOperator):
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
:param do_xcom_push: Whether we should push run_id and run_page_url to xcom.
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
:param deferrable: Run operator in the deferrable mode.
"""

# Used in airflow.models.BaseOperator
Expand Down Expand Up @@ -578,6 +599,7 @@ def __init__(
databricks_retry_args: dict[Any, Any] | None = None,
do_xcom_push: bool = True,
wait_for_termination: bool = True,
deferrable: bool = False,
**kwargs,
) -> None:
"""Creates a new ``DatabricksRunNowOperator``."""
Expand All @@ -589,6 +611,7 @@ def __init__(
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable

if job_id is not None:
self.json["job_id"] = job_id
Expand Down Expand Up @@ -636,7 +659,10 @@ def execute(self, context: Context):
self.json["job_id"] = job_id
del self.json["job_name"]
self.run_id = hook.run_now(self.json)
_handle_databricks_operator_execution(self, hook, self.log, context)
if self.deferrable:
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)
else:
_handle_databricks_operator_execution(self, hook, self.log, context)

def on_kill(self):
if self.run_id:
Expand All @@ -651,6 +677,16 @@ def on_kill(self):
class DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):
"""Deferrable version of ``DatabricksRunNowOperator``"""

def __init__(self, *args, **kwargs):
warnings.warn(
"`DatabricksRunNowDeferrableOperator` has been deprecated. "
"Please use `airflow.providers.databricks.operators.DatabricksRunNowOperator` with "
"`deferrable=True` instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(deferrable=True, *args, **kwargs)

def execute(self, context):
hook = self._get_hook(caller="DatabricksRunNowDeferrableOperator")
self.run_id = hook.run_now(self.json)
Expand Down

0 comments on commit 7d02277

Please sign in to comment.