diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index fe821418e..94e4ffbf0 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -108,7 +108,9 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): return False -def get(request, path, root=_METADATA_ROOT, recursive=False, retry_count=5): +def get( + request, path, root=_METADATA_ROOT, params=None, recursive=False, retry_count=5 +): """Fetch a resource from the metadata server. Args: @@ -117,6 +119,8 @@ def get(request, path, root=_METADATA_ROOT, recursive=False, retry_count=5): path (str): The resource to retrieve. For example, ``'instance/service-accounts/default'``. root (str): The full path to the metadata server root. + params (Optional[Mapping[str, str]]): A mapping of query parameter + keys to values. recursive (bool): Whether to do a recursive query of metadata. See https://cloud.google.com/compute/docs/metadata#aggcontents for more details. @@ -133,7 +137,7 @@ def get(request, path, root=_METADATA_ROOT, recursive=False, retry_count=5): retrieving metadata. """ base_url = urlparse.urljoin(root, path) - query_params = {} + query_params = {} if params is None else params if recursive: query_params["recursive"] = "true" @@ -224,11 +228,10 @@ def get_service_account_info(request, service_account="default"): google.auth.exceptions.TransportError: if an error occurred while retrieving metadata. """ - return get( - request, - "instance/service-accounts/{0}/".format(service_account), - recursive=True, - ) + path = "instance/service-accounts/{0}/".format(service_account) + # See https://cloud.google.com/compute/docs/metadata#aggcontents + # for more on the use of 'recursive'. + return get(request, path, params={"recursive": "true"}) def get_service_account_token(request, service_account="default"): diff --git a/google/auth/compute_engine/credentials.py b/google/auth/compute_engine/credentials.py index b7fca1832..8a41ffcc0 100644 --- a/google/auth/compute_engine/credentials.py +++ b/google/auth/compute_engine/credentials.py @@ -323,12 +323,9 @@ def _call_metadata_identity_endpoint(self, request): ValueError: If extracting expiry from the obtained ID token fails. """ try: - id_token = _metadata.get( - request, - "instance/service-accounts/default/identity?audience={}&format=full".format( - self._target_audience - ), - ) + path = "instance/service-accounts/default/identity" + params = {"audience": self._target_audience, "format": "full"} + id_token = _metadata.get(request, path, params=params) except exceptions.TransportError as caught_exc: new_exc = exceptions.RefreshError(caught_exc) six.raise_from(new_exc, caught_exc) diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index d9b039a32..d05337263 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -155,6 +155,49 @@ def test_get_success_text(): assert result == data +def test_get_success_params(): + data = "foobar" + request = make_request(data, headers={"content-type": "text/plain"}) + params = {"recursive": "true"} + + result = _metadata.get(request, PATH, params=params) + + request.assert_called_once_with( + method="GET", + url=_metadata._METADATA_ROOT + PATH + "?recursive=true", + headers=_metadata._METADATA_HEADERS, + ) + assert result == data + + +def test_get_success_recursive_and_params(): + data = "foobar" + request = make_request(data, headers={"content-type": "text/plain"}) + params = {"recursive": "false"} + result = _metadata.get(request, PATH, recursive=True, params=params) + + request.assert_called_once_with( + method="GET", + url=_metadata._METADATA_ROOT + PATH + "?recursive=true", + headers=_metadata._METADATA_HEADERS, + ) + assert result == data + + +def test_get_success_recursive(): + data = "foobar" + request = make_request(data, headers={"content-type": "text/plain"}) + + result = _metadata.get(request, PATH, recursive=True) + + request.assert_called_once_with( + method="GET", + url=_metadata._METADATA_ROOT + PATH + "?recursive=true", + headers=_metadata._METADATA_HEADERS, + ) + assert result == data + + def test_get_success_custom_root_new_variable(): request = make_request("{}", headers={"content-type": "application/json"})