@@ -1107,7 +1107,6 @@ @MakeLicenseSpan(segment); } } - @:license } else { From 315ab746bc2998c66fc317bbddd0cebcd7ce9e04 Mon Sep 17 00:00:00 2001 From: Daniel Jacinto <17834924+dannyjdev@users.noreply.github.com> Date: Mon, 18 Nov 2024 13:00:15 -0800 Subject: [PATCH 29/33] [Storage Migration] V3 jobs (#10228) --- src/Catalog/Persistence/AzureStorage.cs | 2 +- src/Ng/CommandHelpers.cs | 29 ++++++- src/Ng/Jobs/Catalog2DnxJob.cs | 6 +- src/Ng/Jobs/Db2CatalogJob.cs | 6 +- src/Ng/Jobs/FixCatalogCachingJob.cs | 4 +- src/Ng/Jobs/LightningJob.cs | 75 +++++++++++++------ src/Ng/Jobs/NgJob.cs | 6 +- .../Catalog2RegistrationConfiguration.cs | 22 +++++- .../DependencyInjectionExtensions.cs | 24 +++++- src/NuGet.Jobs.Catalog2Registration/Job.cs | 17 ++++- .../BlobServiceClientAuthType.cs | 12 +++ .../BlobServiceClientFactory.cs | 44 ++++++++--- 12 files changed, 191 insertions(+), 56 deletions(-) create mode 100644 src/NuGet.Services.Storage/BlobServiceClientAuthType.cs diff --git a/src/Catalog/Persistence/AzureStorage.cs b/src/Catalog/Persistence/AzureStorage.cs index b8805f9bb7..91bfea2038 100644 --- a/src/Catalog/Persistence/AzureStorage.cs +++ b/src/Catalog/Persistence/AzureStorage.cs @@ -94,7 +94,7 @@ private static ICloudBlobDirectory GetCloudBlobDirectoryUri(Uri storageBaseUri) var blobEndpoint = new Uri(storageBaseUri.GetComponents(UriComponents.SchemeAndServer, UriFormat.Unescaped)); // Create BlobServiceClient with anonymous credentials - var blobServiceClient = new BlobServiceClientFactory(blobEndpoint, new DefaultAzureCredential()); + var blobServiceClient = new BlobServiceClientFactory(blobEndpoint); string containerName = pathSegments[0]; string pathInContainer = string.Join("/", pathSegments.Skip(1)); diff --git a/src/Ng/CommandHelpers.cs b/src/Ng/CommandHelpers.cs index cfe14ae4ec..298a900d5a 100644 --- a/src/Ng/CommandHelpers.cs +++ b/src/Ng/CommandHelpers.cs @@ -11,6 +11,7 @@ using Azure; using Azure.Core; using Azure.Identity; +using Azure.Storage; using Azure.Storage.Blobs; using Azure.Storage.Queues; using Microsoft.Extensions.Logging; @@ -40,6 +41,8 @@ public static class CommandHelpers }; private static readonly IDictionary ArgumentNames = new Dictionary { + { Arguments.UseManagedIdentity, Arguments.UseManagedIdentity }, + { Arguments.ClientId, Arguments.ClientId}, { Arguments.StorageBaseAddress, Arguments.StorageBaseAddress }, { Arguments.StorageAccountName, Arguments.StorageAccountName }, { Arguments.StorageKeyValue, Arguments.StorageKeyValue }, @@ -52,7 +55,6 @@ public static class CommandHelpers { Arguments.StorageOperationMaxExecutionTimeInSeconds, Arguments.StorageOperationMaxExecutionTimeInSeconds }, { Arguments.StorageServerTimeoutInSeconds, Arguments.StorageServerTimeoutInSeconds }, { Arguments.StorageInitializeContainer, Arguments.StorageInitializeContainer }, - { Arguments.ClientId, Arguments.ClientId }, }; public static IDictionary GetArguments(string[] args, int start, out ICachingSecretInjector secretInjector) @@ -164,6 +166,8 @@ public static CatalogStorageFactory CreateSuffixedStorageFactory( IDictionary names = new Dictionary { + { Arguments.UseManagedIdentity, Arguments.UseManagedIdentity }, + { Arguments.ClientId, Arguments.ClientId}, { Arguments.StorageBaseAddress, Arguments.StorageBaseAddress + suffix }, { Arguments.StorageAccountName, Arguments.StorageAccountName + suffix }, { Arguments.StorageUseManagedIdentity, Arguments.StorageUseManagedIdentity + suffix }, @@ -198,7 +202,6 @@ private static CatalogStorageFactory CreateStorageFactoryImpl( if (!string.IsNullOrEmpty(storageBaseAddressStr)) { storageBaseAddressStr = storageBaseAddressStr.TrimEnd('/') + "/"; - storageBaseAddress = new Uri(storageBaseAddressStr); } @@ -427,7 +430,16 @@ private static IBlobServiceClientFactory GetBlobServiceClient( IDictionary argumentNameMap) { bool storageUseManagedIdentity = arguments.GetOrDefault(argumentNameMap[Arguments.StorageUseManagedIdentity], defaultValue: false); - if (storageUseManagedIdentity) + bool useManagedIdentity = storageUseManagedIdentity || arguments.GetOrDefault(argumentNameMap[Arguments.UseManagedIdentity], defaultValue: false); + + var storageKeyValue = arguments.GetOrDefault(argumentNameMap[Arguments.StorageKeyValue]); + var storageSasValue = arguments.GetOrDefault(argumentNameMap[Arguments.StorageSasValue]); + + bool hasStorageKeyOrSas = !string.IsNullOrEmpty(storageKeyValue) || !string.IsNullOrEmpty(storageSasValue); + + // This comparison is due to some jobs using both global and china storages in a single instance. + // They require MSI auth for global storage and SAS/SAK auth for china storage. + if (useManagedIdentity && !hasStorageKeyOrSas) { var managedIdentityClientId = arguments.GetOrThrow(argumentNameMap[Arguments.ClientId]); var identity = new ManagedIdentityCredential(managedIdentityClientId); @@ -444,7 +456,16 @@ private static QueueServiceClient GetQueueServiceClient( IDictionary argumentNameMap) { bool storageUseManagedIdentity = arguments.GetOrDefault(argumentNameMap[Arguments.StorageUseManagedIdentity], defaultValue: false); - if (storageUseManagedIdentity) + bool useManagedIdentity = storageUseManagedIdentity || arguments.GetOrDefault(argumentNameMap[Arguments.UseManagedIdentity], defaultValue: false); + + var storageKeyValue = arguments.GetOrDefault(argumentNameMap[Arguments.StorageKeyValue]); + var storageSasValue = arguments.GetOrDefault(argumentNameMap[Arguments.StorageSasValue]); + + bool hasStorageKeyOrSas = !string.IsNullOrEmpty(storageKeyValue) || !string.IsNullOrEmpty(storageSasValue); + + // This comparison is due to some jobs using both global and china storages in a single instance. + // They require MSI auth for global storage and SAS/SAK auth for china storage. + if (useManagedIdentity && !hasStorageKeyOrSas) { var managedIdentityClientId = arguments.GetOrThrow(argumentNameMap[Arguments.ClientId]); var identity = new ManagedIdentityCredential(managedIdentityClientId); diff --git a/src/Ng/Jobs/Catalog2DnxJob.cs b/src/Ng/Jobs/Catalog2DnxJob.cs index 30c7c6def8..adc18ea6cb 100644 --- a/src/Ng/Jobs/Catalog2DnxJob.cs +++ b/src/Ng/Jobs/Catalog2DnxJob.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -45,7 +45,7 @@ public override string GetUsage() + $"-{Arguments.StoragePath} " + $"[-{Arguments.VaultName} " + $"-{Arguments.UseManagedIdentity} true|false " - + $"-{Arguments.ClientId} Should not be set if {Arguments.UseManagedIdentity} is true" + + $"-{Arguments.ClientId} If {Arguments.UseManagedIdentity} is true this is used for managed identity authentication, if false, is used for KeyVault certificate authentication" + $"-{Arguments.CertificateThumbprint} Should not be set if {Arguments.UseManagedIdentity} is true" + $"[-{Arguments.ValidateCertificate} true|false]]] " + $"[-{Arguments.Verbose} true|false] " @@ -122,4 +122,4 @@ protected override async Task RunInternalAsync(CancellationToken cancellationTok } } } -} \ No newline at end of file +} diff --git a/src/Ng/Jobs/Db2CatalogJob.cs b/src/Ng/Jobs/Db2CatalogJob.cs index f4463e8b68..5457a6c0fc 100644 --- a/src/Ng/Jobs/Db2CatalogJob.cs +++ b/src/Ng/Jobs/Db2CatalogJob.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -63,7 +63,7 @@ public override string GetUsage() + $"-{Arguments.StoragePath} " + $"[-{Arguments.VaultName} " + $"-{Arguments.UseManagedIdentity} true|false " - + $"-{Arguments.ClientId} Should not be set if {Arguments.UseManagedIdentity} is true" + + $"-{Arguments.ClientId} If {Arguments.UseManagedIdentity} is true this is used for managed identity authentication, if false, is used for KeyVault certificate authentication" + $"-{Arguments.CertificateThumbprint} Should not be set if {Arguments.UseManagedIdentity} is true" + $"[-{Arguments.ValidateCertificate} true|false]]] " + $"-{Arguments.StorageTypeAuditing} file|azure " @@ -413,4 +413,4 @@ private async Task Deletes2Catalog( return lastDeleted; } } -} \ No newline at end of file +} diff --git a/src/Ng/Jobs/FixCatalogCachingJob.cs b/src/Ng/Jobs/FixCatalogCachingJob.cs index 762e23a503..428e39e0e0 100644 --- a/src/Ng/Jobs/FixCatalogCachingJob.cs +++ b/src/Ng/Jobs/FixCatalogCachingJob.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -52,7 +52,7 @@ public override string GetUsage() + $"-{Arguments.StoragePath} " + $"[-{Arguments.VaultName} " + $"-{Arguments.UseManagedIdentity} true|false " - + $"-{Arguments.ClientId} Should not be set if {Arguments.UseManagedIdentity} is true" + + $"-{Arguments.ClientId} If {Arguments.UseManagedIdentity} is true this is used for managed identity authentication, if false, is used for KeyVault certificate authentication" + $"-{Arguments.CertificateThumbprint} Should not be set if {Arguments.UseManagedIdentity} is true" + $"[-{Arguments.ValidateCertificate} true|false ]]"; } diff --git a/src/Ng/Jobs/LightningJob.cs b/src/Ng/Jobs/LightningJob.cs index f71a259c5b..399ed2f2e7 100644 --- a/src/Ng/Jobs/LightningJob.cs +++ b/src/Ng/Jobs/LightningJob.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -400,8 +400,10 @@ private void AddStorageCredentialArgument(StringBuilder argument, string sasToke { AppendArgument(argument, sasTokenArgument); } - - AppendArgument(argument, storageKeyArgument); + else if (!string.IsNullOrEmpty(_arguments.GetOrDefault(storageKeyArgument))) + { + AppendArgument(argument, storageKeyArgument); + } } private async Task StrikeAsync() @@ -504,32 +506,55 @@ private IContainer GetAutofacContainer() services.Configure(config => { + config.StorageUseManagedIdentity = _arguments.GetOrDefault(Arguments.UseManagedIdentity); + config.StorageManagedIdentityClientId = _arguments.GetOrDefault(Arguments.ClientId); + config.LegacyBaseUrl = _arguments.GetOrDefault(Arguments.StorageBaseAddress); config.LegacyStorageContainer = _arguments.GetOrDefault(Arguments.StorageContainer); - config.StorageConnectionString = GetConnectionString( - config.StorageConnectionString, - Arguments.StorageAccountName, - Arguments.StorageKeyValue, - Arguments.StorageSasValue, - Arguments.StorageSuffix); config.GzippedBaseUrl = _arguments.GetOrDefault(Arguments.CompressedStorageBaseAddress); config.GzippedStorageContainer = _arguments.GetOrDefault(Arguments.CompressedStorageContainer); - config.StorageConnectionString = GetConnectionString( - config.StorageConnectionString, - Arguments.CompressedStorageAccountName, - Arguments.CompressedStorageKeyValue, - Arguments.CompressedStorageSasValue, - Arguments.StorageSuffix); config.SemVer2BaseUrl = _arguments.GetOrDefault(Arguments.SemVer2StorageBaseAddress); config.SemVer2StorageContainer = _arguments.GetOrDefault(Arguments.SemVer2StorageContainer); - config.StorageConnectionString = GetConnectionString( - config.StorageConnectionString, - Arguments.SemVer2StorageAccountName, - Arguments.SemVer2StorageKeyValue, - Arguments.SemVer2StorageSasValue, - Arguments.StorageSuffix); + + config.HasSasToken = new List() + { + _arguments.GetOrDefault(Arguments.StorageSasValue), + _arguments.GetOrDefault(Arguments.CompressedStorageSasValue), + _arguments.GetOrDefault(Arguments.SemVer2StorageSasValue) + } + .All(t => !string.IsNullOrEmpty(t)); + + if (config.StorageUseManagedIdentity && !config.HasSasToken) + { + var storageAccountName = _arguments.GetOrDefault(Arguments.StorageAccountName); + var storageSuffix = _arguments.GetOrDefault(Arguments.StorageSuffix, "core.windows.net"); + + config.StorageServiceUrl = $"https://{storageAccountName}.blob.{storageSuffix}"; + config.StorageConnectionString = $"BlobEndpoint={config.StorageServiceUrl}"; + } + else + { + config.StorageConnectionString = GetConnectionString( + config.StorageConnectionString, + Arguments.StorageAccountName, + Arguments.StorageKeyValue, + Arguments.StorageSasValue, + Arguments.StorageSuffix); + config.StorageConnectionString = GetConnectionString( + config.StorageConnectionString, + Arguments.CompressedStorageAccountName, + Arguments.CompressedStorageKeyValue, + Arguments.CompressedStorageSasValue, + Arguments.StorageSuffix); + config.StorageConnectionString = GetConnectionString( + config.StorageConnectionString, + Arguments.SemVer2StorageAccountName, + Arguments.SemVer2StorageKeyValue, + Arguments.SemVer2StorageSasValue, + Arguments.StorageSuffix); + } config.GalleryBaseUrl = _arguments.GetOrThrow(Arguments.GalleryBaseAddress); var contentBaseAddress = _arguments.GetOrThrow(Arguments.ContentBaseAddress); @@ -568,7 +593,13 @@ private string GetConnectionString( } else { - builder.AppendFormat("SharedAccessSignature={0};", _arguments.GetOrThrow(accountSasArgument)); + var sasToken = _arguments.GetOrThrow(accountSasArgument); + // workaround for https://github.com/Azure/azure-sdk-for-net/issues/44373 + if (sasToken.StartsWith("?")) + { + sasToken = sasToken.Substring(1); + } + builder.AppendFormat("SharedAccessSignature={0};", sasToken); } builder.AppendFormat("EndpointSuffix={0}", _arguments.GetOrDefault(endpointSuffixArgument, "core.windows.net")); diff --git a/src/Ng/Jobs/NgJob.cs b/src/Ng/Jobs/NgJob.cs index 8fd2b7b779..a22e94f7b8 100644 --- a/src/Ng/Jobs/NgJob.cs +++ b/src/Ng/Jobs/NgJob.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -51,7 +51,7 @@ public static string GetUsageBase() return "Usage: ng [" + string.Join("|", NgJobFactory.JobMap.Keys) + "] " + $"[-{Arguments.VaultName} " + $"-{Arguments.UseManagedIdentity} true|false " - + $"-{Arguments.ClientId} Should not be set if {Arguments.UseManagedIdentity} is true" + + $"-{Arguments.ClientId} If {Arguments.UseManagedIdentity} is true this is used for managed identity authentication, if false, is used for KeyVault certificate authentication" + $"-{Arguments.CertificateThumbprint} Should not be set if {Arguments.UseManagedIdentity} is true" + $"[-{Arguments.ValidateCertificate} true|false]]"; } @@ -76,4 +76,4 @@ public virtual async Task RunAsync(IDictionary arguments, Cancel await RunInternalAsync(cancellationToken); } } -} \ No newline at end of file +} diff --git a/src/NuGet.Jobs.Catalog2Registration/Catalog2RegistrationConfiguration.cs b/src/NuGet.Jobs.Catalog2Registration/Catalog2RegistrationConfiguration.cs index a81b458633..1146d96f0a 100644 --- a/src/NuGet.Jobs.Catalog2Registration/Catalog2RegistrationConfiguration.cs +++ b/src/NuGet.Jobs.Catalog2Registration/Catalog2RegistrationConfiguration.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -11,6 +11,26 @@ public class Catalog2RegistrationConfiguration : ICommitCollectorConfiguration { private static readonly int DefaultMaxConcurrentHivesPerId = Enum.GetValues(typeof(HiveType)).Length; + /// + /// Whether or not managed identity will be used as credential. + /// + public bool StorageUseManagedIdentity { get; set; } + + /// + /// Specific manage identity client id. + /// + public string StorageManagedIdentityClientId { get; set; } + + /// + /// Whether or not any storage contains a sas token. + /// + public bool HasSasToken { get; set; } + + /// + /// Azure storage service uri. e.g. https://.blob.core.windows.net + /// + public string StorageServiceUrl { get; set; } + /// /// The connection string used to connect to an Azure Blob Storage account. The connection string specifies /// the account name, the endpoint suffix (e.g. Azure vs. Azure China), and authentication credential (e.g. storage diff --git a/src/NuGet.Jobs.Catalog2Registration/DependencyInjectionExtensions.cs b/src/NuGet.Jobs.Catalog2Registration/DependencyInjectionExtensions.cs index e19ec81dd8..dbd96ba094 100644 --- a/src/NuGet.Jobs.Catalog2Registration/DependencyInjectionExtensions.cs +++ b/src/NuGet.Jobs.Catalog2Registration/DependencyInjectionExtensions.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Net.Http; using Autofac; +using Azure.Identity; using Azure.Storage.Blobs; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; @@ -31,10 +32,20 @@ public static ContainerBuilder AddCatalog2Registration(this ContainerBuilder con containerBuilder.AddV3(); RegisterCursorStorage(containerBuilder); - containerBuilder - .RegisterStorageAccount(c => c.StorageConnectionString, requestTimeout: DefaultBlobRequestOptions.ServerTimeout) - .As(); + .Register(c => + { + var options = c.Resolve>(); + + if (options.Value.StorageUseManagedIdentity && !options.Value.HasSasToken) + { + return CloudBlobClientWrapper.UsingMsi(options.Value.StorageConnectionString, clientId: options.Value.StorageManagedIdentityClientId, requestTimeout: DefaultBlobRequestOptions.ServerTimeout); + } + + return new CloudBlobClientWrapper( + options.Value.StorageConnectionString, + requestTimeout: DefaultBlobRequestOptions.ServerTimeout); + }); containerBuilder.Register(c => new Catalog2RegistrationCommand( c.Resolve(), @@ -57,7 +68,12 @@ private static void RegisterCursorStorage(ContainerBuilder containerBuilder) { var options = c.Resolve>(); - // workaround for https://github.com/Azure/azure-sdk-for-net/issues/44373 + if (options.Value.StorageUseManagedIdentity && !options.Value.HasSasToken) + { + var credential = new ManagedIdentityCredential(options.Value.StorageManagedIdentityClientId); + + return new BlobServiceClientFactory(new Uri(options.Value.StorageServiceUrl), credential); + } var connectionString = options.Value.StorageConnectionString.Replace("SharedAccessSignature=?", "SharedAccessSignature="); return new BlobServiceClientFactory(connectionString); diff --git a/src/NuGet.Jobs.Catalog2Registration/Job.cs b/src/NuGet.Jobs.Catalog2Registration/Job.cs index 7084f227df..d340d261b4 100644 --- a/src/NuGet.Jobs.Catalog2Registration/Job.cs +++ b/src/NuGet.Jobs.Catalog2Registration/Job.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Net; @@ -7,6 +7,7 @@ using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using NuGet.Jobs.Catalog2Registration; +using NuGet.Services.Configuration; using NuGet.Services.V3; namespace NuGet.Jobs @@ -33,6 +34,20 @@ protected override void ConfigureJobServices(IServiceCollection services, IConfi services.AddCatalog2Registration(GlobalTelemetryDimensions, configurationRoot); services.Configure(configurationRoot.GetSection(ConfigurationSectionName)); + services.Configure((config) => + { + config.StorageUseManagedIdentity = configurationRoot.GetValue(Constants.StorageUseManagedIdentityPropertyName, false); + config.StorageManagedIdentityClientId = configurationRoot.GetValue(Constants.ManagedIdentityClientIdKey, string.Empty); + + if(config.StorageConnectionString.Contains("SharedAccessSignature")) + { + config.HasSasToken = true; + } + else + { + config.StorageServiceUrl = config.StorageConnectionString.Replace("BlobEndpoint=", ""); + } + }); services.Configure(configurationRoot.GetSection(ConfigurationSectionName)); } } diff --git a/src/NuGet.Services.Storage/BlobServiceClientAuthType.cs b/src/NuGet.Services.Storage/BlobServiceClientAuthType.cs new file mode 100644 index 0000000000..f4cb3591e2 --- /dev/null +++ b/src/NuGet.Services.Storage/BlobServiceClientAuthType.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace NuGet.Services.Storage +{ + enum BlobServiceClientAuthType + { + Anonymous, + TokenCredential, + ConnectionString + } +} diff --git a/src/NuGet.Services.Storage/BlobServiceClientFactory.cs b/src/NuGet.Services.Storage/BlobServiceClientFactory.cs index 1452f6c5c1..1e8706e031 100644 --- a/src/NuGet.Services.Storage/BlobServiceClientFactory.cs +++ b/src/NuGet.Services.Storage/BlobServiceClientFactory.cs @@ -14,37 +14,57 @@ namespace NuGet.Services.Storage { public class BlobServiceClientFactory : IBlobServiceClientFactory { - private bool _useTokenCredential = false; + private readonly BlobServiceClientAuthType? _authType; private TokenCredential _credential; private string _connectionString = ""; public virtual Uri Uri { get; set; } - public BlobServiceClientFactory() { } + protected BlobServiceClientFactory() { } public BlobServiceClientFactory(string connectionString) { + if (string.IsNullOrEmpty(connectionString)) + { + throw new ArgumentNullException(nameof(connectionString)); + } + _connectionString = connectionString; this.Uri = new BlobServiceClient(connectionString).Uri; + _authType = BlobServiceClientAuthType.ConnectionString; } - public BlobServiceClientFactory(Uri serviceUri, TokenCredential credential) + public BlobServiceClientFactory(Uri serviceUri, TokenCredential credential = null) { - this.Uri = serviceUri; - _credential = credential ?? throw new ArgumentNullException(nameof(credential)); - _useTokenCredential = true; - } + this.Uri = serviceUri ?? throw new ArgumentNullException(nameof(serviceUri)); - public virtual BlobServiceClient GetBlobServiceClient(BlobClientOptions blobClientOptions = null) - { - if (_useTokenCredential) + if (credential != null) { - return new BlobServiceClient(this.Uri, _credential, blobClientOptions); + _credential = credential; + _authType = BlobServiceClientAuthType.TokenCredential; } else { - return new BlobServiceClient(_connectionString, blobClientOptions); + _authType = BlobServiceClientAuthType.Anonymous; } } + + public virtual BlobServiceClient GetBlobServiceClient(BlobClientOptions blobClientOptions = null) + { + if (_authType.HasValue) + { + switch (_authType) + { + case BlobServiceClientAuthType.TokenCredential: + return new BlobServiceClient(this.Uri, _credential, blobClientOptions); + case BlobServiceClientAuthType.ConnectionString: + return new BlobServiceClient(_connectionString, blobClientOptions); + case BlobServiceClientAuthType.Anonymous: + return new BlobServiceClient(this.Uri, blobClientOptions); + } + } + + throw new Exception("No authentication type configured"); + } } } From 5be598ae146f9498211b6d62bdc95929447bd9ff Mon Sep 17 00:00:00 2001 From: Joel Verhagen Date: Mon, 18 Nov 2024 17:27:40 -0500 Subject: [PATCH 30/33] Do not audit values for removed/revoked API keys (#10272) --- .../Auditing/CredentialAuditRecord.cs | 23 ++++-------- ...FailedAuthenticatedOperationAuditRecord.cs | 4 +-- .../Auditing/UserAuditRecord.cs | 10 +++--- .../Auditing/CredentialAuditRecordTests.cs | 36 ++++++++----------- .../Auditing/UserAuditRecordTests.cs | 6 ++-- .../AuthenticationServiceFacts.cs | 6 ++-- 6 files changed, 32 insertions(+), 53 deletions(-) diff --git a/src/NuGetGallery.Core/Auditing/CredentialAuditRecord.cs b/src/NuGetGallery.Core/Auditing/CredentialAuditRecord.cs index f2f5f36027..0f9a1e1fd5 100644 --- a/src/NuGetGallery.Core/Auditing/CredentialAuditRecord.cs +++ b/src/NuGetGallery.Core/Auditing/CredentialAuditRecord.cs @@ -21,7 +21,7 @@ public class CredentialAuditRecord public string TenantId { get; } public string RevocationSource { get; } - public CredentialAuditRecord(Credential credential, bool removedOrRevoked) + public CredentialAuditRecord(Credential credential) { if (credential == null) { @@ -34,23 +34,12 @@ public CredentialAuditRecord(Credential credential, bool removedOrRevoked) Identity = credential.Identity; TenantId = credential.TenantId; - // Track the value for credentials that are external (object id) or definitely revocable (API Key, etc.) and have been removed + // Track the value for credentials that are external (object id) + // Do not track the credential valid for API keys or passwords, even if they are revoked. if (credential.IsExternal()) { Value = credential.Value; } - else if (removedOrRevoked) - { - if (Type == null) - { - throw new ArgumentNullException(nameof(credential.Type)); - } - - if (!credential.IsPassword()) - { - Value = credential.Value; - } - } Created = credential.Created; Expires = credential.Expires; @@ -65,10 +54,10 @@ public CredentialAuditRecord(Credential credential, bool removedOrRevoked) } } - public CredentialAuditRecord(Credential credential, bool removedOrRevoked, string revocationSource) - : this(credential, removedOrRevoked) + public CredentialAuditRecord(Credential credential, string revocationSource) + : this(credential) { RevocationSource = revocationSource; } } -} \ No newline at end of file +} diff --git a/src/NuGetGallery.Core/Auditing/FailedAuthenticatedOperationAuditRecord.cs b/src/NuGetGallery.Core/Auditing/FailedAuthenticatedOperationAuditRecord.cs index 5dca831102..d4c3413eae 100644 --- a/src/NuGetGallery.Core/Auditing/FailedAuthenticatedOperationAuditRecord.cs +++ b/src/NuGetGallery.Core/Auditing/FailedAuthenticatedOperationAuditRecord.cs @@ -31,7 +31,7 @@ public FailedAuthenticatedOperationAuditRecord( if (attemptedCredential != null) { - AttemptedCredential = new CredentialAuditRecord(attemptedCredential, removedOrRevoked: false); + AttemptedCredential = new CredentialAuditRecord(attemptedCredential); } } @@ -40,4 +40,4 @@ public override string GetPath() return Path; // store in /failedauthenticatedoperation/all } } -} \ No newline at end of file +} diff --git a/src/NuGetGallery.Core/Auditing/UserAuditRecord.cs b/src/NuGetGallery.Core/Auditing/UserAuditRecord.cs index 1c05d96431..6b679fb534 100644 --- a/src/NuGetGallery.Core/Auditing/UserAuditRecord.cs +++ b/src/NuGetGallery.Core/Auditing/UserAuditRecord.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -56,7 +56,7 @@ public UserAuditRecord(User user, AuditedUserAction action) Roles = user.Roles.Select(r => r.Name).ToArray(); Credentials = user.Credentials.Where(CredentialTypes.IsSupportedCredential) - .Select(c => new CredentialAuditRecord(c, removedOrRevoked: false)).ToArray(); + .Select(c => new CredentialAuditRecord(c)).ToArray(); AffectedCredential = Array.Empty(); AffectedPolicies = Array.Empty(); @@ -70,8 +70,7 @@ public UserAuditRecord(User user, AuditedUserAction action, Credential affected, public UserAuditRecord(User user, AuditedUserAction action, IEnumerable affected, string revocationSource) : this(user, action) { - AffectedCredential = affected.Select(c => new CredentialAuditRecord(c, - removedOrRevoked: action == AuditedUserAction.RemoveCredential || action == AuditedUserAction.RevokeCredential, revocationSource: revocationSource)).ToArray(); + AffectedCredential = affected.Select(c => new CredentialAuditRecord(c, revocationSource: revocationSource)).ToArray(); } public UserAuditRecord(User user, AuditedUserAction action, Credential affected) @@ -82,8 +81,7 @@ public UserAuditRecord(User user, AuditedUserAction action, Credential affected) public UserAuditRecord(User user, AuditedUserAction action, IEnumerable affected) : this(user, action) { - AffectedCredential = affected.Select(c => new CredentialAuditRecord(c, - removedOrRevoked: action == AuditedUserAction.RemoveCredential || action == AuditedUserAction.RevokeCredential)).ToArray(); + AffectedCredential = affected.Select(c => new CredentialAuditRecord(c)).ToArray(); } public UserAuditRecord(User user, AuditedUserAction action, string affectedEmailAddress) diff --git a/tests/NuGetGallery.Core.Facts/Auditing/CredentialAuditRecordTests.cs b/tests/NuGetGallery.Core.Facts/Auditing/CredentialAuditRecordTests.cs index 3d21a1cbd9..e1828b7e03 100644 --- a/tests/NuGetGallery.Core.Facts/Auditing/CredentialAuditRecordTests.cs +++ b/tests/NuGetGallery.Core.Facts/Auditing/CredentialAuditRecordTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -13,31 +13,23 @@ public class CredentialAuditRecordTests [Fact] public void Constructor_ThrowsForNullCredential() { - Assert.Throws(() => new CredentialAuditRecord(credential: null, removedOrRevoked: true)); + Assert.Throws(() => new CredentialAuditRecord(credential: null)); } [Fact] - public void Constructor_ThrowsForRemovalWithNullType() - { - var credential = new Credential(); - - Assert.Throws(() => new CredentialAuditRecord(credential, removedOrRevoked: true)); - } - - [Fact] - public void Constructor_RemovalOfNonPasswordSetsValue() + public void Constructor_RemovalOfNonPasswordDoesNotSetValue() { var credential = new Credential(type: "a", value: "b"); - var record = new CredentialAuditRecord(credential, removedOrRevoked: true); + var record = new CredentialAuditRecord(credential); - Assert.Equal("b", record.Value); + Assert.Null(record.Value); } [Fact] public void Constructor_RemovalOfPasswordDoesNotSetValue() { var credential = new Credential(type: CredentialTypes.Password.V3, value: "a"); - var record = new CredentialAuditRecord(credential, removedOrRevoked: true); + var record = new CredentialAuditRecord(credential); Assert.Null(record.Value); } @@ -46,7 +38,7 @@ public void Constructor_RemovalOfPasswordDoesNotSetValue() public void Constructor_NonRemovalOfNonPasswordDoesNotSetsValue() { var credential = new Credential(type: "a", value: "b"); - var record = new CredentialAuditRecord(credential, removedOrRevoked: false); + var record = new CredentialAuditRecord(credential); Assert.Null(record.Value); } @@ -57,7 +49,7 @@ public void Constructor_NonRemovalOfNonPasswordDoesNotSetsValue() public void Constructor_ExternalCredentialSetsValue(string externalType) { var credential = new Credential(type: externalType, value: "b"); - var record = new CredentialAuditRecord(credential, removedOrRevoked: false); + var record = new CredentialAuditRecord(credential); Assert.Equal("b", record.Value); } @@ -66,7 +58,7 @@ public void Constructor_ExternalCredentialSetsValue(string externalType) public void Constructor_NonRemovalOfPasswordDoesNotSetValue() { var credential = new Credential(type: CredentialTypes.Password.V3, value: "a"); - var record = new CredentialAuditRecord(credential, removedOrRevoked: false); + var record = new CredentialAuditRecord(credential); Assert.Null(record.Value); } @@ -90,7 +82,7 @@ public void Constructor_SetsProperties() Type = "e", Value = "f" }; - var record = new CredentialAuditRecord(credential, removedOrRevoked: true); + var record = new CredentialAuditRecord(credential); Assert.Equal(created, record.Created); Assert.Equal("a", record.Description); @@ -104,7 +96,7 @@ public void Constructor_SetsProperties() Assert.Equal("c", scope.Subject); Assert.Equal("d", scope.AllowedAction); Assert.Equal("e", record.Type); - Assert.Equal("f", record.Value); + Assert.Null(record.Value); } [Fact] @@ -112,11 +104,11 @@ public void Constructor_WithRevocationSource_Properties() { var testRevocationSource = "TestRevocationSource"; var credential = new Credential(type: "a", value: "b"); - var record = new CredentialAuditRecord(credential, removedOrRevoked: true, revocationSource: testRevocationSource); + var record = new CredentialAuditRecord(credential, revocationSource: testRevocationSource); Assert.Equal(testRevocationSource, record.RevocationSource); Assert.Equal("a", record.Type); - Assert.Equal("b", record.Value); + Assert.Null(record.Value); } } -} \ No newline at end of file +} diff --git a/tests/NuGetGallery.Core.Facts/Auditing/UserAuditRecordTests.cs b/tests/NuGetGallery.Core.Facts/Auditing/UserAuditRecordTests.cs index 6760589e61..ca42318626 100644 --- a/tests/NuGetGallery.Core.Facts/Auditing/UserAuditRecordTests.cs +++ b/tests/NuGetGallery.Core.Facts/Auditing/UserAuditRecordTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -108,7 +108,7 @@ public void Constructor_WithRevocationSource_SetsProperties() Assert.Single(record.AffectedCredential); Assert.Equal(testRevocationSource, record.AffectedCredential[0].RevocationSource); Assert.Equal("b", record.AffectedCredential[0].Type); - Assert.Equal("c", record.AffectedCredential[0].Value); + Assert.Null(record.AffectedCredential[0].Value); } [Fact] @@ -127,4 +127,4 @@ public void GetPath_ReturnsLowerCasedUserName() Assert.Equal("a", actualPath); } } -} \ No newline at end of file +} diff --git a/tests/NuGetGallery.Facts/Authentication/AuthenticationServiceFacts.cs b/tests/NuGetGallery.Facts/Authentication/AuthenticationServiceFacts.cs index ee1a26eb7e..7a6777e2a2 100644 --- a/tests/NuGetGallery.Facts/Authentication/AuthenticationServiceFacts.cs +++ b/tests/NuGetGallery.Facts/Authentication/AuthenticationServiceFacts.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -1605,7 +1605,7 @@ public async Task WritesAuditRecordRemovingTheOldCredential() ar.AffectedCredential.Length == 1 && ar.AffectedCredential[0].Type == existingCred.Type && ar.AffectedCredential[0].Identity == existingCred.Identity && - ar.AffectedCredential[0].Value == existingCred.Value && + ar.AffectedCredential[0].Value is null && ar.AffectedCredential[0].Created == existingCred.Created && ar.AffectedCredential[0].Expires == existingCred.Expires)); } @@ -2633,4 +2633,4 @@ public static bool VerifyPasswordHash(string hash, string algorithm, string pass return canAuthenticate && !confidenceCheck; } } -} \ No newline at end of file +} From db81abe3006a35e1a1582e19bf6bf7ddd85bb414 Mon Sep 17 00:00:00 2001 From: Joel Verhagen Date: Tue, 19 Nov 2024 17:59:26 -0500 Subject: [PATCH 31/33] Copy list of scopes before removing them (#10275) A collection modified exception is thrown since internally the DbContext modifies the collection we are enumerating. --- .../Authentication/AuthenticationService.cs | 6 +-- .../AuthenticationServiceFacts.cs | 39 +++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/src/NuGetGallery.Services/Authentication/AuthenticationService.cs b/src/NuGetGallery.Services/Authentication/AuthenticationService.cs index 2065a465d6..54529a8fab 100644 --- a/src/NuGetGallery.Services/Authentication/AuthenticationService.cs +++ b/src/NuGetGallery.Services/Authentication/AuthenticationService.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -714,7 +714,7 @@ public virtual async Task RemoveCredential(User user, Credential cred, bool comm public virtual async Task EditCredentialScopes(User user, Credential cred, ICollection newScopes) { - foreach (var oldScope in cred.Scopes) + foreach (var oldScope in cred.Scopes.ToList()) { Entities.Scopes.Remove(oldScope); } @@ -1053,4 +1053,4 @@ private async Task MigrateCredentials(User user, List creds, string } } } -} \ No newline at end of file +} diff --git a/tests/NuGetGallery.Facts/Authentication/AuthenticationServiceFacts.cs b/tests/NuGetGallery.Facts/Authentication/AuthenticationServiceFacts.cs index 7a6777e2a2..01466c530a 100644 --- a/tests/NuGetGallery.Facts/Authentication/AuthenticationServiceFacts.cs +++ b/tests/NuGetGallery.Facts/Authentication/AuthenticationServiceFacts.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Data.Entity; using System.Globalization; using System.Linq; using System.Security.Claims; @@ -2382,6 +2383,44 @@ public async Task SavesChangesInTheDataStore() authService.Entities.VerifyCommitChanges(); } + /// + /// Needed to avoid collection modified exception caused by the entity context. + /// + [Fact] + public async Task CopiesScopeCollectionForDeletion() + { + // Arrange + var credentialBuilder = new CredentialBuilder(); + + var credScopes = + Enumerable.Range(0, 5) + .Select( + i => new Scope { AllowedAction = NuGetScopes.PackagePush, Key = i, Subject = "package" + i }).ToList(); + + var mockScopes = new Mock>(); + var dbContext = GetMock(); + dbContext.Setup(x => x.Scopes).Returns(mockScopes.Object); + mockScopes.Setup(x => x.Remove(It.IsAny())).Callback(x => credScopes.Remove(x)); + + var fakes = Get(); + var cred = credentialBuilder.CreateApiKey(null, out string plaintextApiKey); + var user = fakes.CreateUser("test", credentialBuilder.CreatePasswordCredential(Fakes.Password), cred); + var authService = Get(); + + var newScopes = + Enumerable.Range(1, 2) + .Select( + i => new Scope { AllowedAction = NuGetScopes.PackageUnlist, Key = i * 10, Subject = "otherpackage" + i }).ToList(); + + cred.Scopes = credScopes; + + // Act + await authService.EditCredentialScopes(user, cred, newScopes); + + // Act + Assert.Empty(credScopes); + } + [Fact] public async Task WritesAuditRecordForTheEditedCredential() { From d84cb2c1602a36a48326c200a3d2cc356968eb0f Mon Sep 17 00:00:00 2001 From: Joel Verhagen Date: Thu, 21 Nov 2024 19:06:23 -0500 Subject: [PATCH 32/33] [OIDC] Add method to create a short-lived API key (minimal) (#10267) This is a stub implementation until we have finalized the new API key design. --- .../Authentication/CredentialBuilder.cs | 23 +++++ .../Authentication/ICredentialBuilder.cs | 2 + .../Authentication/CredentialBuilderFacts.cs | 83 +++++++++++++++++++ 3 files changed, 108 insertions(+) create mode 100644 tests/NuGetGallery.Facts/Authentication/CredentialBuilderFacts.cs diff --git a/src/NuGetGallery.Services/Authentication/CredentialBuilder.cs b/src/NuGetGallery.Services/Authentication/CredentialBuilder.cs index 81c42780e9..ea7f96d542 100644 --- a/src/NuGetGallery.Services/Authentication/CredentialBuilder.cs +++ b/src/NuGetGallery.Services/Authentication/CredentialBuilder.cs @@ -6,6 +6,7 @@ using System.Linq; using NuGet.Services.Entities; using NuGetGallery.Authentication; +using NuGetGallery.Services.Authentication; namespace NuGetGallery.Infrastructure.Authentication { @@ -28,6 +29,28 @@ public Credential CreatePasswordCredential(string plaintextPassword) V3Hasher.GenerateHash(plaintextPassword)); } + public Credential CreateShortLivedApiKey(TimeSpan expiration, FederatedCredentialPolicy policy, out string plaintextApiKey) + { + if (policy.PackageOwner is null) + { + throw new ArgumentException($"The {nameof(policy.PackageOwner)} property on the policy must not be null."); + } + + if (expiration <= TimeSpan.Zero || expiration > TimeSpan.FromHours(1)) + { + throw new ArgumentOutOfRangeException(nameof(expiration)); + } + + // TODO: introduce a new API key type for short-lived API keys + // Tracking: https://github.com/NuGet/NuGetGallery/issues/10212 + var credential = CreateApiKey(expiration, out plaintextApiKey); + + credential.Description = "Short-lived API key generated via a federated credential"; + credential.Scopes = [new Scope(policy.PackageOwner, NuGetPackagePattern.AllInclusivePattern, NuGetScopes.All)]; + + return credential; + } + public Credential CreateApiKey(TimeSpan? expiration, out string plaintextApiKey) { var apiKey = ApiKeyV4.Create(); diff --git a/src/NuGetGallery.Services/Authentication/ICredentialBuilder.cs b/src/NuGetGallery.Services/Authentication/ICredentialBuilder.cs index 9da4eb84cd..9c5e9dea1a 100644 --- a/src/NuGetGallery.Services/Authentication/ICredentialBuilder.cs +++ b/src/NuGetGallery.Services/Authentication/ICredentialBuilder.cs @@ -20,5 +20,7 @@ public interface ICredentialBuilder IList BuildScopes(User scopeOwner, string[] scopes, string[] subjects); bool VerifyScopes(User currentUser, IEnumerable scopes); + + Credential CreateShortLivedApiKey(TimeSpan expiration, FederatedCredentialPolicy policy, out string plaintextApiKey); } } diff --git a/tests/NuGetGallery.Facts/Authentication/CredentialBuilderFacts.cs b/tests/NuGetGallery.Facts/Authentication/CredentialBuilderFacts.cs new file mode 100644 index 0000000000..24f273a6cb --- /dev/null +++ b/tests/NuGetGallery.Facts/Authentication/CredentialBuilderFacts.cs @@ -0,0 +1,83 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using NuGet.Services.Entities; +using NuGetGallery.Authentication; +using Xunit; + +namespace NuGetGallery.Infrastructure.Authentication +{ + public class CredentialBuilderFacts + { + public class TheCreateShortLivedApiKeyMethod : CredentialBuilderFacts + { + [Fact] + public void CreatesShortLivedApiKey() + { + // Act + var credential = Target.CreateShortLivedApiKey(Expiration, Policy, out var plaintextApiKey); + + // Assert + Assert.Null(credential.User); + Assert.Equal(default, credential.UserKey); + Assert.StartsWith("oy2", plaintextApiKey, StringComparison.Ordinal); + Assert.Equal(CredentialTypes.ApiKey.V4, credential.Type); + Assert.Equal("Short-lived API key generated via a federated credential", credential.Description); + Assert.Equal(Expiration.Ticks, credential.ExpirationTicks); + Assert.Null(credential.User); + + var scope = Assert.Single(credential.Scopes); + Assert.Equal(NuGetScopes.All, scope.AllowedAction); + Assert.Equal(NuGetPackagePattern.AllInclusivePattern, scope.Subject); + Assert.Same(Policy.PackageOwner, scope.Owner); + } + + [Fact] + public void RejectsMissingPackageOwner() + { + // Arrange + Policy.PackageOwner = null; + + // Act + Assert.Throws(() => Target.CreateShortLivedApiKey(Expiration, Policy, out var plaintextApiKey)); + } + + [Theory] + [InlineData(-1)] + [InlineData(0)] + [InlineData(61)] + public void RejectsOutOfRangeExpiration(int expirationMinutes) + { + // Arrange + Expiration = TimeSpan.FromMinutes(expirationMinutes); + + // Act + Assert.Throws(() => Target.CreateShortLivedApiKey(Expiration, Policy, out var plaintextApiKey)); + } + + public FederatedCredentialPolicy Policy { get; } + + public TheCreateShortLivedApiKeyMethod() + { + Policy = new FederatedCredentialPolicy + { + Key = 23, + PackageOwner = new User { Key = 42 }, + CreatedBy = new User { Key = 43 }, + }; + } + } + + public TimeSpan Expiration { get; set; } + + public CredentialBuilder Target { get; } + + public CredentialBuilderFacts() + { + Expiration = TimeSpan.FromMinutes(13); + + Target = new CredentialBuilder(); + } + } +} From 45b5070973cea66b04c35bc0fee938feae6b770f Mon Sep 17 00:00:00 2001 From: Joel Verhagen Date: Thu, 21 Nov 2024 21:01:18 -0500 Subject: [PATCH 33/33] [OIDC] Add repository for federated credential DB entities (EF wrapper) (#10268) --- .../FederatedCredentialRepository.cs | 84 +++++++++ .../App_Start/DefaultDependenciesModule.cs | 5 + .../FederatedCredentialRepositoryFacts.cs | 159 ++++++++++++++++++ 3 files changed, 248 insertions(+) create mode 100644 src/NuGetGallery.Services/Authentication/Federated/FederatedCredentialRepository.cs create mode 100644 tests/NuGetGallery.Facts/Authentication/Federated/FederatedCredentialRepositoryFacts.cs diff --git a/src/NuGetGallery.Services/Authentication/Federated/FederatedCredentialRepository.cs b/src/NuGetGallery.Services/Authentication/Federated/FederatedCredentialRepository.cs new file mode 100644 index 0000000000..45e9dfd403 --- /dev/null +++ b/src/NuGetGallery.Services/Authentication/Federated/FederatedCredentialRepository.cs @@ -0,0 +1,84 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; +using System.Linq; +using System.Data; +using System.Collections.Generic; +using System.Data.Entity; +using NuGet.Services.Entities; + +#nullable enable + +namespace NuGetGallery.Services.Authentication +{ + public interface IFederatedCredentialRepository + { + Task AddPolicyAsync(FederatedCredentialPolicy policy, bool saveChanges); + Task SaveFederatedCredentialAsync(FederatedCredential federatedCredential, bool saveChanges); + IReadOnlyList GetPoliciesCreatedByUser(int userKey); + FederatedCredentialPolicy? GetPolicyByKey(int policyKey); + Task DeletePolicyAsync(FederatedCredentialPolicy policy, bool saveChanges); + } + + public class FederatedCredentialRepository : IFederatedCredentialRepository + { + private readonly IEntityRepository _policyRepository; + private readonly IEntityRepository _federatedCredentialRepository; + + public FederatedCredentialRepository( + IEntityRepository policyRepository, + IEntityRepository federatedCredentialRepository) + { + _policyRepository = policyRepository; + _federatedCredentialRepository = federatedCredentialRepository; + } + + public IReadOnlyList GetPoliciesCreatedByUser(int userKey) + { + return _policyRepository + .GetAll() + .Where(p => p.CreatedByUserKey == userKey) + .ToList(); + } + + public FederatedCredentialPolicy? GetPolicyByKey(int policyKey) + { + return _policyRepository + .GetAll() + .Where(p => p.Key == policyKey) + .Include(p => p.CreatedBy) + .FirstOrDefault(); + } + + public async Task SaveFederatedCredentialAsync(FederatedCredential federatedCredential, bool saveChanges) + { + _federatedCredentialRepository.InsertOnCommit(federatedCredential); + + if (saveChanges) + { + await _federatedCredentialRepository.CommitChangesAsync(); + } + } + + public async Task AddPolicyAsync(FederatedCredentialPolicy policy, bool saveChanges) + { + _policyRepository.InsertOnCommit(policy); + + if (saveChanges) + { + await _policyRepository.CommitChangesAsync(); + } + } + + public async Task DeletePolicyAsync(FederatedCredentialPolicy policy, bool saveChanges) + { + _policyRepository.DeleteOnCommit(policy); + + if (saveChanges) + { + await _policyRepository.CommitChangesAsync(); + } + } + } +} diff --git a/src/NuGetGallery/App_Start/DefaultDependenciesModule.cs b/src/NuGetGallery/App_Start/DefaultDependenciesModule.cs index 96641b016a..6b138f24c7 100644 --- a/src/NuGetGallery/App_Start/DefaultDependenciesModule.cs +++ b/src/NuGetGallery/App_Start/DefaultDependenciesModule.cs @@ -549,6 +549,11 @@ protected override void Load(ContainerBuilder builder) private static void ConfigureFederatedCredentials(ContainerBuilder builder, ConfigurationService configuration) { + builder + .RegisterType() + .As() + .InstancePerLifetimeScope(); + builder .Register(c => configuration.FederatedCredential) .As(); diff --git a/tests/NuGetGallery.Facts/Authentication/Federated/FederatedCredentialRepositoryFacts.cs b/tests/NuGetGallery.Facts/Authentication/Federated/FederatedCredentialRepositoryFacts.cs new file mode 100644 index 0000000000..426c44d219 --- /dev/null +++ b/tests/NuGetGallery.Facts/Authentication/Federated/FederatedCredentialRepositoryFacts.cs @@ -0,0 +1,159 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Moq; +using NuGet.Services.Entities; +using Xunit; + +#nullable enable + +namespace NuGetGallery.Services.Authentication +{ + public class FederatedCredentialRepositoryFacts + { + public class TheGetPoliciesCreatedByUserMethod : FederatedCredentialRepositoryFacts + { + [Fact] + public void FiltersByUserKey() + { + // Act + var policies = Target.GetPoliciesCreatedByUser(userKey: 4); + + // Assert + Assert.Equal(2, policies.Count); + Assert.Equal(1, policies[0].Key); + Assert.Equal(2, policies[1].Key); + } + } + + public class TheGetPolicyByKeyMethod : FederatedCredentialRepositoryFacts + { + [Fact] + public void ReturnsPolicyByKey() + { + // Act + var policy = Target.GetPolicyByKey(2); + + // Assert + Assert.Equal(2, policy!.Key); + } + + [Fact] + public void ReturnsNullIfDoesNotExist() + { + // Act + var policy = Target.GetPolicyByKey(23); + + // Assert + Assert.Null(policy); + } + } + + public class TheSaveFederatedCredentialAsyncMethod : FederatedCredentialRepositoryFacts + { + [Fact] + public async Task InsertsCredential() + { + // Arrange + var credential = new FederatedCredential(); + + // Act + await Target.SaveFederatedCredentialAsync(credential, saveChanges: false); + + // Assert + CredentialRepository.Verify(x => x.InsertOnCommit(credential), Times.Once); + CredentialRepository.Verify(x => x.CommitChangesAsync(), Times.Never); + } + + [Fact] + public async Task CommitsChangesIfRequested() + { + // Arrange + var credential = new FederatedCredential(); + + // Act + await Target.SaveFederatedCredentialAsync(credential, saveChanges: true); + + // Assert + CredentialRepository.Verify(x => x.InsertOnCommit(credential), Times.Once); + CredentialRepository.Verify(x => x.CommitChangesAsync(), Times.Once); + } + } + + public class TheAddPolicyAsyncMethod : FederatedCredentialRepositoryFacts + { + [Fact] + public async Task InsertsPolicy() + { + // Act + await Target.AddPolicyAsync(Policies[0], saveChanges: false); + + // Assert + PolicyRepository.Verify(x => x.InsertOnCommit(Policies[0]), Times.Once); + PolicyRepository.Verify(x => x.CommitChangesAsync(), Times.Never); + } + + [Fact] + public async Task CommitsChangesIfRequested() + { + // Act + await Target.AddPolicyAsync(Policies[0], saveChanges: true); + + // Assert + PolicyRepository.Verify(x => x.InsertOnCommit(Policies[0]), Times.Once); + PolicyRepository.Verify(x => x.CommitChangesAsync(), Times.Once); + } + } + + public class TheDeletePolicyAsyncMethod : FederatedCredentialRepositoryFacts + { + [Fact] + public async Task DeletesPolicy() + { + // Act + await Target.DeletePolicyAsync(Policies[0], saveChanges: false); + + // Assert + PolicyRepository.Verify(x => x.DeleteOnCommit(Policies[0]), Times.Once); + PolicyRepository.Verify(x => x.CommitChangesAsync(), Times.Never); + } + + [Fact] + public async Task CommitsChangesIfRequested() + { + // Act + await Target.DeletePolicyAsync(Policies[0], saveChanges: true); + + // Assert + PolicyRepository.Verify(x => x.DeleteOnCommit(Policies[0]), Times.Once); + PolicyRepository.Verify(x => x.CommitChangesAsync(), Times.Once); + } + } + + public FederatedCredentialRepositoryFacts() + { + CredentialRepository = new Mock>(); + PolicyRepository = new Mock>(); + + Policies = new List + { + new() { Key = 1, CreatedByUserKey = 4 }, + new() { Key = 2, CreatedByUserKey = 4 }, + new() { Key = 3, CreatedByUserKey = 5 }, + }; + PolicyRepository.Setup(x => x.GetAll()).Returns(() => Policies.AsQueryable()); + + Target = new FederatedCredentialRepository( + PolicyRepository.Object, + CredentialRepository.Object); + } + + public Mock> CredentialRepository { get; } + public Mock> PolicyRepository { get; } + public List Policies { get; } + public FederatedCredentialRepository Target { get; } + } +}