Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify SNIProxy #934

Merged
merged 3 commits into from
Jun 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,48 +21,9 @@ internal class SNIProxy
private const int DefaultSqlServerDacPort = 1434;
private const string SqlServerSpnHeader = "MSSQLSvc";

internal class SspiClientContextResult
{
internal const uint OK = 0;
internal const uint Failed = 1;
internal const uint KerberosTicketMissing = 2;
}

internal static readonly SNIProxy s_singleton = new SNIProxy();
private static readonly SNIProxy s_singleton = new SNIProxy();

internal static SNIProxy GetInstance() => s_singleton;

/// <summary>
/// Enable SSL on a connection
/// </summary>
/// <param name="handle">Connection handle</param>
/// <param name="options"></param>
/// <returns>SNI error code</returns>
internal uint EnableSsl(SNIHandle handle, uint options)
{
try
{
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.EnableSsl | Info | Session Id {0}", handle?.ConnectionId);
return handle.EnableSsl(options);
}
catch (Exception e)
{
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.EnableSsl | Err | Session Id {0}, SNI Handshake failed with exception: {1}", handle?.ConnectionId, e?.Message);
return SNICommon.ReportSNIError(SNIProviders.SSL_PROV, SNICommon.HandshakeFailureError, e);
}
}

/// <summary>
/// Disable SSL on a connection
/// </summary>
/// <param name="handle">Connection handle</param>
/// <returns>SNI error code</returns>
internal uint DisableSsl(SNIHandle handle)
{
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.DisableSsl | Info | Session Id {0}", handle?.ConnectionId);
handle.DisableSsl();
return TdsEnums.SNI_SUCCESS;
}
internal static SNIProxy Instance => s_singleton;

/// <summary>
/// Generate SSPI context
Expand All @@ -72,7 +33,7 @@ internal uint DisableSsl(SNIHandle handle)
/// <param name="sendBuff">Send buffer</param>
/// <param name="serverName">Service Principal Name buffer</param>
/// <returns>SNI error code</returns>
internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
{
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
Expand Down Expand Up @@ -165,83 +126,6 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
errorCode != SecurityStatusPalErrorCode.Renegotiate;
}

/// <summary>
/// Set connection buffer size
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="bufferSize">Buffer size</param>
/// <returns>SNI error code</returns>
internal uint SetConnectionBufferSize(SNIHandle handle, uint bufferSize)
{
handle.SetBufferSize((int)bufferSize);
return TdsEnums.SNI_SUCCESS;
}

/// <summary>
/// Copies data in SNIPacket to given byte array parameter
/// </summary>
/// <param name="packet">SNIPacket object containing data packets</param>
/// <param name="inBuff">Destination byte array where data packets are copied to</param>
/// <param name="dataSize">Length of data packets</param>
/// <returns>SNI error status</returns>
Wraith2 marked this conversation as resolved.
Show resolved Hide resolved
internal uint PacketGetData(SNIPacket packet, byte[] inBuff, ref uint dataSize)
{
int dataSizeInt = 0;
packet.GetData(inBuff, ref dataSizeInt);
dataSize = (uint)dataSizeInt;

return TdsEnums.SNI_SUCCESS;
}

/// <summary>
/// Read synchronously
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="packet">SNI packet</param>
/// <param name="timeout">Timeout</param>
/// <returns>SNI error status</returns>
internal uint ReadSyncOverAsync(SNIHandle handle, out SNIPacket packet, int timeout)
{
return handle.Receive(out packet, timeout);
}

/// <summary>
/// Get SNI connection ID
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="clientConnectionId">Client connection ID</param>
/// <returns>SNI error status</returns>
internal uint GetConnectionId(SNIHandle handle, ref Guid clientConnectionId)
{
clientConnectionId = handle.ConnectionId;
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.GetConnectionId | Info | Session Id {0}", clientConnectionId);
return TdsEnums.SNI_SUCCESS;
}

/// <summary>
/// Send a packet
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="packet">SNI packet</param>
/// <param name="sync">true if synchronous, false if asynchronous</param>
/// <returns>SNI error status</returns>
internal uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
{
uint result;
if (sync)
{
result = handle.Send(packet);
handle.ReturnPacket(packet);
}
else
{
result = handle.SendAsync(packet);
}

SqlClientEventSource.Log.TryTraceEvent("SNIProxy.WritePacket | Info | Session Id {0}, SendAsync Result {1}", handle?.ConnectionId, result);
return result;
}

/// <summary>
/// Create a SNI connection handle
/// </summary>
Expand All @@ -258,7 +142,7 @@ internal uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
/// <param name="cachedFQDN">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNI handle</returns>
internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer,
internal static SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer,
bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
instanceName = new byte[1];
Expand Down Expand Up @@ -380,7 +264,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
/// <param name="cachedFQDN">Key for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNITCPHandle</returns>
private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
private static SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
// TCP Format:
// tcp:<host name>\<instance name>
Expand Down Expand Up @@ -421,16 +305,14 @@ private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool
return new SNITCPHandle(hostName, port, timerExpire, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo);
}



/// <summary>
/// Creates an SNINpHandle object
/// </summary>
/// <param name="details">Data source</param>
/// <param name="timerExpire">Timer expiration</param>
/// <param name="parallel">Should MultiSubnetFailover be used. Only returns an error for named pipes.</param>
/// <returns>SNINpHandle</returns>
private SNINpHandle CreateNpHandle(DataSource details, long timerExpire, bool parallel)
private static SNINpHandle CreateNpHandle(DataSource details, long timerExpire, bool parallel)
{
if (parallel)
{
Expand All @@ -441,39 +323,6 @@ private SNINpHandle CreateNpHandle(DataSource details, long timerExpire, bool pa
return new SNINpHandle(details.PipeHostName, details.PipeName, timerExpire);
}

/// <summary>
/// Read packet asynchronously
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="packet">Packet</param>
/// <returns>SNI error status</returns>
internal uint ReadAsync(SNIHandle handle, out SNIPacket packet)
{
packet = null;
return handle.ReceiveAsync(ref packet);
}

/// <summary>
/// Set packet data
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="data">Data</param>
/// <param name="length">Length</param>
internal void PacketSetData(SNIPacket packet, byte[] data, int length)
{
packet.AppendData(data, length);
}

/// <summary>
/// Check SNI handle connection
/// </summary>
/// <param name="handle"></param>
/// <returns>SNI error status</returns>
internal uint CheckConnection(SNIHandle handle)
{
return handle.CheckConnection();
}

/// <summary>
/// Get last SNI error on this thread
/// </summary>
Expand All @@ -489,7 +338,7 @@ internal SNIError GetLastError()
/// <param name="fullServerName">The data source</param>
/// <param name="error">Set true when an error occurred while getting LocalDB up</param>
/// <returns></returns>
private string GetLocalDBDataSource(string fullServerName, out bool error)
private static string GetLocalDBDataSource(string fullServerName, out bool error)
{
string localDBConnectionString = null;
bool isBadLocalDBDataSource;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ private void WaitForSSLHandShakeToComplete(ref uint error, ref int protocolVersi
private SNIErrorDetails GetSniErrorDetails()
{
SNIErrorDetails details;
SNIError sniError = SNIProxy.GetInstance().GetLastError();
SNIError sniError = SNIProxy.Instance.GetLastError();
details.sniErrorNumber = sniError.sniError;
details.errorMessage = sniError.errorMessage;
details.nativeError = sniError.nativeError;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ private SNIErrorDetails GetSniErrorDetails()

if (TdsParserStateObjectFactory.UseManagedSNI)
{
SNIError sniError = SNIProxy.GetInstance().GetLastError();
SNIError sniError = SNIProxy.Instance.GetLastError();
details.sniErrorNumber = sniError.sniError;
details.errorMessage = sniError.errorMessage;
details.nativeError = sniError.nativeError;
Expand Down
Loading