diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index f4fae7298..30cd3d43b 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -99,7 +99,7 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): return False -def get(request, path, root=_METADATA_ROOT, recursive=False): +def get(request, path, root=_METADATA_ROOT, recursive=False, retry_count=5): """Fetch a resource from the metadata server. Args: @@ -111,6 +111,8 @@ def get(request, path, root=_METADATA_ROOT, recursive=False): recursive (bool): Whether to do a recursive query of metadata. See https://cloud.google.com/compute/docs/metadata#aggcontents for more details. + retry_count (int): How many times to attempt connecting to metadata + server using above timeout. Returns: Union[Mapping, str]: If the metadata server returns JSON, a mapping of @@ -129,7 +131,24 @@ def get(request, path, root=_METADATA_ROOT, recursive=False): url = _helpers.update_query(base_url, query_params) - response = request(url=url, method="GET", headers=_METADATA_HEADERS) + retries = 0 + while retries < retry_count: + try: + response = request(url=url, method="GET", headers=_METADATA_HEADERS) + break + + except exceptions.TransportError: + _LOGGER.info( + "Compute Engine Metadata server unavailable on" "attempt %s of %s", + retries + 1, + retry_count, + ) + retries += 1 + else: + raise exceptions.TransportError( + "Failed to retrieve {} from the Google Compute Engine" + "metadata service. Compute Engine Metadata server unavailable".format(url) + ) if response.status == http_client.OK: content = _helpers.from_bytes(response.data) diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index bd06b7402..0898e1f4e 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -30,14 +30,17 @@ PATH = "instance/service-accounts/default" -def make_request(data, status=http_client.OK, headers=None): +def make_request(data, status=http_client.OK, headers=None, retry=False): response = mock.create_autospec(transport.Response, instance=True) response.status = status response.data = _helpers.to_bytes(data) response.headers = headers or {} request = mock.create_autospec(transport.Request) - request.return_value = response + if retry: + request.side_effect = [exceptions.TransportError(), response] + else: + request.return_value = response return request @@ -55,6 +58,20 @@ def test_ping_success(): ) +def test_ping_success_retry(): + request = make_request("", headers=_metadata._METADATA_HEADERS, retry=True) + + assert _metadata.ping(request) + + request.assert_called_with( + method="GET", + url=_metadata._METADATA_IP_ROOT, + headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, + ) + assert request.call_count == 2 + + def test_ping_failure_bad_flavor(): request = make_request("", headers={_metadata._METADATA_FLAVOR_HEADER: "meep"}) @@ -105,6 +122,25 @@ def test_get_success_json(): assert result[key] == value +def test_get_success_retry(): + key, value = "foo", "bar" + + data = json.dumps({key: value}) + request = make_request( + data, headers={"content-type": "application/json"}, retry=True + ) + + result = _metadata.get(request, PATH) + + request.assert_called_with( + method="GET", + url=_metadata._METADATA_ROOT + PATH, + headers=_metadata._METADATA_HEADERS, + ) + assert request.call_count == 2 + assert result[key] == value + + def test_get_success_text(): data = "foobar" request = make_request(data, headers={"content-type": "text/plain"}) @@ -154,6 +190,23 @@ def test_get_failure(): ) +def test_get_failure_connection_failed(): + request = make_request("") + request.side_effect = exceptions.TransportError() + + with pytest.raises(exceptions.TransportError) as excinfo: + _metadata.get(request, PATH) + + assert excinfo.match(r"Compute Engine Metadata server unavailable") + + request.assert_called_with( + method="GET", + url=_metadata._METADATA_ROOT + PATH, + headers=_metadata._METADATA_HEADERS, + ) + assert request.call_count == 5 + + def test_get_failure_bad_json(): request = make_request("{", headers={"content-type": "application/json"})