Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache values in dependent project finder #73493

Merged
merged 16 commits into from
May 16, 2024
Merged
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
Expand All @@ -28,7 +29,15 @@ internal static partial class DependentProjectsFinder
/// Cache from the <see cref="MetadataId"/> for a particular <see cref="PortableExecutableReference"/> to the
/// name of the <see cref="IAssemblySymbol"/> defined by it.
/// </summary>
private static ImmutableDictionary<MetadataId, string> s_metadataIdToAssemblyName = ImmutableDictionary<MetadataId, string>.Empty;
private static readonly Dictionary<MetadataId, string?> s_metadataIdToAssemblyName = new();
private static readonly SemaphoreSlim s_metadataIdToAssemblyNameGate = new(initialCount: 1);

private static readonly ConditionalWeakTable<
Solution,
Dictionary<
(IAssemblySymbol assembly, Project? sourceProject, SymbolVisibility visibility),
ImmutableArray<(Project project, bool hasInternalsAccess)>>> s_solutionToDependentProjectMap = new();
private static readonly SemaphoreSlim s_solutionToDependentProjectMapGate = new(initialCount: 1);

public static async Task<ImmutableArray<Project>> GetDependentProjectsAsync(
Solution solution, ImmutableArray<ISymbol> symbols, IImmutableSet<Project> projects, CancellationToken cancellationToken)
Expand Down Expand Up @@ -128,24 +137,59 @@ private static async Task<ImmutableArray<Project>> GetDependentProjectsWorkerAsy
SymbolVisibility visibility,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
var dictionary = s_solutionToDependentProjectMap.GetValue(solution, static _ => new());

var dependentProjects = new HashSet<(Project, bool hasInternalsAccess)>();
var key = (symbolOrigination.assembly, symbolOrigination.sourceProject, visibility);
ImmutableArray<(Project project, bool hasInternalsAccess)> dependentProjects;

// If a symbol was defined in source, then it is always visible to the project it
// was defined in.
if (symbolOrigination.sourceProject != null)
dependentProjects.Add((symbolOrigination.sourceProject, hasInternalsAccess: true));
// Check cache first.
using (await s_solutionToDependentProjectMapGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
{
if (dictionary.TryGetValue(key, out dependentProjects))
return dependentProjects;
}

// Compute if not in cache.
dependentProjects = await ComputeDependentProjectsWorkerAsync(
solution, symbolOrigination, visibility, cancellationToken).ConfigureAwait(false);

// Try to add to cache, returning existing value if another thread already added it.
using (await s_solutionToDependentProjectMapGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
{
if (dictionary.TryGetValue(key, out dependentProjects))
return dependentProjects;

return dictionary.GetOrAdd(key, dependentProjects);
}

static async Task<ImmutableArray<(Project project, bool hasInternalsAccess)>> ComputeDependentProjectsWorkerAsync(
Solution solution,
(IAssemblySymbol assembly, Project? sourceProject) symbolOrigination,
SymbolVisibility visibility,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

// If it's not private, then we need to find possible references.
if (visibility != SymbolVisibility.Private)
AddNonSubmissionDependentProjects(solution, symbolOrigination, dependentProjects, cancellationToken);
using var _ = PooledHashSet<(Project, bool hasInternalsAccess)>.GetInstance(out var dependentProjects);

// submission projects are special here. The fields generated inside the Script object is private, but
// further submissions can bind to them.
await AddSubmissionDependentProjectsAsync(solution, symbolOrigination.sourceProject, dependentProjects, cancellationToken).ConfigureAwait(false);
// If a symbol was defined in source, then it is always visible to the project it
// was defined in.
if (symbolOrigination.sourceProject != null)
dependentProjects.Add((symbolOrigination.sourceProject, hasInternalsAccess: true));

return [.. dependentProjects];
// If it's not private, then we need to find possible references.
if (visibility != SymbolVisibility.Private)
{
await AddNonSubmissionDependentProjectsAsync(
solution, symbolOrigination, dependentProjects, cancellationToken).ConfigureAwait(false);
}

// submission projects are special here. The fields generated inside the Script object is private, but
// further submissions can bind to them.
await AddSubmissionDependentProjectsAsync(solution, symbolOrigination.sourceProject, dependentProjects, cancellationToken).ConfigureAwait(false);

return [.. dependentProjects];
}
}
CyrusNajmabadi marked this conversation as resolved.
Show resolved Hide resolved

private static async Task AddSubmissionDependentProjectsAsync(
Expand Down Expand Up @@ -221,7 +265,7 @@ private static bool IsInternalsVisibleToAttribute(AttributeData attr)
attrType.ContainingNamespace.ContainingNamespace.ContainingNamespace.ContainingNamespace?.IsGlobalNamespace == true;
}

private static void AddNonSubmissionDependentProjects(
private static async Task AddNonSubmissionDependentProjectsAsync(
Solution solution,
(IAssemblySymbol assembly, Project? sourceProject) symbolOrigination,
HashSet<(Project project, bool hasInternalsAccess)> dependentProjects,
Expand All @@ -235,7 +279,7 @@ private static void AddNonSubmissionDependentProjects(
foreach (var project in solution.Projects)
{
if (!project.SupportsCompilation ||
!HasReferenceTo(symbolOrigination, project, cancellationToken))
!await HasReferenceToAsync(symbolOrigination, project, cancellationToken).ConfigureAwait(false))
{
continue;
}
Expand Down Expand Up @@ -270,7 +314,7 @@ private static HashSet<string> GetInternalsVisibleToSet(IAssemblySymbol assembly
return set;
}

private static bool HasReferenceTo(
private static async Task<bool> HasReferenceToAsync(
(IAssemblySymbol assembly, Project? sourceProject) symbolOrigination,
Project project,
CancellationToken cancellationToken)
Expand All @@ -284,10 +328,11 @@ private static bool HasReferenceTo(
return project.ProjectReferences.Any(p => p.ProjectId == symbolOrigination.sourceProject.Id);

// Otherwise, if the symbol is from metadata, see if the project's compilation references that metadata assembly.
return HasReferenceToAssembly(project, symbolOrigination.assembly.Name, cancellationToken);
return await HasReferenceToAssemblyAsync(
project, symbolOrigination.assembly.Name, cancellationToken).ConfigureAwait(false);
}

private static bool HasReferenceToAssembly(Project project, string assemblyName, CancellationToken cancellationToken)
private static async Task<bool> HasReferenceToAssemblyAsync(Project project, string assemblyName, CancellationToken cancellationToken)
{
Contract.ThrowIfFalse(project.SupportsCompilation);

Expand All @@ -307,31 +352,41 @@ private static bool HasReferenceToAssembly(Project project, string assemblyName,
if (metadataId is null)
continue;

if (!s_metadataIdToAssemblyName.TryGetValue(metadataId, out var name))
using (await s_metadataIdToAssemblyNameGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
{
uncomputedReferences.Add((peReference, metadataId));
continue;
if (s_metadataIdToAssemblyName.TryGetValue(metadataId, out var name))
{
// We already know the assembly name for this metadata id. If it matches the one we're looking for,
// we're done. Otherwise, keep looking.
if (name == assemblyName)
return true;
else
continue;
}
}

if (name == assemblyName)
return true;
// We didn't know the name for the metadata id. Add it to the list of things we need to compute below.
uncomputedReferences.Add((peReference, metadataId));
}

if (uncomputedReferences.Count == 0)
return false;

Compilation? compilation = null;
var compilation = CreateCompilation(project);

foreach (var (peReference, metadataId) in uncomputedReferences)
{
cancellationToken.ThrowIfCancellationRequested();

if (!s_metadataIdToAssemblyName.TryGetValue(metadataId, out var name))
// Attempt to get the assembly name for this pe-reference. If we fail, we still want to add that info into
// the dictionary (by mapping us to 'null'). That way we don't keep trying to compute it over and over.
var name = compilation.GetAssemblyOrModuleSymbol(peReference) is IAssemblySymbol { Name: string metadataAssemblyName }
? metadataAssemblyName
: null;

using (await s_metadataIdToAssemblyNameGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
{
// Defer creating the compilation till needed.
CreateCompilation(project, ref compilation);
if (compilation.GetAssemblyOrModuleSymbol(peReference) is IAssemblySymbol { Name: string metadataAssemblyName })
name = ImmutableInterlocked.GetOrAdd(ref s_metadataIdToAssemblyName, metadataId, metadataAssemblyName);
name = s_metadataIdToAssemblyName.GetOrAdd(metadataId, name);
}

if (name == assemblyName)
Expand All @@ -340,18 +395,15 @@ private static bool HasReferenceToAssembly(Project project, string assemblyName,

return false;

static void CreateCompilation(Project project, [NotNull] ref Compilation? compilation)
static Compilation CreateCompilation(Project project)
{
if (compilation != null)
return;

// Use the project's compilation if it has one.
if (project.TryGetCompilation(out compilation))
return;
if (project.TryGetCompilation(out var compilation))
return compilation;

// Perf: check metadata reference using newly created empty compilation with only metadata references.
var factory = project.Services.GetRequiredService<ICompilationFactoryService>();
compilation = factory
return factory
.CreateCompilation(project.AssemblyName, project.CompilationOptions!)
.AddReferences(project.MetadataReferences);
}
Expand Down
Loading