diff --git a/src/libraries/Common/src/Interop/Interop.zlib.cs b/src/libraries/Common/src/Interop/Interop.zlib.cs index 280c5558667eb..ad517da4079ca 100644 --- a/src/libraries/Common/src/Interop/Interop.zlib.cs +++ b/src/libraries/Common/src/Interop/Interop.zlib.cs @@ -20,6 +20,9 @@ internal static extern ZLibNative.ErrorCode DeflateInit2_( [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_Deflate")] internal static extern ZLibNative.ErrorCode Deflate(ref ZLibNative.ZStream stream, ZLibNative.FlushCode flush); + [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_DeflateReset")] + internal static extern ZLibNative.ErrorCode DeflateReset(ref ZLibNative.ZStream stream); + [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_DeflateEnd")] internal static extern ZLibNative.ErrorCode DeflateEnd(ref ZLibNative.ZStream stream); @@ -29,6 +32,9 @@ internal static extern ZLibNative.ErrorCode DeflateInit2_( [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_Inflate")] internal static extern ZLibNative.ErrorCode Inflate(ref ZLibNative.ZStream stream, ZLibNative.FlushCode flush); + [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_InflateReset")] + internal static extern ZLibNative.ErrorCode InflateReset(ref ZLibNative.ZStream stream); + [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_InflateEnd")] internal static extern ZLibNative.ErrorCode InflateEnd(ref ZLibNative.ZStream stream); diff --git a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.ZStream.cs b/src/libraries/Common/src/System/IO/Compression/ZLibNative.ZStream.cs similarity index 100% rename from src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.ZStream.cs rename to src/libraries/Common/src/System/IO/Compression/ZLibNative.ZStream.cs diff --git a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.cs b/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs similarity index 97% rename from src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.cs rename to src/libraries/Common/src/System/IO/Compression/ZLibNative.cs index 8118aeba0ecb8..f0393ebbf35cb 100644 --- a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.cs +++ b/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs @@ -23,6 +23,7 @@ public enum FlushCode : int NoFlush = 0, SyncFlush = 2, Finish = 4, + Block = 5 } public enum ErrorCode : int @@ -281,6 +282,13 @@ public ErrorCode Deflate(FlushCode flush) } + public ErrorCode DeflateReset() + { + EnsureNotDisposed(); + EnsureState(State.InitializedForDeflate); + return Interop.zlib.DeflateReset(ref _zStream); + } + public ErrorCode DeflateEnd() { EnsureNotDisposed(); @@ -313,6 +321,13 @@ public ErrorCode Inflate(FlushCode flush) } + public ErrorCode InflateReset() + { + EnsureNotDisposed(); + EnsureState(State.InitializedForInflate); + return Interop.zlib.InflateReset(ref _zStream); + } + public ErrorCode InflateEnd() { EnsureNotDisposed(); diff --git a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs index a092c96648389..d074f618bf16d 100644 --- a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs +++ b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs @@ -9,6 +9,19 @@ namespace System.Net.WebSockets { internal static partial class WebSocketValidate { + /// + /// The minimum value for window bits that the websocket per-message-deflate extension can support. + /// For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported. + /// We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream + /// and thus it needs to know the window bits in advance. + /// + internal const int MinDeflateWindowBits = 9; + + /// + /// The maximum value for window bits that the websocket per-message-deflate extension can support. + /// + internal const int MaxDeflateWindowBits = 15; + internal const int MaxControlFramePayloadLength = 123; private const int CloseStatusCodeAbort = 1006; private const int CloseStatusCodeFailedTLSHandshake = 1015; diff --git a/src/libraries/Native/AnyOS/System.IO.Compression.Native/entrypoints.c b/src/libraries/Native/AnyOS/System.IO.Compression.Native/entrypoints.c index b194b978debe2..f363a91eb1add 100644 --- a/src/libraries/Native/AnyOS/System.IO.Compression.Native/entrypoints.c +++ b/src/libraries/Native/AnyOS/System.IO.Compression.Native/entrypoints.c @@ -28,9 +28,11 @@ static const Entry s_compressionNative[] = DllImportEntry(CompressionNative_Crc32) DllImportEntry(CompressionNative_Deflate) DllImportEntry(CompressionNative_DeflateEnd) + DllImportEntry(CompressionNative_DeflateReset) DllImportEntry(CompressionNative_DeflateInit2_) DllImportEntry(CompressionNative_Inflate) DllImportEntry(CompressionNative_InflateEnd) + DllImportEntry(CompressionNative_InflateReset) DllImportEntry(CompressionNative_InflateInit2_) }; diff --git a/src/libraries/Native/AnyOS/zlib/pal_zlib.c b/src/libraries/Native/AnyOS/zlib/pal_zlib.c index 2c399639d0fa9..aa4dcdca8a29e 100644 --- a/src/libraries/Native/AnyOS/zlib/pal_zlib.c +++ b/src/libraries/Native/AnyOS/zlib/pal_zlib.c @@ -135,6 +135,17 @@ int32_t CompressionNative_Deflate(PAL_ZStream* stream, int32_t flush) return result; } +int32_t CompressionNative_DeflateReset(PAL_ZStream* stream) +{ + assert(stream != NULL); + + z_stream* zStream = GetCurrentZStream(stream); + int32_t result = deflateReset(zStream); + TransferStateToPalZStream(zStream, stream); + + return result; +} + int32_t CompressionNative_DeflateEnd(PAL_ZStream* stream) { assert(stream != NULL); @@ -172,6 +183,17 @@ int32_t CompressionNative_Inflate(PAL_ZStream* stream, int32_t flush) return result; } +int32_t CompressionNative_InflateReset(PAL_ZStream* stream) +{ + assert(stream != NULL); + + z_stream* zStream = GetCurrentZStream(stream); + int32_t result = inflateReset(zStream); + TransferStateToPalZStream(zStream, stream); + + return result; +} + int32_t CompressionNative_InflateEnd(PAL_ZStream* stream) { assert(stream != NULL); diff --git a/src/libraries/Native/AnyOS/zlib/pal_zlib.h b/src/libraries/Native/AnyOS/zlib/pal_zlib.h index b317091b843f6..1eb1baa6b3846 100644 --- a/src/libraries/Native/AnyOS/zlib/pal_zlib.h +++ b/src/libraries/Native/AnyOS/zlib/pal_zlib.h @@ -95,6 +95,14 @@ Returns a PAL_ErrorCode indicating success or an error number on failure. */ FUNCTIONEXPORT int32_t FUNCTIONCALLINGCONVENCTION CompressionNative_Deflate(PAL_ZStream* stream, int32_t flush); +/* +This function is equivalent to DeflateEnd followed by DeflateInit, but does not free and reallocate +the internal compression state. The stream will leave the compression level and any other attributes that may have been set unchanged. + +Returns a PAL_ErrorCode indicating success or an error number on failure. +*/ +FUNCTIONEXPORT int32_t FUNCTIONCALLINGCONVENCTION CompressionNative_DeflateReset(PAL_ZStream* stream); + /* All dynamically allocated data structures for this stream are freed. @@ -117,6 +125,14 @@ Returns a PAL_ErrorCode indicating success or an error number on failure. */ FUNCTIONEXPORT int32_t FUNCTIONCALLINGCONVENCTION CompressionNative_Inflate(PAL_ZStream* stream, int32_t flush); +/* +This function is equivalent to InflateEnd followed by InflateInit, but does not free and reallocate +the internal decompression state. The The stream will keep attributes that may have been set by InflateInit. + +Returns a PAL_ErrorCode indicating success or an error number on failure. +*/ +FUNCTIONEXPORT int32_t FUNCTIONCALLINGCONVENCTION CompressionNative_InflateReset(PAL_ZStream* stream); + /* All dynamically allocated data structures for this stream are freed. diff --git a/src/libraries/Native/Unix/System.IO.Compression.Native/System.IO.Compression.Native_unixexports.src b/src/libraries/Native/Unix/System.IO.Compression.Native/System.IO.Compression.Native_unixexports.src index 08dd1700a52f2..2ac827035f271 100644 --- a/src/libraries/Native/Unix/System.IO.Compression.Native/System.IO.Compression.Native_unixexports.src +++ b/src/libraries/Native/Unix/System.IO.Compression.Native/System.IO.Compression.Native_unixexports.src @@ -15,7 +15,9 @@ BrotliEncoderSetParameter CompressionNative_Crc32 CompressionNative_Deflate CompressionNative_DeflateEnd +CompressionNative_DeflateReset CompressionNative_DeflateInit2_ CompressionNative_Inflate CompressionNative_InflateEnd +CompressionNative_InflateReset CompressionNative_InflateInit2_ diff --git a/src/libraries/Native/Windows/System.IO.Compression.Native/System.IO.Compression.Native.def b/src/libraries/Native/Windows/System.IO.Compression.Native/System.IO.Compression.Native.def index 6821d0e538f51..aecd0dd974618 100644 --- a/src/libraries/Native/Windows/System.IO.Compression.Native/System.IO.Compression.Native.def +++ b/src/libraries/Native/Windows/System.IO.Compression.Native/System.IO.Compression.Native.def @@ -15,7 +15,9 @@ EXPORTS CompressionNative_Crc32 CompressionNative_Deflate CompressionNative_DeflateEnd + CompressionNative_DeflateReset CompressionNative_DeflateInit2_ CompressionNative_Inflate CompressionNative_InflateEnd + CompressionNative_InflateReset CompressionNative_InflateInit2_ diff --git a/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj b/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj index e2a7adee12f57..0ffa0044e2a16 100644 --- a/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj +++ b/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj @@ -25,8 +25,10 @@ - - + + diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs index cee3a5170b862..96cecd9e30f47 100644 --- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs +++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs @@ -36,6 +36,8 @@ internal ClientWebSocketOptions() { } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] + public System.Net.WebSockets.WebSocketDeflateOptions? DangerousDeflateOptions { get { throw null; } set { } } + [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.IWebProxy? Proxy { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.Security.RemoteCertificateValidationCallback? RemoteCertificateValidationCallback { get { throw null; } set { } } diff --git a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx index 3259b86c99fcb..7b4718b554a15 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx @@ -1,63 +1,4 @@  - @@ -193,8 +134,14 @@ Connection was aborted. - + WebSocket binary type '{0}' not supported. - - + + + The WebSocket failed to negotiate max server window bits. The client requested {0} but the server responded with {1}. + + + The WebSocket failed to negotiate max client window bits. The client requested {0} but the server responded with {1}. + + \ No newline at end of file diff --git a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj index b74f3d8962be6..e84ea02f895ba 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj +++ b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj @@ -6,6 +6,7 @@ + @@ -37,6 +38,7 @@ + diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs index 85b0f025b4650..79dd04229b9c3 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs @@ -100,6 +100,13 @@ public TimeSpan KeepAliveInterval set => throw new PlatformNotSupportedException(); } + [UnsupportedOSPlatform("browser")] + public WebSocketDeflateOptions? DangerousDeflateOptions + { + get => throw new PlatformNotSupportedException(); + set => throw new PlatformNotSupportedException(); + } + [UnsupportedOSPlatform("browser")] public void SetBuffer(int receiveBufferSize, int sendBufferSize) { diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs new file mode 100644 index 0000000000000..3faa886d5c306 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.WebSockets +{ + internal static class ClientWebSocketDeflateConstants + { + /// + /// The maximum length that this extension can have, assuming that we're not abusing white space. + /// + /// "permessage-deflate; client_max_window_bits=15; client_no_context_takeover; server_max_window_bits=15; server_no_context_takeover" + /// + public const int MaxExtensionLength = 128; + + public const string Extension = "permessage-deflate"; + + public const string ClientMaxWindowBits = "client_max_window_bits"; + public const string ClientNoContextTakeover = "client_no_context_takeover"; + + public const string ServerMaxWindowBits = "server_max_window_bits"; + public const string ServerNoContextTakeover = "server_no_context_takeover"; + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index a7609a0ff0905..5ab2ad51d94eb 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -148,6 +148,18 @@ public TimeSpan KeepAliveInterval } } + /// + /// Gets or sets the options for the per-message-deflate extension. + /// When present, the options are sent to the server during the handshake phase. If the server + /// supports per-message-deflate and the options are accepted, the instance + /// will be created with compression enabled by default for all messages. + /// Be aware that enabling compression makes the application subject to CRIME/BREACH type of attacks. + /// It is strongly advised to turn off compression when sending data containing secrets by + /// specifying flag for such messages. + /// + [UnsupportedOSPlatform("browser")] + public WebSocketDeflateOptions? DangerousDeflateOptions { get; set; } + internal int ReceiveBufferSize => _receiveBufferSize; internal ArraySegment? Buffer => _buffer; diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index d61f368e7aae8..e0c3902a91590 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.IO; using System.Net.Http; using System.Net.Http.Headers; @@ -22,6 +23,7 @@ internal sealed class WebSocketHandle private readonly CancellationTokenSource _abortSource = new CancellationTokenSource(); private WebSocketState _state = WebSocketState.Connecting; + private WebSocketDeflateOptions? _negotiatedDeflateOptions; public WebSocket? WebSocket { get; private set; } public WebSocketState State => WebSocket?.State ?? _state; @@ -183,6 +185,21 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli } } + // Because deflate options are negotiated we need a new object + WebSocketDeflateOptions? negotiatedDeflateOptions = null; + + if (options.DangerousDeflateOptions is not null && response.Headers.TryGetValues(HttpKnownHeaderNames.SecWebSocketExtensions, out IEnumerable? extensions)) + { + foreach (ReadOnlySpan extension in extensions) + { + if (extension.TrimStart().StartsWith(ClientWebSocketDeflateConstants.Extension)) + { + negotiatedDeflateOptions = ParseDeflateOptions(extension, options.DangerousDeflateOptions); + break; + } + } + } + if (response.Content is null) { throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); @@ -192,11 +209,14 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli Stream connectedStream = response.Content.ReadAsStream(); Debug.Assert(connectedStream.CanWrite); Debug.Assert(connectedStream.CanRead); - WebSocket = WebSocket.CreateFromStream( - connectedStream, - isServer: false, - subprotocol, - options.KeepAliveInterval); + WebSocket = WebSocket.CreateFromStream(connectedStream, new WebSocketCreationOptions + { + IsServer = false, + SubProtocol = subprotocol, + KeepAliveInterval = options.KeepAliveInterval, + DangerousDeflateOptions = negotiatedDeflateOptions + }); + _negotiatedDeflateOptions = negotiatedDeflateOptions; } catch (Exception exc) { @@ -226,6 +246,73 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli } } + private static WebSocketDeflateOptions ParseDeflateOptions(ReadOnlySpan extension, WebSocketDeflateOptions original) + { + var options = new WebSocketDeflateOptions(); + + while (true) + { + int end = extension.IndexOf(';'); + ReadOnlySpan value = (end >= 0 ? extension[..end] : extension).Trim(); + + if (value.Length > 0) + { + if (value.SequenceEqual(ClientWebSocketDeflateConstants.ClientNoContextTakeover)) + { + options.ClientContextTakeover = false; + } + else if (value.SequenceEqual(ClientWebSocketDeflateConstants.ServerNoContextTakeover)) + { + options.ServerContextTakeover = false; + } + else if (value.StartsWith(ClientWebSocketDeflateConstants.ClientMaxWindowBits)) + { + options.ClientMaxWindowBits = ParseWindowBits(value); + } + else if (value.StartsWith(ClientWebSocketDeflateConstants.ServerMaxWindowBits)) + { + options.ServerMaxWindowBits = ParseWindowBits(value); + } + + static int ParseWindowBits(ReadOnlySpan value) + { + var startIndex = value.IndexOf('='); + + if (startIndex < 0 || + !int.TryParse(value.Slice(startIndex + 1), NumberStyles.Integer, CultureInfo.InvariantCulture, out int windowBits) || + windowBits < WebSocketValidate.MinDeflateWindowBits || + windowBits > WebSocketValidate.MaxDeflateWindowBits) + { + throw new WebSocketException(WebSocketError.HeaderError, + SR.Format(SR.net_WebSockets_InvalidResponseHeader, ClientWebSocketDeflateConstants.Extension, value.ToString())); + } + + return windowBits; + } + } + + if (end < 0) + { + break; + } + extension = extension[(end + 1)..]; + } + + if (options.ClientMaxWindowBits > original.ClientMaxWindowBits) + { + throw new WebSocketException(string.Format(SR.net_WebSockets_ClientWindowBitsNegotiationFailure, + original.ClientMaxWindowBits, options.ClientMaxWindowBits)); + } + + if (options.ServerMaxWindowBits > original.ServerMaxWindowBits) + { + throw new WebSocketException(string.Format(SR.net_WebSockets_ServerWindowBitsNegotiationFailure, + original.ServerMaxWindowBits, options.ServerMaxWindowBits)); + } + + return options; + } + /// Adds the necessary headers for the web socket request. /// The request to which the headers should be added. /// The generated security key to send in the Sec-WebSocket-Key header. @@ -240,6 +327,47 @@ private static void AddWebSocketHeaders(HttpRequestMessage request, string secKe { request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketProtocol, string.Join(", ", options.RequestedSubProtocols)); } + if (options.DangerousDeflateOptions is not null) + { + request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketExtensions, GetDeflateOptions(options.DangerousDeflateOptions)); + + static string GetDeflateOptions(WebSocketDeflateOptions options) + { + var builder = new StringBuilder(ClientWebSocketDeflateConstants.MaxExtensionLength); + builder.Append(ClientWebSocketDeflateConstants.Extension).Append("; "); + + if (options.ClientMaxWindowBits != WebSocketValidate.MaxDeflateWindowBits) + { + builder.Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits).Append('=') + .Append(options.ClientMaxWindowBits.ToString(CultureInfo.InvariantCulture)); + } + else + { + // Advertise that we support this option + builder.Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits); + } + + if (!options.ClientContextTakeover) + { + builder.Append("; ").Append(ClientWebSocketDeflateConstants.ClientNoContextTakeover); + } + + if (options.ServerMaxWindowBits != WebSocketValidate.MaxDeflateWindowBits) + { + builder.Append("; ") + .Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits).Append('=') + .Append(options.ServerMaxWindowBits.ToString(CultureInfo.InvariantCulture)); + } + + if (!options.ServerContextTakeover) + { + builder.Append("; ").Append(ClientWebSocketDeflateConstants.ServerNoContextTakeover); + } + + Debug.Assert(builder.Length <= ClientWebSocketDeflateConstants.MaxExtensionLength); + return builder.ToString(); + } + } } /// diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs new file mode 100644 index 0000000000000..e0a0e1e59fd84 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -0,0 +1,103 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Net.Test.Common; +using System.Reflection; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.WebSockets.Client.Tests +{ + [PlatformSpecific(~TestPlatforms.Browser)] + public class DeflateTests : ClientWebSocketTestBase + { + public DeflateTests(ITestOutputHelper output) : base(output) + { + } + + [ConditionalTheory(nameof(WebSocketsSupported))] + [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)] + [InlineData(15, true, 15, true, "permessage-deflate; client_max_window_bits")] + [InlineData(14, true, 15, true, "permessage-deflate; client_max_window_bits=14")] + [InlineData(15, true, 14, true, "permessage-deflate; client_max_window_bits; server_max_window_bits=14")] + [InlineData(10, true, 11, true, "permessage-deflate; client_max_window_bits=10; server_max_window_bits=11")] + [InlineData(15, false, 15, true, "permessage-deflate; client_max_window_bits; client_no_context_takeover")] + [InlineData(15, true, 15, false, "permessage-deflate; client_max_window_bits; server_no_context_takeover")] + public async Task PerMessageDeflateHeaders(int clientWindowBits, bool clientContextTakeover, + int serverWindowBits, bool serverContextTakover, + string expected) + { + await LoopbackServer.CreateClientAndServerAsync(async uri => + { + using var client = new ClientWebSocket(); + using var cancellation = new CancellationTokenSource(TimeOutMilliseconds); + + client.Options.DangerousDeflateOptions = new WebSocketDeflateOptions + { + ClientMaxWindowBits = clientWindowBits, + ClientContextTakeover = clientContextTakeover, + ServerMaxWindowBits = serverWindowBits, + ServerContextTakeover = serverContextTakover + }; + + await client.ConnectAsync(uri, cancellation.Token); + + object webSocketHandle = client.GetType().GetField("_innerWebSocket", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(client); + WebSocketDeflateOptions negotiatedDeflateOptions = (WebSocketDeflateOptions)webSocketHandle.GetType() + .GetField("_negotiatedDeflateOptions", BindingFlags.NonPublic | BindingFlags.Instance) + .GetValue(webSocketHandle); + + Assert.Equal(clientWindowBits - 1, negotiatedDeflateOptions.ClientMaxWindowBits); + Assert.Equal(clientContextTakeover, negotiatedDeflateOptions.ClientContextTakeover); + Assert.Equal(serverWindowBits - 1, negotiatedDeflateOptions.ServerMaxWindowBits); + Assert.Equal(serverContextTakover, negotiatedDeflateOptions.ServerContextTakeover); + }, server => server.AcceptConnectionAsync(async connection => + { + var extensionsReply = CreateDeflateOptionsHeader(new WebSocketDeflateOptions + { + ClientMaxWindowBits = clientWindowBits - 1, + ClientContextTakeover = clientContextTakeover, + ServerMaxWindowBits = serverWindowBits - 1, + ServerContextTakeover = serverContextTakover + }); + Dictionary headers = await LoopbackHelper.WebSocketHandshakeAsync(connection, extensionsReply); + Assert.NotNull(headers); + Assert.True(headers.TryGetValue("Sec-WebSocket-Extensions", out string extensions)); + Assert.Equal(expected, extensions); + }), new LoopbackServer.Options { WebSocketEndpoint = true }); + } + + private static string CreateDeflateOptionsHeader(WebSocketDeflateOptions options) + { + var builder = new StringBuilder(); + builder.Append("permessage-deflate"); + + if (options.ClientMaxWindowBits != 15) + { + builder.Append("; client_max_window_bits=").Append(options.ClientMaxWindowBits); + } + + if (!options.ClientContextTakeover) + { + builder.Append("; client_no_context_takeover"); + } + + if (options.ServerMaxWindowBits != 15) + { + builder.Append("; server_max_window_bits=").Append(options.ServerMaxWindowBits); + } + + if (!options.ServerContextTakeover) + { + builder.Append("; server_no_context_takeover"); + } + + return builder.ToString(); + } + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs index 5726326c6ab8f..48d167b072f78 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs @@ -11,7 +11,7 @@ namespace System.Net.WebSockets.Client.Tests { public static class LoopbackHelper { - public static async Task> WebSocketHandshakeAsync(LoopbackServer.Connection connection) + public static async Task> WebSocketHandshakeAsync(LoopbackServer.Connection connection, string? extensions = null) { string serverResponse = null; List headers = await connection.ReadRequestHeaderAsync().ConfigureAwait(false); @@ -34,6 +34,7 @@ public static async Task> WebSocketHandshakeAsync(Loo "Content-Length: 0\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + + (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") + "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n"; } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj index a1323fa83db1e..21ba2a12dd524 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj +++ b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj @@ -46,6 +46,7 @@ + diff --git a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs index e4ff945bb5b64..32ebf5eb1e804 100644 --- a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs +++ b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs @@ -29,6 +29,7 @@ protected WebSocket() { } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] public static System.Net.WebSockets.WebSocket CreateClientWebSocket(System.IO.Stream innerStream, string? subProtocol, int receiveBufferSize, int sendBufferSize, System.TimeSpan keepAliveInterval, bool useZeroMaskingKey, System.ArraySegment internalBuffer) { throw null; } public static System.Net.WebSockets.WebSocket CreateFromStream(System.IO.Stream stream, bool isServer, string? subProtocol, System.TimeSpan keepAliveInterval) { throw null; } + public static System.Net.WebSockets.WebSocket CreateFromStream(System.IO.Stream stream, System.Net.WebSockets.WebSocketCreationOptions options) { throw null; } public static System.ArraySegment CreateServerBuffer(int receiveBufferSize) { throw null; } public abstract void Dispose(); [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] @@ -41,6 +42,7 @@ protected WebSocket() { } public static void RegisterPrefixes() { } public abstract System.Threading.Tasks.Task SendAsync(System.ArraySegment buffer, System.Net.WebSockets.WebSocketMessageType messageType, bool endOfMessage, System.Threading.CancellationToken cancellationToken); public virtual System.Threading.Tasks.ValueTask SendAsync(System.ReadOnlyMemory buffer, System.Net.WebSockets.WebSocketMessageType messageType, bool endOfMessage, System.Threading.CancellationToken cancellationToken) { throw null; } + public virtual System.Threading.Tasks.ValueTask SendAsync(System.ReadOnlyMemory buffer, System.Net.WebSockets.WebSocketMessageType messageType, System.Net.WebSockets.WebSocketMessageFlags messageFlags, System.Threading.CancellationToken cancellationToken) { throw null; } protected static void ThrowOnInvalidState(System.Net.WebSockets.WebSocketState state, params System.Net.WebSockets.WebSocketState[] validStates) { } } public enum WebSocketCloseStatus @@ -131,4 +133,25 @@ public enum WebSocketState Closed = 5, Aborted = 6, } + public sealed partial class WebSocketCreationOptions + { + public bool IsServer { get { throw null; } set { } } + public string? SubProtocol { get { throw null; } set { } } + public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } + public System.Net.WebSockets.WebSocketDeflateOptions? DangerousDeflateOptions { get { throw null; } set { } } + } + public sealed partial class WebSocketDeflateOptions + { + public int ClientMaxWindowBits { get { throw null; } set { } } + public bool ClientContextTakeover { get { throw null; } set { } } + public int ServerMaxWindowBits { get { throw null; } set { } } + public bool ServerContextTakeover { get { throw null; } set { } } + } + [Flags] + public enum WebSocketMessageFlags + { + None = 0, + EndOfMessage = 1, + DisableCompression = 2 + } } diff --git a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx index a4f630ea24c03..693f8d3863fd7 100644 --- a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx @@ -138,4 +138,31 @@ The base stream is not writeable. - + + The argument must be a value between {0} and {1}. + + + The WebSocket received a continuation frame with Per-Message Compressed flag set. + + + The WebSocket received compressed frame when compression is not enabled. + + + The underlying compression routine could not be loaded correctly. + + + The stream state of the underlying compression routine is inconsistent. + + + The underlying compression routine could not reserve sufficient memory. + + + The underlying compression routine returned an unexpected error code {0}. + + + The message was compressed using an unsupported compression method. + + + The compression options for a continuation cannot be different than the options used to send the first fragment of the message. + + \ No newline at end of file diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index d65e6c55737af..215cf6b4a9164 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -1,26 +1,48 @@ True - $(NetCoreAppCurrent) + $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser enable + + + + + + + + + + + + + + + + + diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs new file mode 100644 index 0000000000000..e7f1807284243 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -0,0 +1,235 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using static System.IO.Compression.ZLibNative; + +namespace System.Net.WebSockets.Compression +{ + /// + /// Provides a wrapper around the ZLib compression API. + /// + internal sealed class WebSocketDeflater : IDisposable + { + private readonly int _windowBits; + private ZLibStreamHandle? _stream; + private readonly bool _persisted; + + private byte[]? _buffer; + + internal WebSocketDeflater(int windowBits, bool persisted) + { + _windowBits = -windowBits; // Negative for raw deflate + _persisted = persisted; + } + + public void Dispose() + { + if (_stream is not null) + { + _stream.Dispose(); + _stream = null; + } + } + + public void ReleaseBuffer() + { + if (_buffer is not null) + { + ArrayPool.Shared.Return(_buffer); + _buffer = null; + } + } + + public ReadOnlySpan Deflate(ReadOnlySpan payload, bool endOfMessage) + { + Debug.Assert(_buffer is null, "Invalid state, ReleaseBuffer not called."); + + // Do not try to rent more than 1MB initially, because it will actually allocate + // instead of renting. Be optimistic that what we're sending is actually going to fit. + const int MaxInitialBufferLength = 1024 * 1024; + + // For small payloads there might actually be overhead in the compression and the resulting + // output might be larger than the payload. This is why we rent at least 4KB initially. + const int MinInitialBufferLength = 4 * 1024; + + _buffer = ArrayPool.Shared.Rent(Math.Clamp(payload.Length, MinInitialBufferLength, MaxInitialBufferLength)); + int position = 0; + + while (true) + { + DeflatePrivate(payload, _buffer.AsSpan(position), endOfMessage, + out int consumed, out int written, out bool needsMoreOutput); + position += written; + + if (!needsMoreOutput) + { + Debug.Assert(consumed == payload.Length); + break; + } + + payload = payload.Slice(consumed); + + // Rent a 30% bigger buffer + byte[] newBuffer = ArrayPool.Shared.Rent((int)(_buffer.Length * 1.3)); + _buffer.AsSpan(0, position).CopyTo(newBuffer); + ArrayPool.Shared.Return(_buffer); + _buffer = newBuffer; + } + + return new ReadOnlySpan(_buffer, 0, position); + } + + private void DeflatePrivate(ReadOnlySpan payload, Span output, bool endOfMessage, + out int consumed, out int written, out bool needsMoreOutput) + { + _stream ??= CreateDeflater(); + + if (payload.Length == 0) + { + consumed = 0; + written = 0; + } + else + { + UnsafeDeflate(payload, output, out consumed, out written, out needsMoreOutput); + + if (needsMoreOutput) + { + Debug.Assert(written == output.Length); + return; + } + } + + written += UnsafeFlush(output.Slice(written), out needsMoreOutput); + + if (needsMoreOutput) + { + return; + } + Debug.Assert(output.Slice(written - WebSocketInflater.FlushMarkerLength, WebSocketInflater.FlushMarkerLength) + .EndsWith(WebSocketInflater.FlushMarker), "The deflated block must always end with a flush marker."); + + if (endOfMessage) + { + // As per RFC we need to remove the flush markers + written -= WebSocketInflater.FlushMarkerLength; + } + + if (endOfMessage && !_persisted) + { + _stream.Dispose(); + _stream = null; + } + } + + private unsafe void UnsafeDeflate(ReadOnlySpan input, Span output, out int consumed, out int written, out bool needsMoreBuffer) + { + Debug.Assert(_stream is not null); + + fixed (byte* fixedInput = input) + fixed (byte* fixedOutput = output) + { + _stream.NextIn = (IntPtr)fixedInput; + _stream.AvailIn = (uint)input.Length; + + _stream.NextOut = (IntPtr)fixedOutput; + _stream.AvailOut = (uint)output.Length; + + // The flush is set to Z_NO_FLUSH, which allows deflate to decide + // how much data to accumulate before producing output, + // in order to maximize compression. + var errorCode = Deflate(_stream, FlushCode.NoFlush); + + consumed = input.Length - (int)_stream.AvailIn; + written = output.Length - (int)_stream.AvailOut; + + needsMoreBuffer = errorCode == ErrorCode.BufError || _stream.AvailIn > 0; + } + } + + private unsafe int UnsafeFlush(Span output, out bool needsMoreBuffer) + { + Debug.Assert(_stream is not null); + Debug.Assert(_stream.AvailIn == 0); + + fixed (byte* fixedOutput = output) + { + _stream.NextIn = IntPtr.Zero; + _stream.AvailIn = 0; + + _stream.NextOut = (IntPtr)fixedOutput; + _stream.AvailOut = (uint)output.Length; + + // We need to use Z_BLOCK_FLUSH to instruct the zlib to flush all outstanding + // data but also not to emit a deflate block boundary. After we know that there is no + // more data, we can safely proceed to instruct the library to emit the boundary markers. + ErrorCode errorCode = Deflate(_stream, FlushCode.Block); + Debug.Assert(errorCode is ErrorCode.Ok or ErrorCode.BufError); + + // We need at least 6 bytes to guarantee that we can emit a deflate block boundary. + needsMoreBuffer = _stream.AvailOut < 6; + + if (!needsMoreBuffer) + { + // The flush is set to Z_SYNC_FLUSH, all pending output is flushed + // to the output buffer and the output is aligned on a byte boundary, + // so that the decompressor can get all input data available so far. + // This completes the current deflate block and follows it with an empty + // stored block that is three bits plus filler bits to the next byte, + // followed by four bytes (00 00 ff ff). + errorCode = Deflate(_stream, FlushCode.SyncFlush); + Debug.Assert(errorCode == ErrorCode.Ok); + } + + return output.Length - (int)_stream.AvailOut; + } + } + + private static ErrorCode Deflate(ZLibStreamHandle stream, FlushCode flushCode) + { + ErrorCode errorCode = stream.Deflate(flushCode); + + if (errorCode is ErrorCode.Ok or ErrorCode.StreamEnd or ErrorCode.BufError) + { + return errorCode; + } + + string message = errorCode == ErrorCode.StreamError + ? SR.ZLibErrorInconsistentStream + : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); + throw new WebSocketException(message); + } + + private ZLibStreamHandle CreateDeflater() + { + ZLibStreamHandle stream; + ErrorCode errorCode; + try + { + errorCode = CreateZLibStreamForDeflate(out stream, + level: CompressionLevel.DefaultCompression, + windowBits: _windowBits, + memLevel: Deflate_DefaultMemLevel, + strategy: CompressionStrategy.DefaultStrategy); + } + catch (Exception cause) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + + if (errorCode == ErrorCode.Ok) + { + return stream; + } + + stream.Dispose(); + + string message = errorCode == ErrorCode.MemError + ? SR.ZLibErrorNotEnoughMemory + : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); + throw new WebSocketException(message); + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs new file mode 100644 index 0000000000000..6ade12d539a44 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -0,0 +1,285 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using static System.IO.Compression.ZLibNative; + +namespace System.Net.WebSockets.Compression +{ + /// + /// Provides a wrapper around the ZLib decompression API. + /// + internal sealed class WebSocketInflater : IDisposable + { + internal const int FlushMarkerLength = 4; + internal static ReadOnlySpan FlushMarker => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; + + private readonly int _windowBits; + private ZLibStreamHandle? _stream; + private readonly bool _persisted; + + /// + /// There is no way of knowing, when decoding data, if the underlying inflater + /// has flushed all outstanding data to consumer other than to provide a buffer + /// and see whether any bytes are written. There are cases when the consumers + /// provide a buffer exactly the size of the uncompressed data and in this case + /// to avoid requiring another read we will use this field. + /// + private byte? _remainingByte; + + /// + /// The last added bytes to the inflater were part of the final + /// payload for the message being sent. + /// + private bool _endOfMessage; + + private byte[]? _buffer; + + /// + /// The position for the next unconsumed byte in the inflate buffer. + /// + private int _position; + + /// + /// How many unconsumed bytes are left in the inflate buffer. + /// + private int _available; + + internal WebSocketInflater(int windowBits, bool persisted) + { + _windowBits = -windowBits; // Negative for raw deflate + _persisted = persisted; + } + + public Memory Memory => _buffer.AsMemory(_position + _available); + + public Span Span => _buffer.AsSpan(_position + _available); + + public void Dispose() + { + if (_stream is not null) + { + _stream.Dispose(); + _stream = null; + } + ReleaseBuffer(); + } + + /// + /// Initializes the inflater by allocating a buffer so the websocket can receive directly onto it. + /// + /// the length of the message payload + /// the length of the buffer where the payload will be inflated + public void Prepare(long payloadLength, int userBufferLength) + { + if (_buffer is not null) + { + Debug.Assert(_available > 0); + + _buffer.AsSpan(_position, _available).CopyTo(_buffer); + _position = 0; + } + else + { + // Rent a buffer as close to the size of the user buffer as possible, + // but not try to rent anything above 1MB because the array pool will allocate. + // If the payload is smaller than the user buffer, rent only as much as we need. + _buffer = ArrayPool.Shared.Rent(Math.Min(userBufferLength, (int)Math.Min(payloadLength, 1024 * 1024))); + } + } + + public void AddBytes(int totalBytesReceived, bool endOfMessage) + { + Debug.Assert(totalBytesReceived == 0 || _buffer is not null, "Prepare must be called."); + + _available += totalBytesReceived; + _endOfMessage = endOfMessage; + + if (endOfMessage) + { + if (_buffer is null) + { + Debug.Assert(_available == 0); + + _buffer = ArrayPool.Shared.Rent(FlushMarkerLength); + _available = FlushMarkerLength; + FlushMarker.CopyTo(_buffer); + } + else + { + if (_buffer.Length < _available + FlushMarkerLength) + { + byte[] newBuffer = ArrayPool.Shared.Rent(_available + FlushMarkerLength); + _buffer.AsSpan(0, _available).CopyTo(newBuffer); + ArrayPool.Shared.Return(_buffer); + + _buffer = newBuffer; + } + + FlushMarker.CopyTo(_buffer.AsSpan(_available)); + _available += FlushMarkerLength; + } + } + } + + /// + /// Inflates the last receive payload into the provided buffer. + /// + public unsafe bool Inflate(Span output, out int written) + { + _stream ??= CreateInflater(); + + if (_available > 0 && output.Length > 0) + { + int consumed; + + fixed (byte* bufferPtr = _buffer) + { + _stream.NextIn = (IntPtr)(bufferPtr + _position); + _stream.AvailIn = (uint)_available; + + written = Inflate(_stream, output, FlushCode.NoFlush); + consumed = _available - (int)_stream.AvailIn; + } + + _position += consumed; + _available -= consumed; + } + else + { + written = 0; + } + + if (_available == 0) + { + ReleaseBuffer(); + return _endOfMessage ? Finish(output, ref written) : true; + } + + return false; + } + + /// + /// Finishes the decoding by flushing any outstanding data to the output. + /// + /// true if the flush completed, false to indicate that there is more outstanding data. + private unsafe bool Finish(Span output, ref int written) + { + Debug.Assert(_stream is not null && _stream.AvailIn == 0); + Debug.Assert(_available == 0); + + if (_remainingByte is not null) + { + if (output.Length == written) + { + return false; + } + output[written] = _remainingByte.GetValueOrDefault(); + _remainingByte = null; + written += 1; + } + + // If we have more space in the output, try to inflate + if (output.Length > written) + { + written += Inflate(_stream, output[written..], FlushCode.SyncFlush); + } + + // After inflate, if we have more space in the output then it means that we + // have finished. Otherwise we need to manually check for more data. + if (written < output.Length || IsFinished(_stream, out _remainingByte)) + { + if (!_persisted) + { + _stream.Dispose(); + _stream = null; + } + return true; + } + + return false; + } + + private void ReleaseBuffer() + { + if (_buffer is not null) + { + ArrayPool.Shared.Return(_buffer); + _buffer = null; + _available = 0; + _position = 0; + } + } + + private static unsafe bool IsFinished(ZLibStreamHandle stream, out byte? remainingByte) + { + // There is no other way to make sure that we've consumed all data + // but to try to inflate again with at least one byte of output buffer. + byte b; + if (Inflate(stream, new Span(&b, 1), FlushCode.SyncFlush) == 0) + { + remainingByte = null; + return true; + } + + remainingByte = b; + return false; + } + + private static unsafe int Inflate(ZLibStreamHandle stream, Span destination, FlushCode flushCode) + { + Debug.Assert(destination.Length > 0); + ErrorCode errorCode; + + fixed (byte* bufPtr = destination) + { + stream.NextOut = (IntPtr)bufPtr; + stream.AvailOut = (uint)destination.Length; + + errorCode = stream.Inflate(flushCode); + + if (errorCode is ErrorCode.Ok or ErrorCode.StreamEnd or ErrorCode.BufError) + { + return destination.Length - (int)stream.AvailOut; + } + } + + string message = errorCode switch + { + ErrorCode.MemError => SR.ZLibErrorNotEnoughMemory, + ErrorCode.DataError => SR.ZLibUnsupportedCompression, + ErrorCode.StreamError => SR.ZLibErrorInconsistentStream, + _ => string.Format(SR.ZLibErrorUnexpected, (int)errorCode) + }; + throw new WebSocketException(message); + } + + private ZLibStreamHandle CreateInflater() + { + ZLibStreamHandle stream; + ErrorCode errorCode; + + try + { + errorCode = CreateZLibStreamForInflate(out stream, _windowBits); + } + catch (Exception exception) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception); + } + + if (errorCode == ErrorCode.Ok) + { + return stream; + } + + stream.Dispose(); + + string message = errorCode == ErrorCode.MemError + ? SR.ZLibErrorNotEnoughMemory + : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); + throw new WebSocketException(message); + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 9a0142f9c73b3..971c2ceff82be 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.Diagnostics; using System.IO; +using System.Net.WebSockets.Compression; using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -25,18 +26,6 @@ namespace System.Net.WebSockets /// internal sealed partial class ManagedWebSocket : WebSocket { - /// Creates a from a connected to a websocket endpoint. - /// The connected Stream. - /// true if this is the server-side of the connection; false if this is the client-side of the connection. - /// The agreed upon subprotocol for the connection. - /// The interval to use for keep-alive pings. - /// The created instance. - public static ManagedWebSocket CreateFromConnectedStream( - Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval) - { - return new ManagedWebSocket(stream, isServer, subprotocol, keepAliveInterval); - } - /// Thread-safe random number generator used to generate masks for each send. private static readonly RandomNumberGenerator s_random = RandomNumberGenerator.Create(); /// Encoding for the payload of text messages: UTF8 encoding that throws if invalid bytes are discovered, per the RFC. @@ -113,7 +102,7 @@ public static ManagedWebSocket CreateFromConnectedStream( /// remaining to be received for that header. As a result, between fragments, the payload /// length in this header should be 0. /// - private MessageHeader _lastReceiveHeader = new MessageHeader { Opcode = MessageOpcode.Text, Fin = true }; + private MessageHeader _lastReceiveHeader = new MessageHeader { Opcode = MessageOpcode.Text, Fin = true, Processed = true }; /// The offset of the next available byte in the _receiveBuffer. private int _receiveBufferOffset; /// The number of bytes available in the _receiveBuffer. @@ -137,6 +126,10 @@ public static ManagedWebSocket CreateFromConnectedStream( /// private bool _lastSendWasFragment; /// + /// Whether the last SendAsync had flag set. + /// + private bool _lastSendHadDisableCompression; + /// /// The task returned from the last ReceiveAsync(ArraySegment, ...) operation to not complete synchronously. /// If this is not null and not completed when a subsequent ReceiveAsync is issued, an exception occurs. /// @@ -151,12 +144,15 @@ public static ManagedWebSocket CreateFromConnectedStream( /// private object ReceiveAsyncLock => _utf8TextState; // some object, as we're simply lock'ing on it + private readonly WebSocketInflater? _inflater; + private readonly WebSocketDeflater? _deflater; + /// Initializes the websocket. /// The connected Stream. /// true if this is the server-side of the connection; false if this is the client-side of the connection. /// The agreed upon subprotocol for the connection. /// The interval to use for keep-alive pings. - private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval) + internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval) { Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null"); Debug.Assert(ReceiveAsyncLock != null, $"Expected {nameof(ReceiveAsyncLock)} to be non-null"); @@ -212,6 +208,29 @@ private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Time } } + /// Initializes the websocket. + /// The connected Stream. + /// The options with which the websocket must be created. + internal ManagedWebSocket(Stream stream, WebSocketCreationOptions options) + : this(stream, options.IsServer, options.SubProtocol, options.KeepAliveInterval) + { + var deflateOptions = options.DangerousDeflateOptions; + + if (deflateOptions is not null) + { + if (options.IsServer) + { + _inflater = new WebSocketInflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover); + _deflater = new WebSocketDeflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover); + } + else + { + _inflater = new WebSocketInflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover); + _deflater = new WebSocketDeflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover); + } + } + } + public override void Dispose() { lock (StateUpdateLock) @@ -227,7 +246,10 @@ private void DisposeCore() { _disposed = true; _keepAliveTimer?.Dispose(); - _stream?.Dispose(); + _stream.Dispose(); + _inflater?.Dispose(); + _deflater?.Dispose(); + if (_state < WebSocketState.Aborted) { _state = WebSocketState.Closed; @@ -255,10 +277,10 @@ public override Task SendAsync(ArraySegment buffer, WebSocketMessageType m WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); - return SendPrivateAsync(buffer, messageType, endOfMessage, cancellationToken).AsTask(); + return SendPrivateAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken).AsTask(); } - private ValueTask SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + private ValueTask SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) { if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary) { @@ -277,13 +299,27 @@ private ValueTask SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessage return new ValueTask(Task.FromException(exc)); } - MessageOpcode opcode = - _lastSendWasFragment ? MessageOpcode.Continuation : - messageType == WebSocketMessageType.Binary ? MessageOpcode.Binary : - MessageOpcode.Text; + bool endOfMessage = messageFlags.HasFlag(WebSocketMessageFlags.EndOfMessage); + bool disableCompression = messageFlags.HasFlag(WebSocketMessageFlags.DisableCompression); + MessageOpcode opcode; - ValueTask t = SendFrameAsync(opcode, endOfMessage, buffer, cancellationToken); + if (_lastSendWasFragment) + { + if (_lastSendHadDisableCompression != disableCompression) + { + throw new ArgumentException(SR.net_WebSockets_Argument_MessageFlagsHasDifferentCompressionOptions, nameof(messageFlags)); + } + opcode = MessageOpcode.Continuation; + } + else + { + opcode = messageType == WebSocketMessageType.Binary ? MessageOpcode.Binary : MessageOpcode.Text; + } + + ValueTask t = SendFrameAsync(opcode, endOfMessage, disableCompression, buffer, cancellationToken); _lastSendWasFragment = !endOfMessage; + _lastSendHadDisableCompression = disableCompression; + return t; } @@ -299,7 +335,7 @@ public override Task ReceiveAsync(ArraySegment buf lock (ReceiveAsyncLock) // synchronize with receives in CloseAsync { ThrowIfOperationInProgress(_lastReceiveAsync.IsCompleted); - Task t = ReceiveAsyncPrivate(buffer, cancellationToken).AsTask(); + Task t = ReceiveAsyncPrivate(buffer, cancellationToken).AsTask(); _lastReceiveAsync = t; return t; } @@ -357,7 +393,12 @@ public override void Abort() public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { - return SendPrivateAsync(buffer, messageType, endOfMessage, cancellationToken); + return SendPrivateAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken); + } + + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) + { + return SendPrivateAsync(buffer, messageType, messageFlags, cancellationToken); } public override ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken) @@ -371,7 +412,7 @@ public override ValueTask ReceiveAsync(Memory { ThrowIfOperationInProgress(_lastReceiveAsync.IsCompleted); - ValueTask receiveValueTask = ReceiveAsyncPrivate(buffer, cancellationToken); + ValueTask receiveValueTask = ReceiveAsyncPrivate(buffer, cancellationToken); if (receiveValueTask.IsCompletedSuccessfully) { _lastReceiveAsync = receiveValueTask.Result.MessageType == WebSocketMessageType.Close ? s_cachedCloseTask : Task.CompletedTask; @@ -400,7 +441,7 @@ private Task ValidateAndReceiveAsync(Task receiveTask, byte[] buffer, Cancellati !(receiveTask is Task wsrr && wsrr.Result.MessageType == WebSocketMessageType.Close) && !(receiveTask is Task vwsrr && vwsrr.Result.MessageType == WebSocketMessageType.Close))) { - ValueTask vt = ReceiveAsyncPrivate(buffer, cancellationToken); + ValueTask vt = ReceiveAsyncPrivate(buffer, cancellationToken); receiveTask = vt.IsCompletedSuccessfully ? (vt.Result.MessageType == WebSocketMessageType.Close ? s_cachedCloseTask : Task.CompletedTask) : vt.AsTask(); @@ -409,19 +450,13 @@ private Task ValidateAndReceiveAsync(Task receiveTask, byte[] buffer, Cancellati return receiveTask; } - /// implementation for . - private readonly struct ValueWebSocketReceiveResultGetter : IWebSocketReceiveResultGetter - { - public ValueWebSocketReceiveResult GetResult(int count, WebSocketMessageType messageType, bool endOfMessage, WebSocketCloseStatus? closeStatus, string? closeDescription) => - new ValueWebSocketReceiveResult(count, messageType, endOfMessage); // closeStatus/closeDescription are ignored - } - /// Sends a websocket frame to the network. /// The opcode for the message. /// The value of the FIN bit for the message. - /// The buffer containing the payload data fro the message. + /// Disables compression for the message. + /// The buffer containing the payload data from the message. /// The CancellationToken to use to cancel the websocket. - private ValueTask SendFrameAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken) + private ValueTask SendFrameAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken) { // If a cancelable cancellation token was provided, that would require registering with it, which means more state we have to // pass around (the CancellationTokenRegistration), so if it is cancelable, just immediately go to the fallback path. @@ -430,15 +465,16 @@ private ValueTask SendFrameAsync(MessageOpcode opcode, bool endOfMessage, ReadOn #pragma warning disable CA1416 // Validate platform compatibility, will not wait because timeout equals 0 return cancellationToken.CanBeCanceled || !_sendFrameAsyncLock.Wait(0, default) ? #pragma warning restore CA1416 - SendFrameFallbackAsync(opcode, endOfMessage, payloadBuffer, cancellationToken) : - SendFrameLockAcquiredNonCancelableAsync(opcode, endOfMessage, payloadBuffer); + SendFrameFallbackAsync(opcode, endOfMessage, disableCompression, payloadBuffer, cancellationToken) : + SendFrameLockAcquiredNonCancelableAsync(opcode, endOfMessage, disableCompression, payloadBuffer); } /// Sends a websocket frame to the network. The caller must hold the sending lock. /// The opcode for the message. /// The value of the FIN bit for the message. + /// Disables compression for the message. /// The buffer containing the payload data fro the message. - private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer) + private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer) { Debug.Assert(_sendFrameAsyncLock.CurrentCount == 0, "Caller should hold the _sendFrameAsyncLock"); @@ -449,7 +485,7 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, try { // Write the payload synchronously to the buffer, then write that buffer out to the network. - int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, payloadBuffer.Span); + int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, disableCompression, payloadBuffer.Span); writeTask = _stream.WriteAsync(new ReadOnlyMemory(_sendBuffer, 0, sendBytes)); // If the operation happens to complete synchronously (or, more specifically, by @@ -503,12 +539,12 @@ private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask) } } - private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken) + private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken) { await _sendFrameAsyncLock.WaitAsync(cancellationToken).ConfigureAwait(false); try { - int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, payloadBuffer.Span); + int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, disableCompression, payloadBuffer.Span); using (cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this)) { await _stream.WriteAsync(new ReadOnlyMemory(_sendBuffer, 0, sendBytes), cancellationToken).ConfigureAwait(false); @@ -528,10 +564,16 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM } /// Writes a frame into the send buffer, which can then be sent over the network. - private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, ReadOnlySpan payloadBuffer) + private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlySpan payloadBuffer) { - // Ensure we have a _sendBuffer. - AllocateSendBuffer(payloadBuffer.Length + MaxMessageHeaderLength); + if (_deflater is not null && !disableCompression) + { + payloadBuffer = _deflater.Deflate(payloadBuffer, endOfMessage); + } + int payloadLength = payloadBuffer.Length; + + // Ensure we have a _sendBuffer + AllocateSendBuffer(payloadLength + MaxMessageHeaderLength); Debug.Assert(_sendBuffer != null); // Write the message header data to the buffer. @@ -541,31 +583,34 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Read { // The server doesn't send a mask, so the mask offset returned by WriteHeader // is actually the end of the header. - headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false); + headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false, compressed: _deflater is not null && !disableCompression); } else { // We need to know where the mask starts so that we can use the mask to manipulate the payload data, // and we need to know the total length for sending it on the wire. - maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true); + maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true, compressed: _deflater is not null && !disableCompression); headerLength = maskOffset.GetValueOrDefault() + MaskLength; } // Write the payload if (payloadBuffer.Length > 0) { - payloadBuffer.CopyTo(new Span(_sendBuffer, headerLength, payloadBuffer.Length)); + payloadBuffer.CopyTo(new Span(_sendBuffer, headerLength, payloadLength)); + + // Release the deflater buffer if any, we're not going to need the payloadBuffer anymore. + _deflater?.ReleaseBuffer(); // If we added a mask to the header, XOR the payload with the mask. We do the manipulation in the send buffer so as to avoid // changing the data in the caller-supplied payload buffer. if (maskOffset.HasValue) { - ApplyMask(new Span(_sendBuffer, headerLength, payloadBuffer.Length), _sendBuffer, maskOffset.Value, 0); + ApplyMask(new Span(_sendBuffer, headerLength, payloadLength), _sendBuffer, maskOffset.Value, 0); } } // Return the number of bytes in the send buffer - return headerLength + payloadBuffer.Length; + return headerLength + payloadLength; } private void SendKeepAliveFrameAsync() @@ -578,7 +623,7 @@ private void SendKeepAliveFrameAsync() // This exists purely to keep the connection alive; don't wait for the result, and ignore any failures. // The call will handle releasing the lock. We send a pong rather than ping, since it's allowed by // the RFC as a unidirectional heartbeat and we're not interested in waiting for a response. - ValueTask t = SendFrameLockAcquiredNonCancelableAsync(MessageOpcode.Pong, true, ReadOnlyMemory.Empty); + ValueTask t = SendFrameLockAcquiredNonCancelableAsync(MessageOpcode.Pong, endOfMessage: true, disableCompression: true, ReadOnlyMemory.Empty); if (t.IsCompletedSuccessfully) { t.GetAwaiter().GetResult(); @@ -599,7 +644,7 @@ private void SendKeepAliveFrameAsync() } } - private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnlySpan payload, bool endOfMessage, bool useMask) + private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnlySpan payload, bool endOfMessage, bool useMask, bool compressed) { // Client header format: // 1 bit - FIN - 1 if this is the final fragment in the message (it could be the only fragment), otherwise 0 @@ -629,6 +674,11 @@ private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnly { sendBuffer[0] |= 0x80; // 1 bit for FIN } + if (compressed && opcode != MessageOpcode.Continuation) + { + // Per-Message Deflate flag needs to be set only in the first frame + sendBuffer[0] |= 0b_0100_0000; + } // Store the payload length. int maskOffset; @@ -680,13 +730,8 @@ private static void WriteRandomMask(byte[] buffer, int offset) => /// /// The buffer into which payload data should be written. /// The CancellationToken used to cancel the websocket. - /// Used to get the result. Allows the same method to be used with both WebSocketReceiveResult and ValueWebSocketReceiveResult. /// Information about the received message. - private async ValueTask ReceiveAsyncPrivate( - Memory payloadBuffer, - CancellationToken cancellationToken, - TWebSocketReceiveResultGetter resultGetter = default) - where TWebSocketReceiveResultGetter : struct, IWebSocketReceiveResultGetter // constrained to avoid boxing and enable inlining + private async ValueTask ReceiveAsyncPrivate(Memory payloadBuffer, CancellationToken cancellationToken) { // This is a long method. While splitting it up into pieces would arguably help with readability, doing so would // also result in more allocations, as each async method that yields ends up with multiple allocations. The impact @@ -707,7 +752,7 @@ private async ValueTask ReceiveAsyncPrivate ReceiveAsyncPrivate ReceiveAsyncPrivate(0, WebSocketMessageType.Close, true); } // If this is a continuation, replace the opcode with the one of the message it's continuing if (header.Opcode == MessageOpcode.Continuation) { header.Opcode = _lastReceiveHeader.Opcode; + header.Compressed = _lastReceiveHeader.Compressed; } // The message should now be a binary or text message. Handle it by reading the payload and returning the contents. Debug.Assert(header.Opcode == MessageOpcode.Binary || header.Opcode == MessageOpcode.Text, $"Unexpected opcode {header.Opcode}"); // If there's no data to read, return an appropriate result. - if (header.PayloadLength == 0 || payloadBuffer.Length == 0) + if (header.Processed || payloadBuffer.Length == 0) { _lastReceiveHeader = header; - return resultGetter.GetResult( - 0, - header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, - header.Fin && header.PayloadLength == 0, - null, null); + return GetReceiveResult( + count: 0, + messageType: header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, + endOfMessage: header.EndOfMessage); } // Otherwise, read as much of the payload as we can efficiently, and update the header to reflect how much data @@ -779,56 +832,86 @@ private async ValueTask ReceiveAsyncPrivate 0) - { - int receiveBufferBytesToCopy = Math.Min(payloadBuffer.Length, (int)Math.Min(header.PayloadLength, _receiveBufferCount)); - Debug.Assert(receiveBufferBytesToCopy > 0); - _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(payloadBuffer.Span); - ConsumeFromBuffer(receiveBufferBytesToCopy); - totalBytesReceived += receiveBufferBytesToCopy; - Debug.Assert( - _receiveBufferCount == 0 || - totalBytesReceived == payloadBuffer.Length || - totalBytesReceived == header.PayloadLength); - } - // Then read directly into the payload buffer until we've hit a limit. - while (totalBytesReceived < payloadBuffer.Length && - totalBytesReceived < header.PayloadLength) + // Only start a new receive if we haven't received the entire frame. + if (header.PayloadLength > 0) { - int numBytesRead = await _stream.ReadAsync(payloadBuffer.Slice( - totalBytesReceived, - (int)Math.Min(payloadBuffer.Length, header.PayloadLength) - totalBytesReceived), cancellationToken).ConfigureAwait(false); - if (numBytesRead <= 0) + if (header.Compressed) + { + Debug.Assert(_inflater is not null); + _inflater.Prepare(header.PayloadLength, payloadBuffer.Length); + } + + // Read directly into the appropriate buffer until we've hit a limit. + int limit = (int)Math.Min(header.Compressed ? _inflater!.Span.Length : payloadBuffer.Length, header.PayloadLength); + + if (_receiveBufferCount > 0) + { + int receiveBufferBytesToCopy = Math.Min(limit, _receiveBufferCount); + Debug.Assert(receiveBufferBytesToCopy > 0); + + _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo( + header.Compressed ? _inflater!.Span : payloadBuffer.Span); + ConsumeFromBuffer(receiveBufferBytesToCopy); + totalBytesReceived += receiveBufferBytesToCopy; + } + + while (totalBytesReceived < limit) { - ThrowIfEOFUnexpected(throwOnPrematureClosure: true); - break; + int numBytesRead = await _stream.ReadAsync(header.Compressed ? + _inflater!.Memory.Slice(totalBytesReceived, limit - totalBytesReceived) : + payloadBuffer.Slice(totalBytesReceived, limit - totalBytesReceived), + cancellationToken).ConfigureAwait(false); + if (numBytesRead <= 0) + { + ThrowIfEOFUnexpected(throwOnPrematureClosure: true); + break; + } + totalBytesReceived += numBytesRead; + } + + if (_isServer) + { + _receivedMaskOffsetOffset = ApplyMask(header.Compressed ? + _inflater!.Span.Slice(0, totalBytesReceived) : + payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset); + } + + header.PayloadLength -= totalBytesReceived; + + if (header.Compressed) + { + _inflater!.AddBytes(totalBytesReceived, endOfMessage: header.Fin && header.PayloadLength == 0); } - totalBytesReceived += numBytesRead; } - if (_isServer) + if (header.Compressed) + { + // In case of compression totalBytesReceived should actually represent how much we've + // inflated, rather than how much we've read from the stream. + header.Processed = _inflater!.Inflate(payloadBuffer.Span, out totalBytesReceived) && header.PayloadLength == 0; + } + else { - _receivedMaskOffsetOffset = ApplyMask(payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset); + // Without compression the frame is processed as soon as we've received everything + header.Processed = header.PayloadLength == 0; } - header.PayloadLength -= totalBytesReceived; // If this a text message, validate that it contains valid UTF8. if (header.Opcode == MessageOpcode.Text && - !TryValidateUtf8(payloadBuffer.Span.Slice(0, totalBytesReceived), header.Fin && header.PayloadLength == 0, _utf8TextState)) + !TryValidateUtf8(payloadBuffer.Span.Slice(0, totalBytesReceived), header.EndOfMessage, _utf8TextState)) { await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false); } _lastReceiveHeader = header; - return resultGetter.GetResult( + return GetReceiveResult( totalBytesReceived, header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, - header.Fin && header.PayloadLength == 0, - null, null); + header.EndOfMessage); } } - catch (Exception exc) when (!(exc is OperationCanceledException)) + catch (Exception exc) when (exc is not OperationCanceledException) { if (_state == WebSocketState.Aborted) { @@ -849,6 +932,23 @@ private async ValueTask ReceiveAsyncPrivate + /// Returns either or . + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private TResult GetReceiveResult(int count, WebSocketMessageType messageType, bool endOfMessage) + { + if (typeof(TResult) == typeof(ValueWebSocketReceiveResult)) + { + // Although it might seem that this will incur boxing of the struct, + // the JIT is smart enough to figure out it is unncessessary and will emit + // bytecode that returns the ValueWebSocketReceiveResult directly. + return (TResult)(object)new ValueWebSocketReceiveResult(count, messageType, endOfMessage); + } + + return (TResult)(object)new WebSocketReceiveResult(count, messageType, endOfMessage, _closeStatus, _closeStatusDescription); + } + /// Processes a received close message. /// The message header. /// The CancellationToken used to cancel the websocket operation. @@ -967,6 +1067,7 @@ private async ValueTask HandleReceivedPingPongAsync(MessageHeader header, Cancel await SendFrameAsync( MessageOpcode.Pong, endOfMessage: true, + disableCompression: true, _receiveBuffer.Slice(_receiveBufferOffset, (int)header.PayloadLength), cancellationToken).ConfigureAwait(false); } @@ -1051,8 +1152,9 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( Span receiveBufferSpan = _receiveBuffer.Span; header.Fin = (receiveBufferSpan[_receiveBufferOffset] & 0x80) != 0; - bool reservedSet = (receiveBufferSpan[_receiveBufferOffset] & 0x70) != 0; + bool reservedSet = (receiveBufferSpan[_receiveBufferOffset] & 0b_0011_0000) != 0; header.Opcode = (MessageOpcode)(receiveBufferSpan[_receiveBufferOffset] & 0xF); + header.Compressed = (receiveBufferSpan[_receiveBufferOffset] & 0b_0100_0000) != 0; bool masked = (receiveBufferSpan[_receiveBufferOffset + 1] & 0x80) != 0; header.PayloadLength = receiveBufferSpan[_receiveBufferOffset + 1] & 0x7F; @@ -1083,6 +1185,12 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( return SR.net_Websockets_ReservedBitsSet; } + if (header.Compressed && _inflater is null) + { + resultHeader = default; + return SR.net_Websockets_PerMessageCompressedFlagWhenNotEnabled; + } + if (masked) { if (!_isServer) @@ -1106,6 +1214,16 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( resultHeader = default; return SR.net_Websockets_ContinuationFromFinalFrame; } + if (header.Compressed) + { + // Must not mark continuations as compressed + resultHeader = default; + return SR.net_Websockets_PerMessageCompressedFlagInContinuation; + } + + // Set the compressed flag from the previous header so the receive procedure can use it + // directly without needing to check the previous header in case of continuations. + header.Compressed = _lastReceiveHeader.Compressed; break; case MessageOpcode.Binary: @@ -1137,6 +1255,7 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( // Return the read header resultHeader = header; + resultHeader.Processed = header.PayloadLength == 0; return null; } @@ -1248,7 +1367,7 @@ private async ValueTask SendCloseFrameAsync(WebSocketCloseStatus closeStatus, st buffer[0] = (byte)(closeStatusValue >> 8); buffer[1] = (byte)(closeStatusValue & 0xFF); - await SendFrameAsync(MessageOpcode.Close, true, new Memory(buffer, 0, count), cancellationToken).ConfigureAwait(false); + await SendFrameAsync(MessageOpcode.Close, endOfMessage: true, disableCompression: true, new Memory(buffer, 0, count), cancellationToken).ConfigureAwait(false); } finally { @@ -1580,24 +1699,18 @@ private struct MessageHeader internal MessageOpcode Opcode; internal bool Fin; internal long PayloadLength; + internal bool Compressed; internal int Mask; - } - /// - /// Interface used by to enable it to return - /// different result types in an efficient manner. - /// - /// The type of the result - private interface IWebSocketReceiveResultGetter - { - TResult GetResult(int count, WebSocketMessageType messageType, bool endOfMessage, WebSocketCloseStatus? closeStatus, string? closeDescription); - } + /// + /// Returns if frame has been received and processed. + /// + internal bool Processed { get; set; } - /// implementation for . - private readonly struct WebSocketReceiveResultGetter : IWebSocketReceiveResultGetter - { - public WebSocketReceiveResult GetResult(int count, WebSocketMessageType messageType, bool endOfMessage, WebSocketCloseStatus? closeStatus, string? closeDescription) => - new WebSocketReceiveResult(count, messageType, endOfMessage, closeStatus, closeDescription); + /// + /// Returns if message has been received and processed. + /// + internal bool EndOfMessage => Fin && Processed && PayloadLength == 0; } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs index 3bd6835a16f1d..044c7b95536be 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs @@ -58,6 +58,11 @@ public virtual ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessage new ValueTask(SendAsync(arraySegment, messageType, endOfMessage, cancellationToken)) : SendWithArrayPoolAsync(buffer, messageType, endOfMessage, cancellationToken); + public virtual ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken = default) + { + return SendAsync(buffer, messageType, messageFlags.HasFlag(WebSocketMessageFlags.EndOfMessage), cancellationToken); + } + private async ValueTask SendWithArrayPoolAsync( ReadOnlyMemory buffer, WebSocketMessageType messageType, @@ -157,7 +162,24 @@ public static WebSocket CreateFromStream(Stream stream, bool isServer, string? s 0)); } - return ManagedWebSocket.CreateFromConnectedStream(stream, isServer, subProtocol, keepAliveInterval); + return new ManagedWebSocket(stream, isServer, subProtocol, keepAliveInterval); + } + + /// Creates a that operates on a representing a web socket connection. + /// The for the connection. + /// The options with which the websocket must be created. + public static WebSocket CreateFromStream(Stream stream, WebSocketCreationOptions options) + { + if (stream is null) + throw new ArgumentNullException(nameof(stream)); + + if (options is null) + throw new ArgumentNullException(nameof(options)); + + if (!stream.CanRead || !stream.CanWrite) + throw new ArgumentException(!stream.CanRead ? SR.NotReadableStream : SR.NotWriteableStream, nameof(stream)); + + return new ManagedWebSocket(stream, options); } [EditorBrowsable(EditorBrowsableState.Never)] @@ -209,8 +231,7 @@ public static WebSocket CreateClientWebSocket(Stream innerStream, // Ignore useZeroMaskingKey. ManagedWebSocket doesn't currently support that debugging option. // Ignore internalBuffer. ManagedWebSocket uses its own small buffer for headers/control messages. - - return ManagedWebSocket.CreateFromConnectedStream(innerStream, false, subProtocol, keepAliveInterval); + return new ManagedWebSocket(innerStream, false, subProtocol, keepAliveInterval); } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs new file mode 100644 index 0000000000000..d042583da5444 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs @@ -0,0 +1,63 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading; + +namespace System.Net.WebSockets +{ + /// + /// Options that control how a is created. + /// + public sealed class WebSocketCreationOptions + { + private string? _subProtocol; + private TimeSpan _keepAliveInterval; + + /// + /// Defines if this websocket is the server-side of the connection. The default value is false. + /// + public bool IsServer { get; set; } + + /// + /// The agreed upon sub-protocol that was used when creating the connection. + /// + public string? SubProtocol + { + get => _subProtocol; + set + { + if (value is not null) + { + WebSocketValidate.ValidateSubprotocol(value); + } + _subProtocol = value; + } + } + + /// + /// The keep-alive interval to use, or or to disable keep-alives. + /// The default is . + /// + public TimeSpan KeepAliveInterval + { + get => _keepAliveInterval; + set + { + if (value != Timeout.InfiniteTimeSpan && value < TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(KeepAliveInterval), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, 0)); + } + _keepAliveInterval = value; + } + } + + /// + /// The agreed upon options for per message deflate. + /// Be aware that enabling compression makes the application subject to CRIME/BREACH type of attacks. + /// It is strongly advised to turn off compression when sending data containing secrets by + /// specifying flag for such messages. + /// + public WebSocketDeflateOptions? DangerousDeflateOptions { get; set; } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs new file mode 100644 index 0000000000000..e497751db288e --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs @@ -0,0 +1,71 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.WebSockets +{ + /// + /// Options to enable per-message deflate compression for . + /// + /// + /// Although the WebSocket spec allows window bits from 8 to 15, the current implementation doesn't support 8 bits. + /// + public sealed class WebSocketDeflateOptions + { + private int _clientMaxWindowBits = WebSocketValidate.MaxDeflateWindowBits; + private int _serverMaxWindowBits = WebSocketValidate.MaxDeflateWindowBits; + + /// + /// This parameter indicates the base-2 logarithm for the LZ77 sliding window size used by + /// the client to compress messages and by the server to decompress them. + /// Must be a value between 9 and 15. The default is 15. + /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.2.2 + public int ClientMaxWindowBits + { + get => _clientMaxWindowBits; + set + { + if (value < WebSocketValidate.MinDeflateWindowBits || value > WebSocketValidate.MaxDeflateWindowBits) + { + throw new ArgumentOutOfRangeException(nameof(ClientMaxWindowBits), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, WebSocketValidate.MinDeflateWindowBits, WebSocketValidate.MaxDeflateWindowBits)); + } + _clientMaxWindowBits = value; + } + } + + /// + /// When true the client-side of the connection indicates that it will persist the deflate context accross messages. + /// The default is true. + /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.1.2 + public bool ClientContextTakeover { get; set; } = true; + + /// + /// This parameter indicates the base-2 logarithm for the LZ77 sliding window size used by + /// the server to compress messages and by the client to decompress them. + /// Must be a value between 9 and 15. The default is 15. + /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.2.1 + public int ServerMaxWindowBits + { + get => _serverMaxWindowBits; + set + { + if (value < WebSocketValidate.MinDeflateWindowBits || value > WebSocketValidate.MaxDeflateWindowBits) + { + throw new ArgumentOutOfRangeException(nameof(ServerMaxWindowBits), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, WebSocketValidate.MinDeflateWindowBits, WebSocketValidate.MaxDeflateWindowBits)); + } + _serverMaxWindowBits = value; + } + } + + /// + /// When true the server-side of the connection indicates that it will persist the deflate context accross messages. + /// The default is true. + /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.1.1 + public bool ServerContextTakeover { get; set; } = true; + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketMessageFlags.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketMessageFlags.cs new file mode 100644 index 0000000000000..9ce165d8de843 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketMessageFlags.cs @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.WebSockets +{ + /// + /// Flags for controlling how the should send a message. + /// + [Flags] + public enum WebSocketMessageFlags + { + /// + /// None + /// + None = 0, + + /// + /// Indicates that the data in "buffer" is the last part of a message. + /// + EndOfMessage = 1, + + /// + /// Disables compression for the message if compression has been enabled for the instance. + /// + DisableCompression = 2 + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj index 7cf0328df31ca..4e0bc74ebdaec 100644 --- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj +++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj @@ -7,6 +7,9 @@ + + + (() => options.ClientMaxWindowBits = 8); + Assert.Throws(() => options.ClientMaxWindowBits = 16); + + options.ClientMaxWindowBits = 14; + Assert.Equal(14, options.ClientMaxWindowBits); + } + + [Fact] + public void ServerMaxWindowBits() + { + WebSocketDeflateOptions options = new(); + Assert.Equal(15, options.ServerMaxWindowBits); + + Assert.Throws(() => options.ServerMaxWindowBits = 8); + Assert.Throws(() => options.ServerMaxWindowBits = 16); + + options.ServerMaxWindowBits = 14; + Assert.Equal(14, options.ServerMaxWindowBits); + } + + [Fact] + public void ContextTakeover() + { + WebSocketDeflateOptions options = new(); + + Assert.True(options.ClientContextTakeover); + Assert.True(options.ServerContextTakeover); + + options.ClientContextTakeover = false; + Assert.False(options.ClientContextTakeover); + + options.ServerContextTakeover = false; + Assert.False(options.ServerContextTakeover); + } + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs new file mode 100644 index 0000000000000..25efbe94b1d5b --- /dev/null +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -0,0 +1,626 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.WebSockets.Tests +{ + public class WebSocketDeflateTests + { + private readonly CancellationTokenSource? _cancellation; + + public WebSocketDeflateTests() + { + if (!Debugger.IsAttached) + { + _cancellation = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + } + } + + public CancellationToken CancellationToken => _cancellation?.Token ?? default; + + public static IEnumerable SupportedWindowBits + { + get + { + for (var i = 9; i <= 15; ++i) + { + yield return new object[] { i }; + } + } + } + + [Fact] + public async Task ReceiveHelloWithContextTakeover() + { + WebSocketTestStream stream = new(); + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + using WebSocket websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DangerousDeflateOptions = new() + }); + + Memory buffer = new byte[5]; + ValueWebSocketReceiveResult result = await websocket.ReceiveAsync(buffer, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(buffer.Length, result.Count); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.Span)); + + // Because context takeover is set by default if we try to send + // the same message it would take fewer bytes. + stream.Enqueue(0xc1, 0x05, 0xf2, 0x00, 0x11, 0x00, 0x00); + + buffer.Span.Clear(); + result = await websocket.ReceiveAsync(buffer, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(buffer.Length, result.Count); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.Span)); + } + + [Fact] + public async Task SendHelloWithContextTakeover() + { + WebSocketTestStream stream = new(); + using WebSocket websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + IsServer = true, + DangerousDeflateOptions = new() + }); + + await websocket.SendAsync(Encoding.UTF8.GetBytes("Hello"), WebSocketMessageType.Text, true, CancellationToken); + Assert.Equal("C107F248CDC9C90700", Convert.ToHexString(stream.NextAvailableBytes)); + + stream.Clear(); + await websocket.SendAsync(Encoding.UTF8.GetBytes("Hello"), WebSocketMessageType.Text, true, CancellationToken); + + // Because context takeover is set by default if we try to send + // the same message it should result in fewer bytes. + Assert.Equal("C105F200110000", Convert.ToHexString(stream.NextAvailableBytes)); + } + + [Fact] + public async Task SendHelloWithDisableCompression() + { + WebSocketTestStream stream = new(); + using WebSocket websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + IsServer = true, + DangerousDeflateOptions = new() + }); + + byte[] bytes = Encoding.UTF8.GetBytes("Hello"); + WebSocketMessageFlags flags = WebSocketMessageFlags.DisableCompression | WebSocketMessageFlags.EndOfMessage; + await websocket.SendAsync(bytes, WebSocketMessageType.Text, flags, CancellationToken); + + Assert.Equal(bytes.Length + 2, stream.Available); + Assert.True(stream.NextAvailableBytes.EndsWith(bytes)); + } + + [Fact] + public async Task SendHelloWithEmptyFrame() + { + WebSocketTestStream stream = new(); + using WebSocket websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + IsServer = true, + DangerousDeflateOptions = new() + }); + + byte[] bytes = Encoding.UTF8.GetBytes("Hello"); + await websocket.SendAsync(Memory.Empty, WebSocketMessageType.Text, endOfMessage: false, CancellationToken); + await websocket.SendAsync(bytes, WebSocketMessageType.Text, endOfMessage: true, CancellationToken); + + using WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = false, + DangerousDeflateOptions = new() + }); + + ValueWebSocketReceiveResult result = await client.ReceiveAsync(bytes.AsMemory(), CancellationToken); + Assert.False(result.EndOfMessage); + Assert.Equal(0, result.Count); + + result = await client.ReceiveAsync(bytes.AsMemory(), CancellationToken); + Assert.True(result.EndOfMessage); + Assert.Equal(5, result.Count); + } + + [Fact] + public async Task ReceiveHelloWithoutContextTakeover() + { + WebSocketTestStream stream = new(); + using WebSocket websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DangerousDeflateOptions = new() + { + ClientContextTakeover = false + } + }); + + Memory buffer = new byte[5]; + + for (var i = 0; i < 100; ++i) + { + // Without context takeover the message should look the same every time + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + buffer.Span.Clear(); + + ValueWebSocketReceiveResult result = await websocket.ReceiveAsync(buffer, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(buffer.Length, result.Count); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.Span)); + } + } + + [Fact] + public async Task SendHelloWithoutContextTakeover() + { + WebSocketTestStream stream = new(); + using WebSocket websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + IsServer = true, + DangerousDeflateOptions = new() + { + ServerContextTakeover = false + } + }); + + Memory buffer = new byte[5]; + + for (var i = 0; i < 100; ++i) + { + await websocket.SendAsync(Encoding.UTF8.GetBytes("Hello"), WebSocketMessageType.Text, true, CancellationToken); + + // Without context takeover the message should look the same every time + Assert.Equal("C107F248CDC9C90700", Convert.ToHexString(stream.NextAvailableBytes)); + stream.Clear(); + } + } + + [Fact] + public async Task TwoDeflateBlocksInOneMessage() + { + // Two or more DEFLATE blocks may be used in one message. + WebSocketTestStream stream = new(); + using WebSocket websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DangerousDeflateOptions = new() + }); + // The first 3 octets(0xf2 0x48 0x05) and the least significant two + // bits of the 4th octet(0x00) constitute one DEFLATE block with + // "BFINAL" set to 0 and "BTYPE" set to 01 containing "He". The rest of + // the 4th octet contains the header bits with "BFINAL" set to 0 and + // "BTYPE" set to 00, and the 3 padding bits of 0. Together with the + // following 4 octets(0x00 0x00 0xff 0xff), the header bits constitute + // an empty DEFLATE block with no compression. A DEFLATE block + // containing "llo" follows the empty DEFLATE block. + stream.Enqueue(0x41, 0x08, 0xf2, 0x48, 0x05, 0x00, 0x00, 0x00, 0xff, 0xff); + stream.Enqueue(0x80, 0x05, 0xca, 0xc9, 0xc9, 0x07, 0x00); + + Memory buffer = new byte[5]; + ValueWebSocketReceiveResult result = await websocket.ReceiveAsync(buffer, CancellationToken); + + Assert.Equal(2, result.Count); + Assert.False(result.EndOfMessage); + + result = await websocket.ReceiveAsync(buffer.Slice(result.Count), CancellationToken); + + Assert.Equal(3, result.Count); + Assert.True(result.EndOfMessage); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.Span)); + } + + [Theory] + [InlineData(false, false)] + [InlineData(true, true)] + [InlineData(false, true)] + [InlineData(true, false)] + public async Task Duplex(bool clientContextTakover, bool serverContextTakover) + { + WebSocketTestStream stream = new(); + using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DangerousDeflateOptions = new WebSocketDeflateOptions + { + ClientContextTakeover = clientContextTakover, + ServerContextTakeover = serverContextTakover + } + }); + using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DangerousDeflateOptions = new WebSocketDeflateOptions + { + ClientContextTakeover = clientContextTakover, + ServerContextTakeover = serverContextTakover + } + }); + + var buffer = new byte[1024]; + + for (var i = 0; i < 10; ++i) + { + string message = $"Sending number {i} from server."; + await SendTextAsync(message, server, disableCompression: i % 2 == 0); + + ValueWebSocketReceiveResult result = await client.ReceiveAsync(buffer.AsMemory(), CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + + Assert.Equal(message, Encoding.UTF8.GetString(buffer.AsSpan(0, result.Count))); + } + + for (var i = 0; i < 10; ++i) + { + string message = $"Sending number {i} from client."; + await SendTextAsync(message, client, disableCompression: i % 2 == 0); + + ValueWebSocketReceiveResult result = await server.ReceiveAsync(buffer.AsMemory(), CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + + Assert.Equal(message, Encoding.UTF8.GetString(buffer.AsSpan(0, result.Count))); + } + } + + [Theory] + [MemberData(nameof(SupportedWindowBits))] + public async Task LargeMessageSplitInMultipleFrames(int windowBits) + { + WebSocketTestStream stream = new(); + using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DangerousDeflateOptions = new() + { + ClientMaxWindowBits = windowBits + } + }); + using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DangerousDeflateOptions = new() + { + ClientMaxWindowBits = windowBits + } + }); + + Memory testData = new byte[ushort.MaxValue]; + Memory receivedData = new byte[testData.Length]; + + // Make the data incompressible to make sure that the output is larger than the input + var rng = new Random(0); + rng.NextBytes(testData.Span); + + // Test it a few times with different frame sizes + for (var i = 0; i < 10; ++i) + { + var frameSize = rng.Next(1024, 2048); + var position = 0; + + while (position < testData.Length) + { + var currentFrameSize = Math.Min(frameSize, testData.Length - position); + var eof = position + currentFrameSize == testData.Length; + + await server.SendAsync(testData.Slice(position, currentFrameSize), WebSocketMessageType.Binary, eof, CancellationToken); + position += currentFrameSize; + } + + Assert.True(testData.Length < stream.Remote.Available, "The compressed data should be bigger."); + Assert.Equal(testData.Length, position); + + // Receive the data from the client side + receivedData.Span.Clear(); + position = 0; + + // Intentionally receive with a frame size that is less than what the sender used + frameSize /= 3; + + while (true) + { + int currentFrameSize = Math.Min(frameSize, testData.Length - position); + ValueWebSocketReceiveResult result = await client.ReceiveAsync(receivedData.Slice(position, currentFrameSize), CancellationToken); + + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + position += result.Count; + + if (result.EndOfMessage) + break; + } + + Assert.Equal(0, stream.Remote.Available); + Assert.Equal(testData.Length, position); + Assert.True(testData.Span.SequenceEqual(receivedData.Span)); + } + } + + [Fact] + public async Task WebSocketWithoutDeflateShouldThrowOnCompressedMessage() + { + WebSocketTestStream stream = new(); + + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + using WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions()); + + var exception = await Assert.ThrowsAsync(() => + client.ReceiveAsync(Memory.Empty, CancellationToken).AsTask()); + + Assert.Equal("The WebSocket received compressed frame when compression is not enabled.", exception.Message); + } + + [Fact] + public async Task ReceiveUncompressedMessageWhenCompressionEnabled() + { + // We should be able to handle the situation where even if we have + // deflate compression enabled, uncompressed messages are OK + WebSocketTestStream stream = new(); + using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DangerousDeflateOptions = null + }); + using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DangerousDeflateOptions = new WebSocketDeflateOptions() + }); + + // Server sends uncompressed + await SendTextAsync("Hello", server); + + // Although client has deflate options, it should still be able + // to handle uncompressed messages. + Assert.Equal("Hello", await ReceiveTextAsync(client)); + + // Client sends compressed, but server compression is disabled and should throw on receive + await SendTextAsync("Hello back", client); + var exception = await Assert.ThrowsAsync(() => ReceiveTextAsync(server)); + Assert.Equal("The WebSocket received compressed frame when compression is not enabled.", exception.Message); + Assert.Equal(WebSocketState.Aborted, server.State); + + // The client should close if we try to receive + ValueWebSocketReceiveResult result = await client.ReceiveAsync(Memory.Empty, CancellationToken); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketCloseStatus.ProtocolError, client.CloseStatus); + Assert.Equal(WebSocketState.CloseReceived, client.State); + } + + [Fact] + public async Task ReceiveInvalidCompressedData() + { + WebSocketTestStream stream = new(); + using WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DangerousDeflateOptions = new WebSocketDeflateOptions() + }); + + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + Assert.Equal("Hello", await ReceiveTextAsync(client)); + + stream.Enqueue(0xc1, 0x07, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00); + var exception = await Assert.ThrowsAsync(() => ReceiveTextAsync(client)); + + Assert.Equal("The message was compressed using an unsupported compression method.", exception.Message); + Assert.Equal(WebSocketState.Aborted, client.State); + } + + [Theory] + [MemberData(nameof(SupportedWindowBits))] + public async Task PayloadShouldHaveSimilarSizeWhenSplitIntoSegments(int windowBits) + { + MemoryStream stream = new(); + using WebSocket client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DangerousDeflateOptions = new WebSocketDeflateOptions() + { + ClientMaxWindowBits = windowBits + } + }); + + // We're using a frame size that is close to the sliding window size for the deflate + int frameSize = 2 << windowBits; + + byte[] message = new byte[frameSize * 10]; + Random random = new(0); + + for (int i = 0; i < message.Length; ++i) + { + message[i] = (byte)random.Next(maxValue: 10); + } + + await client.SendAsync(message, WebSocketMessageType.Binary, true, CancellationToken); + + long payloadLength = stream.Length; + stream.SetLength(0); + + for (int i = 0; i < message.Length; i += frameSize) + { + await client.SendAsync(message.AsMemory(i, frameSize), WebSocketMessageType.Binary, i + frameSize == message.Length, CancellationToken); + } + + double difference = Math.Round(1 - payloadLength * 1.0 / stream.Length, 3); + + // The difference should not be more than 10% in either direction + Assert.InRange(difference, -0.1, 0.1); + } + + [Theory] + [InlineData(9, 15)] + [InlineData(15, 9)] + public async Task SendReceiveWithDifferentWindowBits(int clientWindowBits, int serverWindowBits) + { + WebSocketTestStream stream = new(); + using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DangerousDeflateOptions = new() + { + ClientContextTakeover = false, + ClientMaxWindowBits = clientWindowBits, + ServerContextTakeover = false, + ServerMaxWindowBits = serverWindowBits + } + }); + using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DangerousDeflateOptions = new() + { + ClientContextTakeover = false, + ClientMaxWindowBits = clientWindowBits, + ServerContextTakeover = false, + ServerMaxWindowBits = serverWindowBits + } + }); + + Memory data = new byte[64 * 1024]; + Memory buffer = new byte[data.Length]; + new Random(0).NextBytes(data.Span.Slice(0, data.Length / 2)); + + await server.SendAsync(data, WebSocketMessageType.Binary, true, CancellationToken); + ValueWebSocketReceiveResult result = await client.ReceiveAsync(buffer, CancellationToken); + + Assert.Equal(data.Length, result.Count); + Assert.True(result.EndOfMessage); + Assert.True(data.Span.SequenceEqual(buffer.Span)); + + buffer.Span.Clear(); + + await client.SendAsync(data, WebSocketMessageType.Binary, true, CancellationToken); + result = await server.ReceiveAsync(buffer, CancellationToken); + + Assert.Equal(data.Length, result.Count); + Assert.True(result.EndOfMessage); + Assert.True(data.Span.SequenceEqual(buffer.Span)); + } + + [Fact] + public async Task AutobahnTestCase13_3_1() + { + // When running Autobahn Test Suite some tests failed with zlib error "invalid distance too far back". + // Further investigation lead to a bug fix in zlib intel's implementation - https://github.com/dotnet/runtime/issues/50235. + // This test replicates one of the Autobahn tests to make sure this issue doesn't appear again. + byte[][] messages = new[] + { + new byte[] { 0x7B, 0x0A, 0x20, 0x20, 0x20, 0x22, 0x41, 0x75, 0x74, 0x6F, 0x62, 0x61, 0x68, 0x6E, 0x50, 0x79 }, + new byte[] { 0x74, 0x68, 0x6F, 0x6E, 0x2F, 0x30, 0x2E, 0x36, 0x2E, 0x30, 0x22, 0x3A, 0x20, 0x7B, 0x0A, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x31, 0x2E, 0x31, 0x2E, 0x31, 0x22, 0x3A, 0x20, 0x7B, 0x0A }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69 }, + new byte[] { 0x6F, 0x72, 0x22, 0x3A, 0x20, 0x22, 0x4F, 0x4B, 0x22, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x22, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69, 0x6F, 0x72, 0x43, 0x6C, 0x6F }, + new byte[] { 0x73, 0x65, 0x22, 0x3A, 0x20, 0x22, 0x4F, 0x4B, 0x22, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x22, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6F, 0x6E, 0x22, 0x3A, 0x20 }, + new byte[] { 0x32, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x72, 0x65, 0x6D }, + new byte[] { 0x6F, 0x74, 0x65, 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x43, 0x6F, 0x64, 0x65, 0x22, 0x3A, 0x20, 0x31 }, + new byte[] { 0x30, 0x30, 0x30, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x72 }, + new byte[] { 0x65, 0x70, 0x6F, 0x72, 0x74, 0x66, 0x69, 0x6C, 0x65, 0x22, 0x3A, 0x20, 0x22, 0x61, 0x75, 0x74 }, + new byte[] { 0x6F, 0x62, 0x61, 0x68, 0x6E, 0x70, 0x79, 0x74, 0x68, 0x6F, 0x6E, 0x5F, 0x30, 0x5F, 0x36, 0x5F }, + new byte[] { 0x30, 0x5F, 0x63, 0x61, 0x73, 0x65, 0x5F, 0x31, 0x5F, 0x31, 0x5F, 0x31, 0x2E, 0x6A, 0x73, 0x6F }, + new byte[] { 0x6E, 0x22, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x7D, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x22, 0x31, 0x2E, 0x31, 0x2E, 0x32, 0x22, 0x3A, 0x20, 0x7B, 0x0A, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69, 0x6F, 0x72, 0x22 }, + new byte[] { 0x3A, 0x20, 0x22, 0x4F, 0x4B, 0x22, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x22, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69, 0x6F, 0x72, 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x22 }, + new byte[] { 0x3A, 0x20, 0x22, 0x4F, 0x4B, 0x22, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x22, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6F, 0x6E, 0x22, 0x3A, 0x20, 0x32, 0x2C, 0x0A }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x72, 0x65, 0x6D, 0x6F, 0x74, 0x65 }, + new byte[] { 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x43, 0x6F, 0x64, 0x65, 0x22, 0x3A, 0x20, 0x31, 0x30, 0x30, 0x30 }, + new byte[] { 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x72, 0x65, 0x70, 0x6F }, + new byte[] { 0x72, 0x74, 0x66, 0x69, 0x6C, 0x65, 0x22, 0x3A, 0x20, 0x22, 0x61, 0x75, 0x74, 0x6F, 0x62, 0x61 }, + new byte[] { 0x68, 0x6E, 0x70, 0x79, 0x74, 0x68, 0x6F, 0x6E, 0x5F, 0x30, 0x5F, 0x36, 0x5F, 0x30, 0x5F, 0x63 }, + new byte[] { 0x61, 0x73, 0x65, 0x5F, 0x31, 0x5F, 0x31, 0x5F, 0x32, 0x2E, 0x6A, 0x73, 0x6F, 0x6E, 0x22, 0x0A }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x7D, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22 }, + new byte[] { 0x31, 0x2E, 0x31, 0x2E, 0x33, 0x22, 0x3A, 0x20, 0x7B, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x22, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69, 0x6F, 0x72, 0x22, 0x3A, 0x20, 0x22 }, + new byte[] { 0x4F, 0x4B, 0x22, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x62 }, + new byte[] { 0x65, 0x68, 0x61, 0x76, 0x69, 0x6F, 0x72, 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x22, 0x3A, 0x20, 0x22 }, + new byte[] { 0x4F, 0x4B, 0x22, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x64 }, + new byte[] { 0x75, 0x72, 0x61, 0x74, 0x69, 0x6F, 0x6E, 0x22, 0x3A, 0x20, 0x32, 0x2C, 0x0A, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x72, 0x65, 0x6D, 0x6F, 0x74, 0x65, 0x43, 0x6C, 0x6F }, + new byte[] { 0x73, 0x65, 0x43, 0x6F, 0x64, 0x65, 0x22, 0x3A, 0x20, 0x31, 0x30, 0x30, 0x30, 0x2C, 0x0A, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x72, 0x65, 0x70, 0x6F, 0x72, 0x74, 0x66 }, + new byte[] { 0x69, 0x6C, 0x65, 0x22, 0x3A, 0x20, 0x22, 0x61, 0x75, 0x74, 0x6F, 0x62, 0x61, 0x68, 0x6E, 0x70 }, + new byte[] { 0x79, 0x74, 0x68, 0x6F, 0x6E, 0x5F, 0x30, 0x5F, 0x36, 0x5F, 0x30, 0x5F, 0x63, 0x61, 0x73, 0x65 }, + new byte[] { 0x5F, 0x31, 0x5F, 0x31, 0x5F, 0x33, 0x2E, 0x6A, 0x73, 0x6F, 0x6E, 0x22, 0x0A, 0x20, 0x20, 0x20 }, + new byte[] { 0x20, 0x20, 0x20, 0x7D, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x31, 0x2E, 0x31 }, + new byte[] { 0x2E, 0x34, 0x22, 0x3A, 0x20, 0x7B, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20 }, + new byte[] { 0x22, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69, 0x6F, 0x72, 0x22, 0x3A, 0x20, 0x22, 0x4F, 0x4B, 0x22 }, + new byte[] { 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x62, 0x65, 0x68, 0x61 }, + new byte[] { 0x76, 0x69, 0x6F, 0x72, 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x22, 0x3A, 0x20, 0x22, 0x4F, 0x4B, 0x22 }, + new byte[] { 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x22, 0x64, 0x75, 0x72, 0x61 }, + new byte[] { 0x74, 0x69, 0x6F, 0x6E, 0x22, 0x3A, 0x20, 0x32, 0x2C, 0x0A, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20 } + }; + + WebSocketTestStream stream = new(); + using WebSocket server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + KeepAliveInterval = TimeSpan.Zero, + DangerousDeflateOptions = new() + { + ClientMaxWindowBits = 9, + ServerMaxWindowBits = 9 + } + }); + using WebSocket client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + KeepAliveInterval = TimeSpan.Zero, + DangerousDeflateOptions = new() + { + ClientMaxWindowBits = 9, + ServerMaxWindowBits = 9 + } + }); + + foreach (var message in messages) + { + await server.SendAsync(message, WebSocketMessageType.Text, true, CancellationToken); + } + + Memory buffer = new byte[32]; + + for (int i = 0; i < messages.Length; ++i) + { + ValueWebSocketReceiveResult result = await client.ReceiveAsync(buffer, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(messages[i].Length, result.Count); + Assert.True(buffer.Span.Slice(0, result.Count).SequenceEqual(messages[i])); + } + } + + private ValueTask SendTextAsync(string text, WebSocket websocket, bool disableCompression = false) + { + WebSocketMessageFlags flags = WebSocketMessageFlags.EndOfMessage; + if (disableCompression) + { + flags |= WebSocketMessageFlags.DisableCompression; + } + byte[] bytes = Encoding.UTF8.GetBytes(text); + return websocket.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, flags, CancellationToken); + } + + private async Task ReceiveTextAsync(WebSocket websocket) + { + using IMemoryOwner buffer = MemoryPool.Shared.Rent(1024 * 32); + ValueWebSocketReceiveResult result = await websocket.ReceiveAsync(buffer.Memory, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + + return Encoding.UTF8.GetString(buffer.Memory.Span.Slice(0, result.Count)); + } + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs new file mode 100644 index 0000000000000..b7dfb3ea7f26d --- /dev/null +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs @@ -0,0 +1,238 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets.Tests +{ + /// + /// A helper stream class that can be used simulate sending / receiving (duplex) data in a websocket. + /// + public class WebSocketTestStream : Stream + { + private readonly SemaphoreSlim _inputLock = new(initialCount: 0); + private readonly Queue _inputQueue = new(); + private readonly CancellationTokenSource _disposed = new(); + + public WebSocketTestStream() + { + GC.SuppressFinalize(this); + Remote = new WebSocketTestStream(this); + } + + private WebSocketTestStream(WebSocketTestStream remote) + { + GC.SuppressFinalize(this); + Remote = remote; + } + + public WebSocketTestStream Remote { get; } + + /// + /// Returns the number of unread bytes. + /// + public int Available + { + get + { + int available = 0; + + lock (_inputQueue) + { + foreach (Block x in _inputQueue) + { + available += x.AvailableLength; + } + } + + return available; + } + } + + public Span NextAvailableBytes + { + get + { + lock (_inputQueue) + { + if (_inputQueue.TryPeek(out Block block)) + { + return block.Available; + } + return default; + } + } + } + + /// + /// If set, would cause the next send operation to be delayed + /// and complete asynchronously. Can be used to test cancellation tokens + /// and async code branches. + /// + public TimeSpan DelayForNextSend { get; set; } + + public override bool CanRead => true; + + public override bool CanSeek => false; + + public override bool CanWrite => true; + + public override long Length => -1; + + public override long Position { get => -1; set => throw new NotSupportedException(); } + + protected override void Dispose(bool disposing) + { + if (!_disposed.IsCancellationRequested) + { + _disposed.Cancel(); + + lock (Remote._inputQueue) + { + Remote._inputLock.Release(); + Remote._inputQueue.Enqueue(Block.ConnectionClosed); + } + } + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) + { + using CancellationTokenSource cancellation = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _disposed.Token); + try + { + await _inputLock.WaitAsync(cancellation.Token).ConfigureAwait(false); + } + catch (TaskCanceledException) when (cancellationToken.IsCancellationRequested) + { + throw new OperationCanceledException(cancellationToken); + } + catch (OperationCanceledException) when (_disposed.IsCancellationRequested) + { + return 0; + } + + lock (_inputQueue) + { + Block block = _inputQueue.Peek(); + if (block == Block.ConnectionClosed) + { + return 0; + } + int count = Math.Min(block.AvailableLength, buffer.Length); + + block.Available.Slice(0, count).CopyTo(buffer.Span); + block.Advance(count); + + if (block.AvailableLength == 0) + { + _inputQueue.Dequeue(); + } + else + { + // Because we haven't fully consumed the buffer + // we should release once the input lock so we can acquire + // it again on consequent receive. + _inputLock.Release(); + } + + return count; + } + } + + /// + /// Enqueues the provided data for receive by the WebSocket. + /// + public void Enqueue(params byte[] data) + { + lock (_inputQueue) + { + _inputLock.Release(); + _inputQueue.Enqueue(new Block(data)); + } + } + + /// + /// Enqueues the provided data for receive by the WebSocket. + /// + public void Enqueue(ReadOnlySpan data) + { + lock (_inputQueue) + { + _inputLock.Release(); + _inputQueue.Enqueue(new Block(data.ToArray())); + } + } + + public void Clear() + { + lock (_inputQueue) + { + while (_inputQueue.Count > 0) + { + if (_inputQueue.Peek() == Block.ConnectionClosed) + { + break; + } + _inputQueue.Dequeue(); + } + + while (_inputLock.CurrentCount > _inputQueue.Count) + { + _inputLock.Wait(0); + } + } + } + + public override void Write(ReadOnlySpan buffer) + { + lock (Remote._inputQueue) + { + Remote._inputLock.Release(); + Remote._inputQueue.Enqueue(new Block(buffer.ToArray())); + } + } + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) + { + if (DelayForNextSend > TimeSpan.Zero) + { + await Task.Delay(DelayForNextSend, cancellationToken); + DelayForNextSend = TimeSpan.Zero; + } + + Write(buffer.Span); + } + + public override void Flush() => throw new NotSupportedException(); + + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + + public override void SetLength(long value) => throw new NotSupportedException(); + + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + private sealed class Block + { + public static readonly Block ConnectionClosed = new(Array.Empty()); + + private readonly byte[] _data; + private int _position; + + public Block(byte[] data) + { + _data = data; + } + + public Span Available => _data.AsSpan(_position); + + public int AvailableLength => _data.Length - _position; + + public void Advance(int count) => _position += count; + } + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs index ad738a00ec864..19b38d8d21760 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.IO; +using System.Threading.Tasks; using Xunit; namespace System.Net.WebSockets.Tests @@ -171,6 +172,16 @@ public void ValueWebSocketReceiveResult_Ctor_ValidArguments_Roundtrip(int count, Assert.Equal(endOfMessage, r.EndOfMessage); } + [Fact] + public async Task ThrowWhenContinuationWithDifferentCompressionFlags() + { + using WebSocket client = CreateFromStream(new MemoryStream(), isServer: false, null, TimeSpan.Zero); + + await client.SendAsync(Memory.Empty, WebSocketMessageType.Text, WebSocketMessageFlags.DisableCompression, default); + Assert.Throws("messageFlags", () => + client.SendAsync(Memory.Empty, WebSocketMessageType.Binary, WebSocketMessageFlags.EndOfMessage, default)); + } + public abstract class ExposeProtectedWebSocket : WebSocket { public static new bool IsStateTerminal(WebSocketState state) =>