Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize RPC traffic for IAsyncEnumerable<T> #389

Merged
merged 3 commits into from
Dec 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 238 additions & 26 deletions doc/asyncenumerable.md

Large diffs are not rendered by default.

171 changes: 160 additions & 11 deletions src/StreamJsonRpc.Tests/AsyncEnumerableTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Globalization;
using System.Runtime.CompilerServices;
Expand Down Expand Up @@ -37,26 +38,34 @@ protected AsyncEnumerableTests(ITestOutputHelper logger)

/// <summary>
/// This interface should NOT be implemented by <see cref="Server"/>,
/// since the server implements the one method on this interface with a return type of Task{T}
/// since the server implements the methods on this interface with a return type of Task{T}
/// but we want the client proxy to NOT be that.
/// </summary>
protected interface IServer2
{
IAsyncEnumerable<int> WaitTillCanceledBeforeReturningAsync(CancellationToken cancellationToken);

IAsyncEnumerable<int> GetNumbersParameterizedAsync(int batchSize, int readAhead, int prefetch, int totalCount, CancellationToken cancellationToken);
}

protected interface IServer
{
IAsyncEnumerable<int> GetValuesFromEnumeratedSourceAsync(CancellationToken cancellationToken);

IAsyncEnumerable<int> GetNumbersInBatchesAsync(CancellationToken cancellationToken);

IAsyncEnumerable<int> GetNumbersWithReadAheadAsync(CancellationToken cancellationToken);

IAsyncEnumerable<int> GetNumbersAsync(CancellationToken cancellationToken);

IAsyncEnumerable<int> GetNumbersNoCancellationAsync();

IAsyncEnumerable<int> WaitTillCanceledBeforeFirstItemAsync(CancellationToken cancellationToken);

Task<IAsyncEnumerable<int>> WaitTillCanceledBeforeReturningAsync(CancellationToken cancellationToken);

Task<IAsyncEnumerable<int>> WaitTillCanceledBeforeFirstItemWithPrefetchAsync(CancellationToken cancellationToken);

Task<CompoundEnumerableResult> GetNumbersAndMetadataAsync(CancellationToken cancellationToken);

Task PassInNumbersAsync(IAsyncEnumerable<int> numbers, CancellationToken cancellationToken);
Expand All @@ -78,6 +87,7 @@ public Task InitializeAsync()
this.serverRpc = new JsonRpc(serverHandler, this.server);
this.clientRpc = new JsonRpc(clientHandler);

// Don't use Verbose as it has nasty side-effects leading to test failures, which we need to fix!
this.serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Information);
this.clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Information);

Expand Down Expand Up @@ -122,6 +132,20 @@ public async Task GetIAsyncEnumerableAsReturnType(bool useProxy)
Assert.Equal(Server.ValuesReturnedByEnumerables, realizedValuesCount);
}

[Fact]
public async Task GetIAsyncEnumerableAsReturnType_WithProxy_NoCancellation()
{
int realizedValuesCount = 0;
IAsyncEnumerable<int> enumerable = this.clientProxy.Value.GetNumbersNoCancellationAsync();
await foreach (int number in enumerable)
{
realizedValuesCount++;
this.Logger.WriteLine(number.ToString(CultureInfo.InvariantCulture));
}

Assert.Equal(Server.ValuesReturnedByEnumerables, realizedValuesCount);
}

[Theory]
[PairwiseData]
public async Task GetIAsyncEnumerableAsMemberWithinReturnType(bool useProxy)
Expand Down Expand Up @@ -284,6 +308,44 @@ public async Task Cancellation_AfterFirstMoveNext(bool useProxy)
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await enumerator.MoveNextAsync());
}

[Fact]
public async Task Cancellation_AfterFirstMoveNext_NaturalForEach_Proxy()
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(this.TimeoutToken);
IAsyncEnumerable<int> enumerable = this.clientProxy.Value.GetNumbersAsync(cts.Token);

int iterations = 0;
await Assert.ThrowsAsync<OperationCanceledException>(async delegate
{
await foreach (var item in enumerable)
{
iterations++;
cts.Cancel();
}
}).WithCancellation(this.TimeoutToken);

Assert.Equal(1, iterations);
}

[Fact]
public async Task Cancellation_AfterFirstMoveNext_NaturalForEach_NoProxy()
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(this.TimeoutToken);
var enumerable = await this.clientRpc.InvokeWithCancellationAsync<IAsyncEnumerable<int>>(nameof(Server.GetNumbersAsync), cancellationToken: cts.Token);

int iterations = 0;
await Assert.ThrowsAsync<OperationCanceledException>(async delegate
{
await foreach (var item in enumerable.WithCancellation(cts.Token))
{
iterations++;
cts.Cancel();
}
}).WithCancellation(this.TimeoutToken);

Assert.Equal(1, iterations);
}

[Theory]
[PairwiseData]
public async Task Cancellation_DuringLongRunningServerMoveNext(bool useProxy)
Expand All @@ -297,17 +359,17 @@ public async Task Cancellation_DuringLongRunningServerMoveNext(bool useProxy)
var moveNextTask = enumerator.MoveNextAsync();
Assert.False(moveNextTask.IsCompleted);
cts.Cancel();
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await moveNextTask);
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await moveNextTask).WithCancellation(this.TimeoutToken);
}

[Theory]
[PairwiseData]
public async Task Cancellation_DuringLongRunningServerBeforeReturning(bool useProxy)
public async Task Cancellation_DuringLongRunningServerBeforeReturning(bool useProxy, bool prefetch)
{
var cts = new CancellationTokenSource();
Task<IAsyncEnumerable<int>> enumerable = useProxy
? this.clientProxy.Value.WaitTillCanceledBeforeReturningAsync(cts.Token)
: this.clientRpc.InvokeWithCancellationAsync<IAsyncEnumerable<int>>(nameof(Server.WaitTillCanceledBeforeReturningAsync), cancellationToken: cts.Token);
? (prefetch ? this.clientProxy.Value.WaitTillCanceledBeforeFirstItemWithPrefetchAsync(cts.Token) : this.clientProxy.Value.WaitTillCanceledBeforeReturningAsync(cts.Token))
: this.clientRpc.InvokeWithCancellationAsync<IAsyncEnumerable<int>>(prefetch ? nameof(Server.WaitTillCanceledBeforeFirstItemWithPrefetchAsync) : nameof(Server.WaitTillCanceledBeforeReturningAsync), cancellationToken: cts.Token);

// Make sure the method has been invoked first.
await this.server.MethodEntered.WaitAsync(this.TimeoutToken);
Expand Down Expand Up @@ -397,6 +459,35 @@ public async Task ArgumentEnumerable_ForciblyDisposedAndReleasedWhenNotDisposedW
await Assert.ThrowsAsync<InvalidOperationException>(() => this.server.ArgEnumeratorAfterReturn).WithCancellation(this.TimeoutToken);
}

[SkippableFact]
[Trait("GC", "")]
public async Task ReturnEnumerable_AutomaticallyReleasedOnErrorFromIteratorMethod()
{
WeakReference enumerable = await this.ReturnEnumerable_AutomaticallyReleasedOnErrorFromIteratorMethod_Helper();
AssertCollectedObject(enumerable);
}

[Theory]
[InlineData(1, 0, 2, Server.ValuesReturnedByEnumerables)]
[InlineData(2, 2, 2, Server.ValuesReturnedByEnumerables)]
[InlineData(2, 4, 2, Server.ValuesReturnedByEnumerables)]
[InlineData(2, 2, 4, Server.ValuesReturnedByEnumerables)]
[InlineData(2, 2, Server.ValuesReturnedByEnumerables, Server.ValuesReturnedByEnumerables)]
[InlineData(2, 2, Server.ValuesReturnedByEnumerables + 1, Server.ValuesReturnedByEnumerables)]
[InlineData(2, 2, 2, 0)]
[InlineData(2, 2, 2, 1)]
public async Task Prefetch(int minBatchSize, int maxReadAhead, int prefetch, int totalCount)
{
var proxy = this.clientRpc.Attach<IServer2>();
int enumerated = 0;
await foreach (var item in proxy.GetNumbersParameterizedAsync(minBatchSize, maxReadAhead, prefetch, totalCount, this.TimeoutToken).WithCancellation(this.TimeoutToken))
{
Assert.Equal(++enumerated, item);
}

Assert.Equal(totalCount, enumerated);
}

protected abstract void InitializeFormattersAndHandlers();

private static void AssertCollectedObject(WeakReference weakReference)
Expand Down Expand Up @@ -452,6 +543,28 @@ private async Task<WeakReference> ArgumentEnumerable_ForciblyDisposedAndReleased
return result;
}

[MethodImpl(MethodImplOptions.NoInlining)]
private async Task<WeakReference> ReturnEnumerable_AutomaticallyReleasedOnErrorFromIteratorMethod_Helper()
{
this.server.EnumeratedSource = ImmutableList.Create(1, 2, 3);
WeakReference weakReferenceToSource = new WeakReference(this.server.EnumeratedSource);
var cts = CancellationTokenSource.CreateLinkedTokenSource(this.TimeoutToken);

// Start up th emethod and get the first item.
var enumerable = this.clientProxy.Value.GetValuesFromEnumeratedSourceAsync(cts.Token);
var enumerator = enumerable.GetAsyncEnumerator(cts.Token);
Assert.True(await enumerator.MoveNextAsync());

// Now remove the only strong reference to the source object other than what would be captured by the async iterator method.
this.server.EnumeratedSource = this.server.EnumeratedSource.Clear();

// Now array for the server method to be canceled
cts.Cancel();
await Assert.ThrowsAsync<OperationCanceledException>(async () => await enumerator.MoveNextAsync());

return weakReferenceToSource;
}

protected class Server : IServer
{
/// <summary>
Expand All @@ -476,23 +589,33 @@ protected class Server : IServer

public AsyncManualResetEvent ValueGenerated { get; } = new AsyncManualResetEvent();

public ImmutableList<int> EnumeratedSource { get; set; } = ImmutableList<int>.Empty;

public async IAsyncEnumerable<int> GetValuesFromEnumeratedSourceAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
foreach (var item in this.EnumeratedSource)
{
cancellationToken.ThrowIfCancellationRequested();
await Task.Yield();
yield return item;
}
}

public IAsyncEnumerable<int> GetNumbersInBatchesAsync(CancellationToken cancellationToken)
=> this.GetNumbersAsync(cancellationToken).WithJsonRpcSettings(new JsonRpcEnumerableSettings { MinBatchSize = MinBatchSize });

public IAsyncEnumerable<int> GetNumbersWithReadAheadAsync(CancellationToken cancellationToken)
=> this.GetNumbersAsync(cancellationToken).WithJsonRpcSettings(new JsonRpcEnumerableSettings { MaxReadAhead = MaxReadAhead, MinBatchSize = MinBatchSize });

public IAsyncEnumerable<int> GetNumbersNoCancellationAsync() => this.GetNumbersAsync(CancellationToken.None);

public async IAsyncEnumerable<int> GetNumbersAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
try
{
for (int i = 1; i <= ValuesReturnedByEnumerables; i++)
await foreach (var item in this.GetNumbersAsync(ValuesReturnedByEnumerables, cancellationToken))
{
cancellationToken.ThrowIfCancellationRequested();
await Task.Yield();
this.ActuallyGeneratedValueCount++;
this.ValueGenerated.PulseAll();
yield return i;
yield return item;
}
}
finally
Expand All @@ -501,13 +624,27 @@ public async IAsyncEnumerable<int> GetNumbersAsync([EnumeratorCancellation] Canc
}
}

public async ValueTask<IAsyncEnumerable<int>> GetNumbersParameterizedAsync(int batchSize, int readAhead, int prefetch, int totalCount, CancellationToken cancellationToken)
{
return await this.GetNumbersAsync(totalCount, cancellationToken)
.WithJsonRpcSettings(new JsonRpcEnumerableSettings { MinBatchSize = batchSize, MaxReadAhead = readAhead })
.WithPrefetchAsync(prefetch, cancellationToken);
}

public async IAsyncEnumerable<int> WaitTillCanceledBeforeFirstItemAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSource<int>();
await tcs.Task.WithCancellation(cancellationToken);
yield return 0; // we will never reach this.
}

public async Task<IAsyncEnumerable<int>> WaitTillCanceledBeforeFirstItemWithPrefetchAsync(CancellationToken cancellationToken)
{
this.MethodEntered.Set();
return await this.WaitTillCanceledBeforeFirstItemAsync(cancellationToken)
.WithPrefetchAsync(1, cancellationToken);
}

public Task<IAsyncEnumerable<int>> WaitTillCanceledBeforeReturningAsync(CancellationToken cancellationToken)
{
this.MethodEntered.Set();
Expand Down Expand Up @@ -550,6 +687,18 @@ public Task<CompoundEnumerableResult> GetNumbersAndMetadataAsync(CancellationTok
Enumeration = this.GetNumbersAsync(cancellationToken),
});
}

private async IAsyncEnumerable<int> GetNumbersAsync(int totalCount, [EnumeratorCancellation] CancellationToken cancellationToken)
{
for (int i = 1; i <= totalCount; i++)
{
cancellationToken.ThrowIfCancellationRequested();
await Task.Yield();
this.ActuallyGeneratedValueCount++;
this.ValueGenerated.PulseAll();
yield return i;
}
}
}

[DataContract]
Expand Down
Loading