Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Adds aggressive prefetching for scalar aggregates #3168

Merged
merged 9 commits into from
May 6, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,11 @@ public static CrossPartitionChangeFeedAsyncEnumerator Create(
documentContainer,
changeFeedPaginationOptions,
cancellationToken),
comparer: default /* this uses a regular queue instead of prioirty queue */,
comparer: default /* this uses a regular queue instead of priority queue */,
maxConcurrency: default,
cancellationToken,
state);
prefetchPolicy: PrefetchPolicy.PrefetchSinglePage,
cancellationToken: cancellationToken,
state: state);

CrossPartitionChangeFeedAsyncEnumerator enumerator = new CrossPartitionChangeFeedAsyncEnumerator(
crossPartitionEnumerator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace Microsoft.Azure.Cosmos.Pagination
using Microsoft.Azure.Cosmos.Query.Core.Monads;
using Microsoft.Azure.Cosmos.Tracing;

internal sealed class BufferedPartitionRangePageAsyncEnumerator<TPage, TState> : PartitionRangePageAsyncEnumerator<TPage, TState>, IPrefetcher
internal sealed class BufferedPartitionRangePageAsyncEnumerator<TPage, TState> : BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState>
where TPage : Page<TState>
where TState : State
{
Expand All @@ -23,7 +23,10 @@ public BufferedPartitionRangePageAsyncEnumerator(PartitionRangePageAsyncEnumerat
this.enumerator = enumerator ?? throw new ArgumentNullException(nameof(enumerator));
}

public override ValueTask DisposeAsync() => this.enumerator.DisposeAsync();
public override ValueTask DisposeAsync()
{
return this.enumerator.DisposeAsync();
}

protected override async Task<TryCatch<TPage>> GetNextPageAsync(ITrace trace, CancellationToken cancellationToken)
{
Expand All @@ -40,20 +43,22 @@ protected override async Task<TryCatch<TPage>> GetNextPageAsync(ITrace trace, Ca
return returnValue;
}

public async ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken)
public override async ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken)
{
if (trace == null)
{
throw new ArgumentNullException(nameof(trace));
}

using (ITrace prefetchTrace = trace.StartChild("Prefetch", TraceComponent.Pagination, TraceLevel.Info))
if (this.bufferedPage.HasValue)
{
if (this.bufferedPage.HasValue)
{
return;
}
return;
}

cancellationToken.ThrowIfCancellationRequested();

using (ITrace prefetchTrace = trace.StartChild("Prefetch", TraceComponent.Pagination, TraceLevel.Info))
{
await this.enumerator.MoveNextAsync(prefetchTrace);
this.bufferedPage = this.enumerator.Current;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
// ------------------------------------------------------------

namespace Microsoft.Azure.Cosmos.Pagination
{
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Tracing;

internal abstract class BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState> : PartitionRangePageAsyncEnumerator<TPage, TState>, IPrefetcher
where TPage : Page<TState>
where TState : State
{
protected BufferedPartitionRangePageAsyncEnumeratorBase(FeedRangeState<TState> feedRangeState, CancellationToken cancellationToken)
: base(feedRangeState, cancellationToken)
{
}

public abstract ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,64 +34,24 @@ public CrossPartitionRangePageAsyncEnumerator(
CreatePartitionRangePageAsyncEnumerator<TPage, TState> createPartitionRangeEnumerator,
IComparer<PartitionRangePageAsyncEnumerator<TPage, TState>> comparer,
int? maxConcurrency,
PrefetchPolicy prefetchPolicy,
CancellationToken cancellationToken,
CrossFeedRangeState<TState> state = default)
{
this.feedRangeProvider = feedRangeProvider ?? throw new ArgumentNullException(nameof(feedRangeProvider));
this.createPartitionRangeEnumerator = createPartitionRangeEnumerator ?? throw new ArgumentNullException(nameof(createPartitionRangeEnumerator));
this.cancellationToken = cancellationToken;

this.lazyEnumerators = new AsyncLazy<IQueue<PartitionRangePageAsyncEnumerator<TPage, TState>>>(async (ITrace trace, CancellationToken token) =>
{
ReadOnlyMemory<FeedRangeState<TState>> rangeAndStates;
if (state != default)
{
rangeAndStates = state.Value;
}
else
{
// Fan out to all partitions with default state
List<FeedRangeEpk> ranges = await feedRangeProvider.GetFeedRangesAsync(trace, token);

List<FeedRangeState<TState>> rangesAndStatesBuilder = new List<FeedRangeState<TState>>(ranges.Count);
foreach (FeedRangeInternal range in ranges)
{
rangesAndStatesBuilder.Add(new FeedRangeState<TState>(range, default));
}

rangeAndStates = rangesAndStatesBuilder.ToArray();
}

List<BufferedPartitionRangePageAsyncEnumerator<TPage, TState>> bufferedEnumerators = new List<BufferedPartitionRangePageAsyncEnumerator<TPage, TState>>(rangeAndStates.Length);
for (int i = 0; i < rangeAndStates.Length; i++)
{
FeedRangeState<TState> feedRangeState = rangeAndStates.Span[i];
PartitionRangePageAsyncEnumerator<TPage, TState> enumerator = createPartitionRangeEnumerator(feedRangeState);
BufferedPartitionRangePageAsyncEnumerator<TPage, TState> bufferedEnumerator = new BufferedPartitionRangePageAsyncEnumerator<TPage, TState>(enumerator, cancellationToken);
bufferedEnumerators.Add(bufferedEnumerator);
}

if (maxConcurrency.HasValue)
{
await ParallelPrefetch.PrefetchInParallelAsync(bufferedEnumerators, maxConcurrency.Value, trace, token);
}

IQueue<PartitionRangePageAsyncEnumerator<TPage, TState>> queue;
if (comparer == null)
{
queue = new QueueWrapper<PartitionRangePageAsyncEnumerator<TPage, TState>>(
new Queue<PartitionRangePageAsyncEnumerator<TPage, TState>>(bufferedEnumerators));
}
else
{
queue = new PriorityQueueWrapper<PartitionRangePageAsyncEnumerator<TPage, TState>>(
new PriorityQueue<PartitionRangePageAsyncEnumerator<TPage, TState>>(
bufferedEnumerators,
comparer));
}

return queue;
});
this.lazyEnumerators = new AsyncLazy<IQueue<PartitionRangePageAsyncEnumerator<TPage, TState>>>((ITrace trace, CancellationToken token) =>
InitializeEnumeratorsAsync(
feedRangeProvider,
createPartitionRangeEnumerator,
comparer,
maxConcurrency,
prefetchPolicy,
state,
trace,
token));
}

public TryCatch<CrossFeedRangePage<TPage, TState>> Current { get; private set; }
Expand Down Expand Up @@ -281,6 +241,79 @@ private static bool IsSplitException(Exception exeception)
return enumerators.Peek()?.FeedRangeState;
}

private static async Task<IQueue<PartitionRangePageAsyncEnumerator<TPage, TState>>> InitializeEnumeratorsAsync(
IFeedRangeProvider feedRangeProvider,
CreatePartitionRangePageAsyncEnumerator<TPage, TState> createPartitionRangeEnumerator,
IComparer<PartitionRangePageAsyncEnumerator<TPage, TState>> comparer,
int? maxConcurrency,
PrefetchPolicy prefetchPolicy,
CrossFeedRangeState<TState> state,
ITrace trace,
CancellationToken token)
{
ReadOnlyMemory<FeedRangeState<TState>> rangeAndStates;
if (state != default)
{
rangeAndStates = state.Value;
}
else
{
// Fan out to all partitions with default state
List<FeedRangeEpk> ranges = await feedRangeProvider.GetFeedRangesAsync(trace, token);

List<FeedRangeState<TState>> rangesAndStatesBuilder = new List<FeedRangeState<TState>>(ranges.Count);
foreach (FeedRangeInternal range in ranges)
{
rangesAndStatesBuilder.Add(new FeedRangeState<TState>(range, default));
}

rangeAndStates = rangesAndStatesBuilder.ToArray();
}

IReadOnlyList<BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState>> bufferedEnumerators = CreateBufferedEnumerators(
prefetchPolicy,
createPartitionRangeEnumerator,
rangeAndStates,
token);

if (maxConcurrency.HasValue)
{
await ParallelPrefetch.PrefetchInParallelAsync(bufferedEnumerators, maxConcurrency.Value, trace, token);
}

IQueue<PartitionRangePageAsyncEnumerator<TPage, TState>> queue = comparer == null
? new QueueWrapper<PartitionRangePageAsyncEnumerator<TPage, TState>>(
new Queue<PartitionRangePageAsyncEnumerator<TPage, TState>>(bufferedEnumerators))
: new PriorityQueueWrapper<PartitionRangePageAsyncEnumerator<TPage, TState>>(
new PriorityQueue<PartitionRangePageAsyncEnumerator<TPage, TState>>(
bufferedEnumerators,
comparer));
return queue;
}

private static IReadOnlyList<BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState>> CreateBufferedEnumerators(
PrefetchPolicy policy,
CreatePartitionRangePageAsyncEnumerator<TPage, TState> createPartitionRangeEnumerator,
ReadOnlyMemory<FeedRangeState<TState>> rangeAndStates,
CancellationToken cancellationToken)
{
List<BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState>> bufferedEnumerators = new (rangeAndStates.Length);
for (int i = 0; i < rangeAndStates.Length; i++)
{
FeedRangeState<TState> feedRangeState = rangeAndStates.Span[i];
PartitionRangePageAsyncEnumerator<TPage, TState> enumerator = createPartitionRangeEnumerator(feedRangeState);
BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState> bufferedEnumerator = policy switch
{
PrefetchPolicy.PrefetchSinglePage => new BufferedPartitionRangePageAsyncEnumerator<TPage, TState>(enumerator, cancellationToken),
PrefetchPolicy.PrefetchAll => new FullyBufferedPartitionRangeAsyncEnumerator<TPage, TState>(enumerator, cancellationToken),
_ => throw new ArgumentOutOfRangeException(nameof(policy)),
};
bufferedEnumerators.Add(bufferedEnumerator);
}

return bufferedEnumerators;
}

private interface IQueue<T> : IEnumerable<T>
{
T Peek();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
// ------------------------------------------------------------
namespace Microsoft.Azure.Cosmos.Pagination
{
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Query.Core.Monads;
using Microsoft.Azure.Cosmos.Tracing;

internal sealed class FullyBufferedPartitionRangeAsyncEnumerator<TPage, TState> : BufferedPartitionRangePageAsyncEnumeratorBase<TPage, TState>
ealsur marked this conversation as resolved.
Show resolved Hide resolved
where TPage : Page<TState>
where TState : State
{
private readonly PartitionRangePageAsyncEnumerator<TPage, TState> enumerator;
private readonly List<TPage> bufferedPages;
private int currentIndex;
private Exception exception;

private bool HasPrefetched => (this.exception != null) || (this.bufferedPages.Count > 0);

public FullyBufferedPartitionRangeAsyncEnumerator(PartitionRangePageAsyncEnumerator<TPage, TState> enumerator, CancellationToken cancellationToken)
: base(enumerator.FeedRangeState, cancellationToken)
{
this.enumerator = enumerator ?? throw new ArgumentNullException(nameof(enumerator));
this.bufferedPages = new List<TPage>();
}

public override ValueTask DisposeAsync()
{
return this.enumerator.DisposeAsync();
}

public override async ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken)
{
if (trace == null)
{
throw new ArgumentNullException(nameof(trace));
}

if (this.HasPrefetched)
{
return;
}

cancellationToken.ThrowIfCancellationRequested();

using (ITrace prefetchTrace = trace.StartChild("Prefetch", TraceComponent.Pagination, TraceLevel.Info))
{
while (await this.enumerator.MoveNextAsync(prefetchTrace))
{
cancellationToken.ThrowIfCancellationRequested();
TryCatch<TPage> current = this.enumerator.Current;
if (current.Succeeded)
{
this.bufferedPages.Add(current.Result);
neildsh marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
this.exception = current.Exception;
neildsh marked this conversation as resolved.
Show resolved Hide resolved
break;
}
}
}
}

protected override async Task<TryCatch<TPage>> GetNextPageAsync(ITrace trace, CancellationToken cancellationToken)
{
TryCatch<TPage> result;
if (this.currentIndex < this.bufferedPages.Count)
{
result = TryCatch<TPage>.FromResult(this.bufferedPages[this.currentIndex]);
}
else if (this.currentIndex == this.bufferedPages.Count && this.exception != null)
{
result = TryCatch<TPage>.FromException(this.exception);
}
else
{
await this.enumerator.MoveNextAsync(trace);
result = this.enumerator.Current;
}

++this.currentIndex;
return result;
}

public override void SetCancellationToken(CancellationToken cancellationToken)
{
base.SetCancellationToken(cancellationToken);
this.enumerator.SetCancellationToken(cancellationToken);
}
}
}
11 changes: 11 additions & 0 deletions Microsoft.Azure.Cosmos/src/Pagination/PrefetchPolicy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
// ------------------------------------------------------------
namespace Microsoft.Azure.Cosmos.Pagination
{
internal enum PrefetchPolicy
{
PrefetchSinglePage,
PrefetchAll
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ private static TryCatch<IQueryPipelineStage> TryCreatePassthroughQueryExecutionC
queryPaginationOptions: new QueryPaginationOptions(
pageSizeHint: inputParameters.MaxItemCount),
partitionKey: inputParameters.PartitionKey,
prefetchPolicy: PrefetchPolicy.PrefetchSinglePage,
maxConcurrency: inputParameters.MaxConcurrency,
cancellationToken: cancellationToken,
continuationToken: inputParameters.InitialUserContinuationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ public static TryCatch<IQueryPipelineStage> MonadicCreate(
Cosmos.PartitionKey? partitionKey,
QueryPaginationOptions queryPaginationOptions,
int maxConcurrency,
PrefetchPolicy prefetchPolicy,
CosmosElement continuationToken,
CancellationToken cancellationToken)
{
Expand All @@ -159,10 +160,11 @@ public static TryCatch<IQueryPipelineStage> MonadicCreate(
CrossFeedRangeState<QueryState> state = monadicExtractState.Result;

CrossPartitionRangePageAsyncEnumerator<QueryPage, QueryState> crossPartitionPageEnumerator = new CrossPartitionRangePageAsyncEnumerator<QueryPage, QueryState>(
documentContainer,
ParallelCrossPartitionQueryPipelineStage.MakeCreateFunction(documentContainer, sqlQuerySpec, queryPaginationOptions, partitionKey, cancellationToken),
feedRangeProvider: documentContainer,
createPartitionRangeEnumerator: ParallelCrossPartitionQueryPipelineStage.MakeCreateFunction(documentContainer, sqlQuerySpec, queryPaginationOptions, partitionKey, cancellationToken),
comparer: Comparer.Singleton,
maxConcurrency,
maxConcurrency: maxConcurrency,
prefetchPolicy: prefetchPolicy,
state: state,
cancellationToken: cancellationToken);

Expand Down
Loading