diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index ce0f125fd3540..a46b57cf69646 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -61,11 +61,7 @@ dependencies = [ "python-dateutil>=2.7.0", "psutil>=6.1.0", "structlog>=25.4.0", - "retryhttp>=1.2.0,!=1.3.0", "greenback>=1.2.1", - # Requests is known to introduce breaking changes, so we pin it to a specific range - "requests>=2.31.0,<3", - "types-requests>=2.31.0", "tenacity>=8.3.0", # Start of shared timezones dependencies "pendulum>=3.1.0", diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 0a353271c1025..49da3ecbfbef8 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -30,8 +30,13 @@ import msgspec import structlog from pydantic import BaseModel -from retryhttp import retry, wait_retry_after -from tenacity import before_log, wait_random_exponential +from tenacity import ( + before_log, + retry, + retry_if_exception, + stop_after_attempt, + wait_random_exponential, +) from uuid6 import uuid7 from airflow.configuration import conf @@ -812,6 +817,14 @@ def noop_handler(request: httpx.Request) -> httpx.Response: API_SSL_CERT_PATH = conf.get("api", "ssl_cert") +def _should_retry_api_request(exception: BaseException) -> bool: + """Determine if an API request should be retried based on the exception type.""" + if isinstance(exception, httpx.HTTPStatusError): + return exception.response.status_code >= 500 + + return isinstance(exception, httpx.RequestError) + + class Client(httpx.Client): def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any): if (not base_url) ^ dry_run: @@ -840,21 +853,17 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, * **kwargs, ) - _default_wait = wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX) - def _update_auth(self, response: httpx.Response): if new_token := response.headers.get("Refreshed-API-Token"): log.debug("Execution API issued us a refreshed Task token") self.auth = BearerAuth(new_token) @retry( - reraise=True, - max_attempt_number=API_RETRIES, - wait_server_errors=_default_wait, - wait_network_errors=_default_wait, - wait_timeouts=_default_wait, - wait_rate_limited=wait_retry_after(fallback=_default_wait), # No infinite timeout on HTTP 429 + retry=retry_if_exception(_should_retry_api_request), + stop=stop_after_attempt(API_RETRIES), + wait=wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX), before_sleep=before_log(log, logging.WARNING), + reraise=True, ) def request(self, *args, **kwargs): """Implement a convenience for httpx.Client.request with a retry layer.""" diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index cbd7c80bb0178..3931823793092 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -25,12 +25,13 @@ import httpx import pytest +import time_machine import uuid6 from task_sdk import make_client, make_client_w_dry_run, make_client_w_responses from uuid6 import uuid7 from airflow.sdk import timezone -from airflow.sdk.api.client import RemoteValidationError, ServerResponseError +from airflow.sdk.api.client import Client, RemoteValidationError, ServerResponseError from airflow.sdk.api.datamodels._generated import ( AssetEventsResponse, AssetResponse, @@ -39,6 +40,7 @@ DagRunStateResponse, HITLDetailResponse, HITLUser, + TerminalTIState, VariableResponse, XComResponse, ) @@ -52,7 +54,6 @@ RescheduleTask, TaskRescheduleStartDate, ) -from airflow.utils.state import TerminalTIState if TYPE_CHECKING: from time_machine import TimeMachineFixture @@ -99,6 +100,23 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert isinstance(err.value, FileNotFoundError) + @mock.patch("airflow.sdk.api.client.API_TIMEOUT", 60.0) + def test_timeout_configuration(self): + def handle_request(request: httpx.Request) -> httpx.Response: + return httpx.Response(status_code=200) + + client = make_client(httpx.MockTransport(handle_request)) + assert client.timeout == httpx.Timeout(60.0) + + def test_timeout_can_be_overridden(self): + def handle_request(request: httpx.Request) -> httpx.Response: + return httpx.Response(status_code=200) + + client = Client( + base_url="test://server", token="", transport=httpx.MockTransport(handle_request), timeout=120.0 + ) + assert client.timeout == httpx.Timeout(120.0) + def test_error_parsing(self): responses = [ httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": "err", "type": "required"}]}) @@ -154,76 +172,57 @@ def test_server_response_error_pickling(self): assert unpickled.response.status_code == 404 assert unpickled.request.url == "http://error" - @mock.patch("time.sleep", return_value=None) - def test_retry_handling_unrecoverable_error(self, mock_sleep): - responses: list[httpx.Response] = [ - *[httpx.Response(500, text="Internal Server Error")] * 6, - httpx.Response(200, json={"detail": "Recovered from error - but will fail before"}), - httpx.Response(400, json={"detail": "Should not get here"}), - ] - client = make_client_w_responses(responses) - - with pytest.raises(httpx.HTTPStatusError) as err: - client.get("http://error") - assert not isinstance(err.value, ServerResponseError) - assert len(responses) == 3 - assert mock_sleep.call_count == 4 - - @mock.patch("time.sleep", return_value=None) - def test_retry_handling_recovered(self, mock_sleep): - responses: list[httpx.Response] = [ - *[httpx.Response(500, text="Internal Server Error")] * 2, - httpx.Response(200, json={"detail": "Recovered from error"}), - httpx.Response(400, json={"detail": "Should not get here"}), - ] - client = make_client_w_responses(responses) - - response = client.get("http://error") - assert response.status_code == 200 - assert len(responses) == 1 - assert mock_sleep.call_count == 2 - - @mock.patch("time.sleep", return_value=None) - def test_retry_handling_overload(self, mock_sleep): - responses: list[httpx.Response] = [ - httpx.Response(429, text="I am really busy atm, please back-off", headers={"Retry-After": "37"}), - httpx.Response(200, json={"detail": "Recovered from error"}), - httpx.Response(400, json={"detail": "Should not get here"}), - ] - client = make_client_w_responses(responses) - - response = client.get("http://error") - assert response.status_code == 200 - assert len(responses) == 1 - assert mock_sleep.call_count == 1 - assert mock_sleep.call_args[0][0] == 37 - - @mock.patch("time.sleep", return_value=None) - def test_retry_handling_non_retry_error(self, mock_sleep): - responses: list[httpx.Response] = [ - httpx.Response(422, json={"detail": "Somehow this is a bad request"}), - httpx.Response(400, json={"detail": "Should not get here"}), - ] - client = make_client_w_responses(responses) - - with pytest.raises(ServerResponseError) as err: - client.get("http://error") - assert len(responses) == 1 - assert mock_sleep.call_count == 0 - assert err.value.args == ("Somehow this is a bad request",) - - @mock.patch("time.sleep", return_value=None) - def test_retry_handling_ok(self, mock_sleep): - responses: list[httpx.Response] = [ - httpx.Response(200, json={"detail": "Recovered from error"}), - httpx.Response(400, json={"detail": "Should not get here"}), - ] - client = make_client_w_responses(responses) - - response = client.get("http://error") - assert response.status_code == 200 - assert len(responses) == 1 - assert mock_sleep.call_count == 0 + def test_retry_handling_unrecoverable_error(self): + with time_machine.travel("2023-01-01T00:00:00Z", tick=False): + responses: list[httpx.Response] = [ + *[httpx.Response(500, text="Internal Server Error")] * 6, + httpx.Response(200, json={"detail": "Recovered from error - but will fail before"}), + httpx.Response(400, json={"detail": "Should not get here"}), + ] + client = make_client_w_responses(responses) + + with pytest.raises(httpx.HTTPStatusError) as err: + client.get("http://error") + assert not isinstance(err.value, ServerResponseError) + assert len(responses) == 3 + + def test_retry_handling_recovered(self): + with time_machine.travel("2023-01-01T00:00:00Z", tick=False): + responses: list[httpx.Response] = [ + *[httpx.Response(500, text="Internal Server Error")] * 2, + httpx.Response(200, json={"detail": "Recovered from error"}), + httpx.Response(400, json={"detail": "Should not get here"}), + ] + client = make_client_w_responses(responses) + + response = client.get("http://error") + assert response.status_code == 200 + assert len(responses) == 1 + + def test_retry_handling_non_retry_error(self): + with time_machine.travel("2023-01-01T00:00:00Z", tick=False): + responses: list[httpx.Response] = [ + httpx.Response(422, json={"detail": "Somehow this is a bad request"}), + httpx.Response(400, json={"detail": "Should not get here"}), + ] + client = make_client_w_responses(responses) + + with pytest.raises(ServerResponseError) as err: + client.get("http://error") + assert len(responses) == 1 + assert err.value.args == ("Somehow this is a bad request",) + + def test_retry_handling_ok(self): + with time_machine.travel("2023-01-01T00:00:00Z", tick=False): + responses: list[httpx.Response] = [ + httpx.Response(200, json={"detail": "Recovered from error"}), + httpx.Response(400, json={"detail": "Should not get here"}), + ] + client = make_client_w_responses(responses) + + response = client.get("http://error") + assert response.status_code == 200 + assert len(responses) == 1 def test_token_renewal(self): responses: list[httpx.Response] = [ @@ -269,40 +268,40 @@ class TestTaskInstanceOperations: response parsing. """ - @mock.patch("time.sleep", return_value=None) # To have retries not slowing down tests - def test_task_instance_start(self, mock_sleep, make_ti_context): - # Simulate a successful response from the server that starts a task - ti_id = uuid6.uuid7() - start_date = "2024-10-31T12:00:00Z" - ti_context = make_ti_context( - start_date=start_date, - logical_date="2024-10-31T12:00:00Z", - run_type="manual", - ) - - # ...including a validation that retry really works - call_count = 0 + def test_task_instance_start(self, make_ti_context): + with time_machine.travel("2023-01-01T00:00:00Z", tick=False): + # Simulate a successful response from the server that starts a task + ti_id = uuid6.uuid7() + start_date = "2024-10-31T12:00:00Z" + ti_context = make_ti_context( + start_date=start_date, + logical_date="2024-10-31T12:00:00Z", + run_type="manual", + ) - def handle_request(request: httpx.Request) -> httpx.Response: - nonlocal call_count - call_count += 1 - if call_count < 3: - return httpx.Response(status_code=500, json={"detail": "Internal Server Error"}) - if request.url.path == f"/task-instances/{ti_id}/run": - actual_body = json.loads(request.read()) - assert actual_body["pid"] == 100 - assert actual_body["start_date"] == start_date - assert actual_body["state"] == "running" - return httpx.Response( - status_code=200, - json=ti_context.model_dump(mode="json"), - ) - return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + # ...including a validation that retry really works + call_count = 0 + + def handle_request(request: httpx.Request) -> httpx.Response: + nonlocal call_count + call_count += 1 + if call_count < 3: + return httpx.Response(status_code=500, json={"detail": "Internal Server Error"}) + if request.url.path == f"/task-instances/{ti_id}/run": + actual_body = json.loads(request.read()) + assert actual_body["pid"] == 100 + assert actual_body["start_date"] == start_date + assert actual_body["state"] == "running" + return httpx.Response( + status_code=200, + json=ti_context.model_dump(mode="json"), + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) - client = make_client(transport=httpx.MockTransport(handle_request)) - resp = client.task_instances.start(ti_id, 100, start_date) - assert resp == ti_context - assert call_count == 3 + client = make_client(transport=httpx.MockTransport(handle_request)) + resp = client.task_instances.start(ti_id, 100, start_date) + assert resp == ti_context + assert call_count == 3 @pytest.mark.parametrize( "state", [state for state in TerminalTIState if state != TerminalTIState.SUCCESS] @@ -545,31 +544,31 @@ class TestVariableOperations: response parsing. """ - @mock.patch("time.sleep", return_value=None) # To have retries not slowing down tests - def test_variable_get_success(self, mock_sleep): - # Simulate a successful response from the server with a variable - # ...including a validation that retry really works - call_count = 0 - - def handle_request(request: httpx.Request) -> httpx.Response: - nonlocal call_count - call_count += 1 - if call_count < 2: - return httpx.Response(status_code=500, json={"detail": "Internal Server Error"}) - if request.url.path == "/variables/test_key": - return httpx.Response( - status_code=200, - json={"key": "test_key", "value": "test_value"}, - ) - return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + def test_variable_get_success(self): + with time_machine.travel("2023-01-01T00:00:00Z", tick=False): + # Simulate a successful response from the server with a variable + # ...including a validation that retry really works + call_count = 0 + + def handle_request(request: httpx.Request) -> httpx.Response: + nonlocal call_count + call_count += 1 + if call_count < 2: + return httpx.Response(status_code=500, json={"detail": "Internal Server Error"}) + if request.url.path == "/variables/test_key": + return httpx.Response( + status_code=200, + json={"key": "test_key", "value": "test_value"}, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) - client = make_client(transport=httpx.MockTransport(handle_request)) - result = client.variables.get(key="test_key") + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.variables.get(key="test_key") - assert isinstance(result, VariableResponse) - assert result.key == "test_key" - assert result.value == "test_value" - assert call_count == 2 + assert isinstance(result, VariableResponse) + assert result.key == "test_key" + assert result.value == "test_value" + assert call_count == 2 def test_variable_not_found(self): # Simulate a 404 response from the server @@ -594,26 +593,26 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert resp.error == ErrorType.VARIABLE_NOT_FOUND assert resp.detail == {"key": "non_existent_var"} - @mock.patch("time.sleep", return_value=None) - def test_variable_get_500_error(self, mock_sleep): - # Simulate a response from the server returning a 500 error - def handle_request(request: httpx.Request) -> httpx.Response: - if request.url.path == "/variables/test_key": - return httpx.Response( - status_code=500, - headers=[("content-Type", "application/json")], - json={ - "reason": "internal_server_error", - "message": "Internal Server Error", - }, - ) - return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + def test_variable_get_500_error(self): + with time_machine.travel("2023-01-01T00:00:00Z", tick=False): + # Simulate a response from the server returning a 500 error + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/variables/test_key": + return httpx.Response( + status_code=500, + headers=[("content-Type", "application/json")], + json={ + "reason": "internal_server_error", + "message": "Internal Server Error", + }, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) - client = make_client(transport=httpx.MockTransport(handle_request)) - with pytest.raises(ServerResponseError): - client.variables.get( - key="test_key", - ) + client = make_client(transport=httpx.MockTransport(handle_request)) + with pytest.raises(ServerResponseError): + client.variables.get( + key="test_key", + ) def test_variable_set_success(self): # Simulate a successful response from the server when putting a variable @@ -663,35 +662,35 @@ class TestXCOMOperations: pytest.param({"key": "test_key", "value": {"key2": "value2"}}, id="nested-dict-value"), ], ) - @mock.patch("time.sleep", return_value=None) # To have retries not slowing down tests - def test_xcom_get_success(self, mock_sleep, value): - # Simulate a successful response from the server when getting an xcom - # ...including a validation that retry really works - call_count = 0 - - def handle_request(request: httpx.Request) -> httpx.Response: - nonlocal call_count - call_count += 1 - if call_count < 3: - return httpx.Response(status_code=500, json={"detail": "Internal Server Error"}) - if request.url.path == "/xcoms/dag_id/run_id/task_id/key": - return httpx.Response( - status_code=201, - json={"key": "test_key", "value": value}, - ) - return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + def test_xcom_get_success(self, value): + with time_machine.travel("2023-01-01T00:00:00Z", tick=False): + # Simulate a successful response from the server when getting an xcom + # ...including a validation that retry really works + call_count = 0 + + def handle_request(request: httpx.Request) -> httpx.Response: + nonlocal call_count + call_count += 1 + if call_count < 3: + return httpx.Response(status_code=500, json={"detail": "Internal Server Error"}) + if request.url.path == "/xcoms/dag_id/run_id/task_id/key": + return httpx.Response( + status_code=201, + json={"key": "test_key", "value": value}, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) - client = make_client(transport=httpx.MockTransport(handle_request)) - result = client.xcoms.get( - dag_id="dag_id", - run_id="run_id", - task_id="task_id", - key="key", - ) - assert isinstance(result, XComResponse) - assert result.key == "test_key" - assert result.value == value - assert call_count == 3 + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.xcoms.get( + dag_id="dag_id", + run_id="run_id", + task_id="task_id", + key="key", + ) + assert isinstance(result, XComResponse) + assert result.key == "test_key" + assert result.value == value + assert call_count == 3 def test_xcom_get_success_with_map_index(self): # Simulate a successful response from the server when getting an xcom with map_index passed @@ -742,29 +741,29 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert result.key == "test_key" assert result.value == "test_value" - @mock.patch("time.sleep", return_value=None) - def test_xcom_get_500_error(self, mock_sleep): - # Simulate a successful response from the server returning a 500 error - def handle_request(request: httpx.Request) -> httpx.Response: - if request.url.path == "/xcoms/dag_id/run_id/task_id/key": - return httpx.Response( - status_code=500, - headers=[("content-Type", "application/json")], - json={ - "reason": "invalid_format", - "message": "XCom value is not a valid JSON", - }, - ) - return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + def test_xcom_get_500_error(self): + with time_machine.travel("2023-01-01T00:00:00Z", tick=False): + # Simulate a successful response from the server returning a 500 error + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/xcoms/dag_id/run_id/task_id/key": + return httpx.Response( + status_code=500, + headers=[("content-Type", "application/json")], + json={ + "reason": "invalid_format", + "message": "XCom value is not a valid JSON", + }, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) - client = make_client(transport=httpx.MockTransport(handle_request)) - with pytest.raises(ServerResponseError): - client.xcoms.get( - dag_id="dag_id", - run_id="run_id", - task_id="task_id", - key="key", - ) + client = make_client(transport=httpx.MockTransport(handle_request)) + with pytest.raises(ServerResponseError): + client.xcoms.get( + dag_id="dag_id", + run_id="run_id", + task_id="task_id", + key="key", + ) @pytest.mark.parametrize( "values",