Skip to content

Commit d803d95

Browse files
committed
Playing around with validation in SignalR
1 parent f355564 commit d803d95

File tree

5 files changed

+63
-6
lines changed

5 files changed

+63
-6
lines changed

src/SignalR/samples/SignalRSamples/Hubs/Chat.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.ComponentModel.DataAnnotations;
45
using Microsoft.AspNetCore.SignalR;
56

67
namespace SignalRSamples.Hubs;
@@ -19,7 +20,7 @@ public override Task OnDisconnectedAsync(Exception exception)
1920
return Clients.All.SendAsync("Send", $"{name} left the chat");
2021
}
2122

22-
public Task Send(string name, string message)
23+
public Task Send([StringLength(10)] string name, [Required] string message)
2324
{
2425
return Clients.All.SendAsync("Send", $"{name}: {message}");
2526
}

src/SignalR/samples/SignalRSamples/Startup.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ public void ConfigureServices(IServiceCollection services)
2121
services.AddSignalR()
2222
.AddMessagePackProtocol();
2323
//.AddStackExchangeRedis();
24+
25+
services.AddValidation();
2426
}
2527

2628
// This method gets called by the runtime. Use this method to configure the HTTP request pipeline.

src/SignalR/server/Core/src/HubConnectionHandler.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Diagnostics.CodeAnalysis;
66
using System.Linq;
77
using Microsoft.AspNetCore.Connections;
8+
using Microsoft.AspNetCore.Http.Validation;
89
using Microsoft.AspNetCore.SignalR.Internal;
910
using Microsoft.AspNetCore.SignalR.Protocol;
1011
using Microsoft.Extensions.DependencyInjection;
@@ -52,7 +53,8 @@ public HubConnectionHandler(HubLifetimeManager<THub> lifetimeManager,
5253
IOptions<HubOptions<THub>> hubOptions,
5354
ILoggerFactory loggerFactory,
5455
IUserIdProvider userIdProvider,
55-
IServiceScopeFactory serviceScopeFactory
56+
IServiceScopeFactory serviceScopeFactory,
57+
IOptions<ValidationOptions>? validationOptions
5658
)
5759
{
5860
_protocolResolver = protocolResolver;
@@ -101,7 +103,8 @@ IServiceScopeFactory serviceScopeFactory
101103
disableImplicitFromServiceParameters,
102104
new Logger<DefaultHubDispatcher<THub>>(loggerFactory),
103105
hubFilters,
104-
lifetimeManager);
106+
lifetimeManager,
107+
validationOptions?.Value);
105108
}
106109

107110
/// <inheritdoc />

src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.ComponentModel.DataAnnotations;
45
using System.Diagnostics;
56
using System.Diagnostics.CodeAnalysis;
67
using System.Linq;
78
using System.Reflection;
89
using System.Security.Claims;
10+
using System.Text.Json;
911
using System.Threading.Channels;
1012
using Microsoft.AspNetCore.Authorization;
13+
using Microsoft.AspNetCore.Http.Validation;
1114
using Microsoft.AspNetCore.Internal;
1215
using Microsoft.AspNetCore.Shared;
1316
using Microsoft.AspNetCore.SignalR.Protocol;
1417
using Microsoft.Extensions.DependencyInjection;
1518
using Microsoft.Extensions.Internal;
1619
using Microsoft.Extensions.Logging;
20+
using Microsoft.Extensions.Options;
1721
using Log = Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcherLog;
1822

1923
namespace Microsoft.AspNetCore.SignalR.Internal;
@@ -32,6 +36,7 @@ internal sealed partial class DefaultHubDispatcher<[DynamicallyAccessedMembers(H
3236
private readonly Func<HubLifetimeContext, Task>? _onConnectedMiddleware;
3337
private readonly Func<HubLifetimeContext, Exception?, Task>? _onDisconnectedMiddleware;
3438
private readonly HubLifetimeManager<THub> _hubLifetimeManager;
39+
private readonly ValidationOptions? _validationOptions;
3540

3641
[FeatureSwitchDefinition("Microsoft.AspNetCore.SignalR.Hub.IsCustomAwaitableSupported")]
3742
[FeatureGuard(typeof(RequiresDynamicCodeAttribute))]
@@ -40,13 +45,15 @@ internal sealed partial class DefaultHubDispatcher<[DynamicallyAccessedMembers(H
4045
AppContext.TryGetSwitch("Microsoft.AspNetCore.SignalR.Hub.IsCustomAwaitableSupported", out bool customAwaitableSupport) ? customAwaitableSupport : true;
4146

4247
public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext<THub> hubContext, bool enableDetailedErrors,
43-
bool disableImplicitFromServiceParameters, ILogger<DefaultHubDispatcher<THub>> logger, List<IHubFilter>? hubFilters, HubLifetimeManager<THub> lifetimeManager)
48+
bool disableImplicitFromServiceParameters, ILogger<DefaultHubDispatcher<THub>> logger, List<IHubFilter>? hubFilters, HubLifetimeManager<THub> lifetimeManager,
49+
ValidationOptions? validationOptions)
4450
{
4551
_serviceScopeFactory = serviceScopeFactory;
4652
_hubContext = hubContext;
4753
_enableDetailedErrors = enableDetailedErrors;
4854
_logger = logger;
4955
_hubLifetimeManager = lifetimeManager;
56+
_validationOptions = validationOptions;
5057
DiscoverHubMethods(disableImplicitFromServiceParameters);
5158

5259
var count = hubFilters?.Count ?? 0;
@@ -343,6 +350,11 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
343350
return true;
344351
}
345352

353+
if (!await ValidateArguments(descriptor, hubMethodInvocationMessage, connection))
354+
{
355+
return true;
356+
}
357+
346358
try
347359
{
348360
var clientStreamLength = hubMethodInvocationMessage.StreamIds?.Length ?? 0;
@@ -687,6 +699,41 @@ private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provi
687699
return authorizationResult.Succeeded;
688700
}
689701

702+
private async Task<bool> ValidateArguments(HubMethodDescriptor hubMethodDescriptor, HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
703+
{
704+
if (_validationOptions == null || _validationOptions.Resolvers.Count == 0)
705+
{
706+
return true;
707+
}
708+
709+
var validateContext = new ValidateContext()
710+
{
711+
ValidationOptions = _validationOptions,
712+
CurrentValidationPath = $"{_fullHubName}.{hubMethodInvocationMessage.Target}",
713+
};
714+
715+
for (var i = 0; i < hubMethodDescriptor.ParameterInfos.Count; i++)
716+
{
717+
validateContext.ValidationContext = new ValidationContext(hubMethodInvocationMessage.Arguments[i], serviceProvider: null, items: null)
718+
{
719+
DisplayName = hubMethodDescriptor.ParameterInfos[i].Name,
720+
};
721+
if (_validationOptions.TryGetValidatableParameterInfo(hubMethodDescriptor.ParameterInfos[i], out var parameterValidator))
722+
{
723+
await parameterValidator.ValidateAsync(hubMethodInvocationMessage.Arguments[i], validateContext, default);
724+
}
725+
}
726+
727+
if (validateContext.ValidationErrors is { Count: > 0 })
728+
{
729+
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
730+
$"Failed to invoke '{hubMethodInvocationMessage.Target}' because of validation errors: {JsonSerializer.Serialize(validateContext.ValidationErrors)}");
731+
return false;
732+
}
733+
734+
return true;
735+
}
736+
690737
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamResponse,
691738
HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
692739
{

src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider
5858
}
5959

6060
// Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers
61-
ParameterTypes = methodExecutor.MethodParameters.Where((p, index) =>
61+
ParameterInfos = methodExecutor.MethodParameters.Where((p, index) =>
6262
{
6363
// Only streams can take CancellationTokens currently
6464
if (IsStreamResponse && p.ParameterType == typeof(CancellationToken))
@@ -134,7 +134,9 @@ void ThrowIfMarked(bool marked)
134134
}
135135

136136
return true;
137-
}).Select(p => p.ParameterType).ToArray();
137+
}).ToArray();
138+
139+
ParameterTypes = ParameterInfos.Select(p => p.ParameterType).ToArray();
138140

139141
if (HasSyntheticArguments)
140142
{
@@ -164,6 +166,8 @@ private bool MarkServiceParameter(int index)
164166

165167
public IReadOnlyList<Type> ParameterTypes { get; }
166168

169+
public IReadOnlyList<ParameterInfo> ParameterInfos { get; }
170+
167171
public IReadOnlyList<Type>? OriginalParameterTypes { get; }
168172

169173
public Type NonAsyncReturnType { get; }

0 commit comments

Comments
 (0)