Skip to content

Commit 89ff42a

Browse files
authored
[MSI v2] - Enable attestation in pop flows (#5496)
* attestation * address pr comments * remove console * pr comments * pr comments
1 parent 087130e commit 89ff42a

File tree

11 files changed

+363
-16
lines changed

11 files changed

+363
-16
lines changed

src/client/Microsoft.Identity.Client/ApiConfig/AcquireTokenForManagedIdentityParameterBuilder.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using System.Threading.Tasks;
99
using Microsoft.Identity.Client.ApiConfig.Executors;
1010
using Microsoft.Identity.Client.ApiConfig.Parameters;
11+
using Microsoft.Identity.Client.ManagedIdentity;
1112
using Microsoft.Identity.Client.TelemetryCore.Internal.Events;
1213
using Microsoft.Identity.Client.Utils;
1314

@@ -80,6 +81,7 @@ public AcquireTokenForManagedIdentityParameterBuilder WithClaims(string claims)
8081
/// <inheritdoc/>
8182
internal override Task<AuthenticationResult> ExecuteInternalAsync(CancellationToken cancellationToken)
8283
{
84+
ApplyMtlsPopAndAttestation(acquireTokenForManagedIdentityParameters: Parameters, acquireTokenCommonParameters: CommonParameters);
8385
return ManagedIdentityApplicationExecutor.ExecuteAsync(CommonParameters, Parameters, cancellationToken);
8486
}
8587

@@ -93,5 +95,29 @@ internal override ApiEvent.ApiIds CalculateApiEventId()
9395

9496
return ApiEvent.ApiIds.AcquireTokenForUserAssignedManagedIdentity;
9597
}
98+
99+
/// <summary>
100+
/// TEST HOOK ONLY: Allows unit tests to inject a fake attestation-token provider
101+
/// so we don't hit the real attestation service. Not part of the public API.
102+
/// </summary>
103+
internal AcquireTokenForManagedIdentityParameterBuilder WithAttestationProviderForTests(
104+
Func<AttestationTokenInput, CancellationToken, Task<AttestationTokenResponse>> provider)
105+
{
106+
if (provider is null)
107+
{
108+
throw new ArgumentNullException(nameof(provider));
109+
}
110+
111+
CommonParameters.AttestationTokenProvider = provider;
112+
return this;
113+
}
114+
115+
private static void ApplyMtlsPopAndAttestation(
116+
AcquireTokenCommonParameters acquireTokenCommonParameters,
117+
AcquireTokenForManagedIdentityParameters acquireTokenForManagedIdentityParameters)
118+
{
119+
acquireTokenForManagedIdentityParameters.IsMtlsPopRequested = acquireTokenCommonParameters.IsMtlsPopRequested;
120+
acquireTokenForManagedIdentityParameters.AttestationTokenProvider ??= acquireTokenCommonParameters.AttestationTokenProvider;
121+
}
96122
}
97123
}

src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenForManagedIdentityParameters.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
using System.Collections.Generic;
66
using System.Linq;
77
using System.Text;
8+
using System.Threading;
89
using System.Threading.Tasks;
910
using Microsoft.Identity.Client.Core;
11+
using Microsoft.Identity.Client.ManagedIdentity;
1012

1113
namespace Microsoft.Identity.Client.ApiConfig.Parameters
1214
{
@@ -22,6 +24,8 @@ internal class AcquireTokenForManagedIdentityParameters : IAcquireTokenParameter
2224

2325
public bool IsMtlsPopRequested { get; set; }
2426

27+
internal Func<AttestationTokenInput, CancellationToken, Task<AttestationTokenResponse>> AttestationTokenProvider { get; set; }
28+
2529
public void LogParameters(ILoggerAdapter logger)
2630
{
2731
if (logger.IsLoggingEnabled(LogLevel.Info))

src/client/Microsoft.Identity.Client/AppConfig/ApplicationConfiguration.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
using Microsoft.Identity.Client.Internal.Broker;
1818
using Microsoft.Identity.Client.Internal.ClientCredential;
1919
using Microsoft.Identity.Client.Kerberos;
20+
using Microsoft.Identity.Client.ManagedIdentity;
2021
using Microsoft.Identity.Client.ManagedIdentity.V2;
2122
using Microsoft.Identity.Client.PlatformsCommon.Interfaces;
2223
using Microsoft.Identity.Client.UI;

src/client/Microsoft.Identity.Client/AppConfig/ManagedIdentityApplicationBuilder.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using Microsoft.Identity.Client.AppConfig;
1212
using Microsoft.Identity.Client.Extensibility;
1313
using Microsoft.Identity.Client.Internal;
14+
using Microsoft.Identity.Client.ManagedIdentity;
1415
using Microsoft.Identity.Client.TelemetryCore;
1516
using Microsoft.Identity.Client.TelemetryCore.TelemetryClient;
1617
using Microsoft.Identity.Client.Utils;

src/client/Microsoft.Identity.Client/Internal/RequestContext.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
using System.Collections.Generic;
66
using System.Security.Cryptography.X509Certificates;
77
using System.Threading;
8+
using System.Threading.Tasks;
89
using Microsoft.Identity.Client.Core;
910
using Microsoft.Identity.Client.Internal.Logger;
11+
using Microsoft.Identity.Client.ManagedIdentity;
1012
using Microsoft.Identity.Client.TelemetryCore;
1113
using Microsoft.Identity.Client.TelemetryCore.Internal.Events;
1214
using Microsoft.Identity.Client.TelemetryCore.TelemetryClient;
@@ -29,6 +31,8 @@ internal class RequestContext
2931

3032
public X509Certificate2 MtlsCertificate { get; }
3133

34+
internal Func<AttestationTokenInput, CancellationToken, Task<AttestationTokenResponse>> AttestationTokenProvider { get; set; }
35+
3236
public RequestContext(IServiceBundle serviceBundle, Guid correlationId, X509Certificate2 mtlsCertificate, CancellationToken cancellationToken = default)
3337
{
3438
ServiceBundle = serviceBundle ?? throw new ArgumentNullException(nameof(serviceBundle));

src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ private async Task<AuthenticationResult> SendTokenRequestForManagedIdentityAsync
200200

201201
_managedIdentityParameters.IsMtlsPopRequested = AuthenticationRequestParameters.IsMtlsPopRequested;
202202

203+
// Ensure the attestation provider reaches RequestContext for IMDSv2
204+
AuthenticationRequestParameters.RequestContext.AttestationTokenProvider ??=
205+
_managedIdentityParameters.AttestationTokenProvider;
206+
203207
ManagedIdentityResponse managedIdentityResponse =
204208
await _managedIdentityClient
205209
.SendTokenRequestForManagedIdentityAsync(AuthenticationRequestParameters.RequestContext, _managedIdentityParameters, cancellationToken)

src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs

Lines changed: 148 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Linq;
77
using System.Net;
88
using System.Net.Http;
9+
using System.Threading;
910
using System.Threading.Tasks;
1011
using Microsoft.Identity.Client.Core;
1112
using Microsoft.Identity.Client.Http;
@@ -31,7 +32,7 @@ public static async Task<CsrMetadata> GetCsrMetadataAsync(
3132
bool probeMode)
3233
{
3334
#if NET462
34-
requestContext.Logger.Info(() => "[Managed Identity] IMDSv2 flow is not supported on .NET Framework 4.6.2. Cryptographic operations required for managed identity authentication are unavailable on this platform. Skipping IMDSv2 probe.");
35+
requestContext.Logger.Info("[Managed Identity] IMDSv2 flow is not supported on .NET Framework 4.6.2. Cryptographic operations required for managed identity authentication are unavailable on this platform. Skipping IMDSv2 probe.");
3536
return await Task.FromResult<CsrMetadata>(null).ConfigureAwait(false);
3637
#else
3738
var queryParams = ImdsV2QueryParamsHelper(requestContext);
@@ -66,7 +67,7 @@ public static async Task<CsrMetadata> GetCsrMetadataAsync(
6667
{
6768
if (probeMode)
6869
{
69-
requestContext.Logger.Info(() => $"[Managed Identity] IMDSv2 CSR endpoint failure. Exception occurred while sending request to CSR metadata endpoint: ${ex}");
70+
requestContext.Logger.Info($"[Managed Identity] IMDSv2 CSR endpoint failure. Exception occurred while sending request to CSR metadata endpoint: {ex}");
7071
return null;
7172
}
7273
else
@@ -187,7 +188,11 @@ internal ImdsV2ManagedIdentitySource(RequestContext requestContext) :
187188
base(requestContext, ManagedIdentitySource.ImdsV2)
188189
{ }
189190

190-
private async Task<CertificateRequestResponse> ExecuteCertificateRequestAsync(string csr)
191+
private async Task<CertificateRequestResponse> ExecuteCertificateRequestAsync(
192+
string clientId,
193+
string attestationEndpoint,
194+
string csr,
195+
ManagedIdentityKeyInfo managedIdentityKeyInfo)
191196
{
192197
var queryParams = ImdsV2QueryParamsHelper(_requestContext);
193198

@@ -199,10 +204,32 @@ private async Task<CertificateRequestResponse> ExecuteCertificateRequestAsync(st
199204
{ OAuth2Header.XMsCorrelationId, _requestContext.CorrelationId.ToString() }
200205
};
201206

207+
if (_isMtlsPopRequested && managedIdentityKeyInfo.Type != ManagedIdentityKeyType.KeyGuard)
208+
{
209+
throw new MsalClientException(
210+
"mtls_pop_requires_keyguard",
211+
"[ImdsV2] mTLS Proof-of-Possession requires a KeyGuard-backed key. Enable KeyGuard or use a KeyGuard-supported environment.");
212+
}
213+
214+
// TODO: : Normalize and validate attestation endpoint Code needs to be removed
215+
// once IMDS team start returning full URI
216+
Uri normalizedEndpoint = NormalizeAttestationEndpoint(attestationEndpoint, _requestContext.Logger);
217+
218+
// Ask helper for JWT only for KeyGuard keys
219+
string attestationJwt = string.Empty;
220+
if (managedIdentityKeyInfo.Type == ManagedIdentityKeyType.KeyGuard)
221+
{
222+
attestationJwt = await GetAttestationJwtAsync(
223+
clientId,
224+
normalizedEndpoint,
225+
managedIdentityKeyInfo,
226+
_requestContext.UserCancellationToken).ConfigureAwait(false);
227+
}
228+
202229
var certificateRequestBody = new CertificateRequestBody()
203230
{
204231
Csr = csr,
205-
// AttestationToken = "fake_attestation_token" TODO: implement attestation token
232+
AttestationToken = attestationJwt
206233
};
207234

208235
string body = JsonHelper.SerializeToJson(certificateRequestBody);
@@ -257,12 +284,21 @@ protected override async Task<ManagedIdentityRequest> CreateRequestAsync(string
257284
{
258285
var csrMetadata = await GetCsrMetadataAsync(_requestContext, false).ConfigureAwait(false);
259286

260-
var keyInfo = await _requestContext.ServiceBundle.PlatformProxy.ManagedIdentityKeyProvider
261-
.GetOrCreateKeyAsync(_requestContext.Logger, _requestContext.UserCancellationToken).ConfigureAwait(false);
287+
IManagedIdentityKeyProvider keyProvider = _requestContext.ServiceBundle.PlatformProxy.ManagedIdentityKeyProvider;
288+
289+
ManagedIdentityKeyInfo keyInfo = await keyProvider
290+
.GetOrCreateKeyAsync(
291+
_requestContext.Logger,
292+
_requestContext.UserCancellationToken)
293+
.ConfigureAwait(false);
262294

263295
var (csr, privateKey) = _requestContext.ServiceBundle.Config.CsrFactory.Generate(keyInfo.Key, csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId);
264296

265-
var certificateRequestResponse = await ExecuteCertificateRequestAsync(csr).ConfigureAwait(false);
297+
var certificateRequestResponse = await ExecuteCertificateRequestAsync(
298+
csrMetadata.ClientId,
299+
csrMetadata.AttestationEndpoint,
300+
csr,
301+
keyInfo).ConfigureAwait(false);
266302

267303
// transform certificateRequestResponse.Certificate to x509 with private key
268304
var mtlsCertificate = CommonCryptographyManager.AttachPrivateKeyToCert(
@@ -302,12 +338,117 @@ private static string ImdsV2QueryParamsHelper(RequestContext requestContext)
302338
requestContext.ServiceBundle.Config.ManagedIdentityId.IdType,
303339
requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId,
304340
requestContext.Logger);
341+
305342
if (userAssignedIdQueryParam != null)
306343
{
307344
queryParams += $"&{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}";
308345
}
309346

310347
return queryParams;
311348
}
349+
350+
/// <summary>
351+
/// Obtains an attestation JWT for the KeyGuard/CSR payload using the configured
352+
/// attestation provider and normalized endpoint.
353+
/// </summary>
354+
/// <param name="clientId">Client ID to be sent to the attestation provider.</param>
355+
/// <param name="attestationEndpoint">The attestation endpoint.</param>
356+
/// <param name="keyInfo">The key information.</param>
357+
/// <param name="cancellationToken">Cancellation token.</param>
358+
/// <returns>JWT string suitable for the IMDSv2 attested POP flow.</returns>
359+
/// <exception cref="MsalClientException">Wraps client/network failures.</exception>
360+
361+
private async Task<string> GetAttestationJwtAsync(
362+
string clientId,
363+
Uri attestationEndpoint,
364+
ManagedIdentityKeyInfo keyInfo,
365+
CancellationToken cancellationToken)
366+
{
367+
// Provider is a local dependency; missing provider is a client error
368+
var provider = _requestContext.AttestationTokenProvider;
369+
370+
// KeyGuard requires RSACng on Windows
371+
if (keyInfo.Type == ManagedIdentityKeyType.KeyGuard &&
372+
keyInfo.Key is not System.Security.Cryptography.RSACng rsaCng)
373+
{
374+
throw new MsalClientException(
375+
"keyguard_requires_cng",
376+
"[ImdsV2] KeyGuard attestation currently supports only RSA CNG keys on Windows.");
377+
}
378+
379+
// Attestation token input
380+
var input = new AttestationTokenInput
381+
{
382+
ClientId = clientId,
383+
AttestationEndpoint = attestationEndpoint,
384+
KeyHandle = (keyInfo.Key as System.Security.Cryptography.RSACng)?.Key.Handle
385+
};
386+
387+
// response from provider
388+
var response = await provider(input, cancellationToken).ConfigureAwait(false);
389+
390+
// Validate response
391+
if (response == null || string.IsNullOrWhiteSpace(response.AttestationToken))
392+
{
393+
throw new MsalClientException(
394+
"attestation_failed",
395+
"[ImdsV2] Attestation provider failed to return an attestation token.");
396+
}
397+
398+
// Return the JWT
399+
return response.AttestationToken;
400+
}
401+
402+
//To-do : Remove this method once IMDS team start returning full URI
403+
/// <summary>
404+
/// Temporarily normalize attestation endpoint values to a full https:// URI.
405+
/// IMDS team will eventually return a full URI.
406+
/// </summary>
407+
/// <param name="rawEndpoint"></param>
408+
/// <param name="logger"></param>
409+
/// <returns></returns>
410+
private static Uri NormalizeAttestationEndpoint(string rawEndpoint, ILoggerAdapter logger)
411+
{
412+
if (string.IsNullOrWhiteSpace(rawEndpoint))
413+
{
414+
return null;
415+
}
416+
417+
// Trim whitespace
418+
rawEndpoint = rawEndpoint.Trim();
419+
420+
// If it already parses as an absolute URI with https, keep it.
421+
if (Uri.TryCreate(rawEndpoint, UriKind.Absolute, out var absolute) &&
422+
(absolute.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase)))
423+
{
424+
return absolute;
425+
}
426+
427+
// If it has no scheme (common service behavior returning only host)
428+
// prepend https:// and try again.
429+
if (!rawEndpoint.StartsWith("https://", StringComparison.OrdinalIgnoreCase))
430+
{
431+
var candidate = "https://" + rawEndpoint;
432+
if (Uri.TryCreate(candidate, UriKind.Absolute, out var httpsUri))
433+
{
434+
logger.Info(() => $"[Managed Identity] Normalized attestation endpoint '{rawEndpoint}' -> '{httpsUri.ToString()}'.");
435+
return httpsUri;
436+
}
437+
}
438+
439+
// Final attempt: reject http (non‑TLS) or malformed
440+
if (Uri.TryCreate(rawEndpoint, UriKind.Absolute, out var anyUri))
441+
{
442+
if (!anyUri.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase))
443+
{
444+
logger.Warning($"[Managed Identity] Attestation endpoint uses unsupported scheme '{anyUri.Scheme}'. HTTPS is required.");
445+
return null;
446+
}
447+
return anyUri;
448+
}
449+
450+
logger.Warning($"[Managed Identity] Failed to normalize attestation endpoint value '{rawEndpoint}'.");
451+
return null;
452+
}
312453
}
313454
}

0 commit comments

Comments
 (0)