Skip to content

Commit

Permalink
Addressed feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
drewgillies committed May 18, 2021
1 parent 2e4e873 commit 6f3cce4
Show file tree
Hide file tree
Showing 12 changed files with 121 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -370,5 +370,10 @@ public void TrackVerifyPackageKeyEvent(string packageId, string packageVersion,
{
throw new NotImplementedException();
}

public void TrackVulnerabilitiesCacheRefreshDuration(long duration)
{
throw new NotImplementedException();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,10 @@ public interface IPackageVulnerabilitiesManagementService
/// <param name="withdrawn">Whether or not the vulnerability was withdrawn.</param>
Task UpdateVulnerabilityAsync(PackageVulnerability vulnerability, bool withdrawn);

/// <summary>
/// Get a package's collection of vulnerable ranges.
/// </summary>
/// <param name="packageId">The package's Id</param>
/// <returns>The package's vulnerable ranges, connecting it to <see cref="PackageVulnerability" /> instances</returns>
IQueryable<VulnerablePackageVersionRange> GetVulnerableRangesById(string packageId);

/// <summary>
/// Get the full set of vulnerable package entities
/// </summary>
/// <returns>Vulnerable package version ranges</returns>
IQueryable<VulnerablePackageVersionRange> GetAllVulnerableRanges();

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ public async Task UpdateVulnerabilityAsync(PackageVulnerability vulnerability, b
}
}

public IQueryable<VulnerablePackageVersionRange> GetVulnerableRangesById(string packageId) =>
_entitiesContext.VulnerableRanges.Where(x => x.PackageId == packageId);

public IQueryable<VulnerablePackageVersionRange> GetAllVulnerableRanges() =>
_entitiesContext.Set<VulnerablePackageVersionRange>();

Expand Down
6 changes: 6 additions & 0 deletions src/NuGetGallery.Services/Telemetry/ITelemetryService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -404,5 +404,11 @@ void TrackABTestEvaluated(
bool isAuthenticated,
int testBucket,
int testPercentage);

/// <summary>
/// Track how long it takes to populate the vulnerabilities cache
/// </summary>
/// <param name="milliseconds">Refresh duration for vulnerabilities cache</param>
void TrackVulnerabilitiesCacheRefreshDuration(long milliseconds);
}
}
6 changes: 6 additions & 0 deletions src/NuGetGallery.Services/Telemetry/TelemetryService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public class Events
public const string ABTestEvaluated = "ABTestEvaluated";
public const string PackagePushDisconnect = "PackagePushDisconnect";
public const string SymbolPackagePushDisconnect = "SymbolPackagePushDisconnect";
public const string VulnerabilitiesCacheRefreshDuration = "VulnerabilitiesCacheRefreshDuration";
}

private readonly IDiagnosticsSource _diagnosticsSource;
Expand Down Expand Up @@ -1103,6 +1104,11 @@ public void TrackSymbolPackagePushDisconnectEvent()
TrackMetric(Events.SymbolPackagePushDisconnect, 1, p => { });
}

public void TrackVulnerabilitiesCacheRefreshDuration(long milliseconds)
{
TrackMetric(Events.VulnerabilitiesCacheRefreshDuration, milliseconds, properties => { });
}

/// <summary>
/// We use <see cref="ITelemetryClient.TrackMetric(string, double, IDictionary{string, string})"/> instead of
/// <see cref="ITelemetryClient.TrackEvent(string, IDictionary{string, string}, IDictionary{string, double})"/>
Expand Down
10 changes: 10 additions & 0 deletions src/NuGetGallery/App_Start/AppActivator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,16 @@ private static void BackgroundJobsPostStart(IAppConfiguration configuration)
HostingEnvironment.QueueBackgroundWorkItem(_ => cloudDownloadCountService.RefreshAsync());
jobs.Add(new CloudDownloadCountServiceRefreshJob(TimeSpan.FromMinutes(15), cloudDownloadCountService));
}

var packageVulnerabilitiesCacheService =
DependencyResolver.Current.GetService<IPackageVulnerabilitiesCacheService>() as
PackageVulnerabilitiesCacheService;
if (packageVulnerabilitiesCacheService != null)
{
// Perform initial refresh + schedule new refreshes every 30 minutes
HostingEnvironment.QueueBackgroundWorkItem(_ => packageVulnerabilitiesCacheService.RefreshCache());
jobs.Add(new PackageVulnerabilitiesCacheRefreshJob(TimeSpan.FromMinutes(30), packageVulnerabilitiesCacheService));
}
}

if (jobs.AnySafe())
Expand Down
4 changes: 3 additions & 1 deletion src/NuGetGallery/App_Start/DefaultDependenciesModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,9 @@ protected override void Load(ContainerBuilder builder)
.As<IPackageVulnerabilitiesManagementService>()
.InstancePerLifetimeScope();

builder.Register(c => new PackageVulnerabilitiesCacheService(c.Resolve<IPackageVulnerabilitiesManagementService>()))
builder.Register(c =>
new PackageVulnerabilitiesCacheService(c.Resolve<IPackageVulnerabilitiesManagementService>(),
c.Resolve<ITelemetryService>()))
.AsSelf()
.As<IPackageVulnerabilitiesCacheService>()
.SingleInstance();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Threading.Tasks;
using WebBackgrounder;

namespace NuGetGallery
{
public class PackageVulnerabilitiesCacheRefreshJob : Job
{
private readonly PackageVulnerabilitiesCacheService _packageVulnerabilitiesCacheService;

public PackageVulnerabilitiesCacheRefreshJob(TimeSpan interval, PackageVulnerabilitiesCacheService packageVulnerabilitiesCacheService)
: base("", interval)
{
_packageVulnerabilitiesCacheService = packageVulnerabilitiesCacheService;
}

public override Task Execute()
{
return new Task(() => _packageVulnerabilitiesCacheService.RefreshCache());
}
}
}
1 change: 1 addition & 0 deletions src/NuGetGallery/NuGetGallery.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@
<Compile Include="Infrastructure\ABTestEnrollmentState.cs" />
<Compile Include="Infrastructure\ABTestEnrollmentFactory.cs" />
<Compile Include="Infrastructure\CookieBasedABTestService.cs" />
<Compile Include="Infrastructure\Jobs\PackageVulnerabilitiesCacheRefreshJob.cs" />
<Compile Include="Infrastructure\RequestValidationExceptionFilter.cs" />
<Compile Include="Infrastructure\HttpStatusCodeWithHeadersResult.cs" />
<Compile Include="Infrastructure\IABTestEnrollmentFactory.cs" />
Expand Down
105 changes: 58 additions & 47 deletions src/NuGetGallery/Services/PackageVulnerabilitiesCacheService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,32 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Data.Entity;
using System.Diagnostics;
using System.Linq;
using NuGet.Services.Entities;

namespace NuGetGallery
{
public class PackageVulnerabilitiesCacheService : IPackageVulnerabilitiesCacheService
{
private const int CachingLimitMinutes = 1440; // We could make this 1 day instead (same value) but this is easier for spot testing the cache
private readonly object Locker = new object();
private IDictionary<string,
(DateTime cachedAt, Dictionary<int, IReadOnlyList<PackageVulnerability>> vulnerabilitiesById)> vulnerabilitiesByIdCache
= new Dictionary<string, (DateTime, Dictionary<int, IReadOnlyList<PackageVulnerability>>)>();
Dictionary<int, IReadOnlyList<PackageVulnerability>>> _vulnerabilitiesByIdCache
= new ConcurrentDictionary<string, Dictionary<int, IReadOnlyList<PackageVulnerability>>>();
private readonly object _refreshLock = new object();
private bool _isRefreshing;

private readonly IPackageVulnerabilitiesManagementService _packageVulnerabilitiesManagementService;
public PackageVulnerabilitiesCacheService(IPackageVulnerabilitiesManagementService packageVulnerabilitiesManagementService)
private readonly ITelemetryService _telemetryService;

public PackageVulnerabilitiesCacheService(
IPackageVulnerabilitiesManagementService packageVulnerabilitiesManagementService,
ITelemetryService telemetryService)
{
_packageVulnerabilitiesManagementService = packageVulnerabilitiesManagementService;
Initialize();
_telemetryService = telemetryService;
}

public IReadOnlyDictionary<int, IReadOnlyList<PackageVulnerability>> GetVulnerabilitiesById(string id)
Expand All @@ -32,52 +37,58 @@ public IReadOnlyDictionary<int, IReadOnlyList<PackageVulnerability>> GetVulnerab
throw new ArgumentException("Must have a value.", nameof(id));
}

if (ShouldCachedValueBeUpdated(id))
if (_vulnerabilitiesByIdCache.TryGetValue(id, out var result))
{
lock (Locker)
{
if (ShouldCachedValueBeUpdated(id))
{
var packageKeyAndVulnerability = _packageVulnerabilitiesManagementService
.GetVulnerableRangesById(id)
.Include(x => x.Vulnerability)
.SelectMany(x => x.Packages.Select(p => new {PackageKey = p.Key, x.Vulnerability}))
.GroupBy(pv => pv.PackageKey, pv => pv.Vulnerability)
.ToDictionary(pv => pv.Key,
pv => pv.ToList().AsReadOnly() as IReadOnlyList<PackageVulnerability>);

vulnerabilitiesByIdCache[id] = (cachedAt: DateTime.Now, vulnerabilitiesById: packageKeyAndVulnerability);
}
}
return result;
}

return vulnerabilitiesByIdCache[id].vulnerabilitiesById.Any()
? new ReadOnlyDictionary<int, IReadOnlyList<PackageVulnerability>>(vulnerabilitiesByIdCache[id].vulnerabilitiesById)
: null;
return null;
}

private void Initialize()
public void RefreshCache()
{
// We need to build a dictionary of dictionaries. Breaking it down:
// - this give us a list of all vulnerable package version ranges
vulnerabilitiesByIdCache = _packageVulnerabilitiesManagementService.GetAllVulnerableRanges()
.Include(x => x.Vulnerability)
// - from these we want a list in this format: (<id>, (<package key>, <vulnerability>))
// which will allow us to look up the dictionary by id, and return a dictionary of version -> vulnerability
.SelectMany(x => x.Packages.Select(p => new
{PackageId = x.PackageId ?? string.Empty, KeyVulnerability = new {PackageKey = p.Key, x.Vulnerability}}))
.GroupBy(ikv => ikv.PackageId, ikv => ikv.KeyVulnerability)
// - build the outer dictionary, keyed by <id> - each inner dictionary is paired with a time of creation (for cache invalidation)
.ToDictionary(ikv => ikv.Key,
ikv => (cachedAt: DateTime.Now,
vulnerabilitiesById: ikv.GroupBy(kv => kv.PackageKey, kv => kv.Vulnerability)
// - build the inner dictionaries, all under the same <id>, each keyed by <package key>
.ToDictionary(kv => kv.Key,
kv => kv.ToList().AsReadOnly() as IReadOnlyList<PackageVulnerability>)));
}
bool shouldRefresh = false;
lock (_refreshLock)
{
if (!_isRefreshing)
{
_isRefreshing = true;
shouldRefresh = true;
}
}

private bool ShouldCachedValueBeUpdated(string id) => !vulnerabilitiesByIdCache.ContainsKey(id) ||
vulnerabilitiesByIdCache[id].cachedAt
.AddMinutes(CachingLimitMinutes) < DateTime.Now;
if (shouldRefresh)
{
try
{
var stopwatch = Stopwatch.StartNew();

// We need to build a dictionary of dictionaries. Breaking it down:
// - this give us a list of all vulnerable package version ranges
_vulnerabilitiesByIdCache = _packageVulnerabilitiesManagementService.GetAllVulnerableRanges()
.Include(x => x.Vulnerability)
// - from these we want a list in this format: (<id>, (<package key>, <vulnerability>))
// which will allow us to look up the dictionary by id, and return a dictionary of version -> vulnerability
.SelectMany(x => x.Packages.Select(p => new
{ PackageId = x.PackageId ?? string.Empty, KeyVulnerability = new { PackageKey = p.Key, x.Vulnerability } }))
.GroupBy(ikv => ikv.PackageId, ikv => ikv.KeyVulnerability)
// - build the outer dictionary, keyed by <id> - each inner dictionary is paired with a time of creation (for cache invalidation)
.ToDictionary(ikv => ikv.Key,
ikv =>
ikv.GroupBy(kv => kv.PackageKey, kv => kv.Vulnerability)
// - build the inner dictionaries, all under the same <id>, each keyed by <package key>
.ToDictionary(kv => kv.Key,
kv => kv.ToList().AsReadOnly() as IReadOnlyList<PackageVulnerability>));

stopwatch.Stop();

_telemetryService.TrackVulnerabilitiesCacheRefreshDuration(stopwatch.ElapsedMilliseconds);
}
finally
{
_isRefreshing = false;
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public Task UpdateVulnerabilityAsync(PackageVulnerability vulnerability, bool wi
return Task.CompletedTask;
}

public IQueryable<VulnerablePackageVersionRange> GetVulnerableRangesById(string packageId) => throw new NotImplementedException();
public IQueryable<VulnerablePackageVersionRange> GetAllVulnerableRanges() => throw new NotImplementedException();

private void VerifyVulnerabilityInDatabase(PackageVulnerability vulnerability, bool withdrawn)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,23 @@ namespace NuGetGallery.Services
public class PackageVulnerabilitiesCacheServiceFacts : TestContainer
{
[Fact]
public void InitializesVulnerabilitiesCache()
public void RefreshesVulnerabilitiesCache()
{
// Arrange
var vulnerableVersionRanges = GetVersionRanges();
var pvmService = new Mock<IPackageVulnerabilitiesManagementService>();
pvmService.Setup(stub => stub.GetAllVulnerableRanges()).Returns(vulnerableVersionRanges);
pvmService.Setup(stub => stub.GetVulnerableRangesById(It.IsAny<string>())).Verifiable();
var cacheService = new PackageVulnerabilitiesCacheService(pvmService.Object);
var telemetryService = new Mock<ITelemetryService>();
telemetryService.Setup(stub => stub.TrackVulnerabilitiesCacheRefreshDuration(It.IsAny<long>())).Verifiable();
var cacheService = new PackageVulnerabilitiesCacheService(pvmService.Object, telemetryService.Object);
cacheService.RefreshCache();

// Act
var vulnerabilitiesFoo = cacheService.GetVulnerabilitiesById("Foo");
var vulnerabilitiesBar = cacheService.GetVulnerabilitiesById("Bar");

// Assert
// This method should never be called (it's only called when cache can't provide, and these values are loaded into the cache on initialize)
pvmService.Verify(s => s.GetVulnerableRangesById(It.IsAny<string>()), Times.Never);
// Test cache contents
telemetryService.Verify(stub => stub.TrackVulnerabilitiesCacheRefreshDuration(It.IsAny<long>()), Times.Once);
Assert.Equal(4, vulnerabilitiesFoo.Count);
Assert.Equal(1, vulnerabilitiesFoo[0].Count);
Assert.Equal(1, vulnerabilitiesFoo[1].Count);
Expand Down

0 comments on commit 6f3cce4

Please sign in to comment.