Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Assert for TlsHandler.MediationStream #51

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Binary file modified shared/contoso.com.pfx
Binary file not shown.
Binary file modified shared/dotnetty.com.pfx
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ partial class MediationStream

public void SetSource(in ReadOnlyMemory<byte> source)
{
Debug.Assert(SourceReadableBytes == 0);
_input = source;
_inputOffset = 0;
_inputLength = 0;
}

public void ResetSource()
{
Debug.Assert(SourceReadableBytes == 0);
_input = null;
_inputOffset = 0;
_inputLength = 0;
}

Expand All @@ -60,6 +63,7 @@ public void ExpandSource(int count)
if (sslBuffer.IsEmpty)
{
// there is no pending read operation - keep for future
Debug.Assert(_readCompletionSource == null);
Copy link
Collaborator Author

@yyjdelete yyjdelete May 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReadAsync can also be called with zero length array, and it should be finished if any data is available.

Copy link
Owner

@cuteant cuteant May 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

真是仔细啊, 👍

return;
}
_sslOwnedBuffer = default;
Expand Down Expand Up @@ -87,6 +91,7 @@ public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken
return new ValueTask<int>(read);
}

Debug.Assert(_readCompletionSource == null);
Debug.Assert(_sslOwnedBuffer.IsEmpty);
// take note of buffer - we will pass bytes there once available
_sslOwnedBuffer = buffer;
Expand Down
3 changes: 3 additions & 0 deletions src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetFx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ partial class MediationStream

public void SetSource(byte[] source, int offset)
{
Debug.Assert(SourceReadableBytes == 0);
_input = source;
_inputStartOffset = offset;
_inputOffset = 0;
Expand All @@ -53,7 +54,9 @@ public void SetSource(byte[] source, int offset)

public void ResetSource()
{
Debug.Assert(SourceReadableBytes == 0);
_input = null;
_inputOffset = 0;
_inputLength = 0;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ partial class MediationStream

public void SetSource(byte[] source, int offset)
{
Debug.Assert(SourceReadableBytes == 0);
_input = source;
_inputStartOffset = offset;
_inputOffset = 0;
Expand All @@ -47,7 +48,9 @@ public void SetSource(byte[] source, int offset)

public void ResetSource()
{
Debug.Assert(SourceReadableBytes == 0);
_input = null;
_inputOffset = 0;
_inputLength = 0;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public static IEnumerable<object[]> GetKeepAliveCases() => new[]
};

[Theory]
[MemberData(nameof(GetKeepAliveCases))]
[MemberData(nameof(GetKeepAliveCases), DisableDiscoveryEnumeration = true)]
public void KeepAlive(bool isKeepAliveResponseExpected, HttpVersion httpVersion, HttpResponseStatus responseStatus, string sendKeepAlive, int setSelfDefinedMessageLength, ICharSequence setResponseConnection)
{
var channel = new EmbeddedChannel(new HttpServerKeepAliveHandler());
Expand All @@ -66,7 +66,7 @@ public void KeepAlive(bool isKeepAliveResponseExpected, HttpVersion httpVersion,
}

[Theory]
[MemberData(nameof(GetKeepAliveCases))]
[MemberData(nameof(GetKeepAliveCases), DisableDiscoveryEnumeration = true)]
#pragma warning disable xUnit1026 // Theory methods should use all of their parameters
public void ConnectionCloseHeaderHandledCorrectly(bool isKeepAliveResponseExpected, HttpVersion httpVersion, HttpResponseStatus responseStatus, string sendKeepAlive, int setSelfDefinedMessageLength, ICharSequence setResponseConnection)
#pragma warning restore xUnit1026 // Theory methods should use all of their parameters
Expand All @@ -85,7 +85,7 @@ public void ConnectionCloseHeaderHandledCorrectly(bool isKeepAliveResponseExpect
}

[Theory]
[MemberData(nameof(GetKeepAliveCases))]
[MemberData(nameof(GetKeepAliveCases), DisableDiscoveryEnumeration = true)]
public void PipelineKeepAlive(bool isKeepAliveResponseExpected, HttpVersion httpVersion, HttpResponseStatus responseStatus, string sendKeepAlive, int setSelfDefinedMessageLength, ICharSequence setResponseConnection)
{
var channel = new EmbeddedChannel(new HttpServerKeepAliveHandler());
Expand Down
2 changes: 1 addition & 1 deletion test/DotNetty.Codecs.Http2.Tests/HpackTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public static IEnumerable<object[]> GetJsonFiles()
}

[Theory]
[MemberData(nameof(GetJsonFiles))]
[MemberData(nameof(GetJsonFiles), DisableDiscoveryEnumeration = true)]
public void Test(FileInfo file)
{
using (var fs = file.Open(FileMode.Open))
Expand Down
2 changes: 1 addition & 1 deletion test/DotNetty.Codecs.Protobuf.Tests/RoundTripTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ public static IEnumerable<object[]> GetAddressBookCases()
}

[Theory]
[MemberData(nameof(GetAddressBookCases))]
[MemberData(nameof(GetAddressBookCases), DisableDiscoveryEnumeration = true)]
public void Run1(AddressBook addressBook, bool isCompositeBuffer)
{
var channel = new EmbeddedChannel(
Expand Down
26 changes: 8 additions & 18 deletions test/DotNetty.Handlers.Tests/SniHandlerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,7 @@ static SniHandlerTest()
X509Certificate2 tlsCertificate = TestResourceHelper.GetTestCertificate();
X509Certificate2 tlsCertificate2 = TestResourceHelper.GetTestCertificate2();

//#if NETCOREAPP_3_0_GREATER
// SslProtocols serverProtocol = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? SslProtocols.Tls13 : SslProtocols.Tls12;
//#else
SslProtocols serverProtocol = SslProtocols.Tls12;
//#endif
SslProtocols serverProtocol = SslProtocols.None;
SettingMap[tlsCertificate.GetNameInfo(X509NameType.DnsName, false)] = new ServerTlsSettings(tlsCertificate, false, false, serverProtocol);
SettingMap[tlsCertificate2.GetNameInfo(X509NameType.DnsName, false)] = new ServerTlsSettings(tlsCertificate2, false, false, serverProtocol);
}
Expand All @@ -53,11 +49,7 @@ public static IEnumerable<object[]> GetTlsReadTestData()
new[] { 1 }
};
var boolToggle = new[] { false, true };
//#if NETCOREAPP_3_0_GREATER
// var protocols = new[] { RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? SslProtocols.Tls13 : SslProtocols.Tls12 };
//#else
var protocols = new[] { SslProtocols.Tls12 };
//#endif
var protocols = new[] { SslProtocols.None };
var writeStrategyFactories = new Func<IWriteStrategy>[]
{
() => new AsIsWriteStrategy()
Expand All @@ -74,7 +66,7 @@ from targetHost in SettingMap.Keys


[Theory]
[MemberData(nameof(GetTlsReadTestData))]
[MemberData(nameof(GetTlsReadTestData), DisableDiscoveryEnumeration = true)]
public async Task TlsRead(int[] frameLengths, bool isClient, IWriteStrategy writeStrategy, SslProtocols protocol, string targetHost)
{
this.Output.WriteLine($"frameLengths: {string.Join(", ", frameLengths)}");
Expand Down Expand Up @@ -106,7 +98,8 @@ public async Task TlsRead(int[] frameLengths, bool isClient, IWriteStrategy writ
#pragma warning disable CS1998 // 异步方法缺少 "await" 运算符,将以同步方式运行
await ReadOutboundAsync(async () => ch.ReadInbound<IByteBuffer>(), expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
#pragma warning restore CS1998 // 异步方法缺少 "await" 运算符,将以同步方式运行
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
if (!ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer))
Assert.True(false, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");

if (!isClient)
{
Expand All @@ -131,11 +124,7 @@ public static IEnumerable<object[]> GetTlsWriteTestData()
new[] { 1 }
};
var boolToggle = new[] { false, true };
//#if NETCOREAPP_3_0_GREATER
// var protocols = new[] { RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? SslProtocols.Tls13 : SslProtocols.Tls12 };
//#else
var protocols = new[] { SslProtocols.Tls12 };
//#endif
var protocols = new[] { SslProtocols.None };

return
from frameLengths in lengthVariations
Expand Down Expand Up @@ -186,7 +175,8 @@ await ReadOutboundAsync(
return Unpooled.WrappedBuffer(readBuffer, 0, read);
},
expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
if (!ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer))
Assert.True(false, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");

if (!isClient)
{
Expand Down
Loading