Skip to content

Commit

Permalink
convert ExecuteNonQueryAsync to use async context object
Browse files Browse the repository at this point in the history
  • Loading branch information
Wraith2 committed Jul 30, 2022
1 parent 9b1996a commit 829b9cf
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,69 @@ namespace Microsoft.Data.SqlClient
// CONSIDER creating your own Set method that calls the base Set rather than providing a parameterized ctor, it is friendlier to caching
// DO NOT use this class' state after Dispose has been called. It will not throw ObjectDisposedException but it will be a cleared object

internal abstract class AAsyncCallContext<TOwner, TTask> : IDisposable
internal abstract class AAsyncCallContext<TOwner, TTask, TDisposable> : AAsyncBaseCallContext<TOwner,TTask>
where TOwner : class
where TDisposable : IDisposable
{
protected TOwner _owner;
protected TaskCompletionSource<TTask> _source;
protected IDisposable _disposable;
protected TDisposable _disposable;

protected AAsyncCallContext()
{
}

protected AAsyncCallContext(TOwner owner, TaskCompletionSource<TTask> source, IDisposable disposable = null)
protected AAsyncCallContext(TOwner owner, TaskCompletionSource<TTask> source, TDisposable disposable = default)
{
Set(owner, source, disposable);
}

protected void Set(TOwner owner, TaskCompletionSource<TTask> source, IDisposable disposable = null)
protected void Set(TOwner owner, TaskCompletionSource<TTask> source, TDisposable disposable = default)
{
base.Set(owner, source);
_disposable = disposable;
}

protected override void DisposeCore()
{
TDisposable copyDisposable = _disposable;
_disposable = default;
_isDisposed = true;
copyDisposable?.Dispose();
}
}

internal abstract class AAsyncBaseCallContext<TOwner, TTask>
{
protected TOwner _owner;
protected TaskCompletionSource<TTask> _source;
protected bool _isDisposed;

protected AAsyncBaseCallContext()
{
}

protected void Set(TOwner owner, TaskCompletionSource<TTask> source)
{
_owner = owner ?? throw new ArgumentNullException(nameof(owner));
_source = source ?? throw new ArgumentNullException(nameof(source));
_disposable = disposable;
_isDisposed = false;
}

protected void ClearCore()
{
_source = null;
_owner = default;
IDisposable copyDisposable = _disposable;
_disposable = null;
copyDisposable?.Dispose();
try
{
DisposeCore();
}
finally
{
_isDisposed = true;
}
}

protected abstract void DisposeCore();

/// <summary>
/// override this method to cleanup instance data before ClearCore is called which will blank the base data
/// </summary>
Expand All @@ -65,16 +96,19 @@ protected virtual void AfterCleared(TOwner owner)

public void Dispose()
{
TOwner owner = _owner;
try
{
Clear();
}
finally
if (!_isDisposed)
{
ClearCore();
TOwner owner = _owner;
try
{
Clear();
}
finally
{
ClearCore();
}
AfterCleared(owner);
}
AfterCleared(owner);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ public sealed partial class SqlCommand : DbCommand, ICloneable
private static readonly Func<SqlCommand, CommandBehavior, AsyncCallback, object, int, bool, bool, IAsyncResult> s_beginExecuteXmlReaderInternal = BeginExecuteXmlReaderInternalCallback;
private static readonly Func<SqlCommand, CommandBehavior, AsyncCallback, object, int, bool, bool, IAsyncResult> s_beginExecuteNonQueryInternal = BeginExecuteNonQueryInternalCallback;

internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext<SqlCommand, SqlDataReader>
internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext<SqlCommand, SqlDataReader, CancellationTokenRegistration>
{
public Guid OperationID;
public CommandBehavior CommandBehavior;

public SqlCommand Command => _owner;
public TaskCompletionSource<SqlDataReader> TaskCompletionSource => _source;

public void Set(SqlCommand command, TaskCompletionSource<SqlDataReader> source, IDisposable disposable, CommandBehavior behavior, Guid operationID)
public void Set(SqlCommand command, TaskCompletionSource<SqlDataReader> source, CancellationTokenRegistration disposable, CommandBehavior behavior, Guid operationID)
{
base.Set(command, source, disposable);
CommandBehavior = behavior;
Expand All @@ -73,6 +73,31 @@ protected override void AfterCleared(SqlCommand owner)
}
}

internal sealed class ExecuteNonQueryAsyncCallContext : AAsyncCallContext<SqlCommand, int, CancellationTokenRegistration>
{
public Guid OperationID;

public SqlCommand Command => _owner;

public TaskCompletionSource<int> TaskCompletionSource => _source;

public void Set(SqlCommand command, TaskCompletionSource<int> source, CancellationTokenRegistration disposable, Guid operationID)
{
base.Set(command, source, disposable);
OperationID = operationID;
}

protected override void Clear()
{
OperationID = default;
}

protected override void AfterCleared(SqlCommand owner)
{

}
}

private CommandType _commandType;
private int? _commandTimeout;
private UpdateRowSource _updatedRowSource = UpdateRowSource.Both;
Expand Down Expand Up @@ -2540,37 +2565,56 @@ private Task<int> InternalExecuteNonQueryAsync(CancellationToken cancellationTok
}

Task<int> returnedTask = source.Task;
returnedTask = RegisterForConnectionCloseNotification(returnedTask);

ExecuteNonQueryAsyncCallContext context = new ExecuteNonQueryAsyncCallContext();
context.Set(this, source, registration, operationId);
try
{
returnedTask = RegisterForConnectionCloseNotification(returnedTask);

Task<int>.Factory.FromAsync(BeginExecuteNonQueryAsync, EndExecuteNonQueryAsync, null).ContinueWith((t) =>
{
registration.Dispose();
if (t.IsFaulted)
{
Exception e = t.Exception.InnerException;
_diagnosticListener.WriteCommandError(operationId, this, _transaction, e);
source.SetException(e);
}
else
Task<int>.Factory.FromAsync(
static (AsyncCallback callback, object stateObject) => ((ExecuteNonQueryAsyncCallContext)stateObject).Command.BeginExecuteNonQueryAsync(callback, stateObject),
static (IAsyncResult result) => ((ExecuteNonQueryAsyncCallContext)result.AsyncState).Command.EndExecuteNonQueryAsync(result),
state: context
).ContinueWith(
static (Task<int> task, object state) =>
{
if (t.IsCanceled)
ExecuteNonQueryAsyncCallContext context = (ExecuteNonQueryAsyncCallContext)state;
Guid operationId = context.OperationID;
SqlCommand command = context.Command;
TaskCompletionSource<int> source = context.TaskCompletionSource;
context.Dispose();
context = null;
if (task.IsFaulted)
{
source.SetCanceled();
Exception e = task.Exception.InnerException;
_diagnosticListener.WriteCommandError(operationId, command, command._transaction, e);
source.SetException(e);
}
else
{
source.SetResult(t.Result);
if (task.IsCanceled)
{
source.SetCanceled();
}
else
{
source.SetResult(task.Result);
}
_diagnosticListener.WriteCommandAfter(operationId, command, command._transaction);
}
_diagnosticListener.WriteCommandAfter(operationId, this, _transaction);
}
}, TaskScheduler.Default);
},
state: context,
scheduler: TaskScheduler.Default
);
}
catch (Exception e)
{
_diagnosticListener.WriteCommandError(operationId, this, _transaction, e);
source.SetException(e);
context.Dispose();
}

return returnedTask;
Expand Down Expand Up @@ -2645,11 +2689,11 @@ private Task<SqlDataReader> InternalExecuteReaderAsync(CommandBehavior behavior,
}

Task<SqlDataReader> returnedTask = source.Task;
ExecuteReaderAsyncCallContext context = null;
try
{
returnedTask = RegisterForConnectionCloseNotification(returnedTask);

ExecuteReaderAsyncCallContext context = null;
if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection)
{
context = Interlocked.Exchange(ref sqlInternalConnection.CachedCommandExecuteReaderAsyncContext, null);
Expand Down Expand Up @@ -2677,6 +2721,7 @@ private Task<SqlDataReader> InternalExecuteReaderAsync(CommandBehavior behavior,
}

source.SetException(e);
context.Dispose();
}

return returnedTask;
Expand Down
Loading

0 comments on commit 829b9cf

Please sign in to comment.