Skip to content

Commit

Permalink
Implement a cache for ReferenceAssemblies instances
Browse files Browse the repository at this point in the history
Avoids the need to manually keep track of identical instances across
test suites.
  • Loading branch information
sharwell committed Aug 13, 2024
1 parent fd1b713 commit af952dc
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
using System;
using NuGet.Versioning;

#if !NETCOREAPP
using System.Collections.Generic;
#endif

namespace Microsoft.CodeAnalysis.Testing
{
/// <summary>
/// Represents the core identity of a NuGet package.
/// </summary>
/// <seealso cref="NuGet.Packaging.Core.PackageIdentity"/>
public sealed class PackageIdentity
public sealed class PackageIdentity : IEquatable<PackageIdentity?>
{
/// <summary>
/// Initializes a new instance of the <see cref="PackageIdentity"/> class with the specified name and version.
Expand Down Expand Up @@ -41,6 +45,28 @@ public PackageIdentity(string id, string version)
/// <seealso cref="NuGet.Packaging.Core.PackageIdentity.Version"/>
public string Version { get; }

public override int GetHashCode()
{
#if NETCOREAPP
return HashCode.Combine(Id, Version);
#else
var hashCode = -612338121;
hashCode = (hashCode * -1521134295) + EqualityComparer<string>.Default.GetHashCode(Id);
hashCode = (hashCode * -1521134295) + EqualityComparer<string>.Default.GetHashCode(Version);
return hashCode;
#endif
}

public override bool Equals(object? obj)
=> Equals(obj as PackageIdentity);

public bool Equals(PackageIdentity? other)
{
return other is not null
&& Id == other.Id
&& Version == other.Version;
}

internal NuGet.Packaging.Core.PackageIdentity ToNuGetIdentity()
{
return new NuGet.Packaging.Core.PackageIdentity(Id, NuGetVersion.Parse(Version));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ Microsoft.CodeAnalysis.Testing.Model.EvaluatedProjectState.Sources.get -> System
Microsoft.CodeAnalysis.Testing.Model.EvaluatedProjectState.WithAdditionalDiagnostics(System.Collections.Immutable.ImmutableArray<Microsoft.CodeAnalysis.Diagnostic> additionalDiagnostics) -> Microsoft.CodeAnalysis.Testing.Model.EvaluatedProjectState
Microsoft.CodeAnalysis.Testing.Model.EvaluatedProjectState.WithSources(System.Collections.Immutable.ImmutableArray<(string filename, Microsoft.CodeAnalysis.Text.SourceText content)> sources) -> Microsoft.CodeAnalysis.Testing.Model.EvaluatedProjectState
Microsoft.CodeAnalysis.Testing.PackageIdentity
Microsoft.CodeAnalysis.Testing.PackageIdentity.Equals(Microsoft.CodeAnalysis.Testing.PackageIdentity other) -> bool
Microsoft.CodeAnalysis.Testing.PackageIdentity.Id.get -> string
Microsoft.CodeAnalysis.Testing.PackageIdentity.PackageIdentity(string id, string version) -> void
Microsoft.CodeAnalysis.Testing.PackageIdentity.Version.get -> string
Expand Down Expand Up @@ -176,6 +177,7 @@ Microsoft.CodeAnalysis.Testing.ReferenceAssemblies.AddLanguageSpecificAssemblies
Microsoft.CodeAnalysis.Testing.ReferenceAssemblies.AddPackages(System.Collections.Immutable.ImmutableArray<Microsoft.CodeAnalysis.Testing.PackageIdentity> packages) -> Microsoft.CodeAnalysis.Testing.ReferenceAssemblies
Microsoft.CodeAnalysis.Testing.ReferenceAssemblies.Assemblies.get -> System.Collections.Immutable.ImmutableArray<string>
Microsoft.CodeAnalysis.Testing.ReferenceAssemblies.AssemblyIdentityComparer.get -> Microsoft.CodeAnalysis.AssemblyIdentityComparer
Microsoft.CodeAnalysis.Testing.ReferenceAssemblies.Equals(Microsoft.CodeAnalysis.Testing.ReferenceAssemblies other) -> bool
Microsoft.CodeAnalysis.Testing.ReferenceAssemblies.FacadeAssemblies.get -> System.Collections.Immutable.ImmutableArray<string>
Microsoft.CodeAnalysis.Testing.ReferenceAssemblies.LanguageSpecificAssemblies.get -> System.Collections.Immutable.ImmutableDictionary<string, System.Collections.Immutable.ImmutableArray<string>>
Microsoft.CodeAnalysis.Testing.ReferenceAssemblies.Net
Expand Down Expand Up @@ -246,6 +248,10 @@ abstract Microsoft.CodeAnalysis.Testing.CodeActionTest<TVerifier>.SyntaxKindType
override Microsoft.CodeAnalysis.Testing.DiagnosticResult.ToString() -> string
override Microsoft.CodeAnalysis.Testing.EmptyDiagnosticAnalyzer.Initialize(Microsoft.CodeAnalysis.Diagnostics.AnalysisContext context) -> void
override Microsoft.CodeAnalysis.Testing.EmptyDiagnosticAnalyzer.SupportedDiagnostics.get -> System.Collections.Immutable.ImmutableArray<Microsoft.CodeAnalysis.DiagnosticDescriptor>
override Microsoft.CodeAnalysis.Testing.PackageIdentity.Equals(object obj) -> bool
override Microsoft.CodeAnalysis.Testing.PackageIdentity.GetHashCode() -> int
override Microsoft.CodeAnalysis.Testing.ReferenceAssemblies.Equals(object obj) -> bool
override Microsoft.CodeAnalysis.Testing.ReferenceAssemblies.GetHashCode() -> int
static Microsoft.CodeAnalysis.Testing.AnalyzerTest<TVerifier>.Verify.get -> TVerifier
static Microsoft.CodeAnalysis.Testing.AnalyzerVerifier<TAnalyzer, TTest, TVerifier>.Diagnostic() -> Microsoft.CodeAnalysis.Testing.DiagnosticResult
static Microsoft.CodeAnalysis.Testing.AnalyzerVerifier<TAnalyzer, TTest, TVerifier>.Diagnostic(Microsoft.CodeAnalysis.DiagnosticDescriptor descriptor) -> Microsoft.CodeAnalysis.Testing.DiagnosticResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

namespace Microsoft.CodeAnalysis.Testing
{
public sealed partial class ReferenceAssemblies
public sealed partial class ReferenceAssemblies : IEquatable<ReferenceAssemblies?>
{
private const string ReferenceAssembliesPackageVersion = "1.0.2";

Expand All @@ -37,6 +37,8 @@ public sealed partial class ReferenceAssemblies
private static ImmutableHashSet<NuGet.Packaging.Core.PackageIdentity> s_emptyPackages
= ImmutableHashSet.Create<NuGet.Packaging.Core.PackageIdentity>(PackageIdentityComparer.Default);

private static ImmutableHashSet<ReferenceAssemblies> s_knownAssemblies = ImmutableHashSet<ReferenceAssemblies>.Empty;

private readonly Dictionary<string, ImmutableArray<MetadataReference>> _references
= new();

Expand Down Expand Up @@ -123,14 +125,83 @@ public static ReferenceAssemblies Default

public string? NuGetConfigFilePath { get; }

private static ReferenceAssemblies GetOrAddReferenceAssemblies(ReferenceAssemblies value)
{
if (s_knownAssemblies.TryGetValue(value, out var existingValue))
{
return existingValue;
}

if (ImmutableInterlocked.Update(
ref s_knownAssemblies,
static (knownAssemblies, value) => knownAssemblies.Add(value),
value))
{
return value;
}

if (!s_knownAssemblies.TryGetValue(value, out existingValue))
{
throw new InvalidOperationException();
}

return existingValue;
}

public override int GetHashCode()
{
#if NETCOREAPP
var hash = default(HashCode);
hash.Add(TargetFramework);
hash.Add(AssemblyIdentityComparer);
hash.Add(ReferenceAssemblyPackage);
hash.Add(ReferenceAssemblyPath);
hash.Add(Assemblies, ImmutableArrayEqualityComparer<string>.Instance);
hash.Add(FacadeAssemblies, ImmutableArrayEqualityComparer<string>.Instance);
hash.Add(LanguageSpecificAssemblies, ImmutableDictionaryWithImmutableArrayValuesEqualityComparer<string, string>.Instance);
hash.Add(Packages, ImmutableArrayEqualityComparer<PackageIdentity>.Instance);
hash.Add(NuGetConfigFilePath);
return hash.ToHashCode();
#else
var hashCode = -450793227;
hashCode = (hashCode * -1521134295) + EqualityComparer<string>.Default.GetHashCode(TargetFramework);
hashCode = (hashCode * -1521134295) + EqualityComparer<AssemblyIdentityComparer>.Default.GetHashCode(AssemblyIdentityComparer);
hashCode = (hashCode * -1521134295) + EqualityComparer<PackageIdentity?>.Default.GetHashCode(ReferenceAssemblyPackage);
hashCode = (hashCode * -1521134295) + EqualityComparer<string?>.Default.GetHashCode(ReferenceAssemblyPath);
hashCode = (hashCode * -1521134295) + ImmutableArrayEqualityComparer<string>.Instance.GetHashCode(Assemblies);
hashCode = (hashCode * -1521134295) + ImmutableArrayEqualityComparer<string>.Instance.GetHashCode(FacadeAssemblies);
hashCode = (hashCode * -1521134295) + ImmutableDictionaryWithImmutableArrayValuesEqualityComparer<string, string>.Instance.GetHashCode(LanguageSpecificAssemblies);
hashCode = (hashCode * -1521134295) + ImmutableArrayEqualityComparer<PackageIdentity>.Instance.GetHashCode(Packages);
hashCode = (hashCode * -1521134295) + EqualityComparer<string?>.Default.GetHashCode(NuGetConfigFilePath);
return hashCode;
#endif
}

public override bool Equals(object? obj)
=> Equals(obj as ReferenceAssemblies);

public bool Equals(ReferenceAssemblies? other)
{
return other is not null
&& TargetFramework == other.TargetFramework
&& EqualityComparer<AssemblyIdentityComparer>.Default.Equals(AssemblyIdentityComparer, other.AssemblyIdentityComparer)
&& EqualityComparer<PackageIdentity?>.Default.Equals(ReferenceAssemblyPackage, other.ReferenceAssemblyPackage)
&& ReferenceAssemblyPath == other.ReferenceAssemblyPath
&& ImmutableArrayEqualityComparer<string>.Instance.Equals(Assemblies, other.Assemblies)
&& ImmutableArrayEqualityComparer<string>.Instance.Equals(FacadeAssemblies, other.FacadeAssemblies)
&& ImmutableDictionaryWithImmutableArrayValuesEqualityComparer<string, string>.Instance.Equals(LanguageSpecificAssemblies, other.LanguageSpecificAssemblies)
&& ImmutableArrayEqualityComparer<PackageIdentity>.Instance.Equals(Packages, other.Packages)
&& NuGetConfigFilePath == other.NuGetConfigFilePath;
}

public ReferenceAssemblies WithAssemblyIdentityComparer(AssemblyIdentityComparer assemblyIdentityComparer)
=> new(TargetFramework, assemblyIdentityComparer, ReferenceAssemblyPackage, ReferenceAssemblyPath, Assemblies, FacadeAssemblies, LanguageSpecificAssemblies, Packages, NuGetConfigFilePath);
=> GetOrAddReferenceAssemblies(new(TargetFramework, assemblyIdentityComparer, ReferenceAssemblyPackage, ReferenceAssemblyPath, Assemblies, FacadeAssemblies, LanguageSpecificAssemblies, Packages, NuGetConfigFilePath));

public ReferenceAssemblies WithAssemblies(ImmutableArray<string> assemblies)
=> new(TargetFramework, AssemblyIdentityComparer, ReferenceAssemblyPackage, ReferenceAssemblyPath, assemblies, FacadeAssemblies, LanguageSpecificAssemblies, Packages, NuGetConfigFilePath);
=> GetOrAddReferenceAssemblies(new(TargetFramework, AssemblyIdentityComparer, ReferenceAssemblyPackage, ReferenceAssemblyPath, assemblies, FacadeAssemblies, LanguageSpecificAssemblies, Packages, NuGetConfigFilePath));

public ReferenceAssemblies WithFacadeAssemblies(ImmutableArray<string> facadeAssemblies)
=> new(TargetFramework, AssemblyIdentityComparer, ReferenceAssemblyPackage, ReferenceAssemblyPath, Assemblies, facadeAssemblies, LanguageSpecificAssemblies, Packages, NuGetConfigFilePath);
=> GetOrAddReferenceAssemblies(new(TargetFramework, AssemblyIdentityComparer, ReferenceAssemblyPackage, ReferenceAssemblyPath, Assemblies, facadeAssemblies, LanguageSpecificAssemblies, Packages, NuGetConfigFilePath));

public ReferenceAssemblies AddAssemblies(ImmutableArray<string> assemblies)
=> WithAssemblies(Assemblies.AddRange(assemblies));
Expand All @@ -139,7 +210,7 @@ public ReferenceAssemblies AddFacadeAssemblies(ImmutableArray<string> facadeAsse
=> WithFacadeAssemblies(FacadeAssemblies.AddRange(facadeAssemblies));

public ReferenceAssemblies WithLanguageSpecificAssemblies(ImmutableDictionary<string, ImmutableArray<string>> languageSpecificAssemblies)
=> new(TargetFramework, AssemblyIdentityComparer, ReferenceAssemblyPackage, ReferenceAssemblyPath, Assemblies, FacadeAssemblies, languageSpecificAssemblies, Packages, NuGetConfigFilePath);
=> GetOrAddReferenceAssemblies(new(TargetFramework, AssemblyIdentityComparer, ReferenceAssemblyPackage, ReferenceAssemblyPath, Assemblies, FacadeAssemblies, languageSpecificAssemblies, Packages, NuGetConfigFilePath));

public ReferenceAssemblies WithLanguageSpecificAssemblies(string language, ImmutableArray<string> assemblies)
=> WithLanguageSpecificAssemblies(LanguageSpecificAssemblies.SetItem(language, assemblies));
Expand All @@ -155,13 +226,13 @@ public ReferenceAssemblies AddLanguageSpecificAssemblies(string language, Immuta
}

public ReferenceAssemblies WithPackages(ImmutableArray<PackageIdentity> packages)
=> new(TargetFramework, AssemblyIdentityComparer, ReferenceAssemblyPackage, ReferenceAssemblyPath, Assemblies, FacadeAssemblies, LanguageSpecificAssemblies, packages, NuGetConfigFilePath);
=> GetOrAddReferenceAssemblies(new(TargetFramework, AssemblyIdentityComparer, ReferenceAssemblyPackage, ReferenceAssemblyPath, Assemblies, FacadeAssemblies, LanguageSpecificAssemblies, packages, NuGetConfigFilePath));

public ReferenceAssemblies AddPackages(ImmutableArray<PackageIdentity> packages)
=> WithPackages(Packages.AddRange(packages));

public ReferenceAssemblies WithNuGetConfigFilePath(string nugetConfigFilePath)
=> new(TargetFramework, AssemblyIdentityComparer, ReferenceAssemblyPackage, ReferenceAssemblyPath, Assemblies, FacadeAssemblies, LanguageSpecificAssemblies, Packages, nugetConfigFilePath);
=> GetOrAddReferenceAssemblies(new(TargetFramework, AssemblyIdentityComparer, ReferenceAssemblyPackage, ReferenceAssemblyPath, Assemblies, FacadeAssemblies, LanguageSpecificAssemblies, Packages, nugetConfigFilePath));

public async Task<ImmutableArray<MetadataReference>> ResolveAsync(string? language, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -1353,5 +1424,140 @@ public static bool IsPackageBased(string targetFramework)
return framework.IsPackageBased;
}
}

private sealed class ImmutableArrayEqualityComparer<T> : IEqualityComparer<ImmutableArray<T>>
{
public static readonly ImmutableArrayEqualityComparer<T> Instance = new();

private ImmutableArrayEqualityComparer()
{
}

public bool Equals(ImmutableArray<T> x, ImmutableArray<T> y)
{
if (x.IsDefault)
{
return y.IsDefault;
}
else if (y.IsDefault)
{
return false;
}

if (x.Length != y.Length)
{
return false;
}

for (var i = 0; i < x.Length; i++)
{
if (!EqualityComparer<T>.Default.Equals(x[i], y[i]))
{
return false;
}
}

return true;
}

public int GetHashCode(ImmutableArray<T> obj)
{
if (obj.IsDefault)
{
return 0;
}

#if NETCOREAPP
var hash = default(HashCode);
foreach (var item in obj)
{
hash.Add(item);
}

return hash.ToHashCode();
#else
var hashCode = -450793227;
foreach (var item in obj)
{
hashCode = (hashCode * -1521134295) + EqualityComparer<T>.Default.GetHashCode(item);
}

return hashCode;
#endif
}
}

private sealed class ImmutableDictionaryWithImmutableArrayValuesEqualityComparer<TKey, TValue> : IEqualityComparer<ImmutableDictionary<TKey, ImmutableArray<TValue>>?>
{
public static readonly ImmutableDictionaryWithImmutableArrayValuesEqualityComparer<TKey, TValue> Instance = new();

private ImmutableDictionaryWithImmutableArrayValuesEqualityComparer()
{
}

public bool Equals(ImmutableDictionary<TKey, ImmutableArray<TValue>>? x, ImmutableDictionary<TKey, ImmutableArray<TValue>>? y)
{
if (x is null)
{
return y is null;
}
else if (y is null)
{
return false;
}

if (x.Count != y.Count)
{
return false;
}

foreach (var (key, valueX) in x)
{
// Use a separate lookup in 'y' since ImmutableDictionary<,> can reorder pairs where the key has the
// same hash code.
if (!y.TryGetValue(key, out var valueY))
{
return false;
}

if (!ImmutableArrayEqualityComparer<TValue>.Instance.Equals(valueX, valueY))
{
return false;
}
}

return true;
}

public int GetHashCode(ImmutableDictionary<TKey, ImmutableArray<TValue>>? obj)
{
if (obj is null)
{
return 0;
}

#if NETCOREAPP
var hash = default(HashCode);
foreach (var (key, _) in obj)
{
// Intentionally ignore values since ImmutableDictionary<,> can reorder pairs where the key has the
// same hash code.
hash.Add(key);
}

return hash.ToHashCode();
#else
var hashCode = -450793227;
foreach (var (key, _) in obj)
{
// Intentionally ignore values since ImmutableDictionary<,> can reorder pairs where the key has the
// same hash code.
hashCode = (hashCode * -1521134295) + EqualityComparer<TKey>.Default.GetHashCode(key);
}

return hashCode;
#endif
}
}
}
}

0 comments on commit af952dc

Please sign in to comment.