Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mis-reported cancellation of message transmission #498

Merged
merged 4 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions src/StreamJsonRpc.Tests/JsonRpcJsonHeadersTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.Serialization;
using System.Text;
Expand Down Expand Up @@ -182,7 +182,7 @@ public async Task CanPassExceptionFromServer_ErrorData()
Assert.StrictEqual(COR_E_UNAUTHORIZEDACCESS, errorData.HResult);
}

protected override void InitializeFormattersAndHandlers()
protected override void InitializeFormattersAndHandlers(bool controlledFlushingClient)
{
this.clientMessageFormatter = new JsonMessageFormatter
{
Expand All @@ -208,7 +208,9 @@ protected override void InitializeFormattersAndHandlers()
};

this.serverMessageHandler = new HeaderDelimitedMessageHandler(this.serverStream, this.serverStream, this.serverMessageFormatter);
this.clientMessageHandler = new HeaderDelimitedMessageHandler(this.clientStream, this.clientStream, this.clientMessageFormatter);
this.clientMessageHandler = controlledFlushingClient
? new DelayedFlushingHandler(this.clientStream, this.clientMessageFormatter)
: new HeaderDelimitedMessageHandler(this.clientStream, this.clientStream, this.clientMessageFormatter);
}

[DataContract]
Expand Down Expand Up @@ -267,4 +269,23 @@ public override void WriteJson(JsonWriter writer, object value, JsonSerializer s
writer.WriteValue(encoded);
}
}

private class DelayedFlushingHandler : HeaderDelimitedMessageHandler, IControlledFlushHandler
{
public DelayedFlushingHandler(Stream stream, IJsonRpcMessageFormatter formatter)
: base(stream, formatter)
{
}

public AsyncAutoResetEvent FlushEntered { get; } = new AsyncAutoResetEvent();

public AsyncManualResetEvent AllowFlushAsyncExit { get; } = new AsyncManualResetEvent();

protected override async ValueTask FlushAsync(CancellationToken cancellationToken)
{
this.FlushEntered.Set();
await this.AllowFlushAsyncExit.WaitAsync();
await base.FlushAsync(cancellationToken);
}
}
}
29 changes: 25 additions & 4 deletions src/StreamJsonRpc.Tests/JsonRpcMessagePackLengthTests.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using MessagePack;
using MessagePack.Formatters;
Expand Down Expand Up @@ -53,7 +53,7 @@ public async Task CanPassExceptionFromServer_ErrorData()
Assert.StrictEqual(COR_E_UNAUTHORIZEDACCESS, errorData.HResult);
}

protected override void InitializeFormattersAndHandlers()
protected override void InitializeFormattersAndHandlers(bool controlledFlushingClient)
{
this.serverMessageFormatter = new MessagePackFormatter();
this.clientMessageFormatter = new MessagePackFormatter();
Expand All @@ -66,7 +66,9 @@ protected override void InitializeFormattersAndHandlers()
((MessagePackFormatter)this.clientMessageFormatter).SetMessagePackSerializerOptions(options);

this.serverMessageHandler = new LengthHeaderMessageHandler(this.serverStream, this.serverStream, this.serverMessageFormatter);
this.clientMessageHandler = new LengthHeaderMessageHandler(this.clientStream, this.clientStream, this.clientMessageFormatter);
this.clientMessageHandler = controlledFlushingClient
? new DelayedFlushingHandler(this.clientStream, this.clientMessageFormatter)
: new LengthHeaderMessageHandler(this.clientStream, this.clientStream, this.clientMessageFormatter);
}

private class UnserializableTypeFormatter : IMessagePackFormatter<CustomSerializedType>
Expand Down Expand Up @@ -94,4 +96,23 @@ public void Serialize(ref MessagePackWriter writer, TypeThrowsWhenDeserialized v
writer.WriteArrayHeader(0);
}
}

private class DelayedFlushingHandler : LengthHeaderMessageHandler, IControlledFlushHandler
{
public DelayedFlushingHandler(Stream stream, IJsonRpcMessageFormatter formatter)
: base(stream, stream, formatter)
{
}

public AsyncAutoResetEvent FlushEntered { get; } = new AsyncAutoResetEvent();

public AsyncManualResetEvent AllowFlushAsyncExit { get; } = new AsyncManualResetEvent();

protected override async ValueTask FlushAsync(CancellationToken cancellationToken)
{
this.FlushEntered.Set();
await this.AllowFlushAsyncExit.WaitAsync();
await base.FlushAsync(cancellationToken);
}
}
}
42 changes: 39 additions & 3 deletions src/StreamJsonRpc.Tests/JsonRpcTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ public JsonRpcTests(ITestOutputHelper logger)
this.clientRpc.StartListening();
}

protected interface IControlledFlushHandler : IJsonRpcMessageHandler
{
/// <summary>
/// Gets an event that is raised when <see cref="MessageHandlerBase.FlushAsync(CancellationToken)"/> is invoked.
/// </summary>
AsyncAutoResetEvent FlushEntered { get; }

/// <summary>
/// Gets an event that must be set before <see cref="MessageHandlerBase.FlushAsync(CancellationToken)"/> is allowed to return.
/// </summary>
AsyncManualResetEvent AllowFlushAsyncExit { get; }
}

private interface IServer
{
[JsonRpcMethod("AnotherName")]
Expand Down Expand Up @@ -763,6 +776,29 @@ public async Task InvokeThenCancelToken()
}
}

[Fact]
public async Task InvokeThenCancelToken_BetweenWriteAndFlush()
{
this.ReinitializeRpcWithoutListening(controlledFlushingClient: true);
var clientMessageHandler = (IControlledFlushHandler)this.clientMessageHandler;

this.clientRpc.StartListening();
this.serverRpc.StartListening();

using (var cts = new CancellationTokenSource())
{
this.server.AllowServerMethodToReturn.Set();
Task<string> invokeTask = this.clientRpc.InvokeWithCancellationAsync<string>(nameof(this.server.AsyncMethod), new[] { "a" }, cts.Token);
await clientMessageHandler.FlushEntered.WaitAsync(this.TimeoutToken);
cts.Cancel();
clientMessageHandler.AllowFlushAsyncExit.Set();
await invokeTask.WithCancellation(this.TimeoutToken);

string result = await this.clientRpc.InvokeWithCancellationAsync<string>(nameof(this.server.AsyncMethod), new[] { "b" }, this.TimeoutToken).WithCancellation(this.TimeoutToken);
Assert.Equal("b!", result);
}
}

[Fact]
[Trait("Category", "SkipWhenLiveUnitTesting")] // flaky test
[Trait("GC", "")]
Expand Down Expand Up @@ -1895,7 +1931,7 @@ protected override void Dispose(bool disposing)
base.Dispose(disposing);
}

protected abstract void InitializeFormattersAndHandlers();
protected abstract void InitializeFormattersAndHandlers(bool controlledFlushingClient = false);

protected override Task CheckGCPressureAsync(Func<Task> scenario, int maxBytesAllocated = -1, int iterations = 100, int allowedAttempts = 10)
{
Expand Down Expand Up @@ -1926,13 +1962,13 @@ private static IEnumerable<CommonErrorData> FlattenCommonErrorData(CommonErrorDa
}
}

private void ReinitializeRpcWithoutListening()
private void ReinitializeRpcWithoutListening(bool controlledFlushingClient = false)
{
var streams = Nerdbank.FullDuplexStream.CreateStreams();
this.serverStream = streams.Item1;
this.clientStream = streams.Item2;

this.InitializeFormattersAndHandlers();
this.InitializeFormattersAndHandlers(controlledFlushingClient);

this.serverRpc = new JsonRpc(this.serverMessageHandler, this.server);
this.clientRpc = new JsonRpc(this.clientMessageHandler);
Expand Down
14 changes: 13 additions & 1 deletion src/StreamJsonRpc/MessageHandlerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ private enum MessageHandlerState
/// <returns>A task that represents the asynchronous operation.</returns>
/// <exception cref="InvalidOperationException">Thrown when <see cref="CanWrite"/> returns <c>false</c>.</exception>
/// <exception cref="OperationCanceledException">Thrown if <paramref name="cancellationToken"/> is canceled before message transmission begins.</exception>
/// <exception cref="ObjectDisposedException">Thrown if this instance is disposed before or during transmission.</exception>
/// <remarks>
/// Implementations should expect this method to be invoked concurrently
/// and use a queue to preserve message order as they are transmitted one at a time.
Expand All @@ -189,7 +190,18 @@ public async ValueTask WriteAsync(JsonRpcMessage content, CancellationToken canc
{
cancellationToken.ThrowIfCancellationRequested();
await this.WriteCoreAsync(content, cancellationToken).ConfigureAwait(false);
await this.FlushAsync(cancellationToken).ConfigureAwait(false);

// When flushing, do NOT honor the caller's CancellationToken since the writing is done
// and we must not throw OperationCanceledException back at them as if we hadn't transmitted it.
// But *do* cancel flushing if we're being disposed.
try
{
await this.FlushAsync(this.DisposalToken).ConfigureAwait(false);
}
catch (OperationCanceledException ex) when (this.DisposalToken.IsCancellationRequested)
{
throw new ObjectDisposedException(this.GetType().FullName, ex);
}
}
finally
{
Expand Down