Skip to content

Commit 8aab131

Browse files
committed
Semantic Search: Add support for async queries and FAR tool
Finder
1 parent 43fd6f9 commit 8aab131

25 files changed

+453
-112
lines changed

src/Features/CSharp/Portable/SemanticSearch/CSharpSemanticSearchService.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ protected override Compilation CreateCompilation(
2929
{
3030
var syntaxTreeFactory = services.GetRequiredLanguageService<ISyntaxTreeFactoryService>(LanguageNames.CSharp);
3131

32-
var globalUsingsTree = syntaxTreeFactory.ParseSyntaxTree(
32+
var globalUsingsAndToolsTree = syntaxTreeFactory.ParseSyntaxTree(
3333
filePath: null,
3434
CSharpSemanticSearchUtilities.ParseOptions,
35-
SemanticSearchUtilities.CreateSourceText(CSharpSemanticSearchUtilities.Configuration.GlobalUsings),
35+
SemanticSearchUtilities.CreateSourceText(CSharpSemanticSearchUtilities.Configuration.GlobalUsingsAndTools),
3636
cancellationToken);
3737

3838
queryTree = syntaxTreeFactory.ParseSyntaxTree(
@@ -43,7 +43,7 @@ protected override Compilation CreateCompilation(
4343

4444
return CSharpCompilation.Create(
4545
assemblyName: SemanticSearchUtilities.QueryProjectName,
46-
[queryTree, globalUsingsTree],
46+
[queryTree, globalUsingsAndToolsTree],
4747
references,
4848
CSharpSemanticSearchUtilities.CompilationOptions);
4949
}

src/Features/CSharp/Portable/SemanticSearch/CSharpSemanticSearchUtilities.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,25 @@ static IEnumerable<ISymbol> Find(Compilation compilation)
2020
return compilation.Assembly.GlobalNamespace.GetMembers("C");
2121
}
2222
""",
23-
GlobalUsings = """
23+
GlobalUsingsAndTools = $$"""
2424
global using System;
2525
global using System.Collections.Generic;
2626
global using System.Collections.Immutable;
2727
global using System.Linq;
2828
global using System.Threading;
2929
global using System.Threading.Tasks;
3030
global using Microsoft.CodeAnalysis;
31+
32+
public static class {{SemanticSearchUtilities.ToolsTypeName}}
33+
{
34+
private static Func<ISymbol, IAsyncEnumerable<ISymbol>> {{SemanticSearchUtilities.FindReferencingSymbolsImplName}};
35+
36+
/// <summary>
37+
/// Returns all symbols that reference (use) a given <paramref name="symbol" />.
38+
/// </summary>
39+
public static IAsyncEnumerable<ISymbol> FindReferencingSymbols(this ISymbol symbol)
40+
=> {{SemanticSearchUtilities.FindReferencingSymbolsImplName}}(symbol);
41+
}
3142
""",
3243
EditorConfig = """
3344
is_global = true

src/Features/CSharpTest/SemanticSearch/CSharpSemanticSearchServiceTests.cs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,64 @@ static IEnumerable<ISymbol> Find(IEventSymbol e)
264264
AssertEx.Equal(["event Action C.E"], results.Select(Inspect));
265265
}
266266

267+
[ConditionalFact(typeof(CoreClrOnly))]
268+
public async Task FindReferencingSymbols()
269+
{
270+
using var workspace = TestWorkspace.Create("""
271+
<Workspace>
272+
<Project Language="C#" CommonReferences="true">
273+
<Document FilePath="File1.cs">
274+
class C
275+
{
276+
void F()
277+
{
278+
}
279+
}
280+
281+
class D
282+
{
283+
void R1() => new C().F();
284+
void R2() => new C().F();
285+
}
286+
</Document>
287+
</Project>
288+
</Workspace>
289+
""", composition: FeaturesTestCompositions.Features);
290+
291+
var solution = workspace.CurrentSolution;
292+
293+
var service = solution.Services.GetRequiredLanguageService<ISemanticSearchService>(LanguageNames.CSharp);
294+
295+
var query = """
296+
static async IAsyncEnumerable<ISymbol> Find(IMethodSymbol e)
297+
{
298+
if (e.Name != "F")
299+
{
300+
yield break;
301+
}
302+
303+
await foreach (var s in e.FindReferencingSymbols())
304+
{
305+
yield return s;
306+
}
307+
}
308+
""";
309+
310+
var results = new List<DefinitionItem>();
311+
var observer = new MockSemanticSearchResultsObserver() { OnDefinitionFoundImpl = results.Add };
312+
var traceSource = new TraceSource("test");
313+
314+
var options = workspace.GlobalOptions.GetClassificationOptionsProvider();
315+
var result = await service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, CancellationToken.None);
316+
317+
Assert.Null(result.ErrorMessage);
318+
AssertEx.Equal(
319+
[
320+
"void D.R1()",
321+
"void D.R2()"
322+
], results.Select(Inspect));
323+
}
324+
267325
[ConditionalFact(typeof(CoreClrOnly))]
268326
public async Task ForcedCancellation()
269327
{

src/Features/Core/Portable/FeaturesResources.resx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3165,8 +3165,11 @@ Zero-width positive lookbehind assertions are typically used at the beginning of
31653165
<data name="Top_level_function_0_must_have_a_single_parameter" xml:space="preserve">
31663166
<value>Top-level function '{0}' must have a single parameter</value>
31673167
</data>
3168-
<data name="Type_0_is_not_among_supported_types_1" xml:space="preserve">
3169-
<value>Type '{0}' is not among supported types: {1}</value>
3168+
<data name="Parameter_type_0_is_not_among_supported_types_1" xml:space="preserve">
3169+
<value>Parameter type '{0}' is not among supported types: {1}</value>
3170+
</data>
3171+
<data name="Return_type_0_is_not_among_supported_types_1" xml:space="preserve">
3172+
<value>Return type '{0}' is not among supported types: {1}</value>
31703173
</data>
31713174
<data name="Unable_to_load_type_0_1" xml:space="preserve">
31723175
<value>Unable to load type '{0}': '{1}'</value>

src/Features/Core/Portable/SemanticSearch/AbstractSemanticSearchService.cs

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
using System.Text;
1818
using System.Text.RegularExpressions;
1919
using System.Threading;
20+
using System.Threading.Channels;
2021
using System.Threading.Tasks;
2122
using Microsoft.CodeAnalysis.Classification;
2223
using Microsoft.CodeAnalysis.Diagnostics;
@@ -28,12 +29,14 @@
2829
using Microsoft.CodeAnalysis.Host.Mef;
2930
using Microsoft.CodeAnalysis.Internal.Log;
3031
using Microsoft.CodeAnalysis.NavigateTo;
32+
using Microsoft.CodeAnalysis.Notification;
3133
using Microsoft.CodeAnalysis.PooledObjects;
3234
using Microsoft.CodeAnalysis.Shared.Collections;
3335
using Microsoft.CodeAnalysis.Shared.Extensions;
3436
using Microsoft.CodeAnalysis.Shared.Utilities;
3537
using Microsoft.CodeAnalysis.Tags;
3638
using Microsoft.CodeAnalysis.Text;
39+
using Microsoft.CodeAnalysis.Threading;
3740
using Roslyn.Utilities;
3841

3942
namespace Microsoft.CodeAnalysis.SemanticSearch;
@@ -133,14 +136,10 @@ public async Task<ExecuteQueryResult> ExecuteQueryAsync(
133136
try
134137
{
135138
var queryAssembly = loadContext.LoadFromStream(peStream, pdbStream);
139+
SetModuleCancellationToken(queryAssembly, cancellationToken);
140+
SetFindReferencingSymbolsImpl(queryAssembly, new ReferencingSymbolsFinder(solution, classificationOptions, cancellationToken));
136141

137-
var pidType = queryAssembly.GetType("<PrivateImplementationDetails>", throwOnError: true);
138-
Contract.ThrowIfNull(pidType);
139-
var moduleCancellationTokenField = pidType.GetField("ModuleCancellationToken", BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Static);
140-
Contract.ThrowIfNull(moduleCancellationTokenField);
141-
moduleCancellationTokenField.SetValue(null, cancellationToken);
142-
143-
if (!TryGetFindMethod(queryAssembly, out var findMethod, out var queryKind, out var errorMessage, out var errorMessageArgs))
142+
if (!TryGetFindMethod(queryAssembly, out var findMethod, out var queryKind, out var isAsync, out var errorMessage, out var errorMessageArgs))
144143
{
145144
traceSource.TraceInformation($"Semantic search failed: {errorMessage}");
146145
return CreateResult(compilationErrors: [], errorMessage, errorMessageArgs);
@@ -149,7 +148,7 @@ public async Task<ExecuteQueryResult> ExecuteQueryAsync(
149148
var invocationContext = new QueryExecutionContext(queryText, findMethod, observer, classificationOptions, traceSource);
150149
try
151150
{
152-
await invocationContext.InvokeAsync(solution, queryKind, cancellationToken).ConfigureAwait(false);
151+
await invocationContext.InvokeAsync(solution, queryKind, isAsync, cancellationToken).ConfigureAwait(false);
153152

154153
if (invocationContext.TerminatedWithException)
155154
{
@@ -184,12 +183,31 @@ ExecuteQueryResult CreateResult(ImmutableArray<QueryCompilationError> compilatio
184183
}
185184
}
186185

187-
private static bool TryGetFindMethod(Assembly queryAssembly, [NotNullWhen(true)] out MethodInfo? method, out QueryKind queryKind, out string? error, out string[]? errorMessageArgs)
186+
private static void SetModuleCancellationToken(Assembly queryAssembly, CancellationToken cancellationToken)
187+
{
188+
var pidType = queryAssembly.GetType("<PrivateImplementationDetails>", throwOnError: true);
189+
Contract.ThrowIfNull(pidType);
190+
var moduleCancellationTokenField = pidType.GetField("ModuleCancellationToken", BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Static);
191+
Contract.ThrowIfNull(moduleCancellationTokenField);
192+
moduleCancellationTokenField.SetValue(null, cancellationToken);
193+
}
194+
195+
private static void SetFindReferencingSymbolsImpl(Assembly queryAssembly, ReferencingSymbolsFinder finder)
196+
{
197+
var toolsType = queryAssembly.GetType(SemanticSearchUtilities.ToolsTypeName, throwOnError: true);
198+
Contract.ThrowIfNull(toolsType);
199+
var findReferencingSymbolsImplField = toolsType.GetField(SemanticSearchUtilities.FindReferencingSymbolsImplName, BindingFlags.NonPublic | BindingFlags.Static);
200+
Contract.ThrowIfNull(findReferencingSymbolsImplField);
201+
findReferencingSymbolsImplField.SetValue(null, new Func<ISymbol, IAsyncEnumerable<ISymbol>>(finder.FindSymbolsAsync));
202+
}
203+
204+
private static bool TryGetFindMethod(Assembly queryAssembly, [NotNullWhen(true)] out MethodInfo? method, out QueryKind queryKind, out bool isAsync, out string? error, out string[]? errorMessageArgs)
188205
{
189206
method = null;
190207
error = null;
191208
errorMessageArgs = null;
192209
queryKind = default;
210+
isAsync = false;
193211

194212
Type? program;
195213
try
@@ -207,7 +225,7 @@ private static bool TryGetFindMethod(Assembly queryAssembly, [NotNullWhen(true)]
207225
{
208226
try
209227
{
210-
(method, queryKind) = GetFindMethod(program, ref error);
228+
(method, queryKind, isAsync) = GetFindMethod(program, ref error);
211229
}
212230
catch
213231
{
@@ -223,7 +241,7 @@ private static bool TryGetFindMethod(Assembly queryAssembly, [NotNullWhen(true)]
223241
return false;
224242
}
225243

226-
private static (MethodInfo? method, QueryKind queryKind) GetFindMethod(Type type, ref string? error)
244+
private static (MethodInfo? method, QueryKind queryKind, bool isAsync) GetFindMethod(Type type, ref string? error)
227245
{
228246
try
229247
{
@@ -262,17 +280,30 @@ private static (MethodInfo? method, QueryKind queryKind) GetFindMethod(Type type
262280
return default;
263281
}
264282

265-
if (!s_queryKindByParameterType.TryGetValue(parameter.ParameterType, out var entity))
283+
if (!s_queryKindByParameterType.TryGetValue(parameter.ParameterType, out var queryKind))
266284
{
267285
error = string.Format(
268-
FeaturesResources.Type_0_is_not_among_supported_types_1,
269-
SemanticSearchUtilities.FindMethodName,
286+
FeaturesResources.Parameter_type_0_is_not_among_supported_types_1,
287+
parameter.ParameterType,
270288
string.Join(", ", s_queryKindByParameterType.Keys.Select(t => $"'{t.Name}'")));
271289

272290
return default;
273291
}
274292

275-
return (method, entity);
293+
// IEnumerable<ISymbol>
294+
// IAsyncEnumerable<ISymbol>
295+
bool? isAsync = method.ReturnType == typeof(IEnumerable<ISymbol>) ? false : method.ReturnType == typeof(IAsyncEnumerable<ISymbol>) ? true : null;
296+
if (isAsync == null)
297+
{
298+
error = string.Format(
299+
FeaturesResources.Return_type_0_is_not_among_supported_types_1,
300+
method.ReturnType,
301+
$"'{typeof(IEnumerable<ISymbol>).Name}', '{typeof(IAsyncEnumerable<ISymbol>).Name}'");
302+
303+
return default;
304+
}
305+
306+
return (method, queryKind, isAsync.Value);
276307
}
277308
catch (Exception e)
278309
{

src/Features/Core/Portable/SemanticSearch/QueryExecutionContext.cs

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ internal sealed class QueryExecutionContext(
4141
public long ExecutionTime => _executionTime;
4242
public int ProcessedProjectCount => _processedProjectCount;
4343

44-
public async Task InvokeAsync(Solution solution, QueryKind targetEntity, CancellationToken cancellationToken)
44+
public async Task InvokeAsync(Solution solution, QueryKind queryKind, bool isAsync, CancellationToken cancellationToken)
4545
{
4646
// Invoke query on projects and types in parallel and on members serially.
4747
// Cancel execution if the query throws an exception.
@@ -63,16 +63,16 @@ await Parallel.ForEachAsync(solution.Projects, symbolEnumerationCancellationSour
6363
// only search source symbols:
6464
var rootNamespace = compilation.Assembly.GlobalNamespace;
6565

66-
switch (targetEntity)
66+
switch (queryKind)
6767
{
6868
case QueryKind.Compilation:
69-
await InvokeAsync(project, compilation, entity: compilation, symbolEnumerationCancellationSource, cancellationToken).ConfigureAwait(false);
69+
await InvokeAsync(project, compilation, entity: compilation, isAsync, symbolEnumerationCancellationSource, cancellationToken).ConfigureAwait(false);
7070
break;
7171

7272
case QueryKind.Namespace:
7373
await Parallel.ForEachAsync(rootNamespace.GetAllNamespaces(cancellationToken), cancellationToken, async (namespaceSymbol, cancellationToken) =>
7474
{
75-
await InvokeAsync(project, compilation, entity: namespaceSymbol, symbolEnumerationCancellationSource, cancellationToken).ConfigureAwait(false);
75+
await InvokeAsync(project, compilation, entity: namespaceSymbol, isAsync, symbolEnumerationCancellationSource, cancellationToken).ConfigureAwait(false);
7676
}).ConfigureAwait(false);
7777
break;
7878

@@ -82,21 +82,21 @@ await Parallel.ForEachAsync(rootNamespace.GetAllNamespaces(cancellationToken), c
8282
case QueryKind.Property:
8383
case QueryKind.Event:
8484

85-
var kind = GetSymbolKind(targetEntity);
85+
var kind = GetSymbolKind(queryKind);
8686

8787
await Parallel.ForEachAsync(rootNamespace.GetAllTypes(cancellationToken), async (type, cancellationToken) =>
8888
{
8989
if (kind == SymbolKind.NamedType)
9090
{
91-
await InvokeAsync(project, compilation, entity: type, symbolEnumerationCancellationSource, cancellationToken).ConfigureAwait(false);
91+
await InvokeAsync(project, compilation, entity: type, isAsync, symbolEnumerationCancellationSource, cancellationToken).ConfigureAwait(false);
9292
}
9393
else
9494
{
9595
foreach (var member in type.GetMembers())
9696
{
9797
if (member.Kind == kind)
9898
{
99-
await InvokeAsync(project, compilation, entity: member, symbolEnumerationCancellationSource, cancellationToken).ConfigureAwait(false);
99+
await InvokeAsync(project, compilation, entity: member, isAsync, symbolEnumerationCancellationSource, cancellationToken).ConfigureAwait(false);
100100
}
101101
}
102102
}
@@ -118,7 +118,7 @@ await Parallel.ForEachAsync(rootNamespace.GetAllTypes(cancellationToken), async
118118
}
119119
}
120120

121-
private async ValueTask InvokeAsync(Project project, Compilation compilation, object entity, CancellationTokenSource symbolEnumerationCancellationSource, CancellationToken cancellationToken)
121+
private async ValueTask InvokeAsync(Project project, Compilation compilation, object entity, bool isAsync, CancellationTokenSource symbolEnumerationCancellationSource, CancellationToken cancellationToken)
122122
{
123123
cancellationToken.ThrowIfCancellationRequested();
124124

@@ -130,30 +130,49 @@ private async ValueTask InvokeAsync(Project project, Compilation compilation, ob
130130

131131
try
132132
{
133-
var symbols = (IEnumerable<ISymbol?>?)method.Invoke(null, [entity]) ?? [];
133+
var symbols = method.Invoke(null, [entity]);
134+
if (symbols != null)
135+
{
136+
if (isAsync)
137+
{
138+
await foreach (var symbol in ((IAsyncEnumerable<ISymbol?>)symbols).WithCancellation(cancellationToken).ConfigureAwait(false))
139+
{
140+
await ObserveSymbolAsync(symbol).ConfigureAwait(false);
141+
}
142+
}
143+
else
144+
{
145+
foreach (var symbol in (IEnumerable<ISymbol?>)symbols)
146+
{
147+
await ObserveSymbolAsync(symbol).ConfigureAwait(false);
148+
}
149+
}
150+
}
134151

135-
foreach (var symbol in symbols)
152+
async ValueTask ObserveSymbolAsync(ISymbol? symbol)
136153
{
137154
cancellationToken.ThrowIfCancellationRequested();
138155

139-
if (symbol != null)
156+
if (symbol == null)
140157
{
141-
executionTime += Stopwatch.GetElapsedTime(executionStart);
158+
return;
159+
}
142160

143-
try
144-
{
145-
var definitionItem = await symbol.ToClassifiedDefinitionItemAsync(
146-
classificationOptions, project.Solution, s_findReferencesSearchOptions, isPrimary: true, includeHiddenLocations: false, cancellationToken).ConfigureAwait(false);
161+
executionTime += Stopwatch.GetElapsedTime(executionStart);
147162

148-
await resultsObserver.OnDefinitionFoundAsync(definitionItem, cancellationToken).ConfigureAwait(false);
149-
}
150-
catch (Exception e) when (FatalError.ReportAndCatchUnlessCanceled(e, cancellationToken))
151-
{
152-
// skip symbol
153-
}
163+
try
164+
{
165+
var definitionItem = await symbol.ToClassifiedDefinitionItemAsync(
166+
classificationOptions, project.Solution, s_findReferencesSearchOptions, isPrimary: true, includeHiddenLocations: false, cancellationToken).ConfigureAwait(false);
154167

155-
executionStart = Stopwatch.GetTimestamp();
168+
await resultsObserver.OnDefinitionFoundAsync(definitionItem, cancellationToken).ConfigureAwait(false);
156169
}
170+
catch (Exception e) when (FatalError.ReportAndCatchUnlessCanceled(e, cancellationToken))
171+
{
172+
// skip symbol
173+
}
174+
175+
executionStart = Stopwatch.GetTimestamp();
157176
}
158177
}
159178
finally

0 commit comments

Comments
 (0)