Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify SNIProxy #934

Merged
merged 3 commits into from
Jun 14, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,38 +32,6 @@ internal class SspiClientContextResult

internal static SNIProxy GetInstance() => s_singleton;
Wraith2 marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Enable SSL on a connection
/// </summary>
/// <param name="handle">Connection handle</param>
/// <param name="options"></param>
/// <returns>SNI error code</returns>
internal uint EnableSsl(SNIHandle handle, uint options)
{
try
{
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.EnableSsl | Info | Session Id {0}", handle?.ConnectionId);
return handle.EnableSsl(options);
}
catch (Exception e)
{
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.EnableSsl | Err | Session Id {0}, SNI Handshake failed with exception: {1}", handle?.ConnectionId, e?.Message);
return SNICommon.ReportSNIError(SNIProviders.SSL_PROV, SNICommon.HandshakeFailureError, e);
}
}

/// <summary>
/// Disable SSL on a connection
/// </summary>
/// <param name="handle">Connection handle</param>
/// <returns>SNI error code</returns>
internal uint DisableSsl(SNIHandle handle)
{
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.DisableSsl | Info | Session Id {0}", handle?.ConnectionId);
handle.DisableSsl();
return TdsEnums.SNI_SUCCESS;
}

/// <summary>
/// Generate SSPI context
/// </summary>
Expand All @@ -72,7 +40,7 @@ internal uint DisableSsl(SNIHandle handle)
/// <param name="sendBuff">Send buffer</param>
/// <param name="serverName">Service Principal Name buffer</param>
/// <returns>SNI error code</returns>
internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
{
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
Expand Down Expand Up @@ -165,83 +133,6 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
errorCode != SecurityStatusPalErrorCode.Renegotiate;
}

/// <summary>
/// Set connection buffer size
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="bufferSize">Buffer size</param>
/// <returns>SNI error code</returns>
internal uint SetConnectionBufferSize(SNIHandle handle, uint bufferSize)
{
handle.SetBufferSize((int)bufferSize);
return TdsEnums.SNI_SUCCESS;
}

/// <summary>
/// Copies data in SNIPacket to given byte array parameter
/// </summary>
/// <param name="packet">SNIPacket object containing data packets</param>
/// <param name="inBuff">Destination byte array where data packets are copied to</param>
/// <param name="dataSize">Length of data packets</param>
/// <returns>SNI error status</returns>
Wraith2 marked this conversation as resolved.
Show resolved Hide resolved
internal uint PacketGetData(SNIPacket packet, byte[] inBuff, ref uint dataSize)
{
int dataSizeInt = 0;
packet.GetData(inBuff, ref dataSizeInt);
dataSize = (uint)dataSizeInt;

return TdsEnums.SNI_SUCCESS;
}

/// <summary>
/// Read synchronously
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="packet">SNI packet</param>
/// <param name="timeout">Timeout</param>
/// <returns>SNI error status</returns>
internal uint ReadSyncOverAsync(SNIHandle handle, out SNIPacket packet, int timeout)
{
return handle.Receive(out packet, timeout);
}

/// <summary>
/// Get SNI connection ID
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="clientConnectionId">Client connection ID</param>
/// <returns>SNI error status</returns>
internal uint GetConnectionId(SNIHandle handle, ref Guid clientConnectionId)
{
clientConnectionId = handle.ConnectionId;
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.GetConnectionId | Info | Session Id {0}", clientConnectionId);
return TdsEnums.SNI_SUCCESS;
}

/// <summary>
/// Send a packet
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="packet">SNI packet</param>
/// <param name="sync">true if synchronous, false if asynchronous</param>
/// <returns>SNI error status</returns>
internal uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
{
uint result;
if (sync)
{
result = handle.Send(packet);
handle.ReturnPacket(packet);
}
else
{
result = handle.SendAsync(packet);
}

SqlClientEventSource.Log.TryTraceEvent("SNIProxy.WritePacket | Info | Session Id {0}, SendAsync Result {1}", handle?.ConnectionId, result);
return result;
}

/// <summary>
/// Create a SNI connection handle
/// </summary>
Expand All @@ -258,7 +149,7 @@ internal uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
/// <param name="cachedFQDN">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNI handle</returns>
internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer,
internal static SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer,
bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
instanceName = new byte[1];
Expand Down Expand Up @@ -380,7 +271,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
/// <param name="cachedFQDN">Key for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNITCPHandle</returns>
private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
private static SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
// TCP Format:
// tcp:<host name>\<instance name>
Expand Down Expand Up @@ -421,16 +312,14 @@ private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool
return new SNITCPHandle(hostName, port, timerExpire, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo);
}



/// <summary>
/// Creates an SNINpHandle object
/// </summary>
/// <param name="details">Data source</param>
/// <param name="timerExpire">Timer expiration</param>
/// <param name="parallel">Should MultiSubnetFailover be used. Only returns an error for named pipes.</param>
/// <returns>SNINpHandle</returns>
private SNINpHandle CreateNpHandle(DataSource details, long timerExpire, bool parallel)
private static SNINpHandle CreateNpHandle(DataSource details, long timerExpire, bool parallel)
{
if (parallel)
{
Expand All @@ -441,39 +330,6 @@ private SNINpHandle CreateNpHandle(DataSource details, long timerExpire, bool pa
return new SNINpHandle(details.PipeHostName, details.PipeName, timerExpire);
}

/// <summary>
/// Read packet asynchronously
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="packet">Packet</param>
/// <returns>SNI error status</returns>
internal uint ReadAsync(SNIHandle handle, out SNIPacket packet)
{
packet = null;
return handle.ReceiveAsync(ref packet);
}

/// <summary>
/// Set packet data
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="data">Data</param>
/// <param name="length">Length</param>
internal void PacketSetData(SNIPacket packet, byte[] data, int length)
{
packet.AppendData(data, length);
}

/// <summary>
/// Check SNI handle connection
/// </summary>
/// <param name="handle"></param>
/// <returns>SNI error status</returns>
internal uint CheckConnection(SNIHandle handle)
{
return handle.CheckConnection();
}

/// <summary>
/// Get last SNI error on this thread
/// </summary>
Expand All @@ -489,7 +345,7 @@ internal SNIError GetLastError()
/// <param name="fullServerName">The data source</param>
/// <param name="error">Set true when an error occurred while getting LocalDB up</param>
/// <returns></returns>
private string GetLocalDBDataSource(string fullServerName, out bool error)
private static string GetLocalDBDataSource(string fullServerName, out bool error)
{
string localDBConnectionString = null;
bool isBadLocalDBDataSource;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,18 @@ internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async)
return _marsConnection.CreateMarsSession(callbackObject, async);
}

protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
=> SNIProxy.GetInstance().PacketGetData(packet.ManagedPacket, _inBuff, ref dataSize);
protected override uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize)
{
int dataSizeInt = 0;
packet.ManagedPacket.GetData(inBuff, ref dataSizeInt);
dataSize = (uint)dataSizeInt;
return TdsEnums.SNI_SUCCESS;
}

internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel,
SqlConnectionIPAddressPreference iPAddressPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
{
_sessionHandle = SNIProxy.GetInstance().CreateConnectionHandle(serverName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref spnBuffer, flushCache, async, parallel, isIntegratedSecurity,
_sessionHandle = SNIProxy.CreateConnectionHandle(serverName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref spnBuffer, flushCache, async, parallel, isIntegratedSecurity,
iPAddressPreference, cachedFQDN, ref pendingDNSInfo);
if (_sessionHandle == null)
{
Expand Down Expand Up @@ -162,7 +167,9 @@ internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint
{
throw ADP.ClosedConnectionError();
}
error = SNIProxy.GetInstance().ReadSyncOverAsync(handle, out SNIPacket packet, timeoutRemaining);

error = handle.Receive(out SNIPacket packet, timeoutRemaining);
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved

SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.ReadSyncOverAsync | Info | State Object Id {0}, Session Id {1}", _objectID, _sessionHandle?.ConnectionId);
#if DEBUG
SqlClientEventSource.Log.TryAdvancedTraceEvent("TdsParserStateObjectManaged.ReadSyncOverAsync | TRC | State Object Id {0}, Session Id {1}, Packet {2} received, Packet owner Id {3}, Packet dataLeft {4}", _objectID, _sessionHandle?.ConnectionId, packet?._id, packet?._owner.ConnectionId, packet?.DataLeft);
Expand Down Expand Up @@ -191,12 +198,14 @@ internal override void ReleasePacket(PacketHandle syncReadPacket)
internal override uint CheckConnection()
{
SNIHandle handle = Handle;
return handle == null ? TdsEnums.SNI_SUCCESS : SNIProxy.GetInstance().CheckConnection(handle);
return handle == null ? TdsEnums.SNI_SUCCESS : handle.CheckConnection();
}

internal override PacketHandle ReadAsync(SessionHandle handle, out uint error)
{
error = SNIProxy.GetInstance().ReadAsync(handle.ManagedHandle, out SNIPacket packet);
SNIPacket packet = null;
error = handle.ManagedHandle.ReceiveAsync(ref packet);

SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.ReadAsync | Info | State Object Id {0}, Session Id {1}, Packet DataLeft {2}", _objectID, _sessionHandle?.ConnectionId, packet?.DataLeft);
return PacketHandle.FromManagedPacket(packet);
}
Expand All @@ -214,8 +223,24 @@ internal override PacketHandle CreateAndSetAttentionPacket()
return packetHandle;
}

internal override uint WritePacket(PacketHandle packet, bool sync) =>
SNIProxy.GetInstance().WritePacket(Handle, packet.ManagedPacket, sync);
internal override uint WritePacket(PacketHandle packetHandle, bool sync)
{
uint result;
SNIHandle handle = Handle;
SNIPacket packet = packetHandle.ManagedPacket;
if (sync)
{
result = handle.Send(packet);
handle.ReturnPacket(packet);
}
else
{
result = handle.SendAsync(packet);
}

SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.WritePacket | Info | Session Id {0}, SendAsync Result {1}", handle?.ConnectionId, result);
return result;
}

// No- Op in managed SNI
internal override PacketHandle AddPacketToPendingList(PacketHandle packet) => packet;
Expand Down Expand Up @@ -246,11 +271,25 @@ internal override void ClearAllWritePackets()
Debug.Assert(_asyncWriteCount == 0, "Should not clear all write packets if there are packets pending");
}

internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed) => SNIProxy.GetInstance().PacketSetData(packet.ManagedPacket, buffer, bytesUsed);
internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed)
{
packet.ManagedPacket.AppendData(buffer, bytesUsed);
}

internal override uint SniGetConnectionId(ref Guid clientConnectionId) => SNIProxy.GetInstance().GetConnectionId(Handle, ref clientConnectionId);
internal override uint SniGetConnectionId(ref Guid clientConnectionId)
{
clientConnectionId = Handle.ConnectionId;
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.GetConnectionId | Info | Session Id {0}", clientConnectionId);
Wraith2 marked this conversation as resolved.
Show resolved Hide resolved
return TdsEnums.SNI_SUCCESS;
}

internal override uint DisableSsl() => SNIProxy.GetInstance().DisableSsl(Handle);
internal override uint DisableSsl()
{
SNIHandle handle = Handle;
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.DisableSsl | Info | Session Id {0}", handle?.ConnectionId);
handle.DisableSsl();
return TdsEnums.SNI_SUCCESS;
}

internal override uint EnableMars(ref uint info)
{
Expand All @@ -265,9 +304,26 @@ internal override uint EnableMars(ref uint info)
return TdsEnums.SNI_ERROR;
}

internal override uint EnableSsl(ref uint info) => SNIProxy.GetInstance().EnableSsl(Handle, info);
internal override uint EnableSsl(ref uint info)
{
SNIHandle handle = Handle;
try
{
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.EnableSsl | Info | Session Id {0}", handle?.ConnectionId);
return handle.EnableSsl(info);
}
catch (Exception e)
{
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.EnableSsl | Err | Session Id {0}, SNI Handshake failed with exception: {1}", handle?.ConnectionId, e?.Message);
return SNICommon.ReportSNIError(SNIProviders.SSL_PROV, SNICommon.HandshakeFailureError, e);
}
}

internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) => SNIProxy.GetInstance().SetConnectionBufferSize(Handle, unsignedPacketSize);
internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize)
{
Handle.SetBufferSize((int)unsignedPacketSize);
return TdsEnums.SNI_SUCCESS;
}

internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
{
Expand All @@ -276,8 +332,8 @@ internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint recei
_sspiClientContextStatus = new SspiClientContextStatus();
}

SNIProxy.GetInstance().GenSspiClientContext(_sspiClientContextStatus, receivedBuff, ref sendBuff, _sniSpnBuffer);
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.GenerateSspiClientContext | Info | Session Id {0}", _sessionHandle?.ConnectionId);
SNIProxy.GenSspiClientContext(_sspiClientContextStatus, receivedBuff, ref sendBuff, _sniSpnBuffer);
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}", _sessionHandle?.ConnectionId);
sendLength = (uint)(sendBuff != null ? sendBuff.Length : 0);
return 0;
}
Expand Down