11// Licensed to the .NET Foundation under one or more agreements.
22// The .NET Foundation licenses this file to you under the MIT license.
33
4- using System ;
5- using System . Collections . Generic ;
6- using System . Threading . Tasks ;
4+ using System . Net . WebSockets ;
75using Microsoft . AspNetCore . Http . Connections ;
6+ using Microsoft . AspNetCore . Http . Connections . Client ;
87using Microsoft . AspNetCore . SignalR . Client ;
98using Microsoft . AspNetCore . SignalR . Protocol ;
109using Microsoft . AspNetCore . SignalR . Tests ;
1110using Microsoft . AspNetCore . Testing ;
1211using Microsoft . Extensions . DependencyInjection ;
1312using Microsoft . Extensions . Logging ;
14- using Xunit ;
1513
1614namespace Microsoft . AspNetCore . SignalR . StackExchangeRedis . Tests ;
1715
@@ -211,7 +209,105 @@ public async Task CanSendAndReceiveUserMessagesUserNameWithPatternIsTreatedAsLit
211209 }
212210 }
213211
214- private static HubConnection CreateConnection ( string url , HttpTransportType transportType , IHubProtocol protocol , ILoggerFactory loggerFactory , string userName = null )
212+ [ ConditionalTheory ]
213+ [ SkipIfDockerNotPresent ]
214+ [ InlineData ( "messagepack" ) ]
215+ [ InlineData ( "json" ) ]
216+ public async Task StatefulReconnectPreservesMessageFromOtherServer ( string protocolName )
217+ {
218+ using ( StartVerifiableLog ( ) )
219+ {
220+ var protocol = HubProtocolHelpers . GetHubProtocol ( protocolName ) ;
221+
222+ ClientWebSocket innerWs = null ;
223+ WebSocketWrapper ws = null ;
224+ TaskCompletionSource reconnectTcs = null ;
225+ TaskCompletionSource startedReconnectTcs = null ;
226+
227+ var connection = CreateConnection ( _serverFixture . FirstServer . Url + "/stateful" , HttpTransportType . WebSockets , protocol , LoggerFactory ,
228+ customizeConnection : builder =>
229+ {
230+ builder . WithStatefulReconnect ( ) ;
231+ builder . Services . Configure < HttpConnectionOptions > ( o =>
232+ {
233+ // Replace the websocket creation for the first connection so we can make the client think there was an ungraceful closure
234+ // Which will trigger the stateful reconnect flow
235+ o . WebSocketFactory = async ( context , token ) =>
236+ {
237+ if ( reconnectTcs is null )
238+ {
239+ reconnectTcs = new TaskCompletionSource ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
240+ startedReconnectTcs = new TaskCompletionSource ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
241+ }
242+ else
243+ {
244+ startedReconnectTcs . SetResult ( ) ;
245+ // We only want to wait on the reconnect, not the initial connection attempt
246+ await reconnectTcs . Task . DefaultTimeout ( ) ;
247+ }
248+
249+ innerWs = new ClientWebSocket ( ) ;
250+ ws = new WebSocketWrapper ( innerWs ) ;
251+ await innerWs . ConnectAsync ( context . Uri , token ) ;
252+
253+ _ = Task . Run ( async ( ) =>
254+ {
255+ try
256+ {
257+ while ( innerWs . State == WebSocketState . Open )
258+ {
259+ var buffer = new byte [ 1024 ] ;
260+ var res = await innerWs . ReceiveAsync ( buffer , default ) ;
261+ ws . SetReceiveResult ( ( res , buffer . AsMemory ( 0 , res . Count ) ) ) ;
262+ }
263+ }
264+ // Log but ignore receive errors, that likely just means the connection closed
265+ catch ( Exception ex )
266+ {
267+ Logger . LogInformation ( ex , "Error while reading from inner websocket" ) ;
268+ }
269+ } ) ;
270+
271+ return ws ;
272+ } ;
273+ } ) ;
274+ } ) ;
275+ var secondConnection = CreateConnection ( _serverFixture . SecondServer . Url + "/stateful" , HttpTransportType . WebSockets , protocol , LoggerFactory ) ;
276+
277+ var tcs = new TaskCompletionSource < string > ( ) ;
278+ connection . On < string > ( "SendToAll" , message => tcs . TrySetResult ( message ) ) ;
279+
280+ var tcs2 = new TaskCompletionSource < string > ( ) ;
281+ secondConnection . On < string > ( "SendToAll" , message => tcs2 . TrySetResult ( message ) ) ;
282+
283+ await connection . StartAsync ( ) . DefaultTimeout ( ) ;
284+ await secondConnection . StartAsync ( ) . DefaultTimeout ( ) ;
285+
286+ // Close first connection before the second connection sends a message to all clients
287+ await ws . CloseOutputAsync ( WebSocketCloseStatus . InternalServerError , statusDescription : null , default ) ;
288+ await startedReconnectTcs . Task . DefaultTimeout ( ) ;
289+
290+ // Send to all clients, since both clients are on different servers this means the backplane will be used
291+ // And we want to test that messages are still preserved for stateful reconnect purposes when a client disconnects
292+ // But is on a different server from the original message sender.
293+ await secondConnection . SendAsync ( "SendToAll" , "test message" ) . DefaultTimeout ( ) ;
294+
295+ // Check that second connection still receives the message
296+ Assert . Equal ( "test message" , await tcs2 . Task . DefaultTimeout ( ) ) ;
297+ Assert . False ( tcs . Task . IsCompleted ) ;
298+
299+ // allow first connection to reconnect
300+ reconnectTcs . SetResult ( ) ;
301+
302+ // Check that first connection received the message once it reconnected
303+ Assert . Equal ( "test message" , await tcs . Task . DefaultTimeout ( ) ) ;
304+
305+ await connection . DisposeAsync ( ) . DefaultTimeout ( ) ;
306+ }
307+ }
308+
309+ private static HubConnection CreateConnection ( string url , HttpTransportType transportType , IHubProtocol protocol , ILoggerFactory loggerFactory , string userName = null ,
310+ Action < IHubConnectionBuilder > customizeConnection = null )
215311 {
216312 var hubConnectionBuilder = new HubConnectionBuilder ( )
217313 . WithLoggerFactory ( loggerFactory )
@@ -225,6 +321,8 @@ private static HubConnection CreateConnection(string url, HttpTransportType tran
225321
226322 hubConnectionBuilder . Services . AddSingleton ( protocol ) ;
227323
324+ customizeConnection ? . Invoke ( hubConnectionBuilder ) ;
325+
228326 return hubConnectionBuilder . Build ( ) ;
229327 }
230328
@@ -253,4 +351,67 @@ public static IEnumerable<object[]> TransportTypesAndProtocolTypes
253351 }
254352 }
255353 }
354+
355+ internal sealed class WebSocketWrapper : WebSocket
356+ {
357+ private readonly WebSocket _inner ;
358+ private TaskCompletionSource < ( WebSocketReceiveResult , ReadOnlyMemory < byte > ) > _receiveTcs = new ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
359+
360+ public WebSocketWrapper ( WebSocket inner )
361+ {
362+ _inner = inner ;
363+ }
364+
365+ public override WebSocketCloseStatus ? CloseStatus => _inner . CloseStatus ;
366+
367+ public override string CloseStatusDescription => _inner . CloseStatusDescription ;
368+
369+ public override WebSocketState State => _inner . State ;
370+
371+ public override string SubProtocol => _inner . SubProtocol ;
372+
373+ public override void Abort ( )
374+ {
375+ _inner . Abort ( ) ;
376+ }
377+
378+ public override Task CloseAsync ( WebSocketCloseStatus closeStatus , string statusDescription , CancellationToken cancellationToken )
379+ {
380+ return _inner . CloseAsync ( closeStatus , statusDescription , cancellationToken ) ;
381+ }
382+
383+ public override Task CloseOutputAsync ( WebSocketCloseStatus closeStatus , string statusDescription , CancellationToken cancellationToken )
384+ {
385+ _receiveTcs . TrySetException ( new IOException ( "force reconnect" ) ) ;
386+ return Task . CompletedTask ;
387+ }
388+
389+ public override void Dispose ( )
390+ {
391+ _inner . Dispose ( ) ;
392+ }
393+
394+ public void SetReceiveResult ( ( WebSocketReceiveResult , ReadOnlyMemory < byte > ) result )
395+ {
396+ _receiveTcs . SetResult ( result ) ;
397+ }
398+
399+ public override async Task < WebSocketReceiveResult > ReceiveAsync ( ArraySegment < byte > buffer , CancellationToken cancellationToken )
400+ {
401+ var res = await _receiveTcs . Task ;
402+ // Handle zero-byte reads
403+ if ( buffer . Count == 0 )
404+ {
405+ return res . Item1 ;
406+ }
407+ _receiveTcs = new ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
408+ res . Item2 . CopyTo ( buffer ) ;
409+ return res . Item1 ;
410+ }
411+
412+ public override Task SendAsync ( ArraySegment < byte > buffer , WebSocketMessageType messageType , bool endOfMessage , CancellationToken cancellationToken )
413+ {
414+ return _inner . SendAsync ( buffer , messageType , endOfMessage , cancellationToken ) ;
415+ }
416+ }
256417}
0 commit comments