diff --git a/sdk/identity/Azure.Identity/CHANGELOG.md b/sdk/identity/Azure.Identity/CHANGELOG.md index 818430ec06a39..05732fafbcd42 100644 --- a/sdk/identity/Azure.Identity/CHANGELOG.md +++ b/sdk/identity/Azure.Identity/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- Implement `OnBehalfOfCredential` which enables authentication to Azure Active Directory using an On-Behalf-Of flow. + ### Breaking Changes ### Bugs Fixed @@ -43,6 +45,7 @@ Thank you to our developer community members who helped to make Azure Identity b - Added support to `ManagedIdentityCredential` for Bridge to Kubernetes local development authentication. - TenantId values returned from service challenge responses can now be used to request tokens from the correct tenantId. To support this feature, there is a new `AllowMultiTenantAuthentication` option on `TokenCredentialOptions`. - By default, `AllowMultiTenantAuthentication` is false. When this option property is false and the tenant Id configured in the credential options differs from the tenant Id set in the `TokenRequestContext` sent to a credential, an `AuthorizationFailedException` will be thrown. This is potentially breaking change as it could be a different exception than what was thrown previously. This exception behavior can be overridden by either setting an `AppContext` switch named "Azure.Identity.EnableLegacyTenantSelection" to `true` or by setting the environment variable "AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION" to "true". Note: AppContext switches can also be configured via configuration like below: +- Added `OnBehalfOfFlowCredential` which enables support for AAD On-Behalf-Of (OBO) flow. See the [Azure Active Directory documentation](https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-on-behalf-of-flow) to learn more about OBO flow scenarios. ```xml diff --git a/sdk/identity/Azure.Identity/api/Azure.Identity.netstandard2.0.cs b/sdk/identity/Azure.Identity/api/Azure.Identity.netstandard2.0.cs index 481a0d77a48ae..b9449e0d09a35 100644 --- a/sdk/identity/Azure.Identity/api/Azure.Identity.netstandard2.0.cs +++ b/sdk/identity/Azure.Identity/api/Azure.Identity.netstandard2.0.cs @@ -237,6 +237,22 @@ public ManagedIdentityCredential(string clientId = null, Azure.Identity.TokenCre public override Azure.Core.AccessToken GetToken(Azure.Core.TokenRequestContext requestContext, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override System.Threading.Tasks.ValueTask GetTokenAsync(Azure.Core.TokenRequestContext requestContext, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } } + public partial class OnBehalfOfCredential : Azure.Core.TokenCredential + { + protected OnBehalfOfCredential() { } + public OnBehalfOfCredential(string tenantId, string clientId, System.Security.Cryptography.X509Certificates.X509Certificate2 clientCertificate, string userAssertion) { } + public OnBehalfOfCredential(string tenantId, string clientId, System.Security.Cryptography.X509Certificates.X509Certificate2 clientCertificate, string userAssertion, Azure.Identity.OnBehalfOfCredentialOptions options) { } + public OnBehalfOfCredential(string tenantId, string clientId, string clientSecret, string userAssertion, Azure.Identity.OnBehalfOfCredentialOptions options = null) { } + public override Azure.Core.AccessToken GetToken(Azure.Core.TokenRequestContext requestContext, System.Threading.CancellationToken cancellationToken) { throw null; } + public override System.Threading.Tasks.ValueTask GetTokenAsync(Azure.Core.TokenRequestContext requestContext, System.Threading.CancellationToken cancellationToken) { throw null; } + } + public partial class OnBehalfOfCredentialOptions : Azure.Identity.TokenCredentialOptions + { + public OnBehalfOfCredentialOptions() { } + public Azure.Identity.RegionalAuthority? RegionalAuthority { get { throw null; } set { } } + public bool SendCertificateChain { get { throw null; } set { } } + public Azure.Identity.TokenCachePersistenceOptions TokenCachePersistenceOptions { get { throw null; } set { } } + } [System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)] public readonly partial struct RegionalAuthority : System.IEquatable { @@ -326,6 +342,18 @@ public SharedTokenCacheCredentialOptions(Azure.Identity.TokenCachePersistenceOpt public Azure.Identity.TokenCachePersistenceOptions TokenCachePersistenceOptions { get { throw null; } } public string Username { get { throw null; } set { } } } + [System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)] + public partial struct TokenCacheDetails + { + private object _dummy; + private int _dummyPrimitive; + public System.ReadOnlyMemory CacheBytes { get { throw null; } set { } } + } + public partial class TokenCacheNotificationDetails + { + internal TokenCacheNotificationDetails() { } + public string SuggestedCacheKey { get { throw null; } } + } public partial class TokenCachePersistenceOptions { public TokenCachePersistenceOptions() { } @@ -348,6 +376,7 @@ public abstract partial class UnsafeTokenCacheOptions : Azure.Identity.TokenCach { protected UnsafeTokenCacheOptions() { } protected internal abstract System.Threading.Tasks.Task> RefreshCacheAsync(); + protected internal virtual System.Threading.Tasks.Task RefreshCacheAsync(Azure.Identity.TokenCacheNotificationDetails details) { throw null; } protected internal abstract System.Threading.Tasks.Task TokenCacheUpdatedAsync(Azure.Identity.TokenCacheUpdatedArgs tokenCacheUpdatedArgs); } public partial class UsernamePasswordCredential : Azure.Core.TokenCredential diff --git a/sdk/identity/Azure.Identity/src/ClientCertificateCredential.cs b/sdk/identity/Azure.Identity/src/ClientCertificateCredential.cs index 5269fddabdbeb..0fa09e12b5b15 100644 --- a/sdk/identity/Azure.Identity/src/ClientCertificateCredential.cs +++ b/sdk/identity/Azure.Identity/src/ClientCertificateCredential.cs @@ -41,8 +41,7 @@ public class ClientCertificateCredential : TokenCredential /// Protected constructor for mocking. /// protected ClientCertificateCredential() - { - } + { } /// /// Creates an instance of the ClientCertificateCredential with the details needed to authenticate against Azure Active Directory with the specified certificate. @@ -52,8 +51,7 @@ protected ClientCertificateCredential() /// The path to a file which contains both the client certificate and private key. public ClientCertificateCredential(string tenantId, string clientId, string clientCertificatePath) : this(tenantId, clientId, clientCertificatePath, null, null, null) - { - } + { } /// /// Creates an instance of the ClientCertificateCredential with the details needed to authenticate against Azure Active Directory with the specified certificate. @@ -64,8 +62,7 @@ public ClientCertificateCredential(string tenantId, string clientId, string clie /// Options that allow to configure the management of the requests sent to the Azure Active Directory service. public ClientCertificateCredential(string tenantId, string clientId, string clientCertificatePath, TokenCredentialOptions options) : this(tenantId, clientId, clientCertificatePath, options, null, null) - { - } + { } /// /// Creates an instance of the ClientCertificateCredential with the details needed to authenticate against Azure Active Directory with the specified certificate. @@ -86,8 +83,7 @@ public ClientCertificateCredential(string tenantId, string clientId, string clie /// The authentication X509 Certificate of the service principal public ClientCertificateCredential(string tenantId, string clientId, X509Certificate2 clientCertificate) : this(tenantId, clientId, clientCertificate, null, null, null) - { - } + { } /// /// Creates an instance of the ClientCertificateCredential with the details needed to authenticate against Azure Active Directory with the specified certificate. @@ -97,7 +93,8 @@ public ClientCertificateCredential(string tenantId, string clientId, X509Certifi /// The authentication X509 Certificate of the service principal /// Options that allow to configure the management of the requests sent to the Azure Active Directory service. public ClientCertificateCredential(string tenantId, string clientId, X509Certificate2 clientCertificate, TokenCredentialOptions options) - : this(tenantId, clientId, clientCertificate, options, null, null) {} + : this(tenantId, clientId, clientCertificate, options, null, null) + { } /// /// Creates an instance of the ClientCertificateCredential with the details needed to authenticate against Azure Active Directory with the specified certificate. @@ -108,20 +105,47 @@ public ClientCertificateCredential(string tenantId, string clientId, X509Certifi /// Options that allow to configure the management of the requests sent to the Azure Active Directory service. public ClientCertificateCredential(string tenantId, string clientId, X509Certificate2 clientCertificate, ClientCertificateCredentialOptions options) : this(tenantId, clientId, clientCertificate, options, null, null) - { - } + { } - internal ClientCertificateCredential(string tenantId, string clientId, string certificatePath, TokenCredentialOptions options, CredentialPipeline pipeline, MsalConfidentialClient client) - : this(tenantId, clientId, new X509Certificate2FromFileProvider(certificatePath ?? throw new ArgumentNullException(nameof(certificatePath))), options, pipeline, client) - { - } + internal ClientCertificateCredential( + string tenantId, + string clientId, + string certificatePath, + TokenCredentialOptions options, + CredentialPipeline pipeline, + MsalConfidentialClient client) + : this( + tenantId, + clientId, + new X509Certificate2FromFileProvider(certificatePath ?? throw new ArgumentNullException(nameof(certificatePath))), + options, + pipeline, + client) + { } - internal ClientCertificateCredential(string tenantId, string clientId, X509Certificate2 certificate, TokenCredentialOptions options, CredentialPipeline pipeline, MsalConfidentialClient client) - : this(tenantId, clientId, new X509Certificate2FromObjectProvider(certificate ?? throw new ArgumentNullException(nameof(certificate))), options, pipeline, client) - { - } + internal ClientCertificateCredential( + string tenantId, + string clientId, + X509Certificate2 certificate, + TokenCredentialOptions options, + CredentialPipeline pipeline, + MsalConfidentialClient client) + : this( + tenantId, + clientId, + new X509Certificate2FromObjectProvider(certificate ?? throw new ArgumentNullException(nameof(certificate))), + options, + pipeline, + client) + { } - internal ClientCertificateCredential(string tenantId, string clientId, IX509Certificate2Provider certificateProvider, TokenCredentialOptions options, CredentialPipeline pipeline, MsalConfidentialClient client) + internal ClientCertificateCredential( + string tenantId, + string clientId, + IX509Certificate2Provider certificateProvider, + TokenCredentialOptions options, + CredentialPipeline pipeline, + MsalConfidentialClient client) { TenantId = Validations.ValidateTenantId(tenantId, nameof(tenantId)); @@ -193,152 +217,5 @@ public override async ValueTask GetTokenAsync(TokenRequestContext r throw scope.FailWrapAndThrow(e); } } - - /// - /// IX509Certificate2Provider provides a way to control how the X509Certificate2 object is fetched. - /// - internal interface IX509Certificate2Provider - { - ValueTask GetCertificateAsync(bool async, CancellationToken cancellationToken); - } - - /// - /// X509Certificate2FromObjectProvider provides an X509Certificate2 from an existing instance. - /// - private class X509Certificate2FromObjectProvider : IX509Certificate2Provider - { - private X509Certificate2 Certificate { get; } - - public X509Certificate2FromObjectProvider(X509Certificate2 clientCertificate) - { - Certificate = clientCertificate ?? throw new ArgumentNullException(nameof(clientCertificate)); - } - - public ValueTask GetCertificateAsync(bool async, CancellationToken cancellationToken) - { - return new ValueTask(Certificate); - } - } - - /// - /// X509Certificate2FromFileProvider provides an X509Certificate2 from a file on disk. It supports both - /// "pfx" and "pem" encoded certificates. - /// - internal class X509Certificate2FromFileProvider : IX509Certificate2Provider - { - // Lazy initialized on the first call to GetCertificateAsync, based on CertificatePath. - private X509Certificate2 Certificate { get; set; } - internal string CertificatePath { get; } - - public X509Certificate2FromFileProvider(string clientCertificatePath) - { - CertificatePath = clientCertificatePath ?? throw new ArgumentNullException(nameof(clientCertificatePath)); - } - - public ValueTask GetCertificateAsync(bool async, CancellationToken cancellationToken) - { - if (!(Certificate is null)) - { - return new ValueTask(Certificate); - } - - string fileType = Path.GetExtension(CertificatePath); - - switch (fileType.ToLowerInvariant()) - { - case ".pfx": - return LoadCertificateFromPfxFileAsync(async, CertificatePath, cancellationToken); - case ".pem": - return LoadCertificateFromPemFileAsync(async, CertificatePath, cancellationToken); - default: - throw new CredentialUnavailableException("Only .pfx and .pem files are supported."); - } - } - - private async ValueTask LoadCertificateFromPfxFileAsync(bool async, string clientCertificatePath, CancellationToken cancellationToken) - { - const int BufferSize = 4 * 1024; - - if (!(Certificate is null)) - { - return Certificate; - } - - try - { - if (!async) - { - Certificate = new X509Certificate2(clientCertificatePath); - } - else - { - List certContents = new List(); - byte[] buf = new byte[BufferSize]; - int offset = 0; - using (Stream s = File.OpenRead(clientCertificatePath)) - { - while (true) - { - int read = await s.ReadAsync(buf, offset, buf.Length, cancellationToken).ConfigureAwait(false); - for (int i = 0; i < read; i++) - { - certContents.Add(buf[i]); - } - - if (read == 0) - { - break; - } - } - } - - Certificate = new X509Certificate2(certContents.ToArray()); - } - - return Certificate; - } - catch (Exception e) when (!(e is OperationCanceledException)) - { - throw new CredentialUnavailableException("Could not load certificate file", e); - } - } - - private async ValueTask LoadCertificateFromPemFileAsync(bool async, string clientCertificatePath, CancellationToken cancellationToken) - { - if (!(Certificate is null)) - { - return Certificate; - } - - string certficateText; - - try - { - if (!async) - { - certficateText = File.ReadAllText(clientCertificatePath); - } - else - { - cancellationToken.ThrowIfCancellationRequested(); - - using (StreamReader sr = new StreamReader(clientCertificatePath)) - { - certficateText = await sr.ReadToEndAsync().ConfigureAwait(false); - } - } - - Certificate = PemReader.LoadCertificate(certficateText.AsSpan(), keyType: PemReader.KeyType.RSA); - - return Certificate; - } - catch (Exception e) when (!(e is OperationCanceledException)) - { - throw new CredentialUnavailableException("Could not load certificate file", e); - } - } - - private delegate void ImportPkcs8PrivateKeyDelegate(ReadOnlySpan blob, out int bytesRead); - } } } diff --git a/sdk/identity/Azure.Identity/src/IX509Certificate2Provider.cs b/sdk/identity/Azure.Identity/src/IX509Certificate2Provider.cs new file mode 100644 index 0000000000000..2c265860f37ec --- /dev/null +++ b/sdk/identity/Azure.Identity/src/IX509Certificate2Provider.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace Azure.Identity +{ + /// + /// IX509Certificate2Provider provides a way to control how the X509Certificate2 object is fetched. + /// + internal interface IX509Certificate2Provider + { + ValueTask GetCertificateAsync(bool async, CancellationToken cancellationToken); + } +} diff --git a/sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs b/sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs index 240093ea3a88a..94548f16f4343 100644 --- a/sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs +++ b/sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs @@ -10,9 +10,9 @@ namespace Azure.Identity { internal class MsalConfidentialClient : MsalClientBase { - private readonly string _clientSecret; - private readonly bool _includeX5CClaimHeader; - private readonly ClientCertificateCredential.IX509Certificate2Provider _certificateProvider; + internal readonly string _clientSecret; + internal readonly bool _includeX5CClaimHeader; + internal readonly IX509Certificate2Provider _certificateProvider; /// /// For mocking purposes only. @@ -27,7 +27,7 @@ public MsalConfidentialClient(CredentialPipeline pipeline, string tenantId, stri RegionalAuthority = regionalAuthority; } - public MsalConfidentialClient(CredentialPipeline pipeline, string tenantId, string clientId, ClientCertificateCredential.IX509Certificate2Provider certificateProvider, bool includeX5CClaimHeader, ITokenCacheOptions cacheOptions, RegionalAuthority? regionalAuthority, bool isPiiLoggingEnabled) + public MsalConfidentialClient(CredentialPipeline pipeline, string tenantId, string clientId, IX509Certificate2Provider certificateProvider, bool includeX5CClaimHeader, ITokenCacheOptions cacheOptions, RegionalAuthority? regionalAuthority, bool isPiiLoggingEnabled) : base(pipeline, tenantId, clientId, isPiiLoggingEnabled, cacheOptions) { _includeX5CClaimHeader = includeX5CClaimHeader; @@ -124,5 +124,25 @@ public virtual async ValueTask AcquireTokenByAuthorization .ExecuteAsync(async, cancellationToken) .ConfigureAwait(false); } + + public virtual async ValueTask AcquireTokenOnBehalfOf( + string[] scopes, + string tenantId, + UserAssertion userAssertionValue, + bool async, + CancellationToken cancellationToken) + { + IConfidentialClientApplication client = await GetClientAsync(async, cancellationToken).ConfigureAwait(false); + + var builder = client.AcquireTokenOnBehalfOf(scopes, userAssertionValue); + + if (!string.IsNullOrEmpty(tenantId)) + { + builder.WithAuthority(Pipeline.AuthorityHost.AbsoluteUri, tenantId); + } + return await builder + .ExecuteAsync(async, cancellationToken) + .ConfigureAwait(false); + } } } diff --git a/sdk/identity/Azure.Identity/src/OnBehalfOfCredential.cs b/sdk/identity/Azure.Identity/src/OnBehalfOfCredential.cs new file mode 100644 index 0000000000000..b679aa052fa8b --- /dev/null +++ b/sdk/identity/Azure.Identity/src/OnBehalfOfCredential.cs @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Pipeline; +using Microsoft.Identity.Client; + +namespace Azure.Identity +{ + /// + /// Enables authentication to Azure Active Directory using an On-Behalf-Of flow. + /// + public class OnBehalfOfCredential : TokenCredential + { + internal readonly MsalConfidentialClient _client; + private readonly string _tenantId; + private readonly CredentialPipeline _pipeline; + private readonly bool _allowMultiTenantAuthentication; + private readonly string _clientId; + private readonly string _clientSecret; + private readonly UserAssertion _userAssertion; + + /// + /// Protected constructor for mocking. + /// + protected OnBehalfOfCredential() + { } + + /// + /// Creates an instance of the OnBehalfOfCredential with the details needed to authenticate against Azure Active Directory with the specified certificate. + /// + /// The Azure Active Directory tenant (directory) Id of the service principal. + /// The client (application) ID of the service principal + /// The authentication X509 Certificate of the service principal + /// The access token that will be used by as the user assertion when requesting On-Behalf-Of tokens. + public OnBehalfOfCredential(string tenantId, string clientId, X509Certificate2 clientCertificate, string userAssertion) + : this(tenantId, clientId, clientCertificate, userAssertion, null, null, null) + { } + + /// + /// Creates an instance of the OnBehalfOfCredential with the details needed to authenticate against Azure Active Directory with the specified certificate. + /// + /// The Azure Active Directory tenant (directory) Id of the service principal. + /// The client (application) ID of the service principal + /// The authentication X509 Certificate of the service principal + /// The access token that will be used by as the user assertion when requesting On-Behalf-Of tokens. + /// Options that allow to configure the management of the requests sent to the Azure Active Directory service. + public OnBehalfOfCredential(string tenantId, string clientId, X509Certificate2 clientCertificate, string userAssertion, OnBehalfOfCredentialOptions options) + : this(tenantId, clientId, clientCertificate, userAssertion, options, null, null) + { } + + internal OnBehalfOfCredential( + string tenantId, + string clientId, + X509Certificate2 certificate, + string userAssertion, + OnBehalfOfCredentialOptions options, + CredentialPipeline pipeline, + MsalConfidentialClient client) + : this( + tenantId, + clientId, + new X509Certificate2FromObjectProvider(certificate ?? throw new ArgumentNullException(nameof(certificate))), + userAssertion, + options, + pipeline, + client) + { } + + internal OnBehalfOfCredential( + string tenantId, + string clientId, + IX509Certificate2Provider certificateProvider, + string userAssertion, + OnBehalfOfCredentialOptions options, + CredentialPipeline pipeline, + MsalConfidentialClient client) + { + _tenantId = Validations.ValidateTenantId(tenantId, nameof(tenantId)); + _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); + _allowMultiTenantAuthentication = options?.AllowMultiTenantAuthentication ?? false; + _pipeline = pipeline ?? CredentialPipeline.GetInstance(options); + options ??= new OnBehalfOfCredentialOptions(); + _userAssertion = new UserAssertion(userAssertion); + + _client = client ?? + new MsalConfidentialClient( + _pipeline, + tenantId, + clientId, + certificateProvider, + options.SendCertificateChain, + options, + options.RegionalAuthority, + options.IsLoggingPIIEnabled); + } + + /// + /// Creates an instance of the with the details needed to authenticate with Azure Active Directory. + /// + /// The Azure Active Directory tenant (directory) Id of the service principal. + /// The client (application) ID of the service principal + /// A client secret that was generated for the App Registration used to authenticate the client. + /// The access token that will be used by as the user assertion when requesting On-Behalf-Of tokens. + /// Options that allow to configure the management of the requests sent to the Azure Active Directory service. + public OnBehalfOfCredential( + string tenantId, + string clientId, + string clientSecret, + string userAssertion, + OnBehalfOfCredentialOptions options = null) + : this(tenantId, clientId, clientSecret, userAssertion, options, null, null) + { } + + internal OnBehalfOfCredential( + string tenantId, + string clientId, + string clientSecret, + string userAssertion, + OnBehalfOfCredentialOptions options, + CredentialPipeline pipeline, + MsalConfidentialClient client) + { + Argument.AssertNotNull(clientId, nameof(clientId)); + Argument.AssertNotNull(clientSecret, nameof(clientSecret)); + + options ??= new OnBehalfOfCredentialOptions(); + _pipeline = pipeline ?? CredentialPipeline.GetInstance(options); + _allowMultiTenantAuthentication = options.AllowMultiTenantAuthentication; + _tenantId = Validations.ValidateTenantId(tenantId, nameof(tenantId)); + _clientId = clientId; + _clientSecret = clientSecret; + _userAssertion = new UserAssertion(userAssertion); + _client = client ?? new MsalConfidentialClient(_pipeline, _tenantId, _clientId, _clientSecret, options, default, options.IsLoggingPIIEnabled); + } + + /// + public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) => + GetTokenInternalAsync(requestContext, false, cancellationToken).EnsureCompleted(); + + /// + public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) => + GetTokenInternalAsync(requestContext, true, cancellationToken); + + internal async ValueTask GetTokenInternalAsync(TokenRequestContext requestContext, bool async, CancellationToken cancellationToken) + { + using CredentialDiagnosticScope scope = _pipeline.StartGetTokenScope("OnBehalfOfCredential.GetToken", requestContext); + + try + { + var tenantId = TenantIdResolver.Resolve(_tenantId, requestContext, _allowMultiTenantAuthentication); + + AuthenticationResult result = await _client + .AcquireTokenOnBehalfOf(requestContext.Scopes, tenantId, _userAssertion, async, cancellationToken) + .ConfigureAwait(false); + + return new AccessToken(result.AccessToken, result.ExpiresOn); + } + catch (Exception e) + { + throw scope.FailWrapAndThrow(e); + } + } + } +} diff --git a/sdk/identity/Azure.Identity/src/OnBehalfOfCredentialOptions.cs b/sdk/identity/Azure.Identity/src/OnBehalfOfCredentialOptions.cs new file mode 100644 index 0000000000000..2e7e340086a61 --- /dev/null +++ b/sdk/identity/Azure.Identity/src/OnBehalfOfCredentialOptions.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Azure.Identity +{ + /// + /// + /// + public class OnBehalfOfCredentialOptions : TokenCredentialOptions, ITokenCacheOptions + { + /// + /// The . + /// + public TokenCachePersistenceOptions TokenCachePersistenceOptions { get; set; } + + /// + /// Will include x5c header in client claims when acquiring a token to enable subject name / issuer based authentication for the . + /// + public bool SendCertificateChain { get; set; } + + /// + /// Specifies either the specific (preferred), or use to attempt to auto-detect the region. + /// If not specified or auto-detection fails the non-regional endpoint will be used. + /// + public RegionalAuthority? RegionalAuthority { get; set; } = Azure.Identity.RegionalAuthority.FromEnvironment(); + } +} diff --git a/sdk/identity/Azure.Identity/src/TokenCache.cs b/sdk/identity/Azure.Identity/src/TokenCache.cs index 73f2a4dba2eda..5c4890dd1f41c 100644 --- a/sdk/identity/Azure.Identity/src/TokenCache.cs +++ b/sdk/identity/Azure.Identity/src/TokenCache.cs @@ -6,7 +6,6 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; -using Azure.Core.Pipeline; using Microsoft.Identity.Client; using Microsoft.Identity.Client.Extensions.Msal; @@ -88,7 +87,7 @@ internal TokenCache(TokenCachePersistenceOptions options, MsalCacheHelperWrapper /// /// A delegate that will be called before the cache is accessed. The data returned will be used to set the current state of the cache. /// - internal Func>> RefreshCacheFromOptionsAsync; + internal Func> RefreshCacheFromOptionsAsync; internal virtual async Task RegisterCache(bool async, ITokenCache tokenCache, CancellationToken cancellationToken) { @@ -143,7 +142,8 @@ private async Task OnBeforeCacheAccessAsync(TokenCacheNotificationArgs args) { if (RefreshCacheFromOptionsAsync != null) { - Data = (await RefreshCacheFromOptionsAsync().ConfigureAwait(false)).ToArray(); + Data = (await RefreshCacheFromOptionsAsync(new TokenCacheNotificationDetails(args)).ConfigureAwait(false)) + .CacheBytes.ToArray(); } args.TokenCache.DeserializeMsalV3(Data, true); diff --git a/sdk/identity/Azure.Identity/src/TokenCacheDetails.cs b/sdk/identity/Azure.Identity/src/TokenCacheDetails.cs new file mode 100644 index 0000000000000..d9ab5e8eed16c --- /dev/null +++ b/sdk/identity/Azure.Identity/src/TokenCacheDetails.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; + +namespace Azure.Identity +{ + /// + /// Details related to a cache delegate. + /// + public struct TokenCacheDetails + { + /// + /// The bytes representing the state of the token cache. + /// + public ReadOnlyMemory CacheBytes { get; set; } + } +} diff --git a/sdk/identity/Azure.Identity/src/TokenCacheNotificationDetails.cs b/sdk/identity/Azure.Identity/src/TokenCacheNotificationDetails.cs new file mode 100644 index 0000000000000..9eeebbf2b77b7 --- /dev/null +++ b/sdk/identity/Azure.Identity/src/TokenCacheNotificationDetails.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.Identity.Client; + +namespace Azure.Identity +{ + /// + /// Args setnt to TokenCache OnBefore and OnAfter events. + /// + public class TokenCacheNotificationDetails + { + /// + /// A suggested token cache key, which can be used with general purpose storage mechanisms that allow + /// storing key-value pairs and key based retrieval. Useful in applications that store 1 token cache per user, + /// the recommended pattern for web apps. + /// + /// The value is: + /// + /// + /// the homeAccountId for AcquireTokenSilent, GetAccount(homeAccountId), RemoveAccount and when writing tokens on confidential client calls + /// clientID + "_AppTokenCache" for AcquireTokenForClient + /// clientID_tenantID + "_AppTokenCache" for AcquireTokenForClient when tenant specific authority + /// the hash of the original token for AcquireTokenOnBehalfOf + /// + /// + public string SuggestedCacheKey { get; } + + internal TokenCacheNotificationDetails(TokenCacheNotificationArgs args) + { + SuggestedCacheKey = args.SuggestedCacheKey; + } + } +} diff --git a/sdk/identity/Azure.Identity/src/UnsafeTokenCacheOptions.cs b/sdk/identity/Azure.Identity/src/UnsafeTokenCacheOptions.cs index 5593d754d31e1..cecba653bed9e 100644 --- a/sdk/identity/Azure.Identity/src/UnsafeTokenCacheOptions.cs +++ b/sdk/identity/Azure.Identity/src/UnsafeTokenCacheOptions.cs @@ -18,9 +18,20 @@ public abstract class UnsafeTokenCacheOptions : TokenCachePersistenceOptions protected internal abstract Task TokenCacheUpdatedAsync(TokenCacheUpdatedArgs tokenCacheUpdatedArgs); /// - /// The bytes used to initialize the token cache. This would most likely have come from the . + /// Returns the bytes used to initialize the token cache. This would most likely have come from the . + /// This implementation will get called by the default implementation of . + /// It is recommended to provide an implementation for rather than this method. /// /// protected internal abstract Task> RefreshCacheAsync(); + + /// + /// Returns the bytes used to initialize the token cache. This would most likely have come from the . + /// It is recommended that if this method is overriden, there is no need to provide a duplicate implementation for the parameterless . + /// + /// + /// + protected internal virtual async Task RefreshCacheAsync(TokenCacheNotificationDetails details) => + new() {CacheBytes = await RefreshCacheAsync().ConfigureAwait(false)}; } } diff --git a/sdk/identity/Azure.Identity/src/X509Certificate2FromFileProvider.cs b/sdk/identity/Azure.Identity/src/X509Certificate2FromFileProvider.cs new file mode 100644 index 0000000000000..e2d2354e8f972 --- /dev/null +++ b/sdk/identity/Azure.Identity/src/X509Certificate2FromFileProvider.cs @@ -0,0 +1,135 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; + +namespace Azure.Identity +{ + /// + /// X509Certificate2FromFileProvider provides an X509Certificate2 from a file on disk. It supports both + /// "pfx" and "pem" encoded certificates. + /// + internal class X509Certificate2FromFileProvider : IX509Certificate2Provider + { + // Lazy initialized on the first call to GetCertificateAsync, based on CertificatePath. + private X509Certificate2 Certificate { get; set; } + internal string CertificatePath { get; } + + public X509Certificate2FromFileProvider(string clientCertificatePath) + { + Argument.AssertNotNull(clientCertificatePath, nameof(clientCertificatePath)); + CertificatePath = clientCertificatePath ?? throw new ArgumentNullException(nameof(clientCertificatePath)); + } + + public ValueTask GetCertificateAsync(bool async, CancellationToken cancellationToken) + { + if (!(Certificate is null)) + { + return new ValueTask(Certificate); + } + + string fileType = Path.GetExtension(CertificatePath); + + switch (fileType.ToLowerInvariant()) + { + case ".pfx": + return LoadCertificateFromPfxFileAsync(async, CertificatePath, cancellationToken); + case ".pem": + return LoadCertificateFromPemFileAsync(async, CertificatePath, cancellationToken); + default: + throw new CredentialUnavailableException("Only .pfx and .pem files are supported."); + } + } + + private async ValueTask LoadCertificateFromPfxFileAsync(bool async, string clientCertificatePath, CancellationToken cancellationToken) + { + const int BufferSize = 4 * 1024; + + if (!(Certificate is null)) + { + return Certificate; + } + + try + { + if (!async) + { + Certificate = new X509Certificate2(clientCertificatePath); + } + else + { + List certContents = new List(); + byte[] buf = new byte[BufferSize]; + int offset = 0; + using (Stream s = File.OpenRead(clientCertificatePath)) + { + while (true) + { + int read = await s.ReadAsync(buf, offset, buf.Length, cancellationToken).ConfigureAwait(false); + for (int i = 0; i < read; i++) + { + certContents.Add(buf[i]); + } + + if (read == 0) + { + break; + } + } + } + + Certificate = new X509Certificate2(certContents.ToArray()); + } + + return Certificate; + } + catch (Exception e) when (!(e is OperationCanceledException)) + { + throw new CredentialUnavailableException("Could not load certificate file", e); + } + } + + private async ValueTask LoadCertificateFromPemFileAsync(bool async, string clientCertificatePath, CancellationToken cancellationToken) + { + if (!(Certificate is null)) + { + return Certificate; + } + + string certficateText; + + try + { + if (!async) + { + certficateText = File.ReadAllText(clientCertificatePath); + } + else + { + cancellationToken.ThrowIfCancellationRequested(); + + using (StreamReader sr = new StreamReader(clientCertificatePath)) + { + certficateText = await sr.ReadToEndAsync().ConfigureAwait(false); + } + } + + Certificate = PemReader.LoadCertificate(certficateText.AsSpan(), keyType: PemReader.KeyType.RSA); + + return Certificate; + } + catch (Exception e) when (!(e is OperationCanceledException)) + { + throw new CredentialUnavailableException("Could not load certificate file", e); + } + } + + private delegate void ImportPkcs8PrivateKeyDelegate(ReadOnlySpan blob, out int bytesRead); + } +} diff --git a/sdk/identity/Azure.Identity/src/X509Certificate2FromObjectProvider.cs b/sdk/identity/Azure.Identity/src/X509Certificate2FromObjectProvider.cs new file mode 100644 index 0000000000000..f721a0a45bb69 --- /dev/null +++ b/sdk/identity/Azure.Identity/src/X509Certificate2FromObjectProvider.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace Azure.Identity +{ + /// + /// X509Certificate2FromObjectProvider provides an X509Certificate2 from an existing instance. + /// + internal class X509Certificate2FromObjectProvider : IX509Certificate2Provider + { + private X509Certificate2 Certificate { get; } + + public X509Certificate2FromObjectProvider(X509Certificate2 clientCertificate) + { + Certificate = clientCertificate ?? throw new ArgumentNullException(nameof(clientCertificate)); + } + + public ValueTask GetCertificateAsync(bool async, CancellationToken cancellationToken) + { + return new ValueTask(Certificate); + } + } +} diff --git a/sdk/identity/Azure.Identity/tests/AuthorizationCodeCredentialTests.cs b/sdk/identity/Azure.Identity/tests/AuthorizationCodeCredentialTests.cs index 614a4ba210094..02e493bbf70ac 100644 --- a/sdk/identity/Azure.Identity/tests/AuthorizationCodeCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/AuthorizationCodeCredentialTests.cs @@ -4,67 +4,20 @@ using System; using System.Threading.Tasks; using Azure.Core; -using Azure.Core.TestFramework; -using Azure.Identity.Tests.Mock; -using Microsoft.Identity.Client; using NUnit.Framework; namespace Azure.Identity.Tests { - public class AuthorizationCodeCredentialTests : ClientTestBase + public class AuthorizationCodeCredentialTests : CredentialTestBase { - private const string ClientId = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"; - private const string TenantId = "a0287521-e002-0026-7112-207c0c000000"; - private const string TenantIdHint = "a0287521-e002-0026-7112-207c0c001234"; - private const string ReplyUrl = "https://myredirect/"; - private const string Scope = "https://vault.azure.net/.default"; - private TokenCredentialOptions options; - private string authCode; - private string expectedToken; - private DateTimeOffset expiresOn; - private MockMsalConfidentialClient mockMsalClient; - private string expectedTenantId; - private string expectedReplyUri; - private string clientSecret = Guid.NewGuid().ToString(); - private Func> silentFactory; - public AuthorizationCodeCredentialTests(bool isAsync) : base(isAsync) { } [SetUp] - public void TestSetup() + public void Setup() { + TestSetup(); expectedTenantId = TenantId; - expectedReplyUri = null; - authCode = Guid.NewGuid().ToString(); - expectedToken = Guid.NewGuid().ToString(); - expiresOn = DateTimeOffset.Now.AddHours(1); - var result = new AuthenticationResult( - expectedToken, - false, - null, - expiresOn, - expiresOn, - TenantId, - new MockAccount("username"), - null, - new[] { Scope }, - Guid.NewGuid(), - null, - "Bearer"); - silentFactory = (_, _tenantId, _replyUri, _) => - { - Assert.AreEqual(expectedTenantId, _tenantId); - Assert.AreEqual(expectedReplyUri, _replyUri); - return new ValueTask(result); - }; - mockMsalClient = new MockMsalConfidentialClient(silentFactory); - mockMsalClient.AuthcodeFactory = (_, _tenantId, _replyUri, _) => - { - Assert.AreEqual(expectedTenantId, _tenantId); - Assert.AreEqual(expectedReplyUri, _replyUri); - return result; - }; } [Test] @@ -79,7 +32,7 @@ public async Task AuthenticateWithAuthCodeHonorsReplyUrl([Values(null, ReplyUrl) expectedReplyUri = replyUri; AuthorizationCodeCredential cred = InstrumentClient( - new AuthorizationCodeCredential(TenantId, ClientId, clientSecret, authCode, options, mockMsalClient)); + new AuthorizationCodeCredential(TenantId, ClientId, clientSecret, authCode, options, mockConfidentialMsalClient)); AccessToken token = await cred.GetTokenAsync(context); @@ -97,7 +50,8 @@ public async Task AuthenticateWithAuthCodeHonorsTenantId([Values(null, TenantIdH var context = new TokenRequestContext(new[] { Scope }, tenantId: tenantId); expectedTenantId = TenantIdResolver.Resolve(TenantId, context, options.AllowMultiTenantAuthentication); - AuthorizationCodeCredential cred = InstrumentClient(new AuthorizationCodeCredential(TenantId, ClientId, clientSecret, authCode, options, mockMsalClient)); + AuthorizationCodeCredential cred = InstrumentClient( + new AuthorizationCodeCredential(TenantId, ClientId, clientSecret, authCode, options, mockConfidentialMsalClient)); AccessToken token = await cred.GetTokenAsync(context); diff --git a/sdk/identity/Azure.Identity/tests/AzureCliCredentialTests.cs b/sdk/identity/Azure.Identity/tests/AzureCliCredentialTests.cs index 19e84dcc257cf..1c9d4afcc7ee3 100644 --- a/sdk/identity/Azure.Identity/tests/AzureCliCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/AzureCliCredentialTests.cs @@ -6,18 +6,13 @@ using System.Threading; using System.Threading.Tasks; using Azure.Core; -using Azure.Core.TestFramework; using Azure.Identity.Tests.Mock; using NUnit.Framework; namespace Azure.Identity.Tests { - public class AzureCliCredentialTests : ClientTestBase + public class AzureCliCredentialTests : CredentialTestBase { - private const string Scope = "https://vault.azure.net/.default"; - private const string TenantId = "explicitTenantId"; - private const string TenantIdHint = "tenantIdChallenge"; - public AzureCliCredentialTests(bool isAsync) : base(isAsync) { } [Test] diff --git a/sdk/identity/Azure.Identity/tests/AzurePowerShellCredentialsTests.cs b/sdk/identity/Azure.Identity/tests/AzurePowerShellCredentialsTests.cs index 086aaf0170dbd..c31c82e870288 100644 --- a/sdk/identity/Azure.Identity/tests/AzurePowerShellCredentialsTests.cs +++ b/sdk/identity/Azure.Identity/tests/AzurePowerShellCredentialsTests.cs @@ -16,15 +16,11 @@ namespace Azure.Identity.Tests { - public class AzurePowerShellCredentialsTests : ClientTestBase + public class AzurePowerShellCredentialsTests : CredentialTestBase { private string tokenXML = "Kg==5/11/2021 8:20:03 PM +00:0072f988bf-86f1-41af-91ab-2d7cd011db47chriss@microsoft.comBearer"; - private const string Scope = "https://vault.azure.net/.default"; - private const string TenantId = "a0287521-e002-0026-7112-207c0c000000"; - private const string TenantIdHint = "a0287521-e002-0026-7112-207c0c001234"; - public AzurePowerShellCredentialsTests(bool isAsync) : base(isAsync) { } diff --git a/sdk/identity/Azure.Identity/tests/ClientCertificateCredentialTests.cs b/sdk/identity/Azure.Identity/tests/ClientCertificateCredentialTests.cs index fbd7f9fdd15db..0c18f526ca5ff 100644 --- a/sdk/identity/Azure.Identity/tests/ClientCertificateCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/ClientCertificateCredentialTests.cs @@ -2,33 +2,18 @@ // Licensed under the MIT License. using System; -using System.Collections.Generic; using System.IO; -using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; -using System.Text; -using System.Text.Json; using System.Threading.Tasks; using Azure.Core; using Azure.Core.TestFramework; using Azure.Identity.Tests.Mock; -using Microsoft.Identity.Client; using NUnit.Framework; namespace Azure.Identity.Tests { - public class ClientCertificateCredentialTests : ClientTestBase + public class ClientCertificateCredentialTests : CredentialTestBase { - private const string Scope = "https://vault.azure.net/.default"; - private const string TenantIdHint = "a0287521-e002-0026-7112-207c0c001234"; - private const string ClientId = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"; - private const string TenantId = "a0287521-e002-0026-7112-207c0c000000"; - private string expectedToken; - private DateTimeOffset expiresOn; - private MockMsalConfidentialClient mockMsalClient; - private string expectedTenantId; - private TokenCredentialOptions options; - public ClientCertificateCredentialTests(bool isAsync) : base(isAsync) { } @@ -36,9 +21,7 @@ public ClientCertificateCredentialTests(bool isAsync) : base(isAsync) public void VerifyCtorParametersValidation() { var tenantId = Guid.NewGuid().ToString(); - var clientId = Guid.NewGuid().ToString(); - var certificatePath = Path.Combine(TestContext.CurrentContext.TestDirectory, "Data", "cert.pfx"); var mockCert = new X509Certificate2(certificatePath); @@ -56,25 +39,25 @@ public void VerifyBadCertificateFileBehavior() var tenantId = Guid.NewGuid().ToString(); var clientId = Guid.NewGuid().ToString(); - TokenRequestContext tokenContext = new TokenRequestContext(MockScopes.Default); + TokenRequestContext tokenContext = new(MockScopes.Default); - ClientCertificateCredential missingFileCredential = new ClientCertificateCredential( + ClientCertificateCredential missingFileCredential = new( tenantId, clientId, Path.Combine(TestContext.CurrentContext.TestDirectory, "Data", "notfound.pem")); - ClientCertificateCredential invalidPemCredential = new ClientCertificateCredential( + ClientCertificateCredential invalidPemCredential = new( tenantId, clientId, Path.Combine(TestContext.CurrentContext.TestDirectory, "Data", "cert-invalid-data.pem")); - ClientCertificateCredential unknownFormatCredential = new ClientCertificateCredential( + ClientCertificateCredential unknownFormatCredential = new( tenantId, clientId, Path.Combine(TestContext.CurrentContext.TestDirectory, "Data", "cert.unknown")); - ClientCertificateCredential encryptedCredential = new ClientCertificateCredential( + ClientCertificateCredential encryptedCredential = new( tenantId, clientId, Path.Combine(TestContext.CurrentContext.TestDirectory, "Data", "cert-password-protected.pfx")); - ClientCertificateCredential unsupportedCertCredential = new ClientCertificateCredential( + ClientCertificateCredential unsupportedCertCredential = new( tenantId, clientId, Path.Combine(TestContext.CurrentContext.TestDirectory, "Data", "ec-cert.pem")); @@ -169,62 +152,13 @@ public async Task UsesTenantIdHint( ClientCertificateCredential credential = InstrumentClient( usePemFile - ? new ClientCertificateCredential(TenantId, ClientId, certificatePathPem, options, default, mockMsalClient) - : new ClientCertificateCredential(TenantId, ClientId, mockCert, options, default, mockMsalClient) + ? new ClientCertificateCredential(TenantId, ClientId, certificatePathPem, options, default, mockConfidentialMsalClient) + : new ClientCertificateCredential(TenantId, ClientId, mockCert, options, default, mockConfidentialMsalClient) ); var token = await credential.GetTokenAsync(context); Assert.AreEqual(token.Token, expectedToken, "Should be the expected token value"); } - - public void TestSetup() - { - options = new TokenCredentialOptions(); - expectedTenantId = null; - expectedToken = Guid.NewGuid().ToString(); - expiresOn = DateTimeOffset.Now.AddHours(1); - var result = new AuthenticationResult( - expectedToken, - false, - null, - expiresOn, - expiresOn, - TenantId, - new MockAccount("username"), - null, - new[] { Scope }, - Guid.NewGuid(), - null, - "Bearer"); - - Func clientFactory = (_, _tenantId) => - { - return result; - }; - mockMsalClient = new MockMsalConfidentialClient(clientFactory); - } - - private static IEnumerable RegionalAuthorityTestData() - { - yield return new TestCaseData(null); - yield return new TestCaseData(RegionalAuthority.AutoDiscoverRegion); - yield return new TestCaseData(RegionalAuthority.USWest); - } - - [Test] - [TestCaseSource("RegionalAuthorityTestData")] - public void VerifyMsalClientRegionalAuthority(RegionalAuthority? regionalAuthority) - { - var expectedTenantId = Guid.NewGuid().ToString(); - - var expectedClientId = Guid.NewGuid().ToString(); - - var certificatePath = Path.Combine(TestContext.CurrentContext.TestDirectory, "Data", "cert.pfx"); - - var cred = new ClientCertificateCredential(expectedTenantId, expectedClientId, certificatePath, new ClientCertificateCredentialOptions { RegionalAuthority = regionalAuthority }); - - Assert.AreEqual(regionalAuthority, cred.Client.RegionalAuthority); - } } } diff --git a/sdk/identity/Azure.Identity/tests/ClientSecretCredentialTests.cs b/sdk/identity/Azure.Identity/tests/ClientSecretCredentialTests.cs index 12feff05d2a42..6a01b6b24a91c 100644 --- a/sdk/identity/Azure.Identity/tests/ClientSecretCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/ClientSecretCredentialTests.cs @@ -2,31 +2,17 @@ // Licensed under the MIT License. using System; -using System.Collections.Generic; -using System.IO; -using System.Text; using System.Threading.Tasks; using Azure.Core; using Azure.Core.TestFramework; using Azure.Identity.Tests.Mock; using Microsoft.Identity.Client; -using Moq; using NUnit.Framework; namespace Azure.Identity.Tests { - public class ClientSecretCredentialTests : ClientTestBase + public class ClientSecretCredentialTests : CredentialTestBase { - private const string Scope = "https://vault.azure.net/.default"; - private const string TenantIdHint = "a0287521-e002-0026-7112-207c0c001234"; - private const string ClientId = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"; - private const string TenantId = "a0287521-e002-0026-7112-207c0c000000"; - private string expectedToken; - private DateTimeOffset expiresOn; - private MockMsalConfidentialClient mockMsalClient; - private string expectedTenantId; - private TokenCredentialOptions options; - public ClientSecretCredentialTests(bool isAsync) : base(isAsync) { } @@ -44,7 +30,6 @@ public void VerifyCtorParametersValidation() [Test] public async Task UsesTenantIdHint( - [Values(true, false)] bool usePemFile, [Values(null, TenantIdHint)] string tenantId, [Values(true)] bool allowMultiTenantAuthentication) { @@ -52,7 +37,7 @@ public async Task UsesTenantIdHint( options.AllowMultiTenantAuthentication = allowMultiTenantAuthentication; var context = new TokenRequestContext(new[] { Scope }, tenantId: tenantId); expectedTenantId = TenantIdResolver.Resolve(TenantId, context, options.AllowMultiTenantAuthentication); - ClientSecretCredential client = InstrumentClient(new ClientSecretCredential(expectedTenantId, ClientId, "secret", options, null, mockMsalClient)); + ClientSecretCredential client = InstrumentClient(new ClientSecretCredential(expectedTenantId, ClientId, "secret", options, null, mockConfidentialMsalClient)); var token = await client.GetTokenAsync(new TokenRequestContext(MockScopes.Default)); @@ -93,33 +78,5 @@ public async Task VerifyClientSecretCredentialExceptionAsync() await Task.CompletedTask; } - - public void TestSetup() - { - options = new TokenCredentialOptions(); - expectedTenantId = null; - expectedToken = Guid.NewGuid().ToString(); - expiresOn = DateTimeOffset.Now.AddHours(1); - var result = new AuthenticationResult( - expectedToken, - false, - null, - expiresOn, - expiresOn, - TenantId, - new MockAccount("username"), - null, - new[] { Scope }, - Guid.NewGuid(), - null, - "Bearer"); - - Func clientFactory = (_, _tenantId) => - { - Assert.AreEqual(expectedTenantId, _tenantId); - return result; - }; - mockMsalClient = new MockMsalConfidentialClient(clientFactory); - } } } diff --git a/sdk/identity/Azure.Identity/tests/CredentialTestBase.cs b/sdk/identity/Azure.Identity/tests/CredentialTestBase.cs new file mode 100644 index 0000000000000..2bf5d14da0971 --- /dev/null +++ b/sdk/identity/Azure.Identity/tests/CredentialTestBase.cs @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.TestFramework; +using Azure.Identity.Tests.Mock; +using Microsoft.Identity.Client; +using NUnit.Framework; + +namespace Azure.Identity.Tests +{ + public class CredentialTestBase : ClientTestBase + { + protected const string Scope = "https://vault.azure.net/.default"; + protected const string TenantIdHint = "a0287521-e002-0026-7112-207c0c001234"; + protected const string ClientId = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"; + protected const string TenantId = "a0287521-e002-0026-7112-207c0c000000"; + protected const string expectedUsername = "mockuser@mockdomain.com"; + protected string expectedToken; + protected string expectedUserAssertion; + protected string expectedTenantId; + protected string expectedReplyUri; + protected string authCode; + protected const string ReplyUrl = "https://myredirect/"; + protected string clientSecret = Guid.NewGuid().ToString(); + protected DateTimeOffset expiresOn; + internal MockMsalConfidentialClient mockConfidentialMsalClient; + internal MockMsalPublicClient mockPublicMsalClient; + protected TokenCredentialOptions options; + protected AuthenticationResult result; + protected string expectedCode; + protected DeviceCodeResult deviceCodeResult; + + public CredentialTestBase(bool isAsync) : base(isAsync) + { } + + public void TestSetup() + { + expectedTenantId = null; + expectedReplyUri = null; + authCode = Guid.NewGuid().ToString(); + options = new TokenCredentialOptions(); + expectedToken = Guid.NewGuid().ToString(); + expectedUserAssertion = Guid.NewGuid().ToString(); + expiresOn = DateTimeOffset.Now.AddHours(1); + result = new AuthenticationResult( + expectedToken, + false, + null, + expiresOn, + expiresOn, + TenantId, + new MockAccount("username"), + null, + new[] { Scope }, + Guid.NewGuid(), + null, + "Bearer"); + + mockConfidentialMsalClient = new MockMsalConfidentialClient() + .WithSilentFactory( + (_, _tenantId, _replyUri, _) => + { + Assert.AreEqual(expectedTenantId, _tenantId); + Assert.AreEqual(expectedReplyUri, _replyUri); + return new ValueTask(result); + }) + .WithAuthCodeFactory( + (_, _tenantId, _replyUri, _) => + { + Assert.AreEqual(expectedTenantId, _tenantId); + Assert.AreEqual(expectedReplyUri, _replyUri); + return result; + }) + .WithOnBehalfOfFactory( + (_, _, userAssertion, _, _) => + { + Assert.AreEqual(expectedUserAssertion, userAssertion.Assertion); + return new ValueTask(result); + }) + .WithClientFactory( + (_, _tenantId) => + { + Assert.AreEqual(expectedTenantId, _tenantId); + return result; + }); + + expectedCode = Guid.NewGuid().ToString(); + mockPublicMsalClient = new MockMsalPublicClient(); + deviceCodeResult = MockMsalPublicClient.GetDeviceCodeResult(deviceCode: expectedCode); + mockPublicMsalClient.DeviceCodeResult = deviceCodeResult; + var publicResult = new AuthenticationResult( + expectedToken, + false, + null, + expiresOn, + expiresOn, + TenantId, + new MockAccount("username"), + null, + new[] { Scope }, + Guid.NewGuid(), + null, + "Bearer"); + mockPublicMsalClient.SilentAuthFactory = (_, tId) => + { + Assert.AreEqual(expectedTenantId, tId); + return publicResult; + }; + mockPublicMsalClient.DeviceCodeAuthFactory = (_, _) => + { + // Assert.AreEqual(tenantId, tId); + return publicResult; + }; + mockPublicMsalClient.InteractiveAuthFactory = (_, _, _, _, tenant, _, _) => + { + Assert.AreEqual(expectedTenantId, tenant, "TenantId passed to msal should match"); + return result; + }; + mockPublicMsalClient.SilentAuthFactory = (_, tenant) => + { + Assert.AreEqual(expectedTenantId, tenant, "TenantId passed to msal should match"); + return result; + }; + mockPublicMsalClient.ExtendedSilentAuthFactory = (_, _, _, tenant, _, _) => + { + Assert.AreEqual(expectedTenantId, tenant, "TenantId passed to msal should match"); + return result; + }; + mockPublicMsalClient.UserPassAuthFactory = (_, tenant) => + { + Assert.AreEqual(expectedTenantId, tenant, "TenantId passed to msal should match"); + return result; + }; + mockPublicMsalClient.RefreshTokenFactory = (_, _, _, _, tenant, _, _) => + { + Assert.AreEqual(expectedTenantId, tenant, "TenantId passed to msal should match"); + return result; + }; + } + } +} diff --git a/sdk/identity/Azure.Identity/tests/DeviceCodeCredentialTests.cs b/sdk/identity/Azure.Identity/tests/DeviceCodeCredentialTests.cs index 20943a0aa2856..eb90589f03973 100644 --- a/sdk/identity/Azure.Identity/tests/DeviceCodeCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/DeviceCodeCredentialTests.cs @@ -17,25 +17,11 @@ namespace Azure.Identity.Tests { - public class DeviceCodeCredentialTests : ClientTestBase + public class DeviceCodeCredentialTests : CredentialTestBase { public DeviceCodeCredentialTests(bool isAsync) : base(isAsync) { } - private const string ClientId = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"; - private const string TenantId = "a0287521-e002-0026-7112-207c0c000000"; - private const string TenantIdHint = "a0287521-e002-0026-7112-207c0c001234"; - private const string Scope = "https://vault.azure.net/.default"; - private readonly HashSet _requestedCodes = new HashSet(); - private TokenCredentialOptions options = new TokenCredentialOptions(); - private readonly object _requestedCodesLock = new object(); - private string expectedCode; - private string expectedToken; - private DateTimeOffset expiresOn; - private MockMsalPublicClient mockMsalClient; - private DeviceCodeResult deviceCodeResult; - private string expectedTenantId = null; - private Task VerifyDeviceCode(DeviceCodeInfo codeInfo, string expectedCode) { Assert.AreEqual(expectedCode, codeInfo.DeviceCode); @@ -69,38 +55,9 @@ private async Task ThrowingDeviceCodeCallback(DeviceCodeInfo code, CancellationT } [SetUp] - public void TestSetup() + public void Setup() { - expectedTenantId = null; - expectedCode = Guid.NewGuid().ToString(); - expectedToken = Guid.NewGuid().ToString(); - expiresOn = DateTimeOffset.Now.AddHours(1); - mockMsalClient = new MockMsalPublicClient(); - deviceCodeResult = MockMsalPublicClient.GetDeviceCodeResult(deviceCode: expectedCode); - mockMsalClient.DeviceCodeResult = deviceCodeResult; - var result = new AuthenticationResult( - expectedToken, - false, - null, - expiresOn, - expiresOn, - TenantId, - new MockAccount("username"), - null, - new[] { Scope }, - Guid.NewGuid(), - null, - "Bearer"); - mockMsalClient.SilentAuthFactory = (_, tId) => - { - Assert.AreEqual(expectedTenantId, tId); - return result; - }; - mockMsalClient.DeviceCodeAuthFactory = (_, _) => - { - // Assert.AreEqual(tenantId, tId); - return result; - }; + TestSetup(); } [Test] @@ -110,7 +67,7 @@ public async Task AuthenticateWithDeviceCodeMockAsync([Values(null, TenantIdHint var context = new TokenRequestContext(new[] { Scope }, tenantId: tenantId); expectedTenantId = TenantIdResolver.Resolve(TenantId, context, options.AllowMultiTenantAuthentication) ; var cred = InstrumentClient( - new DeviceCodeCredential((code, _) => VerifyDeviceCode(code, expectedCode), TenantId, ClientId, options, null, mockMsalClient)); + new DeviceCodeCredential((code, _) => VerifyDeviceCode(code, expectedCode), TenantId, ClientId, options, null, mockPublicMsalClient)); AccessToken token = await cred.GetTokenAsync(context); @@ -141,18 +98,18 @@ public async Task AuthenticateWithDeviceCodeNoCallback() try { - var client = new DeviceCodeCredential { Client = mockMsalClient }; + var client = new DeviceCodeCredential { Client = mockPublicMsalClient }; var cred = InstrumentClient(client); AccessToken token = await cred.GetTokenAsync(new TokenRequestContext(new[] { Scope })); Assert.AreEqual(token.Token, expectedToken); - Assert.AreEqual(mockMsalClient.DeviceCodeResult.Message + Environment.NewLine, capturedOut.ToString()); + Assert.AreEqual(mockPublicMsalClient.DeviceCodeResult.Message + Environment.NewLine, capturedOut.ToString()); token = await cred.GetTokenAsync(new TokenRequestContext(new[] { Scope })); Assert.AreEqual(token.Token, expectedToken); - Assert.AreEqual(mockMsalClient.DeviceCodeResult.Message + Environment.NewLine, capturedOut.ToString()); + Assert.AreEqual(mockPublicMsalClient.DeviceCodeResult.Message + Environment.NewLine, capturedOut.ToString()); } finally { @@ -173,7 +130,7 @@ public async Task AuthenticateWithDeviceCodeMockVerifyMsalCancellationAsync() ClientId, options, null, - mockMsalClient)); + mockPublicMsalClient)); var ex = Assert.CatchAsync( async () => await cred.GetTokenAsync(new TokenRequestContext(new[] { Scope }), cancelSource.Token)); @@ -186,7 +143,7 @@ public async Task AuthenticateWithDeviceCodeMockVerifyCallbackCancellationAsync( { var cancelSource = new CancellationTokenSource(); cancelSource.Cancel(); - var cred = InstrumentClient(new DeviceCodeCredential(VerifyDeviceCodeCallbackCancellationToken, null, ClientId, options, null, mockMsalClient)); + var cred = InstrumentClient(new DeviceCodeCredential(VerifyDeviceCodeCallbackCancellationToken, null, ClientId, options, null, mockPublicMsalClient)); var ex = Assert.CatchAsync( async () => await cred.GetTokenAsync(new TokenRequestContext(new[] { Scope }), cancelSource.Token)); @@ -200,7 +157,7 @@ public void AuthenticateWithDeviceCodeCallbackThrowsAsync() var cancelSource = new CancellationTokenSource(); var options = new TokenCredentialOptions(); - var cred = InstrumentClient(new DeviceCodeCredential(ThrowingDeviceCodeCallback, null, ClientId, options, null, mockMsalClient)); + var cred = InstrumentClient(new DeviceCodeCredential(ThrowingDeviceCodeCallback, null, ClientId, options, null, mockPublicMsalClient)); var ex = Assert.ThrowsAsync( async () => await cred.GetTokenAsync(new TokenRequestContext(new[] { testEnvironment.KeyvaultScope }), cancelSource.Token)); diff --git a/sdk/identity/Azure.Identity/tests/EnvironmentCredentialProviderTests.cs b/sdk/identity/Azure.Identity/tests/EnvironmentCredentialProviderTests.cs index 8ff5ee345b122..fc1d500f9c499 100644 --- a/sdk/identity/Azure.Identity/tests/EnvironmentCredentialProviderTests.cs +++ b/sdk/identity/Azure.Identity/tests/EnvironmentCredentialProviderTests.cs @@ -69,7 +69,7 @@ public void CredentialConstructionClientCertificate() Assert.AreEqual("mockclientid", cred.ClientId); Assert.AreEqual("mocktenantid", cred.TenantId); - var certProvider = cred.ClientCertificateProvider as ClientCertificateCredential.X509Certificate2FromFileProvider; + var certProvider = cred.ClientCertificateProvider as X509Certificate2FromFileProvider; Assert.NotNull(certProvider); Assert.AreEqual("mockcertificatepath", certProvider.CertificatePath); diff --git a/sdk/identity/Azure.Identity/tests/InteractiveBrowserCredentialTests.cs b/sdk/identity/Azure.Identity/tests/InteractiveBrowserCredentialTests.cs index dfe399c658220..ec3fdc990f733 100644 --- a/sdk/identity/Azure.Identity/tests/InteractiveBrowserCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/InteractiveBrowserCredentialTests.cs @@ -13,19 +13,8 @@ namespace Azure.Identity.Tests { - public class InteractiveBrowserCredentialTests : ClientTestBase + public class InteractiveBrowserCredentialTests : CredentialTestBase { - private string TenantId = "a0287521-e002-0026-7112-207c0c000000"; - private const string TenantIdHint = "a0287521-e002-0026-7112-207c0c001234"; - private const string Scope = "https://vault.azure.net/.default"; - private const string ClientId = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"; - private string expectedCode; - private string expectedToken; - private DateTimeOffset expiresOn; - private MockMsalPublicClient mockMsalClient; - private DeviceCodeResult deviceCodeResult; - private string expectedTenantId; - public InteractiveBrowserCredentialTests(bool isAsync) : base(isAsync) { } @@ -242,46 +231,12 @@ public async Task UsesTenantIdHint([Values(null, TenantIdHint)] string tenantId, ClientId, options, null, - mockMsalClient)); + mockPublicMsalClient)); var actualToken = await credential.GetTokenAsync(context, CancellationToken.None); Assert.AreEqual(expectedToken, actualToken.Token, "Token should match"); Assert.AreEqual(expiresOn, actualToken.ExpiresOn, "expiresOn should match"); } - - public void TestSetup() - { - expectedTenantId = null; - expectedCode = Guid.NewGuid().ToString(); - expectedToken = Guid.NewGuid().ToString(); - expiresOn = DateTimeOffset.Now.AddHours(1); - mockMsalClient = new MockMsalPublicClient(); - deviceCodeResult = MockMsalPublicClient.GetDeviceCodeResult(deviceCode: expectedCode); - mockMsalClient.DeviceCodeResult = deviceCodeResult; - var result = new AuthenticationResult( - expectedToken, - false, - null, - expiresOn, - expiresOn, - TenantId, - new MockAccount("username"), - null, - new[] { Scope }, - Guid.NewGuid(), - null, - "Bearer"); - mockMsalClient.InteractiveAuthFactory = (_, _, _, _, tenant, _, _) => - { - Assert.AreEqual(expectedTenantId, tenant, "TenantId passed to msal should match"); - return result; - }; - mockMsalClient.SilentAuthFactory = (_, tenant) => - { - Assert.AreEqual(expectedTenantId, tenant, "TenantId passed to msal should match"); - return result; - }; - } } } diff --git a/sdk/identity/Azure.Identity/tests/Mock/MockMsalConfidentialClient.cs b/sdk/identity/Azure.Identity/tests/Mock/MockMsalConfidentialClient.cs index 4d835191aa30d..8f79af8054275 100644 --- a/sdk/identity/Azure.Identity/tests/Mock/MockMsalConfidentialClient.cs +++ b/sdk/identity/Azure.Identity/tests/Mock/MockMsalConfidentialClient.cs @@ -11,35 +11,52 @@ namespace Azure.Identity.Tests.Mock internal class MockMsalConfidentialClient : MsalConfidentialClient { internal Func ClientFactory { get; set; } - internal Func> SilentFactory { get; set; } + internal Func> SilentFactory { get; set; } internal Func AuthcodeFactory { get; set; } + internal Func> OnBehalfOfFactory { get; set; } + + public MockMsalConfidentialClient() + { } public MockMsalConfidentialClient(AuthenticationResult result) { - ClientFactory = (_,_) => result; + ClientFactory = (_, _) => result; + SilentFactory = (_, _, _, _) => new ValueTask(result); AuthcodeFactory = (_, _, _, _) => result; + OnBehalfOfFactory = (_, _, _, _, _) => new ValueTask(result); } public MockMsalConfidentialClient(Exception exception) { - ClientFactory = (_,_) => throw exception; + ClientFactory = (_, _) => throw exception; SilentFactory = (_, _, _, _) => throw exception; AuthcodeFactory = (_, _, _, _) => throw exception; + OnBehalfOfFactory = (_, _, _, _, _) => throw exception; } - public MockMsalConfidentialClient(Func clientFactory) + public MockMsalConfidentialClient WithClientFactory(Func clientFactory) { ClientFactory = clientFactory; + return this; } - public MockMsalConfidentialClient(Func> factory) + public MockMsalConfidentialClient WithSilentFactory(Func> factory) { SilentFactory = factory; + return this; } - public MockMsalConfidentialClient(Func factory) + public MockMsalConfidentialClient WithAuthCodeFactory(Func factory) { AuthcodeFactory = factory; + return this; + } + + public MockMsalConfidentialClient WithOnBehalfOfFactory( + Func> onBehalfOfFactory) + { + OnBehalfOfFactory = onBehalfOfFactory; + return this; } public override ValueTask AcquireTokenForClientAsync(string[] scopes, string tenantId, bool async, CancellationToken cancellationToken) @@ -68,5 +85,15 @@ public override ValueTask AcquireTokenByAuthorizationCodeA { return new(AuthcodeFactory(scopes, tenantId, replyUri, code)); } + + public override async ValueTask AcquireTokenOnBehalfOf( + string[] scopes, + string tenantId, + UserAssertion userAssertionValue, + bool async, + CancellationToken cancellationToken) + { + return await OnBehalfOfFactory(scopes, tenantId, userAssertionValue, async, cancellationToken); + } } } diff --git a/sdk/identity/Azure.Identity/tests/OnBehalfOfCredentialTests.cs b/sdk/identity/Azure.Identity/tests/OnBehalfOfCredentialTests.cs new file mode 100644 index 0000000000000..63ae982c9bccb --- /dev/null +++ b/sdk/identity/Azure.Identity/tests/OnBehalfOfCredentialTests.cs @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Security.Cryptography.X509Certificates; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Identity.Tests.Mock; +using NUnit.Framework; + +namespace Azure.Identity.Tests +{ + public class OnBehalfOfCredentialTests : CredentialTestBase + { + public OnBehalfOfCredentialTests(bool isAsync) : base(isAsync) { } + + [Test] + public void CtorValidation() + { + OnBehalfOfCredential cred; + string userAssertion = Guid.NewGuid().ToString(); + string clientSecret = Guid.NewGuid().ToString(); + + Assert.Throws(() => new OnBehalfOfCredential(null, ClientId, clientSecret, userAssertion)); + Assert.Throws(() => new OnBehalfOfCredential(TenantId, null, clientSecret, userAssertion)); + Assert.Throws(() => new OnBehalfOfCredential(TenantId, ClientId, null, userAssertion)); + Assert.Throws(() => new OnBehalfOfCredential(TenantId, ClientId, clientSecret, null)); + cred = new OnBehalfOfCredential(TenantId, ClientId, clientSecret, userAssertion); + // Assert + Assert.AreEqual(clientSecret, cred._client._clientSecret); + + Assert.Throws(() => new OnBehalfOfCredential(null, ClientId, new X509Certificate2(), userAssertion)); + Assert.Throws(() => new OnBehalfOfCredential(TenantId, null, new X509Certificate2(), userAssertion)); + Assert.Throws(() => new OnBehalfOfCredential(TenantId, ClientId, null, userAssertion)); + Assert.Throws(() => new OnBehalfOfCredential(TenantId, ClientId, new X509Certificate2(), null)); + cred = new OnBehalfOfCredential(TenantId, ClientId, new X509Certificate2(), userAssertion); + // Assert + Assert.NotNull(cred._client._certificateProvider); + + Assert.Throws(() => new OnBehalfOfCredential(null, ClientId, new X509Certificate2(), userAssertion, new OnBehalfOfCredentialOptions())); + Assert.Throws(() => new OnBehalfOfCredential(TenantId, null, new X509Certificate2(), userAssertion, new OnBehalfOfCredentialOptions())); + Assert.Throws(() => new OnBehalfOfCredential(TenantId, ClientId, default(X509Certificate2), userAssertion, new OnBehalfOfCredentialOptions())); + Assert.Throws(() => new OnBehalfOfCredential(TenantId, ClientId, new X509Certificate2(), null, new OnBehalfOfCredentialOptions())); + cred = new OnBehalfOfCredential(TenantId, ClientId, new X509Certificate2(), userAssertion, new OnBehalfOfCredentialOptions()); + // Assert + Assert.NotNull(cred._client._certificateProvider); + } + + [Test] + public async Task UsesTenantIdHint( + [Values(null, TenantIdHint)] string tenantId, + [Values(true)] bool allowMultiTenantAuthentication, + [Values(null, TenantId)] string explicitTenantId) + { + TestSetup(); + options = new OnBehalfOfCredentialOptions(); + options.AllowMultiTenantAuthentication = allowMultiTenantAuthentication; + var context = new TokenRequestContext(new[] { Scope }, tenantId: tenantId); + expectedTenantId = TenantIdResolver.Resolve(explicitTenantId, context, options.AllowMultiTenantAuthentication); + OnBehalfOfCredential client = InstrumentClient( + new OnBehalfOfCredential( + TenantId, + ClientId, + "secret", + expectedUserAssertion, + options as OnBehalfOfCredentialOptions, + null, + mockConfidentialMsalClient)); + + var token = await client.GetTokenAsync(new TokenRequestContext(MockScopes.Default), default); + Assert.AreEqual(token.Token, expectedToken, "Should be the expected token value"); + } + } +} diff --git a/sdk/identity/Azure.Identity/tests/SharedTokenCacheCredentialTests.cs b/sdk/identity/Azure.Identity/tests/SharedTokenCacheCredentialTests.cs index a6d995d099d53..e48e9dc744161 100644 --- a/sdk/identity/Azure.Identity/tests/SharedTokenCacheCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/SharedTokenCacheCredentialTests.cs @@ -12,20 +12,8 @@ namespace Azure.Identity.Tests { - public class SharedTokenCacheCredentialTests : ClientTestBase + public class SharedTokenCacheCredentialTests : CredentialTestBase { - private string TenantId = "a0287521-e002-0026-7112-207c0c000000"; - private const string TenantIdHint = "a0287521-e002-0026-7112-207c0c001234"; - private const string Scope = "https://vault.azure.net/.default"; - private const string ClientId = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"; - private string expectedCode; - private string expectedToken; - private DateTimeOffset expiresOn; - private MockMsalPublicClient mockMsal; - private DeviceCodeResult deviceCodeResult; - private string expectedTenantId; - private const string expectedUsername = "mockuser@mockdomain.com"; - public SharedTokenCacheCredentialTests(bool isAsync) : base(isAsync) { } @@ -590,46 +578,17 @@ public async Task UsesTenantIdHint([Values(null, TenantIdHint)] string tenantId, var options = new SharedTokenCacheCredentialOptions { AllowMultiTenantAuthentication = allowMultiTenantAuthentication }; var context = new TokenRequestContext(new[] { Scope }, tenantId: tenantId); expectedTenantId = TenantIdResolver.Resolve(TenantId, context, options.AllowMultiTenantAuthentication); - mockMsal.Accounts = new List + mockPublicMsalClient.Accounts = new List { new MockAccount(expectedUsername, expectedTenantId) }; - var credential = InstrumentClient(new SharedTokenCacheCredential(TenantId, null, options, null, mockMsal)); + var credential = InstrumentClient(new SharedTokenCacheCredential(TenantId, null, options, null, mockPublicMsalClient)); AccessToken token = await credential.GetTokenAsync(context); Assert.AreEqual(expectedToken, token.Token); Assert.AreEqual(expiresOn, token.ExpiresOn); } - - public void TestSetup() - { - expectedTenantId = null; - expectedCode = Guid.NewGuid().ToString(); - expectedToken = Guid.NewGuid().ToString(); - expiresOn = DateTimeOffset.Now.AddHours(1); - mockMsal = new MockMsalPublicClient(); - deviceCodeResult = MockMsalPublicClient.GetDeviceCodeResult(deviceCode: expectedCode); - mockMsal.DeviceCodeResult = deviceCodeResult; - var result = new AuthenticationResult( - expectedToken, - false, - null, - expiresOn, - expiresOn, - TenantId, - new MockAccount("username"), - null, - new[] { Scope }, - Guid.NewGuid(), - null, - "Bearer"); - mockMsal.ExtendedSilentAuthFactory = (_, _, _, tenant, _, _) => - { - Assert.AreEqual(expectedTenantId, tenant, "TenantId passed to msal should match"); - return result; - }; - } } } diff --git a/sdk/identity/Azure.Identity/tests/TokenCacheTests.cs b/sdk/identity/Azure.Identity/tests/TokenCacheTests.cs index b79f0d30f4cbf..52e668c397505 100644 --- a/sdk/identity/Azure.Identity/tests/TokenCacheTests.cs +++ b/sdk/identity/Azure.Identity/tests/TokenCacheTests.cs @@ -27,16 +27,16 @@ public TokenCacheTests(bool isAsync) : base(isAsync) public Mock mockSerializer2; public Mock mockMSALCache; internal Mock mockWrapper; - public static Random random = new Random(); - public byte[] bytes = new byte[] { 1, 0 }; - public byte[] updatedBytes = new byte[] { 0, 2 }; - public byte[] mergedBytes = new byte[] { 1, 2 }; + public static Random random = new(); + public byte[] bytes = { 1, 0 }; + public byte[] updatedBytes = { 0, 2 }; + public byte[] mergedBytes = { 1, 2 }; public Func main_OnBeforeCacheAccessAsync = null; public Func main_OnAfterCacheAccessAsync = null; public TokenCacheCallback merge_OnBeforeCacheAccessAsync = null; public TokenCacheCallback merge_OnAfterCacheAccessAsync = null; private const int TestBufferSize = 512; - private static Random rand = new Random(); + private static Random rand = new(); [SetUp] public void Setup() @@ -75,7 +75,7 @@ public void CtorAllowsAllPermutations(TokenCachePersistenceOptions options, bool [Test] public async Task NoPersistance_RegisterCacheInitializesEvents() { - cache = new TokenCache(new InMemoryTokenCacheOptions(bytes)); + cache = new TokenCache(new TestInMemoryTokenCacheOptions(bytes)); await cache.RegisterCache(IsAsync, mockMSALCache.Object, default); @@ -86,7 +86,7 @@ public async Task NoPersistance_RegisterCacheInitializesEvents() [Test] public async Task NoPersistance_RegisterCacheInitializesEventsOnlyOnce() { - cache = new TokenCache(new InMemoryTokenCacheOptions(bytes)); + cache = new TokenCache(new TestInMemoryTokenCacheOptions(bytes)); await cache.RegisterCache(IsAsync, mockMSALCache.Object, default); await cache.RegisterCache(IsAsync, mockMSALCache.Object, default); @@ -151,8 +151,8 @@ public async Task RegisterCacheInitializesCacheOnlyOnce() [NonParallelizable] public void RegisterCacheInitializesCacheAndIsThreadSafe() { - ManualResetEventSlim resetEvent2 = new ManualResetEventSlim(); - ManualResetEventSlim resetEvent1 = new ManualResetEventSlim(); + ManualResetEventSlim resetEvent2 = new(); + ManualResetEventSlim resetEvent1 = new(); //The fist call to InitializeAsync will block. The second one will complete immediately. mockWrapper.SetupSequence(m => m.InitializeAsync(It.IsAny(), null)) @@ -218,7 +218,7 @@ public async Task RegisterCacheInitializesCacheIfEncryptionIsUnavailableAndAllow [Test] public async Task Persistance_RegisterCacheDoesNotInitializesEvents() { - var options = System.Environment.OSVersion.Platform switch + var options = Environment.OSVersion.Platform switch { // Linux tests will fail without UnsafeAllowUnencryptedStorage = true. PlatformID.Unix => new TokenCachePersistenceOptions { UnsafeAllowUnencryptedStorage = true }, @@ -235,7 +235,7 @@ public async Task Persistance_RegisterCacheDoesNotInitializesEvents() [Test] public async Task InMemory_RegisterCacheInitializesEvents() { - cache = new TokenCache(new InMemoryTokenCacheOptions(bytes)); + cache = new TokenCache(new TestInMemoryTokenCacheOptions(bytes)); await cache.RegisterCache(IsAsync, mockMSALCache.Object, default); @@ -246,7 +246,7 @@ public async Task InMemory_RegisterCacheInitializesEvents() [Test] public async Task InMemory_RegisterCacheInitializesEventsOnlyOnce() { - cache = new TokenCache(new InMemoryTokenCacheOptions(bytes)); + cache = new TokenCache(new TestInMemoryTokenCacheOptions(bytes)); await cache.RegisterCache(IsAsync, mockMSALCache.Object, default); await cache.RegisterCache(IsAsync, mockMSALCache.Object, default); @@ -258,7 +258,7 @@ public async Task InMemory_RegisterCacheInitializesEventsOnlyOnce() [Test] public async Task RegisteredEventsAreCalledOnFirstUpdate() { - cache = new TokenCache(new InMemoryTokenCacheOptions(bytes)); + cache = new TokenCache(new TestInMemoryTokenCacheOptions(bytes)); TokenCacheNotificationArgs mockArgs = GetMockArgs(mockSerializer, true); bool updatedCalled = false; @@ -331,7 +331,7 @@ public async Task MergeOccursOnSecondUpdate() merge_OnAfterCacheAccessAsync(mockArgs1); }); - cache = new TokenCache(new InMemoryTokenCacheOptions(bytes), default, publicApplicationFactory: new Func(() => mockPublicClient.Object)); + cache = new TokenCache(new TestInMemoryTokenCacheOptions(bytes), default, publicApplicationFactory: new Func(() => mockPublicClient.Object)); await cache.RegisterCache(IsAsync, mockMSALCache.Object, default); await cache.RegisterCache(IsAsync, mergeMSALCache.Object, default); @@ -373,7 +373,7 @@ public async Task Serialize() .Setup(m => m.SetAfterAccessAsync(It.IsAny>())) .Callback>(afterAccess => main_OnAfterCacheAccessAsync = afterAccess); - var cache = new TokenCache(new InMemoryTokenCacheOptions(bytes, updateHandler), default, publicApplicationFactory: new Func(() => mockPublicClient.Object)); + var cache = new TokenCache(new TestInMemoryTokenCacheOptions(bytes, updateHandler), default, publicApplicationFactory: new Func(() => mockPublicClient.Object)); await cache.RegisterCache(IsAsync, mockMSALCache.Object, default); await main_OnBeforeCacheAccessAsync.Invoke(mockArgs); @@ -408,16 +408,9 @@ public async Task UnsafeOptions() mockMSALCache .Setup(m => m.SetAfterAccessAsync(It.IsAny>())) .Callback>(afterAccess => main_OnAfterCacheAccessAsync = afterAccess); - var mockUnsafeOptions = new Mock(); - mockUnsafeOptions - .SetupSequence(m => m.RefreshCacheAsync()) - .ReturnsAsync(bytes1) - .ReturnsAsync(bytes2) - .ReturnsAsync(bytes3); - mockUnsafeOptions - .Setup(m => m.TokenCacheUpdatedAsync(It.IsAny())); - - var cache = new TokenCache(mockUnsafeOptions.Object, default, publicApplicationFactory: new Func(() => mockPublicClient.Object)); + var mockUnsafeOptions = new MockInMemoryTokenCacheOptions( new[]{bytes1, bytes2, bytes3}); + + var cache = new TokenCache(mockUnsafeOptions, default, () => mockPublicClient.Object); await cache.RegisterCache(IsAsync, mockMSALCache.Object, default); await main_OnBeforeCacheAccessAsync.Invoke(mockArgs); @@ -437,11 +430,30 @@ private static TokenCacheNotificationArgs GetMockArgs(Mock[] _bytes; + private int callIndex; + public MockInMemoryTokenCacheOptions(ReadOnlyMemory[] bytes) + { + _bytes = bytes; + } + protected internal override Task> RefreshCacheAsync() + { + return Task.FromResult(_bytes[callIndex++]); + } + + protected internal override Task TokenCacheUpdatedAsync(TokenCacheUpdatedArgs tokenCacheUpdatedArgs) + { + return Task.CompletedTask; + } + } + + public class TestInMemoryTokenCacheOptions : UnsafeTokenCacheOptions { private readonly ReadOnlyMemory _bytes; private readonly Func _updated; - public InMemoryTokenCacheOptions(byte[] bytes, Func updated = null) + public TestInMemoryTokenCacheOptions(byte[] bytes, Func updated = null) { _bytes = bytes; _updated = updated; diff --git a/sdk/identity/Azure.Identity/tests/UsernamePasswordCredentialTests.cs b/sdk/identity/Azure.Identity/tests/UsernamePasswordCredentialTests.cs index 018da493a19ed..f2fb4e0d64bb4 100644 --- a/sdk/identity/Azure.Identity/tests/UsernamePasswordCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/UsernamePasswordCredentialTests.cs @@ -11,21 +11,8 @@ namespace Azure.Identity.Tests { - public class UsernamePasswordCredentialTests : ClientTestBase + public class UsernamePasswordCredentialTests : CredentialTestBase { - private string TenantId = "a0287521-e002-0026-7112-207c0c000000"; - private const string TenantIdHint = "a0287521-e002-0026-7112-207c0c001234"; - private const string Scope = "https://vault.azure.net/.default"; - private const string ClientId = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"; - private string expectedCode; - private string expectedToken; - private DateTimeOffset expiresOn; - private MockMsalPublicClient mockMsal; - private DeviceCodeResult deviceCodeResult; - private string expectedTenantId; - private bool interactiveCalled; - private bool silentCalled; - public UsernamePasswordCredentialTests(bool isAsync) : base(isAsync) { } @@ -83,78 +70,12 @@ public async Task UsesTenantIdHint([Values(null, TenantIdHint)] string tenantId, var context = new TokenRequestContext(new[] { Scope }, tenantId: tenantId); expectedTenantId = TenantIdResolver.Resolve(TenantId, context, options.AllowMultiTenantAuthentication); - var credential = InstrumentClient(new UsernamePasswordCredential("user", "password", TenantId, ClientId, options, null, mockMsal)); - - AccessToken token = await credential.GetTokenAsync(context); - - Assert.AreEqual(expectedToken, token.Token); - Assert.AreEqual(expiresOn, token.ExpiresOn); - } - - [Test] - public async Task CallsGetAzquireTokenSilentAfterFirstTokenAcquired( - [Values(null, TenantIdHint)] string tenantId, - [Values(true)] bool allowMultiTenantAuthentication) - { - TestSetup(); - var options = new UsernamePasswordCredentialOptions { AllowMultiTenantAuthentication = allowMultiTenantAuthentication }; - var context = new TokenRequestContext(new[] { Scope }, tenantId: tenantId); - expectedTenantId = TenantIdResolver.Resolve(TenantId, context, options.AllowMultiTenantAuthentication); - - var credential = InstrumentClient(new UsernamePasswordCredential("user", "password", TenantId, ClientId, options, null, mockMsal)); + var credential = InstrumentClient(new UsernamePasswordCredential("user", "password", TenantId, ClientId, options, null, mockPublicMsalClient)); AccessToken token = await credential.GetTokenAsync(context); Assert.AreEqual(expectedToken, token.Token); Assert.AreEqual(expiresOn, token.ExpiresOn); - Assert.True(interactiveCalled); - Assert.False(silentCalled); - - // Second call should acquireSilent - token = await credential.GetTokenAsync(context); - - Assert.AreEqual(expectedToken, token.Token); - Assert.AreEqual(expiresOn, token.ExpiresOn); - Assert.True(silentCalled); - } - - public void TestSetup() - { - interactiveCalled = false; - silentCalled = false; - expectedTenantId = null; - expectedCode = Guid.NewGuid().ToString(); - expectedToken = Guid.NewGuid().ToString(); - expiresOn = DateTimeOffset.Now.AddHours(1); - mockMsal = new MockMsalPublicClient(); - deviceCodeResult = MockMsalPublicClient.GetDeviceCodeResult(deviceCode: expectedCode); - mockMsal.DeviceCodeResult = deviceCodeResult; - var result = new AuthenticationResult( - expectedToken, - false, - null, - expiresOn, - expiresOn, - TenantId, - new MockAccount("username"), - null, - new[] { Scope }, - Guid.NewGuid(), - null, - "Bearer"); - mockMsal.UserPassAuthFactory = (_, tenant) => - { - interactiveCalled = true; - Assert.AreEqual(expectedTenantId, tenant, "TenantId passed to msal should match"); - return result; - }; - - mockMsal.SilentAuthFactory = (_, tenant) => - { - silentCalled = true; - Assert.AreEqual(expectedTenantId, tenant, "TenantId passed to msal should match"); - return result; - }; } } } diff --git a/sdk/identity/Azure.Identity/tests/VisualStudioCodeCredentialTests.cs b/sdk/identity/Azure.Identity/tests/VisualStudioCodeCredentialTests.cs index a6113d872c657..de4eb631da979 100644 --- a/sdk/identity/Azure.Identity/tests/VisualStudioCodeCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/VisualStudioCodeCredentialTests.cs @@ -1,61 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Azure.Core; using Azure.Core.TestFramework; -using Azure.Identity.Tests.Mock; -using Microsoft.Identity.Client; using NUnit.Framework; namespace Azure.Identity.Tests { - public class VisualStudioCodeCredentialTests : ClientTestBase + public class VisualStudioCodeCredentialTests : CredentialTestBase { - private string TenantId = "a0287521-e002-0026-7112-207c0c000000"; - private const string TenantIdHint = "a0287521-e002-0026-7112-207c0c001234"; - private const string Scope = "https://vault.azure.net/.default"; - private string expectedCode; - private string expectedToken; - private DateTimeOffset expiresOn; - private MockMsalPublicClient mockMsalClient; - private DeviceCodeResult deviceCodeResult; - private string expectedTenantId; - public VisualStudioCodeCredentialTests(bool isAsync) : base(isAsync) { } [SetUp] - public void TestSetup() + public void Setup() { - expectedTenantId = null; - expectedCode = Guid.NewGuid().ToString(); - expectedToken = Guid.NewGuid().ToString(); - expiresOn = DateTimeOffset.Now.AddHours(1); - mockMsalClient = new MockMsalPublicClient(); - deviceCodeResult = MockMsalPublicClient.GetDeviceCodeResult(deviceCode: expectedCode); - mockMsalClient.DeviceCodeResult = deviceCodeResult; - var result = new AuthenticationResult( - expectedToken, - false, - null, - expiresOn, - expiresOn, - TenantId, - new MockAccount("username"), - null, - new[] { Scope }, - Guid.NewGuid(), - null, - "Bearer"); - mockMsalClient.RefreshTokenFactory = (_, _,_, _, tenant, _, _) => - { - Assert.AreEqual(expectedTenantId, tenant, "TenantId passed to msal should match"); - return result; - }; + TestSetup(); } [Test] @@ -72,7 +35,7 @@ public async Task AuthenticateWithVsCodeCredential([Values(null, TenantIdHint)] new VisualStudioCodeCredential( options, null, - mockMsalClient, + mockPublicMsalClient, CredentialTestHelpers.CreateFileSystemForVisualStudioCode(environment), new TestVscAdapter("VS Code Azure", "AzureCloud", expectedToken))); diff --git a/sdk/identity/Azure.Identity/tests/VisualStudioCredentialTests.cs b/sdk/identity/Azure.Identity/tests/VisualStudioCredentialTests.cs index 4816252af7d6f..55848d9daa9b9 100644 --- a/sdk/identity/Azure.Identity/tests/VisualStudioCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/VisualStudioCredentialTests.cs @@ -13,22 +13,17 @@ namespace Azure.Identity.Tests { [RunOnlyOnPlatforms(Windows = true)] // VisualStudioCredential works only on Windows - public class VisualStudioCredentialTests : ClientTestBase + public class VisualStudioCredentialTests : CredentialTestBase { - private string TenantId = "a0287521-e002-0026-7112-207c0c000000"; - private const string TenantIdHint = "a0287521-e002-0026-7112-207c0c001234"; - private const string Scope = "https://vault.azure.net/.default"; - private string expectedTenantId; - public VisualStudioCredentialTests(bool isAsync) : base(isAsync) { } [Test] - public async Task AuthenticateWithVsCredential([Values(null, TenantIdHint)] string tenantId, [Values(true)] bool preferHint) + public async Task AuthenticateWithVsCredential([Values(null, TenantIdHint)] string tenantId, [Values(true)] bool allowMultiTenantAuthentication) { var fileSystem = CredentialTestHelpers.CreateFileSystemForVisualStudio(); var (expectedToken, expectedExpiresOn, processOutput) = CredentialTestHelpers.CreateTokenForVisualStudio(); var testProcess = new TestProcess { Output = processOutput }; - var options = new VisualStudioCredentialOptions { AllowMultiTenantAuthentication = preferHint }; + var options = new VisualStudioCredentialOptions { AllowMultiTenantAuthentication = allowMultiTenantAuthentication }; var credential = InstrumentClient(new VisualStudioCredential(TenantId, default, fileSystem, new TestProcessService(testProcess, true), options)); var context = new TokenRequestContext(new[] { Scope }, tenantId: tenantId); expectedTenantId = TenantIdResolver.Resolve(TenantId, context, options.AllowMultiTenantAuthentication);