Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove wireserver fallback for imds calls #3152

Merged
merged 5 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion azurelinuxagent/common/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def initialize_vminfo_common_parameters(self, protocol):
logger.warn("Failed to get VM info from goal state; will be missing from telemetry: {0}", ustr(e))

try:
imds_client = get_imds_client(protocol.get_endpoint())
imds_client = get_imds_client()
imds_info = imds_client.get_compute()
parameters[CommonTelemetryEventSchema.Location].value = imds_info.location
parameters[CommonTelemetryEventSchema.SubscriptionId].value = imds_info.subscriptionId
Expand Down
11 changes: 4 additions & 7 deletions azurelinuxagent/common/protocol/imds.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
IMDS_INTERNAL_SERVER_ERROR = 3


def get_imds_client(wireserver_endpoint):
return ImdsClient(wireserver_endpoint)
def get_imds_client():
return ImdsClient()


# A *slightly* future proof list of endorsed distros.
Expand Down Expand Up @@ -256,7 +256,7 @@ def image_origin(self):


class ImdsClient(object):
def __init__(self, wireserver_endpoint, version=APIVERSION):
def __init__(self, version=APIVERSION):
self._api_version = version
self._headers = {
'User-Agent': restutil.HTTP_USER_AGENT,
Expand All @@ -268,7 +268,6 @@ def __init__(self, wireserver_endpoint, version=APIVERSION):
}
self._regex_ioerror = re.compile(r".*HTTP Failed. GET http://[^ ]+ -- IOError .*")
self._regex_throttled = re.compile(r".*HTTP Retry. GET http://[^ ]+ -- Status Code 429 .*")
self._wireserver_endpoint = wireserver_endpoint

def _get_metadata_url(self, endpoint, resource_path):
return BASE_METADATA_URI.format(endpoint, resource_path, self._api_version)
Expand Down Expand Up @@ -326,14 +325,12 @@ def get_metadata(self, resource_path, is_health):
endpoint = IMDS_ENDPOINT

status, resp = self._get_metadata_from_endpoint(endpoint, resource_path, headers)
if status == IMDS_CONNECTION_ERROR:
endpoint = self._wireserver_endpoint
status, resp = self._get_metadata_from_endpoint(endpoint, resource_path, headers)

if status == IMDS_RESPONSE_SUCCESS:
return MetadataResult(True, False, resp)
elif status == IMDS_INTERNAL_SERVER_ERROR:
return MetadataResult(False, True, resp)
# else it's a client-side error, e.g. IMDS_CONNECTION_ERROR
return MetadataResult(False, False, resp)

def get_compute(self):
Expand Down
6 changes: 3 additions & 3 deletions azurelinuxagent/ga/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ class SendImdsHeartbeat(PeriodicOperation):
Periodic operation to report the IDMS's health. The signal is 'Healthy' when we have successfully called and validated
a response in the last _IMDS_HEALTH_PERIOD.
"""
def __init__(self, protocol_util, health_service):
def __init__(self, health_service):
super(SendImdsHeartbeat, self).__init__(SendImdsHeartbeat._IMDS_HEARTBEAT_PERIOD)
self.health_service = health_service
self.imds_client = get_imds_client(protocol_util.get_wireserver_endpoint())
self.imds_client = get_imds_client()
self.imds_error_state = ErrorState(min_timedelta=SendImdsHeartbeat._IMDS_HEALTH_PERIOD)

_IMDS_HEARTBEAT_PERIOD = datetime.timedelta(minutes=1)
Expand Down Expand Up @@ -298,7 +298,7 @@ def daemon(self):
PollResourceUsage(),
PollSystemWideResourceUsage(),
SendHostPluginHeartbeat(protocol, health_service),
SendImdsHeartbeat(protocol_util, health_service)
SendImdsHeartbeat(health_service)
]

report_network_configuration_changes = ReportNetworkConfigurationChanges()
Expand Down
20 changes: 0 additions & 20 deletions azurelinuxagent/ga/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

from azurelinuxagent.common import conf
from azurelinuxagent.common import logger
from azurelinuxagent.common.protocol.imds import get_imds_client
from azurelinuxagent.common.utils import fileutil, textutil
from azurelinuxagent.common.agent_supported_feature import get_supported_feature_by_name, SupportedFeatureNames, \
get_agent_supported_features_list_for_crp
Expand Down Expand Up @@ -475,25 +474,6 @@ def _wait_for_cloud_init(self):
add_event(op=WALAEventOperation.CloudInit, message=message, is_success=False, log_event=False)
self._cloud_init_completed = True # Mark as completed even on error since we will proceed to execute extensions

def _get_vm_size(self, protocol):
"""
Including VMSize is meant to capture the architecture of the VM (i.e. arm64 VMs will
have arm64 included in their vmsize field and amd64 will have no architecture indicated).
"""
if self._vm_size is None:

imds_client = get_imds_client(protocol.get_endpoint())

try:
imds_info = imds_client.get_compute()
self._vm_size = imds_info.vmSize
except Exception as e:
err_msg = "Attempts to retrieve VM size information from IMDS are failing: {0}".format(textutil.format_exception(e))
logger.periodic_warn(logger.EVERY_SIX_HOURS, "[PERIODIC] {0}".format(err_msg))
return "unknown"

return self._vm_size

def _get_vm_arch(self):
return platform.machine()

Expand Down
155 changes: 71 additions & 84 deletions tests/common/protocol/test_imds.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TestImds(AgentTestCase):
def test_get(self, mock_http_get):
mock_http_get.return_value = get_mock_compute_response()

test_subject = imds.ImdsClient(restutil.KNOWN_WIRESERVER_IP)
test_subject = imds.ImdsClient()
test_subject.get_compute()

self.assertEqual(1, mock_http_get.call_count)
Expand All @@ -71,21 +71,21 @@ def test_get(self, mock_http_get):
def test_get_bad_request(self, mock_http_get):
mock_http_get.return_value = MockHttpResponse(status=restutil.httpclient.BAD_REQUEST)

test_subject = imds.ImdsClient(restutil.KNOWN_WIRESERVER_IP)
test_subject = imds.ImdsClient()
self.assertRaises(HttpError, test_subject.get_compute)

@patch("azurelinuxagent.common.protocol.imds.restutil.http_get")
def test_get_internal_service_error(self, mock_http_get):
mock_http_get.return_value = MockHttpResponse(status=restutil.httpclient.INTERNAL_SERVER_ERROR)

test_subject = imds.ImdsClient(restutil.KNOWN_WIRESERVER_IP)
test_subject = imds.ImdsClient()
self.assertRaises(HttpError, test_subject.get_compute)

@patch("azurelinuxagent.common.protocol.imds.restutil.http_get")
def test_get_empty_response(self, mock_http_get):
mock_http_get.return_value = MockHttpResponse(status=httpclient.OK, body=''.encode('utf-8'))

test_subject = imds.ImdsClient(restutil.KNOWN_WIRESERVER_IP)
test_subject = imds.ImdsClient()
self.assertRaises(ValueError, test_subject.get_compute)

def test_deserialize_ComputeInfo(self):
Expand Down Expand Up @@ -359,7 +359,7 @@ def _imds_response(f):
return fh.read()

def _assert_validation(self, http_status_code, http_response, expected_valid, expected_response):
test_subject = imds.ImdsClient(restutil.KNOWN_WIRESERVER_IP)
test_subject = imds.ImdsClient()
with patch("azurelinuxagent.common.utils.restutil.http_get") as mock_http_get:
mock_http_get.return_value = MockHttpResponse(status=http_status_code,
reason='reason',
Expand All @@ -386,99 +386,86 @@ def test_endpoint_fallback(self):
# http GET calls and enforces a single GET call (fallback would cause 2) and
# checks the url called.

test_subject = imds.ImdsClient("foo.bar")
test_subject = imds.ImdsClient()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably change this test name + comment since we're no longer using a fallback method.


# ensure user-agent gets set correctly
for is_health, expected_useragent in [(False, restutil.HTTP_USER_AGENT), (True, restutil.HTTP_USER_AGENT_HEALTH)]:
# set a different resource path for health query to make debugging unit test easier
resource_path = 'something/health' if is_health else 'something'

for has_primary_ioerror in (False, True):
# secondary endpoint unreachable
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(primary_ioerror=has_primary_ioerror, secondary_ioerror=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success) if has_primary_ioerror else self.assertTrue(result.success) # pylint: disable=expression-not-assigned
self.assertFalse(result.service_error)
if has_primary_ioerror:
self.assertEqual('IMDS error in /metadata/{0}: Unable to connect to endpoint'.format(resource_path), result.response)
else:
self.assertEqual('Mock success response', result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

# IMDS success
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(primary_ioerror=has_primary_ioerror)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertTrue(result.success)
self.assertFalse(result.service_error)
self.assertEqual('Mock success response', result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

# IMDS throttled
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(primary_ioerror=has_primary_ioerror, throttled=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertFalse(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: Throttled'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

# IMDS gone error
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(primary_ioerror=has_primary_ioerror, gone_error=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertTrue(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: HTTP Failed with Status Code 410: Gone'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

# IMDS bad request
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(primary_ioerror=has_primary_ioerror, bad_request=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertFalse(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: [HTTP Failed] [404: reason] Mock not found'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

def _mock_imds_setup(self, primary_ioerror=False, secondary_ioerror=False, gone_error=False, throttled=False, bad_request=False):
self._mock_imds_expect_fallback = primary_ioerror # pylint: disable=attribute-defined-outside-init
self._mock_imds_primary_ioerror = primary_ioerror # pylint: disable=attribute-defined-outside-init
self._mock_imds_secondary_ioerror = secondary_ioerror # pylint: disable=attribute-defined-outside-init
# IMDS success
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup()
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertTrue(result.success)
self.assertFalse(result.service_error)
self.assertEqual('Mock success response', result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(1, test_subject._http_get.call_count)

# Connection error
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(ioerror=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertFalse(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: Unable to connect to endpoint'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(1, test_subject._http_get.call_count)

# IMDS throttled
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(throttled=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertFalse(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: Throttled'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(1, test_subject._http_get.call_count)

# IMDS gone error
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(gone_error=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertTrue(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: HTTP Failed with Status Code 410: Gone'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(1, test_subject._http_get.call_count)

# IMDS bad request
test_subject._http_get = Mock(side_effect=self._mock_http_get)
self._mock_imds_setup(bad_request=True)
result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
self.assertFalse(result.success)
self.assertFalse(result.service_error)
self.assertEqual('IMDS error in /metadata/{0}: [HTTP Failed] [404: reason] Mock not found'.format(resource_path), result.response)
for _, kwargs in test_subject._http_get.call_args_list:
self.assertTrue('User-Agent' in kwargs['headers'])
self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
self.assertEqual(1, test_subject._http_get.call_count)

def _mock_imds_setup(self, ioerror=False, gone_error=False, throttled=False, bad_request=False):
self._mock_imds_ioerror = ioerror # pylint: disable=attribute-defined-outside-init
self._mock_imds_gone_error = gone_error # pylint: disable=attribute-defined-outside-init
self._mock_imds_throttled = throttled # pylint: disable=attribute-defined-outside-init
self._mock_imds_bad_request = bad_request # pylint: disable=attribute-defined-outside-init

def _mock_http_get(self, *_, **kwargs):
if "foo.bar" == kwargs['endpoint'] and not self._mock_imds_expect_fallback:
raise Exception("Unexpected endpoint called")
if self._mock_imds_primary_ioerror and "169.254.169.254" == kwargs['endpoint']:
raise HttpError("[HTTP Failed] GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made"
.format(kwargs['endpoint'], kwargs['resource_path']))
if self._mock_imds_secondary_ioerror and "foo.bar" == kwargs['endpoint']:
raise HttpError("[HTTP Failed] GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made"
.format(kwargs['endpoint'], kwargs['resource_path']))
if self._mock_imds_ioerror:
raise HttpError("[HTTP Failed] GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made".format(kwargs['endpoint'], kwargs['resource_path']))
if self._mock_imds_gone_error:
raise ResourceGoneError("Resource is gone")
if self._mock_imds_throttled:
raise HttpError("[HTTP Retry] GET http://{0}/metadata/{1} -- Status Code 429 -- 25 attempts made"
.format(kwargs['endpoint'], kwargs['resource_path']))
raise HttpError("[HTTP Retry] GET http://{0}/metadata/{1} -- Status Code 429 -- 25 attempts made".format(kwargs['endpoint'], kwargs['resource_path']))

resp = MagicMock()
resp.reason = 'reason'
Expand Down
6 changes: 3 additions & 3 deletions tests/ga/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from tests.lib.mock_update_handler import mock_update_handler
from tests.lib.mock_wire_protocol import mock_wire_protocol, MockHttpResponse
from tests.lib.wire_protocol_data import DATA_FILE, DATA_FILE_MULTIPLE_EXT, DATA_FILE_VM_SETTINGS
from tests.lib.tools import AgentTestCase, AgentTestCaseWithGetVmSizeMock, data_dir, DEFAULT, patch, load_bin_data, Mock, MagicMock, \
from tests.lib.tools import AgentTestCase, data_dir, DEFAULT, patch, load_bin_data, Mock, MagicMock, \
clear_singleton_instances, is_python_version_26_or_34, skip_if_predicate_true
from tests.lib import wire_protocol_data
from tests.lib.http_request_predicates import HttpRequestPredicates
Expand Down Expand Up @@ -119,7 +119,7 @@ def _get_update_handler(iterations=1, test_data=None, protocol=None, autoupdate_
yield update_handler, protocol


class UpdateTestCase(AgentTestCaseWithGetVmSizeMock):
class UpdateTestCase(AgentTestCase):
_test_suite_tmp_dir = None
_agent_zip_dir = None

Expand Down Expand Up @@ -1928,7 +1928,7 @@ def reload_conf(url, protocol):
@patch('azurelinuxagent.ga.update.get_collect_logs_handler')
@patch('azurelinuxagent.ga.update.get_monitor_handler')
@patch('azurelinuxagent.ga.update.get_env_handler')
class MonitorThreadTest(AgentTestCaseWithGetVmSizeMock):
class MonitorThreadTest(AgentTestCase):
def setUp(self):
super(MonitorThreadTest, self).setUp()
self.event_patch = patch('azurelinuxagent.common.event.add_event')
Expand Down
Loading
Loading