diff --git a/tests/config/test_legacy_env.py b/tests/config/test_legacy_env.py index 7375a4b83f..cb9cba06a9 100644 --- a/tests/config/test_legacy_env.py +++ b/tests/config/test_legacy_env.py @@ -43,9 +43,9 @@ def test_env_pending_overrides_apply_on_activation(mock_config_dir, config_manag assert current_manager.profile == "dev" dev_web = current_manager.get_section("web") assert dev_web.enable_caching is False - assert dev_web.ssl_version == ssl.TLSVersion.TLSv1_2 + assert dev_web.ssl_version == "TLSv1_2" assert Env.current.enable_caching is False - assert Env.current.ssl_version == ssl.TLSVersion.TLSv1_2 + assert Env.current.ssl_version == "TLSv1_2" finally: reload_config(profile="default") diff --git a/tests/config/test_web_config.py b/tests/config/test_web_config.py index 2c0d15b4cc..c0cf3816eb 100644 --- a/tests/config/test_web_config.py +++ b/tests/config/test_web_config.py @@ -1,5 +1,9 @@ from __future__ import annotations +import ssl + +import pytest + from tidy3d.config.sections import WebConfig @@ -21,3 +25,21 @@ def test_build_api_url_returns_base_for_empty_path(): def test_build_api_url_without_base_returns_path(): web = WebConfig.model_construct(api_endpoint="") assert web.build_api_url("/v1/tasks") == "v1/tasks" + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (ssl.TLSVersion.TLSv1, "TLSv1"), + (ssl.TLSVersion.TLSv1_1, "TLSv1_1"), + ], +) +def test_web_config_normalizes_ssl_version_aliases(value, expected): + web = WebConfig(ssl_version=value) + assert web.ssl_version == expected + + +@pytest.mark.parametrize("value", ["", "TLSv2", "SSLv3", "udp1.0"]) +def test_web_config_rejects_invalid_ssl_version(value): + with pytest.raises(ValueError): + WebConfig(ssl_version=value) diff --git a/tests/test_web/test_env.py b/tests/test_web/test_env.py index 62e8566b29..d9746e4f6d 100644 --- a/tests/test_web/test_env.py +++ b/tests/test_web/test_env.py @@ -14,14 +14,14 @@ def test_tidy3d_env(): def test_set_ssl_version(): - Env.set_ssl_version(ssl.TLSVersion.TLSv1_3) - assert Env.current.ssl_version == ssl.TLSVersion.TLSv1_3 + Env.set_ssl_version("TLSv1_3") + assert Env.current.ssl_version == "TLSv1_3" Env.set_ssl_version(ssl.TLSVersion.TLSv1_2) - assert Env.current.ssl_version == ssl.TLSVersion.TLSv1_2 + assert Env.current.ssl_version == "TLSv1_2" Env.set_ssl_version(ssl.TLSVersion.TLSv1_1) - assert Env.current.ssl_version == ssl.TLSVersion.TLSv1_1 + assert Env.current.ssl_version == "TLSv1_1" Env.set_ssl_version(None) assert Env.current.ssl_version is None diff --git a/tidy3d/config/legacy.py b/tidy3d/config/legacy.py index 4fb73ddeb4..aff6fd688e 100644 --- a/tidy3d/config/legacy.py +++ b/tidy3d/config/legacy.py @@ -7,7 +7,6 @@ from __future__ import annotations import os -import ssl import warnings from pathlib import Path from typing import Any, Optional @@ -160,7 +159,7 @@ def __init__( s3_region: Optional[str] = None, ssl_verify: Optional[bool] = None, enable_caching: Optional[bool] = None, - ssl_version: Optional[ssl.TLSVersion] = None, + ssl_version: Optional[str] = None, env_vars: Optional[dict[str, str]] = None, environment: Optional[LegacyEnvironment] = None, ) -> None: @@ -239,11 +238,11 @@ def enable_caching(self, value: Optional[bool]) -> None: self._set_pending("enable_caching", value) @property - def ssl_version(self) -> Optional[ssl.TLSVersion]: + def ssl_version(self) -> Optional[str]: return self._value("ssl_version") @ssl_version.setter - def ssl_version(self, value: Optional[ssl.TLSVersion]) -> None: + def ssl_version(self, value: Optional[str]) -> None: self._set_pending("ssl_version", value) @property @@ -363,7 +362,7 @@ def enable_caching(self, enable_caching: Optional[bool] = True) -> None: config.enable_caching = enable_caching self._sync_to_manager() - def set_ssl_version(self, ssl_version: Optional[ssl.TLSVersion]) -> None: + def set_ssl_version(self, ssl_version: Optional[str]) -> None: config = self.current config.ssl_version = ssl_version self._sync_to_manager() diff --git a/tidy3d/config/sections.py b/tidy3d/config/sections.py index 0e9bd04a4d..9894207d27 100644 --- a/tidy3d/config/sections.py +++ b/tidy3d/config/sections.py @@ -3,7 +3,6 @@ from __future__ import annotations import os -import ssl from os import PathLike from pathlib import Path from typing import Any, Literal, Optional @@ -28,6 +27,8 @@ from .registry import get_manager as _get_attached_manager from .registry import register_handler, register_section +TLS_VERSION_CHOICES = {"TLSv1", "TLSv1_1", "TLSv1_2", "TLSv1_3"} + class ConfigSection(BaseModel): """Base class for configuration sections.""" @@ -328,10 +329,13 @@ class WebConfig(ConfigSection): le=300, ) - ssl_version: Optional[ssl.TLSVersion] = Field( + ssl_version: Optional[str] = Field( None, title="SSL/TLS version", - description="Optional SSL/TLS version to enforce for requests.", + description=( + "Optional TLS version override to enforce for requests. Accepts values such as " + "'TLSv1_2'." + ), ) env_vars: dict[str, str] = Field( @@ -349,14 +353,34 @@ def to_dict(self, *, mask_secrets: bool = True) -> dict[str, Any]: secret = data.get("apikey") if isinstance(secret, SecretStr): data["apikey"] = secret.get_secret_value() - ssl_version = data.get("ssl_version") - if isinstance(ssl_version, ssl.TLSVersion): - data["ssl_version"] = ssl_version.value for field in ("api_endpoint", "website_endpoint"): if field in data and data[field] is not None: data[field] = str(data[field]) return data + @field_validator("ssl_version", mode="before") + @classmethod + def _convert_and_check_ssl_version_name(cls, value: Any) -> Optional[str]: + """Convert SSL enum to string and check if valid. + + Accepted examples: + "TLSv1" + "TLSv1_2" + ssl.TLSVersion.TLSv1_2.name -> "TLSv1_2" + """ + if value is None: + return None + + # Prefer enum.name if present, otherwise raw string + candidate = getattr(value, "name", value) + candidate = str(candidate).strip() + + if candidate not in TLS_VERSION_CHOICES: + allowed = ", ".join(sorted(TLS_VERSION_CHOICES)) + raise ValueError(f"Invalid TLS version {candidate!r}. Must be one of: {allowed}") + + return candidate + @field_validator("api_endpoint", "website_endpoint", mode="before") @classmethod def _validate_http_url(cls, value: Any) -> str: diff --git a/tidy3d/web/core/http_util.py b/tidy3d/web/core/http_util.py index 9e5c9f3f9f..35c73b3cbb 100644 --- a/tidy3d/web/core/http_util.py +++ b/tidy3d/web/core/http_util.py @@ -4,6 +4,7 @@ import json import os +import ssl from enum import Enum from functools import wraps from typing import Any, Callable, Optional, TypeAlias @@ -12,6 +13,7 @@ from requests.adapters import HTTPAdapter from urllib3.util.ssl_ import create_urllib3_context +from tidy3d import log from tidy3d.config import config from . import core_config @@ -200,7 +202,16 @@ def wrapper(*args: Any, **kwargs: Any) -> JSONType: class TLSAdapter(HTTPAdapter): def init_poolmanager(self, *args: Any, **kwargs: Any) -> None: - context = create_urllib3_context(ssl_version=config.web.ssl_version) + try: + ssl_version = ( + ssl.TLSVersion[config.web.ssl_version] + if config.web.ssl_version is not None + else None + ) + except KeyError: + log.warning(f"Invalid SSL/TLS version '{config.web.ssl_version}', using default") + ssl_version = None + context = create_urllib3_context(ssl_version=ssl_version) kwargs["ssl_context"] = context return super().init_poolmanager(*args, **kwargs)