From a751be7ce6b84db742951dabf014687e551d8370 Mon Sep 17 00:00:00 2001 From: Jeongwoo Do <48639483+wjddn279@users.noreply.github.com> Date: Tue, 28 Oct 2025 05:34:46 +0900 Subject: [PATCH] [v3-1-test] Fix memory leak in Client via SSL context creation (#57334) related: https://github.com/apache/airflow/issues/56641 When I performed the same memray inspection on the latest version that includes both fixes, the first issue was clearly resolved, but the second issue still persists. [memray1.html](https://github.com/user-attachments/files/23160226/memray1.html) When a Client object is created, `ctx = ssl.create_default_context(cafile=ca_file)` continues to be executed repeatedly, which accumulates in memory and causes a memory leak. (It appears to be allocated as a C language object and remains in memory regardless of Python object GC) This PR uses caching to prevent the SSL context object from being recreated. Here are the results after running for two hours with this change. The memory usage, which was previously growing to tens of MBs, now stabilizes at approximately 700KB. [memray2.html](https://github.com/user-attachments/files/23160405/memray2.html) (cherry picked from commit 7369e4645c17e29c2c15fe10f9ede3d5afe02a2a) Co-authored-by: Jeongwoo Do <48639483+wjddn279@users.noreply.github.com> --- task-sdk/src/airflow/sdk/api/client.py | 14 +++++++---- task-sdk/tests/task_sdk/api/test_client.py | 27 ++++++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 6d71a7f106b9b..85c6b5e9988bf 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -814,6 +814,15 @@ def noop_handler(request: httpx.Request) -> httpx.Response: class Client(httpx.Client): + @classmethod + @lru_cache() + def _get_ssl_context_cached(cls, ca_file: str, ca_path: str | None = None) -> ssl.SSLContext: + """Cache SSL context to prevent memory growth from repeated context creation.""" + ctx = ssl.create_default_context(cafile=ca_file) + if ca_path: + ctx.load_verify_locations(ca_path) + return ctx + def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any): if (not base_url) ^ dry_run: raise ValueError(f"Can only specify one of {base_url=} or {dry_run=}") @@ -826,10 +835,7 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, * kwargs.setdefault("base_url", "dry-run://server") else: kwargs["base_url"] = base_url - ctx = ssl.create_default_context(cafile=certifi.where()) - if API_SSL_CERT_PATH: - ctx.load_verify_locations(API_SSL_CERT_PATH) - kwargs["verify"] = ctx + kwargs["verify"] = self._get_ssl_context_cached(certifi.where(), API_SSL_CERT_PATH) # Set timeout if not explicitly provided kwargs.setdefault("timeout", API_TIMEOUT) diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 32a709663acb2..d9c87d915c9ba 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING from unittest import mock +import certifi import httpx import pytest import uuid6 @@ -1366,3 +1367,29 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert result.params_input == {} assert result.responded_by_user == HITLUser(id="admin", name="admin") assert result.responded_at == timezone.datetime(2025, 7, 3, 0, 0, 0) + + +class TestSSLContextCaching: + def setup_method(self): + Client._get_ssl_context_cached.cache_clear() + + def teardown_method(self): + Client._get_ssl_context_cached.cache_clear() + + def test_cache_hit_on_same_parameters(self): + ca_file = certifi.where() + ctx1 = Client._get_ssl_context_cached(ca_file, None) + ctx2 = Client._get_ssl_context_cached(ca_file, None) + assert ctx1 is ctx2 + + def test_cache_miss_on_different_parameters(self): + ca_file = certifi.where() + + ctx1 = Client._get_ssl_context_cached(ca_file, None) + ctx2 = Client._get_ssl_context_cached(ca_file, ca_file) + + info = Client._get_ssl_context_cached.cache_info() + + assert ctx1 is not ctx2 + assert info.misses == 2 + assert info.currsize == 2