From 5498968b64d291e6fc73645e17e949b84a9cf41b Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 6 Aug 2025 12:12:25 -0400 Subject: [PATCH 01/29] Initial commit. 2 TODOs --- .../ManagedIdentity/CsrRequest.cs | 41 ++++++++++ .../ManagedIdentity/CsrRequestResponse.cs | 53 +++++++++++++ .../ImdsV2ManagedIdentitySource.cs | 74 ++++++++++++++++++- .../ManagedIdentityTests/ImdsV2Tests.cs | 2 + 4 files changed, 167 insertions(+), 3 deletions(-) create mode 100644 src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs create mode 100644 src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs new file mode 100644 index 0000000000..eda86d9325 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + internal class CsrRequest + { + public string Pem { get; } + + public CsrRequest(string pem) + { + Pem = pem ?? throw new ArgumentNullException(nameof(pem)); + } + + /// + /// Generates a CSR for the given client, tenant, and CUID info. + /// + /// Managed Identity client_id. + /// AAD tenant_id. + /// CuidInfo object containing VMID and VMSSID. + /// CsrRequest containing the PEM CSR. + public static CsrRequest Generate(string clientId, string tenantId, CuidInfo cuid) + { + if (string.IsNullOrWhiteSpace(clientId)) + throw new ArgumentException("clientId must not be null or empty.", nameof(clientId)); + if (string.IsNullOrWhiteSpace(tenantId)) + throw new ArgumentException("tenantId must not be null or empty.", nameof(tenantId)); + if (cuid == null) + throw new ArgumentNullException(nameof(cuid)); + if (string.IsNullOrWhiteSpace(cuid.Vmid)) + throw new ArgumentException("cuid.Vmid must not be null or empty.", nameof(cuid.Vmid)); + if (string.IsNullOrWhiteSpace(cuid.Vmssid)) + throw new ArgumentException("cuid.Vmssid must not be null or empty.", nameof(cuid.Vmssid)); + + // TODO: Implement the actual CSR generation logic. + return new CsrRequest("pem"); + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs new file mode 100644 index 0000000000..10274e48ba --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if SUPPORTS_SYSTEM_TEXT_JSON + using JsonProperty = System.Text.Json.Serialization.JsonPropertyNameAttribute; +#else +using Microsoft.Identity.Client.Utils; +using Microsoft.Identity.Json; +#endif + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + /// + /// Represents the response for a Managed Identity CSR request. + /// + internal class CsrRequestResponse + { + [JsonProperty("client_id")] + public string ClientId { get; } + + [JsonProperty("tenant_id")] + public string TenantId { get; } + + [JsonProperty("client_credential")] + public string ClientCredential { get; } + + [JsonProperty("regional_token_url")] + public string RegionalTokenUrl { get; } + + [JsonProperty("expires_in")] + public int ExpiresIn { get; } + + [JsonProperty("refresh_in")] + public int RefreshIn { get; } + + public CsrRequestResponse() { } + + public static bool ValidateCsrRequestResponse(CsrRequestResponse csrRequestResponse) + { + if (string.IsNullOrEmpty(csrRequestResponse.ClientId) || + string.IsNullOrEmpty(csrRequestResponse.TenantId) || + string.IsNullOrEmpty(csrRequestResponse.ClientCredential) || + string.IsNullOrEmpty(csrRequestResponse.RegionalTokenUrl) || + csrRequestResponse.ExpiresIn <= 0 || + csrRequestResponse.RefreshIn <= 0) + { + return false; + } + + return true; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 9db03cc298..6fef52849b 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Net; +using System.Net.Http; using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Http; @@ -16,6 +17,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; + private const string CsrRequestPath = "/metadata/identity/issuecredential"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, @@ -29,7 +31,7 @@ public static async Task GetCsrMetadataAsync( requestContext.Logger); if (userAssignedIdQueryParam != null) { - queryParams += $"{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}"; + queryParams += $"&{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}"; } var headers = new Dictionary @@ -41,7 +43,6 @@ public static async Task GetCsrMetadataAsync( IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.CsrMetadataProbe); - // CSR metadata GET request HttpResponse response = null; try @@ -50,7 +51,7 @@ public static async Task GetCsrMetadataAsync( ImdsManagedIdentitySource.GetValidatedEndpoint(requestContext.Logger, CsrMetadataPath, queryParams), headers, body: null, - method: System.Net.Http.HttpMethod.Get, + method: HttpMethod.Get, logger: requestContext.Logger, doNotThrow: false, mtlsCertificate: null, @@ -194,8 +195,75 @@ public static AbstractManagedIdentity Create(RequestContext requestContext) internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.ImdsV2) { } + private async Task ExecuteCsrRequestAsync( + RequestContext requestContext, + string queryParams, + string pem) + { + var headers = new Dictionary + { + { "Metadata", "true" }, + { "x-ms-client-request-id", requestContext.CorrelationId.ToString() } + }; + + IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.Imds); + + HttpResponse response = null; + + try + { + response = await requestContext.ServiceBundle.HttpManager.SendRequestAsync( + ImdsManagedIdentitySource.GetValidatedEndpoint(requestContext.Logger, CsrRequestPath, queryParams), + headers, + body: new StringContent($"{{\"pem\":\"{pem}\"}}", System.Text.Encoding.UTF8, "application/json"), + method: HttpMethod.Post, + logger: requestContext.Logger, + doNotThrow: false, + mtlsCertificate: null, + validateServerCertificate: null, + cancellationToken: requestContext.UserCancellationToken, + retryPolicy: retryPolicy) + .ConfigureAwait(false); + } + catch (Exception ex) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCsrRequest failed.", + ex, + ManagedIdentitySource.ImdsV2, + (int)response.StatusCode); + } + + var csrRequestResponse = JsonHelper.DeserializeFromJson(response.Body); + if (!CsrRequestResponse.ValidateCsrRequestResponse(csrRequestResponse)) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed because the CsrMetadata response is invalid. Status code: {response.StatusCode} Body: {response.Body}", + null, + ManagedIdentitySource.ImdsV2, + (int)response.StatusCode); + } + + return csrRequestResponse; + } + protected override ManagedIdentityRequest CreateRequest(string resource) { + var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); + var csrRequest = CsrRequest.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); + + var queryParams = $"cid={csrMetadata.Cuid}"; + if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) + { + queryParams += $"&uaid{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; + } + queryParams += $"&api-version={ImdsManagedIdentitySource.ImdsApiVersion}"; + + var csrRequestResponse = ExecuteCsrRequestAsync(_requestContext, queryParams, csrRequest.Pem); + throw new NotImplementedException(); } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index e1aea27aa4..4bf53c7a42 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -130,5 +130,7 @@ public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIs Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } } + + // TODO: Create CSR generation unit tests } } From 6bc21644d3fca39d43d4e04b4640161b98c4c3aa Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 6 Aug 2025 16:50:16 -0400 Subject: [PATCH 02/29] Implemented CSR generator --- .../ManagedIdentity/CsrRequest.cs | 402 +++++++++++++++++- .../ManagedIdentityTests/ImdsV2Tests.cs | 21 +- 2 files changed, 420 insertions(+), 3 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs index eda86d9325..aa692a5b0f 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs @@ -2,6 +2,8 @@ // Licensed under the MIT License. using System; +using System.Security.Cryptography; +using System.Text; namespace Microsoft.Identity.Client.ManagedIdentity { @@ -34,8 +36,404 @@ public static CsrRequest Generate(string clientId, string tenantId, CuidInfo cui if (string.IsNullOrWhiteSpace(cuid.Vmssid)) throw new ArgumentException("cuid.Vmssid must not be null or empty.", nameof(cuid.Vmssid)); - // TODO: Implement the actual CSR generation logic. - return new CsrRequest("pem"); + string pemCsr = GeneratePkcs10Csr(clientId, tenantId, cuid); + return new CsrRequest(pemCsr); } + + /// + /// Generates a PKCS#10 Certificate Signing Request in PEM format. + /// + private static string GeneratePkcs10Csr(string clientId, string tenantId, CuidInfo cuid) + { + // Generate RSA key pair (2048-bit) + RSA rsa = CreateRsaKeyPair(); + + try + { + // Build the CSR components + byte[] certificationRequestInfo = BuildCertificationRequestInfo(clientId, tenantId, cuid, rsa); + byte[] signatureAlgorithm = BuildSignatureAlgorithmIdentifier(); + byte[] signature = SignCertificationRequestInfo(certificationRequestInfo, rsa); + + // Combine into final CSR structure + byte[] csrBytes = BuildFinalCsr(certificationRequestInfo, signatureAlgorithm, signature); + + // Convert to PEM format + return ConvertToPem(csrBytes); + } + finally + { + rsa?.Dispose(); + } + } + + /// + /// Creates a 2048-bit RSA key pair compatible with all target frameworks. + /// + private static RSA CreateRsaKeyPair() + { +#if NET462 || NET472 + var rsa = new RSACryptoServiceProvider(2048); + return rsa; +#else + var rsa = RSA.Create(); + rsa.KeySize = 2048; + return rsa; +#endif + } + + /// + /// Builds the CertificationRequestInfo structure containing subject, public key, and attributes. + /// + private static byte[] BuildCertificationRequestInfo(string clientId, string tenantId, CuidInfo cuid, RSA rsa) + { + var components = new System.Collections.Generic.List(); + + // Version (INTEGER 0) + components.Add(EncodeAsn1Integer(new byte[] { 0x00 })); + + // Subject: CN=, DC= + components.Add(BuildSubjectName(clientId, tenantId)); + + // SubjectPublicKeyInfo + components.Add(BuildSubjectPublicKeyInfo(rsa)); + + // Attributes (including CUID) + components.Add(BuildAttributes(cuid)); + + return EncodeAsn1Sequence(components.ToArray()); + } + + /// + /// Builds the X.500 Distinguished Name for the subject. + /// + private static byte[] BuildSubjectName(string clientId, string tenantId) + { + var rdnSequence = new System.Collections.Generic.List(); + + // CN= + byte[] cnOid = EncodeAsn1ObjectIdentifier(new int[] { 2, 5, 4, 3 }); // commonName OID + byte[] cnValue = EncodeAsn1Utf8String(clientId); + byte[] cnAttributeValue = EncodeAsn1Sequence(new[] { cnOid, cnValue }); + rdnSequence.Add(EncodeAsn1Set(new[] { cnAttributeValue })); + + // DC= + byte[] dcOid = EncodeAsn1ObjectIdentifier(new int[] { 0, 9, 2342, 19200300, 100, 1, 25 }); // domainComponent OID + byte[] dcValue = EncodeAsn1Utf8String(tenantId); + byte[] dcAttributeValue = EncodeAsn1Sequence(new[] { dcOid, dcValue }); + rdnSequence.Add(EncodeAsn1Set(new[] { dcAttributeValue })); + + return EncodeAsn1Sequence(rdnSequence.ToArray()); + } + + /// + /// Builds the SubjectPublicKeyInfo structure containing the RSA public key. + /// + private static byte[] BuildSubjectPublicKeyInfo(RSA rsa) + { + RSAParameters rsaParams = rsa.ExportParameters(false); + + // RSA Public Key structure + byte[] modulus = EncodeAsn1Integer(rsaParams.Modulus); + byte[] exponent = EncodeAsn1Integer(rsaParams.Exponent); + byte[] rsaPublicKey = EncodeAsn1Sequence(new[] { modulus, exponent }); + + // Algorithm identifier for RSA encryption + byte[] rsaOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 1 }); // RSA encryption OID + byte[] nullParam = EncodeAsn1Null(); + byte[] algorithmIdentifier = EncodeAsn1Sequence(new[] { rsaOid, nullParam }); + + // SubjectPublicKeyInfo + byte[] publicKeyBitString = EncodeAsn1BitString(rsaPublicKey); + return EncodeAsn1Sequence(new[] { algorithmIdentifier, publicKeyBitString }); + } + + /// + /// Builds the attributes section including the CUID extension. + /// + private static byte[] BuildAttributes(CuidInfo cuid) + { + var attributes = new System.Collections.Generic.List(); + + // CUID attribute (OID 1.2.840.113549.1.9.7) + byte[] cuidOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 9, 7 }); + string cuidValue = $"{cuid.Vmid}:{cuid.Vmssid}"; + byte[] cuidData = EncodeAsn1PrintableString(cuidValue); + byte[] cuidAttributeValue = EncodeAsn1Set(new[] { cuidData }); + byte[] cuidAttribute = EncodeAsn1Sequence(new[] { cuidOid, cuidAttributeValue }); + attributes.Add(cuidAttribute); + + return EncodeAsn1ContextSpecific(0, EncodeAsn1SequenceRaw(attributes.ToArray())); + } + + /// + /// Builds the signature algorithm identifier for SHA256withRSA. + /// + private static byte[] BuildSignatureAlgorithmIdentifier() + { + byte[] sha256WithRsaOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 11 }); // SHA256withRSA OID + byte[] nullParam = EncodeAsn1Null(); + return EncodeAsn1Sequence(new[] { sha256WithRsaOid, nullParam }); + } + + /// + /// Signs the CertificationRequestInfo with SHA256withRSA. + /// + private static byte[] SignCertificationRequestInfo(byte[] certificationRequestInfo, RSA rsa) + { +#if NET462 || NET472 + using (var sha256 = SHA256.Create()) + { + byte[] hash = sha256.ComputeHash(certificationRequestInfo); + var formatter = new RSAPKCS1SignatureFormatter(rsa); + formatter.SetHashAlgorithm("SHA256"); + return formatter.CreateSignature(hash); + } +#else + return rsa.SignData(certificationRequestInfo, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); +#endif + } + + /// + /// Combines all components into the final CSR structure. + /// + private static byte[] BuildFinalCsr(byte[] certificationRequestInfo, byte[] signatureAlgorithm, byte[] signature) + { + byte[] signatureBitString = EncodeAsn1BitString(signature); + return EncodeAsn1Sequence(new[] { certificationRequestInfo, signatureAlgorithm, signatureBitString }); + } + + /// + /// Converts DER-encoded bytes to PEM format. + /// + private static string ConvertToPem(byte[] derBytes) + { + string base64 = Convert.ToBase64String(derBytes); + var sb = new StringBuilder(); + sb.AppendLine("-----BEGIN CERTIFICATE REQUEST-----"); + + // Split into 64-character lines + for (int i = 0; i < base64.Length; i += 64) + { + int length = Math.Min(64, base64.Length - i); + sb.AppendLine(base64.Substring(i, length)); + } + + sb.AppendLine("-----END CERTIFICATE REQUEST-----"); + return sb.ToString(); + } + + #region ASN.1 Encoding Helpers + + /// + /// Encodes an ASN.1 SEQUENCE. + /// + private static byte[] EncodeAsn1Sequence(byte[][] components) + { + return EncodeAsn1Tag(0x30, ConcatenateByteArrays(components)); + } + + /// + /// Encodes an ASN.1 SEQUENCE without the outer tag (for raw concatenation). + /// + private static byte[] EncodeAsn1SequenceRaw(byte[][] components) + { + return ConcatenateByteArrays(components); + } + + /// + /// Encodes an ASN.1 SET. + /// + private static byte[] EncodeAsn1Set(byte[][] components) + { + return EncodeAsn1Tag(0x31, ConcatenateByteArrays(components)); + } + + /// + /// Encodes an ASN.1 INTEGER. + /// + private static byte[] EncodeAsn1Integer(byte[] value) + { + // Ensure positive integer (prepend 0x00 if high bit is set) + if (value != null && value.Length > 0 && (value[0] & 0x80) != 0) + { + byte[] paddedValue = new byte[value.Length + 1]; + paddedValue[0] = 0x00; + Array.Copy(value, 0, paddedValue, 1, value.Length); + value = paddedValue; + } + return EncodeAsn1Tag(0x02, value ?? new byte[0]); + } + + /// + /// Encodes an ASN.1 INTEGER from an integer value. + /// + private static byte[] EncodeAsn1Integer(int value) + { + if (value == 0) + return EncodeAsn1Tag(0x02, new byte[] { 0x00 }); + + var bytes = new System.Collections.Generic.List(); + int temp = value; + while (temp > 0) + { + bytes.Insert(0, (byte)(temp & 0xFF)); + temp >>= 8; + } + + return EncodeAsn1Integer(bytes.ToArray()); + } + + /// + /// Encodes an ASN.1 BIT STRING. + /// + private static byte[] EncodeAsn1BitString(byte[] value) + { + byte[] bitStringValue = new byte[value.Length + 1]; + bitStringValue[0] = 0x00; // No unused bits + Array.Copy(value, 0, bitStringValue, 1, value.Length); + return EncodeAsn1Tag(0x03, bitStringValue); + } + + /// + /// Encodes an ASN.1 UTF8String. + /// + private static byte[] EncodeAsn1Utf8String(string value) + { + byte[] utf8Bytes = Encoding.UTF8.GetBytes(value); + return EncodeAsn1Tag(0x0C, utf8Bytes); + } + + /// + /// Encodes an ASN.1 PrintableString. + /// + private static byte[] EncodeAsn1PrintableString(string value) + { + byte[] asciiBytes = Encoding.ASCII.GetBytes(value); + return EncodeAsn1Tag(0x13, asciiBytes); + } + + /// + /// Encodes an ASN.1 NULL. + /// + private static byte[] EncodeAsn1Null() + { + return new byte[] { 0x05, 0x00 }; + } + + /// + /// Encodes an ASN.1 OBJECT IDENTIFIER. + /// + private static byte[] EncodeAsn1ObjectIdentifier(int[] oid) + { + if (oid == null || oid.Length < 2) + throw new ArgumentException("OID must have at least 2 components"); + + var bytes = new System.Collections.Generic.List(); + + // First two components are encoded as (first * 40 + second) + bytes.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); + + // Remaining components + for (int i = 2; i < oid.Length; i++) + { + bytes.AddRange(EncodeOidComponent(oid[i])); + } + + return EncodeAsn1Tag(0x06, bytes.ToArray()); + } + + /// + /// Encodes an ASN.1 context-specific tag. + /// + private static byte[] EncodeAsn1ContextSpecific(int tagNumber, byte[] content) + { + byte tag = (byte)(0xA0 | tagNumber); // Context-specific, constructed + return EncodeAsn1Tag(tag, content); + } + + /// + /// Encodes an ASN.1 tag with length and content. + /// + private static byte[] EncodeAsn1Tag(byte tag, byte[] content) + { + byte[] lengthBytes = EncodeAsn1Length(content.Length); + byte[] result = new byte[1 + lengthBytes.Length + content.Length]; + result[0] = tag; + Array.Copy(lengthBytes, 0, result, 1, lengthBytes.Length); + Array.Copy(content, 0, result, 1 + lengthBytes.Length, content.Length); + return result; + } + + /// + /// Encodes ASN.1 length field. + /// + private static byte[] EncodeAsn1Length(int length) + { + if (length < 0x80) + { + return new byte[] { (byte)length }; + } + + var lengthBytes = new System.Collections.Generic.List(); + int temp = length; + while (temp > 0) + { + lengthBytes.Insert(0, (byte)(temp & 0xFF)); + temp >>= 8; + } + + byte[] result = new byte[lengthBytes.Count + 1]; + result[0] = (byte)(0x80 | lengthBytes.Count); + lengthBytes.CopyTo(result, 1); + return result; + } + + /// + /// Encodes a single OID component using variable-length encoding. + /// + private static byte[] EncodeOidComponent(int value) + { + if (value == 0) + return new byte[] { 0x00 }; + + var bytes = new System.Collections.Generic.List(); + int temp = value; + + bytes.Insert(0, (byte)(temp & 0x7F)); + temp >>= 7; + + while (temp > 0) + { + bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); + temp >>= 7; + } + + return bytes.ToArray(); + } + + /// + /// Concatenates multiple byte arrays. + /// + private static byte[] ConcatenateByteArrays(byte[][] arrays) + { + int totalLength = 0; + foreach (byte[] array in arrays) + { + totalLength += array.Length; + } + + byte[] result = new byte[totalLength]; + int offset = 0; + foreach (byte[] array in arrays) + { + Array.Copy(array, 0, result, offset, array.Length); + offset += array.Length; + } + + return result; + } + + #endregion } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 4bf53c7a42..830082c140 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using System.Net; using System.Threading.Tasks; using Microsoft.Identity.Client; @@ -131,6 +132,24 @@ public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIs } } - // TODO: Create CSR generation unit tests + [TestMethod] + public void TestCsrGeneration() + { + var cuid = new CuidInfo + { + Vmid = "test-vm-id-12345", + Vmssid = "test-vmss-id-67890" + }; + + string clientId = "12345678-1234-1234-1234-123456789012"; + string tenantId = "87654321-4321-4321-4321-210987654321"; + + // Generate CSR + var csrRequest = CsrRequest.Generate(clientId, tenantId, cuid); + + // Output the generated CSR for analysis + System.Console.WriteLine("Generated CSR:"); + System.Console.WriteLine(csrRequest.Pem); + } } } From 762ccdfbcce6ec145375354b9d588742b3b3ba2a Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 6 Aug 2025 17:00:50 -0400 Subject: [PATCH 03/29] first pass at improved unit tests --- .../ManagedIdentityTests/ImdsV2Tests.cs | 486 ++++++++++++++++++ 1 file changed, 486 insertions(+) diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 830082c140..f5c8c2ed0a 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -150,6 +150,492 @@ public void TestCsrGeneration() // Output the generated CSR for analysis System.Console.WriteLine("Generated CSR:"); System.Console.WriteLine(csrRequest.Pem); + + // Validate the CSR contents + ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); + } + + [TestMethod] + public void TestCsrGeneration_InvalidClientId() + { + var cuid = new CuidInfo + { + Vmid = "test-vm-id-12345", + Vmssid = "test-vmss-id-67890" + }; + + string tenantId = "87654321-4321-4321-4321-210987654321"; + + // Test with null client ID + Assert.ThrowsException(() => + CsrRequest.Generate(null, tenantId, cuid)); + + // Test with empty client ID + Assert.ThrowsException(() => + CsrRequest.Generate("", tenantId, cuid)); + + // Test with whitespace client ID + Assert.ThrowsException(() => + CsrRequest.Generate(" ", tenantId, cuid)); + } + + [TestMethod] + public void TestCsrGeneration_InvalidTenantId() + { + var cuid = new CuidInfo + { + Vmid = "test-vm-id-12345", + Vmssid = "test-vmss-id-67890" + }; + + string clientId = "12345678-1234-1234-1234-123456789012"; + + // Test with null tenant ID + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, null, cuid)); + + // Test with empty tenant ID + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, "", cuid)); + + // Test with whitespace tenant ID + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, " ", cuid)); + } + + [TestMethod] + public void TestCsrGeneration_InvalidCuid() + { + string clientId = "12345678-1234-1234-1234-123456789012"; + string tenantId = "87654321-4321-4321-4321-210987654321"; + + // Test with null CUID + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, tenantId, null)); + + // Test with null VMID + var cuidWithNullVmid = new CuidInfo + { + Vmid = null, + Vmssid = "test-vmss-id-67890" + }; + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, tenantId, cuidWithNullVmid)); + + // Test with empty VMID + var cuidWithEmptyVmid = new CuidInfo + { + Vmid = "", + Vmssid = "test-vmss-id-67890" + }; + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, tenantId, cuidWithEmptyVmid)); + + // Test with null VMSSID + var cuidWithNullVmssid = new CuidInfo + { + Vmid = "test-vm-id-12345", + Vmssid = null + }; + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, tenantId, cuidWithNullVmssid)); + + // Test with empty VMSSID + var cuidWithEmptyVmssid = new CuidInfo + { + Vmid = "test-vm-id-12345", + Vmssid = "" + }; + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, tenantId, cuidWithEmptyVmssid)); + } + + [TestMethod] + public void TestCsrGeneration_MalformedPem() + { + // Test parsing malformed PEM with invalid Base64 characters + string malformedPem = "-----BEGIN CERTIFICATE REQUEST-----\nInvalid@#$%Base64Content!\n-----END CERTIFICATE REQUEST-----"; + + Assert.ThrowsException(() => + ParseCsrFromPem(malformedPem)); + + // Test with wrong headers + string wrongHeaders = "-----BEGIN CERTIFICATE-----\nTUlJQzNqQ0NBY1lDQVFBd1pURT0K\n-----END CERTIFICATE-----"; + + Assert.ThrowsException(() => + ParseCsrFromPem(wrongHeaders)); + } + + #region CSR Validation Helper Methods + + /// + /// Validates the content of a CSR PEM string against expected values. + /// + private void ValidateCsrContent(string pemCsr, string expectedClientId, string expectedTenantId, CuidInfo expectedCuid) + { + // Parse the CSR from PEM format + var csrData = ParseCsrFromPem(pemCsr); + + // Parse the PKCS#10 structure + var csrInfo = ParsePkcs10Structure(csrData); + + // Validate subject name + ValidateSubjectName(csrInfo.Subject, expectedClientId, expectedTenantId); + + // Validate public key + ValidatePublicKey(csrInfo.PublicKey); + + // Validate CUID attribute + ValidateCuidAttribute(csrInfo.Attributes, expectedCuid); + + // Validate signature algorithm + ValidateSignatureAlgorithm(csrInfo.SignatureAlgorithm); + } + + /// + /// Parses a PEM-formatted CSR and returns the DER bytes. + /// + private byte[] ParseCsrFromPem(string pemCsr) + { + if (string.IsNullOrWhiteSpace(pemCsr)) + throw new ArgumentException("PEM CSR cannot be null or empty"); + + const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; + const string endMarker = "-----END CERTIFICATE REQUEST-----"; + + if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) + throw new ArgumentException("Invalid PEM format - missing CSR headers"); + + int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; + int endIndex = pemCsr.IndexOf(endMarker); + + if (beginIndex >= endIndex) + throw new ArgumentException("Invalid PEM format - malformed headers"); + + string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) + .Replace("\r", "").Replace("\n", "").Replace(" ", ""); + + try + { + return Convert.FromBase64String(base64Content); + } + catch (FormatException) + { + throw new FormatException("Invalid Base64 content in PEM CSR"); + } + } + + /// + /// Represents parsed PKCS#10 CSR information. + /// + private class CsrInfo + { + public byte[] Subject { get; set; } + public byte[] PublicKey { get; set; } + public byte[] Attributes { get; set; } + public byte[] SignatureAlgorithm { get; set; } + } + + /// + /// Parses the PKCS#10 ASN.1 structure and extracts key components. + /// + private CsrInfo ParsePkcs10Structure(byte[] derBytes) + { + int offset = 0; + + // Parse outer SEQUENCE (CertificationRequest) + var outerSequence = ParseAsn1Tag(derBytes, ref offset, 0x30); + + // Reset offset to parse the CertificationRequestInfo within the outer sequence + int infoOffset = 0; + var certRequestInfo = ParseAsn1Tag(outerSequence, ref infoOffset, 0x30); + + // Parse version (should be 0) + int versionOffset = 0; + var version = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x02); + if (version.Length != 1 || version[0] != 0x00) + throw new ArgumentException("Invalid CSR version"); + + // Parse subject + var subject = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); + + // Parse SubjectPublicKeyInfo + var publicKey = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); + + // Parse attributes (context-specific [0]) + var attributes = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0xA0); + + return new CsrInfo + { + Subject = subject, + PublicKey = publicKey, + Attributes = attributes, + SignatureAlgorithm = new byte[0] // Simplified for this test + }; + } + + /// + /// Parses an ASN.1 tag and returns its content. + /// + private byte[] ParseAsn1Tag(byte[] data, ref int offset, byte expectedTag) + { + if (offset >= data.Length) + throw new ArgumentException("Unexpected end of data"); + + // Check tag (if expectedTag is -1, accept any tag) + if (expectedTag != 255 && data[offset] != expectedTag) + throw new ArgumentException($"Expected tag 0x{expectedTag:X2}, got 0x{data[offset]:X2}"); + + offset++; + + // Parse length + int length = ParseAsn1Length(data, ref offset); + + // Extract content + if (offset + length > data.Length) + throw new ArgumentException("Invalid ASN.1 length"); + + byte[] content = new byte[length]; + Array.Copy(data, offset, content, 0, length); + offset += length; + + return content; + } + + /// + /// Parses ASN.1 length encoding. + /// + private int ParseAsn1Length(byte[] data, ref int offset) + { + if (offset >= data.Length) + throw new ArgumentException("Unexpected end of data in length"); + + byte firstByte = data[offset++]; + + // Short form + if ((firstByte & 0x80) == 0) + return firstByte; + + // Long form + int lengthBytes = firstByte & 0x7F; + if (lengthBytes == 0) + throw new ArgumentException("Indefinite length not supported"); + + if (offset + lengthBytes > data.Length) + throw new ArgumentException("Invalid length encoding"); + + int length = 0; + for (int i = 0; i < lengthBytes; i++) + { + length = (length << 8) | data[offset++]; + } + + return length; + } + + /// + /// Validates the subject name contains the expected client ID and tenant ID. + /// + private void ValidateSubjectName(byte[] subjectBytes, string expectedClientId, string expectedTenantId) + { + // Subject is already a SEQUENCE of RDNs + int offset = 0; + bool foundClientId = false; + bool foundTenantId = false; + + // Parse each RDN (Relative Distinguished Name) directly from subjectBytes + while (offset < subjectBytes.Length) + { + var rdnSet = ParseAsn1Tag(subjectBytes, ref offset, 0x31); // SET + + int rdnOffset = 0; + var rdnSequence = ParseAsn1Tag(rdnSet, ref rdnOffset, 0x30); // SEQUENCE + + // Parse OID and value + int attrOffset = 0; + var oid = ParseAsn1Tag(rdnSequence, ref attrOffset, 0x06); // OID + var value = ParseAsn1Tag(rdnSequence, ref attrOffset, 255); // Any string type + + string stringValue = System.Text.Encoding.UTF8.GetString(value); + + // Check for CN (commonName) OID: 2.5.4.3 + if (IsOid(oid, new int[] { 2, 5, 4, 3 })) + { + Assert.AreEqual(expectedClientId, stringValue, "Client ID in subject CN does not match"); + foundClientId = true; + } + // Check for DC (domainComponent) OID: 0.9.2342.19200300.100.1.25 + else if (IsOid(oid, new int[] { 0, 9, 2342, 19200300, 100, 1, 25 })) + { + Assert.AreEqual(expectedTenantId, stringValue, "Tenant ID in subject DC does not match"); + foundTenantId = true; + } + } + + Assert.IsTrue(foundClientId, "Client ID (CN) not found in subject"); + Assert.IsTrue(foundTenantId, "Tenant ID (DC) not found in subject"); + } + + /// + /// Validates the public key is a valid RSA key. + /// + private void ValidatePublicKey(byte[] publicKeyBytes) + { + // publicKeyBytes is already the SubjectPublicKeyInfo SEQUENCE content + int offset = 0; + + // Parse algorithm identifier + var algorithmId = ParseAsn1Tag(publicKeyBytes, ref offset, 0x30); + + // Parse public key bit string + var publicKeyBitString = ParseAsn1Tag(publicKeyBytes, ref offset, 0x03); + + // Validate algorithm is RSA (1.2.840.113549.1.1.1) + int algOffset = 0; + var algorithmOid = ParseAsn1Tag(algorithmId, ref algOffset, 0x06); + Assert.IsTrue(IsOid(algorithmOid, new int[] { 1, 2, 840, 113549, 1, 1, 1 }), + "Public key algorithm is not RSA"); + + // Skip the unused bits byte in bit string + if (publicKeyBitString.Length < 2 || publicKeyBitString[0] != 0x00) + throw new ArgumentException("Invalid public key bit string"); + + // Parse RSA public key (skip unused bits byte) + byte[] rsaKeyBytes = new byte[publicKeyBitString.Length - 1]; + Array.Copy(publicKeyBitString, 1, rsaKeyBytes, 0, rsaKeyBytes.Length); + + int rsaOffset = 0; + var rsaSequence = ParseAsn1Tag(rsaKeyBytes, ref rsaOffset, 0x30); + + rsaOffset = 0; + var modulus = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); + var exponent = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); + + // Validate key size (should be 2048 bits = 256 bytes, plus potential leading zero) + Assert.IsTrue(modulus.Length >= 256 && modulus.Length <= 257, + $"RSA modulus should be 2048 bits, got {modulus.Length * 8} bits"); + + // Validate exponent (commonly 65537 = 0x010001) + Assert.IsTrue(exponent.Length >= 1 && exponent.Length <= 4, "RSA exponent has invalid length"); + } + + /// + /// Validates the CUID attribute contains the expected VM and VMSS IDs. + /// + private void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expectedCuid) + { + // Attributes is a SET of attributes + // We expect one attribute with challengePassword OID (1.2.840.113549.1.9.7) + + int offset = 0; + bool foundCuid = false; + + // Parse each attribute in the SET + while (offset < attributesBytes.Length) + { + var attributeSequence = ParseAsn1Tag(attributesBytes, ref offset, 0x30); + + int attrOffset = 0; + var oid = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x06); + var valueSet = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x31); // SET of values + + // Check for challengePassword OID: 1.2.840.113549.1.9.7 + if (IsOid(oid, new int[] { 1, 2, 840, 113549, 1, 9, 7 })) + { + // Parse the value from the SET (should be one value) + int valueOffset = 0; + var value = ParseAsn1Tag(valueSet, ref valueOffset, 255); // Any string type + + string cuidValue = System.Text.Encoding.ASCII.GetString(value); + string expectedCuidValue = $"{expectedCuid.Vmid}:{expectedCuid.Vmssid}"; + + Assert.AreEqual(expectedCuidValue, cuidValue, "CUID attribute value does not match expected"); + foundCuid = true; + break; + } + } + + Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); + } + + /// + /// Validates the signature algorithm is SHA256withRSA. + /// + private void ValidateSignatureAlgorithm(byte[] signatureAlgBytes) + { + // For this test, we'll just verify that signature algorithm exists + // Full validation would require parsing the outer CSR structure + // which is more complex for this unit test scenario + Assert.IsNotNull(signatureAlgBytes, "Signature algorithm should be present"); } + + /// + /// Checks if the given OID bytes match the expected OID components. + /// + private bool IsOid(byte[] oidBytes, int[] expectedOid) + { + if (expectedOid.Length < 2) + return false; + + var expectedBytes = EncodeOid(expectedOid); + + if (oidBytes.Length != expectedBytes.Length) + return false; + + for (int i = 0; i < oidBytes.Length; i++) + { + if (oidBytes[i] != expectedBytes[i]) + return false; + } + + return true; + } + + /// + /// Encodes an OID from integer components to bytes (simplified version). + /// + private byte[] EncodeOid(int[] oid) + { + if (oid.Length < 2) + throw new ArgumentException("OID must have at least 2 components"); + + var result = new System.Collections.Generic.List(); + + // First two components are encoded as (first * 40 + second) + result.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); + + // Remaining components + for (int i = 2; i < oid.Length; i++) + { + result.AddRange(EncodeOidComponent(oid[i])); + } + + return result.ToArray(); + } + + /// + /// Encodes a single OID component using variable-length encoding. + /// + private byte[] EncodeOidComponent(int value) + { + if (value == 0) + return new byte[] { 0x00 }; + + var bytes = new System.Collections.Generic.List(); + int temp = value; + + bytes.Insert(0, (byte)(temp & 0x7F)); + temp >>= 7; + + while (temp > 0) + { + bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); + temp >>= 7; + } + + return bytes.ToArray(); + } + + #endregion } } From 4ea6c09af1ef8dc267cb6547522f175a838775d5 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 6 Aug 2025 18:06:47 -0400 Subject: [PATCH 04/29] Finished improving unit tests --- .../ImdsV2ManagedIdentitySource.cs | 2 +- .../ManagedIdentityTests/CsrValidator.cs | 419 +++++++++++++++ .../ManagedIdentityTests/ImdsV2Tests.cs | 485 ++---------------- 3 files changed, 455 insertions(+), 451 deletions(-) create mode 100644 tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 6fef52849b..08bbc9caf4 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -91,7 +91,7 @@ public static async Task GetCsrMetadataAsync( } } - if (!probeMode && !ValidateCsrMetadataResponse(response, requestContext.Logger, probeMode)) + if (!ValidateCsrMetadataResponse(response, requestContext.Logger, probeMode)) { return null; } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs new file mode 100644 index 0000000000..b4a4e99dd5 --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs @@ -0,0 +1,419 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests +{ + /// + /// Test helper to expose CsrValidator methods for testing malformed PEM. + /// + internal static class TestCsrValidator + { + public static byte[] ParseCsrFromPem(string pemCsr) + { + if (string.IsNullOrWhiteSpace(pemCsr)) + throw new ArgumentException("PEM CSR cannot be null or empty"); + + const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; + const string endMarker = "-----END CERTIFICATE REQUEST-----"; + + if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) + throw new ArgumentException("Invalid PEM format - missing CSR headers"); + + int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; + int endIndex = pemCsr.IndexOf(endMarker); + + if (beginIndex >= endIndex) + throw new ArgumentException("Invalid PEM format - malformed headers"); + + string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) + .Replace("\r", "").Replace("\n", "").Replace(" ", ""); + + try + { + return Convert.FromBase64String(base64Content); + } + catch (FormatException) + { + throw new FormatException("Invalid Base64 content in PEM CSR"); + } + } + } + + /// + /// Helper class for validating Certificate Signing Request (CSR) content and structure. + /// + internal static class CsrValidator + { + /// + /// Validates the content of a CSR PEM string against expected values. + /// + public static void ValidateCsrContent(string pemCsr, string expectedClientId, string expectedTenantId, CuidInfo expectedCuid) + { + // Parse the CSR from PEM format + var csrData = ParseCsrFromPem(pemCsr); + + // Parse the PKCS#10 structure + var csrInfo = ParsePkcs10Structure(csrData); + + // Validate subject name + ValidateSubjectName(csrInfo.Subject, expectedClientId, expectedTenantId); + + // Validate public key + ValidatePublicKey(csrInfo.PublicKey); + + // Validate CUID attribute + ValidateCuidAttribute(csrInfo.Attributes, expectedCuid); + + // Validate signature algorithm + ValidateSignatureAlgorithm(csrInfo.SignatureAlgorithm); + } + + /// + /// Parses a PEM-formatted CSR and returns the DER bytes. + /// + private static byte[] ParseCsrFromPem(string pemCsr) + { + if (string.IsNullOrWhiteSpace(pemCsr)) + throw new ArgumentException("PEM CSR cannot be null or empty"); + + const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; + const string endMarker = "-----END CERTIFICATE REQUEST-----"; + + if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) + throw new ArgumentException("Invalid PEM format - missing CSR headers"); + + int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; + int endIndex = pemCsr.IndexOf(endMarker); + + if (beginIndex >= endIndex) + throw new ArgumentException("Invalid PEM format - malformed headers"); + + string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) + .Replace("\r", "").Replace("\n", "").Replace(" ", ""); + + try + { + return Convert.FromBase64String(base64Content); + } + catch (FormatException) + { + throw new FormatException("Invalid Base64 content in PEM CSR"); + } + } + + /// + /// Represents parsed PKCS#10 CSR information. + /// + private class CsrInfo + { + public byte[] Subject { get; set; } + public byte[] PublicKey { get; set; } + public byte[] Attributes { get; set; } + public byte[] SignatureAlgorithm { get; set; } + } + + /// + /// Parses the PKCS#10 ASN.1 structure and extracts key components. + /// + private static CsrInfo ParsePkcs10Structure(byte[] derBytes) + { + int offset = 0; + + // Parse outer SEQUENCE (CertificationRequest) + var outerSequence = ParseAsn1Tag(derBytes, ref offset, 0x30); + + // Reset offset to parse the CertificationRequestInfo within the outer sequence + int infoOffset = 0; + var certRequestInfo = ParseAsn1Tag(outerSequence, ref infoOffset, 0x30); + + // Parse version (should be 0) + int versionOffset = 0; + var version = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x02); + if (version.Length != 1 || version[0] != 0x00) + throw new ArgumentException("Invalid CSR version"); + + // Parse subject + var subject = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); + + // Parse SubjectPublicKeyInfo + var publicKey = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); + + // Parse attributes (context-specific [0]) + var attributes = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0xA0); + + return new CsrInfo + { + Subject = subject, + PublicKey = publicKey, + Attributes = attributes, + SignatureAlgorithm = new byte[0] // Simplified for this test + }; + } + + /// + /// Parses an ASN.1 tag and returns its content. + /// + private static byte[] ParseAsn1Tag(byte[] data, ref int offset, byte expectedTag) + { + if (offset >= data.Length) + throw new ArgumentException("Unexpected end of data"); + + // Check tag (if expectedTag is -1, accept any tag) + if (expectedTag != 255 && data[offset] != expectedTag) + throw new ArgumentException($"Expected tag 0x{expectedTag:X2}, got 0x{data[offset]:X2}"); + + offset++; + + // Parse length + int length = ParseAsn1Length(data, ref offset); + + // Extract content + if (offset + length > data.Length) + throw new ArgumentException("Invalid ASN.1 length"); + + byte[] content = new byte[length]; + Array.Copy(data, offset, content, 0, length); + offset += length; + + return content; + } + + /// + /// Parses ASN.1 length encoding. + /// + private static int ParseAsn1Length(byte[] data, ref int offset) + { + if (offset >= data.Length) + throw new ArgumentException("Unexpected end of data in length"); + + byte firstByte = data[offset++]; + + // Short form + if ((firstByte & 0x80) == 0) + return firstByte; + + // Long form + int lengthBytes = firstByte & 0x7F; + if (lengthBytes == 0) + throw new ArgumentException("Indefinite length not supported"); + + if (offset + lengthBytes > data.Length) + throw new ArgumentException("Invalid length encoding"); + + int length = 0; + for (int i = 0; i < lengthBytes; i++) + { + length = (length << 8) | data[offset++]; + } + + return length; + } + + /// + /// Validates the subject name contains the expected client ID and tenant ID. + /// + private static void ValidateSubjectName(byte[] subjectBytes, string expectedClientId, string expectedTenantId) + { + // Subject is already a SEQUENCE of RDNs + int offset = 0; + bool foundClientId = false; + bool foundTenantId = false; + + // Parse each RDN (Relative Distinguished Name) directly from subjectBytes + while (offset < subjectBytes.Length) + { + var rdnSet = ParseAsn1Tag(subjectBytes, ref offset, 0x31); // SET + + int rdnOffset = 0; + var rdnSequence = ParseAsn1Tag(rdnSet, ref rdnOffset, 0x30); // SEQUENCE + + // Parse OID and value + int attrOffset = 0; + var oid = ParseAsn1Tag(rdnSequence, ref attrOffset, 0x06); // OID + var value = ParseAsn1Tag(rdnSequence, ref attrOffset, 255); // Any string type + + string stringValue = System.Text.Encoding.UTF8.GetString(value); + + // Check for CN (commonName) OID: 2.5.4.3 + if (IsOid(oid, new int[] { 2, 5, 4, 3 })) + { + Assert.AreEqual(expectedClientId, stringValue, "Client ID in subject CN does not match"); + foundClientId = true; + } + // Check for DC (domainComponent) OID: 0.9.2342.19200300.100.1.25 + else if (IsOid(oid, new int[] { 0, 9, 2342, 19200300, 100, 1, 25 })) + { + Assert.AreEqual(expectedTenantId, stringValue, "Tenant ID in subject DC does not match"); + foundTenantId = true; + } + } + + Assert.IsTrue(foundClientId, "Client ID (CN) not found in subject"); + Assert.IsTrue(foundTenantId, "Tenant ID (DC) not found in subject"); + } + + /// + /// Validates the public key is a valid RSA key. + /// + private static void ValidatePublicKey(byte[] publicKeyBytes) + { + // publicKeyBytes is already the SubjectPublicKeyInfo SEQUENCE content + int offset = 0; + + // Parse algorithm identifier + var algorithmId = ParseAsn1Tag(publicKeyBytes, ref offset, 0x30); + + // Parse public key bit string + var publicKeyBitString = ParseAsn1Tag(publicKeyBytes, ref offset, 0x03); + + // Validate algorithm is RSA (1.2.840.113549.1.1.1) + int algOffset = 0; + var algorithmOid = ParseAsn1Tag(algorithmId, ref algOffset, 0x06); + Assert.IsTrue(IsOid(algorithmOid, new int[] { 1, 2, 840, 113549, 1, 1, 1 }), + "Public key algorithm is not RSA"); + + // Skip the unused bits byte in bit string + if (publicKeyBitString.Length < 2 || publicKeyBitString[0] != 0x00) + throw new ArgumentException("Invalid public key bit string"); + + // Parse RSA public key (skip unused bits byte) + byte[] rsaKeyBytes = new byte[publicKeyBitString.Length - 1]; + Array.Copy(publicKeyBitString, 1, rsaKeyBytes, 0, rsaKeyBytes.Length); + + int rsaOffset = 0; + var rsaSequence = ParseAsn1Tag(rsaKeyBytes, ref rsaOffset, 0x30); + + rsaOffset = 0; + var modulus = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); + var exponent = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); + + // Validate key size (should be 2048 bits = 256 bytes, plus potential leading zero) + Assert.IsTrue(modulus.Length >= 256 && modulus.Length <= 257, + $"RSA modulus should be 2048 bits, got {modulus.Length * 8} bits"); + + // Validate exponent (commonly 65537 = 0x010001) + Assert.IsTrue(exponent.Length >= 1 && exponent.Length <= 4, "RSA exponent has invalid length"); + } + + /// + /// Validates the CUID attribute contains the expected VM and VMSS IDs. + /// + private static void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expectedCuid) + { + // Attributes is a SET of attributes + // We expect one attribute with challengePassword OID (1.2.840.113549.1.9.7) + + int offset = 0; + bool foundCuid = false; + + // Parse each attribute in the SET + while (offset < attributesBytes.Length) + { + var attributeSequence = ParseAsn1Tag(attributesBytes, ref offset, 0x30); + + int attrOffset = 0; + var oid = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x06); + var valueSet = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x31); // SET of values + + // Check for challengePassword OID: 1.2.840.113549.1.9.7 + if (IsOid(oid, new int[] { 1, 2, 840, 113549, 1, 9, 7 })) + { + // Parse the value from the SET (should be one value) + int valueOffset = 0; + var value = ParseAsn1Tag(valueSet, ref valueOffset, 255); // Any string type + + string cuidValue = System.Text.Encoding.ASCII.GetString(value); + string expectedCuidValue = $"{expectedCuid.Vmid}:{expectedCuid.Vmssid}"; + + Assert.AreEqual(expectedCuidValue, cuidValue, "CUID attribute value does not match expected"); + foundCuid = true; + break; + } + } + + Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); + } + + /// + /// Validates the signature algorithm is SHA256withRSA. + /// + private static void ValidateSignatureAlgorithm(byte[] signatureAlgBytes) + { + // For this test, we'll just verify that signature algorithm exists + // Full validation would require parsing the outer CSR structure + // which is more complex for this unit test scenario + Assert.IsNotNull(signatureAlgBytes, "Signature algorithm should be present"); + } + + /// + /// Checks if the given OID bytes match the expected OID components. + /// + private static bool IsOid(byte[] oidBytes, int[] expectedOid) + { + if (expectedOid.Length < 2) + return false; + + var expectedBytes = EncodeOid(expectedOid); + + if (oidBytes.Length != expectedBytes.Length) + return false; + + for (int i = 0; i < oidBytes.Length; i++) + { + if (oidBytes[i] != expectedBytes[i]) + return false; + } + + return true; + } + + /// + /// Encodes an OID from integer components to bytes (simplified version). + /// + private static byte[] EncodeOid(int[] oid) + { + if (oid.Length < 2) + throw new ArgumentException("OID must have at least 2 components"); + + var result = new System.Collections.Generic.List(); + + // First two components are encoded as (first * 40 + second) + result.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); + + // Remaining components + for (int i = 2; i < oid.Length; i++) + { + result.AddRange(EncodeOidComponent(oid[i])); + } + + return result.ToArray(); + } + + /// + /// Encodes a single OID component using variable-length encoding. + /// + private static byte[] EncodeOidComponent(int value) + { + if (value == 0) + return new byte[] { 0x00 }; + + var bytes = new System.Collections.Generic.List(); + int temp = value; + + bytes.Insert(0, (byte)(temp & 0x7F)); + temp >>= 7; + + while (temp > 0) + { + bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); + temp >>= 7; + } + + return bytes.ToArray(); + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index f5c8c2ed0a..b9ee9b2268 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -151,12 +151,18 @@ public void TestCsrGeneration() System.Console.WriteLine("Generated CSR:"); System.Console.WriteLine(csrRequest.Pem); - // Validate the CSR contents - ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); + // Validate the CSR contents using the helper + CsrValidator.ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); } - [TestMethod] - public void TestCsrGeneration_InvalidClientId() + [DataTestMethod] + [DataRow(null, "87654321-4321-4321-4321-210987654321", DisplayName = "Null ClientId")] + [DataRow("", "87654321-4321-4321-4321-210987654321", DisplayName = "Empty ClientId")] + [DataRow(" ", "87654321-4321-4321-4321-210987654321", DisplayName = "Whitespace ClientId")] + [DataRow("12345678-1234-1234-1234-123456789012", null, DisplayName = "Null TenantId")] + [DataRow("12345678-1234-1234-1234-123456789012", "", DisplayName = "Empty TenantId")] + [DataRow("12345678-1234-1234-1234-123456789012", " ", DisplayName = "Whitespace TenantId")] + public void TestCsrGeneration_InvalidParameters(string clientId, string tenantId) { var cuid = new CuidInfo { @@ -164,47 +170,12 @@ public void TestCsrGeneration_InvalidClientId() Vmssid = "test-vmss-id-67890" }; - string tenantId = "87654321-4321-4321-4321-210987654321"; - - // Test with null client ID - Assert.ThrowsException(() => - CsrRequest.Generate(null, tenantId, cuid)); - - // Test with empty client ID - Assert.ThrowsException(() => - CsrRequest.Generate("", tenantId, cuid)); - - // Test with whitespace client ID Assert.ThrowsException(() => - CsrRequest.Generate(" ", tenantId, cuid)); + CsrRequest.Generate(clientId, tenantId, cuid)); } [TestMethod] - public void TestCsrGeneration_InvalidTenantId() - { - var cuid = new CuidInfo - { - Vmid = "test-vm-id-12345", - Vmssid = "test-vmss-id-67890" - }; - - string clientId = "12345678-1234-1234-1234-123456789012"; - - // Test with null tenant ID - Assert.ThrowsException(() => - CsrRequest.Generate(clientId, null, cuid)); - - // Test with empty tenant ID - Assert.ThrowsException(() => - CsrRequest.Generate(clientId, "", cuid)); - - // Test with whitespace tenant ID - Assert.ThrowsException(() => - CsrRequest.Generate(clientId, " ", cuid)); - } - - [TestMethod] - public void TestCsrGeneration_InvalidCuid() + public void TestCsrGeneration_NullCuid() { string clientId = "12345678-1234-1234-1234-123456789012"; string tenantId = "87654321-4321-4321-4321-210987654321"; @@ -212,430 +183,44 @@ public void TestCsrGeneration_InvalidCuid() // Test with null CUID Assert.ThrowsException(() => CsrRequest.Generate(clientId, tenantId, null)); + } - // Test with null VMID - var cuidWithNullVmid = new CuidInfo - { - Vmid = null, - Vmssid = "test-vmss-id-67890" - }; - Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, cuidWithNullVmid)); - - // Test with empty VMID - var cuidWithEmptyVmid = new CuidInfo - { - Vmid = "", - Vmssid = "test-vmss-id-67890" - }; - Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, cuidWithEmptyVmid)); + [DataTestMethod] + [DataRow(null, "test-vmss-id-67890", DisplayName = "Null VMID")] + [DataRow("", "test-vmss-id-67890", DisplayName = "Empty VMID")] + [DataRow("test-vm-id-12345", null, DisplayName = "Null VMSSID")] + [DataRow("test-vm-id-12345", "", DisplayName = "Empty VMSSID")] + public void TestCsrGeneration_InvalidCuidProperties(string vmid, string vmssid) + { + string clientId = "12345678-1234-1234-1234-123456789012"; + string tenantId = "87654321-4321-4321-4321-210987654321"; - // Test with null VMSSID - var cuidWithNullVmssid = new CuidInfo + var cuid = new CuidInfo { - Vmid = "test-vm-id-12345", - Vmssid = null + Vmid = vmid, + Vmssid = vmssid }; - Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, cuidWithNullVmssid)); - // Test with empty VMSSID - var cuidWithEmptyVmssid = new CuidInfo - { - Vmid = "test-vm-id-12345", - Vmssid = "" - }; Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, cuidWithEmptyVmssid)); + CsrRequest.Generate(clientId, tenantId, cuid)); } [TestMethod] - public void TestCsrGeneration_MalformedPem() + public void TestCsrGeneration_MalformedPem_FormatException() { - // Test parsing malformed PEM with invalid Base64 characters string malformedPem = "-----BEGIN CERTIFICATE REQUEST-----\nInvalid@#$%Base64Content!\n-----END CERTIFICATE REQUEST-----"; - Assert.ThrowsException(() => - ParseCsrFromPem(malformedPem)); - - // Test with wrong headers - string wrongHeaders = "-----BEGIN CERTIFICATE-----\nTUlJQzNqQ0NBY1lDQVFBd1pURT0K\n-----END CERTIFICATE-----"; - - Assert.ThrowsException(() => - ParseCsrFromPem(wrongHeaders)); - } - - #region CSR Validation Helper Methods - - /// - /// Validates the content of a CSR PEM string against expected values. - /// - private void ValidateCsrContent(string pemCsr, string expectedClientId, string expectedTenantId, CuidInfo expectedCuid) - { - // Parse the CSR from PEM format - var csrData = ParseCsrFromPem(pemCsr); - - // Parse the PKCS#10 structure - var csrInfo = ParsePkcs10Structure(csrData); - - // Validate subject name - ValidateSubjectName(csrInfo.Subject, expectedClientId, expectedTenantId); - - // Validate public key - ValidatePublicKey(csrInfo.PublicKey); - - // Validate CUID attribute - ValidateCuidAttribute(csrInfo.Attributes, expectedCuid); - - // Validate signature algorithm - ValidateSignatureAlgorithm(csrInfo.SignatureAlgorithm); - } - - /// - /// Parses a PEM-formatted CSR and returns the DER bytes. - /// - private byte[] ParseCsrFromPem(string pemCsr) - { - if (string.IsNullOrWhiteSpace(pemCsr)) - throw new ArgumentException("PEM CSR cannot be null or empty"); - - const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; - const string endMarker = "-----END CERTIFICATE REQUEST-----"; - - if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) - throw new ArgumentException("Invalid PEM format - missing CSR headers"); - - int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; - int endIndex = pemCsr.IndexOf(endMarker); - - if (beginIndex >= endIndex) - throw new ArgumentException("Invalid PEM format - malformed headers"); - - string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) - .Replace("\r", "").Replace("\n", "").Replace(" ", ""); - - try - { - return Convert.FromBase64String(base64Content); - } - catch (FormatException) - { - throw new FormatException("Invalid Base64 content in PEM CSR"); - } - } - - /// - /// Represents parsed PKCS#10 CSR information. - /// - private class CsrInfo - { - public byte[] Subject { get; set; } - public byte[] PublicKey { get; set; } - public byte[] Attributes { get; set; } - public byte[] SignatureAlgorithm { get; set; } + TestCsrValidator.ParseCsrFromPem(malformedPem)); } - /// - /// Parses the PKCS#10 ASN.1 structure and extracts key components. - /// - private CsrInfo ParsePkcs10Structure(byte[] derBytes) + [DataTestMethod] + [DataRow("-----BEGIN CERTIFICATE-----\nTUlJQzNqQ0NBY1lDQVFBd1pURT0K\n-----END CERTIFICATE REQUEST-----", DisplayName = "Wrong Headers")] + [DataRow("", DisplayName = "Empty PEM")] + [DataRow(null, DisplayName = "Null PEM")] + public void TestCsrGeneration_MalformedPem_ArgumentException(string malformedPem) { - int offset = 0; - - // Parse outer SEQUENCE (CertificationRequest) - var outerSequence = ParseAsn1Tag(derBytes, ref offset, 0x30); - - // Reset offset to parse the CertificationRequestInfo within the outer sequence - int infoOffset = 0; - var certRequestInfo = ParseAsn1Tag(outerSequence, ref infoOffset, 0x30); - - // Parse version (should be 0) - int versionOffset = 0; - var version = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x02); - if (version.Length != 1 || version[0] != 0x00) - throw new ArgumentException("Invalid CSR version"); - - // Parse subject - var subject = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); - - // Parse SubjectPublicKeyInfo - var publicKey = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); - - // Parse attributes (context-specific [0]) - var attributes = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0xA0); - - return new CsrInfo - { - Subject = subject, - PublicKey = publicKey, - Attributes = attributes, - SignatureAlgorithm = new byte[0] // Simplified for this test - }; - } - - /// - /// Parses an ASN.1 tag and returns its content. - /// - private byte[] ParseAsn1Tag(byte[] data, ref int offset, byte expectedTag) - { - if (offset >= data.Length) - throw new ArgumentException("Unexpected end of data"); - - // Check tag (if expectedTag is -1, accept any tag) - if (expectedTag != 255 && data[offset] != expectedTag) - throw new ArgumentException($"Expected tag 0x{expectedTag:X2}, got 0x{data[offset]:X2}"); - - offset++; - - // Parse length - int length = ParseAsn1Length(data, ref offset); - - // Extract content - if (offset + length > data.Length) - throw new ArgumentException("Invalid ASN.1 length"); - - byte[] content = new byte[length]; - Array.Copy(data, offset, content, 0, length); - offset += length; - - return content; - } - - /// - /// Parses ASN.1 length encoding. - /// - private int ParseAsn1Length(byte[] data, ref int offset) - { - if (offset >= data.Length) - throw new ArgumentException("Unexpected end of data in length"); - - byte firstByte = data[offset++]; - - // Short form - if ((firstByte & 0x80) == 0) - return firstByte; - - // Long form - int lengthBytes = firstByte & 0x7F; - if (lengthBytes == 0) - throw new ArgumentException("Indefinite length not supported"); - - if (offset + lengthBytes > data.Length) - throw new ArgumentException("Invalid length encoding"); - - int length = 0; - for (int i = 0; i < lengthBytes; i++) - { - length = (length << 8) | data[offset++]; - } - - return length; - } - - /// - /// Validates the subject name contains the expected client ID and tenant ID. - /// - private void ValidateSubjectName(byte[] subjectBytes, string expectedClientId, string expectedTenantId) - { - // Subject is already a SEQUENCE of RDNs - int offset = 0; - bool foundClientId = false; - bool foundTenantId = false; - - // Parse each RDN (Relative Distinguished Name) directly from subjectBytes - while (offset < subjectBytes.Length) - { - var rdnSet = ParseAsn1Tag(subjectBytes, ref offset, 0x31); // SET - - int rdnOffset = 0; - var rdnSequence = ParseAsn1Tag(rdnSet, ref rdnOffset, 0x30); // SEQUENCE - - // Parse OID and value - int attrOffset = 0; - var oid = ParseAsn1Tag(rdnSequence, ref attrOffset, 0x06); // OID - var value = ParseAsn1Tag(rdnSequence, ref attrOffset, 255); // Any string type - - string stringValue = System.Text.Encoding.UTF8.GetString(value); - - // Check for CN (commonName) OID: 2.5.4.3 - if (IsOid(oid, new int[] { 2, 5, 4, 3 })) - { - Assert.AreEqual(expectedClientId, stringValue, "Client ID in subject CN does not match"); - foundClientId = true; - } - // Check for DC (domainComponent) OID: 0.9.2342.19200300.100.1.25 - else if (IsOid(oid, new int[] { 0, 9, 2342, 19200300, 100, 1, 25 })) - { - Assert.AreEqual(expectedTenantId, stringValue, "Tenant ID in subject DC does not match"); - foundTenantId = true; - } - } - - Assert.IsTrue(foundClientId, "Client ID (CN) not found in subject"); - Assert.IsTrue(foundTenantId, "Tenant ID (DC) not found in subject"); - } - - /// - /// Validates the public key is a valid RSA key. - /// - private void ValidatePublicKey(byte[] publicKeyBytes) - { - // publicKeyBytes is already the SubjectPublicKeyInfo SEQUENCE content - int offset = 0; - - // Parse algorithm identifier - var algorithmId = ParseAsn1Tag(publicKeyBytes, ref offset, 0x30); - - // Parse public key bit string - var publicKeyBitString = ParseAsn1Tag(publicKeyBytes, ref offset, 0x03); - - // Validate algorithm is RSA (1.2.840.113549.1.1.1) - int algOffset = 0; - var algorithmOid = ParseAsn1Tag(algorithmId, ref algOffset, 0x06); - Assert.IsTrue(IsOid(algorithmOid, new int[] { 1, 2, 840, 113549, 1, 1, 1 }), - "Public key algorithm is not RSA"); - - // Skip the unused bits byte in bit string - if (publicKeyBitString.Length < 2 || publicKeyBitString[0] != 0x00) - throw new ArgumentException("Invalid public key bit string"); - - // Parse RSA public key (skip unused bits byte) - byte[] rsaKeyBytes = new byte[publicKeyBitString.Length - 1]; - Array.Copy(publicKeyBitString, 1, rsaKeyBytes, 0, rsaKeyBytes.Length); - - int rsaOffset = 0; - var rsaSequence = ParseAsn1Tag(rsaKeyBytes, ref rsaOffset, 0x30); - - rsaOffset = 0; - var modulus = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); - var exponent = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); - - // Validate key size (should be 2048 bits = 256 bytes, plus potential leading zero) - Assert.IsTrue(modulus.Length >= 256 && modulus.Length <= 257, - $"RSA modulus should be 2048 bits, got {modulus.Length * 8} bits"); - - // Validate exponent (commonly 65537 = 0x010001) - Assert.IsTrue(exponent.Length >= 1 && exponent.Length <= 4, "RSA exponent has invalid length"); - } - - /// - /// Validates the CUID attribute contains the expected VM and VMSS IDs. - /// - private void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expectedCuid) - { - // Attributes is a SET of attributes - // We expect one attribute with challengePassword OID (1.2.840.113549.1.9.7) - - int offset = 0; - bool foundCuid = false; - - // Parse each attribute in the SET - while (offset < attributesBytes.Length) - { - var attributeSequence = ParseAsn1Tag(attributesBytes, ref offset, 0x30); - - int attrOffset = 0; - var oid = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x06); - var valueSet = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x31); // SET of values - - // Check for challengePassword OID: 1.2.840.113549.1.9.7 - if (IsOid(oid, new int[] { 1, 2, 840, 113549, 1, 9, 7 })) - { - // Parse the value from the SET (should be one value) - int valueOffset = 0; - var value = ParseAsn1Tag(valueSet, ref valueOffset, 255); // Any string type - - string cuidValue = System.Text.Encoding.ASCII.GetString(value); - string expectedCuidValue = $"{expectedCuid.Vmid}:{expectedCuid.Vmssid}"; - - Assert.AreEqual(expectedCuidValue, cuidValue, "CUID attribute value does not match expected"); - foundCuid = true; - break; - } - } - - Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); - } - - /// - /// Validates the signature algorithm is SHA256withRSA. - /// - private void ValidateSignatureAlgorithm(byte[] signatureAlgBytes) - { - // For this test, we'll just verify that signature algorithm exists - // Full validation would require parsing the outer CSR structure - // which is more complex for this unit test scenario - Assert.IsNotNull(signatureAlgBytes, "Signature algorithm should be present"); - } - - /// - /// Checks if the given OID bytes match the expected OID components. - /// - private bool IsOid(byte[] oidBytes, int[] expectedOid) - { - if (expectedOid.Length < 2) - return false; - - var expectedBytes = EncodeOid(expectedOid); - - if (oidBytes.Length != expectedBytes.Length) - return false; - - for (int i = 0; i < oidBytes.Length; i++) - { - if (oidBytes[i] != expectedBytes[i]) - return false; - } - - return true; - } - - /// - /// Encodes an OID from integer components to bytes (simplified version). - /// - private byte[] EncodeOid(int[] oid) - { - if (oid.Length < 2) - throw new ArgumentException("OID must have at least 2 components"); - - var result = new System.Collections.Generic.List(); - - // First two components are encoded as (first * 40 + second) - result.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); - - // Remaining components - for (int i = 2; i < oid.Length; i++) - { - result.AddRange(EncodeOidComponent(oid[i])); - } - - return result.ToArray(); - } - - /// - /// Encodes a single OID component using variable-length encoding. - /// - private byte[] EncodeOidComponent(int value) - { - if (value == 0) - return new byte[] { 0x00 }; - - var bytes = new System.Collections.Generic.List(); - int temp = value; - - bytes.Insert(0, (byte)(temp & 0x7F)); - temp >>= 7; - - while (temp > 0) - { - bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); - temp >>= 7; - } - - return bytes.ToArray(); + Assert.ThrowsException(() => + TestCsrValidator.ParseCsrFromPem(malformedPem)); } - - #endregion } } From 009f948a9b7267d00a0183439f0912cd24f338ad Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 7 Aug 2025 12:02:32 -0400 Subject: [PATCH 05/29] Updates to CUID --- .../ManagedIdentity/CsrMetadata.cs | 3 +- .../ManagedIdentity/CsrRequest.cs | 8 ++--- .../ManagedIdentityTests/CsrValidator.cs | 18 +++++++++-- .../ManagedIdentityTests/ImdsV2Tests.cs | 32 +++++++++++++++---- 4 files changed, 45 insertions(+), 16 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs index 04a9e06baf..a831d02c7a 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs @@ -57,13 +57,12 @@ public CsrMetadata() { } /// Validates a JSON decoded CsrMetadata instance. /// /// The CsrMetadata object. - /// false if any field is null. + /// false if any required field is null. Note: Vmid is required, Vmssid is optional. public static bool ValidateCsrMetadata(CsrMetadata csrMetadata) { if (csrMetadata == null || csrMetadata.Cuid == null || string.IsNullOrEmpty(csrMetadata.Cuid.Vmid) || - string.IsNullOrEmpty(csrMetadata.Cuid.Vmssid) || string.IsNullOrEmpty(csrMetadata.ClientId) || string.IsNullOrEmpty(csrMetadata.TenantId) || string.IsNullOrEmpty(csrMetadata.AttestationEndpoint)) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs index aa692a5b0f..2604a8c31e 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs @@ -4,6 +4,7 @@ using System; using System.Security.Cryptography; using System.Text; +using Microsoft.Identity.Client.Utils; namespace Microsoft.Identity.Client.ManagedIdentity { @@ -21,7 +22,7 @@ public CsrRequest(string pem) /// /// Managed Identity client_id. /// AAD tenant_id. - /// CuidInfo object containing VMID and VMSSID. + /// CuidInfo object containing required VMID and optional VMSSID. /// CsrRequest containing the PEM CSR. public static CsrRequest Generate(string clientId, string tenantId, CuidInfo cuid) { @@ -33,8 +34,6 @@ public static CsrRequest Generate(string clientId, string tenantId, CuidInfo cui throw new ArgumentNullException(nameof(cuid)); if (string.IsNullOrWhiteSpace(cuid.Vmid)) throw new ArgumentException("cuid.Vmid must not be null or empty.", nameof(cuid.Vmid)); - if (string.IsNullOrWhiteSpace(cuid.Vmssid)) - throw new ArgumentException("cuid.Vmssid must not be null or empty.", nameof(cuid.Vmssid)); string pemCsr = GeneratePkcs10Csr(clientId, tenantId, cuid); return new CsrRequest(pemCsr); @@ -156,8 +155,9 @@ private static byte[] BuildAttributes(CuidInfo cuid) var attributes = new System.Collections.Generic.List(); // CUID attribute (OID 1.2.840.113549.1.9.7) + // Serialize CuidInfo as JSON object string using existing JSON serialization byte[] cuidOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 9, 7 }); - string cuidValue = $"{cuid.Vmid}:{cuid.Vmssid}"; + string cuidValue = JsonHelper.SerializeToJson(cuid); byte[] cuidData = EncodeAsn1PrintableString(cuidValue); byte[] cuidAttributeValue = EncodeAsn1Set(new[] { cuidData }); byte[] cuidAttribute = EncodeAsn1Sequence(new[] { cuidOid, cuidAttributeValue }); diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs index b4a4e99dd5..671700c100 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs @@ -3,6 +3,7 @@ using System; using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.Utils; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests @@ -300,7 +301,8 @@ private static void ValidatePublicKey(byte[] publicKeyBytes) } /// - /// Validates the CUID attribute contains the expected VM and VMSS IDs. + /// Validates the CUID attribute contains the expected VM and VMSS IDs as JSON. + /// Note: Vmid is required, Vmssid is optional and will be omitted if null/empty. /// private static void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expectedCuid) { @@ -327,9 +329,11 @@ private static void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expec var value = ParseAsn1Tag(valueSet, ref valueOffset, 255); // Any string type string cuidValue = System.Text.Encoding.ASCII.GetString(value); - string expectedCuidValue = $"{expectedCuid.Vmid}:{expectedCuid.Vmssid}"; - Assert.AreEqual(expectedCuidValue, cuidValue, "CUID attribute value does not match expected"); + // Build expected CUID value as JSON + string expectedCuidValue = BuildExpectedCuidJson(expectedCuid); + + Assert.AreEqual(expectedCuidValue, cuidValue, "CUID attribute JSON value does not match expected"); foundCuid = true; break; } @@ -338,6 +342,14 @@ private static void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expec Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); } + /// + /// Builds the expected CUID JSON string for validation using JsonHelper. + /// + private static string BuildExpectedCuidJson(CuidInfo expectedCuid) + { + return JsonHelper.SerializeToJson(expectedCuid); + } + /// /// Validates the signature algorithm is SHA256withRSA. /// diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index b9ee9b2268..a3b9fd7df8 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -147,10 +147,6 @@ public void TestCsrGeneration() // Generate CSR var csrRequest = CsrRequest.Generate(clientId, tenantId, cuid); - // Output the generated CSR for analysis - System.Console.WriteLine("Generated CSR:"); - System.Console.WriteLine(csrRequest.Pem); - // Validate the CSR contents using the helper CsrValidator.ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); } @@ -188,9 +184,7 @@ public void TestCsrGeneration_NullCuid() [DataTestMethod] [DataRow(null, "test-vmss-id-67890", DisplayName = "Null VMID")] [DataRow("", "test-vmss-id-67890", DisplayName = "Empty VMID")] - [DataRow("test-vm-id-12345", null, DisplayName = "Null VMSSID")] - [DataRow("test-vm-id-12345", "", DisplayName = "Empty VMSSID")] - public void TestCsrGeneration_InvalidCuidProperties(string vmid, string vmssid) + public void TestCsrGeneration_InvalidVmid(string vmid, string vmssid) { string clientId = "12345678-1234-1234-1234-123456789012"; string tenantId = "87654321-4321-4321-4321-210987654321"; @@ -201,10 +195,34 @@ public void TestCsrGeneration_InvalidCuidProperties(string vmid, string vmssid) Vmssid = vmssid }; + // Should throw ArgumentException since Vmid is required Assert.ThrowsException(() => CsrRequest.Generate(clientId, tenantId, cuid)); } + [DataTestMethod] + [DataRow("test-vm-id-12345", null, DisplayName = "Null VMSSID")] + [DataRow("test-vm-id-12345", "", DisplayName = "Empty VMSSID")] + public void TestCsrGeneration_OptionalVmssid(string vmid, string vmssid) + { + string clientId = "12345678-1234-1234-1234-123456789012"; + string tenantId = "87654321-4321-4321-4321-210987654321"; + + var cuid = new CuidInfo + { + Vmid = vmid, + Vmssid = vmssid + }; + + // Should succeed since Vmssid is optional (Vmid is provided and valid) + var csrRequest = CsrRequest.Generate(clientId, tenantId, cuid); + Assert.IsNotNull(csrRequest); + Assert.IsFalse(string.IsNullOrWhiteSpace(csrRequest.Pem)); + + // Validate the CSR contents - this should handle null/empty VMSSID gracefully + CsrValidator.ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); + } + [TestMethod] public void TestCsrGeneration_MalformedPem_FormatException() { From 21d4ef3663cad4c52ec024da257e7ca276309a5f Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 7 Aug 2025 13:57:26 -0400 Subject: [PATCH 06/29] Unit test improvements --- .../TestConstants.cs | 2 + .../ManagedIdentityTests/ImdsV2Tests.cs | 56 +++++++------------ 2 files changed, 23 insertions(+), 35 deletions(-) diff --git a/tests/Microsoft.Identity.Test.Common/TestConstants.cs b/tests/Microsoft.Identity.Test.Common/TestConstants.cs index 3d89cc1bbe..d4a63354c0 100644 --- a/tests/Microsoft.Identity.Test.Common/TestConstants.cs +++ b/tests/Microsoft.Identity.Test.Common/TestConstants.cs @@ -154,6 +154,8 @@ public static HashSet s_scope public const string IdentityProvider = "my-idp"; public const string Name = "First Last"; public const string MiResourceId = "/subscriptions/ffa4aaa2-4444-4444-5555-e3ccedd3d046/resourcegroups/UAMI_group/providers/Microsoft.ManagedIdentityClient/userAssignedIdentities/UAMI"; + public const string Vmid = "test-vm-id"; + public const string Vmssid = "test-vmss-id"; public const string Claims = @"{""userinfo"":{""given_name"":{""essential"":true},""nickname"":null,""email"":{""essential"":true},""email_verified"":{""essential"":true},""picture"":null,""http://example.info/claims/groups"":null},""id_token"":{""auth_time"":{""essential"":true},""acr"":{""values"":[""urn:mace:incommon:iap:silver""]}}}"; public static readonly string[] ClientCapabilities = new[] { "cp1", "cp2" }; diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index a3b9fd7df8..2e09ace05e 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -137,33 +137,28 @@ public void TestCsrGeneration() { var cuid = new CuidInfo { - Vmid = "test-vm-id-12345", - Vmssid = "test-vmss-id-67890" + Vmid = TestConstants.Vmid, + Vmssid = TestConstants.Vmssid }; - string clientId = "12345678-1234-1234-1234-123456789012"; - string tenantId = "87654321-4321-4321-4321-210987654321"; - // Generate CSR - var csrRequest = CsrRequest.Generate(clientId, tenantId, cuid); + var csrRequest = CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); // Validate the CSR contents using the helper - CsrValidator.ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); + CsrValidator.ValidateCsrContent(csrRequest.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); } [DataTestMethod] - [DataRow(null, "87654321-4321-4321-4321-210987654321", DisplayName = "Null ClientId")] - [DataRow("", "87654321-4321-4321-4321-210987654321", DisplayName = "Empty ClientId")] - [DataRow(" ", "87654321-4321-4321-4321-210987654321", DisplayName = "Whitespace ClientId")] - [DataRow("12345678-1234-1234-1234-123456789012", null, DisplayName = "Null TenantId")] - [DataRow("12345678-1234-1234-1234-123456789012", "", DisplayName = "Empty TenantId")] - [DataRow("12345678-1234-1234-1234-123456789012", " ", DisplayName = "Whitespace TenantId")] + [DataRow(null, TestConstants.TenantId)] + [DataRow("", TestConstants.TenantId)] + [DataRow(TestConstants.ClientId, null)] + [DataRow(TestConstants.ClientId, "")] public void TestCsrGeneration_InvalidParameters(string clientId, string tenantId) { var cuid = new CuidInfo { - Vmid = "test-vm-id-12345", - Vmssid = "test-vmss-id-67890" + Vmid = TestConstants.Vmid, + Vmssid = TestConstants.Vmssid }; Assert.ThrowsException(() => @@ -173,22 +168,16 @@ public void TestCsrGeneration_InvalidParameters(string clientId, string tenantId [TestMethod] public void TestCsrGeneration_NullCuid() { - string clientId = "12345678-1234-1234-1234-123456789012"; - string tenantId = "87654321-4321-4321-4321-210987654321"; - // Test with null CUID Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, null)); + CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, null)); } [DataTestMethod] - [DataRow(null, "test-vmss-id-67890", DisplayName = "Null VMID")] - [DataRow("", "test-vmss-id-67890", DisplayName = "Empty VMID")] + [DataRow(null, TestConstants.Vmssid)] + [DataRow("", TestConstants.Vmssid)] public void TestCsrGeneration_InvalidVmid(string vmid, string vmssid) { - string clientId = "12345678-1234-1234-1234-123456789012"; - string tenantId = "87654321-4321-4321-4321-210987654321"; - var cuid = new CuidInfo { Vmid = vmid, @@ -197,17 +186,14 @@ public void TestCsrGeneration_InvalidVmid(string vmid, string vmssid) // Should throw ArgumentException since Vmid is required Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, cuid)); + CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid)); } [DataTestMethod] - [DataRow("test-vm-id-12345", null, DisplayName = "Null VMSSID")] - [DataRow("test-vm-id-12345", "", DisplayName = "Empty VMSSID")] + [DataRow(TestConstants.Vmid, null)] + [DataRow(TestConstants.Vmid, "")] public void TestCsrGeneration_OptionalVmssid(string vmid, string vmssid) { - string clientId = "12345678-1234-1234-1234-123456789012"; - string tenantId = "87654321-4321-4321-4321-210987654321"; - var cuid = new CuidInfo { Vmid = vmid, @@ -215,12 +201,12 @@ public void TestCsrGeneration_OptionalVmssid(string vmid, string vmssid) }; // Should succeed since Vmssid is optional (Vmid is provided and valid) - var csrRequest = CsrRequest.Generate(clientId, tenantId, cuid); + var csrRequest = CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); Assert.IsNotNull(csrRequest); Assert.IsFalse(string.IsNullOrWhiteSpace(csrRequest.Pem)); // Validate the CSR contents - this should handle null/empty VMSSID gracefully - CsrValidator.ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); + CsrValidator.ValidateCsrContent(csrRequest.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); } [TestMethod] @@ -232,9 +218,9 @@ public void TestCsrGeneration_MalformedPem_FormatException() } [DataTestMethod] - [DataRow("-----BEGIN CERTIFICATE-----\nTUlJQzNqQ0NBY1lDQVFBd1pURT0K\n-----END CERTIFICATE REQUEST-----", DisplayName = "Wrong Headers")] - [DataRow("", DisplayName = "Empty PEM")] - [DataRow(null, DisplayName = "Null PEM")] + [DataRow("-----BEGIN CERTIFICATE-----\nTUlJQzNqQ0NBY1lDQVFBd1pURT0K\n-----END CERTIFICATE REQUEST-----")] + [DataRow("")] + [DataRow(null)] public void TestCsrGeneration_MalformedPem_ArgumentException(string malformedPem) { Assert.ThrowsException(() => From cd013a33c09d4c81b13206edd625d9941c593855 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 7 Aug 2025 14:15:05 -0400 Subject: [PATCH 07/29] Implemented Feedback --- ...e.cs => ClientCredentialRequestResponse.cs} | 18 +++++++++--------- .../ManagedIdentity/CsrRequest.cs | 14 +++++++------- .../ImdsV2ManagedIdentitySource.cs | 16 ++++++++-------- .../Microsoft.Identity.Client.csproj | 5 +++++ .../ManagedIdentityTests/ImdsV2Tests.cs | 12 ++++++------ 5 files changed, 35 insertions(+), 30 deletions(-) rename src/client/Microsoft.Identity.Client/ManagedIdentity/{CsrRequestResponse.cs => ClientCredentialRequestResponse.cs} (60%) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs similarity index 60% rename from src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs rename to src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs index 10274e48ba..924dd75ad1 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs @@ -13,7 +13,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity /// /// Represents the response for a Managed Identity CSR request. /// - internal class CsrRequestResponse + internal class ClientCredentialRequestResponse { [JsonProperty("client_id")] public string ClientId { get; } @@ -33,16 +33,16 @@ internal class CsrRequestResponse [JsonProperty("refresh_in")] public int RefreshIn { get; } - public CsrRequestResponse() { } + public ClientCredentialRequestResponse() { } - public static bool ValidateCsrRequestResponse(CsrRequestResponse csrRequestResponse) + public static bool ValidateCsrRequestResponse(ClientCredentialRequestResponse clientCredentialRequestResponse) { - if (string.IsNullOrEmpty(csrRequestResponse.ClientId) || - string.IsNullOrEmpty(csrRequestResponse.TenantId) || - string.IsNullOrEmpty(csrRequestResponse.ClientCredential) || - string.IsNullOrEmpty(csrRequestResponse.RegionalTokenUrl) || - csrRequestResponse.ExpiresIn <= 0 || - csrRequestResponse.RefreshIn <= 0) + if (string.IsNullOrEmpty(clientCredentialRequestResponse.ClientId) || + string.IsNullOrEmpty(clientCredentialRequestResponse.TenantId) || + string.IsNullOrEmpty(clientCredentialRequestResponse.ClientCredential) || + string.IsNullOrEmpty(clientCredentialRequestResponse.RegionalTokenUrl) || + clientCredentialRequestResponse.ExpiresIn <= 0 || + clientCredentialRequestResponse.RefreshIn <= 0) { return false; } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs index 2604a8c31e..c3b05ec34e 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs @@ -8,11 +8,11 @@ namespace Microsoft.Identity.Client.ManagedIdentity { - internal class CsrRequest + internal class Csr { public string Pem { get; } - public CsrRequest(string pem) + public Csr(string pem) { Pem = pem ?? throw new ArgumentNullException(nameof(pem)); } @@ -24,19 +24,19 @@ public CsrRequest(string pem) /// AAD tenant_id. /// CuidInfo object containing required VMID and optional VMSSID. /// CsrRequest containing the PEM CSR. - public static CsrRequest Generate(string clientId, string tenantId, CuidInfo cuid) + public static Csr Generate(string clientId, string tenantId, CuidInfo cuid) { - if (string.IsNullOrWhiteSpace(clientId)) + if (string.IsNullOrEmpty(clientId)) throw new ArgumentException("clientId must not be null or empty.", nameof(clientId)); - if (string.IsNullOrWhiteSpace(tenantId)) + if (string.IsNullOrEmpty(tenantId)) throw new ArgumentException("tenantId must not be null or empty.", nameof(tenantId)); if (cuid == null) throw new ArgumentNullException(nameof(cuid)); - if (string.IsNullOrWhiteSpace(cuid.Vmid)) + if (string.IsNullOrEmpty(cuid.Vmid)) throw new ArgumentException("cuid.Vmid must not be null or empty.", nameof(cuid.Vmid)); string pemCsr = GeneratePkcs10Csr(clientId, tenantId, cuid); - return new CsrRequest(pemCsr); + return new Csr(pemCsr); } /// diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 08bbc9caf4..6787b9dfc9 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -17,7 +17,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; - private const string CsrRequestPath = "/metadata/identity/issuecredential"; + private const string ClientCredentialRequestPath = "/metadata/identity/issuecredential"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, @@ -195,7 +195,7 @@ public static AbstractManagedIdentity Create(RequestContext requestContext) internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.ImdsV2) { } - private async Task ExecuteCsrRequestAsync( + private async Task ExecuteClientCredentialRequestAsync( RequestContext requestContext, string queryParams, string pem) @@ -214,7 +214,7 @@ private async Task ExecuteCsrRequestAsync( try { response = await requestContext.ServiceBundle.HttpManager.SendRequestAsync( - ImdsManagedIdentitySource.GetValidatedEndpoint(requestContext.Logger, CsrRequestPath, queryParams), + ImdsManagedIdentitySource.GetValidatedEndpoint(requestContext.Logger, ClientCredentialRequestPath, queryParams), headers, body: new StringContent($"{{\"pem\":\"{pem}\"}}", System.Text.Encoding.UTF8, "application/json"), method: HttpMethod.Post, @@ -236,8 +236,8 @@ private async Task ExecuteCsrRequestAsync( (int)response.StatusCode); } - var csrRequestResponse = JsonHelper.DeserializeFromJson(response.Body); - if (!CsrRequestResponse.ValidateCsrRequestResponse(csrRequestResponse)) + var clientCredentialRequestResponse = JsonHelper.DeserializeFromJson(response.Body); + if (!ClientCredentialRequestResponse.ValidateCsrRequestResponse(clientCredentialRequestResponse)) { throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, @@ -247,13 +247,13 @@ private async Task ExecuteCsrRequestAsync( (int)response.StatusCode); } - return csrRequestResponse; + return clientCredentialRequestResponse; } protected override ManagedIdentityRequest CreateRequest(string resource) { var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); - var csrRequest = CsrRequest.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); + var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); var queryParams = $"cid={csrMetadata.Cuid}"; if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) @@ -262,7 +262,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) } queryParams += $"&api-version={ImdsManagedIdentitySource.ImdsApiVersion}"; - var csrRequestResponse = ExecuteCsrRequestAsync(_requestContext, queryParams, csrRequest.Pem); + var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(_requestContext, queryParams, csr.Pem); throw new NotImplementedException(); } diff --git a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj index 578bb27e45..6c52e6dded 100644 --- a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj +++ b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj @@ -82,6 +82,7 @@ + @@ -163,4 +164,8 @@ + + + + diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 2e09ace05e..6851b425e3 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -142,10 +142,10 @@ public void TestCsrGeneration() }; // Generate CSR - var csrRequest = CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); + var csr = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); // Validate the CSR contents using the helper - CsrValidator.ValidateCsrContent(csrRequest.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); + CsrValidator.ValidateCsrContent(csr.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); } [DataTestMethod] @@ -162,7 +162,7 @@ public void TestCsrGeneration_InvalidParameters(string clientId, string tenantId }; Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, cuid)); + Csr.Generate(clientId, tenantId, cuid)); } [TestMethod] @@ -170,7 +170,7 @@ public void TestCsrGeneration_NullCuid() { // Test with null CUID Assert.ThrowsException(() => - CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, null)); + Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, null)); } [DataTestMethod] @@ -186,7 +186,7 @@ public void TestCsrGeneration_InvalidVmid(string vmid, string vmssid) // Should throw ArgumentException since Vmid is required Assert.ThrowsException(() => - CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid)); + Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid)); } [DataTestMethod] @@ -201,7 +201,7 @@ public void TestCsrGeneration_OptionalVmssid(string vmid, string vmssid) }; // Should succeed since Vmssid is optional (Vmid is provided and valid) - var csrRequest = CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); + var csrRequest = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); Assert.IsNotNull(csrRequest); Assert.IsFalse(string.IsNullOrWhiteSpace(csrRequest.Pem)); From 480ae9ea4174cfa118ef1489a548cefe5f302c2f Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 7 Aug 2025 14:16:57 -0400 Subject: [PATCH 08/29] renamed file --- .../ManagedIdentity/{CsrRequest.cs => Csr.cs} | 0 .../Microsoft.Identity.Client/Microsoft.Identity.Client.csproj | 2 ++ 2 files changed, 2 insertions(+) rename src/client/Microsoft.Identity.Client/ManagedIdentity/{CsrRequest.cs => Csr.cs} (100%) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs similarity index 100% rename from src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs rename to src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs diff --git a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj index 6c52e6dded..468292d402 100644 --- a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj +++ b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj @@ -83,6 +83,7 @@ + @@ -167,5 +168,6 @@ + From 0aa869281000cd7b933fae3781a0d120168e61a6 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 16:36:12 -0400 Subject: [PATCH 09/29] small improvement --- .../ImdsV2ManagedIdentitySource.cs | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 6787b9dfc9..fe49b73266 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -196,33 +196,39 @@ internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.ImdsV2) { } private async Task ExecuteClientCredentialRequestAsync( - RequestContext requestContext, - string queryParams, + CuidInfo Cuid, string pem) { + var queryParams = $"cid={Cuid}"; + if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) + { + queryParams += $"&uaid{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; + } + queryParams += $"&api-version={ImdsManagedIdentitySource.ImdsApiVersion}"; + var headers = new Dictionary { { "Metadata", "true" }, - { "x-ms-client-request-id", requestContext.CorrelationId.ToString() } + { "x-ms-client-request-id", _requestContext.CorrelationId.ToString() } }; - IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.Imds); HttpResponse response = null; try { - response = await requestContext.ServiceBundle.HttpManager.SendRequestAsync( - ImdsManagedIdentitySource.GetValidatedEndpoint(requestContext.Logger, ClientCredentialRequestPath, queryParams), + response = await _requestContext.ServiceBundle.HttpManager.SendRequestAsync( + ImdsManagedIdentitySource.GetValidatedEndpoint(_requestContext.Logger, ClientCredentialRequestPath, queryParams), headers, body: new StringContent($"{{\"pem\":\"{pem}\"}}", System.Text.Encoding.UTF8, "application/json"), method: HttpMethod.Post, - logger: requestContext.Logger, + logger: _requestContext.Logger, doNotThrow: false, mtlsCertificate: null, validateServerCertificate: null, - cancellationToken: requestContext.UserCancellationToken, + cancellationToken: _requestContext.UserCancellationToken, retryPolicy: retryPolicy) .ConfigureAwait(false); } @@ -255,14 +261,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); - var queryParams = $"cid={csrMetadata.Cuid}"; - if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) - { - queryParams += $"&uaid{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; - } - queryParams += $"&api-version={ImdsManagedIdentitySource.ImdsApiVersion}"; - - var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(_requestContext, queryParams, csr.Pem); + var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(csrMetadata.Cuid, csr.Pem); throw new NotImplementedException(); } From 621c5662b3b5b26b8e1189bc8e6f8ed62798a3e4 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 17:38:56 -0400 Subject: [PATCH 10/29] added missing awaitor for async method --- global.json | 2 +- .../ManagedIdentity/ImdsV2ManagedIdentitySource.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/global.json b/global.json index 66e4a5c8a7..e5135e9ff3 100644 --- a/global.json +++ b/global.json @@ -1,6 +1,6 @@ { "sdk": { - "version": "8.0.404", + "version": "9.0.0", "rollForward": "latestFeature" } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index fe49b73266..c2151a3b66 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -261,7 +261,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); - var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(csrMetadata.Cuid, csr.Pem); + var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(csrMetadata.Cuid, csr.Pem).GetAwaiter().GetResult(); throw new NotImplementedException(); } From 068461b344757548f01a734ce8fa88f1e4cabefb Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 19:17:19 -0400 Subject: [PATCH 11/29] Fixed bugs discovered from unit testing in child branch --- .../ManagedIdentity/ImdsV2ManagedIdentitySource.cs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index c2151a3b66..081e353f48 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -18,6 +18,7 @@ internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; private const string ClientCredentialRequestPath = "/metadata/identity/issuecredential"; + private const string AcquireEntraTokenPath = "/oauth2/v2.0/token"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, @@ -199,7 +200,7 @@ private async Task ExecuteClientCredentialReque CuidInfo Cuid, string pem) { - var queryParams = $"cid={Cuid}"; + var queryParams = $"cid={JsonHelper.SerializeToJson(Cuid)}"; if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) { queryParams += $"&uaid{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; @@ -212,6 +213,12 @@ private async Task ExecuteClientCredentialReque { "x-ms-client-request-id", _requestContext.CorrelationId.ToString() } }; + var payload = new + { + pem = pem + }; + var body = JsonHelper.SerializeToJson(payload); + IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.Imds); @@ -222,7 +229,7 @@ private async Task ExecuteClientCredentialReque response = await _requestContext.ServiceBundle.HttpManager.SendRequestAsync( ImdsManagedIdentitySource.GetValidatedEndpoint(_requestContext.Logger, ClientCredentialRequestPath, queryParams), headers, - body: new StringContent($"{{\"pem\":\"{pem}\"}}", System.Text.Encoding.UTF8, "application/json"), + body: new StringContent(body, System.Text.Encoding.UTF8, "application/json"), method: HttpMethod.Post, logger: _requestContext.Logger, doNotThrow: false, From 2034b25487a33eb8b42baae6547d900e7064ecef Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 19:20:36 -0400 Subject: [PATCH 12/29] undid changes to .proj --- .../Microsoft.Identity.Client.csproj | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj index 468292d402..578bb27e45 100644 --- a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj +++ b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj @@ -82,8 +82,6 @@ - - @@ -165,9 +163,4 @@ - - - - - From 2b7486a5d5060e93405bf619cab32c4356bcb90e Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 19:26:09 -0400 Subject: [PATCH 13/29] undid change to global.json --- global.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/global.json b/global.json index e5135e9ff3..66e4a5c8a7 100644 --- a/global.json +++ b/global.json @@ -1,6 +1,6 @@ { "sdk": { - "version": "9.0.0", + "version": "8.0.404", "rollForward": "latestFeature" } } From 189ff9e9f79db3b0177af10ae3d337aad2f28d8d Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 19:32:26 -0400 Subject: [PATCH 14/29] added missing sets --- .../ClientCredentialRequestResponse.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs index 924dd75ad1..efec6a1487 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs @@ -16,22 +16,22 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class ClientCredentialRequestResponse { [JsonProperty("client_id")] - public string ClientId { get; } + public string ClientId { get; set; } [JsonProperty("tenant_id")] - public string TenantId { get; } + public string TenantId { get; set; } [JsonProperty("client_credential")] - public string ClientCredential { get; } + public string ClientCredential { get; set; } [JsonProperty("regional_token_url")] - public string RegionalTokenUrl { get; } + public string RegionalTokenUrl { get; set; } [JsonProperty("expires_in")] - public int ExpiresIn { get; } + public int ExpiresIn { get; set; } [JsonProperty("refresh_in")] - public int RefreshIn { get; } + public int RefreshIn { get; set; } public ClientCredentialRequestResponse() { } From 92b325fd897a86a9ee674cd443e7dda56b56d9c9 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Mon, 11 Aug 2025 15:40:39 -0400 Subject: [PATCH 15/29] Inplemented some feedback --- .../ManagedIdentity/ClientCredentialRequestResponse.cs | 2 +- .../ManagedIdentity/ImdsV2ManagedIdentitySource.cs | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs index efec6a1487..22d92566af 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs @@ -35,7 +35,7 @@ internal class ClientCredentialRequestResponse public ClientCredentialRequestResponse() { } - public static bool ValidateCsrRequestResponse(ClientCredentialRequestResponse clientCredentialRequestResponse) + public static bool IsValid(ClientCredentialRequestResponse clientCredentialRequestResponse) { if (string.IsNullOrEmpty(clientCredentialRequestResponse.ClientId) || string.IsNullOrEmpty(clientCredentialRequestResponse.TenantId) || diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 081e353f48..1e54224700 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -18,7 +18,6 @@ internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; private const string ClientCredentialRequestPath = "/metadata/identity/issuecredential"; - private const string AcquireEntraTokenPath = "/oauth2/v2.0/token"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, @@ -250,7 +249,7 @@ private async Task ExecuteClientCredentialReque } var clientCredentialRequestResponse = JsonHelper.DeserializeFromJson(response.Body); - if (!ClientCredentialRequestResponse.ValidateCsrRequestResponse(clientCredentialRequestResponse)) + if (!ClientCredentialRequestResponse.IsValid(clientCredentialRequestResponse)) { throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, From 067c83c58efde7ae49a00dc986607ea5212f97a5 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 14 Aug 2025 14:30:01 -0400 Subject: [PATCH 16/29] Implemented some feedback --- ...ponse.cs => CertificateRequestResponse.cs} | 21 +++++++++---------- .../ImdsV2ManagedIdentitySource.cs | 14 ++++++------- 2 files changed, 17 insertions(+), 18 deletions(-) rename src/client/Microsoft.Identity.Client/ManagedIdentity/{ClientCredentialRequestResponse.cs => CertificateRequestResponse.cs} (57%) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs similarity index 57% rename from src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs rename to src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs index 22d92566af..4391fba4be 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs @@ -4,8 +4,7 @@ #if SUPPORTS_SYSTEM_TEXT_JSON using JsonProperty = System.Text.Json.Serialization.JsonPropertyNameAttribute; #else -using Microsoft.Identity.Client.Utils; -using Microsoft.Identity.Json; + using Microsoft.Identity.Json; #endif namespace Microsoft.Identity.Client.ManagedIdentity @@ -13,7 +12,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity /// /// Represents the response for a Managed Identity CSR request. /// - internal class ClientCredentialRequestResponse + internal class CertificateRequestResponse { [JsonProperty("client_id")] public string ClientId { get; set; } @@ -33,16 +32,16 @@ internal class ClientCredentialRequestResponse [JsonProperty("refresh_in")] public int RefreshIn { get; set; } - public ClientCredentialRequestResponse() { } + public CertificateRequestResponse() { } - public static bool IsValid(ClientCredentialRequestResponse clientCredentialRequestResponse) + public static bool IsValid(CertificateRequestResponse certificateRequestResponse) { - if (string.IsNullOrEmpty(clientCredentialRequestResponse.ClientId) || - string.IsNullOrEmpty(clientCredentialRequestResponse.TenantId) || - string.IsNullOrEmpty(clientCredentialRequestResponse.ClientCredential) || - string.IsNullOrEmpty(clientCredentialRequestResponse.RegionalTokenUrl) || - clientCredentialRequestResponse.ExpiresIn <= 0 || - clientCredentialRequestResponse.RefreshIn <= 0) + if (string.IsNullOrEmpty(certificateRequestResponse.ClientId) || + string.IsNullOrEmpty(certificateRequestResponse.TenantId) || + string.IsNullOrEmpty(certificateRequestResponse.ClientCredential) || + string.IsNullOrEmpty(certificateRequestResponse.RegionalTokenUrl) || + certificateRequestResponse.ExpiresIn <= 0 || + certificateRequestResponse.RefreshIn <= 0) { return false; } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 1e54224700..08ec8313cf 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -17,7 +17,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; - private const string ClientCredentialRequestPath = "/metadata/identity/issuecredential"; + private const string CertificateRequestPath = "/metadata/identity/issuecredential"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, @@ -195,7 +195,7 @@ public static AbstractManagedIdentity Create(RequestContext requestContext) internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.ImdsV2) { } - private async Task ExecuteClientCredentialRequestAsync( + private async Task ExecuteCertificateRequestAsync( CuidInfo Cuid, string pem) { @@ -226,7 +226,7 @@ private async Task ExecuteClientCredentialReque try { response = await _requestContext.ServiceBundle.HttpManager.SendRequestAsync( - ImdsManagedIdentitySource.GetValidatedEndpoint(_requestContext.Logger, ClientCredentialRequestPath, queryParams), + ImdsManagedIdentitySource.GetValidatedEndpoint(_requestContext.Logger, CertificateRequestPath, queryParams), headers, body: new StringContent(body, System.Text.Encoding.UTF8, "application/json"), method: HttpMethod.Post, @@ -248,8 +248,8 @@ private async Task ExecuteClientCredentialReque (int)response.StatusCode); } - var clientCredentialRequestResponse = JsonHelper.DeserializeFromJson(response.Body); - if (!ClientCredentialRequestResponse.IsValid(clientCredentialRequestResponse)) + var certificateRequestResponse = JsonHelper.DeserializeFromJson(response.Body); + if (!CertificateRequestResponse.IsValid(certificateRequestResponse)) { throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, @@ -259,7 +259,7 @@ private async Task ExecuteClientCredentialReque (int)response.StatusCode); } - return clientCredentialRequestResponse; + return certificateRequestResponse; } protected override ManagedIdentityRequest CreateRequest(string resource) @@ -267,7 +267,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); - var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(csrMetadata.Cuid, csr.Pem).GetAwaiter().GetResult(); + var certificateRequestResponse = ExecuteCertificateRequestAsync(csrMetadata.Cuid, csr.Pem).GetAwaiter().GetResult(); throw new NotImplementedException(); } From f7d6f881386099c1f776e6362783fb6f76e0c4fe Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 15 Aug 2025 13:54:41 -0400 Subject: [PATCH 17/29] PKCS1 -> Pss padding --- .../ManagedIdentity/Csr.cs | 78 ++++++++++++++----- 1 file changed, 57 insertions(+), 21 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs index c3b05ec34e..5599231411 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs @@ -67,18 +67,29 @@ private static string GeneratePkcs10Csr(string clientId, string tenantId, CuidIn } /// - /// Creates a 2048-bit RSA key pair compatible with all target frameworks. + /// Creates a 2048-bit RSA key pair that supports PSS padding across all target frameworks. /// + /// + /// On .NET Framework 4.6.2/4.7.2 (Windows-only), explicitly uses RSACng (also Windows-only) + /// to ensure PSS padding support, as RSA.Create() may return RSACryptoServiceProvider + /// which doesn't support PSS. + /// On .NET Standard 2.0 and .NET 8.0+ (cross-platform), uses RSA.Create() which returns + /// modern implementations that support PSS: RSACng on Windows, OpenSSL-based on Linux/macOS. + /// + /// An RSA instance configured for 2048-bit keys with PSS padding capability. private static RSA CreateRsaKeyPair() { + RSA rsa = null; + #if NET462 || NET472 - var rsa = new RSACryptoServiceProvider(2048); - return rsa; + // .NET Framework runs only on Windows, so RSACng (Windows-only) is always available + rsa = new System.Security.Cryptography.RSACng(); #else - var rsa = RSA.Create(); + // Cross-platform .NET - RSA.Create() returns appropriate PSS-capable implementation + rsa = RSA.Create(); +#endif rsa.KeySize = 2048; return rsa; -#endif } /// @@ -167,31 +178,56 @@ private static byte[] BuildAttributes(CuidInfo cuid) } /// - /// Builds the signature algorithm identifier for SHA256withRSA. + /// Builds the signature algorithm identifier for RSASSA-PSS with SHA256. /// private static byte[] BuildSignatureAlgorithmIdentifier() { - byte[] sha256WithRsaOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 11 }); // SHA256withRSA OID - byte[] nullParam = EncodeAsn1Null(); - return EncodeAsn1Sequence(new[] { sha256WithRsaOid, nullParam }); + byte[] rsassaPssOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 10 }); // RSASSA-PSS OID + byte[] pssParams = BuildPssParameters(); + return EncodeAsn1Sequence(new[] { rsassaPssOid, pssParams }); } /// - /// Signs the CertificationRequestInfo with SHA256withRSA. + /// Builds the RSASSA-PSS parameters for SHA256 with MGF1. + /// + private static byte[] BuildPssParameters() + { + var parameters = new System.Collections.Generic.List(); + + // hashAlgorithm [0] AlgorithmIdentifier DEFAULT sha1 + // We explicitly specify SHA256 since default is SHA1 + byte[] sha256Oid = EncodeAsn1ObjectIdentifier(new int[] { 2, 16, 840, 1, 101, 3, 4, 2, 1 }); // SHA256 OID + byte[] sha256Null = EncodeAsn1Null(); + byte[] hashAlgorithm = EncodeAsn1Sequence(new[] { sha256Oid, sha256Null }); + byte[] hashAlgorithmParam = EncodeAsn1ContextSpecific(0, hashAlgorithm); + parameters.Add(hashAlgorithmParam); + + // maskGenAlgorithm [1] AlgorithmIdentifier DEFAULT mgf1SHA1 + // We explicitly specify MGF1 with SHA256 + byte[] mgf1Oid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 8 }); // MGF1 OID + byte[] mgf1HashAlgorithm = EncodeAsn1Sequence(new[] { sha256Oid, sha256Null }); // MGF1 uses SHA256 + byte[] maskGenAlgorithm = EncodeAsn1Sequence(new[] { mgf1Oid, mgf1HashAlgorithm }); + byte[] maskGenAlgorithmParam = EncodeAsn1ContextSpecific(1, maskGenAlgorithm); + parameters.Add(maskGenAlgorithmParam); + + // saltLength [2] INTEGER DEFAULT 20 + // We explicitly specify 32 for SHA256 (hash length) + byte[] saltLength = EncodeAsn1Integer(32); + byte[] saltLengthParam = EncodeAsn1ContextSpecific(2, saltLength); + parameters.Add(saltLengthParam); + + // trailerField [3] INTEGER DEFAULT 1 + // Default value is 1 (0xBC), so we omit this parameter + + return EncodeAsn1Sequence(parameters.ToArray()); + } + + /// + /// Signs the CertificationRequestInfo with SHA256withRSA-PSS. /// private static byte[] SignCertificationRequestInfo(byte[] certificationRequestInfo, RSA rsa) { -#if NET462 || NET472 - using (var sha256 = SHA256.Create()) - { - byte[] hash = sha256.ComputeHash(certificationRequestInfo); - var formatter = new RSAPKCS1SignatureFormatter(rsa); - formatter.SetHashAlgorithm("SHA256"); - return formatter.CreateSignature(hash); - } -#else - return rsa.SignData(certificationRequestInfo, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); -#endif + return rsa.SignData(certificationRequestInfo, HashAlgorithmName.SHA256, RSASignaturePadding.Pss); } /// From 74e8e606d014439c750156b65fd628d07e2e1ae7 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 15 Aug 2025 15:13:04 -0400 Subject: [PATCH 18/29] re-used imports --- .../ManagedIdentity/Csr.cs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs index 5599231411..c35a93e4a9 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Collections.Generic; using System.Security.Cryptography; using System.Text; using Microsoft.Identity.Client.Utils; @@ -83,7 +84,7 @@ private static RSA CreateRsaKeyPair() #if NET462 || NET472 // .NET Framework runs only on Windows, so RSACng (Windows-only) is always available - rsa = new System.Security.Cryptography.RSACng(); + rsa = new RSACng(); #else // Cross-platform .NET - RSA.Create() returns appropriate PSS-capable implementation rsa = RSA.Create(); @@ -97,7 +98,7 @@ private static RSA CreateRsaKeyPair() /// private static byte[] BuildCertificationRequestInfo(string clientId, string tenantId, CuidInfo cuid, RSA rsa) { - var components = new System.Collections.Generic.List(); + var components = new List(); // Version (INTEGER 0) components.Add(EncodeAsn1Integer(new byte[] { 0x00 })); @@ -119,7 +120,7 @@ private static byte[] BuildCertificationRequestInfo(string clientId, string tena /// private static byte[] BuildSubjectName(string clientId, string tenantId) { - var rdnSequence = new System.Collections.Generic.List(); + var rdnSequence = new List(); // CN= byte[] cnOid = EncodeAsn1ObjectIdentifier(new int[] { 2, 5, 4, 3 }); // commonName OID @@ -163,7 +164,7 @@ private static byte[] BuildSubjectPublicKeyInfo(RSA rsa) /// private static byte[] BuildAttributes(CuidInfo cuid) { - var attributes = new System.Collections.Generic.List(); + var attributes = new List(); // CUID attribute (OID 1.2.840.113549.1.9.7) // Serialize CuidInfo as JSON object string using existing JSON serialization @@ -192,7 +193,7 @@ private static byte[] BuildSignatureAlgorithmIdentifier() /// private static byte[] BuildPssParameters() { - var parameters = new System.Collections.Generic.List(); + var parameters = new List(); // hashAlgorithm [0] AlgorithmIdentifier DEFAULT sha1 // We explicitly specify SHA256 since default is SHA1 @@ -309,7 +310,7 @@ private static byte[] EncodeAsn1Integer(int value) if (value == 0) return EncodeAsn1Tag(0x02, new byte[] { 0x00 }); - var bytes = new System.Collections.Generic.List(); + var bytes = new List(); int temp = value; while (temp > 0) { @@ -365,7 +366,7 @@ private static byte[] EncodeAsn1ObjectIdentifier(int[] oid) if (oid == null || oid.Length < 2) throw new ArgumentException("OID must have at least 2 components"); - var bytes = new System.Collections.Generic.List(); + var bytes = new List(); // First two components are encoded as (first * 40 + second) bytes.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); @@ -411,7 +412,7 @@ private static byte[] EncodeAsn1Length(int length) return new byte[] { (byte)length }; } - var lengthBytes = new System.Collections.Generic.List(); + var lengthBytes = new List(); int temp = length; while (temp > 0) { @@ -433,7 +434,7 @@ private static byte[] EncodeOidComponent(int value) if (value == 0) return new byte[] { 0x00 }; - var bytes = new System.Collections.Generic.List(); + var bytes = new List(); int temp = value; bytes.Insert(0, (byte)(temp & 0x7F)); From 152f396704046f77d37e8fca55b7c82993156504 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 15 Aug 2025 15:51:12 -0400 Subject: [PATCH 19/29] Implemented feedback --- .../ManagedIdentity/Csr.cs | 1 + .../ManagedIdentity/ImdsV2ManagedIdentitySource.cs | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs index c35a93e4a9..fdc8584cd3 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs @@ -80,6 +80,7 @@ private static string GeneratePkcs10Csr(string clientId, string tenantId, CuidIn /// An RSA instance configured for 2048-bit keys with PSS padding capability. private static RSA CreateRsaKeyPair() { + // TODO: use the strongest key on the machine i.e. a TPM key RSA rsa = null; #if NET462 || NET472 diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 08ec8313cf..4d5354dcc6 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -242,18 +242,28 @@ private async Task ExecuteCertificateRequestAsync( { throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, - $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCsrRequest failed.", + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed.", ex, ManagedIdentitySource.ImdsV2, (int)response.StatusCode); } + if (response.StatusCode != HttpStatusCode.OK) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed due to HTTP error. Status code: {response.StatusCode} Body: {response.Body}", + null, + ManagedIdentitySource.ImdsV2, + (int)response.StatusCode); + } + var certificateRequestResponse = JsonHelper.DeserializeFromJson(response.Body); if (!CertificateRequestResponse.IsValid(certificateRequestResponse)) { throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, - $"[ImdsV2] ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed because the CsrMetadata response is invalid. Status code: {response.StatusCode} Body: {response.Body}", + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed because the certificate request response is malformed. Status code: {response.StatusCode}", null, ManagedIdentitySource.ImdsV2, (int)response.StatusCode); From d46c853ba31e2d903d32aca5026ff566c41573f2 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Tue, 19 Aug 2025 14:36:28 -0400 Subject: [PATCH 20/29] Changes from manual testing. --- .../ManagedIdentity/Csr.cs | 11 +++-- .../ManagedIdentity/CsrMetadata.cs | 18 ++++---- .../ImdsV2ManagedIdentitySource.cs | 31 ++++++------- .../net/MsalJsonSerializerContext.cs | 5 ++ .../Core/Mocks/MockHelpers.cs | 6 +-- .../TestConstants.cs | 4 +- .../ManagedIdentityTests/CsrValidator.cs | 2 +- .../ManagedIdentityTests/ImdsV2Tests.cs | 46 ++++++++++--------- 8 files changed, 68 insertions(+), 55 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs index fdc8584cd3..899304b544 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs @@ -9,6 +9,11 @@ namespace Microsoft.Identity.Client.ManagedIdentity { + internal class PemPayload + { + public string pem { get; set; } + } + internal class Csr { public string Pem { get; } @@ -23,7 +28,7 @@ public Csr(string pem) /// /// Managed Identity client_id. /// AAD tenant_id. - /// CuidInfo object containing required VMID and optional VMSSID. + /// CuidInfo object containing required vmId and optional vmssId. /// CsrRequest containing the PEM CSR. public static Csr Generate(string clientId, string tenantId, CuidInfo cuid) { @@ -33,8 +38,8 @@ public static Csr Generate(string clientId, string tenantId, CuidInfo cuid) throw new ArgumentException("tenantId must not be null or empty.", nameof(tenantId)); if (cuid == null) throw new ArgumentNullException(nameof(cuid)); - if (string.IsNullOrEmpty(cuid.Vmid)) - throw new ArgumentException("cuid.Vmid must not be null or empty.", nameof(cuid.Vmid)); + if (string.IsNullOrEmpty(cuid.VmId)) + throw new ArgumentException("cuid.VmId must not be null or empty.", nameof(cuid.VmId)); string pemCsr = GeneratePkcs10Csr(clientId, tenantId, cuid); return new Csr(pemCsr); diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs index a831d02c7a..5de9a5e490 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs @@ -14,11 +14,11 @@ namespace Microsoft.Identity.Client.ManagedIdentity /// internal class CuidInfo { - [JsonProperty("vmid")] - public string Vmid { get; set; } + [JsonProperty("vmId")] + public string VmId { get; set; } - [JsonProperty("vmssid")] - public string Vmssid { get; set; } + [JsonProperty("vmssId")] + public string VmssId { get; set; } } /// @@ -29,8 +29,8 @@ internal class CsrMetadata /// /// VM unique Id /// - [JsonProperty("cuid")] - public CuidInfo Cuid { get; set; } + [JsonProperty("cuId")] + public CuidInfo CuId { get; set; } /// /// client_id of the Managed Identity @@ -57,12 +57,12 @@ public CsrMetadata() { } /// Validates a JSON decoded CsrMetadata instance. /// /// The CsrMetadata object. - /// false if any required field is null. Note: Vmid is required, Vmssid is optional. + /// false if any required field is null. Note: VmId is required, VmssId is optional. public static bool ValidateCsrMetadata(CsrMetadata csrMetadata) { if (csrMetadata == null || - csrMetadata.Cuid == null || - string.IsNullOrEmpty(csrMetadata.Cuid.Vmid) || + csrMetadata.CuId == null || + string.IsNullOrEmpty(csrMetadata.CuId.VmId) || string.IsNullOrEmpty(csrMetadata.ClientId) || string.IsNullOrEmpty(csrMetadata.TenantId) || string.IsNullOrEmpty(csrMetadata.AttestationEndpoint)) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 4d5354dcc6..bd0d2c2e4a 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -16,15 +16,16 @@ namespace Microsoft.Identity.Client.ManagedIdentity { internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { - private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; + public const string ImdsV2ApiVersion = "2.0"; + private const string CsrMetadataPath = "/metadata/identity/getplatformmetadata"; private const string CertificateRequestPath = "/metadata/identity/issuecredential"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, bool probeMode) { - string queryParams = $"cred-api-version={ImdsManagedIdentitySource.ImdsApiVersion}"; - + string queryParams = $"cred-api-version={ImdsV2ApiVersion}"; + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam( requestContext.ServiceBundle.Config.ManagedIdentityId.IdType, requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId, @@ -129,12 +130,14 @@ private static bool ValidateCsrMetadataResponse( * "1556" // index 1: captured group (\d+) * ] */ - string serverHeader = response.HeadersAsDictionary.TryGetValue("server", out var value) ? value : null; + // Imds bug: headers are missing + // TODO: uncomment this when the bug is fixed + /*string serverHeader = response.HeadersAsDictionary.TryGetValue("server", out var value) ? value : null; if (serverHeader == null) { if (probeMode) { - logger.Info(() => "[Managed Identity] IMDSv2 managed identity is not available. 'server' header is missing from the CSR metadata response."); + logger.Info(() => $"[Managed Identity] IMDSv2 managed identity is not available. 'server' header is missing from the CSR metadata response. Body: {response.Body}"); return false; } else @@ -164,7 +167,7 @@ private static bool ValidateCsrMetadataResponse( null, (int)response.StatusCode); } - } + }*/ return true; } @@ -196,26 +199,22 @@ internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.ImdsV2) { } private async Task ExecuteCertificateRequestAsync( - CuidInfo Cuid, + CuidInfo cuid, string pem) { - var queryParams = $"cid={JsonHelper.SerializeToJson(Cuid)}"; + var queryParams = $"cuid={JsonHelper.SerializeToJson(cuid)}&cred-api-version={ImdsV2ApiVersion}"; if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) { queryParams += $"&uaid{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; } - queryParams += $"&api-version={ImdsManagedIdentitySource.ImdsApiVersion}"; var headers = new Dictionary { { "Metadata", "true" }, { "x-ms-client-request-id", _requestContext.CorrelationId.ToString() } }; - - var payload = new - { - pem = pem - }; + + var payload = new PemPayload { pem = pem }; var body = JsonHelper.SerializeToJson(payload); IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; @@ -275,9 +274,9 @@ private async Task ExecuteCertificateRequestAsync( protected override ManagedIdentityRequest CreateRequest(string resource) { var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); - var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); + var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); - var certificateRequestResponse = ExecuteCertificateRequestAsync(csrMetadata.Cuid, csr.Pem).GetAwaiter().GetResult(); + var certificateRequestResponse = ExecuteCertificateRequestAsync(csrMetadata.CuId, csr.Pem).GetAwaiter().GetResult(); throw new NotImplementedException(); } diff --git a/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs b/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs index d36f036282..6d6a6cb7f2 100644 --- a/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs +++ b/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs @@ -40,6 +40,10 @@ namespace Microsoft.Identity.Client.Platforms.net [JsonSerializable(typeof(ManagedIdentityResponse))] [JsonSerializable(typeof(ManagedIdentityErrorResponse))] [JsonSerializable(typeof(OidcMetadata))] + [JsonSerializable(typeof(CsrMetadata))] + [JsonSerializable(typeof(CuidInfo))] + [JsonSerializable(typeof(CertificateRequestResponse))] + [JsonSerializable(typeof(PemPayload))] [JsonSourceGenerationOptions] internal partial class MsalJsonSerializerContext : JsonSerializerContext { @@ -54,6 +58,7 @@ public static MsalJsonSerializerContext Custom { NumberHandling = JsonNumberHandling.AllowReadingFromString, AllowTrailingCommas = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, Converters = { new JsonStringConverter(), diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index c0c293840a..3cf5e7a64f 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -590,12 +590,12 @@ public static MockHttpMessageHandler MockCsrResponse( { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); - expectedQueryParams.Add("cred-api-version", "2018-02-01"); + expectedQueryParams.Add("cred-api-version", "2.0"); expectedRequestHeaders.Add("Metadata", "true"); string content = "{" + - "\"cuid\": { \"vmid\": \"fake_vmid\", \"vmssid\": \"fake_vmssid\" }," + + "\"cuid\": { \"vmId\": \"fake_vmId\", \"vmssId\": \"fake_vmssId\" }," + "\"clientId\": \"fake_client_id\"," + "\"tenantId\": \"fake_tenant_id\"," + "\"attestationEndpoint\": \"fake_attestation_endpoint\"" + @@ -603,7 +603,7 @@ public static MockHttpMessageHandler MockCsrResponse( var handler = new MockHttpMessageHandler() { - ExpectedUrl = "http://169.254.169.254/metadata/identity/getPlatformMetadata", + ExpectedUrl = "http://169.254.169.254/metadata/identity/getplatformmetadata", ExpectedMethod = HttpMethod.Get, ExpectedQueryParams = expectedQueryParams, ExpectedRequestHeaders = expectedRequestHeaders, diff --git a/tests/Microsoft.Identity.Test.Common/TestConstants.cs b/tests/Microsoft.Identity.Test.Common/TestConstants.cs index d4a63354c0..5a2ea2986a 100644 --- a/tests/Microsoft.Identity.Test.Common/TestConstants.cs +++ b/tests/Microsoft.Identity.Test.Common/TestConstants.cs @@ -154,8 +154,8 @@ public static HashSet s_scope public const string IdentityProvider = "my-idp"; public const string Name = "First Last"; public const string MiResourceId = "/subscriptions/ffa4aaa2-4444-4444-5555-e3ccedd3d046/resourcegroups/UAMI_group/providers/Microsoft.ManagedIdentityClient/userAssignedIdentities/UAMI"; - public const string Vmid = "test-vm-id"; - public const string Vmssid = "test-vmss-id"; + public const string VmId = "test-vm-id"; + public const string VmssId = "test-vmss-id"; public const string Claims = @"{""userinfo"":{""given_name"":{""essential"":true},""nickname"":null,""email"":{""essential"":true},""email_verified"":{""essential"":true},""picture"":null,""http://example.info/claims/groups"":null},""id_token"":{""auth_time"":{""essential"":true},""acr"":{""values"":[""urn:mace:incommon:iap:silver""]}}}"; public static readonly string[] ClientCapabilities = new[] { "cp1", "cp2" }; diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs index 671700c100..adbc2e298d 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs @@ -302,7 +302,7 @@ private static void ValidatePublicKey(byte[] publicKeyBytes) /// /// Validates the CUID attribute contains the expected VM and VMSS IDs as JSON. - /// Note: Vmid is required, Vmssid is optional and will be omitted if null/empty. + /// Note: VmId is required, VmssId is optional and will be omitted if null/empty. /// private static void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expectedCuid) { diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 6851b425e3..312d43ec74 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -60,7 +60,9 @@ public async Task GetCsrMetadataAsyncSucceedsAfterRetry() } } - [TestMethod] + // Imds bug: headers are missing + // TODO: uncomment this when the bug is fixed + /*[TestMethod] public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() { using (var httpManager = new MockHttpManager()) @@ -75,9 +77,11 @@ public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } - } + }*/ - [TestMethod] + // Imds bug: headers are missing + // TODO: uncomment this when the bug is fixed + /*[TestMethod] public async Task GetCsrMetadataAsyncFailsWithInvalidVersion() { using (var httpManager = new MockHttpManager()) @@ -92,7 +96,7 @@ public async Task GetCsrMetadataAsyncFailsWithInvalidVersion() var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } - } + }*/ [TestMethod] public async Task GetCsrMetadataAsyncFailsAfterMaxRetries() @@ -137,8 +141,8 @@ public void TestCsrGeneration() { var cuid = new CuidInfo { - Vmid = TestConstants.Vmid, - Vmssid = TestConstants.Vmssid + VmId = TestConstants.VmId, + VmssId = TestConstants.VmssId }; // Generate CSR @@ -157,8 +161,8 @@ public void TestCsrGeneration_InvalidParameters(string clientId, string tenantId { var cuid = new CuidInfo { - Vmid = TestConstants.Vmid, - Vmssid = TestConstants.Vmssid + VmId = TestConstants.VmId, + //VmssId = TestConstants.VmssId }; Assert.ThrowsException(() => @@ -174,38 +178,38 @@ public void TestCsrGeneration_NullCuid() } [DataTestMethod] - [DataRow(null, TestConstants.Vmssid)] - [DataRow("", TestConstants.Vmssid)] - public void TestCsrGeneration_InvalidVmid(string vmid, string vmssid) + [DataRow(null, TestConstants.VmssId)] + [DataRow("", TestConstants.VmssId)] + public void TestCsrGeneration_InvalidVmId(string vmId, string vmssId) { var cuid = new CuidInfo { - Vmid = vmid, - Vmssid = vmssid + VmId = vmId, + //VmssId = vmssId }; - // Should throw ArgumentException since Vmid is required + // Should throw ArgumentException since VmId is required Assert.ThrowsException(() => Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid)); } [DataTestMethod] - [DataRow(TestConstants.Vmid, null)] - [DataRow(TestConstants.Vmid, "")] - public void TestCsrGeneration_OptionalVmssid(string vmid, string vmssid) + [DataRow(TestConstants.VmId, null)] + [DataRow(TestConstants.VmId, "")] + public void TestCsrGeneration_OptionalVmssId(string vmId, string vmssId) { var cuid = new CuidInfo { - Vmid = vmid, - Vmssid = vmssid + VmId = vmId, + //VmssId = vmssId }; - // Should succeed since Vmssid is optional (Vmid is provided and valid) + // Should succeed since VmssId is optional (VmId is provided and valid) var csrRequest = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); Assert.IsNotNull(csrRequest); Assert.IsFalse(string.IsNullOrWhiteSpace(csrRequest.Pem)); - // Validate the CSR contents - this should handle null/empty VMSSID gracefully + // Validate the CSR contents - this should handle null/empty vmssId gracefully CsrValidator.ValidateCsrContent(csrRequest.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); } From 3f75e3ad3c767246306c68aa66e89f4b6d2dc878 Mon Sep 17 00:00:00 2001 From: Robbie-Microsoft <87724641+Robbie-Microsoft@users.noreply.github.com> Date: Fri, 22 Aug 2025 16:27:25 -0400 Subject: [PATCH 21/29] ImdsV2: Reworked Custom ASN1 Encoder to use System.Formats.Asn1 Nuget Package (#5449) --- Directory.Packages.props | 2 +- .../ManagedIdentity/Csr.cs | 482 ------------------ .../ManagedIdentity/ManagedIdentityClient.cs | 1 + .../ManagedIdentity/V2/CertificateRequest.cs | 235 +++++++++ .../{ => V2}/CertificateRequestResponse.cs | 2 +- .../ManagedIdentity/V2/Csr.cs | 51 ++ .../ManagedIdentity/{ => V2}/CsrMetadata.cs | 2 +- .../{ => V2}/ImdsV2ManagedIdentitySource.cs | 11 +- .../Microsoft.Identity.Client.csproj | 2 + .../net/MsalJsonSerializerContext.cs | 2 +- .../ManagedIdentityTests/CsrValidator.cs | 457 +++-------------- .../ManagedIdentityTests/ImdsV2Tests.cs | 70 +-- 12 files changed, 388 insertions(+), 929 deletions(-) delete mode 100644 src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs create mode 100644 src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequest.cs rename src/client/Microsoft.Identity.Client/ManagedIdentity/{ => V2}/CertificateRequestResponse.cs (96%) create mode 100644 src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs rename src/client/Microsoft.Identity.Client/ManagedIdentity/{ => V2}/CsrMetadata.cs (97%) rename src/client/Microsoft.Identity.Client/ManagedIdentity/{ => V2}/ImdsV2ManagedIdentitySource.cs (97%) diff --git a/Directory.Packages.props b/Directory.Packages.props index f222636fe1..d8f368dfc9 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -17,6 +17,7 @@ + @@ -80,6 +81,5 @@ - diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs deleted file mode 100644 index 899304b544..0000000000 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs +++ /dev/null @@ -1,482 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Collections.Generic; -using System.Security.Cryptography; -using System.Text; -using Microsoft.Identity.Client.Utils; - -namespace Microsoft.Identity.Client.ManagedIdentity -{ - internal class PemPayload - { - public string pem { get; set; } - } - - internal class Csr - { - public string Pem { get; } - - public Csr(string pem) - { - Pem = pem ?? throw new ArgumentNullException(nameof(pem)); - } - - /// - /// Generates a CSR for the given client, tenant, and CUID info. - /// - /// Managed Identity client_id. - /// AAD tenant_id. - /// CuidInfo object containing required vmId and optional vmssId. - /// CsrRequest containing the PEM CSR. - public static Csr Generate(string clientId, string tenantId, CuidInfo cuid) - { - if (string.IsNullOrEmpty(clientId)) - throw new ArgumentException("clientId must not be null or empty.", nameof(clientId)); - if (string.IsNullOrEmpty(tenantId)) - throw new ArgumentException("tenantId must not be null or empty.", nameof(tenantId)); - if (cuid == null) - throw new ArgumentNullException(nameof(cuid)); - if (string.IsNullOrEmpty(cuid.VmId)) - throw new ArgumentException("cuid.VmId must not be null or empty.", nameof(cuid.VmId)); - - string pemCsr = GeneratePkcs10Csr(clientId, tenantId, cuid); - return new Csr(pemCsr); - } - - /// - /// Generates a PKCS#10 Certificate Signing Request in PEM format. - /// - private static string GeneratePkcs10Csr(string clientId, string tenantId, CuidInfo cuid) - { - // Generate RSA key pair (2048-bit) - RSA rsa = CreateRsaKeyPair(); - - try - { - // Build the CSR components - byte[] certificationRequestInfo = BuildCertificationRequestInfo(clientId, tenantId, cuid, rsa); - byte[] signatureAlgorithm = BuildSignatureAlgorithmIdentifier(); - byte[] signature = SignCertificationRequestInfo(certificationRequestInfo, rsa); - - // Combine into final CSR structure - byte[] csrBytes = BuildFinalCsr(certificationRequestInfo, signatureAlgorithm, signature); - - // Convert to PEM format - return ConvertToPem(csrBytes); - } - finally - { - rsa?.Dispose(); - } - } - - /// - /// Creates a 2048-bit RSA key pair that supports PSS padding across all target frameworks. - /// - /// - /// On .NET Framework 4.6.2/4.7.2 (Windows-only), explicitly uses RSACng (also Windows-only) - /// to ensure PSS padding support, as RSA.Create() may return RSACryptoServiceProvider - /// which doesn't support PSS. - /// On .NET Standard 2.0 and .NET 8.0+ (cross-platform), uses RSA.Create() which returns - /// modern implementations that support PSS: RSACng on Windows, OpenSSL-based on Linux/macOS. - /// - /// An RSA instance configured for 2048-bit keys with PSS padding capability. - private static RSA CreateRsaKeyPair() - { - // TODO: use the strongest key on the machine i.e. a TPM key - RSA rsa = null; - -#if NET462 || NET472 - // .NET Framework runs only on Windows, so RSACng (Windows-only) is always available - rsa = new RSACng(); -#else - // Cross-platform .NET - RSA.Create() returns appropriate PSS-capable implementation - rsa = RSA.Create(); -#endif - rsa.KeySize = 2048; - return rsa; - } - - /// - /// Builds the CertificationRequestInfo structure containing subject, public key, and attributes. - /// - private static byte[] BuildCertificationRequestInfo(string clientId, string tenantId, CuidInfo cuid, RSA rsa) - { - var components = new List(); - - // Version (INTEGER 0) - components.Add(EncodeAsn1Integer(new byte[] { 0x00 })); - - // Subject: CN=, DC= - components.Add(BuildSubjectName(clientId, tenantId)); - - // SubjectPublicKeyInfo - components.Add(BuildSubjectPublicKeyInfo(rsa)); - - // Attributes (including CUID) - components.Add(BuildAttributes(cuid)); - - return EncodeAsn1Sequence(components.ToArray()); - } - - /// - /// Builds the X.500 Distinguished Name for the subject. - /// - private static byte[] BuildSubjectName(string clientId, string tenantId) - { - var rdnSequence = new List(); - - // CN= - byte[] cnOid = EncodeAsn1ObjectIdentifier(new int[] { 2, 5, 4, 3 }); // commonName OID - byte[] cnValue = EncodeAsn1Utf8String(clientId); - byte[] cnAttributeValue = EncodeAsn1Sequence(new[] { cnOid, cnValue }); - rdnSequence.Add(EncodeAsn1Set(new[] { cnAttributeValue })); - - // DC= - byte[] dcOid = EncodeAsn1ObjectIdentifier(new int[] { 0, 9, 2342, 19200300, 100, 1, 25 }); // domainComponent OID - byte[] dcValue = EncodeAsn1Utf8String(tenantId); - byte[] dcAttributeValue = EncodeAsn1Sequence(new[] { dcOid, dcValue }); - rdnSequence.Add(EncodeAsn1Set(new[] { dcAttributeValue })); - - return EncodeAsn1Sequence(rdnSequence.ToArray()); - } - - /// - /// Builds the SubjectPublicKeyInfo structure containing the RSA public key. - /// - private static byte[] BuildSubjectPublicKeyInfo(RSA rsa) - { - RSAParameters rsaParams = rsa.ExportParameters(false); - - // RSA Public Key structure - byte[] modulus = EncodeAsn1Integer(rsaParams.Modulus); - byte[] exponent = EncodeAsn1Integer(rsaParams.Exponent); - byte[] rsaPublicKey = EncodeAsn1Sequence(new[] { modulus, exponent }); - - // Algorithm identifier for RSA encryption - byte[] rsaOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 1 }); // RSA encryption OID - byte[] nullParam = EncodeAsn1Null(); - byte[] algorithmIdentifier = EncodeAsn1Sequence(new[] { rsaOid, nullParam }); - - // SubjectPublicKeyInfo - byte[] publicKeyBitString = EncodeAsn1BitString(rsaPublicKey); - return EncodeAsn1Sequence(new[] { algorithmIdentifier, publicKeyBitString }); - } - - /// - /// Builds the attributes section including the CUID extension. - /// - private static byte[] BuildAttributes(CuidInfo cuid) - { - var attributes = new List(); - - // CUID attribute (OID 1.2.840.113549.1.9.7) - // Serialize CuidInfo as JSON object string using existing JSON serialization - byte[] cuidOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 9, 7 }); - string cuidValue = JsonHelper.SerializeToJson(cuid); - byte[] cuidData = EncodeAsn1PrintableString(cuidValue); - byte[] cuidAttributeValue = EncodeAsn1Set(new[] { cuidData }); - byte[] cuidAttribute = EncodeAsn1Sequence(new[] { cuidOid, cuidAttributeValue }); - attributes.Add(cuidAttribute); - - return EncodeAsn1ContextSpecific(0, EncodeAsn1SequenceRaw(attributes.ToArray())); - } - - /// - /// Builds the signature algorithm identifier for RSASSA-PSS with SHA256. - /// - private static byte[] BuildSignatureAlgorithmIdentifier() - { - byte[] rsassaPssOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 10 }); // RSASSA-PSS OID - byte[] pssParams = BuildPssParameters(); - return EncodeAsn1Sequence(new[] { rsassaPssOid, pssParams }); - } - - /// - /// Builds the RSASSA-PSS parameters for SHA256 with MGF1. - /// - private static byte[] BuildPssParameters() - { - var parameters = new List(); - - // hashAlgorithm [0] AlgorithmIdentifier DEFAULT sha1 - // We explicitly specify SHA256 since default is SHA1 - byte[] sha256Oid = EncodeAsn1ObjectIdentifier(new int[] { 2, 16, 840, 1, 101, 3, 4, 2, 1 }); // SHA256 OID - byte[] sha256Null = EncodeAsn1Null(); - byte[] hashAlgorithm = EncodeAsn1Sequence(new[] { sha256Oid, sha256Null }); - byte[] hashAlgorithmParam = EncodeAsn1ContextSpecific(0, hashAlgorithm); - parameters.Add(hashAlgorithmParam); - - // maskGenAlgorithm [1] AlgorithmIdentifier DEFAULT mgf1SHA1 - // We explicitly specify MGF1 with SHA256 - byte[] mgf1Oid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 8 }); // MGF1 OID - byte[] mgf1HashAlgorithm = EncodeAsn1Sequence(new[] { sha256Oid, sha256Null }); // MGF1 uses SHA256 - byte[] maskGenAlgorithm = EncodeAsn1Sequence(new[] { mgf1Oid, mgf1HashAlgorithm }); - byte[] maskGenAlgorithmParam = EncodeAsn1ContextSpecific(1, maskGenAlgorithm); - parameters.Add(maskGenAlgorithmParam); - - // saltLength [2] INTEGER DEFAULT 20 - // We explicitly specify 32 for SHA256 (hash length) - byte[] saltLength = EncodeAsn1Integer(32); - byte[] saltLengthParam = EncodeAsn1ContextSpecific(2, saltLength); - parameters.Add(saltLengthParam); - - // trailerField [3] INTEGER DEFAULT 1 - // Default value is 1 (0xBC), so we omit this parameter - - return EncodeAsn1Sequence(parameters.ToArray()); - } - - /// - /// Signs the CertificationRequestInfo with SHA256withRSA-PSS. - /// - private static byte[] SignCertificationRequestInfo(byte[] certificationRequestInfo, RSA rsa) - { - return rsa.SignData(certificationRequestInfo, HashAlgorithmName.SHA256, RSASignaturePadding.Pss); - } - - /// - /// Combines all components into the final CSR structure. - /// - private static byte[] BuildFinalCsr(byte[] certificationRequestInfo, byte[] signatureAlgorithm, byte[] signature) - { - byte[] signatureBitString = EncodeAsn1BitString(signature); - return EncodeAsn1Sequence(new[] { certificationRequestInfo, signatureAlgorithm, signatureBitString }); - } - - /// - /// Converts DER-encoded bytes to PEM format. - /// - private static string ConvertToPem(byte[] derBytes) - { - string base64 = Convert.ToBase64String(derBytes); - var sb = new StringBuilder(); - sb.AppendLine("-----BEGIN CERTIFICATE REQUEST-----"); - - // Split into 64-character lines - for (int i = 0; i < base64.Length; i += 64) - { - int length = Math.Min(64, base64.Length - i); - sb.AppendLine(base64.Substring(i, length)); - } - - sb.AppendLine("-----END CERTIFICATE REQUEST-----"); - return sb.ToString(); - } - - #region ASN.1 Encoding Helpers - - /// - /// Encodes an ASN.1 SEQUENCE. - /// - private static byte[] EncodeAsn1Sequence(byte[][] components) - { - return EncodeAsn1Tag(0x30, ConcatenateByteArrays(components)); - } - - /// - /// Encodes an ASN.1 SEQUENCE without the outer tag (for raw concatenation). - /// - private static byte[] EncodeAsn1SequenceRaw(byte[][] components) - { - return ConcatenateByteArrays(components); - } - - /// - /// Encodes an ASN.1 SET. - /// - private static byte[] EncodeAsn1Set(byte[][] components) - { - return EncodeAsn1Tag(0x31, ConcatenateByteArrays(components)); - } - - /// - /// Encodes an ASN.1 INTEGER. - /// - private static byte[] EncodeAsn1Integer(byte[] value) - { - // Ensure positive integer (prepend 0x00 if high bit is set) - if (value != null && value.Length > 0 && (value[0] & 0x80) != 0) - { - byte[] paddedValue = new byte[value.Length + 1]; - paddedValue[0] = 0x00; - Array.Copy(value, 0, paddedValue, 1, value.Length); - value = paddedValue; - } - return EncodeAsn1Tag(0x02, value ?? new byte[0]); - } - - /// - /// Encodes an ASN.1 INTEGER from an integer value. - /// - private static byte[] EncodeAsn1Integer(int value) - { - if (value == 0) - return EncodeAsn1Tag(0x02, new byte[] { 0x00 }); - - var bytes = new List(); - int temp = value; - while (temp > 0) - { - bytes.Insert(0, (byte)(temp & 0xFF)); - temp >>= 8; - } - - return EncodeAsn1Integer(bytes.ToArray()); - } - - /// - /// Encodes an ASN.1 BIT STRING. - /// - private static byte[] EncodeAsn1BitString(byte[] value) - { - byte[] bitStringValue = new byte[value.Length + 1]; - bitStringValue[0] = 0x00; // No unused bits - Array.Copy(value, 0, bitStringValue, 1, value.Length); - return EncodeAsn1Tag(0x03, bitStringValue); - } - - /// - /// Encodes an ASN.1 UTF8String. - /// - private static byte[] EncodeAsn1Utf8String(string value) - { - byte[] utf8Bytes = Encoding.UTF8.GetBytes(value); - return EncodeAsn1Tag(0x0C, utf8Bytes); - } - - /// - /// Encodes an ASN.1 PrintableString. - /// - private static byte[] EncodeAsn1PrintableString(string value) - { - byte[] asciiBytes = Encoding.ASCII.GetBytes(value); - return EncodeAsn1Tag(0x13, asciiBytes); - } - - /// - /// Encodes an ASN.1 NULL. - /// - private static byte[] EncodeAsn1Null() - { - return new byte[] { 0x05, 0x00 }; - } - - /// - /// Encodes an ASN.1 OBJECT IDENTIFIER. - /// - private static byte[] EncodeAsn1ObjectIdentifier(int[] oid) - { - if (oid == null || oid.Length < 2) - throw new ArgumentException("OID must have at least 2 components"); - - var bytes = new List(); - - // First two components are encoded as (first * 40 + second) - bytes.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); - - // Remaining components - for (int i = 2; i < oid.Length; i++) - { - bytes.AddRange(EncodeOidComponent(oid[i])); - } - - return EncodeAsn1Tag(0x06, bytes.ToArray()); - } - - /// - /// Encodes an ASN.1 context-specific tag. - /// - private static byte[] EncodeAsn1ContextSpecific(int tagNumber, byte[] content) - { - byte tag = (byte)(0xA0 | tagNumber); // Context-specific, constructed - return EncodeAsn1Tag(tag, content); - } - - /// - /// Encodes an ASN.1 tag with length and content. - /// - private static byte[] EncodeAsn1Tag(byte tag, byte[] content) - { - byte[] lengthBytes = EncodeAsn1Length(content.Length); - byte[] result = new byte[1 + lengthBytes.Length + content.Length]; - result[0] = tag; - Array.Copy(lengthBytes, 0, result, 1, lengthBytes.Length); - Array.Copy(content, 0, result, 1 + lengthBytes.Length, content.Length); - return result; - } - - /// - /// Encodes ASN.1 length field. - /// - private static byte[] EncodeAsn1Length(int length) - { - if (length < 0x80) - { - return new byte[] { (byte)length }; - } - - var lengthBytes = new List(); - int temp = length; - while (temp > 0) - { - lengthBytes.Insert(0, (byte)(temp & 0xFF)); - temp >>= 8; - } - - byte[] result = new byte[lengthBytes.Count + 1]; - result[0] = (byte)(0x80 | lengthBytes.Count); - lengthBytes.CopyTo(result, 1); - return result; - } - - /// - /// Encodes a single OID component using variable-length encoding. - /// - private static byte[] EncodeOidComponent(int value) - { - if (value == 0) - return new byte[] { 0x00 }; - - var bytes = new List(); - int temp = value; - - bytes.Insert(0, (byte)(temp & 0x7F)); - temp >>= 7; - - while (temp > 0) - { - bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); - temp >>= 7; - } - - return bytes.ToArray(); - } - - /// - /// Concatenates multiple byte arrays. - /// - private static byte[] ConcatenateByteArrays(byte[][] arrays) - { - int totalLength = 0; - foreach (byte[] array in arrays) - { - totalLength += array.Length; - } - - byte[] result = new byte[totalLength]; - int offset = 0; - foreach (byte[] array in arrays) - { - Array.Copy(array, 0, result, offset, array.Length); - offset += array.Length; - } - - return result; - } - - #endregion - } -} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs index 35f334d4bf..4e8e7fd8ce 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs @@ -9,6 +9,7 @@ using Microsoft.Identity.Client.PlatformsCommon.Shared; using System.IO; using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.ManagedIdentity.V2; namespace Microsoft.Identity.Client.ManagedIdentity { diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequest.cs new file mode 100644 index 0000000000..81a2de8d30 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequest.cs @@ -0,0 +1,235 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.ObjectModel; +using System.Formats.Asn1; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Text; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + internal class CertificateRequest + { + private X500DistinguishedName _subjectName; + private RSA _rsa; + private HashAlgorithmName _hashAlgorithmName; + private RSASignaturePadding _rsaPadding; + + internal CertificateRequest( + X500DistinguishedName subjectName, + RSA key, + HashAlgorithmName hashAlgorithm, + RSASignaturePadding padding) + { + _subjectName = subjectName; + _rsa = key; + _hashAlgorithmName = hashAlgorithm; + _rsaPadding = padding; + } + + internal Collection OtherRequestAttributes { get; } = new Collection(); + + private static string MakePem(byte[] ber, string header) + { + const int LineLength = 64; + + string base64 = Convert.ToBase64String(ber); + int offset = 0; + + StringBuilder builder = new StringBuilder("-----BEGIN "); + builder.Append(header); + builder.AppendLine("-----"); + + while (offset < base64.Length) + { + int lineEnd = Math.Min(offset + LineLength, base64.Length); + builder.AppendLine(base64.Substring(offset, lineEnd - offset)); + offset = lineEnd; + } + + builder.Append("-----END "); + builder.Append(header); + builder.AppendLine("-----"); + + return builder.ToString(); + } + + internal string CreateSigningRequestPem() + { + byte[] csr = CreateSigningRequest(); + return MakePem(csr, "CERTIFICATE REQUEST"); + } + + internal byte[] CreateSigningRequest() + { + if (_hashAlgorithmName != HashAlgorithmName.SHA256) + { + throw new NotSupportedException("Signature Processing has only been written for SHA256"); + } + + AsnWriter writer = new AsnWriter(AsnEncodingRules.DER); + + // RSAPublicKey ::= SEQUENCE { + // modulus INTEGER, -- n + // publicExponent INTEGER -- e + // } + + using (writer.PushSequence()) + { + RSAParameters rsaParameters = _rsa.ExportParameters(false); + writer.WriteIntegerUnsigned(rsaParameters.Modulus); + writer.WriteIntegerUnsigned(rsaParameters.Exponent); + } + + byte[] publicKey = writer.Encode(); + writer.Reset(); + + // CertificationRequestInfo ::= SEQUENCE { + // version INTEGER { v1(0) } (v1,...), + // subject Name, + // subjectPKInfo SubjectPublicKeyInfo{{ PKInfoAlgorithms }}, + // attributes [0] Attributes{{ CRIAttributes }} + // } + // + // SubjectPublicKeyInfo { ALGORITHM: IOSet} ::= SEQUENCE { + // algorithm AlgorithmIdentifier { { IOSet} }, + // subjectPublicKey BIT STRING + // } + // + // Attributes { ATTRIBUTE:IOSet } ::= SET OF Attribute{{ IOSet }} + // + // Attribute { ATTRIBUTE:IOSet } ::= SEQUENCE { + // type ATTRIBUTE.&id({IOSet}), + // values SET SIZE(1..MAX) OF ATTRIBUTE.&Type({IOSet}{@type}) + // } + + using (writer.PushSequence()) + { + writer.WriteInteger(0); + writer.WriteEncodedValue(_subjectName.RawData); + + // subjectPKInfo + using (writer.PushSequence()) + { + // algorithm + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier("1.2.840.113549.1.1.1"); + // RSA uses an explicit NULL value for parameters + writer.WriteNull(); + } + + writer.WriteBitString(publicKey); + } + + if (OtherRequestAttributes.Count > 0) + { + // attributes + using (writer.PushSetOf(new Asn1Tag(TagClass.ContextSpecific, 0))) + { + foreach (AsnEncodedData attribute in OtherRequestAttributes) + { + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier(attribute.Oid.Value); + + using (writer.PushSetOf()) + { + writer.WriteEncodedValue(attribute.RawData); + } + } + } + } + } + } + + byte[] certReqInfo = writer.Encode(); + writer.Reset(); + + // CertificationRequest ::= SEQUENCE { + // certificationRequestInfo CertificationRequestInfo, + // signatureAlgorithm AlgorithmIdentifier{{ SignatureAlgorithms }}, + // signature BIT STRING + // } + + using (writer.PushSequence()) + { + writer.WriteEncodedValue(certReqInfo); + + // signatureAlgorithm + using (writer.PushSequence()) + { + if (_rsaPadding == RSASignaturePadding.Pss) + { + if (_hashAlgorithmName != HashAlgorithmName.SHA256) + { + throw new NotSupportedException("Only SHA256 is supported with PSS padding."); + } + + writer.WriteObjectIdentifier("1.2.840.113549.1.1.10"); + + // RSASSA-PSS-params ::= SEQUENCE { + // hashAlgorithm [0] HashAlgorithm DEFAULT sha1, + // maskGenAlgorithm [1] MaskGenAlgorithm DEFAULT mgf1SHA1, + // saltLength [2] INTEGER DEFAULT 20, + // trailerField [3] TrailerField DEFAULT trailerFieldBC + // } + + using (writer.PushSequence()) + { + string digestOid = "2.16.840.1.101.3.4.2.1"; + + // hashAlgorithm + using (writer.PushSequence(new Asn1Tag(TagClass.ContextSpecific, 0))) + { + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier(digestOid); + } + } + + using (writer.PushSequence(new Asn1Tag(TagClass.ContextSpecific, 1))) + { + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier("1.2.840.113549.1.1.8"); + + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier(digestOid); + } + } + } + + // saltLength (SHA256.Length, 32 bytes) + using (writer.PushSequence(new Asn1Tag(TagClass.ContextSpecific, 2))) + { + writer.WriteInteger(32); + } + + // trailerField 1, which is trailerFieldBC, which is the DEFAULT, + // so don't write it down. + } + } + else if (_rsaPadding == RSASignaturePadding.Pkcs1) + { + writer.WriteObjectIdentifier("1.2.840.113549.1.1.11"); + // RSA PKCS1 uses an explicit NULL value for parameters + writer.WriteNull(); + } + else + { + throw new NotSupportedException("Unsupported RSA padding."); + } + } + + byte[] signature = _rsa.SignData(certReqInfo, _hashAlgorithmName, _rsaPadding); + writer.WriteBitString(signature); + } + + return writer.Encode(); + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs similarity index 96% rename from src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs rename to src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs index 4391fba4be..a000ea008c 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs @@ -7,7 +7,7 @@ using Microsoft.Identity.Json; #endif -namespace Microsoft.Identity.Client.ManagedIdentity +namespace Microsoft.Identity.Client.ManagedIdentity.V2 { /// /// Represents the response for a Managed Identity CSR request. diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs new file mode 100644 index 0000000000..35d6465d99 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Formats.Asn1; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Identity.Client.Utils; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + internal class Csr + { + internal static string Generate(string clientId, string tenantId, CuidInfo cuid) + { + using (RSA rsa = CreateRsaKeyPair()) + { + CertificateRequest req = new CertificateRequest( + new X500DistinguishedName($"CN={clientId}, DC={tenantId}"), + rsa, + HashAlgorithmName.SHA256, + RSASignaturePadding.Pss); + + AsnWriter writer = new AsnWriter(AsnEncodingRules.DER); + writer.WriteCharacterString(UniversalTagNumber.UTF8String, JsonHelper.SerializeToJson(cuid)); + + req.OtherRequestAttributes.Add( + new AsnEncodedData( + "1.2.840.113549.1.9.7", + writer.Encode())); + + return req.CreateSigningRequestPem(); + } + } + + private static RSA CreateRsaKeyPair() + { + // TODO: use the strongest key on the machine i.e. a TPM key + RSA rsa = null; + +#if NET462 || NET472 + // .NET Framework runs only on Windows, so RSACng (Windows-only) is always available + rsa = new RSACng(); +#else + // Cross-platform .NET - RSA.Create() returns appropriate PSS-capable implementation + rsa = RSA.Create(); +#endif + rsa.KeySize = 2048; + return rsa; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs similarity index 97% rename from src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs rename to src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs index 5de9a5e490..6281fcec14 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs @@ -7,7 +7,7 @@ using Microsoft.Identity.Json; #endif -namespace Microsoft.Identity.Client.ManagedIdentity +namespace Microsoft.Identity.Client.ManagedIdentity.V2 { /// /// Represents VM unique Ids for CSR metadata. diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs similarity index 97% rename from src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs rename to src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index bd0d2c2e4a..dc65aa72ef 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -12,7 +12,7 @@ using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Utils; -namespace Microsoft.Identity.Client.ManagedIdentity +namespace Microsoft.Identity.Client.ManagedIdentity.V2 { internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { @@ -200,7 +200,7 @@ internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : private async Task ExecuteCertificateRequestAsync( CuidInfo cuid, - string pem) + string csrPem) { var queryParams = $"cuid={JsonHelper.SerializeToJson(cuid)}&cred-api-version={ImdsV2ApiVersion}"; if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) @@ -214,8 +214,7 @@ private async Task ExecuteCertificateRequestAsync( { "x-ms-client-request-id", _requestContext.CorrelationId.ToString() } }; - var payload = new PemPayload { pem = pem }; - var body = JsonHelper.SerializeToJson(payload); + var body = $"{{\"pem\":\"{csrPem}\"}}"; IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.Imds); @@ -274,9 +273,9 @@ private async Task ExecuteCertificateRequestAsync( protected override ManagedIdentityRequest CreateRequest(string resource) { var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); - var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); + var csrPem = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); - var certificateRequestResponse = ExecuteCertificateRequestAsync(csrMetadata.CuId, csr.Pem).GetAwaiter().GetResult(); + var certificateRequestResponse = ExecuteCertificateRequestAsync(csrMetadata.CuId, csrPem).GetAwaiter().GetResult(); throw new NotImplementedException(); } diff --git a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj index 578bb27e45..8342355663 100644 --- a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj +++ b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj @@ -94,6 +94,7 @@ + @@ -118,6 +119,7 @@ + diff --git a/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs b/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs index 6d6a6cb7f2..5933d95e58 100644 --- a/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs +++ b/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs @@ -13,6 +13,7 @@ using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Kerberos; using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.OAuth2; using Microsoft.Identity.Client.Region; using Microsoft.Identity.Client.WsTrust; @@ -43,7 +44,6 @@ namespace Microsoft.Identity.Client.Platforms.net [JsonSerializable(typeof(CsrMetadata))] [JsonSerializable(typeof(CuidInfo))] [JsonSerializable(typeof(CertificateRequestResponse))] - [JsonSerializable(typeof(PemPayload))] [JsonSourceGenerationOptions] internal partial class MsalJsonSerializerContext : JsonSerializerContext { diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs index adbc2e298d..667c5a6050 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs @@ -2,17 +2,22 @@ // Licensed under the MIT License. using System; -using Microsoft.Identity.Client.ManagedIdentity; +using System.Formats.Asn1; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.Utils; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests { /// - /// Test helper to expose CsrValidator methods for testing malformed PEM. + /// Helper class for parsing and validating Certificate Signing Request (CSR) content and structure. /// - internal static class TestCsrValidator + internal static class CsrValidator { + /// + /// Parses a PEM-encoded CSR and returns the DER bytes. + /// public static byte[] ParseCsrFromPem(string pemCsr) { if (string.IsNullOrWhiteSpace(pemCsr)) @@ -21,15 +26,13 @@ public static byte[] ParseCsrFromPem(string pemCsr) const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; const string endMarker = "-----END CERTIFICATE REQUEST-----"; - if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) - throw new ArgumentException("Invalid PEM format - missing CSR headers"); + int beginIndex = pemCsr.IndexOf(beginMarker, StringComparison.Ordinal); + int endIndex = pemCsr.IndexOf(endMarker, StringComparison.Ordinal); - int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; - int endIndex = pemCsr.IndexOf(endMarker); - - if (beginIndex >= endIndex) - throw new ArgumentException("Invalid PEM format - malformed headers"); + if (beginIndex < 0 || endIndex < 0) + throw new ArgumentException("Invalid PEM format - missing CSR headers"); + beginIndex += beginMarker.Length; string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) .Replace("\r", "").Replace("\n", "").Replace(" ", ""); @@ -42,390 +45,100 @@ public static byte[] ParseCsrFromPem(string pemCsr) throw new FormatException("Invalid Base64 content in PEM CSR"); } } - } - /// - /// Helper class for validating Certificate Signing Request (CSR) content and structure. - /// - internal static class CsrValidator - { /// /// Validates the content of a CSR PEM string against expected values. /// public static void ValidateCsrContent(string pemCsr, string expectedClientId, string expectedTenantId, CuidInfo expectedCuid) { - // Parse the CSR from PEM format - var csrData = ParseCsrFromPem(pemCsr); - - // Parse the PKCS#10 structure - var csrInfo = ParsePkcs10Structure(csrData); - - // Validate subject name - ValidateSubjectName(csrInfo.Subject, expectedClientId, expectedTenantId); - - // Validate public key - ValidatePublicKey(csrInfo.PublicKey); - - // Validate CUID attribute - ValidateCuidAttribute(csrInfo.Attributes, expectedCuid); - - // Validate signature algorithm - ValidateSignatureAlgorithm(csrInfo.SignatureAlgorithm); - } - - /// - /// Parses a PEM-formatted CSR and returns the DER bytes. - /// - private static byte[] ParseCsrFromPem(string pemCsr) - { - if (string.IsNullOrWhiteSpace(pemCsr)) - throw new ArgumentException("PEM CSR cannot be null or empty"); - - const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; - const string endMarker = "-----END CERTIFICATE REQUEST-----"; - - if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) - throw new ArgumentException("Invalid PEM format - missing CSR headers"); - - int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; - int endIndex = pemCsr.IndexOf(endMarker); - - if (beginIndex >= endIndex) - throw new ArgumentException("Invalid PEM format - malformed headers"); - - string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) - .Replace("\r", "").Replace("\n", "").Replace(" ", ""); - - try + byte[] csrBytes = ParseCsrFromPem(pemCsr); + + // Parse the CSR using AsnReader + var reader = new AsnReader(csrBytes, AsnEncodingRules.DER); + var csrSequence = reader.ReadSequence(); + + // certificationRequestInfo + var certReqInfoBytes = csrSequence.PeekEncodedValue().ToArray(); + var certReqInfoReader = new AsnReader(csrSequence.ReadEncodedValue().ToArray(), AsnEncodingRules.DER); + var certReqInfoSeq = certReqInfoReader.ReadSequence(); + + // version + int version = (int)certReqInfoSeq.ReadInteger(); + Assert.AreEqual(0, version, "CSR version should be 0"); + + // subject + var subjectBytes = certReqInfoSeq.PeekEncodedValue().ToArray(); + var subject = new X500DistinguishedName(certReqInfoSeq.ReadEncodedValue().ToArray()); + string subjectString = subject.Name; + + Assert.IsTrue(subjectString.Contains($"CN={expectedClientId}"), "Client ID (CN) not found in subject"); + Assert.IsTrue(subjectString.Contains($"DC={expectedTenantId}"), "Tenant ID (DC) not found in subject"); + + // subjectPKInfo + var pkInfoReader = new AsnReader(certReqInfoSeq.ReadEncodedValue().ToArray(), AsnEncodingRules.DER); + var pkInfoSeq = pkInfoReader.ReadSequence(); + + // algorithm + var algIdSeq = pkInfoSeq.ReadSequence(); + string algOid = algIdSeq.ReadObjectIdentifier(); + Assert.AreEqual("1.2.840.113549.1.1.1", algOid, "Public key algorithm is not RSA"); + if (algIdSeq.HasData) { - return Convert.FromBase64String(base64Content); + algIdSeq.ReadNull(); } - catch (FormatException) - { - throw new FormatException("Invalid Base64 content in PEM CSR"); - } - } - /// - /// Represents parsed PKCS#10 CSR information. - /// - private class CsrInfo - { - public byte[] Subject { get; set; } - public byte[] PublicKey { get; set; } - public byte[] Attributes { get; set; } - public byte[] SignatureAlgorithm { get; set; } - } - - /// - /// Parses the PKCS#10 ASN.1 structure and extracts key components. - /// - private static CsrInfo ParsePkcs10Structure(byte[] derBytes) - { - int offset = 0; - - // Parse outer SEQUENCE (CertificationRequest) - var outerSequence = ParseAsn1Tag(derBytes, ref offset, 0x30); - - // Reset offset to parse the CertificationRequestInfo within the outer sequence - int infoOffset = 0; - var certRequestInfo = ParseAsn1Tag(outerSequence, ref infoOffset, 0x30); - - // Parse version (should be 0) - int versionOffset = 0; - var version = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x02); - if (version.Length != 1 || version[0] != 0x00) - throw new ArgumentException("Invalid CSR version"); - - // Parse subject - var subject = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); - - // Parse SubjectPublicKeyInfo - var publicKey = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); - - // Parse attributes (context-specific [0]) - var attributes = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0xA0); - - return new CsrInfo - { - Subject = subject, - PublicKey = publicKey, - Attributes = attributes, - SignatureAlgorithm = new byte[0] // Simplified for this test - }; - } - - /// - /// Parses an ASN.1 tag and returns its content. - /// - private static byte[] ParseAsn1Tag(byte[] data, ref int offset, byte expectedTag) - { - if (offset >= data.Length) - throw new ArgumentException("Unexpected end of data"); - - // Check tag (if expectedTag is -1, accept any tag) - if (expectedTag != 255 && data[offset] != expectedTag) - throw new ArgumentException($"Expected tag 0x{expectedTag:X2}, got 0x{data[offset]:X2}"); - - offset++; - - // Parse length - int length = ParseAsn1Length(data, ref offset); - - // Extract content - if (offset + length > data.Length) - throw new ArgumentException("Invalid ASN.1 length"); - - byte[] content = new byte[length]; - Array.Copy(data, offset, content, 0, length); - offset += length; - - return content; - } + // subjectPublicKey BIT STRING + var publicKeyBitString = pkInfoSeq.ReadBitString(out _); - /// - /// Parses ASN.1 length encoding. - /// - private static int ParseAsn1Length(byte[] data, ref int offset) - { - if (offset >= data.Length) - throw new ArgumentException("Unexpected end of data in length"); - - byte firstByte = data[offset++]; - - // Short form - if ((firstByte & 0x80) == 0) - return firstByte; - - // Long form - int lengthBytes = firstByte & 0x7F; - if (lengthBytes == 0) - throw new ArgumentException("Indefinite length not supported"); - - if (offset + lengthBytes > data.Length) - throw new ArgumentException("Invalid length encoding"); - - int length = 0; - for (int i = 0; i < lengthBytes; i++) - { - length = (length << 8) | data[offset++]; - } - - return length; - } + // Parse the RSAPublicKey structure from the BIT STRING (SEQUENCE of modulus, exponent) + var rsaKeyReader = new AsnReader(publicKeyBitString, AsnEncodingRules.DER); + var rsaKeySeq = rsaKeyReader.ReadSequence(); + byte[] modulus = rsaKeySeq.ReadIntegerBytes().ToArray(); + byte[] exponent = rsaKeySeq.ReadIntegerBytes().ToArray(); - /// - /// Validates the subject name contains the expected client ID and tenant ID. - /// - private static void ValidateSubjectName(byte[] subjectBytes, string expectedClientId, string expectedTenantId) - { - // Subject is already a SEQUENCE of RDNs - int offset = 0; - bool foundClientId = false; - bool foundTenantId = false; - - // Parse each RDN (Relative Distinguished Name) directly from subjectBytes - while (offset < subjectBytes.Length) - { - var rdnSet = ParseAsn1Tag(subjectBytes, ref offset, 0x31); // SET - - int rdnOffset = 0; - var rdnSequence = ParseAsn1Tag(rdnSet, ref rdnOffset, 0x30); // SEQUENCE - - // Parse OID and value - int attrOffset = 0; - var oid = ParseAsn1Tag(rdnSequence, ref attrOffset, 0x06); // OID - var value = ParseAsn1Tag(rdnSequence, ref attrOffset, 255); // Any string type - - string stringValue = System.Text.Encoding.UTF8.GetString(value); - - // Check for CN (commonName) OID: 2.5.4.3 - if (IsOid(oid, new int[] { 2, 5, 4, 3 })) - { - Assert.AreEqual(expectedClientId, stringValue, "Client ID in subject CN does not match"); - foundClientId = true; - } - // Check for DC (domainComponent) OID: 0.9.2342.19200300.100.1.25 - else if (IsOid(oid, new int[] { 0, 9, 2342, 19200300, 100, 1, 25 })) - { - Assert.AreEqual(expectedTenantId, stringValue, "Tenant ID in subject DC does not match"); - foundTenantId = true; - } - } - - Assert.IsTrue(foundClientId, "Client ID (CN) not found in subject"); - Assert.IsTrue(foundTenantId, "Tenant ID (DC) not found in subject"); - } + // Validate modulus length (2048 bits = 256 bytes, may have leading zero) + Assert.IsTrue(modulus.Length == 256 || modulus.Length == 257, $"RSA modulus should be 2048 bits, got {modulus.Length * 8} bits"); - /// - /// Validates the public key is a valid RSA key. - /// - private static void ValidatePublicKey(byte[] publicKeyBytes) - { - // publicKeyBytes is already the SubjectPublicKeyInfo SEQUENCE content - int offset = 0; - - // Parse algorithm identifier - var algorithmId = ParseAsn1Tag(publicKeyBytes, ref offset, 0x30); - - // Parse public key bit string - var publicKeyBitString = ParseAsn1Tag(publicKeyBytes, ref offset, 0x03); - - // Validate algorithm is RSA (1.2.840.113549.1.1.1) - int algOffset = 0; - var algorithmOid = ParseAsn1Tag(algorithmId, ref algOffset, 0x06); - Assert.IsTrue(IsOid(algorithmOid, new int[] { 1, 2, 840, 113549, 1, 1, 1 }), - "Public key algorithm is not RSA"); - - // Skip the unused bits byte in bit string - if (publicKeyBitString.Length < 2 || publicKeyBitString[0] != 0x00) - throw new ArgumentException("Invalid public key bit string"); - - // Parse RSA public key (skip unused bits byte) - byte[] rsaKeyBytes = new byte[publicKeyBitString.Length - 1]; - Array.Copy(publicKeyBitString, 1, rsaKeyBytes, 0, rsaKeyBytes.Length); - - int rsaOffset = 0; - var rsaSequence = ParseAsn1Tag(rsaKeyBytes, ref rsaOffset, 0x30); - - rsaOffset = 0; - var modulus = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); - var exponent = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); - - // Validate key size (should be 2048 bits = 256 bytes, plus potential leading zero) - Assert.IsTrue(modulus.Length >= 256 && modulus.Length <= 257, - $"RSA modulus should be 2048 bits, got {modulus.Length * 8} bits"); - // Validate exponent (commonly 65537 = 0x010001) Assert.IsTrue(exponent.Length >= 1 && exponent.Length <= 4, "RSA exponent has invalid length"); - } - /// - /// Validates the CUID attribute contains the expected VM and VMSS IDs as JSON. - /// Note: VmId is required, VmssId is optional and will be omitted if null/empty. - /// - private static void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expectedCuid) - { - // Attributes is a SET of attributes - // We expect one attribute with challengePassword OID (1.2.840.113549.1.9.7) - - int offset = 0; - bool foundCuid = false; - - // Parse each attribute in the SET - while (offset < attributesBytes.Length) + // attributes [0] (optional) + if (certReqInfoSeq.HasData) { - var attributeSequence = ParseAsn1Tag(attributesBytes, ref offset, 0x30); - - int attrOffset = 0; - var oid = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x06); - var valueSet = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x31); // SET of values - - // Check for challengePassword OID: 1.2.840.113549.1.9.7 - if (IsOid(oid, new int[] { 1, 2, 840, 113549, 1, 9, 7 })) + var attrTag = new Asn1Tag(TagClass.ContextSpecific, 0); + if (certReqInfoSeq.PeekTag().HasSameClassAndValue(attrTag)) { - // Parse the value from the SET (should be one value) - int valueOffset = 0; - var value = ParseAsn1Tag(valueSet, ref valueOffset, 255); // Any string type - - string cuidValue = System.Text.Encoding.ASCII.GetString(value); - - // Build expected CUID value as JSON - string expectedCuidValue = BuildExpectedCuidJson(expectedCuid); - - Assert.AreEqual(expectedCuidValue, cuidValue, "CUID attribute JSON value does not match expected"); - foundCuid = true; - break; + var attrSetReader = certReqInfoSeq.ReadSetOf(attrTag); + bool foundCuid = false; + while (attrSetReader.HasData) + { + var attrSeq = attrSetReader.ReadSequence(); + string oid = attrSeq.ReadObjectIdentifier(); + if (oid == "1.2.840.113549.1.9.7") // challengePassword + { + var valueSet = attrSeq.ReadSetOf(); + while (valueSet.HasData) + { + string cuidJson = valueSet.ReadCharacterString(UniversalTagNumber.UTF8String); + string expectedCuidJson = JsonHelper.SerializeToJson(expectedCuid); + Assert.AreEqual(expectedCuidJson, cuidJson, "CUID attribute JSON value does not match expected"); + foundCuid = true; + } + } + } + Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); } } - - Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); - } - /// - /// Builds the expected CUID JSON string for validation using JsonHelper. - /// - private static string BuildExpectedCuidJson(CuidInfo expectedCuid) - { - return JsonHelper.SerializeToJson(expectedCuid); - } - - /// - /// Validates the signature algorithm is SHA256withRSA. - /// - private static void ValidateSignatureAlgorithm(byte[] signatureAlgBytes) - { - // For this test, we'll just verify that signature algorithm exists - // Full validation would require parsing the outer CSR structure - // which is more complex for this unit test scenario - Assert.IsNotNull(signatureAlgBytes, "Signature algorithm should be present"); - } + // signatureAlgorithm + var sigAlgSeq = csrSequence.ReadSequence(); + string sigAlgOid = sigAlgSeq.ReadObjectIdentifier(); + Assert.AreEqual("1.2.840.113549.1.1.10", sigAlgOid, "Signature algorithm is not RSASSA-PSS (SHA256withRSA/PSS)"); - /// - /// Checks if the given OID bytes match the expected OID components. - /// - private static bool IsOid(byte[] oidBytes, int[] expectedOid) - { - if (expectedOid.Length < 2) - return false; - - var expectedBytes = EncodeOid(expectedOid); - - if (oidBytes.Length != expectedBytes.Length) - return false; - - for (int i = 0; i < oidBytes.Length; i++) - { - if (oidBytes[i] != expectedBytes[i]) - return false; - } - - return true; - } + // signature + csrSequence.ReadBitString(out _); - /// - /// Encodes an OID from integer components to bytes (simplified version). - /// - private static byte[] EncodeOid(int[] oid) - { - if (oid.Length < 2) - throw new ArgumentException("OID must have at least 2 components"); - - var result = new System.Collections.Generic.List(); - - // First two components are encoded as (first * 40 + second) - result.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); - - // Remaining components - for (int i = 2; i < oid.Length; i++) - { - result.AddRange(EncodeOidComponent(oid[i])); - } - - return result.ToArray(); - } - - /// - /// Encodes a single OID component using variable-length encoding. - /// - private static byte[] EncodeOidComponent(int value) - { - if (value == 0) - return new byte[] { 0x00 }; - - var bytes = new System.Collections.Generic.List(); - int temp = value; - - bytes.Insert(0, (byte)(temp & 0x7F)); - temp >>= 7; - - while (temp > 0) - { - bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); - temp >>= 7; - } - - return bytes.ToArray(); + Assert.IsFalse(csrSequence.HasData, "Extra data found after CSR structure"); } } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 312d43ec74..afd52214f9 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -7,6 +7,7 @@ using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.Identity.Test.Unit.Helpers; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -146,71 +147,10 @@ public void TestCsrGeneration() }; // Generate CSR - var csr = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); + var csrPem = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); // Validate the CSR contents using the helper - CsrValidator.ValidateCsrContent(csr.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); - } - - [DataTestMethod] - [DataRow(null, TestConstants.TenantId)] - [DataRow("", TestConstants.TenantId)] - [DataRow(TestConstants.ClientId, null)] - [DataRow(TestConstants.ClientId, "")] - public void TestCsrGeneration_InvalidParameters(string clientId, string tenantId) - { - var cuid = new CuidInfo - { - VmId = TestConstants.VmId, - //VmssId = TestConstants.VmssId - }; - - Assert.ThrowsException(() => - Csr.Generate(clientId, tenantId, cuid)); - } - - [TestMethod] - public void TestCsrGeneration_NullCuid() - { - // Test with null CUID - Assert.ThrowsException(() => - Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, null)); - } - - [DataTestMethod] - [DataRow(null, TestConstants.VmssId)] - [DataRow("", TestConstants.VmssId)] - public void TestCsrGeneration_InvalidVmId(string vmId, string vmssId) - { - var cuid = new CuidInfo - { - VmId = vmId, - //VmssId = vmssId - }; - - // Should throw ArgumentException since VmId is required - Assert.ThrowsException(() => - Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid)); - } - - [DataTestMethod] - [DataRow(TestConstants.VmId, null)] - [DataRow(TestConstants.VmId, "")] - public void TestCsrGeneration_OptionalVmssId(string vmId, string vmssId) - { - var cuid = new CuidInfo - { - VmId = vmId, - //VmssId = vmssId - }; - - // Should succeed since VmssId is optional (VmId is provided and valid) - var csrRequest = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); - Assert.IsNotNull(csrRequest); - Assert.IsFalse(string.IsNullOrWhiteSpace(csrRequest.Pem)); - - // Validate the CSR contents - this should handle null/empty vmssId gracefully - CsrValidator.ValidateCsrContent(csrRequest.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); + CsrValidator.ValidateCsrContent(csrPem, TestConstants.ClientId, TestConstants.TenantId, cuid); } [TestMethod] @@ -218,7 +158,7 @@ public void TestCsrGeneration_MalformedPem_FormatException() { string malformedPem = "-----BEGIN CERTIFICATE REQUEST-----\nInvalid@#$%Base64Content!\n-----END CERTIFICATE REQUEST-----"; Assert.ThrowsException(() => - TestCsrValidator.ParseCsrFromPem(malformedPem)); + CsrValidator.ParseCsrFromPem(malformedPem)); } [DataTestMethod] @@ -228,7 +168,7 @@ public void TestCsrGeneration_MalformedPem_FormatException() public void TestCsrGeneration_MalformedPem_ArgumentException(string malformedPem) { Assert.ThrowsException(() => - TestCsrValidator.ParseCsrFromPem(malformedPem)); + CsrValidator.ParseCsrFromPem(malformedPem)); } } } From 3481c395ced132890e57c90e6f1cf04b6bd4fcb5 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Mon, 25 Aug 2025 10:53:57 -0400 Subject: [PATCH 22/29] Implemented feedback --- .../V2/CertificateRequestResponse.cs | 14 +++++++++----- .../V2/ImdsV2ManagedIdentitySource.cs | 10 +--------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs index a000ea008c..77d32c9ff0 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs @@ -1,10 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System.Net; #if SUPPORTS_SYSTEM_TEXT_JSON using JsonProperty = System.Text.Json.Serialization.JsonPropertyNameAttribute; #else - using Microsoft.Identity.Json; +using Microsoft.Identity.Json; #endif namespace Microsoft.Identity.Client.ManagedIdentity.V2 @@ -34,7 +35,7 @@ internal class CertificateRequestResponse public CertificateRequestResponse() { } - public static bool IsValid(CertificateRequestResponse certificateRequestResponse) + public static void Validate(CertificateRequestResponse certificateRequestResponse) { if (string.IsNullOrEmpty(certificateRequestResponse.ClientId) || string.IsNullOrEmpty(certificateRequestResponse.TenantId) || @@ -43,10 +44,13 @@ public static bool IsValid(CertificateRequestResponse certificateRequestResponse certificateRequestResponse.ExpiresIn <= 0 || certificateRequestResponse.RefreshIn <= 0) { - return false; + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed because the certificate request response is malformed. Status code: 200", + null, + ManagedIdentitySource.ImdsV2, + (int)HttpStatusCode.OK); } - - return true; } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index dc65aa72ef..658c0f27a7 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -257,15 +257,7 @@ private async Task ExecuteCertificateRequestAsync( } var certificateRequestResponse = JsonHelper.DeserializeFromJson(response.Body); - if (!CertificateRequestResponse.IsValid(certificateRequestResponse)) - { - throw MsalServiceExceptionFactory.CreateManagedIdentityException( - MsalError.ManagedIdentityRequestFailed, - $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed because the certificate request response is malformed. Status code: {response.StatusCode}", - null, - ManagedIdentitySource.ImdsV2, - (int)response.StatusCode); - } + CertificateRequestResponse.Validate(certificateRequestResponse); return certificateRequestResponse; } From 92158bb9fc6caba3df76b7b81411b2286de936cd Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Mon, 25 Aug 2025 15:22:15 -0400 Subject: [PATCH 23/29] Small rework due to spec changes --- .../V2/CertificateRequestResponse.cs | 27 +++++++++---------- .../V2/ImdsV2ManagedIdentitySource.cs | 14 +++++----- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs index 77d32c9ff0..5e84000054 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System.Buffers.Text; using System.Net; #if SUPPORTS_SYSTEM_TEXT_JSON using JsonProperty = System.Text.Json.Serialization.JsonPropertyNameAttribute; @@ -16,22 +17,19 @@ namespace Microsoft.Identity.Client.ManagedIdentity.V2 internal class CertificateRequestResponse { [JsonProperty("client_id")] - public string ClientId { get; set; } + public string ClientId { get; set; } // client_id of the Managed Identity  [JsonProperty("tenant_id")] - public string TenantId { get; set; } + public string TenantId { get; set; } // AAD Tenant of the Managed Identity  - [JsonProperty("client_credential")] - public string ClientCredential { get; set; } + [JsonProperty("certificate")] + public string Certificate { get; set; } // Base64 encoded X509certificate - [JsonProperty("regional_token_url")] - public string RegionalTokenUrl { get; set; } + [JsonProperty("identity_type")] + public string IdentityType { get; set; } // SAMI or UAMI - [JsonProperty("expires_in")] - public int ExpiresIn { get; set; } - - [JsonProperty("refresh_in")] - public int RefreshIn { get; set; } + [JsonProperty("mtls_authentication_endpoint")] + public string MtlsAuthenticationEndpoint { get; set; } // Regional STS mTLS endpoint public CertificateRequestResponse() { } @@ -39,10 +37,9 @@ public static void Validate(CertificateRequestResponse certificateRequestRespons { if (string.IsNullOrEmpty(certificateRequestResponse.ClientId) || string.IsNullOrEmpty(certificateRequestResponse.TenantId) || - string.IsNullOrEmpty(certificateRequestResponse.ClientCredential) || - string.IsNullOrEmpty(certificateRequestResponse.RegionalTokenUrl) || - certificateRequestResponse.ExpiresIn <= 0 || - certificateRequestResponse.RefreshIn <= 0) + string.IsNullOrEmpty(certificateRequestResponse.Certificate) || + string.IsNullOrEmpty(certificateRequestResponse.IdentityType) || + string.IsNullOrEmpty(certificateRequestResponse.MtlsAuthenticationEndpoint)) { throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index 658c0f27a7..e0975dbfef 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -198,14 +198,12 @@ public static AbstractManagedIdentity Create(RequestContext requestContext) internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.ImdsV2) { } - private async Task ExecuteCertificateRequestAsync( - CuidInfo cuid, - string csrPem) + private async Task ExecuteCertificateRequestAsync(string csr) { - var queryParams = $"cuid={JsonHelper.SerializeToJson(cuid)}&cred-api-version={ImdsV2ApiVersion}"; + var queryParams = $"cred-api-version={ImdsV2ApiVersion}"; if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) { - queryParams += $"&uaid{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; + queryParams += $"&client_id{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; } var headers = new Dictionary @@ -214,7 +212,7 @@ private async Task ExecuteCertificateRequestAsync( { "x-ms-client-request-id", _requestContext.CorrelationId.ToString() } }; - var body = $"{{\"pem\":\"{csrPem}\"}}"; + var body = $"{{\"csr\":\"{csr}\"}}"; IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.Imds); @@ -265,9 +263,9 @@ private async Task ExecuteCertificateRequestAsync( protected override ManagedIdentityRequest CreateRequest(string resource) { var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); - var csrPem = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); + var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); - var certificateRequestResponse = ExecuteCertificateRequestAsync(csrMetadata.CuId, csrPem).GetAwaiter().GetResult(); + var certificateRequestResponse = ExecuteCertificateRequestAsync(csr).GetAwaiter().GetResult(); throw new NotImplementedException(); } From 729a56a69443c6ba2168502b76443699914c5d11 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Mon, 25 Aug 2025 15:51:40 -0400 Subject: [PATCH 24/29] Additional rework due to spec changes --- .../V2/ImdsV2ManagedIdentitySource.cs | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index e0975dbfef..fbcdf877a2 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -24,16 +24,7 @@ public static async Task GetCsrMetadataAsync( RequestContext requestContext, bool probeMode) { - string queryParams = $"cred-api-version={ImdsV2ApiVersion}"; - - var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam( - requestContext.ServiceBundle.Config.ManagedIdentityId.IdType, - requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId, - requestContext.Logger); - if (userAssignedIdQueryParam != null) - { - queryParams += $"&{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}"; - } + var queryParams = ImdsV2QueryParamsHelper(requestContext); var headers = new Dictionary { @@ -200,11 +191,9 @@ internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : private async Task ExecuteCertificateRequestAsync(string csr) { - var queryParams = $"cred-api-version={ImdsV2ApiVersion}"; - if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) - { - queryParams += $"&client_id{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; - } + var queryParams = ImdsV2QueryParamsHelper(_requestContext); + + // TODO: add bypass_cache query param in case of token revocation. Boolean: true/false var headers = new Dictionary { @@ -269,5 +258,21 @@ protected override ManagedIdentityRequest CreateRequest(string resource) throw new NotImplementedException(); } + + private static string ImdsV2QueryParamsHelper(RequestContext requestContext) + { + var queryParams = $"cred-api-version={ImdsV2ApiVersion}"; + + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam( + requestContext.ServiceBundle.Config.ManagedIdentityId.IdType, + requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId, + requestContext.Logger); + if (userAssignedIdQueryParam != null) + { + queryParams += $"&{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}"; + } + + return queryParams; + } } } From 3027392af6ca3951da59cd1e19735094b24c2a97 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Mon, 25 Aug 2025 17:03:38 -0400 Subject: [PATCH 25/29] Implemented feedback --- .../ManagedIdentity/V2/Csr.cs | 2 +- .../ManagedIdentityTests/CsrValidator.cs | 2 +- .../ManagedIdentityTests/ImdsV2Tests.cs | 17 +++++++++++++---- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs index 35d6465d99..3f3b1175a3 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs @@ -25,7 +25,7 @@ internal static string Generate(string clientId, string tenantId, CuidInfo cuid) req.OtherRequestAttributes.Add( new AsnEncodedData( - "1.2.840.113549.1.9.7", + "1.3.6.1.4.1.311.90.2.10", writer.Encode())); return req.CreateSigningRequestPem(); diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs index 667c5a6050..23b80e7303 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs @@ -114,7 +114,7 @@ public static void ValidateCsrContent(string pemCsr, string expectedClientId, st { var attrSeq = attrSetReader.ReadSequence(); string oid = attrSeq.ReadObjectIdentifier(); - if (oid == "1.2.840.113549.1.9.7") // challengePassword + if (oid == "1.3.6.1.4.1.311.90.2.10") // challengePassword { var valueSet = attrSeq.ReadSetOf(); while (valueSet.HasData) diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index afd52214f9..7568814a28 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -138,7 +138,19 @@ public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIs } [TestMethod] - public void TestCsrGeneration() + public void TestCsrGeneration_OnlyVmId() + { + var cuid = new CuidInfo + { + VmId = TestConstants.VmId + }; + + var csrPem = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); + CsrValidator.ValidateCsrContent(csrPem, TestConstants.ClientId, TestConstants.TenantId, cuid); + } + + [TestMethod] + public void TestCsrGeneration_VmIdAndVmssId() { var cuid = new CuidInfo { @@ -146,10 +158,7 @@ public void TestCsrGeneration() VmssId = TestConstants.VmssId }; - // Generate CSR var csrPem = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); - - // Validate the CSR contents using the helper CsrValidator.ValidateCsrContent(csrPem, TestConstants.ClientId, TestConstants.TenantId, cuid); } From 3c3dcdf8d77332c3fb2d2fe5601adcba19674c0c Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Mon, 25 Aug 2025 18:24:47 -0400 Subject: [PATCH 26/29] Removed null check on vmId. Created CuidInfo.IsNullOrEmpty --- .../ManagedIdentity/V2/CsrMetadata.cs | 9 +++++++-- .../Core/Mocks/MockHelpers.cs | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs index 6281fcec14..87ca519e12 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs @@ -19,6 +19,12 @@ internal class CuidInfo [JsonProperty("vmssId")] public string VmssId { get; set; } + + public static bool IsNullOrEmpty(CuidInfo cuidInfo) + { + return cuidInfo == null || + (string.IsNullOrEmpty(cuidInfo.VmId) && string.IsNullOrEmpty(cuidInfo.VmssId)); + } } /// @@ -61,8 +67,7 @@ public CsrMetadata() { } public static bool ValidateCsrMetadata(CsrMetadata csrMetadata) { if (csrMetadata == null || - csrMetadata.CuId == null || - string.IsNullOrEmpty(csrMetadata.CuId.VmId) || + CuidInfo.IsNullOrEmpty(csrMetadata.CuId) || string.IsNullOrEmpty(csrMetadata.ClientId) || string.IsNullOrEmpty(csrMetadata.TenantId) || string.IsNullOrEmpty(csrMetadata.AttestationEndpoint)) diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index 21ad2c3695..81143817b0 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -595,7 +595,7 @@ public static MockHttpMessageHandler MockCsrResponse( string content = "{" + - "\"cuid\": { \"vmId\": \"fake_vmId\", \"vmssId\": \"fake_vmssId\" }," + + "\"cuId\": { \"vmId\": \"fake_vmId\" }," + "\"clientId\": \"fake_client_id\"," + "\"tenantId\": \"fake_tenant_id\"," + "\"attestationEndpoint\": \"fake_attestation_endpoint\"" + From f51cdf9ce9e0ca28cf100412ef537e0cbe09c509 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Tue, 26 Aug 2025 14:42:35 -0400 Subject: [PATCH 27/29] Implemented feedback --- .../ManagedIdentity/AbstractManagedIdentity.cs | 4 ++-- .../AppServiceManagedIdentitySource.cs | 8 +++----- .../AzureArcManagedIdentitySource.cs | 6 +++--- .../CloudShellManagedIdentitySource.cs | 6 +++--- .../ManagedIdentity/ImdsManagedIdentitySource.cs | 4 ++-- .../MachineLearningManagedIdentitySource.cs | 6 +++--- .../ServiceFabricManagedIdentitySource.cs | 6 +++--- .../V2/ImdsV2ManagedIdentitySource.cs | 12 +++++------- .../ManagedIdentityTests/ImdsV2Tests.cs | 14 +++++--------- 9 files changed, 29 insertions(+), 37 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs index ad2b9a0c17..276fe67c78 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs @@ -55,7 +55,7 @@ public virtual async Task AuthenticateAsync( // Convert the scopes to a resource string. string resource = parameters.Resource; - ManagedIdentityRequest request = CreateRequest(resource); + ManagedIdentityRequest request = await CreateRequestAsync(resource).ConfigureAwait(false); // Automatically add claims / capabilities if this MI source supports them if (_sourceType.SupportsClaimsAndCapabilities()) @@ -149,7 +149,7 @@ protected virtual Task HandleResponseAsync( throw exception; } - protected abstract ManagedIdentityRequest CreateRequest(string resource); + protected abstract Task CreateRequestAsync(string resource); protected ManagedIdentityResponse GetSuccessfulResponse(HttpResponse response) { diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AppServiceManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AppServiceManagedIdentitySource.cs index 10fef4610b..6c8cb95f7f 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AppServiceManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AppServiceManagedIdentitySource.cs @@ -2,12 +2,10 @@ // Licensed under the MIT License. using System; -using System.Collections.Generic; using System.Globalization; -using Microsoft.Identity.Client.ApiConfig.Parameters; +using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; -using Microsoft.Identity.Client.Utils; namespace Microsoft.Identity.Client.ManagedIdentity { @@ -66,7 +64,7 @@ private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger return true; } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override Task CreateRequestAsync(string resource) { ManagedIdentityRequest request = new(System.Net.Http.HttpMethod.Get, _endpoint); @@ -92,7 +90,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) break; } - return request; + return Task.FromResult(request); } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs index 8071a13944..ae3048e401 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs @@ -79,7 +79,7 @@ private AzureArcManagedIdentitySource(Uri endpoint, RequestContext requestContex } } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override Task CreateRequestAsync(string resource) { ManagedIdentityRequest request = new ManagedIdentityRequest(System.Net.Http.HttpMethod.Get, _endpoint); @@ -87,7 +87,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) request.QueryParameters["api-version"] = ArcApiVersion; request.QueryParameters["resource"] = resource; - return request; + return Task.FromResult(request); } protected override async Task HandleResponseAsync( @@ -119,7 +119,7 @@ protected override async Task HandleResponseAsync( var authHeaderValue = "Basic " + File.ReadAllText(splitChallenge[1]); - ManagedIdentityRequest request = CreateRequest(parameters.Resource); + ManagedIdentityRequest request = await CreateRequestAsync(parameters.Resource).ConfigureAwait(false); _requestContext.Logger.Verbose(() => "[Managed Identity] Adding authorization header to the request."); request.Headers.Add("Authorization", authHeaderValue); diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs index 63a6eb493c..844458cbce 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs @@ -4,7 +4,7 @@ using System; using System.Globalization; using System.Net.Http; -using Microsoft.Identity.Client.ApiConfig.Parameters; +using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; @@ -74,7 +74,7 @@ private CloudShellManagedIdentitySource(Uri endpoint, RequestContext requestCont } } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override Task CreateRequestAsync(string resource) { ManagedIdentityRequest request = new ManagedIdentityRequest(HttpMethod.Post, _endpoint); @@ -83,7 +83,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) request.BodyParameters.Add("resource", resource); - return request; + return Task.FromResult(request); } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs index af6be6cf81..fdbb1d44b8 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs @@ -43,7 +43,7 @@ internal ImdsManagedIdentitySource(RequestContext requestContext) : requestContext.Logger.Verbose(() => "[Managed Identity] Creating IMDS managed identity source. Endpoint URI: " + _imdsEndpoint); } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override Task CreateRequestAsync(string resource) { ManagedIdentityRequest request = new(HttpMethod.Get, _imdsEndpoint); @@ -80,7 +80,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) request.RequestType = RequestType.Imds; - return request; + return Task.FromResult(request); } public static KeyValuePair? GetUserAssignedIdQueryParam( diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs index e6916fe919..9d2af3cadc 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs @@ -3,7 +3,7 @@ using System; using System.Globalization; -using Microsoft.Identity.Client.ApiConfig.Parameters; +using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; @@ -64,7 +64,7 @@ private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger return true; } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override Task CreateRequestAsync(string resource) { ManagedIdentityRequest request = new(System.Net.Http.HttpMethod.Get, _endpoint); @@ -108,7 +108,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) null); // statusCode is null in this case } - return request; + return Task.FromResult(request); } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs index 3224a8f3fe..55b6b28690 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs @@ -6,7 +6,7 @@ using System.Net.Http; using System.Net.Security; using System.Security.Cryptography.X509Certificates; -using Microsoft.Identity.Client.ApiConfig.Parameters; +using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; @@ -75,7 +75,7 @@ private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri en } } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override Task CreateRequestAsync(string resource) { ManagedIdentityRequest request = new ManagedIdentityRequest(HttpMethod.Get, _endpoint); @@ -102,7 +102,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) break; } - return request; + return Task.FromResult(request); } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index fbcdf877a2..726553c32f 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -121,9 +121,7 @@ private static bool ValidateCsrMetadataResponse( * "1556" // index 1: captured group (\d+) * ] */ - // Imds bug: headers are missing - // TODO: uncomment this when the bug is fixed - /*string serverHeader = response.HeadersAsDictionary.TryGetValue("server", out var value) ? value : null; + string serverHeader = response.HeadersAsDictionary.TryGetValue("server", out var value) ? value : null; if (serverHeader == null) { if (probeMode) @@ -158,7 +156,7 @@ private static bool ValidateCsrMetadataResponse( null, (int)response.StatusCode); } - }*/ + } return true; } @@ -249,12 +247,12 @@ private async Task ExecuteCertificateRequestAsync(st return certificateRequestResponse; } - protected override ManagedIdentityRequest CreateRequest(string resource) + protected override async Task CreateRequestAsync(string resource) { - var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); + var csrMetadata = await GetCsrMetadataAsync(_requestContext, false).ConfigureAwait(false); var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); - var certificateRequestResponse = ExecuteCertificateRequestAsync(csr).GetAwaiter().GetResult(); + var certificateRequestResponse = await ExecuteCertificateRequestAsync(csr).ConfigureAwait(false); throw new NotImplementedException(); } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 7568814a28..39d6c896c5 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -61,9 +61,7 @@ public async Task GetCsrMetadataAsyncSucceedsAfterRetry() } } - // Imds bug: headers are missing - // TODO: uncomment this when the bug is fixed - /*[TestMethod] + [TestMethod] public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() { using (var httpManager = new MockHttpManager()) @@ -78,11 +76,9 @@ public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } - }*/ - - // Imds bug: headers are missing - // TODO: uncomment this when the bug is fixed - /*[TestMethod] + } + + [TestMethod] public async Task GetCsrMetadataAsyncFailsWithInvalidVersion() { using (var httpManager = new MockHttpManager()) @@ -97,7 +93,7 @@ public async Task GetCsrMetadataAsyncFailsWithInvalidVersion() var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } - }*/ + } [TestMethod] public async Task GetCsrMetadataAsyncFailsAfterMaxRetries() From 5e7ab075dec0b5ab548d1be39517e479d6508db6 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Tue, 26 Aug 2025 17:22:06 -0400 Subject: [PATCH 28/29] Updated min version of imds, spec has incorrect info --- .../ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs | 2 +- .../ManagedIdentityTests/ImdsV2Tests.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index 726553c32f..def3034f6a 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -142,7 +142,7 @@ private static bool ValidateCsrMetadataResponse( serverHeader, @"^IMDS/\d+\.\d+\.\d+\.(\d+)$" ); - if (!match.Success || !int.TryParse(match.Groups[1].Value, out int version) || version <= 1324) + if (!match.Success || !int.TryParse(match.Groups[1].Value, out int version) || version < 1854) { if (probeMode) { diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 39d6c896c5..2a67dd80d9 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -83,7 +83,7 @@ public async Task GetCsrMetadataAsyncFailsWithInvalidVersion() { using (var httpManager = new MockHttpManager()) { - httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: "IMDS/150.870.65.1324")); + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: "IMDS/150.870.65.1853")); // min version is 1854 var managedIdentityApp = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager) From 362b4079641714c736ce6341c52bcdc4135092f8 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 27 Aug 2025 11:25:42 -0400 Subject: [PATCH 29/29] Updated a comment --- .../Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs index 87ca519e12..a03f9f69bc 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs @@ -63,7 +63,7 @@ public CsrMetadata() { } /// Validates a JSON decoded CsrMetadata instance. /// /// The CsrMetadata object. - /// false if any required field is null. Note: VmId is required, VmssId is optional. + /// false if any field is null or empty public static bool ValidateCsrMetadata(CsrMetadata csrMetadata) { if (csrMetadata == null ||