Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix | Fix possible server connection leak if an exception occurs in pooling layer #890

Merged
merged 8 commits into from
Apr 14, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,14 @@ private DbConnectionInternal CreateObject(DbConnection owningObject, DbConnectio

CheckPoolBlockingPeriod(e);

// Close associated Parser if connection already established.
if (newObj?.IsConnectionAlive() == true)
{
newObj.Dispose();
}

newObj = null; // set to null, so we do not return bad new object

// Failed to create instance
_resError = e;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,14 @@ private DbConnectionInternal CreateObject(DbConnection owningObject, DbConnectio
throw;
}

// Close associated Parser if connection already established.
if (newObj?.IsConnectionAlive() == true)
{
newObj.Dispose();
}

newObj = null; // set to null, so we do not return bad new object

// Failed to create instance
_resError = e;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Data;
using System.Data.Common;
using System.Reflection;
using System.Security;
using Microsoft.SqlServer.TDS.Servers;
using Xunit;

namespace Microsoft.Data.SqlClient.Tests
Expand Down Expand Up @@ -40,6 +42,37 @@ public void IntegratedAuthConnectionTest()
}
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotArmProcess))]
[PlatformSpecific(TestPlatforms.Windows)]
public void TransientFaultTest()
{
using (TransientFaultTDSServer server = TransientFaultTDSServer.StartTestServer(true, true, 40613))
{
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder()
{
DataSource = "localhost," + server.Port,
IntegratedSecurity = true
};

using (SqlConnection connection = new SqlConnection(builder.ConnectionString))
{
try
{
connection.Open();
Assert.Equal(ConnectionState.Open, connection.State);
}
catch (Exception e)
{
if (null != connection)
{
Assert.Equal(ConnectionState.Closed, connection.State);
}
Assert.False(true, e.Message);
}
}
}
}

[Fact]
public void SqlConnectionDbProviderFactoryTest()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@ namespace Microsoft.Data.SqlClient.Tests
{
internal class TestTdsServer : GenericTDSServer, IDisposable
{
private const int DefaultConnectionTimeout = 5;

private TDSServerEndPoint _endpoint = null;

private SqlConnectionStringBuilder connectionStringBuilder;
private SqlConnectionStringBuilder _connectionStringBuilder;

public TestTdsServer(TDSServerArguments args) : base(args) { }

public TestTdsServer(QueryEngine engine, TDSServerArguments args) : base(args)
{
this.Engine = engine;
Engine = engine;
}

public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool enableFedAuth = false, bool enableLog = false, [CallerMemberName] string methodName = "")
public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool enableFedAuth = false, bool enableLog = false, int connectionTimeout = DefaultConnectionTimeout, [CallerMemberName] string methodName = "")
{
TDSServerArguments args = new TDSServerArguments()
{
Expand All @@ -32,7 +34,7 @@ public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool

if (enableFedAuth)
{
args.FedAuthRequiredPreLoginOption = Microsoft.SqlServer.TDS.PreLogin.TdsPreLoginFedAuthRequiredOption.FedAuthRequired;
args.FedAuthRequiredPreLoginOption = SqlServer.TDS.PreLogin.TdsPreLoginFedAuthRequiredOption.FedAuthRequired;
}

TestTdsServer server = engine == null ? new TestTdsServer(args) : new TestTdsServer(engine, args);
Expand All @@ -43,14 +45,14 @@ public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool
server._endpoint.Start();

int port = server._endpoint.ServerEndPoint.Port;
server.connectionStringBuilder = new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = 5, Encrypt = false };
server.ConnectionString = server.connectionStringBuilder.ConnectionString;
server._connectionStringBuilder = new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = false };
server.ConnectionString = server._connectionStringBuilder.ConnectionString;
return server;
}

public static TestTdsServer StartTestServer(bool enableFedAuth = false, bool enableLog = false, [CallerMemberName] string methodName = "")
public static TestTdsServer StartTestServer(bool enableFedAuth = false, bool enableLog = false, int connectionTimeout = DefaultConnectionTimeout, [CallerMemberName] string methodName = "")
{
return StartServerWithQueryEngine(null, false, false, methodName);
return StartServerWithQueryEngine(null, enableFedAuth, enableLog, connectionTimeout, methodName);
}

public void Dispose() => _endpoint?.Stop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
<Compile Include="RoutingTDSServerArguments.cs" />
<Compile Include="ServerNameFilterType.cs" />
<Compile Include="TDSServerArguments.cs" />
<Compile Include="TransientFaultTDSServer.cs" />
<Compile Include="TransientFaultTDSServerArguments.cs" />
<None Include="TdsServerCertificate.pfx">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Net;
using System.Runtime.CompilerServices;
using System.Threading;
using Microsoft.SqlServer.TDS.Done;
using Microsoft.SqlServer.TDS.EndPoint;
using Microsoft.SqlServer.TDS.Error;
using Microsoft.SqlServer.TDS.Login7;

namespace Microsoft.SqlServer.TDS.Servers
{
/// <summary>
/// TDS Server that authenticates clients according to the requested parameters
/// </summary>
public class TransientFaultTDSServer : GenericTDSServer, IDisposable
{
private static int RequestCounter = 0;

public int Port { get; set; }

/// <summary>
/// Constructor
/// </summary>
public TransientFaultTDSServer() => new TransientFaultTDSServer(new TransientFaultTDSServerArguments());

/// <summary>
/// Constructor
/// </summary>
/// <param name="arguments"></param>
public TransientFaultTDSServer(TransientFaultTDSServerArguments arguments) :
base(arguments)
{ }

/// <summary>
/// Constructor
/// </summary>
/// <param name="engine"></param>
/// <param name="args"></param>
public TransientFaultTDSServer(QueryEngine engine, TransientFaultTDSServerArguments args) : base(args)
{
Engine = engine;
}

private TDSServerEndPoint _endpoint = null;

private static string GetErrorMessage(uint errorNumber)
{
switch (errorNumber)
{
case 40613:
return "Database on server is not currently available. Please retry the connection later. " +
"If the problem persists, contact customer support, and provide them the session tracing ID.";
}
return "Unknown server error occurred";
}

/// <summary>
/// Handler for login request
/// </summary>
public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSMessage request)
{
// Inflate login7 request from the message
TDSLogin7Token loginRequest = request[0] as TDSLogin7Token;

// Check if arguments are of the transient fault TDS server
if (Arguments is TransientFaultTDSServerArguments)
{
// Cast to transient fault TDS server arguments
TransientFaultTDSServerArguments ServerArguments = Arguments as TransientFaultTDSServerArguments;

// Check if we're still going to raise transient error
if (ServerArguments.IsEnabledTransientError && RequestCounter < 1) // Fail first time, then connect
{
uint errorNumber = ServerArguments.Number;
string errorMessage = ServerArguments.Message;

// Log request to which we're about to send a failure
TDSUtilities.Log(Arguments.Log, "Request", loginRequest);

// Prepare ERROR token with the denial details
TDSErrorToken errorToken = new TDSErrorToken(errorNumber, 1, 20, errorMessage);

// Log response
TDSUtilities.Log(Arguments.Log, "Response", errorToken);

// Serialize the error token into the response packet
TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken);

// Create DONE token
TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error);

// Log response
TDSUtilities.Log(Arguments.Log, "Response", doneToken);

// Serialize DONE token into the response packet
responseMessage.Add(doneToken);

RequestCounter++;

// Put a single message into the collection and return it
return new TDSMessageCollection(responseMessage);
}
}

// Return login response from the base class
return base.OnLogin7Request(session, request);
}

public static TransientFaultTDSServer StartTestServer(bool isEnabledTransientFault, bool enableLog, uint errorNumber, [CallerMemberName] string methodName = "")
=> StartServerWithQueryEngine(null, isEnabledTransientFault, enableLog, errorNumber, methodName);

public static TransientFaultTDSServer StartServerWithQueryEngine(QueryEngine engine, bool isEnabledTransientFault, bool enableLog, uint errorNumber, [CallerMemberName] string methodName = "")
{
TransientFaultTDSServerArguments args = new TransientFaultTDSServerArguments()
{
Log = enableLog ? Console.Out : null,
IsEnabledTransientError = isEnabledTransientFault,
Number = errorNumber,
Message = GetErrorMessage(errorNumber)
};

TransientFaultTDSServer server = engine == null ? new TransientFaultTDSServer(args) : new TransientFaultTDSServer(engine, args);
server._endpoint = new TDSServerEndPoint(server) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) };
server._endpoint.EndpointName = methodName;

// The server EventLog should be enabled as it logs the exceptions.
server._endpoint.EventLog = Console.Out;
server._endpoint.Start();

server.Port = server._endpoint.ServerEndPoint.Port;
return server;
}

public void Dispose() => Dispose(true);

private void Dispose(bool isDisposing)
{
if (isDisposing)
{
_endpoint?.Stop();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.SqlServer.TDS.Servers
{
public class TransientFaultTDSServerArguments : TDSServerArguments
{
/// <summary>
/// Transient error number to be raised by server.
/// </summary>
public uint Number { get; set; }

/// <summary>
/// Transient error message to be raised by server.
/// </summary>
public string Message { get; set; }

/// <summary>
/// Flag to consider when raising Transient error.
/// </summary>
public bool IsEnabledTransientError { get; set; }

/// <summary>
/// Constructor to initialize
/// </summary>
public TransientFaultTDSServerArguments()
{
Number = 0;
Message = string.Empty;
IsEnabledTransientError = false;
}
}
}