Skip to content

Commit 9acfcbd

Browse files
authored
Split query execution into compile and execute calls (#78081)
1 parent 605d934 commit 9acfcbd

File tree

8 files changed

+338
-218
lines changed

8 files changed

+338
-218
lines changed

src/Features/CSharpTest/SemanticSearch/CSharpSemanticSearchServiceTests.cs

Lines changed: 58 additions & 125 deletions
Large diffs are not rendered by default.

src/Features/Core/Portable/SemanticSearch/AbstractSemanticSearchService.cs

Lines changed: 87 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,19 @@ protected override IntPtr LoadUnmanagedDll(string unmanagedDllName)
5454
=> IntPtr.Zero;
5555
}
5656

57+
private readonly struct CompiledQuery(MemoryStream peStream, MemoryStream pdbStream, SourceText text) : IDisposable
58+
{
59+
public MemoryStream PEStream { get; } = peStream;
60+
public MemoryStream PdbStream { get; } = pdbStream;
61+
public SourceText Text { get; } = text;
62+
63+
public void Dispose()
64+
{
65+
PEStream.Dispose();
66+
PdbStream.Dispose();
67+
}
68+
}
69+
5770
/// <summary>
5871
/// Mapping from the parameter type of the <c>Find</c> method to the <see cref="QueryKind"/> value.
5972
/// </summary>
@@ -66,76 +79,93 @@ protected override IntPtr LoadUnmanagedDll(string unmanagedDllName)
6679
.Add(typeof(IPropertySymbol), QueryKind.Property)
6780
.Add(typeof(IEventSymbol), QueryKind.Event);
6881

82+
private ImmutableDictionary<CompiledQueryId, CompiledQuery> _compiledQueries = ImmutableDictionary<CompiledQueryId, CompiledQuery>.Empty;
83+
6984
protected abstract Compilation CreateCompilation(SourceText query, IEnumerable<MetadataReference> references, SolutionServices services, out SyntaxTree queryTree, CancellationToken cancellationToken);
7085

71-
public async Task<ExecuteQueryResult> ExecuteQueryAsync(
72-
Solution solution,
86+
public CompileQueryResult CompileQuery(
87+
SolutionServices services,
7388
string query,
7489
string referenceAssembliesDir,
75-
ISemanticSearchResultsObserver observer,
76-
OptionsProvider<ClassificationOptions> classificationOptions,
7790
TraceSource traceSource,
7891
CancellationToken cancellationToken)
7992
{
80-
try
81-
{
82-
// add progress items - one for compilation, one for emit and one for each project:
83-
var remainingProgressItemCount = 2 + solution.ProjectIds.Count;
84-
await observer.AddItemsAsync(remainingProgressItemCount, cancellationToken).ConfigureAwait(false);
85-
86-
var metadataService = solution.Services.GetRequiredService<IMetadataService>();
87-
var metadataReferences = SemanticSearchUtilities.GetMetadataReferences(metadataService, referenceAssembliesDir);
88-
var queryText = SemanticSearchUtilities.CreateSourceText(query);
89-
var queryCompilation = CreateCompilation(queryText, metadataReferences, solution.Services, out var queryTree, cancellationToken);
93+
var metadataService = services.GetRequiredService<IMetadataService>();
94+
var metadataReferences = SemanticSearchUtilities.GetMetadataReferences(metadataService, referenceAssembliesDir);
95+
var queryText = SemanticSearchUtilities.CreateSourceText(query);
96+
var queryCompilation = CreateCompilation(queryText, metadataReferences, services, out var queryTree, cancellationToken);
9097

91-
cancellationToken.ThrowIfCancellationRequested();
98+
cancellationToken.ThrowIfCancellationRequested();
9299

93-
// complete compilation progress item:
94-
remainingProgressItemCount--;
95-
await observer.ItemsCompletedAsync(1, cancellationToken).ConfigureAwait(false);
100+
var emitOptions = new EmitOptions(
101+
debugInformationFormat: DebugInformationFormat.PortablePdb,
102+
instrumentationKinds: [InstrumentationKind.StackOverflowProbing, InstrumentationKind.ModuleCancellation]);
96103

97-
var emitOptions = new EmitOptions(
98-
debugInformationFormat: DebugInformationFormat.PortablePdb,
99-
instrumentationKinds: [InstrumentationKind.StackOverflowProbing, InstrumentationKind.ModuleCancellation]);
104+
var peStream = new MemoryStream();
105+
var pdbStream = new MemoryStream();
100106

101-
using var peStream = new MemoryStream();
102-
using var pdbStream = new MemoryStream();
103-
104-
var emitDifferenceTimer = SharedStopwatch.StartNew();
105-
var emitResult = queryCompilation.Emit(peStream, pdbStream, options: emitOptions, cancellationToken: cancellationToken);
106-
var emitTime = emitDifferenceTimer.Elapsed;
107-
108-
var executionTime = TimeSpan.Zero;
107+
var emitDifferenceTimer = SharedStopwatch.StartNew();
108+
var emitResult = queryCompilation.Emit(peStream, pdbStream, options: emitOptions, cancellationToken: cancellationToken);
109+
var emitTime = emitDifferenceTimer.Elapsed;
109110

110-
cancellationToken.ThrowIfCancellationRequested();
111+
CompiledQueryId queryId;
112+
ImmutableArray<QueryCompilationError> errors;
113+
if (emitResult.Success)
114+
{
115+
queryId = CompiledQueryId.Create(queryCompilation.Language);
116+
Contract.ThrowIfFalse(ImmutableInterlocked.TryAdd(ref _compiledQueries, queryId, new CompiledQuery(peStream, pdbStream, queryText)));
111117

112-
// complete compilation progress item:
113-
remainingProgressItemCount--;
114-
await observer.ItemsCompletedAsync(1, cancellationToken).ConfigureAwait(false);
118+
errors = [];
119+
}
120+
else
121+
{
122+
queryId = default;
115123

116-
if (!emitResult.Success)
124+
foreach (var diagnostic in emitResult.Diagnostics)
117125
{
118-
foreach (var diagnostic in emitResult.Diagnostics)
126+
if (diagnostic.Severity == DiagnosticSeverity.Error)
119127
{
120-
if (diagnostic.Severity == DiagnosticSeverity.Error)
121-
{
122-
traceSource.TraceInformation($"Semantic search query compilation failed: {diagnostic}");
123-
}
128+
traceSource.TraceInformation($"Semantic search query compilation failed: {diagnostic}");
124129
}
130+
}
125131

126-
var errors = emitResult.Diagnostics.SelectAsArray(
127-
d => d.Severity == DiagnosticSeverity.Error,
128-
d => new QueryCompilationError(d.Id, d.GetMessage(), (d.Location.SourceTree == queryTree) ? d.Location.SourceSpan : default));
132+
errors = emitResult.Diagnostics.SelectAsArray(
133+
d => d.Severity == DiagnosticSeverity.Error,
134+
d => new QueryCompilationError(d.Id, d.GetMessage(), (d.Location.SourceTree == queryTree) ? d.Location.SourceSpan : default));
135+
}
129136

130-
return CreateResult(errors, FeaturesResources.Semantic_search_query_failed_to_compile);
131-
}
137+
return new CompileQueryResult(queryId, errors, emitTime);
138+
}
139+
140+
public void DiscardQuery(CompiledQueryId queryId)
141+
{
142+
Contract.ThrowIfFalse(ImmutableInterlocked.TryRemove(ref _compiledQueries, queryId, out var compiledQuery));
143+
compiledQuery.Dispose();
144+
}
132145

133-
peStream.Position = 0;
134-
pdbStream.Position = 0;
146+
public async Task<ExecuteQueryResult> ExecuteQueryAsync(
147+
Solution solution,
148+
CompiledQueryId queryId,
149+
ISemanticSearchResultsObserver observer,
150+
OptionsProvider<ClassificationOptions> classificationOptions,
151+
TraceSource traceSource,
152+
CancellationToken cancellationToken)
153+
{
154+
Contract.ThrowIfFalse(ImmutableInterlocked.TryRemove(ref _compiledQueries, queryId, out var query));
155+
156+
try
157+
{
158+
var executionTime = TimeSpan.Zero;
159+
160+
var remainingProgressItemCount = solution.ProjectIds.Count;
161+
await observer.AddItemsAsync(remainingProgressItemCount, cancellationToken).ConfigureAwait(false);
162+
163+
query.PEStream.Position = 0;
164+
query.PdbStream.Position = 0;
135165
var loadContext = new LoadContext();
136166
try
137167
{
138-
var queryAssembly = loadContext.LoadFromStream(peStream, pdbStream);
168+
var queryAssembly = loadContext.LoadFromStream(query.PEStream, query.PdbStream);
139169
SetModuleCancellationToken(queryAssembly, cancellationToken);
140170

141171
SetToolImplementations(
@@ -146,17 +176,17 @@ public async Task<ExecuteQueryResult> ExecuteQueryAsync(
146176
if (!TryGetFindMethod(queryAssembly, out var findMethod, out var queryKind, out var errorMessage, out var errorMessageArgs))
147177
{
148178
traceSource.TraceInformation($"Semantic search failed: {errorMessage}");
149-
return CreateResult(compilationErrors: [], errorMessage, errorMessageArgs);
179+
return CreateResult(errorMessage, errorMessageArgs);
150180
}
151181

152-
var invocationContext = new QueryExecutionContext(queryText, findMethod, observer, classificationOptions, traceSource);
182+
var invocationContext = new QueryExecutionContext(query.Text, findMethod, observer, classificationOptions, traceSource);
153183
try
154184
{
155185
await invocationContext.InvokeAsync(solution, queryKind, cancellationToken).ConfigureAwait(false);
156186

157187
if (invocationContext.TerminatedWithException)
158188
{
159-
return CreateResult(compilationErrors: [], FeaturesResources.Semantic_search_query_terminated_with_exception);
189+
return CreateResult(FeaturesResources.Semantic_search_query_terminated_with_exception);
160190
}
161191
}
162192
finally
@@ -176,15 +206,19 @@ public async Task<ExecuteQueryResult> ExecuteQueryAsync(
176206
}
177207
}
178208

179-
return CreateResult(compilationErrors: [], errorMessage: null);
209+
return CreateResult(errorMessage: null);
180210

181-
ExecuteQueryResult CreateResult(ImmutableArray<QueryCompilationError> compilationErrors, string? errorMessage, params string[]? args)
182-
=> new(compilationErrors, errorMessage, args, emitTime, executionTime);
211+
ExecuteQueryResult CreateResult(string? errorMessage, params string[]? args)
212+
=> new(errorMessage, args, executionTime);
183213
}
184214
catch (Exception e) when (FatalError.ReportAndPropagateUnlessCanceled(e, cancellationToken, ErrorSeverity.Critical))
185215
{
186216
throw ExceptionUtilities.Unreachable();
187217
}
218+
finally
219+
{
220+
query.Dispose();
221+
}
188222
}
189223

190224
private static void SetModuleCancellationToken(Assembly queryAssembly, CancellationToken cancellationToken)

src/Features/Core/Portable/SemanticSearch/ExecuteQueryResult.cs

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,59 @@
55
using System;
66
using System.Collections.Immutable;
77
using System.Runtime.Serialization;
8+
using System.Threading;
89

910
namespace Microsoft.CodeAnalysis.SemanticSearch;
1011

1112
/// <summary>
1213
/// The result of Semantic Search query execution.
1314
/// </summary>
14-
/// <param name="compilationErrors">Compilation errors.</param>
1515
/// <param name="ErrorMessage">An error message if the execution failed.</param>
1616
/// <param name="ErrorMessageArgs">
1717
/// Arguments to be substituted to <paramref name="ErrorMessage"/>.
1818
/// Use when the values may contain PII that needs to be obscured in telemetry.
1919
/// Otherwise, <paramref name="ErrorMessage"/> should contain the formatted message.
2020
/// </param>
21-
/// <param name="EmitTime">Time it took to emit the query compilation.</param>
2221
/// <param name="ExecutionTime">Time it took to execute the query.</param>
2322
[DataContract]
2423
internal readonly record struct ExecuteQueryResult(
25-
[property: DataMember(Order = 0)] ImmutableArray<QueryCompilationError> compilationErrors,
26-
[property: DataMember(Order = 1)] string? ErrorMessage,
27-
[property: DataMember(Order = 2)] string[]? ErrorMessageArgs = null,
28-
[property: DataMember(Order = 3)] TimeSpan EmitTime = default,
29-
[property: DataMember(Order = 4)] TimeSpan ExecutionTime = default);
24+
[property: DataMember(Order = 0)] string? ErrorMessage,
25+
[property: DataMember(Order = 1)] string[]? ErrorMessageArgs = null,
26+
[property: DataMember(Order = 2)] TimeSpan ExecutionTime = default);
27+
28+
/// <summary>
29+
/// The result of Semantic Search query compilation.
30+
/// </summary>
31+
/// <param name="QueryId">Id of the compiled query if the compilation was successful.</param>
32+
/// <param name="CompilationErrors">Compilation errors.</param>
33+
/// <param name="EmitTime">Time it took to emit the query compilation.</param>
34+
[DataContract]
35+
internal readonly record struct CompileQueryResult(
36+
[property: DataMember(Order = 0)] CompiledQueryId QueryId,
37+
[property: DataMember(Order = 1)] ImmutableArray<QueryCompilationError> CompilationErrors,
38+
[property: DataMember(Order = 2)] TimeSpan EmitTime = default);
39+
40+
[DataContract]
41+
internal readonly record struct CompiledQueryId
42+
{
43+
private static int s_id;
44+
45+
[DataMember(Order = 0)]
46+
#pragma warning disable IDE0052 // Remove unread private members (https://github.com/dotnet/roslyn/issues/77907)
47+
private readonly int _id;
48+
#pragma warning restore IDE0052
49+
50+
[DataMember(Order = 1)]
51+
#pragma warning disable IDE0052 // Remove unread private members (https://github.com/dotnet/roslyn/issues/77907)
52+
public readonly string Language;
53+
#pragma warning restore IDE0052
54+
55+
private CompiledQueryId(int id, string language)
56+
{
57+
_id = id;
58+
Language = language;
59+
}
60+
61+
public static CompiledQueryId Create(string language)
62+
=> new(Interlocked.Increment(ref s_id), language);
63+
}

src/Features/Core/Portable/SemanticSearch/IRemoteSemanticSearchService.cs

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using Microsoft.CodeAnalysis.Classification;
1111
using Microsoft.CodeAnalysis.ErrorReporting;
1212
using Microsoft.CodeAnalysis.FindUsages;
13+
using Microsoft.CodeAnalysis.Host;
1314
using Microsoft.CodeAnalysis.Host.Mef;
1415
using Microsoft.CodeAnalysis.Remote;
1516

@@ -26,7 +27,9 @@ internal interface ICallback
2627
ValueTask ItemsCompletedAsync(RemoteServiceCallbackId callbackId, int itemCount, CancellationToken cancellationToken);
2728
}
2829

29-
ValueTask<ExecuteQueryResult> ExecuteQueryAsync(Checksum solutionChecksum, RemoteServiceCallbackId callbackId, string language, string query, string referenceAssembliesDir, CancellationToken cancellationToken);
30+
ValueTask<CompileQueryResult> CompileQueryAsync(string query, string language, string referenceAssembliesDir, CancellationToken cancellationToken);
31+
ValueTask<ExecuteQueryResult> ExecuteQueryAsync(Checksum solutionChecksum, RemoteServiceCallbackId callbackId, CompiledQueryId queryId, CancellationToken cancellationToken);
32+
ValueTask DiscardQueryAsync(CompiledQueryId queryId, CancellationToken cancellationToken);
3033
}
3134

3235
internal static class RemoteSemanticSearchServiceProxy
@@ -112,19 +115,41 @@ public async ValueTask<ClassificationOptions> GetClassificationOptionsAsync(stri
112115
}
113116
}
114117

115-
public static async ValueTask<ExecuteQueryResult> ExecuteQueryAsync(Solution solution, string language, string query, string referenceAssembliesDir, ISemanticSearchResultsObserver results, OptionsProvider<ClassificationOptions> classificationOptions, CancellationToken cancellationToken)
118+
public static async ValueTask<CompileQueryResult?> CompileQueryAsync(SolutionServices services, string query, string language, string referenceAssembliesDir, CancellationToken cancellationToken)
116119
{
117-
var client = await RemoteHostClient.TryGetClientAsync(solution.Services, cancellationToken).ConfigureAwait(false);
120+
var client = await RemoteHostClient.TryGetClientAsync(services, cancellationToken).ConfigureAwait(false);
118121
if (client == null)
119122
{
120-
return new ExecuteQueryResult(compilationErrors: [], FeaturesResources.Semantic_search_only_supported_on_net_core);
123+
return null;
121124
}
122125

126+
var result = await client.TryInvokeAsync<IRemoteSemanticSearchService, CompileQueryResult>(
127+
(service, cancellationToken) => service.CompileQueryAsync(query, language, referenceAssembliesDir, cancellationToken),
128+
cancellationToken).ConfigureAwait(false);
129+
130+
return result.Value;
131+
}
132+
133+
public static async ValueTask DiscardQueryAsync(SolutionServices services, CompiledQueryId queryId, CancellationToken cancellationToken)
134+
{
135+
var client = await RemoteHostClient.TryGetClientAsync(services, cancellationToken).ConfigureAwait(false);
136+
Contract.ThrowIfNull(client);
137+
138+
await client.TryInvokeAsync<IRemoteSemanticSearchService>(
139+
(service, cancellationToken) => service.DiscardQueryAsync(queryId, cancellationToken),
140+
cancellationToken).ConfigureAwait(false);
141+
}
142+
143+
public static async ValueTask<ExecuteQueryResult> ExecuteQueryAsync(Solution solution, CompiledQueryId queryId, ISemanticSearchResultsObserver results, OptionsProvider<ClassificationOptions> classificationOptions, CancellationToken cancellationToken)
144+
{
145+
var client = await RemoteHostClient.TryGetClientAsync(solution.Services, cancellationToken).ConfigureAwait(false);
146+
Contract.ThrowIfNull(client);
147+
123148
var serverCallback = new ServerCallback(solution, results, classificationOptions);
124149

125150
var result = await client.TryInvokeAsync<IRemoteSemanticSearchService, ExecuteQueryResult>(
126151
solution,
127-
(service, solutionInfo, callbackId, cancellationToken) => service.ExecuteQueryAsync(solutionInfo, callbackId, language, query, referenceAssembliesDir, cancellationToken),
152+
(service, solutionInfo, callbackId, cancellationToken) => service.ExecuteQueryAsync(solutionInfo, callbackId, queryId, cancellationToken),
128153
callbackTarget: serverCallback,
129154
cancellationToken).ConfigureAwait(false);
130155

src/Features/Core/Portable/SemanticSearch/ISemanticSearchService.cs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,36 @@ namespace Microsoft.CodeAnalysis.SemanticSearch;
1313
internal interface ISemanticSearchService : ILanguageService
1414
{
1515
/// <summary>
16-
/// Executes given <paramref name="query"/> query against <paramref name="solution"/>.
16+
/// Compiles a query. The query has to be executed or discarded.
1717
/// </summary>
18-
/// <param name="solution">The solution snapshot.</param>
1918
/// <param name="query">Query (top-level code).</param>
2019
/// <param name="referenceAssembliesDir">Directory that contains refernece assemblies to be used for compilation of the query.</param>
20+
CompileQueryResult CompileQuery(
21+
SolutionServices services,
22+
string query,
23+
string referenceAssembliesDir,
24+
TraceSource traceSource,
25+
CancellationToken cancellationToken);
26+
27+
/// <summary>
28+
/// Executes given query against <paramref name="solution"/> and discards it.
29+
/// </summary>
30+
/// <param name="solution">The solution snapshot.</param>
31+
/// <param name="queryId">Id of a compiled query.</param>
2132
/// <param name="observer">Observer of the found symbols.</param>
2233
/// <param name="classificationOptions">Options to use to classify the textual representation of the found symbols.</param>
2334
/// <param name="cancellationToken">Cancellation token.</param>
24-
/// <returns>Error message on failure.</returns>
2535
Task<ExecuteQueryResult> ExecuteQueryAsync(
2636
Solution solution,
27-
string query,
28-
string referenceAssembliesDir,
37+
CompiledQueryId queryId,
2938
ISemanticSearchResultsObserver observer,
3039
OptionsProvider<ClassificationOptions> classificationOptions,
3140
TraceSource traceSource,
3241
CancellationToken cancellationToken);
42+
43+
/// <summary>
44+
/// Discards resources associated with compiled query.
45+
/// Only call if the query is not executed.
46+
/// </summary>
47+
void DiscardQuery(CompiledQueryId queryId);
3348
}

0 commit comments

Comments
 (0)