Skip to content

Commit

Permalink
Add ICancellationStrategy.OutboundRequestEnded
Browse files Browse the repository at this point in the history
This is important to support strategy implementations that create a file in `CancelOutboundRequest` and must delete it when the cancellation request has completed.
  • Loading branch information
AArnott committed Sep 4, 2020
1 parent d0796bc commit d535fbf
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 21 deletions.
109 changes: 99 additions & 10 deletions src/StreamJsonRpc.Tests/CustomCancellationStrategyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.Threading;
Expand All @@ -25,7 +23,7 @@ public class CustomCancellationStrategyTests : TestBase
public CustomCancellationStrategyTests(ITestOutputHelper logger)
: base(logger)
{
this.mockStrategy = new MockCancellationStrategy(logger);
this.mockStrategy = new MockCancellationStrategy(this, logger);

var streams = FullDuplexStream.CreatePair();
this.clientRpc = new JsonRpc(streams.Item1)
Expand Down Expand Up @@ -55,48 +53,119 @@ public CustomCancellationStrategyTests(ITestOutputHelper logger)
/// Verifies that cancellation can occur through a custom strategy.
/// </summary>
[Fact]
public async Task CancelRequest()
public async Task CancelRequest_ServerMethodReturns()
{
using var cts = new CancellationTokenSource();
Task invokeTask = this.clientRpc.InvokeWithCancellationAsync(nameof(Server.NoticeCancellationAsync), cancellationToken: cts.Token);
Task invokeTask = this.clientRpc.InvokeWithCancellationAsync(nameof(Server.NoticeCancellationAsync), new object?[] { false }, cancellationToken: cts.Token);
var completingTask = await Task.WhenAny(invokeTask, this.server.MethodEntered.WaitAsync()).WithCancellation(this.TimeoutToken);
await completingTask; // rethrow an exception if there is one.

cts.Cancel();
await invokeTask.WithCancellation(this.TimeoutToken);
Assert.True(this.mockStrategy.CancelRequestMade);
await this.mockStrategy.OutboundRequestEndedInvoked.WaitAsync(this.TimeoutToken);
}

/// <summary>
/// Verifies that cancellation can occur through a custom strategy.
/// </summary>
[Fact]
public async Task CancelRequest_ServerMethodThrows()
{
using var cts = new CancellationTokenSource();
Task invokeTask = this.clientRpc.InvokeWithCancellationAsync(nameof(Server.NoticeCancellationAsync), new object?[] { true }, cancellationToken: cts.Token);
var completingTask = await Task.WhenAny(invokeTask, this.server.MethodEntered.WaitAsync()).WithCancellation(this.TimeoutToken);
await completingTask; // rethrow an exception if there is one.

cts.Cancel();
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => invokeTask).WithCancellation(this.TimeoutToken);
Assert.True(this.mockStrategy.CancelRequestMade);
await this.mockStrategy.OutboundRequestEndedInvoked.WaitAsync(this.TimeoutToken);
}

[Fact]
public async Task UncanceledRequest_GetsNoClientSideInvocations()
{
using var cts = new CancellationTokenSource();
await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.EmptyMethod), cancellationToken: cts.Token);

Assert.False(this.mockStrategy.CancelRequestMade);
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => this.mockStrategy.OutboundRequestEndedInvoked.WaitAsync(ExpectedTimeoutToken));
}

/// <summary>
/// This test attempts to force the timing issue where <see cref="ICancellationStrategy.OutboundRequestEnded(RequestId)"/>
/// is called before <see cref="ICancellationStrategy.CancelOutboundRequest(RequestId)"/> has finished executing
/// as this would be most undesirable for any implementation that needs to clean up state in the former that is created by the latter.
/// </summary>
[Fact]
public async Task OutboundCancellationStartAndRequestFinishOverlap()
{
using var cts = new CancellationTokenSource();
Task invokeTask = this.clientRpc.InvokeWithCancellationAsync(nameof(Server.EmptyMethod), cancellationToken: cts.Token);
var completingTask = await Task.WhenAny(invokeTask, this.server.MethodEntered.WaitAsync()).WithCancellation(this.TimeoutToken);
await completingTask; // rethrow an exception if there is one.

this.mockStrategy.AllowCancelOutboundRequestToExit.Reset();
cts.Cancel();

// This may be invoked, but if the product doesn't invoke it, that's ok too.
////await this.mockStrategy.OutboundRequestEndedInvoked.WaitAsync(this.TimeoutToken);
}

private class Server
{
internal AsyncManualResetEvent MethodEntered { get; } = new AsyncManualResetEvent();

public async Task NoticeCancellationAsync(CancellationToken cancellationToken)
public async Task NoticeCancellationAsync(bool throwOnCanceled, CancellationToken cancellationToken)
{
this.MethodEntered.Set();
var canceled = new AsyncManualResetEvent();
using (cancellationToken.Register(canceled.Set))
{
await canceled.WaitAsync();
if (throwOnCanceled)
{
cancellationToken.ThrowIfCancellationRequested();
}
}
}

public void EmptyMethod(CancellationToken cancellationToken)
{
this.MethodEntered.Set();
}
}

private class MockCancellationStrategy : ICancellationStrategy
{
private readonly Dictionary<RequestId, CancellationTokenSource> cancelableRequests = new Dictionary<RequestId, CancellationTokenSource>();
private readonly CustomCancellationStrategyTests owner;
private readonly ITestOutputHelper logger;
private readonly List<RequestId> endedRequestIds = new List<RequestId>();

internal MockCancellationStrategy(ITestOutputHelper logger)
internal MockCancellationStrategy(CustomCancellationStrategyTests owner, ITestOutputHelper logger)
{
this.owner = owner;
this.logger = logger;
}

internal bool CancelRequestMade { get; private set; }

internal AsyncAutoResetEvent OutboundRequestEndedInvoked { get; } = new AsyncAutoResetEvent();

internal ManualResetEventSlim AllowCancelOutboundRequestToExit { get; } = new ManualResetEventSlim(initialState: true);

public void CancelOutboundRequest(RequestId requestId)
{
this.logger.WriteLine("Canceling outbound request: {0}", requestId);
this.logger.WriteLine($"{nameof(this.CancelOutboundRequest)}({requestId})");
lock (this.endedRequestIds)
{
// We should NEVER be called about a request ID that has already ended.
// If so, the product can cause a custom cancellation strategy to leak state.
Assert.DoesNotContain(requestId, this.endedRequestIds);
}

CancellationTokenSource? cts;
lock (this.cancelableRequests)
{
Expand All @@ -105,11 +174,31 @@ public void CancelOutboundRequest(RequestId requestId)

cts?.Cancel();
this.CancelRequestMade = true;

// Wait for the out of order invocation to happen if it's possible,
// so the OutboundCancellationStartAndRequestFinishOverlap test can catch it.
// Otherwise timeout, which is necessary to avoid a test hang when the product DOES work,
// since it shouldn't allow OutboundRequestEnded to execute before this method exits.
if (!this.AllowCancelOutboundRequestToExit.Wait(ExpectedTimeout))
{
this.logger.WriteLine("Timed out waiting for " + nameof(this.AllowCancelOutboundRequestToExit) + " to be signaled (good thing).");
}
}

public void OutboundRequestEnded(RequestId requestId)
{
this.logger.WriteLine($"{nameof(this.OutboundRequestEnded)}({requestId}) invoked.");
lock (this.endedRequestIds)
{
this.endedRequestIds.Add(requestId);
}

this.OutboundRequestEndedInvoked.Set();
}

public void IncomingRequestStarted(RequestId requestId, CancellationTokenSource cancellationTokenSource)
{
this.logger.WriteLine("Recognizing incoming request start: {0}", requestId);
this.logger.WriteLine($"{nameof(this.IncomingRequestStarted)}({requestId})");
lock (this.cancelableRequests)
{
this.cancelableRequests.Add(requestId, cancellationTokenSource);
Expand All @@ -118,7 +207,7 @@ public void IncomingRequestStarted(RequestId requestId, CancellationTokenSource

public void IncomingRequestEnded(RequestId requestId)
{
this.logger.WriteLine("Recognizing incoming request end: {0}", requestId);
this.logger.WriteLine($"{nameof(this.IncomingRequestEnded)}({requestId})");
lock (this.cancelableRequests)
{
this.cancelableRequests.Remove(requestId);
Expand Down
19 changes: 17 additions & 2 deletions src/StreamJsonRpc/ICancellationStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,25 @@ public interface ICancellationStrategy
/// the RPC server can understand.
/// </summary>
/// <param name="requestId">The ID of the canceled request.</param>
/// <remarks>
/// Every call to this method is followed by a subsequent call to <see cref="OutboundRequestEnded(RequestId)"/>.
/// </remarks>
void CancelOutboundRequest(RequestId requestId);

/// <summary>
/// Reports an incoming request and the <see cref="CancellationTokenSource"/> that is assigned to it.
/// Cleans up any state associated with an earlier <see cref="CancelOutboundRequest(RequestId)"/> call.
/// </summary>
/// <param name="requestId">The ID of the canceled request.</param>
/// <remarks>
/// This method is invoked by <see cref="JsonRpc"/> when the response to a canceled request has been received.
/// It *may* be invoked for requests for which a prior call to <see cref="CancelOutboundRequest(RequestId)"/> was *not* made, due to timing.
/// But it should never be invoked concurrently with <see cref="CancelOutboundRequest(RequestId)"/> for the same <see cref="RequestId"/>.
/// </remarks>
void OutboundRequestEnded(RequestId requestId);

/// <summary>
/// Associates the <see cref="RequestId"/> from an incoming request with the <see cref="CancellationTokenSource"/>
/// that is used for the <see cref="CancellationToken"/> passed to that RPC method so it can be canceled later.
/// </summary>
/// <param name="requestId">The ID of the incoming request.</param>
/// <param name="cancellationTokenSource">A means to cancel the <see cref="CancellationToken"/> that will be used when invoking the RPC server method.</param>
Expand All @@ -34,7 +49,7 @@ public interface ICancellationStrategy
void IncomingRequestStarted(RequestId requestId, CancellationTokenSource cancellationTokenSource);

/// <summary>
/// Reports that an incoming request is no longer a candidate for cancellation.
/// Cleans up any state associated with an earlier <see cref="IncomingRequestStarted(RequestId, CancellationTokenSource)"/> call.
/// </summary>
/// <param name="requestId">The ID of the request that has been fulfilled.</param>
/// <remarks>
Expand Down
16 changes: 13 additions & 3 deletions src/StreamJsonRpc/JsonRpc.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1866,10 +1866,20 @@ private static bool TryGetTaskFromValueTask(object? result, [NotNullWhen(true)]
}

// Arrange for sending a cancellation message if canceled while we're waiting for a response.
using (cancellationToken.Register(this.cancelPendingOutboundRequestAction, request.RequestId, useSynchronizationContext: false))
try
{
using (cancellationToken.Register(this.cancelPendingOutboundRequestAction, request.RequestId, useSynchronizationContext: false))
{
// This task will be completed when the Response object comes back from the other end of the pipe
return await tcs.Task.ConfigureAwait(false);
}
}
finally
{
// This task will be completed when the Response object comes back from the other end of the pipe
return await tcs.Task.ConfigureAwait(false);
if (cancellationToken.IsCancellationRequested)
{
this.CancellationStrategy?.OutboundRequestEnded(request.RequestId);
}
}
}
}
Expand Down
14 changes: 8 additions & 6 deletions src/StreamJsonRpc/StandardCancellationStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@ namespace StreamJsonRpc
using System.Collections.Generic;
using System.Diagnostics;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft;
using Microsoft.VisualStudio.Threading;
using StreamJsonRpc.Protocol;

internal class StandardCancellationStrategy : ICancellationStrategy
{
Expand Down Expand Up @@ -56,7 +53,7 @@ public StandardCancellationStrategy(JsonRpc jsonRpc)
internal JsonRpc JsonRpc { get; }

/// <inheritdoc />
public virtual void IncomingRequestStarted(RequestId requestId, CancellationTokenSource cancellationTokenSource)
public void IncomingRequestStarted(RequestId requestId, CancellationTokenSource cancellationTokenSource)
{
lock (this.inboundCancellationSources)
{
Expand All @@ -65,7 +62,7 @@ public virtual void IncomingRequestStarted(RequestId requestId, CancellationToke
}

/// <inheritdoc />
public virtual void IncomingRequestEnded(RequestId requestId)
public void IncomingRequestEnded(RequestId requestId)
{
lock (this.inboundCancellationSources)
{
Expand All @@ -74,7 +71,7 @@ public virtual void IncomingRequestEnded(RequestId requestId)
}

/// <inheritdoc />
public virtual void CancelOutboundRequest(RequestId requestId)
public void CancelOutboundRequest(RequestId requestId)
{
Task.Run(async delegate
{
Expand All @@ -95,6 +92,11 @@ public virtual void CancelOutboundRequest(RequestId requestId)
}).Forget();
}

/// <inheritdoc />
public void OutboundRequestEnded(RequestId requestId)
{
}

/// <summary>
/// Cancels an inbound request that was previously received by <see cref="IncomingRequestStarted(RequestId, CancellationTokenSource)"/>.
/// </summary>
Expand Down
1 change: 1 addition & 0 deletions src/StreamJsonRpc/netcoreapp2.1/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ StreamJsonRpc.ICancellationStrategy
StreamJsonRpc.ICancellationStrategy.CancelOutboundRequest(StreamJsonRpc.RequestId requestId) -> void
StreamJsonRpc.ICancellationStrategy.IncomingRequestEnded(StreamJsonRpc.RequestId requestId) -> void
StreamJsonRpc.ICancellationStrategy.IncomingRequestStarted(StreamJsonRpc.RequestId requestId, System.Threading.CancellationTokenSource! cancellationTokenSource) -> void
StreamJsonRpc.ICancellationStrategy.OutboundRequestEnded(StreamJsonRpc.RequestId requestId) -> void
StreamJsonRpc.JsonRpc.AddLocalRpcTarget(System.Type! exposingMembersOn, object! target, StreamJsonRpc.JsonRpcTargetOptions? options) -> void
StreamJsonRpc.JsonRpc.AddLocalRpcTarget<T>(T target, StreamJsonRpc.JsonRpcTargetOptions? options) -> void
StreamJsonRpc.JsonRpc.CancellationStrategy.get -> StreamJsonRpc.ICancellationStrategy?
Expand Down
1 change: 1 addition & 0 deletions src/StreamJsonRpc/netstandard2.0/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ StreamJsonRpc.ICancellationStrategy
StreamJsonRpc.ICancellationStrategy.CancelOutboundRequest(StreamJsonRpc.RequestId requestId) -> void
StreamJsonRpc.ICancellationStrategy.IncomingRequestEnded(StreamJsonRpc.RequestId requestId) -> void
StreamJsonRpc.ICancellationStrategy.IncomingRequestStarted(StreamJsonRpc.RequestId requestId, System.Threading.CancellationTokenSource! cancellationTokenSource) -> void
StreamJsonRpc.ICancellationStrategy.OutboundRequestEnded(StreamJsonRpc.RequestId requestId) -> void
StreamJsonRpc.JsonRpc.AddLocalRpcTarget(System.Type! exposingMembersOn, object! target, StreamJsonRpc.JsonRpcTargetOptions? options) -> void
StreamJsonRpc.JsonRpc.AddLocalRpcTarget<T>(T target, StreamJsonRpc.JsonRpcTargetOptions? options) -> void
StreamJsonRpc.JsonRpc.CancellationStrategy.get -> StreamJsonRpc.ICancellationStrategy?
Expand Down

0 comments on commit d535fbf

Please sign in to comment.