diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index f2fbf0d30af65..d760c2f65df51 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -851,6 +851,15 @@ def _should_retry_api_request(exception: BaseException) -> bool: class Client(httpx.Client): + @classmethod + @lru_cache() + def _get_ssl_context_cached(cls, ca_file: str, ca_path: str | None = None) -> ssl.SSLContext: + """Cache SSL context to prevent memory growth from repeated context creation.""" + ctx = ssl.create_default_context(cafile=ca_file) + if ca_path: + ctx.load_verify_locations(ca_path) + return ctx + def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any): if (not base_url) ^ dry_run: raise ValueError(f"Can only specify one of {base_url=} or {dry_run=}") @@ -863,10 +872,7 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, * kwargs.setdefault("base_url", "dry-run://server") else: kwargs["base_url"] = base_url - ctx = ssl.create_default_context(cafile=certifi.where()) - if API_SSL_CERT_PATH: - ctx.load_verify_locations(API_SSL_CERT_PATH) - kwargs["verify"] = ctx + kwargs["verify"] = self._get_ssl_context_cached(certifi.where(), API_SSL_CERT_PATH) # Set timeout if not explicitly provided kwargs.setdefault("timeout", API_TIMEOUT) diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 5c960eb2aa07b..abeda2acd31ea 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING from unittest import mock +import certifi import httpx import pytest import time_machine @@ -1366,3 +1367,29 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert result.params_input == {} assert result.responded_by_user == HITLUser(id="admin", name="admin") assert result.responded_at == timezone.datetime(2025, 7, 3, 0, 0, 0) + + +class TestSSLContextCaching: + def setup_method(self): + Client._get_ssl_context_cached.cache_clear() + + def teardown_method(self): + Client._get_ssl_context_cached.cache_clear() + + def test_cache_hit_on_same_parameters(self): + ca_file = certifi.where() + ctx1 = Client._get_ssl_context_cached(ca_file, None) + ctx2 = Client._get_ssl_context_cached(ca_file, None) + assert ctx1 is ctx2 + + def test_cache_miss_on_different_parameters(self): + ca_file = certifi.where() + + ctx1 = Client._get_ssl_context_cached(ca_file, None) + ctx2 = Client._get_ssl_context_cached(ca_file, ca_file) + + info = Client._get_ssl_context_cached.cache_info() + + assert ctx1 is not ctx2 + assert info.misses == 2 + assert info.currsize == 2