Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ internal class EnvironmentVariables
public static string MsiEndpoint => Environment.GetEnvironmentVariable("MSI_ENDPOINT");
public static string MsiSecret => Environment.GetEnvironmentVariable("MSI_SECRET");
public static string IdentityServerThumbprint => Environment.GetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT");
public static string MachineLearningDefaultClientId => Environment.GetEnvironmentVariable("DEFAULT_IDENTITY_CLIENT_ID");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ namespace Microsoft.Identity.Client.ManagedIdentity
{
internal class MachineLearningManagedIdentitySource : AbstractManagedIdentity
{
private const string MachineLearning = "Machine Learning";

private const string MachineLearningMsiApiVersion = "2017-09-01";
private const string SecretHeaderName = "secret";

private readonly Uri _endpoint;
private readonly string _secret;

public const string UnsupportedIdTypeError = "Only client id is supported for user-assigned managed identity in Machine Learning."; // referenced in unit test

public static AbstractManagedIdentity Create(RequestContext requestContext)
{
requestContext.Logger.Info(() => "[Managed Identity] Machine learning managed identity is available.");
Expand Down Expand Up @@ -47,15 +51,12 @@ private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger
MsalErrorMessage.ManagedIdentityEndpointInvalidUriError,
"MSI_ENDPOINT", msiEndpoint, "Machine learning");

// Use the factory to create and throw the exception
var exception = MsalServiceExceptionFactory.CreateManagedIdentityException(
throw MsalServiceExceptionFactory.CreateManagedIdentityException(
MsalError.InvalidManagedIdentityEndpoint,
errorMessage,
ex,
ManagedIdentitySource.MachineLearning,
null); // statusCode is null in this case

throw exception;
}

logger.Info($"[Managed Identity] Environment variables validation passed for machine learning managed identity. Endpoint URI: {endpointUri}. Creating machine learning managed identity.");
Expand All @@ -73,21 +74,37 @@ protected override ManagedIdentityRequest CreateRequest(string resource)

switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType)
{
case AppConfig.ManagedIdentityIdType.SystemAssigned:
_requestContext.Logger.Info("[Managed Identity] Adding system assigned client id to the request.");

// this environment variable is always set in an Azure Machine Learning source, but check if null just in case
if (EnvironmentVariables.MachineLearningDefaultClientId == null)
{
throw MsalServiceExceptionFactory.CreateManagedIdentityException(
MsalError.InvalidManagedIdentityIdType,
"The DEFAULT_IDENTITY_CLIENT_ID environment variable is null.",
null, // configuration error
ManagedIdentitySource.MachineLearning,
null); // statusCode is null in this case
}

// Use the new 2017 constant for older ML-based environment
request.QueryParameters[Constants.ManagedIdentityClientId2017] = EnvironmentVariables.MachineLearningDefaultClientId;
break;

case AppConfig.ManagedIdentityIdType.ClientId:
_requestContext.Logger.Info("[Managed Identity] Adding user assigned client id to the request.");
// Use the new 2017 constant for older ML-based environment
request.QueryParameters[Constants.ManagedIdentityClientId2017] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId;
break;

case AppConfig.ManagedIdentityIdType.ResourceId:
_requestContext.Logger.Info("[Managed Identity] Adding user assigned resource id to the request.");
request.QueryParameters[Constants.ManagedIdentityResourceId] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId;
break;

case AppConfig.ManagedIdentityIdType.ObjectId:
_requestContext.Logger.Info("[Managed Identity] Adding user assigned object id to the request.");
request.QueryParameters[Constants.ManagedIdentityObjectId] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId;
break;
default:
throw MsalServiceExceptionFactory.CreateManagedIdentityException(
MsalError.InvalidManagedIdentityIdType,
UnsupportedIdTypeError,
null, // configuration error
ManagedIdentitySource.MachineLearning,
null); // statusCode is null in this case
}

return request;
Expand Down
10 changes: 10 additions & 0 deletions src/client/Microsoft.Identity.Client/MsalError.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,16 @@ public static class MsalError
/// </summary>
public const string InvalidManagedIdentityResponse = "invalid_managed_identity_response";

/// <summary>
/// The managed identity's source does not select a specific id type.
/// </summary>
public const string InvalidManagedIdentityIdType = "invalid_managed_identity_id_type";

/// <summary>
/// The managed identity is missing a required environment variable.
/// </summary>
public const string MissingManagedIdentityEnvVar = "missing_managed_identity_env_var";

/// <summary>
/// Managed Identity error response was received.
/// </summary>
Expand Down
1 change: 1 addition & 0 deletions src/client/Microsoft.Identity.Client/MsalErrorMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ public static string InvalidTokenProviderResponseValue(string invalidValueName)

public const string ManagedIdentityNoResponseReceived = "[Managed Identity] Authentication unavailable. No response received from the managed identity endpoint.";
public const string ManagedIdentityInvalidResponse = "[Managed Identity] Invalid response, the authentication response received did not contain the expected fields.";
public const string ManagedIdentityInvalidIdType = "Only {0} supported for user-assigned managed identity in {1}";
public const string ManagedIdentityJsonParseFailure = "[Managed Identity] MSI returned 200 OK, but the response could not be parsed.";
public const string ManagedIdentityUnexpectedResponse = "[Managed Identity] Unexpected exception occurred when parsing the response. See the inner exception for details.";
public const string ManagedIdentityExactlyOneScopeExpected = "[Managed Identity] To acquire token for managed identity, exactly one scope must be passed.";
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@

const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@

const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@

const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentity
case ManagedIdentitySource.MachineLearning:
Environment.SetEnvironmentVariable("MSI_ENDPOINT", endpoint);
Environment.SetEnvironmentVariable("MSI_SECRET", secret);
Environment.SetEnvironmentVariable("DEFAULT_IDENTITY_CLIENT_ID", "fake_DEFAULT_IDENTITY_CLIENT_ID");
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ public static void AddRegionDiscoveryMockHandler(
});
}

public static void AddManagedIdentityMockHandler(
public static MockHttpMessageHandler AddManagedIdentityMockHandler(
this MockHttpManager httpManager,
string expectedUrl,
string resource,
Expand All @@ -383,37 +383,42 @@ public static void AddManagedIdentityMockHandler(

MockHttpMessageHandler httpMessageHandler = BuildMockHandlerForManagedIdentitySource(managedIdentitySourceType, resource);

if (userAssignedIdentityId == UserAssignedIdentityId.ClientId)
if (managedIdentitySourceType == ManagedIdentitySource.MachineLearning)
{
if (managedIdentitySourceType == ManagedIdentitySource.MachineLearning)
{
// For Machine Learning (App Service 2017), the param is "clientid"
httpMessageHandler.ExpectedQueryParams.Add(Constants.ManagedIdentityClientId2017, userAssignedId);
}
else
{
// For App Service 2019, Azure Arc, IMDS, etc., the param is "client_id"
httpMessageHandler.ExpectedQueryParams.Add(Constants.ManagedIdentityClientId, userAssignedId);
}
// For Machine Learning (App Service 2017), the client id param is "clientid"
// it will always be a query parameter, no matter the source type
// use env var for SAMI, passed-in userAssignedId for UAMI
httpMessageHandler.ExpectedQueryParams.Add(
Constants.ManagedIdentityClientId2017,
userAssignedId ?? EnvironmentVariables.MachineLearningDefaultClientId);
}

if (userAssignedIdentityId == UserAssignedIdentityId.ResourceId)
else if (userAssignedIdentityId == UserAssignedIdentityId.ClientId)
{
// For App Service 2019, Azure Arc, IMDS, etc., the param is "client_id"
httpMessageHandler.ExpectedQueryParams.Add(
Constants.ManagedIdentityClientId,
userAssignedId);
}
else if (userAssignedIdentityId == UserAssignedIdentityId.ResourceId)
{
httpMessageHandler.ExpectedQueryParams.Add(
managedIdentitySourceType == ManagedIdentitySource.Imds ?
Constants.ManagedIdentityResourceIdImds : Constants.ManagedIdentityResourceId,
userAssignedId);
}

if (userAssignedIdentityId == UserAssignedIdentityId.ObjectId)
else if (userAssignedIdentityId == UserAssignedIdentityId.ObjectId)
{
httpMessageHandler.ExpectedQueryParams.Add(Constants.ManagedIdentityObjectId, userAssignedId);
httpMessageHandler.ExpectedQueryParams.Add(
Constants.ManagedIdentityObjectId,
userAssignedId);
}

httpMessageHandler.ResponseMessage = responseMessage;
httpMessageHandler.ExpectedUrl = expectedUrl;

httpManager.AddMockHandler(httpMessageHandler);

return httpMessageHandler;
}

private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(ManagedIdentitySource managedIdentitySourceType, string resource)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

using System;
using System.Globalization;
using System.Net;
using System.Threading.Tasks;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.AppConfig;
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.ManagedIdentity;
using Microsoft.Identity.Test.Common;
using Microsoft.Identity.Test.Common.Core.Helpers;
using Microsoft.Identity.Test.Common.Core.Mocks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
Expand All @@ -20,6 +19,84 @@ namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests
public class MachineLearningTests : TestBase
{
private const string MachineLearning = "Machine learning";
private const string MachineLearningEndpoint = "http://localhost:7071/msi/token";
internal const string Resource = "https://management.azure.com";

[DataTestMethod]
[DataRow(null, null)] // SAMI
[DataRow(TestConstants.ClientId, UserAssignedIdentityId.ClientId)] // UAMI
public async Task MachineLearningUserAssignedHappyPathAndHasCorrectClientIdQueryParameterAsync(
string userAssignedId,
UserAssignedIdentityId userAssignedIdentityId)
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.MachineLearning, MachineLearningEndpoint);

ManagedIdentityId managedIdentityId = userAssignedId == null
? ManagedIdentityId.SystemAssigned
: ManagedIdentityId.WithUserAssignedClientId(userAssignedId);
var miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId)
.WithHttpManager(httpManager);

// Disabling shared cache options to avoid cross test pollution.
miBuilder.Config.AccessorOptions = null;

var mi = miBuilder.Build();

MockHttpMessageHandler mockHandler = httpManager.AddManagedIdentityMockHandler(
MachineLearningEndpoint,
Resource,
MockHelpers.GetMsiSuccessfulResponse(),
ManagedIdentitySource.MachineLearning,
userAssignedId: userAssignedId,
userAssignedIdentityId);

AuthenticationResult result = await mi.AcquireTokenForManagedIdentity(Resource).ExecuteAsync().ConfigureAwait(false);

Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);

// Verify query parameter is "clientid" and not "client_id"
Assert.IsTrue(mockHandler.ExpectedQueryParams.ContainsKey(Constants.ManagedIdentityClientId2017), "Query parameter should use 'clientid' and not 'client_id'");

// Verify the clientid value based on identity type
string expectedClientId = userAssignedId ?? EnvironmentVariables.MachineLearningDefaultClientId;
Assert.AreEqual(expectedClientId, mockHandler.ExpectedQueryParams[Constants.ManagedIdentityClientId2017],
"Clientid value should match the provided user assigned ID for UAMI or environment variable for SAMI");
}
}

[DataTestMethod]
[DataRow(TestConstants.MiResourceId, UserAssignedIdentityId.ResourceId)]
[DataRow(TestConstants.MiResourceId, UserAssignedIdentityId.ObjectId)]
public async Task MachineLearningUserAssignedNonClientIdThrowsAsync(
string userAssignedId,
UserAssignedIdentityId userAssignedIdentityId)
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.MachineLearning, MachineLearningEndpoint);

var miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId)
.WithHttpManager(httpManager);

// Disabling shared cache options to avoid cross test pollution.
miBuilder.Config.AccessorOptions = null;

var mi = miBuilder.Build();

MsalServiceException ex = await Assert.ThrowsExceptionAsync<MsalServiceException>(async () =>
await mi.AcquireTokenForManagedIdentity(Resource)
.ExecuteAsync().ConfigureAwait(false)).ConfigureAwait(false);

Assert.IsNotNull(ex);
Assert.AreEqual(ManagedIdentitySource.MachineLearning.ToString(), ex.AdditionalExceptionData[MsalException.ManagedIdentitySource]);
Assert.AreEqual(MsalError.InvalidManagedIdentityIdType, ex.ErrorCode);
}
}

[TestMethod]
public async Task MachineLearningTestsInvalidEndpointAsync()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ public async Task ManagedIdentityHappyPathAsync(
[DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.MiResourceId, UserAssignedIdentityId .ResourceId)]
[DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.MiResourceId, UserAssignedIdentityId.ObjectId)]
[DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, TestConstants.ClientId, UserAssignedIdentityId.ClientId)]
[DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, TestConstants.MiResourceId, UserAssignedIdentityId.ResourceId)]
[DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, TestConstants.MiResourceId, UserAssignedIdentityId.ObjectId)]
public async Task ManagedIdentityUserAssignedHappyPathAsync(
string endpoint,
ManagedIdentitySource managedIdentitySource,
Expand Down
Loading