diff --git a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py index 378d3a0ab7b12..80a7921e2f989 100644 --- a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py +++ b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py @@ -559,7 +559,7 @@ async def _do_api_call_async( url, json=data if self.method in ("POST", "PATCH") else None, params=data if self.method == "GET" else None, - headers=headers, + headers=_headers or None, auth=auth, **extra_options, ) @@ -629,7 +629,10 @@ async def get_batch_state(self, session_id: int | str) -> Any: """ self._validate_session_id(session_id) self.log.info("Fetching info for batch session %s", session_id) - result = await self.run_method(endpoint=f"{self.endpoint_prefix}/batches/{session_id}/state") + result = await self.run_method( + endpoint=f"{self.endpoint_prefix}/batches/{session_id}/state", + headers=self.extra_headers, + ) if result["status"] == "error": self.log.info(result) return {"batch_state": "error", "response": result, "status": "error"} @@ -665,7 +668,9 @@ async def get_batch_logs( self._validate_session_id(session_id) log_params = {"from": log_start_position, "size": log_batch_size} result = await self.run_method( - endpoint=f"{self.endpoint_prefix}/batches/{session_id}/log", data=log_params + endpoint=f"{self.endpoint_prefix}/batches/{session_id}/log", + data=log_params, + headers=self.extra_headers, ) if result["status"] == "error": self.log.info(result) diff --git a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py index c794c42431eae..1c6df11454dee 100644 --- a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py +++ b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py @@ -905,6 +905,24 @@ async def test_get_batch_state_with_endpoint_prefix(self, mock_run_method): } mock_run_method.assert_called_once_with( endpoint=f"/livy/batches/{BATCH_ID}/state", + headers={}, + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method") + async def test_get_batch_state_with_extra_headers(self, mock_run_method): + headers = {"X-Requested-By": "user"} + mock_run_method.return_value = {"status": "success", "response": {"state": BatchState.RUNNING}} + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID, extra_headers=headers) + state = await hook.get_batch_state(BATCH_ID) + assert state == { + "batch_state": BatchState.RUNNING, + "response": "successfully fetched the batch state.", + "status": "success", + } + mock_run_method.assert_called_once_with( + endpoint=f"/batches/{BATCH_ID}/state", + headers=headers, ) @pytest.mark.asyncio @@ -918,4 +936,20 @@ async def test_get_batch_logs_with_endpoint_prefix(self, mock_run_method): mock_run_method.assert_called_once_with( endpoint=f"/livy/batches/{BATCH_ID}/log", data={"from": 0, "size": 100}, + headers={}, + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method") + async def test_get_batch_logs_with_extra_headers(self, mock_run_method): + headers = {"X-Requested-By": "user"} + mock_run_method.return_value = {"status": "success", "response": {}} + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID, extra_headers=headers) + state = await hook.get_batch_logs(BATCH_ID, 0, 100) + assert state["status"] == "success" + + mock_run_method.assert_called_once_with( + endpoint=f"/batches/{BATCH_ID}/log", + data={"from": 0, "size": 100}, + headers=headers, )