Skip to content
This repository has been archived by the owner on Dec 18, 2018. It is now read-only.

Allow CancellationToken in streaming hub methods #2818

Merged
merged 6 commits into from
Sep 19, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,45 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
{
InitializeHub(hub, connection);

var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments);
CancellationTokenSource cts = null;
var arguments = hubMethodInvocationMessage.Arguments;
if (descriptor.HasSyntheticArguments)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep this concept. It's cleaner than having the property called "HasCancellationToken"

{
// In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments)
arguments = new object[descriptor.OriginalParameterTypes.Count];

var hubInvocationArgumentPointer = 0;
for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++)
{
if (hubMethodInvocationMessage.Arguments.Length > hubInvocationArgumentPointer &&
hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType() == descriptor.OriginalParameterTypes[parameterPointer])
{
// The types match so it isn't a synthetic argument, just copy it into the arguments array
arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer];
hubInvocationArgumentPointer++;
}
else
{
// This is the only synthetic argument type we currently support
if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken))
{
cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
arguments[parameterPointer] = cts.Token;
}
else
{
// This should never happen
Debug.Assert(false, $"Failed to bind argument of type '{descriptor.OriginalParameterTypes[parameterPointer].Name}' for hub method '{methodExecutor.MethodInfo.Name}'.");
}
}
}
}

var result = await ExecuteHubMethod(methodExecutor, hub, arguments);

if (isStreamedInvocation)
{
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, out var streamCts))
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, ref cts))
{
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);

Expand All @@ -204,7 +238,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
disposeScope = false;
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
// Fire-and-forget stream invocations, otherwise they would block other hub invocations from being able to run
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, streamCts);
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts);
}
// Non-empty/null InvocationId ==> Blocking invocation that needs a response
else if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
Expand Down Expand Up @@ -375,29 +409,24 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa
return true;
}

private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, out CancellationTokenSource streamCts)
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, ref CancellationTokenSource streamCts)
{
if (result != null)
{
if (hubMethodDescriptor.IsChannel)
{
streamCts = CreateCancellation();
if (streamCts == null)
{
streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
}
connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts);
enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token);
return true;
}
}

streamCts = null;
enumerator = null;
return false;

CancellationTokenSource CreateCancellation()
{
var userCts = new CancellationTokenSource();
connection.ActiveRequestCancellationSources.TryAdd(invocationId, userCts);

return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, userCts.Token);
}
}

private void DiscoverHubMethods()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ internal class HubMethodDescriptor
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
{
MethodExecutor = methodExecutor;
ParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
Policies = policies.ToArray();

NonAsyncReturnType = (MethodExecutor.IsMethodAsync)
? MethodExecutor.AsyncResultType
Expand All @@ -34,6 +32,25 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut
IsChannel = true;
StreamReturnType = channelItemType;
}

// Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers
ParameterTypes = methodExecutor.MethodParameters.Where(p =>
{
// Only streams can take CancellationTokens currently
if (IsStreamable && p.ParameterType == typeof(CancellationToken))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this check into the SyntheticArgumentManager? Something like SyntheticArgumentManager.IsSyntheticArgumentType(...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disregard this comment now.

{
HasSyntheticArguments = true;
return false;
}
return true;
}).Select(p => p.ParameterType).ToArray();

if (HasSyntheticArguments)
{
OriginalParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
}

Policies = policies.ToArray();
}

private Func<object, CancellationToken, IAsyncEnumerator<object>> _convertToEnumerator;
Expand All @@ -42,6 +59,8 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut

public IReadOnlyList<Type> ParameterTypes { get; }

public IReadOnlyList<Type> OriginalParameterTypes { get; }

public Type NonAsyncReturnType { get; }

public bool IsChannel { get; }
Expand All @@ -52,6 +71,8 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut

public IList<IAuthorizeData> Policies { get; }

public bool HasSyntheticArguments { get; private set; }

private static bool IsChannelType(Type type, out Type payloadType)
{
var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
Expand Down Expand Up @@ -165,6 +166,10 @@ public Task ProtocolError()
return Clients.Caller.SendAsync("Send", new string('x', 3000), new SelfRef());
}

public void InvalidArgument(CancellationToken token)
{
}

private class SelfRef
{
public SelfRef()
Expand Down Expand Up @@ -547,6 +552,51 @@ public async Task<ChannelReader<string>> LongRunningStream()
return Channel.CreateUnbounded<string>().Reader;
}

public ChannelReader<int> CancelableStream(CancellationToken token)
{
var channel = Channel.CreateBounded<int>(10);

Task.Run(() =>
{
_tcsService.StartedMethod.SetResult(null);
token.WaitHandle.WaitOne();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turn this into a TCS and await it instead. Having spent a large part of the past few weeks fighting legacy SignalR tests that use WaitHandles, I don't want to start a precedent :).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest an making an extension method that does the conversion: WaitForCancellationAsync

channel.Writer.TryComplete();
_tcsService.EndMethod.SetResult(null);
});

return channel.Reader;
}

public ChannelReader<int> CancelableStream2(int ignore, int ignore2, CancellationToken token)
{
var channel = Channel.CreateBounded<int>(10);

Task.Run(() =>
{
_tcsService.StartedMethod.SetResult(null);
token.WaitHandle.WaitOne();
channel.Writer.TryComplete();
_tcsService.EndMethod.SetResult(null);
});

return channel.Reader;
}

public ChannelReader<int> CancelableStreamMiddle(int ignore, CancellationToken token, int ignore2)
{
var channel = Channel.CreateBounded<int>(10);

Task.Run(() =>
{
_tcsService.StartedMethod.SetResult(null);
token.WaitHandle.WaitOne();
channel.Writer.TryComplete();
_tcsService.EndMethod.SetResult(null);
});

return channel.Reader;
}

public int SimpleMethod()
{
return 21;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2381,6 +2381,95 @@ public async Task ConnectionAbortedIfSendFailsWithProtocolError()
}
}

[Theory]
[InlineData(nameof(LongRunningHub.CancelableStream))]
[InlineData(nameof(LongRunningHub.CancelableStream2), 1, 2)]
[InlineData(nameof(LongRunningHub.CancelableStreamMiddle), 1, 2)]
public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnCancellation(string methodName, params object[] args)
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(tcsService);
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();

var streamInvocationId = await client.SendStreamInvocationAsync(methodName, args).OrTimeout();
// Wait for the stream method to start
await tcsService.StartedMethod.Task.OrTimeout();

// Cancel the stream which should trigger the CancellationToken in the hub method
await client.SendHubMessageAsync(new CancelInvocationMessage(streamInvocationId)).OrTimeout();

var result = await client.ReadAsync().OrTimeout();

var simpleCompletion = Assert.IsType<CompletionMessage>(result);
Assert.Null(simpleCompletion.Result);

// CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled.
await tcsService.EndMethod.Task.OrTimeout();

// Shut down
client.Dispose();

await connectionHandlerTask.OrTimeout();
}
}

[Fact]
public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnConnectionAborted()
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(tcsService);
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();

var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStream)).OrTimeout();
// Wait for the stream method to start
await tcsService.StartedMethod.Task.OrTimeout();

// Shut down the client which should trigger the CancellationToken in the hub method
client.Dispose();

// CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled.
await tcsService.EndMethod.Task.OrTimeout();

await connectionHandlerTask.OrTimeout();
}
}

[Fact]
public async Task InvokeHubMethodCannotAcceptCancellationTokenAsArgument()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();

var invocationId = await client.SendInvocationAsync(nameof(MethodHub.InvalidArgument)).OrTimeout();

var completion = Assert.IsType<CompletionMessage>(await client.ReadAsync().OrTimeout());

Assert.Equal("Failed to invoke 'InvalidArgument' due to an error on the server.", completion.Error);

client.Dispose();

await connectionHandlerTask.OrTimeout();
}
}

private class CustomHubActivator<THub> : IHubActivator<THub> where THub : Hub
{
public int ReleaseCount;
Expand Down