diff --git a/Directory.Packages.props b/Directory.Packages.props index 25bc8e28bc..5de1e62bc8 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -19,11 +19,13 @@ + + @@ -73,7 +75,6 @@ - @@ -82,6 +83,5 @@ - diff --git a/LibsAndSamples.sln b/LibsAndSamples.sln index c82dabb247..ee3f05d15d 100644 --- a/LibsAndSamples.sln +++ b/LibsAndSamples.sln @@ -194,6 +194,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MacMauiAppWithBroker", "tes EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MacConsoleAppWithBroker", "tests\devapps\MacConsoleAppWithBroker\MacConsoleAppWithBroker.csproj", "{DBD18BC8-72E4-47D4-BD79-8DEBD9F2C0D0}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.Identity.Client.MtlsPop", "src\client\Microsoft.Identity.Client.MtlsPop\Microsoft.Identity.Client.MtlsPop.csproj", "{3E1C29E5-6E67-D9B2-28DF-649A609937A2}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug + MobileApps|Any CPU = Debug + MobileApps|Any CPU @@ -1987,6 +1989,48 @@ Global {DBD18BC8-72E4-47D4-BD79-8DEBD9F2C0D0}.Release|x64.Build.0 = Release|Any CPU {DBD18BC8-72E4-47D4-BD79-8DEBD9F2C0D0}.Release|x86.ActiveCfg = Release|Any CPU {DBD18BC8-72E4-47D4-BD79-8DEBD9F2C0D0}.Release|x86.Build.0 = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|Any CPU.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|Any CPU.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|ARM.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|ARM.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|ARM64.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|ARM64.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|iPhone.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|iPhone.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|iPhoneSimulator.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|iPhoneSimulator.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|x64.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|x64.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|x86.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug + MobileApps|x86.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|ARM.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|ARM.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|ARM64.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|ARM64.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|iPhone.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|iPhone.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|iPhoneSimulator.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|iPhoneSimulator.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|x64.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|x64.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|x86.ActiveCfg = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Debug|x86.Build.0 = Debug|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|Any CPU.Build.0 = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|ARM.ActiveCfg = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|ARM.Build.0 = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|ARM64.ActiveCfg = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|ARM64.Build.0 = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|iPhone.ActiveCfg = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|iPhone.Build.0 = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|iPhoneSimulator.ActiveCfg = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|iPhoneSimulator.Build.0 = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|x64.ActiveCfg = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|x64.Build.0 = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|x86.ActiveCfg = Release|Any CPU + {3E1C29E5-6E67-D9B2-28DF-649A609937A2}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -2045,6 +2089,7 @@ Global {97995B86-AA0F-3AF9-DA40-85A6263E4391} = {9B0B5396-4D95-4C15-82ED-DC22B5A3123F} {AEF6BB00-931F-4638-955D-24D735625C34} = {34BE693E-3496-45A4-B1D2-D3A0E068EEDB} {DBD18BC8-72E4-47D4-BD79-8DEBD9F2C0D0} = {34BE693E-3496-45A4-B1D2-D3A0E068EEDB} + {3E1C29E5-6E67-D9B2-28DF-649A609937A2} = {1A37FD75-94E9-4D6F-953A-0DABBD7B49E9} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {020399A9-DC27-4B82-9CAA-EF488665AC27} diff --git a/build/template-pack-and-sign-all-nugets.yaml b/build/template-pack-and-sign-all-nugets.yaml index 7566b7dde0..0b1cfdf5e3 100644 --- a/build/template-pack-and-sign-all-nugets.yaml +++ b/build/template-pack-and-sign-all-nugets.yaml @@ -44,6 +44,13 @@ steps: ProjectRootPath: '$(Build.SourcesDirectory)\$(MsalSourceDir)src\client' AssemblyName: 'Microsoft.Identity.Client.Extensions.Msal' +# Sign binary and pack Microsoft.Identity.Client.MtlsPop +- template: template-pack-and-sign-nuget.yaml + parameters: + BuildConfiguration: ${{ parameters.BuildConfiguration }} + ProjectRootPath: '$(Build.SourcesDirectory)\$(MsalSourceDir)src\client' + AssemblyName: 'Microsoft.Identity.Client.MtlsPop' + # Copy all packages out to staging - task: CopyFiles@2 displayName: 'Copy Files to: $(Build.ArtifactStagingDirectory)\packages' diff --git a/build/template-run-mi-e2e-azurearc.yaml b/build/template-run-mi-e2e-azurearc.yaml index 62c11d9222..82a8dcce0c 100644 --- a/build/template-run-mi-e2e-azurearc.yaml +++ b/build/template-run-mi-e2e-azurearc.yaml @@ -37,4 +37,4 @@ steps: codeCoverageEnabled: false failOnMinTestsNotRun: true minimumExpectedTests: '1' - testFiltercriteria: 'TestCategory=MI_E2E_AzureArc' + testFiltercriteria: '(TestCategory=MI_E2E_AzureArc|TestCategory=MI_E2E_KeyAcquisition_KeyGuard)' diff --git a/build/template-run-mi-e2e-imds.yaml b/build/template-run-mi-e2e-imds.yaml index 3beb42d030..c98eda0ed3 100644 --- a/build/template-run-mi-e2e-imds.yaml +++ b/build/template-run-mi-e2e-imds.yaml @@ -38,4 +38,4 @@ steps: runInParallel: false failOnMinTestsNotRun: true minimumExpectedTests: '1' - testFiltercriteria: 'TestCategory=MI_E2E_Imds' + testFiltercriteria: '(TestCategory=MI_E2E_Imds|TestCategory=MI_E2E_KeyAcquisition_Hardware)' diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationClient.cs b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationClient.cs new file mode 100644 index 0000000000..4b8dd68631 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationClient.cs @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Runtime.InteropServices; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.Identity.Client.MtlsPop.Attestation +{ + /// + /// Managed façade for AttestationClientLib.dll. Holds initialization state, + /// does ref-count hygiene on , and returns a JWT. + /// + internal sealed class AttestationClient : IDisposable + { + private bool _initialized; + + /// + /// AttestationClient constructor. Pro-actively verifies the native DLL. + /// + /// + public AttestationClient() + { + /* step 0 ── ensure the resolver probes all valid locations + (env override → app base → System32/SysWOW64 → PATH) */ + NativeDllResolver.EnsureLoaded(); + + /* step 1 ── optional proactive verification (non-fatal) + Keep the probe for diagnostics, but do NOT throw here; if the DLL + is truly unavailable/mismatched, InitAttestationLib will fail. */ + string dllError = NativeDiagnostics.ProbeNativeDll(); + // intentionally not throwing on dllError to avoid path-specific false negatives + + /* step 2 ── load & initialize (logger is required by native lib) */ + var info = new AttestationClientLib.AttestationLogInfo + { + Log = AttestationLogger.ConsoleLogger, // minimal rooted delegate; works on netstandard2.0 & net8.0 + Ctx = IntPtr.Zero + }; + + _initialized = AttestationClientLib.InitAttestationLib(ref info) == 0; + if (!_initialized) + throw new InvalidOperationException("Failed to initialize AttestationClientLib."); + } + + /// + /// Calls the native AttestKeyGuardImportKey and returns a structured result. + /// + public AttestationResult Attest(string endpoint, + SafeNCryptKeyHandle keyHandle, + string clientId) + { + if (!_initialized) + return new(AttestationStatus.NotInitialized, null, -1, + "Native library not initialized."); + + IntPtr buf = IntPtr.Zero; + bool addRef = false; + + try + { + keyHandle.DangerousAddRef(ref addRef); + + int rc = AttestationClientLib.AttestKeyGuardImportKey( + endpoint, null, null, keyHandle, out buf, clientId); + + if (rc != 0) + return new(AttestationStatus.NativeError, null, rc, null); + + if (buf == IntPtr.Zero) + return new(AttestationStatus.TokenEmpty, null, 0, + "rc==0 but token buffer was null."); + + string jwt = Marshal.PtrToStringAnsi(buf)!; + return new(AttestationStatus.Success, jwt, 0, null); + } + catch (DllNotFoundException ex) + { + return new(AttestationStatus.Exception, null, -1, + $"Native DLL not found: {ex.Message}"); + } + catch (BadImageFormatException ex) + { + return new(AttestationStatus.Exception, null, -1, + $"Architecture mismatch (x86/x64) or corrupted DLL: {ex.Message}"); + } + catch (SEHException ex) + { + return new(AttestationStatus.Exception, null, -1, + $"Native library raised SEHException: {ex.Message}"); + } + catch (Exception ex) + { + return new(AttestationStatus.Exception, null, -1, ex.Message); + } + finally + { + if (buf != IntPtr.Zero) + AttestationClientLib.FreeAttestationToken(buf); + if (addRef) + keyHandle.DangerousRelease(); + } + } + + /// + /// Disposes the client, releasing any resources and un-initializing the native library. + /// + public void Dispose() + { + if (_initialized) + { + AttestationClientLib.UninitAttestationLib(); + _initialized = false; + } + GC.SuppressFinalize(this); + } + } +} diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationClientLib.cs b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationClientLib.cs new file mode 100644 index 0000000000..df84387024 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationClientLib.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.Win32.SafeHandles; +using System; +using System.IO; +using System.Runtime.InteropServices; + +namespace Microsoft.Identity.Client.MtlsPop.Attestation +{ + internal static class AttestationClientLib + { + internal enum LogLevel { Error, Warn, Info, Debug } + + internal delegate void LogFunc( + IntPtr ctx, string tag, LogLevel lvl, string func, int line, string msg); + + [StructLayout(LayoutKind.Sequential)] + internal struct AttestationLogInfo + { + public LogFunc Log; + public IntPtr Ctx; + } + + [DllImport("AttestationClientLib.dll", CallingConvention = CallingConvention.Cdecl, + CharSet = CharSet.Ansi)] + internal static extern int InitAttestationLib(ref AttestationLogInfo info); + + [DllImport("AttestationClientLib.dll", CallingConvention = CallingConvention.Cdecl, + CharSet = CharSet.Ansi)] + internal static extern int AttestKeyGuardImportKey( + string endpoint, + string authToken, + string clientPayload, + SafeNCryptKeyHandle keyHandle, + out IntPtr token, + string clientId); + + [DllImport("AttestationClientLib.dll", CallingConvention = CallingConvention.Cdecl)] + internal static extern void FreeAttestationToken(IntPtr token); + + [DllImport("AttestationClientLib.dll", CallingConvention = CallingConvention.Cdecl)] + internal static extern void UninitAttestationLib(); + } +} diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationErrors.cs b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationErrors.cs new file mode 100644 index 0000000000..0c47ceed76 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationErrors.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.Identity.Client.MtlsPop.Attestation +{ + internal static class AttestationErrors + { + internal static string Describe(AttestationResultErrorCode rc) => rc switch + { + AttestationResultErrorCode.ERRORCURLINITIALIZATION + => "libcurl failed to initialize (DLL missing or version mismatch).", + AttestationResultErrorCode.ERRORHTTPREQUESTFAILED + => "Could not reach the attestation service (network / proxy?).", + AttestationResultErrorCode.ERRORATTESTATIONFAILED + => "The enclave rejected the evidence (key type / PCR policy).", + AttestationResultErrorCode.ERRORJWTDECRYPTIONFAILED + => "The JWT returned by the service could not be decrypted.", + AttestationResultErrorCode.ERRORLOGGERINITIALIZATION + => "Native logger setup failed (rare).", + _ => rc.ToString() // default: enum name + }; + } +} diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationLogger.cs b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationLogger.cs new file mode 100644 index 0000000000..574a1d1821 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationLogger.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Microsoft.Identity.Client.MtlsPop.Attestation +{ + internal static class AttestationLogger + { + /// + /// Attestation Logger + /// + internal static readonly AttestationClientLib.LogFunc ConsoleLogger = (ctx, tag, lvl, func, line, msg) => + { + try + { + string sTag = ToText(tag); + string sFunc = ToText(func); + string sMsg = ToText(msg); + + var lineText = $"[MtlsPop][{lvl}] {sTag} {sFunc}:{line} {sMsg}"; + + // Default: Trace (respects listeners; safe for all app types) + Trace.WriteLine(lineText); + + // Opt-in console mirroring for local debugging + if (Environment.GetEnvironmentVariable("MSAL_MTLSPOP_LOG_TO_CONSOLE") == "1") + { + Console.WriteLine(lineText); + } + } + catch + { + } + }; + + // Converts either string or IntPtr (char*) to text. Works with any LogFunc variant. + private static string ToText(object value) + { + if (value is IntPtr p && p != IntPtr.Zero) + { + try + { return Marshal.PtrToStringAnsi(p) ?? string.Empty; } + catch { return string.Empty; } + } + return value?.ToString() ?? string.Empty; + } + } +} diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationResult.cs b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationResult.cs new file mode 100644 index 0000000000..67e1dfd071 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationResult.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.Identity.Client.MtlsPop.Attestation +{ + /// + /// AttestationResult is the result of an attestation operation. + /// + /// + /// + /// + /// + internal sealed record AttestationResult( + AttestationStatus Status, + string Jwt, + int NativeErrorCode, + string ErrorMessage); +} diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationResultErrorCode.cs b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationResultErrorCode.cs new file mode 100644 index 0000000000..4f02375292 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationResultErrorCode.cs @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.Identity.Client.MtlsPop.Attestation +{ + /// + /// Error codes returned by AttestationClientLib.dll. + /// A value of (0) indicates success; all other + /// values are negative and represent specific failure categories. + /// + internal enum AttestationResultErrorCode + { + /// The operation completed successfully. + SUCCESS = 0, + + /// libcurl could not be initialized inside the native library. + ERRORCURLINITIALIZATION = -1, + + /// The HTTP response body could not be parsed (malformed JSON, invalid JWT, etc.). + ERRORRESPONSEPARSING = -2, + + /// Managed-Identity (MSI) access token could not be obtained. + ERRORMSITOKENNOTFOUND = -3, + + /// The HTTP request exceeded the maximum retry count configured by the native client. + ERRORHTTPREQUESTEXCEEDEDRETRIES = -4, + + /// An HTTP request to the attestation service failed (network error, non-200 status, timeout, etc.). + ERRORHTTPREQUESTFAILED = -5, + + /// The attestation enclave rejected the supplied evidence (policy or signature failure). + ERRORATTESTATIONFAILED = -6, + + /// libcurl reported “couldn’t send” (DNS resolution, TLS handshake, or socket error). + ERRORSENDINGCURLREQUESTFAILED = -7, + + /// One or more input parameters passed to the native API were invalid or null. + ERRORINVALIDINPUTPARAMETER = -8, + + /// Validation of the attestation parameters failed on the client side. + ERRORATTESTATIONPARAMETERSVALIDATIONFAILED = -9, + + /// Native client failed to allocate heap memory. + ERRORFAILEDMEMORYALLOCATION = -10, + + /// Could not retrieve OS build / version information required for the attestation payload. + ERRORFAILEDTOGETOSINFO = -11, + + /// Internal TPM failure while gathering quotes or PCR values. + ERRORTPMINTERNALFAILURE = -12, + + /// TPM operation (e.g., signing the quote) failed. + ERRORTPMOPERATIONFAILURE = -13, + + /// The returned JWT could not be decrypted on the client. + ERRORJWTDECRYPTIONFAILED = -14, + + /// JWT decryption failed due to a TPM error. + ERRORJWTDECRYPTIONTPMERROR = -15, + + /// JSON in the service response was invalid or lacked required fields. + ERRORINVALIDJSONRESPONSE = -16, + + /// The VCEK certificate blob returned from the service was empty. + ERROREMPTYVCEKCERT = -17, + + /// The service response body was empty. + ERROREMPTYRESPONSE = -18, + + /// The HTTP request body generated by the client was empty. + ERROREMPTYREQUESTBODY = -19, + + /// Failed to parse the host-configuration-level (HCL) report. + ERRORHCLREPORTPARSINGFAILURE = -20, + + /// The retrieved HCL report was empty. + ERRORHCLREPORTEMPTY = -21, + + /// Could not extract JWK information from the attestation evidence. + ERROREXTRACTINGJWKINFO = -22, + + /// Failed converting a JWK structure to an RSA public key. + ERRORCONVERTINGJWKTORSAPUB = -23, + + /// EVP initialization for RSA encryption failed (OpenSSL). + ERROREVPPKEYENCRYPTINITFAILED = -24, + + /// EVP encryption failed when building the attestation claim. + ERROREVPPKEYENCRYPTFAILED = -25, + + /// Failed to decrypt data due to a TPM error. + ERRORDATADECRYPTIONTPMERROR = -26, + + /// Parsing DNS information for the attestation service endpoint failed. + ERRORPARSINGDNSINFO = -27, + + /// Failed to parse the attestation response envelope. + ERRORPARSINGATTESTATIONRESPONSE = -28, + + /// Provisioning of the Attestation Key (AK) certificate failed. + ERRORAKCERTPROVISIONINGFAILED = -29, + + /// Initialising the native attestation client failed. + ERRORCLIENTINITFAILED = -30, + + /// The service returned an empty JWT. + ERROREMPTYJWTRESPONSE = -31, + + /// Creating the KeyGuard attestation report failed on the client. + ERRORCREATEKGATTESTATIONREPORT = -32, + + /// Failed to extract the public key from the import-only key. + ERROREXTRACTIMPORTKEYPUB = -33, + + /// An unexpected C++ exception occurred inside the native client. + ERRORUNEXPECTEDEXCEPTION = -34, + + /// Initialising the native logger failed (file I/O / permissions / path issues). + ERRORLOGGERINITIALIZATION = -35 + } +} diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationStatus.cs b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationStatus.cs new file mode 100644 index 0000000000..ff20df8aa9 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/AttestationStatus.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.Identity.Client.MtlsPop.Attestation +{ + /// + /// High-level outcome categories returned by . + /// + internal enum AttestationStatus + { + /// Everything succeeded; is populated. + Success = 0, + + /// Native library returned a non-zero AttestationResultErrorCode. + NativeError = 1, + + /// rc == 0 but the token buffer was null/empty. + TokenEmpty = 2, + + /// could not initialize the native DLL. + NotInitialized = 3, + + /// Any managed exception thrown while attempting the call. + Exception = 4 + } +} diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/NativeDiagnostics.cs b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/NativeDiagnostics.cs new file mode 100644 index 0000000000..9482039c8e --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/NativeDiagnostics.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ComponentModel; +using System.IO; + +namespace Microsoft.Identity.Client.MtlsPop.Attestation +{ + internal static class NativeDiagnostics + { + private const string NativeDll = "AttestationClientLib.dll"; + + internal static string ProbeNativeDll() + { + string path = Path.Combine(AppContext.BaseDirectory, NativeDll); + + if (!File.Exists(path)) + return $"Native DLL not found at: {path}"; + + IntPtr h; + + try + { + h = WindowsDllLoader.Load(path); + } + catch (Win32Exception w32) + { + return w32.NativeErrorCode switch + { + 193 or 216 => $"{NativeDll} is the wrong architecture for this process.", + 126 => $"{NativeDll} found but one of its dependencies is missing (libcurl, OpenSSL, or VC++ runtime).", + _ => $"{NativeDll} could not be loaded (Win32 error 0x{w32.NativeErrorCode:X})." + }; + } + catch (Exception ex) + { + return $"Unable to load {NativeDll}: {ex.Message}"; + } + + // success – unload and return null (meaning “no error”) + WindowsDllLoader.Free(h); + return null; + } + } +} diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/NativeDllResolver.cs b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/NativeDllResolver.cs new file mode 100644 index 0000000000..8a127e461d --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/NativeDllResolver.cs @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Runtime.InteropServices; +using System.Runtime.Versioning; + +namespace Microsoft.Identity.Client.MtlsPop.Attestation +{ + /// + /// Ensures AttestationClientLib.dll is resolved from an override path, the app folder, + /// the system directories (System32/SysWOW64), or the default DLL search order (PATH). + /// + internal static class NativeDllResolver + { + private const string NativeDll = "AttestationClientLib.dll"; + private static IntPtr s_module; + + static NativeDllResolver() + { + // 1) Env override (per-job / per-process) + if (TryLoadFromEnv()) + return; + + // 2) App base directory + if (TryLoadFromAppBase()) + return; + + // 3) System directory (System32 for x64 process, SysWOW64 for x86 process) + if (TryLoadFromSystemDir()) + return; + + // 4) Let Windows search PATH / SxS / Known DLL dirs + s_module = WindowsDllLoader.Load(NativeDll); + } + + /// Touch this method from startup code to trigger the static ctor. + internal static void EnsureLoaded() { } + + private static bool TryLoadFromEnv() + { + var overrideDir = Environment.GetEnvironmentVariable("MSAL_MTLSPOP_NATIVE_PATH"); + if (string.IsNullOrWhiteSpace(overrideDir)) + { + return false; + } + + var candidate = Path.Combine(overrideDir, NativeDll); + if (!File.Exists(candidate)) + { + return false; + } + + s_module = WindowsDllLoader.Load(candidate); + return s_module != IntPtr.Zero; + } + + private static bool TryLoadFromAppBase() + { + var exePath = Path.Combine(AppContext.BaseDirectory, NativeDll); + if (!File.Exists(exePath)) + { + return false; + } + + s_module = WindowsDllLoader.Load(exePath); + return s_module != IntPtr.Zero; + } + + private static bool TryLoadFromSystemDir() + { + var windowsRoot = Environment.GetFolderPath(Environment.SpecialFolder.Windows); + if (string.IsNullOrEmpty(windowsRoot)) + { + return false; + } + + // x64 process -> System32, x86 process -> SysWOW64 + var sysDir = Path.Combine( + windowsRoot, + Environment.Is64BitProcess ? "System32" : "SysWOW64"); + + var sysPath = Path.Combine(sysDir, NativeDll); + if (!File.Exists(sysPath)) + { + return false; + } + + s_module = WindowsDllLoader.Load(sysPath); + return s_module != IntPtr.Zero; + } + } +} diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/WindowsDllLoader.cs b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/WindowsDllLoader.cs new file mode 100644 index 0000000000..aaee9eadb2 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/Attestation/WindowsDllLoader.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ComponentModel; +using System.Runtime.InteropServices; + +namespace Microsoft.Identity.Client.MtlsPop.Attestation +{ + /// + /// Windows‑only helper that loads a native DLL from an absolute path. + /// + internal static class WindowsDllLoader + { + /// + /// Load the DLL and throw when the OS loader fails. + /// + /// Absolute path to AttestationClientLib.dll + /// Module handle (never zero on success). + /// + /// Thrown when kernel32!LoadLibraryW returns NULL. + /// + [DllImport("kernel32", + EntryPoint = "LoadLibraryW", + CharSet = CharSet.Unicode, + SetLastError = true, + ExactSpelling = true)] + private static extern IntPtr LoadLibraryW(string path); + + internal static IntPtr Load(string path) + { + if (string.IsNullOrEmpty(path)) + throw new ArgumentNullException(nameof(path)); + + IntPtr h = LoadLibraryW(path); + + if (h == IntPtr.Zero) + { + // Preserve Win32 error code for diagnosis + int err = Marshal.GetLastWin32Error(); + + throw new MsalClientException( + "attestationmodule_load_failure", + $"Key Attestation Module load failed " + + $"(error={err}, " + + $"Unable to load {path}"); + } + + return h; + } + + /// + /// Optionally expose a Free helper so callers can unload if needed. + /// + [DllImport("kernel32", SetLastError = true)] + private static extern bool FreeLibrary(IntPtr hModule); + + internal static void Free(IntPtr handle) + { + if (handle != IntPtr.Zero) + FreeLibrary(handle); + } + } +} diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/IsExternalInit.cs b/src/client/Microsoft.Identity.Client.MtlsPop/IsExternalInit.cs new file mode 100644 index 0000000000..dfb6a17acc --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/IsExternalInit.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if NETSTANDARD +namespace System.Runtime.CompilerServices +{ + internal static class IsExternalInit + { + } +} +#endif diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/ManagedIdentityPopExtensions.cs b/src/client/Microsoft.Identity.Client.MtlsPop/ManagedIdentityPopExtensions.cs new file mode 100644 index 0000000000..7185026342 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/ManagedIdentityPopExtensions.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.MtlsPop.Attestation; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.Identity.Client.MtlsPop +{ + /// + /// Registers the mTLS PoP attestation runtime (interop) by installing a provider + /// function into MSAL's internal config. + /// + public static class ManagedIdentityPopExtensions + { + /// + /// App-level registration: tells MSAL how to obtain a KeyGuard/CNG handle + /// and perform attestation to get the JWT needed for mTLS PoP. + /// + public static AcquireTokenForManagedIdentityParameterBuilder WithMtlsProofOfPossession( + this AcquireTokenForManagedIdentityParameterBuilder builder) + { + builder.CommonParameters.IsMtlsPopRequested = true; + AddRuntimeSupport(builder); + return builder; + } + + /// + /// Adds the runtime support by registering the attestation function. + /// + /// + /// + private static void AddRuntimeSupport( + AcquireTokenForManagedIdentityParameterBuilder builder) + { + // Register the "runtime" function that PoP operation will invoke. + builder.CommonParameters.AttestationTokenProvider = + async (req, ct) => + { + // 1) Get the caller-provided KeyGuard/CNG handle + SafeHandle keyHandle = req.KeyHandle; + + // 2) Call the native interop via PopKeyAttestor + AttestationResult attestationResult = await PopKeyAttestor.AttestKeyGuardAsync( + req.AttestationEndpoint.AbsoluteUri, // expects string + keyHandle, + req.ClientId ?? string.Empty, + ct).ConfigureAwait(false); + + // 3) Map to MSAL's internal response + if (attestationResult != null && + attestationResult.Status == AttestationStatus.Success && + !string.IsNullOrWhiteSpace(attestationResult.Jwt)) + { + return new ManagedIdentity.AttestationTokenResponse { AttestationToken = attestationResult.Jwt }; + } + + throw new MsalClientException( + "attestation_failure", + $"Key Attestation failed " + + $"(status={attestationResult?.Status}, " + + $"code={attestationResult?.NativeErrorCode}). {attestationResult?.ErrorMessage}"); + }; + } + } +} diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/Microsoft.Identity.Client.MtlsPop.csproj b/src/client/Microsoft.Identity.Client.MtlsPop/Microsoft.Identity.Client.MtlsPop.csproj new file mode 100644 index 0000000000..281eaacf00 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/Microsoft.Identity.Client.MtlsPop.csproj @@ -0,0 +1,50 @@ + + + + + netstandard2.0 + net8.0 + AnyCPU + + + $(TargetFrameworkNetStandard);$(TargetFrameworkNet) + + + Debug;Release + + + + + $(MsalInternalVersion) + + $(MicrosoftIdentityClientVersion)-preview + + MSAL.NET extension for managed identity proof-of-possession flows + + This package contains binaries needed to use managed identity proof-of-possession (MTLS PoP) flows in applications using MSAL.NET. + + Microsoft Authentication Library Managed Identity MSAL Proof-of-Possession + Microsoft Authentication Library + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/PopKeyAttestor.cs b/src/client/Microsoft.Identity.Client.MtlsPop/PopKeyAttestor.cs new file mode 100644 index 0000000000..f855041bce --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/PopKeyAttestor.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Client.MtlsPop.Attestation; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.Identity.Client.MtlsPop +{ + /// + /// Static facade for attesting a KeyGuard/CNG key and getting a JWT back. + /// Key discovery / rotation is the caller's responsibility. + /// + internal static class PopKeyAttestor + { + /// + /// Asynchronously attests a KeyGuard/CNG key with the remote attestation service and returns a JWT. + /// Wraps the synchronous in a Task.Run so callers can + /// avoid blocking. Cancellation only applies before the native call starts. + /// + /// Attestation service endpoint (required). + /// Valid SafeNCryptKeyHandle (must remain valid for duration of call). + /// Optional client identifier (may be null/empty). + /// Cancellation token (cooperative before scheduling / start). + public static Task AttestKeyGuardAsync( + string endpoint, + SafeHandle keyHandle, + string clientId, + CancellationToken cancellationToken = default) + { + if (keyHandle is null) + throw new ArgumentNullException(nameof(keyHandle)); + + if (string.IsNullOrWhiteSpace(endpoint)) + throw new ArgumentNullException(nameof(endpoint)); + + if (keyHandle.IsInvalid) + throw new ArgumentException("keyHandle is invalid", nameof(keyHandle)); + + var safeNCryptKeyHandle = keyHandle as SafeNCryptKeyHandle + ?? throw new ArgumentException("keyHandle must be a SafeNCryptKeyHandle. Only Windows CNG keys are supported.", nameof(keyHandle)); + + cancellationToken.ThrowIfCancellationRequested(); + + return Task.Run(() => + { + try + { + using var client = new AttestationClient(); + return client.Attest(endpoint, safeNCryptKeyHandle, clientId ?? string.Empty); + } + catch (Exception ex) + { + // Map any managed exception to AttestationStatus.Exception for consistency. + return new AttestationResult(AttestationStatus.Exception, string.Empty, -1, ex.Message); + } + }, cancellationToken); + } + } +} diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/PublicApi/net8.0/PublicAPI.Shipped.txt b/src/client/Microsoft.Identity.Client.MtlsPop/PublicApi/net8.0/PublicAPI.Shipped.txt new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/PublicApi/net8.0/PublicAPI.Shipped.txt @@ -0,0 +1 @@ + diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/PublicApi/net8.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client.MtlsPop/PublicApi/net8.0/PublicAPI.Unshipped.txt new file mode 100644 index 0000000000..63fd8c92c0 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/PublicApi/net8.0/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +Microsoft.Identity.Client.MtlsPop.ManagedIdentityPopExtensions +static Microsoft.Identity.Client.MtlsPop.ManagedIdentityPopExtensions.WithMtlsProofOfPossession(this Microsoft.Identity.Client.AcquireTokenForManagedIdentityParameterBuilder builder) -> Microsoft.Identity.Client.AcquireTokenForManagedIdentityParameterBuilder diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/PublicApi/netstandard2.0/PublicAPI.Shipped.txt b/src/client/Microsoft.Identity.Client.MtlsPop/PublicApi/netstandard2.0/PublicAPI.Shipped.txt new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/PublicApi/netstandard2.0/PublicAPI.Shipped.txt @@ -0,0 +1 @@ + diff --git a/src/client/Microsoft.Identity.Client.MtlsPop/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client.MtlsPop/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt new file mode 100644 index 0000000000..63fd8c92c0 --- /dev/null +++ b/src/client/Microsoft.Identity.Client.MtlsPop/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +Microsoft.Identity.Client.MtlsPop.ManagedIdentityPopExtensions +static Microsoft.Identity.Client.MtlsPop.ManagedIdentityPopExtensions.WithMtlsProofOfPossession(this Microsoft.Identity.Client.AcquireTokenForManagedIdentityParameterBuilder builder) -> Microsoft.Identity.Client.AcquireTokenForManagedIdentityParameterBuilder diff --git a/src/client/Microsoft.Identity.Client/ApiConfig/AcquireTokenForClientParameterBuilder.cs b/src/client/Microsoft.Identity.Client/ApiConfig/AcquireTokenForClientParameterBuilder.cs index d23792ff9f..2d802e43b3 100644 --- a/src/client/Microsoft.Identity.Client/ApiConfig/AcquireTokenForClientParameterBuilder.cs +++ b/src/client/Microsoft.Identity.Client/ApiConfig/AcquireTokenForClientParameterBuilder.cs @@ -9,15 +9,12 @@ using Microsoft.Identity.Client.ApiConfig.Executors; using Microsoft.Identity.Client.ApiConfig.Parameters; using Microsoft.Identity.Client.AuthScheme.PoP; +using Microsoft.Identity.Client.Extensibility; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Internal.ClientCredential; -using Microsoft.Identity.Client.TelemetryCore.Internal.Events; -using Microsoft.Identity.Client.Utils; -using Microsoft.Identity.Client.Extensibility; using Microsoft.Identity.Client.OAuth2; -using System.Security.Cryptography.X509Certificates; -using System.Security.Cryptography; -using System.Text; +using Microsoft.Identity.Client.PlatformsCommon.Shared; +using Microsoft.Identity.Client.TelemetryCore.Internal.Events; namespace Microsoft.Identity.Client { @@ -99,6 +96,25 @@ public AcquireTokenForClientParameterBuilder WithSendX5C(bool withSendX5C) /// The current instance of to enable method chaining. public AcquireTokenForClientParameterBuilder WithMtlsProofOfPossession() { + if (ServiceBundle.Config.IsManagedIdentity) + { + void MtlsNotSupportedForManagedIdentity(string message) + { + throw new MsalClientException( + MsalError.MtlsNotSupportedForManagedIdentity, + message); + } + + if (!DesktopOsHelper.IsWindows()) + { + MtlsNotSupportedForManagedIdentity(MsalErrorMessage.MtlsNotSupportedForNonWindowsMessage); + } + +#if NET462 + MtlsNotSupportedForManagedIdentity(MsalErrorMessage.MtlsNotSupportedForManagedIdentityMessage); +#endif + } + if (ServiceBundle.Config.ClientCredential is CertificateClientCredential certificateCredential) { if (certificateCredential.Certificate == null) diff --git a/src/client/Microsoft.Identity.Client/ApiConfig/AcquireTokenForManagedIdentityParameterBuilder.cs b/src/client/Microsoft.Identity.Client/ApiConfig/AcquireTokenForManagedIdentityParameterBuilder.cs index 7cbf69f999..dd9600b1aa 100644 --- a/src/client/Microsoft.Identity.Client/ApiConfig/AcquireTokenForManagedIdentityParameterBuilder.cs +++ b/src/client/Microsoft.Identity.Client/ApiConfig/AcquireTokenForManagedIdentityParameterBuilder.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.Identity.Client.ApiConfig.Executors; using Microsoft.Identity.Client.ApiConfig.Parameters; +using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.TelemetryCore.Internal.Events; using Microsoft.Identity.Client.Utils; @@ -80,6 +81,7 @@ public AcquireTokenForManagedIdentityParameterBuilder WithClaims(string claims) /// internal override Task ExecuteInternalAsync(CancellationToken cancellationToken) { + ApplyMtlsPopAndAttestation(acquireTokenForManagedIdentityParameters: Parameters, acquireTokenCommonParameters: CommonParameters); return ManagedIdentityApplicationExecutor.ExecuteAsync(CommonParameters, Parameters, cancellationToken); } @@ -93,5 +95,29 @@ internal override ApiEvent.ApiIds CalculateApiEventId() return ApiEvent.ApiIds.AcquireTokenForUserAssignedManagedIdentity; } + + /// + /// TEST HOOK ONLY: Allows unit tests to inject a fake attestation-token provider + /// so we don't hit the real attestation service. Not part of the public API. + /// + internal AcquireTokenForManagedIdentityParameterBuilder WithAttestationProviderForTests( + Func> provider) + { + if (provider is null) + { + throw new ArgumentNullException(nameof(provider)); + } + + CommonParameters.AttestationTokenProvider = provider; + return this; + } + + private static void ApplyMtlsPopAndAttestation( + AcquireTokenCommonParameters acquireTokenCommonParameters, + AcquireTokenForManagedIdentityParameters acquireTokenForManagedIdentityParameters) + { + acquireTokenForManagedIdentityParameters.IsMtlsPopRequested = acquireTokenCommonParameters.IsMtlsPopRequested; + acquireTokenForManagedIdentityParameters.AttestationTokenProvider ??= acquireTokenCommonParameters.AttestationTokenProvider; + } } } diff --git a/src/client/Microsoft.Identity.Client/ApiConfig/Executors/ManagedIdentityExecutor.cs b/src/client/Microsoft.Identity.Client/ApiConfig/Executors/ManagedIdentityExecutor.cs index ff1ba983c6..c6d70db147 100644 --- a/src/client/Microsoft.Identity.Client/ApiConfig/Executors/ManagedIdentityExecutor.cs +++ b/src/client/Microsoft.Identity.Client/ApiConfig/Executors/ManagedIdentityExecutor.cs @@ -44,7 +44,8 @@ public async Task ExecuteAsync( var handler = new ManagedIdentityAuthRequest( ServiceBundle, requestParams, - managedIdentityParameters); + managedIdentityParameters, + _managedIdentityApplication.ManagedIdentityClient); return await handler.RunAsync(cancellationToken).ConfigureAwait(false); } diff --git a/src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenCommonParameters.cs b/src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenCommonParameters.cs index e5bd4fdac8..cbb52f4a88 100644 --- a/src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenCommonParameters.cs +++ b/src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenCommonParameters.cs @@ -13,6 +13,7 @@ using Microsoft.Identity.Client.Extensibility; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Internal.ClientCredential; +using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.TelemetryCore.Internal.Events; using Microsoft.Identity.Client.Utils; using static Microsoft.Identity.Client.Extensibility.AbstractConfidentialClientAcquireTokenParameterBuilderExtension; @@ -39,6 +40,7 @@ internal class AcquireTokenCommonParameters public string FmiPathSuffix { get; internal set; } public string ClientAssertionFmiPath { get; internal set; } public bool IsMtlsPopRequested { get; set; } + internal Func> AttestationTokenProvider { get; set; } internal async Task InitMtlsPopParametersAsync(IServiceBundle serviceBundle, CancellationToken ct) { diff --git a/src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenForManagedIdentityParameters.cs b/src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenForManagedIdentityParameters.cs index a4450c4268..ca9ab69f92 100644 --- a/src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenForManagedIdentityParameters.cs +++ b/src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenForManagedIdentityParameters.cs @@ -5,8 +5,10 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.ManagedIdentity; namespace Microsoft.Identity.Client.ApiConfig.Parameters { @@ -20,6 +22,10 @@ internal class AcquireTokenForManagedIdentityParameters : IAcquireTokenParameter public string RevokedTokenHash { get; set; } + public bool IsMtlsPopRequested { get; set; } + + internal Func> AttestationTokenProvider { get; set; } + public void LogParameters(ILoggerAdapter logger) { if (logger.IsLoggingEnabled(LogLevel.Info)) diff --git a/src/client/Microsoft.Identity.Client/AppConfig/ApplicationConfiguration.cs b/src/client/Microsoft.Identity.Client/AppConfig/ApplicationConfiguration.cs index 6bf57e3161..4e22a2855c 100644 --- a/src/client/Microsoft.Identity.Client/AppConfig/ApplicationConfiguration.cs +++ b/src/client/Microsoft.Identity.Client/AppConfig/ApplicationConfiguration.cs @@ -17,6 +17,8 @@ using Microsoft.Identity.Client.Internal.Broker; using Microsoft.Identity.Client.Internal.ClientCredential; using Microsoft.Identity.Client.Kerberos; +using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.PlatformsCommon.Interfaces; using Microsoft.Identity.Client.UI; using Microsoft.IdentityModel.Abstractions; @@ -127,6 +129,7 @@ public string ClientVersion public Func> AppTokenProvider; internal IRetryPolicyFactory RetryPolicyFactory { get; set; } + internal ICsrFactory CsrFactory { get; set; } #region ClientCredentials diff --git a/src/client/Microsoft.Identity.Client/AppConfig/BaseAbstractApplicationBuilder.cs b/src/client/Microsoft.Identity.Client/AppConfig/BaseAbstractApplicationBuilder.cs index bf083a6101..ea20f7ee20 100644 --- a/src/client/Microsoft.Identity.Client/AppConfig/BaseAbstractApplicationBuilder.cs +++ b/src/client/Microsoft.Identity.Client/AppConfig/BaseAbstractApplicationBuilder.cs @@ -15,6 +15,7 @@ using Microsoft.IdentityModel.Abstractions; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Http.Retry; +using Microsoft.Identity.Client.ManagedIdentity.V2; #if SUPPORTS_SYSTEM_TEXT_JSON using System.Text.Json; @@ -39,6 +40,12 @@ internal BaseAbstractApplicationBuilder(ApplicationConfiguration configuration) { Config.RetryPolicyFactory = new RetryPolicyFactory(); } + + // Ensure the default csr factory is set if the test factory was not provided + if (Config.CsrFactory == null) + { + Config.CsrFactory = new DefaultCsrFactory(); + } } internal ApplicationConfiguration Config { get; } @@ -249,6 +256,17 @@ internal T WithRetryPolicyFactory(IRetryPolicyFactory factory) return (T)this; } + /// + /// Internal only: Allows tests to inject a custom csr factory. + /// + /// The csr factory to use. + /// The builder for chaining. + internal T WithCsrFactory(ICsrFactory factory) + { + Config.CsrFactory = factory; + return (T)this; + } + internal virtual ApplicationConfiguration BuildConfiguration() { ResolveAuthority(); diff --git a/src/client/Microsoft.Identity.Client/AppConfig/ManagedIdentityApplicationBuilder.cs b/src/client/Microsoft.Identity.Client/AppConfig/ManagedIdentityApplicationBuilder.cs index 434d2764ce..3a2b1cf765 100644 --- a/src/client/Microsoft.Identity.Client/AppConfig/ManagedIdentityApplicationBuilder.cs +++ b/src/client/Microsoft.Identity.Client/AppConfig/ManagedIdentityApplicationBuilder.cs @@ -11,6 +11,7 @@ using Microsoft.Identity.Client.AppConfig; using Microsoft.Identity.Client.Extensibility; using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.TelemetryCore; using Microsoft.Identity.Client.TelemetryCore.TelemetryClient; using Microsoft.Identity.Client.Utils; @@ -102,6 +103,34 @@ public ManagedIdentityApplicationBuilder WithClientCapabilities(IEnumerable + /// Sets Extra Query Parameters for the query string in the HTTP authentication request. + /// + /// This parameter will be appended as is to the query string in the HTTP authentication request to the authority + /// as a string of segments of the form key=value separated by an ampersand character. + /// The parameter can be null. + /// The builder to chain the .With methods. + /// This API is experimental and it may change in future versions of the library without a major version increment + [EditorBrowsable(EditorBrowsableState.Never)] + public ManagedIdentityApplicationBuilder WithExtraQueryParameters(IDictionary extraQueryParameters) + { + ValidateUseOfExperimentalFeature(); + + if (Config.ExtraQueryParameters == null) + { + Config.ExtraQueryParameters = extraQueryParameters; + } + else + { + foreach (var kvp in extraQueryParameters) + { + Config.ExtraQueryParameters[kvp.Key] = kvp.Value; // This will overwrite if key exists, or add if new + } + } + + return this; + } + /// /// Builds an instance of /// from the parameters set in the . diff --git a/src/client/Microsoft.Identity.Client/ApplicationBase.cs b/src/client/Microsoft.Identity.Client/ApplicationBase.cs index 5608eea888..ef327a92fe 100644 --- a/src/client/Microsoft.Identity.Client/ApplicationBase.cs +++ b/src/client/Microsoft.Identity.Client/ApplicationBase.cs @@ -6,8 +6,17 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Identity.Client.ApiConfig.Parameters; +using Microsoft.Identity.Client.AuthScheme.PoP; +using Microsoft.Identity.Client.Http; +using Microsoft.Identity.Client.Instance; +using Microsoft.Identity.Client.Instance.Discovery; +using Microsoft.Identity.Client.Instance.Oidc; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Internal.Requests; +using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.OAuth2.Throttling; +using Microsoft.Identity.Client.PlatformsCommon.Shared; +using Microsoft.Identity.Client.Region; namespace Microsoft.Identity.Client { @@ -74,5 +83,24 @@ internal static void GuardMobileFrameworks() "See https://aka.ms/msal-net-confidential-availability and https://aka.ms/msal-net-managed-identity for details."); #endif } + + /// + /// Resets the SDKs internal state, such as static caches, to facilitate testing. + /// This API is meant to be used by other SDKs that build on top of MSAL, and only by test code. + /// + public static void ResetStateForTest() + { + NetworkCacheMetadataProvider.ResetStaticCacheForTest(); + RegionManager.ResetStaticCacheForTest(); + OidcRetrieverWithCache.ResetCacheForTest(); + AuthorityManager.ClearValidationCache(); + SingletonThrottlingManager.GetInstance().ResetCache(); + ManagedIdentityClient.ResetSourceForTest(); + AuthorityManager.ClearValidationCache(); + PoPCryptoProviderFactory.Reset(); + + InMemoryPartitionedAppTokenCacheAccessor.ClearStaticCacheForTest(); + InMemoryPartitionedUserTokenCacheAccessor.ClearStaticCacheForTest(); + } } } diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/CsrMetadataProbeRetryPolicy.cs b/src/client/Microsoft.Identity.Client/Http/Retry/CsrMetadataProbeRetryPolicy.cs new file mode 100644 index 0000000000..71de66726d --- /dev/null +++ b/src/client/Microsoft.Identity.Client/Http/Retry/CsrMetadataProbeRetryPolicy.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; + +namespace Microsoft.Identity.Client.Http.Retry +{ + internal class CsrMetadataProbeRetryPolicy : ImdsRetryPolicy + { + protected override bool ShouldRetry(HttpResponse response, Exception exception) + { + return HttpRetryConditions.CsrMetadataProbe(response, exception); + } + } +} diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryCondition.cs b/src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryCondition.cs index be6b1791a0..8b2231cf4a 100644 --- a/src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryCondition.cs +++ b/src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryCondition.cs @@ -62,6 +62,21 @@ public static bool RegionDiscovery(HttpResponse response, Exception exception) return (int)response.StatusCode is not (404 or 408); } + /// + /// Retry policy specific to CSR Metadata Probe. + /// Extends Imds retry policy but excludes 404 status code. + /// + public static bool CsrMetadataProbe(HttpResponse response, Exception exception) + { + if (!Imds(response, exception)) + { + return false; + } + + // If Imds would retry but the status code is 404, don't retry + return (int)response.StatusCode is not 404; + } + /// /// Retry condition for /token and /authorize endpoints /// diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/ImdsRetryPolicy.cs b/src/client/Microsoft.Identity.Client/Http/Retry/ImdsRetryPolicy.cs index 6ed44e0745..8d20fe9f13 100644 --- a/src/client/Microsoft.Identity.Client/Http/Retry/ImdsRetryPolicy.cs +++ b/src/client/Microsoft.Identity.Client/Http/Retry/ImdsRetryPolicy.cs @@ -33,6 +33,11 @@ internal virtual Task DelayAsync(int milliseconds) return Task.Delay(milliseconds); } + protected virtual bool ShouldRetry(HttpResponse response, Exception exception) + { + return HttpRetryConditions.Imds(response, exception); + } + public async Task PauseForRetryAsync(HttpResponse response, Exception exception, int retryCount, ILoggerAdapter logger) { int httpStatusCode = (int)response.StatusCode; @@ -46,7 +51,7 @@ public async Task PauseForRetryAsync(HttpResponse response, Exception exce } // Check if the status code is retriable and if the current retry count is less than max retries - if (HttpRetryConditions.Imds(response, exception) && + if (ShouldRetry(response, exception) && retryCount < _maxRetries) { int retryAfterDelay = httpStatusCode == (int)HttpStatusCode.Gone diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/RetryPolicyFactory.cs b/src/client/Microsoft.Identity.Client/Http/Retry/RetryPolicyFactory.cs index 8b133777b6..e190f1ba4d 100644 --- a/src/client/Microsoft.Identity.Client/Http/Retry/RetryPolicyFactory.cs +++ b/src/client/Microsoft.Identity.Client/Http/Retry/RetryPolicyFactory.cs @@ -18,6 +18,8 @@ public virtual IRetryPolicy GetRetryPolicy(RequestType requestType) return new ImdsRetryPolicy(); case RequestType.RegionDiscovery: return new RegionDiscoveryRetryPolicy(); + case RequestType.CsrMetadataProbe: + return new CsrMetadataProbeRetryPolicy(); default: throw new ArgumentOutOfRangeException(nameof(requestType), requestType, "Unknown request type."); } diff --git a/src/client/Microsoft.Identity.Client/IManagedIdentityApplication.cs b/src/client/Microsoft.Identity.Client/IManagedIdentityApplication.cs index 7cf595ae1c..998cf095f8 100644 --- a/src/client/Microsoft.Identity.Client/IManagedIdentityApplication.cs +++ b/src/client/Microsoft.Identity.Client/IManagedIdentityApplication.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.ComponentModel; using System.Threading.Tasks; +using Microsoft.Identity.Client.ManagedIdentity; namespace Microsoft.Identity.Client { diff --git a/src/client/Microsoft.Identity.Client/Instance/Discovery/INetworkCacheMetadataProvider.cs b/src/client/Microsoft.Identity.Client/Instance/Discovery/INetworkCacheMetadataProvider.cs index 99f37bec7c..f30c5af975 100644 --- a/src/client/Microsoft.Identity.Client/Instance/Discovery/INetworkCacheMetadataProvider.cs +++ b/src/client/Microsoft.Identity.Client/Instance/Discovery/INetworkCacheMetadataProvider.cs @@ -9,6 +9,5 @@ internal interface INetworkCacheMetadataProvider { void AddMetadata(string environment, InstanceDiscoveryMetadataEntry entry); InstanceDiscoveryMetadataEntry GetMetadata(string environment, ILoggerAdapter logger); - void /* for test purposes */ Clear(); } } diff --git a/src/client/Microsoft.Identity.Client/Instance/Discovery/InstanceDiscoveryManager.cs b/src/client/Microsoft.Identity.Client/Instance/Discovery/InstanceDiscoveryManager.cs index a42d8c4103..e24d85b40c 100644 --- a/src/client/Microsoft.Identity.Client/Instance/Discovery/InstanceDiscoveryManager.cs +++ b/src/client/Microsoft.Identity.Client/Instance/Discovery/InstanceDiscoveryManager.cs @@ -37,12 +37,10 @@ internal class InstanceDiscoveryManager : IInstanceDiscoveryManager public InstanceDiscoveryManager( IHttpManager httpManager, - bool /* for test */ shouldClearCaches, InstanceDiscoveryResponse userProvidedInstanceDiscoveryResponse = null, Uri userProvidedInstanceDiscoveryUri = null) : this( httpManager, - shouldClearCaches, userProvidedInstanceDiscoveryResponse != null ? new UserMetadataProvider(userProvidedInstanceDiscoveryResponse) : null, userProvidedInstanceDiscoveryUri, null, null, null, null) @@ -51,7 +49,6 @@ public InstanceDiscoveryManager( public /* public for test */ InstanceDiscoveryManager( IHttpManager httpManager, - bool shouldClearCaches, IUserMetadataProvider userMetadataProvider = null, Uri userProvidedInstanceDiscoveryUri = null, IKnownMetadataProvider knownMetadataProvider = null, @@ -72,12 +69,9 @@ public InstanceDiscoveryManager( userProvidedInstanceDiscoveryUri); _regionDiscoveryProvider = regionDiscoveryProvider ?? - new RegionAndMtlsDiscoveryProvider(_httpManager, shouldClearCaches); + new RegionAndMtlsDiscoveryProvider(_httpManager); - if (shouldClearCaches) - { - _networkCacheMetadataProvider.Clear(); - } + } public InstanceDiscoveryMetadataEntry GetMetadataEntryAvoidNetwork( diff --git a/src/client/Microsoft.Identity.Client/Instance/Discovery/NetworkCacheMetadataProvider.cs b/src/client/Microsoft.Identity.Client/Instance/Discovery/NetworkCacheMetadataProvider.cs index 73b6ce558c..0f82b13554 100644 --- a/src/client/Microsoft.Identity.Client/Instance/Discovery/NetworkCacheMetadataProvider.cs +++ b/src/client/Microsoft.Identity.Client/Instance/Discovery/NetworkCacheMetadataProvider.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using System.Collections.Concurrent; using Microsoft.Identity.Client.Core; @@ -25,7 +26,7 @@ public void AddMetadata(string environment, InstanceDiscoveryMetadataEntry entry s_cache.AddOrUpdate(environment, entry, (_, _) => entry); } - public void Clear() + internal static void ResetStaticCacheForTest() { s_cache.Clear(); } diff --git a/src/client/Microsoft.Identity.Client/Instance/Discovery/RegionAndMtlsDiscoveryProvider.cs b/src/client/Microsoft.Identity.Client/Instance/Discovery/RegionAndMtlsDiscoveryProvider.cs index dd3dfb2e66..a2ec8962e3 100644 --- a/src/client/Microsoft.Identity.Client/Instance/Discovery/RegionAndMtlsDiscoveryProvider.cs +++ b/src/client/Microsoft.Identity.Client/Instance/Discovery/RegionAndMtlsDiscoveryProvider.cs @@ -16,9 +16,9 @@ internal class RegionAndMtlsDiscoveryProvider : IRegionDiscoveryProvider public const string PublicEnvForRegional = "login.microsoft.com"; public const string PublicEnvForRegionalMtlsAuth = "mtlsauth.microsoft.com"; - public RegionAndMtlsDiscoveryProvider(IHttpManager httpManager, bool clearCache) + public RegionAndMtlsDiscoveryProvider(IHttpManager httpManager) { - _regionManager = new RegionManager(httpManager, shouldClearStaticCache: clearCache); + _regionManager = new RegionManager(httpManager); } public async Task GetMetadataAsync(Uri authority, RequestContext requestContext) @@ -55,6 +55,7 @@ public async Task GetMetadataAsync(Uri authority string regionalEnv = GetRegionalizedEnvironment(authority, region, requestContext); return CreateEntry(authority.Host, regionalEnv); } + private static InstanceDiscoveryMetadataEntry CreateEntry(string originalEnv, string regionalEnv) { @@ -68,7 +69,6 @@ private static InstanceDiscoveryMetadataEntry CreateEntry(string originalEnv, st private static string GetRegionalizedEnvironment(Uri authority, string region, RequestContext requestContext) { - string host = authority.Host; if (KnownMetadataProvider.IsPublicEnvironment(host)) diff --git a/src/client/Microsoft.Identity.Client/Instance/Region/RegionManager.cs b/src/client/Microsoft.Identity.Client/Instance/Region/RegionManager.cs index c7cf0a0ed3..74c8118863 100644 --- a/src/client/Microsoft.Identity.Client/Instance/Region/RegionManager.cs +++ b/src/client/Microsoft.Identity.Client/Instance/Region/RegionManager.cs @@ -49,20 +49,13 @@ public RegionInfo(string region, RegionAutodetectionSource regionSource, string public RegionManager( IHttpManager httpManager, - int imdsCallTimeout = 2000, - bool shouldClearStaticCache = false) // for test + int imdsCallTimeout = 2000) // for test { _httpManager = httpManager; _imdsCallTimeoutMs = imdsCallTimeout; - - if (shouldClearStaticCache) - { - s_failedAutoDiscovery = false; - s_autoDiscoveredRegion = null; - s_regionDiscoveryDetails = null; - } } + public async Task GetAzureRegionAsync(RequestContext requestContext) { string azureRegionConfig = requestContext.ServiceBundle.Config.AzureRegion; @@ -107,6 +100,13 @@ public async Task GetAzureRegionAsync(RequestContext requestContext) return azureRegionConfig; } + internal static void ResetStaticCacheForTest() + { + s_failedAutoDiscovery = false; + s_autoDiscoveredRegion = null; + s_regionDiscoveryDetails = null; + } + private static bool IsAutoDiscoveryRequested(string azureRegionConfig) { return string.Equals(azureRegionConfig, ConfidentialClientApplication.AttemptRegionDiscovery); diff --git a/src/client/Microsoft.Identity.Client/Internal/Constants.cs b/src/client/Microsoft.Identity.Client/Internal/Constants.cs index a7f2a9719a..4dbb0a7502 100644 --- a/src/client/Microsoft.Identity.Client/Internal/Constants.cs +++ b/src/client/Microsoft.Identity.Client/Internal/Constants.cs @@ -76,5 +76,7 @@ public static string FormatAdfsWebFingerUrl(string host, string resource) { return $"https://{host}/.well-known/webfinger?rel={DefaultRealm}&resource={resource}"; } + + public const int RsaKeySize = 2048; } } diff --git a/src/client/Microsoft.Identity.Client/Internal/RequestContext.cs b/src/client/Microsoft.Identity.Client/Internal/RequestContext.cs index 4db15591e0..2fb786a6a6 100644 --- a/src/client/Microsoft.Identity.Client/Internal/RequestContext.cs +++ b/src/client/Microsoft.Identity.Client/Internal/RequestContext.cs @@ -5,8 +5,10 @@ using System.Collections.Generic; using System.Security.Cryptography.X509Certificates; using System.Threading; +using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal.Logger; +using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.TelemetryCore; using Microsoft.Identity.Client.TelemetryCore.Internal.Events; using Microsoft.Identity.Client.TelemetryCore.TelemetryClient; @@ -29,6 +31,8 @@ internal class RequestContext public X509Certificate2 MtlsCertificate { get; } + internal Func> AttestationTokenProvider { get; set; } + public RequestContext(IServiceBundle serviceBundle, Guid correlationId, X509Certificate2 mtlsCertificate, CancellationToken cancellationToken = default) { ServiceBundle = serviceBundle ?? throw new ArgumentNullException(nameof(serviceBundle)); diff --git a/src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs b/src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs index 67e9590e6a..3db3707a9f 100644 --- a/src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs +++ b/src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System.Collections.Generic; -using System.Net; using System.Threading; using System.Threading.Tasks; using Microsoft.Identity.Client.ApiConfig.Parameters; @@ -18,17 +17,22 @@ namespace Microsoft.Identity.Client.Internal.Requests internal class ManagedIdentityAuthRequest : RequestBase { private readonly AcquireTokenForManagedIdentityParameters _managedIdentityParameters; + private readonly ManagedIdentityClient _managedIdentityClient; private static readonly SemaphoreSlim s_semaphoreSlim = new SemaphoreSlim(1, 1); private readonly ICryptographyManager _cryptoManager; + private readonly IManagedIdentityKeyProvider _managedIdentityKeyProvider; public ManagedIdentityAuthRequest( IServiceBundle serviceBundle, AuthenticationRequestParameters authenticationRequestParameters, - AcquireTokenForManagedIdentityParameters managedIdentityParameters) + AcquireTokenForManagedIdentityParameters managedIdentityParameters, + ManagedIdentityClient managedIdentityClient) : base(serviceBundle, authenticationRequestParameters, managedIdentityParameters) { _managedIdentityParameters = managedIdentityParameters; + _managedIdentityClient = managedIdentityClient; _cryptoManager = serviceBundle.PlatformProxy.CryptographyManager; + _managedIdentityKeyProvider = serviceBundle.PlatformProxy.ManagedIdentityKeyProvider; } protected override async Task ExecuteAsync(CancellationToken cancellationToken) @@ -91,7 +95,7 @@ protected override async Task ExecuteAsync(CancellationTok logger.Info("[ManagedIdentityRequest] Access token retrieved from cache."); try - { + { var proactivelyRefresh = SilentRequestHelper.NeedsRefresh(cachedAccessTokenItem); // If needed, refreshes token in the background @@ -137,7 +141,7 @@ protected override async Task ExecuteAsync(CancellationTok } private async Task GetAccessTokenAsync( - CancellationToken cancellationToken, + CancellationToken cancellationToken, ILoggerAdapter logger) { AuthenticationResult authResult; @@ -157,7 +161,7 @@ private async Task GetAccessTokenAsync( // 1) ForceRefresh is requested // 2) Proactive refresh is in effect // 3) Claims are present (revocation flow) - if (_managedIdentityParameters.ForceRefresh || + if (_managedIdentityParameters.ForceRefresh || AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo == CacheRefreshReason.ProactivelyRefreshed || !string.IsNullOrEmpty(_managedIdentityParameters.Claims)) { @@ -194,12 +198,15 @@ private async Task SendTokenRequestForManagedIdentityAsync await ResolveAuthorityAsync().ConfigureAwait(false); - ManagedIdentityClient managedIdentityClient = - new ManagedIdentityClient(AuthenticationRequestParameters.RequestContext); + _managedIdentityParameters.IsMtlsPopRequested = AuthenticationRequestParameters.IsMtlsPopRequested; + + // Ensure the attestation provider reaches RequestContext for IMDSv2 + AuthenticationRequestParameters.RequestContext.AttestationTokenProvider ??= + _managedIdentityParameters.AttestationTokenProvider; ManagedIdentityResponse managedIdentityResponse = - await managedIdentityClient - .SendTokenRequestForManagedIdentityAsync(_managedIdentityParameters, cancellationToken) + await _managedIdentityClient + .SendTokenRequestForManagedIdentityAsync(AuthenticationRequestParameters.RequestContext, _managedIdentityParameters, cancellationToken) .ConfigureAwait(false); var msalTokenResponse = MsalTokenResponse.CreateFromManagedIdentityResponse(managedIdentityResponse); diff --git a/src/client/Microsoft.Identity.Client/Internal/ServiceBundle.cs b/src/client/Microsoft.Identity.Client/Internal/ServiceBundle.cs index 4078092116..9c5d6dac5d 100644 --- a/src/client/Microsoft.Identity.Client/Internal/ServiceBundle.cs +++ b/src/client/Microsoft.Identity.Client/Internal/ServiceBundle.cs @@ -21,8 +21,7 @@ namespace Microsoft.Identity.Client.Internal internal class ServiceBundle : IServiceBundle { internal ServiceBundle( - ApplicationConfiguration config, - bool shouldClearCaches = false) + ApplicationConfiguration config) { Config = config; @@ -38,20 +37,13 @@ internal ServiceBundle( HttpTelemetryManager = new HttpTelemetryManager(); InstanceDiscoveryManager = new InstanceDiscoveryManager( - HttpManager, - shouldClearCaches, + HttpManager, config.CustomInstanceDiscoveryMetadata, config.CustomInstanceDiscoveryMetadataUri); WsTrustWebRequestManager = new WsTrustWebRequestManager(HttpManager); ThrottlingManager = SingletonThrottlingManager.GetInstance(); - DeviceAuthManager = config.DeviceAuthManagerForTest ?? PlatformProxy.CreateDeviceAuthManager(); - - if (shouldClearCaches) // for test - { - AuthorityManager.ClearValidationCache(); - PoPCryptoProviderFactory.Reset(); - } + DeviceAuthManager = config.DeviceAuthManagerForTest ?? PlatformProxy.CreateDeviceAuthManager(); } /// diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs index ad2b9a0c17..b86d617ac7 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs @@ -31,9 +31,11 @@ internal abstract class AbstractManagedIdentity protected readonly RequestContext _requestContext; + protected bool _isMtlsPopRequested; + internal const string TimeoutError = "[Managed Identity] Authentication unavailable. The request to the managed identity endpoint timed out."; internal readonly ManagedIdentitySource _sourceType; - + protected AbstractManagedIdentity(RequestContext requestContext, ManagedIdentitySource sourceType) { _requestContext = requestContext; @@ -55,7 +57,9 @@ public virtual async Task AuthenticateAsync( // Convert the scopes to a resource string. string resource = parameters.Resource; - ManagedIdentityRequest request = CreateRequest(resource); + _isMtlsPopRequested = parameters.IsMtlsPopRequested; + + ManagedIdentityRequest request = await CreateRequestAsync(resource).ConfigureAwait(false); // Automatically add claims / capabilities if this MI source supports them if (_sourceType.SupportsClaimsAndCapabilities()) @@ -66,6 +70,10 @@ public virtual async Task AuthenticateAsync( _requestContext.Logger); } + request.AddExtraQueryParams( + _requestContext.ServiceBundle.Config.ExtraQueryParameters, + _requestContext.Logger); + _requestContext.Logger.Info("[Managed Identity] Sending request to managed identity endpoints."); IRetryPolicy retryPolicy = _requestContext.ServiceBundle.Config.RetryPolicyFactory.GetRetryPolicy(request.RequestType); @@ -82,8 +90,8 @@ public virtual async Task AuthenticateAsync( method: HttpMethod.Get, logger: _requestContext.Logger, doNotThrow: true, - mtlsCertificate: null, - validateServerCertificate: GetValidationCallback(), + mtlsCertificate: request.MtlsCertificate, + validateServerCertificate: GetValidationCallback(), cancellationToken: cancellationToken, retryPolicy: retryPolicy).ConfigureAwait(false); } @@ -97,8 +105,8 @@ public virtual async Task AuthenticateAsync( method: HttpMethod.Post, logger: _requestContext.Logger, doNotThrow: true, - mtlsCertificate: null, - validateServerCertificate: GetValidationCallback(), + mtlsCertificate: request.MtlsCertificate, + validateServerCertificate: GetValidationCallback(), cancellationToken: cancellationToken, retryPolicy: retryPolicy) .ConfigureAwait(false); @@ -149,7 +157,7 @@ protected virtual Task HandleResponseAsync( throw exception; } - protected abstract ManagedIdentityRequest CreateRequest(string resource); + protected abstract Task CreateRequestAsync(string resource); protected ManagedIdentityResponse GetSuccessfulResponse(HttpResponse response) { @@ -172,8 +180,8 @@ protected ManagedIdentityResponse GetSuccessfulResponse(HttpResponse response) throw exception; } - if (managedIdentityResponse == null || - managedIdentityResponse.AccessToken.IsNullOrEmpty() || + if (managedIdentityResponse == null || + managedIdentityResponse.AccessToken.IsNullOrEmpty() || managedIdentityResponse.ExpiresOn.IsNullOrEmpty()) { _requestContext.Logger.Error("[Managed Identity] Response is either null or insufficient for authentication."); diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AppServiceManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AppServiceManagedIdentitySource.cs index 10fef4610b..6c8cb95f7f 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AppServiceManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AppServiceManagedIdentitySource.cs @@ -2,12 +2,10 @@ // Licensed under the MIT License. using System; -using System.Collections.Generic; using System.Globalization; -using Microsoft.Identity.Client.ApiConfig.Parameters; +using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; -using Microsoft.Identity.Client.Utils; namespace Microsoft.Identity.Client.ManagedIdentity { @@ -66,7 +64,7 @@ private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger return true; } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override Task CreateRequestAsync(string resource) { ManagedIdentityRequest request = new(System.Net.Http.HttpMethod.Get, _endpoint); @@ -92,7 +90,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) break; } - return request; + return Task.FromResult(request); } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AttestationTokenInput.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AttestationTokenInput.cs new file mode 100644 index 0000000000..a956ae3d4f --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AttestationTokenInput.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + internal sealed class AttestationTokenInput + { + public string ClientId { get; set; } + + public Uri AttestationEndpoint { get; set; } + + /// + /// The key handle of the assymetric algorithm to be attested. Currently, only RSA CNG is supported, + /// available on Windows only, i.e. RSACng.Key.Handle. + /// The handle must remain valid for the duration of the attestation call. + /// + public SafeHandle KeyHandle { get; set; } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AttestationTokenResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AttestationTokenResponse.cs new file mode 100644 index 0000000000..aac307aa8d --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AttestationTokenResponse.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + internal sealed class AttestationTokenResponse + { + public string AttestationToken { get; set; } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs index 8071a13944..ae3048e401 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs @@ -79,7 +79,7 @@ private AzureArcManagedIdentitySource(Uri endpoint, RequestContext requestContex } } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override Task CreateRequestAsync(string resource) { ManagedIdentityRequest request = new ManagedIdentityRequest(System.Net.Http.HttpMethod.Get, _endpoint); @@ -87,7 +87,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) request.QueryParameters["api-version"] = ArcApiVersion; request.QueryParameters["resource"] = resource; - return request; + return Task.FromResult(request); } protected override async Task HandleResponseAsync( @@ -119,7 +119,7 @@ protected override async Task HandleResponseAsync( var authHeaderValue = "Basic " + File.ReadAllText(splitChallenge[1]); - ManagedIdentityRequest request = CreateRequest(parameters.Resource); + ManagedIdentityRequest request = await CreateRequestAsync(parameters.Resource).ConfigureAwait(false); _requestContext.Logger.Verbose(() => "[Managed Identity] Adding authorization header to the request."); request.Headers.Add("Authorization", authHeaderValue); diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs index 63a6eb493c..844458cbce 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs @@ -4,7 +4,7 @@ using System; using System.Globalization; using System.Net.Http; -using Microsoft.Identity.Client.ApiConfig.Parameters; +using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; @@ -74,7 +74,7 @@ private CloudShellManagedIdentitySource(Uri endpoint, RequestContext requestCont } } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override Task CreateRequestAsync(string resource) { ManagedIdentityRequest request = new ManagedIdentityRequest(HttpMethod.Post, _endpoint); @@ -83,7 +83,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) request.BodyParameters.Add("resource", resource); - return request; + return Task.FromResult(request); } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/IManagedIdentityKeyProvider.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/IManagedIdentityKeyProvider.cs new file mode 100644 index 0000000000..ed7f64fdb8 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/IManagedIdentityKeyProvider.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Client.Core; + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + /// + /// Provides managed identity keys for authentication scenarios. + /// Implementations of this interface are responsible for obtaining or creating + /// the best available key type (KeyGuard, Hardware, or InMemory) for managed identity authentication. + /// + internal interface IManagedIdentityKeyProvider + { + /// + /// Gets an existing managed identity key or creates a new one if none exists. + /// The method returns the best available key type based on the provider's capabilities + /// and the current environment. + /// + /// Logger adapter for recording operations and diagnostics. + /// Cancellation token to observe while waiting for the task to complete. + /// + /// A task that represents the asynchronous operation. The task result contains + /// a object with the key, its type, and provider message. + /// + /// + /// Thrown when the operation is canceled via the cancellation token. + /// + Task GetOrCreateKeyAsync(ILoggerAdapter logger, CancellationToken ct); + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs index e4c6384103..b26fef740f 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Collections.Generic; using System.Net; using System.Net.Http; using System.Text; @@ -17,10 +18,11 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class ImdsManagedIdentitySource : AbstractManagedIdentity { // IMDS constants. Docs for IMDS are available here https://docs.microsoft.com/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http - private static readonly Uri s_imdsEndpoint = new("http://169.254.169.254/metadata/identity/oauth2/token"); - + // used in unit tests as well + public const string DefaultImdsBaseEndpoint= "http://169.254.169.254"; private const string ImdsTokenPath = "/metadata/identity/oauth2/token"; - private const string ImdsApiVersion = "2018-02-01"; + public const string ImdsApiVersion = "2018-02-01"; + private const string DefaultMessage = "[Managed Identity] Service request failed."; internal const string IdentityUnavailableError = "[Managed Identity] Authentication unavailable. " + @@ -32,30 +34,19 @@ internal class ImdsManagedIdentitySource : AbstractManagedIdentity private readonly Uri _imdsEndpoint; + private static string s_cachedBaseEndpoint = null; + internal ImdsManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.Imds) { requestContext.Logger.Info(() => "[Managed Identity] Defaulting to IMDS endpoint for managed identity."); - if (!string.IsNullOrEmpty(EnvironmentVariables.PodIdentityEndpoint)) - { - requestContext.Logger.Verbose(() => "[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: " + EnvironmentVariables.PodIdentityEndpoint); - var builder = new UriBuilder(EnvironmentVariables.PodIdentityEndpoint) - { - Path = ImdsTokenPath - }; - _imdsEndpoint = builder.Uri; - } - else - { - requestContext.Logger.Verbose(() => "[Managed Identity] Unable to find AZURE_POD_IDENTITY_AUTHORITY_HOST environment variable for IMDS, using the default endpoint."); - _imdsEndpoint = s_imdsEndpoint; - } + _imdsEndpoint = GetValidatedEndpoint(requestContext.Logger, ImdsTokenPath); requestContext.Logger.Verbose(() => "[Managed Identity] Creating IMDS managed identity source. Endpoint URI: " + _imdsEndpoint); } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override Task CreateRequestAsync(string resource) { ManagedIdentityRequest request = new(HttpMethod.Get, _imdsEndpoint); @@ -81,9 +72,42 @@ protected override ManagedIdentityRequest CreateRequest(string resource) break; } + var userAssignedIdQueryParam = GetUserAssignedIdQueryParam( + _requestContext.ServiceBundle.Config.ManagedIdentityId.IdType, + _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId, + _requestContext.Logger); + if (userAssignedIdQueryParam != null) + { + request.QueryParameters[userAssignedIdQueryParam.Value.Key] = userAssignedIdQueryParam.Value.Value; + } + request.RequestType = RequestType.Imds; - return request; + return Task.FromResult(request); + } + + public static KeyValuePair? GetUserAssignedIdQueryParam( + AppConfig.ManagedIdentityIdType idType, + string userAssignedId, + ILoggerAdapter logger) + { + switch (idType) + { + case AppConfig.ManagedIdentityIdType.ClientId: + logger?.Info("[Managed Identity] Adding user assigned client id to the request."); + return new KeyValuePair(Constants.ManagedIdentityClientId, userAssignedId); + + case AppConfig.ManagedIdentityIdType.ResourceId: + logger?.Info("[Managed Identity] Adding user assigned resource id to the request."); + return new KeyValuePair(Constants.ManagedIdentityResourceIdImds, userAssignedId); + + case AppConfig.ManagedIdentityIdType.ObjectId: + logger?.Info("[Managed Identity] Adding user assigned object id to the request."); + return new KeyValuePair(Constants.ManagedIdentityObjectId, userAssignedId); + + default: + return null; + } } protected override async Task HandleResponseAsync( @@ -152,5 +176,38 @@ internal static string CreateRequestFailedMessage(HttpResponse response, string return messageBuilder.ToString(); } + + public static Uri GetValidatedEndpoint( + ILoggerAdapter logger, + string subPath, + string queryParams = null + ) + { + if (s_cachedBaseEndpoint == null) + { + if (!string.IsNullOrEmpty(EnvironmentVariables.PodIdentityEndpoint)) + { + logger.Verbose(() => "[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: " + EnvironmentVariables.PodIdentityEndpoint); + s_cachedBaseEndpoint = EnvironmentVariables.PodIdentityEndpoint; + } + else + { + logger.Verbose(() => "[Managed Identity] Unable to find AZURE_POD_IDENTITY_AUTHORITY_HOST environment variable for IMDS, using the default endpoint."); + s_cachedBaseEndpoint = DefaultImdsBaseEndpoint; + } + } + + UriBuilder builder = new UriBuilder(s_cachedBaseEndpoint) + { + Path = subPath + }; + + if (!string.IsNullOrEmpty(queryParams)) + { + builder.Query = queryParams; + } + + return builder.Uri; + } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/KeyProviders/InMemoryManagedIdentityKeyProvider.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/KeyProviders/InMemoryManagedIdentityKeyProvider.cs new file mode 100644 index 0000000000..dcc47a7518 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/KeyProviders/InMemoryManagedIdentityKeyProvider.cs @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Security.Cryptography; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.Internal; +using System.Runtime.InteropServices; + +namespace Microsoft.Identity.Client.ManagedIdentity.KeyProviders +{ + /// + /// In-memory RSA key provider for managed identity authentication. + /// + internal sealed class InMemoryManagedIdentityKeyProvider : IManagedIdentityKeyProvider + { + private static readonly SemaphoreSlim s_once = new (1, 1); + private volatile ManagedIdentityKeyInfo _cachedKey; + + /// + /// Asynchronously retrieves or creates an RSA key pair for managed identity authentication. + /// Uses thread-safe caching to ensure only one key is created per provider instance. + /// + /// Logger adapter for recording key creation operations and diagnostics. + /// Cancellation token to support cooperative cancellation of the key creation process. + /// + /// A task that represents the asynchronous operation. The task result contains a + /// with the RSA key, key type, and provider message. + /// + public async Task GetOrCreateKeyAsync( + ILoggerAdapter logger, + CancellationToken ct) + { + // Return cached if available + if (_cachedKey is not null) + { + logger?.Info("[MI][InMemoryKeyProvider] Returning cached key."); + return _cachedKey; + } + + // Ensure only one creation at a time + logger?.Info(() => "[MI][InMemoryKeyProvider] Waiting on creation semaphore."); + await s_once.WaitAsync(ct).ConfigureAwait(false); + + try + { + if (_cachedKey is not null) + { + logger?.Info(() => "[MI][InMemoryKeyProvider] Cached key created while waiting; returning it."); + return _cachedKey; + } + + if (ct.IsCancellationRequested) + { + logger?.Info(() => "[MI][InMemoryKeyProvider] Cancellation requested after entering critical section."); + ct.ThrowIfCancellationRequested(); + } + + logger?.Info(() => "[MI][InMemoryKeyProvider] Starting RSA key creation."); + RSA rsa = null; + string message; + + try + { + rsa = CreateRsaKeyPair(); + message = "In-memory RSA key created for Managed Identity authentication."; + logger?.Info("[MI][InMemoryKeyProvider] RSA key created (2048)."); + } + catch (Exception ex) + { + message = $"Failed to create in-memory RSA key: {ex.GetType().Name} - {ex.Message}"; + logger?.WarningPii( + $"[MI][InMemoryKeyProvider] Exception during RSA creation: {ex}", + $"[MI][InMemoryKeyProvider] Exception during RSA creation: {ex.GetType().Name}"); + } + + _cachedKey = new ManagedIdentityKeyInfo(rsa, ManagedIdentityKeyType.InMemory, message); + + logger?.Info(() => + $"[MI][InMemoryKeyProvider] Caching key. Success={(rsa != null)}. HasMessage={!string.IsNullOrEmpty(message)}."); + + return _cachedKey; + } + finally + { + s_once.Release(); + } + } + + /// + /// Creates a new RSA key pair with 2048-bit key size for cryptographic operations. + /// Uses platform-specific RSA implementations: RSACng on .NET Framework and RSA.Create() on other platforms. + /// + /// + /// An instance configured with a 2048-bit key size. + /// On .NET Framework, returns ; on other platforms, returns the default RSA implementation. + /// + /// + /// This method is public instead of private because it is used in unit tests + /// + public static RSA CreateRsaKeyPair() + { +#if NET462 || NET472 || NET8_0 + // Windows-only TFMs (Framework or -windows TFMs): compile CNG path + return CreateWindowsPersistedRsa(); + +#else + // netstandard2.0 can run anywhere; pick at runtime + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return CreateWindowsPersistedRsa(); // requires CNG package in csproj + } + return CreatePortableRsa(); + +#endif + } + + private static RSA CreatePortableRsa() + { + var rsa = RSA.Create(); + if (rsa.KeySize < Constants.RsaKeySize) + rsa.KeySize = Constants.RsaKeySize; + return rsa; + } + + private static RSA CreateWindowsPersistedRsa() + { + // Persisted CNG key (non-ephemeral) so Schannel can use it for TLS client auth + var creation = new CngKeyCreationParameters + { + ExportPolicy = CngExportPolicies.AllowExport, + KeyCreationOptions = CngKeyCreationOptions.MachineKey, // try machine store first + Provider = CngProvider.MicrosoftSoftwareKeyStorageProvider + }; + + // Persist key length with the key + creation.Parameters.Add( + new CngProperty("Length", BitConverter.GetBytes(Constants.RsaKeySize), CngPropertyOptions.Persist)); + + // Non-null name => persisted; null would be ephemeral (bad for Schannel) + string keyName = "MSAL-MTLS-" + Guid.NewGuid().ToString("N"); + + try + { + var key = CngKey.Create(CngAlgorithm.Rsa, keyName, creation); + return new RSACng(key); + } + catch (CryptographicException) + { + // Some environments disallow MachineKey. Fall back to user profile. + creation.KeyCreationOptions = CngKeyCreationOptions.None; + var key = CngKey.Create(CngAlgorithm.Rsa, keyName, creation); + return new RSACng(key); + } + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/KeyProviders/WindowsCngKeyOperations.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/KeyProviders/WindowsCngKeyOperations.cs new file mode 100644 index 0000000000..3d2af6c0f8 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/KeyProviders/WindowsCngKeyOperations.cs @@ -0,0 +1,294 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Security.Cryptography; +using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.Internal; + +namespace Microsoft.Identity.Client.ManagedIdentity.KeyProviders +{ + /// + /// Provides CNG-backed cryptographic key operations for Windows platforms, supporting both + /// KeyGuard-protected keys (with VBS/TPM integration) and hardware-backed TPM/KSP keys + /// for managed identity authentication scenarios. + /// + /// + /// This class handles two primary key protection mechanisms: + /// + /// KeyGuard: Requires Virtualization Based Security (VBS) and provides enhanced key protection + /// Hardware TPM/KSP: Uses Platform Crypto Provider (PCP) for TPM-backed keys + /// + /// All operations are performed in user scope with silent key access patterns. + /// + internal static class WindowsCngKeyOperations + { + private const string SoftwareKspName = "Microsoft Software Key Storage Provider"; + private const string KeyGuardKeyName = "KeyGuardRSAKey"; + private const string HardwareKeyName = "HardwareRSAKey"; + private const string KeyGuardVirtualIsoProperty = "Virtual Iso"; + private const string VbsNotAvailable = "VBS key isolation is not available"; + + // KeyGuard + per-boot flags + private const CngKeyCreationOptions NCryptUseVirtualIsolationFlag = (CngKeyCreationOptions)0x00020000; + private const CngKeyCreationOptions NCryptUsePerBootKeyFlag = (CngKeyCreationOptions)0x00040000; + + /// + /// Attempts to get or create a KeyGuard-protected RSA key for managed identity operations. + /// This method first tries to open an existing key, and if not found, creates a fresh KeyGuard-protected key. + /// KeyGuard requires VBS (Virtualization Based Security) to be enabled and supported. + /// + /// Logger adapter for diagnostic messages and error reporting + /// When this method returns , contains the RSA instance with the KeyGuard-protected key; + /// when this method returns , this parameter is set to + /// if a KeyGuard-protected RSA key was successfully obtained or created; + /// if KeyGuard is unavailable, VBS is not supported, or the operation failed + /// + /// This method performs the following operations in sequence: + /// + /// Attempts to open an existing KeyGuard key using the software KSP in user scope + /// If the key doesn't exist, creates a new KeyGuard-protected key + /// Validates that the key is actually KeyGuard-protected + /// If validation fails, recreates the key and re-validates + /// Ensures the RSA key size is at least 2048 bits when possible + /// + /// The method gracefully handles scenarios where VBS is disabled or not supported by returning . + /// + /// Thrown when VBS/Core Isolation is not available on the platform + /// Thrown when cryptographic operations fail during key creation or access + public static bool TryGetOrCreateKeyGuard(ILoggerAdapter logger, out RSA rsa) + { + rsa = default(RSA); + + try + { + // Try open by the known name first (Software KSP, user scope, silent) + CngKey key; + try + { + key = CngKey.Open( + KeyGuardKeyName, + new CngProvider(SoftwareKspName), + CngKeyOpenOptions.UserKey | CngKeyOpenOptions.Silent); + } + catch (CryptographicException) + { + // Not found -> create fresh (helper may return null if VBS unavailable) + logger?.Info(() => "[MI][WinKeyProvider] KeyGuard key not found; creating fresh."); + key = CreateFresh(logger); + } + + // If VBS is unavailable, CreateFresh() returns null. Bail out cleanly. + if (key == null) + { + logger?.Info(() => "[MI][WinKeyProvider] KeyGuard unavailable (VBS off or not supported)."); + return false; + } + + // Ensure actually KeyGuard-protected; recreate if not + if (!IsKeyGuardProtected(key)) + { + logger?.Info(() => "[MI][WinKeyProvider] KeyGuard key found but not protected; recreating."); + key.Dispose(); + key = CreateFresh(logger); + + // Check again after recreate; still null or not protected -> give up KeyGuard path + if (key == null || !IsKeyGuardProtected(key)) + { + key?.Dispose(); + logger?.Info(() => "[MI][WinKeyProvider] Unable to obtain a KeyGuard-protected key."); + return false; + } + } + + rsa = new RSACng(key); + if (rsa.KeySize < Constants.RsaKeySize) + { + try + { rsa.KeySize = Constants.RsaKeySize; } + catch { logger?.Info(() => $"[MI][WinKeyProvider] Unable to extend the size of the KeyGuard key to {Constants.RsaKeySize} bits."); } + } + return true; + } + catch (PlatformNotSupportedException) + { + // VBS/Core Isolation not available => KeyGuard unavailable + logger?.Info(() => "[MI][WinKeyProvider] Exception creating KeyGuard key."); + return false; + } + catch (CryptographicException ex) + { + logger?.Info(() => $"[MI][WinKeyProvider] KeyGuard creation failed due to platform limitation. {ex.GetType().Name}: {ex.Message}"); + return false; + } + } + + /// + /// Attempts to get or create a hardware-backed RSA key using the Platform Crypto Provider (PCP) + /// for TPM-based key storage and operations. + /// + /// Logger adapter for diagnostic messages and error reporting + /// When this method returns , contains the RSA instance backed by hardware (TPM); + /// when this method returns , this parameter is set to + /// if a hardware-backed RSA key was successfully obtained or created; + /// if hardware key operations are not available or the operation failed + /// + /// This method performs the following operations: + /// + /// Checks if a hardware key with the predefined name already exists in user scope + /// Opens the existing key if found, or creates a new hardware-backed key if not found + /// Configures the key with non-exportable policy (standard for TPM keys) + /// Ensures the RSA key size is at least 2048 bits when supported by the provider + /// + /// The created keys are stored in user scope and are non-exportable for security reasons. + /// TPM providers typically ignore post-creation key size changes. + /// + /// Thrown when hardware key creation, opening, or configuration fails. + /// The exception's HResult property provides additional diagnostic information + public static bool TryGetOrCreateHardwareRsa(ILoggerAdapter logger, out RSA rsa) + { + rsa = default(RSA); + + try + { + // PCP (TPM) in USER scope + CngProvider provider = new CngProvider(SoftwareKspName); + CngKeyOpenOptions openOpts = CngKeyOpenOptions.UserKey | CngKeyOpenOptions.Silent; + + CngKey key = CngKey.Exists(HardwareKeyName, provider, openOpts) + ? CngKey.Open(HardwareKeyName, provider, openOpts) + : CreateUserPcpRsa(provider, HardwareKeyName); + + rsa = new RSACng(key); + + if (rsa.KeySize < Constants.RsaKeySize) + { + try + { rsa.KeySize = Constants.RsaKeySize; } + catch { logger?.Info(() => $"[MI][WinKeyProvider] Unable to extend the size of the Hardware key to {Constants.RsaKeySize} bits."); } + } + + logger?.Info("[MI][WinKeyProvider] Using Hardware key (RSA, PCP user)."); + return true; + } + catch (CryptographicException e) + { + // Add HResult to make CI diagnostics actionable + logger?.Info(() => "[MI][WinKeyProvider] Hardware key creation/open failed. " + + $"HR=0x{e.HResult:X8}. {e.GetType().Name}: {e.Message}"); + return false; + } + } + + /// + /// Creates a new RSA key using the Platform Crypto Provider (PCP) in user scope + /// with non-exportable policy suitable for TPM-backed operations. + /// + /// The CNG provider to use for key creation (typically PCP for TPM) + /// The name to assign to the created key for future reference + /// A new instance configured for signing operations with 2048-bit key size + /// + /// The created key has the following characteristics: + /// + /// Algorithm: RSA + /// Key size: 2048 bits + /// Usage: Signing operations + /// Export policy: None (non-exportable) + /// Scope: User scope + /// + /// + private static CngKey CreateUserPcpRsa(CngProvider provider, string name) + { + var ckcParams = new CngKeyCreationParameters + { + Provider = provider, + KeyUsage = CngKeyUsages.Signing, + ExportPolicy = CngExportPolicies.None, // non-exportable (expected for TPM) + KeyCreationOptions = CngKeyCreationOptions.None // USER scope + }; + + ckcParams.Parameters.Add(new CngProperty("Length", BitConverter.GetBytes(Constants.RsaKeySize), CngPropertyOptions.None)); + + return CngKey.Create(CngAlgorithm.Rsa, name, ckcParams); + } + + /// + /// Creates a new RSA-2048 Key Guard key. + /// + /// Logger adapter for recording diagnostic information and warnings. + /// + /// A instance protected by Key Guard if VBS is available; + /// otherwise, null if VBS is not supported on the system. + /// + /// + /// This method attempts to create a cryptographic key with hardware-backed security using + /// Virtualization Based Security (VBS). If VBS is not available, the method logs a warning + /// and returns null, allowing the caller to fall back to software-based key storage. + /// + private static CngKey CreateFresh(ILoggerAdapter logger) + { + var ckcParams = new CngKeyCreationParameters + { + Provider = new CngProvider(SoftwareKspName), + KeyUsage = CngKeyUsages.AllUsages, + ExportPolicy = CngExportPolicies.None, + KeyCreationOptions = + CngKeyCreationOptions.OverwriteExistingKey + | NCryptUseVirtualIsolationFlag + | NCryptUsePerBootKeyFlag + }; + + ckcParams.Parameters.Add(new CngProperty("Length", + BitConverter.GetBytes(Constants.RsaKeySize), + CngPropertyOptions.None)); + + try + { + return CngKey.Create(CngAlgorithm.Rsa, KeyGuardKeyName, ckcParams); + } + catch (CryptographicException ex) + when (IsVbsUnavailable(ex)) + { + logger?.Warning( + $"[MI][KeyGuardHelper] {VbsNotAvailable}; falling back to software keys. " + + "Ensure that Virtualization Based Security (VBS) is enabled on this machine " + + "(e.g. Credential Guard, Hyper-V, or Windows Defender Application Guard). " + + "Inner exception: " + ex.Message); + + return null; + } + } + + /// + /// Determines whether the specified CNG key is protected by Key Guard. + /// + /// The CNG key to check for Key Guard protection. + /// true if the key has the Key Guard flag; otherwise, false. + /// + /// This method checks for the presence of the Virtual Iso property on the key, + /// which indicates that the key is protected by hardware-backed security features. + /// + public static bool IsKeyGuardProtected(CngKey key) + { + if (!key.HasProperty(KeyGuardVirtualIsoProperty, CngPropertyOptions.None)) + return false; + + byte[] val = key.GetProperty(KeyGuardVirtualIsoProperty, CngPropertyOptions.None).GetValue(); + return val?.Length > 0 && val[0] != 0; + } + + /// + /// Determines whether a cryptographic exception indicates that VBS is unavailable. + /// + /// The cryptographic exception to examine. + /// true if the exception indicates VBS is not supported; otherwise, false. + private static bool IsVbsUnavailable(CryptographicException ex) + { + // HResult for “NTE_NOT_SUPPORTED” = 0x80890014 + const int NTE_NOT_SUPPORTED = unchecked((int)0x80890014); + + return ex.HResult == NTE_NOT_SUPPORTED || + ex.Message.Contains(VbsNotAvailable); + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/KeyProviders/WindowsManagedIdentityKeyProvider.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/KeyProviders/WindowsManagedIdentityKeyProvider.cs new file mode 100644 index 0000000000..878787a8d5 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/KeyProviders/WindowsManagedIdentityKeyProvider.cs @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Security.Cryptography; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Client.Core; + +namespace Microsoft.Identity.Client.ManagedIdentity.KeyProviders +{ + /// + /// Windows-specific managed identity key provider that implements a hierarchical key selection strategy. + /// Attempts to use the most secure key source available in the following priority order: + /// 1. KeyGuard (CVM/TVM) if available - provides VBS (Virtualization-based Security) isolation + /// 2. Hardware (TPM/KSP via Microsoft Platform Crypto Provider) - hardware-backed keys + /// 3. In-memory fallback - software-based keys stored in memory + /// + /// + /// This provider ensures that only one key creation operation occurs at a time using a semaphore, + /// and caches the created key for subsequent requests to improve performance. + /// + internal sealed class WindowsManagedIdentityKeyProvider : IManagedIdentityKeyProvider + { + private static readonly SemaphoreSlim s_once = new (1, 1); + private volatile ManagedIdentityKeyInfo _cachedKey; + + /// + /// Gets or creates a managed identity key using the best available security mechanism. + /// + /// Logger adapter for recording key creation attempts and results. + /// Cancellation token to cancel the operation if needed. + /// + /// A task that represents the asynchronous key creation operation. + /// The task result contains with the created key and its type. + /// + /// + /// Thrown when the operation is cancelled via the parameter. + /// + /// + /// + /// This method implements a thread-safe, single-creation pattern using a semaphore. + /// If a key has already been created and cached, it returns immediately. + /// + /// + /// The key creation follows this priority order: + /// + /// KeyGuard: Uses VBS isolation for maximum security (RSA-2048) + /// Hardware: Uses TPM or hardware security module (RSA-2048, non-exportable) + /// In-memory: Software fallback when hardware options are unavailable + /// + /// + /// + /// Exceptions during key creation are logged but do not prevent fallback to the next option. + /// Only the final in-memory fallback can throw exceptions that terminate the operation. + /// + /// + public async Task GetOrCreateKeyAsync( + ILoggerAdapter logger, + CancellationToken ct) + { + // Return cached if available + if (_cachedKey != null) + { + logger?.Info("[MI][WinKeyProvider] Returning cached key."); + return _cachedKey; + } + + // Ensure only one creation at a time + logger?.Info(() => "[MI][WinKeyProvider] Waiting on creation semaphore."); + await s_once.WaitAsync(ct).ConfigureAwait(false); + + try + { + if (_cachedKey != null) + { + logger?.Info(() => "[MI][WinKeyProvider] Cached key created while waiting; returning it."); + return _cachedKey; + } + + if (ct.IsCancellationRequested) + { + logger?.Info(() => "[MI][WinKeyProvider] Cancellation requested after entering critical section."); + ct.ThrowIfCancellationRequested(); + } + + var messageBuilder = new StringBuilder(); + + // 1) KeyGuard (RSA-2048 under VBS isolation) + try + { + logger.Info("[MI][WinKeyProvider] Trying KeyGuard key."); + if (WindowsCngKeyOperations.TryGetOrCreateKeyGuard(logger, out RSA kgRsa)) + { + messageBuilder.AppendLine("KeyGuard RSA key created successfully."); + _cachedKey = new ManagedIdentityKeyInfo(kgRsa, ManagedIdentityKeyType.KeyGuard, messageBuilder.ToString()); + logger?.Info("[MI][WinKeyProvider] Using KeyGuard key (RSA)."); + return _cachedKey; + } + else + { + messageBuilder.AppendLine("KeyGuard RSA key creation not available or failed."); + logger?.Info(() => "[MI][WinKeyProvider] KeyGuard key not available."); + } + } + catch (Exception ex) + { + messageBuilder.AppendLine($"KeyGuard RSA key creation threw exception: {ex.GetType().Name}: {ex.Message}"); + logger?.WarningPii( + $"[MI][WinKeyProvider] Exception creating KeyGuard key: {ex}", + $"[MI][WinKeyProvider] Exception creating KeyGuard key: {ex.GetType().Name}"); + } + + // 2) Hardware TPM/KSP (RSA-2048, non-exportable) + try + { + logger?.Info(() => "[MI][WinKeyProvider] Trying Hardware (TPM/KSP) key."); + if (WindowsCngKeyOperations.TryGetOrCreateHardwareRsa(logger, out RSA hwRsa)) + { + messageBuilder.AppendLine("Hardware RSA key created successfully."); + _cachedKey = new ManagedIdentityKeyInfo(hwRsa, ManagedIdentityKeyType.Hardware, messageBuilder.ToString()); + logger?.Info("[MI][WinKeyProvider] Using Hardware key (RSA)."); + return _cachedKey; + } + else + { + messageBuilder.AppendLine("Hardware RSA key creation not available or failed."); + logger?.Info(() => "[MI][WinKeyProvider] Hardware key not available."); + } + } + catch (Exception ex) + { + messageBuilder.AppendLine($"Hardware RSA key creation threw exception: {ex.GetType().Name}: {ex.Message}"); + logger?.WarningPii( + $"[MI][WinKeyProvider] Exception creating Hardware key: {ex}", + $"[MI][WinKeyProvider] Exception creating Hardware key: {ex.GetType().Name}"); + } + + // 3) In-memory fallback (software RSA) + logger?.Info("[MI][WinKeyProvider] Falling back to in-memory RSA key (software)."); + if (ct.IsCancellationRequested) + { + logger?.Info(() => "[MI][WinKeyProvider] Cancellation requested before in-memory fallback."); + ct.ThrowIfCancellationRequested(); + } + + var fallbackIMMIKP = new InMemoryManagedIdentityKeyProvider(); + _cachedKey = await fallbackIMMIKP.GetOrCreateKeyAsync(logger, ct).ConfigureAwait(false); + + if (messageBuilder.Length > 0) + { + logger?.Info(() => "[MI][WinKeyProvider] Fallback reasons:\n" + messageBuilder.ToString().Trim()); + } + + return _cachedKey; + + } + finally + { + s_once.Release(); + } + } + } +} + diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs index e6916fe919..9d2af3cadc 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs @@ -3,7 +3,7 @@ using System; using System.Globalization; -using Microsoft.Identity.Client.ApiConfig.Parameters; +using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; @@ -64,7 +64,7 @@ private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger return true; } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override Task CreateRequestAsync(string resource) { ManagedIdentityRequest request = new(System.Net.Http.HttpMethod.Get, _endpoint); @@ -108,7 +108,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) null); // statusCode is null in this case } - return request; + return Task.FromResult(request); } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs index 80a45bb0da..4e8e7fd8ce 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs @@ -9,6 +9,7 @@ using Microsoft.Identity.Client.PlatformsCommon.Shared; using System.IO; using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.ManagedIdentity.V2; namespace Microsoft.Identity.Client.ManagedIdentity { @@ -19,39 +20,72 @@ internal class ManagedIdentityClient { private const string WindowsHimdsFilePath = "%Programfiles%\\AzureConnectedMachineAgent\\himds.exe"; private const string LinuxHimdsFilePath = "/opt/azcmagent/bin/himds"; - private readonly AbstractManagedIdentity _identitySource; + internal static ManagedIdentitySource s_sourceName = ManagedIdentitySource.None; - public ManagedIdentityClient(RequestContext requestContext) + internal static void ResetSourceForTest() { - using (requestContext.Logger.LogMethodDuration()) - { - _identitySource = SelectManagedIdentitySource(requestContext); - } + s_sourceName = ManagedIdentitySource.None; } - internal Task SendTokenRequestForManagedIdentityAsync(AcquireTokenForManagedIdentityParameters parameters, CancellationToken cancellationToken) + internal async Task SendTokenRequestForManagedIdentityAsync( + RequestContext requestContext, + AcquireTokenForManagedIdentityParameters parameters, + CancellationToken cancellationToken) { - return _identitySource.AuthenticateAsync(parameters, cancellationToken); + AbstractManagedIdentity msi = await GetOrSelectManagedIdentitySourceAsync(requestContext).ConfigureAwait(false); + return await msi.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false); } // This method tries to create managed identity source for different sources, if none is created then defaults to IMDS. - private static AbstractManagedIdentity SelectManagedIdentitySource(RequestContext requestContext) + private async Task GetOrSelectManagedIdentitySourceAsync(RequestContext requestContext) + { + using (requestContext.Logger.LogMethodDuration()) + { + requestContext.Logger.Info($"[Managed Identity] Selecting managed identity source if not cached. Cached value is {s_sourceName} "); + + var source = (s_sourceName != ManagedIdentitySource.None) ? s_sourceName : await GetManagedIdentitySourceAsync(requestContext).ConfigureAwait(false); + return source switch + { + ManagedIdentitySource.ServiceFabric => ServiceFabricManagedIdentitySource.Create(requestContext), + ManagedIdentitySource.AppService => AppServiceManagedIdentitySource.Create(requestContext), + ManagedIdentitySource.MachineLearning => MachineLearningManagedIdentitySource.Create(requestContext), + ManagedIdentitySource.CloudShell => CloudShellManagedIdentitySource.Create(requestContext), + ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext), + ManagedIdentitySource.ImdsV2 => ImdsV2ManagedIdentitySource.Create(requestContext), + _ => new ImdsManagedIdentitySource(requestContext) + }; + } + } + + // Detect managed identity source based on the availability of environment variables and csr metadata probe request. + // This method is perf sensitive any changes should be benchmarked. + internal async Task GetManagedIdentitySourceAsync(RequestContext requestContext) { - return GetManagedIdentitySource(requestContext.Logger) switch + ManagedIdentitySource source = GetManagedIdentitySourceNoImdsV2(requestContext.Logger); + + if (source != ManagedIdentitySource.DefaultToImds) { - ManagedIdentitySource.ServiceFabric => ServiceFabricManagedIdentitySource.Create(requestContext), - ManagedIdentitySource.AppService => AppServiceManagedIdentitySource.Create(requestContext), - ManagedIdentitySource.MachineLearning => MachineLearningManagedIdentitySource.Create(requestContext), - ManagedIdentitySource.CloudShell => CloudShellManagedIdentitySource.Create(requestContext), - ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext), - _ => new ImdsManagedIdentitySource(requestContext) - }; + return source; + } + + // probe IMDSv2 + var response = await ImdsV2ManagedIdentitySource.GetCsrMetadataAsync(requestContext, probeMode: true).ConfigureAwait(false); + if (response != null) + { + requestContext.Logger.Info("[Managed Identity] ImdsV2 detected."); + s_sourceName = ManagedIdentitySource.ImdsV2; + return s_sourceName; + } + + requestContext.Logger.Info("[Managed Identity] IMDSv2 probe failed. Defaulting to IMDSv1."); + s_sourceName = ManagedIdentitySource.DefaultToImds; + return s_sourceName; } // Detect managed identity source based on the availability of environment variables. // The result of this method is not cached because reading environment variables is cheap. // This method is perf sensitive any changes should be benchmarked. - internal static ManagedIdentitySource GetManagedIdentitySource(ILoggerAdapter logger = null) + internal static ManagedIdentitySource GetManagedIdentitySourceNoImdsV2(ILoggerAdapter logger = null) { string identityEndpoint = EnvironmentVariables.IdentityEndpoint; string identityHeader = EnvironmentVariables.IdentityHeader; diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityKeyInfo.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityKeyInfo.cs new file mode 100644 index 0000000000..fd2684e881 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityKeyInfo.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Security.Cryptography; + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + /// + /// Encapsulates information about a Managed Identity key used for authentication. + /// Provides the best available key and its type for Managed Identity scenarios. + /// The caller does not need to know how the key is sourced. + /// + /// Key types: + /// - : Key sourced from KeyGuard provider. + /// - : Key stored in hardware (e.g., TPM). + /// - : Key stored in memory only. + /// + internal sealed class ManagedIdentityKeyInfo + { + public RSA Key { get; } + public ManagedIdentityKeyType Type { get; } + public string ProviderMessage { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The RSA key instance to be used for cryptographic operations. + /// The type of the Managed Identity key indicating its storage method. + /// A message from the key provider with additional information. + public ManagedIdentityKeyInfo(RSA keyInfo, ManagedIdentityKeyType type, string providerMessage) + { + Key = keyInfo; + Type = type; + ProviderMessage = providerMessage; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityKeyProviderFactory.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityKeyProviderFactory.cs new file mode 100644 index 0000000000..b5c757f6e1 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityKeyProviderFactory.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Threading; +using Microsoft.Identity.Client.ManagedIdentity.KeyProviders; +using Microsoft.Identity.Client.PlatformsCommon.Shared; +using Microsoft.Identity.Client.Core; + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + /// + /// Creates (once) and caches the most suitable Managed Identity key provider for the current platform. + /// Thread-safe, lock-free (uses CompareExchange). + /// + /// + /// This factory class uses a singleton pattern with lazy initialization to ensure only one + /// key provider instance is created per application domain. The implementation is thread-safe + /// using to avoid locking overhead. + /// + /// The factory automatically selects the most appropriate key provider based on the current + /// platform capabilities: + /// + /// Windows: Uses WindowsManagedIdentityKeyProvider with CNG support + /// Non-Windows: Falls back to InMemoryManagedIdentityKeyProvider + /// + /// + internal static class ManagedIdentityKeyProviderFactory + { + // Cached singleton instance of the chosen key provider. + private static IManagedIdentityKeyProvider s_provider; + + /// + /// Returns the cached provider if available; otherwise creates it in a thread-safe manner. + /// + /// + /// Logger adapter for recording operations and diagnostics. Can be null. + /// + /// + /// The singleton instance appropriate for the current platform. + /// + /// + /// This method implements the double-checked locking pattern using atomic operations + /// to ensure thread safety without the overhead of explicit locks. If multiple threads + /// call this method concurrently before initialization, only one provider instance + /// will be created and cached. + /// + internal static IManagedIdentityKeyProvider GetOrCreateProvider(ILoggerAdapter logger) + { + // Fast path: read the field once (Volatile ensures latest published value). + IManagedIdentityKeyProvider existing = Volatile.Read(ref s_provider); + + if (existing != null) + { + logger?.Verbose(() => "[MI][KeyProviderFactory] Returning cached key provider instance."); + return existing; + } + + logger?.Verbose(() => "[MI][KeyProviderFactory] Creating key provider instance (first use)."); + IManagedIdentityKeyProvider created = CreateProviderCore(logger); + + // Publish the created instance only if another thread has not already published one. + // If another thread won the race, discard our newly created instance and use theirs. + IManagedIdentityKeyProvider prior = Interlocked.CompareExchange(ref s_provider, created, null); + + if (prior == null) + { + logger?.Info($"[MI][KeyProviderFactory] Key provider created: {created.GetType().Name}."); + return created; + } + + logger?.Verbose(() => "[MI][KeyProviderFactory] Another thread already created the provider; using existing instance."); + return prior; + } + + /// + /// Chooses an implementation based on compile-time and runtime platform capabilities. + /// + /// + /// Logger adapter for recording platform detection and provider selection. Can be null. + /// + /// + /// A new instance suitable for the detected platform. + /// + /// + /// This method performs platform detection and selects the most appropriate key provider: + /// + /// Windows Platform: + /// + /// Detected using + /// Returns WindowsManagedIdentityKeyProvider with CNG support + /// Provides hardware-backed key storage when available + /// + /// + /// Non-Windows Platforms: + /// + /// Includes Linux, macOS, and other Unix-like systems + /// Returns InMemoryManagedIdentityKeyProvider as fallback + /// Keys are stored in memory for the application lifetime + /// + /// + private static IManagedIdentityKeyProvider CreateProviderCore(ILoggerAdapter logger) + { + if (DesktopOsHelper.IsWindows()) + { + logger?.Info("[MI][KeyProviderFactory] Windows detected with CNG support - using Windows managed identity key provider."); + return new WindowsManagedIdentityKeyProvider(); + } + + // Non-Windows OS - we will fall back to in-memory implementation. + logger?.Info("[MI][KeyProviderFactory] Non-Windows platform (with CNG) - using InMemory provider."); + return new InMemoryManagedIdentityKeyProvider(); + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityKeyType.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityKeyType.cs new file mode 100644 index 0000000000..2ab2047eae --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityKeyType.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + /// + /// Specifies the type of key storage mechanism used for managed identity authentication. + /// + internal enum ManagedIdentityKeyType + { + // Represents a key stored using a secure key guard mechanism that provides hardware-level protection. + KeyGuard, + + // Represents a key stored directly in hardware security modules or trusted platform modules. + Hardware, + + // Represents a key stored in memory with software-based protection mechanisms. + InMemory + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs index c5b9af2b73..75c6cf4031 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq; using System.Net.Http; +using System.Security.Cryptography.X509Certificates; using Microsoft.Identity.Client.ApiConfig.Parameters; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.OAuth2; @@ -26,7 +27,13 @@ internal class ManagedIdentityRequest public RequestType RequestType { get; set; } - public ManagedIdentityRequest(HttpMethod method, Uri endpoint, RequestType requestType = RequestType.ManagedIdentityDefault) + public X509Certificate2 MtlsCertificate { get; set; } + + public ManagedIdentityRequest( + HttpMethod method, + Uri endpoint, + RequestType requestType = RequestType.ManagedIdentityDefault, + X509Certificate2 mtlsCertificate = null) { Method = method; _baseEndpoint = endpoint; @@ -34,6 +41,7 @@ public ManagedIdentityRequest(HttpMethod method, Uri endpoint, RequestType reque BodyParameters = new Dictionary(); QueryParameters = new Dictionary(); RequestType = requestType; + MtlsCertificate = mtlsCertificate; } public Uri ComputeUri() @@ -64,5 +72,24 @@ internal void AddClaimsAndCapabilities( logger.Info("[Managed Identity] Passing SHA-256 of the 'revoked' token to Managed Identity endpoint."); } } + + /// + /// Adds extra query parameters to the Managed Identity request. + /// + /// Dictionary containing additional query parameters to append to the request. + /// The parameter can be null. + /// Logger instance for recording the operation. + internal void AddExtraQueryParams(IDictionary extraQueryParameters, ILoggerAdapter logger) + { + if (extraQueryParameters != null) + { + foreach (var kvp in extraQueryParameters) + { + QueryParameters[kvp.Key] = kvp.Value; + } + + logger.Info($"[Managed Identity] Adding {extraQueryParameters.Count} extra query parameters to Managed Identity request."); + } + } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityResponse.cs index 4eddced093..d1fccfaba9 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityResponse.cs @@ -5,6 +5,7 @@ #if SUPPORTS_SYSTEM_TEXT_JSON using Microsoft.Identity.Client.Platforms.net; using JsonProperty = System.Text.Json.Serialization.JsonPropertyNameAttribute; +using JsonIgnore = System.Text.Json.Serialization.JsonIgnoreAttribute; #else using Microsoft.Identity.Json; #endif @@ -29,8 +30,22 @@ internal class ManagedIdentityResponse /// /// The date is represented as the number of seconds from "1970-01-01T0:0:0Z UTC" /// (corresponds to the token's exp claim). + [JsonIgnore] + public string ExpiresOn { get; set; } // The actual property consumers use + [JsonProperty("expires_on")] - public string ExpiresOn { get; set; } + public string ExpiresOnRaw // Proxy for "expires_on" JSON field + { + get => ExpiresOn; // When serializing, return ExpiresOn value + set => ExpiresOn = value; // When deserializing, store in ExpiresOn + } + + [JsonProperty("expires_in")] + public string ExpiresInRaw // Proxy for "expires_in" JSON field + { + get => null; // Never serialize this (return null) + set => ExpiresOn = value; // When deserializing, store in ExpiresOn + } /// /// The resource the access token was requested for. diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs index 69e3471bdf..0b687fe7bb 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs @@ -53,6 +53,11 @@ public enum ManagedIdentitySource /// /// The source to acquire token for managed identity is Machine Learning Service. /// - MachineLearning + MachineLearning, + + /// + /// The source to acquire token for managed identity is IMDSV2. + /// + ImdsV2, } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs index 3224a8f3fe..55b6b28690 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs @@ -6,7 +6,7 @@ using System.Net.Http; using System.Net.Security; using System.Security.Cryptography.X509Certificates; -using Microsoft.Identity.Client.ApiConfig.Parameters; +using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; @@ -75,7 +75,7 @@ private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri en } } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override Task CreateRequestAsync(string resource) { ManagedIdentityRequest request = new ManagedIdentityRequest(HttpMethod.Get, _endpoint); @@ -102,7 +102,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) break; } - return request; + return Task.FromResult(request); } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/AttestationTokenMemoryCache.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/AttestationTokenMemoryCache.cs new file mode 100644 index 0000000000..14928c5904 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/AttestationTokenMemoryCache.cs @@ -0,0 +1,270 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Concurrent; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + /// + /// Phase 1: process-local in-memory cache for attestation tokens. + /// - Key: KeyHandle pointer value + /// - TTL (Time to live): 8 hours (until provider exposes an explicit expiry) + /// - Background refresh: kicks off at half-time (4h) without blocking callers + /// - Thread-safe across callers; no cross-process guarantees (by design for Phase 1) + /// + /// Phase 2 (hand-off notes for persistent cache): + /// - Add an IAttestationTokenCache interface to the provider input + /// - Add a persistent cache implementation + /// - Use a named OS mutex + /// - Persist using the same key (KeyHandle pointer value) for simplicity + /// - needs logging + /// - details around background refresh and process exit needs some discussion + /// + internal static class AttestationTokenMemoryCache + { + // Today MAA does not give expiry info; assume 8h TTL for now. + // We have manually validated this with MAA tokens. + private static readonly TimeSpan s_defaultTtl = TimeSpan.FromHours(8); // provider has no expiry yet + private static readonly TimeSpan s_halfTime = TimeSpan.FromHours(4); // background refresh point + private static readonly TimeSpan s_expirySkew = TimeSpan.FromMinutes(2); + private static readonly TimeSpan s_bgRetryBackoff = TimeSpan.FromMinutes(15); + + // One Entry per key handle value + private static readonly ConcurrentDictionary s_entries = + new ConcurrentDictionary(); + + /// + /// Returns a valid token. If missing/expired, mints via and caches it. + /// If past half-time, returns the current token and schedules a background refresh. + /// + internal static async Task GetOrCreateAsync( + AttestationTokenInput input, + Func> provider, + CancellationToken ct) + { + if (input == null) + throw new ArgumentNullException(nameof(input)); + if (provider == null) + throw new ArgumentNullException(nameof(provider)); + + long key = GetHandleValue(input); + var entry = s_entries.GetOrAdd(key, k => new Entry(k)); + + // Gate all mutations per key + await entry.Gate.WaitAsync(ct).ConfigureAwait(false); + try + { + var now = DateTimeOffset.UtcNow; + + // Happy path: valid token in memory + if (!string.IsNullOrEmpty(entry.Token) && now + s_expirySkew < entry.ExpiresOnUtc) + { + // Past refresh time? Kick a non-blocking background refresh. + if (now >= entry.RefreshOnUtc) + { + KickBackgroundRefresh(entry, input, provider); + } + + return new AttestationTokenResponse { AttestationToken = entry.Token }; + } + + // Miss / expired -> mint synchronously and update cache + var minted = await provider(input, ct).ConfigureAwait(false); + if (minted == null || string.IsNullOrEmpty(minted.AttestationToken)) + { + throw new MsalClientException("attestation_failed", "Attestation provider returned no token."); + } + + var now2 = DateTimeOffset.UtcNow; + entry.Token = minted.AttestationToken; + entry.ExpiresOnUtc = now2 + s_defaultTtl; + entry.RefreshOnUtc = now2 + s_halfTime; + + // Store the refresh factory so background timer can re-mint without caller context. + entry.Mint = ctk => provider(input, ctk); + + // (Re)schedule the per-key timer to fire at RefreshOnUtc + ScheduleTimer(entry); + + return minted; + } + finally + { + entry.Gate.Release(); + } + } + + // ---------------- internals ---------------- + + private static long GetHandleValue(AttestationTokenInput input) + { + try + { + if (input.KeyHandle != null && !input.KeyHandle.IsInvalid) + { + return input.KeyHandle.DangerousGetHandle().ToInt64(); + } + } + catch { /* ignore */ } + return 0L; + } + + private static void KickBackgroundRefresh( + Entry entry, + AttestationTokenInput lastInput, + Func> provider) + { + // Background: do not block the caller thread; dedupe via Gate.TryEnter + Task.Run(async () => + { + if (!entry.Gate.Wait(0)) + return; // another refresh in progress + try + { + // Freshen only if still past refresh (re-check) + var now = DateTimeOffset.UtcNow; + if (string.IsNullOrEmpty(entry.Token) || now < entry.RefreshOnUtc) + { + return; + } + + // Prefer stored Mint; if null (first call), mint with the last input/provider + var mint = entry.Mint ?? (ct => provider(lastInput, ct)); + + var minted = await mint(CancellationToken.None).ConfigureAwait(false); + if (minted != null && !string.IsNullOrEmpty(minted.AttestationToken)) + { + var now2 = DateTimeOffset.UtcNow; + entry.Token = minted.AttestationToken; + entry.ExpiresOnUtc = now2 + s_defaultTtl; + entry.RefreshOnUtc = now2 + s_halfTime; + ScheduleTimer(entry); // push next half-time + } + else + { + // Best-effort retry before expiry + ScheduleRetry(entry, s_bgRetryBackoff); + } + } + catch + { + // Swallow background errors; keep current token; try again later + ScheduleRetry(entry, s_bgRetryBackoff); + } + finally + { + entry.Gate.Release(); + } + }); + } + + private static void ScheduleTimer(Entry entry) + { + var due = entry.RefreshOnUtc - DateTimeOffset.UtcNow; + if (due < TimeSpan.Zero) + due = TimeSpan.Zero; + + int dueMs = SafeMs(due); + if (entry.RefreshTimer == null) + { + entry.RefreshTimer = new Timer(TimerCallback, entry, dueMs, Timeout.Infinite); + } + else + { + entry.RefreshTimer.Change(dueMs, Timeout.Infinite); + } + } + + private static void ScheduleRetry(Entry entry, TimeSpan delay) + { + int dueMs = SafeMs(delay); + if (entry.RefreshTimer == null) + { + entry.RefreshTimer = new Timer(TimerCallback, entry, dueMs, Timeout.Infinite); + } + else + { + entry.RefreshTimer.Change(dueMs, Timeout.Infinite); + } + } + + private static int SafeMs(TimeSpan ts) + { + if (ts <= TimeSpan.Zero) + return 0; + double ms = ts.TotalMilliseconds; + if (ms > int.MaxValue) + return int.MaxValue; + return (int)ms; + } + + private static void TimerCallback(object state) + { + var entry = (Entry)state; + // We only schedule; actual minting happens in KickBackgroundRefresh semantics: + // Acquire lock, check refresh condition again, then mint. + // Using stored Mint delegate to avoid needing caller context. + if (entry.Mint == null) + return; // no way to mint yet + Task.Run(async () => + { + if (!entry.Gate.Wait(0)) + return; + try + { + var now = DateTimeOffset.UtcNow; + if (now < entry.RefreshOnUtc) + return; // not due anymore (rescheduled) + var minted = await entry.Mint(CancellationToken.None).ConfigureAwait(false); + if (minted != null && !string.IsNullOrEmpty(minted.AttestationToken)) + { + var now2 = DateTimeOffset.UtcNow; + entry.Token = minted.AttestationToken; + entry.ExpiresOnUtc = now2 + s_defaultTtl; + entry.RefreshOnUtc = now2 + s_halfTime; + ScheduleTimer(entry); + } + else + { + ScheduleRetry(entry, s_bgRetryBackoff); + } + } + catch + { + ScheduleRetry(entry, s_bgRetryBackoff); + } + finally + { + entry.Gate.Release(); + } + }); + } + + // Per-key state + private sealed class Entry : IDisposable + { + internal Entry(long key) { Key = key; Gate = new SemaphoreSlim(1, 1); } + internal long Key; + internal string Token; // opaque JWT (never parsed) + internal DateTimeOffset ExpiresOnUtc; + internal DateTimeOffset RefreshOnUtc; + internal SemaphoreSlim Gate; + internal Timer RefreshTimer; + internal Func> Mint; // stored mint delegate + + public void Dispose() + { + try + { RefreshTimer?.Dispose(); } + catch { } + try + { Gate?.Dispose(); } + catch { } + } + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequest.cs new file mode 100644 index 0000000000..7a1de63551 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequest.cs @@ -0,0 +1,243 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !NET7_0_OR_GREATER +using System; +using System.Collections.ObjectModel; +using System.Formats.Asn1; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Text; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + /// + /// Downlevel polyfill for System.Security.Cryptography.X509Certificates.CertificateRequest + /// that provides OtherRequestAttributes support for frameworks prior to .NET 7.0. + /// This file is conditionally included only for net462, net472, and netstandard2.0. + /// For .NET 8.0+, the built-in CertificateRequest class is used instead. + /// + internal class CertificateRequest + { + private X500DistinguishedName _subjectName; + private RSA _rsa; + private HashAlgorithmName _hashAlgorithmName; + private RSASignaturePadding _rsaPadding; + + internal CertificateRequest( + X500DistinguishedName subjectName, + RSA key, + HashAlgorithmName hashAlgorithm, + RSASignaturePadding padding) + { + _subjectName = subjectName; + _rsa = key; + _hashAlgorithmName = hashAlgorithm; + _rsaPadding = padding; + } + + internal Collection OtherRequestAttributes { get; } = new Collection(); + + private static string MakePem(byte[] ber, string header) + { + const int LineLength = 64; + + string base64 = Convert.ToBase64String(ber); + int offset = 0; + + StringBuilder builder = new StringBuilder("-----BEGIN "); + builder.Append(header); + builder.AppendLine("-----"); + + while (offset < base64.Length) + { + int lineEnd = Math.Min(offset + LineLength, base64.Length); + builder.AppendLine(base64.Substring(offset, lineEnd - offset)); + offset = lineEnd; + } + + builder.Append("-----END "); + builder.Append(header); + builder.AppendLine("-----"); + + return builder.ToString(); + } + + internal string CreateSigningRequestPem() + { + byte[] csr = CreateSigningRequest(); + return MakePem(csr, "CERTIFICATE REQUEST"); + } + + internal byte[] CreateSigningRequest() + { + if (_hashAlgorithmName != HashAlgorithmName.SHA256) + { + throw new NotSupportedException("Signature Processing has only been written for SHA256"); + } + + AsnWriter writer = new AsnWriter(AsnEncodingRules.DER); + + // RSAPublicKey ::= SEQUENCE { + // modulus INTEGER, -- n + // publicExponent INTEGER -- e + // } + + using (writer.PushSequence()) + { + RSAParameters rsaParameters = _rsa.ExportParameters(false); + writer.WriteIntegerUnsigned(rsaParameters.Modulus); + writer.WriteIntegerUnsigned(rsaParameters.Exponent); + } + + byte[] publicKey = writer.Encode(); + writer.Reset(); + + // CertificationRequestInfo ::= SEQUENCE { + // version INTEGER { v1(0) } (v1,...), + // subject Name, + // subjectPKInfo SubjectPublicKeyInfo{{ PKInfoAlgorithms }}, + // attributes [0] Attributes{{ CRIAttributes }} + // } + // + // SubjectPublicKeyInfo { ALGORITHM: IOSet} ::= SEQUENCE { + // algorithm AlgorithmIdentifier { { IOSet} }, + // subjectPublicKey BIT STRING + // } + // + // Attributes { ATTRIBUTE:IOSet } ::= SET OF Attribute{{ IOSet }} + // + // Attribute { ATTRIBUTE:IOSet } ::= SEQUENCE { + // type ATTRIBUTE.&id({IOSet}), + // values SET SIZE(1..MAX) OF ATTRIBUTE.&Type({IOSet}{@type}) + // } + + using (writer.PushSequence()) + { + writer.WriteInteger(0); + writer.WriteEncodedValue(_subjectName.RawData); + + // subjectPKInfo + using (writer.PushSequence()) + { + // algorithm + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier("1.2.840.113549.1.1.1"); + // RSA uses an explicit NULL value for parameters + writer.WriteNull(); + } + + writer.WriteBitString(publicKey); + } + + if (OtherRequestAttributes.Count > 0) + { + // attributes + using (writer.PushSetOf(new Asn1Tag(TagClass.ContextSpecific, 0))) + { + foreach (AsnEncodedData attribute in OtherRequestAttributes) + { + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier(attribute.Oid.Value); + + using (writer.PushSetOf()) + { + writer.WriteEncodedValue(attribute.RawData); + } + } + } + } + } + } + + byte[] certReqInfo = writer.Encode(); + writer.Reset(); + + // CertificationRequest ::= SEQUENCE { + // certificationRequestInfo CertificationRequestInfo, + // signatureAlgorithm AlgorithmIdentifier{{ SignatureAlgorithms }}, + // signature BIT STRING + // } + + using (writer.PushSequence()) + { + writer.WriteEncodedValue(certReqInfo); + + // signatureAlgorithm + using (writer.PushSequence()) + { + if (_rsaPadding == RSASignaturePadding.Pss) + { + if (_hashAlgorithmName != HashAlgorithmName.SHA256) + { + throw new NotSupportedException("Only SHA256 is supported with PSS padding."); + } + + writer.WriteObjectIdentifier("1.2.840.113549.1.1.10"); + + // RSASSA-PSS-params ::= SEQUENCE { + // hashAlgorithm [0] HashAlgorithm DEFAULT sha1, + // maskGenAlgorithm [1] MaskGenAlgorithm DEFAULT mgf1SHA1, + // saltLength [2] INTEGER DEFAULT 20, + // trailerField [3] TrailerField DEFAULT trailerFieldBC + // } + + using (writer.PushSequence()) + { + string digestOid = "2.16.840.1.101.3.4.2.1"; + + // hashAlgorithm + using (writer.PushSequence(new Asn1Tag(TagClass.ContextSpecific, 0))) + { + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier(digestOid); + } + } + + using (writer.PushSequence(new Asn1Tag(TagClass.ContextSpecific, 1))) + { + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier("1.2.840.113549.1.1.8"); + + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier(digestOid); + } + } + } + + // saltLength (SHA256.Length, 32 bytes) + using (writer.PushSequence(new Asn1Tag(TagClass.ContextSpecific, 2))) + { + writer.WriteInteger(32); + } + + // trailerField 1, which is trailerFieldBC, which is the DEFAULT, + // so don't write it down. + } + } + else if (_rsaPadding == RSASignaturePadding.Pkcs1) + { + writer.WriteObjectIdentifier("1.2.840.113549.1.1.11"); + // RSA PKCS1 uses an explicit NULL value for parameters + writer.WriteNull(); + } + else + { + throw new NotSupportedException("Unsupported RSA padding."); + } + } + + byte[] signature = _rsa.SignData(certReqInfo, _hashAlgorithmName, _rsaPadding); + writer.WriteBitString(signature); + } + + return writer.Encode(); + } + } +} +#endif diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestBody.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestBody.cs new file mode 100644 index 0000000000..64b27ccc45 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestBody.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if SUPPORTS_SYSTEM_TEXT_JSON + using JsonProperty = System.Text.Json.Serialization.JsonPropertyNameAttribute; +#else + using Microsoft.Identity.Json; +#endif + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + internal class CertificateRequestBody + { + [JsonProperty("csr")] + public string Csr { get; set; } + + [JsonProperty("attestation_token")] + public string AttestationToken { get; set; } + + public static bool IsNullOrEmpty(CertificateRequestBody certificateRequestBody) + { + return certificateRequestBody == null || + (string.IsNullOrEmpty(certificateRequestBody.Csr) && string.IsNullOrEmpty(certificateRequestBody.AttestationToken)); + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs new file mode 100644 index 0000000000..5e84000054 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Buffers.Text; +using System.Net; +#if SUPPORTS_SYSTEM_TEXT_JSON + using JsonProperty = System.Text.Json.Serialization.JsonPropertyNameAttribute; +#else +using Microsoft.Identity.Json; +#endif + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + /// + /// Represents the response for a Managed Identity CSR request. + /// + internal class CertificateRequestResponse + { + [JsonProperty("client_id")] + public string ClientId { get; set; } // client_id of the Managed Identity  + + [JsonProperty("tenant_id")] + public string TenantId { get; set; } // AAD Tenant of the Managed Identity  + + [JsonProperty("certificate")] + public string Certificate { get; set; } // Base64 encoded X509certificate + + [JsonProperty("identity_type")] + public string IdentityType { get; set; } // SAMI or UAMI + + [JsonProperty("mtls_authentication_endpoint")] + public string MtlsAuthenticationEndpoint { get; set; } // Regional STS mTLS endpoint + + public CertificateRequestResponse() { } + + public static void Validate(CertificateRequestResponse certificateRequestResponse) + { + if (string.IsNullOrEmpty(certificateRequestResponse.ClientId) || + string.IsNullOrEmpty(certificateRequestResponse.TenantId) || + string.IsNullOrEmpty(certificateRequestResponse.Certificate) || + string.IsNullOrEmpty(certificateRequestResponse.IdentityType) || + string.IsNullOrEmpty(certificateRequestResponse.MtlsAuthenticationEndpoint)) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed because the certificate request response is malformed. Status code: 200", + null, + ManagedIdentitySource.ImdsV2, + (int)HttpStatusCode.OK); + } + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs new file mode 100644 index 0000000000..52cd10354e --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Formats.Asn1; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Identity.Client.Utils; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + internal class Csr + { + internal static (string csrPem, RSA privateKey) Generate(RSA rsa, string clientId, string tenantId, CuidInfo cuid) + { + // Use custom polyfill for downlevel frameworks (net462, net472, netstandard2.0) + // See CertificateRequest.cs + var req = new CertificateRequest( + new X500DistinguishedName($"CN={clientId}, DC={tenantId}"), + rsa, + HashAlgorithmName.SHA256, + RSASignaturePadding.Pss); + + AsnWriter writer = new AsnWriter(AsnEncodingRules.DER); + writer.WriteCharacterString(UniversalTagNumber.UTF8String, JsonHelper.SerializeToJson(cuid)); + + req.OtherRequestAttributes.Add( + new AsnEncodedData( + "1.3.6.1.4.1.311.90.2.10", + writer.Encode())); + + string pemCsr = req.CreateSigningRequestPem(); + + // Remove PEM headers and format as single line + string rawCsr = pemCsr + .Replace("-----BEGIN CERTIFICATE REQUEST-----", "") + .Replace("-----END CERTIFICATE REQUEST-----", "") + .Replace("\r", "") + .Replace("\n", "") + .Trim(); + + return (rawCsr, rsa); + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs new file mode 100644 index 0000000000..94be9c72cc --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if SUPPORTS_SYSTEM_TEXT_JSON + using JsonProperty = System.Text.Json.Serialization.JsonPropertyNameAttribute; +#else + using Microsoft.Identity.Json; +#endif + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + /// + /// Represents VM unique Ids for CSR metadata. + /// + internal class CuidInfo + { + [JsonProperty("vmId")] + public string VmId { get; set; } + + [JsonProperty("vmssId")] + public string VmssId { get; set; } + + public static bool IsNullOrEmpty(CuidInfo cuidInfo) + { + return cuidInfo == null || + (string.IsNullOrEmpty(cuidInfo.VmId) && string.IsNullOrEmpty(cuidInfo.VmssId)); + } + } + + /// + /// Represents metadata required for Certificate Signing Request (CSR) operations. + /// + internal class CsrMetadata + { + /// + /// VM unique Id + /// + [JsonProperty("cuId")] + public CuidInfo CuId { get; set; } + + /// + /// client_id of the Managed Identity + /// + [JsonProperty("clientId")] + public string ClientId { get; set; } + + /// + /// AAD Tenant of the Managed Identity + /// + [JsonProperty("tenantId")] + public string TenantId { get; set; } + + /// + /// MAA Regional / Custom Endpoint for attestation purposes. + /// + [JsonProperty("attestationEndpoint")] + public string AttestationEndpoint { get; set; } + + // Parameterless constructor for deserialization + public CsrMetadata() { } + + /// + /// Validates a JSON decoded CsrMetadata instance. + /// + /// The CsrMetadata object. + /// false if any field is null or empty + public static bool ValidateCsrMetadata(CsrMetadata csrMetadata) + { + if (csrMetadata == null || + CuidInfo.IsNullOrEmpty(csrMetadata.CuId) || + string.IsNullOrEmpty(csrMetadata.ClientId) || + string.IsNullOrEmpty(csrMetadata.TenantId) || + string.IsNullOrEmpty(csrMetadata.AttestationEndpoint)) + { + return false; + } + + return true; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/DefaultCsrFactory.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/DefaultCsrFactory.cs new file mode 100644 index 0000000000..cdc2b9e526 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/DefaultCsrFactory.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Security.Cryptography; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + internal class DefaultCsrFactory : ICsrFactory + { + public (string csrPem, RSA privateKey) Generate(RSA rsa, string clientId, string tenantId, CuidInfo cuid) + { + return Csr.Generate(rsa, clientId, tenantId, cuid); + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ICsrFactory.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ICsrFactory.cs new file mode 100644 index 0000000000..69f71f8079 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ICsrFactory.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Security.Cryptography; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + internal interface ICsrFactory + { + (string csrPem, RSA privateKey) Generate(RSA rsa, string clientId, string tenantId, CuidInfo cuid); + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs new file mode 100644 index 0000000000..6c3449d0b1 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -0,0 +1,455 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.Http; +using Microsoft.Identity.Client.Http.Retry; +using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.OAuth2; +using Microsoft.Identity.Client.OAuth2.Throttling; +using Microsoft.Identity.Client.PlatformsCommon.Shared; +using Microsoft.Identity.Client.Utils; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity + { + // used in unit tests + public const string ImdsV2ApiVersion = "2.0"; + public const string CsrMetadataPath = "/metadata/identity/getplatformmetadata"; + public const string CertificateRequestPath = "/metadata/identity/issuecredential"; + public const string AcquireEntraTokenPath = "/oauth2/v2.0/token"; + + public static async Task GetCsrMetadataAsync( + RequestContext requestContext, + bool probeMode) + { +#if NET462 + requestContext.Logger.Info("[Managed Identity] IMDSv2 flow is not supported on .NET Framework 4.6.2. Cryptographic operations required for managed identity authentication are unavailable on this platform. Skipping IMDSv2 probe."); + return await Task.FromResult(null).ConfigureAwait(false); +#else + var queryParams = ImdsV2QueryParamsHelper(requestContext); + + var headers = new Dictionary + { + { "Metadata", "true" }, + { OAuth2Header.XMsCorrelationId, requestContext.CorrelationId.ToString() } + }; + + IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.CsrMetadataProbe); + + HttpResponse response = null; + + try + { + response = await requestContext.ServiceBundle.HttpManager.SendRequestAsync( + ImdsManagedIdentitySource.GetValidatedEndpoint(requestContext.Logger, CsrMetadataPath, queryParams), + headers, + body: null, + method: HttpMethod.Get, + logger: requestContext.Logger, + doNotThrow: false, + mtlsCertificate: null, + validateServerCertificate: null, + cancellationToken: requestContext.UserCancellationToken, + retryPolicy: retryPolicy) + .ConfigureAwait(false); + } + catch (Exception ex) + { + if (probeMode) + { + requestContext.Logger.Info($"[Managed Identity] IMDSv2 CSR endpoint failure. Exception occurred while sending request to CSR metadata endpoint: {ex}"); + return null; + } + else + { + ThrowProbeFailedException( + "ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed.", + ex); + } + } + + if (response.StatusCode != HttpStatusCode.OK) + { + if (probeMode) + { + requestContext.Logger.Info(() => $"[Managed Identity] IMDSv2 managed identity is not available. Status code: {response.StatusCode}, Body: {response.Body}"); + return null; + } + else + { + ThrowProbeFailedException( + $"ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed due to HTTP error. Status code: {response.StatusCode} Body: {response.Body}", + null, + (int)response.StatusCode); + } + } + + if (!ValidateCsrMetadataResponse(response, requestContext.Logger, probeMode)) + { + return null; + } + + return TryCreateCsrMetadata(response, requestContext.Logger, probeMode); +#endif + } + + private static void ThrowProbeFailedException( + String errorMessage, + Exception ex = null, + int? statusCode = null) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] {errorMessage}", + ex, + ManagedIdentitySource.ImdsV2, + statusCode); + } + + private static bool ValidateCsrMetadataResponse( + HttpResponse response, + ILoggerAdapter logger, + bool probeMode) + { + string serverHeader = response.HeadersAsDictionary + .FirstOrDefault((kvp) => { + return string.Equals(kvp.Key, "server", StringComparison.OrdinalIgnoreCase); + }).Value; + + if (serverHeader == null) + { + if (probeMode) + { + logger.Info(() => $"[Managed Identity] IMDSv2 managed identity is not available. 'server' header is missing from the CSR metadata response. Body: {response.Body}"); + return false; + } + else + { + ThrowProbeFailedException( + $"ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed because response doesn't have server header. Status code: {response.StatusCode} Body: {response.Body}", + null, + (int)response.StatusCode); + } + } + + if (!serverHeader.Contains("IMDS", StringComparison.OrdinalIgnoreCase)) + { + if (probeMode) + { + logger.Info(() => $"[Managed Identity] IMDSv2 managed identity is not available. The 'server' header format is invalid. Extracted server header: {serverHeader}"); + return false; + } + else + { + ThrowProbeFailedException( + $"ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed because the 'server' header format is invalid. Extracted server header: {serverHeader}. Status code: {response.StatusCode} Body: {response.Body}", + null, + (int)response.StatusCode); + } + } + + return true; + } + + private static CsrMetadata TryCreateCsrMetadata( + HttpResponse response, + ILoggerAdapter logger, + bool probeMode) + { + CsrMetadata csrMetadata = JsonHelper.DeserializeFromJson(response.Body); + if (!CsrMetadata.ValidateCsrMetadata(csrMetadata)) + { + ThrowProbeFailedException( + $"ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed because the CsrMetadata response is invalid. Status code: {response.StatusCode} Body: {response.Body}", + null, + (int)response.StatusCode); + } + + logger.Info(() => "[Managed Identity] IMDSv2 managed identity is available."); + return csrMetadata; + } + + public static AbstractManagedIdentity Create(RequestContext requestContext) + { + return new ImdsV2ManagedIdentitySource(requestContext); + } + + internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : + base(requestContext, ManagedIdentitySource.ImdsV2) + { } + + private async Task ExecuteCertificateRequestAsync( + string clientId, + string attestationEndpoint, + string csr, + ManagedIdentityKeyInfo managedIdentityKeyInfo) + { + var queryParams = ImdsV2QueryParamsHelper(_requestContext); + + // TODO: add bypass_cache query param in case of token revocation. Boolean: true/false + + var headers = new Dictionary + { + { "Metadata", "true" }, + { OAuth2Header.XMsCorrelationId, _requestContext.CorrelationId.ToString() } + }; + + if (_isMtlsPopRequested && managedIdentityKeyInfo.Type != ManagedIdentityKeyType.KeyGuard) + { + throw new MsalClientException( + "mtls_pop_requires_keyguard", + "[ImdsV2] mTLS Proof-of-Possession requires a KeyGuard-backed key. Enable KeyGuard or use a KeyGuard-supported environment."); + } + + // TODO: : Normalize and validate attestation endpoint Code needs to be removed + // once IMDS team start returning full URI + Uri normalizedEndpoint = NormalizeAttestationEndpoint(attestationEndpoint, _requestContext.Logger); + + // Ask helper for JWT only for KeyGuard keys + string attestationJwt = string.Empty; + if (managedIdentityKeyInfo.Type == ManagedIdentityKeyType.KeyGuard) + { + attestationJwt = await GetAttestationJwtAsync( + clientId, + normalizedEndpoint, + managedIdentityKeyInfo, + _requestContext.UserCancellationToken).ConfigureAwait(false); + } + + var certificateRequestBody = new CertificateRequestBody() + { + Csr = csr, + AttestationToken = attestationJwt + }; + + string body = JsonHelper.SerializeToJson(certificateRequestBody); + + IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.Imds); + + HttpResponse response = null; + + try + { + response = await _requestContext.ServiceBundle.HttpManager.SendRequestAsync( + ImdsManagedIdentitySource.GetValidatedEndpoint(_requestContext.Logger, CertificateRequestPath, queryParams), + headers, + body: new StringContent(body, System.Text.Encoding.UTF8, "application/json"), + method: HttpMethod.Post, + logger: _requestContext.Logger, + doNotThrow: false, + mtlsCertificate: null, + validateServerCertificate: null, + cancellationToken: _requestContext.UserCancellationToken, + retryPolicy: retryPolicy) + .ConfigureAwait(false); + } + catch (Exception ex) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed.", + ex, + ManagedIdentitySource.ImdsV2, + (int)response.StatusCode); + } + + if (response.StatusCode != HttpStatusCode.OK) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed due to HTTP error. Status code: {response.StatusCode} Body: {response.Body}", + null, + ManagedIdentitySource.ImdsV2, + (int)response.StatusCode); + } + + var certificateRequestResponse = JsonHelper.DeserializeFromJson(response.Body); + CertificateRequestResponse.Validate(certificateRequestResponse); + + return certificateRequestResponse; + } + + protected override async Task CreateRequestAsync(string resource) + { + var csrMetadata = await GetCsrMetadataAsync(_requestContext, false).ConfigureAwait(false); + + IManagedIdentityKeyProvider keyProvider = _requestContext.ServiceBundle.PlatformProxy.ManagedIdentityKeyProvider; + + ManagedIdentityKeyInfo keyInfo = await keyProvider + .GetOrCreateKeyAsync( + _requestContext.Logger, + _requestContext.UserCancellationToken) + .ConfigureAwait(false); + + var (csr, privateKey) = _requestContext.ServiceBundle.Config.CsrFactory.Generate(keyInfo.Key, csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); + + var certificateRequestResponse = await ExecuteCertificateRequestAsync( + csrMetadata.ClientId, + csrMetadata.AttestationEndpoint, + csr, + keyInfo).ConfigureAwait(false); + + // transform certificateRequestResponse.Certificate to x509 with private key + var mtlsCertificate = CommonCryptographyManager.AttachPrivateKeyToCert( + certificateRequestResponse.Certificate, + privateKey); + + ManagedIdentityRequest request = new(HttpMethod.Post, new Uri($"{certificateRequestResponse.MtlsAuthenticationEndpoint}/{certificateRequestResponse.TenantId}{AcquireEntraTokenPath}")); + + var idParams = MsalIdHelper.GetMsalIdParameters(_requestContext.Logger); + foreach (var idParam in idParams) + { + request.Headers[idParam.Key] = idParam.Value; + } + request.Headers.Add(OAuth2Header.XMsCorrelationId, _requestContext.CorrelationId.ToString()); + request.Headers.Add(ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue); + request.Headers.Add(OAuth2Header.RequestCorrelationIdInResponse, "true"); + + var tokenType = _isMtlsPopRequested ? "mtls_pop" : "bearer"; + + request.BodyParameters.Add("client_id", certificateRequestResponse.ClientId); + request.BodyParameters.Add("grant_type", OAuth2GrantType.ClientCredentials); + request.BodyParameters.Add("scope", resource.TrimEnd('/') + "/.default"); + request.BodyParameters.Add("token_type", tokenType); + + request.RequestType = RequestType.STS; + + request.MtlsCertificate = mtlsCertificate; + + return request; + } + + private static string ImdsV2QueryParamsHelper(RequestContext requestContext) + { + var queryParams = $"cred-api-version={ImdsV2ApiVersion}"; + + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam( + requestContext.ServiceBundle.Config.ManagedIdentityId.IdType, + requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId, + requestContext.Logger); + + if (userAssignedIdQueryParam != null) + { + queryParams += $"&{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}"; + } + + return queryParams; + } + + /// + /// Obtains an attestation JWT for the KeyGuard/CSR payload using the configured + /// attestation provider and normalized endpoint. + /// + /// Client ID to be sent to the attestation provider. + /// The attestation endpoint. + /// The key information. + /// Cancellation token. + /// JWT string suitable for the IMDSv2 attested POP flow. + /// Wraps client/network failures. + + /// + /// Obtains an attestation JWT for the KeyGuard/CSR payload using the configured + /// attestation provider and normalized endpoint. Now uses AttestationTokenMemoryCache. + /// + private async Task GetAttestationJwtAsync( + string clientId, + Uri attestationEndpoint, + ManagedIdentityKeyInfo keyInfo, + CancellationToken cancellationToken) + { + var provider = _requestContext.AttestationTokenProvider; + + if (keyInfo.Type == ManagedIdentityKeyType.KeyGuard && + keyInfo.Key is not System.Security.Cryptography.RSACng rsaCng) + { + throw new MsalClientException( + "keyguard_requires_cng", + "[ImdsV2] KeyGuard attestation currently supports only RSA CNG keys on Windows."); + } + + var input = new AttestationTokenInput + { + ClientId = clientId, + AttestationEndpoint = attestationEndpoint, + KeyHandle = (keyInfo.Key as System.Security.Cryptography.RSACng)?.Key.Handle + }; + + // Use in-memory cache (phase 1). Caches per key handle (or 0 if unavailable). + var cached = await AttestationTokenMemoryCache + .GetOrCreateAsync(input, provider, cancellationToken) + .ConfigureAwait(false); + + if (cached == null || string.IsNullOrWhiteSpace(cached.AttestationToken)) + { + throw new MsalClientException( + "attestation_failed", + "[ImdsV2] Attestation provider failed to return an attestation token."); + } + + return cached.AttestationToken; + } + + //To-do : Remove this method once IMDS team start returning full URI + /// + /// Temporarily normalize attestation endpoint values to a full https:// URI. + /// IMDS team will eventually return a full URI. + /// + /// + /// + /// + private static Uri NormalizeAttestationEndpoint(string rawEndpoint, ILoggerAdapter logger) + { + if (string.IsNullOrWhiteSpace(rawEndpoint)) + { + return null; + } + + // Trim whitespace + rawEndpoint = rawEndpoint.Trim(); + + // If it already parses as an absolute URI with https, keep it. + if (Uri.TryCreate(rawEndpoint, UriKind.Absolute, out var absolute) && + (absolute.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase))) + { + return absolute; + } + + // If it has no scheme (common service behavior returning only host) + // prepend https:// and try again. + if (!rawEndpoint.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + { + var candidate = "https://" + rawEndpoint; + if (Uri.TryCreate(candidate, UriKind.Absolute, out var httpsUri)) + { + logger.Info(() => $"[Managed Identity] Normalized attestation endpoint '{rawEndpoint}' -> '{httpsUri.ToString()}'."); + return httpsUri; + } + } + + // Final attempt: reject http (non‑TLS) or malformed + if (Uri.TryCreate(rawEndpoint, UriKind.Absolute, out var anyUri)) + { + if (!anyUri.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase)) + { + logger.Warning($"[Managed Identity] Attestation endpoint uses unsupported scheme '{anyUri.Scheme}'. HTTPS is required."); + return null; + } + return anyUri; + } + + logger.Warning($"[Managed Identity] Failed to normalize attestation endpoint value '{rawEndpoint}'."); + return null; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs b/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs index eded64dc91..2144402c10 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs @@ -2,15 +2,11 @@ // Licensed under the MIT License. using System; -using System.Collections.Generic; -using System.ComponentModel; using System.Threading; using System.Threading.Tasks; using Microsoft.Identity.Client.ApiConfig.Executors; -using Microsoft.Identity.Client.ApiConfig.Parameters; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; -using Microsoft.Identity.Client.Internal.Requests; using Microsoft.Identity.Client.ManagedIdentity; namespace Microsoft.Identity.Client @@ -28,6 +24,8 @@ public sealed class ManagedIdentityApplication : ApplicationBase, IManagedIdentityApplication { + internal ManagedIdentityClient ManagedIdentityClient { get; } + internal ManagedIdentityApplication( ApplicationConfiguration configuration) : base(configuration) @@ -37,6 +35,8 @@ internal ManagedIdentityApplication( AppTokenCacheInternal = configuration.AppTokenCacheInternalForTest ?? new TokenCache(ServiceBundle, true); this.ServiceBundle.ApplicationLogger.Verbose(()=>$"ManagedIdentityApplication {configuration.GetHashCode()} created"); + + ManagedIdentityClient = new ManagedIdentityClient(); } // Stores all app tokens @@ -55,13 +55,28 @@ public AcquireTokenForManagedIdentityParameterBuilder AcquireTokenForManagedIden resource); } + /// + public async Task GetManagedIdentitySourceAsync() + { + if (ManagedIdentityClient.s_sourceName != ManagedIdentitySource.None) + { + return ManagedIdentityClient.s_sourceName; + } + + // Create a temporary RequestContext for the CSR metadata probe request. + var csrMetadataProbeRequestContext = new RequestContext(this.ServiceBundle, Guid.NewGuid(), null, CancellationToken.None); + + return await ManagedIdentityClient.GetManagedIdentitySourceAsync(csrMetadataProbeRequestContext).ConfigureAwait(false); + } + /// /// Detects and returns the managed identity source available on the environment. /// /// Managed identity source detected on the environment if any. + [Obsolete("Use GetManagedIdentitySourceAsync() instead. \"ManagedIdentityApplication mi = miBuilder.Build() as ManagedIdentityApplication;\"")] public static ManagedIdentitySource GetManagedIdentitySource() { - return ManagedIdentityClient.GetManagedIdentitySource(); + return ManagedIdentityClient.GetManagedIdentitySourceNoImdsV2(); } } } diff --git a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj index 2fbcd27067..3279f0338a 100644 --- a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj +++ b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj @@ -92,8 +92,10 @@ - + + + @@ -102,7 +104,7 @@ - + @@ -112,6 +114,7 @@ + @@ -157,4 +160,5 @@ + diff --git a/src/client/Microsoft.Identity.Client/MsalError.cs b/src/client/Microsoft.Identity.Client/MsalError.cs index 4bcc576a42..23846c19b1 100644 --- a/src/client/Microsoft.Identity.Client/MsalError.cs +++ b/src/client/Microsoft.Identity.Client/MsalError.cs @@ -1196,6 +1196,11 @@ public static class MsalError /// public const string RegionRequiredForMtlsPop = "region_required_for_mtls_pop"; + /// + /// What happened? mTLS is not supported for managed identity authentication. + /// + public const string MtlsNotSupportedForManagedIdentity = "mtls_not_supported_for_managed_identity"; + /// /// What happened? The operation attempted to force a token refresh while also using a token hash. /// These two options are incompatible because forcing a refresh bypasses token caching, @@ -1206,5 +1211,10 @@ public static class MsalError /// - If token hashing is required, allow the cached token to be used instead of forcing a refresh. /// public const string ForceRefreshNotCompatibleWithTokenHash = "force_refresh_and_token_hash_not_compatible"; + + /// + /// The certificate received from the Imds server is invalid. + /// + public const string InvalidCertificate = "invalid_certificate"; } } diff --git a/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs b/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs index 85bb9d74cc..92d4e06840 100644 --- a/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs +++ b/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs @@ -441,9 +441,12 @@ public static string InvalidTokenProviderResponseValue(string invalidValueName) public const string MtlsCertificateNotProvidedMessage = "mTLS Proof‑of‑Possession requires a certificate for this request. Either configure the application with .WithCertificate(...) or pass a certificate‑bound client‑assertion and chain .WithMtlsProofOfPossession() on the request builder. See https://aka.ms/msal-net-pop for details."; public const string MtlsInvalidAuthorityTypeMessage = "mTLS PoP is only supported for AAD authority type. See https://aka.ms/msal-net-pop for details."; public const string MtlsNonTenantedAuthorityNotAllowedMessage = "mTLS authentication requires a tenanted authority. Using 'common', 'organizations', or similar non-tenanted authorities is not allowed. Please provide an authority with a specific tenant ID (e.g., 'https://login.microsoftonline.com/{tenantId}'). See https://aka.ms/msal-net-pop for details."; + public const string MtlsNotSupportedForManagedIdentityMessage = "IMDSv2 flow is not supported on .NET Framework 4.6.2. Cryptographic operations required for managed identity authentication are unavailable on this platform."; + public const string MtlsNotSupportedForNonWindowsMessage = "mTLS PoP with Managed Identity is not supported on this OS. See https://aka.ms/msal-net-pop."; public const string RegionRequiredForMtlsPopMessage = "Regional auto-detect failed. mTLS Proof-of-Possession requires a region to be specified, as there is no global endpoint for mTLS. See https://aka.ms/msal-net-pop for details."; public const string ForceRefreshAndTokenHasNotCompatible = "Cannot specify ForceRefresh and AccessTokenSha256ToRefresh in the same request."; public const string RequestTimeOut = "Request to the endpoint timed out."; public const string MalformedOidcAuthorityFormat = "Possible cause: When using Entra External ID, you didn't append /v2.0, for example {0}/v2.0\""; + public const string InvalidCertificate = "The certificate received from the Imds server is invalid."; } } diff --git a/src/client/Microsoft.Identity.Client/OAuth2/OAuthConstants.cs b/src/client/Microsoft.Identity.Client/OAuth2/OAuthConstants.cs index 7aa76e3cd1..a6003809d9 100644 --- a/src/client/Microsoft.Identity.Client/OAuth2/OAuthConstants.cs +++ b/src/client/Microsoft.Identity.Client/OAuth2/OAuthConstants.cs @@ -77,6 +77,7 @@ internal static class OAuth2RequestedTokenUse internal static class OAuth2Header { public const string CorrelationId = "client-request-id"; + public const string XMsCorrelationId = $"x-ms-{CorrelationId}"; public const string RequestCorrelationIdInResponse = "return-client-request-id"; public const string AppName = "x-app-name"; public const string AppVer = "x-app-ver"; diff --git a/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs b/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs index d36af0533d..ea9a757a33 100644 --- a/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs +++ b/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs @@ -11,6 +11,7 @@ using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Kerberos; using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.OAuth2; using Microsoft.Identity.Client.Region; using Microsoft.Identity.Client.WsTrust; @@ -37,6 +38,10 @@ namespace Microsoft.Identity.Client.Platforms.net [JsonSerializable(typeof(ManagedIdentityResponse))] [JsonSerializable(typeof(ManagedIdentityErrorResponse))] [JsonSerializable(typeof(OidcMetadata))] + [JsonSerializable(typeof(CsrMetadata))] + [JsonSerializable(typeof(CuidInfo))] + [JsonSerializable(typeof(CertificateRequestBody))] + [JsonSerializable(typeof(CertificateRequestResponse))] [JsonSourceGenerationOptions] internal partial class MsalJsonSerializerContext : JsonSerializerContext { @@ -51,6 +56,7 @@ public static MsalJsonSerializerContext Custom { NumberHandling = JsonNumberHandling.AllowReadingFromString, AllowTrailingCommas = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, Converters = { new JsonStringConverter(), diff --git a/src/client/Microsoft.Identity.Client/PlatformsCommon/Interfaces/IPlatformProxy.cs b/src/client/Microsoft.Identity.Client/PlatformsCommon/Interfaces/IPlatformProxy.cs index 00b1a446bb..43ddeb032a 100644 --- a/src/client/Microsoft.Identity.Client/PlatformsCommon/Interfaces/IPlatformProxy.cs +++ b/src/client/Microsoft.Identity.Client/PlatformsCommon/Interfaces/IPlatformProxy.cs @@ -5,6 +5,7 @@ using Microsoft.Identity.Client.AuthScheme.PoP; using Microsoft.Identity.Client.Cache; using Microsoft.Identity.Client.Internal.Broker; +using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.TelemetryCore.OpenTelemetry; using Microsoft.Identity.Client.UI; @@ -110,5 +111,7 @@ internal interface IPlatformProxy bool BrokerSupportsWamAccounts { get; } IMsalHttpClientFactory CreateDefaultHttpClientFactory(); + + IManagedIdentityKeyProvider ManagedIdentityKeyProvider { get; } } } diff --git a/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/AbstractPlatformProxy.cs b/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/AbstractPlatformProxy.cs index 8f2301896c..d6a2b488a2 100644 --- a/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/AbstractPlatformProxy.cs +++ b/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/AbstractPlatformProxy.cs @@ -8,6 +8,8 @@ using Microsoft.Identity.Client.Cache; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal.Broker; +using Microsoft.Identity.Client.ManagedIdentity; + #if SUPPORTS_OTEL using Microsoft.Identity.Client.Platforms.Features.OpenTelemetry; #endif @@ -34,6 +36,7 @@ internal abstract class AbstractPlatformProxy : IPlatformProxy private readonly Lazy _productName; private readonly Lazy _runtimeVersion; private readonly Lazy _otelInstrumentation; + private readonly Lazy _miKeyProvider; protected AbstractPlatformProxy(ILoggerAdapter logger) { @@ -49,6 +52,7 @@ protected AbstractPlatformProxy(ILoggerAdapter logger) _platformLogger = new Lazy(InternalGetPlatformLogger); _runtimeVersion = new Lazy(InternalGetRuntimeVersion); _otelInstrumentation = new Lazy(InternalGetOtelInstrumentation); + _miKeyProvider = new Lazy(GetManagedIdentityKeyProvider); } private IOtelInstrumentation InternalGetOtelInstrumentation() @@ -229,10 +233,17 @@ public virtual IMsalHttpClientFactory CreateDefaultHttpClientFactory() return new SimpleHttpClientFactory(); } + internal virtual IManagedIdentityKeyProvider GetManagedIdentityKeyProvider() + { + return ManagedIdentityKeyProviderFactory.GetOrCreateProvider(Logger); + } + /// /// On Android and iOS, MSAL will save the legacy ADAL cache in a known location. /// On other platforms, the app developer must use the serialization callbacks /// public virtual bool LegacyCacheRequiresSerialization => true; + + public IManagedIdentityKeyProvider ManagedIdentityKeyProvider => _miKeyProvider.Value; } } diff --git a/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/CommonCryptographyManager.cs b/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/CommonCryptographyManager.cs index 20fc279fc4..bbb368bc92 100644 --- a/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/CommonCryptographyManager.cs +++ b/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/CommonCryptographyManager.cs @@ -102,14 +102,90 @@ public virtual byte[] SignWithCertificate(string message, X509Certificate2 certi } byte[] SignDataAndCacheProvider(string message) - { + { // CodeQL [SM03799] PKCS1 padding is for Identity Providers not supporting PSS (older ADFS, dSTS) var signedData = rsa.SignData(Encoding.UTF8.GetBytes(message), HashAlgorithmName.SHA256, signaturePadding); - + // Cache only valid RSA crypto providers, which are able to sign data successfully s_certificateToRsaMap[certificate.Thumbprint] = rsa; return signedData; } } + + /// + /// Attaches a private key to a certificate for use in mTLS authentication. + /// + /// The certificate received from the Imds server + /// The RSA private key to attach + /// An X509Certificate2 with the private key attached + /// Thrown when rawCertificate or privateKey is null + /// Thrown when rawCertificate is empty, invalid, and cannot be parsed + internal static X509Certificate2 AttachPrivateKeyToCert(string rawCertificate, RSA privateKey) + { + if (string.IsNullOrEmpty(rawCertificate)) + throw new MsalServiceException(MsalError.InvalidCertificate, MsalErrorMessage.InvalidCertificate); + if (privateKey == null) + throw new ArgumentNullException(nameof(privateKey)); + + X509Certificate2 certificate = null; + + try + { + byte[] certBytes = Convert.FromBase64String(rawCertificate); + certificate = new X509Certificate2(certBytes); + } + catch (FormatException ex) + { + throw new MsalServiceException(MsalError.InvalidCertificate, MsalErrorMessage.InvalidCertificate, ex); + } + + try + { +#if NET8_0_OR_GREATER + // Attach the private key and return a new certificate instance + return certificate.CopyWithPrivateKey(privateKey); +#else + // .NET Framework 4.7.2 and .NET Standard 2.0 - manual private key attachment + return AttachPrivateKeyToOlderFrameworks(certificate, privateKey); +#endif + } + catch (Exception ex) + { + throw new MsalServiceException(MsalError.InvalidCertificate, MsalErrorMessage.InvalidCertificate, ex); + } + } + +#if !NET8_0_OR_GREATER + /// + /// Attaches a private key to a certificate for older .NET Framework versions. + /// This method uses the older RSACng approach for .NET Framework 4.7.2 and .NET Standard 2.0. + /// + /// The certificate without private key + /// The RSA private key to attach + /// An X509Certificate2 with the private key attached + /// Thrown when private key attachment fails + private static X509Certificate2 AttachPrivateKeyToOlderFrameworks(X509Certificate2 certificate, RSA privateKey) + { + // For older frameworks, we need to use the legacy approach with RSACryptoServiceProvider + // First, export the RSA parameters from the provided private key + var parameters = privateKey.ExportParameters(includePrivateParameters: true); + + // Create a new RSACryptoServiceProvider with the correct key size + int keySize = parameters.Modulus.Length * 8; + using (var rsaProvider = new RSACryptoServiceProvider(keySize)) + { + // Import the parameters into the new provider + rsaProvider.ImportParameters(parameters); + + // Create a new certificate instance from the raw data + var certWithPrivateKey = new X509Certificate2(certificate.RawData); + + // Assign the private key using the legacy property + certWithPrivateKey.PrivateKey = rsaProvider; + + return certWithPrivateKey; + } + } +#endif } } diff --git a/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/InMemoryPartitionedAppTokenCacheAccessor.cs b/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/InMemoryPartitionedAppTokenCacheAccessor.cs index 1a732eb856..9f97393bfe 100644 --- a/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/InMemoryPartitionedAppTokenCacheAccessor.cs +++ b/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/InMemoryPartitionedAppTokenCacheAccessor.cs @@ -256,5 +256,11 @@ private ref int GetEntryCountRef() return ref _tokenCacheAccessorOptions.UseSharedCache ? ref s_entryCount : ref _entryCount; } + public static void ClearStaticCacheForTest() + { + s_accessTokenCacheDictionary.Clear(); + s_appMetadataDictionary.Clear(); + Interlocked.Exchange(ref s_entryCount, 0); + } } } diff --git a/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/InMemoryPartitionedUserTokenCacheAccessor.cs b/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/InMemoryPartitionedUserTokenCacheAccessor.cs index 0d23372704..94011b1e65 100644 --- a/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/InMemoryPartitionedUserTokenCacheAccessor.cs +++ b/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/InMemoryPartitionedUserTokenCacheAccessor.cs @@ -5,6 +5,7 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; +using System.Threading; using Microsoft.Identity.Client.Cache; using Microsoft.Identity.Client.Cache.Items; using Microsoft.Identity.Client.Cache.Keys; @@ -168,7 +169,7 @@ public void DeleteAccessToken(MsalAccessTokenCacheItem item) if (AccessTokenCacheDictionary.TryGetValue(partitionKey, out var partition)) { - bool removed = partition.TryRemove(item.CacheKey, out _); + bool removed = partition.TryRemove(item.CacheKey, out _); if (removed) { System.Threading.Interlocked.Decrement(ref GetEntryCountRef()); @@ -365,5 +366,15 @@ private ref int GetEntryCountRef() { return ref _tokenCacheAccessorOptions.UseSharedCache ? ref s_entryCount : ref _entryCount; } + + public static void ClearStaticCacheForTest() + { + s_accessTokenCacheDictionary.Clear(); + s_refreshTokenCacheDictionary.Clear(); + s_idTokenCacheDictionary.Clear(); + s_accountCacheDictionary.Clear(); + s_appMetadataDictionary.Clear(); + Interlocked.Exchange(ref s_entryCount, 0); + } } } diff --git a/src/client/Microsoft.Identity.Client/Properties/InternalsVisibleTo.cs b/src/client/Microsoft.Identity.Client/Properties/InternalsVisibleTo.cs index 769ebfcbf0..d6c67f8270 100644 --- a/src/client/Microsoft.Identity.Client/Properties/InternalsVisibleTo.cs +++ b/src/client/Microsoft.Identity.Client/Properties/InternalsVisibleTo.cs @@ -7,6 +7,7 @@ [assembly: InternalsVisibleTo("Microsoft.Identity.Client.Desktop" + KeyTokens.MSAL)] [assembly: InternalsVisibleTo("Microsoft.Identity.Client.Desktop.WinUI3" + KeyTokens.MSAL)] [assembly: InternalsVisibleTo("Microsoft.Identity.Client.Broker" + KeyTokens.MSAL)] +[assembly: InternalsVisibleTo("Microsoft.Identity.Client.MtlsPop" + KeyTokens.MSAL)] [assembly: InternalsVisibleTo("Microsoft.Identity.Test.Unit" + KeyTokens.MSAL)] [assembly: InternalsVisibleTo("Microsoft.Identity.Test.Common" + KeyTokens.MSAL)] diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt index e69de29bb2..a7f08728c1 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt @@ -0,0 +1,6 @@ +const Microsoft.Identity.Client.MsalError.InvalidCertificate = "invalid_certificate" -> string +const Microsoft.Identity.Client.MsalError.MtlsNotSupportedForManagedIdentity = "mtls_not_supported_for_managed_identity" -> string +Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource +Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.WithExtraQueryParameters(System.Collections.Generic.IDictionary extraQueryParameters) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder +static Microsoft.Identity.Client.ApplicationBase.ResetStateForTest() -> void diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt index e69de29bb2..a7f08728c1 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt @@ -0,0 +1,6 @@ +const Microsoft.Identity.Client.MsalError.InvalidCertificate = "invalid_certificate" -> string +const Microsoft.Identity.Client.MsalError.MtlsNotSupportedForManagedIdentity = "mtls_not_supported_for_managed_identity" -> string +Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource +Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.WithExtraQueryParameters(System.Collections.Generic.IDictionary extraQueryParameters) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder +static Microsoft.Identity.Client.ApplicationBase.ResetStateForTest() -> void diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt index e69de29bb2..a7f08728c1 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt @@ -0,0 +1,6 @@ +const Microsoft.Identity.Client.MsalError.InvalidCertificate = "invalid_certificate" -> string +const Microsoft.Identity.Client.MsalError.MtlsNotSupportedForManagedIdentity = "mtls_not_supported_for_managed_identity" -> string +Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource +Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.WithExtraQueryParameters(System.Collections.Generic.IDictionary extraQueryParameters) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder +static Microsoft.Identity.Client.ApplicationBase.ResetStateForTest() -> void diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt index e69de29bb2..a7f08728c1 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt @@ -0,0 +1,6 @@ +const Microsoft.Identity.Client.MsalError.InvalidCertificate = "invalid_certificate" -> string +const Microsoft.Identity.Client.MsalError.MtlsNotSupportedForManagedIdentity = "mtls_not_supported_for_managed_identity" -> string +Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource +Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.WithExtraQueryParameters(System.Collections.Generic.IDictionary extraQueryParameters) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder +static Microsoft.Identity.Client.ApplicationBase.ResetStateForTest() -> void diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt index e69de29bb2..a7f08728c1 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt @@ -0,0 +1,6 @@ +const Microsoft.Identity.Client.MsalError.InvalidCertificate = "invalid_certificate" -> string +const Microsoft.Identity.Client.MsalError.MtlsNotSupportedForManagedIdentity = "mtls_not_supported_for_managed_identity" -> string +Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource +Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.WithExtraQueryParameters(System.Collections.Generic.IDictionary extraQueryParameters) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder +static Microsoft.Identity.Client.ApplicationBase.ResetStateForTest() -> void diff --git a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt index e69de29bb2..a7f08728c1 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt @@ -0,0 +1,6 @@ +const Microsoft.Identity.Client.MsalError.InvalidCertificate = "invalid_certificate" -> string +const Microsoft.Identity.Client.MsalError.MtlsNotSupportedForManagedIdentity = "mtls_not_supported_for_managed_identity" -> string +Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource +Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.WithExtraQueryParameters(System.Collections.Generic.IDictionary extraQueryParameters) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder +static Microsoft.Identity.Client.ApplicationBase.ResetStateForTest() -> void diff --git a/src/client/Microsoft.Identity.Client/RequestType.cs b/src/client/Microsoft.Identity.Client/RequestType.cs index 107067943a..272bcfa5c9 100644 --- a/src/client/Microsoft.Identity.Client/RequestType.cs +++ b/src/client/Microsoft.Identity.Client/RequestType.cs @@ -26,6 +26,11 @@ internal enum RequestType /// /// Region Discovery request, used for region discovery operations with exponential backoff retry strategy. /// - RegionDiscovery + RegionDiscovery, + + /// + /// CSR Metadata Probe request, used to probe an IMDSv2 managed identity for metadata to be used in acquiring a token. + /// + CsrMetadataProbe } } diff --git a/src/client/Microsoft.Identity.Client/Utils/DateTimeHelpers.cs b/src/client/Microsoft.Identity.Client/Utils/DateTimeHelpers.cs index bd421b935d..60b4ea25cf 100644 --- a/src/client/Microsoft.Identity.Client/Utils/DateTimeHelpers.cs +++ b/src/client/Microsoft.Identity.Client/Utils/DateTimeHelpers.cs @@ -85,7 +85,15 @@ public static long GetDurationFromManagedIdentityTimestamp(string dateTimeStamp) // Example: "1697490590" (Unix timestamp representing seconds since 1970-01-01) if (long.TryParse(dateTimeStamp, out long expiresOnUnixTimestamp)) { - return expiresOnUnixTimestamp - DateTimeHelpers.CurrDateTimeInUnixTimestamp(); + var timestamp = expiresOnUnixTimestamp - DateTimeHelpers.CurrDateTimeInUnixTimestamp(); + + // If the timestamp is negative, return the original expiresOnUnixTimestamp. Its format is "seconds from now". + if (timestamp < 0) + { + return expiresOnUnixTimestamp; + } + + return timestamp; } // Try parsing as ISO 8601 diff --git a/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs b/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs index 1283c66ad5..eb6fd842ef 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs @@ -42,6 +42,7 @@ public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentity break; case ManagedIdentitySource.Imds: + case ManagedIdentitySource.ImdsV2: Environment.SetEnvironmentVariable("AZURE_POD_IDENTITY_AUTHORITY_HOST", endpoint); break; @@ -59,11 +60,15 @@ public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentity Environment.SetEnvironmentVariable("IDENTITY_HEADER", secret); Environment.SetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT", thumbprint); break; + case ManagedIdentitySource.MachineLearning: Environment.SetEnvironmentVariable("MSI_ENDPOINT", endpoint); Environment.SetEnvironmentVariable("MSI_SECRET", secret); Environment.SetEnvironmentVariable("DEFAULT_IDENTITY_CLIENT_ID", "fake_DEFAULT_IDENTITY_CLIENT_ID"); break; + + default: + throw new NotImplementedException($"Setting environment variables for {managedIdentitySource} is not implemented."); } } @@ -123,7 +128,7 @@ public static ManagedIdentityApplicationBuilder CreateMIABuilder(string userAssi break; } - // Disabling shared cache options to avoid cross test pollution. + builder.Config.AccessorOptions = null; return builder; diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index 7050ae13e5..9a1e143029 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -2,16 +2,25 @@ // Licensed under the MIT License. using System; +using System.Collections.Generic; using System.Globalization; using System.IO; using System.Net; using System.Net.Http; using System.Net.Http.Headers; -using Microsoft.Identity.Client.Utils; -using Microsoft.Identity.Test.Unit; +using Castle.Core.Logging; using Microsoft.Identity.Client; -using Microsoft.Identity.Client.OAuth2; +using Microsoft.Identity.Client.AppConfig; +using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.Internal.Logger; using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.ManagedIdentity.V2; +using Microsoft.Identity.Client.OAuth2; +using Microsoft.Identity.Client.OAuth2.Throttling; +using Microsoft.Identity.Client.Utils; +using Microsoft.Identity.Test.Unit; +using Microsoft.VisualStudio.TestTools.UnitTesting.Logging; +using static Microsoft.Identity.Test.Common.Core.Helpers.ManagedIdentityTestUtil; namespace Microsoft.Identity.Test.Common.Core.Mocks { @@ -70,12 +79,12 @@ public static string GetTokenResponseWithNoOidClaim() public static string GetDefaultTokenResponse(string accessToken = TestConstants.ATSecret, string refreshToken = TestConstants.RTSecret) { - return - "{\"token_type\":\"Bearer\",\"expires_in\":\"3599\",\"refresh_in\":\"2400\",\"scope\":" + - "\"r1/scope1 r1/scope2\",\"access_token\":\"" + accessToken + "\"" + - ",\"refresh_token\":\"" + refreshToken + "\",\"client_info\"" + - ":\"" + CreateClientInfo() + "\",\"id_token\"" + - ":\"" + CreateIdToken(TestConstants.UniqueId, TestConstants.DisplayableId) + "\"}"; + return + "{\"token_type\":\"Bearer\",\"expires_in\":\"3599\",\"refresh_in\":\"2400\",\"scope\":" + + "\"r1/scope1 r1/scope2\",\"access_token\":\"" + accessToken + "\"" + + ",\"refresh_token\":\"" + refreshToken + "\",\"client_info\"" + + ":\"" + CreateClientInfo() + "\",\"id_token\"" + + ":\"" + CreateIdToken(TestConstants.UniqueId, TestConstants.DisplayableId) + "\"}"; } public static string GetPopTokenResponse() @@ -113,24 +122,30 @@ public static string GetBridgedHybridSpaTokenResponse(string spaAccountId) ",\"id_token_expires_in\":\"3600\"}"; } - public static string GetMsiSuccessfulResponse(int expiresInHours = 1, bool useIsoFormat = false) + public static string GetMsiSuccessfulResponse( + int expiresInHours = 1, + bool useIsoFormat = false, + bool mTLSPop = false, + bool imdsV2 = false) { - string expiresOn; - + var expiresOnKey = imdsV2 ? "expires_in" : "expires_on"; + string expiresOnValue; if (useIsoFormat) { // Return ISO 8601 format - expiresOn = DateTime.UtcNow.AddHours(expiresInHours).ToString("o", CultureInfo.InvariantCulture); + expiresOnValue = DateTime.UtcNow.AddHours(expiresInHours).ToString("o", CultureInfo.InvariantCulture); } else { // Return Unix timestamp format - expiresOn = DateTimeHelpers.DateTimeToUnixTimestamp(DateTime.UtcNow.AddHours(expiresInHours)); + expiresOnValue = DateTimeHelpers.DateTimeToUnixTimestamp(DateTime.UtcNow.AddHours(expiresInHours)); } + var tokenType = mTLSPop ? "mtls_pop" : "Bearer"; + return - "{\"access_token\":\"" + TestConstants.ATSecret + "\",\"expires_on\":\"" + expiresOn + "\",\"resource\":\"https://management.azure.com/\",\"token_type\":" + - "\"Bearer\",\"client_id\":\"client_id\"}"; + "{\"access_token\":\"" + TestConstants.ATSecret + "\",\"" + expiresOnKey + "\":\"" + expiresOnValue + "\",\"resource\":\"https://management.azure.com/\"," + + "\"token_type\":\"" + tokenType + "\",\"client_id\":\"client_id\"}"; } public static string GetMsiErrorBadJson() @@ -396,7 +411,7 @@ public static string CreateSuccessTokenResponseString(string uniqueId, idToken + (foci ? "\",\"foci\":\"1" : "") + "\",\"id_token_expires_in\":\"3600\",\"client_info\":\"" + CreateClientInfo(uniqueId, utid) + "\"}"; - + return stringContent; } @@ -583,5 +598,144 @@ public static MsalTokenResponse CreateMsalRunTimeBrokerTokenResponse(string acce TokenSource = TokenSource.Broker }; } + + public static MockHttpMessageHandler MockCsrResponse( + HttpStatusCode statusCode = HttpStatusCode.OK, + string responseServerHeader = "IMDS/150.870.65.1854", + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null) + { + IDictionary expectedQueryParams = new Dictionary(); + IDictionary expectedRequestHeaders = new Dictionary(); + IList presentRequestHeaders = new List + { + OAuth2Header.XMsCorrelationId + }; + + if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) + { + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)userAssignedIdentityId, userAssignedId, null); + expectedQueryParams.Add(userAssignedIdQueryParam.Value.Key, userAssignedIdQueryParam.Value.Value); + } + expectedQueryParams.Add("cred-api-version", "2.0"); + expectedRequestHeaders.Add("Metadata", "true"); + + string content = + "{" + + "\"cuId\": { \"vmId\": \"fake_vmId\" }," + + "\"clientId\": \"" + TestConstants.ClientId + "\"," + + "\"tenantId\": \"" + TestConstants.TenantId + "\"," + + "\"attestationEndpoint\": \"fake_attestation_endpoint\"" + + "}"; + + var handler = new MockHttpMessageHandler() + { + ExpectedUrl = $"{ImdsManagedIdentitySource.DefaultImdsBaseEndpoint}{ImdsV2ManagedIdentitySource.CsrMetadataPath}", + ExpectedMethod = HttpMethod.Get, + ExpectedQueryParams = expectedQueryParams, + ExpectedRequestHeaders = expectedRequestHeaders, + PresentRequestHeaders = presentRequestHeaders, + ResponseMessage = new HttpResponseMessage(statusCode) + { + Content = new StringContent(content), + } + }; + + if (responseServerHeader != null) + handler.ResponseMessage.Headers.TryAddWithoutValidation("server", responseServerHeader); + + return handler; + } + + // used for unit tests in ManagedIdentityTests.cs + public static MockHttpMessageHandler MockCsrResponseFailure() + { + // 400 doesn't trigger the retry policy + return MockCsrResponse(HttpStatusCode.BadRequest); + } + + public static MockHttpMessageHandler MockCertificateRequestResponse( + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null, + string certificate = TestConstants.ValidRawCertificate) + { + IDictionary expectedQueryParams = new Dictionary(); + IDictionary expectedRequestHeaders = new Dictionary(); + IList presentRequestHeaders = new List + { + OAuth2Header.XMsCorrelationId + }; + + if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) + { + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)userAssignedIdentityId, userAssignedId, null); + expectedQueryParams.Add(userAssignedIdQueryParam.Value.Key, userAssignedIdQueryParam.Value.Value); + } + expectedQueryParams.Add("cred-api-version", ImdsV2ManagedIdentitySource.ImdsV2ApiVersion); + expectedRequestHeaders.Add("Metadata", "true"); + + string content = + "{" + + "\"client_id\": \"" + TestConstants.ClientId + "\"," + + "\"tenant_id\": \"" + TestConstants.TenantId + "\"," + + "\"certificate\": \"" + certificate + "\"," + + "\"identity_type\": \"fake_identity_type\"," + // "SystemAssigned" or "UserAssigned", it doesn't matter for these tests + "\"mtls_authentication_endpoint\": \"" + TestConstants.MtlsAuthenticationEndpoint + "\"" + + "}"; + + var handler = new MockHttpMessageHandler() + { + ExpectedUrl = $"{ImdsManagedIdentitySource.DefaultImdsBaseEndpoint}{ImdsV2ManagedIdentitySource.CertificateRequestPath}", + ExpectedMethod = HttpMethod.Post, + ExpectedQueryParams = expectedQueryParams, + ExpectedRequestHeaders = expectedRequestHeaders, + PresentRequestHeaders = presentRequestHeaders, + ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(content), + } + }; + + return handler; + } + + public static MockHttpMessageHandler MockImdsV2EntraTokenRequestResponse( + IdentityLoggerAdapter identityLoggerAdapter, + bool mTLSPop = false) + { + IDictionary expectedPostData = new Dictionary(); + IDictionary expectedRequestHeaders = new Dictionary + { + { ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue } + }; + IList presentRequestHeaders = new List + { + OAuth2Header.XMsCorrelationId + }; + + var idParams = MsalIdHelper.GetMsalIdParameters(identityLoggerAdapter); + foreach (var idParam in idParams) + { + expectedRequestHeaders[idParam.Key] = idParam.Value; + } + + var tokenType = mTLSPop ? "mtls_pop" : "bearer"; + expectedPostData.Add("token_type", tokenType); + + var handler = new MockHttpMessageHandler() + { + ExpectedUrl = $"{TestConstants.MtlsAuthenticationEndpoint}/{TestConstants.TenantId}{ImdsV2ManagedIdentitySource.AcquireEntraTokenPath}", + ExpectedMethod = HttpMethod.Post, + ExpectedPostData = expectedPostData, + ExpectedRequestHeaders = expectedRequestHeaders, + PresentRequestHeaders = presentRequestHeaders, + ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(GetMsiSuccessfulResponse(mTLSPop: mTLSPop, imdsV2: true)), + } + }; + + return handler; + } } } diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs index 565ca72e68..c8b63a208d 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs @@ -374,7 +374,8 @@ public static MockHttpMessageHandler AddManagedIdentityMockHandler( HttpStatusCode statusCode = HttpStatusCode.OK, string retryAfterHeader = null, // A number of seconds (e.g., "120"), or an HTTP-date in RFC1123 format (e.g., "Fri, 19 Apr 2025 15:00:00 GMT") bool capabilityEnabled = false, - bool claimsEnabled = false + bool claimsEnabled = false, + IDictionary extraQueryParameters = null ) { HttpResponseMessage responseMessage = new HttpResponseMessage(statusCode) @@ -393,6 +394,15 @@ public static MockHttpMessageHandler AddManagedIdentityMockHandler( capabilityEnabled, claimsEnabled); + // Add extra query parameters if provided + if (extraQueryParameters != null) + { + foreach (var kvp in extraQueryParameters) + { + httpMessageHandler.ExpectedQueryParams[kvp.Key] = kvp.Value; + } + } + if (managedIdentitySourceType == ManagedIdentitySource.MachineLearning) { // For Machine Learning (App Service 2017), the client id param is "clientid" diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs index e067d640b8..cdfbb5432b 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs @@ -37,6 +37,8 @@ internal class MockHttpMessageHandler : HttpClientHandler public HttpRequestMessage ActualRequestMessage { get; private set; } public Dictionary ActualRequestPostData { get; private set; } public HttpRequestHeaders ActualRequestHeaders { get; private set; } + public IList PresentRequestHeaders { get; set; } + public X509Certificate2 ExpectedMtlsBindingCertificate { get; set; } protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) @@ -174,6 +176,15 @@ private void ValidateNotExpectedPostData() private void ValidateHeaders(HttpRequestMessage request) { + if (PresentRequestHeaders != null) + { + foreach (var headerName in PresentRequestHeaders) + { + Assert.IsTrue(request.Headers.Contains(headerName), + $"Expected request header to be present: {headerName}."); + } + } + ActualRequestHeaders = request.Headers; if (ExpectedRequestHeaders != null) { diff --git a/tests/Microsoft.Identity.Test.Common/TestCommon.cs b/tests/Microsoft.Identity.Test.Common/TestCommon.cs index 8258df4383..0fd0163719 100644 --- a/tests/Microsoft.Identity.Test.Common/TestCommon.cs +++ b/tests/Microsoft.Identity.Test.Common/TestCommon.cs @@ -6,63 +6,31 @@ using System.Globalization; using System.Linq; using System.Net; -using System.Net.Http; using System.Text; using System.Threading; -using System.Threading.Tasks; using Microsoft.Identity.Client; using Microsoft.Identity.Client.ApiConfig.Parameters; using Microsoft.Identity.Client.AppConfig; using Microsoft.Identity.Client.Cache; using Microsoft.Identity.Client.Cache.Items; using Microsoft.Identity.Client.Http; +using Microsoft.Identity.Client.Http.Retry; using Microsoft.Identity.Client.Instance; -using Microsoft.Identity.Client.Instance.Discovery; -using Microsoft.Identity.Client.Instance.Oidc; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Internal.Requests; using Microsoft.Identity.Client.Kerberos; -using Microsoft.Identity.Client.OAuth2.Throttling; using Microsoft.Identity.Client.PlatformsCommon.Factories; using Microsoft.Identity.Client.PlatformsCommon.Interfaces; using Microsoft.Identity.Client.PlatformsCommon.Shared; +using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.Identity.Test.Unit; using Microsoft.VisualStudio.TestTools.UnitTesting; -using Microsoft.Identity.Test.Common.Core.Mocks; -using NSubstitute; using static Microsoft.Identity.Client.TelemetryCore.Internal.Events.ApiEvent; -using Microsoft.Identity.Client.Http.Retry; namespace Microsoft.Identity.Test.Common { internal static class TestCommon { - public static void ResetInternalStaticCaches() - { - // This initializes the classes so that the statics inside them are fully initialized, and clears any cached content in them. - new InstanceDiscoveryManager( - Substitute.For(), - true, null, null); - OidcRetrieverWithCache.ResetCacheForTest(); - AuthorityManager.ClearValidationCache(); - SingletonThrottlingManager.GetInstance().ResetCache(); - } - - public static object GetPropValue(object src, string propName) - { - object result = null; - try - { - result = src.GetType().GetProperty(propName).GetValue(src, null); - } - catch - { - Console.WriteLine($"Property with name {propName}"); - } - - return result; - } - public static IServiceBundle CreateServiceBundleWithCustomHttpManager( IHttpManager httpManager, LogCallback logCallback = null, @@ -94,7 +62,11 @@ public static IServiceBundle CreateServiceBundleWithCustomHttpManager( PlatformProxy = platformProxy, RetryPolicyFactory = new RetryPolicyFactory() }; - return new ServiceBundle(appConfig, clearCaches); + + if (clearCaches) + ApplicationBase.ResetStateForTest(); + + return new ServiceBundle(appConfig); } public static IServiceBundle CreateDefaultServiceBundle() diff --git a/tests/Microsoft.Identity.Test.Common/TestConstants.cs b/tests/Microsoft.Identity.Test.Common/TestConstants.cs index 8a59bc1d48..bf31b71dc3 100644 --- a/tests/Microsoft.Identity.Test.Common/TestConstants.cs +++ b/tests/Microsoft.Identity.Test.Common/TestConstants.cs @@ -156,6 +156,9 @@ public static HashSet s_scope public const string IdentityProvider = "my-idp"; public const string Name = "First Last"; public const string MiResourceId = "/subscriptions/ffa4aaa2-4444-4444-5555-e3ccedd3d046/resourcegroups/UAMI_group/providers/Microsoft.ManagedIdentityClient/userAssignedIdentities/UAMI"; + public const string VmId = "test-vm-id"; + public const string VmssId = "test-vmss-id"; + public const string MtlsAuthenticationEndpoint = "http://fake_mtls_authentication_endpoint"; public const string Claims = @"{""userinfo"":{""given_name"":{""essential"":true},""nickname"":null,""email"":{""essential"":true},""email_verified"":{""essential"":true},""picture"":null,""http://example.info/claims/groups"":null},""id_token"":{""auth_time"":{""essential"":true},""acr"":{""values"":[""urn:mace:incommon:iap:silver""]}}}"; public static readonly string[] ClientCapabilities = new[] { "cp1", "cp2" }; @@ -585,6 +588,29 @@ public static MsalTokenResponse CreateAadTestTokenResponseWithFoci() internal const string UserAccessToken = "flMpQIKiCoiPK6qISSjmF9dGhKe47KFGPwe82BDBxBCVfYI4UiKYbBuShsjf8oGTsjN5ODeaO6k0cmZJYuNNbLyOr8JGqoxQRW9bI8j5ETpbTNf6tYpAWde9PIYj2wEBnbughVgtJsh2QxIrahie5leMpsGb1yoFzADD5gyoJq8etNUSgZwe5qkfaE9UBCUKrznKjKbsG5hBJXut5GD0QdQy3wo2PnocewrptlMzd5SsHCzUUBGA4q7ks7IfrLiQH11JyBnjBhypOX3XvuqBz4JKkpftVYfvwPWE3f5Onku6FkZJFFESyGQP9YnJVx5dQCpHH9l6ShTqOLSQduf7wxoyeAgxwPrM9Y8Kvj31IrXqiwP52x4hBsctLCqOXOZ3wMXnozMXyHpNvKMJaNgDgvBgMYhiyORkb3qKYw0gAP4659I8dK1esxJoD8I3EreDftGfNMFCgn7kFfauUQphkqx8ukqzw068R7g5TOUci1pgPcVXCAMxj0P3fTiKe1doVuF6znKYh3m7pjyzyaqb5K9VFIh4A8TXOO0MqjaVkoSWJXARTy4T0kAZBVPbO6U2BWku23yLIt43MhQTc9uf7inuirwaIgh5u7noDxYG4QZLB1CJl04Zq2gbh9GW7dqweAaC9efYTEDwhxDTPHeGTQs44e8cnWerIyZA7mq8sFuzihIiCfgZ6nNBPcx2lXKyarUtQGmjjRyOEAhs66atv3SgMhNBhontPoUhR1QEnTKeYzfaavlnf5qMZA41hijGazHyxy5FgLD5aLEpZTHN5MPQLeaEXzDMX5Wtdvq7nokiItRfLkKZtXkuSiFVltmRPcKqzGbjNRH96OQzuxLE1Mv25FYFR3PAwv6np69yScVOpNFL8CqJdT310dGnRPUKSrEqTPuMsHqVRr36j2ZUaGs6YBtcrxIxKHuPrv23FQg5fC0FgxZvKqve0hf68AocJ1HqKRy01CGQobmYpTwBByftOZYGC4KOfGd13l78kZaKLuk2gxfFuTQyr11A0L4n5tXfjlikJtr3wlTGt0KCGGXmNK1xsSoRC0VcXDOgQUu3FHblhiaYjbSvPRF09xn9tRPnUkznbsT1kPMiJ8v89ZOCtVWpvkoiy9VUVcSUpZNQwRh3wHidZAkp1xyjyVc2pIHPg6XhzJnlt77zHNiBkPxWbYt7hXBQf3QeYoMF4s0Qi1y5N72DdoSNJ3iaTwx3esAz6TeyxSh36PIz35mR5jGyGMssyaNg6lIewLPbjnizgC6xssi6mKOheDqWqBv89nIvSBOXEkKcUYsBlhBBK6BgxOIha1NAeP93RRKfyjrF7LtIoSOk3DJUx75rUJ9oyuuTt4FdSnp7ZdrIciO8vlNslPrfa7UjBdOtVHiaz9Ef91dctdADVFcwXXmcu2ypyKB1YvMbkPP7mc12TF1a8X6t0mU4s4J4IpA3SHmT5JvbQBEzOIs6ex38X3UtXSItxpaS2gKozAhAmvjt6NKMe3Jysm4bafH1kb8eB1vdwTQu3jIOGozqHC3rvqEVAt26NNKOuNYAoYYamQOSb2w8PUCuDDWs1ffLvvfyvRndZztV5C4HGGR1Tg82N291Sb7rSUYmA1rdGyJ4kPtSaiPOwMyPUs9FuZNef5Ib83D3gTcgS1gMxto5UkfSxtCDKLXtGKArOdACrRzHiiMSn3owQfyVtSXZPdeofoCzuPWcZzFLBUJR0iKWBpUkxd0N17vw45uMQpQUNGgGoyvyboKkAFlOGsEIAmrnooC3CJGVA4jHPYJnVG4xTJ37U6QL5sX95qWtjbvuD5KoT2GyWec0o62CNr09tCQsiALLC1QrfCiCGsullefbsgBB5tsOY1Kyiy4uf84qBMu20GbsJ01R8xxpJ5bh6HFRaStEK3WIy7TMJym42YMbxB3AGsGFGhNYljtuqgeUjXn1UuWskkB6QqdepFHCof6CHg0LlV0o4Iz9QKu5cfoi8jk5HKbvIGyDqCgZaC2LdugNgQ0X"; internal const string RefreshToken = "mhJDJ8wtjA3KxpRtuPAreZnMcJ2yKC2JUbpOGbRTdOCImLyQ2B4EIhv8AiA2cCEylZZfZsOsZrNsMBZZAAU9TQYYEO72QcdfnIWpAOeKkud5W2L8nMq6i9dx1EVIl09zFXhOJ79BdFbU0Eb5aUHlcqPCQjec62UKBLkZJmtMnoAa8cjvgIuxTdVM8FNdghe5nlCNTEVooKleTTEHNl2BrdyitLaWTKSP0lRqnFxriG0xWcJoSMsdS7Vt6HZd1TkwHIXycNMlCcCdUh5tOgqx1M8y8uoXK4OJ1LQmtkZvcQWcycvOCPACYakKM1pUQqwTxI6Y4HrL38sqQaSNxpF9OcFxOQWpuGodRekCbxXVbWclttIpvSOLaBhZ2ZBpcCBEeEMSmhqqYgajNwwwe9w88u0UsYKe6PBbaI48ENr02u2qBeLsIQ2HUyKlN3iVmX7u7MhgDWA3NNavMtlLmWd63NfuDgXpLI0O4cLhjAx8uoBIK8LntXPHPTxJ28o0yrszvD4gf7RdhuTq5VE15zne6iAJgIGfy7latGFzxuDMcML9OoXURHnNEHBgS9ZQCfNzYZ2O9flF1UjGpcBLEi7hHVHnrQb4y7c98dz9p62cvEMhorGx9kCwSIkOae5LheXPQkFIbsGyomNEwz3HZvR131VGAwdfmUUodvPr6LAAtmjl4sZ72PRqAo8EdQ0IFsWoypXVv51IooR87tO3uiG2DkxhIAwumOQdaJNxw1a0WS9mpQOmwFlvfbZkaIoUKgagHc8fVa1aHZntLGwH0S1iYixJiIrMnPYAeRdSp9mlHllrMX8xUIznobcZ5i8MpUYCKlUXMZ82S3XUJ5dJxARNRPxXlLJ5LPYBhUNkBLQen9Qmq3VZEV1RDJyhbGp6GAo14KsMtVAVYNmYPIgo85pCZgOwVEOBUycszu4AD3p4PT2ella4LVoqmTTMSA5GEWoeWb5JvEo222Z0oKr7UK8dGwpWRSbg8TNeODihJaTUDfErvbgaZnjIRpqfgtM5i1HfQbD7Yyft5PqyygUra7GYy7pjRrEvq95XQD8sAZ32ku9AqCo5qOB584iX881WErOoheQZokt1txqwuIMUyhVuMKNEXy70CeNTsb30ghQMZpZcXIkrLYyQCZ0gNmARhMKagCSdrpUtxudLk44yfmuwSQzBN3ifWfLZiFpU53qdPLZoTw5"; internal const string IdToken = "6GwdM7f6hHXfivavPozhaRqrbxvEysfXSMQyEKBwVgivPZTtmowsmYygchhIuxjeFFeq1ZPHjhxKFnulrvoY6TDerZY5xyOlg45bToI9Bu95qFvUrrt5r17UJcXdw4YkvEt10CcDDcLcEYw704RpVefvbpjbF24pOgIuafcAkDnbDA0Qea4ePuSC45Lw7zpJhbo9Gh8IfMX597fayBvMs3fh7frrm9KpWMCeKY3h99YSaCYjZFKp1ppvXXPE9bc4sh4pRDOfnv0Yr9J8u4elZevEE4qGddfgd3hYb18XPGRjPEMlWsh7tnwxwUm6OSZlMTHYuvwBENNMx7SUQmMeg4rCfgnbcNDkWpXCiSDVt1lLLv8F2GjYnM6De3v1Ks5lhBWx3grLggcN9LnXz92eJ1l5lTB2v0y9MgmFZ4gY43oIOW5n8G5HOx3bGOyjTw0TKKbyVa3mDj0A3QqW8eLTUJz42BNiGOf5m9prMSlpAW59CHCMJLatsj3IvGeCITsGAr3sUZEytORWUdxCfuIPwecQgU6bO7pNqNvZc1tJHHNwJlfS23ZkiFuEXqEThHYfxBCFxAzMDlzO0TOdWhvrb8hlNeAOcNhoAKxu7HXsePajKs4fU1rcdSxzNKwtASEla3p6jfJnnDtKf38RJZPaRRYMviqqWEMhjmqIvBm7sMaf8RyNNuYl7otZwmwNVCR1hzzmaTAy4kQce67FJqFba7uizrgwp9zsvK8muCHKKPvNthy7fHsxKmrBIm0bLcoePKK3wAID4kFvNQcxXp6rAOr8bLFF3bLEoYdzmF2QJz1frVZZHHPy90Cmlhw48EQN8NE2OllpdaykKt5k4rPcZQyitayNNhism30qh7eCBhcA7mm5Ja0S8X4VPlkwvgwg0mQuul6gakmja8xpnTrwiOdtao320GDmJaJA6zf3UTpNZTq9tdfBtUrjAD8RS0tNUBT3Ko8N2Lfh9ry8y9ESmRVIhch3rKY7UeefFAnkiwH2WwC57ZEsHtMP0SwKYtYKHZW9HkERCCyqOT1Mw0IavsLGFvchzMAvTnz4RwRBk6IrWgANvqT3F3Vexc2K0poKb71XZ4aMXxjqAzydGQAKpKJEJcqEvX9RD8nL76TF2LZIepiaZ3dbQImkqSjbF7aaY2JFoN9ZWlcSQKe8zdO8TIG16bF8W9R4ldDyzV39L33KcweG"; + + #region Test Certificate and Private Key (ExpiredRawCertificate, ValidRawCertificate & XmlPrivateKey) + /// + /// Test (expired and valid) PEM-encoded X.509 certificate and their matching RSA private key. + /// These are used together in unit tests that require both a certificate and its private key. + /// The / and are a matched pair: + /// - is an expired PEM-encoded certificate. The certificate is valid for 1 day and was created on September 8 2025, ensuring it will always be expired. + /// - is a valid PEM-encoded certificate. The certificate is valid for 100 years and expires on August 4, 2125, ensuring it will not expire during the lifetime of the tests. + /// - is their corresponding RSA private key in XML format. + /// + internal const string ExpiredRawCertificate = "MIIC/zCCAeegAwIBAgIUGSVU23Wc0+QtCbUTjsyPOrc0XpEwDQYJKoZIhvcNAQELBQAwDzENMAsGA1UEAwwEVGVzdDAeFw0yNTA5MDgyMjAxMTdaFw0yNTA5MDkyMjAxMTdaMA8xDTALBgNVBAMMBFRlc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC5XNEuk3cIEChkZd2P/bljUaVqNVh4mbXdWHYAgbdK48U6rG0FLq1NAfSnZO0EPbK8Zo4psRh2lBcqW29/WsKiHUEHLkLyFI+frEIfc8wskd+WxkKfL8G52uRpYQCG87FIv8uZBBlDG7kDdOV36CUkK1N+V2fHbkEgx+YfWg6+pLi3KQx6Pf/b2YqLD36hj8WRrVYzL6yXVUBiyRd+cQ9y5V/MRtoiX1Sv8WEFYtzIG0TUGi9pR7WWhgHNQk6DFDzutMV62ZEBNPIQvdO2EwXGr1FUIOL6zmj6bArPhY+hCXGrAAwCXodZhgZ95BxTwsQWtjCha2hT6ed8zmoE72FdAgMBAAGjUzBRMB0GA1UdDgQWBBQPYq0Efzuv1diVcgxBxTnVA4wLMjAfBgNVHSMEGDAWgBQPYq0Efzuv1diVcgxBxTnVA4wLMjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCXAD7cjWmmTqP0NX4MqwO0AHtO+KGVtfxF8aI21Ty/nHh2SAODzsemP3NBBvoEvllwtcVyutPqvUiAflMLNbp0ucTu+aWE14s1V9Bnt6++5g7gtXItsNV3F/ymYKsyfhDvJbWCOv5qYeJMQ+jtODHN9qnATODT5voULTwEVSYQXtutwRxR8e70Cvok+F+4I6Ni49DJ8DmcYzvB94uthqpDsygY1vYzpRbB5hpW0/D7kgVVWyWoOWiE1mV7Fry7tUWQw7EqnX89kMLMy4g6UfOv4gtam8RBa9dLyMW1rCHRxOulP47joI10g9JoJ9DssiQTUojJgQXOSBBXdD20H+zl"; + internal const string ValidRawCertificate = "MIIDATCCAemgAwIBAgIUSfjghyQB4FIS41rWfNcZHTLE/R4wDQYJKoZIhvcNAQELBQAwDzENMAsGA1UEAwwEVGVzdDAgFw0yNTA4MjgyMDIxMDBaGA8yMTI1MDgwNDIwMjEwMFowDzENMAsGA1UEAwwEVGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBALlc0S6TdwgQKGRl3Y/9uWNRpWo1WHiZtd1YdgCBt0rjxTqsbQUurU0B9Kdk7QQ9srxmjimxGHaUFypbb39awqIdQQcuQvIUj5+sQh9zzCyR35bGQp8vwbna5GlhAIbzsUi/y5kEGUMbuQN05XfoJSQrU35XZ8duQSDH5h9aDr6kuLcpDHo9/9vZiosPfqGPxZGtVjMvrJdVQGLJF35xD3LlX8xG2iJfVK/xYQVi3MgbRNQaL2lHtZaGAc1CToMUPO60xXrZkQE08hC907YTBcavUVQg4vrOaPpsCs+Fj6EJcasADAJeh1mGBn3kHFPCxBa2MKFraFPp53zOagTvYV0CAwEAAaNTMFEwHQYDVR0OBBYEFA9irQR/O6/V2JVyDEHFOdUDjAsyMB8GA1UdIwQYMBaAFA9irQR/O6/V2JVyDEHFOdUDjAsyMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAAOxtgYjtkUDVvWzq/lkjLTdcLjPvmH0hF34A3uvX4zcjmqF845lfvszTuhc1mx5J6YLEzKfr4TrO3D3g2BnDLvhupok0wEmJ9yVwbt1laim7zP09gZqnUqYM9hYKDhwgLZAaG3zGNocxDEAU7jazMGOGF7TweB7LdNuVI6CqgDOBQ8Cy2ObuZvzCI5Y7f+HucXpiJOu1xNa2ZZpMpQycYEvi5TD+CL5CBv2fcKQRn/+u5B3ZXCD2C9jT/RZ7rH46mIG7nC7dS4J2o4JjmlJIUAe2U6tRay5GvEmc/nZK8hd9y4BICzrykp9ENAoy9i+uaE1GGWeNgO+irrcrAcLwto="; + internal const string XmlPrivateKey = @" + uVzRLpN3CBAoZGXdj/25Y1GlajVYeJm13Vh2AIG3SuPFOqxtBS6tTQH0p2TtBD2yvGaOKbEYdpQXKltvf1rCoh1BBy5C8hSPn6xCH3PMLJHflsZCny/BudrkaWEAhvOxSL/LmQQZQxu5A3Tld+glJCtTfldnx25BIMfmH1oOvqS4tykMej3/29mKiw9+oY/Fka1WMy+sl1VAYskXfnEPcuVfzEbaIl9Ur/FhBWLcyBtE1BovaUe1loYBzUJOgxQ87rTFetmRATTyEL3TthMFxq9RVCDi+s5o+mwKz4WPoQlxqwAMAl6HWYYGfeQcU8LEFrYwoWtoU+nnfM5qBO9hXQ== + AQAB +

3pGBJXfhILNTsbRLHmUy7YVvD75HpvMCey2aaN4gU9Jvi1s2vQFU15a8p75Yt8UYHZDr+Yqwl1Jd4J+UtWsGqGBGNB1Ae4V1dwR8zUDKxXXee7G/dCDnIu4xpkZbPD+brcULcpF/Tdq/WsTbpCNhPgjHuo8hQY3vFv1NMla8mr0=

+ 1TSgE9DfTeqk0qybQM1r83M5ZwWKV0mPQBZl1VMs+VplB6E/6JAYWCKiq9ewgocOaktK94jtEtsaDhYeyojZFBlukt1lKp4kmkUwUSEmi3EFsprNakg+Bm6t85tEm5he5mG1ivHlE3M5lBWJ2A0r1g3jWSjYJlkk2nOwFE8bmyE= + UIcU0xmsusgnYAR7qWO0KXw90tRl2GHUY/z8ATVdPPbGpQU7qObya45+c7LLJrKJJyloN8GWYynKDZuvknRG1GUBAZoT2p1PAuD8xsbKlucuuFJ3kuzUtC66iA6ss//Ps++3VJyQEvsygQT480pZxLgoi7d9sNpJx2eeprf7RYE= + zwIZqyPSrUR2ZFdTJshNWEM4KN8oQzgY7pDQrx/jOviZv57A/n1qJaj7aP4zU4juZiZU06MPDI/P7H1tyBi3LNzEj7SG1apWv7MOBre5RQqoDZJggCFEl9o+65iGNMzs16NnMVFMqmXmMfH3tN6VAXDanWca96D2N2S8QfvNQgE= + Uoxh1dskd3C0N7SQ1nJXW7FyjB+J54R5yAcd8Zk0ukunhtuzsziQH4ZoMhBuzwxRwOaw0Umj77EcdEevuvFHn6LAK/solK2lkRcuKY2QTgkbYyYOxZNa1pJJaAfgzSGsBiwiGtHXl2eFLb2jfYDa4V/SV2B6BPOVheSUQGZlyYM= + Lkq21wnu7S2T2NbzyVUVKm+mfurJqHzCxX+lIKVEkEhn5ipPo76vew7k+bUj2C5MZ+64zEK1GFANpP9mzghtmSzzI4bzIx/tanQLo2047VyU2UO0Oaskl3TKHGMkTY+ok8GKaDF02aSfxPQ5poNsWycS1/eeLFklnLkviF7mVcfCoStSHAb+8dQzxO22Mu+oN2rXHinoNDSmFzUTx8cJapQhgji+GADRKF77Sfa5tHk/hCzVUXGBHgBs1jJM9cin2BBij8PngOaAAlby4gr07/r8SZU2uuXoxEDhpxf6mRTET5Wr2hxAyhu3bpZeCc0LokckNkzJPGUG6JaXXdUcgQ== +
"; + #endregion } internal static class Adfs2019LabConstants diff --git a/tests/Microsoft.Identity.Test.E2e/KeyGuardAttestationTests.cs b/tests/Microsoft.Identity.Test.E2e/KeyGuardAttestationTests.cs new file mode 100644 index 0000000000..a50fcd376b --- /dev/null +++ b/tests/Microsoft.Identity.Test.E2e/KeyGuardAttestationTests.cs @@ -0,0 +1,228 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/* +WHY THESE TESTS ONLY RUN ON A SPECIFIC (AZURE ARC) MACHINE +--------------------------------------------------------- +KeyGuard attestation requires: + 1. A KeyGuard / Virtualization-Based Security (VBS) capable environment. + 2. The ability to create a CNG RSA key with: + - Virtual Isolation (NCRYPT_USE_VIRTUAL_ISOLATION_FLAG) + - Per-boot scope (NCRYPT_USE_PER_BOOT_KEY_FLAG) + 3. A native KeyGuard attestation stack (deployed via the MtlsPop package) capable of: + - Accessing the key handle + - Interacting with the VBS services to produce an attestation + +Most hosted build agents (including standard Azure DevOps Microsoft-hosted pools) do NOT expose: + - Virtualization-based key isolation + - The necessary kernel components for KeyGuard property retrieval + - The proper security context to create KeyGuard-protected keys + +We therefore run these tests ONLY on a dedicated Azure Arc–connected VM (custom self-hosted agent) that: + - Is provisioned with VBS + KeyGuard enabled + - Has the Microsoft Software Key Storage Provider configured to honor Virtual Isolation + per-boot flags + - Has an identity/endpoint (TOKEN_ATTESTATION_ENDPOINT) capable of accepting and validating a KeyGuard attestation + - Is allowed in the pipeline via filtering on the TestCategory MI_E2E_AzureArc (and infra chooses that agent) + +If any prerequisite is missing (e.g., VBS off, endpoint unset, native DLL absent, or key not actually KeyGuard-protected), +the test exits early with Assert.Inconclusive instead of failing the overall build. +*/ + +using Microsoft.Identity.Client.MtlsPop.Attestation; +using Microsoft.Identity.Test.Common.Core.Helpers; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Runtime.InteropServices; +using System.Security.Cryptography; +using Microsoft.Identity.Client.MtlsPop; +using System.Threading.Tasks; +using System.Threading; + +namespace Microsoft.Identity.Test.E2E +{ + [TestClass] + public class KeyGuardAttestationTests + { + /* + Creates a KeyGuard-capable RSA key (2048-bit) using the Microsoft Software Key Storage Provider. + Flags: + - NCRYPT_USE_VIRTUAL_ISOLATION_FLAG: Requests KeyGuard / Virtual Isolation (backed by VBS). + - NCRYPT_USE_PER_BOOT_KEY_FLAG: Key material only valid for the current boot (expected scenario for attestation). + On machines without KeyGuard/VBS support the provider may silently ignore the flags; we detect that later via IsKeyGuardProtected. + IMPORTANT: This must run on the Azure Arc custom agent where VBS + KeyGuard is enabled. + */ + private static CngKey CreateKeyGuardKey(string keyName) + { + const string ProviderName = "Microsoft Software Key Storage Provider"; + const int NCRYPT_USE_VIRTUAL_ISOLATION_FLAG = 0x00020000; + const int NCRYPT_USE_PER_BOOT_KEY_FLAG = 0x00040000; + + var p = new CngKeyCreationParameters + { + Provider = new CngProvider(ProviderName), + ExportPolicy = CngExportPolicies.None, // No export allowed; expected for attested keys. + KeyUsage = CngKeyUsages.AllUsages, // Broad usage; attestation library only needs signing. + KeyCreationOptions = + CngKeyCreationOptions.OverwriteExistingKey | + (CngKeyCreationOptions)NCRYPT_USE_VIRTUAL_ISOLATION_FLAG | + (CngKeyCreationOptions)NCRYPT_USE_PER_BOOT_KEY_FLAG, + }; + + // Set 2048-bit RSA length (current attestation native lib expects RSA; adjust only with platform guidance). + p.Parameters.Add(new CngProperty( + "Length", + BitConverter.GetBytes(2048), + CngPropertyOptions.None)); + + return CngKey.Create(CngAlgorithm.Rsa, keyName, p); + } + + /* + Determines whether the key actually received KeyGuard Virtual Isolation backing. + Some environments will accept the creation flags but produce a normal (non-KeyGuard) key; + those runs should be marked Inconclusive rather than Fail to avoid noisy pipeline failures. + This mirrors the logic used in other internal tracking (ref #5448). + */ + private static bool IsKeyGuardProtected(CngKey key) + { + try + { + // KeyGuard exposes a "Virtual Iso" property that is non-zero when protected. + // Same check used in #5448. :contentReference[oaicite:1]{index=1} + var prop = key.GetProperty("Virtual Iso", CngPropertyOptions.None); + var bytes = prop.GetValue(); + return bytes != null && bytes.Length >= 4 && BitConverter.ToInt32(bytes, 0) != 0; + } + catch + { + return false; + } + } + + /* + Synchronous attestation path. + Restricted to Azure Arc (MI_E2E_AzureArc) because: + - Needs a machine with KeyGuard + VBS + - Needs TOKEN_ATTESTATION_ENDPOINT env var (injected by pipeline/agent config) + - Uses AttestationClient which depends on a native DLL deployed only on that custom agent + Fails fast with Assert.Inconclusive when prerequisites are missing. + */ + [TestCategory("MI_E2E_AzureArc")] + [RunOnAzureDevOps] + [TestMethod] + public void Attest_KeyGuardKey_OnAzureArc_Succeeds() + { + // Endpoint is provisioned only on the Azure Arc agent (backed by MSI / identity service). + var endpoint = Environment.GetEnvironmentVariable("TOKEN_ATTESTATION_ENDPOINT"); + if (string.IsNullOrWhiteSpace(endpoint)) + { + Assert.Inconclusive($"Set {"TOKEN_ATTESTATION_ENDPOINT"} on the Azure Arc agent to run this test."); + } + + // Placeholder logical client ID used by the attestation endpoint (matches agent configuration). + var clientId = "MSI_CLIENT_ID"; + string keyName = "MsalE2E_Keyguard"; + + CngKey key = null; + try + { + key = CreateKeyGuardKey(keyName); + + if (!IsKeyGuardProtected(key)) + { + // Indicates environment does not truly support KeyGuard (e.g., VBS disabled) — do not treat as test failure. + Assert.Inconclusive("Key was created but not KeyGuard-protected. Is KeyGuard/VBS enabled on this machine?"); + } + + // Use the new public AttestationClient from the MtlsPop package. :contentReference[oaicite:2]{index=2} + using var client = new AttestationClient(); + var result = client.Attest(endpoint, key.Handle, clientId); + + // Validate success + JWT shape (3 parts). + Assert.AreEqual(AttestationStatus.Success, result.Status, + $"Attestation failed: status={result.Status}, nativeRc={result.NativeErrorCode}, msg={result.ErrorMessage}"); + Assert.IsFalse(string.IsNullOrEmpty(result.Jwt), "Expected a non-empty attestation JWT."); + + var parts = result.Jwt.Split('.'); + Assert.AreEqual(3, parts.Length, "Expected a JWT (3 parts)."); + } + catch (CryptographicException ex) + { + // Common when provider flags unsupported or isolation services absent. + Assert.Inconclusive("CNG/KeyGuard is not available or access is denied on this machine: " + ex.Message); + } + catch (InvalidOperationException ex) + { + // Thrown by AttestationClient when the native DLL cannot be found/initialized (not deployed outside Azure Arc agent). + Assert.Inconclusive("Attestation native lib not available on this runner: " + ex.Message); + } + finally + { + try { key?.Delete(); } catch { /* best-effort cleanup */ } + } + } + + /* + Async attestation path. + Demonstrates PopKeyAttestor.AttestKeyGuardAsync which wraps the native synchronous call. + Same environmental constraints as the synchronous test; still limited to the Azure Arc agent. + */ + [TestCategory("MI_E2E_AzureArc")] + [RunOnAzureDevOps] + [TestMethod] + public async Task Attest_KeyGuardKey_OnAzureArc_Async_Succeeds() + { + var endpoint = Environment.GetEnvironmentVariable("TOKEN_ATTESTATION_ENDPOINT"); + if (string.IsNullOrWhiteSpace(endpoint)) + { + Assert.Inconclusive($"Set {"TOKEN_ATTESTATION_ENDPOINT"} on the Azure Arc agent to run this test."); + } + + var clientId = "MSI_CLIENT_ID"; + string keyName = "MsalE2E_Keyguard_Async"; + + CngKey key = null; + try + { + key = CreateKeyGuardKey(keyName); + + if (!IsKeyGuardProtected(key)) + { + Assert.Inconclusive("Key was created but not KeyGuard-protected. Is KeyGuard/VBS enabled on this machine?"); + } + + // Exercise the async facade (PopKeyAttestor) which wraps the synchronous native call in Task.Run. + var result = await PopKeyAttestor.AttestKeyGuardAsync( + endpoint, + key.Handle, + clientId: clientId, + cancellationToken: CancellationToken.None).ConfigureAwait(false); + + Assert.AreEqual(AttestationStatus.Success, result.Status, + $"Async attestation failed: status={result.Status}, nativeRc={result.NativeErrorCode}, msg={result.ErrorMessage}"); + Assert.IsFalse(string.IsNullOrEmpty(result.Jwt), "Expected a non-empty attestation JWT from async path."); + + var parts = result.Jwt.Split('.'); + Assert.AreEqual(3, parts.Length, "Expected a JWT (3 parts) from async path."); + } + catch (CryptographicException ex) + { + Assert.Inconclusive("CNG/KeyGuard is not available or access is denied on this machine: " + ex.Message); + } + catch (InvalidOperationException ex) + { + // Could originate from native initialization inside PopKeyAttestor (AttestationClient constructor). + Assert.Inconclusive("Attestation native lib not available on this runner (async path): " + ex.Message); + } + catch (ArgumentException ex) + { + // Defensive: invalid handle or parameters — treat as environment/setup issue for this scenario. + Assert.Inconclusive("Handle or parameters invalid for async attestation path: " + ex.Message); + } + finally + { + try { key?.Delete(); } catch { /* best-effort cleanup */ } + } + } + } +} diff --git a/tests/Microsoft.Identity.Test.E2e/ManagedIdentityKeyAcquisitionTests.cs b/tests/Microsoft.Identity.Test.E2e/ManagedIdentityKeyAcquisitionTests.cs new file mode 100644 index 0000000000..22a3b4b047 --- /dev/null +++ b/tests/Microsoft.Identity.Test.E2e/ManagedIdentityKeyAcquisitionTests.cs @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Security.Cryptography; +using Microsoft.Identity.Client.ManagedIdentity.KeyProviders; +using Microsoft.Identity.Test.Common.Core.Helpers; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.Identity.Test.E2E +{ + [TestClass] + public class ManagedIdentityKeyAcquisitionTests + { + private const string SoftwareKspName = "Microsoft Software Key Storage Provider"; + + // Runs on the AzureArc agent: must obtain a VBS/KeyGuard key. + [TestMethod] + [TestCategory("MI_E2E_KeyAcquisition_KeyGuard")] + [RunOnAzureDevOps] + public void KeyAcquisition_Fetches_KeyGuard_Key() + { + if (!OperatingSystem.IsWindows()) + { + Assert.Inconclusive("This test runs on Windows agents only."); + } + + bool ok = WindowsCngKeyOperations.TryGetOrCreateKeyGuard(logger: null, out RSA rsa); + Assert.IsTrue(ok, "Expected KeyGuard key on AzureArc agent."); + + using (rsa) + { + var rsacng = rsa as RSACng; + Assert.IsNotNull(rsacng, "Expected RSACng for KeyGuard."); + Assert.IsTrue( + WindowsCngKeyOperations.IsKeyGuardProtected(rsacng.Key), + "Expected KeyGuard (VBS) protected key on AzureArc agent."); + } + } + + // Runs on the IMDS agent: must obtain a TPM/PCP hardware key (user scope). + [TestMethod] + [TestCategory("MI_E2E_KeyAcquisition_Hardware")] + [RunOnAzureDevOps] + public void KeyAcquisition_Fetches_Hardware_Key() + { + if (!OperatingSystem.IsWindows()) + { + Assert.Inconclusive("This test runs on Windows agents only."); + } + + bool ok = WindowsCngKeyOperations.TryGetOrCreateHardwareRsa(logger: null, out RSA rsa); + Assert.IsTrue(ok, "Expected TPM hardware key on IMDS agent."); + + using (rsa) + { + var rsacng = rsa as RSACng; + Assert.IsNotNull(rsacng, "Expected RSACng for hardware key."); + + Assert.AreEqual( + SoftwareKspName, + rsacng.Key.Provider.Provider, + "Expected TPM-backed key via Microsoft Software Key Storage Provider."); + + // TPM keys created with ExportPolicy=None should not allow private export. + bool privateExportable = true; + try + { _ = rsacng.ExportParameters(true); } + catch (CryptographicException) { privateExportable = false; } + catch (NotSupportedException) { privateExportable = false; } + + Assert.IsFalse(privateExportable, "Hardware (TPM) key should be non-exportable."); + } + } + } +} diff --git a/tests/Microsoft.Identity.Test.E2e/Microsoft.Identity.Test.E2E.MSI.csproj b/tests/Microsoft.Identity.Test.E2e/Microsoft.Identity.Test.E2E.MSI.csproj index 6b85de84ce..ae7c12399c 100644 --- a/tests/Microsoft.Identity.Test.E2e/Microsoft.Identity.Test.E2E.MSI.csproj +++ b/tests/Microsoft.Identity.Test.E2e/Microsoft.Identity.Test.E2E.MSI.csproj @@ -8,6 +8,7 @@ + diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ClientCredentialsMtlsPopTests.cs b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ClientCredentialsMtlsPopTests.cs index c5bd5f0ab6..a67478cdec 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ClientCredentialsMtlsPopTests.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ClientCredentialsMtlsPopTests.cs @@ -15,14 +15,14 @@ namespace Microsoft.Identity.Test.Integration.HeadlessTests { // Tests in this class will run on .NET Core [TestClass] - public class ClientCredentialsMtlsPopTests + public class ClientCredentialsMtlsPopTests { private const string MsiAllowListedAppIdforSNI = "163ffef9-a313-45b4-ab2f-c7e2f5e0e23e"; [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); } [DoNotRunOnLinux] // POP is not supported on Linux diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ClientCredentialsTests.NetFwk.cs b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ClientCredentialsTests.NetFwk.cs index 750707399d..056840313d 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ClientCredentialsTests.NetFwk.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ClientCredentialsTests.NetFwk.cs @@ -48,7 +48,7 @@ private enum CredentialType [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); } // regression test based on SAL introducing a new SKU value and making ESTS not issue the refresh_in value diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ClientCredentialsTests.WithRegion.cs b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ClientCredentialsTests.WithRegion.cs index 069f0978e7..2caa7b60a2 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ClientCredentialsTests.WithRegion.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ClientCredentialsTests.WithRegion.cs @@ -38,7 +38,7 @@ public class RegionalAuthIntegrationTests [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); if (_keyVault == null) { diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/LongRunningOnBehalfOfTests.cs b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/LongRunningOnBehalfOfTests.cs index 1aba270fc7..4dd18d49ad 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/LongRunningOnBehalfOfTests.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/LongRunningOnBehalfOfTests.cs @@ -33,7 +33,7 @@ public class LongRunningOnBehalfOfTests [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); if (string.IsNullOrEmpty(_confidentialClientSecret)) { _confidentialClientSecret = _keyVault.GetSecretByName(TestConstants.MsalOBOKeyVaultSecretName).Value; diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ManagedIdentityTests.NetFwk.cs b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ManagedIdentityTests.NetFwk.cs index cfb1c04af3..83f3451209 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ManagedIdentityTests.NetFwk.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/ManagedIdentityTests.NetFwk.cs @@ -491,7 +491,7 @@ private IManagedIdentityApplication CreateMIAWithProxy(string url, string userAs break; } - // Disabling shared cache options to avoid cross test pollution. + builder.Config.AccessorOptions = null; IManagedIdentityApplication mia = builder.WithClientCapabilities(new[] { "cp1" }) diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/OnBehalfOfTests.cs b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/OnBehalfOfTests.cs index d99e9a82c1..c88b338af2 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/OnBehalfOfTests.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/OnBehalfOfTests.cs @@ -38,7 +38,7 @@ public class OnBehalfOfTests [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); if (string.IsNullOrEmpty(_confidentialClientSecret)) { _confidentialClientSecret = _keyVault.GetSecretByName(TestConstants.MsalOBOKeyVaultSecretName).Value; diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/PoPTests.NetFwk.cs b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/PoPTests.NetFwk.cs index e306e13ec4..a0c49289fb 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/PoPTests.NetFwk.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/PoPTests.NetFwk.cs @@ -53,7 +53,7 @@ public class PoPTests [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); } [RunOn(TargetFrameworks.NetCore)] diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/UsernamePasswordIntegrationTests.NetFwk.cs b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/UsernamePasswordIntegrationTests.NetFwk.cs index ed313d39dd..67e56b3d8b 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/UsernamePasswordIntegrationTests.NetFwk.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/HeadlessTests/UsernamePasswordIntegrationTests.NetFwk.cs @@ -42,7 +42,7 @@ public class UsernamePasswordIntegrationTests [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); } #region Happy Path Tests @@ -406,7 +406,6 @@ private async Task GetAuthenticationResultWithAssertAsync( .ConfigureAwait(false); } - Assert.IsNotNull(authResult); Assert.AreEqual(TokenSource.IdentityProvider, authResult.AuthenticationResultMetadata.TokenSource); Assert.IsNotNull(authResult.AccessToken); diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/ConfidentialClientAuthorizationTests.cs b/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/ConfidentialClientAuthorizationTests.cs index 35a9bba068..02fe8d8534 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/ConfidentialClientAuthorizationTests.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/ConfidentialClientAuthorizationTests.cs @@ -53,7 +53,7 @@ public static void ClassInitialize(TestContext context) [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); } #endregion diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/FociTests.cs b/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/FociTests.cs index 0b9710e9ca..ebd72d706f 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/FociTests.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/FociTests.cs @@ -30,7 +30,7 @@ public class FociTests [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); } #endregion diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/InteractiveFlowTests.NetFwk.cs b/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/InteractiveFlowTests.NetFwk.cs index 8430e572b0..e424bc42c0 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/InteractiveFlowTests.NetFwk.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/InteractiveFlowTests.NetFwk.cs @@ -34,7 +34,7 @@ public class InteractiveFlowTests [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); } #endregion MSTest Hooks diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/SeleniumInfrastructureTests.NetFwk.cs b/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/SeleniumInfrastructureTests.NetFwk.cs index 1e7110db1f..48bbefc6de 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/SeleniumInfrastructureTests.NetFwk.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/SeleniumTests/SeleniumInfrastructureTests.NetFwk.cs @@ -30,7 +30,7 @@ public class InfrastructureTests [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); } #endregion diff --git a/tests/Microsoft.Identity.Test.Unit/ApiConfigTests/AcquireTokenInteractiveBuilderTests.cs b/tests/Microsoft.Identity.Test.Unit/ApiConfigTests/AcquireTokenInteractiveBuilderTests.cs index d2ccbe677c..633ac1afbe 100644 --- a/tests/Microsoft.Identity.Test.Unit/ApiConfigTests/AcquireTokenInteractiveBuilderTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ApiConfigTests/AcquireTokenInteractiveBuilderTests.cs @@ -26,7 +26,7 @@ public class AcquireTokenInteractiveBuilderTests [TestInitialize] public async Task TestInitializeAsync() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); _harness = new AcquireTokenInteractiveBuilderHarness(); await _harness.SetupAsync() .ConfigureAwait(false); diff --git a/tests/Microsoft.Identity.Test.Unit/ApiConfigTests/AuthorityTests.cs b/tests/Microsoft.Identity.Test.Unit/ApiConfigTests/AuthorityTests.cs index d368bb8bb0..17821aae1b 100644 --- a/tests/Microsoft.Identity.Test.Unit/ApiConfigTests/AuthorityTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ApiConfigTests/AuthorityTests.cs @@ -35,7 +35,7 @@ public class AuthorityTests : TestBase [TestInitialize] public override void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); base.TestInitialize(); _harness = base.CreateTestHarness(); _testRequestContext = new RequestContext( diff --git a/tests/Microsoft.Identity.Test.Unit/AppConfigTests/ConfidentialClientApplicationBuilderTests.cs b/tests/Microsoft.Identity.Test.Unit/AppConfigTests/ConfidentialClientApplicationBuilderTests.cs index 2c38a3d194..1bab2b0ab3 100644 --- a/tests/Microsoft.Identity.Test.Unit/AppConfigTests/ConfidentialClientApplicationBuilderTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/AppConfigTests/ConfidentialClientApplicationBuilderTests.cs @@ -25,7 +25,7 @@ public class ConfidentialClientApplicationBuilderTests [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); } [TestMethod] diff --git a/tests/Microsoft.Identity.Test.Unit/AppConfigTests/ManagedIdentityApplicationBuilderTests.cs b/tests/Microsoft.Identity.Test.Unit/AppConfigTests/ManagedIdentityApplicationBuilderTests.cs index 01e26fc413..6897f314b0 100644 --- a/tests/Microsoft.Identity.Test.Unit/AppConfigTests/ManagedIdentityApplicationBuilderTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/AppConfigTests/ManagedIdentityApplicationBuilderTests.cs @@ -24,7 +24,7 @@ public class ManagedIdentityApplicationBuilderTests [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); } [TestMethod] diff --git a/tests/Microsoft.Identity.Test.Unit/CacheTests/CacheFallbackOperationsTests.cs b/tests/Microsoft.Identity.Test.Unit/CacheTests/CacheFallbackOperationsTests.cs index c04a02fc9e..b5981956b4 100644 --- a/tests/Microsoft.Identity.Test.Unit/CacheTests/CacheFallbackOperationsTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CacheTests/CacheFallbackOperationsTests.cs @@ -17,7 +17,7 @@ namespace Microsoft.Identity.Test.Unit.CacheTests { [TestClass] - public class CacheFallbackOperationsTests + public class CacheFallbackOperationsTests { private InMemoryLegacyCachePersistence _legacyCachePersistence; private ILoggerAdapter _logger; @@ -25,7 +25,7 @@ public class CacheFallbackOperationsTests [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); // Methods in CacheFallbackOperations silently catch all exceptions and log them; // By setting this to null, logging will fail, making the test fail. diff --git a/tests/Microsoft.Identity.Test.Unit/CacheTests/CacheSerializationTests.cs b/tests/Microsoft.Identity.Test.Unit/CacheTests/CacheSerializationTests.cs index edc494bc9d..5505ee114e 100644 --- a/tests/Microsoft.Identity.Test.Unit/CacheTests/CacheSerializationTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CacheTests/CacheSerializationTests.cs @@ -30,14 +30,8 @@ namespace Microsoft.Identity.Test.Unit.CacheTests { [TestClass] - public class CacheSerializationTests + public class CacheSerializationTests : TestBase { - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - private static readonly IEnumerable s_appMetadataKeys = new[] { StorageJsonKeys.ClientId , diff --git a/tests/Microsoft.Identity.Test.Unit/CacheTests/LoadingProjectsTests.cs b/tests/Microsoft.Identity.Test.Unit/CacheTests/LoadingProjectsTests.cs index c22bda6088..5f0384e943 100644 --- a/tests/Microsoft.Identity.Test.Unit/CacheTests/LoadingProjectsTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CacheTests/LoadingProjectsTests.cs @@ -9,14 +9,8 @@ namespace Microsoft.Identity.Test.Unit.CacheTests { [TestClass] - public class LoadingProjectsTests + public class LoadingProjectsTests : TestBase { - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - [TestMethod] public void CanDeserializeTokenCache() { diff --git a/tests/Microsoft.Identity.Test.Unit/CacheTests/MsalTokenCacheKeysTests.cs b/tests/Microsoft.Identity.Test.Unit/CacheTests/MsalTokenCacheKeysTests.cs index 2828fe01bd..4c83fcb6c6 100644 --- a/tests/Microsoft.Identity.Test.Unit/CacheTests/MsalTokenCacheKeysTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CacheTests/MsalTokenCacheKeysTests.cs @@ -12,14 +12,8 @@ namespace Microsoft.Identity.Test.Unit.CacheTests { [TestClass] - public class MsalTokenCacheKeysTests + public class MsalTokenCacheKeysTests : TestBase { - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - [TestMethod] public void MsalAccessTokenCacheKey() { diff --git a/tests/Microsoft.Identity.Test.Unit/CacheTests/UnifiedCacheFormatTests.cs b/tests/Microsoft.Identity.Test.Unit/CacheTests/UnifiedCacheFormatTests.cs index 7f1f46a7bf..4d3efbea57 100644 --- a/tests/Microsoft.Identity.Test.Unit/CacheTests/UnifiedCacheFormatTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CacheTests/UnifiedCacheFormatTests.cs @@ -150,7 +150,7 @@ public void B2C_NoTenantId_CacheFormatValidationTest() { using (var harness = CreateTestHarness()) { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); IntitTestData(ResourceHelper.GetTestResourceRelativePath("B2CNoTenantIdTestData.txt")); RunCacheFormatValidation(harness); } diff --git a/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/HttpClientFactoryTests.cs b/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/HttpClientFactoryTests.cs index 8e2f79ddf7..23c268f7c4 100644 --- a/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/HttpClientFactoryTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/HttpClientFactoryTests.cs @@ -13,14 +13,8 @@ namespace Microsoft.Identity.Test.Unit.CoreTests.HttpTests { [TestClass] - public class HttpClientFactoryTests + public class HttpClientFactoryTests : TestBase { - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - // You might need to add a method to clear the HttpClient cache in SimpleHttpClientFactory - } [TestMethod] public void TestGetHttpClientWithCustomCallback() diff --git a/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/HttpManagerTests.cs b/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/HttpManagerTests.cs index 561e9dcb23..0c72b0d5f3 100644 --- a/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/HttpManagerTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/HttpManagerTests.cs @@ -24,16 +24,10 @@ namespace Microsoft.Identity.Test.Unit.CoreTests.HttpTests { [TestClass] - public class HttpManagerTests + public class HttpManagerTests : TestBase { private readonly TestDefaultRetryPolicy _stsRetryPolicy = new TestDefaultRetryPolicy(RequestType.STS); - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - [TestMethod] public async Task MtlsCertAsync() { diff --git a/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/RedirectUriHelperTests.cs b/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/RedirectUriHelperTests.cs index ba18de0449..49a32e395c 100644 --- a/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/RedirectUriHelperTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/RedirectUriHelperTests.cs @@ -13,14 +13,8 @@ namespace Microsoft.Identity.Test.Unit.CoreTests.HttpTests { [TestClass] - public class RedirectUriHelperTests + public class RedirectUriHelperTests : TestBase { - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - [TestMethod] public void ValidateRedirectUri_Throws() { diff --git a/tests/Microsoft.Identity.Test.Unit/CoreTests/InstanceTests/InstanceDiscoveryManagerTests.cs b/tests/Microsoft.Identity.Test.Unit/CoreTests/InstanceTests/InstanceDiscoveryManagerTests.cs index 89a910abd1..68ff2e638f 100644 --- a/tests/Microsoft.Identity.Test.Unit/CoreTests/InstanceTests/InstanceDiscoveryManagerTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CoreTests/InstanceTests/InstanceDiscoveryManagerTests.cs @@ -62,7 +62,6 @@ private void InitializeTestObjects(bool isInstanceDiscoveryEnabled = true) _testRequestContext = new RequestContext(_harness.ServiceBundle, Guid.NewGuid(), null); _discoveryManager = new InstanceDiscoveryManager( _harness.HttpManager, - false, null, null, _knownMetadataProvider, @@ -103,7 +102,6 @@ public async Task NetworkCacheProvider_IsUsedFirst_Async() _discoveryManager = new InstanceDiscoveryManager( _harness.HttpManager, - false, null, null, _knownMetadataProvider, @@ -143,7 +141,6 @@ public async Task InstanceDiscoveryDisabled_Async() _discoveryManager = new InstanceDiscoveryManager( _harness.HttpManager, - false, null, null, _knownMetadataProvider, @@ -285,7 +282,6 @@ public async Task NetworkProviderIsCalledLastAsync() // Arrange _discoveryManager = new InstanceDiscoveryManager( _harness.HttpManager, - false, null, null, _knownMetadataProvider, @@ -323,7 +319,6 @@ public async Task UserProvider_TakesPrecedence_OverNetworkProvider_Async() // Arrange _discoveryManager = new InstanceDiscoveryManager( _harness.HttpManager, - false, _userMetadataProvider, null, _knownMetadataProvider, @@ -363,7 +358,6 @@ public async Task CustomDiscoveryEndpoint_Async() _discoveryManager = new InstanceDiscoveryManager( _harness.HttpManager, - false, null, customDiscoveryEndpoint, _knownMetadataProvider, diff --git a/tests/Microsoft.Identity.Test.Unit/CoreTests/InstanceTests/InstanceProviderTests.cs b/tests/Microsoft.Identity.Test.Unit/CoreTests/InstanceTests/InstanceProviderTests.cs index 1922f31505..f00d9eb1dd 100644 --- a/tests/Microsoft.Identity.Test.Unit/CoreTests/InstanceTests/InstanceProviderTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CoreTests/InstanceTests/InstanceProviderTests.cs @@ -28,7 +28,7 @@ public void StaticProviderPreservesStateAcrossInstances() // Act InstanceDiscoveryMetadataEntry result = staticMetadataProvider2.GetMetadata("env", _logger); - staticMetadataProvider2.Clear(); + NetworkCacheMetadataProvider.ResetStaticCacheForTest(); InstanceDiscoveryMetadataEntry result2 = staticMetadataProvider2.GetMetadata("env", _logger); // Assert diff --git a/tests/Microsoft.Identity.Test.Unit/CoreTests/RegionDiscoveryProviderTests.cs b/tests/Microsoft.Identity.Test.Unit/CoreTests/RegionDiscoveryProviderTests.cs index c549604ae4..f96d24c4be 100644 --- a/tests/Microsoft.Identity.Test.Unit/CoreTests/RegionDiscoveryProviderTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CoreTests/RegionDiscoveryProviderTests.cs @@ -50,7 +50,7 @@ public override void TestInitialize() _apiEvent = new ApiEvent(Guid.NewGuid()); _apiEvent.ApiId = ApiEvent.ApiIds.AcquireTokenForClient; _testRequestContext.ApiEvent = _apiEvent; - _regionDiscoveryProvider = new RegionAndMtlsDiscoveryProvider(_httpManager, true); + _regionDiscoveryProvider = new RegionAndMtlsDiscoveryProvider(_httpManager); } [TestCleanup] @@ -58,7 +58,7 @@ public override void TestCleanup() { Environment.SetEnvironmentVariable(TestConstants.RegionName, ""); _harness?.Dispose(); - _regionDiscoveryProvider = new RegionAndMtlsDiscoveryProvider(_httpManager, true); + _regionDiscoveryProvider = new RegionAndMtlsDiscoveryProvider(_httpManager); _httpManager.Dispose(); base.TestCleanup(); } @@ -164,8 +164,8 @@ public async Task SuccessfulResponseFromUserProvidedRegionAsync( } _testRequestContext.ServiceBundle.Config.AzureRegion = TestConstants.Region; - - IRegionDiscoveryProvider regionDiscoveryProvider = new RegionAndMtlsDiscoveryProvider(_httpManager, true); + RegionManager.ResetStaticCacheForTest(); + IRegionDiscoveryProvider regionDiscoveryProvider = new RegionAndMtlsDiscoveryProvider(_httpManager); InstanceDiscoveryMetadataEntry regionalMetadata = await regionDiscoveryProvider.GetMetadataAsync( new Uri("https://login.microsoftonline.com/common/"), _testRequestContext).ConfigureAwait(false); diff --git a/tests/Microsoft.Identity.Test.Unit/CoreTests/WsTrustTests/WsTrustEndpointTests.cs b/tests/Microsoft.Identity.Test.Unit/CoreTests/WsTrustTests/WsTrustEndpointTests.cs index f7fd65aa6d..9e428de69d 100644 --- a/tests/Microsoft.Identity.Test.Unit/CoreTests/WsTrustTests/WsTrustEndpointTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CoreTests/WsTrustTests/WsTrustEndpointTests.cs @@ -11,14 +11,8 @@ namespace Microsoft.Identity.Test.Unit.CoreTests.WsTrustTests { [TestClass] - public class WsTrustEndpointTests + public class WsTrustEndpointTests : TestBase { - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - private readonly Uri _uri = new Uri("https://windowsorusernamepasswordendpointurl"); private readonly string _cloudAudienceUri = "https://cloudAudienceUrn"; diff --git a/tests/Microsoft.Identity.Test.Unit/CryptographyTests.cs b/tests/Microsoft.Identity.Test.Unit/CryptographyTests.cs index ac919fbc89..fa69507108 100644 --- a/tests/Microsoft.Identity.Test.Unit/CryptographyTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CryptographyTests.cs @@ -14,14 +14,8 @@ namespace Microsoft.Identity.Test.Unit { [TestClass] [DeploymentItem(@"Resources\testCert.crtfile")] - public class CryptographyTests + public class CryptographyTests : TestBase { - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - [TestMethod] [TestCategory("CryptographyTests")] [System.Diagnostics.CodeAnalysis.SuppressMessage("Internal.Analyzers", "IA5352:DoNotMisuseCryptographicApi", Justification = "Suppressing RoslynAnalyzers: Rule: IA5352 - Do Not Misuse Cryptographic APIs in test only code")] diff --git a/tests/Microsoft.Identity.Test.Unit/DeviceCodeResponseTests.cs b/tests/Microsoft.Identity.Test.Unit/DeviceCodeResponseTests.cs index 9bfa6ecdb9..c4f0c511d5 100644 --- a/tests/Microsoft.Identity.Test.Unit/DeviceCodeResponseTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/DeviceCodeResponseTests.cs @@ -8,14 +8,8 @@ namespace Microsoft.Identity.Test.Unit { [TestClass] - public class DeviceCodeResponseTests + public class DeviceCodeResponseTests : TestBase { - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - private const string VerificationUrl = "http://verification.url"; private const string VerificationUri = "http://verification.uri"; diff --git a/tests/Microsoft.Identity.Test.Unit/Helpers/TestCsrFactory.cs b/tests/Microsoft.Identity.Test.Unit/Helpers/TestCsrFactory.cs new file mode 100644 index 0000000000..f407550b07 --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/Helpers/TestCsrFactory.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Security.Cryptography; +using Microsoft.Identity.Client.ManagedIdentity.V2; + +namespace Microsoft.Identity.Test.Unit.Helpers +{ + internal class TestCsrFactory : ICsrFactory + { + public (string csrPem, RSA privateKey) Generate(RSA rsa, string clientId, string tenantId, CuidInfo cuId) + { + // we don't care about the RSA that's passed in, we will always return the same mock private key + return ("mock-csr", CreateMockRsa()); + } + + /// + /// Creates a mock private key for testing purposes by loading key parameters from an XML string. + /// The XML format is used because it allows all necessary RSA parameters to be embedded directly in the code, + /// enabling deterministic and repeatable test runs. This method returns an object rather than a string, + /// as cryptographic operations in tests require a usable key instance, not just its serialized representation. + /// + public static RSA CreateMockRsa() + { + RSA rsa = null; + +#if NET462 || NET472 + // .NET Framework runs only on Windows, so RSACng (Windows-only) is always available + rsa = new RSACng(); +#else + // Cross-platform .NET - RSA.Create() returns appropriate PSS-capable implementation + rsa = RSA.Create(); +#endif + rsa.FromXmlString(TestConstants.XmlPrivateKey); + return rsa; + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicies.cs b/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicies.cs index a327e882c2..07a2ba9a11 100644 --- a/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicies.cs +++ b/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicies.cs @@ -39,4 +39,15 @@ internal override Task DelayAsync(int milliseconds) return Task.CompletedTask; } } + + internal class TestCsrMetadataProbeRetryPolicy : CsrMetadataProbeRetryPolicy + { + public TestCsrMetadataProbeRetryPolicy() : base() { } + + internal override Task DelayAsync(int milliseconds) + { + // No delay for tests + return Task.CompletedTask; + } + } } diff --git a/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicyFactory.cs b/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicyFactory.cs index f08e426961..2ed0c98f0d 100644 --- a/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicyFactory.cs +++ b/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicyFactory.cs @@ -20,6 +20,8 @@ public virtual IRetryPolicy GetRetryPolicy(RequestType requestType) return new TestImdsRetryPolicy(); case RequestType.RegionDiscovery: return new TestRegionDiscoveryRetryPolicy(); + case RequestType.CsrMetadataProbe: + return new TestCsrMetadataProbeRetryPolicy(); default: throw new ArgumentOutOfRangeException(nameof(requestType), requestType, "Unknown request type."); } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AppServiceTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AppServiceTests.cs index 8ed55e6b4d..c78924c248 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AppServiceTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AppServiceTests.cs @@ -1,14 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using System; using System.Globalization; -using System.Net; using System.Threading.Tasks; using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; using Microsoft.Identity.Client.ManagedIdentity; -using Microsoft.Identity.Test.Common; using Microsoft.Identity.Test.Common.Core.Helpers; using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -34,8 +31,8 @@ public async Task AppServiceInvalidEndpointAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); @@ -55,16 +52,25 @@ await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) [DataRow("http://127.0.0.1:41564/msi/token/", ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)] [DataRow(AppServiceEndpoint, ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)] [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, ManagedIdentitySource.MachineLearning)] - public void TestAppServiceUpgradeScenario( + public async Task TestAppServiceUpgradeScenario( string endpoint, ManagedIdentitySource managedIdentitySource, ManagedIdentitySource expectedManagedIdentitySource) { using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) { SetUpgradeScenarioEnvironmentVariables(managedIdentitySource, endpoint); - Assert.AreEqual(expectedManagedIdentitySource, ManagedIdentityApplication.GetManagedIdentitySource()); + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager); + + + + + ManagedIdentityApplication mi = miBuilder.Build() as ManagedIdentityApplication; + + Assert.AreEqual(expectedManagedIdentitySource, await mi.GetManagedIdentitySourceAsync().ConfigureAwait(false)); } } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AzureArcTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AzureArcTests.cs index fdc1bde4f9..2a172aff72 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AzureArcTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AzureArcTests.cs @@ -61,8 +61,8 @@ public async Task AzureArcAuthHeaderMissingAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); @@ -93,8 +93,8 @@ public async Task AzureArcAuthHeaderInvalidAsync(string filename, string errorMe var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); @@ -122,8 +122,8 @@ public async Task AzureArcInvalidEndpointAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CloudShellTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CloudShellTests.cs index 8463839fb4..90dd8da152 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CloudShellTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CloudShellTests.cs @@ -59,8 +59,8 @@ public async Task CloudShellInvalidEndpointAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs new file mode 100644 index 0000000000..5af70fc059 --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Formats.Asn1; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.ManagedIdentity.V2; +using Microsoft.Identity.Client.Utils; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests +{ + /// + /// Helper class for parsing and validating Certificate Signing Request (CSR) content and structure. + /// + internal static class CsrValidator + { + /// + /// Parses a raw CSR and returns the DER bytes. + /// + public static byte[] ParseRawCsr(string rawCsr) + { + if (string.IsNullOrWhiteSpace(rawCsr)) + throw new MsalServiceException(MsalError.InvalidCertificate, MsalErrorMessage.InvalidCertificate); + + try + { + return Convert.FromBase64String(rawCsr); + } + catch (Exception ex) + { + throw new MsalServiceException(MsalError.InvalidCertificate, MsalErrorMessage.InvalidCertificate, ex); + } + } + + /// + /// Validates the content of a CSR string against expected values. + /// + public static void ValidateCsrContent(string rawCsr, string expectedClientId, string expectedTenantId, CuidInfo expectedCuid) + { + byte[] csrBytes = ParseRawCsr(rawCsr); + + // Parse the CSR using AsnReader + var reader = new AsnReader(csrBytes, AsnEncodingRules.DER); + var csrSequence = reader.ReadSequence(); + + // certificationRequestInfo + var certReqInfoBytes = csrSequence.PeekEncodedValue().ToArray(); + var certReqInfoReader = new AsnReader(csrSequence.ReadEncodedValue().ToArray(), AsnEncodingRules.DER); + var certReqInfoSeq = certReqInfoReader.ReadSequence(); + + // version + int version = (int)certReqInfoSeq.ReadInteger(); + Assert.AreEqual(0, version, "CSR version should be 0"); + + // subject + var subjectBytes = certReqInfoSeq.PeekEncodedValue().ToArray(); + var subject = new X500DistinguishedName(certReqInfoSeq.ReadEncodedValue().ToArray()); + string subjectString = subject.Name; + + Assert.IsTrue(subjectString.Contains($"CN={expectedClientId}"), "Client ID (CN) not found in subject"); + Assert.IsTrue(subjectString.Contains($"DC={expectedTenantId}"), "Tenant ID (DC) not found in subject"); + + // subjectPKInfo + var pkInfoReader = new AsnReader(certReqInfoSeq.ReadEncodedValue().ToArray(), AsnEncodingRules.DER); + var pkInfoSeq = pkInfoReader.ReadSequence(); + + // algorithm + var algIdSeq = pkInfoSeq.ReadSequence(); + string algOid = algIdSeq.ReadObjectIdentifier(); + Assert.AreEqual("1.2.840.113549.1.1.1", algOid, "Public key algorithm is not RSA"); + if (algIdSeq.HasData) + { + algIdSeq.ReadNull(); + } + + // subjectPublicKey BIT STRING + var publicKeyBitString = pkInfoSeq.ReadBitString(out _); + + // Parse the RSAPublicKey structure from the BIT STRING (SEQUENCE of modulus, exponent) + var rsaKeyReader = new AsnReader(publicKeyBitString, AsnEncodingRules.DER); + var rsaKeySeq = rsaKeyReader.ReadSequence(); + byte[] modulus = rsaKeySeq.ReadIntegerBytes().ToArray(); + byte[] exponent = rsaKeySeq.ReadIntegerBytes().ToArray(); + + // Validate modulus length (2048 bits = 256 bytes, may have leading zero) + Assert.IsTrue(modulus.Length == 256 || modulus.Length == 257, $"RSA modulus should be 2048 bits, got {modulus.Length * 8} bits"); + + // Validate exponent (commonly 65537 = 0x010001) + Assert.IsTrue(exponent.Length >= 1 && exponent.Length <= 4, "RSA exponent has invalid length"); + + // attributes [0] (optional) + if (certReqInfoSeq.HasData) + { + var attrTag = new Asn1Tag(TagClass.ContextSpecific, 0); + if (certReqInfoSeq.PeekTag().HasSameClassAndValue(attrTag)) + { + var attrSetReader = certReqInfoSeq.ReadSetOf(attrTag); + bool foundCuid = false; + while (attrSetReader.HasData) + { + var attrSeq = attrSetReader.ReadSequence(); + string oid = attrSeq.ReadObjectIdentifier(); + if (oid == "1.3.6.1.4.1.311.90.2.10") // challengePassword + { + var valueSet = attrSeq.ReadSetOf(); + while (valueSet.HasData) + { + string cuidJson = valueSet.ReadCharacterString(UniversalTagNumber.UTF8String); + string expectedCuidJson = JsonHelper.SerializeToJson(expectedCuid); + Assert.AreEqual(expectedCuidJson, cuidJson, "CUID attribute JSON value does not match expected"); + foundCuid = true; + } + } + } + Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); + } + } + + // signatureAlgorithm + var sigAlgSeq = csrSequence.ReadSequence(); + string sigAlgOid = sigAlgSeq.ReadObjectIdentifier(); + Assert.AreEqual("1.2.840.113549.1.1.10", sigAlgOid, "Signature algorithm is not RSASSA-PSS (SHA256withRSA/PSS)"); + + // signature + csrSequence.ReadBitString(out _); + + Assert.IsFalse(csrSequence.HasData, "Extra data found after CSR structure"); + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/DefaultRetryPolicyTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/DefaultRetryPolicyTests.cs index f563a34fe9..0126093fb4 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/DefaultRetryPolicyTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/DefaultRetryPolicyTests.cs @@ -47,7 +47,7 @@ public async Task UAMIFails500OnceThenSucceeds200Async( .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; + IManagedIdentityApplication mi = miBuilder.Build(); @@ -104,7 +104,7 @@ public async Task UAMIFails500PermanentlyAsync( .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; + IManagedIdentityApplication mi = miBuilder.Build(); @@ -171,7 +171,7 @@ public async Task SAMIFails500OnceWithVariousRetryAfterHeaderValuesThenSucceeds2 .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; + IManagedIdentityApplication mi = miBuilder.Build(); @@ -223,7 +223,7 @@ public async Task SAMIFails500Permanently( .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; + IManagedIdentityApplication mi = miBuilder.Build(); @@ -277,7 +277,7 @@ public async Task SAMIFails400WhichIsNonRetriableAndRetryPolicyIsNotTriggeredAsy .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; + IManagedIdentityApplication mi = miBuilder.Build(); @@ -327,7 +327,7 @@ public async Task SAMIFails500AndRetryPolicyIsDisabledAndNotTriggeredAsync( .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; + IManagedIdentityApplication mi = miBuilder.Build(); diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs index 13e22314af..9853d54283 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs @@ -40,7 +40,7 @@ public async Task ImdsFails404TwiceThenSucceeds200Async( .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; + IManagedIdentityApplication mi = miBuilder.Build(); @@ -99,7 +99,7 @@ public async Task ImdsFails410FourTimesThenSucceeds200Async( .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; + IManagedIdentityApplication mi = miBuilder.Build(); @@ -158,7 +158,7 @@ public async Task ImdsFails410PermanentlyAsync( .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; + IManagedIdentityApplication mi = miBuilder.Build(); @@ -214,7 +214,7 @@ public async Task ImdsFails504PermanentlyAsync( .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; + IManagedIdentityApplication mi = miBuilder.Build(); @@ -270,7 +270,7 @@ public async Task ImdsFails400WhichIsNonRetriableAndRetryPolicyIsNotTriggeredAsy .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; + IManagedIdentityApplication mi = miBuilder.Build(); @@ -322,7 +322,7 @@ public async Task ImdsFails500AndRetryPolicyIsDisabledAndNotTriggeredAsync( .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; + IManagedIdentityApplication mi = miBuilder.Build(); @@ -368,7 +368,7 @@ public async Task ImdsRetryPolicyLifeTimeIsPerRequestAsync() .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; + IManagedIdentityApplication mi = miBuilder.Build(); diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs new file mode 100644 index 0000000000..61316f08d7 --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -0,0 +1,747 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Net; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.AppConfig; +using Microsoft.Identity.Client.Internal.Logger; +using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.ManagedIdentity.KeyProviders; +using Microsoft.Identity.Client.ManagedIdentity.V2; +using Microsoft.Identity.Client.MtlsPop; +using Microsoft.Identity.Client.PlatformsCommon.Interfaces; +using Microsoft.Identity.Client.PlatformsCommon.Shared; +using Microsoft.Identity.Test.Common.Core.Helpers; +using Microsoft.Identity.Test.Common.Core.Mocks; +using Microsoft.Identity.Test.Unit.Helpers; +using Microsoft.Identity.Test.Unit.PublicApiTests; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NSubstitute; +using static Microsoft.Identity.Test.Common.Core.Helpers.ManagedIdentityTestUtil; + +namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests +{ + [TestClass] + public class ImdsV2Tests : TestBase + { + private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); + private readonly TestCsrFactory _testCsrFactory = new TestCsrFactory(); + private readonly IdentityLoggerAdapter _identityLoggerAdapter = new IdentityLoggerAdapter( + new TestIdentityLogger(), + Guid.Empty, + "TestClient", + "1.0.0", + enablePiiLogging: false + ); + + // Fake attestation provider used by mTLS PoP tests so we never hit the real service + private static readonly Func> + s_fakeAttestationProvider = + (input, ct) => Task.FromResult(new AttestationTokenResponse + { + AttestationToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.fake.attestation.sig" + }); + + public const string Bearer = "Bearer"; + public const string MTLSPoP = "mtls_pop"; + + private void AddMocksToGetEntraToken( + MockHttpManager httpManager, + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null, + string certificateRequestCertificate = TestConstants.ValidRawCertificate, + bool mTLSPop = false) + { + if (userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null) + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(userAssignedIdentityId: userAssignedIdentityId, userAssignedId: userAssignedId)); + httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse(userAssignedIdentityId, userAssignedId, certificateRequestCertificate)); + } + else + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse(certificate: certificateRequestCertificate)); + } + + httpManager.AddMockHandler(MockHelpers.MockImdsV2EntraTokenRequestResponse(_identityLoggerAdapter, mTLSPop)); + } + + private async Task CreateManagedIdentityAsync( + MockHttpManager httpManager, + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, + string userAssignedId = null, + bool addProbeMock = true, + bool addSourceCheck = true, + ManagedIdentityKeyType managedIdentityKeyType = ManagedIdentityKeyType.InMemory) + { + ManagedIdentityApplicationBuilder miBuilder = null; + + var uami = userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null; + if (uami) + { + miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId); + } + else + { + miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned); + } + + miBuilder + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory) + .WithCsrFactory(_testCsrFactory); + + var managedIdentityApp = miBuilder.Build(); + + if (addProbeMock) + { + if (uami) + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(userAssignedIdentityId: userAssignedIdentityId, userAssignedId: userAssignedId)); + } + else + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + } + } + + if (addSourceCheck) + { + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSource); + } + + // Choose deterministic key source for tests. + IManagedIdentityKeyProvider managedIdentityKeyProvider = null; + if (managedIdentityKeyType == ManagedIdentityKeyType.KeyGuard) + { + // Force KeyGuard keys to deterministically exercise the attestation path. + managedIdentityKeyProvider = new TestKeyGuardManagedIdentityKeyProvider(); + } + else if (managedIdentityKeyType == ManagedIdentityKeyType.InMemory) + { + // Default for bearer tests: no attestation. + managedIdentityKeyProvider = new InMemoryManagedIdentityKeyProvider(); + } + + // Inject a test platform proxy that provides the chosen key provider + if (managedIdentityKeyProvider != null) + { + var platformProxy = Substitute.For(); + platformProxy.ManagedIdentityKeyProvider.Returns(managedIdentityKeyProvider); + + (managedIdentityApp as ManagedIdentityApplication) + .ServiceBundle.SetPlatformProxyForTest(platformProxy); + } + + return managedIdentityApp; + } + + #region Acceptance Tests + #region Bearer Token Tests + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task BearerTokenHappyPath( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.InMemory).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId); + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId, $"{TestConstants.ClientId}-2")] + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId, $"{TestConstants.MiResourceId}-2")] + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId, $"{TestConstants.ObjectId}-2")] + public async Task BearerTokenIsPerIdentity( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId, + string userAssignedId2) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + #region Identity 1 + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId); + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + #endregion Identity 1 + + #region Identity 2 + UserAssignedIdentityId userAssignedIdentityId2 = userAssignedIdentityId; // keep the same type, that's the most common scenario + var managedIdentityApp2 = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId2, userAssignedId2, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); // source is already cached + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId2, userAssignedId2); + + var result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result2); + Assert.IsNotNull(result2.AccessToken); + Assert.AreEqual(result2.TokenType, Bearer); + Assert.AreEqual(TokenSource.IdentityProvider, result2.AuthenticationResultMetadata.TokenSource); + + result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result2); + Assert.IsNotNull(result2.AccessToken); + Assert.AreEqual(result2.TokenType, Bearer); + Assert.AreEqual(TokenSource.Cache, result2.AuthenticationResultMetadata.TokenSource); + #endregion Identity 2 + + // TODO: Assert.AreEqual(CertificateCache.Count, 2); + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task BearerTokenIsReAcquiredWhenCertificatIsExpired( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, TestConstants.ExpiredRawCertificate); // cert will be expired on second request + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + // TODO: Add functionality to check cert expiration in the cache + /** + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId); + + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, Bearer); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + Assert.AreEqual(CertificateCache.Count, 1); // expired cert was removed from the cache + */ + } + } + #endregion Bearer Token Tests + + #region mTLS PoP Token Tests + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task mTLSPopTokenHappyPath( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, mTLSPop: true); + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, MTLSPoP); + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + // TODO: broken until Gladwin's PR is merged in + /*result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, MTLSPoP); + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);*/ + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId, $"{TestConstants.ClientId}-2")] + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId, $"{TestConstants.MiResourceId}-2")] + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId, $"{TestConstants.ObjectId}-2")] + public async Task mTLSPopTokenIsPerIdentity( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId, + string userAssignedId2) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + #region Identity 1 + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, mTLSPop: true); + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, MTLSPoP); + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + // TODO: broken until Gladwin's PR is merged in + /*result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, MTLSPoP); + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);*/ + #endregion Identity 1 + + #region Identity 2 + UserAssignedIdentityId userAssignedIdentityId2 = userAssignedIdentityId; // keep the same type, that's the most common scenario + var managedIdentityApp2 = await CreateManagedIdentityAsync( + httpManager, + userAssignedIdentityId2, + userAssignedId2, + addProbeMock: false, + addSourceCheck: false, + managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); // source is already cached + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId2, userAssignedId2, mTLSPop: true); + + var result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result2); + Assert.IsNotNull(result2.AccessToken); + Assert.AreEqual(result2.TokenType, MTLSPoP); + // Assert.IsNotNull(result2.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate + Assert.AreEqual(TokenSource.IdentityProvider, result2.AuthenticationResultMetadata.TokenSource); + + // TODO: broken until Gladwin's PR is merged in + /*result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result2); + Assert.IsNotNull(result2.AccessToken); + Assert.AreEqual(result2.TokenType, MTLSPoP); + // Assert.IsNotNull(result2.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate + Assert.AreEqual(TokenSource.Cache, result2.AuthenticationResultMetadata.TokenSource);*/ + #endregion Identity 2 + + // TODO: Assert.AreEqual(CertificateCache.Count, 2); + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task mTLSPopTokenIsReAcquiredWhenCertificatIsExpired( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, TestConstants.ExpiredRawCertificate, mTLSPop: true); + + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, MTLSPoP); + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + // TODO: Add functionality to check cert expiration in the cache + /** + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, mTLSPop: true); + + result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(s_fakeAttestationProvider) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(result.TokenType, MTLSPoP); + // Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + Assert.AreEqual(CertificateCache.Count, 1); // expired cert was removed from the cache + */ + } + } + #endregion mTLS Pop Token Tests + #endregion Acceptance Tests + + [TestMethod] + public async Task GetCsrMetadataAsyncSucceeds() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + var handler = httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + + await CreateManagedIdentityAsync(httpManager, addProbeMock: false).ConfigureAwait(false); + } + } + + [TestMethod] + public async Task GetCsrMetadataAsyncSucceedsAfterRetry() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + // First attempt fails with INTERNAL_SERVER_ERROR (500) + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.InternalServerError)); + + // Second attempt succeeds (defined inside of CreateSAMIAsync) + await CreateManagedIdentityAsync(httpManager).ConfigureAwait(false); + } + } + + [TestMethod] + public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: null)); + + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); + + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); + } + } + + [TestMethod] + public async Task GetCsrMetadataAsyncFailsWithInvalidFormat() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: "I_MDS/150.870.65.1854")); + + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); + + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); + } + } + + [TestMethod] + public async Task GetCsrMetadataAsyncFailsAfterMaxRetries() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + const int Num500Errors = 1 + TestCsrMetadataProbeRetryPolicy.ExponentialStrategyNumRetries; + for (int i = 0; i < Num500Errors; i++) + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.InternalServerError)); + } + + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); + + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); + } + } + + [TestMethod] + public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIsNotTriggeredAsync() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.NotFound)); + + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); + + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); + } + } + + #region Cuid Tests + [TestMethod] + public void TestCsrGeneration_OnlyVmId() + { + var cuid = new CuidInfo + { + VmId = TestConstants.VmId + }; + + var rsa = InMemoryManagedIdentityKeyProvider.CreateRsaKeyPair(); + var (csr, _) = Csr.Generate(rsa, TestConstants.ClientId, TestConstants.TenantId, cuid); + CsrValidator.ValidateCsrContent(csr, TestConstants.ClientId, TestConstants.TenantId, cuid); + } + + [TestMethod] + public void TestCsrGeneration_VmIdAndVmssId() + { + var cuid = new CuidInfo + { + VmId = TestConstants.VmId, + VmssId = TestConstants.VmssId + }; + + var rsa = InMemoryManagedIdentityKeyProvider.CreateRsaKeyPair(); + var (csr, _) = Csr.Generate(rsa, TestConstants.ClientId, TestConstants.TenantId, cuid); + CsrValidator.ValidateCsrContent(csr, TestConstants.ClientId, TestConstants.TenantId, cuid); + } + #endregion + + [DataTestMethod] + [DataRow("Invalid@#$%Certificate!")] + [DataRow("")] + [DataRow(null)] + public void TestCsrGeneration_BadCert_ThrowsMsalServiceException(string badCert) + { + Assert.ThrowsException(() => + CsrValidator.ParseRawCsr(badCert)); + } + + #region AttachPrivateKeyToCert Tests + [TestMethod] + public void AttachPrivateKeyToCert_ValidInputs_ReturnsValidCertificate() + { + using (RSA rsa = RSA.Create()) + { + X509Certificate2 certificate = CommonCryptographyManager.AttachPrivateKeyToCert(TestConstants.ValidRawCertificate, TestCsrFactory.CreateMockRsa()); + Assert.IsNotNull(certificate); + } + } + + [DataTestMethod] + [DataRow("Invalid@#$%Certificate!")] + [DataRow("")] + [DataRow(null)] + public void AttachPrivateKeyToCert_BadContent_ThrowsMsalServiceException(string badCert) + { + using (RSA rsa = RSA.Create()) + { + Assert.ThrowsException(() => + CommonCryptographyManager.AttachPrivateKeyToCert(badCert, rsa)); + } + } + + [TestMethod] + public void AttachPrivateKeyToCert_NullPrivateKey_ThrowsArgumentNullException() + { + Assert.ThrowsException(() => + CommonCryptographyManager.AttachPrivateKeyToCert(TestConstants.ValidRawCertificate, null)); + } + #endregion + + #region Attestation Tests + [TestMethod] + public async Task MtlsPop_AttestationProviderMissing_ThrowsClientException() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + var mi = await CreateManagedIdentityAsync(httpManager, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + + // CreateManagedIdentityAsync does a probe; Add one more CSR response for the actual acquire. + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + + var ex = await Assert.ThrowsExceptionAsync(async () => + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + // Intentionally DO NOT call .WithAttestationProviderForTests(...) + .ExecuteAsync().ConfigureAwait(false) + ).ConfigureAwait(false); + + Assert.AreEqual("attestation_failure", ex.ErrorCode); + } + } + + [TestMethod] + public async Task MtlsPop_AttestationProviderReturnsNull_ThrowsClientException() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + var mi = await CreateManagedIdentityAsync(httpManager, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + + // CreateManagedIdentityAsync does a probe; Add one more CSR response for the actual acquire. + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + + var nullProvider = new Func>( + (input, ct) => Task.FromResult(null)); + + var ex = await Assert.ThrowsExceptionAsync(async () => + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(nullProvider) + .ExecuteAsync().ConfigureAwait(false) + ).ConfigureAwait(false); + + Assert.AreEqual("attestation_failed", ex.ErrorCode); + } + } + + [TestMethod] + public async Task MtlsPop_AttestationProviderReturnsEmptyToken_ThrowsClientException() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + var mi = await CreateManagedIdentityAsync(httpManager, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); + + // CreateManagedIdentityAsync does a probe; Add one more CSR response for the actual acquire. + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + + var emptyProvider = new Func>( + (input, ct) => Task.FromResult(new AttestationTokenResponse { AttestationToken = " " })); + + var ex = await Assert.ThrowsExceptionAsync(async () => + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationProviderForTests(emptyProvider) + .ExecuteAsync().ConfigureAwait(false) + ).ConfigureAwait(false); + + Assert.AreEqual("attestation_failed", ex.ErrorCode); + } + } + + [TestMethod] + public async Task mTLSPop_RequestedWithoutKeyGuard_ThrowsClientException() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + + // Force in-memory keys (i.e., not KeyGuard) + var managedIdentityApp = await CreateManagedIdentityAsync( + httpManager, + managedIdentityKeyType: ManagedIdentityKeyType.InMemory + ).ConfigureAwait(false); + + // CreateManagedIdentityAsync does a probe; Add one more CSR response for the actual acquire. + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + + var ex = await Assert.ThrowsExceptionAsync(async () => + await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() // request PoP on a non-KeyGuard env + .ExecuteAsync().ConfigureAwait(false) + ).ConfigureAwait(false); + + Assert.AreEqual("mtls_pop_requires_keyguard", ex.ErrorCode); + } + } + #endregion + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/InMemoryManagedIdentityKeyProviderTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/InMemoryManagedIdentityKeyProviderTests.cs new file mode 100644 index 0000000000..ba501fce15 --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/InMemoryManagedIdentityKeyProviderTests.cs @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Linq; +using System.Security.Cryptography; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.ManagedIdentity.KeyProviders; +using Microsoft.Identity.Client.Core; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NSubstitute; // For Substitute.For() + +namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests +{ + [TestClass] + public class InMemoryManagedIdentityKeyProviderTests + { + private static (InMemoryManagedIdentityKeyProvider keyProvider, ILoggerAdapter logger) CreateKeyProviderAndLogger() + { + return (new InMemoryManagedIdentityKeyProvider(), Substitute.For()); + } + + [TestMethod] + public async Task ReturnsRsa2048_AndCaches_Success() + { + var (keyProvider, logger) = CreateKeyProviderAndLogger(); + + ManagedIdentityKeyInfo k1 = await keyProvider.GetOrCreateKeyAsync(logger, CancellationToken.None).ConfigureAwait(false); + ManagedIdentityKeyInfo k2 = await keyProvider.GetOrCreateKeyAsync(logger, CancellationToken.None).ConfigureAwait(false); + + Assert.IsNotNull(k1); + Assert.AreSame(k1, k2, "Provider should cache the same ManagedIdentityKeyInfo instance per process."); + Assert.IsInstanceOfType(k1.Key, typeof(RSA)); + Assert.IsTrue(k1.Key.KeySize >= Constants.RsaKeySize); + Assert.AreEqual(ManagedIdentityKeyType.InMemory, k1.Type); + } + + [TestMethod] + public async Task Concurrency_SingleCreation() + { + var (keyProvider, logger) = CreateKeyProviderAndLogger(); + + var tasks = Enumerable.Range(0, 32) + .Select(_ => keyProvider.GetOrCreateKeyAsync(logger, CancellationToken.None)) + .ToArray(); + + await Task.WhenAll(tasks).ConfigureAwait(false); + + var first = tasks[0].Result; + foreach (var task in tasks) + { + Assert.AreSame(first, task.Result, "All concurrent calls should return the same cached ManagedIdentityKeyInfo."); + } + } + + [TestMethod] + public async Task Rsa_SignsAndVerifies() + { + var (keyProvider, logger) = CreateKeyProviderAndLogger(); + + var managedIdentityApp = await keyProvider.GetOrCreateKeyAsync(logger, CancellationToken.None).ConfigureAwait(false); + + byte[] data = Encoding.UTF8.GetBytes("ping"); + byte[] signature = managedIdentityApp.Key.SignData(data, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + bool isSignatureValid = managedIdentityApp.Key.VerifyData(data, signature, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + + Assert.IsTrue(isSignatureValid); + } + + [TestMethod] + public async Task Cancellation_BeforeCreation_Throws_And_SubsequentCallSucceeds() + { + var (keyProvider, logger) = CreateKeyProviderAndLogger(); + + using (var cts = new CancellationTokenSource()) + { + cts.Cancel(); // Pre-cancel so WaitAsync throws TaskCanceledException. + + await Assert.ThrowsExceptionAsync( + () => keyProvider.GetOrCreateKeyAsync(logger, cts.Token)).ConfigureAwait(false); + } + + // Subsequent non-cancelled call should create and cache the key. + var keyInfo = await keyProvider.GetOrCreateKeyAsync(logger, CancellationToken.None).ConfigureAwait(false); + Assert.IsNotNull(keyInfo); + Assert.IsNotNull(keyInfo.Key); + Assert.AreEqual(ManagedIdentityKeyType.InMemory, keyInfo.Type); + } + + [TestMethod] + public async Task Cancellation_AfterCache_ReturnsCachedKey_IgnoringCancellation() + { + var (keyProvider, logger) = CreateKeyProviderAndLogger(); + + ManagedIdentityKeyInfo k1 = await keyProvider.GetOrCreateKeyAsync(logger, CancellationToken.None).ConfigureAwait(false); + + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Cached path should not throw. + ManagedIdentityKeyInfo k2 = await keyProvider.GetOrCreateKeyAsync(logger, cts.Token).ConfigureAwait(false); + + Assert.AreSame(k1, k2); + Assert.IsNotNull(k2.Key); + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs index 137be702a3..40304cebcd 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs @@ -40,8 +40,8 @@ public async Task MachineLearningUserAssignedHappyPathAndHasCorrectClientIdQuery var miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); @@ -83,8 +83,8 @@ public async Task MachineLearningUserAssignedNonClientIdThrowsAsync( var miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); @@ -109,8 +109,8 @@ public async Task MachineLearningTestsInvalidEndpointAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs index e7e42ba6ae..7c119315da 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Net; @@ -41,6 +42,16 @@ public class ManagedIdentityTests : TestBase private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); + private void AddImdsV2CsrMockHandlerIfNeeded( + ManagedIdentitySource managedIdentitySource, + MockHttpManager httpManager) + { + if (managedIdentitySource == ManagedIdentitySource.Imds) + { + httpManager.AddMockHandler(MockHelpers.MockCsrResponseFailure()); + } + } + [DataTestMethod] [DataRow("http://127.0.0.1:41564/msi/token/", ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)] [DataRow(AppServiceEndpoint, ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)] @@ -50,16 +61,26 @@ public class ManagedIdentityTests : TestBase [DataRow(CloudShellEndpoint, ManagedIdentitySource.CloudShell, ManagedIdentitySource.CloudShell)] [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, ManagedIdentitySource.ServiceFabric)] [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, ManagedIdentitySource.MachineLearning)] - public void GetManagedIdentityTests( + public async Task GetManagedIdentityTests( string endpoint, ManagedIdentitySource managedIdentitySource, ManagedIdentitySource expectedManagedIdentitySource) { using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) { + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); - Assert.AreEqual(expectedManagedIdentitySource, ManagedIdentityApplication.GetManagedIdentitySource()); + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager); + + + + + ManagedIdentityApplication mi = miBuilder.Build() as ManagedIdentityApplication; + + Assert.AreEqual(expectedManagedIdentitySource, await mi.GetManagedIdentitySourceAsync().ConfigureAwait(false)); } } @@ -77,7 +98,7 @@ public void GetManagedIdentityTests( [DataRow(ServiceFabricEndpoint, ResourceDefaultSuffix, ManagedIdentitySource.ServiceFabric)] [DataRow(MachineLearningEndpoint, Resource, ManagedIdentitySource.MachineLearning)] [DataRow(MachineLearningEndpoint, ResourceDefaultSuffix, ManagedIdentitySource.MachineLearning)] - public async Task ManagedIdentityHappyPathAsync( + public async Task SAMIHappyPathAsync( string endpoint, string scope, ManagedIdentitySource managedIdentitySource) @@ -85,21 +106,19 @@ public async Task ManagedIdentityHappyPathAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - - var mi = miBuilder.Build(); + var mi = miBuilder.Build(); + httpManager.AddManagedIdentityMockHandler( - endpoint, - Resource, - MockHelpers.GetMsiSuccessfulResponse(), - managedIdentitySource); + endpoint, + Resource, + MockHelpers.GetMsiSuccessfulResponse(), + managedIdentitySource); var result = await mi.AcquireTokenForManagedIdentity(scope).ExecuteAsync().ConfigureAwait(false); @@ -122,12 +141,12 @@ public async Task ManagedIdentityHappyPathAsync( [DataRow(AppServiceEndpoint, ManagedIdentitySource.AppService, TestConstants.ObjectId, UserAssignedIdentityId.ObjectId)] [DataRow(ImdsEndpoint, ManagedIdentitySource.Imds, TestConstants.ClientId, UserAssignedIdentityId.ClientId)] [DataRow(ImdsEndpoint, ManagedIdentitySource.Imds, TestConstants.MiResourceId, UserAssignedIdentityId.ResourceId)] - [DataRow(ImdsEndpoint, ManagedIdentitySource.Imds, TestConstants.MiResourceId, UserAssignedIdentityId.ObjectId)] + [DataRow(ImdsEndpoint, ManagedIdentitySource.Imds, TestConstants.ObjectId, UserAssignedIdentityId.ObjectId)] [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.ClientId, UserAssignedIdentityId.ClientId)] - [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.MiResourceId, UserAssignedIdentityId .ResourceId)] - [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.MiResourceId, UserAssignedIdentityId.ObjectId)] + [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.MiResourceId, UserAssignedIdentityId.ResourceId)] + [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.ObjectId, UserAssignedIdentityId.ObjectId)] [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, TestConstants.ClientId, UserAssignedIdentityId.ClientId)] - public async Task ManagedIdentityUserAssignedHappyPathAsync( + public async Task UAMIHappyPathAsync( string endpoint, ManagedIdentitySource managedIdentitySource, string userAssignedId, @@ -136,13 +155,14 @@ public async Task ManagedIdentityUserAssignedHappyPathAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); - ManagedIdentityApplicationBuilder miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId); + var miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId); miBuilder.WithHttpManager(httpManager); - IManagedIdentityApplication mi = miBuilder.Build(); + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( endpoint, @@ -183,14 +203,12 @@ public async Task ManagedIdentityDifferentScopesTestAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -243,14 +261,12 @@ public async Task ManagedIdentityForceRefreshTestAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -304,15 +320,13 @@ public async Task ManagedIdentityWithClaimsAndCapabilitiesTestAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithClientCapabilities(TestConstants.ClientCapabilities) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -370,14 +384,12 @@ public async Task ManagedIdentityWithClaimsTestAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -444,14 +456,12 @@ public async Task ManagedIdentityTestWrongScopeAsync(string resource, ManagedIde using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler(endpoint, resource, MockHelpers.GetMsiErrorResponse(managedIdentitySource), @@ -491,9 +501,7 @@ public async Task ManagedIdentityTestErrorResponseParsing(string errorResponse, var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler(AppServiceEndpoint, Resource, errorResponse, @@ -552,14 +560,12 @@ public async Task ManagedIdentityErrorResponseNoPayloadTestAsync(ManagedIdentity using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler(endpoint, "scope", "", @@ -594,14 +600,12 @@ public async Task ManagedIdentityNullResponseAsync(ManagedIdentitySource managed using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -634,14 +638,12 @@ public async Task ManagedIdentityUnreachableNetworkAsync(ManagedIdentitySource m using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - + var mi = miBuilder.Build(); httpManager.AddFailingRequest(new HttpRequestException("A socket operation was attempted to an unreachable network.", @@ -669,10 +671,7 @@ public async Task SystemAssignedManagedIdentityApiIdTestAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -702,10 +701,7 @@ public async Task UserAssignedManagedIdentityApiIdTestAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.WithUserAssignedClientId(TestConstants.ClientId)) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -738,8 +734,8 @@ public async Task ManagedIdentityCacheTestAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.BuildConcrete(); @@ -783,10 +779,7 @@ public async Task ManagedIdentityExpiresOnTestAsync(int expiresInHours, bool ref var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -819,10 +812,7 @@ public async Task ManagedIdentityInvalidRefreshOnThrowsAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -850,8 +840,8 @@ public async Task ManagedIdentityIsProActivelyRefreshedAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.BuildConcrete(); @@ -918,8 +908,8 @@ public async Task ProactiveRefresh_CancelsSuccessfully_Async() .WithLogging(LocalLogCallback) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.BuildConcrete(); @@ -976,8 +966,8 @@ public async Task ParallelRequests_CallTokenEndpointOnceAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.BuildConcrete(); @@ -1046,15 +1036,13 @@ public async Task InvalidJsonResponseHandling(ManagedIdentitySource managedIdent using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder .Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -1083,26 +1071,24 @@ await mi.AcquireTokenForManagedIdentity("scope") public async Task ManagedIdentityRequestTokensForDifferentScopesTestAsync( string initialResource, string newResource, - ManagedIdentitySource source, + ManagedIdentitySource managedIdentitySource, string endpoint) { using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - SetEnvironmentVariables(source, endpoint); + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); + SetEnvironmentVariables(managedIdentitySource, endpoint); - ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder + var miBuilder = ManagedIdentityApplicationBuilder .Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - - IManagedIdentityApplication mi = miBuilder.Build(); + + var mi = miBuilder.Build(); // Mock handler for the initial resource request httpManager.AddManagedIdentityMockHandler(endpoint, initialResource, - MockHelpers.GetMsiSuccessfulResponse(), source); + MockHelpers.GetMsiSuccessfulResponse(), managedIdentitySource); // Request token for initial resource AuthenticationResult result = await mi.AcquireTokenForManagedIdentity(initialResource).ExecuteAsync().ConfigureAwait(false); @@ -1111,7 +1097,7 @@ public async Task ManagedIdentityRequestTokensForDifferentScopesTestAsync( // Mock handler for the new resource request httpManager.AddManagedIdentityMockHandler(endpoint, newResource, - MockHelpers.GetMsiSuccessfulResponse(), source); + MockHelpers.GetMsiSuccessfulResponse(), managedIdentitySource); // Request token for new resource result = await mi.AcquireTokenForManagedIdentity(newResource).ExecuteAsync().ConfigureAwait(false); @@ -1133,10 +1119,6 @@ public async Task ManagedIdentityRequestTokensForDifferentScopesTestAsync( [DataTestMethod] [DataRow(ManagedIdentitySource.AppService)] [DataRow(ManagedIdentitySource.Imds)] - [DataRow(ManagedIdentitySource.AzureArc)] - [DataRow(ManagedIdentitySource.CloudShell)] - [DataRow(ManagedIdentitySource.ServiceFabric)] - [DataRow(ManagedIdentitySource.MachineLearning)] public async Task UnsupportedManagedIdentitySource_ThrowsExceptionDuringTokenAcquisitionAsync( ManagedIdentitySource managedIdentitySource) { @@ -1144,22 +1126,18 @@ public async Task UnsupportedManagedIdentitySource_ThrowsExceptionDuringTokenAcq using (new EnvVariableContext()) { - // Set unsupported environment variable SetEnvironmentVariables(managedIdentitySource, UnsupportedEndpoint); // Create the Managed Identity Application var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned); - // Build the application var mi = miBuilder.Build(); - // Attempt to acquire a token and verify an exception is thrown MsalServiceException ex = await Assert.ThrowsExceptionAsync(async () => await mi.AcquireTokenForManagedIdentity("https://management.azure.com") .ExecuteAsync() .ConfigureAwait(false)).ConfigureAwait(false); - // Verify the exception details Assert.IsNotNull(ex); Assert.AreEqual(MsalError.ManagedIdentityRequestFailed, ex.ErrorCode); } @@ -1181,7 +1159,7 @@ public async Task MixedUserAndSystemAssignedManagedIdentityTestAsync() var userAssignedBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.WithUserAssignedClientId(UserAssignedClientId)) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. + userAssignedBuilder.Config.AccessorOptions = null; var userAssignedMI = userAssignedBuilder.BuildConcrete(); @@ -1269,10 +1247,7 @@ public async Task ManagedIdentityRetryPolicyLifeTimeIsPerRequestAsync( var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager) .WithRetryPolicyFactory(_testRetryPolicyFactory); - - // Disable cache to avoid pollution - miBuilder.Config.AccessorOptions = null; - + var mi = miBuilder.Build(); // Simulate permanent errors (to trigger the maximum number of retries) @@ -1357,15 +1332,13 @@ public async Task ManagedIdentityWithCapabilitiesTestAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithClientCapabilities(TestConstants.ClientCapabilities) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; - var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -1401,7 +1374,7 @@ public void ValidateServerCertificate_OnlySetForServiceFabric() // Test all managed identity sources foreach (ManagedIdentitySource sourceType in Enum.GetValues(typeof(ManagedIdentitySource)) .Cast() - .Where(s => s != ManagedIdentitySource.None && s != ManagedIdentitySource.DefaultToImds)) + .Where(s => s != ManagedIdentitySource.None && s != ManagedIdentitySource.DefaultToImds && s != ManagedIdentitySource.ImdsV2)) { // Create a managed identity source for each type AbstractManagedIdentity managedIdentity = CreateManagedIdentitySource(sourceType, httpManager); @@ -1475,5 +1448,70 @@ private AbstractManagedIdentity CreateManagedIdentitySource(ManagedIdentitySourc return managedIdentity; } + + [TestMethod] + public async Task ManagedIdentityWithExtraQueryParametersTestAsync() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.AppService, AppServiceEndpoint); + + var extraQueryParameters = new Dictionary + { + { "param1", "value1" }, + { "param2", "value2" }, + { "custom_param", "custom_value" } + }; + + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithExperimentalFeatures(true) + .WithExtraQueryParameters(extraQueryParameters) + .WithHttpManager(httpManager); + + var mi = miBuilder.Build(); + + httpManager.AddManagedIdentityMockHandler( + AppServiceEndpoint, + Resource, + MockHelpers.GetMsiSuccessfulResponse(), + ManagedIdentitySource.AppService, + extraQueryParameters: extraQueryParameters); + + var result = await mi.AcquireTokenForManagedIdentity(Resource).ExecuteAsync().ConfigureAwait(false); + } + } + + [TestMethod] + public void WithExtraQueryParameters_MultipleCallsMergeValues() + { + var firstParams = new Dictionary + { + { "param1", "value1" }, + { "param2", "value2" } + }; + + var secondParams = new Dictionary + { + { "param3", "value3" }, + { "param4", "value4" }, + { "param1", "newvalue1" } // This should overwrite the first param1 + }; + + var miBuilder = ManagedIdentityApplicationBuilder + .Create(ManagedIdentityId.SystemAssigned) + .WithExperimentalFeatures(true) + .WithExtraQueryParameters(firstParams) + .WithExtraQueryParameters(secondParams); + + // Verify that parameters are merged + Assert.AreEqual(4, miBuilder.Config.ExtraQueryParameters.Count); + + // Verify merged values + Assert.AreEqual("newvalue1", miBuilder.Config.ExtraQueryParameters["param1"]); + Assert.AreEqual("value2", miBuilder.Config.ExtraQueryParameters["param2"]); + Assert.AreEqual("value3", miBuilder.Config.ExtraQueryParameters["param3"]); + Assert.AreEqual("value4", miBuilder.Config.ExtraQueryParameters["param4"]); + } } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ServiceFabricTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ServiceFabricTests.cs index 8dc79d1e99..48c8e6d37a 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ServiceFabricTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ServiceFabricTests.cs @@ -20,16 +20,10 @@ namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests { [TestClass] - public class ServiceFabricTests + public class ServiceFabricTests : TestBase { private const string Resource = "https://management.azure.com"; - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - [TestMethod] public async Task ServiceFabricInvalidEndpointAsync() { @@ -42,7 +36,7 @@ public async Task ServiceFabricInvalidEndpointAsync() .WithHttpManager(httpManager); // Disabling the shared cache to avoid the test to pass because of the cache - miBuilder.Config.AccessorOptions = null; + var mi = miBuilder.Build(); @@ -77,7 +71,7 @@ public void ValidateServerCertificateCallback_ServerCertificateValidationCallbac .WithHttpManager(httpManager); // Disabling the shared cache to avoid the test to pass because of the cache - miBuilder.Config.AccessorOptions = null; + var mi = miBuilder.BuildConcrete(); @@ -102,7 +96,7 @@ public async Task SFThrowsWhenGetHttpClientWithValidationIsNotImplementedAsync() .WithHttpClientFactory(new MsalSFFactoryNotImplementedException()); // Disabling the shared cache to avoid the test to pass because of the cache - miBuilder.Config.AccessorOptions = null; + var mi = miBuilder.BuildConcrete(); MsalServiceException ex = await Assert.ThrowsExceptionAsync(async () => diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/TestKeyGuardManagedIdentityKeyProvider.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/TestKeyGuardManagedIdentityKeyProvider.cs new file mode 100644 index 0000000000..ccd522e1fe --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/TestKeyGuardManagedIdentityKeyProvider.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Security.Cryptography; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.ManagedIdentity.KeyProviders; + +namespace Microsoft.Identity.Test.Common.Core.Mocks +{ + /// + /// Returns a KeyGuard key (Type = KeyGuard). On Windows, attempts to use RSACng so the + /// production check in GetAttestationJwtAsync passes; elsewhere, RSA is fine (the RSACng + /// requirement is compiled only for Windows/NETFX). + /// + internal sealed class TestKeyGuardManagedIdentityKeyProvider : IManagedIdentityKeyProvider + { + public Task GetOrCreateKeyAsync(ILoggerAdapter logger, CancellationToken cancellationToken) + { + var rsacng = new RSACng(2048); + return Task.FromResult(new ManagedIdentityKeyInfo(rsacng, ManagedIdentityKeyType.KeyGuard, "Test KeyGuard Provider")); + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/Microsoft.Identity.Test.Unit.csproj b/tests/Microsoft.Identity.Test.Unit/Microsoft.Identity.Test.Unit.csproj index eb6a842e0f..a6f295c6d2 100644 --- a/tests/Microsoft.Identity.Test.Unit/Microsoft.Identity.Test.Unit.csproj +++ b/tests/Microsoft.Identity.Test.Unit/Microsoft.Identity.Test.Unit.csproj @@ -16,6 +16,7 @@ + {3433eb33-114a-4db7-bc57-14f17f55da3c} Microsoft.Identity.Client diff --git a/tests/Microsoft.Identity.Test.Unit/PlatformProxyPerformanceTests.cs b/tests/Microsoft.Identity.Test.Unit/PlatformProxyPerformanceTests.cs index 21de4a6ad3..4ac63b1bd9 100644 --- a/tests/Microsoft.Identity.Test.Unit/PlatformProxyPerformanceTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/PlatformProxyPerformanceTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using Microsoft.Identity.Client; using Microsoft.Identity.Client.PlatformsCommon.Factories; using Microsoft.Identity.Client.PlatformsCommon.Interfaces; using Microsoft.Identity.Test.Common; @@ -16,7 +17,7 @@ public class PlatformProxyPerformanceTests [TestInitialize] public void TestInitialize() { - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); } [TestMethod] diff --git a/tests/Microsoft.Identity.Test.Unit/PublicApiTests/AccountIdTest.cs b/tests/Microsoft.Identity.Test.Unit/PublicApiTests/AccountIdTest.cs index 8c0176c439..7e4c222ea9 100644 --- a/tests/Microsoft.Identity.Test.Unit/PublicApiTests/AccountIdTest.cs +++ b/tests/Microsoft.Identity.Test.Unit/PublicApiTests/AccountIdTest.cs @@ -8,14 +8,8 @@ namespace Microsoft.Identity.Test.Unit.PublicApiTests { [TestClass] - public class AccountIdTest + public class AccountIdTest : TestBase { - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - [TestMethod] public void EqualityTest() { diff --git a/tests/Microsoft.Identity.Test.Unit/PublicApiTests/AccountTests.cs b/tests/Microsoft.Identity.Test.Unit/PublicApiTests/AccountTests.cs index 07a1c666b4..f6611db027 100644 --- a/tests/Microsoft.Identity.Test.Unit/PublicApiTests/AccountTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/PublicApiTests/AccountTests.cs @@ -23,17 +23,11 @@ namespace Microsoft.Identity.Test.Unit.PublicApiTests { [TestClass] - public class AccountTests + public class AccountTests : TestBase { // Some tests load the TokenCache from a file and use this clientId private const string ClientIdInFile = "0615b6ca-88d4-4884-8729-b178178f7c27"; - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - [TestMethod] public void Constructor_IdIsNotRequired() { diff --git a/tests/Microsoft.Identity.Test.Unit/PublicApiTests/LoggerTests.cs b/tests/Microsoft.Identity.Test.Unit/PublicApiTests/LoggerTests.cs index 7a31a199a4..b6237043de 100644 --- a/tests/Microsoft.Identity.Test.Unit/PublicApiTests/LoggerTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/PublicApiTests/LoggerTests.cs @@ -20,15 +20,13 @@ namespace Microsoft.Identity.Test.Unit.PublicApiTests { [TestClass] - public class LoggerTests + public class LoggerTests : TestBase { private LogCallback _callback; [TestInitialize] public void TestInit() { - TestCommon.ResetInternalStaticCaches(); - _callback = Substitute.For(); } diff --git a/tests/Microsoft.Identity.Test.Unit/PublicApiTests/PromptTests.cs b/tests/Microsoft.Identity.Test.Unit/PublicApiTests/PromptTests.cs index 9bcc583dde..25736c9a05 100644 --- a/tests/Microsoft.Identity.Test.Unit/PublicApiTests/PromptTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/PublicApiTests/PromptTests.cs @@ -8,13 +8,8 @@ namespace Microsoft.Identity.Test.Unit.PublicApiTests { [TestClass] - public class PromptTests + public class PromptTests : TestBase { - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } [TestMethod()] [TestCategory(TestCategories.PromptTests)] diff --git a/tests/Microsoft.Identity.Test.Unit/PublicApiTests/TelemetryTests.cs b/tests/Microsoft.Identity.Test.Unit/PublicApiTests/TelemetryTests.cs index d3fc805d78..7c4678c2a8 100644 --- a/tests/Microsoft.Identity.Test.Unit/PublicApiTests/TelemetryTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/PublicApiTests/TelemetryTests.cs @@ -27,8 +27,7 @@ public class TelemetryTests : TestBase [TestInitialize] public void Initialize() - { - TestCommon.ResetInternalStaticCaches(); + { _serviceBundle = TestCommon.CreateServiceBundleWithCustomHttpManager(null, clientId: ClientId); _logger = _serviceBundle.ApplicationLogger; _platformProxy = _serviceBundle.PlatformProxy; diff --git a/tests/Microsoft.Identity.Test.Unit/PublicApiTests/TenantIdTests.cs b/tests/Microsoft.Identity.Test.Unit/PublicApiTests/TenantIdTests.cs index 6409baff3e..e750c01e7d 100644 --- a/tests/Microsoft.Identity.Test.Unit/PublicApiTests/TenantIdTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/PublicApiTests/TenantIdTests.cs @@ -10,14 +10,8 @@ namespace Microsoft.Identity.Test.Unit.PublicApiTests { [TestClass] - public class TenantIdTests + public class TenantIdTests : TestBase { - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - [DataTestMethod] [DataRow(TestConstants.AuthorityCommonTenant, TestConstants.Common, DisplayName = "Common endpoint")] [DataRow(TestConstants.AuthorityNotKnownCommon, TestConstants.Common, DisplayName = "Common endpoint")] diff --git a/tests/Microsoft.Identity.Test.Unit/RequestsTests/IntegratedWindowsAuthUsernamePasswordTests.cs b/tests/Microsoft.Identity.Test.Unit/RequestsTests/IntegratedWindowsAuthUsernamePasswordTests.cs index c2945a87b0..e2b824ea72 100644 --- a/tests/Microsoft.Identity.Test.Unit/RequestsTests/IntegratedWindowsAuthUsernamePasswordTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/RequestsTests/IntegratedWindowsAuthUsernamePasswordTests.cs @@ -25,16 +25,10 @@ namespace Microsoft.Identity.Test.Unit.RequestsTests { [TestClass] - public class IntegratedWindowsAuthAndUsernamePasswordTests + public class IntegratedWindowsAuthAndUsernamePasswordTests : TestBase { private string _password = "x"; - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - private MockHttpMessageHandler AddMockHandlerDefaultUserRealmDiscovery(MockHttpManager httpManager) { var handler = new MockHttpMessageHandler diff --git a/tests/Microsoft.Identity.Test.Unit/RequestsTests/LongRunningOnBehalfOfTests.cs b/tests/Microsoft.Identity.Test.Unit/RequestsTests/LongRunningOnBehalfOfTests.cs index afbc34c203..670a5624de 100644 --- a/tests/Microsoft.Identity.Test.Unit/RequestsTests/LongRunningOnBehalfOfTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/RequestsTests/LongRunningOnBehalfOfTests.cs @@ -21,13 +21,8 @@ namespace Microsoft.Identity.Test.Unit.RequestsTests { [TestClass] - public class LongRunningOnBehalfOfTests + public class LongRunningOnBehalfOfTests : TestBase { - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } [TestMethod] public async Task LongRunningObo_RunsSuccessfully_TestAsync() diff --git a/tests/Microsoft.Identity.Test.Unit/RequestsTests/OnBehalfOfTests.cs b/tests/Microsoft.Identity.Test.Unit/RequestsTests/OnBehalfOfTests.cs index d523ac21db..5e37287697 100644 --- a/tests/Microsoft.Identity.Test.Unit/RequestsTests/OnBehalfOfTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/RequestsTests/OnBehalfOfTests.cs @@ -19,14 +19,8 @@ namespace Microsoft.Identity.Test.Unit.RequestsTests { [TestClass] - public class OnBehalfOfTests + public class OnBehalfOfTests : TestBase { - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - private MockHttpMessageHandler AddMockHandlerAadSuccess( MockHttpManager httpManager, string authority = TestConstants.AuthorityCommonTenant, diff --git a/tests/Microsoft.Identity.Test.Unit/TelemetryTests/OTelInstrumentationTests.cs b/tests/Microsoft.Identity.Test.Unit/TelemetryTests/OTelInstrumentationTests.cs index 2fe90904f8..30d2d42b0e 100644 --- a/tests/Microsoft.Identity.Test.Unit/TelemetryTests/OTelInstrumentationTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/TelemetryTests/OTelInstrumentationTests.cs @@ -162,8 +162,8 @@ public async Task ProactiveTokenRefresh_ValidResponse_MSI_Async() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - // Disabling shared cache options to avoid cross test pollution. - miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.BuildConcrete(); diff --git a/tests/Microsoft.Identity.Test.Unit/TestBase.cs b/tests/Microsoft.Identity.Test.Unit/TestBase.cs index 483f32f32e..460b6123d9 100644 --- a/tests/Microsoft.Identity.Test.Unit/TestBase.cs +++ b/tests/Microsoft.Identity.Test.Unit/TestBase.cs @@ -42,7 +42,7 @@ public virtual void TestInitialize() Trace.WriteLine("Framework: .NET "); #endif Trace.WriteLine("Test started " + TestContext.TestName); - TestCommon.ResetInternalStaticCaches(); + ApplicationBase.ResetStateForTest(); } [TestCleanup] diff --git a/tests/Microsoft.Identity.Test.Unit/UtilTests/ScopeHelperTests.cs b/tests/Microsoft.Identity.Test.Unit/UtilTests/ScopeHelperTests.cs index d92b47b08c..31567023a7 100644 --- a/tests/Microsoft.Identity.Test.Unit/UtilTests/ScopeHelperTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/UtilTests/ScopeHelperTests.cs @@ -13,16 +13,10 @@ namespace Microsoft.Identity.Test.Unit.UtilTests { [TestClass] - public class ScopeHelperTests + public class ScopeHelperTests : TestBase { private const string LotsOfScopes = "Agreement.Read.All Agreement.ReadWrite.All AgreementAcceptance.Read AgreementAcceptance.Read.All AllSites.FullControl AllSites.Manage AllSites.Read AllSites.Write AppCatalog.ReadWrite.All AuditLog.Read.All Bookings.Manage.All Bookings.Read.All Bookings.ReadWrite.All BookingsAppointment.ReadWrite.All Calendars.Read Calendars.Read.All Calendars.Read.Shared Calendars.ReadWrite Calendars.ReadWrite.All Calendars.ReadWrite.Shared Contacts.Read Contacts.Read.All Contacts.Read.Shared Contacts.ReadWrite Contacts.ReadWrite.All Contacts.ReadWrite.Shared Device.Command Device.Read DeviceManagementApps.Read.All DeviceManagementApps.ReadWrite.All DeviceManagementConfiguration.Read.All DeviceManagementConfiguration.ReadWrite.All DeviceManagementManagedDevices.PrivilegedOperations.All DeviceManagementManagedDevices.Read.All DeviceManagementManagedDevices.ReadWrite.All DeviceManagementRBAC.Read.All DeviceManagementRBAC.ReadWrite.All DeviceManagementServiceConfig.Read.All DeviceManagementServiceConfig.ReadWrite.All Directory.AccessAsUser.All Directory.Read.All Directory.ReadWrite.All EAS.AccessAsUser.All EduAdministration.Read EduAdministration.ReadWrite EduAssignments.Read EduAssignments.ReadBasic EduAssignments.ReadWrite EduAssignments.ReadWriteBasic EduRoster.Read EduRoster.ReadBasic EduRoster.ReadWrite email EWS.AccessAsUser.All Exchange.Manage Files.Read Files.Read.All Files.Read.Selected Files.ReadWrite Files.ReadWrite.All Files.ReadWrite.AppFolder Files.ReadWrite.Selected Financials.ReadWrite.All Group.Read.All Group.ReadWrite.All IdentityProvider.Read.All IdentityProvider.ReadWrite.All IdentityRiskEvent.Read.All Mail.Read Mail.Read.All Mail.Read.Shared Mail.ReadWrite Mail.ReadWrite.All Mail.ReadWrite.Shared Mail.Send Mail.Send.All Mail.Send.Shared MailboxSettings.Read MailboxSettings.ReadWrite Member.Read.Hidden MyFiles.Read MyFiles.Write Notes.Create Notes.Read Notes.Read.All Notes.ReadWrite Notes.ReadWrite.All Notes.ReadWrite.CreatedByApp offline_access openid People.Read People.Read.All People.ReadWrite PrivilegedAccess.ReadWrite.AzureAD PrivilegedAccess.ReadWrite.AzureResources profile Reports.Read.All SecurityEvents.Read.All SecurityEvents.ReadWrite.All Sites.FullControl.All Sites.Manage.All Sites.Read.All Sites.ReadWrite.All Sites.Search.All Subscription.Read.All Tasks.Read Tasks.Read.Shared Tasks.ReadWrite Tasks.ReadWrite.Shared TermStore.Read.All TermStore.ReadWrite.All User.Export.All User.Invite.All User.Read User.Read.All User.ReadBasic.All User.ReadWrite User.ReadWrite.All UserActivity.ReadWrite.CreatedByApp"; - [TestInitialize] - public void TestInitialize() - { - TestCommon.ResetInternalStaticCaches(); - } - [TestMethod] public void ScopeHelperPerf() { diff --git a/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/ManagedIdentityAppVM.csproj b/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/ManagedIdentityAppVM.csproj index 150d7f869c..ea3a7b6aec 100644 --- a/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/ManagedIdentityAppVM.csproj +++ b/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/ManagedIdentityAppVM.csproj @@ -8,6 +8,7 @@ + diff --git a/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/Program.cs b/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/Program.cs index 427b7ca149..f9f72091a9 100644 --- a/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/Program.cs +++ b/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/Program.cs @@ -4,6 +4,7 @@ using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; using Microsoft.IdentityModel.Abstractions; +using Microsoft.Identity.Client.MtlsPop; IIdentityLogger identityLogger = new IdentityLogger(); @@ -20,6 +21,7 @@ try { var result = await mi.AcquireTokenForManagedIdentity(scope) + .WithMtlsProofOfPossession() .ExecuteAsync().ConfigureAwait(false); Console.WriteLine("Success");