Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Commit

Permalink
propogate and detect socket errors in managed connections
Browse files Browse the repository at this point in the history
  • Loading branch information
Wraith2 committed Apr 21, 2019
1 parent c608dda commit 853eb42
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public SNIMarsConnection(SNIHandle lowerHandle)
_lowerHandle.SetAsyncCallbacks(HandleReceiveComplete, HandleSendComplete);
}

public SNIMarsHandle CreateMarsSession(object callbackObject, bool async)
public SNIMarsHandle CreateMarsSession(TdsParserStateObject callbackObject, bool async)
{
lock (this)
{
Expand Down Expand Up @@ -126,12 +126,12 @@ public uint CheckConnection()
/// <summary>
/// Process a receive error
/// </summary>
public void HandleReceiveError(SNIPacket packet)
public void HandleReceiveError(SNIPacket packet, uint sniErrorCode)
{
Debug.Assert(Monitor.IsEntered(this), "HandleReceiveError was called without being locked.");
foreach (SNIMarsHandle handle in _sessions.Values)
{
handle.HandleReceiveError(packet);
handle.HandleReceiveError(packet, sniErrorCode);
}
packet?.Dispose();
}
Expand Down Expand Up @@ -161,7 +161,7 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode)
{
lock (this)
{
HandleReceiveError(packet);
HandleReceiveError(packet, sniErrorCode);
return;
}
}
Expand Down Expand Up @@ -192,7 +192,7 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode)
return;
}

HandleReceiveError(packet);
HandleReceiveError(packet, sniErrorCode);
return;
}
}
Expand Down Expand Up @@ -223,7 +223,7 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode)
return;
}

HandleReceiveError(packet);
HandleReceiveError(packet, sniErrorCode);
return;
}
}
Expand All @@ -234,7 +234,7 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode)
if (!_sessions.ContainsKey(_currentHeader.sessionId))
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.SMUX_PROV, 0, SNICommon.InvalidParameterError, string.Empty);
HandleReceiveError(packet);
HandleReceiveError(packet, sniErrorCode);
_lowerHandle.Dispose();
_lowerHandle = null;
return;
Expand Down Expand Up @@ -280,7 +280,7 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode)
return;
}

HandleReceiveError(packet);
HandleReceiveError(packet, sniErrorCode);
return;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ internal class SNIMarsHandle : SNIHandle
private readonly uint _status = TdsEnums.SNI_UNINITIALIZED;
private readonly Queue<SNIPacket> _receivedPacketQueue = new Queue<SNIPacket>();
private readonly Queue<SNIMarsQueuedPacket> _sendPacketQueue = new Queue<SNIMarsQueuedPacket>();
private readonly object _callbackObject;
private readonly TdsParserStateObject _callbackObject;
private readonly Guid _connectionId = Guid.NewGuid();
private readonly ushort _sessionId;
private readonly ManualResetEventSlim _packetEvent = new ManualResetEventSlim(false);
Expand Down Expand Up @@ -78,7 +78,7 @@ public override void Dispose()
/// <param name="sessionId">MARS session ID</param>
/// <param name="callbackObject">Callback object</param>
/// <param name="async">true if connection is asynchronous</param>
public SNIMarsHandle(SNIMarsConnection connection, ushort sessionId, object callbackObject, bool async)
public SNIMarsHandle(SNIMarsConnection connection, ushort sessionId, TdsParserStateObject callbackObject, bool async)
{
_sessionId = sessionId;
_connection = connection;
Expand All @@ -101,7 +101,7 @@ private void SendControlPacket(SNISMUXFlags flags)

SNIPacket packet = new SNIPacket(SNISMUXHeader.HEADER_LENGTH);
packet.AppendData(headerBytes);

_connection.Send(packet);
}

Expand Down Expand Up @@ -295,15 +295,23 @@ public override uint ReceiveAsync(ref SNIPacket packet)
/// <summary>
/// Handle receive error
/// </summary>
public void HandleReceiveError(SNIPacket packet)
public void HandleReceiveError(SNIPacket packet, uint sniErrorCode)
{
lock (_receivedPacketQueue)
{
_connectionError = SNILoadHandle.SingletonInstance.LastError;
_packetEvent.Set();
}

((TdsParserStateObject)_callbackObject).ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 1);
if (sniErrorCode == TdsEnums.SNI_WSAECONNRESET || sniErrorCode == TdsEnums.SNI_ERROR)
{
TdsParser parser = _callbackObject.Parser;
parser.State = TdsParserState.Broken;
parser.Connection.BreakConnection();
}
else
{
_callbackObject.ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), sniErrorCode);
}
}

/// <summary>
Expand All @@ -317,7 +325,7 @@ public void HandleSendComplete(SNIPacket packet, uint sniErrorCode)
{
Debug.Assert(_callbackObject != null);

((TdsParserStateObject)_callbackObject).WriteAsyncCallback(PacketHandle.FromManagedPacket(packet), sniErrorCode);
_callbackObject.WriteAsyncCallback(PacketHandle.FromManagedPacket(packet), sniErrorCode);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System.Buffers;
using System.IO;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;

Expand All @@ -18,31 +19,39 @@ internal partial class SNIPacket
/// <param name="callback">Completion callback</param>
public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback)
{
// Treat local function as a static and pass all params otherwise as async will allocate
async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask<int> valueTask)
static async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask<int> valueTask)
{
bool error = false;
uint errorCode = TdsEnums.SNI_SUCCESS;
try
{
packet._length = await valueTask.ConfigureAwait(false);
if (packet._length == 0)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, string.Empty);
error = true;
errorCode = TdsEnums.SNI_WSAECONNRESET;
}
}
catch (IOException ioException) when (
ioException?.InnerException is SocketException socketException &&
socketException != null &&
socketException.SocketErrorCode == SocketError.OperationAborted
)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, socketException);
errorCode = TdsEnums.SNI_WSAECONNRESET;
}
catch (Exception ex)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, ex);
error = true;
errorCode = TdsEnums.SNI_ERROR;
}

if (error)
if (errorCode != TdsEnums.SNI_SUCCESS)
{
packet.Release();
}

cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
cb(packet, errorCode);
}

ValueTask<int> vt = stream.ReadAsync(new Memory<byte>(_data, 0, _capacity), CancellationToken.None);
Expand Down Expand Up @@ -70,8 +79,7 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask<
/// <param name="stream">Stream to write to</param>
public void WriteToStreamAsync(Stream stream, SNIAsyncCallback callback, SNIProviders provider, bool disposeAfterWriteAsync = false)
{
// Treat local function as a static and pass all params otherwise as async will allocate
async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProviders providers, bool disposeAfter, ValueTask valueTask)
static async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProviders providers, bool disposeAfter, ValueTask valueTask)
{
uint status = TdsEnums.SNI_SUCCESS;
try
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System.Buffers;
using System.IO;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;

Expand All @@ -18,31 +19,39 @@ internal partial class SNIPacket
/// <param name="callback">Completion callback</param>
public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback)
{
// Treat local function as a static and pass all params otherwise as async will allocate
async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, Task<int> task)
static async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, Task<int> task)
{
bool error = false;
uint errorCode = TdsEnums.SNI_SUCCESS;
try
{
packet._length = await task.ConfigureAwait(false);
if (packet._length == 0)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, string.Empty);
error = true;
errorCode = TdsEnums.SNI_WSAECONNRESET;
}
}
catch (IOException ioException) when (
ioException?.InnerException is SocketException socketException &&
socketException != null &&
socketException.SocketErrorCode == SocketError.OperationAborted
)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, socketException);
errorCode = TdsEnums.SNI_WSAECONNRESET;
}
catch (Exception ex)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, ex);
error = true;
errorCode = TdsEnums.SNI_ERROR;
}

if (error)
if (errorCode != TdsEnums.SNI_SUCCESS)
{
packet.Release();
}

cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
cb(packet, errorCode);
}

Task<int> t = stream.ReadAsync(_data, 0, _capacity, CancellationToken.None);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ public enum FedAuthLibrary : byte
public const uint SNI_SUCCESS_IO_PENDING = 997; // Overlapped I/O operation is in progress.

// Windows Sockets Error Codes
public const short SNI_WSAECONNRESET = 10054; // An existing connection was forcibly closed by the remote host.
public const uint SNI_WSAECONNRESET = 10054; // An existing connection was forcibly closed by the remote host.

// SNI internal errors (shouldn't overlap with Win32 / socket errors)
public const uint SNI_QUEUE_FULL = 1048576; // Packet queue is full
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ protected override void CreateSessionHandle(TdsParserStateObject physicalConnect
_sessionHandle = managedSNIObject.CreateMarsSession(this, async);
}

internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async)
internal SNIMarsHandle CreateMarsSession(TdsParserStateObject callbackObject, bool async)
{
return _marsConnection.CreateMarsSession(callbackObject, async);
}
Expand Down Expand Up @@ -96,7 +96,7 @@ internal override void Dispose()
{
packetHandle?.Dispose();
asyncAttnPacket?.Dispose();

if (sessionHandle != null)
{
sessionHandle.Dispose();
Expand Down Expand Up @@ -239,7 +239,7 @@ internal override uint EnableMars(ref uint info)
return TdsEnums.SNI_ERROR;
}

internal override uint EnableSsl(ref uint info)=> SNIProxy.Singleton.EnableSsl(Handle, info);
internal override uint EnableSsl(ref uint info) => SNIProxy.Singleton.EnableSsl(Handle, info);

internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) => SNIProxy.Singleton.SetConnectionBufferSize(Handle, unsignedPacketSize);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ public static class ConnectionPoolTest
private static readonly string _tcpMarsConnStr = (new SqlConnectionStringBuilder(DataTestUtility.TcpConnStr) { MultipleActiveResultSets = true, Pooling = true }).ConnectionString;


[ConditionalFact(typeof(DataTestUtility),nameof(DataTestUtility.AreConnStringsSetup), /* [ActiveIssue(33930)]: */ nameof(DataTestUtility.IsUsingNativeSNI))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
public static void ConnectionPool_NonMars()
{
RunDataTestForSingleConnString(_tcpConnStr);
}

[ConditionalFact(typeof(DataTestUtility),nameof(DataTestUtility.AreConnStringsSetup), /* [ActiveIssue(33930)] */ nameof(DataTestUtility.IsUsingNativeSNI))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
public static void ConnectionPool_Mars()
{
RunDataTestForSingleConnString(_tcpMarsConnStr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class PoolBlockPeriodTest
private const int ConnectionTimeout = 15;
private const int CompareMargin = 2;

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), /* [ActiveIssue(33930)] */ nameof(DataTestUtility.IsUsingNativeSNI))]
[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
[InlineData("Azure with Default Policy must Disable blocking (*.database.windows.net)", new object[] { AzureEndpointSample })]
[InlineData("Azure with Default Policy must Disable blocking (*.database.chinacloudapi.cn)", new object[] { AzureChinaEnpointSample })]
[InlineData("Azure with Default Policy must Disable blocking (*.database.usgovcloudapi.net)", new object[] { AzureUSGovernmentEndpointSample })]
Expand All @@ -45,7 +45,7 @@ public void TestAzureBlockingPeriod(string description, object[] Params)
PoolBlockingPeriodAzureTest(connString, policy);
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), /* [ActiveIssue(33930)] */ nameof(DataTestUtility.IsUsingNativeSNI))]
[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
[InlineData("NonAzure with Default Policy must Enable blocking", new object[] { NonExistentServer })]
[InlineData("NonAzure with Auto Policy must Enable Blocking", new object[] { NonExistentServer, PoolBlockingPeriod.Auto })]
[InlineData("NonAzure with Always Policy must Enable Blocking", new object[] { NonExistentServer, PoolBlockingPeriod.AlwaysBlock })]
Expand All @@ -66,7 +66,7 @@ public void TestNonAzureBlockingPeriod(string description, object[] Params)
PoolBlockingPeriodNonAzureTest(connString, policy);
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), /* [ActiveIssue(33930)] */ nameof(DataTestUtility.IsUsingNativeSNI))]
[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
[InlineData("Test policy with Auto (lowercase)", "auto")]
[InlineData("Test policy with Auto (PascalCase)", "Auto")]
[InlineData("Test policy with Always (lowercase)", "alwaysblock")]
Expand Down
Loading

0 comments on commit 853eb42

Please sign in to comment.