-
Notifications
You must be signed in to change notification settings - Fork 346
feat: MDS connections use mTLS #1856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
02657fa
b49aad1
a6a2cc3
6e00ea4
df1378e
e3a311a
5bbac84
c34a17a
b1d7021
e01ea09
c3caa16
e0b31e8
eaf8435
9f4bbaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,31 +24,72 @@ | |
| import os | ||
| from urllib.parse import urljoin | ||
|
|
||
| import requests | ||
|
|
||
| from google.auth import _helpers | ||
| from google.auth import environment_vars | ||
| from google.auth import exceptions | ||
| from google.auth import metrics | ||
| from google.auth import transport | ||
| from google.auth._exponential_backoff import ExponentialBackoff | ||
| from google.auth.compute_engine import _mtls | ||
|
|
||
|
|
||
| _LOGGER = logging.getLogger(__name__) | ||
|
|
||
| _GCE_DEFAULT_MDS_IP = "169.254.169.254" | ||
| _GCE_DEFAULT_HOST = "metadata.google.internal" | ||
| _GCE_DEFAULT_MDS_HOSTS = [_GCE_DEFAULT_HOST, _GCE_DEFAULT_MDS_IP] | ||
|
|
||
| # Environment variable GCE_METADATA_HOST is originally named | ||
| # GCE_METADATA_ROOT. For compatibility reasons, here it checks | ||
| # the new variable first; if not set, the system falls back | ||
| # to the old variable. | ||
| _GCE_METADATA_HOST = os.getenv(environment_vars.GCE_METADATA_HOST, None) | ||
| if not _GCE_METADATA_HOST: | ||
| _GCE_METADATA_HOST = os.getenv( | ||
| environment_vars.GCE_METADATA_ROOT, "metadata.google.internal" | ||
| environment_vars.GCE_METADATA_ROOT, _GCE_DEFAULT_HOST | ||
| ) | ||
|
|
||
|
|
||
| def _validate_gce_mds_configured_environment(): | ||
| """Validates the GCE metadata server environment configuration for mTLS. | ||
|
|
||
| mTLS is only supported when connecting to the default metadata server hosts. | ||
| If we are in strict mode (which requires mTLS), ensure that the metadata host | ||
| has not been overridden to a custom value (which means mTLS will fail). | ||
|
|
||
| Raises: | ||
| google.auth.exceptions.MutualTLSChannelError: if the environment | ||
| configuration is invalid for mTLS. | ||
| """ | ||
| mode = _mtls._parse_mds_mode() | ||
| if mode == _mtls.MdsMtlsMode.STRICT: | ||
| # mTLS is only supported when connecting to the default metadata host. | ||
| # Raise an exception if we are in strict mode (which requires mTLS) | ||
| # but the metadata host has been overridden to a custom MDS. (which means mTLS will fail) | ||
| if _GCE_METADATA_HOST not in _GCE_DEFAULT_MDS_HOSTS: | ||
| raise exceptions.MutualTLSChannelError( | ||
| "Mutual TLS is required, but the metadata host has been overridden. " | ||
| "mTLS is only supported when connecting to the default metadata host." | ||
| ) | ||
|
|
||
|
|
||
| def _get_metadata_root(use_mtls: bool): | ||
| """Returns the metadata server root URL.""" | ||
|
|
||
| scheme = "https" if use_mtls else "http" | ||
| return "{}://{}/computeMetadata/v1/".format(scheme, _GCE_METADATA_HOST) | ||
|
|
||
|
|
||
| def _get_metadata_ip_root(use_mtls: bool): | ||
| """Returns the metadata server IP root URL.""" | ||
| scheme = "https" if use_mtls else "http" | ||
| return "{}://{}".format( | ||
| scheme, os.getenv(environment_vars.GCE_METADATA_IP, _GCE_DEFAULT_MDS_IP) | ||
| ) | ||
| _METADATA_ROOT = "http://{}/computeMetadata/v1/".format(_GCE_METADATA_HOST) | ||
|
|
||
| # This is used to ping the metadata server, it avoids the cost of a DNS | ||
| # lookup. | ||
| _METADATA_IP_ROOT = "http://{}".format( | ||
| os.getenv(environment_vars.GCE_METADATA_IP, "169.254.169.254") | ||
| ) | ||
|
|
||
| _METADATA_FLAVOR_HEADER = "metadata-flavor" | ||
| _METADATA_FLAVOR_VALUE = "Google" | ||
| _METADATA_HEADERS = {_METADATA_FLAVOR_HEADER: _METADATA_FLAVOR_VALUE} | ||
|
|
@@ -102,6 +143,33 @@ def detect_gce_residency_linux(): | |
| return content.startswith(_GOOGLE) | ||
|
|
||
|
|
||
| def _prepare_request_for_mds(request, use_mtls=False) -> None: | ||
| """Prepares a request for the metadata server. | ||
|
|
||
| This will check if mTLS should be used and mount the mTLS adapter if needed. | ||
|
|
||
| Args: | ||
| request (google.auth.transport.Request): A callable used to make | ||
| HTTP requests. | ||
| use_mtls (bool): Whether to use mTLS for the request. | ||
|
|
||
| Returns: | ||
| google.auth.transport.Request: A request object to use. | ||
| If mTLS is enabled, the request will have the mTLS adapter mounted. | ||
| Otherwise, the original request will be returned unchanged. | ||
| """ | ||
| # Only modify the request if mTLS is enabled. | ||
| if use_mtls: | ||
| # Ensure the request has a session to mount the adapter to. | ||
| if not request.session: | ||
| request.session = requests.Session() | ||
|
|
||
| adapter = _mtls.MdsMtlsAdapter() | ||
| # Mount the adapter for all default GCE metadata hosts. | ||
| for host in _GCE_DEFAULT_MDS_HOSTS: | ||
| request.session.mount(f"https://{host}/", adapter) | ||
|
|
||
|
|
||
| def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): | ||
| """Checks to see if the metadata server is available. | ||
|
|
||
|
|
@@ -115,6 +183,8 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): | |
| Returns: | ||
| bool: True if the metadata server is reachable, False otherwise. | ||
| """ | ||
| use_mtls = _mtls.should_use_mds_mtls() | ||
| _prepare_request_for_mds(request, use_mtls=use_mtls) | ||
| # NOTE: The explicit ``timeout`` is a workaround. The underlying | ||
| # issue is that resolving an unknown host on some networks will take | ||
| # 20-30 seconds; making this timeout short fixes the issue, but | ||
|
|
@@ -129,7 +199,10 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): | |
| for attempt in backoff: | ||
| try: | ||
| response = request( | ||
| url=_METADATA_IP_ROOT, method="GET", headers=headers, timeout=timeout | ||
| url=_get_metadata_ip_root(use_mtls), | ||
| method="GET", | ||
| headers=headers, | ||
| timeout=timeout, | ||
| ) | ||
|
|
||
| metadata_flavor = response.headers.get(_METADATA_FLAVOR_HEADER) | ||
|
|
@@ -153,7 +226,7 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): | |
| def get( | ||
| request, | ||
| path, | ||
| root=_METADATA_ROOT, | ||
| root=None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the docstring for this says
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| params=None, | ||
| recursive=False, | ||
| retry_count=5, | ||
|
|
@@ -168,7 +241,8 @@ def get( | |
| HTTP requests. | ||
| path (str): The resource to retrieve. For example, | ||
| ``'instance/service-accounts/default'``. | ||
| root (str): The full path to the metadata server root. | ||
| root (Optional[str]): The full path to the metadata server root. If not | ||
| provided, the default root will be used. | ||
| params (Optional[Mapping[str, str]]): A mapping of query parameter | ||
| keys to values. | ||
| recursive (bool): Whether to do a recursive query of metadata. See | ||
|
|
@@ -189,7 +263,24 @@ def get( | |
| Raises: | ||
| google.auth.exceptions.TransportError: if an error occurred while | ||
| retrieving metadata. | ||
| google.auth.exceptions.MutualTLSChannelError: if using mtls and the environment | ||
| configuration is invalid for mTLS (for example, the metadata host | ||
| has been overridden in strict mTLS mode). | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this is only if mtls is requested and the configuration is invalid, right?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes that's right. updated that. |
||
|
|
||
| """ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Raises doesn't mention MutualTLSChannelError
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| use_mtls = _mtls.should_use_mds_mtls() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we determine whether to use mtls or not at cred initialization time and put it in an attribute instead of determining each time? Do we expect the value to change after cred initilization? |
||
| # Prepare the request object for mTLS if needed. | ||
| # This will create a new request object with the mTLS session. | ||
| _prepare_request_for_mds(request, use_mtls=use_mtls) | ||
|
|
||
| if root is None: | ||
| root = _get_metadata_root(use_mtls) | ||
|
|
||
| # mTLS is only supported when connecting to the default metadata host. | ||
| # If we are in strict mode (which requires mTLS), ensure that the metadata host | ||
| # has not been overridden to a non-default host value (which means mTLS will fail). | ||
| _validate_gce_mds_configured_environment() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe add a comment explaining what this does
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
|
||
| base_url = urljoin(root, path) | ||
| query_params = {} if params is None else params | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| # -*- coding: utf-8 -*- | ||
| # | ||
| # Copyright 2024 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
| """Mutual TLS for Google Compute Engine metadata server.""" | ||
|
|
||
| from dataclasses import dataclass, field | ||
| import enum | ||
| import logging | ||
| import os | ||
| from pathlib import Path | ||
| import ssl | ||
| from urllib.parse import urlparse, urlunparse | ||
|
|
||
| import requests | ||
| from requests.adapters import HTTPAdapter | ||
|
|
||
| from google.auth import environment_vars, exceptions | ||
|
|
||
|
|
||
| _LOGGER = logging.getLogger(__name__) | ||
|
|
||
| _WINDOWS_OS_NAME = "nt" | ||
|
|
||
| # MDS mTLS certificate paths based on OS. | ||
| # Documentation to well known locations can be found at: | ||
| # https://cloud.google.com/compute/docs/metadata/overview#https-mds-certificates | ||
| _WINDOWS_MTLS_COMPONENTS_BASE_PATH = Path("C:/ProgramData/Google/ComputeEngine") | ||
| _MTLS_COMPONENTS_BASE_PATH = Path("/run/google-mds-mtls") | ||
|
|
||
|
|
||
| def _get_mds_root_crt_path(): | ||
| if os.name == _WINDOWS_OS_NAME: | ||
| return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-root.crt" | ||
| else: | ||
| return _MTLS_COMPONENTS_BASE_PATH / "root.crt" | ||
|
|
||
|
|
||
| def _get_mds_client_combined_cert_path(): | ||
| if os.name == _WINDOWS_OS_NAME: | ||
| return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-client.key" | ||
| else: | ||
| return _MTLS_COMPONENTS_BASE_PATH / "client.key" | ||
|
|
||
|
|
||
| @dataclass | ||
| class MdsMtlsConfig: | ||
| ca_cert_path: Path = field( | ||
| default_factory=_get_mds_root_crt_path | ||
| ) # path to CA certificate | ||
| client_combined_cert_path: Path = field( | ||
| default_factory=_get_mds_client_combined_cert_path | ||
| ) # path to file containing client certificate and key | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these are marked as str, but it looks like they are paths now? Are you running
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I have been running mypy checks! Not sure why this isn't flagged though.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either way, updated |
||
|
|
||
|
|
||
| def _certs_exist(mds_mtls_config: MdsMtlsConfig): | ||
| """Checks if the mTLS certificates exist.""" | ||
| return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists( | ||
| mds_mtls_config.client_combined_cert_path | ||
| ) | ||
|
|
||
|
|
||
| class MdsMtlsMode(enum.Enum): | ||
| """MDS mTLS mode. Used to configure connection behavior when connecting to MDS. | ||
|
|
||
| STRICT: Always use HTTPS/mTLS. If certificates are not found locally, an error will be returned. | ||
| NONE: Never use mTLS. Requests will use regular HTTP. | ||
| DEFAULT: Use mTLS if certificates are found locally, otherwise use regular HTTP. | ||
| """ | ||
|
|
||
| STRICT = "strict" | ||
sai-sunder-s marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| NONE = "none" | ||
| DEFAULT = "default" | ||
|
|
||
|
|
||
| def _parse_mds_mode(): | ||
| """Parses the GCE_METADATA_MTLS_MODE environment variable.""" | ||
| mode_str = os.environ.get( | ||
| environment_vars.GCE_METADATA_MTLS_MODE, "default" | ||
| ).lower() | ||
| try: | ||
| return MdsMtlsMode(mode_str) | ||
| except ValueError: | ||
| raise ValueError( | ||
| "Invalid value for GCE_METADATA_MTLS_MODE. Must be one of 'strict', 'none', or 'default'." | ||
| ) | ||
|
|
||
|
|
||
| def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()): | ||
| """Determines if mTLS should be used for the metadata server.""" | ||
| mode = _parse_mds_mode() | ||
| if mode == MdsMtlsMode.STRICT: | ||
| if not _certs_exist(mds_mtls_config): | ||
| raise exceptions.MutualTLSChannelError( | ||
| "mTLS certificates not found in strict mode." | ||
| ) | ||
| return True | ||
| elif mode == MdsMtlsMode.NONE: | ||
| return False | ||
| else: # Default mode | ||
| return _certs_exist(mds_mtls_config) | ||
|
|
||
|
|
||
| class MdsMtlsAdapter(HTTPAdapter): | ||
| """An HTTP adapter that uses mTLS for the metadata server.""" | ||
|
|
||
| def __init__( | ||
| self, mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig(), *args, **kwargs | ||
| ): | ||
| self.ssl_context = ssl.create_default_context() | ||
| self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path) | ||
sai-sunder-s marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.ssl_context.load_cert_chain( | ||
| certfile=mds_mtls_config.client_combined_cert_path | ||
| ) | ||
| super(MdsMtlsAdapter, self).__init__(*args, **kwargs) | ||
|
|
||
| def init_poolmanager(self, *args, **kwargs): | ||
| kwargs["ssl_context"] = self.ssl_context | ||
| return super(MdsMtlsAdapter, self).init_poolmanager(*args, **kwargs) | ||
|
|
||
| def proxy_manager_for(self, *args, **kwargs): | ||
| kwargs["ssl_context"] = self.ssl_context | ||
| return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs) | ||
|
|
||
| def send(self, request, **kwargs): | ||
| # If we are in strict mode, always use mTLS (no HTTP fallback) | ||
| if _parse_mds_mode() == MdsMtlsMode.STRICT: | ||
| return super(MdsMtlsAdapter, self).send(request, **kwargs) | ||
|
|
||
| # In default mode, attempt mTLS first, then fallback to HTTP on failure | ||
| try: | ||
| response = super(MdsMtlsAdapter, self).send(request, **kwargs) | ||
| response.raise_for_status() | ||
| return response | ||
| except ( | ||
| ssl.SSLError, | ||
| requests.exceptions.SSLError, | ||
| requests.exceptions.HTTPError, | ||
| ) as e: | ||
| _LOGGER.warning( | ||
| "mTLS connection to Compute Engine Metadata server failed. " | ||
| "Falling back to standard HTTP. Reason: %s", | ||
| e, | ||
| ) | ||
| # Fallback to standard HTTP | ||
| parsed_original_url = urlparse(request.url) | ||
| http_fallback_url = urlunparse(parsed_original_url._replace(scheme="http")) | ||
| request.url = http_fallback_url | ||
|
|
||
| # Use a standard HTTPAdapter for the fallback | ||
| http_adapter = HTTPAdapter() | ||
| return http_adapter.send(request, **kwargs) | ||
Uh oh!
There was an error while loading. Please reload this page.