diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 6d71a7f106b9b..85c6b5e9988bf 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -814,6 +814,15 @@ def noop_handler(request: httpx.Request) -> httpx.Response: class Client(httpx.Client): + @classmethod + @lru_cache() + def _get_ssl_context_cached(cls, ca_file: str, ca_path: str | None = None) -> ssl.SSLContext: + """Cache SSL context to prevent memory growth from repeated context creation.""" + ctx = ssl.create_default_context(cafile=ca_file) + if ca_path: + ctx.load_verify_locations(ca_path) + return ctx + def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any): if (not base_url) ^ dry_run: raise ValueError(f"Can only specify one of {base_url=} or {dry_run=}") @@ -826,10 +835,7 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, * kwargs.setdefault("base_url", "dry-run://server") else: kwargs["base_url"] = base_url - ctx = ssl.create_default_context(cafile=certifi.where()) - if API_SSL_CERT_PATH: - ctx.load_verify_locations(API_SSL_CERT_PATH) - kwargs["verify"] = ctx + kwargs["verify"] = self._get_ssl_context_cached(certifi.where(), API_SSL_CERT_PATH) # Set timeout if not explicitly provided kwargs.setdefault("timeout", API_TIMEOUT) diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 32a709663acb2..d9c87d915c9ba 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING from unittest import mock +import certifi import httpx import pytest import uuid6 @@ -1366,3 +1367,29 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert result.params_input == {} assert result.responded_by_user == HITLUser(id="admin", name="admin") assert result.responded_at == timezone.datetime(2025, 7, 3, 0, 0, 0) + + +class TestSSLContextCaching: + def setup_method(self): + Client._get_ssl_context_cached.cache_clear() + + def teardown_method(self): + Client._get_ssl_context_cached.cache_clear() + + def test_cache_hit_on_same_parameters(self): + ca_file = certifi.where() + ctx1 = Client._get_ssl_context_cached(ca_file, None) + ctx2 = Client._get_ssl_context_cached(ca_file, None) + assert ctx1 is ctx2 + + def test_cache_miss_on_different_parameters(self): + ca_file = certifi.where() + + ctx1 = Client._get_ssl_context_cached(ca_file, None) + ctx2 = Client._get_ssl_context_cached(ca_file, ca_file) + + info = Client._get_ssl_context_cached.cache_info() + + assert ctx1 is not ctx2 + assert info.misses == 2 + assert info.currsize == 2