diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 102d6325b1..d374688503 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -33,7 +33,7 @@ Microsoft\Data\SqlClient\SqlClientEventSource.cs - + Microsoft\Data\SqlClient\SqlClientLogger.cs @@ -57,7 +57,7 @@ Microsoft\Data\SqlClient\ColumnEncryptionKeyInfo.cs - + Microsoft\Data\SqlClient\OnChangedEventHandler.cs @@ -635,6 +635,7 @@ + True True diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/ProviderBase/DbConnectionFactory.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/ProviderBase/DbConnectionFactory.cs index e42f08d963..faf87c9f53 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/ProviderBase/DbConnectionFactory.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/ProviderBase/DbConnectionFactory.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Data; using System.Data.Common; using System.Diagnostics; @@ -15,6 +16,7 @@ namespace Microsoft.Data.ProviderBase { internal abstract partial class DbConnectionFactory { + private static readonly Action, object> s_tryGetConnectionCompletedContinuation = TryGetConnectionCompletedContinuation; internal bool TryGetConnection(DbConnection owningConnection, TaskCompletionSource retry, DbConnectionOptions userOptions, DbConnectionInternal oldConnection, out DbConnectionInternal connection) { @@ -82,25 +84,7 @@ internal bool TryGetConnection(DbConnection owningConnection, TaskCompletionSour // now that we have an antecedent task, schedule our work when it is completed. // If it is a new slot or a completed task, this continuation will start right away. - newTask = s_pendingOpenNonPooled[idx].ContinueWith((_) => - { - Transaction originalTransaction = ADP.GetCurrentTransaction(); - try - { - ADP.SetCurrentTransaction(retry.Task.AsyncState as Transaction); - var newConnection = CreateNonPooledConnection(owningConnection, poolGroup, userOptions); - if ((oldConnection != null) && (oldConnection.State == ConnectionState.Open)) - { - oldConnection.PrepareForReplaceConnection(); - oldConnection.Dispose(); - } - return newConnection; - } - finally - { - ADP.SetCurrentTransaction(originalTransaction); - } - }, cancellationTokenSource.Token, TaskContinuationOptions.LongRunning, TaskScheduler.Default); + newTask = CreateReplaceConnectionContinuation(s_pendingOpenNonPooled[idx], owningConnection, retry, userOptions, oldConnection, poolGroup, cancellationTokenSource); // Place this new task in the slot so any future work will be queued behind it s_pendingOpenNonPooled[idx] = newTask; @@ -114,29 +98,11 @@ internal bool TryGetConnection(DbConnection owningConnection, TaskCompletionSour } // once the task is done, propagate the final results to the original caller - newTask.ContinueWith((task) => - { - cancellationTokenSource.Dispose(); - if (task.IsCanceled) - { - retry.TrySetException(ADP.ExceptionWithStackTrace(ADP.NonPooledOpenTimeout())); - } - else if (task.IsFaulted) - { - retry.TrySetException(task.Exception.InnerException); - } - else - { - if (!retry.TrySetResult(task.Result)) - { - // The outer TaskCompletionSource was already completed - // Which means that we don't know if someone has messed with the outer connection in the middle of creation - // So the best thing to do now is to destroy the newly created connection - task.Result.DoomThisConnection(); - task.Result.Dispose(); - } - } - }, TaskScheduler.Default); + newTask.ContinueWith( + continuationAction: s_tryGetConnectionCompletedContinuation, + state: Tuple.Create(cancellationTokenSource, retry), + scheduler: TaskScheduler.Default + ); return false; } @@ -188,5 +154,62 @@ internal bool TryGetConnection(DbConnection owningConnection, TaskCompletionSour return true; } + + private Task CreateReplaceConnectionContinuation(Task task, DbConnection owningConnection, TaskCompletionSource retry, DbConnectionOptions userOptions, DbConnectionInternal oldConnection, DbConnectionPoolGroup poolGroup, CancellationTokenSource cancellationTokenSource) + { + return task.ContinueWith( + (_) => + { + Transaction originalTransaction = ADP.GetCurrentTransaction(); + try + { + ADP.SetCurrentTransaction(retry.Task.AsyncState as Transaction); + var newConnection = CreateNonPooledConnection(owningConnection, poolGroup, userOptions); + if ((oldConnection != null) && (oldConnection.State == ConnectionState.Open)) + { + oldConnection.PrepareForReplaceConnection(); + oldConnection.Dispose(); + } + return newConnection; + } + finally + { + ADP.SetCurrentTransaction(originalTransaction); + } + }, + cancellationTokenSource.Token, + TaskContinuationOptions.LongRunning, + TaskScheduler.Default + ); + } + + private static void TryGetConnectionCompletedContinuation(Task task, object state) + { + Tuple> parameters = (Tuple>)state; + CancellationTokenSource source = parameters.Item1; + source.Dispose(); + + TaskCompletionSource retryCompletionSrouce = parameters.Item2; + + if (task.IsCanceled) + { + retryCompletionSrouce.TrySetException(ADP.ExceptionWithStackTrace(ADP.NonPooledOpenTimeout())); + } + else if (task.IsFaulted) + { + retryCompletionSrouce.TrySetException(task.Exception.InnerException); + } + else + { + if (!retryCompletionSrouce.TrySetResult(task.Result)) + { + // The outer TaskCompletionSource was already completed + // Which means that we don't know if someone has messed with the outer connection in the middle of creation + // So the best thing to do now is to destroy the newly created connection + task.Result.DoomThisConnection(); + task.Result.Dispose(); + } + } + } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs new file mode 100644 index 0000000000..c70724a1be --- /dev/null +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading.Tasks; + +namespace Microsoft.Data.SqlClient +{ + // this class is a base class for creating derived objects that will store state for async operations + // avoiding the use of closures and allowing caching/reuse of the instances for frequently used async + // calls + // + // DO derive from this and seal and your class + // DO add additional fields or properties needed for the async operation and then override Clear to zero them + // DO override AfterClear and use the owner parameter to return the object to a cache location if you have one, this is the purpose of the method + // CONSIDER creating your own Set method that calls the base Set rather than providing a parameterized ctor, it is friendlier to caching + // DO NOT use this class after Dispose has been called. It will not throw ObjectDisposedException but it will be a cleared object + + internal abstract class AAsyncCallContext : IDisposable + where TOwner : class + { + protected TOwner _owner; + protected TaskCompletionSource _source; + protected IDisposable _disposable; + + protected AAsyncCallContext() + { + } + + protected AAsyncCallContext(TOwner owner, TaskCompletionSource source, IDisposable disposable = null) + { + Set(owner, source, disposable); + } + + protected void Set(TOwner owner, TaskCompletionSource source, IDisposable disposable = null) + { + _owner = owner ?? throw new ArgumentNullException(nameof(owner)); + _source = source ?? throw new ArgumentNullException(nameof(source)); + _disposable = disposable; + } + + protected void ClearCore() + { + _source = null; + _owner = default; + IDisposable copyDisposable = _disposable; + _disposable = null; + copyDisposable?.Dispose(); + } + + /// + /// overrride this method to cleanup instance data before ClearCore is called which will blank the base data + /// + protected virtual void Clear() + { + } + + /// + /// override this method to do work after the instance has been totally blanked, intended for cache return etc + /// + protected virtual void AfterCleared(TOwner owner) + { + + } + + public void Dispose() + { + TOwner owner = _owner; + try + { + Clear(); + } + finally + { + ClearCore(); + } + AfterCleared(owner); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 686e194397..62d3418319 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -2224,7 +2224,7 @@ private Task ReadWriteColumnValueAsync(int col) return writeTask; } - private void RegisterForConnectionCloseNotification(ref Task outerTask) + private Task RegisterForConnectionCloseNotification(Task outerTask) { SqlConnection connection = _connection; if (connection == null) @@ -2233,7 +2233,7 @@ private void RegisterForConnectionCloseNotification(ref Task outerTask) throw ADP.ClosedConnectionError(); } - connection.RegisterForConnectionCloseNotification(ref outerTask, this, SqlReferenceCollection.BulkCopyTag); + return connection.RegisterForConnectionCloseNotification(outerTask, this, SqlReferenceCollection.BulkCopyTag); } // Runs a loop to copy all columns of a single row. @@ -3016,7 +3016,7 @@ private Task WriteToServerInternalAsync(CancellationToken ctoken) source = new TaskCompletionSource(); // Creating the completion source/Task that we pass to application resultTask = source.Task; - RegisterForConnectionCloseNotification(ref resultTask); + resultTask = RegisterForConnectionCloseNotification(resultTask); } if (_destinationTableName == null) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index 66bd59c4e2..980f576e35 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -28,6 +28,42 @@ public sealed partial class SqlCommand : DbCommand, ICloneable private static int _objectTypeCount; // EventSource Counter internal readonly int ObjectID = Interlocked.Increment(ref _objectTypeCount); private string _commandText; + private static readonly Func s_beginExecuteReaderAsync = BeginExecuteReaderAsyncCallback; + private static readonly Func s_endExecuteReaderAsync = EndExecuteReaderAsyncCallback; + private static readonly Action> s_cleanupExecuteReaderAsync = CleanupExecuteReaderAsyncCallback; + private static readonly Func s_internalEndExecuteNonQuery = InternalEndExecuteNonQueryCallback; + private static readonly Func s_internalEndExecuteReader = InternalEndExecuteReaderCallback; + private static readonly Func s_beginExecuteReaderInternal = BeginExecuteReaderInternalCallback; + private static readonly Func s_beginExecuteXmlReaderInternal = BeginExecuteXmlReaderInternalCallback; + private static readonly Func s_beginExecuteNonQueryInternal = BeginExecuteNonQueryInternalCallback; + + internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext + { + public Guid OperationID; + public CommandBehavior CommandBehavior; + + public SqlCommand Command => _owner; + public TaskCompletionSource TaskCompletionSource => _source; + + public void Set(SqlCommand command, TaskCompletionSource source, IDisposable disposable, CommandBehavior behavior, Guid operationID) + { + base.Set(command, source, disposable); + CommandBehavior = behavior; + OperationID = operationID; + } + + protected override void Clear() + { + OperationID = default; + CommandBehavior = default; + } + + protected override void AfterCleared(SqlCommand owner) + { + owner?.SetCachedCommandExecuteReaderAsyncContext(this); + } + } + private CommandType _commandType; private int _commandTimeout = ADP.DefaultCommandTimeout; private UpdateRowSource _updatedRowSource = UpdateRowSource.Both; @@ -1139,7 +1175,7 @@ private IAsyncResult BeginExecuteNonQueryInternal(CommandBehavior behavior, Asyn // When we use query caching for parameter encryption we need to retry on specific errors. // In these cases finalize the call internally and trigger a retry when needed. - if (!TriggerInternalEndAndRetryIfNecessary(behavior, stateObject, timeout, usedCache, inRetry, asyncWrite, globalCompletion, localCompletion, InternalEndExecuteNonQuery, BeginExecuteNonQueryInternal, nameof(EndExecuteNonQuery))) + if (!TriggerInternalEndAndRetryIfNecessary(behavior, stateObject, timeout, usedCache, inRetry, asyncWrite, globalCompletion, localCompletion, s_internalEndExecuteNonQuery, s_beginExecuteNonQueryInternal, nameof(EndExecuteNonQuery))) { globalCompletion = localCompletion; } @@ -1642,7 +1678,7 @@ private IAsyncResult BeginExecuteXmlReaderInternal(CommandBehavior behavior, Asy // When we use query caching for parameter encryption we need to retry on specific errors. // In these cases finalize the call internally and trigger a retry when needed. - if (!TriggerInternalEndAndRetryIfNecessary(behavior, stateObject, timeout, usedCache, inRetry, asyncWrite, globalCompletion, localCompletion, InternalEndExecuteReader, BeginExecuteXmlReaderInternal, endMethod: nameof(EndExecuteXmlReader))) + if (!TriggerInternalEndAndRetryIfNecessary(behavior, stateObject, timeout, usedCache, inRetry, asyncWrite, globalCompletion, localCompletion, s_internalEndExecuteReader, s_beginExecuteXmlReaderInternal, endMethod: nameof(EndExecuteXmlReader))) { globalCompletion = localCompletion; } @@ -1800,8 +1836,6 @@ private XmlReader CompleteXmlReader(SqlDataReader ds, bool isAsync = false) /// public IAsyncResult BeginExecuteReader(AsyncCallback callback, object stateObject) => BeginExecuteReader(callback, stateObject, CommandBehavior.Default); - - /// public IAsyncResult BeginExecuteReader(CommandBehavior behavior) => BeginExecuteReader(null, null, behavior); @@ -1895,7 +1929,6 @@ public SqlDataReader EndExecuteReader(IAsyncResult asyncResult) } } - internal SqlDataReader EndExecuteReaderAsync(IAsyncResult asyncResult) { SqlClientEventSource.Log.CorrelationTraceEvent(" ObjectID{0}, ActivityID {1}", ObjectID, ActivityCorrelator.Current); @@ -1961,9 +1994,55 @@ private SqlDataReader EndExecuteReaderInternal(IAsyncResult asyncResult) } } - private IAsyncResult BeginExecuteReaderAsync(CommandBehavior behavior, AsyncCallback callback, object stateObject) + private void CleanupExecuteReaderAsync(Task task, TaskCompletionSource source, Guid operationId) { - return BeginExecuteReaderInternal(behavior, callback, stateObject, CommandTimeout, inRetry: false, asyncWrite: true); + if (task.IsFaulted) + { + Exception e = task.Exception.InnerException; + if (!_parentOperationStarted) + { + _diagnosticListener.WriteCommandError(operationId, this, e); + } + source.SetException(e); + } + else + { + if (task.IsCanceled) + { + source.SetCanceled(); + } + else + { + source.SetResult(task.Result); + } + if (!_parentOperationStarted) + { + _diagnosticListener.WriteCommandAfter(operationId, this); + } + } + } + + private static IAsyncResult BeginExecuteReaderAsyncCallback(AsyncCallback callback, object stateObject) + { + ExecuteReaderAsyncCallContext args = (ExecuteReaderAsyncCallContext)stateObject; + return args.Command.BeginExecuteReaderInternal(args.CommandBehavior, callback, stateObject, args.Command.CommandTimeout, inRetry: false, asyncWrite: true); + } + + private static SqlDataReader EndExecuteReaderAsyncCallback(IAsyncResult asyncResult) + { + ExecuteReaderAsyncCallContext args = (ExecuteReaderAsyncCallContext)asyncResult.AsyncState; + return args.Command.EndExecuteReaderAsync(asyncResult); + } + + private static void CleanupExecuteReaderAsyncCallback(Task task) + { + ExecuteReaderAsyncCallContext context = (ExecuteReaderAsyncCallContext)task.AsyncState; + SqlCommand command = context.Command; + Guid operationId = context.OperationID; + TaskCompletionSource source = context.TaskCompletionSource; + context.Dispose(); + + command.CleanupExecuteReaderAsync(task, source, operationId); } private IAsyncResult BeginExecuteReaderInternal(CommandBehavior behavior, AsyncCallback callback, object stateObject, int timeout, bool inRetry, bool asyncWrite = false) @@ -2027,7 +2106,7 @@ private IAsyncResult BeginExecuteReaderInternal(CommandBehavior behavior, AsyncC // When we use query caching for parameter encryption we need to retry on specific errors. // In these cases finalize the call internally and trigger a retry when needed. - if (!TriggerInternalEndAndRetryIfNecessary(behavior, stateObject, timeout, usedCache, inRetry, asyncWrite, globalCompletion, localCompletion, InternalEndExecuteReader, BeginExecuteReaderInternal, nameof(EndExecuteReader))) + if (!TriggerInternalEndAndRetryIfNecessary(behavior, stateObject, timeout, usedCache, inRetry, asyncWrite, globalCompletion, localCompletion, s_internalEndExecuteReader, s_beginExecuteReaderInternal, nameof(EndExecuteReader))) { globalCompletion = localCompletion; } @@ -2049,6 +2128,42 @@ private IAsyncResult BeginExecuteReaderInternal(CommandBehavior behavior, AsyncC } } + /// + /// used to convert an invocation through a cached static delegate back to an instance call + /// + private static SqlDataReader InternalEndExecuteReaderCallback(SqlCommand command, IAsyncResult asyncResult, bool isInternal, string endMethod) + { + return command.InternalEndExecuteReader(asyncResult, isInternal, endMethod); + } + /// + /// used to convert an invocation through a cached static delegate back to an instance call + /// + private static IAsyncResult BeginExecuteReaderInternalCallback(SqlCommand command, CommandBehavior behavior, AsyncCallback callback, object stateObject, int timeout, bool inRetry, bool asyncWrite) + { + return command.BeginExecuteReaderInternal(behavior, callback, stateObject, timeout, inRetry, asyncWrite); + } + /// + /// used to convert an invocation through a cached static delegate back to an instance call + /// + private static IAsyncResult BeginExecuteXmlReaderInternalCallback(SqlCommand command, CommandBehavior behavior, AsyncCallback callback, object stateObject, int timeout, bool inRetry, bool asyncWrite) + { + return command.BeginExecuteXmlReaderInternal(behavior, callback, stateObject, timeout, inRetry, asyncWrite); + } + /// + /// used to convert an invocation through a cached static delegate back to an instance call + /// + private static object InternalEndExecuteNonQueryCallback(SqlCommand command, IAsyncResult asyncResult, bool isInternal, string endMethod) + { + return command.InternalEndExecuteNonQuery(asyncResult, isInternal, endMethod); + } + /// + /// used to convert an invocation through a cached static delegate back to an instance call + /// + private static IAsyncResult BeginExecuteNonQueryInternalCallback(SqlCommand command, CommandBehavior behavior, AsyncCallback callback, object stateObject, int timeout, bool inRetry, bool asyncWrite) + { + return command.BeginExecuteNonQueryInternal(behavior, callback, stateObject, timeout, inRetry, asyncWrite); + } + private bool TriggerInternalEndAndRetryIfNecessary( CommandBehavior behavior, object stateObject, @@ -2058,9 +2173,9 @@ private bool TriggerInternalEndAndRetryIfNecessary( bool asyncWrite, TaskCompletionSource globalCompletion, TaskCompletionSource localCompletion, - Func endFunc, - Func retryFunc, - [CallerMemberName] string endMethod = "" + Func endFunc, + Func retryFunc, + string endMethod ) { // We shouldn't be using the cache if we are in retry. @@ -2086,9 +2201,19 @@ private bool TriggerInternalEndAndRetryIfNecessary( } } - private void CreateLocalCompletionTask(CommandBehavior behavior, object stateObject, int timeout, bool usedCache, - bool asyncWrite, TaskCompletionSource globalCompletion, TaskCompletionSource localCompletion, Func endFunc, - Func retryFunc, string endMethod, long firstAttemptStart) + private void CreateLocalCompletionTask( + CommandBehavior behavior, + object stateObject, + int timeout, + bool usedCache, + bool asyncWrite, + TaskCompletionSource globalCompletion, + TaskCompletionSource localCompletion, + Func endFunc, + Func retryFunc, + string endMethod, + long firstAttemptStart + ) { localCompletion.Task.ContinueWith(tsk => { @@ -2111,7 +2236,7 @@ private void CreateLocalCompletionTask(CommandBehavior behavior, object stateObj // lock on _stateObj prevents races with close/cancel. lock (_stateObj) { - endFunc(tsk, true, endMethod /*inInternal*/); + endFunc(this, tsk, true, endMethod /*inInternal*/); } globalCompletion.TrySetResult(tsk.Result); @@ -2172,7 +2297,7 @@ private void CreateLocalCompletionTask(CommandBehavior behavior, object stateObj { // Kick off the retry. _internalEndExecuteInitiated = false; - Task retryTask = (Task)retryFunc(behavior, null, stateObject, + Task retryTask = (Task)retryFunc(this, behavior, null, stateObject, TdsParserStaticMethods.GetRemainingTimeout(timeout, firstAttemptStart), true /*inRetry*/, asyncWrite); @@ -2266,7 +2391,7 @@ public override Task ExecuteNonQueryAsync(CancellationToken cancellationTok Task returnedTask = source.Task; try { - RegisterForConnectionCloseNotification(ref returnedTask); + returnedTask = RegisterForConnectionCloseNotification(returnedTask); Task.Factory.FromAsync(BeginExecuteNonQueryAsync, EndExecuteNonQueryAsync, null).ContinueWith((t) => { @@ -2337,7 +2462,9 @@ protected override Task ExecuteDbDataReaderAsync(CommandBehavior b SqlClientEventSource.Log.CorrelationTraceEvent(" ObjectID {0}, behavior={1}, ActivityID {2}", ObjectID, (int)behavior, ActivityCorrelator.Current); Guid operationId = default(Guid); if (!_parentOperationStarted) + { operationId = _diagnosticListener.WriteCommandBefore(this); + } TaskCompletionSource source = new TaskCompletionSource(); @@ -2355,37 +2482,34 @@ protected override Task ExecuteDbDataReaderAsync(CommandBehavior b Task returnedTask = source.Task; try { - RegisterForConnectionCloseNotification(ref returnedTask); + returnedTask = RegisterForConnectionCloseNotification(returnedTask); - Task.Factory.FromAsync((commandBehavior, callback, stateObject) => BeginExecuteReaderAsync(commandBehavior, callback, stateObject), EndExecuteReaderAsync, behavior, null).ContinueWith((t) => + ExecuteReaderAsyncCallContext context = null; + if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection) { - registration.Dispose(); - if (t.IsFaulted) - { - Exception e = t.Exception.InnerException; - if (!_parentOperationStarted) - _diagnosticListener.WriteCommandError(operationId, this, e); - source.SetException(e); - } - else - { - if (t.IsCanceled) - { - source.SetCanceled(); - } - else - { - source.SetResult(t.Result); - } - if (!_parentOperationStarted) - _diagnosticListener.WriteCommandAfter(operationId, this); - } - }, TaskScheduler.Default); + context = Interlocked.Exchange(ref sqlInternalConnection.CachedCommandExecuteReaderAsyncContext, null); + } + if (context is null) + { + context = new ExecuteReaderAsyncCallContext(); + } + context.Set(this, source, registration,behavior, operationId); + + Task.Factory.FromAsync( + beginMethod: s_beginExecuteReaderAsync, + endMethod: s_endExecuteReaderAsync, + state: context + ).ContinueWith( + continuationAction: s_cleanupExecuteReaderAsync, + TaskScheduler.Default + ); } catch (Exception e) { if (!_parentOperationStarted) + { _diagnosticListener.WriteCommandError(operationId, this, e); + } source.SetException(e); } @@ -2393,6 +2517,14 @@ protected override Task ExecuteDbDataReaderAsync(CommandBehavior b return returnedTask; } + private void SetCachedCommandExecuteReaderAsyncContext(ExecuteReaderAsyncCallContext instance) + { + if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection) + { + Interlocked.CompareExchange(ref sqlInternalConnection.CachedCommandExecuteReaderAsyncContext, instance, null); + } + } + /// public override Task ExecuteScalarAsync(CancellationToken cancellationToken) { @@ -2504,7 +2636,7 @@ public Task ExecuteXmlReaderAsync(CancellationToken cancellationToken Task returnedTask = source.Task; try { - RegisterForConnectionCloseNotification(ref returnedTask); + returnedTask = RegisterForConnectionCloseNotification(returnedTask); Task.Factory.FromAsync(BeginExecuteXmlReaderAsync, EndExecuteXmlReaderAsync, null).ContinueWith((t) => { @@ -4717,7 +4849,7 @@ private void FinishExecuteReader(SqlDataReader ds, RunBehavior runBehavior, stri } - private void RegisterForConnectionCloseNotification(ref Task outerTask) + private Task RegisterForConnectionCloseNotification(Task outerTask) { SqlConnection connection = _activeConnection; if (connection == null) @@ -4726,7 +4858,7 @@ private void RegisterForConnectionCloseNotification(ref Task outerTask) throw ADP.ClosedConnectionError(); } - connection.RegisterForConnectionCloseNotification(ref outerTask, this, SqlReferenceCollection.CommandTag); + return connection.RegisterForConnectionCloseNotification(outerTask, this, SqlReferenceCollection.CommandTag); } // validates that a command has commandText and a non-busy open connection diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs index 4f847bc1a7..868bb4d433 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -99,6 +99,7 @@ private static readonly ConcurrentDictionary> _ColumnEncry comparer: StringComparer.OrdinalIgnoreCase); private static readonly Action s_openAsyncCancel = OpenAsyncCancel; + private static readonly Action, object> s_openAsyncComplete = OpenAsyncComplete; /// public static TimeSpan ColumnEncryptionKeyCacheTtl { get; set; } = TimeSpan.FromHours(2); @@ -1282,22 +1283,16 @@ public override Task OpenAsync(CancellationToken cancellationToken) System.Transactions.Transaction transaction = ADP.GetCurrentTransaction(); TaskCompletionSource completion = new TaskCompletionSource(transaction); - TaskCompletionSource result = new TaskCompletionSource(); + TaskCompletionSource result = new TaskCompletionSource(state: this); if (s_diagnosticListener.IsEnabled(SqlClientDiagnosticListenerExtensions.SqlAfterOpenConnection) || s_diagnosticListener.IsEnabled(SqlClientDiagnosticListenerExtensions.SqlErrorOpenConnection)) { - result.Task.ContinueWith((t) => - { - if (t.Exception != null) - { - s_diagnosticListener.WriteConnectionOpenError(operationId, this, t.Exception); - } - else - { - s_diagnosticListener.WriteConnectionOpenAfter(operationId, this); - } - }, TaskScheduler.Default); + result.Task.ContinueWith( + continuationAction: s_openAsyncComplete, + state: operationId, // connection is passed in TaskCompletionSource async state + scheduler: TaskScheduler.Default + ); } if (cancellationToken.IsCancellationRequested) @@ -1355,6 +1350,20 @@ public override Task OpenAsync(CancellationToken cancellationToken) } } + private static void OpenAsyncComplete(Task task, object state) + { + Guid operationId = (Guid)state; + SqlConnection connection = (SqlConnection)task.AsyncState; + if (task.Exception != null) + { + s_diagnosticListener.WriteConnectionOpenError(operationId, connection, task.Exception); + } + else + { + s_diagnosticListener.WriteConnectionOpenAfter(operationId, connection); + } + } + private static void OpenAsyncCancel(object state) { ((TaskCompletionSource)state).TrySetCanceled(); @@ -1864,14 +1873,22 @@ private static void ChangePassword(string connectionString, SqlConnectionString // this only happens once per connection // SxS: using named file mapping APIs - internal void RegisterForConnectionCloseNotification(ref Task outerTask, object value, int tag) + internal Task RegisterForConnectionCloseNotification(Task outerTask, object value, int tag) { // Connection exists, schedule removal, will be added to ref collection after calling ValidateAndReconnect - outerTask = outerTask.ContinueWith(task => - { - RemoveWeakReference(value); - return task; - }, TaskScheduler.Default).Unwrap(); + return outerTask.ContinueWith( + continuationFunction: (task,state) => + { + Tuple parameters = (Tuple)state; + SqlConnection connection = parameters.Item1; + object obj = parameters.Item2; + + connection.RemoveWeakReference(obj); + return task; + }, + state: Tuple.Create(this, value), + scheduler: TaskScheduler.Default + ).Unwrap(); } /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs index e11677b0a1..34038ffc7b 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs @@ -22,6 +22,8 @@ abstract internal class SqlInternalConnection : DbConnectionInternal private bool _isGlobalTransactionEnabledForServer = false; // Whether Global Transactions are enabled for this Azure SQL DB Server private static readonly Guid _globalTransactionTMID = new Guid("1c742caf-6680-40ea-9c26-6b6846079764"); // ID of the Non-MSDTC, Azure SQL DB Transaction Manager + internal SqlCommand.ExecuteReaderAsyncCallContext CachedCommandExecuteReaderAsyncContext; + // if connection is not open: null // if connection is open: currently active database internal string CurrentDatabase { get; set; } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 915a28d664..5da0415401 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -2840,7 +2840,7 @@ public void ReadAsyncCallback(IntPtr key, PacketHandle packet, uint error) { if (_executionContext != null) { - ExecutionContext.Run(_executionContext, (state) => ReadAsyncCallbackCaptureException(source), null); + ExecutionContext.Run(_executionContext, state => ReadAsyncCallbackCaptureException((TaskCompletionSource)state), source); } else {