From 059e89178e8395b25d3ce0d20285835aa6b04f01 Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Mon, 27 Oct 2025 17:41:17 +0900 Subject: [PATCH 1/2] add ssl context cache --- task-sdk/src/airflow/sdk/api/client.py | 14 +++++++--- task-sdk/tests/task_sdk/api/test_client.py | 30 ++++++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index f2fbf0d30af65..149ac71f00c34 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -851,6 +851,15 @@ def _should_retry_api_request(exception: BaseException) -> bool: class Client(httpx.Client): + @classmethod + @lru_cache() + def _get_ssl_context_cached(cls, ca_file: str, ca_path: str | None) -> ssl.SSLContext: + """Cache SSL context to prevent memory growth from repeated context creation.""" + ctx = ssl.create_default_context(cafile=ca_file) + if ca_path: + ctx.load_verify_locations(ca_path) + return ctx + def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any): if (not base_url) ^ dry_run: raise ValueError(f"Can only specify one of {base_url=} or {dry_run=}") @@ -863,10 +872,7 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, * kwargs.setdefault("base_url", "dry-run://server") else: kwargs["base_url"] = base_url - ctx = ssl.create_default_context(cafile=certifi.where()) - if API_SSL_CERT_PATH: - ctx.load_verify_locations(API_SSL_CERT_PATH) - kwargs["verify"] = ctx + kwargs["verify"] = self._get_ssl_context_cached(certifi.where(), API_SSL_CERT_PATH) # Set timeout if not explicitly provided kwargs.setdefault("timeout", API_TIMEOUT) diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 5c960eb2aa07b..83a0b6169e31b 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -1366,3 +1366,33 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert result.params_input == {} assert result.responded_by_user == HITLUser(id="admin", name="admin") assert result.responded_at == timezone.datetime(2025, 7, 3, 0, 0, 0) + + +class TestSSLContextCaching: + def setup_method(self): + Client._get_ssl_context_cached.cache_clear() + + def teardown_method(self): + Client._get_ssl_context_cached.cache_clear() + + def test_cache_hit_on_same_parameters(self): + import certifi + + ca_file = certifi.where() + ctx1 = Client._get_ssl_context_cached(ca_file, "") + ctx2 = Client._get_ssl_context_cached(ca_file, "") + assert ctx1 is ctx2 + + def test_cache_miss_on_different_parameters(self): + import certifi + + ca_file = certifi.where() + + ctx1 = Client._get_ssl_context_cached(ca_file, "") + ctx2 = Client._get_ssl_context_cached(ca_file, ca_file) + + info = Client._get_ssl_context_cached.cache_info() + + assert ctx1 is not ctx2 + assert info.misses == 2 + assert info.currsize == 2 From 7ca1a48a6386aff806073e46578fbffd909212d8 Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Mon, 27 Oct 2025 21:49:13 +0900 Subject: [PATCH 2/2] fix logic --- task-sdk/src/airflow/sdk/api/client.py | 2 +- task-sdk/tests/task_sdk/api/test_client.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 149ac71f00c34..d760c2f65df51 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -853,7 +853,7 @@ def _should_retry_api_request(exception: BaseException) -> bool: class Client(httpx.Client): @classmethod @lru_cache() - def _get_ssl_context_cached(cls, ca_file: str, ca_path: str | None) -> ssl.SSLContext: + def _get_ssl_context_cached(cls, ca_file: str, ca_path: str | None = None) -> ssl.SSLContext: """Cache SSL context to prevent memory growth from repeated context creation.""" ctx = ssl.create_default_context(cafile=ca_file) if ca_path: diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 83a0b6169e31b..abeda2acd31ea 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING from unittest import mock +import certifi import httpx import pytest import time_machine @@ -1376,19 +1377,15 @@ def teardown_method(self): Client._get_ssl_context_cached.cache_clear() def test_cache_hit_on_same_parameters(self): - import certifi - ca_file = certifi.where() - ctx1 = Client._get_ssl_context_cached(ca_file, "") - ctx2 = Client._get_ssl_context_cached(ca_file, "") + ctx1 = Client._get_ssl_context_cached(ca_file, None) + ctx2 = Client._get_ssl_context_cached(ca_file, None) assert ctx1 is ctx2 def test_cache_miss_on_different_parameters(self): - import certifi - ca_file = certifi.where() - ctx1 = Client._get_ssl_context_cached(ca_file, "") + ctx1 = Client._get_ssl_context_cached(ca_file, None) ctx2 = Client._get_ssl_context_cached(ca_file, ca_file) info = Client._get_ssl_context_cached.cache_info()