From 37b84fdd8bb42288ba9fabb4d7f1c38e17d088ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Dach?= Date: Wed, 17 Jul 2024 14:12:04 +0200 Subject: [PATCH 1/5] Read & use client-provided session GUID for concurrency control --- .../ConcurrentConnectionLimiterTests.cs | 313 +++++++++++++++++- .../ConcurrentConnectionLimiter.cs | 22 +- .../Entities/ConnectionState.cs | 49 ++- .../Hubs/Metadata/MetadataHub.cs | 3 +- .../ServerShuttingDownException.cs | 3 +- 5 files changed, 364 insertions(+), 26 deletions(-) diff --git a/osu.Server.Spectator.Tests/ConcurrentConnectionLimiterTests.cs b/osu.Server.Spectator.Tests/ConcurrentConnectionLimiterTests.cs index c5c37e4d..7d5ce52a 100644 --- a/osu.Server.Spectator.Tests/ConcurrentConnectionLimiterTests.cs +++ b/osu.Server.Spectator.Tests/ConcurrentConnectionLimiterTests.cs @@ -5,9 +5,12 @@ using System.Linq; using System.Security.Claims; using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Connections.Features; using Microsoft.AspNetCore.SignalR; using Microsoft.Extensions.Logging; using Moq; +using osu.Game.Online; using osu.Server.Spectator.Entities; using osu.Server.Spectator.Hubs.Spectator; using Xunit; @@ -37,8 +40,301 @@ public ConcurrentConnectionLimiterTests() hubMock = new Mock(); } + #region New path (uses client-side generated session GUID) + + [Fact] + public async Task TestNormalOperation_SessionIDPresent() + { + var hubCallerContextMock = new Mock(); + var httpContextMock = new Mock(); + hubCallerContextMock.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + hubCallerContextMock.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", Guid.NewGuid().ToString()) + }) + })); + hubCallerContextMock.Setup(ctx => ctx.Features.Get()).Returns(httpContextMock.Object); + httpContextMock.Setup(ctx => ctx.HttpContext).Returns(() => + { + var context = new DefaultHttpContext(); + context.Request.Headers[HubClientConnector.CLIENT_SESSION_ID_HEADER] = Guid.NewGuid().ToString(); + return context; + }); + + var filter = new ConcurrentConnectionLimiter(connectionStates, serviceProviderMock.Object, loggerFactoryMock.Object); + var lifetimeContext = new HubLifetimeContext(hubCallerContextMock.Object, serviceProviderMock.Object, hubMock.Object); + + bool connected = false; + await filter.OnConnectedAsync(lifetimeContext, _ => + { + connected = true; + return Task.CompletedTask; + }); + Assert.True(connected); + Assert.Single(connectionStates.GetEntityUnsafe(1234)!.ConnectionIds); + + bool methodInvoked = false; + var invocationContext = new HubInvocationContext(hubCallerContextMock.Object, serviceProviderMock.Object, hubMock.Object, + typeof(SpectatorHub).GetMethod(nameof(SpectatorHub.StartWatchingUser))!, new object[] { 1234 }); + await filter.InvokeMethodAsync(invocationContext, _ => + { + methodInvoked = true; + return new ValueTask(new object()); + }); + Assert.True(methodInvoked); + Assert.Single(connectionStates.GetEntityUnsafe(1234)!.ConnectionIds); + + bool disconnected = false; + await filter.OnDisconnectedAsync(lifetimeContext, null, (_, _) => + { + disconnected = true; + return Task.CompletedTask; + }); + Assert.True(disconnected); + Assert.Null(connectionStates.GetEntityUnsafe(1234)); + } + + [Fact] + public async Task TestConcurrencyBlocked_SessionIDPresent() + { + var firstHubCallerContext = new Mock(); + var firstHttpContextMock = new Mock(); + var secondHubCallerContext = new Mock(); + var secondHttpContextMock = new Mock(); + + firstHubCallerContext.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + firstHubCallerContext.Setup(ctx => ctx.ConnectionId).Returns("abcd"); + firstHubCallerContext.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", Guid.NewGuid().ToString()) + }) + })); + firstHubCallerContext.Setup(ctx => ctx.Features.Get()).Returns(firstHttpContextMock.Object); + firstHttpContextMock.Setup(ctx => ctx.HttpContext).Returns(() => + { + var context = new DefaultHttpContext(); + context.Request.Headers[HubClientConnector.CLIENT_SESSION_ID_HEADER] = Guid.NewGuid().ToString(); + return context; + }); + + secondHubCallerContext.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + secondHubCallerContext.Setup(ctx => ctx.ConnectionId).Returns("efgh"); + secondHubCallerContext.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", Guid.NewGuid().ToString()) + }) + })); + secondHubCallerContext.Setup(ctx => ctx.Features.Get()).Returns(secondHttpContextMock.Object); + secondHttpContextMock.Setup(ctx => ctx.HttpContext).Returns(() => + { + var context = new DefaultHttpContext(); + context.Request.Headers[HubClientConnector.CLIENT_SESSION_ID_HEADER] = Guid.NewGuid().ToString(); + return context; + }); + + var filter = new ConcurrentConnectionLimiter(connectionStates, serviceProviderMock.Object, loggerFactoryMock.Object); + + var firstLifetimeContext = new HubLifetimeContext(firstHubCallerContext.Object, serviceProviderMock.Object, hubMock.Object); + await filter.OnConnectedAsync(firstLifetimeContext, _ => Task.CompletedTask); + + var secondLifetimeContext = new HubLifetimeContext(secondHubCallerContext.Object, serviceProviderMock.Object, hubMock.Object); + await filter.OnConnectedAsync(secondLifetimeContext, _ => Task.CompletedTask); + + var secondInvocationContext = new HubInvocationContext(secondHubCallerContext.Object, serviceProviderMock.Object, hubMock.Object, + typeof(SpectatorHub).GetMethod(nameof(SpectatorHub.StartWatchingUser))!, new object[] { 1234 }); + // should succeed. + await filter.InvokeMethodAsync(secondInvocationContext, _ => new ValueTask(new object())); + + var firstInvocationContext = new HubInvocationContext(firstHubCallerContext.Object, serviceProviderMock.Object, hubMock.Object, + typeof(SpectatorHub).GetMethod(nameof(SpectatorHub.StartWatchingUser))!, new object[] { 1234 }); + // should throw. + await Assert.ThrowsAsync(() => filter.InvokeMethodAsync(firstInvocationContext, _ => new ValueTask(new object())).AsTask()); + } + + [Fact] + public async Task TestStaleDisconnectIsANoOp_SessionIDPresent() + { + var firstHubCallerContext = new Mock(); + var firstHttpContextMock = new Mock(); + var secondHubCallerContext = new Mock(); + var secondHttpContextMock = new Mock(); + string commonSessionId = Guid.NewGuid().ToString(); + + firstHubCallerContext.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + firstHubCallerContext.Setup(ctx => ctx.ConnectionId).Returns("abcd"); + firstHubCallerContext.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", commonSessionId) + }) + })); + firstHubCallerContext.Setup(ctx => ctx.Features.Get()).Returns(firstHttpContextMock.Object); + firstHttpContextMock.Setup(ctx => ctx.HttpContext).Returns(() => + { + var context = new DefaultHttpContext(); + context.Request.Headers[HubClientConnector.CLIENT_SESSION_ID_HEADER] = commonSessionId; + return context; + }); + + secondHubCallerContext.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + secondHubCallerContext.Setup(ctx => ctx.ConnectionId).Returns("efgh"); + secondHubCallerContext.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", commonSessionId) + }) + })); + secondHubCallerContext.Setup(ctx => ctx.Features.Get()).Returns(secondHttpContextMock.Object); + secondHttpContextMock.Setup(ctx => ctx.HttpContext).Returns(() => + { + var context = new DefaultHttpContext(); + context.Request.Headers[HubClientConnector.CLIENT_SESSION_ID_HEADER] = commonSessionId; + return context; + }); + + var filter = new ConcurrentConnectionLimiter(connectionStates, serviceProviderMock.Object, loggerFactoryMock.Object); + + var firstLifetimeContext = new HubLifetimeContext(firstHubCallerContext.Object, serviceProviderMock.Object, hubMock.Object); + await filter.OnConnectedAsync(firstLifetimeContext, _ => Task.CompletedTask); + + var secondLifetimeContext = new HubLifetimeContext(secondHubCallerContext.Object, serviceProviderMock.Object, hubMock.Object); + await filter.OnConnectedAsync(secondLifetimeContext, _ => Task.CompletedTask); + + await filter.OnDisconnectedAsync(firstLifetimeContext, null, (_, _) => Task.CompletedTask); + Assert.Single(connectionStates.GetEntityUnsafe(1234)!.ConnectionIds); + Assert.Equal("efgh", connectionStates.GetEntityUnsafe(1234)!.ConnectionIds.Single().Value); + } + [Fact] - public async Task TestNormalOperation() + public async Task TestHubDisconnectsTrackedSeparately_SessionIDPresent() + { + var firstHubCallerContext = new Mock(); + var firstHttpContextMock = new Mock(); + var secondHubCallerContext = new Mock(); + var secondHttpContextMock = new Mock(); + string commonSessionId = Guid.NewGuid().ToString(); + + firstHubCallerContext.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + firstHubCallerContext.Setup(ctx => ctx.ConnectionId).Returns("abcd"); + firstHubCallerContext.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", commonSessionId) + }) + })); + firstHubCallerContext.Setup(ctx => ctx.Features.Get()).Returns(firstHttpContextMock.Object); + firstHttpContextMock.Setup(ctx => ctx.HttpContext).Returns(() => + { + var context = new DefaultHttpContext(); + context.Request.Headers[HubClientConnector.CLIENT_SESSION_ID_HEADER] = commonSessionId; + return context; + }); + + secondHubCallerContext.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + secondHubCallerContext.Setup(ctx => ctx.ConnectionId).Returns("efgh"); + secondHubCallerContext.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", commonSessionId) + }) + })); + secondHubCallerContext.Setup(ctx => ctx.Features.Get()).Returns(secondHttpContextMock.Object); + secondHttpContextMock.Setup(ctx => ctx.HttpContext).Returns(() => + { + var context = new DefaultHttpContext(); + context.Request.Headers[HubClientConnector.CLIENT_SESSION_ID_HEADER] = commonSessionId; + return context; + }); + + var filter = new ConcurrentConnectionLimiter(connectionStates, serviceProviderMock.Object, loggerFactoryMock.Object); + + var firstLifetimeContext = new HubLifetimeContext(firstHubCallerContext.Object, serviceProviderMock.Object, new FirstHub()); + await filter.OnConnectedAsync(firstLifetimeContext, _ => Task.CompletedTask); + + var secondLifetimeContext = new HubLifetimeContext(secondHubCallerContext.Object, serviceProviderMock.Object, new SecondHub()); + await filter.OnConnectedAsync(secondLifetimeContext, _ => Task.CompletedTask); + Assert.Equal(2, connectionStates.GetEntityUnsafe(1234)!.ConnectionIds.Count); + Assert.Equal("abcd", connectionStates.GetEntityUnsafe(1234)!.ConnectionIds[typeof(FirstHub)]); + Assert.Equal("efgh", connectionStates.GetEntityUnsafe(1234)!.ConnectionIds[typeof(SecondHub)]); + + await filter.OnDisconnectedAsync(firstLifetimeContext, null, (_, _) => Task.CompletedTask); + Assert.Single(connectionStates.GetEntityUnsafe(1234)!.ConnectionIds); + Assert.Equal("efgh", connectionStates.GetEntityUnsafe(1234)!.ConnectionIds[typeof(SecondHub)]); + } + + [Fact] + public async Task TestSessionIDOverrulesTokenID() + { + var firstHubCallerContext = new Mock(); + var firstHttpContextMock = new Mock(); + var secondHubCallerContext = new Mock(); + var secondHttpContextMock = new Mock(); + string commonSessionId = Guid.NewGuid().ToString(); + + firstHubCallerContext.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + firstHubCallerContext.Setup(ctx => ctx.ConnectionId).Returns("abcd"); + firstHubCallerContext.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", "first token ID") + }) + })); + firstHubCallerContext.Setup(ctx => ctx.Features.Get()).Returns(firstHttpContextMock.Object); + firstHttpContextMock.Setup(ctx => ctx.HttpContext).Returns(() => + { + var context = new DefaultHttpContext(); + context.Request.Headers[HubClientConnector.CLIENT_SESSION_ID_HEADER] = commonSessionId; + return context; + }); + + secondHubCallerContext.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + secondHubCallerContext.Setup(ctx => ctx.ConnectionId).Returns("efgh"); + secondHubCallerContext.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", "second token ID") + }) + })); + secondHubCallerContext.Setup(ctx => ctx.Features.Get()).Returns(secondHttpContextMock.Object); + secondHttpContextMock.Setup(ctx => ctx.HttpContext).Returns(() => + { + var context = new DefaultHttpContext(); + context.Request.Headers[HubClientConnector.CLIENT_SESSION_ID_HEADER] = commonSessionId; + return context; + }); + + var filter = new ConcurrentConnectionLimiter(connectionStates, serviceProviderMock.Object, loggerFactoryMock.Object); + + var firstLifetimeContext = new HubLifetimeContext(firstHubCallerContext.Object, serviceProviderMock.Object, new FirstHub()); + await filter.OnConnectedAsync(firstLifetimeContext, _ => Task.CompletedTask); + + var secondLifetimeContext = new HubLifetimeContext(secondHubCallerContext.Object, serviceProviderMock.Object, new SecondHub()); + await filter.OnConnectedAsync(secondLifetimeContext, _ => Task.CompletedTask); + + var firstInvocationContext = new HubInvocationContext(firstHubCallerContext.Object, serviceProviderMock.Object, new FirstHub(), + typeof(SpectatorHub).GetMethod(nameof(SpectatorHub.StartWatchingUser))!, new object[] { 1234 }); + // should not throw. + await filter.InvokeMethodAsync(firstInvocationContext, _ => new ValueTask(new object())); + } + + #endregion + + #region Legacy path (uses JWT `jti` claim to distinguish clients) + + [Fact] + public async Task TestNormalOperation_SessionIDNotPresent() { var hubCallerContextMock = new Mock(); hubCallerContextMock.Setup(ctx => ctx.UserIdentifier).Returns("1234"); @@ -49,6 +345,7 @@ public async Task TestNormalOperation() new Claim("jti", Guid.NewGuid().ToString()) }) })); + hubCallerContextMock.Setup(ctx => ctx.Features.Get()).Returns(new Mock().Object); var filter = new ConcurrentConnectionLimiter(connectionStates, serviceProviderMock.Object, loggerFactoryMock.Object); var lifetimeContext = new HubLifetimeContext(hubCallerContextMock.Object, serviceProviderMock.Object, hubMock.Object); @@ -84,7 +381,7 @@ await filter.OnDisconnectedAsync(lifetimeContext, null, (_, _) => } [Fact] - public async Task TestConcurrencyBlocked() + public async Task TestConcurrencyBlocked_SessionIDNotPresent() { var firstContextMock = new Mock(); var secondContextMock = new Mock(); @@ -98,6 +395,7 @@ public async Task TestConcurrencyBlocked() new Claim("jti", Guid.NewGuid().ToString()) }) })); + firstContextMock.Setup(ctx => ctx.Features.Get()).Returns(new Mock().Object); secondContextMock.Setup(ctx => ctx.UserIdentifier).Returns("1234"); secondContextMock.Setup(ctx => ctx.ConnectionId).Returns("efgh"); @@ -108,6 +406,7 @@ public async Task TestConcurrencyBlocked() new Claim("jti", Guid.NewGuid().ToString()) }) })); + secondContextMock.Setup(ctx => ctx.Features.Get()).Returns(new Mock().Object); var filter = new ConcurrentConnectionLimiter(connectionStates, serviceProviderMock.Object, loggerFactoryMock.Object); @@ -129,7 +428,7 @@ public async Task TestConcurrencyBlocked() } [Fact] - public async Task TestStaleDisconnectIsANoOp() + public async Task TestStaleDisconnectIsANoOp_SessionIDNotPresent() { var firstContextMock = new Mock(); var secondContextMock = new Mock(); @@ -144,6 +443,7 @@ public async Task TestStaleDisconnectIsANoOp() new Claim("jti", commonTokenId) }) })); + firstContextMock.Setup(ctx => ctx.Features.Get()).Returns(new Mock().Object); secondContextMock.Setup(ctx => ctx.UserIdentifier).Returns("1234"); secondContextMock.Setup(ctx => ctx.ConnectionId).Returns("efgh"); @@ -154,6 +454,7 @@ public async Task TestStaleDisconnectIsANoOp() new Claim("jti", commonTokenId) }) })); + secondContextMock.Setup(ctx => ctx.Features.Get()).Returns(new Mock().Object); var filter = new ConcurrentConnectionLimiter(connectionStates, serviceProviderMock.Object, loggerFactoryMock.Object); @@ -169,7 +470,7 @@ public async Task TestStaleDisconnectIsANoOp() } [Fact] - public async Task TestHubDisconnectsTrackedSeparately() + public async Task TestHubDisconnectsTrackedSeparately_SessionIDNotPresent() { var firstContextMock = new Mock(); var secondContextMock = new Mock(); @@ -184,6 +485,7 @@ public async Task TestHubDisconnectsTrackedSeparately() new Claim("jti", commonTokenId) }) })); + firstContextMock.Setup(ctx => ctx.Features.Get()).Returns(new Mock().Object); secondContextMock.Setup(ctx => ctx.UserIdentifier).Returns("1234"); secondContextMock.Setup(ctx => ctx.ConnectionId).Returns("efgh"); @@ -194,6 +496,7 @@ public async Task TestHubDisconnectsTrackedSeparately() new Claim("jti", commonTokenId) }) })); + secondContextMock.Setup(ctx => ctx.Features.Get()).Returns(new Mock().Object); var filter = new ConcurrentConnectionLimiter(connectionStates, serviceProviderMock.Object, loggerFactoryMock.Object); @@ -211,6 +514,8 @@ public async Task TestHubDisconnectsTrackedSeparately() Assert.Equal("efgh", connectionStates.GetEntityUnsafe(1234)!.ConnectionIds[typeof(SecondHub)]); } + #endregion + private class FirstHub : Hub { } diff --git a/osu.Server.Spectator/ConcurrentConnectionLimiter.cs b/osu.Server.Spectator/ConcurrentConnectionLimiter.cs index 2fa99fb8..ae899965 100644 --- a/osu.Server.Spectator/ConcurrentConnectionLimiter.cs +++ b/osu.Server.Spectator/ConcurrentConnectionLimiter.cs @@ -55,7 +55,7 @@ private async Task registerConnection(HubLifetimeContext context) return; } - if (context.Context.GetTokenId() == userState.Item.TokenId) + if (userState.Item.IsConnectionFromSameClient(context)) { // The assumption is that the client has already dropped the old connection, // so we don't bother to ask for a disconnection. @@ -99,15 +99,7 @@ private void log(HubLifetimeContext context, string message) using (var userState = await connectionStates.GetForUse(userId)) { - string? registeredConnectionId = null; - - bool tokenIdMatches = invocationContext.Context.GetTokenId() == userState.Item?.TokenId; - bool hubRegistered = userState.Item?.ConnectionIds.TryGetValue(invocationContext.Hub.GetType(), out registeredConnectionId) == true; - bool connectionIdMatches = registeredConnectionId == invocationContext.Context.ConnectionId; - - bool connectionIsValid = tokenIdMatches && hubRegistered && connectionIdMatches; - - if (!connectionIsValid) + if (userState.Item?.IsInvocationPermitted(invocationContext) != true) throw new InvalidOperationException($"State is not valid for this connection, context: {LoggingHubFilter.GetMethodCallDisplayString(invocationContext)})"); } @@ -129,15 +121,7 @@ private async Task unregisterConnection(HubLifetimeContext context, Exception? e using (var userState = await connectionStates.GetForUse(userId, true)) { - string? registeredConnectionId = null; - - bool tokenIdMatches = context.Context.GetTokenId() == userState.Item?.TokenId; - bool hubRegistered = userState.Item?.ConnectionIds.TryGetValue(context.Hub.GetType(), out registeredConnectionId) == true; - bool connectionIdMatches = registeredConnectionId == context.Context.ConnectionId; - - bool connectionCanBeCleanedUp = tokenIdMatches && hubRegistered && connectionIdMatches; - - if (connectionCanBeCleanedUp) + if (userState.Item?.CanCleanUpConnection(context) == true) { log(context, "disconnected from hub"); userState.Item!.ConnectionIds.Remove(context.Hub.GetType()); diff --git a/osu.Server.Spectator/Entities/ConnectionState.cs b/osu.Server.Spectator/Entities/ConnectionState.cs index c158fc5d..69e96ad3 100644 --- a/osu.Server.Spectator/Entities/ConnectionState.cs +++ b/osu.Server.Spectator/Entities/ConnectionState.cs @@ -4,8 +4,11 @@ using System; using System.Collections.Generic; using Microsoft.AspNetCore.SignalR; +using osu.Game.Online; using osu.Server.Spectator.Extensions; +#pragma warning disable CS0618 // Type or member is obsolete + namespace osu.Server.Spectator.Entities { /// @@ -14,9 +17,19 @@ namespace osu.Server.Spectator.Entities public class ConnectionState { /// - /// The unique ID of the JWT the user is using to authenticate. + /// A client-side generated GUID identifying the client instance connecting to this server. /// This is used to control user uniqueness. /// + public readonly Guid? ClientSessionId; + + /// + /// The unique ID of the JWT the user is using to authenticate. + /// + /// + /// This was previously used as a method of controlling user uniqueness / limiting concurrency, + /// but it turned out to be a bad fit for the purpose (see https://github.com/ppy/osu/issues/26338#issuecomment-2222935517). + /// + [Obsolete("Use ClientSessionId instead.")] public readonly string TokenId; /// @@ -33,6 +46,9 @@ public ConnectionState(HubLifetimeContext context) { TokenId = context.Context.GetTokenId(); + if (tryGetClientSessionID(context, out var clientSessionId)) + ClientSessionId = clientSessionId; + RegisterConnectionId(context); } @@ -42,5 +58,36 @@ public ConnectionState(HubLifetimeContext context) /// The hub context to retrieve information from. public void RegisterConnectionId(HubLifetimeContext context) => ConnectionIds[context.Hub.GetType()] = context.Context.ConnectionId; + + private bool tryGetClientSessionID(HubLifetimeContext context, out Guid clientSessionId) + { + clientSessionId = Guid.Empty; + return context.Context.GetHttpContext()?.Request.Headers.TryGetValue(HubClientConnector.CLIENT_SESSION_ID_HEADER, out var value) == true + && Guid.TryParse(value, out clientSessionId); + } + + public bool IsConnectionFromSameClient(HubLifetimeContext context) + { + if (tryGetClientSessionID(context, out var clientSessionId)) + return ClientSessionId == clientSessionId; + + return TokenId == context.Context.GetTokenId(); + } + + public bool IsInvocationPermitted(HubInvocationContext context) + { + bool hubRegistered = ConnectionIds.TryGetValue(context.Hub.GetType(), out string? registeredConnectionId); + bool connectionIdMatches = registeredConnectionId == context.Context.ConnectionId; + + return hubRegistered && connectionIdMatches; + } + + public bool CanCleanUpConnection(HubLifetimeContext context) + { + bool hubRegistered = ConnectionIds.TryGetValue(context.Hub.GetType(), out string? registeredConnectionId); + bool connectionIdMatches = registeredConnectionId == context.Context.ConnectionId; + + return hubRegistered && connectionIdMatches; + } } } diff --git a/osu.Server.Spectator/Hubs/Metadata/MetadataHub.cs b/osu.Server.Spectator/Hubs/Metadata/MetadataHub.cs index b3242297..1b72b8f8 100644 --- a/osu.Server.Spectator/Hubs/Metadata/MetadataHub.cs +++ b/osu.Server.Spectator/Hubs/Metadata/MetadataHub.cs @@ -7,6 +7,7 @@ using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Primitives; using Microsoft.Extensions.Logging; +using osu.Game.Online; using osu.Game.Online.Metadata; using osu.Game.Users; using osu.Server.Spectator.Database; @@ -48,7 +49,7 @@ public override async Task OnConnectedAsync() { string? versionHash = null; - if (Context.GetHttpContext()?.Request.Headers.TryGetValue("OsuVersionHash", out StringValues headerValue) == true) + if (Context.GetHttpContext()?.Request.Headers.TryGetValue(HubClientConnector.VERSION_HASH_HEADER, out StringValues headerValue) == true) { versionHash = headerValue; diff --git a/osu.Server.Spectator/ServerShuttingDownException.cs b/osu.Server.Spectator/ServerShuttingDownException.cs index 9206487d..75796d77 100644 --- a/osu.Server.Spectator/ServerShuttingDownException.cs +++ b/osu.Server.Spectator/ServerShuttingDownException.cs @@ -2,13 +2,14 @@ // See the LICENCE file in the repository root for full licence text. using Microsoft.AspNetCore.SignalR; +using osu.Game.Online; namespace osu.Server.Spectator { public class ServerShuttingDownException : HubException { public ServerShuttingDownException() - : base("Server is shutting down.") + : base(HubClientConnector.SERVER_SHUTDOWN_MESSAGE) { } } From 15b49baa52d23fa33e7fee27055b272105f0c81f Mon Sep 17 00:00:00 2001 From: Dean Herbert Date: Thu, 18 Jul 2024 13:33:17 +0900 Subject: [PATCH 2/5] Update packages --- SampleMultiplayerClient/SampleMultiplayerClient.csproj | 2 +- SampleSpectatorClient/SampleSpectatorClient.csproj | 2 +- osu.Server.Spectator/osu.Server.Spectator.csproj | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/SampleMultiplayerClient/SampleMultiplayerClient.csproj b/SampleMultiplayerClient/SampleMultiplayerClient.csproj index d33aa350..488db391 100644 --- a/SampleMultiplayerClient/SampleMultiplayerClient.csproj +++ b/SampleMultiplayerClient/SampleMultiplayerClient.csproj @@ -11,7 +11,7 @@ - + diff --git a/SampleSpectatorClient/SampleSpectatorClient.csproj b/SampleSpectatorClient/SampleSpectatorClient.csproj index d33aa350..488db391 100644 --- a/SampleSpectatorClient/SampleSpectatorClient.csproj +++ b/SampleSpectatorClient/SampleSpectatorClient.csproj @@ -11,7 +11,7 @@ - + diff --git a/osu.Server.Spectator/osu.Server.Spectator.csproj b/osu.Server.Spectator/osu.Server.Spectator.csproj index 0da10764..a13e4d24 100644 --- a/osu.Server.Spectator/osu.Server.Spectator.csproj +++ b/osu.Server.Spectator/osu.Server.Spectator.csproj @@ -15,11 +15,11 @@ - - - - - + + + + + From 9e101166f8106073e25382fa66bcb4c5c9ad8194 Mon Sep 17 00:00:00 2001 From: Dean Herbert Date: Thu, 18 Jul 2024 14:10:19 +0900 Subject: [PATCH 3/5] Move retrieval helper method to last and make `static` --- osu.Server.Spectator/Entities/ConnectionState.cs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/osu.Server.Spectator/Entities/ConnectionState.cs b/osu.Server.Spectator/Entities/ConnectionState.cs index 69e96ad3..22119a6e 100644 --- a/osu.Server.Spectator/Entities/ConnectionState.cs +++ b/osu.Server.Spectator/Entities/ConnectionState.cs @@ -59,13 +59,6 @@ public ConnectionState(HubLifetimeContext context) public void RegisterConnectionId(HubLifetimeContext context) => ConnectionIds[context.Hub.GetType()] = context.Context.ConnectionId; - private bool tryGetClientSessionID(HubLifetimeContext context, out Guid clientSessionId) - { - clientSessionId = Guid.Empty; - return context.Context.GetHttpContext()?.Request.Headers.TryGetValue(HubClientConnector.CLIENT_SESSION_ID_HEADER, out var value) == true - && Guid.TryParse(value, out clientSessionId); - } - public bool IsConnectionFromSameClient(HubLifetimeContext context) { if (tryGetClientSessionID(context, out var clientSessionId)) @@ -89,5 +82,12 @@ public bool CanCleanUpConnection(HubLifetimeContext context) return hubRegistered && connectionIdMatches; } + + private static bool tryGetClientSessionID(HubLifetimeContext context, out Guid clientSessionId) + { + clientSessionId = Guid.Empty; + return context.Context.GetHttpContext()?.Request.Headers.TryGetValue(HubClientConnector.CLIENT_SESSION_ID_HEADER, out var value) == true + && Guid.TryParse(value, out clientSessionId); + } } } From 3def1bdf6af3f7b49a9bf2ddf15dc1307efc8681 Mon Sep 17 00:00:00 2001 From: Dean Herbert Date: Thu, 18 Jul 2024 14:13:46 +0900 Subject: [PATCH 4/5] Add notes about obsolete path --- osu.Server.Spectator/Entities/ConnectionState.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/osu.Server.Spectator/Entities/ConnectionState.cs b/osu.Server.Spectator/Entities/ConnectionState.cs index 22119a6e..b79cc9a0 100644 --- a/osu.Server.Spectator/Entities/ConnectionState.cs +++ b/osu.Server.Spectator/Entities/ConnectionState.cs @@ -29,7 +29,7 @@ public class ConnectionState /// This was previously used as a method of controlling user uniqueness / limiting concurrency, /// but it turned out to be a bad fit for the purpose (see https://github.com/ppy/osu/issues/26338#issuecomment-2222935517). /// - [Obsolete("Use ClientSessionId instead.")] + [Obsolete("Use ClientSessionId instead.")] // Can be removed 2024-08-18 public readonly string TokenId; /// @@ -64,6 +64,7 @@ public bool IsConnectionFromSameClient(HubLifetimeContext context) if (tryGetClientSessionID(context, out var clientSessionId)) return ClientSessionId == clientSessionId; + // Legacy pathway using JTI claim left for compatibility with older clients – can be removed 2024-08-18 return TokenId == context.Context.GetTokenId(); } From 0cc6189a675a7e05be86dee739d4731710913254 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Dach?= Date: Thu, 18 Jul 2024 08:35:47 +0200 Subject: [PATCH 5/5] Rename methods to better convey they are overloads of each other --- osu.Server.Spectator/ConcurrentConnectionLimiter.cs | 4 ++-- osu.Server.Spectator/Entities/ConnectionState.cs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/osu.Server.Spectator/ConcurrentConnectionLimiter.cs b/osu.Server.Spectator/ConcurrentConnectionLimiter.cs index ae899965..bdfa9d51 100644 --- a/osu.Server.Spectator/ConcurrentConnectionLimiter.cs +++ b/osu.Server.Spectator/ConcurrentConnectionLimiter.cs @@ -99,7 +99,7 @@ private void log(HubLifetimeContext context, string message) using (var userState = await connectionStates.GetForUse(userId)) { - if (userState.Item?.IsInvocationPermitted(invocationContext) != true) + if (userState.Item?.ExistingConnectionMatches(invocationContext) != true) throw new InvalidOperationException($"State is not valid for this connection, context: {LoggingHubFilter.GetMethodCallDisplayString(invocationContext)})"); } @@ -121,7 +121,7 @@ private async Task unregisterConnection(HubLifetimeContext context, Exception? e using (var userState = await connectionStates.GetForUse(userId, true)) { - if (userState.Item?.CanCleanUpConnection(context) == true) + if (userState.Item?.ExistingConnectionMatches(context) == true) { log(context, "disconnected from hub"); userState.Item!.ConnectionIds.Remove(context.Hub.GetType()); diff --git a/osu.Server.Spectator/Entities/ConnectionState.cs b/osu.Server.Spectator/Entities/ConnectionState.cs index b79cc9a0..48de19e9 100644 --- a/osu.Server.Spectator/Entities/ConnectionState.cs +++ b/osu.Server.Spectator/Entities/ConnectionState.cs @@ -68,7 +68,7 @@ public bool IsConnectionFromSameClient(HubLifetimeContext context) return TokenId == context.Context.GetTokenId(); } - public bool IsInvocationPermitted(HubInvocationContext context) + public bool ExistingConnectionMatches(HubInvocationContext context) { bool hubRegistered = ConnectionIds.TryGetValue(context.Hub.GetType(), out string? registeredConnectionId); bool connectionIdMatches = registeredConnectionId == context.Context.ConnectionId; @@ -76,7 +76,7 @@ public bool IsInvocationPermitted(HubInvocationContext context) return hubRegistered && connectionIdMatches; } - public bool CanCleanUpConnection(HubLifetimeContext context) + public bool ExistingConnectionMatches(HubLifetimeContext context) { bool hubRegistered = ConnectionIds.TryGetValue(context.Hub.GetType(), out string? registeredConnectionId); bool connectionIdMatches = registeredConnectionId == context.Context.ConnectionId;