diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs index 56e369593a..a93674a420 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs @@ -17,38 +17,69 @@ namespace Microsoft.Data.SqlClient // CONSIDER creating your own Set method that calls the base Set rather than providing a parameterized ctor, it is friendlier to caching // DO NOT use this class' state after Dispose has been called. It will not throw ObjectDisposedException but it will be a cleared object - internal abstract class AAsyncCallContext : IDisposable + internal abstract class AAsyncCallContext : AAsyncBaseCallContext where TOwner : class + where TDisposable : IDisposable { - protected TOwner _owner; - protected TaskCompletionSource _source; - protected IDisposable _disposable; + protected TDisposable _disposable; protected AAsyncCallContext() { } - protected AAsyncCallContext(TOwner owner, TaskCompletionSource source, IDisposable disposable = null) + protected AAsyncCallContext(TOwner owner, TaskCompletionSource source, TDisposable disposable = default) { Set(owner, source, disposable); } - protected void Set(TOwner owner, TaskCompletionSource source, IDisposable disposable = null) + protected void Set(TOwner owner, TaskCompletionSource source, TDisposable disposable = default) + { + base.Set(owner, source); + _disposable = disposable; + } + + protected override void DisposeCore() + { + TDisposable copyDisposable = _disposable; + _disposable = default; + _isDisposed = true; + copyDisposable?.Dispose(); + } + } + + internal abstract class AAsyncBaseCallContext + { + protected TOwner _owner; + protected TaskCompletionSource _source; + protected bool _isDisposed; + + protected AAsyncBaseCallContext() + { + } + + protected void Set(TOwner owner, TaskCompletionSource source) { _owner = owner ?? throw new ArgumentNullException(nameof(owner)); _source = source ?? throw new ArgumentNullException(nameof(source)); - _disposable = disposable; + _isDisposed = false; } protected void ClearCore() { _source = null; _owner = default; - IDisposable copyDisposable = _disposable; - _disposable = null; - copyDisposable?.Dispose(); + try + { + DisposeCore(); + } + finally + { + _isDisposed = true; + } } + protected abstract void DisposeCore(); + /// /// override this method to cleanup instance data before ClearCore is called which will blank the base data /// @@ -65,16 +96,19 @@ protected virtual void AfterCleared(TOwner owner) public void Dispose() { - TOwner owner = _owner; - try - { - Clear(); - } - finally + if (!_isDisposed) { - ClearCore(); + TOwner owner = _owner; + try + { + Clear(); + } + finally + { + ClearCore(); + } + AfterCleared(owner); } - AfterCleared(owner); } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index 66bacefd98..e1754b4b8f 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -46,7 +46,7 @@ public sealed partial class SqlCommand : DbCommand, ICloneable private static readonly Func s_beginExecuteXmlReaderInternal = BeginExecuteXmlReaderInternalCallback; private static readonly Func s_beginExecuteNonQueryInternal = BeginExecuteNonQueryInternalCallback; - internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext + internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext { public Guid OperationID; public CommandBehavior CommandBehavior; @@ -54,7 +54,7 @@ internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext _owner; public TaskCompletionSource TaskCompletionSource => _source; - public void Set(SqlCommand command, TaskCompletionSource source, IDisposable disposable, CommandBehavior behavior, Guid operationID) + public void Set(SqlCommand command, TaskCompletionSource source, CancellationTokenRegistration disposable, CommandBehavior behavior, Guid operationID) { base.Set(command, source, disposable); CommandBehavior = behavior; @@ -73,6 +73,31 @@ protected override void AfterCleared(SqlCommand owner) } } + internal sealed class ExecuteNonQueryAsyncCallContext : AAsyncCallContext + { + public Guid OperationID; + + public SqlCommand Command => _owner; + + public TaskCompletionSource TaskCompletionSource => _source; + + public void Set(SqlCommand command, TaskCompletionSource source, CancellationTokenRegistration disposable, Guid operationID) + { + base.Set(command, source, disposable); + OperationID = operationID; + } + + protected override void Clear() + { + OperationID = default; + } + + protected override void AfterCleared(SqlCommand owner) + { + + } + } + private CommandType _commandType; private int? _commandTimeout; private UpdateRowSource _updatedRowSource = UpdateRowSource.Both; @@ -2540,37 +2565,56 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok } Task returnedTask = source.Task; + returnedTask = RegisterForConnectionCloseNotification(returnedTask); + + ExecuteNonQueryAsyncCallContext context = new ExecuteNonQueryAsyncCallContext(); + context.Set(this, source, registration, operationId); try { - returnedTask = RegisterForConnectionCloseNotification(returnedTask); - - Task.Factory.FromAsync(BeginExecuteNonQueryAsync, EndExecuteNonQueryAsync, null).ContinueWith((t) => - { - registration.Dispose(); - if (t.IsFaulted) - { - Exception e = t.Exception.InnerException; - _diagnosticListener.WriteCommandError(operationId, this, _transaction, e); - source.SetException(e); - } - else + Task.Factory.FromAsync( + static (AsyncCallback callback, object stateObject) => ((ExecuteNonQueryAsyncCallContext)stateObject).Command.BeginExecuteNonQueryAsync(callback, stateObject), + static (IAsyncResult result) => ((ExecuteNonQueryAsyncCallContext)result.AsyncState).Command.EndExecuteNonQueryAsync(result), + state: context + ).ContinueWith( + static (Task task, object state) => { - if (t.IsCanceled) + ExecuteNonQueryAsyncCallContext context = (ExecuteNonQueryAsyncCallContext)state; + + Guid operationId = context.OperationID; + SqlCommand command = context.Command; + TaskCompletionSource source = context.TaskCompletionSource; + + context.Dispose(); + context = null; + + if (task.IsFaulted) { - source.SetCanceled(); + Exception e = task.Exception.InnerException; + _diagnosticListener.WriteCommandError(operationId, command, command._transaction, e); + source.SetException(e); } else { - source.SetResult(t.Result); + if (task.IsCanceled) + { + source.SetCanceled(); + } + else + { + source.SetResult(task.Result); + } + _diagnosticListener.WriteCommandAfter(operationId, command, command._transaction); } - _diagnosticListener.WriteCommandAfter(operationId, this, _transaction); - } - }, TaskScheduler.Default); + }, + state: context, + scheduler: TaskScheduler.Default + ); } catch (Exception e) { _diagnosticListener.WriteCommandError(operationId, this, _transaction, e); source.SetException(e); + context.Dispose(); } return returnedTask; @@ -2645,11 +2689,11 @@ private Task InternalExecuteReaderAsync(CommandBehavior behavior, } Task returnedTask = source.Task; + ExecuteReaderAsyncCallContext context = null; try { returnedTask = RegisterForConnectionCloseNotification(returnedTask); - ExecuteReaderAsyncCallContext context = null; if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection) { context = Interlocked.Exchange(ref sqlInternalConnection.CachedCommandExecuteReaderAsyncContext, null); @@ -2677,6 +2721,7 @@ private Task InternalExecuteReaderAsync(CommandBehavior behavior, } source.SetException(e); + context.Dispose(); } return returnedTask; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index a801697d8f..48d6167157 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -4406,7 +4406,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken) return source.Task; } - IDisposable registration = null; + CancellationTokenRegistration registration = default; if (cancellationToken.CanBeCanceled) { if (cancellationToken.IsCancellationRequested) @@ -4706,7 +4706,7 @@ out bytesRead Debug.Assert(context.Source != null, "context._source should not be null when continuing"); // setup for cleanup/completing retryTask.ContinueWith( - continuationAction: SqlDataReaderAsyncCallContext.s_completeCallback, + continuationAction: SqlDataReaderBaseAsyncCallContext.s_completeCallback, state: context, TaskScheduler.Default ); @@ -4831,7 +4831,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) return source.Task; } - IDisposable registration = null; + CancellationTokenRegistration registration = default; if (cancellationToken.CanBeCanceled) { registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); @@ -5005,7 +5005,7 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo } // Setup cancellations - IDisposable registration = null; + CancellationTokenRegistration registration = default; if (cancellationToken.CanBeCanceled) { registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); @@ -5021,7 +5021,7 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo context = new IsDBNullAsyncCallContext(); } - Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == null, "cached ISDBNullAsync context not properly disposed"); + Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ISDBNullAsync context not properly disposed"); context.Set(this, source, registration); context._columnIndex = i; @@ -5152,7 +5152,7 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat } // Setup cancellations - IDisposable registration = null; + CancellationTokenRegistration registration = default; if (cancellationToken.CanBeCanceled) { registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); @@ -5216,49 +5216,63 @@ internal void CompletePendingReadWithFailure(int errorCode, bool resetForcePendi } #endif - - internal abstract class SqlDataReaderAsyncCallContext : AAsyncCallContext + + internal abstract class SqlDataReaderBaseAsyncCallContext : AAsyncBaseCallContext { internal static readonly Action, object> s_completeCallback = CompleteAsyncCallCallback; internal static readonly Func> s_executeCallback = ExecuteAsyncCallCallback; - protected SqlDataReaderAsyncCallContext() + protected SqlDataReaderBaseAsyncCallContext() { } - protected SqlDataReaderAsyncCallContext(SqlDataReader owner, TaskCompletionSource source, IDisposable disposable = null) + protected SqlDataReaderBaseAsyncCallContext(SqlDataReader owner, TaskCompletionSource source) { - Set(owner, source, disposable); + Set(owner, source); } internal abstract Func> Execute { get; } internal SqlDataReader Reader { get => _owner; set => _owner = value; } - public IDisposable Disposable { get => _disposable; set => _disposable = value; } - public TaskCompletionSource Source { get => _source; set => _source = value; } - new public void Set(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable) - { - base.Set(reader, source, disposable); - } - private static Task ExecuteAsyncCallCallback(Task task, object state) { - SqlDataReaderAsyncCallContext context = (SqlDataReaderAsyncCallContext)state; + SqlDataReaderBaseAsyncCallContext context = (SqlDataReaderBaseAsyncCallContext)state; return context.Reader.ContinueAsyncCall(task, context); } private static void CompleteAsyncCallCallback(Task task, object state) { - SqlDataReaderAsyncCallContext context = (SqlDataReaderAsyncCallContext)state; + SqlDataReaderBaseAsyncCallContext context = (SqlDataReaderBaseAsyncCallContext)state; context.Reader.CompleteAsyncCall(task, context); } } - internal sealed class ReadAsyncCallContext : SqlDataReaderAsyncCallContext + internal abstract class SqlDataReaderAsyncCallContext : SqlDataReaderBaseAsyncCallContext + where TDisposable : IDisposable + { + private TDisposable _disposable; + + public TDisposable Disposable { get => _disposable; set => _disposable = value; } + + public void Set(SqlDataReader owner, TaskCompletionSource source, TDisposable disposable) + { + base.Set(owner, source); + _disposable = disposable; + } + + protected override void DisposeCore() + { + TDisposable copy = _disposable; + _disposable = default; + copy.Dispose(); + } + } + + internal sealed class ReadAsyncCallContext : SqlDataReaderAsyncCallContext { internal static readonly Func> s_execute = SqlDataReader.ReadAsyncExecute; @@ -5277,7 +5291,7 @@ protected override void AfterCleared(SqlDataReader owner) } } - internal sealed class IsDBNullAsyncCallContext : SqlDataReaderAsyncCallContext + internal sealed class IsDBNullAsyncCallContext : SqlDataReaderAsyncCallContext { internal static readonly Func> s_execute = SqlDataReader.IsDBNullAsyncExecute; @@ -5293,19 +5307,19 @@ protected override void AfterCleared(SqlDataReader owner) } } - private sealed class HasNextResultAsyncCallContext : SqlDataReaderAsyncCallContext + private sealed class HasNextResultAsyncCallContext : SqlDataReaderAsyncCallContext { private static readonly Func> s_execute = SqlDataReader.NextResultAsyncExecute; - public HasNextResultAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable) - : base(reader, source, disposable) + public HasNextResultAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, CancellationTokenRegistration disposable) { + Set(reader, source, disposable); } internal override Func> Execute => s_execute; } - private sealed class GetBytesAsyncCallContext : SqlDataReaderAsyncCallContext + private sealed class GetBytesAsyncCallContext : SqlDataReaderAsyncCallContext { internal enum OperationMode { @@ -5343,7 +5357,7 @@ protected override void Clear() } } - private sealed class GetFieldValueAsyncCallContext : SqlDataReaderAsyncCallContext + private sealed class GetFieldValueAsyncCallContext : SqlDataReaderAsyncCallContext { private static readonly Func> s_execute = SqlDataReader.GetFieldValueAsyncExecute; @@ -5351,9 +5365,9 @@ private sealed class GetFieldValueAsyncCallContext : SqlDataReaderAsyncCallCo internal GetFieldValueAsyncCallContext() { } - internal GetFieldValueAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable) - : base(reader, source, disposable) + internal GetFieldValueAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, CancellationTokenRegistration disposable) { + Set(reader, source, disposable); } protected override void Clear() @@ -5373,7 +5387,7 @@ protected override void Clear() /// /// /// - private Task InvokeAsyncCall(SqlDataReaderAsyncCallContext context) + private Task InvokeAsyncCall(SqlDataReaderBaseAsyncCallContext context) { TaskCompletionSource source = context.Source; try @@ -5395,7 +5409,7 @@ private Task InvokeAsyncCall(SqlDataReaderAsyncCallContext context) else { task.ContinueWith( - continuationAction: SqlDataReaderAsyncCallContext.s_completeCallback, + continuationAction: SqlDataReaderBaseAsyncCallContext.s_completeCallback, state: context, TaskScheduler.Default ); @@ -5420,7 +5434,7 @@ private Task InvokeAsyncCall(SqlDataReaderAsyncCallContext context) /// /// /// - private Task ExecuteAsyncCall(SqlDataReaderAsyncCallContext context) + private Task ExecuteAsyncCall(AAsyncBaseCallContext context) { // _networkPacketTaskSource could be null if the connection was closed // while an async invocation was outstanding. @@ -5433,7 +5447,7 @@ private Task ExecuteAsyncCall(SqlDataReaderAsyncCallContext context) else { return completionSource.Task.ContinueWith( - continuationFunction: SqlDataReaderAsyncCallContext.s_executeCallback, + continuationFunction: SqlDataReaderBaseAsyncCallContext.s_executeCallback, state: context, TaskScheduler.Default ).Unwrap(); @@ -5449,7 +5463,7 @@ private Task ExecuteAsyncCall(SqlDataReaderAsyncCallContext context) /// /// /// - private Task ContinueAsyncCall(Task task, SqlDataReaderAsyncCallContext context) + private Task ContinueAsyncCall(Task task, SqlDataReaderBaseAsyncCallContext context) { // this function must be an instance function called from the static callback because otherwise a compiler error // is caused by accessing the _cancelAsyncOnCloseToken field of a MarshalByRefObject derived class @@ -5509,7 +5523,7 @@ private Task ContinueAsyncCall(Task task, SqlDataReaderAsyncCallContext /// /// /// - private void CompleteAsyncCall(Task task, SqlDataReaderAsyncCallContext context) + private void CompleteAsyncCall(Task task, SqlDataReaderBaseAsyncCallContext context) { TaskCompletionSource source = context.Source; context.Dispose();