diff --git a/providers/databricks/docs/operators/sql_statements.rst b/providers/databricks/docs/operators/sql_statements.rst new file mode 100644 index 0000000000000..73b7948a14465 --- /dev/null +++ b/providers/databricks/docs/operators/sql_statements.rst @@ -0,0 +1,57 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:DatabricksSQLStatementsOperator: + + +DatabricksSQLStatementsOperator +=============================== + +Use the :class:`~airflow.providers.databricks.operators.databricks.DatabricksSQLStatementsOperator` to submit a +Databricks SQL Statement to Databricks using the +`Databricks SQL Statement Execution API `_. + + +Using the Operator +------------------ + +The ``DatabricksSQLStatementsOperator`` submits SQL statements to Databricks using the +`/api/2.0/sql/statements/ `_ endpoint. +It supports configurable execution parameters such as warehouse selection, catalog, schema, and parameterized queries. +The operator can either synchronously poll for query completion or run in a deferrable mode for improved efficiency. + +The only required parameters for using the operator are: + +* ``statement`` - The SQL statement to execute. The statement can optionally be parameterized, see parameters. +* ``warehouse_id`` - Warehouse upon which to execute a statement. + +All other parameters are optional and described in the documentation for ``DatabricksSQLStatementsOperator`` including +but not limited to: + +* ``catalog`` +* ``schema`` +* ``parameters`` + +Examples +-------- + +An example usage of the ``DatabricksSQLStatementsOperator`` is as follows: + +.. exampleinclude:: /../../databricks/tests/system/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_sql_statements] + :end-before: [END howto_operator_sql_statements] diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py index 7143d59a1f1a9..f9363aabd7f52 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py @@ -63,6 +63,7 @@ WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status") SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions") +SQL_STATEMENTS_ENDPOINT = "api/2.0/sql/statements" class RunLifeCycleState(Enum): @@ -189,6 +190,67 @@ def from_json(cls, data: str) -> ClusterState: return ClusterState(**json.loads(data)) +class SQLStatementState: + """Utility class for the SQL statement state concept of Databricks statements.""" + + SQL_STATEMENT_LIFE_CYCLE_STATES = [ + "PENDING", + "RUNNING", + "SUCCEEDED", + "FAILED", + "CANCELED", + "CLOSED", + ] + + def __init__( + self, state: str = "", error_code: str = "", error_message: str = "", *args, **kwargs + ) -> None: + if state not in self.SQL_STATEMENT_LIFE_CYCLE_STATES: + raise AirflowException( + f"Unexpected SQL statement life cycle state: {state}: If the state has " + "been introduced recently, please check the Databricks user " + "guide for troubleshooting information" + ) + + self.state = state + self.error_code = error_code + self.error_message = error_message + + @property + def is_terminal(self) -> bool: + """True if the current state is a terminal state.""" + return self.state in ("SUCCEEDED", "FAILED", "CANCELED", "CLOSED") + + @property + def is_running(self) -> bool: + """True if the current state is running.""" + return self.state in ("PENDING", "RUNNING") + + @property + def is_successful(self) -> bool: + """True if the state is SUCCEEDED.""" + return self.state == "SUCCEEDED" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SQLStatementState): + return NotImplemented + return ( + self.state == other.state + and self.error_code == other.error_code + and self.error_message == other.error_message + ) + + def __repr__(self) -> str: + return str(self.__dict__) + + def to_json(self) -> str: + return json.dumps(self.__dict__) + + @classmethod + def from_json(cls, data: str) -> SQLStatementState: + return SQLStatementState(**json.loads(data)) + + class DatabricksHook(BaseDatabricksHook): """ Interact with Databricks. @@ -709,6 +771,54 @@ def update_job_permission(self, job_id: int, json: dict[str, Any]) -> dict: """ return self._do_api_call(("PATCH", f"api/2.0/permissions/jobs/{job_id}"), json) + def post_sql_statement(self, json: dict[str, Any]) -> str: + """ + Submit a SQL statement to the Databricks SQL Statements endpoint. + + :param json: The data used in the body of the request to the SQL Statements endpoint. + :return: The statement_id as a string. + """ + response = self._do_api_call(("POST", f"{SQL_STATEMENTS_ENDPOINT}"), json) + return response["statement_id"] + + def get_sql_statement_state(self, statement_id: str) -> SQLStatementState: + """ + Retrieve run state of the SQL statement. + + :param statement_id: ID of the SQL statement. + :return: state of the SQL statement. + """ + get_statement_endpoint = ("GET", f"{SQL_STATEMENTS_ENDPOINT}/{statement_id}") + response = self._do_api_call(get_statement_endpoint) + state = response["status"]["state"] + error_code = response["status"].get("error", {}).get("error_code", "") + error_message = response["status"].get("error", {}).get("message", "") + return SQLStatementState(state, error_code, error_message) + + async def a_get_sql_statement_state(self, statement_id: str) -> SQLStatementState: + """ + Async version of `get_sql_statement_state`. + + :param statement_id: ID of the SQL statement + :return: state of the SQL statement + """ + get_sql_statement_endpoint = ("GET", f"{SQL_STATEMENTS_ENDPOINT}/{statement_id}") + response = await self._a_do_api_call(get_sql_statement_endpoint) + state = response["status"]["state"] + error_code = response["status"].get("error", {}).get("error_code", "") + error_message = response["status"].get("error", {}).get("message", "") + return SQLStatementState(state, error_code, error_message) + + def cancel_sql_statement(self, statement_id: str) -> None: + """ + Cancel the SQL statement. + + :param statement_id: ID of the SQL statement + """ + self.log.info("Canceling SQL statement with ID: %s", statement_id) + cancel_sql_statement_endpoint = ("POST", f"{SQL_STATEMENTS_ENDPOINT}/{statement_id}/cancel") + self._do_api_call(cancel_sql_statement_endpoint) + def test_connection(self) -> tuple[bool, str]: """Test the Databricks connectivity from UI.""" hook = DatabricksHook(databricks_conn_id=self.databricks_conn_id) diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index 2de7de67b3e94..3265607bbd854 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -30,7 +30,12 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import BaseOperator -from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState, RunState +from airflow.providers.databricks.hooks.databricks import ( + DatabricksHook, + RunLifeCycleState, + RunState, + SQLStatementState, +) from airflow.providers.databricks.operators.databricks_workflow import ( DatabricksWorkflowTaskGroup, WorkflowRunMetadata, @@ -39,7 +44,10 @@ WorkflowJobRepairSingleTaskLink, WorkflowJobRunLink, ) -from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger +from airflow.providers.databricks.triggers.databricks import ( + DatabricksExecutionTrigger, + DatabricksSQLStatementExecutionTrigger, +) from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS @@ -59,6 +67,7 @@ XCOM_RUN_ID_KEY = "run_id" XCOM_JOB_ID_KEY = "job_id" XCOM_RUN_PAGE_URL_KEY = "run_page_url" +XCOM_STATEMENT_ID_KEY = "statement_id" def _handle_databricks_operator_execution(operator, hook, log, context) -> None: @@ -969,6 +978,204 @@ def on_kill(self) -> None: self.log.error("Error: Task: %s with invalid run_id was requested to be cancelled.", self.task_id) +class DatabricksSQLStatementsOperator(BaseOperator): + """ + Submits a Databricks SQL Statement to Databricks using the api/2.0/sql/statements/ API endpoint. + + See: https://docs.databricks.com/api/workspace/statementexecution + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DatabricksSQLStatementsOperator` + + :param statement: The SQL statement to execute. The statement can optionally be parameterized, see parameters. + :param warehouse_id: Warehouse upon which to execute a statement. + :param catalog: Sets default catalog for statement execution, similar to USE CATALOG in SQL. + :param schema: Sets default schema for statement execution, similar to USE SCHEMA in SQL. + :param parameters: A list of parameters to pass into a SQL statement containing parameter markers. + + .. seealso:: + https://docs.databricks.com/api/workspace/statementexecution/executestatement#parameters + :param wait_for_termination: if we should wait for termination of the statement execution. ``True`` by default. + :param databricks_conn_id: Reference to the :ref:`Databricks connection `. + By default and in the common case this will be ``databricks_default``. To use + token based authentication, provide the key ``token`` in the extra field for the + connection and create the key ``host`` and leave the ``host`` field empty. (templated) + :param polling_period_seconds: Controls the rate which we poll for the result of + this statement. By default the operator will poll every 30 seconds. + :param databricks_retry_limit: Amount of times retry if the Databricks backend is + unreachable. Its value must be greater than or equal to 1. + :param databricks_retry_delay: Number of seconds to wait between retries (it + might be a floating point number). + :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + :param do_xcom_push: Whether we should push statement_id to xcom.: + :param timeout: The timeout for the Airflow task executing the SQL statement. By default a value of 3600 seconds is used. + :param deferrable: Run operator in the deferrable mode. + """ + + # Used in airflow.models.BaseOperator + template_fields: Sequence[str] = ("databricks_conn_id",) + template_ext: Sequence[str] = (".json-tpl",) + # Databricks brand color (blue) under white text + ui_color = "#1CB1C2" + ui_fgcolor = "#fff" + + def __init__( + self, + statement: str, + warehouse_id: str, + *, + catalog: str | None = None, + schema: str | None = None, + parameters: list[dict[str, Any]] | None = None, + databricks_conn_id: str = "databricks_default", + polling_period_seconds: int = 30, + databricks_retry_limit: int = 3, + databricks_retry_delay: int = 1, + databricks_retry_args: dict[Any, Any] | None = None, + do_xcom_push: bool = True, + wait_for_termination: bool = True, + timeout: float = 3600, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ) -> None: + """Create a new ``DatabricksSubmitRunOperator``.""" + super().__init__(**kwargs) + self.statement = statement + self.warehouse_id = warehouse_id + self.catalog = catalog + self.schema = schema + self.parameters = parameters + self.databricks_conn_id = databricks_conn_id + self.polling_period_seconds = polling_period_seconds + self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay + self.databricks_retry_args = databricks_retry_args + self.wait_for_termination = wait_for_termination + self.deferrable = deferrable + + # This variable will be used in case our task gets killed. + self.statement_id: str | None = None + + self.timeout = timeout + self.do_xcom_push = do_xcom_push + + @cached_property + def _hook(self): + return self._get_hook(caller="DatabricksSQLStatementsOperator") + + def _get_hook(self, caller: str) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=caller, + ) + + def _handle_operator_execution(self) -> None: + end_time = time.time() + self.timeout + while end_time > time.time(): + statement_state = self._hook.get_sql_statement_state(self.statement_id) + if statement_state.is_terminal: + if statement_state.is_successful: + self.log.info("%s completed successfully.", self.task_id) + return + error_message = ( + f"{self.task_id} failed with terminal state: {statement_state.state} " + f"and with the error code {statement_state.error_code} " + f"and error message {statement_state.error_message}" + ) + raise AirflowException(error_message) + + self.log.info("%s in run state: %s", self.task_id, statement_state.state) + self.log.info("Sleeping for %s seconds.", self.polling_period_seconds) + time.sleep(self.polling_period_seconds) + + self._hook.cancel_sql_statement(self.statement_id) + raise AirflowException( + f"{self.task_id} timed out after {self.timeout} seconds with state: {statement_state.state}", + ) + + def _handle_deferrable_operator_execution(self) -> None: + statement_state = self._hook.get_sql_statement_state(self.statement_id) + end_time = time.time() + self.timeout + if not statement_state.is_terminal: + if not self.statement_id: + raise AirflowException("Failed to retrieve statement_id after submitting SQL statement.") + self.defer( + trigger=DatabricksSQLStatementExecutionTrigger( + statement_id=self.statement_id, + databricks_conn_id=self.databricks_conn_id, + end_time=end_time, + polling_period_seconds=self.polling_period_seconds, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + ), + method_name=DEFER_METHOD_NAME, + ) + else: + if statement_state.is_successful: + self.log.info("%s completed successfully.", self.task_id) + else: + error_message = ( + f"{self.task_id} failed with terminal state: {statement_state.state} " + f"and with the error code {statement_state.error_code} " + f"and error message {statement_state.error_message}" + ) + raise AirflowException(error_message) + + def execute(self, context: Context): + json = { + "statement": self.statement, + "warehouse_id": self.warehouse_id, + "catalog": self.catalog, + "schema": self.schema, + "parameters": self.parameters, + # We set the wait timeout to 0s as that seems the appropriate way for our deferrable version + # support of the operator. For synchronous version, we still poll on the statement + # execution state. + "wait_timeout": "0s", + } + self.statement_id = self._hook.post_sql_statement(json) + if self.do_xcom_push and context is not None: + context["ti"].xcom_push(key=XCOM_STATEMENT_ID_KEY, value=self.statement_id) + + self.log.info("SQL Statement submitted with statement_id: %s", self.statement_id) + if not self.wait_for_termination: + return + if self.deferrable: + self._handle_deferrable_operator_execution() + else: + self._handle_operator_execution() + + def on_kill(self): + if self.statement_id: + self._hook.cancel_sql_statement(self.statement_id) + self.log.info( + "Task: %s with statement ID: %s was requested to be cancelled.", + self.task_id, + self.statement_id, + ) + else: + self.log.error( + "Error: Task: %s with invalid statement_id was requested to be cancelled.", self.task_id + ) + + def execute_complete(self, context: dict | None, event: dict): + statement_state = SQLStatementState.from_json(event["state"]) + error = event["error"] + statement_id = event["statement_id"] + + if statement_state.is_successful: + self.log.info("SQL Statement with ID %s completed successfully.", statement_id) + return + + error_message = f"SQL Statement execution failed with terminal state: {statement_state} and with the error {error}" + raise AirflowException(error_message) + + class DatabricksTaskBaseOperator(BaseOperator, ABC): """ Base class for operators that are run as Databricks job tasks or tasks within a Databricks workflow. diff --git a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py index 55845fc6f7c57..d985d7ca68ab2 100644 --- a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py @@ -18,6 +18,7 @@ from __future__ import annotations import asyncio +import time from typing import Any from airflow.providers.databricks.hooks.databricks import DatabricksHook @@ -119,3 +120,102 @@ async def run(self): } ) return + + +class DatabricksSQLStatementExecutionTrigger(BaseTrigger): + """ + The trigger handles the logic of async communication with DataBricks SQL Statements API. + + :param statement_id: ID of the SQL statement. + :param databricks_conn_id: Reference to the :ref:`Databricks connection `. + :param end_time: The end time (set based on timeout supplied for the operator) for the SQL statement execution. + :param polling_period_seconds: Controls the rate of the poll for the result of this run. + By default, the trigger will poll every 30 seconds. + :param retry_limit: The number of times to retry the connection in case of service outages. + :param retry_delay: The number of seconds to wait between retries. + :param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + """ + + def __init__( + self, + statement_id: str, + databricks_conn_id: str, + end_time: float, + polling_period_seconds: int = 30, + retry_limit: int = 3, + retry_delay: int = 10, + retry_args: dict[Any, Any] | None = None, + caller: str = "DatabricksSQLStatementExecutionTrigger", + ) -> None: + super().__init__() + self.statement_id = statement_id + self.databricks_conn_id = databricks_conn_id + self.end_time = end_time + self.polling_period_seconds = polling_period_seconds + self.retry_limit = retry_limit + self.retry_delay = retry_delay + self.retry_args = retry_args + self.hook = DatabricksHook( + databricks_conn_id, + retry_limit=self.retry_limit, + retry_delay=self.retry_delay, + retry_args=retry_args, + caller=caller, + ) + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.databricks.triggers.databricks.DatabricksSQLStatementExecutionTrigger", + { + "statement_id": self.statement_id, + "databricks_conn_id": self.databricks_conn_id, + "polling_period_seconds": self.polling_period_seconds, + "end_time": self.end_time, + "retry_limit": self.retry_limit, + "retry_delay": self.retry_delay, + "retry_args": self.retry_args, + }, + ) + + async def run(self): + async with self.hook: + while self.end_time > time.time(): + statement_state = await self.hook.a_get_sql_statement_state(self.statement_id) + if not statement_state.is_terminal: + self.log.info( + "Statement ID %s is in state %s. sleeping for %s seconds", + self.statement_id, + statement_state, + self.polling_period_seconds, + ) + await asyncio.sleep(self.polling_period_seconds) + continue + + error = {} + if statement_state.error_code: + error = { + "error_code": statement_state.error_code, + "error_message": statement_state.error_message, + } + yield TriggerEvent( + { + "statement_id": self.statement_id, + "state": statement_state.to_json(), + "error": error, + } + ) + return + + # If we reach here, it means the statement should be timed out as per the end_time. + self.hook.cancel_sql_statement(self.statement_id) + yield TriggerEvent( + { + "statement_id": self.statement_id, + "state": statement_state.to_json(), + "error": { + "error_code": "TIMEOUT", + "error_message": f"Statement ID {self.statement_id} timed out after set end time {self.end_time}", + }, + } + ) + return diff --git a/providers/databricks/tests/system/databricks/example_databricks.py b/providers/databricks/tests/system/databricks/example_databricks.py index 999cebb674292..39aaa2891d6ad 100644 --- a/providers/databricks/tests/system/databricks/example_databricks.py +++ b/providers/databricks/tests/system/databricks/example_databricks.py @@ -41,6 +41,7 @@ DatabricksCreateJobsOperator, DatabricksNotebookOperator, DatabricksRunNowOperator, + DatabricksSQLStatementsOperator, DatabricksSubmitRunOperator, DatabricksTaskOperator, ) @@ -152,6 +153,16 @@ # [END howto_operator_databricks_named] notebook_task >> spark_jar_task + # [START howto_operator_sql_statements] + sql_statement = DatabricksSQLStatementsOperator( + task_id="sql_statement", + databricks_conn_id="databricks_default", + statement="select * from default.my_airflow_table", + warehouse_id=WAREHOUSE_ID, + # deferrable=True, # For using the operator in deferrable mode + ) + # [END howto_operator_sql_statements] + # [START howto_operator_databricks_notebook_new_cluster] new_cluster_spec = { "cluster_name": "", diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py index 4eaeddf972edf..d0c437ad02940 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py @@ -43,6 +43,7 @@ ClusterState, DatabricksHook, RunState, + SQLStatementState, ) from airflow.providers.databricks.hooks.databricks_base import ( AZURE_MANAGEMENT_ENDPOINT, @@ -65,6 +66,9 @@ JOB_NAME = "job-name" PIPELINE_NAME = "some pipeline name" PIPELINE_ID = "its-a-pipeline-id" +STATEMENT_ID = "statement_id" +STATEMENT_STATE = "SUCCEEDED" +WAREHOUSE_ID = "warehouse_id" DEFAULT_RETRY_NUMBER = 3 DEFAULT_RETRY_ARGS = dict( wait=tenacity.wait_none(), @@ -90,6 +94,7 @@ CLUSTER_STATE = "TERMINATED" CLUSTER_STATE_MESSAGE = "Inactive cluster terminated (inactive for 120 minutes)." GET_CLUSTER_RESPONSE = {"state": CLUSTER_STATE, "state_message": CLUSTER_STATE_MESSAGE} +GET_SQL_STATEMENT_RESPONSE = {"statement_id": STATEMENT_ID, "status": {"state": STATEMENT_STATE}} NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"} JAR_PARAMS = ["param1", "param2"] RESULT_STATE = "" @@ -273,6 +278,11 @@ def create_valid_response_mock(content): return response +def sql_statements_endpoint(host): + """Utility function to generate the sql statements endpoint given the host.""" + return f"https://{host}/api/2.0/sql/statements" + + def create_successful_response_mock(content): response = mock.MagicMock() response.json.return_value = content @@ -1149,6 +1159,62 @@ def test_list_pipelines_raise_exception_with_duplicates(self, mock_requests): timeout=self.hook.timeout_seconds, ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_post_sql_statement(self, mock_requests): + mock_requests.post.return_value.json.return_value = { + "statement_id": "01f00ed2-04e2-15bd-a944-a8ae011dac69" + } + json = { + "statement": "select * from test.test;", + "warehouse_id": WAREHOUSE_ID, + "catalog": "", + "schema": "", + "parameters": {}, + "wait_timeout": "0s", + } + self.hook.post_sql_statement(json) + + mock_requests.post.assert_called_once_with( + sql_statements_endpoint(HOST), + json=json, + params=None, + auth=HTTPBasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_get_sql_statement_state(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.get.return_value.json.return_value = GET_SQL_STATEMENT_RESPONSE + + sql_statement_state = self.hook.get_sql_statement_state(STATEMENT_ID) + + assert sql_statement_state == SQLStatementState(STATEMENT_STATE) + mock_requests.get.assert_called_once_with( + f"{sql_statements_endpoint(HOST)}/{STATEMENT_ID}", + json=None, + params=None, + auth=HTTPBasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_cancel_sql_statement(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.get.return_value.json.return_value = GET_SQL_STATEMENT_RESPONSE + + self.hook.cancel_sql_statement(STATEMENT_ID) + mock_requests.post.assert_called_once_with( + f"{sql_statements_endpoint(HOST)}/{STATEMENT_ID}/cancel", + json=None, + params=None, + auth=HTTPBasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_connection_success(self, mock_requests): mock_requests.codes.ok = 200 @@ -2068,3 +2134,59 @@ async def test_get_run_state(self, mock_post, mock_get): headers=self.hook.user_agent_header, timeout=self.hook.timeout_seconds, ) + + +class TestSQLStatementState: + def test_sqlstatementstate_initialization_valid_states(self): + valid_states = ["PENDING", "RUNNING", "SUCCEEDED", "FAILED", "CANCELED", "CLOSED"] + for state in valid_states: + obj = SQLStatementState(state=state) + assert obj.state == state + + def test_sqlstatementstate_initialization_invalid_state(self): + with pytest.raises(AirflowException, match="Unexpected SQL statement life cycle state: UNKNOWN"): + SQLStatementState(state="UNKNOWN") + + def test_sqlstatementstate_terminal_states(self): + terminal_states = ["SUCCEEDED", "FAILED", "CANCELED", "CLOSED"] + for state in terminal_states: + obj = SQLStatementState(state=state) + assert obj.is_terminal is True + + def test_sqlstatementstate_running_states(self): + running_states = ["PENDING", "RUNNING"] + for state in running_states: + obj = SQLStatementState(state=state) + assert obj.is_running is True + + def test_sqlstatementstate_successful_state(self): + obj = SQLStatementState(state="SUCCEEDED") + assert obj.is_successful is True + + def test_sqlstatementstate_equality(self): + obj1 = SQLStatementState(state="FAILED", error_code="123", error_message="Error occurred") + obj2 = SQLStatementState(state="FAILED", error_code="123", error_message="Error occurred") + obj3 = SQLStatementState(state="SUCCEEDED") + assert obj1 == obj2 + assert obj1 != obj3 + + def test_sqlstatementstate_repr(self): + obj = SQLStatementState(state="FAILED", error_code="123", error_message="Error occurred") + assert "'state': 'FAILED'" in repr(obj) + assert "'error_code': '123'" in repr(obj) + assert "'error_message': 'Error occurred'" in repr(obj) + + def test_sqlstatementstate_to_json(self): + obj = SQLStatementState(state="FAILED", error_code="123", error_message="Error occurred") + json_data = obj.to_json() + expected_data = json.dumps( + {"state": "FAILED", "error_code": "123", "error_message": "Error occurred"} + ) + assert json.loads(json_data) == json.loads(expected_data) + + def test_sqlstatementstate_from_json(self): + json_data = json.dumps({"state": "FAILED", "error_code": "123", "error_message": "Error occurred"}) + obj = SQLStatementState.from_json(json_data) + assert obj.state == "FAILED" + assert obj.error_code == "123" + assert obj.error_message == "Error occurred" diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks.py b/providers/databricks/tests/unit/databricks/operators/test_databricks.py index 298c2aecd47ac..8f2d9f4cdd03a 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks.py @@ -27,16 +27,20 @@ from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG -from airflow.providers.databricks.hooks.databricks import RunState +from airflow.providers.databricks.hooks.databricks import RunState, SQLStatementState from airflow.providers.databricks.operators.databricks import ( DatabricksCreateJobsOperator, DatabricksNotebookOperator, DatabricksRunNowOperator, + DatabricksSQLStatementsOperator, DatabricksSubmitRunOperator, DatabricksTaskBaseOperator, DatabricksTaskOperator, ) -from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger +from airflow.providers.databricks.triggers.databricks import ( + DatabricksExecutionTrigger, + DatabricksSQLStatementExecutionTrigger, +) from airflow.providers.databricks.utils import databricks as utils pytestmark = pytest.mark.db_test @@ -64,6 +68,8 @@ RUN_NAME = "run-name" RUN_ID = 1 RUN_PAGE_URL = "run-page-url" +STATEMENT_ID = "statement_id" +WAREHOUSE_ID = "warehouse_id" JOB_ID = "42" JOB_NAME = "job-name" JOB_DESCRIPTION = "job-description" @@ -1948,6 +1954,171 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_ assert not mock_defer.called +class TestDatabricksSQLStatementsOperator: + def test_init(self): + """ + Test the initializer. + """ + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, statement="select * from test.test;", warehouse_id=WAREHOUSE_ID + ) + + assert op.statement == "select * from test.test;" + assert op.warehouse_id == WAREHOUSE_ID + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_success(self, db_mock_class): + """ + Test the execute function in case where the statement is successful. + """ + expected_json = { + "statement": "select * from test.test;", + "warehouse_id": WAREHOUSE_ID, + "catalog": None, + "schema": None, + "parameters": None, + "wait_timeout": "0s", + } + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, statement="select * from test.test;", warehouse_id=WAREHOUSE_ID + ) + db_mock = db_mock_class.return_value + db_mock.post_sql_statement.return_value = STATEMENT_ID + + op.execute(None) + + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksSQLStatementsOperator", + ) + + db_mock.post_sql_statement.assert_called_once_with(expected_json) + db_mock.get_sql_statement_state.assert_called_once_with(STATEMENT_ID) + assert op.statement_id == STATEMENT_ID + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_failure(self, db_mock_class): + """ + Test the execute function in case where the statement failed. + """ + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, statement="select * from test.test;", warehouse_id=WAREHOUSE_ID + ) + db_mock = db_mock_class.return_value + db_mock.post_sql_statement.return_value = STATEMENT_ID + db_mock.get_sql_statement_state.return_value = SQLStatementState( + state="FAILED", error_code="500", error_message="Something went wrong" + ) + + with pytest.raises(AirflowException): + op.execute(None) + + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksSQLStatementsOperator", + ) + db_mock.get_sql_statement_state.assert_called_once_with(STATEMENT_ID) + assert op.statement_id == STATEMENT_ID + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_on_kill(self, db_mock_class): + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, statement="select * from test.test;", warehouse_id=WAREHOUSE_ID + ) + db_mock = db_mock_class.return_value + op.statement_id = STATEMENT_ID + + op.on_kill() + + db_mock.cancel_sql_statement.assert_called_once_with(STATEMENT_ID) + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_wait_for_termination_is_default(self, db_mock_class): + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, statement="select * from test.test;", warehouse_id=WAREHOUSE_ID + ) + + assert op.wait_for_termination + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_no_wait_for_termination(self, db_mock_class): + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, + statement="select * from test.test;", + warehouse_id=WAREHOUSE_ID, + wait_for_termination=False, + ) + db_mock = db_mock_class.return_value + + assert not op.wait_for_termination + + op.execute(None) + + db_mock.get_sql_statement_state.assert_not_called() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_execute_task_deferred(self, db_mock_class): + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, + statement="select * from test.test;", + warehouse_id=WAREHOUSE_ID, + deferrable=True, + ) + db_mock = db_mock_class.return_value + db_mock.get_sql_statement_state.return_value = SQLStatementState("RUNNING") + + with pytest.raises(TaskDeferred) as exc: + op.execute(None) + assert isinstance(exc.value.trigger, DatabricksSQLStatementExecutionTrigger) + assert exc.value.method_name == "execute_complete" + + def test_execute_complete_success(self): + """ + Test `execute_complete` function in case the Trigger has returned a successful completion event. + """ + event = { + "statement_id": STATEMENT_ID, + "state": SQLStatementState("SUCCEEDED").to_json(), + "error": {}, + } + + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, + statement="select * from test.test;", + warehouse_id=WAREHOUSE_ID, + deferrable=True, + ) + assert op.execute_complete(context=None, event=event) is None + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_execute_complete_failure(self, db_mock_class): + """ + Test `execute_complete` function in case the Trigger has returned a failure completion event. + """ + event = { + "statement_id": STATEMENT_ID, + "state": SQLStatementState("FAILED").to_json(), + "error": SQLStatementState( + state="FAILED", error_code="500", error_message="Something Went Wrong" + ).to_json(), + } + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, + statement="select * from test.test;", + warehouse_id=WAREHOUSE_ID, + deferrable=True, + ) + + with pytest.raises(AirflowException, match="^SQL Statement execution failed with terminal state: .*"): + op.execute_complete(context=None, event=event) + + class TestDatabricksNotebookOperator: def test_is_instance_of_databricks_task_base_operator(self): operator = DatabricksNotebookOperator( diff --git a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py index b4bcbc133c0ce..bb143fe5c9b0c 100644 --- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py @@ -17,13 +17,17 @@ # under the License. from __future__ import annotations +import time from unittest import mock import pytest from airflow.models import Connection -from airflow.providers.databricks.hooks.databricks import RunState -from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger +from airflow.providers.databricks.hooks.databricks import RunState, SQLStatementState +from airflow.providers.databricks.triggers.databricks import ( + DatabricksExecutionTrigger, + DatabricksSQLStatementExecutionTrigger, +) from airflow.triggers.base import TriggerEvent from airflow.utils.session import provide_session @@ -38,6 +42,7 @@ RETRY_DELAY = 10 RETRY_LIMIT = 3 RUN_ID = 1 +STATEMENT_ID = "statement_id" TASK_RUN_ID1 = 11 TASK_RUN_ID1_KEY = "first_task" TASK_RUN_ID2 = 22 @@ -250,3 +255,102 @@ async def test_sleep_between_retries( ) mock_sleep.assert_called_once() mock_sleep.assert_called_with(POLLING_INTERVAL_SECONDS) + + +class TestDatabricksSQLStatementExecutionTrigger: + @provide_session + def setup_method(self, method, session=None): + self.end_time = time.time() + 60 + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() + conn.host = HOST + conn.login = LOGIN + conn.password = PASSWORD + conn.extra = None + session.commit() + + self.trigger = DatabricksSQLStatementExecutionTrigger( + statement_id=STATEMENT_ID, + databricks_conn_id=DEFAULT_CONN_ID, + polling_period_seconds=POLLING_INTERVAL_SECONDS, + end_time=self.end_time, + ) + + def test_serialize(self): + assert self.trigger.serialize() == ( + "airflow.providers.databricks.triggers.databricks.DatabricksSQLStatementExecutionTrigger", + { + "statement_id": STATEMENT_ID, + "databricks_conn_id": DEFAULT_CONN_ID, + "end_time": self.end_time, + "polling_period_seconds": POLLING_INTERVAL_SECONDS, + "retry_delay": 10, + "retry_limit": 3, + "retry_args": None, + }, + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_sql_statement_state") + async def test_run_return_success(self, mock_a_get_sql_statement_state): + mock_a_get_sql_statement_state.return_value = SQLStatementState(state="SUCCEEDED") + + trigger_event = self.trigger.run() + async for event in trigger_event: + assert event == TriggerEvent( + { + "statement_id": STATEMENT_ID, + "state": SQLStatementState(state="SUCCEEDED").to_json(), + "error": {}, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_sql_statement_state") + async def test_run_return_failure(self, mock_a_get_sql_statement_state): + mock_a_get_sql_statement_state.return_value = SQLStatementState( + state="FAILED", + error_code="500", + error_message="Something went wrong", + ) + + trigger_event = self.trigger.run() + async for event in trigger_event: + assert event == TriggerEvent( + { + "statement_id": STATEMENT_ID, + "state": SQLStatementState( + state="FAILED", + error_code="500", + error_message="Something went wrong", + ).to_json(), + "error": { + "error_code": "500", + "error_message": "Something went wrong", + }, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.triggers.databricks.asyncio.sleep") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_sql_statement_state") + async def test_sleep_between_retries(self, mock_a_get_sql_statement_state, mock_sleep): + mock_a_get_sql_statement_state.side_effect = [ + SQLStatementState( + state="PENDING", + ), + SQLStatementState( + state="SUCCEEDED", + ), + ] + + trigger_event = self.trigger.run() + async for event in trigger_event: + assert event == TriggerEvent( + { + "statement_id": STATEMENT_ID, + "state": SQLStatementState(state="SUCCEEDED").to_json(), + "error": {}, + } + ) + mock_sleep.assert_called_once() + mock_sleep.assert_called_with(POLLING_INTERVAL_SECONDS)