Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions providers/databricks/docs/operators/sql_statements.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,49 @@ An example usage of the ``DatabricksSQLStatementsOperator`` is as follows:
:language: python
:start-after: [START howto_operator_sql_statements]
:end-before: [END howto_operator_sql_statements]


.. _howto/sensor:DatabricksSQLStatementsSensor:

DatabricksSQLStatementsSensor
===============================

Use the :class:`~airflow.providers.databricks.sensor.databricks.DatabricksSQLStatementsSensor` to either submit a
Databricks SQL Statement to Databricks using the
`Databricks SQL Statement Execution API <https://docs.databricks.com/api/workspace/statementexecution>`_, or pass
a Statement ID to the Sensor and await for the query to terminate execution.


Using the Sensor
------------------

The ``DatabricksSQLStatementsSensor`` does one of two things. The Sensor can submit SQL statements to Databricks using
the `/api/2.0/sql/statements/ <https://docs.databricks.com/api/workspace/statementexecution/executestatement>`_
endpoint. However, the Sensor can also take the Statement ID of an already-submitted SQL Statement and handle the
response to that execution.

It supports configurable execution parameters such as warehouse selection, catalog, schema, and parameterized queries.
The operator can either synchronously poll for query completion or run in a deferrable mode for improved efficiency.

The only required parameters for using the Sensor are:

* One of ``statement`` or ``statement_id`` - The SQL statement to execute. The statement can optionally be
parameterized, see parameters.
* ``warehouse_id`` - Warehouse upon which to execute a statement.

All other parameters are optional and described in the documentation for ``DatabricksSQLStatementsSensor`` including
but not limited to:

* ``catalog``
* ``schema``
* ``parameters``

Examples
--------

An example usage of the ``DatabricksSQLStatementsSensor`` is as follows:

.. exampleinclude:: /../../databricks/tests/system/databricks/example_databricks_sensors.py
:language: python
:start-after: [START howto_sensor_databricks_sql_statement]
:end-before: [END howto_sensor_databricks_sql_statement]
1 change: 1 addition & 0 deletions providers/databricks/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ triggers:
sensors:
- integration-name: Databricks
python-modules:
- airflow.providers.databricks.sensors.databricks
- airflow.providers.databricks.sensors.databricks_sql
- airflow.providers.databricks.sensors.databricks_partition

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def get_provider_info():
{
"integration-name": "Databricks",
"python-modules": [
"airflow.providers.databricks.sensors.databricks",
"airflow.providers.databricks.sensors.databricks_sql",
"airflow.providers.databricks.sensors.databricks_partition",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
DatabricksHook,
RunLifeCycleState,
RunState,
SQLStatementState,
)
from airflow.providers.databricks.operators.databricks_workflow import (
DatabricksWorkflowTaskGroup,
Expand All @@ -46,9 +45,9 @@
)
from airflow.providers.databricks.triggers.databricks import (
DatabricksExecutionTrigger,
DatabricksSQLStatementExecutionTrigger,
)
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event
from airflow.providers.databricks.utils.mixins import DatabricksSQLStatementsMixin
from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS

if TYPE_CHECKING:
Expand Down Expand Up @@ -978,7 +977,7 @@ def on_kill(self) -> None:
self.log.error("Error: Task: %s with invalid run_id was requested to be cancelled.", self.task_id)


class DatabricksSQLStatementsOperator(BaseOperator):
class DatabricksSQLStatementsOperator(DatabricksSQLStatementsMixin, BaseOperator):
"""
Submits a Databricks SQL Statement to Databricks using the api/2.0/sql/statements/ API endpoint.

Expand Down Expand Up @@ -1073,59 +1072,6 @@ def _get_hook(self, caller: str) -> DatabricksHook:
caller=caller,
)

def _handle_operator_execution(self) -> None:
end_time = time.time() + self.timeout
while end_time > time.time():
statement_state = self._hook.get_sql_statement_state(self.statement_id)
if statement_state.is_terminal:
if statement_state.is_successful:
self.log.info("%s completed successfully.", self.task_id)
return
error_message = (
f"{self.task_id} failed with terminal state: {statement_state.state} "
f"and with the error code {statement_state.error_code} "
f"and error message {statement_state.error_message}"
)
raise AirflowException(error_message)

self.log.info("%s in run state: %s", self.task_id, statement_state.state)
self.log.info("Sleeping for %s seconds.", self.polling_period_seconds)
time.sleep(self.polling_period_seconds)

self._hook.cancel_sql_statement(self.statement_id)
raise AirflowException(
f"{self.task_id} timed out after {self.timeout} seconds with state: {statement_state.state}",
)

def _handle_deferrable_operator_execution(self) -> None:
statement_state = self._hook.get_sql_statement_state(self.statement_id)
end_time = time.time() + self.timeout
if not statement_state.is_terminal:
if not self.statement_id:
raise AirflowException("Failed to retrieve statement_id after submitting SQL statement.")
self.defer(
trigger=DatabricksSQLStatementExecutionTrigger(
statement_id=self.statement_id,
databricks_conn_id=self.databricks_conn_id,
end_time=end_time,
polling_period_seconds=self.polling_period_seconds,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
),
method_name=DEFER_METHOD_NAME,
)
else:
if statement_state.is_successful:
self.log.info("%s completed successfully.", self.task_id)
else:
error_message = (
f"{self.task_id} failed with terminal state: {statement_state.state} "
f"and with the error code {statement_state.error_code} "
f"and error message {statement_state.error_message}"
)
raise AirflowException(error_message)

def execute(self, context: Context):
json = {
"statement": self.statement,
Expand All @@ -1146,34 +1092,9 @@ def execute(self, context: Context):
if not self.wait_for_termination:
return
if self.deferrable:
self._handle_deferrable_operator_execution()
else:
self._handle_operator_execution()

def on_kill(self):
if self.statement_id:
self._hook.cancel_sql_statement(self.statement_id)
self.log.info(
"Task: %s with statement ID: %s was requested to be cancelled.",
self.task_id,
self.statement_id,
)
self._handle_deferrable_execution(defer_method_name=DEFER_METHOD_NAME) # type: ignore[misc]
else:
self.log.error(
"Error: Task: %s with invalid statement_id was requested to be cancelled.", self.task_id
)

def execute_complete(self, context: dict | None, event: dict):
statement_state = SQLStatementState.from_json(event["state"])
error = event["error"]
statement_id = event["statement_id"]

if statement_state.is_successful:
self.log.info("SQL Statement with ID %s completed successfully.", statement_id)
return

error_message = f"SQL Statement execution failed with terminal state: {statement_state} and with the error {error}"
raise AirflowException(error_message)
self._handle_execution() # type: ignore[misc]


class DatabricksTaskBaseOperator(BaseOperator, ABC):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from __future__ import annotations

from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.databricks.hooks.databricks import DatabricksHook, SQLStatementState
from airflow.providers.databricks.operators.databricks import DEFER_METHOD_NAME
from airflow.providers.databricks.utils.mixins import DatabricksSQLStatementsMixin
from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseSensorOperator
else:
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context

XCOM_STATEMENT_ID_KEY = "statement_id"


class DatabricksSQLStatementsSensor(DatabricksSQLStatementsMixin, BaseSensorOperator):
"""DatabricksSQLStatementsSensor."""

template_fields: Sequence[str] = (
"databricks_conn_id",
"statement",
"statement_id",
)
template_ext: Sequence[str] = (".json-tpl",)
ui_color = "#1CB1C2"
ui_fgcolor = "#fff"

def __init__(
self,
warehouse_id: str,
*,
statement: str | None = None,
statement_id: str | None = None,
catalog: str | None = None,
schema: str | None = None,
parameters: list[dict[str, Any]] | None = None,
databricks_conn_id: str = "databricks_default",
polling_period_seconds: int = 30,
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
databricks_retry_args: dict[Any, Any] | None = None,
do_xcom_push: bool = True,
wait_for_termination: bool = True,
timeout: float = 3600,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
# Handle the scenario where either both statement and statement_id are set/not set
if statement and statement_id:
raise AirflowException("Cannot provide both statement and statement_id.")
if not statement and not statement_id:
raise AirflowException("One of either statement or statement_id must be provided.")

if not warehouse_id:
raise AirflowException("warehouse_id must be provided.")

super().__init__(**kwargs)

self.statement = statement
self.statement_id = statement_id
self.warehouse_id = warehouse_id
self.catalog = catalog
self.schema = schema
self.parameters = parameters
self.databricks_conn_id = databricks_conn_id
self.polling_period_seconds = polling_period_seconds
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
self.timeout = timeout
self.do_xcom_push = do_xcom_push

@cached_property
def _hook(self):
return self._get_hook(caller="DatabricksSQLStatementsSensor")

def _get_hook(self, caller: str) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=caller,
)

def execute(self, context: Context):
if not self.statement_id:
# Otherwise, we'll go ahead and "submit" the statement
json = {
"statement": self.statement,
"warehouse_id": self.warehouse_id,
"catalog": self.catalog,
"schema": self.schema,
"parameters": self.parameters,
"wait_timeout": "0s",
}

self.statement_id = self._hook.post_sql_statement(json)
self.log.info("SQL Statement submitted with statement_id: %s", self.statement_id)

if self.do_xcom_push and context is not None:
context["ti"].xcom_push(key=XCOM_STATEMENT_ID_KEY, value=self.statement_id)

# If we're not waiting for the query to complete execution, then we'll go ahead and return. However, a
# recommendation to use the DatabricksSQLStatementOperator is made in this case
if not self.wait_for_termination:
self.log.info(
"If setting wait_for_termination = False, consider using the DatabricksSQLStatementsOperator instead."
)
return

if self.deferrable:
self._handle_deferrable_execution(defer_method_name=DEFER_METHOD_NAME) # type: ignore[misc]

def poke(self, context: Context):
"""
Handle non-deferrable Sensor execution.

:param context: (Context)
:return: (bool)
"""
# This is going to very closely mirror the execute_complete
statement_state: SQLStatementState = self._hook.get_sql_statement_state(self.statement_id)

if statement_state.is_running:
self.log.info("SQL Statement with ID %s is running", self.statement_id)
return False
if statement_state.is_successful:
self.log.info("SQL Statement with ID %s completed successfully.", self.statement_id)
return True
raise AirflowException(
f"SQL Statement with ID {statement_state} failed with error: {statement_state.error_message}"
)
Loading