From e04ce514f850cdcaa003b64ec19384662c53e18a Mon Sep 17 00:00:00 2001 From: kevin-montrose Date: Mon, 1 Apr 2024 17:34:37 -0400 Subject: [PATCH] Performance: Refactors query prefetch mechanism (#4361) * sketch out improved ParallelPrefetcher; focus is on reducing allocations, but we also can't be substantially slower to start all tasks * little more cleanup to further reduce allocations, and save a tiny amount of CPU * start on testing * some tweaks and testing for buffer management * test exception handling; fix a bug in high concurrency case that would swallow exceptions * test cancellation * more testing, a little bit of cleanup * test the case where the enumerator faults; fixes a couple leaks of Tasks and buffers that could occur in that some places * tiny bit of cleanup * cleanup and expand comments; code is tricky, it needs documentation * address a whole bunch of style nits, just to keep compiler Message counts down * address some feedback on comment clarity * don't rely on finalizers for testing, it's too brittle; hold up was not allocating more in the non-test cases, but found a field to reuse; needs benchmarking * complete ITrace proxy for testing * style nits and a bit more commentary * explicit test for concurrent access to the inner IEnumerator * explicit test that IEnumerator is disposed * explicitly implement all ITrace members * address feedback: internal class members should be public or private * address feedback: break test-only bits of ParallelPrefetch out into a partial * address feedback: move const above type declarations * address feedback: naming nits * address feedback: use the existing NoOpTrace * address feedback: remove pointless using * update baseline trace text for QueryAsync test --------- Co-authored-by: neildsh <35383880+neildsh@users.noreply.github.com> --- .../Pagination/ParallelPrefetch.Testing.cs | 121 ++ .../src/Pagination/ParallelPrefetch.cs | 736 +++++++++- ...EndTraceWriterBaselineTests.QueryAsync.xml | 35 - .../Pagination/ParallelPrefetchTests.cs | 1255 +++++++++++++++++ 4 files changed, 2081 insertions(+), 66 deletions(-) create mode 100644 Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.Testing.cs create mode 100644 Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Pagination/ParallelPrefetchTests.cs diff --git a/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.Testing.cs b/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.Testing.cs new file mode 100644 index 0000000000..b37fa2f456 --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.Testing.cs @@ -0,0 +1,121 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ +namespace Microsoft.Azure.Cosmos.Pagination +{ + using System; + using System.Buffers; + using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Tracing; + + /// + /// Holds the "just for testing"-bits of . + /// + internal static partial class ParallelPrefetch + { + /// + /// For testing purposes, provides ways to instrument . + /// + /// You shouldn't be using this outside of test projects. + /// + internal sealed class ParallelPrefetchTestConfig : ITrace + { + private ITrace innerTrace; + + private int startedTasks; + private int awaitedTasks; + + public ArrayPool PrefetcherPool { get; private set; } + public ArrayPool TaskPool { get; private set; } + public ArrayPool ObjectPool { get; private set; } + + public int StartedTasks + => this.startedTasks; + + public int AwaitedTasks + => this.awaitedTasks; + + string ITrace.Name => this.innerTrace.Name; + + Guid ITrace.Id => this.innerTrace.Id; + + DateTime ITrace.StartTime => this.innerTrace.StartTime; + + TimeSpan ITrace.Duration => this.innerTrace.Duration; + + TraceLevel ITrace.Level => this.innerTrace.Level; + + TraceComponent ITrace.Component => this.innerTrace.Component; + + TraceSummary ITrace.Summary => this.innerTrace.Summary; + + ITrace ITrace.Parent => this.innerTrace.Parent; + + IReadOnlyList ITrace.Children => this.innerTrace.Children; + + IReadOnlyDictionary ITrace.Data => this.innerTrace.Data; + + public ParallelPrefetchTestConfig( + ArrayPool prefetcherPool, + ArrayPool taskPool, + ArrayPool objectPool) + { + this.PrefetcherPool = prefetcherPool; + this.TaskPool = taskPool; + this.ObjectPool = objectPool; + } + + public void SetInnerTrace(ITrace trace) + { + this.innerTrace = trace; + } + + public void TaskStarted() + { + Interlocked.Increment(ref this.startedTasks); + } + + public void TaskAwaited() + { + Interlocked.Increment(ref this.awaitedTasks); + } + + ITrace ITrace.StartChild(string name) + { + return this.innerTrace.StartChild(name); + } + + ITrace ITrace.StartChild(string name, TraceComponent component, TraceLevel level) + { + return this.innerTrace.StartChild(name, component, level); + } + + void ITrace.AddDatum(string key, TraceDatum traceDatum) + { + this.innerTrace.AddDatum(key, traceDatum); + } + + void ITrace.AddDatum(string key, object value) + { + this.innerTrace.AddDatum(key, value); + } + + void ITrace.AddOrUpdateDatum(string key, object value) + { + this.innerTrace.AddOrUpdateDatum(key, value); + } + + void ITrace.AddChild(ITrace trace) + { + this.innerTrace.AddChild(trace); + } + + void IDisposable.Dispose() + { + this.innerTrace.Dispose(); + } + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.cs b/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.cs index 56bc732658..c4782329f7 100644 --- a/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.cs +++ b/Microsoft.Azure.Cosmos/src/Pagination/ParallelPrefetch.cs @@ -5,73 +5,747 @@ namespace Microsoft.Azure.Cosmos.Pagination { using System; + using System.Buffers; using System.Collections.Generic; + using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Cosmos.Tracing; - internal static class ParallelPrefetch + internal static partial class ParallelPrefetch { - public static async Task PrefetchInParallelAsync( + /// + /// Number of tasks started at one time, maximum, when working through prefetchers. + /// + /// Also used as a the limit between Low and High concurrency implementations. + /// + /// This number should be reasonable large, but less than the point where a + /// Task[BatchLimit] ends up on the LOH (which will be around 8,192). + /// + private const int BatchLimit = 512; + + /// + /// Common state that is needed for all tasks started via , unless + /// certain special cases hold. + /// + /// Also used as a synchronization primitive. + /// + private sealed class CommonPrefetchState + { + // we also use this to signal if we're finished enumerating, to save space + private IEnumerator enumerator; + + /// + /// If this is true, it's a signal that new work should not be queued up. + /// + public bool FinishedEnumerating + => Volatile.Read(ref this.enumerator) == null; + + /// + /// Common to be used by all tasks. + /// + /// When testing, this can also include be a . + /// We reuse this to keep allocations down in non-test cases. + /// + public ITrace PrefetchTrace { get; private set; } + + /// + /// The which will produce the next + /// to use. + /// + /// Once at least one Task been started, should only be accessed under a lock. + /// + /// If == true, this returns null. + /// + public IEnumerator Enumerator + => Volatile.Read(ref this.enumerator); + + /// + /// provided via . + /// + public CancellationToken CancellationToken { get; private set; } + + public CommonPrefetchState(ITrace prefetchTrace, IEnumerator enumerator, CancellationToken cancellationToken) + { + this.PrefetchTrace = prefetchTrace; + this.enumerator = enumerator; + this.CancellationToken = cancellationToken; + } + + /// + /// Cause to return true. + /// + public void SetFinishedEnumerating() + { + Volatile.Write(ref this.enumerator, null); + } + } + + /// + /// State passed when we start a Task with an initial . + /// + /// That started Task will obtain it's next IPrefetchers using the + /// that is also provided. + /// + private sealed class SinglePrefetchState + { + /// + /// State common to the whole call. + /// + public CommonPrefetchState CommonState { get; private set; } + + /// + /// which must be invoked next. + /// + public IPrefetcher CurrentPrefetcher { get; set; } + + public SinglePrefetchState(CommonPrefetchState commonState, IPrefetcher initialPrefetcher) + { + this.CommonState = commonState; + this.CurrentPrefetcher = initialPrefetcher; + } + } + + public static Task PrefetchInParallelAsync( + IEnumerable prefetchers, + int maxConcurrency, + ITrace trace, + CancellationToken cancellationToken) + { + prefetchers = prefetchers ?? throw new ArgumentNullException(nameof(prefetchers)); + trace = trace ?? throw new ArgumentNullException(nameof(trace)); + + return PrefetchInParallelCoreAsync(prefetchers, maxConcurrency, trace, null, cancellationToken); + } + + /// + /// Exposed for testing purposes, do not call directly. + /// + public static Task PrefetchInParallelCoreAsync( IEnumerable prefetchers, int maxConcurrency, ITrace trace, + ParallelPrefetchTestConfig config, CancellationToken cancellationToken) { - if (prefetchers == null) + if (maxConcurrency <= 0) + { + // old code would just... allocate and then do nothing + // + // so we do nothing here, for compatability purposes + return Task.CompletedTask; + } + else if (maxConcurrency == 1) + { + return SingleConcurrencyPrefetchInParallelAsync(prefetchers, trace, config, cancellationToken); + } + else if (maxConcurrency <= BatchLimit) { - throw new ArgumentNullException(nameof(prefetchers)); + return LowConcurrencyPrefetchInParallelAsync(prefetchers, maxConcurrency, trace, config, cancellationToken); } + else + { + return HighConcurrencyPrefetchInParallelAsync(prefetchers, maxConcurrency, trace, config, cancellationToken); + } + } + + /// + /// Shared code for starting traces while prefetching. + /// + private static ITrace CommonStartTrace(ITrace trace) + { + return trace.StartChild(name: "Prefetching", TraceComponent.Pagination, TraceLevel.Info); + } - if (trace == null) + /// + /// Helper for grabbing a reusable array. + /// + private static T[] RentArray(ParallelPrefetchTestConfig config, int minSize, bool clear) + { + T[] result; + if (config != null) { - throw new ArgumentNullException(nameof(trace)); +#pragma warning disable IDE0045 // Convert to conditional expression - chained else if is clearer + if (typeof(T) == typeof(IPrefetcher)) + { + result = (T[])(object)config.PrefetcherPool.Rent(minSize); + } + else if (typeof(T) == typeof(Task)) + { + result = (T[])(object)config.TaskPool.Rent(minSize); + } + else + { + result = (T[])(object)config.ObjectPool.Rent(minSize); + } +#pragma warning restore IDE0045 + } + else + { + result = ArrayPool.Shared.Rent(minSize); } - using (ITrace prefetchTrace = trace.StartChild(name: "Prefetching", TraceComponent.Pagination, TraceLevel.Info)) + if (clear) { - HashSet tasks = new HashSet(); - IEnumerator prefetchersEnumerator = prefetchers.GetEnumerator(); - for (int i = 0; i < maxConcurrency; i++) + Array.Clear(result, 0, result.Length); + } + + return result; + } + + /// + /// Helper for returning arrays what were rented via . + /// + private static void ReturnRentedArray(ParallelPrefetchTestConfig config, T[] array, int clearThrough) + { + if (array == null) + { + return; + } + + // this is important, otherwise we might leave Tasks and IPrefetchers + // rooted long enough to cause problems + Array.Clear(array, 0, clearThrough); + + if (config != null) + { + if (typeof(T) == typeof(IPrefetcher)) { - if (!prefetchersEnumerator.MoveNext()) + config.PrefetcherPool.Return((IPrefetcher[])(object)array); + } + else if (typeof(T) == typeof(Task)) + { + config.TaskPool.Return((Task[])(object)array); + } + else + { + config.ObjectPool.Return((object[])(object)array); + } + } + else + { + ArrayPool.Shared.Return(array); + } + } + + /// + /// Starts a new Task that first calls on the passed + /// , and then grabs new ones from and repeats the process + /// until either the enumerator finishes or something sets . + /// + private static Task CommonStartTaskAsync(ParallelPrefetchTestConfig config, CommonPrefetchState commonState, IPrefetcher firstPrefetcher) + { + config?.TaskStarted(); + + SinglePrefetchState state = new (commonState, firstPrefetcher); + + // this is mimicing the behavior of Task.Run(...) (that is, default CancellationToken, default Scheduler, DenyAttachChild, etc.) + // but in a way that let's us pass a context object + // + // this lets us declare a static delegate, and thus let's compiler reuse the delegate allocation + Task taskLoop = + Task.Factory.StartNew( + static async (context) => { - break; - } + // this method is structured a bit oddly to prevent the compiler from putting more data into the + // state of the Task - basically, don't have any locals (except context) that survive across an await + // + // we could go harder here and just not use async/await but that's awful for maintainability + try + { + while (true) + { + // step up to the initial await + { + SinglePrefetchState innerState = (SinglePrefetchState)context; + + CommonPrefetchState innerCommonState = innerState.CommonState; + (ITrace prefetchTrace, CancellationToken cancellationToken) = (innerCommonState.PrefetchTrace, innerCommonState.CancellationToken); + + // we smuggle a test config in as the common ITrace + // + // in most code, this will be null - but this pattern + // let's use keep CommonPrefetchState small + ParallelPrefetchTestConfig config = prefetchTrace as ParallelPrefetchTestConfig; + + config?.TaskStarted(); + config?.TaskAwaited(); + await innerState.CurrentPrefetcher.PrefetchAsync(prefetchTrace, cancellationToken); + } + + // step for preparing the next prefetch + { + SinglePrefetchState innerState = (SinglePrefetchState)context; - IPrefetcher prefetcher = prefetchersEnumerator.Current; - tasks.Add(Task.Run(async () => await prefetcher.PrefetchAsync(prefetchTrace, cancellationToken))); + CommonPrefetchState innerCommonState = innerState.CommonState; + + if (innerCommonState.FinishedEnumerating) + { + // we're done, bail + return; + } + + // proceed to the next item + // + // we need this lock because at this point there + // are other Tasks potentially also looking to call + // enumerator.MoveNext() + lock (innerCommonState) + { + // this can have transitioned to null since we last checked + // so this is basically double-check locking + IEnumerator enumerator = innerCommonState.Enumerator; + if (enumerator == null) + { + return; + } + + if (!enumerator.MoveNext()) + { + // we're done, signal to every other task to also bail + innerCommonState.SetFinishedEnumerating(); + + return; + } + + // move on to the new IPrefetcher just obtained + innerState.CurrentPrefetcher = enumerator.Current; + } + } + } + } + catch + { + SinglePrefetchState innerState = (SinglePrefetchState)context; + + // some error was encountered, we should tell other tasks to stop starting new prefetch tasks + // because we're about to cancel + innerState.CommonState.SetFinishedEnumerating(); + + // percolate the error up + throw; + } + }, + state, + default, + TaskCreationOptions.DenyChildAttach, + TaskScheduler.Default); + + // we _could_ maybe optimize this more... perhaps using a SemaphoreSlim or something + // but that complicates error reporting and is also awful for maintability + Task unwrapped = taskLoop.Unwrap(); + + return unwrapped; + } + + /// + /// Fills a portion of an IPrefetcher[] using the passed enumerator. + /// + /// Returns the index that would next be filled. + /// + /// Updates the passed if the end of the enumerator is reached. + /// + /// Synchronization is the concern of the caller, not this method. + /// + private static int FillPrefetcherBuffer(CommonPrefetchState commonState, IPrefetcher[] prefetchers, int startIndex, int endIndex, IEnumerator enumerator) + { + int curIndex; + for (curIndex = startIndex; curIndex < endIndex; curIndex++) + { + if (!enumerator.MoveNext()) + { + commonState.SetFinishedEnumerating(); + break; } - while (tasks.Count != 0) + prefetchers[curIndex] = enumerator.Current; + } + + return curIndex; + } + + /// + /// Special case for when maxConcurrency == 1. + /// + /// This devolves into a foreach loop. + /// + private static async Task SingleConcurrencyPrefetchInParallelAsync(IEnumerable prefetchers, ITrace trace, ParallelPrefetchTestConfig config, CancellationToken cancellationToken) + { + using (ITrace prefetchTrace = CommonStartTrace(trace)) + { + foreach (IPrefetcher prefetcher in prefetchers) { - Task completedTask = await Task.WhenAny(tasks); - tasks.Remove(completedTask); - try - { - await completedTask; - } - catch + config?.TaskStarted(); + config?.TaskAwaited(); + await prefetcher.PrefetchAsync(prefetchTrace, cancellationToken); + } + } + } + + /// + /// The case where maxConcurrency is less than or equal to BatchLimit. + /// + /// This starts up to maxConcurrency simultanous Tasks, doing so in a way that + /// requires rented arrays of maxConcurrency size. + /// + private static async Task LowConcurrencyPrefetchInParallelAsync( + IEnumerable prefetchers, + int maxConcurrency, + ITrace trace, + ParallelPrefetchTestConfig config, + CancellationToken cancellationToken) + { + IPrefetcher[] initialPrefetchers = null; + Task[] runningTasks = null; + + int nextPrefetcherIndex = 0; + int nextRunningTaskIndex = 0; + + try + { + using (ITrace prefetchTrace = CommonStartTrace(trace)) + { + config?.SetInnerTrace(prefetchTrace); + + using (IEnumerator enumerator = prefetchers.GetEnumerator()) { - // Observe the remaining tasks - try + if (!enumerator.MoveNext()) { - await Task.WhenAll(tasks); + // literally nothing to prefetch + return; } - catch + + IPrefetcher first = enumerator.Current; + + if (!enumerator.MoveNext()) { + // special case: a single prefetcher... just await it, and skip all the heavy work + config?.TaskStarted(); + config?.TaskAwaited(); + await first.PrefetchAsync(prefetchTrace, cancellationToken); + return; } - throw; + // need to actually do things to start prefetching in parallel + // so grab some state and stash the first two prefetchers off + + initialPrefetchers = RentArray(config, maxConcurrency, clear: false); + initialPrefetchers[0] = first; + initialPrefetchers[1] = enumerator.Current; + + CommonPrefetchState commonState = new (config ?? prefetchTrace, enumerator, cancellationToken); + + // batch up a bunch of IPrefetchers to kick off + // + // we do this separately from starting the Tasks so we can avoid a lock + // and quicky get to maxConcurrency degrees of parallelism + nextPrefetcherIndex = FillPrefetcherBuffer(commonState, initialPrefetchers, 2, maxConcurrency, enumerator); + + // actually start all the tasks, stashing them in a rented Task[] + runningTasks = RentArray(config, nextPrefetcherIndex, clear: false); + + for (nextRunningTaskIndex = 0; nextRunningTaskIndex < nextPrefetcherIndex; nextRunningTaskIndex++) + { + IPrefetcher toStart = initialPrefetchers[nextRunningTaskIndex]; + Task startedTask = CommonStartTaskAsync(config, commonState, toStart); + + runningTasks[nextRunningTaskIndex] = startedTask; + } + + // hand the prefetcher array back early, so other callers can use it + ReturnRentedArray(config, initialPrefetchers, nextPrefetcherIndex); + initialPrefetchers = null; + + // now await all Tasks in turn + for (int toAwaitTaskIndex = 0; toAwaitTaskIndex < nextRunningTaskIndex; toAwaitTaskIndex++) + { + Task toAwait = runningTasks[toAwaitTaskIndex]; + + try + { + config?.TaskAwaited(); + await toAwait; + } + catch + { + // if we encountered some exception, tell the remaining tasks to bail + // the next time they check commonState + commonState.SetFinishedEnumerating(); + + // we still need to observe all the tasks we haven't yet to avoid an UnobservedTaskException + for (int awaitAndIgnoreTaskIndex = toAwaitTaskIndex + 1; awaitAndIgnoreTaskIndex < nextRunningTaskIndex; awaitAndIgnoreTaskIndex++) + { + try + { + config?.TaskAwaited(); + await runningTasks[awaitAndIgnoreTaskIndex]; + } + catch + { + // intentionally left empty, we swallow all errors after the first + } + } + + throw; + } + } } + } + } + finally + { + ReturnRentedArray(config, initialPrefetchers, nextPrefetcherIndex); + ReturnRentedArray(config, runningTasks, nextRunningTaskIndex); + } + } - if (prefetchersEnumerator.MoveNext()) + /// + /// The case where maxConcurrency is greater than BatchLimit. + /// + /// This starts up to maxConcurrency simultanous Tasks, doing so in batches + /// of BatchLimit (or less) size. Active Tasks are tracked in a psuedo-linked-list + /// over rented object[]. + /// + /// This is more complicated, less likely to hit maxConcurrency degrees of + /// parallelism, and less allocation efficient when compared to LowConcurrencyPrefetchInParallelAsync. + /// + /// However, it doesn't allocate gigantic arrays and doesn't wait for full enumeration + /// before starting to prefetch. + /// + private static async Task HighConcurrencyPrefetchInParallelAsync( + IEnumerable prefetchers, + int maxConcurrency, + ITrace trace, + ParallelPrefetchTestConfig config, + CancellationToken cancellationToken) + { + IPrefetcher[] currentBatch = null; + + // this ends up holding a sort of linked list where + // each entry is actually a Task until the very last one + // which is an object[] + // + // as soon as a null is encountered, either where a Task or + // an object[] is expected, the linked list is done + object[] runningTasks = null; + + try + { + using (ITrace prefetchTrace = CommonStartTrace(trace)) + { + config?.SetInnerTrace(prefetchTrace); + + using (IEnumerator enumerator = prefetchers.GetEnumerator()) { - IPrefetcher bufferable = prefetchersEnumerator.Current; - tasks.Add(Task.Run(async () => await bufferable.PrefetchAsync(prefetchTrace, cancellationToken))); + if (!enumerator.MoveNext()) + { + // no prefetchers at all + return; + } + + IPrefetcher first = enumerator.Current; + + if (!enumerator.MoveNext()) + { + // special case: a single prefetcher... just await it, and skip all the heavy work + config?.TaskStarted(); + config?.TaskAwaited(); + await first.PrefetchAsync(prefetchTrace, cancellationToken); + return; + } + + // need to actually do things to start prefetching in parallel + // so grab some state and stash the first two prefetchers off + + currentBatch = RentArray(config, BatchLimit, clear: false); + currentBatch[0] = first; + currentBatch[1] = enumerator.Current; + + // we need this all null because we use null as a stopping condition later + runningTasks = RentArray(config, BatchLimit, clear: true); + + CommonPrefetchState commonState = new (config ?? prefetchTrace, enumerator, cancellationToken); + + // what we do here is buffer up to BatchLimit IPrefetchers to start + // and then... start them all + // + // we stagger this so we quickly get a bunch of tasks started without spending too + // much time pre-loading everything + + // grab our first bunch of prefetchers outside of the lock + // + // we know that maxConcurrency > BatchLimit, so can just pass it as our cutoff here + int bufferedPrefetchers = FillPrefetcherBuffer(commonState, currentBatch, 2, BatchLimit, enumerator); + + int nextChunkIndex = 0; + object[] currentChunk = runningTasks; + + int remainingConcurrency = maxConcurrency; + + // if we encounter any error, we remember it + // but as soon as we start a single task we've got + // to see most of this code through so we observe them + ExceptionDispatchInfo capturedException = null; + + while (true) + { + // start and store the last set of Tasks we got from FillPrefetcherBuffer + for (int toStartIndex = 0; toStartIndex < bufferedPrefetchers; toStartIndex++) + { + IPrefetcher prefetcher = currentBatch[toStartIndex]; + Task startedTask = CommonStartTaskAsync(config, commonState, prefetcher); + + currentChunk[nextChunkIndex] = startedTask; + nextChunkIndex++; + + // check if we need a new slab to store tasks + if (nextChunkIndex == currentChunk.Length - 1) + { + // we need this all null because we use null as a stopping condition later + object[] newChunk = RentArray(config, BatchLimit, clear: true); + + currentChunk[currentChunk.Length - 1] = newChunk; + + currentChunk = newChunk; + nextChunkIndex = 0; + } + } + + remainingConcurrency -= bufferedPrefetchers; + + // check to see if we've started all the concurrent Tasks we can + if (remainingConcurrency == 0) + { + break; + } + + int nextBatchSizeLimit = remainingConcurrency < BatchLimit ? remainingConcurrency : BatchLimit; + + // if one of the previously started Tasks exhausted the enumerator + // we're done, even if we still have space + if (commonState.FinishedEnumerating) + { + break; + } + + // now that Tasks have started, we MUST synchronize access to + // the enumerator + lock (commonState) + { + // the answer might have changed, so we double-check + // this once we've got the lock + if (commonState.FinishedEnumerating) + { + break; + } + + // grab the next set of prefetchers to start + try + { + bufferedPrefetchers = FillPrefetcherBuffer(commonState, currentBatch, 0, nextBatchSizeLimit, enumerator); + } + catch (Exception exc) + { + // this can get raised if the enumerator faults + // + // in this case we might have some tasks started, and so we need to _stop_ starting new tasks but + // still move on to observing everything we've already started + + commonState.SetFinishedEnumerating(); + capturedException = ExceptionDispatchInfo.Capture(exc); + + break; + } + } + + // if we got nothing back, we can break right here + if (bufferedPrefetchers == 0) + { + break; + } + } + + // hand the prefetch array back, we're done with it + ReturnRentedArray(config, currentBatch, BatchLimit); + currentBatch = null; + + // now wait for all the tasks to complete + // + // we walk through all of them, even if we encounter an error + // because we need to walk the whole linked-list and this is + // simpler than an explicit error code path + + int toAwaitIndex = 0; + while (runningTasks != null) + { + Task toAwait = (Task)runningTasks[toAwaitIndex]; + + // if we see a null, we're done + if (toAwait == null) + { + // hand the last of the arrays back + ReturnRentedArray(config, runningTasks, toAwaitIndex); + runningTasks = null; + + break; + } + + try + { + config?.TaskAwaited(); + await toAwait; + } + catch (Exception ex) + { + if (capturedException == null) + { + // if we encountered some exception, tell the remaining tasks to bail + // the next time they check commonState + commonState.SetFinishedEnumerating(); + + // save the exception so we can rethrow it later + capturedException = ExceptionDispatchInfo.Capture(ex); + } + } + + // advance, moving to the next chunk if we've hit that limit + toAwaitIndex++; + + if (toAwaitIndex == runningTasks.Length - 1) + { + object[] oldChunk = runningTasks; + + runningTasks = (object[])runningTasks[runningTasks.Length - 1]; + toAwaitIndex = 0; + + // we're done with this, let some other caller reuse it immediately + ReturnRentedArray(config, oldChunk, oldChunk.Length); + } + } + + // fault, if any task failed, after we've finished cleaning up + capturedException?.Throw(); } } } + finally + { + // cleanup if something went wrong while these were still rented + // + // this can basically only happen if the enumerator itself faults + // which is unlikely, but far from impossible + + ReturnRentedArray(config, currentBatch, BatchLimit); + + while (runningTasks != null) + { + object[] oldChunk = runningTasks; + + runningTasks = (object[])runningTasks[runningTasks.Length - 1]; + + ReturnRentedArray(config, oldChunk, oldChunk.Length); + } + } } } } diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/EndToEndTraceWriterBaselineTests.QueryAsync.xml b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/EndToEndTraceWriterBaselineTests.QueryAsync.xml index 7577f2ae4f..460eaa593c 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/EndToEndTraceWriterBaselineTests.QueryAsync.xml +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/EndToEndTraceWriterBaselineTests.QueryAsync.xml @@ -41,7 +41,6 @@ │ ├── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ └── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds - │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds @@ -246,10 +245,6 @@ "name": "MoveNextAsync", "duration in milliseconds": 0, "children": [ - { - "name": "Prefetching", - "duration in milliseconds": 0 - }, { "name": "[,05C1CFFFFFFFF8) move next", "duration in milliseconds": 0, @@ -771,7 +766,6 @@ │ ├── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ ├── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds - │ │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ │ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds @@ -980,10 +974,6 @@ "name": "MoveNextAsync", "duration in milliseconds": 0, "children": [ - { - "name": "Prefetching", - "duration in milliseconds": 0 - }, { "name": "[,05C1CFFFFFFFF8) move next", "duration in milliseconds": 0, @@ -1522,7 +1512,6 @@ │ ├── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ └── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds - │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds @@ -1727,10 +1716,6 @@ "name": "MoveNextAsync", "duration in milliseconds": 0, "children": [ - { - "name": "Prefetching", - "duration in milliseconds": 0 - }, { "name": "[,05C1CFFFFFFFF8) move next", "duration in milliseconds": 0, @@ -2253,7 +2238,6 @@ │ ├── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ ├── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds - │ │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ │ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds @@ -2462,10 +2446,6 @@ "name": "MoveNextAsync", "duration in milliseconds": 0, "children": [ - { - "name": "Prefetching", - "duration in milliseconds": 0 - }, { "name": "[,05C1CFFFFFFFF8) move next", "duration in milliseconds": 0, @@ -3028,7 +3008,6 @@ │ ├── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ ├── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds - │ │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ │ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds @@ -3287,10 +3266,6 @@ "name": "MoveNextAsync", "duration in milliseconds": 0, "children": [ - { - "name": "Prefetching", - "duration in milliseconds": 0 - }, { "name": "[,05C1CFFFFFFFF8) move next", "duration in milliseconds": 0, @@ -3834,7 +3809,6 @@ │ │ └── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ └── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds - │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds @@ -4057,10 +4031,6 @@ "name": "MoveNextAsync", "duration in milliseconds": 0, "children": [ - { - "name": "Prefetching", - "duration in milliseconds": 0 - }, { "name": "[,05C1CFFFFFFFF8) move next", "duration in milliseconds": 0, @@ -4588,7 +4558,6 @@ │ │ └── Get Partition Key Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ │ └── Try Get Overlapping Ranges(00000000-0000-0000-0000-000000000000) Routing-Component 00:00:00:000 0.00 milliseconds │ ├── MoveNextAsync(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds - │ │ ├── Prefetching(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ │ └── Prefetch(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds │ │ └── [,05C1CFFFFFFFF8) move next(00000000-0000-0000-0000-000000000000) Pagination-Component 00:00:00:000 0.00 milliseconds @@ -4815,10 +4784,6 @@ "name": "MoveNextAsync", "duration in milliseconds": 0, "children": [ - { - "name": "Prefetching", - "duration in milliseconds": 0 - }, { "name": "[,05C1CFFFFFFFF8) move next", "duration in milliseconds": 0, diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Pagination/ParallelPrefetchTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Pagination/ParallelPrefetchTests.cs new file mode 100644 index 0000000000..84511c8337 --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Pagination/ParallelPrefetchTests.cs @@ -0,0 +1,1255 @@ +namespace Microsoft.Azure.Cosmos.Tests.Pagination +{ + using System; + using System.Buffers; + using System.Collections; + using System.Collections.Concurrent; + using System.Collections.Generic; + using System.Diagnostics; + using System.Linq; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Pagination; + using Microsoft.Azure.Cosmos.Tracing; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class ParallelPrefetchTests + { + /// + /// IPrefetcher which can only be run once, and invokes callbacks as it executes. + /// + private sealed class TestOncePrefetcher : IPrefetcher + { + private int hasRun; + + private readonly Action beforeAwait; + private readonly Action afterAwait; + + internal TestOncePrefetcher(Action beforeAwait, Action afterAwait) + { + this.hasRun = 0; + this.beforeAwait = beforeAwait; + this.afterAwait = afterAwait; + } + + public async ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken) + { + // test that ParallelPrefetch doesn't start the same Task twice. + int oldRun = Interlocked.Exchange(ref this.hasRun, 1); + Assert.AreEqual(0, oldRun); + + // we use two callbacks to test that ParallelPrefetch is correctly monitoring + // continuations - without this, we might incorrectly consider a Task completed + // despite it awaiting an inner Task + + this.beforeAwait(); + + await Task.Yield(); + cancellationToken.ThrowIfCancellationRequested(); + + this.afterAwait(); + cancellationToken.ThrowIfCancellationRequested(); + } + } + + /// + /// IPrefetcher that does complicated things. + /// + private sealed class ComplicatedPrefetcher : IPrefetcher + { + public long StartTimestamp { get; private set; } + public long AfterYieldTimestamp { get; private set; } + public long AfterDelay1Timestamp { get; private set; } + public long AfterSemaphoreTimestamp { get; private set; } + public long AfterDelay2Timestamp { get; private set; } + public long AfterDelay3Timestamp { get; private set; } + public long AfterDelay4Timestamp { get; private set; } + public long WhenAllTimestamp { get; private set; } + public long EndTimestamp { get; private set; } + + public async ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken) + { + this.StartTimestamp = Stopwatch.GetTimestamp(); + + await Task.Yield(); + + this.AfterYieldTimestamp = Stopwatch.GetTimestamp(); + + using (SemaphoreSlim semaphore = new(0, 1)) + { + Task delay = Task.Delay(5, cancellationToken).ContinueWith(_ => { this.AfterDelay1Timestamp = Stopwatch.GetTimestamp(); semaphore.Release(); }, cancellationToken); + + await semaphore.WaitAsync(cancellationToken); + this.AfterSemaphoreTimestamp = Stopwatch.GetTimestamp(); + + await delay; + } + + await Task.WhenAll( + Task.Delay(2, cancellationToken).ContinueWith(_ => this.AfterDelay2Timestamp = Stopwatch.GetTimestamp(), cancellationToken), + Task.Delay(3, cancellationToken).ContinueWith(_ => this.AfterDelay3Timestamp = Stopwatch.GetTimestamp(), cancellationToken), + Task.Delay(4, cancellationToken).ContinueWith(_ => this.AfterDelay4Timestamp = Stopwatch.GetTimestamp(), cancellationToken)); + this.WhenAllTimestamp = Stopwatch.GetTimestamp(); + + await Task.Yield(); + + this.EndTimestamp = Stopwatch.GetTimestamp(); + } + + internal void AssertCorrect() + { + Assert.IsTrue(this.StartTimestamp > 0); + + Assert.IsTrue(this.AfterYieldTimestamp > this.StartTimestamp); + Assert.IsTrue(this.AfterDelay1Timestamp > this.AfterYieldTimestamp); + Assert.IsTrue(this.AfterSemaphoreTimestamp > this.AfterDelay1Timestamp); + + // these can all fire in any order (delay doesn't guarantee any particular order) + Assert.IsTrue(this.AfterDelay2Timestamp > this.AfterSemaphoreTimestamp); + Assert.IsTrue(this.AfterDelay3Timestamp > this.AfterSemaphoreTimestamp); + Assert.IsTrue(this.AfterDelay4Timestamp > this.AfterSemaphoreTimestamp); + + // but by WhenAll()'ing them, we can assert WhenAll completes after all the other delays + Assert.IsTrue(this.WhenAllTimestamp > this.AfterDelay2Timestamp); + Assert.IsTrue(this.WhenAllTimestamp > this.AfterDelay3Timestamp); + Assert.IsTrue(this.WhenAllTimestamp > this.AfterDelay4Timestamp); + + Assert.IsTrue(this.EndTimestamp > this.WhenAllTimestamp); + } + } + + /// + /// IPrefetcher that asserts it got a trace with an expected parent. + /// + private sealed class ExpectedParentTracePrefetcher : IPrefetcher + { + private readonly ITrace expectedParentTrace; + + internal ExpectedParentTracePrefetcher(ITrace expectedParentTrace) + { + this.expectedParentTrace = expectedParentTrace; + } + + public ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken) + { + Assert.AreSame(this.expectedParentTrace, trace.Parent); + + return default; + } + } + + /// + /// IEnumerable which throws if touched. + /// + private sealed class ThrowsEnumerable : IEnumerable + { + public IEnumerator GetEnumerator() + { + throw new NotSupportedException(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + } + + /// + /// IEnumerable whose IEnumerator throws if access concurrently. + /// + private sealed class NonConcurrentAssertingEnumerable : IEnumerable + { + private sealed class Enumerator : IEnumerator + { + private readonly T[] inner; + private int index; + + private int active; + + internal Enumerator(T[] inner) + { + this.inner = inner; + } + + public T Current { get; private set; } + + object IEnumerator.Current => this.Current; + + public void Dispose() + { + } + + public bool MoveNext() + { + int isActive = Interlocked.Exchange(ref this.active, 1); + Assert.AreEqual(0, isActive, "Modified concurrently"); + + try + { + if (this.index < this.inner.Length) + { + this.Current = this.inner[this.index]; + this.index++; + return true; + } + + return false; + } + finally + { + int wasActive = Interlocked.Exchange(ref this.active, 0); + Assert.AreEqual(1, wasActive, "Modified concurrently"); + } + } + + public void Reset() + { + throw new NotImplementedException(); + } + } + + private readonly T[] inner; + + internal NonConcurrentAssertingEnumerable(T[] inner) + { + this.inner = inner; + } + + public IEnumerator GetEnumerator() + { + return new Enumerator(this.inner); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + } + + /// + /// IEnumerable whose IEnumerator throws if access concurrently. + /// + private sealed class DisposeTrackingEnumerable : IEnumerable + { + private sealed class Enumerator : IEnumerator + { + private readonly DisposeTrackingEnumerable outer; + + private int index; + + internal Enumerator(DisposeTrackingEnumerable outer) + { + this.outer = outer; + } + + public T Current { get; private set; } + + object IEnumerator.Current => this.Current; + + public void Dispose() + { + this.outer.DisposeCalls++; + } + + public bool MoveNext() + { + if (this.index < this.outer.inner.Length) + { + this.Current = this.outer.inner[this.index]; + this.index++; + return true; + } + + return false; + } + + public void Reset() + { + throw new NotImplementedException(); + } + } + + private readonly T[] inner; + + internal DisposeTrackingEnumerable(T[] inner) + { + this.inner = inner; + } + + internal int DisposeCalls { get; private set; } + + public IEnumerator GetEnumerator() + { + return new Enumerator(this); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + } + + /// + /// IEnumerable whose IEnumerator throws after a certain number of + /// calls to MoveNext(). + /// + private sealed class ThrowsAfterEnumerable : IEnumerable + { + private sealed class Enumerator : IEnumerator + { + private readonly IEnumerator inner; + private readonly int throwAfter; + private int callNumber; + + public T Current { get; set; } + + object IEnumerator.Current => this.Current; + + + internal Enumerator(IEnumerator inner, int throwAfter) + { + this.inner = inner; + this.throwAfter = throwAfter; + } + public void Dispose() + { + this.inner.Dispose(); + } + + public bool MoveNext() + { + if (this.callNumber >= this.throwAfter) + { + throw new InvalidOperationException(); + } + + this.callNumber++; + + if (this.inner.MoveNext()) + { + this.Current = this.inner.Current; + return true; + } + + this.Current = default; + return false; + } + + public void Reset() + { + this.inner.Reset(); + } + } + + private readonly IEnumerable inner; + private readonly int throwAfter; + + public ThrowsAfterEnumerable(IEnumerable inner, int throwAfter) + { + this.inner = inner; + this.throwAfter = throwAfter; + } + + public IEnumerator GetEnumerator() + { + return new Enumerator(this.inner.GetEnumerator(), this.throwAfter); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + } + + /// + /// IPrefetcher which throws if touched. + /// + private sealed class ThrowsPrefetcher : IPrefetcher + { + public ValueTask PrefetchAsync(ITrace trace, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + } + + /// + /// ArrayPool that tracks leaks, double returns, and includes non-null values in + /// returned arrays. + /// + private sealed class ValidatingRandomizedArrayPool : ArrayPool + where T : class + { + private readonly T existingValue; + + private readonly ConcurrentBag created; + private readonly ConcurrentDictionary rented; + + internal ValidatingRandomizedArrayPool(T existingValue) + { + this.existingValue = existingValue; + this.created = new(); + this.rented = new(); + } + + public override T[] Rent(int minimumLength) + { + int extra = Random.Shared.Next(6); + + T[] ret = new T[minimumLength + extra]; + for (int i = 0; i < ret.Length; i++) + { + ret[i] = Random.Shared.Next(2) == 0 ? this.existingValue : null; + } + + this.created.Add(ret); + + Assert.IsTrue(this.rented.TryAdd(ret, null)); + + return ret; + } + + public override void Return(T[] array, bool clearArray = false) + { + Assert.IsFalse(clearArray, "Caller should clean up array itself"); + + Assert.IsTrue(this.rented.TryRemove(array, out _), "Tried to return array that isn't rented"); + + for (int i = 0; i < array.Length; i++) + { + object value = array[i]; + + if (object.ReferenceEquals(value, this.existingValue)) + { + continue; + } + + Assert.IsNull(value, "Returned array shouldn't have any non-null values, except those included by the original Rent call"); + } + } + + internal void AssertAllReturned() + { + Assert.IsTrue(this.rented.IsEmpty); + } + } + + /// + /// ITrace which only traces children and parents. + /// + private sealed class SimpleTrace : ITrace + { + public string Name { get; private set; } + + public Guid Id { get; } = Guid.NewGuid(); + + public DateTime StartTime { get; } = DateTime.UtcNow; + + public TimeSpan Duration => DateTime.UtcNow - this.StartTime; + + public Cosmos.Tracing.TraceLevel Level { get; private set; } + + public TraceComponent Component { get; private set; } + + public TraceSummary Summary => new(); + + public ITrace Parent { get; private set; } + + public IReadOnlyList Children { get; } = new List(); + + public IReadOnlyDictionary Data { get; } = new Dictionary(); + + internal SimpleTrace(ITrace parent, string name, TraceComponent component, Cosmos.Tracing.TraceLevel level) + { + this.Parent = parent; + this.Name = name; + this.Component = component; + this.Level = level; + } + + public void AddChild(ITrace trace) + { + List children = (List)this.Children; + lock (children) + { + children.Add(trace); + } + } + + public void AddDatum(string key, TraceDatum traceDatum) + { + } + + public void AddDatum(string key, object value) + { + } + + public void AddOrUpdateDatum(string key, object value) + { + } + + public void Dispose() + { + } + + public ITrace StartChild(string name) + { + return this.StartChild(name, TraceComponent.Unknown, Cosmos.Tracing.TraceLevel.Off); + } + + public ITrace StartChild(string name, TraceComponent component, Cosmos.Tracing.TraceLevel level) + { + ITrace child = new SimpleTrace(this, name, component, level); + + List children = (List)this.Children; + lock (children) + { + children.Add(child); + } + + return child; + } + } + + /// + /// Different task counts which explore different code paths. + /// + private static readonly int[] TaskCounts = new[] { 0, 1, 2, 511, 512, 513, 1024, 1025 }; + + /// + /// Different max concurrencies which explore different code paths. + /// + private static readonly int[] Concurrencies = new[] { 1, 2, 511, 512, 513, int.MaxValue }; + + private static readonly ITrace EmptyTrace = NoOpTrace.Singleton; + + [TestMethod] + public async Task ParameterValidationAsync() + { + // test contract for parameters + + ArgumentNullException prefetchersArg = Assert.ThrowsException( + () => + ParallelPrefetch.PrefetchInParallelAsync( + null, + 123, + EmptyTrace, + default)); + Assert.AreEqual("prefetchers", prefetchersArg.ParamName); + + ArgumentNullException traceArg = Assert.ThrowsException( + () => + ParallelPrefetch.PrefetchInParallelAsync( + Array.Empty(), + 123, + null, + default)); + Assert.AreEqual("trace", traceArg.ParamName); + + // maxConcurrency can be < 0 ; check that that doesn't throw + await ParallelPrefetch.PrefetchInParallelAsync(Array.Empty(), -123, EmptyTrace, default); + } + + [TestMethod] + public async Task ZeroConcurrencyOptimizationAsync() + { + // test that we correctly special case maxConcurrency == 0 as "do nothing" + + IEnumerable prefetchers = new ThrowsEnumerable(); + + await ParallelPrefetch.PrefetchInParallelAsync( + prefetchers, + 0, + EmptyTrace, + default); + } + + [TestMethod] + public async Task AllExecutedAsync() + { + // test that all prefetchers are actually invoked + + foreach (int maxConcurrency in Concurrencies) + { + foreach (int taskCount in TaskCounts) + { + int executed1 = 0; + int executed2 = 0; + IEnumerable prefetchers = CreatePrefetchers(taskCount, () => Interlocked.Increment(ref executed1), () => Interlocked.Increment(ref executed2)); + + await ParallelPrefetch.PrefetchInParallelAsync( + prefetchers, + maxConcurrency, + EmptyTrace, + default); + + Assert.AreEqual(taskCount, executed1); + Assert.AreEqual(taskCount, executed2); + } + } + + static IEnumerable CreatePrefetchers(int count, Action beforeAwait, Action afterAwait) + { + for (int i = 0; i < count; i++) + { + yield return new TestOncePrefetcher(beforeAwait, afterAwait); + } + } + } + + [TestMethod] + public async Task EnumeratorNotConcurrentlyAccessedAsync() + { + // test that the IEnumerator is only accessed by one thread at a time + foreach (int maxConcurrency in Concurrencies) + { + foreach (int taskCount in TaskCounts) + { + IEnumerable prefetchers = CreatePrefetchers(taskCount, static () => { }, static () => { }); + + await ParallelPrefetch.PrefetchInParallelAsync( + prefetchers, + maxConcurrency, + EmptyTrace, + default); + } + } + + static IEnumerable CreatePrefetchers(int count, Action beforeAwait, Action afterAwait) + { + IPrefetcher[] inner = new IPrefetcher[count]; + for (int i = 0; i < count; i++) + { + inner[i] = new TestOncePrefetcher(beforeAwait, afterAwait); + } + + return new NonConcurrentAssertingEnumerable(inner); + } + } + + [TestMethod] + public async Task EnumeratorDisposedAsync() + { + // test that the IEnumerator is only accessed by one thread at a time + foreach (int maxConcurrency in Concurrencies) + { + foreach (int taskCount in TaskCounts) + { + DisposeTrackingEnumerable prefetchers = CreatePrefetchers(taskCount, static () => { }, static () => { }); + + await ParallelPrefetch.PrefetchInParallelAsync( + prefetchers, + maxConcurrency, + EmptyTrace, + default); + + Assert.AreEqual(1, prefetchers.DisposeCalls); + } + } + + static DisposeTrackingEnumerable CreatePrefetchers(int count, Action beforeAwait, Action afterAwait) + { + IPrefetcher[] inner = new IPrefetcher[count]; + for (int i = 0; i < count; i++) + { + inner[i] = new TestOncePrefetcher(beforeAwait, afterAwait); + } + + return new DisposeTrackingEnumerable(inner); + } + } + + [TestMethod] + public async Task ComplicatedPrefetcherAsync() + { + // test that a complicated prefetcher is full started and completed + // + // the rest of the tests don't use a completely trivial + // IPrefetcher, but they are substantially simpler + + foreach (int maxConcurrency in Concurrencies) + { + ComplicatedPrefetcher prefetcher = new(); + + await ParallelPrefetch.PrefetchInParallelAsync( + new IPrefetcher[] { prefetcher }, + maxConcurrency, + EmptyTrace, + default); + + prefetcher.AssertCorrect(); + } + } + + [TestMethod] + public async Task MaxConcurrencyRespectedAsync() + { + // test that we never get above maxConcurrency + // + // whether or not we _reach_ it is dependent on the scheduler + // so we can't reliably test that + + foreach (int maxConcurrency in Concurrencies) + { + foreach (int taskCount in TaskCounts) + { + int observedMax = 0; + int current = 0; + + IEnumerable prefetchers = + CreatePrefetchers( + taskCount, + () => + { + int newCurrent = Interlocked.Increment(ref current); + Assert.IsTrue(newCurrent <= maxConcurrency); + + int oldMax = Volatile.Read(ref observedMax); + + while (newCurrent > oldMax) + { + oldMax = Interlocked.CompareExchange(ref observedMax, newCurrent, oldMax); + } + }, + () => + { + int newCurrent = Interlocked.Decrement(ref current); + + Assert.IsTrue(current >= 0); + }); + + await ParallelPrefetch.PrefetchInParallelAsync( + prefetchers, + maxConcurrency, + EmptyTrace, + default); + + Assert.IsTrue(Volatile.Read(ref observedMax) <= maxConcurrency); + Assert.AreEqual(0, Volatile.Read(ref current)); + } + } + + static IEnumerable CreatePrefetchers(int count, Action beforeAwait, Action afterAwait) + { + for (int i = 0; i < count; i++) + { + yield return new TestOncePrefetcher(beforeAwait, afterAwait); + } + } + } + + [TestMethod] + public async Task TraceCorrectlyPassedAsync() + { + // test that we make ONE ITrace per invocation + // and that it the returned child trace is correctly + // passed to all IPrefetchers + + foreach (int maxConcurrency in Concurrencies) + { + foreach (int taskCount in TaskCounts) + { + using ITrace simpleTrace = new SimpleTrace(null, "Root", TraceComponent.Batch, Cosmos.Tracing.TraceLevel.Off); + + IEnumerable prefetchers = CreatePrefetchers(taskCount, simpleTrace); + + await ParallelPrefetch.PrefetchInParallelAsync( + prefetchers, + maxConcurrency, + simpleTrace, + default); + + // our prefetchers don't create any children, but we expect one + // to be created by ParallelPrefetch + Assert.AreEqual(1, simpleTrace.Children.Count); + Assert.AreEqual(0, simpleTrace.Children[0].Children.Count); + + // the one trace we start has a well known set of attributes, so check them + Assert.AreEqual("Prefetching", simpleTrace.Children[0].Name); + Assert.AreEqual(TraceComponent.Pagination, simpleTrace.Children[0].Component); + Assert.AreEqual(Cosmos.Tracing.TraceLevel.Info, simpleTrace.Children[0].Level); + } + } + + static IEnumerable CreatePrefetchers(int count, ITrace expectedParentTrace) + { + for (int i = 0; i < count; i++) + { + yield return new ExpectedParentTracePrefetcher(expectedParentTrace); + } + } + } + + [TestMethod] + public async Task RentedBuffersAllReturnedAsync() + { + // test that all rented buffers are correctly returned + // (and in the expected state) + + Task faultedTask = Task.FromException(new NotSupportedException()); + + try + { + foreach (int maxConcurrency in Concurrencies) + { + foreach (int taskCount in TaskCounts) + { + IEnumerable prefetchers = CreatePrefetchers(taskCount, static () => { }, static () => { }); + + ValidatingRandomizedArrayPool prefetcherPool = new(new ThrowsPrefetcher()); + ValidatingRandomizedArrayPool taskPool = new(faultedTask); + ValidatingRandomizedArrayPool objectPool = new("unexpected value"); + + ParallelPrefetch.ParallelPrefetchTestConfig config = + new( + prefetcherPool, + taskPool, + objectPool + ); + + await ParallelPrefetch.PrefetchInParallelCoreAsync( + prefetchers, + maxConcurrency, + EmptyTrace, + config, + default); + + Assert.AreEqual(config.StartedTasks, config.AwaitedTasks, $"maxConcurrency={maxConcurrency}, taskCount={taskCount}; some tasks left unawaited"); + + prefetcherPool.AssertAllReturned(); + taskPool.AssertAllReturned(); + objectPool.AssertAllReturned(); + } + } + } + finally + { + // observe this intentionally faulted task, no matter what + try + { + await faultedTask; + } + catch + { + // intentionally empty + } + } + + static IEnumerable CreatePrefetchers(int count, Action beforeAwait, Action afterAwait) + { + for (int i = 0; i < count; i++) + { + yield return new TestOncePrefetcher(beforeAwait, afterAwait); + } + } + } + + [TestMethod] + public async Task TaskSingleExceptionHandledAsync() + { + // test that raising exceptions during processing tasks + // doesn't leak or otherwise fail + + Task faultedTask = Task.FromException(new NotSupportedException()); + + try + { + foreach (int maxConcurrency in Concurrencies) + { + if (maxConcurrency <= 1) + { + // we won't do anything fancy, so skip + continue; + } + + foreach (int taskCount in TaskCounts) + { + if (taskCount <= 1) + { + // we won't do anything fancy, so skip + continue; + } + + for (int faultOnTask = 0; faultOnTask < taskCount; faultOnTask++) + { + IEnumerable prefetchers = CreatePrefetchers(taskCount, faultOnTask, static () => { }, static () => { }); + + ValidatingRandomizedArrayPool prefetcherPool = new(new ThrowsPrefetcher()); + ValidatingRandomizedArrayPool taskPool = new(faultedTask); + ValidatingRandomizedArrayPool objectPool = new("unexpected value"); + + ParallelPrefetch.ParallelPrefetchTestConfig config = + new( + prefetcherPool, + taskPool, + objectPool + ); + + Exception caught = null; + try + { + await ParallelPrefetch.PrefetchInParallelCoreAsync( + prefetchers, + maxConcurrency, + EmptyTrace, + config, + default); + } + catch (Exception e) + { + caught = e; + } + + Assert.IsNotNull(caught, $"concurrency={maxConcurrency}, tasks={taskCount}, faultOn={faultOnTask} - didn't produce exception as expected"); + + Assert.AreEqual(config.StartedTasks, config.AwaitedTasks, $"maxConcurrency={maxConcurrency}, taskCount={taskCount}, faultOnTask={faultedTask}; some tasks left unawaited"); + + // buffer management can't break in the face of errors, so check here too + prefetcherPool.AssertAllReturned(); + taskPool.AssertAllReturned(); + objectPool.AssertAllReturned(); + } + } + } + } + finally + { + // observe this intentionally faulted task, no matter what + try + { + await faultedTask; + } + catch + { + // intentionally empty + } + } + + static IEnumerable CreatePrefetchers(int count, int faultOnTask, Action beforeAwait, Action afterAwait) + { + for (int i = 0; i < count; i++) + { + if (faultOnTask == i) + { + yield return new ThrowsPrefetcher(); + } + else + { + yield return new TestOncePrefetcher(beforeAwait, afterAwait); + } + } + } + } + + [TestMethod] + public async Task TaskMultipleExceptionsHandledAsync() + { + // we throw a lot of exceptions in this test, which is expensive + // so we only probe proportionally to this constant for expediency's + // sake + const int StepRatio = 10; + + // test that raising exceptions during processing tasks + // doesn't leak or otherwise fail + + Task faultedTask = Task.FromException(new NotSupportedException()); + + try + { + foreach (int maxConcurrency in Concurrencies) + { + if (maxConcurrency <= 1) + { + // we won't do anything fancy, so skip + continue; + } + + foreach (int taskCount in TaskCounts) + { + if (taskCount <= 1) + { + // we won't do anything fancy, so skip + continue; + } + + int step = Math.Max(1, taskCount / StepRatio); + + for (int faultOnAndAfterTask = 0; faultOnAndAfterTask < taskCount; faultOnAndAfterTask += step) + { + IEnumerable prefetchers = CreatePrefetchers(taskCount, faultOnAndAfterTask, static () => { }, static () => { }); + + ValidatingRandomizedArrayPool prefetcherPool = new(new ThrowsPrefetcher()); + ValidatingRandomizedArrayPool taskPool = new(faultedTask); + ValidatingRandomizedArrayPool objectPool = new("unexpected value"); + + ParallelPrefetch.ParallelPrefetchTestConfig config = + new( + prefetcherPool, + taskPool, + objectPool + ); + + Exception caught = null; + try + { + await ParallelPrefetch.PrefetchInParallelCoreAsync( + prefetchers, + maxConcurrency, + EmptyTrace, + config, + default); + } + catch (Exception e) + { + caught = e; + } + + Assert.IsNotNull(caught, $"concurrency={maxConcurrency}, tasks={taskCount}, faultOnAndAfterTask={faultOnAndAfterTask} - didn't produce exception as expected"); + + Assert.AreEqual(config.StartedTasks, config.AwaitedTasks, $"maxConcurrency={maxConcurrency}, taskCount={taskCount}, faultedOnAndAfterTask={faultOnAndAfterTask}; some tasks left unawaited"); + + // buffer management can't break in the face of errors, so check here too + prefetcherPool.AssertAllReturned(); + taskPool.AssertAllReturned(); + objectPool.AssertAllReturned(); + } + } + } + } + finally + { + // observe this intentionally faulted task, no matter what + try + { + await faultedTask; + } + catch + { + // intentionally empty + } + } + + static IEnumerable CreatePrefetchers(int count, int faultOnTask, Action beforeAwait, Action afterAwait) + { + for (int i = 0; i < count; i++) + { + if (faultOnTask >= i) + { + yield return new ThrowsPrefetcher(); + } + else + { + yield return new TestOncePrefetcher(beforeAwait, afterAwait); + } + } + } + } + + [TestMethod] + public async Task EnumerableExceptionsHandledAsync() + { + // test that raising exceptions during enumeration + // doesn't leak or otherwise fail + + Task faultedTask = Task.FromException(new NotSupportedException()); + + try + { + foreach (int maxConcurrency in Concurrencies.Reverse()) + { + if (maxConcurrency <= 1) + { + // we won't do anything fancy, so skip + continue; + } + + foreach (int taskCount in TaskCounts) + { + for (int faultAfter = 0; faultAfter < taskCount; faultAfter++) + { + IEnumerable prefetchersRaw = CreatePrefetchers(taskCount, faultAfter, static () => { }, static () => { }); + IEnumerable prefetchers = new ThrowsAfterEnumerable(prefetchersRaw, faultAfter); + + ValidatingRandomizedArrayPool prefetcherPool = new(new ThrowsPrefetcher()); + ValidatingRandomizedArrayPool taskPool = new(faultedTask); + ValidatingRandomizedArrayPool objectPool = new("unexpected value"); + + ParallelPrefetch.ParallelPrefetchTestConfig config = + new( + prefetcherPool, + taskPool, + objectPool + ); + + Exception caught = null; + try + { + await ParallelPrefetch.PrefetchInParallelCoreAsync( + prefetchers, + maxConcurrency, + EmptyTrace, + config, + default); + } + catch (Exception e) + { + caught = e; + } + + Assert.IsNotNull(caught, $"concurrency={maxConcurrency}, tasks={taskCount}, faultAfter={faultAfter} - didn't produce exception as expected"); + + Assert.AreEqual(config.StartedTasks, config.AwaitedTasks, $"maxConcurrency={maxConcurrency}, taskCount={taskCount}; some tasks left unawaited"); + + // buffer management can't break in the face of errors, so check here too + prefetcherPool.AssertAllReturned(); + taskPool.AssertAllReturned(); + objectPool.AssertAllReturned(); + } + } + } + } + finally + { + // observe this intentionally faulted task, no matter what + try + { + await faultedTask; + } + catch + { + // intentionally empty + } + } + + static IEnumerable CreatePrefetchers(int count, int faultOnTask, Action beforeAwait, Action afterAwait) + { + for (int i = 0; i < count; i++) + { + yield return new TestOncePrefetcher(beforeAwait, afterAwait); + } + } + } + + [TestMethod] + public async Task CancellationHandledAsync() + { + // cancellation is expensive, so rather than check every + // cancellation point - we just probe some proportional + // to this constant + const int StepRatio = 10; + + // test that cancellation during processing + // doesn't leak or otherwise fail + + Task faultedTask = Task.FromException(new NotSupportedException()); + + try + { + foreach (int maxConcurrency in Concurrencies) + { + if (maxConcurrency <= 1) + { + // we won't do anything fancy, so skip + continue; + } + + foreach (int taskCount in TaskCounts) + { + if (taskCount <= 1) + { + // we won't do anything fancy, so skip + continue; + } + + int step = Math.Max(1, taskCount / StepRatio); + + for (int cancelBeforeTask = 0; cancelBeforeTask < taskCount; cancelBeforeTask += step) + { + using CancellationTokenSource cts = new(); + + int startedBeforeCancellation = 0; + object sync = new(); + + IEnumerable prefetchers = + CreatePrefetchers( + taskCount, + () => + { + if (!cts.IsCancellationRequested) + { + int newValue = Interlocked.Increment(ref startedBeforeCancellation); + + if (newValue >= cancelBeforeTask) + { + cts.Cancel(); + } + } + }, + () => { }); + + ValidatingRandomizedArrayPool prefetcherPool = new(new ThrowsPrefetcher()); + ValidatingRandomizedArrayPool taskPool = new(faultedTask); + ValidatingRandomizedArrayPool objectPool = new("unexpected value"); + + ParallelPrefetch.ParallelPrefetchTestConfig config = + new( + prefetcherPool, + taskPool, + objectPool + ); + + Exception caught = null; + try + { + await ParallelPrefetch.PrefetchInParallelCoreAsync( + prefetchers, + maxConcurrency, + EmptyTrace, + config, + cts.Token); + } + catch (Exception e) + { + caught = e; + } + + Assert.IsNotNull(caught, $"concurrency={maxConcurrency}, tasks={taskCount}, cancelBeforeTask={cancelBeforeTask} - didn't produce exception as expected"); + + // we might burst above this, but we should always at least _reach_ it + Assert.IsTrue(cancelBeforeTask <= startedBeforeCancellation, $"{cancelBeforeTask} > {startedBeforeCancellation} ; we should have reach our cancellation point"); + + Assert.IsTrue(caught is OperationCanceledException); + + Assert.AreEqual(config.StartedTasks, config.AwaitedTasks, $"maxConcurrency={maxConcurrency}, taskCount={taskCount}, cancelBeforeTask={cancelBeforeTask}; some tasks left unawaited"); + + // buffer management can't break in the face of cancellation, so check here too + prefetcherPool.AssertAllReturned(); + taskPool.AssertAllReturned(); + objectPool.AssertAllReturned(); + } + } + + static IEnumerable CreatePrefetchers(int count, Action beforeAwait, Action afterAwait) + { + for (int i = 0; i < count; i++) + { + yield return new TestOncePrefetcher(beforeAwait, afterAwait); + } + } + } + } + finally + { + // observe this intentionally faulted task, no matter what + try + { + await faultedTask; + } + catch + { + // intentionally empty + } + } + } + } +}