Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,15 @@ def _should_retry_api_request(exception: BaseException) -> bool:


class Client(httpx.Client):
@classmethod
@lru_cache()
def _get_ssl_context_cached(cls, ca_file: str, ca_path: str | None = None) -> ssl.SSLContext:
"""Cache SSL context to prevent memory growth from repeated context creation."""
ctx = ssl.create_default_context(cafile=ca_file)
if ca_path:
ctx.load_verify_locations(ca_path)
return ctx

def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any):
if (not base_url) ^ dry_run:
raise ValueError(f"Can only specify one of {base_url=} or {dry_run=}")
Expand All @@ -863,10 +872,7 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, *
kwargs.setdefault("base_url", "dry-run://server")
else:
kwargs["base_url"] = base_url
ctx = ssl.create_default_context(cafile=certifi.where())
if API_SSL_CERT_PATH:
ctx.load_verify_locations(API_SSL_CERT_PATH)
kwargs["verify"] = ctx
kwargs["verify"] = self._get_ssl_context_cached(certifi.where(), API_SSL_CERT_PATH)

# Set timeout if not explicitly provided
kwargs.setdefault("timeout", API_TIMEOUT)
Expand Down
27 changes: 27 additions & 0 deletions task-sdk/tests/task_sdk/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import TYPE_CHECKING
from unittest import mock

import certifi
import httpx
import pytest
import time_machine
Expand Down Expand Up @@ -1366,3 +1367,29 @@ def handle_request(request: httpx.Request) -> httpx.Response:
assert result.params_input == {}
assert result.responded_by_user == HITLUser(id="admin", name="admin")
assert result.responded_at == timezone.datetime(2025, 7, 3, 0, 0, 0)


class TestSSLContextCaching:
def setup_method(self):
Client._get_ssl_context_cached.cache_clear()

def teardown_method(self):
Client._get_ssl_context_cached.cache_clear()

def test_cache_hit_on_same_parameters(self):
ca_file = certifi.where()
ctx1 = Client._get_ssl_context_cached(ca_file, None)
ctx2 = Client._get_ssl_context_cached(ca_file, None)
assert ctx1 is ctx2

def test_cache_miss_on_different_parameters(self):
ca_file = certifi.where()

ctx1 = Client._get_ssl_context_cached(ca_file, None)
ctx2 = Client._get_ssl_context_cached(ca_file, ca_file)

info = Client._get_ssl_context_cached.cache_info()

assert ctx1 is not ctx2
assert info.misses == 2
assert info.currsize == 2