Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache values in dependent project finder #73493

Merged
merged 16 commits into from
May 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
Expand All @@ -28,7 +29,15 @@ internal static partial class DependentProjectsFinder
/// Cache from the <see cref="MetadataId"/> for a particular <see cref="PortableExecutableReference"/> to the
/// name of the <see cref="IAssemblySymbol"/> defined by it.
/// </summary>
private static ImmutableDictionary<MetadataId, string> s_metadataIdToAssemblyName = ImmutableDictionary<MetadataId, string>.Empty;
private static readonly Dictionary<MetadataId, string?> s_metadataIdToAssemblyName = new();
private static readonly SemaphoreSlim s_metadataIdToAssemblyNameGate = new(initialCount: 1);

private static readonly ConditionalWeakTable<
Solution,
Dictionary<
(IAssemblySymbol assembly, Project? sourceProject, SymbolVisibility visibility),
ImmutableArray<(Project project, bool hasInternalsAccess)>>> s_solutionToDependentProjectMap = new();
private static readonly SemaphoreSlim s_solutionToDependentProjectMapGate = new(initialCount: 1);

public static async Task<ImmutableArray<Project>> GetDependentProjectsAsync(
Solution solution, ImmutableArray<ISymbol> symbols, IImmutableSet<Project> projects, CancellationToken cancellationToken)
Expand Down Expand Up @@ -128,24 +137,56 @@ private static async Task<ImmutableArray<Project>> GetDependentProjectsWorkerAsy
SymbolVisibility visibility,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
var dictionary = s_solutionToDependentProjectMap.GetValue(solution, static _ => new());

var dependentProjects = new HashSet<(Project, bool hasInternalsAccess)>();
var key = (symbolOrigination.assembly, symbolOrigination.sourceProject, visibility);
ImmutableArray<(Project project, bool hasInternalsAccess)> dependentProjects;

// If a symbol was defined in source, then it is always visible to the project it
// was defined in.
if (symbolOrigination.sourceProject != null)
dependentProjects.Add((symbolOrigination.sourceProject, hasInternalsAccess: true));
// Check cache first.
using (await s_solutionToDependentProjectMapGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
{
if (dictionary.TryGetValue(key, out dependentProjects))
return dependentProjects;
}

// Compute if not in cache.
dependentProjects = await ComputeDependentProjectsWorkerAsync(
solution, symbolOrigination, visibility, cancellationToken).ConfigureAwait(false);

// If it's not private, then we need to find possible references.
if (visibility != SymbolVisibility.Private)
AddNonSubmissionDependentProjects(solution, symbolOrigination, dependentProjects, cancellationToken);
// Try to add to cache, returning existing value if another thread already added it.
using (await s_solutionToDependentProjectMapGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
{
return dictionary.GetOrAdd(key, dependentProjects);
}

static async Task<ImmutableArray<(Project project, bool hasInternalsAccess)>> ComputeDependentProjectsWorkerAsync(
Solution solution,
(IAssemblySymbol assembly, Project? sourceProject) symbolOrigination,
SymbolVisibility visibility,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

using var _ = PooledHashSet<(Project, bool hasInternalsAccess)>.GetInstance(out var dependentProjects);

// If a symbol was defined in source, then it is always visible to the project it
// was defined in.
if (symbolOrigination.sourceProject != null)
dependentProjects.Add((symbolOrigination.sourceProject, hasInternalsAccess: true));

// If it's not private, then we need to find possible references.
if (visibility != SymbolVisibility.Private)
{
await AddNonSubmissionDependentProjectsAsync(
solution, symbolOrigination, dependentProjects, cancellationToken).ConfigureAwait(false);
}

// submission projects are special here. The fields generated inside the Script object is private, but
// further submissions can bind to them.
await AddSubmissionDependentProjectsAsync(solution, symbolOrigination.sourceProject, dependentProjects, cancellationToken).ConfigureAwait(false);
// submission projects are special here. The fields generated inside the Script object is private, but
// further submissions can bind to them.
await AddSubmissionDependentProjectsAsync(solution, symbolOrigination.sourceProject, dependentProjects, cancellationToken).ConfigureAwait(false);

return [.. dependentProjects];
return [.. dependentProjects];
}
}
CyrusNajmabadi marked this conversation as resolved.
Show resolved Hide resolved

private static async Task AddSubmissionDependentProjectsAsync(
Expand All @@ -154,7 +195,7 @@ private static async Task AddSubmissionDependentProjectsAsync(
if (sourceProject?.IsSubmission != true)
return;

var projectIdsToReferencingSubmissionIds = new Dictionary<ProjectId, List<ProjectId>>();
using var _1 = PooledDictionary<ProjectId, List<ProjectId>>.GetInstance(out var projectIdsToReferencingSubmissionIds);

// search only submission project
foreach (var projectId in solution.ProjectIds)
Expand All @@ -173,15 +214,7 @@ private static async Task AddSubmissionDependentProjectsAsync(
{
var referencedProject = solution.GetProject(previous.Assembly, cancellationToken);
if (referencedProject != null)
{
if (!projectIdsToReferencingSubmissionIds.TryGetValue(referencedProject.Id, out var referencingSubmissions))
{
referencingSubmissions = [];
projectIdsToReferencingSubmissionIds.Add(referencedProject.Id, referencingSubmissions);
}

referencingSubmissions.Add(project.Id);
}
projectIdsToReferencingSubmissionIds.MultiAdd(referencedProject.Id, project.Id);
}
}
}
Expand All @@ -191,7 +224,7 @@ private static async Task AddSubmissionDependentProjectsAsync(
// and 2, even though 2 doesn't have a direct reference to 1. Hence we need to take
// our current set of projects and find the transitive closure over backwards
// submission previous references.
using var _ = ArrayBuilder<ProjectId>.GetInstance(out var projectIdsToProcess);
using var _2 = ArrayBuilder<ProjectId>.GetInstance(out var projectIdsToProcess);
foreach (var dependentProject in dependentProjects.Select(dp => dp.project.Id))
projectIdsToProcess.Push(dependentProject);

Expand Down Expand Up @@ -221,7 +254,7 @@ private static bool IsInternalsVisibleToAttribute(AttributeData attr)
attrType.ContainingNamespace.ContainingNamespace.ContainingNamespace.ContainingNamespace?.IsGlobalNamespace == true;
}

private static void AddNonSubmissionDependentProjects(
private static async Task AddNonSubmissionDependentProjectsAsync(
Solution solution,
(IAssemblySymbol assembly, Project? sourceProject) symbolOrigination,
HashSet<(Project project, bool hasInternalsAccess)> dependentProjects,
Expand All @@ -235,7 +268,7 @@ private static void AddNonSubmissionDependentProjects(
foreach (var project in solution.Projects)
{
if (!project.SupportsCompilation ||
!HasReferenceTo(symbolOrigination, project, cancellationToken))
!await HasReferenceToAsync(symbolOrigination, project, cancellationToken).ConfigureAwait(false))
{
continue;
}
Expand Down Expand Up @@ -270,7 +303,7 @@ private static HashSet<string> GetInternalsVisibleToSet(IAssemblySymbol assembly
return set;
}

private static bool HasReferenceTo(
private static async Task<bool> HasReferenceToAsync(
(IAssemblySymbol assembly, Project? sourceProject) symbolOrigination,
Project project,
CancellationToken cancellationToken)
Expand All @@ -284,10 +317,11 @@ private static bool HasReferenceTo(
return project.ProjectReferences.Any(p => p.ProjectId == symbolOrigination.sourceProject.Id);

// Otherwise, if the symbol is from metadata, see if the project's compilation references that metadata assembly.
return HasReferenceToAssembly(project, symbolOrigination.assembly.Name, cancellationToken);
return await HasReferenceToAssemblyAsync(
project, symbolOrigination.assembly.Name, cancellationToken).ConfigureAwait(false);
}

private static bool HasReferenceToAssembly(Project project, string assemblyName, CancellationToken cancellationToken)
private static async Task<bool> HasReferenceToAssemblyAsync(Project project, string assemblyName, CancellationToken cancellationToken)
{
Contract.ThrowIfFalse(project.SupportsCompilation);

Expand All @@ -307,31 +341,50 @@ private static bool HasReferenceToAssembly(Project project, string assemblyName,
if (metadataId is null)
continue;

if (!s_metadataIdToAssemblyName.TryGetValue(metadataId, out var name))
using (await s_metadataIdToAssemblyNameGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
{
uncomputedReferences.Add((peReference, metadataId));
continue;
if (s_metadataIdToAssemblyName.TryGetValue(metadataId, out var name))
{
// We already know the assembly name for this metadata id. If it matches the one we're looking for,
// we're done. Otherwise, keep looking.
if (name == assemblyName)
return true;
else
continue;
}
}

if (name == assemblyName)
return true;
// We didn't know the name for the metadata id. Add it to the list of things we need to compute below.
uncomputedReferences.Add((peReference, metadataId));
}

if (uncomputedReferences.Count == 0)
return false;

Compilation? compilation = null;
var compilation = CreateCompilation(project);

foreach (var (peReference, metadataId) in uncomputedReferences)
{
cancellationToken.ThrowIfCancellationRequested();

if (!s_metadataIdToAssemblyName.TryGetValue(metadataId, out var name))
// Attempt to get the assembly name for this pe-reference. If we fail, we still want to add that info into
// the dictionary (by mapping us to 'null'). That way we don't keep trying to compute it over and over.
var name = compilation.GetAssemblyOrModuleSymbol(peReference) is IAssemblySymbol { Name: string metadataAssemblyName }
? metadataAssemblyName
: null;

using (await s_metadataIdToAssemblyNameGate.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
{
// Defer creating the compilation till needed.
CreateCompilation(project, ref compilation);
if (compilation.GetAssemblyOrModuleSymbol(peReference) is IAssemblySymbol { Name: string metadataAssemblyName })
name = ImmutableInterlocked.GetOrAdd(ref s_metadataIdToAssemblyName, metadataId, metadataAssemblyName);
// Overwrite an existing null name with a non-null one.
if (s_metadataIdToAssemblyName.TryGetValue(metadataId, out var existingName) &&
existingName == null &&
name != null)
{
s_metadataIdToAssemblyName[metadataId] = name;
}

// Return whatever is in the map, adding ourselves if something is not already there.
name = s_metadataIdToAssemblyName.GetOrAdd(metadataId, name);
}

if (name == assemblyName)
Expand All @@ -340,18 +393,15 @@ private static bool HasReferenceToAssembly(Project project, string assemblyName,

return false;

static void CreateCompilation(Project project, [NotNull] ref Compilation? compilation)
static Compilation CreateCompilation(Project project)
{
if (compilation != null)
return;

// Use the project's compilation if it has one.
if (project.TryGetCompilation(out compilation))
return;
if (project.TryGetCompilation(out var compilation))
return compilation;

// Perf: check metadata reference using newly created empty compilation with only metadata references.
var factory = project.Services.GetRequiredService<ICompilationFactoryService>();
compilation = factory
return factory
.CreateCompilation(project.AssemblyName, project.CompilationOptions!)
.AddReferences(project.MetadataReferences);
}
Expand Down
Loading