Skip to content

Commit

Permalink
fix: Add Audience for Certificate auth to work with Skills (#6794)
Browse files Browse the repository at this point in the history
* Fix Certificate auth in Skills

* Ensure thread safety for CertificateServiceClientCredentialsFactory

---------

Co-authored-by: Andrés Robinet <andres.robinet@southworks.com>
  • Loading branch information
sw-joelmut and andres-robinet-sw authored May 28, 2024
1 parent b15a8bf commit 259c174
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public CertificateAppCredentials(CertificateAppCredentialsOptions options)
/// <param name="customHttpClient">Optional <see cref="HttpClient"/> to be used when acquiring tokens.</param>
/// <param name="logger">Optional <see cref="ILogger"/> to gather telemetry data while acquiring and managing credentials.</param>
public CertificateAppCredentials(X509Certificate2 clientCertificate, string appId, string channelAuthTenant = null, HttpClient customHttpClient = null, ILogger logger = null)
: this(clientCertificate, false, appId, channelAuthTenant, customHttpClient, logger)
: this(clientCertificate, appId, channelAuthTenant, string.Empty, false, customHttpClient, logger)
{
}

Expand All @@ -62,7 +62,22 @@ public CertificateAppCredentials(X509Certificate2 clientCertificate, string appI
/// <param name="customHttpClient">Optional <see cref="HttpClient"/> to be used when acquiring tokens.</param>
/// <param name="logger">Optional <see cref="ILogger"/> to gather telemetry data while acquiring and managing credentials.</param>
public CertificateAppCredentials(X509Certificate2 clientCertificate, bool sendX5c, string appId, string channelAuthTenant = null, HttpClient customHttpClient = null, ILogger logger = null)
: base(channelAuthTenant, customHttpClient, logger)
: this(clientCertificate, appId, channelAuthTenant, string.Empty, sendX5c, customHttpClient, logger)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="CertificateAppCredentials"/> class.
/// </summary>
/// <param name="clientCertificate">Client certificate to be presented for authentication.</param>
/// <param name="appId">Microsoft application Id related to the certifiacte.</param>
/// <param name="channelAuthTenant">Optional. The oauth token tenant.</param>
/// <param name="oAuthScope">Optional. The scope for the token.</param>
/// <param name="sendX5c">Optional. This parameter, if true, enables application developers to achieve easy certificates roll-over in Azure AD: setting this parameter to true will send the public certificate to Azure AD along with the token request, so that Azure AD can use it to validate the subject name based on a trusted issuer policy. </param>
/// <param name="customHttpClient">Optional <see cref="HttpClient"/> to be used when acquiring tokens.</param>
/// <param name="logger">Optional <see cref="ILogger"/> to gather telemetry data while acquiring and managing credentials.</param>
public CertificateAppCredentials(X509Certificate2 clientCertificate, string appId, string channelAuthTenant = null, string oAuthScope = null, bool sendX5c = false, HttpClient customHttpClient = null, ILogger logger = null)
: base(channelAuthTenant, customHttpClient, logger, oAuthScope: oAuthScope)
{
if (clientCertificate == null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Collections.Concurrent;
using System.Net.Http;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
Expand All @@ -16,8 +17,13 @@ namespace Microsoft.Bot.Connector.Authentication
/// </summary>
public class CertificateServiceClientCredentialsFactory : ServiceClientCredentialsFactory
{
private readonly CertificateAppCredentials _certificateAppCredentials;
private readonly X509Certificate2 _certificate;
private readonly string _appId;
private readonly string _tenantId;
private readonly bool _sendX5c;
private readonly HttpClient _httpClient;
private readonly ILogger _logger;
private readonly ConcurrentDictionary<string, CertificateAppCredentials> _certificateAppCredentialsByAudience = new ConcurrentDictionary<string, CertificateAppCredentials>();

/// <summary>
/// Initializes a new instance of the <see cref="CertificateServiceClientCredentialsFactory"/> class.
Expand All @@ -44,16 +50,12 @@ public CertificateServiceClientCredentialsFactory(
throw new ArgumentNullException(nameof(appId));
}

_certificate = certificate ?? throw new ArgumentNullException(nameof(certificate));
_appId = appId;

// Instance must be reused otherwise it will cause throttling on AAD.
_certificateAppCredentials = new CertificateAppCredentials(
certificate ?? throw new ArgumentNullException(nameof(certificate)),
sendX5c,
appId,
tenantId,
httpClient,
logger);
_tenantId = tenantId;
_sendX5c = sendX5c;
_httpClient = httpClient;
_logger = logger;
}

/// <inheritdoc />
Expand All @@ -78,7 +80,20 @@ public override Task<ServiceClientCredentials> CreateCredentialsAsync(
throw new InvalidOperationException("Invalid Managed ID.");
}

return Task.FromResult<ServiceClientCredentials>(_certificateAppCredentials);
// Instance must be reused per audience, otherwise it will cause throttling on AAD.
var certificateAppCredentials = _certificateAppCredentialsByAudience.GetOrAdd(audience, (audience) =>
{
return new CertificateAppCredentials(
_certificate,
_appId,
_tenantId,
audience,
_sendX5c,
_httpClient,
_logger);
});

return Task.FromResult<ServiceClientCredentials>(certificateAppCredentials);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public class CertificateServiceClientCredentialsFactoryTests
private const string TestAppId = nameof(TestAppId);
private const string TestTenantId = nameof(TestTenantId);
private const string TestAudience = nameof(TestAudience);
private const string LoginEndpoint = "https://login.microsoftonline.com";
private readonly Mock<ILogger> logger = new Mock<ILogger>();
private readonly Mock<X509Certificate2> certificate = new Mock<X509Certificate2>();

Expand Down Expand Up @@ -68,19 +69,38 @@ public async void CanCreateCredentials()
var factory = new CertificateServiceClientCredentialsFactory(certificate.Object, TestAppId);

var credentials = await factory.CreateCredentialsAsync(
TestAppId, TestAudience, "https://login.microsoftonline.com", true, CancellationToken.None);
TestAppId, TestAudience, LoginEndpoint, true, CancellationToken.None);

Assert.NotNull(credentials);
Assert.IsType<CertificateAppCredentials>(credentials);
}

[Fact]
public async void ShouldCreateUniqueCredentialsByAudience()
{
var factory = new CertificateServiceClientCredentialsFactory(certificate.Object, TestAppId);

var credentials1 = await factory.CreateCredentialsAsync(
TestAppId, string.Empty, LoginEndpoint, true, CancellationToken.None);
var credentials2 = await factory.CreateCredentialsAsync(
TestAppId, TestAudience, LoginEndpoint, true, CancellationToken.None);
var credentials3 = await factory.CreateCredentialsAsync(
TestAppId, Guid.NewGuid().ToString(), LoginEndpoint, true, CancellationToken.None);
var credentials4 = await factory.CreateCredentialsAsync(
TestAppId, string.Empty, LoginEndpoint, true, CancellationToken.None);

Assert.NotEqual(credentials1, credentials2);
Assert.NotEqual(credentials1, credentials3);
Assert.Equal(credentials1, credentials4);
}

[Fact]
public void CannotCreateCredentialsWithInvalidAppId()
{
var factory = new CertificateServiceClientCredentialsFactory(certificate.Object, TestAppId);

Assert.ThrowsAsync<InvalidOperationException>(() => factory.CreateCredentialsAsync(
"InvalidAppId", TestAudience, "https://login.microsoftonline.com", true, CancellationToken.None));
"InvalidAppId", TestAudience, LoginEndpoint, true, CancellationToken.None));
}
}
}

0 comments on commit 259c174

Please sign in to comment.