Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 101 additions & 10 deletions google/auth/compute_engine/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,72 @@
import os
from urllib.parse import urljoin

import requests

from google.auth import _helpers
from google.auth import environment_vars
from google.auth import exceptions
from google.auth import metrics
from google.auth import transport
from google.auth._exponential_backoff import ExponentialBackoff
from google.auth.compute_engine import _mtls


_LOGGER = logging.getLogger(__name__)

_GCE_DEFAULT_MDS_IP = "169.254.169.254"
_GCE_DEFAULT_HOST = "metadata.google.internal"
_GCE_DEFAULT_MDS_HOSTS = [_GCE_DEFAULT_HOST, _GCE_DEFAULT_MDS_IP]

# Environment variable GCE_METADATA_HOST is originally named
# GCE_METADATA_ROOT. For compatibility reasons, here it checks
# the new variable first; if not set, the system falls back
# to the old variable.
_GCE_METADATA_HOST = os.getenv(environment_vars.GCE_METADATA_HOST, None)
if not _GCE_METADATA_HOST:
_GCE_METADATA_HOST = os.getenv(
environment_vars.GCE_METADATA_ROOT, "metadata.google.internal"
environment_vars.GCE_METADATA_ROOT, _GCE_DEFAULT_HOST
)


def _validate_gce_mds_configured_environment():
"""Validates the GCE metadata server environment configuration for mTLS.

mTLS is only supported when connecting to the default metadata server hosts.
If we are in strict mode (which requires mTLS), ensure that the metadata host
has not been overridden to a custom value (which means mTLS will fail).

Raises:
google.auth.exceptions.MutualTLSChannelError: if the environment
configuration is invalid for mTLS.
"""
mode = _mtls._parse_mds_mode()
if mode == _mtls.MdsMtlsMode.STRICT:
# mTLS is only supported when connecting to the default metadata host.
# Raise an exception if we are in strict mode (which requires mTLS)
# but the metadata host has been overridden to a custom MDS. (which means mTLS will fail)
if _GCE_METADATA_HOST not in _GCE_DEFAULT_MDS_HOSTS:
raise exceptions.MutualTLSChannelError(
"Mutual TLS is required, but the metadata host has been overridden. "
"mTLS is only supported when connecting to the default metadata host."
)


def _get_metadata_root(use_mtls: bool):
"""Returns the metadata server root URL."""

scheme = "https" if use_mtls else "http"
return "{}://{}/computeMetadata/v1/".format(scheme, _GCE_METADATA_HOST)


def _get_metadata_ip_root(use_mtls: bool):
"""Returns the metadata server IP root URL."""
scheme = "https" if use_mtls else "http"
return "{}://{}".format(
scheme, os.getenv(environment_vars.GCE_METADATA_IP, _GCE_DEFAULT_MDS_IP)
)
_METADATA_ROOT = "http://{}/computeMetadata/v1/".format(_GCE_METADATA_HOST)

# This is used to ping the metadata server, it avoids the cost of a DNS
# lookup.
_METADATA_IP_ROOT = "http://{}".format(
os.getenv(environment_vars.GCE_METADATA_IP, "169.254.169.254")
)

_METADATA_FLAVOR_HEADER = "metadata-flavor"
_METADATA_FLAVOR_VALUE = "Google"
_METADATA_HEADERS = {_METADATA_FLAVOR_HEADER: _METADATA_FLAVOR_VALUE}
Expand Down Expand Up @@ -102,6 +143,33 @@ def detect_gce_residency_linux():
return content.startswith(_GOOGLE)


def _prepare_request_for_mds(request, use_mtls=False) -> None:
"""Prepares a request for the metadata server.

This will check if mTLS should be used and mount the mTLS adapter if needed.

Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
use_mtls (bool): Whether to use mTLS for the request.

Returns:
google.auth.transport.Request: A request object to use.
If mTLS is enabled, the request will have the mTLS adapter mounted.
Otherwise, the original request will be returned unchanged.
"""
# Only modify the request if mTLS is enabled.
if use_mtls:
# Ensure the request has a session to mount the adapter to.
if not request.session:
request.session = requests.Session()

adapter = _mtls.MdsMtlsAdapter()
# Mount the adapter for all default GCE metadata hosts.
for host in _GCE_DEFAULT_MDS_HOSTS:
request.session.mount(f"https://{host}/", adapter)


def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
"""Checks to see if the metadata server is available.

Expand All @@ -115,6 +183,8 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
Returns:
bool: True if the metadata server is reachable, False otherwise.
"""
use_mtls = _mtls.should_use_mds_mtls()
_prepare_request_for_mds(request, use_mtls=use_mtls)
# NOTE: The explicit ``timeout`` is a workaround. The underlying
# issue is that resolving an unknown host on some networks will take
# 20-30 seconds; making this timeout short fixes the issue, but
Expand All @@ -129,7 +199,10 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
for attempt in backoff:
try:
response = request(
url=_METADATA_IP_ROOT, method="GET", headers=headers, timeout=timeout
url=_get_metadata_ip_root(use_mtls),
method="GET",
headers=headers,
timeout=timeout,
)

metadata_flavor = response.headers.get(_METADATA_FLAVOR_HEADER)
Expand All @@ -153,7 +226,7 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
def get(
request,
path,
root=_METADATA_ROOT,
root=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the docstring for this says root (str):. It looks like it should be root (Optional[str]):. And the docstring should mention what happens if its left as None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

params=None,
recursive=False,
retry_count=5,
Expand All @@ -168,7 +241,8 @@ def get(
HTTP requests.
path (str): The resource to retrieve. For example,
``'instance/service-accounts/default'``.
root (str): The full path to the metadata server root.
root (Optional[str]): The full path to the metadata server root. If not
provided, the default root will be used.
params (Optional[Mapping[str, str]]): A mapping of query parameter
keys to values.
recursive (bool): Whether to do a recursive query of metadata. See
Expand All @@ -189,7 +263,24 @@ def get(
Raises:
google.auth.exceptions.TransportError: if an error occurred while
retrieving metadata.
google.auth.exceptions.MutualTLSChannelError: if using mtls and the environment
configuration is invalid for mTLS (for example, the metadata host
has been overridden in strict mTLS mode).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is only if mtls is requested and the configuration is invalid, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that's right. updated that.


"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raises doesn't mention MutualTLSChannelError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

use_mtls = _mtls.should_use_mds_mtls()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we determine whether to use mtls or not at cred initialization time and put it in an attribute instead of determining each time?

Do we expect the value to change after cred initilization?

# Prepare the request object for mTLS if needed.
# This will create a new request object with the mTLS session.
_prepare_request_for_mds(request, use_mtls=use_mtls)

if root is None:
root = _get_metadata_root(use_mtls)

# mTLS is only supported when connecting to the default metadata host.
# If we are in strict mode (which requires mTLS), ensure that the metadata host
# has not been overridden to a non-default host value (which means mTLS will fail).
_validate_gce_mds_configured_environment()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe add a comment explaining what this does

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


base_url = urljoin(root, path)
query_params = {} if params is None else params

Expand Down
164 changes: 164 additions & 0 deletions google/auth/compute_engine/_mtls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# -*- coding: utf-8 -*-
#
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Mutual TLS for Google Compute Engine metadata server."""

from dataclasses import dataclass, field
import enum
import logging
import os
from pathlib import Path
import ssl
from urllib.parse import urlparse, urlunparse

import requests
from requests.adapters import HTTPAdapter

from google.auth import environment_vars, exceptions


_LOGGER = logging.getLogger(__name__)

_WINDOWS_OS_NAME = "nt"

# MDS mTLS certificate paths based on OS.
# Documentation to well known locations can be found at:
# https://cloud.google.com/compute/docs/metadata/overview#https-mds-certificates
_WINDOWS_MTLS_COMPONENTS_BASE_PATH = Path("C:/ProgramData/Google/ComputeEngine")
_MTLS_COMPONENTS_BASE_PATH = Path("/run/google-mds-mtls")


def _get_mds_root_crt_path():
if os.name == _WINDOWS_OS_NAME:
return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-root.crt"
else:
return _MTLS_COMPONENTS_BASE_PATH / "root.crt"


def _get_mds_client_combined_cert_path():
if os.name == _WINDOWS_OS_NAME:
return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-client.key"
else:
return _MTLS_COMPONENTS_BASE_PATH / "client.key"


@dataclass
class MdsMtlsConfig:
ca_cert_path: Path = field(
default_factory=_get_mds_root_crt_path
) # path to CA certificate
client_combined_cert_path: Path = field(
default_factory=_get_mds_client_combined_cert_path
) # path to file containing client certificate and key
Copy link
Collaborator

@daniel-sanche daniel-sanche Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are marked as str, but it looks like they are paths now?

Are you running nox -s mypy checks? I'm not sure if they're configured to run in CI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I have been running mypy checks! Not sure why this isn't flagged though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way, updated



def _certs_exist(mds_mtls_config: MdsMtlsConfig):
"""Checks if the mTLS certificates exist."""
return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists(
mds_mtls_config.client_combined_cert_path
)


class MdsMtlsMode(enum.Enum):
"""MDS mTLS mode. Used to configure connection behavior when connecting to MDS.

STRICT: Always use HTTPS/mTLS. If certificates are not found locally, an error will be returned.
NONE: Never use mTLS. Requests will use regular HTTP.
DEFAULT: Use mTLS if certificates are found locally, otherwise use regular HTTP.
"""

STRICT = "strict"
NONE = "none"
DEFAULT = "default"


def _parse_mds_mode():
"""Parses the GCE_METADATA_MTLS_MODE environment variable."""
mode_str = os.environ.get(
environment_vars.GCE_METADATA_MTLS_MODE, "default"
).lower()
try:
return MdsMtlsMode(mode_str)
except ValueError:
raise ValueError(
"Invalid value for GCE_METADATA_MTLS_MODE. Must be one of 'strict', 'none', or 'default'."
)


def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()):
"""Determines if mTLS should be used for the metadata server."""
mode = _parse_mds_mode()
if mode == MdsMtlsMode.STRICT:
if not _certs_exist(mds_mtls_config):
raise exceptions.MutualTLSChannelError(
"mTLS certificates not found in strict mode."
)
return True
elif mode == MdsMtlsMode.NONE:
return False
else: # Default mode
return _certs_exist(mds_mtls_config)


class MdsMtlsAdapter(HTTPAdapter):
"""An HTTP adapter that uses mTLS for the metadata server."""

def __init__(
self, mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig(), *args, **kwargs
):
self.ssl_context = ssl.create_default_context()
self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path)
self.ssl_context.load_cert_chain(
certfile=mds_mtls_config.client_combined_cert_path
)
super(MdsMtlsAdapter, self).__init__(*args, **kwargs)

def init_poolmanager(self, *args, **kwargs):
kwargs["ssl_context"] = self.ssl_context
return super(MdsMtlsAdapter, self).init_poolmanager(*args, **kwargs)

def proxy_manager_for(self, *args, **kwargs):
kwargs["ssl_context"] = self.ssl_context
return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs)

def send(self, request, **kwargs):
# If we are in strict mode, always use mTLS (no HTTP fallback)
if _parse_mds_mode() == MdsMtlsMode.STRICT:
return super(MdsMtlsAdapter, self).send(request, **kwargs)

# In default mode, attempt mTLS first, then fallback to HTTP on failure
try:
response = super(MdsMtlsAdapter, self).send(request, **kwargs)
response.raise_for_status()
return response
except (
ssl.SSLError,
requests.exceptions.SSLError,
requests.exceptions.HTTPError,
) as e:
_LOGGER.warning(
"mTLS connection to Compute Engine Metadata server failed. "
"Falling back to standard HTTP. Reason: %s",
e,
)
# Fallback to standard HTTP
parsed_original_url = urlparse(request.url)
http_fallback_url = urlunparse(parsed_original_url._replace(scheme="http"))
request.url = http_fallback_url

# Use a standard HTTPAdapter for the fallback
http_adapter = HTTPAdapter()
return http_adapter.send(request, **kwargs)
6 changes: 6 additions & 0 deletions google/auth/environment_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@
"""Environment variable providing an alternate ip:port to be used for ip-only
GCE metadata requests."""

GCE_METADATA_MTLS_MODE = "GCE_METADATA_MTLS_MODE"
"""Environment variable controlling the mTLS behavior for GCE metadata requests.

Can be one of "strict", "none", or "default".
"""

GOOGLE_API_USE_CLIENT_CERTIFICATE = "GOOGLE_API_USE_CLIENT_CERTIFICATE"
"""Environment variable controlling whether to use client certificate or not.

Expand Down
Loading