From 02657fab4fa5018e5b67ec2fa87a15ceea434e87 Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Fri, 1 Aug 2025 22:32:56 +0000 Subject: [PATCH 01/14] feat: MDS connections use mTLS --- google/auth/compute_engine/_metadata.py | 60 +++++++-- google/auth/compute_engine/_mtls.py | 109 +++++++++++++++++ google/auth/environment_vars.py | 6 + tests/compute_engine/test__metadata.py | 156 ++++++++++++++++++++---- tests/compute_engine/test__mtls.py | 115 +++++++++++++++++ 5 files changed, 412 insertions(+), 34 deletions(-) create mode 100644 google/auth/compute_engine/_mtls.py create mode 100644 tests/compute_engine/test__mtls.py diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index ddbe8ac2f..9554bcb0a 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -30,6 +30,8 @@ from google.auth import metrics from google.auth import transport from google.auth._exponential_backoff import ExponentialBackoff +from google.auth.compute_engine import _mtls +from google.auth.transport import requests _LOGGER = logging.getLogger(__name__) @@ -42,13 +44,24 @@ _GCE_METADATA_HOST = os.getenv( environment_vars.GCE_METADATA_ROOT, "metadata.google.internal" ) -_METADATA_ROOT = "http://{}/computeMetadata/v1/".format(_GCE_METADATA_HOST) -# This is used to ping the metadata server, it avoids the cost of a DNS -# lookup. -_METADATA_IP_ROOT = "http://{}".format( - os.getenv(environment_vars.GCE_METADATA_IP, "169.254.169.254") -) +GCE_MDS_HOSTS = ["metadata.google.internal", "169.254.169.254"] + + +def _get_metadata_root(use_mtls): + """Returns the metadata server root URL.""" + scheme = "https" if use_mtls else "http" + return "{}://{}/computeMetadata/v1/".format(scheme, _GCE_METADATA_HOST) + + +def _get_metadata_ip_root(use_mtls): + """Returns the metadata server IP root URL.""" + scheme = "https" if use_mtls else "http" + return "{}://{}".format( + scheme, os.getenv(environment_vars.GCE_METADATA_IP, "169.254.169.254") + ) + + _METADATA_FLAVOR_HEADER = "metadata-flavor" _METADATA_FLAVOR_VALUE = "Google" _METADATA_HEADERS = {_METADATA_FLAVOR_HEADER: _METADATA_FLAVOR_VALUE} @@ -102,6 +115,24 @@ def detect_gce_residency_linux(): return content.startswith(_GOOGLE) +def _prepare_request_for_mds(request, use_mtls=False): + """Prepares a request for the metadata server. + + This will check if mTLS should be used and return a new request object if so. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + + Returns: + google.auth.transport.Request: Request + object to use. + """ + if use_mtls: + request = requests.Request(_mtls.create_session()) + return request + + def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): """Checks to see if the metadata server is available. @@ -115,6 +146,8 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): Returns: bool: True if the metadata server is reachable, False otherwise. """ + use_mtls = _mtls.should_use_mds_mtls() + request = _prepare_request_for_mds(request, use_mtls=use_mtls) # NOTE: The explicit ``timeout`` is a workaround. The underlying # issue is that resolving an unknown host on some networks will take # 20-30 seconds; making this timeout short fixes the issue, but @@ -129,7 +162,10 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): for attempt in backoff: try: response = request( - url=_METADATA_IP_ROOT, method="GET", headers=headers, timeout=timeout + url=_get_metadata_ip_root(use_mtls), + method="GET", + headers=headers, + timeout=timeout, ) metadata_flavor = response.headers.get(_METADATA_FLAVOR_HEADER) @@ -153,7 +189,7 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): def get( request, path, - root=_METADATA_ROOT, + root=None, params=None, recursive=False, retry_count=5, @@ -190,6 +226,14 @@ def get( google.auth.exceptions.TransportError: if an error occurred while retrieving metadata. """ + use_mtls = _mtls.should_use_mds_mtls() + # Prepare the request object for mTLS if needed. + # This will create a new request object with the mTLS session. + request = _prepare_request_for_mds(request, use_mtls=use_mtls) + + if root is None: + root = _get_metadata_root(use_mtls) + base_url = urljoin(root, path) query_params = {} if params is None else params diff --git a/google/auth/compute_engine/_mtls.py b/google/auth/compute_engine/_mtls.py new file mode 100644 index 000000000..8909a6ca7 --- /dev/null +++ b/google/auth/compute_engine/_mtls.py @@ -0,0 +1,109 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Mutual TLS for Google Compute Engine metadata server.""" + +from dataclasses import dataclass +import enum +import os +from pathlib import Path +import ssl + +import requests +from requests.adapters import HTTPAdapter + +from google.auth import environment_vars, exceptions + + +@dataclass +class MdsMtlsConfig: + ca_cert_path: str = os.path.join( + Path.home(), "mtls_mds_certificates", "root.crt" + ) # path to CA certificate + client_combined_cert_path: str = os.path.join( + Path.home(), "mtls_mds_certificates", "client_creds.key" + ) # path to file containing client certificate and key + + +class MdsMtlsMode(enum.Enum): + """MDS mTLS mode.""" + + STRICT = "strict" + NONE = "none" + DEFAULT = "default" + + +def _parse_mds_mode(): + """Parses the GCE_METADATA_MTLS_MODE environment variable.""" + mode_str = os.environ.get( + environment_vars.GCE_METADATA_MTLS_MODE, "default" + ).lower() + try: + return MdsMtlsMode(mode_str) + except ValueError: + raise ValueError( + "Invalid value for GCE_METADATA_MTLS_MODE. Must be one of 'strict', 'none', or 'default'." + ) + + +def _certs_exist(mds_mtls_config: MdsMtlsConfig): + """Checks if the mTLS certificates exist.""" + return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists( + mds_mtls_config.client_combined_cert_path + ) + + +class MdsMtlsAdapter(HTTPAdapter): + """An HTTP adapter that uses mTLS for the metadata server.""" + + def __init__(self, mds_mtls_config: MdsMtlsConfig, *args, **kwargs): + self.ssl_context = ssl.create_default_context() + self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path) + self.ssl_context.load_cert_chain( + certfile=mds_mtls_config.client_combined_cert_path + ) + super(MdsMtlsAdapter, self).__init__(*args, **kwargs) + + def init_poolmanager(self, *args, **kwargs): + kwargs["ssl_context"] = self.ssl_context + return super(MdsMtlsAdapter, self).init_poolmanager(*args, **kwargs) + + def proxy_manager_for(self, *args, **kwargs): + kwargs["ssl_context"] = self.ssl_context + return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs) + + +def create_session(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()): + """Creates a requests.Session configured for mTLS.""" + session = requests.Session() + adapter = MdsMtlsAdapter(mds_mtls_config) + session.mount("https://", adapter) + return session + + +def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()): + """Determines if mTLS should be used for the metadata server.""" + mode = _parse_mds_mode() + if mode == MdsMtlsMode.STRICT: + if not _certs_exist(mds_mtls_config): + raise exceptions.MutualTLSChannelError( + "mTLS certificates not found in strict mode." + ) + return True + elif mode == MdsMtlsMode.NONE: + return False + else: # Default mode + return _certs_exist(mds_mtls_config) diff --git a/google/auth/environment_vars.py b/google/auth/environment_vars.py index e5f3598e8..5da3a7382 100644 --- a/google/auth/environment_vars.py +++ b/google/auth/environment_vars.py @@ -60,6 +60,12 @@ """Environment variable providing an alternate ip:port to be used for ip-only GCE metadata requests.""" +GCE_METADATA_MTLS_MODE = "GCE_METADATA_MTLS_MODE" +"""Environment variable controlling the mTLS behavior for GCE metadata requests. + +Can be one of "strict", "none", or "default". +""" + GOOGLE_API_USE_CLIENT_CERTIFICATE = "GOOGLE_API_USE_CLIENT_CERTIFICATE" """Environment variable controlling whether to use client certificate or not. diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index c90bc603a..fc9afb126 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -20,12 +20,14 @@ import mock import pytest # type: ignore +import requests from google.auth import _helpers from google.auth import environment_vars from google.auth import exceptions from google.auth import transport from google.auth.compute_engine import _metadata +from google.auth.transport import requests as google_auth_requests PATH = "instance/service-accounts/default" @@ -104,7 +106,7 @@ def test_ping_success(mock_metrics_header_value): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_IP_ROOT, + url="http://169.254.169.254", headers=MDS_PING_REQUEST_HEADER, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -118,7 +120,7 @@ def test_ping_success_retry(mock_metrics_header_value): request.assert_called_with( method="GET", - url=_metadata._METADATA_IP_ROOT, + url="http://169.254.169.254", headers=MDS_PING_REQUEST_HEADER, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -172,7 +174,7 @@ def test_get_success_json(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -191,7 +193,7 @@ def test_get_success_json_content_type_charset(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -211,7 +213,7 @@ def test_get_success_retry(mock_sleep): request.assert_called_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -227,7 +229,7 @@ def test_get_success_text(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -243,7 +245,9 @@ def test_get_success_params(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "?recursive=true", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "?recursive=true", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -258,7 +262,9 @@ def test_get_success_recursive_and_params(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "?recursive=true", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "?recursive=true", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -273,7 +279,9 @@ def test_get_success_recursive(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "?recursive=true", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "?recursive=true", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -333,7 +341,7 @@ def test_get_failure(mock_sleep): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -346,7 +354,7 @@ def test_get_return_none_for_not_found_error(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -366,7 +374,7 @@ def test_get_failure_connection_failed(mock_sleep): request.assert_called_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -385,7 +393,7 @@ def test_get_too_many_requests_retryable_error_failure(): request.assert_called_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -402,7 +410,7 @@ def test_get_failure_bad_json(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -416,7 +424,7 @@ def test_get_project_id(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "project/project-id", + url="http://metadata.google.internal/computeMetadata/v1/project/project-id", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -432,7 +440,7 @@ def test_get_universe_domain_success(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -446,7 +454,7 @@ def test_get_universe_domain_success_empty_response(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -462,7 +470,7 @@ def test_get_universe_domain_not_found(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -483,7 +491,7 @@ def test_get_universe_domain_retryable_error_failure(): request.assert_called_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -526,13 +534,13 @@ def request(self, *args, **kwargs): request_error.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) request_ok.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -552,7 +560,7 @@ def test_get_universe_domain_other_error(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -574,7 +582,7 @@ def test_get_service_account_token(utcnow, mock_metrics_header_value): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "/token", + url="http://metadata.google.internal/computeMetadata/v1/" + PATH + "/token", headers={ "metadata-flavor": "Google", "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, @@ -601,7 +609,10 @@ def test_get_service_account_token_with_scopes_list(utcnow, mock_metrics_header_ request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "/token" + "?scopes=foo%2Cbar", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "/token" + + "?scopes=foo%2Cbar", headers={ "metadata-flavor": "Google", "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, @@ -630,7 +641,10 @@ def test_get_service_account_token_with_scopes_string( request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "/token" + "?scopes=foo%2Cbar", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "/token" + + "?scopes=foo%2Cbar", headers={ "metadata-flavor": "Google", "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, @@ -651,9 +665,99 @@ def test_get_service_account_info(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "/?recursive=true", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "/?recursive=true", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert info[key] == value + + +def test__get_metadata_root_mtls(): + assert ( + _metadata._get_metadata_root(use_mtls=True) + == "https://metadata.google.internal/computeMetadata/v1/" + ) + + +def test__get_metadata_root_no_mtls(): + assert ( + _metadata._get_metadata_root(use_mtls=False) + == "http://metadata.google.internal/computeMetadata/v1/" + ) + + +def test__get_metadata_ip_root_mtls(): + assert _metadata._get_metadata_ip_root(use_mtls=True) == "https://169.254.169.254" + + +def test__get_metadata_ip_root_no_mtls(): + assert _metadata._get_metadata_ip_root(use_mtls=False) == "http://169.254.169.254" + + +@mock.patch("google.auth.compute_engine._mtls.create_session") +def test__prepare_request_for_mds_mtls(mock_create_session): + request = mock.Mock() + new_request = _metadata._prepare_request_for_mds(request, use_mtls=True) + mock_create_session.assert_called_once() + assert isinstance(new_request, google_auth_requests.Request) + + +def test__prepare_request_for_mds_no_mtls(): + request = mock.Mock() + new_request = _metadata._prepare_request_for_mds(request, use_mtls=False) + assert new_request is request + + +@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True) +@mock.patch("google.auth.compute_engine._mtls.create_session") +@mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE) +def test_ping_mtls( + mock_metrics_header_value, mock_create_session, mock_should_use_mtls +): + response = mock.create_autospec(requests.Response, instance=True) + response.status_code = http_client.OK + response.headers = _metadata._METADATA_HEADERS + mock_session = mock.Mock() + mock_session.request.return_value = response + mock_create_session.return_value = mock_session + + initial_request = mock.Mock() + assert _metadata.ping(initial_request) + + mock_should_use_mtls.assert_called_once() + mock_create_session.assert_called_once() + mock_session.request.assert_called_once_with( + "GET", + "https://169.254.169.254", + headers=MDS_PING_REQUEST_HEADER, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, + data=None, + ) + + +@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True) +@mock.patch("google.auth.compute_engine._mtls.create_session") +def test_get_mtls(mock_create_session, mock_should_use_mtls): + response = mock.create_autospec(requests.Response, instance=True) + response.status_code = http_client.OK + response.content = _helpers.to_bytes("{}") + response.headers = {"content-type": "application/json"} + mock_session = mock.Mock() + mock_session.request.return_value = response + mock_create_session.return_value = mock_session + + initial_request = mock.Mock() + _metadata.get(initial_request, "some/path") + + mock_should_use_mtls.assert_called_once() + mock_create_session.assert_called_once() + mock_session.request.assert_called_once_with( + "GET", + "https://metadata.google.internal/computeMetadata/v1/some/path", + data=None, + headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, + ) diff --git a/tests/compute_engine/test__mtls.py b/tests/compute_engine/test__mtls.py new file mode 100644 index 000000000..193272815 --- /dev/null +++ b/tests/compute_engine/test__mtls.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import mock +import pytest + +from google.auth import environment_vars, exceptions +from google.auth.compute_engine import _mtls + + +@pytest.fixture +def mock_mds_mtls_config(): + return _mtls.MdsMtlsConfig( + ca_cert_path="/fake/ca.crt", client_combined_cert_path="/fake/client.key" + ) + + +def test__parse_mds_mode_default(monkeypatch): + monkeypatch.delenv(environment_vars.GCE_METADATA_MTLS_MODE, raising=False) + assert _mtls._parse_mds_mode() == _mtls.MdsMtlsMode.DEFAULT + + +@pytest.mark.parametrize( + "mode_str, expected_mode", + [ + ("strict", _mtls.MdsMtlsMode.STRICT), + ("none", _mtls.MdsMtlsMode.NONE), + ("default", _mtls.MdsMtlsMode.DEFAULT), + ("STRICT", _mtls.MdsMtlsMode.STRICT), + ], +) +def test__parse_mds_mode_valid(monkeypatch, mode_str, expected_mode): + monkeypatch.setenv(environment_vars.GCE_METADATA_MTLS_MODE, mode_str) + assert _mtls._parse_mds_mode() == expected_mode + + +def test__parse_mds_mode_invalid(monkeypatch): + monkeypatch.setenv(environment_vars.GCE_METADATA_MTLS_MODE, "invalid_mode") + with pytest.raises(ValueError): + _mtls._parse_mds_mode() + + +@mock.patch("os.path.exists") +def test__certs_exist_true(mock_exists, mock_mds_mtls_config): + mock_exists.return_value = True + assert _mtls._certs_exist(mock_mds_mtls_config) is True + + +@mock.patch("os.path.exists") +def test__certs_exist_false(mock_exists, mock_mds_mtls_config): + mock_exists.return_value = False + assert _mtls._certs_exist(mock_mds_mtls_config) is False + + +@pytest.mark.parametrize( + "mtls_mode, certs_exist, expected_result", + [ + ("strict", True, True), + ("strict", False, exceptions.MutualTLSChannelError), + ("none", True, False), + ("none", False, False), + ("default", True, True), + ("default", False, False), + ], +) +@mock.patch("os.path.exists") +def test_should_use_mds_mtls( + mock_exists, monkeypatch, mtls_mode, certs_exist, expected_result +): + monkeypatch.setenv(environment_vars.GCE_METADATA_MTLS_MODE, mtls_mode) + mock_exists.return_value = certs_exist + + if isinstance(expected_result, type) and issubclass(expected_result, Exception): + with pytest.raises(expected_result): + _mtls.should_use_mds_mtls() + else: + assert _mtls.should_use_mds_mtls() is expected_result + + +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_init(mock_ssl_context, mock_mds_mtls_config): + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + mock_ssl_context.assert_called_once() + adapter.ssl_context.load_verify_locations.assert_called_once_with( + cafile=mock_mds_mtls_config.ca_cert_path + ) + adapter.ssl_context.load_cert_chain.assert_called_once_with( + certfile=mock_mds_mtls_config.client_combined_cert_path + ) + + +@mock.patch("requests.Session") +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +def test_create_session(mock_adapter, mock_session, mock_mds_mtls_config): + session_instance = mock_session.return_value + session = _mtls.create_session(mock_mds_mtls_config) + assert session is session_instance + mock_adapter.assert_called_once_with(mock_mds_mtls_config) + session_instance.mount.assert_called_once_with( + "https://", mock_adapter.return_value + ) From b49aad1d825603909c290ddc0046b2152d7b6c01 Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Tue, 4 Nov 2025 19:44:48 +0000 Subject: [PATCH 02/14] add custom root unit test to increase coverage --- tests/compute_engine/test__metadata.py | 14 ++++++++++++++ tests/compute_engine/test__mtls.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index fc9afb126..ad01a83c0 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -329,6 +329,20 @@ def test_get_success_custom_root_old_variable(): timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) +def test_get_success_custom_root(): + request = make_request("{}", headers={"content-type": "application/json"}) + + fake_root = "http://another.metadata.service" + + _metadata.get(request, PATH, root=fake_root) + + request.assert_called_once_with( + method="GET", + url="{}/{}".format(fake_root, PATH), + headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, + ) + @mock.patch("time.sleep", return_value=None) def test_get_failure(mock_sleep): diff --git a/tests/compute_engine/test__mtls.py b/tests/compute_engine/test__mtls.py index 193272815..9ff1fe651 100644 --- a/tests/compute_engine/test__mtls.py +++ b/tests/compute_engine/test__mtls.py @@ -16,7 +16,7 @@ # import mock -import pytest +import pytest # type: ignore from google.auth import environment_vars, exceptions from google.auth.compute_engine import _mtls From a6a2cc3a0b68ecec82d13b77a7aa5bf8e35f3c61 Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Tue, 4 Nov 2025 20:14:17 +0000 Subject: [PATCH 03/14] add unit test to increase coverage --- tests/compute_engine/test__metadata.py | 1 + tests/compute_engine/test__mtls.py | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index ad01a83c0..c5631308b 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -329,6 +329,7 @@ def test_get_success_custom_root_old_variable(): timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) + def test_get_success_custom_root(): request = make_request("{}", headers={"content-type": "application/json"}) diff --git a/tests/compute_engine/test__mtls.py b/tests/compute_engine/test__mtls.py index 9ff1fe651..e34923ed8 100644 --- a/tests/compute_engine/test__mtls.py +++ b/tests/compute_engine/test__mtls.py @@ -113,3 +113,15 @@ def test_create_session(mock_adapter, mock_session, mock_mds_mtls_config): session_instance.mount.assert_called_once_with( "https://", mock_adapter.return_value ) + + +@mock.patch("ssl.create_default_context") +@mock.patch("requests.adapters.HTTPAdapter.proxy_manager_for") +def test_mds_mtls_adapter_proxy_manager_for( + mock_proxy_manager_for, mock_ssl_context, mock_mds_mtls_config +): + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + adapter.proxy_manager_for("test_proxy") + mock_proxy_manager_for.assert_called_once_with( + "test_proxy", ssl_context=adapter.ssl_context + ) From 6e00ea4ae00be779d7d5adf0ca10dcc55ffab66d Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Tue, 4 Nov 2025 23:49:15 +0000 Subject: [PATCH 04/14] add explanation to mtlsmds modes --- google/auth/compute_engine/_mtls.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/google/auth/compute_engine/_mtls.py b/google/auth/compute_engine/_mtls.py index 8909a6ca7..18e3c305b 100644 --- a/google/auth/compute_engine/_mtls.py +++ b/google/auth/compute_engine/_mtls.py @@ -39,7 +39,12 @@ class MdsMtlsConfig: class MdsMtlsMode(enum.Enum): - """MDS mTLS mode.""" + """MDS mTLS mode. Used to configure connection behavior when connecting to MDS. + + STRICT: Always use HTTPS/mTLS. If certificates are not found locally, an error will be returned. + NONE: Never use mTLS. Requests will use regular HTTP. + DEFAULT: Use mTLS if certificates are found locally, otherwise use regular HTTP. + """ STRICT = "strict" NONE = "none" From df1378e56a2bd7198eebfb54794be6fb62bde021 Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Wed, 5 Nov 2025 19:33:23 +0000 Subject: [PATCH 05/14] Update mds mtls certificate well-known locations --- google/auth/compute_engine/_mtls.py | 33 +++++++++++++++++++++++------ tests/compute_engine/test__mtls.py | 24 +++++++++++++++++++++ 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/google/auth/compute_engine/_mtls.py b/google/auth/compute_engine/_mtls.py index 18e3c305b..47834c5f6 100644 --- a/google/auth/compute_engine/_mtls.py +++ b/google/auth/compute_engine/_mtls.py @@ -16,10 +16,9 @@ # """Mutual TLS for Google Compute Engine metadata server.""" -from dataclasses import dataclass +from dataclasses import dataclass, field import enum import os -from pathlib import Path import ssl import requests @@ -27,14 +26,36 @@ from google.auth import environment_vars, exceptions +# MDS mTLS certificate paths based on OS. +# Documentation to well known locations can be found at: +# https://cloud.google.com/compute/docs/metadata/overview#https-mds-certificates + + +def _get_mds_root_crt_path(): + if os.name == "nt": + return os.path.join( + "C:\\", "ProgramData", "Google", "ComputeEngine", "mds-mtls-root.crt" + ) + else: + return os.path.join("/", "run", "google-mds-mtls", "root.crt") + + +def _get_mds_client_combined_cert_path(): + if os.name == "nt": + return os.path.join( + "C:\\", "ProgramData", "Google", "ComputeEngine", "mds-mtls-client.key" + ) + else: + return os.path.join("/", "run", "google-mds-mtls", "client.key") + @dataclass class MdsMtlsConfig: - ca_cert_path: str = os.path.join( - Path.home(), "mtls_mds_certificates", "root.crt" + ca_cert_path: str = field( + default_factory=_get_mds_root_crt_path ) # path to CA certificate - client_combined_cert_path: str = os.path.join( - Path.home(), "mtls_mds_certificates", "client_creds.key" + client_combined_cert_path: str = field( + default_factory=_get_mds_client_combined_cert_path ) # path to file containing client certificate and key diff --git a/tests/compute_engine/test__mtls.py b/tests/compute_engine/test__mtls.py index e34923ed8..f0426f296 100644 --- a/tests/compute_engine/test__mtls.py +++ b/tests/compute_engine/test__mtls.py @@ -15,6 +15,8 @@ # limitations under the License. # +import os + import mock import pytest # type: ignore @@ -29,6 +31,28 @@ def mock_mds_mtls_config(): ) +@mock.patch("os.name", "nt") +def test__MdsMtlsConfig_windows_defaults(): + config = _mtls.MdsMtlsConfig() + assert config.ca_cert_path == os.path.join( + "C:\\", "ProgramData", "Google", "ComputeEngine", "mds-mtls-root.crt" + ) + assert config.client_combined_cert_path == os.path.join( + "C:\\", "ProgramData", "Google", "ComputeEngine", "mds-mtls-client.key" + ) + + +@mock.patch("os.name", "posix") +def test__MdsMtlsConfig_non_windows_defaults(): + config = _mtls.MdsMtlsConfig() + assert config.ca_cert_path == os.path.join( + "/", "run", "google-mds-mtls", "root.crt" + ) + assert config.client_combined_cert_path == os.path.join( + "/", "run", "google-mds-mtls", "client.key" + ) + + def test__parse_mds_mode_default(monkeypatch): monkeypatch.delenv(environment_vars.GCE_METADATA_MTLS_MODE, raising=False) assert _mtls._parse_mds_mode() == _mtls.MdsMtlsMode.DEFAULT From e3a311a79e70bf2a2f19b7cd3579f53bdfc33030 Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Fri, 7 Nov 2025 17:58:12 +0000 Subject: [PATCH 06/14] modify contants --- google/auth/compute_engine/_metadata.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index 9554bcb0a..7637f5ac8 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -45,7 +45,8 @@ environment_vars.GCE_METADATA_ROOT, "metadata.google.internal" ) -GCE_MDS_HOSTS = ["metadata.google.internal", "169.254.169.254"] +_GCE_DEFAULT_MDS_IP = "169.254.169.254" +_GCE_MDS_HOSTS = ["metadata.google.internal", _GCE_DEFAULT_MDS_IP] def _get_metadata_root(use_mtls): @@ -58,7 +59,7 @@ def _get_metadata_ip_root(use_mtls): """Returns the metadata server IP root URL.""" scheme = "https" if use_mtls else "http" return "{}://{}".format( - scheme, os.getenv(environment_vars.GCE_METADATA_IP, "169.254.169.254") + scheme, os.getenv(environment_vars.GCE_METADATA_IP, _GCE_DEFAULT_MDS_IP) ) @@ -123,10 +124,12 @@ def _prepare_request_for_mds(request, use_mtls=False): Args: request (google.auth.transport.Request): A callable used to make HTTP requests. + use_mtls (bool): Whether to use mTLS for the request. Returns: - google.auth.transport.Request: Request - object to use. + google.auth.transport.Request: A request object to use. + If mTLS is enabled, this will be a new request object with mTLS session configured. + Otherwise, it will be the same as the input request. """ if use_mtls: request = requests.Request(_mtls.create_session()) From 5bbac84e383b67d180fe5a5676cf11c961966cc7 Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Fri, 7 Nov 2025 18:41:45 +0000 Subject: [PATCH 07/14] change the mds mtls implementation 1. now we do not create a new request. instead, create an mds mtls adapter and mount it on the request session. 2. added _validate_gce_mds_configured_environment, which ensures if we are using strict, that the host being contacted is default 3. fix unit tests and add new tests --- google/auth/compute_engine/_metadata.py | 46 +++++++-- google/auth/compute_engine/_mtls.py | 53 +++++------ tests/compute_engine/test__metadata.py | 119 +++++++++++++++++------- tests/compute_engine/test__mtls.py | 38 ++++++-- 4 files changed, 173 insertions(+), 83 deletions(-) diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index 7637f5ac8..dbe27ac12 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -24,6 +24,8 @@ import os from urllib.parse import urljoin +import requests + from google.auth import _helpers from google.auth import environment_vars from google.auth import exceptions @@ -31,10 +33,14 @@ from google.auth import transport from google.auth._exponential_backoff import ExponentialBackoff from google.auth.compute_engine import _mtls -from google.auth.transport import requests + _LOGGER = logging.getLogger(__name__) +_GCE_DEFAULT_MDS_IP = "169.254.169.254" +_GCE_DEFAULT_HOST = "metadata.google.internal" +_GCE_DEFAULT_MDS_HOSTS = [_GCE_DEFAULT_HOST, _GCE_DEFAULT_MDS_IP] + # Environment variable GCE_METADATA_HOST is originally named # GCE_METADATA_ROOT. For compatibility reasons, here it checks # the new variable first; if not set, the system falls back @@ -42,20 +48,37 @@ _GCE_METADATA_HOST = os.getenv(environment_vars.GCE_METADATA_HOST, None) if not _GCE_METADATA_HOST: _GCE_METADATA_HOST = os.getenv( - environment_vars.GCE_METADATA_ROOT, "metadata.google.internal" + environment_vars.GCE_METADATA_ROOT, _GCE_DEFAULT_HOST ) -_GCE_DEFAULT_MDS_IP = "169.254.169.254" -_GCE_MDS_HOSTS = ["metadata.google.internal", _GCE_DEFAULT_MDS_IP] +def _validate_gce_mds_configured_environment(): + """Validates the GCE metadata server environment configuration for mTLS. -def _get_metadata_root(use_mtls): + Raises: + google.auth.exceptions.MutualTLSChannelError: if the environment + configuration is invalid for mTLS. + """ + mode = _mtls._parse_mds_mode() + if mode == _mtls.MdsMtlsMode.STRICT: + if _GCE_METADATA_HOST != _GCE_DEFAULT_HOST: + # mTLS is only supported when connecting to the default metadata host. + # Raise an exception if we are in strict mode (which requires mTLS) + # but the metadata host has been overridden. (which means mTLS will fail) + raise exceptions.MutualTLSChannelError( + "Mutual TLS is required, but the metadata host has been overridden. " + "mTLS is only supported when connecting to the default metadata host." + ) + + +def _get_metadata_root(use_mtls: bool): """Returns the metadata server root URL.""" + scheme = "https" if use_mtls else "http" return "{}://{}/computeMetadata/v1/".format(scheme, _GCE_METADATA_HOST) -def _get_metadata_ip_root(use_mtls): +def _get_metadata_ip_root(use_mtls: bool): """Returns the metadata server IP root URL.""" scheme = "https" if use_mtls else "http" return "{}://{}".format( @@ -131,8 +154,14 @@ def _prepare_request_for_mds(request, use_mtls=False): If mTLS is enabled, this will be a new request object with mTLS session configured. Otherwise, it will be the same as the input request. """ - if use_mtls: - request = requests.Request(_mtls.create_session()) + if not use_mtls: + return request + + adapter = _mtls.MdsMtlsAdapter() + if not request.session: + request.session = requests.Session() + for host in _GCE_DEFAULT_MDS_HOSTS: + request.session.mount(f"https://{host}/", adapter) return request @@ -236,6 +265,7 @@ def get( if root is None: root = _get_metadata_root(use_mtls) + _validate_gce_mds_configured_environment() base_url = urljoin(root, path) query_params = {} if params is None else params diff --git a/google/auth/compute_engine/_mtls.py b/google/auth/compute_engine/_mtls.py index 47834c5f6..4cf4c7303 100644 --- a/google/auth/compute_engine/_mtls.py +++ b/google/auth/compute_engine/_mtls.py @@ -21,7 +21,6 @@ import os import ssl -import requests from requests.adapters import HTTPAdapter from google.auth import environment_vars, exceptions @@ -59,6 +58,13 @@ class MdsMtlsConfig: ) # path to file containing client certificate and key +def _certs_exist(mds_mtls_config: MdsMtlsConfig): + """Checks if the mTLS certificates exist.""" + return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists( + mds_mtls_config.client_combined_cert_path + ) + + class MdsMtlsMode(enum.Enum): """MDS mTLS mode. Used to configure connection behavior when connecting to MDS. @@ -85,17 +91,27 @@ def _parse_mds_mode(): ) -def _certs_exist(mds_mtls_config: MdsMtlsConfig): - """Checks if the mTLS certificates exist.""" - return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists( - mds_mtls_config.client_combined_cert_path - ) +def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()): + """Determines if mTLS should be used for the metadata server.""" + mode = _parse_mds_mode() + if mode == MdsMtlsMode.STRICT: + if not _certs_exist(mds_mtls_config): + raise exceptions.MutualTLSChannelError( + "mTLS certificates not found in strict mode." + ) + return True + elif mode == MdsMtlsMode.NONE: + return False + else: # Default mode + return _certs_exist(mds_mtls_config) class MdsMtlsAdapter(HTTPAdapter): """An HTTP adapter that uses mTLS for the metadata server.""" - def __init__(self, mds_mtls_config: MdsMtlsConfig, *args, **kwargs): + def __init__( + self, mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig(), *args, **kwargs + ): self.ssl_context = ssl.create_default_context() self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path) self.ssl_context.load_cert_chain( @@ -110,26 +126,3 @@ def init_poolmanager(self, *args, **kwargs): def proxy_manager_for(self, *args, **kwargs): kwargs["ssl_context"] = self.ssl_context return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs) - - -def create_session(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()): - """Creates a requests.Session configured for mTLS.""" - session = requests.Session() - adapter = MdsMtlsAdapter(mds_mtls_config) - session.mount("https://", adapter) - return session - - -def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()): - """Determines if mTLS should be used for the metadata server.""" - mode = _parse_mds_mode() - if mode == MdsMtlsMode.STRICT: - if not _certs_exist(mds_mtls_config): - raise exceptions.MutualTLSChannelError( - "mTLS certificates not found in strict mode." - ) - return True - elif mode == MdsMtlsMode.NONE: - return False - else: # Default mode - return _certs_exist(mds_mtls_config) diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index c5631308b..cd15ffe51 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -712,12 +712,12 @@ def test__get_metadata_ip_root_no_mtls(): assert _metadata._get_metadata_ip_root(use_mtls=False) == "http://169.254.169.254" -@mock.patch("google.auth.compute_engine._mtls.create_session") -def test__prepare_request_for_mds_mtls(mock_create_session): - request = mock.Mock() - new_request = _metadata._prepare_request_for_mds(request, use_mtls=True) - mock_create_session.assert_called_once() - assert isinstance(new_request, google_auth_requests.Request) +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +def test__prepare_request_for_mds_mtls(mock_mds_mtls_adapter): + request = google_auth_requests.Request(mock.create_autospec(requests.Session)) + _metadata._prepare_request_for_mds(request, use_mtls=True) + mock_mds_mtls_adapter.assert_called_once() + assert request.session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) def test__prepare_request_for_mds_no_mtls(): @@ -726,53 +726,100 @@ def test__prepare_request_for_mds_no_mtls(): assert new_request is request -@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True) -@mock.patch("google.auth.compute_engine._mtls.create_session") @mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE) +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True) +@mock.patch("google.auth.transport.requests.Request") def test_ping_mtls( - mock_metrics_header_value, mock_create_session, mock_should_use_mtls + mock_request, mock_should_use_mtls, mock_mds_mtls_adapter, mock_metrics_header_value ): - response = mock.create_autospec(requests.Response, instance=True) - response.status_code = http_client.OK + response = mock.create_autospec(transport.Response, instance=True) + response.status = http_client.OK response.headers = _metadata._METADATA_HEADERS - mock_session = mock.Mock() - mock_session.request.return_value = response - mock_create_session.return_value = mock_session + mock_request.return_value = response - initial_request = mock.Mock() - assert _metadata.ping(initial_request) + assert _metadata.ping(mock_request) mock_should_use_mtls.assert_called_once() - mock_create_session.assert_called_once() - mock_session.request.assert_called_once_with( - "GET", - "https://169.254.169.254", + mock_mds_mtls_adapter.assert_called_once() + mock_request.assert_called_once_with( + url="https://169.254.169.254", + method="GET", headers=MDS_PING_REQUEST_HEADER, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, - data=None, ) +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") @mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True) -@mock.patch("google.auth.compute_engine._mtls.create_session") -def test_get_mtls(mock_create_session, mock_should_use_mtls): - response = mock.create_autospec(requests.Response, instance=True) - response.status_code = http_client.OK - response.content = _helpers.to_bytes("{}") +@mock.patch("google.auth.transport.requests.Request") +def test_get_mtls(mock_request, mock_should_use_mtls, mock_mds_mtls_adapter): + response = mock.create_autospec(transport.Response, instance=True) + response.status = http_client.OK + response.data = _helpers.to_bytes("{}") response.headers = {"content-type": "application/json"} - mock_session = mock.Mock() - mock_session.request.return_value = response - mock_create_session.return_value = mock_session + mock_request.return_value = response - initial_request = mock.Mock() - _metadata.get(initial_request, "some/path") + _metadata.get(mock_request, "some/path") mock_should_use_mtls.assert_called_once() - mock_create_session.assert_called_once() - mock_session.request.assert_called_once_with( - "GET", - "https://metadata.google.internal/computeMetadata/v1/some/path", - data=None, + mock_mds_mtls_adapter.assert_called_once() + mock_request.assert_called_once_with( + url="https://metadata.google.internal/computeMetadata/v1/some/path", + method="GET", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) + + +@pytest.mark.parametrize( + "mds_mode, metadata_host, expect_exception", + [ + (_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_HOST, False), + (_metadata._mtls.MdsMtlsMode.STRICT, "custom.host", True), + (_metadata._mtls.MdsMtlsMode.NONE, "custom.host", False), + (_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_HOST, False), + ], +) +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +def test_validate_gce_mds_configured_environment( + mock_parse_mds_mode, mds_mode, metadata_host, expect_exception +): + mock_parse_mds_mode.return_value = mds_mode + with mock.patch( + "google.auth.compute_engine._metadata._GCE_METADATA_HOST", new=metadata_host + ): + if expect_exception: + with pytest.raises(exceptions.MutualTLSChannelError): + _metadata._validate_gce_mds_configured_environment() + else: + _metadata._validate_gce_mds_configured_environment() + mock_parse_mds_mode.assert_called_once() + + +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +def test__prepare_request_for_mds_mtls_session_exists(mock_mds_mtls_adapter): + mock_session = mock.create_autospec(requests.Session) + request = google_auth_requests.Request(mock_session) + new_request = _metadata._prepare_request_for_mds(request, use_mtls=True) + + mock_mds_mtls_adapter.assert_called_once() + assert mock_session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) + assert new_request is request + + +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +def test__prepare_request_for_mds_mtls_no_session(mock_mds_mtls_adapter): + request = google_auth_requests.Request(None) + # Explicitly set session to None to avoid a session being created in the Request constructor. + request.session = None + + with mock.patch("requests.Session") as mock_session_class: + new_request = _metadata._prepare_request_for_mds(request, use_mtls=True) + + mock_session_class.assert_called_once() + mock_mds_mtls_adapter.assert_called_once() + assert new_request.session.mount.call_count == len( + _metadata._GCE_DEFAULT_MDS_HOSTS + ) + assert new_request is request diff --git a/tests/compute_engine/test__mtls.py b/tests/compute_engine/test__mtls.py index f0426f296..5c8e39818 100644 --- a/tests/compute_engine/test__mtls.py +++ b/tests/compute_engine/test__mtls.py @@ -19,6 +19,7 @@ import mock import pytest # type: ignore +import requests from google.auth import environment_vars, exceptions from google.auth.compute_engine import _mtls @@ -127,15 +128,14 @@ def test_mds_mtls_adapter_init(mock_ssl_context, mock_mds_mtls_config): ) -@mock.patch("requests.Session") -@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") -def test_create_session(mock_adapter, mock_session, mock_mds_mtls_config): - session_instance = mock_session.return_value - session = _mtls.create_session(mock_mds_mtls_config) - assert session is session_instance - mock_adapter.assert_called_once_with(mock_mds_mtls_config) - session_instance.mount.assert_called_once_with( - "https://", mock_adapter.return_value +@mock.patch("ssl.create_default_context") +@mock.patch("requests.adapters.HTTPAdapter.init_poolmanager") +def test_mds_mtls_adapter_init_poolmanager( + mock_init_poolmanager, mock_ssl_context, mock_mds_mtls_config +): + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + mock_init_poolmanager.assert_called_with( + 10, 10, block=False, ssl_context=adapter.ssl_context ) @@ -149,3 +149,23 @@ def test_mds_mtls_adapter_proxy_manager_for( mock_proxy_manager_for.assert_called_once_with( "test_proxy", ssl_context=adapter.ssl_context ) + + +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_session_request(mock_ssl_context, mock_mds_mtls_config): + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + session = requests.Session() + session.mount("https://", adapter) + + # Mock the adapter's send method to avoid actual network requests + adapter.send = mock.Mock() + response = requests.Response() + response.status_code = 200 + adapter.send.return_value = response + + # Make a request + response = session.get("https://example.com") + + # Assert that the request was successful + assert response.status_code == 200 + adapter.send.assert_called_once() From c34a17a0c04a25c4fdd56fcefeeb92e6e4ebb3c9 Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Fri, 14 Nov 2025 23:46:42 +0000 Subject: [PATCH 08/14] add fallback to mds mtls --- google/auth/compute_engine/_metadata.py | 6 +- google/auth/compute_engine/_mtls.py | 50 +++++++++++--- tests/compute_engine/test__mtls.py | 86 +++++++++++++++++++++---- 3 files changed, 116 insertions(+), 26 deletions(-) diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index dbe27ac12..b14932ea3 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -142,7 +142,7 @@ def detect_gce_residency_linux(): def _prepare_request_for_mds(request, use_mtls=False): """Prepares a request for the metadata server. - This will check if mTLS should be used and return a new request object if so. + This will check if mTLS should be used and mount the mTLS adapter if needed. Args: request (google.auth.transport.Request): A callable used to make @@ -151,8 +151,8 @@ def _prepare_request_for_mds(request, use_mtls=False): Returns: google.auth.transport.Request: A request object to use. - If mTLS is enabled, this will be a new request object with mTLS session configured. - Otherwise, it will be the same as the input request. + If mTLS is enabled, the request will have the mTLS adapter mounted. + Otherwise, the original request will be returned unchanged. """ if not use_mtls: return request diff --git a/google/auth/compute_engine/_mtls.py b/google/auth/compute_engine/_mtls.py index 4cf4c7303..53b82e7cb 100644 --- a/google/auth/compute_engine/_mtls.py +++ b/google/auth/compute_engine/_mtls.py @@ -18,34 +18,41 @@ from dataclasses import dataclass, field import enum +import logging import os +from pathlib import Path import ssl +from urllib.parse import urlparse, urlunparse +import requests from requests.adapters import HTTPAdapter from google.auth import environment_vars, exceptions + +_LOGGER = logging.getLogger(__name__) + +_WINDOWS_OS_NAME = "nt" + # MDS mTLS certificate paths based on OS. # Documentation to well known locations can be found at: # https://cloud.google.com/compute/docs/metadata/overview#https-mds-certificates +_WINDOWS_MTLS_COMPONENTS_BASE_PATH = Path("C:/ProgramData/Google/ComputeEngine") +_MTLS_COMPONENTS_BASE_PATH = Path("/run/google-mds-mtls") def _get_mds_root_crt_path(): - if os.name == "nt": - return os.path.join( - "C:\\", "ProgramData", "Google", "ComputeEngine", "mds-mtls-root.crt" - ) + if os.name == _WINDOWS_OS_NAME: + return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-root.crt" else: - return os.path.join("/", "run", "google-mds-mtls", "root.crt") + return _MTLS_COMPONENTS_BASE_PATH / "root.crt" def _get_mds_client_combined_cert_path(): - if os.name == "nt": - return os.path.join( - "C:\\", "ProgramData", "Google", "ComputeEngine", "mds-mtls-client.key" - ) + if os.name == _WINDOWS_OS_NAME: + return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-client.key" else: - return os.path.join("/", "run", "google-mds-mtls", "client.key") + return _MTLS_COMPONENTS_BASE_PATH / "client.key" @dataclass @@ -126,3 +133,26 @@ def init_poolmanager(self, *args, **kwargs): def proxy_manager_for(self, *args, **kwargs): kwargs["ssl_context"] = self.ssl_context return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs) + + def send(self, request, **kwargs): + # If we are in strict mode, always use mTLS (no HTTP fallback) + if _parse_mds_mode() == MdsMtlsMode.STRICT: + return super(MdsMtlsAdapter, self).send(request, **kwargs) + + # In default mode, attempt mTLS first, then fallback to HTTP on failure + try: + return super(MdsMtlsAdapter, self).send(request, **kwargs) + except (ssl.SSLError, requests.exceptions.SSLError) as e: + _LOGGER.warning( + "mTLS connection to Compute Engine Metadata server failed. " + "Falling back to standard HTTP. Reason: %s", + e, + ) + # Fallback to standard HTTP + parsed_original_url = urlparse(request.url) + http_fallback_url = urlunparse(parsed_original_url._replace(scheme="http")) + request.url = http_fallback_url + + # Use a standard HTTPAdapter for the fallback + http_adapter = HTTPAdapter() + return http_adapter.send(request, **kwargs) diff --git a/tests/compute_engine/test__mtls.py b/tests/compute_engine/test__mtls.py index 5c8e39818..85ae4d0d6 100644 --- a/tests/compute_engine/test__mtls.py +++ b/tests/compute_engine/test__mtls.py @@ -15,8 +15,6 @@ # limitations under the License. # -import os - import mock import pytest # type: ignore import requests @@ -35,23 +33,21 @@ def mock_mds_mtls_config(): @mock.patch("os.name", "nt") def test__MdsMtlsConfig_windows_defaults(): config = _mtls.MdsMtlsConfig() - assert config.ca_cert_path == os.path.join( - "C:\\", "ProgramData", "Google", "ComputeEngine", "mds-mtls-root.crt" + assert ( + str(config.ca_cert_path) + == "C:/ProgramData/Google/ComputeEngine/mds-mtls-root.crt" ) - assert config.client_combined_cert_path == os.path.join( - "C:\\", "ProgramData", "Google", "ComputeEngine", "mds-mtls-client.key" + assert ( + str(config.client_combined_cert_path) + == "C:/ProgramData/Google/ComputeEngine/mds-mtls-client.key" ) @mock.patch("os.name", "posix") def test__MdsMtlsConfig_non_windows_defaults(): config = _mtls.MdsMtlsConfig() - assert config.ca_cert_path == os.path.join( - "/", "run", "google-mds-mtls", "root.crt" - ) - assert config.client_combined_cert_path == os.path.join( - "/", "run", "google-mds-mtls", "client.key" - ) + assert str(config.ca_cert_path) == "/run/google-mds-mtls/root.crt" + assert str(config.client_combined_cert_path) == "/run/google-mds-mtls/client.key" def test__parse_mds_mode_default(monkeypatch): @@ -157,7 +153,7 @@ def test_mds_mtls_adapter_session_request(mock_ssl_context, mock_mds_mtls_config session = requests.Session() session.mount("https://", adapter) - # Mock the adapter's send method to avoid actual network requests + # Mock the adapter\'s send method to avoid actual network requests adapter.send = mock.Mock() response = requests.Response() response.status_code = 200 @@ -169,3 +165,67 @@ def test_mds_mtls_adapter_session_request(mock_ssl_context, mock_mds_mtls_config # Assert that the request was successful assert response.status_code == 200 adapter.send.assert_called_once() + + +@mock.patch("google.auth.compute_engine._mtls.HTTPAdapter") +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_fallback_default_mode( + mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_class, mock_mds_mtls_config +): + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + mock_fallback_send = mock.Mock() + mock_http_adapter_class.return_value.send = mock_fallback_send + + # Simulate SSLError on the super().send() call + with mock.patch( + "requests.adapters.HTTPAdapter.send", side_effect=requests.exceptions.SSLError + ): + request = requests.Request(method="GET", url="https://example.com").prepare() + adapter.send(request) + + # Check that fallback to HTTPAdapter.send occurred + mock_http_adapter_class.assert_called_once() + mock_fallback_send.assert_called_once() + fallback_request = mock_fallback_send.call_args[0][0] + assert fallback_request.url == "http://example.com/" + + +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_no_fallback_strict_mode( + mock_ssl_context, mock_parse_mds_mode, mock_mds_mtls_config +): + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.STRICT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + # Simulate SSLError on the super().send() call + with mock.patch( + "requests.adapters.HTTPAdapter.send", side_effect=requests.exceptions.SSLError + ): + request = requests.Request(method="GET", url="https://example.com").prepare() + with pytest.raises(requests.exceptions.SSLError): + adapter.send(request) + + +@mock.patch("requests.adapters.HTTPAdapter.send") +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_no_fallback_other_exception( + mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_send, mock_mds_mtls_config +): + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + # Simulate a different exception + with mock.patch( + "requests.adapters.HTTPAdapter.send", + side_effect=requests.exceptions.ConnectionError, + ): + request = requests.Request(method="GET", url="https://example.com").prepare() + with pytest.raises(requests.exceptions.ConnectionError): + adapter.send(request) + + mock_http_adapter_send.assert_not_called() From b1d7021067559e1856206a3c1677e372eeba7d9b Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Mon, 17 Nov 2025 23:09:53 +0000 Subject: [PATCH 09/14] add fallback for http error codes from mds --- google/auth/compute_engine/_mtls.py | 10 ++++++++-- tests/compute_engine/test__mtls.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/google/auth/compute_engine/_mtls.py b/google/auth/compute_engine/_mtls.py index 53b82e7cb..db4766d13 100644 --- a/google/auth/compute_engine/_mtls.py +++ b/google/auth/compute_engine/_mtls.py @@ -141,8 +141,14 @@ def send(self, request, **kwargs): # In default mode, attempt mTLS first, then fallback to HTTP on failure try: - return super(MdsMtlsAdapter, self).send(request, **kwargs) - except (ssl.SSLError, requests.exceptions.SSLError) as e: + response = super(MdsMtlsAdapter, self).send(request, **kwargs) + response.raise_for_status() + return response + except ( + ssl.SSLError, + requests.exceptions.SSLError, + requests.exceptions.HTTPError, + ) as e: _LOGGER.warning( "mTLS connection to Compute Engine Metadata server failed. " "Falling back to standard HTTP. Reason: %s", diff --git a/tests/compute_engine/test__mtls.py b/tests/compute_engine/test__mtls.py index 85ae4d0d6..e4c83a886 100644 --- a/tests/compute_engine/test__mtls.py +++ b/tests/compute_engine/test__mtls.py @@ -193,6 +193,34 @@ def test_mds_mtls_adapter_send_fallback_default_mode( assert fallback_request.url == "http://example.com/" +@mock.patch("google.auth.compute_engine._mtls.HTTPAdapter") +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_fallback_http_error( + mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_class, mock_mds_mtls_config +): + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + mock_fallback_send = mock.Mock() + mock_http_adapter_class.return_value.send = mock_fallback_send + + # Simulate HTTPError on the super().send() call + mock_mtls_response = requests.Response() + mock_mtls_response.status_code = 404 + with mock.patch( + "requests.adapters.HTTPAdapter.send", return_value=mock_mtls_response + ): + request = requests.Request(method="GET", url="https://example.com").prepare() + adapter.send(request) + + # Check that fallback to HTTPAdapter.send occurred + mock_http_adapter_class.assert_called_once() + mock_fallback_send.assert_called_once() + fallback_request = mock_fallback_send.call_args[0][0] + assert fallback_request.url == "http://example.com/" + + @mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") @mock.patch("ssl.create_default_context") def test_mds_mtls_adapter_send_no_fallback_strict_mode( From e01ea09c29684f54a929210daeb17fc72b2485bc Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Tue, 18 Nov 2025 11:43:15 -0800 Subject: [PATCH 10/14] Add more unit tests to address coverage requirements --- tests/compute_engine/test__mtls.py | 84 +++++++++++++++++++----------- 1 file changed, 55 insertions(+), 29 deletions(-) diff --git a/tests/compute_engine/test__mtls.py b/tests/compute_engine/test__mtls.py index e4c83a886..3ef250cfc 100644 --- a/tests/compute_engine/test__mtls.py +++ b/tests/compute_engine/test__mtls.py @@ -147,24 +147,50 @@ def test_mds_mtls_adapter_proxy_manager_for( ) +@mock.patch("requests.adapters.HTTPAdapter.send") # Patch the PARENT class method @mock.patch("ssl.create_default_context") -def test_mds_mtls_adapter_session_request(mock_ssl_context, mock_mds_mtls_config): +def test_mds_mtls_adapter_session_request( + mock_ssl_context, mock_super_send, mock_mds_mtls_config +): adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) session = requests.Session() session.mount("https://", adapter) - # Mock the adapter\'s send method to avoid actual network requests - adapter.send = mock.Mock() + # Setup the parent class send return value response = requests.Response() response.status_code = 200 - adapter.send.return_value = response + mock_super_send.return_value = response - # Make a request - response = session.get("https://example.com") + response = session.get("https://fake-mds.com") # Assert that the request was successful assert response.status_code == 200 - adapter.send.assert_called_once() + mock_super_send.assert_called_once() + + +@mock.patch("requests.adapters.HTTPAdapter.send") +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_success( + mock_ssl_context, mock_parse_mds_mode, mock_super_send, mock_mds_mtls_config +): + """Test the explicit 'happy path' where mTLS succeeds without error.""" + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + # Setup the parent class send return value to be successful (200 OK) + mock_response = requests.Response() + mock_response.status_code = 200 + mock_super_send.return_value = mock_response + + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() + + # Call send directly + response = adapter.send(request) + + # Verify we got the response back and no fallback happened + assert response == mock_response + mock_super_send.assert_called_once() @mock.patch("google.auth.compute_engine._mtls.HTTPAdapter") @@ -183,14 +209,14 @@ def test_mds_mtls_adapter_send_fallback_default_mode( with mock.patch( "requests.adapters.HTTPAdapter.send", side_effect=requests.exceptions.SSLError ): - request = requests.Request(method="GET", url="https://example.com").prepare() + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() adapter.send(request) # Check that fallback to HTTPAdapter.send occurred mock_http_adapter_class.assert_called_once() mock_fallback_send.assert_called_once() fallback_request = mock_fallback_send.call_args[0][0] - assert fallback_request.url == "http://example.com/" + assert fallback_request.url == "http://fake-mds.com/" @mock.patch("google.auth.compute_engine._mtls.HTTPAdapter") @@ -211,49 +237,49 @@ def test_mds_mtls_adapter_send_fallback_http_error( with mock.patch( "requests.adapters.HTTPAdapter.send", return_value=mock_mtls_response ): - request = requests.Request(method="GET", url="https://example.com").prepare() + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() adapter.send(request) # Check that fallback to HTTPAdapter.send occurred mock_http_adapter_class.assert_called_once() mock_fallback_send.assert_called_once() fallback_request = mock_fallback_send.call_args[0][0] - assert fallback_request.url == "http://example.com/" + assert fallback_request.url == "http://fake-mds.com/" +@mock.patch("requests.adapters.HTTPAdapter.send") @mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") @mock.patch("ssl.create_default_context") -def test_mds_mtls_adapter_send_no_fallback_strict_mode( - mock_ssl_context, mock_parse_mds_mode, mock_mds_mtls_config +def test_mds_mtls_adapter_send_no_fallback_other_exception( + mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_send, mock_mds_mtls_config ): - mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.STRICT + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) - # Simulate SSLError on the super().send() call + # Simulate HTTP exception with mock.patch( - "requests.adapters.HTTPAdapter.send", side_effect=requests.exceptions.SSLError + "requests.adapters.HTTPAdapter.send", + side_effect=requests.exceptions.ConnectionError, ): - request = requests.Request(method="GET", url="https://example.com").prepare() - with pytest.raises(requests.exceptions.SSLError): + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() + with pytest.raises(requests.exceptions.ConnectionError): adapter.send(request) + mock_http_adapter_send.assert_not_called() + -@mock.patch("requests.adapters.HTTPAdapter.send") @mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") @mock.patch("ssl.create_default_context") -def test_mds_mtls_adapter_send_no_fallback_other_exception( - mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_send, mock_mds_mtls_config +def test_mds_mtls_adapter_send_no_fallback_strict_mode( + mock_ssl_context, mock_parse_mds_mode, mock_mds_mtls_config ): - mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.STRICT adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) - # Simulate a different exception + # Simulate SSLError on the super().send() call with mock.patch( - "requests.adapters.HTTPAdapter.send", - side_effect=requests.exceptions.ConnectionError, + "requests.adapters.HTTPAdapter.send", side_effect=requests.exceptions.SSLError ): - request = requests.Request(method="GET", url="https://example.com").prepare() - with pytest.raises(requests.exceptions.ConnectionError): + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() + with pytest.raises(requests.exceptions.SSLError): adapter.send(request) - - mock_http_adapter_send.assert_not_called() From c3caa16ea3dc907cc2406d7f4479b31237725425 Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Tue, 18 Nov 2025 11:57:49 -0800 Subject: [PATCH 11/14] update docstrings --- google/auth/compute_engine/_metadata.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index b14932ea3..aa4c8915a 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -55,6 +55,10 @@ def _validate_gce_mds_configured_environment(): """Validates the GCE metadata server environment configuration for mTLS. + mTLS is only supported when connecting to the default metadata host. + If we are in strict mode (which requires mTLS), ensure that the metadata host + has not been overridden (which means mTLS will fail). + Raises: google.auth.exceptions.MutualTLSChannelError: if the environment configuration is invalid for mTLS. @@ -236,7 +240,8 @@ def get( HTTP requests. path (str): The resource to retrieve. For example, ``'instance/service-accounts/default'``. - root (str): The full path to the metadata server root. + root (Optional[str]): The full path to the metadata server root. If not + provided, the default root will be used. params (Optional[Mapping[str, str]]): A mapping of query parameter keys to values. recursive (bool): Whether to do a recursive query of metadata. See @@ -257,6 +262,10 @@ def get( Raises: google.auth.exceptions.TransportError: if an error occurred while retrieving metadata. + google.auth.exceptions.MutualTLSChannelError: if the environment + configuration is invalid for mTLS (for example, the metadata host + has been overridden in strict mTLS mode). + """ use_mtls = _mtls.should_use_mds_mtls() # Prepare the request object for mTLS if needed. @@ -265,6 +274,10 @@ def get( if root is None: root = _get_metadata_root(use_mtls) + + # mTLS is only supported when connecting to the default metadata host. + # If we are in strict mode (which requires mTLS), ensure that the metadata host + # has not been overridden (which means mTLS will fail). _validate_gce_mds_configured_environment() base_url = urljoin(root, path) From e0b31e81dc25c29669885de02304bc78f4fcd6ac Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Tue, 18 Nov 2025 13:32:51 -0800 Subject: [PATCH 12/14] update MdsMtlsConfig fields to be Path type --- google/auth/compute_engine/_mtls.py | 4 ++-- tests/compute_engine/test__mtls.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/google/auth/compute_engine/_mtls.py b/google/auth/compute_engine/_mtls.py index db4766d13..6525dd03e 100644 --- a/google/auth/compute_engine/_mtls.py +++ b/google/auth/compute_engine/_mtls.py @@ -57,10 +57,10 @@ def _get_mds_client_combined_cert_path(): @dataclass class MdsMtlsConfig: - ca_cert_path: str = field( + ca_cert_path: Path = field( default_factory=_get_mds_root_crt_path ) # path to CA certificate - client_combined_cert_path: str = field( + client_combined_cert_path: Path = field( default_factory=_get_mds_client_combined_cert_path ) # path to file containing client certificate and key diff --git a/tests/compute_engine/test__mtls.py b/tests/compute_engine/test__mtls.py index 3ef250cfc..fdd61a07d 100644 --- a/tests/compute_engine/test__mtls.py +++ b/tests/compute_engine/test__mtls.py @@ -15,6 +15,8 @@ # limitations under the License. # +from pathlib import Path + import mock import pytest # type: ignore import requests @@ -26,7 +28,8 @@ @pytest.fixture def mock_mds_mtls_config(): return _mtls.MdsMtlsConfig( - ca_cert_path="/fake/ca.crt", client_combined_cert_path="/fake/client.key" + ca_cert_path=Path("/fake/ca.crt"), + client_combined_cert_path=Path("/fake/client.key"), ) From eaf84359f1ffea47d7e837de17eccd63efe99978 Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Wed, 19 Nov 2025 18:51:13 +0000 Subject: [PATCH 13/14] address feedback --- google/auth/compute_engine/_metadata.py | 37 +++++++++++++------------ tests/compute_engine/test__metadata.py | 16 +++++------ 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index aa4c8915a..969136ed2 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -55,9 +55,9 @@ def _validate_gce_mds_configured_environment(): """Validates the GCE metadata server environment configuration for mTLS. - mTLS is only supported when connecting to the default metadata host. + mTLS is only supported when connecting to the default metadata server hosts. If we are in strict mode (which requires mTLS), ensure that the metadata host - has not been overridden (which means mTLS will fail). + has not been overridden to a custom value (which means mTLS will fail). Raises: google.auth.exceptions.MutualTLSChannelError: if the environment @@ -65,10 +65,10 @@ def _validate_gce_mds_configured_environment(): """ mode = _mtls._parse_mds_mode() if mode == _mtls.MdsMtlsMode.STRICT: - if _GCE_METADATA_HOST != _GCE_DEFAULT_HOST: - # mTLS is only supported when connecting to the default metadata host. - # Raise an exception if we are in strict mode (which requires mTLS) - # but the metadata host has been overridden. (which means mTLS will fail) + # mTLS is only supported when connecting to the default metadata host. + # Raise an exception if we are in strict mode (which requires mTLS) + # but the metadata host has been overridden to a custom MDS. (which means mTLS will fail) + if _GCE_METADATA_HOST not in _GCE_DEFAULT_MDS_HOSTS: raise exceptions.MutualTLSChannelError( "Mutual TLS is required, but the metadata host has been overridden. " "mTLS is only supported when connecting to the default metadata host." @@ -143,7 +143,7 @@ def detect_gce_residency_linux(): return content.startswith(_GOOGLE) -def _prepare_request_for_mds(request, use_mtls=False): +def _prepare_request_for_mds(request, use_mtls=False) -> None: """Prepares a request for the metadata server. This will check if mTLS should be used and mount the mTLS adapter if needed. @@ -158,15 +158,16 @@ def _prepare_request_for_mds(request, use_mtls=False): If mTLS is enabled, the request will have the mTLS adapter mounted. Otherwise, the original request will be returned unchanged. """ - if not use_mtls: - return request + # Only modify the request if mTLS is enabled. + if use_mtls: + # Ensure the request has a session to mount the adapter to. + if not request.session: + request.session = requests.Session() - adapter = _mtls.MdsMtlsAdapter() - if not request.session: - request.session = requests.Session() - for host in _GCE_DEFAULT_MDS_HOSTS: - request.session.mount(f"https://{host}/", adapter) - return request + adapter = _mtls.MdsMtlsAdapter() + # Mount the adapter for all default GCE metadata hosts. + for host in _GCE_DEFAULT_MDS_HOSTS: + request.session.mount(f"https://{host}/", adapter) def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): @@ -183,7 +184,7 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): bool: True if the metadata server is reachable, False otherwise. """ use_mtls = _mtls.should_use_mds_mtls() - request = _prepare_request_for_mds(request, use_mtls=use_mtls) + _prepare_request_for_mds(request, use_mtls=use_mtls) # NOTE: The explicit ``timeout`` is a workaround. The underlying # issue is that resolving an unknown host on some networks will take # 20-30 seconds; making this timeout short fixes the issue, but @@ -270,14 +271,14 @@ def get( use_mtls = _mtls.should_use_mds_mtls() # Prepare the request object for mTLS if needed. # This will create a new request object with the mTLS session. - request = _prepare_request_for_mds(request, use_mtls=use_mtls) + _prepare_request_for_mds(request, use_mtls=use_mtls) if root is None: root = _get_metadata_root(use_mtls) # mTLS is only supported when connecting to the default metadata host. # If we are in strict mode (which requires mTLS), ensure that the metadata host - # has not been overridden (which means mTLS will fail). + # has not been overridden to a non-default host value (which means mTLS will fail). _validate_gce_mds_configured_environment() base_url = urljoin(root, path) diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index cd15ffe51..adb63f667 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -722,8 +722,8 @@ def test__prepare_request_for_mds_mtls(mock_mds_mtls_adapter): def test__prepare_request_for_mds_no_mtls(): request = mock.Mock() - new_request = _metadata._prepare_request_for_mds(request, use_mtls=False) - assert new_request is request + _metadata._prepare_request_for_mds(request, use_mtls=False) + request.session.mount.assert_not_called() @mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE) @@ -776,9 +776,11 @@ def test_get_mtls(mock_request, mock_should_use_mtls, mock_mds_mtls_adapter): "mds_mode, metadata_host, expect_exception", [ (_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_HOST, False), + (_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_MDS_IP, False), (_metadata._mtls.MdsMtlsMode.STRICT, "custom.host", True), (_metadata._mtls.MdsMtlsMode.NONE, "custom.host", False), (_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_HOST, False), + (_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_MDS_IP, False), ], ) @mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") @@ -801,11 +803,10 @@ def test_validate_gce_mds_configured_environment( def test__prepare_request_for_mds_mtls_session_exists(mock_mds_mtls_adapter): mock_session = mock.create_autospec(requests.Session) request = google_auth_requests.Request(mock_session) - new_request = _metadata._prepare_request_for_mds(request, use_mtls=True) + _metadata._prepare_request_for_mds(request, use_mtls=True) mock_mds_mtls_adapter.assert_called_once() assert mock_session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) - assert new_request is request @mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") @@ -815,11 +816,8 @@ def test__prepare_request_for_mds_mtls_no_session(mock_mds_mtls_adapter): request.session = None with mock.patch("requests.Session") as mock_session_class: - new_request = _metadata._prepare_request_for_mds(request, use_mtls=True) + _metadata._prepare_request_for_mds(request, use_mtls=True) mock_session_class.assert_called_once() mock_mds_mtls_adapter.assert_called_once() - assert new_request.session.mount.call_count == len( - _metadata._GCE_DEFAULT_MDS_HOSTS - ) - assert new_request is request + assert request.session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) From 9f4bbaa5880225bf9cb46333b8d2a895bc13ee50 Mon Sep 17 00:00:00 2001 From: Nolan Eastin Date: Wed, 19 Nov 2025 19:20:43 +0000 Subject: [PATCH 14/14] minor modification to docstring --- google/auth/compute_engine/_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index 969136ed2..96f1ff526 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -263,7 +263,7 @@ def get( Raises: google.auth.exceptions.TransportError: if an error occurred while retrieving metadata. - google.auth.exceptions.MutualTLSChannelError: if the environment + google.auth.exceptions.MutualTLSChannelError: if using mtls and the environment configuration is invalid for mTLS (for example, the metadata host has been overridden in strict mTLS mode).