Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Make DataLoader batching more efficient #7308

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
using System;
using System.Threading;
using System.Threading.Tasks;

namespace HotChocolate.Execution;

/// <summary>
/// Provides the base implementation for a executable task.
/// Provides the base implementation for an executable task.
/// </summary>
/// <remarks>
/// The task is by default a parallel execution task.
Expand Down Expand Up @@ -42,23 +41,23 @@ public abstract class ExecutionTask : IExecutionTask
public bool IsRegistered { get; set; }

/// <inheritdoc />
public void BeginExecute(CancellationToken cancellationToken)
public void BeginExecute(IExecutionTaskScheduler scheduler)
{
Status = ExecutionTaskStatus.Running;
_task = ExecuteInternalAsync(cancellationToken).AsTask();
_task = scheduler.Schedule(ExecuteInternalAsync);
}

/// <inheritdoc />
public Task WaitForCompletionAsync(CancellationToken cancellationToken)
public Task WaitForCompletionAsync()
=> _task ?? Task.CompletedTask;

private async ValueTask ExecuteInternalAsync(CancellationToken cancellationToken)
private async Task ExecuteInternalAsync()
{
try
{
using (Context.Track(this))
{
await ExecuteAsync(cancellationToken).ConfigureAwait(false);
await ExecuteAsync().ConfigureAwait(false);
}
}
catch (OperationCanceledException)
Expand All @@ -72,7 +71,7 @@ private async ValueTask ExecuteInternalAsync(CancellationToken cancellationToken
{
Faulted();

if (!cancellationToken.IsCancellationRequested)
if (!Context.RequestAborted.IsCancellationRequested)
{
Context.ReportError(this, ex);
}
Expand All @@ -83,12 +82,9 @@ private async ValueTask ExecuteInternalAsync(CancellationToken cancellationToken
}

/// <summary>
/// This execute method represents the work of this task.
/// This execute-method represents the work of this task.
/// </summary>
/// <param name="cancellationToken">
/// The cancellation token.
/// </param>
protected abstract ValueTask ExecuteAsync(CancellationToken cancellationToken);
protected abstract ValueTask ExecuteAsync();

/// <summary>
/// Completes the task as faulted.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,13 @@ public interface IExecutionTask
/// <summary>
/// Begins executing this task.
/// </summary>
/// <param name="cancellationToken">
/// The cancellation token.
/// <param name="scheduler">
/// The task scheduler to schedule work for this task.
/// </param>
void BeginExecute(CancellationToken cancellationToken);
void BeginExecute(IExecutionTaskScheduler scheduler);

/// <summary>
/// The running task can be awaited to track completion of this particular task.
/// </summary>
/// <param name="cancellationToken">
/// The cancellation token.
/// </param>
Task WaitForCompletionAsync(CancellationToken cancellationToken);
Task WaitForCompletionAsync();
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Threading;

namespace HotChocolate.Execution;

Expand All @@ -8,6 +9,12 @@ namespace HotChocolate.Execution;
/// </summary>
public interface IExecutionTaskContext
{
/// <summary>
/// Notifies when the connection underlying this request is aborted
/// and thus request operations should be cancelled.
/// </summary>
CancellationToken RequestAborted { get; }

/// <summary>
/// Tracks the running task for the diagnostics.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using System;
using System.Threading.Tasks;

namespace HotChocolate.Execution;

public interface IExecutionTaskScheduler
{
event EventHandler AllTasksCompleted;

/// <summary>
/// Specified if the scheduler is still processing scheduled work.
/// </summary>
bool IsProcessing { get; }

/// <summary>
/// Schedules work to be processed.
/// </summary>
/// <param name="work">
/// The work that shall be processed.
/// </param>
/// <returns>
/// Returns a task representing the work.
/// </returns>
Task Schedule(Func<Task> work);
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,8 @@ internal static IServiceCollection TryAddDefaultDocumentHashProvider(
internal static IServiceCollection TryAddDefaultBatchDispatcher(
this IServiceCollection services)
{
services.TryAddScoped<IBatchHandler, BatchScheduler>();
services.TryAddScoped<IBatchScheduler, AutoBatchScheduler>();
services.TryAddScoped<IBatchDispatcher>(sp => sp.GetRequiredService<IBatchHandler>());
services.TryAddScoped<IBatchDispatcher>(_ => new DefaultBatchDispatcher());
return services;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public StreamExecutionTask(DeferredStream deferredStream)

public ResolverTask? ChildTask { get; private set; }

protected override async ValueTask ExecuteAsync(CancellationToken cancellationToken)
protected override async ValueTask ExecuteAsync()
{
ChildTask = null;
_deferredStream.Index++;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using HotChocolate.Fetching;

namespace HotChocolate.Execution.Processing;
Expand All @@ -9,9 +10,16 @@ internal sealed class NoopBatchDispatcher : IBatchDispatcher
{
public event EventHandler? TaskEnqueued;

public IExecutionTaskScheduler Scheduler => DefaultExecutionTaskScheduler.Instance;

public bool DispatchOnSchedule { get; set; }

public void BeginDispatch(CancellationToken cancellationToken) { }

public static NoopBatchDispatcher Default { get; } = new();

public void Schedule(Func<ValueTask> dispatch)
{
throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@ namespace HotChocolate.Execution.Processing.Tasks;

internal sealed partial class ResolverTask
{
private async Task ExecuteAsync(CancellationToken cancellationToken)
private async Task ExecuteAsync()
{
try
{
var ct = _context.RequestAborted;

using (DiagnosticEvents.ResolveFieldValue(_context))
{
// we initialize the field, so we are able to propagate non-null violations
// through the result tree.
_context.ParentResult.InitValueUnsafe(_context.ResponseIndex, _context.Selection);

var success = await TryExecuteAsync(cancellationToken).ConfigureAwait(false);
CompleteValue(success, cancellationToken);
var success = await TryExecuteAsync(ct).ConfigureAwait(false);
CompleteValue(success, ct);

switch (_taskBuffer.Count)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ public ExecutionTaskKind Kind
public bool IsRegistered { get; set; }

/// <inheritdoc />
public void BeginExecute(CancellationToken cancellationToken)
public void BeginExecute(IExecutionTaskScheduler scheduler)
{
Status = ExecutionTaskStatus.Running;
_task = ExecuteAsync(cancellationToken);
_task = scheduler.Schedule(ExecuteAsync);
}

/// <inheritdoc />
public Task WaitForCompletionAsync(CancellationToken cancellationToken)
public Task WaitForCompletionAsync()
=> _task ?? Task.CompletedTask;
}
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ public static ResolverTask EnqueueElementTasks(
var result = operationContext.Result.RentObject(selectionsCount);
var includeFlags = operationContext.IncludeFlags;
var final = !selectionSet.IsConditional;

result.SetParent(parentResult, parentIndex);

ref var selection = ref ((SelectionSet)selectionSet).GetSelectionsReference();
ref var end = ref Unsafe.Add(ref selection, selectionsCount);

Expand All @@ -176,7 +176,7 @@ public static ResolverTask EnqueueElementTasks(
{
return null;
}

if (!final && !selection.IsIncluded(includeFlags))
{
goto NEXT;
Expand All @@ -202,7 +202,7 @@ public static ResolverTask EnqueueElementTasks(
responseIndex++,
context.ResolverContext.ScopedContextData));
}

NEXT:
selection = ref Unsafe.Add(ref selection, 1)!;
}
Expand Down Expand Up @@ -233,7 +233,7 @@ private static void ResolveAndCompleteInline(
var resolverContext = context.ResolverContext;
var executedSuccessfully = false;
object? resolverResult = null;

parentResult.InitValueUnsafe(responseIndex, selection);

try
Expand Down Expand Up @@ -381,7 +381,7 @@ public NoOpExecutionTask(OperationContext context)

protected override IExecutionTaskContext Context { get; }

protected override ValueTask ExecuteAsync(CancellationToken cancellationToken)
protected override ValueTask ExecuteAsync()
=> default;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ private async Task ExecuteInternalAsync(IExecutionTask?[] buffer)

if (!first.IsSerial)
{
first.BeginExecute(_ct);
first.BeginExecute(_batchDispatcher.Scheduler);
buffer[0] = null;

// if work is not serial we will just enqueue it and not wait
// for it to finish.
for (var i = 1; i < work; i++)
{
buffer[i]!.BeginExecute(_ct);
buffer[i]!.BeginExecute(_batchDispatcher.Scheduler);
buffer[i] = null;
}
}
Expand All @@ -60,8 +60,8 @@ private async Task ExecuteInternalAsync(IExecutionTask?[] buffer)
try
{
_batchDispatcher.DispatchOnSchedule = true;
first.BeginExecute(_ct);
await first.WaitForCompletionAsync(_ct).ConfigureAwait(false);
first.BeginExecute(_batchDispatcher.Scheduler);
await first.WaitForCompletionAsync().ConfigureAwait(false);
buffer[0] = null;
}
finally
Expand Down Expand Up @@ -136,7 +136,7 @@ private int TryTake(IExecutionTask?[] buffer)
return size;
}

private void BatchDispatcherEventHandler(object? source, EventArgs args)
private void OnTaskEnqueued(object? source, EventArgs args)
{
lock (_sync)
{
Expand All @@ -146,6 +146,9 @@ private void BatchDispatcherEventHandler(object? source, EventArgs args)
_pause.TryContinue();
}

private void OnAllTasksCompleted(object? source, EventArgs args)
=> _pause.TryContinue();

private void HandleError(Exception exception)
{
var error =
Expand Down Expand Up @@ -178,20 +181,19 @@ private void TryDispatchOrComplete()
if (!_isCompleted)
{
var isWaitingForTaskCompletion = _work.HasRunningTasks && _work.IsEmpty;
var hasWork = !_work.IsEmpty || !_serial.IsEmpty;

if (isWaitingForTaskCompletion)
{
_pause.Reset();

if (_hasBatches)
if (_hasBatches && !_batchDispatcher.Scheduler.IsProcessing)
{
_hasBatches = false;
_batchDispatcher.BeginDispatch(_ct);
}
}
else
{
var hasWork = !_work.IsEmpty || !_serial.IsEmpty;
if (!hasWork)
{
_isCompleted = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ internal sealed partial class WorkScheduler(OperationContext operationContext)
public void Initialize(IBatchDispatcher batchDispatcher)
{
_batchDispatcher = batchDispatcher;
_batchDispatcher.TaskEnqueued += BatchDispatcherEventHandler;
_batchDispatcher.TaskEnqueued += OnTaskEnqueued;
_batchDispatcher.Scheduler.AllTasksCompleted += OnAllTasksCompleted;

_errorHandler = operationContext.ErrorHandler;
_result = operationContext.Result;
Expand All @@ -52,7 +53,8 @@ public void Clear()
_serial.Clear();
_pause.Reset();

_batchDispatcher.TaskEnqueued -= BatchDispatcherEventHandler;
_batchDispatcher.TaskEnqueued -= OnTaskEnqueued;
_batchDispatcher.Scheduler.AllTasksCompleted -= OnAllTasksCompleted;
_batchDispatcher = default!;

_requestContext = default!;
Expand Down
Loading
Loading