Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quic test improvements #56043

Merged
merged 7 commits into from
Jul 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<TargetFrameworks>$(NetCoreAppCurrent)</TargetFrameworks>
</PropertyGroup>
<ItemGroup>
<Compile Include="System\IO\*" />
<Compile Include="System\IO\*.cs" />
</ItemGroup>
<ItemGroup>
<Compile Include="$(CommonTestPath)System\IO\ConnectedStreams.cs" Link="System\IO\ConnectedStreams.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ public static class TaskTimeoutExtensions
#region WaitAsync polyfills
// Test polyfills when targeting a platform that doesn't have these ConfigureAwait overloads on Task

public static Task WaitAsync(this Task task, int millisecondsTimeout) =>
WaitAsync(task, TimeSpan.FromMilliseconds(millisecondsTimeout), default);

public static Task WaitAsync(this Task task, TimeSpan timeout) =>
WaitAsync(task, timeout, default);

Expand All @@ -28,6 +31,9 @@ public async static Task WaitAsync(this Task task, TimeSpan timeout, Cancellatio
}
}

public static Task<TResult> WaitAsync<TResult>(this Task<TResult> task, int millisecondsTimeout) =>
WaitAsync(task, TimeSpan.FromMilliseconds(millisecondsTimeout), default);

public static Task<TResult> WaitAsync<TResult>(this Task<TResult> task, TimeSpan timeout) =>
WaitAsync(task, timeout, default);

Expand All @@ -48,6 +54,9 @@ public static async Task<TResult> WaitAsync<TResult>(this Task<TResult> task, Ti
public static async Task WhenAllOrAnyFailed(this Task[] tasks, int millisecondsTimeout) =>
await tasks.WhenAllOrAnyFailed().WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));

public static async Task WhenAllOrAnyFailed(Task t1, Task t2, int millisecondsTimeout) =>
await new Task[] {t1, t2}.WhenAllOrAnyFailed(millisecondsTimeout);

public static async Task WhenAllOrAnyFailed(this Task[] tasks)
{
try
Expand Down
74 changes: 34 additions & 40 deletions src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,19 @@ namespace System.Net.Quic.Tests
[ConditionalClass(typeof(QuicTestBase<MsQuicProviderFactory>), nameof(IsSupported))]
public class MsQuicTests : QuicTestBase<MsQuicProviderFactory>
{
readonly ITestOutputHelper _output;
private static ReadOnlyMemory<byte> s_data = Encoding.UTF8.GetBytes("Hello world!");

public MsQuicTests(ITestOutputHelper output)
{
_output = output;
}
public MsQuicTests(ITestOutputHelper output) : base(output) { }

[Fact]
public async Task UnidirectionalAndBidirectionalStreamCountsWork()
{
using QuicListener listener = CreateQuicListener();
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SInce this is a common, repeated pattern, I wonder if we should have a helper method for it, like CreateClientAndServer or something like that. Or EstablishConnection or something.

We have RunClientServer, but that's not really the same.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yah. I wanted to like SslStream. But if I return tuple we loose the using var.
We would need to than use good old

using (clientConn)
using (serverConn)
{
  ....
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW the helper would be also nice if we want to bet in any retry logs for the flaky listener if we don't find way how to fix it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a little verbose because of the using stuff.

One idea would be to define something like

struct QuicConnectionPair : IDisposable
{
  public QuicConnection Client { get; }
  public QuicConnection Server { get; }
  public void Dispose() { Client.Dispose(); Server.Dispose(); }
}

public static ValueTask<ConnectionPair> EstablishConnectionAsync(...) { ... }

Then you can just write code like this:

  using (var connectionPair = await EstablishConnectionAsync(...);
  // use connectionPair.Client and connectionPair.Server here

If we feel really ambitious we could generalize ConnectionPair to something like DisposablePair<T1, T2>. And maybe add some implicit conversion ops from ValueTuple<T1, T2>...

@stephentoub already has something vaguely like this for the Stream Conformance Tests.

await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Assert.Equal(100, serverConnection.GetRemoteAvailableBidirectionalStreamCount());
Assert.Equal(100, serverConnection.GetRemoteAvailableUnidirectionalStreamCount());
}
Expand All @@ -55,10 +51,10 @@ public async Task UnidirectionalAndBidirectionalChangeValues()
};

using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Assert.Equal(100, clientConnection.GetRemoteAvailableBidirectionalStreamCount());
Assert.Equal(100, clientConnection.GetRemoteAvailableUnidirectionalStreamCount());
Assert.Equal(10, serverConnection.GetRemoteAvailableBidirectionalStreamCount());
Expand Down Expand Up @@ -112,10 +108,9 @@ public async Task ConnectWithCertificateChain()
};

using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
ValueTask clientTask = clientConnection.ConnectAsync();

using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;
}

[Fact]
Expand Down Expand Up @@ -342,10 +337,10 @@ public async Task ConnectWithClientCertificate()
clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };

using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);
ValueTask clientTask = clientConnection.ConnectAsync();
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
// Verify functionality of the connections.
await PingPong(clientConnection, serverConnection);
// check we completed the client certificate verification.
Expand All @@ -359,10 +354,9 @@ public async Task WaitForAvailableUnidirectionStreamsAsyncWorks()
{
using QuicListener listener = CreateQuicListener(maxUnidirectionalStreams: 1);
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;
listener.Dispose();

// No stream opened yet, should return immediately.
Expand All @@ -387,9 +381,9 @@ public async Task WaitForAvailableBidirectionStreamsAsyncWorks()
using QuicListener listener = CreateQuicListener(maxBidirectionalStreams: 1);
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

// No stream opened yet, should return immediately.
Assert.True(clientConnection.WaitForAvailableBidirectionalStreamsAsync().IsCompletedSuccessfully);
Expand Down Expand Up @@ -425,16 +419,15 @@ public async Task SetListenerTimeoutWorksWithSmallTimeout()
};

using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

await Assert.ThrowsAsync<QuicOperationAbortedException>(async () => await serverConnection.AcceptStreamAsync().AsTask().WaitAsync(TimeSpan.FromSeconds(100)));
}

[Theory]
[MemberData(nameof(WriteData))]
[ActiveIssue("https://github.com/dotnet/runtime/issues/49157")]
public async Task WriteTests(int[][] writes, WriteType writeType)
{
await RunClientServer(
Expand Down Expand Up @@ -530,9 +523,10 @@ public async Task CallDifferentWriteMethodsWorks()
using QuicListener listener = CreateQuicListener();
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;


ReadOnlyMemory<byte> helloWorld = Encoding.ASCII.GetBytes("Hello world!");
ReadOnlySequence<byte> ros = CreateReadOnlySequenceFromBytes(helloWorld.ToArray());
Expand Down Expand Up @@ -714,9 +708,9 @@ async Task GetStreamIdWithoutStartWorks()
using QuicListener listener = CreateQuicListener();
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

using QuicStream clientStream = clientConnection.OpenBidirectionalStream();
Assert.Equal(0, clientStream.StreamId);
Expand All @@ -737,9 +731,9 @@ async Task GetStreamIdWithoutStartWorks()
using QuicListener listener = CreateQuicListener();
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
Task<QuicConnection> serverTask = listener.AcceptConnectionAsync().AsTask();
await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientConnection.ConnectAsync().AsTask(), serverTask, PassingTestTimeoutMilliseconds);
using QuicConnection serverConnection = serverTask.Result;

using QuicStream clientStream = clientConnection.OpenBidirectionalStream();
Assert.Equal(0, clientStream.StreamId);
Expand Down Expand Up @@ -781,7 +775,7 @@ await Task.Run(async () =>
byte[] buffer = new byte[100];
QuicConnectionAbortedException ex = await Assert.ThrowsAsync<QuicConnectionAbortedException>(() => serverStream.ReadAsync(buffer).AsTask());
Assert.Equal(ExpectedErrorCode, ex.ErrorCode);
}).WaitAsync(TimeSpan.FromSeconds(5));
}).WaitAsync(TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds));
}

[Fact]
Expand All @@ -807,7 +801,7 @@ await Task.Run(async () =>

byte[] buffer = new byte[100];
await Assert.ThrowsAsync<QuicOperationAbortedException>(() => serverStream.ReadAsync(buffer).AsTask());
}).WaitAsync(TimeSpan.FromSeconds(5));
}).WaitAsync(TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ public abstract class QuicConnectionTests<T> : QuicTestBase<T>
{
const int ExpectedErrorCode = 1234;

public QuicConnectionTests(ITestOutputHelper output) : base(output) { }

[Fact]
public async Task TestConnect()
{
Expand Down Expand Up @@ -285,8 +287,14 @@ await RunClientServer(
}
}

public sealed class QuicConnectionTests_MockProvider : QuicConnectionTests<MockProviderFactory> { }
public sealed class QuicConnectionTests_MockProvider : QuicConnectionTests<MockProviderFactory>
{
public QuicConnectionTests_MockProvider(ITestOutputHelper output) : base(output) { }
}

[ConditionalClass(typeof(QuicTestBase<MsQuicProviderFactory>), nameof(QuicTestBase<MsQuicProviderFactory>.IsSupported))]
public sealed class QuicConnectionTests_MsQuicProvider : QuicConnectionTests<MsQuicProviderFactory> { }
public sealed class QuicConnectionTests_MsQuicProvider : QuicConnectionTests<MsQuicProviderFactory>
{
public QuicConnectionTests_MsQuicProvider(ITestOutputHelper output) : base(output) { }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;

namespace System.Net.Quic.Tests
{
public abstract class QuicListenerTests<T> : QuicTestBase<T>
where T : IQuicImplProviderFactory, new()
{
public QuicListenerTests(ITestOutputHelper output) : base(output) { }

[Fact]
public async Task Listener_Backlog_Success()
{
Expand All @@ -25,8 +28,14 @@ await Task.Run(async () =>
}
}

public sealed class QuicListenerTests_MockProvider : QuicListenerTests<MockProviderFactory> { }
public sealed class QuicListenerTests_MockProvider : QuicListenerTests<MockProviderFactory>
{
public QuicListenerTests_MockProvider(ITestOutputHelper output) : base(output) { }
}

[ConditionalClass(typeof(QuicTestBase<MsQuicProviderFactory>), nameof(QuicTestBase<MsQuicProviderFactory>.IsSupported))]
public sealed class QuicListenerTests_MsQuicProvider : QuicListenerTests<MsQuicProviderFactory> { }
public sealed class QuicListenerTests_MsQuicProvider : QuicListenerTests<MsQuicProviderFactory>
{
public QuicListenerTests_MsQuicProvider(ITestOutputHelper output) : base(output) { }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
using System.Collections.Generic;
using System.IO;
using System.IO.Tests;
using System.Net.Sockets;
using System.Net.Quic.Implementations;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;

namespace System.Net.Quic.Tests
{
Expand All @@ -23,11 +26,17 @@ public sealed class MsQuicQuicStreamConformanceTests : QuicStreamConformanceTest
protected override QuicImplementationProvider Provider => QuicImplementationProviders.MsQuic;
protected override bool UsableAfterCanceledReads => false;
protected override bool BlocksOnZeroByteReads => true;

public MsQuicQuicStreamConformanceTests(ITestOutputHelper output)
{
_output = output;
}
}

public abstract class QuicStreamConformanceTests : ConnectedStreamConformanceTests
{
public X509Certificate2 ServerCertificate = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate();
public ITestOutputHelper _output;

public bool RemoteCertificateValidationCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors)
{
Expand Down Expand Up @@ -75,21 +84,31 @@ await WhenAllOrAnyFailed(
}),
Task.Run(async () =>
{
connection2 = new QuicConnection(
provider,
listener.ListenEndPoint,
GetSslClientAuthenticationOptions());
await connection2.ConnectAsync();
stream2 = connection2.OpenBidirectionalStream();
// OpenBidirectionalStream only allocates ID. We will force stream opening
// by Writing there and receiving data on the other side.
await stream2.WriteAsync(buffer);
try
{
connection2 = new QuicConnection(
provider,
listener.ListenEndPoint,
GetSslClientAuthenticationOptions());
await connection2.ConnectAsync();
stream2 = connection2.OpenBidirectionalStream();
// OpenBidirectionalStream only allocates ID. We will force stream opening
// by Writing there and receiving data on the other side.
await stream2.WriteAsync(buffer);
}
catch (Exception ex)
{
_output?.WriteLine($"Failed to {ex.Message}");
throw;
}
}));

// No need to keep the listener once we have connected connection and streams
listener.Dispose();

var result = new StreamPairWithOtherDisposables(stream1, stream2);
result.Disposables.Add(connection1);
result.Disposables.Add(connection2);
result.Disposables.Add(listener);

return result;
}
Expand Down
Loading