Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for DI in Hub methods #34047

Merged
merged 11 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
CancellationTokenSource? cts = null;
if (descriptor.HasSyntheticArguments)
{
ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, ref arguments, out cts);
ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, scope, ref arguments, out cts);
}

if (isStreamResponse)
Expand Down Expand Up @@ -601,7 +601,7 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa
}

private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamCall,
HubConnectionContext connection, ref object?[] arguments, out CancellationTokenSource? cts)
HubConnectionContext connection, AsyncServiceScope scope, ref object?[] arguments, out CancellationTokenSource? cts)
{
cts = null;
// In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments)
Expand All @@ -626,6 +626,10 @@ private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocatio
cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
arguments[parameterPointer] = cts.Token;
}
else if (descriptor.IsServiceArgument(parameterPointer))
{
arguments[parameterPointer] = scope.ServiceProvider.GetRequiredService(descriptor.OriginalParameterTypes[parameterPointer]);
}
else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true))
{
Log.StartingParameterStream(_logger, hubMethodInvocationMessage.StreamIds![streamPointer]);
Expand Down
21 changes: 20 additions & 1 deletion src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Reflection;
using System.Threading.Channels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http.Metadata;
using Microsoft.Extensions.Internal;

namespace Microsoft.AspNetCore.SignalR.Internal;
Expand All @@ -22,6 +23,8 @@ internal class HubMethodDescriptor

private readonly MethodInfo? _makeCancelableEnumeratorMethodInfo;
private Func<object, CancellationToken, IAsyncEnumerator<object>>? _makeCancelableEnumerator;
// bitset to store which parameters come from DI
private int _isServiceArgument;

public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
{
Expand Down Expand Up @@ -56,7 +59,7 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut
}

// Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers
ParameterTypes = methodExecutor.MethodParameters.Where(p =>
ParameterTypes = methodExecutor.MethodParameters.Where((p, index) =>
{
// Only streams can take CancellationTokens currently
if (IsStreamResponse && p.ParameterType == typeof(CancellationToken))
Expand All @@ -75,6 +78,17 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut
HasSyntheticArguments = true;
return false;
}
else if (p.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType)))
BrennanConroy marked this conversation as resolved.
Show resolved Hide resolved
{
if (index >= 32)
{
throw new InvalidOperationException(
"Hub methods can't use services from DI in the parameters after the 32nd parameter.");
}
_isServiceArgument |= (1 << index);
HasSyntheticArguments = true;
return false;
}
return true;
}).Select(p => p.ParameterType).ToArray();

Expand Down Expand Up @@ -104,6 +118,11 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut

public bool HasSyntheticArguments { get; private set; }

public bool IsServiceArgument(int argumentIndex)
{
return (_isServiceArgument & (1 << argumentIndex)) != 0;
}

public IAsyncEnumerator<object> FromReturnedStream(object stream, CancellationToken cancellationToken)
{
// there is the potential for compile to be called times but this has no harmful effect other than perf
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http.Metadata;
using Newtonsoft.Json.Serialization;

namespace Microsoft.AspNetCore.SignalR.Tests;
Expand Down Expand Up @@ -1247,3 +1243,57 @@ public void SetCaller(IClientProxy caller)
Caller = caller;
}
}

[AttributeUsage(AttributeTargets.Parameter, AllowMultiple = false, Inherited = true)]
public class FromService : Attribute, IFromServiceMetadata
{ }
public class Service1
{ }
public class Service2
{ }
public class Service3
{ }

public class ServicesHub : TestHub
{
public bool SingleService([FromService] Service1 service)
{
return true;
}

public bool MultipleServices([FromService] Service1 service, [FromService] Service2 service2, [FromService] Service3 service3)
{
return true;
}

public async Task<int> ServicesAndParams(int value, [FromService] Service1 service, ChannelReader<int> channelReader, [FromService] Service2 service2, bool value2)
{
int total = 0;
while (await channelReader.WaitToReadAsync())
{
total += await channelReader.ReadAsync();
}
return total + value;
}

public int ServiceWithoutAttribute(Service1 service)
{
return 1;
}

public async Task Stream(ChannelReader<int> channelReader)
{
while (await channelReader.WaitToReadAsync())
{
await channelReader.ReadAsync();
}
}
}

public class TooManyParamsHub : Hub
{
public void ManyParams(int a, string b, bool c, float d, string e, int f, int g, int h, int i, int j, int k,
int l, int m, int n, int o, int p, int q, int r, int s, int t, int u, int v, int w, int x, int y, int z,
int aa, int ab, int ac, int ad, int ae, int af, [FromService] Service1 service)
{ }
}
137 changes: 130 additions & 7 deletions src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Security.Claims;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Channels;
using MessagePack;
using MessagePack.Formatters;
using MessagePack.Resolvers;
Expand All @@ -32,7 +27,6 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Newtonsoft.Json.Serialization;
using Xunit;

namespace Microsoft.AspNetCore.SignalR.Tests;

Expand Down Expand Up @@ -4597,6 +4591,135 @@ public async Task CanSendThroughIHubContextBaseHub()
}
}

[Fact]
public async Task HubMethodFailsIfServiceNotFound()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
provider.AddSignalR(o => o.EnableDetailedErrors = true);
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.SingleService)).DefaultTimeout();
Assert.Equal("An unexpected error occurred invoking 'SingleService' on the server. InvalidOperationException: No service for type 'Microsoft.AspNetCore.SignalR.Tests.Service1' has been registered.", res.Error);
}
}

[Fact]
public async Task HubMethodCanInjectService()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
provider.AddSingleton<Service1>();
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.SingleService)).DefaultTimeout();
Assert.True(Assert.IsType<bool>(res.Result));
}
}

[Fact]
public async Task HubMethodCanInjectMultipleServices()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
provider.AddSingleton<Service1>();
provider.AddSingleton<Service2>();
provider.AddSingleton<Service3>();
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.MultipleServices)).DefaultTimeout();
Assert.True(Assert.IsType<bool>(res.Result));
}
}

[Fact]
public async Task HubMethodCanInjectServicesWithOtherParameters()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
provider.AddSingleton<Service1>();
provider.AddSingleton<Service2>();
provider.AddSingleton<Service3>();
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
await client.BeginUploadStreamAsync("0", nameof(ServicesHub.ServicesAndParams), new string[] { "1" }, 10, true).DefaultTimeout();

await client.SendHubMessageAsync(new StreamItemMessage("1", 1)).DefaultTimeout();
await client.SendHubMessageAsync(new StreamItemMessage("1", 14)).DefaultTimeout();

await client.SendHubMessageAsync(CompletionMessage.Empty("1")).DefaultTimeout();

var response = Assert.IsType<CompletionMessage>(await client.ReadAsync().DefaultTimeout());
Assert.Equal(25L, response.Result);
}
}

[Fact]
public async Task StreamFromServiceDoesNotWork()
{
var channel = Channel.CreateBounded<int>(10);
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
provider.AddSingleton(channel.Reader);
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.Stream)).DefaultTimeout();
Assert.Equal("An unexpected error occurred invoking 'Stream' on the server. HubException: Client sent 0 stream(s), Hub method expects 1.", res.Error);
}
}

[Fact]
public async Task ServiceNotResolvedWithoutAttribute()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
provider.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
});
provider.AddSingleton<Service1>();
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.ServiceWithoutAttribute)).DefaultTimeout();
Assert.Equal("Failed to invoke 'ServiceWithoutAttribute' due to an error on the server. InvalidDataException: Invocation provides 0 argument(s) but target expects 1.", res.Error);
}
}

[Fact]
public void TooManyParametersWithServiceThrows()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
provider.AddSingleton<Service1>();
});
Assert.Throws<InvalidOperationException>(
() => serviceProvider.GetService<HubConnectionHandler<TooManyParamsHub>>());
}

private class CustomHubActivator<THub> : IHubActivator<THub> where THub : Hub
{
public int ReleaseCount;
Expand Down