diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs index e5861df337..7c71dcdedb 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs @@ -567,17 +567,6 @@ internal SqlDataReader FindLiveReader(SqlCommand command) return reader; } - internal SqlCommand FindLiveCommand(TdsParserStateObject stateObj) - { - SqlCommand command = null; - SqlReferenceCollection referenceCollection = (SqlReferenceCollection)ReferenceCollection; - if (null != referenceCollection) - { - command = referenceCollection.FindLiveCommand(stateObj); - } - return command; - } - abstract protected byte[] GetDTCAddress(); static private byte[] GetTransactionCookie(Transaction transaction, byte[] whereAbouts) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlReferenceCollection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlReferenceCollection.cs index efcb66cf28..f6780b1bf4 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlReferenceCollection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlReferenceCollection.cs @@ -2,18 +2,38 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Diagnostics; +using System.Threading; using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient { sealed internal class SqlReferenceCollection : DbReferenceCollection { + private sealed class FindLiveReaderContext + { + public readonly Func Func; + + private SqlCommand _command; + + public FindLiveReaderContext() => Func = Predicate; + + public void Setup(SqlCommand command) => _command = command; + + public void Clear() => _command = null; + + private bool Predicate(SqlDataReader reader) => (!reader.IsClosed) && (_command == reader.Command); + } + internal const int DataReaderTag = 1; internal const int CommandTag = 2; internal const int BulkCopyTag = 3; - override public void Add(object value, int tag) + private readonly static Func s_hasOpenReaderFunc = HasOpenReaderPredicate; + private static FindLiveReaderContext s_cachedFindLiveReaderContext; + + public override void Add(object value, int tag) { Debug.Assert(DataReaderTag == tag || CommandTag == tag || BulkCopyTag == tag, "unexpected tag?"); Debug.Assert(DataReaderTag != tag || value is SqlDataReader, "tag doesn't match object type: SqlDataReader"); @@ -30,25 +50,24 @@ internal void Deactivate() internal SqlDataReader FindLiveReader(SqlCommand command) { - if (command == null) + if (command is null) { // if null == command, will find first live datareader - return FindItem(DataReaderTag, (dataReader) => (!dataReader.IsClosed)); + return FindItem(DataReaderTag, s_hasOpenReaderFunc); } else { // else will find live datareader associated with the command - return FindItem(DataReaderTag, (dataReader) => ((!dataReader.IsClosed) && (command == dataReader.Command))); + FindLiveReaderContext context = Interlocked.Exchange(ref s_cachedFindLiveReaderContext, null) ?? new FindLiveReaderContext(); + context.Setup(command); + SqlDataReader retval = FindItem(DataReaderTag, context.Func); + context.Clear(); + Interlocked.CompareExchange(ref s_cachedFindLiveReaderContext, context, null); + return retval; } } - // Finds a SqlCommand associated with the given StateObject - internal SqlCommand FindLiveCommand(TdsParserStateObject stateObj) - { - return FindItem(CommandTag, (command) => (command.StateObject == stateObj)); - } - - override protected void NotifyItem(int message, int tag, object value) + protected override void NotifyItem(int message, int tag, object value) { Debug.Assert(0 == message, "unexpected message?"); Debug.Assert(DataReaderTag == tag || CommandTag == tag || BulkCopyTag == tag, "unexpected tag?"); @@ -74,11 +93,13 @@ override protected void NotifyItem(int message, int tag, object value) } } - override public void Remove(object value) + public override void Remove(object value) { Debug.Assert(value is SqlDataReader || value is SqlCommand || value is SqlBulkCopy, "SqlReferenceCollection.Remove expected a SqlDataReader or SqlCommand or SqlBulkCopy"); base.RemoveItem(value); } + + private static bool HasOpenReaderPredicate(SqlDataReader reader) => !reader.IsClosed; } }