Skip to content

Commit aa70ca7

Browse files
refactor(tidy3d): FXC-4436-move-typing-only-imports-behind-if-type-checking
1 parent 3e3526a commit aa70ca7

File tree

6 files changed

+98
-18
lines changed

6 files changed

+98
-18
lines changed

tests/config/test_legacy_env.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def test_env_pending_overrides_apply_on_activation(mock_config_dir, config_manag
4343
assert current_manager.profile == "dev"
4444
dev_web = current_manager.get_section("web")
4545
assert dev_web.enable_caching is False
46-
assert dev_web.ssl_version == ssl.TLSVersion.TLSv1_2
46+
assert dev_web.ssl_version == "TLSv1_2"
4747
assert Env.current.enable_caching is False
48-
assert Env.current.ssl_version == ssl.TLSVersion.TLSv1_2
48+
assert Env.current.ssl_version == "TLSv1_2"
4949
finally:
5050
reload_config(profile="default")
5151

tests/config/test_web_config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from __future__ import annotations
22

3+
import ssl
4+
5+
import pytest
6+
37
from tidy3d.config.sections import WebConfig
48

59

@@ -21,3 +25,24 @@ def test_build_api_url_returns_base_for_empty_path():
2125
def test_build_api_url_without_base_returns_path():
2226
web = WebConfig.model_construct(api_endpoint="")
2327
assert web.build_api_url("/v1/tasks") == "v1/tasks"
28+
29+
30+
@pytest.mark.parametrize(
31+
("value", "expected"),
32+
[
33+
("TLSv1_2", "TLSv1_2"),
34+
("tls1.2", "TLSv1_2"),
35+
("TLS 1.3", "TLSv1_3"),
36+
("1.0", "TLSv1"),
37+
(ssl.TLSVersion.TLSv1_1, "TLSv1_1"),
38+
],
39+
)
40+
def test_web_config_normalizes_ssl_version_aliases(value, expected):
41+
web = WebConfig(ssl_version=value)
42+
assert web.ssl_version == expected
43+
44+
45+
@pytest.mark.parametrize("value", ["", "TLSv2", "SSLv3", "udp1.0"])
46+
def test_web_config_rejects_invalid_ssl_version(value):
47+
with pytest.raises(ValueError):
48+
WebConfig(ssl_version=value)

tests/test_web/test_env.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ def test_tidy3d_env():
1414

1515

1616
def test_set_ssl_version():
17-
Env.set_ssl_version(ssl.TLSVersion.TLSv1_3)
18-
assert Env.current.ssl_version == ssl.TLSVersion.TLSv1_3
17+
Env.set_ssl_version("TLSv1_3")
18+
assert Env.current.ssl_version == "TLSv1_3"
1919

2020
Env.set_ssl_version(ssl.TLSVersion.TLSv1_2)
21-
assert Env.current.ssl_version == ssl.TLSVersion.TLSv1_2
21+
assert Env.current.ssl_version == "TLSv1_2"
2222

2323
Env.set_ssl_version(ssl.TLSVersion.TLSv1_1)
24-
assert Env.current.ssl_version == ssl.TLSVersion.TLSv1_1
24+
assert Env.current.ssl_version == "TLSv1_1"
2525

2626
Env.set_ssl_version(None)
2727
assert Env.current.ssl_version is None

tidy3d/config/legacy.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from __future__ import annotations
88

99
import os
10-
import ssl
1110
import warnings
1211
from pathlib import Path
1312
from typing import Any, Optional
@@ -160,7 +159,7 @@ def __init__(
160159
s3_region: Optional[str] = None,
161160
ssl_verify: Optional[bool] = None,
162161
enable_caching: Optional[bool] = None,
163-
ssl_version: Optional[ssl.TLSVersion] = None,
162+
ssl_version: Optional[str] = None,
164163
env_vars: Optional[dict[str, str]] = None,
165164
environment: Optional[LegacyEnvironment] = None,
166165
) -> None:
@@ -239,11 +238,11 @@ def enable_caching(self, value: Optional[bool]) -> None:
239238
self._set_pending("enable_caching", value)
240239

241240
@property
242-
def ssl_version(self) -> Optional[ssl.TLSVersion]:
241+
def ssl_version(self) -> Optional[str]:
243242
return self._value("ssl_version")
244243

245244
@ssl_version.setter
246-
def ssl_version(self, value: Optional[ssl.TLSVersion]) -> None:
245+
def ssl_version(self, value: Optional[str]) -> None:
247246
self._set_pending("ssl_version", value)
248247

249248
@property
@@ -363,7 +362,7 @@ def enable_caching(self, enable_caching: Optional[bool] = True) -> None:
363362
config.enable_caching = enable_caching
364363
self._sync_to_manager()
365364

366-
def set_ssl_version(self, ssl_version: Optional[ssl.TLSVersion]) -> None:
365+
def set_ssl_version(self, ssl_version: Optional[str]) -> None:
367366
config = self.current
368367
config.ssl_version = ssl_version
369368
self._sync_to_manager()

tidy3d/config/sections.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import os
6-
import ssl
6+
import re
77
from os import PathLike
88
from pathlib import Path
99
from typing import Any, Literal, Optional
@@ -28,6 +28,14 @@
2828
from .registry import get_manager as _get_attached_manager
2929
from .registry import register_handler, register_section
3030

31+
# canonical concrete TLS versions (no MIN/MAX sentinels here)
32+
TLS_VERSION_CHOICES = {"TLSv1", "TLSv1_1", "TLSv1_2", "TLSv1_3"}
33+
34+
_TLS_RE = re.compile(
35+
r"^\s*(?:TLS)?v?(\d+)(?:[._-](\d+))?\s*$",
36+
re.IGNORECASE,
37+
)
38+
3139

3240
class ConfigSection(BaseModel):
3341
"""Base class for configuration sections."""
@@ -328,10 +336,13 @@ class WebConfig(ConfigSection):
328336
le=300,
329337
)
330338

331-
ssl_version: Optional[ssl.TLSVersion] = Field(
339+
ssl_version: Optional[str] = Field(
332340
None,
333341
title="SSL/TLS version",
334-
description="Optional SSL/TLS version to enforce for requests.",
342+
description=(
343+
"Optional TLS version override to enforce for requests. Accepts values such as "
344+
"'TLSv1_2'."
345+
),
335346
)
336347

337348
env_vars: dict[str, str] = Field(
@@ -349,14 +360,55 @@ def to_dict(self, *, mask_secrets: bool = True) -> dict[str, Any]:
349360
secret = data.get("apikey")
350361
if isinstance(secret, SecretStr):
351362
data["apikey"] = secret.get_secret_value()
352-
ssl_version = data.get("ssl_version")
353-
if isinstance(ssl_version, ssl.TLSVersion):
354-
data["ssl_version"] = ssl_version.value
355363
for field in ("api_endpoint", "website_endpoint"):
356364
if field in data and data[field] is not None:
357365
data[field] = str(data[field])
358366
return data
359367

368+
@field_validator("ssl_version", mode="before")
369+
@classmethod
370+
def _normalize_ssl_version_name(cls, value: Any) -> Optional[str]:
371+
"""Normalize TLS version-like inputs to canonical ``TLSvX[_Y]`` names.
372+
373+
Examples accepted:
374+
"TLSv1" -> "TLSv1"
375+
"TLS 1.2" -> "TLSv1_2"
376+
"tls1.3" -> "TLSv1_3"
377+
"1.2" -> "TLSv1_2"
378+
"1.0" -> "TLSv1"
379+
ssl.TLSVersion.TLSv1_2.name -> "TLSv1_2"
380+
"""
381+
if value is None:
382+
return None
383+
384+
candidate = getattr(value, "name", value) # enum.name if present
385+
candidate = str(candidate).strip().replace(" ", "")
386+
387+
if not candidate:
388+
raise ValueError("TLS version cannot be empty")
389+
390+
match = _TLS_RE.match(candidate)
391+
if not match:
392+
allowed = ", ".join(sorted(TLS_VERSION_CHOICES))
393+
raise ValueError(
394+
f"Unrecognized TLS version {candidate!r}. Expected something like "
395+
f"'TLSv1_2', '1.2', 'TLS1.2'. Allowed canonical values: {allowed}"
396+
)
397+
398+
major, minor = match.group(1), match.group(2)
399+
400+
# X.0 → X, because ssl only has TLSv1, not TLSv1_0
401+
if minor in (None, "0"):
402+
canonical = f"TLSv{major}"
403+
else:
404+
canonical = f"TLSv{major}_{minor}"
405+
406+
if canonical not in TLS_VERSION_CHOICES:
407+
allowed = ", ".join(sorted(TLS_VERSION_CHOICES))
408+
raise ValueError(f"TLS version must be one of: {allowed}")
409+
410+
return canonical
411+
360412
@field_validator("api_endpoint", "website_endpoint", mode="before")
361413
@classmethod
362414
def _validate_http_url(cls, value: Any) -> str:

tidy3d/web/core/http_util.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import json
66
import os
7+
import ssl
78
from enum import Enum
89
from functools import wraps
910
from typing import Any, Callable, Optional, TypeAlias
@@ -200,7 +201,10 @@ def wrapper(*args: Any, **kwargs: Any) -> JSONType:
200201

201202
class TLSAdapter(HTTPAdapter):
202203
def init_poolmanager(self, *args: Any, **kwargs: Any) -> None:
203-
context = create_urllib3_context(ssl_version=config.web.ssl_version)
204+
ssl_version = (
205+
ssl.TLSVersion[config.web.ssl_version] if config.web.ssl_version is not None else None
206+
)
207+
context = create_urllib3_context(ssl_version=ssl_version)
204208
kwargs["ssl_context"] = context
205209
return super().init_poolmanager(*args, **kwargs)
206210

0 commit comments

Comments
 (0)