diff --git a/msal/managed_identity.py b/msal/managed_identity.py index ec032ca7..6f85571d 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -448,7 +448,9 @@ def _obtain_token_on_azure_vm(http_client, managed_identity, resource): } _adjust_param(params, managed_identity) resp = http_client.get( - "http://169.254.169.254/metadata/identity/oauth2/token", + os.getenv( + "AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://169.254.169.254" + ).strip("/") + "/metadata/identity/oauth2/token", params=params, headers={"Metadata": "true"}, ) diff --git a/tests/test_mi.py b/tests/test_mi.py index c5a99ae3..a7c2cb6c 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -121,13 +121,29 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"): class VmTestCase(ClientTestCase): - def test_happy_path(self): + def _test_happy_path(self) -> callable: expires_in = 7890 # We test a bigger than 7200 value here with patch.object(self.app._http_client, "get", return_value=MinimalResponse( status_code=200, text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in, )) as mocked_method: - self._test_happy_path(self.app, mocked_method, expires_in) + super(VmTestCase, self)._test_happy_path(self.app, mocked_method, expires_in) + return mocked_method + + def test_happy_path_of_vm(self): + self._test_happy_path().assert_called_with( + 'http://169.254.169.254/metadata/identity/oauth2/token', + params={'api-version': '2018-02-01', 'resource': 'R'}, + headers={'Metadata': 'true'}, + ) + + @patch.dict(os.environ, {"AZURE_POD_IDENTITY_AUTHORITY_HOST": "http://localhost:1234//"}) + def test_happy_path_of_pod_identity(self): + self._test_happy_path().assert_called_with( + 'http://localhost:1234/metadata/identity/oauth2/token', + params={'api-version': '2018-02-01', 'resource': 'R'}, + headers={'Metadata': 'true'}, + ) def test_vm_error_should_be_returned_as_is(self): raw_error = '{"raw": "error format is undefined"}'