Skip to content

Commit

Permalink
Correct creation of the CryptographyClient in the KeyVaultService (#755)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlemstra authored Aug 14, 2024
1 parent 6584f5d commit 429d24d
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 49 deletions.
71 changes: 46 additions & 25 deletions src/Sign.SignatureProviders.KeyVault/KeyVaultService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,69 +8,90 @@
using Azure;
using Azure.Core;
using Azure.Security.KeyVault.Certificates;
using Azure.Security.KeyVault.Keys;
using Azure.Security.KeyVault.Keys.Cryptography;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Sign.Core;

namespace Sign.SignatureProviders.KeyVault
{
internal sealed class KeyVaultService : ISignatureAlgorithmProvider, ICertificateProvider
internal sealed class KeyVaultService : ISignatureAlgorithmProvider, ICertificateProvider, IDisposable
{
private readonly ILogger<KeyVaultService> _logger;
private readonly Task<KeyVaultCertificateWithPolicy>? _task;
private readonly TokenCredential _tokenCredential;
private readonly Uri _keyVaultUrl;
private readonly string _certificateName;
private readonly ILogger<KeyVaultService> _logger;
private readonly SemaphoreSlim _mutex = new(1);
private KeyVaultCertificateWithPolicy? _certificateWithPolicy;

internal KeyVaultService(
IServiceProvider serviceProvider,
TokenCredential tokenCredential,
Uri keyVaultUrl,
string certificateName)
string certificateName,
ILogger<KeyVaultService> logger)
{
ArgumentNullException.ThrowIfNull(serviceProvider, nameof(serviceProvider));
ArgumentNullException.ThrowIfNull(tokenCredential, nameof(tokenCredential));
ArgumentNullException.ThrowIfNull(keyVaultUrl, nameof(keyVaultUrl));
ArgumentException.ThrowIfNullOrEmpty(certificateName, nameof(certificateName));
ArgumentNullException.ThrowIfNull(logger, nameof(logger));

_tokenCredential = tokenCredential;
_logger = serviceProvider.GetRequiredService<ILogger<KeyVaultService>>();
_keyVaultUrl = keyVaultUrl;
_certificateName = certificateName;
_logger = logger;
}

_task = GetKeyVaultCertificateAsync(keyVaultUrl, tokenCredential, certificateName);
public void Dispose()
{
_mutex.Dispose();
GC.SuppressFinalize(this);
}

public async Task<X509Certificate2> GetCertificateAsync(CancellationToken cancellationToken)
{
KeyVaultCertificateWithPolicy certificateWithPolicy = await _task!;
KeyVaultCertificateWithPolicy certificateWithPolicy = await GetCertificateWithPolicyAsync(cancellationToken);

return new X509Certificate2(certificateWithPolicy.Cer);
}

public async Task<RSA> GetRsaAsync(CancellationToken cancellationToken)
{
KeyVaultCertificateWithPolicy certificateWithPolicy = await _task!;
KeyClient keyClient = new(certificateWithPolicy.KeyId, _tokenCredential);
CryptographyClient cryptoClient = keyClient.GetCryptographyClient(certificateWithPolicy.Name);
KeyVaultCertificateWithPolicy certificateWithPolicy = await GetCertificateWithPolicyAsync(cancellationToken);

CryptographyClient cryptoClient = new(certificateWithPolicy.KeyId, _tokenCredential);
return await cryptoClient.CreateRSAAsync(cancellationToken);
}

private async Task<KeyVaultCertificateWithPolicy> GetKeyVaultCertificateAsync(
Uri keyVaultUrl,
TokenCredential tokenCredential,
string certificateName)
private async Task<KeyVaultCertificateWithPolicy> GetCertificateWithPolicyAsync(CancellationToken cancellationToken)
{
Stopwatch stopwatch = Stopwatch.StartNew();
if (_certificateWithPolicy is not null)
{
return _certificateWithPolicy;
}

await _mutex.WaitAsync(cancellationToken);

try
{
if (_certificateWithPolicy is null)
{
Stopwatch stopwatch = Stopwatch.StartNew();

_logger.LogTrace(Resources.FetchingCertificate);

_logger.LogTrace(Resources.FetchingCertificate);
CertificateClient client = new(_keyVaultUrl, _tokenCredential);
Response<KeyVaultCertificateWithPolicy> response = await client.GetCertificateAsync(_certificateName, cancellationToken);

CertificateClient client = new(keyVaultUrl, tokenCredential);
Response<KeyVaultCertificateWithPolicy>? response =
await client.GetCertificateAsync(certificateName).ConfigureAwait(false);
_logger.LogTrace(Resources.FetchedCertificate, stopwatch.Elapsed.TotalMilliseconds);

_logger.LogTrace(Resources.FetchedCertificate, stopwatch.Elapsed.TotalMilliseconds);
_certificateWithPolicy = response.Value;
}
}
finally
{
_mutex.Release();
}

return response.Value;
return _certificateWithPolicy;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// See the LICENSE.txt file in the project root for more information.

using Azure.Core;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Sign.Core;

namespace Sign.SignatureProviders.KeyVault
Expand Down Expand Up @@ -57,7 +59,8 @@ private KeyVaultService GetService(IServiceProvider serviceProvider)
return _keyVaultService;
}

_keyVaultService = new KeyVaultService(serviceProvider, _tokenCredential, _keyVaultUrl, _certificateName);
ILogger<KeyVaultService> logger = serviceProvider.GetRequiredService<ILogger<KeyVaultService>>();
_keyVaultService = new KeyVaultService(_tokenCredential, _keyVaultUrl, _certificateName, logger);
}

return _keyVaultService;
Expand Down
37 changes: 14 additions & 23 deletions test/Sign.SignatureProviders.KeyVault.Test/KeyVaultServiceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
// See the LICENSE.txt file in the project root for more information.

using Azure.Core;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Moq;
using Sign.TestInfrastructure;

namespace Sign.SignatureProviders.KeyVault.Test
{
Expand All @@ -15,29 +13,13 @@ public class KeyVaultServiceTests
private readonly static TokenCredential TokenCredential = Mock.Of<TokenCredential>();
private readonly static Uri KeyVaultUrl = new("https://keyvault.test");
private const string CertificateName = "a";
private readonly IServiceProvider serviceProvider;

public KeyVaultServiceTests()
{
ServiceCollection services = new();
services.AddSingleton<ILogger<KeyVaultService>>(new TestLogger<KeyVaultService>());
serviceProvider = services.BuildServiceProvider();
}

[Fact]
public void Constructor_WhenServiceProviderIsNull_Throws()
{
ArgumentNullException exception = Assert.Throws<ArgumentNullException>(
() => new KeyVaultService(serviceProvider: null!, TokenCredential, KeyVaultUrl, CertificateName));

Assert.Equal("serviceProvider", exception.ParamName);
}
private readonly static ILogger<KeyVaultService> logger = Mock.Of<ILogger<KeyVaultService>>();

[Fact]
public void Constructor_WhenTokenCredentialIsNull_Throws()
{
ArgumentNullException exception = Assert.Throws<ArgumentNullException>(
() => new KeyVaultService(serviceProvider, tokenCredential: null!, KeyVaultUrl, CertificateName));
() => new KeyVaultService(tokenCredential: null!, KeyVaultUrl, CertificateName, logger));

Assert.Equal("tokenCredential", exception.ParamName);
}
Expand All @@ -46,7 +28,7 @@ public void Constructor_WhenTokenCredentialIsNull_Throws()
public void Constructor_WhenKeyVaultUrlIsNull_Throws()
{
ArgumentNullException exception = Assert.Throws<ArgumentNullException>(
() => new KeyVaultService(serviceProvider, TokenCredential, keyVaultUrl: null!, CertificateName));
() => new KeyVaultService(TokenCredential, keyVaultUrl: null!, CertificateName, logger));

Assert.Equal("keyVaultUrl", exception.ParamName);
}
Expand All @@ -55,7 +37,7 @@ public void Constructor_WhenKeyVaultUrlIsNull_Throws()
public void Constructor_WhenCertificateNameIsNull_Throws()
{
ArgumentNullException exception = Assert.Throws<ArgumentNullException>(
() => new KeyVaultService(serviceProvider, TokenCredential, KeyVaultUrl, certificateName: null!));
() => new KeyVaultService(TokenCredential, KeyVaultUrl, certificateName: null!, logger));

Assert.Equal("certificateName", exception.ParamName);
}
Expand All @@ -64,9 +46,18 @@ public void Constructor_WhenCertificateNameIsNull_Throws()
public void Constructor_WhenCertificateNameIsEmpty_Throws()
{
ArgumentException exception = Assert.Throws<ArgumentException>(
() => new KeyVaultService(serviceProvider, TokenCredential, KeyVaultUrl, certificateName: string.Empty));
() => new KeyVaultService(TokenCredential, KeyVaultUrl, certificateName: string.Empty, logger));

Assert.Equal("certificateName", exception.ParamName);
}

[Fact]
public void Constructor_WhenLoggerIsNull_Throws()
{
ArgumentNullException exception = Assert.Throws<ArgumentNullException>(
() => new KeyVaultService(TokenCredential, KeyVaultUrl, CertificateName, logger: null!));

Assert.Equal("logger", exception.ParamName);
}
}
}

0 comments on commit 429d24d

Please sign in to comment.