Skip to content

Commit

Permalink
Query: Adds aggressive prefetching for scalar aggregates (#3168)
Browse files Browse the repository at this point in the history
* Add an aggressive prefetch policy and a fully buffered partition range page enumerator to match

* Start setting up test scaffolding

* Fix up some broken unit tests

* Instantiate the correct type of buffered enumerator in cross partition enumerator

* add some UTs for aggressive prefetching

* Add unit tests for aggressive prefetching

* Add some more unit test coverage

* Incorporate code review feedback and add more test coverage

Co-authored-by: j82w <j82w@users.noreply.github.com>
  • Loading branch information
neildsh and j82w authored May 6, 2022
1 parent b04468c commit e0466c1
Show file tree
Hide file tree
Showing 22 changed files with 841 additions and 236 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,11 @@ public static CrossPartitionChangeFeedAsyncEnumerator Create(
documentContainer,
changeFeedPaginationOptions,
cancellationToken),
comparer: default /* this uses a regular queue instead of prioirty queue */,
comparer: default /* this uses a regular queue instead of priority queue */,
maxConcurrency: default,
cancellationToken,
state);
prefetchPolicy: PrefetchPolicy.PrefetchSinglePage,
cancellationToken: cancellationToken,
state: state);

CrossPartitionChangeFeedAsyncEnumerator enumerator = new CrossPartitionChangeFeedAsyncEnumerator(
crossPartitionEnumerator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace Microsoft.Azure.Cosmos.Pagination
using Microsoft.Azure.Cosmos.Query.Core.Monads;
using Microsoft.Azure.Cosmos.Tracing;

internal sealed class BufferedPartitionRangePageAsyncEnumerator<TPage, TState> : PartitionRangePageAsyncEnumerator<TPage, TState>, IPrefetcher
internal sealed class BufferedPartitionRangePageAsyncEnumerator<TPage, TState> : BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState>
where TPage : Page<TState>
where TState : State
{
Expand All @@ -23,7 +23,10 @@ public BufferedPartitionRangePageAsyncEnumerator(PartitionRangePageAsyncEnumerat
this.enumerator = enumerator ?? throw new ArgumentNullException(nameof(enumerator));
}

public override ValueTask DisposeAsync() => this.enumerator.DisposeAsync();
public override ValueTask DisposeAsync()
{
return this.enumerator.DisposeAsync();
}

protected override async Task<TryCatch<TPage>> GetNextPageAsync(ITrace trace, CancellationToken cancellationToken)
{
Expand All @@ -40,20 +43,22 @@ protected override async Task<TryCatch<TPage>> GetNextPageAsync(ITrace trace, Ca
return returnValue;
}

public async ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken)
public override async ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken)
{
if (trace == null)
{
throw new ArgumentNullException(nameof(trace));
}

using (ITrace prefetchTrace = trace.StartChild("Prefetch", TraceComponent.Pagination, TraceLevel.Info))
if (this.bufferedPage.HasValue)
{
if (this.bufferedPage.HasValue)
{
return;
}
return;
}

cancellationToken.ThrowIfCancellationRequested();

using (ITrace prefetchTrace = trace.StartChild("Prefetch", TraceComponent.Pagination, TraceLevel.Info))
{
await this.enumerator.MoveNextAsync(prefetchTrace);
this.bufferedPage = this.enumerator.Current;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
// ------------------------------------------------------------

namespace Microsoft.Azure.Cosmos.Pagination
{
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Tracing;

internal abstract class BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState> : PartitionRangePageAsyncEnumerator<TPage, TState>, IPrefetcher
where TPage : Page<TState>
where TState : State
{
protected BufferedPartitionRangePageAsyncEnumeratorBase(FeedRangeState<TState> feedRangeState, CancellationToken cancellationToken)
: base(feedRangeState, cancellationToken)
{
}

public abstract ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,64 +34,24 @@ public CrossPartitionRangePageAsyncEnumerator(
CreatePartitionRangePageAsyncEnumerator<TPage, TState> createPartitionRangeEnumerator,
IComparer<PartitionRangePageAsyncEnumerator<TPage, TState>> comparer,
int? maxConcurrency,
PrefetchPolicy prefetchPolicy,
CancellationToken cancellationToken,
CrossFeedRangeState<TState> state = default)
{
this.feedRangeProvider = feedRangeProvider ?? throw new ArgumentNullException(nameof(feedRangeProvider));
this.createPartitionRangeEnumerator = createPartitionRangeEnumerator ?? throw new ArgumentNullException(nameof(createPartitionRangeEnumerator));
this.cancellationToken = cancellationToken;

this.lazyEnumerators = new AsyncLazy<IQueue<PartitionRangePageAsyncEnumerator<TPage, TState>>>(async (ITrace trace, CancellationToken token) =>
{
ReadOnlyMemory<FeedRangeState<TState>> rangeAndStates;
if (state != default)
{
rangeAndStates = state.Value;
}
else
{
// Fan out to all partitions with default state
List<FeedRangeEpk> ranges = await feedRangeProvider.GetFeedRangesAsync(trace, token);

List<FeedRangeState<TState>> rangesAndStatesBuilder = new List<FeedRangeState<TState>>(ranges.Count);
foreach (FeedRangeInternal range in ranges)
{
rangesAndStatesBuilder.Add(new FeedRangeState<TState>(range, default));
}

rangeAndStates = rangesAndStatesBuilder.ToArray();
}

List<BufferedPartitionRangePageAsyncEnumerator<TPage, TState>> bufferedEnumerators = new List<BufferedPartitionRangePageAsyncEnumerator<TPage, TState>>(rangeAndStates.Length);
for (int i = 0; i < rangeAndStates.Length; i++)
{
FeedRangeState<TState> feedRangeState = rangeAndStates.Span[i];
PartitionRangePageAsyncEnumerator<TPage, TState> enumerator = createPartitionRangeEnumerator(feedRangeState);
BufferedPartitionRangePageAsyncEnumerator<TPage, TState> bufferedEnumerator = new BufferedPartitionRangePageAsyncEnumerator<TPage, TState>(enumerator, cancellationToken);
bufferedEnumerators.Add(bufferedEnumerator);
}

if (maxConcurrency.HasValue)
{
await ParallelPrefetch.PrefetchInParallelAsync(bufferedEnumerators, maxConcurrency.Value, trace, token);
}

IQueue<PartitionRangePageAsyncEnumerator<TPage, TState>> queue;
if (comparer == null)
{
queue = new QueueWrapper<PartitionRangePageAsyncEnumerator<TPage, TState>>(
new Queue<PartitionRangePageAsyncEnumerator<TPage, TState>>(bufferedEnumerators));
}
else
{
queue = new PriorityQueueWrapper<PartitionRangePageAsyncEnumerator<TPage, TState>>(
new PriorityQueue<PartitionRangePageAsyncEnumerator<TPage, TState>>(
bufferedEnumerators,
comparer));
}

return queue;
});
this.lazyEnumerators = new AsyncLazy<IQueue<PartitionRangePageAsyncEnumerator<TPage, TState>>>((ITrace trace, CancellationToken token) =>
InitializeEnumeratorsAsync(
feedRangeProvider,
createPartitionRangeEnumerator,
comparer,
maxConcurrency,
prefetchPolicy,
state,
trace,
token));
}

public TryCatch<CrossFeedRangePage<TPage, TState>> Current { get; private set; }
Expand Down Expand Up @@ -281,6 +241,79 @@ private static bool IsSplitException(Exception exeception)
return enumerators.Peek()?.FeedRangeState;
}

private static async Task<IQueue<PartitionRangePageAsyncEnumerator<TPage, TState>>> InitializeEnumeratorsAsync(
IFeedRangeProvider feedRangeProvider,
CreatePartitionRangePageAsyncEnumerator<TPage, TState> createPartitionRangeEnumerator,
IComparer<PartitionRangePageAsyncEnumerator<TPage, TState>> comparer,
int? maxConcurrency,
PrefetchPolicy prefetchPolicy,
CrossFeedRangeState<TState> state,
ITrace trace,
CancellationToken token)
{
ReadOnlyMemory<FeedRangeState<TState>> rangeAndStates;
if (state != default)
{
rangeAndStates = state.Value;
}
else
{
// Fan out to all partitions with default state
List<FeedRangeEpk> ranges = await feedRangeProvider.GetFeedRangesAsync(trace, token);

List<FeedRangeState<TState>> rangesAndStatesBuilder = new List<FeedRangeState<TState>>(ranges.Count);
foreach (FeedRangeInternal range in ranges)
{
rangesAndStatesBuilder.Add(new FeedRangeState<TState>(range, default));
}

rangeAndStates = rangesAndStatesBuilder.ToArray();
}

IReadOnlyList<BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState>> bufferedEnumerators = CreateBufferedEnumerators(
prefetchPolicy,
createPartitionRangeEnumerator,
rangeAndStates,
token);

if (maxConcurrency.HasValue)
{
await ParallelPrefetch.PrefetchInParallelAsync(bufferedEnumerators, maxConcurrency.Value, trace, token);
}

IQueue<PartitionRangePageAsyncEnumerator<TPage, TState>> queue = comparer == null
? new QueueWrapper<PartitionRangePageAsyncEnumerator<TPage, TState>>(
new Queue<PartitionRangePageAsyncEnumerator<TPage, TState>>(bufferedEnumerators))
: new PriorityQueueWrapper<PartitionRangePageAsyncEnumerator<TPage, TState>>(
new PriorityQueue<PartitionRangePageAsyncEnumerator<TPage, TState>>(
bufferedEnumerators,
comparer));
return queue;
}

private static IReadOnlyList<BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState>> CreateBufferedEnumerators(
PrefetchPolicy policy,
CreatePartitionRangePageAsyncEnumerator<TPage, TState> createPartitionRangeEnumerator,
ReadOnlyMemory<FeedRangeState<TState>> rangeAndStates,
CancellationToken cancellationToken)
{
List<BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState>> bufferedEnumerators = new (rangeAndStates.Length);
for (int i = 0; i < rangeAndStates.Length; i++)
{
FeedRangeState<TState> feedRangeState = rangeAndStates.Span[i];
PartitionRangePageAsyncEnumerator<TPage, TState> enumerator = createPartitionRangeEnumerator(feedRangeState);
BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState> bufferedEnumerator = policy switch
{
PrefetchPolicy.PrefetchSinglePage => new BufferedPartitionRangePageAsyncEnumerator<TPage, TState>(enumerator, cancellationToken),
PrefetchPolicy.PrefetchAll => new FullyBufferedPartitionRangeAsyncEnumerator<TPage, TState>(enumerator, cancellationToken),
_ => throw new ArgumentOutOfRangeException(nameof(policy)),
};
bufferedEnumerators.Add(bufferedEnumerator);
}

return bufferedEnumerators;
}

private interface IQueue<T> : IEnumerable<T>
{
T Peek();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
// ------------------------------------------------------------
namespace Microsoft.Azure.Cosmos.Pagination
{
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Query.Core.Monads;
using Microsoft.Azure.Cosmos.Tracing;

internal sealed class FullyBufferedPartitionRangeAsyncEnumerator<TPage, TState> : BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState>
where TPage : Page<TState>
where TState : State
{
private readonly PartitionRangePageAsyncEnumerator<TPage, TState> enumerator;
private readonly List<TPage> bufferedPages;
private int currentIndex;
private Exception exception;

private bool HasPrefetched => (this.exception != null) || (this.bufferedPages.Count > 0);

public FullyBufferedPartitionRangeAsyncEnumerator(PartitionRangePageAsyncEnumerator<TPage, TState> enumerator, CancellationToken cancellationToken)
: base(enumerator.FeedRangeState, cancellationToken)
{
this.enumerator = enumerator ?? throw new ArgumentNullException(nameof(enumerator));
this.bufferedPages = new List<TPage>();
}

public override ValueTask DisposeAsync()
{
return this.enumerator.DisposeAsync();
}

public override async ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken)
{
if (trace == null)
{
throw new ArgumentNullException(nameof(trace));
}

if (this.HasPrefetched)
{
return;
}

cancellationToken.ThrowIfCancellationRequested();

using (ITrace prefetchTrace = trace.StartChild("Prefetch", TraceComponent.Pagination, TraceLevel.Info))
{
while (await this.enumerator.MoveNextAsync(prefetchTrace))
{
cancellationToken.ThrowIfCancellationRequested();
TryCatch<TPage> current = this.enumerator.Current;
if (current.Succeeded)
{
this.bufferedPages.Add(current.Result);
}
else
{
this.exception = current.Exception;
break;
}
}
}
}

protected override async Task<TryCatch<TPage>> GetNextPageAsync(ITrace trace, CancellationToken cancellationToken)
{
TryCatch<TPage> result;
if (this.currentIndex < this.bufferedPages.Count)
{
result = TryCatch<TPage>.FromResult(this.bufferedPages[this.currentIndex]);
}
else if (this.currentIndex == this.bufferedPages.Count && this.exception != null)
{
result = TryCatch<TPage>.FromException(this.exception);
}
else
{
await this.enumerator.MoveNextAsync(trace);
result = this.enumerator.Current;
}

++this.currentIndex;
return result;
}

public override void SetCancellationToken(CancellationToken cancellationToken)
{
base.SetCancellationToken(cancellationToken);
this.enumerator.SetCancellationToken(cancellationToken);
}
}
}
11 changes: 11 additions & 0 deletions Microsoft.Azure.Cosmos/src/Pagination/PrefetchPolicy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
// ------------------------------------------------------------
namespace Microsoft.Azure.Cosmos.Pagination
{
internal enum PrefetchPolicy
{
PrefetchSinglePage,
PrefetchAll
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ private static TryCatch<IQueryPipelineStage> TryCreatePassthroughQueryExecutionC
queryPaginationOptions: new QueryPaginationOptions(
pageSizeHint: inputParameters.MaxItemCount),
partitionKey: inputParameters.PartitionKey,
prefetchPolicy: PrefetchPolicy.PrefetchSinglePage,
maxConcurrency: inputParameters.MaxConcurrency,
cancellationToken: cancellationToken,
continuationToken: inputParameters.InitialUserContinuationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ public static TryCatch<IQueryPipelineStage> MonadicCreate(
Cosmos.PartitionKey? partitionKey,
QueryPaginationOptions queryPaginationOptions,
int maxConcurrency,
PrefetchPolicy prefetchPolicy,
CosmosElement continuationToken,
CancellationToken cancellationToken)
{
Expand All @@ -159,10 +160,11 @@ public static TryCatch<IQueryPipelineStage> MonadicCreate(
CrossFeedRangeState<QueryState> state = monadicExtractState.Result;

CrossPartitionRangePageAsyncEnumerator<QueryPage, QueryState> crossPartitionPageEnumerator = new CrossPartitionRangePageAsyncEnumerator<QueryPage, QueryState>(
documentContainer,
ParallelCrossPartitionQueryPipelineStage.MakeCreateFunction(documentContainer, sqlQuerySpec, queryPaginationOptions, partitionKey, cancellationToken),
feedRangeProvider: documentContainer,
createPartitionRangeEnumerator: ParallelCrossPartitionQueryPipelineStage.MakeCreateFunction(documentContainer, sqlQuerySpec, queryPaginationOptions, partitionKey, cancellationToken),
comparer: Comparer.Singleton,
maxConcurrency,
maxConcurrency: maxConcurrency,
prefetchPolicy: prefetchPolicy,
state: state,
cancellationToken: cancellationToken);

Expand Down
Loading

0 comments on commit e0466c1

Please sign in to comment.