Skip to content

Commit

Permalink
Use Utf8.IsValid to optimize ManagedWebSocket.TryValidateUtf8 (#104865)
Browse files Browse the repository at this point in the history
* Use Utf8.IsValid to optimize ManagedWebSocket.TryValidateUtf8

* Skip ASCII in loop

* Port old Microsoft.AspNetCore.WebSockets.Protocol.Test tests
  • Loading branch information
stephentoub authored Jul 15, 2024
1 parent 57441c7 commit 3212e3f
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Runtime.InteropServices;
using System.Security.Cryptography;
using System.Text;
using System.Text.Unicode;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -1546,22 +1547,31 @@ private void ThrowIfOperationInProgress(bool operationCompleted, [CallerMemberNa
// It checks for valid formatting, overlong encodings, surrogates, and value ranges.
private static bool TryValidateUtf8(ReadOnlySpan<byte> span, bool endOfMessage, Utf8MessageState state)
{
// If no prior segment spilled over and this one is the last, we can validate it efficiently as a complete message.
if (endOfMessage && !state.SequenceInProgress)
{
return Utf8.IsValid(span);
}

for (int i = 0; i < span.Length;)
{
// Have we started a character sequence yet?
if (!state.SequenceInProgress)
{
// Skip past ASCII bytes.
int firstNonAscii = span.Slice(i).IndexOfAnyExceptInRange((byte)0, (byte)127);
if (firstNonAscii < 0)
{
break;
}
i += firstNonAscii;

// The first byte tells us how many bytes are in the sequence.
state.SequenceInProgress = true;
byte b = span[i];
i++;
if ((b & 0x80) == 0) // 0bbbbbbb, single byte
{
state.AdditionalBytesExpected = 0;
state.CurrentDecodeBits = b & 0x7F;
state.ExpectedValueMin = 0;
}
else if ((b & 0xC0) == 0x80)
Debug.Assert((b & 0x80) != 0, "Should have already skipped past ASCII");
if ((b & 0xC0) == 0x80)
{
// Misplaced 10bbbbbb continuation byte. This cannot be the first byte.
return false;
Expand Down Expand Up @@ -1589,6 +1599,7 @@ private static bool TryValidateUtf8(ReadOnlySpan<byte> span, bool endOfMessage,
return false;
}
}

while (state.AdditionalBytesExpected > 0 && i < span.Length)
{
byte b = span[i];
Expand All @@ -1608,12 +1619,14 @@ private static bool TryValidateUtf8(ReadOnlySpan<byte> span, bool endOfMessage,
// This is going to end up in the range of 0xD800-0xDFFF UTF-16 surrogates that are not allowed in UTF-8;
return false;
}

if (state.AdditionalBytesExpected == 2 && state.CurrentDecodeBits >= 0x110)
{
// This is going to be out of the upper Unicode bound 0x10FFFF.
return false;
}
}

if (state.AdditionalBytesExpected == 0)
{
state.SequenceInProgress = false;
Expand All @@ -1624,11 +1637,8 @@ private static bool TryValidateUtf8(ReadOnlySpan<byte> span, bool endOfMessage,
}
}
}
if (endOfMessage && state.SequenceInProgress)
{
return false;
}
return true;

return !endOfMessage || !state.SequenceInProgress;
}

private sealed class Utf8MessageState
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
<Compile Include="WebSocketTestStream.cs" />
<Compile Include="WebSocketDeflateOptionsTests.cs" />
<Compile Include="WebSocketDeflateTests.cs" />
<Compile Include="$(CommonTestPath)System\Net\Configuration.cs"
Link="Common\System\Net\Configuration.cs" />
<Compile Include="$(CommonTestPath)System\Net\Configuration.WebSockets.cs"
Link="Common\System\Net\Configuration.WebSockets.cs" />
<Compile Include="$(CommonPath)System\Net\HttpKnownHeaderNames.cs"
Link="Common\System\Net\HttpKnownHeaderNames.cs" />
<Compile Include="WebSocketUtf8Tests.cs" />
<Compile Include="$(CommonTestPath)System\Net\Configuration.cs" Link="Common\System\Net\Configuration.cs" />
<Compile Include="$(CommonTestPath)System\Net\Configuration.WebSockets.cs" Link="Common\System\Net\Configuration.WebSockets.cs" />
<Compile Include="$(CommonPath)System\Net\HttpKnownHeaderNames.cs" Link="Common\System\Net\HttpKnownHeaderNames.cs" />
<Compile Include="$(CommonTestPath)System\IO\ConnectedStreams.cs" Link="Common\System\IO\ConnectedStreams.cs" />
<Compile Include="$(CommonPath)System\Net\StreamBuffer.cs" Link="ProductionCode\Common\System\Net\StreamBuffer.cs" />
<Compile Include="$(CommonPath)System\Net\MultiArrayBuffer.cs" Link="ProductionCode\Common\System\Net\MultiArrayBuffer.cs" />
</ItemGroup>
</Project>
144 changes: 144 additions & 0 deletions src/libraries/System.Net.WebSockets/tests/WebSocketUtf8Tests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace System.Net.WebSockets.Tests
{
public class WebSocketUtf8Tests
{
[Theory]
[InlineData(new byte[] { })]
[InlineData(new byte[] { 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x20, 0x57, 0x6F, 0x72, 0x6C, 0x64 })] // Hello World
[InlineData(new byte[] { 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x2D, 0xC2, 0xB5, 0x40, 0xC3, 0x9F, 0xC3, 0xB6, 0xC3, 0xA4, 0xC3, 0xBC, 0xC3, 0xA0, 0xC3, 0xA1 })] // "Hello-µ@ßöäüàá";
[InlineData(new byte[] { 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0xf0, 0xa4, 0xad, 0xa2, 0x77, 0x6f, 0x72, 0x6c, 0x64 })] // "hello\U00024b62world"
[InlineData(new byte[] { 0xf0, 0xa4, 0xad, 0xa2 })] // "\U00024b62"
public async Task ValidateSingleValidSegments_Valid(byte[] data)
{
await WithConnectedWebSockets(async (ws1, ws2) =>
{
Assert.True(await IsValidUtf8Async(ws1, ws2, data, endOfMessage: true));

for (int i = 0 ; i < data.Length; i++)
{
Assert.True(await IsValidUtf8Async(ws1, ws2, data.AsMemory(i, 1), endOfMessage: i == data.Length - 1));
}
});
}

[Theory]
[InlineData(new byte[] { }, new byte[] { }, new byte[] { })]
[InlineData(new byte[] { 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x20 }, new byte[] { }, new byte[] { 0x57, 0x6F, 0x72, 0x6C, 0x64 })] // Hello ,, World
[InlineData(new byte[] { 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x2D, 0xC2, }, new byte[] { 0xB5, 0x40, 0xC3, 0x9F, 0xC3, 0xB6, 0xC3, 0xA4, }, new byte[] { 0xC3, 0xBC, 0xC3, 0xA0, 0xC3, 0xA1 })] // "Hello-µ@ßöäüàá";
public async Task ValidateMultipleValidSegments_Valid(byte[] data1, byte[] data2, byte[] data3)
{
await WithConnectedWebSockets(async (ws1, ws2) =>
{
Assert.True(await IsValidUtf8Async(ws1, ws2, data1, endOfMessage: false));
Assert.True(await IsValidUtf8Async(ws1, ws2, data2, endOfMessage: false));
Assert.True(await IsValidUtf8Async(ws1, ws2, data3, endOfMessage: false));

for (int i = 0; i < data1.Length; i++)
{
Assert.True(await IsValidUtf8Async(ws1, ws2, data1.AsMemory(i, 1), endOfMessage: false));
}
for (int i = 0; i < data2.Length; i++)
{
Assert.True(await IsValidUtf8Async(ws1, ws2, data2.AsMemory(i, 1), endOfMessage: false));
}
for (int i = 0; i < data3.Length; i++)
{
Assert.True(await IsValidUtf8Async(ws1, ws2, data3.AsMemory(i, 1), endOfMessage: i == data3.Length - 1));
}
});
}

[Theory]
[InlineData(new byte[] { 0xfe })]
[InlineData(new byte[] { 0xff })]
[InlineData(new byte[] { 0xfe, 0xfe, 0xff, 0xff })]
[InlineData(new byte[] { 0xc0, 0xb1 })] // Overlong Ascii
[InlineData(new byte[] { 0xc1, 0xb1 })] // Overlong Ascii
[InlineData(new byte[] { 0xe0, 0x80, 0xaf })] // Overlong
[InlineData(new byte[] { 0xf0, 0x80, 0x80, 0xaf })] // Overlong
[InlineData(new byte[] { 0xf8, 0x80, 0x80, 0x80, 0xaf })] // Overlong
[InlineData(new byte[] { 0xfc, 0x80, 0x80, 0x80, 0x80, 0xaf })] // Overlong
[InlineData(new byte[] { 0xed, 0xa0, 0x80, 0x65, 0x64, 0x69, 0x74, 0x65, 0x64 })] // 0xEDA080 decodes to 0xD800, which is a reserved high surrogate character.
public async Task ValidateSingleInvalidSegment_Invalid(byte[] data)
{
await WithConnectedWebSockets(async (ws1, ws2) =>
{
Assert.False(await IsValidUtf8Async(ws1, ws2, data, endOfMessage: true));
});
}

[Fact]
public async Task ValidateIndividualInvalidSegments_Invalid()
{
byte[] data = [0xce, 0xba, 0xe1, 0xbd, 0xb9, 0xcf, 0x83, 0xce, 0xbc, 0xce, 0xb5, 0xed, 0xa0, 0x80, 0x65, 0x64, 0x69, 0x74, 0x65, 0x64];

await WithConnectedWebSockets(async (ws1, ws2) =>
{
Assert.False(await IsValidUtf8Async(ws1, ws2, data, endOfMessage: false));
});

await WithConnectedWebSockets(async (ws1, ws2) =>
{
for (int i = 0; i < 12; i++)
{
Assert.True(await IsValidUtf8Async(ws1, ws2, data.AsMemory(i, 1), endOfMessage: false), i.ToString());
}

Assert.False(await IsValidUtf8Async(ws1, ws2, data.AsMemory(12, 1), endOfMessage: false), 12.ToString());
});
}

[Fact]
public async Task ValidateMultipleInvalidSegments_Invalid()
{
byte[] data0 = [0xce, 0xba, 0xe1, 0xbd, 0xb9, 0xcf, 0x83, 0xce, 0xbc, 0xce, 0xb5, 0xf4];
byte[] data1 = [0x90];

await WithConnectedWebSockets(async (ws1, ws2) =>
{
Assert.True(await IsValidUtf8Async(ws1, ws2, data0, endOfMessage: false));
Assert.False(await IsValidUtf8Async(ws1, ws2, data1, endOfMessage: false));
});

await WithConnectedWebSockets(async (ws1, ws2) =>
{
for (int i = 0; i < data0.Length; i++)
{
Assert.True(await IsValidUtf8Async(ws1, ws2, data0.AsMemory(i, 1), endOfMessage: false));
}

Assert.False(await IsValidUtf8Async(ws1, ws2, data1, endOfMessage: false));
});
}

private static async ValueTask<bool> IsValidUtf8Async(WebSocket sender, WebSocket receiver, Memory<byte> buffer, bool endOfMessage)
{
await sender.SendAsync(buffer, WebSocketMessageType.Text, endOfMessage, CancellationToken.None).ConfigureAwait(false);
try
{
await receiver.ReceiveAsync(buffer, CancellationToken.None).ConfigureAwait(false);
return true;
}
catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.Faulted)
{
return false;
}
}

private static async Task WithConnectedWebSockets(Func<WebSocket, WebSocket, Task> callback)
{
(Stream stream1, Stream stream2) = ConnectedStreams.CreateBidirectional();
using WebSocket ws1 = WebSocket.CreateFromStream(stream1, isServer: false, subProtocol: null, Timeout.InfiniteTimeSpan);
using WebSocket ws2 = WebSocket.CreateFromStream(stream2, isServer: true, subProtocol: null, Timeout.InfiniteTimeSpan);
await callback(ws1, ws2).ConfigureAwait(false);
}
}
}

0 comments on commit 3212e3f

Please sign in to comment.