From 61901d805dd84da0f6b1c71132b236861355acd4 Mon Sep 17 00:00:00 2001 From: Andrew Arnott Date: Mon, 12 Feb 2024 18:34:43 -0700 Subject: [PATCH] Fixes early destruction of `IAsyncEnumerable` sent as return value from RPC methods In certain scenarios, there may have also been effectively a memory leak that this also fixes. Fixes microsoft/vs-streamjsonrpc#999 --- .../MessageFormatterEnumerableTracker.cs | 29 ++++++++--- .../AsyncEnumerableTests.cs | 48 +++++++++++++++++-- 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs index defcf4164..bb38835d9 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs @@ -39,8 +39,15 @@ public class MessageFormatterEnumerableTracker private static readonly MethodInfo OnDisposeAsyncMethodInfo = typeof(MessageFormatterEnumerableTracker).GetMethod(nameof(OnDisposeAsync), BindingFlags.NonPublic | BindingFlags.Instance)!; /// - /// Dictionary used to map the outbound request id to their progress info so that the progress objects are cleaned after getting the final response. + /// Dictionary used to map the outbound request id to the list of tokens that track state machines it owns + /// so that the state machines are cleaned after getting the final response. /// + /// + /// Note that we only track OUTBOUND REQUESTS that carry enumerables here. + /// OUTBOUND RESPONSES that carry enumerables are not tracked except in . + /// This means that responses that carry enumerables will not be cleaned up if the response is never processed by the client + /// until the connection dies. + /// private readonly Dictionary> generatorTokensByRequestId = new Dictionary>(); private readonly Dictionary generatorsByToken = new Dictionary(); @@ -116,12 +123,20 @@ public long GetToken(IAsyncEnumerable enumerable) long handle = Interlocked.Increment(ref this.nextToken); lock (this.syncObject) { - if (!this.generatorTokensByRequestId.TryGetValue(this.formatterState.SerializingMessageWithId, out ImmutableList? tokens)) + // We only track the token if we are serializing a request, since per our documentation, + // we forcibly terminate the enumerable at the client side when the request has been responded to. + // Storing request IDs for outbound *responses* that carry enumerables would lead to them being disposed of + // when an INBOUND response with the same ID is received. + if (this.formatterState.SerializingRequest) { - tokens = ImmutableList.Empty; + if (!this.generatorTokensByRequestId.TryGetValue(this.formatterState.SerializingMessageWithId, out ImmutableList? tokens)) + { + tokens = ImmutableList.Empty; + } + + this.generatorTokensByRequestId[this.formatterState.SerializingMessageWithId] = tokens.Add(handle); } - this.generatorTokensByRequestId[this.formatterState.SerializingMessageWithId] = tokens.Add(handle); this.generatorsByToken.Add(handle, new GeneratingEnumeratorTracker(this, handle, enumerable, settings: enumerable.GetJsonRpcSettings())); } @@ -173,18 +188,18 @@ private ValueTask OnDisposeAsync(long token) return generator.DisposeAsync(); } - private void CleanUpResources(RequestId requestId) + private void CleanUpResources(RequestId outboundRequestId) { lock (this.syncObject) { - if (this.generatorTokensByRequestId.TryGetValue(requestId, out ImmutableList? tokens)) + if (this.generatorTokensByRequestId.TryGetValue(outboundRequestId, out ImmutableList? tokens)) { foreach (var token in tokens) { this.generatorsByToken.Remove(token); } - this.generatorTokensByRequestId.Remove(requestId); + this.generatorTokensByRequestId.Remove(outboundRequestId); } } } diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs index 6cbe8a40c..7d884573e 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs @@ -11,17 +11,17 @@ using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; using Newtonsoft.Json; -using StreamJsonRpc; -using Xunit; -using Xunit.Abstractions; public abstract class AsyncEnumerableTests : TestBase, IAsyncLifetime { - protected readonly Server server = new Server(); + protected readonly Server server = new(); + protected readonly Client client = new(); + protected JsonRpc serverRpc; protected IJsonRpcMessageFormatter serverMessageFormatter; protected Lazy clientProxy; + protected Lazy serverProxy; protected JsonRpc clientRpc; protected IJsonRpcMessageFormatter clientMessageFormatter; @@ -73,6 +73,13 @@ protected interface IServer Task PassInNumbersAndIgnoreAsync(IAsyncEnumerable numbers, CancellationToken cancellationToken); Task PassInNumbersOnlyStartEnumerationAsync(IAsyncEnumerable numbers, CancellationToken cancellationToken); + + IAsyncEnumerable CallbackClientAndYieldOneValueAsync(CancellationToken cancellationToken); + } + + protected interface IClient + { + Task DoSomethingAsync(CancellationToken cancellationToken); } public Task InitializeAsync() @@ -85,7 +92,7 @@ public Task InitializeAsync() var clientHandler = new LengthHeaderMessageHandler(streams.Item2.UsePipe(), this.clientMessageFormatter); this.serverRpc = new JsonRpc(serverHandler, this.server); - this.clientRpc = new JsonRpc(clientHandler); + this.clientRpc = new JsonRpc(clientHandler, this.client); this.serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Verbose); this.clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Verbose); @@ -97,6 +104,7 @@ public Task InitializeAsync() this.clientRpc.StartListening(); this.clientProxy = new Lazy(() => this.clientRpc.Attach()); + this.serverProxy = new Lazy(() => this.serverRpc.Attach()); return Task.CompletedTask; } @@ -530,6 +538,17 @@ public async Task AsyncIteratorThrows(int minBatchSize, int maxReadAhead, int pr Assert.Equal(Server.FailByDesignExceptionMessage, ex.Message); } + [Fact] + public async Task EnumerableIdDisposal() + { + // This test is specially arranged to create two RPC calls going opposite directions, with the same request ID. + // By doing so, we can verify that the server doesn't dispose the enumerable until the full sequence is sent to the client. + this.server.Client = this.serverProxy.Value; + await foreach (string s in this.clientProxy.Value.CallbackClientAndYieldOneValueAsync(this.TimeoutToken)) + { + } + } + protected abstract void InitializeFormattersAndHandlers(); private static void AssertCollectedObject(WeakReference weakReference) @@ -621,6 +640,8 @@ protected class Server : IServer internal const string FailByDesignExceptionMessage = "Fail by design"; + public IClient? Client { get; set; } + public AsyncManualResetEvent MethodEntered { get; } = new AsyncManualResetEvent(); public AsyncManualResetEvent MethodExited { get; } = new AsyncManualResetEvent(); @@ -745,6 +766,18 @@ public Task GetNumbersAndMetadataAsync(CancellationTok }); } + public async IAsyncEnumerable CallbackClientAndYieldOneValueAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + if (this.Client is null) + { + throw new InvalidOperationException("Client must be set before calling this method."); + } + + // We deliberately make a callback right away such that the request ID for it collides with the request ID that served THIS request. + await this.Client.DoSomethingAsync(cancellationToken); + yield return "Hello"; + } + private async IAsyncEnumerable GetNumbersAsync(int totalCount, bool endWithException, [EnumeratorCancellation] CancellationToken cancellationToken) { for (int i = 1; i <= totalCount; i++) @@ -763,6 +796,11 @@ private async IAsyncEnumerable GetNumbersAsync(int totalCount, bool endWith } } + protected class Client : IClient + { + public Task DoSomethingAsync(CancellationToken cancellationToken) => Task.CompletedTask; + } + [DataContract] protected class CompoundEnumerableResult {