diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index ddbe8ac2f..96f1ff526 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -24,15 +24,23 @@ import os from urllib.parse import urljoin +import requests + from google.auth import _helpers from google.auth import environment_vars from google.auth import exceptions from google.auth import metrics from google.auth import transport from google.auth._exponential_backoff import ExponentialBackoff +from google.auth.compute_engine import _mtls + _LOGGER = logging.getLogger(__name__) +_GCE_DEFAULT_MDS_IP = "169.254.169.254" +_GCE_DEFAULT_HOST = "metadata.google.internal" +_GCE_DEFAULT_MDS_HOSTS = [_GCE_DEFAULT_HOST, _GCE_DEFAULT_MDS_IP] + # Environment variable GCE_METADATA_HOST is originally named # GCE_METADATA_ROOT. For compatibility reasons, here it checks # the new variable first; if not set, the system falls back @@ -40,15 +48,48 @@ _GCE_METADATA_HOST = os.getenv(environment_vars.GCE_METADATA_HOST, None) if not _GCE_METADATA_HOST: _GCE_METADATA_HOST = os.getenv( - environment_vars.GCE_METADATA_ROOT, "metadata.google.internal" + environment_vars.GCE_METADATA_ROOT, _GCE_DEFAULT_HOST + ) + + +def _validate_gce_mds_configured_environment(): + """Validates the GCE metadata server environment configuration for mTLS. + + mTLS is only supported when connecting to the default metadata server hosts. + If we are in strict mode (which requires mTLS), ensure that the metadata host + has not been overridden to a custom value (which means mTLS will fail). + + Raises: + google.auth.exceptions.MutualTLSChannelError: if the environment + configuration is invalid for mTLS. + """ + mode = _mtls._parse_mds_mode() + if mode == _mtls.MdsMtlsMode.STRICT: + # mTLS is only supported when connecting to the default metadata host. + # Raise an exception if we are in strict mode (which requires mTLS) + # but the metadata host has been overridden to a custom MDS. (which means mTLS will fail) + if _GCE_METADATA_HOST not in _GCE_DEFAULT_MDS_HOSTS: + raise exceptions.MutualTLSChannelError( + "Mutual TLS is required, but the metadata host has been overridden. " + "mTLS is only supported when connecting to the default metadata host." + ) + + +def _get_metadata_root(use_mtls: bool): + """Returns the metadata server root URL.""" + + scheme = "https" if use_mtls else "http" + return "{}://{}/computeMetadata/v1/".format(scheme, _GCE_METADATA_HOST) + + +def _get_metadata_ip_root(use_mtls: bool): + """Returns the metadata server IP root URL.""" + scheme = "https" if use_mtls else "http" + return "{}://{}".format( + scheme, os.getenv(environment_vars.GCE_METADATA_IP, _GCE_DEFAULT_MDS_IP) ) -_METADATA_ROOT = "http://{}/computeMetadata/v1/".format(_GCE_METADATA_HOST) -# This is used to ping the metadata server, it avoids the cost of a DNS -# lookup. -_METADATA_IP_ROOT = "http://{}".format( - os.getenv(environment_vars.GCE_METADATA_IP, "169.254.169.254") -) + _METADATA_FLAVOR_HEADER = "metadata-flavor" _METADATA_FLAVOR_VALUE = "Google" _METADATA_HEADERS = {_METADATA_FLAVOR_HEADER: _METADATA_FLAVOR_VALUE} @@ -102,6 +143,33 @@ def detect_gce_residency_linux(): return content.startswith(_GOOGLE) +def _prepare_request_for_mds(request, use_mtls=False) -> None: + """Prepares a request for the metadata server. + + This will check if mTLS should be used and mount the mTLS adapter if needed. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + use_mtls (bool): Whether to use mTLS for the request. + + Returns: + google.auth.transport.Request: A request object to use. + If mTLS is enabled, the request will have the mTLS adapter mounted. + Otherwise, the original request will be returned unchanged. + """ + # Only modify the request if mTLS is enabled. + if use_mtls: + # Ensure the request has a session to mount the adapter to. + if not request.session: + request.session = requests.Session() + + adapter = _mtls.MdsMtlsAdapter() + # Mount the adapter for all default GCE metadata hosts. + for host in _GCE_DEFAULT_MDS_HOSTS: + request.session.mount(f"https://{host}/", adapter) + + def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): """Checks to see if the metadata server is available. @@ -115,6 +183,8 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): Returns: bool: True if the metadata server is reachable, False otherwise. """ + use_mtls = _mtls.should_use_mds_mtls() + _prepare_request_for_mds(request, use_mtls=use_mtls) # NOTE: The explicit ``timeout`` is a workaround. The underlying # issue is that resolving an unknown host on some networks will take # 20-30 seconds; making this timeout short fixes the issue, but @@ -129,7 +199,10 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): for attempt in backoff: try: response = request( - url=_METADATA_IP_ROOT, method="GET", headers=headers, timeout=timeout + url=_get_metadata_ip_root(use_mtls), + method="GET", + headers=headers, + timeout=timeout, ) metadata_flavor = response.headers.get(_METADATA_FLAVOR_HEADER) @@ -153,7 +226,7 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): def get( request, path, - root=_METADATA_ROOT, + root=None, params=None, recursive=False, retry_count=5, @@ -168,7 +241,8 @@ def get( HTTP requests. path (str): The resource to retrieve. For example, ``'instance/service-accounts/default'``. - root (str): The full path to the metadata server root. + root (Optional[str]): The full path to the metadata server root. If not + provided, the default root will be used. params (Optional[Mapping[str, str]]): A mapping of query parameter keys to values. recursive (bool): Whether to do a recursive query of metadata. See @@ -189,7 +263,24 @@ def get( Raises: google.auth.exceptions.TransportError: if an error occurred while retrieving metadata. + google.auth.exceptions.MutualTLSChannelError: if using mtls and the environment + configuration is invalid for mTLS (for example, the metadata host + has been overridden in strict mTLS mode). + """ + use_mtls = _mtls.should_use_mds_mtls() + # Prepare the request object for mTLS if needed. + # This will create a new request object with the mTLS session. + _prepare_request_for_mds(request, use_mtls=use_mtls) + + if root is None: + root = _get_metadata_root(use_mtls) + + # mTLS is only supported when connecting to the default metadata host. + # If we are in strict mode (which requires mTLS), ensure that the metadata host + # has not been overridden to a non-default host value (which means mTLS will fail). + _validate_gce_mds_configured_environment() + base_url = urljoin(root, path) query_params = {} if params is None else params diff --git a/google/auth/compute_engine/_mtls.py b/google/auth/compute_engine/_mtls.py new file mode 100644 index 000000000..6525dd03e --- /dev/null +++ b/google/auth/compute_engine/_mtls.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Mutual TLS for Google Compute Engine metadata server.""" + +from dataclasses import dataclass, field +import enum +import logging +import os +from pathlib import Path +import ssl +from urllib.parse import urlparse, urlunparse + +import requests +from requests.adapters import HTTPAdapter + +from google.auth import environment_vars, exceptions + + +_LOGGER = logging.getLogger(__name__) + +_WINDOWS_OS_NAME = "nt" + +# MDS mTLS certificate paths based on OS. +# Documentation to well known locations can be found at: +# https://cloud.google.com/compute/docs/metadata/overview#https-mds-certificates +_WINDOWS_MTLS_COMPONENTS_BASE_PATH = Path("C:/ProgramData/Google/ComputeEngine") +_MTLS_COMPONENTS_BASE_PATH = Path("/run/google-mds-mtls") + + +def _get_mds_root_crt_path(): + if os.name == _WINDOWS_OS_NAME: + return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-root.crt" + else: + return _MTLS_COMPONENTS_BASE_PATH / "root.crt" + + +def _get_mds_client_combined_cert_path(): + if os.name == _WINDOWS_OS_NAME: + return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-client.key" + else: + return _MTLS_COMPONENTS_BASE_PATH / "client.key" + + +@dataclass +class MdsMtlsConfig: + ca_cert_path: Path = field( + default_factory=_get_mds_root_crt_path + ) # path to CA certificate + client_combined_cert_path: Path = field( + default_factory=_get_mds_client_combined_cert_path + ) # path to file containing client certificate and key + + +def _certs_exist(mds_mtls_config: MdsMtlsConfig): + """Checks if the mTLS certificates exist.""" + return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists( + mds_mtls_config.client_combined_cert_path + ) + + +class MdsMtlsMode(enum.Enum): + """MDS mTLS mode. Used to configure connection behavior when connecting to MDS. + + STRICT: Always use HTTPS/mTLS. If certificates are not found locally, an error will be returned. + NONE: Never use mTLS. Requests will use regular HTTP. + DEFAULT: Use mTLS if certificates are found locally, otherwise use regular HTTP. + """ + + STRICT = "strict" + NONE = "none" + DEFAULT = "default" + + +def _parse_mds_mode(): + """Parses the GCE_METADATA_MTLS_MODE environment variable.""" + mode_str = os.environ.get( + environment_vars.GCE_METADATA_MTLS_MODE, "default" + ).lower() + try: + return MdsMtlsMode(mode_str) + except ValueError: + raise ValueError( + "Invalid value for GCE_METADATA_MTLS_MODE. Must be one of 'strict', 'none', or 'default'." + ) + + +def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()): + """Determines if mTLS should be used for the metadata server.""" + mode = _parse_mds_mode() + if mode == MdsMtlsMode.STRICT: + if not _certs_exist(mds_mtls_config): + raise exceptions.MutualTLSChannelError( + "mTLS certificates not found in strict mode." + ) + return True + elif mode == MdsMtlsMode.NONE: + return False + else: # Default mode + return _certs_exist(mds_mtls_config) + + +class MdsMtlsAdapter(HTTPAdapter): + """An HTTP adapter that uses mTLS for the metadata server.""" + + def __init__( + self, mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig(), *args, **kwargs + ): + self.ssl_context = ssl.create_default_context() + self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path) + self.ssl_context.load_cert_chain( + certfile=mds_mtls_config.client_combined_cert_path + ) + super(MdsMtlsAdapter, self).__init__(*args, **kwargs) + + def init_poolmanager(self, *args, **kwargs): + kwargs["ssl_context"] = self.ssl_context + return super(MdsMtlsAdapter, self).init_poolmanager(*args, **kwargs) + + def proxy_manager_for(self, *args, **kwargs): + kwargs["ssl_context"] = self.ssl_context + return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs) + + def send(self, request, **kwargs): + # If we are in strict mode, always use mTLS (no HTTP fallback) + if _parse_mds_mode() == MdsMtlsMode.STRICT: + return super(MdsMtlsAdapter, self).send(request, **kwargs) + + # In default mode, attempt mTLS first, then fallback to HTTP on failure + try: + response = super(MdsMtlsAdapter, self).send(request, **kwargs) + response.raise_for_status() + return response + except ( + ssl.SSLError, + requests.exceptions.SSLError, + requests.exceptions.HTTPError, + ) as e: + _LOGGER.warning( + "mTLS connection to Compute Engine Metadata server failed. " + "Falling back to standard HTTP. Reason: %s", + e, + ) + # Fallback to standard HTTP + parsed_original_url = urlparse(request.url) + http_fallback_url = urlunparse(parsed_original_url._replace(scheme="http")) + request.url = http_fallback_url + + # Use a standard HTTPAdapter for the fallback + http_adapter = HTTPAdapter() + return http_adapter.send(request, **kwargs) diff --git a/google/auth/environment_vars.py b/google/auth/environment_vars.py index e5f3598e8..5da3a7382 100644 --- a/google/auth/environment_vars.py +++ b/google/auth/environment_vars.py @@ -60,6 +60,12 @@ """Environment variable providing an alternate ip:port to be used for ip-only GCE metadata requests.""" +GCE_METADATA_MTLS_MODE = "GCE_METADATA_MTLS_MODE" +"""Environment variable controlling the mTLS behavior for GCE metadata requests. + +Can be one of "strict", "none", or "default". +""" + GOOGLE_API_USE_CLIENT_CERTIFICATE = "GOOGLE_API_USE_CLIENT_CERTIFICATE" """Environment variable controlling whether to use client certificate or not. diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index c90bc603a..adb63f667 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -20,12 +20,14 @@ import mock import pytest # type: ignore +import requests from google.auth import _helpers from google.auth import environment_vars from google.auth import exceptions from google.auth import transport from google.auth.compute_engine import _metadata +from google.auth.transport import requests as google_auth_requests PATH = "instance/service-accounts/default" @@ -104,7 +106,7 @@ def test_ping_success(mock_metrics_header_value): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_IP_ROOT, + url="http://169.254.169.254", headers=MDS_PING_REQUEST_HEADER, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -118,7 +120,7 @@ def test_ping_success_retry(mock_metrics_header_value): request.assert_called_with( method="GET", - url=_metadata._METADATA_IP_ROOT, + url="http://169.254.169.254", headers=MDS_PING_REQUEST_HEADER, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -172,7 +174,7 @@ def test_get_success_json(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -191,7 +193,7 @@ def test_get_success_json_content_type_charset(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -211,7 +213,7 @@ def test_get_success_retry(mock_sleep): request.assert_called_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -227,7 +229,7 @@ def test_get_success_text(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -243,7 +245,9 @@ def test_get_success_params(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "?recursive=true", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "?recursive=true", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -258,7 +262,9 @@ def test_get_success_recursive_and_params(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "?recursive=true", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "?recursive=true", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -273,7 +279,9 @@ def test_get_success_recursive(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "?recursive=true", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "?recursive=true", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -322,6 +330,21 @@ def test_get_success_custom_root_old_variable(): ) +def test_get_success_custom_root(): + request = make_request("{}", headers={"content-type": "application/json"}) + + fake_root = "http://another.metadata.service" + + _metadata.get(request, PATH, root=fake_root) + + request.assert_called_once_with( + method="GET", + url="{}/{}".format(fake_root, PATH), + headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, + ) + + @mock.patch("time.sleep", return_value=None) def test_get_failure(mock_sleep): request = make_request("Metadata error", status=http_client.NOT_FOUND) @@ -333,7 +356,7 @@ def test_get_failure(mock_sleep): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -346,7 +369,7 @@ def test_get_return_none_for_not_found_error(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -366,7 +389,7 @@ def test_get_failure_connection_failed(mock_sleep): request.assert_called_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -385,7 +408,7 @@ def test_get_too_many_requests_retryable_error_failure(): request.assert_called_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -402,7 +425,7 @@ def test_get_failure_bad_json(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -416,7 +439,7 @@ def test_get_project_id(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "project/project-id", + url="http://metadata.google.internal/computeMetadata/v1/project/project-id", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -432,7 +455,7 @@ def test_get_universe_domain_success(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -446,7 +469,7 @@ def test_get_universe_domain_success_empty_response(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -462,7 +485,7 @@ def test_get_universe_domain_not_found(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -483,7 +506,7 @@ def test_get_universe_domain_retryable_error_failure(): request.assert_called_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -526,13 +549,13 @@ def request(self, *args, **kwargs): request_error.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) request_ok.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -552,7 +575,7 @@ def test_get_universe_domain_other_error(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -574,7 +597,7 @@ def test_get_service_account_token(utcnow, mock_metrics_header_value): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "/token", + url="http://metadata.google.internal/computeMetadata/v1/" + PATH + "/token", headers={ "metadata-flavor": "Google", "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, @@ -601,7 +624,10 @@ def test_get_service_account_token_with_scopes_list(utcnow, mock_metrics_header_ request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "/token" + "?scopes=foo%2Cbar", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "/token" + + "?scopes=foo%2Cbar", headers={ "metadata-flavor": "Google", "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, @@ -630,7 +656,10 @@ def test_get_service_account_token_with_scopes_string( request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "/token" + "?scopes=foo%2Cbar", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "/token" + + "?scopes=foo%2Cbar", headers={ "metadata-flavor": "Google", "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, @@ -651,9 +680,144 @@ def test_get_service_account_info(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "/?recursive=true", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "/?recursive=true", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert info[key] == value + + +def test__get_metadata_root_mtls(): + assert ( + _metadata._get_metadata_root(use_mtls=True) + == "https://metadata.google.internal/computeMetadata/v1/" + ) + + +def test__get_metadata_root_no_mtls(): + assert ( + _metadata._get_metadata_root(use_mtls=False) + == "http://metadata.google.internal/computeMetadata/v1/" + ) + + +def test__get_metadata_ip_root_mtls(): + assert _metadata._get_metadata_ip_root(use_mtls=True) == "https://169.254.169.254" + + +def test__get_metadata_ip_root_no_mtls(): + assert _metadata._get_metadata_ip_root(use_mtls=False) == "http://169.254.169.254" + + +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +def test__prepare_request_for_mds_mtls(mock_mds_mtls_adapter): + request = google_auth_requests.Request(mock.create_autospec(requests.Session)) + _metadata._prepare_request_for_mds(request, use_mtls=True) + mock_mds_mtls_adapter.assert_called_once() + assert request.session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) + + +def test__prepare_request_for_mds_no_mtls(): + request = mock.Mock() + _metadata._prepare_request_for_mds(request, use_mtls=False) + request.session.mount.assert_not_called() + + +@mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE) +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True) +@mock.patch("google.auth.transport.requests.Request") +def test_ping_mtls( + mock_request, mock_should_use_mtls, mock_mds_mtls_adapter, mock_metrics_header_value +): + response = mock.create_autospec(transport.Response, instance=True) + response.status = http_client.OK + response.headers = _metadata._METADATA_HEADERS + mock_request.return_value = response + + assert _metadata.ping(mock_request) + + mock_should_use_mtls.assert_called_once() + mock_mds_mtls_adapter.assert_called_once() + mock_request.assert_called_once_with( + url="https://169.254.169.254", + method="GET", + headers=MDS_PING_REQUEST_HEADER, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, + ) + + +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True) +@mock.patch("google.auth.transport.requests.Request") +def test_get_mtls(mock_request, mock_should_use_mtls, mock_mds_mtls_adapter): + response = mock.create_autospec(transport.Response, instance=True) + response.status = http_client.OK + response.data = _helpers.to_bytes("{}") + response.headers = {"content-type": "application/json"} + mock_request.return_value = response + + _metadata.get(mock_request, "some/path") + + mock_should_use_mtls.assert_called_once() + mock_mds_mtls_adapter.assert_called_once() + mock_request.assert_called_once_with( + url="https://metadata.google.internal/computeMetadata/v1/some/path", + method="GET", + headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, + ) + + +@pytest.mark.parametrize( + "mds_mode, metadata_host, expect_exception", + [ + (_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_HOST, False), + (_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_MDS_IP, False), + (_metadata._mtls.MdsMtlsMode.STRICT, "custom.host", True), + (_metadata._mtls.MdsMtlsMode.NONE, "custom.host", False), + (_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_HOST, False), + (_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_MDS_IP, False), + ], +) +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +def test_validate_gce_mds_configured_environment( + mock_parse_mds_mode, mds_mode, metadata_host, expect_exception +): + mock_parse_mds_mode.return_value = mds_mode + with mock.patch( + "google.auth.compute_engine._metadata._GCE_METADATA_HOST", new=metadata_host + ): + if expect_exception: + with pytest.raises(exceptions.MutualTLSChannelError): + _metadata._validate_gce_mds_configured_environment() + else: + _metadata._validate_gce_mds_configured_environment() + mock_parse_mds_mode.assert_called_once() + + +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +def test__prepare_request_for_mds_mtls_session_exists(mock_mds_mtls_adapter): + mock_session = mock.create_autospec(requests.Session) + request = google_auth_requests.Request(mock_session) + _metadata._prepare_request_for_mds(request, use_mtls=True) + + mock_mds_mtls_adapter.assert_called_once() + assert mock_session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) + + +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +def test__prepare_request_for_mds_mtls_no_session(mock_mds_mtls_adapter): + request = google_auth_requests.Request(None) + # Explicitly set session to None to avoid a session being created in the Request constructor. + request.session = None + + with mock.patch("requests.Session") as mock_session_class: + _metadata._prepare_request_for_mds(request, use_mtls=True) + + mock_session_class.assert_called_once() + mock_mds_mtls_adapter.assert_called_once() + assert request.session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) diff --git a/tests/compute_engine/test__mtls.py b/tests/compute_engine/test__mtls.py new file mode 100644 index 000000000..fdd61a07d --- /dev/null +++ b/tests/compute_engine/test__mtls.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pathlib import Path + +import mock +import pytest # type: ignore +import requests + +from google.auth import environment_vars, exceptions +from google.auth.compute_engine import _mtls + + +@pytest.fixture +def mock_mds_mtls_config(): + return _mtls.MdsMtlsConfig( + ca_cert_path=Path("/fake/ca.crt"), + client_combined_cert_path=Path("/fake/client.key"), + ) + + +@mock.patch("os.name", "nt") +def test__MdsMtlsConfig_windows_defaults(): + config = _mtls.MdsMtlsConfig() + assert ( + str(config.ca_cert_path) + == "C:/ProgramData/Google/ComputeEngine/mds-mtls-root.crt" + ) + assert ( + str(config.client_combined_cert_path) + == "C:/ProgramData/Google/ComputeEngine/mds-mtls-client.key" + ) + + +@mock.patch("os.name", "posix") +def test__MdsMtlsConfig_non_windows_defaults(): + config = _mtls.MdsMtlsConfig() + assert str(config.ca_cert_path) == "/run/google-mds-mtls/root.crt" + assert str(config.client_combined_cert_path) == "/run/google-mds-mtls/client.key" + + +def test__parse_mds_mode_default(monkeypatch): + monkeypatch.delenv(environment_vars.GCE_METADATA_MTLS_MODE, raising=False) + assert _mtls._parse_mds_mode() == _mtls.MdsMtlsMode.DEFAULT + + +@pytest.mark.parametrize( + "mode_str, expected_mode", + [ + ("strict", _mtls.MdsMtlsMode.STRICT), + ("none", _mtls.MdsMtlsMode.NONE), + ("default", _mtls.MdsMtlsMode.DEFAULT), + ("STRICT", _mtls.MdsMtlsMode.STRICT), + ], +) +def test__parse_mds_mode_valid(monkeypatch, mode_str, expected_mode): + monkeypatch.setenv(environment_vars.GCE_METADATA_MTLS_MODE, mode_str) + assert _mtls._parse_mds_mode() == expected_mode + + +def test__parse_mds_mode_invalid(monkeypatch): + monkeypatch.setenv(environment_vars.GCE_METADATA_MTLS_MODE, "invalid_mode") + with pytest.raises(ValueError): + _mtls._parse_mds_mode() + + +@mock.patch("os.path.exists") +def test__certs_exist_true(mock_exists, mock_mds_mtls_config): + mock_exists.return_value = True + assert _mtls._certs_exist(mock_mds_mtls_config) is True + + +@mock.patch("os.path.exists") +def test__certs_exist_false(mock_exists, mock_mds_mtls_config): + mock_exists.return_value = False + assert _mtls._certs_exist(mock_mds_mtls_config) is False + + +@pytest.mark.parametrize( + "mtls_mode, certs_exist, expected_result", + [ + ("strict", True, True), + ("strict", False, exceptions.MutualTLSChannelError), + ("none", True, False), + ("none", False, False), + ("default", True, True), + ("default", False, False), + ], +) +@mock.patch("os.path.exists") +def test_should_use_mds_mtls( + mock_exists, monkeypatch, mtls_mode, certs_exist, expected_result +): + monkeypatch.setenv(environment_vars.GCE_METADATA_MTLS_MODE, mtls_mode) + mock_exists.return_value = certs_exist + + if isinstance(expected_result, type) and issubclass(expected_result, Exception): + with pytest.raises(expected_result): + _mtls.should_use_mds_mtls() + else: + assert _mtls.should_use_mds_mtls() is expected_result + + +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_init(mock_ssl_context, mock_mds_mtls_config): + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + mock_ssl_context.assert_called_once() + adapter.ssl_context.load_verify_locations.assert_called_once_with( + cafile=mock_mds_mtls_config.ca_cert_path + ) + adapter.ssl_context.load_cert_chain.assert_called_once_with( + certfile=mock_mds_mtls_config.client_combined_cert_path + ) + + +@mock.patch("ssl.create_default_context") +@mock.patch("requests.adapters.HTTPAdapter.init_poolmanager") +def test_mds_mtls_adapter_init_poolmanager( + mock_init_poolmanager, mock_ssl_context, mock_mds_mtls_config +): + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + mock_init_poolmanager.assert_called_with( + 10, 10, block=False, ssl_context=adapter.ssl_context + ) + + +@mock.patch("ssl.create_default_context") +@mock.patch("requests.adapters.HTTPAdapter.proxy_manager_for") +def test_mds_mtls_adapter_proxy_manager_for( + mock_proxy_manager_for, mock_ssl_context, mock_mds_mtls_config +): + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + adapter.proxy_manager_for("test_proxy") + mock_proxy_manager_for.assert_called_once_with( + "test_proxy", ssl_context=adapter.ssl_context + ) + + +@mock.patch("requests.adapters.HTTPAdapter.send") # Patch the PARENT class method +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_session_request( + mock_ssl_context, mock_super_send, mock_mds_mtls_config +): + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + session = requests.Session() + session.mount("https://", adapter) + + # Setup the parent class send return value + response = requests.Response() + response.status_code = 200 + mock_super_send.return_value = response + + response = session.get("https://fake-mds.com") + + # Assert that the request was successful + assert response.status_code == 200 + mock_super_send.assert_called_once() + + +@mock.patch("requests.adapters.HTTPAdapter.send") +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_success( + mock_ssl_context, mock_parse_mds_mode, mock_super_send, mock_mds_mtls_config +): + """Test the explicit 'happy path' where mTLS succeeds without error.""" + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + # Setup the parent class send return value to be successful (200 OK) + mock_response = requests.Response() + mock_response.status_code = 200 + mock_super_send.return_value = mock_response + + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() + + # Call send directly + response = adapter.send(request) + + # Verify we got the response back and no fallback happened + assert response == mock_response + mock_super_send.assert_called_once() + + +@mock.patch("google.auth.compute_engine._mtls.HTTPAdapter") +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_fallback_default_mode( + mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_class, mock_mds_mtls_config +): + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + mock_fallback_send = mock.Mock() + mock_http_adapter_class.return_value.send = mock_fallback_send + + # Simulate SSLError on the super().send() call + with mock.patch( + "requests.adapters.HTTPAdapter.send", side_effect=requests.exceptions.SSLError + ): + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() + adapter.send(request) + + # Check that fallback to HTTPAdapter.send occurred + mock_http_adapter_class.assert_called_once() + mock_fallback_send.assert_called_once() + fallback_request = mock_fallback_send.call_args[0][0] + assert fallback_request.url == "http://fake-mds.com/" + + +@mock.patch("google.auth.compute_engine._mtls.HTTPAdapter") +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_fallback_http_error( + mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_class, mock_mds_mtls_config +): + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + mock_fallback_send = mock.Mock() + mock_http_adapter_class.return_value.send = mock_fallback_send + + # Simulate HTTPError on the super().send() call + mock_mtls_response = requests.Response() + mock_mtls_response.status_code = 404 + with mock.patch( + "requests.adapters.HTTPAdapter.send", return_value=mock_mtls_response + ): + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() + adapter.send(request) + + # Check that fallback to HTTPAdapter.send occurred + mock_http_adapter_class.assert_called_once() + mock_fallback_send.assert_called_once() + fallback_request = mock_fallback_send.call_args[0][0] + assert fallback_request.url == "http://fake-mds.com/" + + +@mock.patch("requests.adapters.HTTPAdapter.send") +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_no_fallback_other_exception( + mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_send, mock_mds_mtls_config +): + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + # Simulate HTTP exception + with mock.patch( + "requests.adapters.HTTPAdapter.send", + side_effect=requests.exceptions.ConnectionError, + ): + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() + with pytest.raises(requests.exceptions.ConnectionError): + adapter.send(request) + + mock_http_adapter_send.assert_not_called() + + +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_no_fallback_strict_mode( + mock_ssl_context, mock_parse_mds_mode, mock_mds_mtls_config +): + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.STRICT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + # Simulate SSLError on the super().send() call + with mock.patch( + "requests.adapters.HTTPAdapter.send", side_effect=requests.exceptions.SSLError + ): + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() + with pytest.raises(requests.exceptions.SSLError): + adapter.send(request)