Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for DI in Hub methods #34047

Merged
merged 11 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public void GlobalSetup()
serviceScopeFactory,
new HubContext<TestHub>(new DefaultHubLifetimeManager<TestHub>(NullLogger<DefaultHubLifetimeManager<TestHub>>.Instance)),
enableDetailedErrors: false,
disableImplicitFromServiceParameters: true,
new Logger<DefaultHubDispatcher<TestHub>>(NullLoggerFactory.Instance),
hubFilters: null);

Expand Down
4 changes: 4 additions & 0 deletions src/SignalR/server/Core/src/HubConnectionHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ IServiceScopeFactory serviceScopeFactory
_userIdProvider = userIdProvider;

_enableDetailedErrors = false;
bool disableImplicitFromServiceParameters;

List<IHubFilter>? hubFilters = null;
if (_hubOptions.UserHasSetValues)
{
_maximumMessageSize = _hubOptions.MaximumReceiveMessageSize;
_enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _enableDetailedErrors;
_maxParallelInvokes = _hubOptions.MaximumParallelInvocationsPerClient;
disableImplicitFromServiceParameters = _hubOptions.DisableImplicitFromServiceParameters;

if (_hubOptions.HubFilters != null)
{
Expand All @@ -80,6 +82,7 @@ IServiceScopeFactory serviceScopeFactory
_maximumMessageSize = _globalHubOptions.MaximumReceiveMessageSize;
_enableDetailedErrors = _globalHubOptions.EnableDetailedErrors ?? _enableDetailedErrors;
_maxParallelInvokes = _globalHubOptions.MaximumParallelInvocationsPerClient;
disableImplicitFromServiceParameters = _globalHubOptions.DisableImplicitFromServiceParameters;

if (_globalHubOptions.HubFilters != null)
{
Expand All @@ -91,6 +94,7 @@ IServiceScopeFactory serviceScopeFactory
serviceScopeFactory,
new HubContext<THub>(lifetimeManager),
_enableDetailedErrors,
disableImplicitFromServiceParameters,
new Logger<DefaultHubDispatcher<THub>>(loggerFactory),
hubFilters);
}
Expand Down
12 changes: 12 additions & 0 deletions src/SignalR/server/Core/src/HubOptions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.AspNetCore.Http.Metadata;
using Microsoft.Extensions.DependencyInjection;

namespace Microsoft.AspNetCore.SignalR;

/// <summary>
Expand Down Expand Up @@ -70,4 +73,13 @@ public int MaximumParallelInvocationsPerClient
_maximumParallelInvocationsPerClient = value;
}
}

/// <summary>
/// When <see langword="false"/>, <see cref="IServiceProviderIsService"/> determines if a Hub method parameter will be injected from the DI container.
/// Parameters can be explicitly marked with an attribute that implements <see cref="IFromServiceMetadata"/> with or without this option set.
/// </summary>
/// <remarks>
/// False by default. Hub method arguments will be resolved from a DI container if possible.
/// </remarks>
public bool DisableImplicitFromServiceParameters { get; set; }
}
1 change: 1 addition & 0 deletions src/SignalR/server/Core/src/HubOptionsSetup`T.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public void Configure(HubOptions<THub> options)
options.MaximumReceiveMessageSize = _hubOptions.MaximumReceiveMessageSize;
options.StreamBufferCapacity = _hubOptions.StreamBufferCapacity;
options.MaximumParallelInvocationsPerClient = _hubOptions.MaximumParallelInvocationsPerClient;
options.DisableImplicitFromServiceParameters = _hubOptions.DisableImplicitFromServiceParameters;

options.UserHasSetValues = true;

Expand Down
24 changes: 18 additions & 6 deletions src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ internal partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> where TH
private readonly Func<HubLifetimeContext, Exception?, Task>? _onDisconnectedMiddleware;

public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext<THub> hubContext, bool enableDetailedErrors,
ILogger<DefaultHubDispatcher<THub>> logger, List<IHubFilter>? hubFilters)
bool disableImplicitFromServiceParameters, ILogger<DefaultHubDispatcher<THub>> logger, List<IHubFilter>? hubFilters)
{
_serviceScopeFactory = serviceScopeFactory;
_hubContext = hubContext;
_enableDetailedErrors = enableDetailedErrors;
_logger = logger;
DiscoverHubMethods();
DiscoverHubMethods(disableImplicitFromServiceParameters);

var count = hubFilters?.Count ?? 0;
if (count != 0)
Expand Down Expand Up @@ -307,7 +307,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
CancellationTokenSource? cts = null;
if (descriptor.HasSyntheticArguments)
{
ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, ref arguments, out cts);
ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, scope, ref arguments, out cts);
}

if (isStreamResponse)
Expand Down Expand Up @@ -601,7 +601,7 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa
}

private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamCall,
HubConnectionContext connection, ref object?[] arguments, out CancellationTokenSource? cts)
HubConnectionContext connection, AsyncServiceScope scope, ref object?[] arguments, out CancellationTokenSource? cts)
{
cts = null;
// In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments)
Expand All @@ -626,6 +626,10 @@ private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocatio
cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
arguments[parameterPointer] = cts.Token;
}
else if (descriptor.IsServiceArgument(parameterPointer))
{
arguments[parameterPointer] = scope.ServiceProvider.GetRequiredService(descriptor.OriginalParameterTypes[parameterPointer]);
}
else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true))
{
Log.StartingParameterStream(_logger, hubMethodInvocationMessage.StreamIds![streamPointer]);
Expand All @@ -644,12 +648,20 @@ private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocatio
}
}

private void DiscoverHubMethods()
private void DiscoverHubMethods(bool disableImplicitFromServiceParameters)
{
var hubType = typeof(THub);
var hubTypeInfo = hubType.GetTypeInfo();
var hubName = hubType.Name;

using var scope = _serviceScopeFactory.CreateScope();

IServiceProviderIsService? serviceProviderIsService = null;
if (!disableImplicitFromServiceParameters)
{
serviceProviderIsService = scope.ServiceProvider.GetService<IServiceProviderIsService>();
}

foreach (var methodInfo in HubReflectionHelper.GetHubMethods(hubType))
{
if (methodInfo.IsGenericMethod)
Expand All @@ -668,7 +680,7 @@ private void DiscoverHubMethods()

var executor = ObjectMethodExecutor.Create(methodInfo, hubTypeInfo);
var authorizeAttributes = methodInfo.GetCustomAttributes<AuthorizeAttribute>(inherit: true);
_methods[methodName] = new HubMethodDescriptor(executor, authorizeAttributes);
_methods[methodName] = new HubMethodDescriptor(executor, serviceProviderIsService, authorizeAttributes);

Log.HubMethodBound(_logger, hubName, methodName);
}
Expand Down
25 changes: 23 additions & 2 deletions src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using System.Reflection;
using System.Threading.Channels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http.Metadata;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Internal;

namespace Microsoft.AspNetCore.SignalR.Internal;
Expand All @@ -22,8 +24,10 @@ internal class HubMethodDescriptor

private readonly MethodInfo? _makeCancelableEnumeratorMethodInfo;
private Func<object, CancellationToken, IAsyncEnumerator<object>>? _makeCancelableEnumerator;
// bitset to store which parameters come from DI up to 64 arguments
private ulong _isServiceArgument;

public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProviderIsService? serviceProviderIsService, IEnumerable<IAuthorizeData> policies)
{
MethodExecutor = methodExecutor;

Expand Down Expand Up @@ -56,7 +60,7 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut
}

// Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers
ParameterTypes = methodExecutor.MethodParameters.Where(p =>
ParameterTypes = methodExecutor.MethodParameters.Where((p, index) =>
{
// Only streams can take CancellationTokens currently
if (IsStreamResponse && p.ParameterType == typeof(CancellationToken))
Expand All @@ -75,6 +79,18 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut
HasSyntheticArguments = true;
return false;
}
else if (p.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType)) ||
serviceProviderIsService?.IsService(p.ParameterType) == true)
{
if (index >= 64)
{
throw new InvalidOperationException(
"Hub methods can't use services from DI in the parameters after the 64th parameter.");
}
_isServiceArgument |= (1UL << index);
HasSyntheticArguments = true;
return false;
}
return true;
}).Select(p => p.ParameterType).ToArray();

Expand Down Expand Up @@ -104,6 +120,11 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut

public bool HasSyntheticArguments { get; private set; }

public bool IsServiceArgument(int argumentIndex)
{
return (_isServiceArgument & (1UL << argumentIndex)) != 0;
}

public IAsyncEnumerator<object> FromReturnedStream(object stream, CancellationToken cancellationToken)
{
// there is the potential for compile to be called times but this has no harmful effect other than perf
Expand Down
2 changes: 2 additions & 0 deletions src/SignalR/server/Core/src/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
#nullable enable
Microsoft.AspNetCore.SignalR.HubOptions.DisableImplicitFromServiceParameters.get -> bool
Microsoft.AspNetCore.SignalR.HubOptions.DisableImplicitFromServiceParameters.set -> void
3 changes: 3 additions & 0 deletions src/SignalR/server/SignalR/test/AddSignalRTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ public void HubSpecificOptionsHaveSameValuesAsGlobalHubOptions()
Assert.Equal(globalHubOptions.SupportedProtocols, hubOptions.SupportedProtocols);
Assert.Equal(globalHubOptions.ClientTimeoutInterval, hubOptions.ClientTimeoutInterval);
Assert.Equal(globalHubOptions.MaximumParallelInvocationsPerClient, hubOptions.MaximumParallelInvocationsPerClient);
Assert.Equal(globalHubOptions.DisableImplicitFromServiceParameters, hubOptions.DisableImplicitFromServiceParameters);
Assert.True(hubOptions.UserHasSetValues);
}

Expand Down Expand Up @@ -145,6 +146,7 @@ public void UserSpecifiedOptionsRunAfterDefaultOptions()
options.SupportedProtocols = null;
options.ClientTimeoutInterval = TimeSpan.FromSeconds(1);
options.MaximumParallelInvocationsPerClient = 3;
options.DisableImplicitFromServiceParameters = true;
});

var serviceProvider = serviceCollection.BuildServiceProvider();
Expand All @@ -158,6 +160,7 @@ public void UserSpecifiedOptionsRunAfterDefaultOptions()
Assert.Null(globalOptions.SupportedProtocols);
Assert.Equal(3, globalOptions.MaximumParallelInvocationsPerClient);
Assert.Equal(TimeSpan.FromSeconds(1), globalOptions.ClientTimeoutInterval);
Assert.True(globalOptions.DisableImplicitFromServiceParameters);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http.Metadata;
using Newtonsoft.Json.Serialization;

namespace Microsoft.AspNetCore.SignalR.Tests;
Expand Down Expand Up @@ -1247,3 +1243,65 @@ public void SetCaller(IClientProxy caller)
Caller = caller;
}
}

[AttributeUsage(AttributeTargets.Parameter, AllowMultiple = false, Inherited = true)]
public class FromService : Attribute, IFromServiceMetadata
{ }
public class Service1
{ }
public class Service2
{ }
public class Service3
{ }

public class ServicesHub : TestHub
{
public bool SingleService([FromService] Service1 service)
{
return true;
}

public bool MultipleServices([FromService] Service1 service, [FromService] Service2 service2, [FromService] Service3 service3)
{
return true;
}

public async Task<int> ServicesAndParams(int value, [FromService] Service1 service, ChannelReader<int> channelReader, [FromService] Service2 service2, bool value2)
{
int total = 0;
while (await channelReader.WaitToReadAsync())
{
total += await channelReader.ReadAsync();
}
return total + value;
}

public int ServiceWithoutAttribute(Service1 service)
{
return 1;
}

public int ServiceWithAndWithoutAttribute(Service1 service, [FromService] Service2 service2)
{
return 1;
}

public async Task Stream(ChannelReader<int> channelReader)
{
while (await channelReader.WaitToReadAsync())
{
await channelReader.ReadAsync();
}
}
}

public class TooManyParamsHub : Hub
{
public void ManyParams(int a1, string a2, bool a3, float a4, string a5, int a6, int a7, int a8, int a9, int a10, int a11,
int a12, int a13, int a14, int a15, int a16, int a17, int a18, int a19, int a20, int a21, int a22, int a23, int a24,
int a25, int a26, int a27, int a28, int a29, int a30, int a31, int a32, int a33, int a34, int a35, int a36, int a37,
int a38, int a39, int a40, int a41, int a42, int a43, int a44, int a45, int a46, int a47, int a48, int a49, int a50,
int a51, int a52, int a53, int a54, int a55, int a56, int a57, int a58, int a59, int a60, int a61, int a62, int a63,
int a64, [FromService] Service1 service)
{ }
}
Loading