Skip to content

Commit

Permalink
move single callsite SNIProxy methods into the caller
Browse files Browse the repository at this point in the history
  • Loading branch information
Wraith2 committed Mar 2, 2021
1 parent ca2fe25 commit 4de9744
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,38 +32,6 @@ internal class SspiClientContextResult

internal static SNIProxy GetInstance() => s_singleton;

/// <summary>
/// Enable SSL on a connection
/// </summary>
/// <param name="handle">Connection handle</param>
/// <param name="options"></param>
/// <returns>SNI error code</returns>
internal uint EnableSsl(SNIHandle handle, uint options)
{
try
{
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.EnableSsl | Info | Session Id {0}", handle?.ConnectionId);
return handle.EnableSsl(options);
}
catch (Exception e)
{
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.EnableSsl | Err | Session Id {0}, SNI Handshake failed with exception: {1}", handle?.ConnectionId, e?.Message);
return SNICommon.ReportSNIError(SNIProviders.SSL_PROV, SNICommon.HandshakeFailureError, e);
}
}

/// <summary>
/// Disable SSL on a connection
/// </summary>
/// <param name="handle">Connection handle</param>
/// <returns>SNI error code</returns>
internal uint DisableSsl(SNIHandle handle)
{
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.DisableSsl | Info | Session Id {0}", handle?.ConnectionId);
handle.DisableSsl();
return TdsEnums.SNI_SUCCESS;
}

/// <summary>
/// Generate SSPI context
/// </summary>
Expand All @@ -72,7 +40,7 @@ internal uint DisableSsl(SNIHandle handle)
/// <param name="sendBuff">Send buffer</param>
/// <param name="serverName">Service Principal Name buffer</param>
/// <returns>SNI error code</returns>
internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
{
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
Expand Down Expand Up @@ -165,83 +133,6 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
errorCode != SecurityStatusPalErrorCode.Renegotiate;
}

/// <summary>
/// Set connection buffer size
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="bufferSize">Buffer size</param>
/// <returns>SNI error code</returns>
internal uint SetConnectionBufferSize(SNIHandle handle, uint bufferSize)
{
handle.SetBufferSize((int)bufferSize);
return TdsEnums.SNI_SUCCESS;
}

/// <summary>
/// Copies data in SNIPacket to given byte array parameter
/// </summary>
/// <param name="packet">SNIPacket object containing data packets</param>
/// <param name="inBuff">Destination byte array where data packets are copied to</param>
/// <param name="dataSize">Length of data packets</param>
/// <returns>SNI error status</returns>
internal uint PacketGetData(SNIPacket packet, byte[] inBuff, ref uint dataSize)
{
int dataSizeInt = 0;
packet.GetData(inBuff, ref dataSizeInt);
dataSize = (uint)dataSizeInt;

return TdsEnums.SNI_SUCCESS;
}

/// <summary>
/// Read synchronously
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="packet">SNI packet</param>
/// <param name="timeout">Timeout</param>
/// <returns>SNI error status</returns>
internal uint ReadSyncOverAsync(SNIHandle handle, out SNIPacket packet, int timeout)
{
return handle.Receive(out packet, timeout);
}

/// <summary>
/// Get SNI connection ID
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="clientConnectionId">Client connection ID</param>
/// <returns>SNI error status</returns>
internal uint GetConnectionId(SNIHandle handle, ref Guid clientConnectionId)
{
clientConnectionId = handle.ConnectionId;
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.GetConnectionId | Info | Session Id {0}", clientConnectionId);
return TdsEnums.SNI_SUCCESS;
}

/// <summary>
/// Send a packet
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="packet">SNI packet</param>
/// <param name="sync">true if synchronous, false if asynchronous</param>
/// <returns>SNI error status</returns>
internal uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
{
uint result;
if (sync)
{
result = handle.Send(packet);
handle.ReturnPacket(packet);
}
else
{
result = handle.SendAsync(packet);
}

SqlClientEventSource.Log.TryTraceEvent("SNIProxy.WritePacket | Info | Session Id {0}, SendAsync Result {1}", handle?.ConnectionId, result);
return result;
}

/// <summary>
/// Create a SNI connection handle
/// </summary>
Expand All @@ -257,7 +148,7 @@ internal uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
/// <param name="cachedFQDN">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNI handle</returns>
internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
internal static SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
instanceName = new byte[1];

Expand Down Expand Up @@ -377,7 +268,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
/// <param name="cachedFQDN">Key for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNITCPHandle</returns>
private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
private static SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
// TCP Format:
// tcp:<host name>\<instance name>
Expand Down Expand Up @@ -418,16 +309,14 @@ private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool
return new SNITCPHandle(hostName, port, timerExpire, parallel, cachedFQDN, ref pendingDNSInfo);
}



/// <summary>
/// Creates an SNINpHandle object
/// </summary>
/// <param name="details">Data source</param>
/// <param name="timerExpire">Timer expiration</param>
/// <param name="parallel">Should MultiSubnetFailover be used. Only returns an error for named pipes.</param>
/// <returns>SNINpHandle</returns>
private SNINpHandle CreateNpHandle(DataSource details, long timerExpire, bool parallel)
private static SNINpHandle CreateNpHandle(DataSource details, long timerExpire, bool parallel)
{
if (parallel)
{
Expand All @@ -438,39 +327,6 @@ private SNINpHandle CreateNpHandle(DataSource details, long timerExpire, bool pa
return new SNINpHandle(details.PipeHostName, details.PipeName, timerExpire);
}

/// <summary>
/// Read packet asynchronously
/// </summary>
/// <param name="handle">SNI handle</param>
/// <param name="packet">Packet</param>
/// <returns>SNI error status</returns>
internal uint ReadAsync(SNIHandle handle, out SNIPacket packet)
{
packet = null;
return handle.ReceiveAsync(ref packet);
}

/// <summary>
/// Set packet data
/// </summary>
/// <param name="packet">SNI packet</param>
/// <param name="data">Data</param>
/// <param name="length">Length</param>
internal void PacketSetData(SNIPacket packet, byte[] data, int length)
{
packet.AppendData(data, length);
}

/// <summary>
/// Check SNI handle connection
/// </summary>
/// <param name="handle"></param>
/// <returns>SNI error status</returns>
internal uint CheckConnection(SNIHandle handle)
{
return handle.CheckConnection();
}

/// <summary>
/// Get last SNI error on this thread
/// </summary>
Expand All @@ -486,7 +342,7 @@ internal SNIError GetLastError()
/// <param name="fullServerName">The data source</param>
/// <param name="error">Set true when an error occurred while getting LocalDB up</param>
/// <returns></returns>
private string GetLocalDBDataSource(string fullServerName, out bool error)
private static string GetLocalDBDataSource(string fullServerName, out bool error)
{
string localDBConnectionString = null;
bool isBadLocalDBDataSource;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,17 @@ internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async)
return _marsConnection.CreateMarsSession(callbackObject, async);
}

protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
=> SNIProxy.GetInstance().PacketGetData(packet.ManagedPacket, _inBuff, ref dataSize);
protected override uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize)
{
int dataSizeInt = 0;
packet.ManagedPacket.GetData(inBuff, ref dataSizeInt);
dataSize = (uint)dataSizeInt;
return TdsEnums.SNI_SUCCESS;
}

internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
{
_sessionHandle = SNIProxy.GetInstance().CreateConnectionHandle(serverName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref spnBuffer, flushCache, async, parallel, isIntegratedSecurity, cachedFQDN, ref pendingDNSInfo);
_sessionHandle = SNIProxy.CreateConnectionHandle(serverName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref spnBuffer, flushCache, async, parallel, isIntegratedSecurity, cachedFQDN, ref pendingDNSInfo);
if (_sessionHandle == null)
{
_parser.ProcessSNIError(this);
Expand Down Expand Up @@ -160,7 +165,9 @@ internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint
{
throw ADP.ClosedConnectionError();
}
error = SNIProxy.GetInstance().ReadSyncOverAsync(handle, out SNIPacket packet, timeoutRemaining);

error = handle.Receive(out SNIPacket packet, timeoutRemaining);

SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.ReadSyncOverAsync | Info | State Object Id {0}, Session Id {1}", _objectID, _sessionHandle?.ConnectionId);
#if DEBUG
SqlClientEventSource.Log.TryAdvancedTraceEvent("TdsParserStateObjectManaged.ReadSyncOverAsync | TRC | State Object Id {0}, Session Id {1}, Packet {2} received, Packet owner Id {3}, Packet dataLeft {4}", _objectID, _sessionHandle?.ConnectionId, packet?._id, packet?._owner.ConnectionId, packet?.DataLeft);
Expand Down Expand Up @@ -189,12 +196,14 @@ internal override void ReleasePacket(PacketHandle syncReadPacket)
internal override uint CheckConnection()
{
SNIHandle handle = Handle;
return handle == null ? TdsEnums.SNI_SUCCESS : SNIProxy.GetInstance().CheckConnection(handle);
return handle == null ? TdsEnums.SNI_SUCCESS : handle.CheckConnection();
}

internal override PacketHandle ReadAsync(SessionHandle handle, out uint error)
{
error = SNIProxy.GetInstance().ReadAsync(handle.ManagedHandle, out SNIPacket packet);
SNIPacket packet = null;
error = handle.ManagedHandle.ReceiveAsync(ref packet);

SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.ReadAsync | Info | State Object Id {0}, Session Id {1}, Packet DataLeft {2}", _objectID, _sessionHandle?.ConnectionId, packet?.DataLeft);
return PacketHandle.FromManagedPacket(packet);
}
Expand All @@ -212,8 +221,24 @@ internal override PacketHandle CreateAndSetAttentionPacket()
return packetHandle;
}

internal override uint WritePacket(PacketHandle packet, bool sync) =>
SNIProxy.GetInstance().WritePacket(Handle, packet.ManagedPacket, sync);
internal override uint WritePacket(PacketHandle packetHandle, bool sync)
{
uint result;
SNIHandle handle = Handle;
SNIPacket packet = packetHandle.ManagedPacket;
if (sync)
{
result = handle.Send(packet);
handle.ReturnPacket(packet);
}
else
{
result = handle.SendAsync(packet);
}

SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.WritePacket | Info | Session Id {0}, SendAsync Result {1}", handle?.ConnectionId, result);
return result;
}

// No- Op in managed SNI
internal override PacketHandle AddPacketToPendingList(PacketHandle packet) => packet;
Expand Down Expand Up @@ -244,11 +269,25 @@ internal override void ClearAllWritePackets()
Debug.Assert(_asyncWriteCount == 0, "Should not clear all write packets if there are packets pending");
}

internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed) => SNIProxy.GetInstance().PacketSetData(packet.ManagedPacket, buffer, bytesUsed);
internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed)
{
packet.ManagedPacket.AppendData(buffer, bytesUsed);
}

internal override uint SniGetConnectionId(ref Guid clientConnectionId) => SNIProxy.GetInstance().GetConnectionId(Handle, ref clientConnectionId);
internal override uint SniGetConnectionId(ref Guid clientConnectionId)
{
clientConnectionId = Handle.ConnectionId;
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.GetConnectionId | Info | Session Id {0}", clientConnectionId);
return TdsEnums.SNI_SUCCESS;
}

internal override uint DisableSsl() => SNIProxy.GetInstance().DisableSsl(Handle);
internal override uint DisableSsl()
{
SNIHandle handle = Handle;
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.DisableSsl | Info | Session Id {0}", handle?.ConnectionId);
handle.DisableSsl();
return TdsEnums.SNI_SUCCESS;
}

internal override uint EnableMars(ref uint info)
{
Expand All @@ -263,9 +302,26 @@ internal override uint EnableMars(ref uint info)
return TdsEnums.SNI_ERROR;
}

internal override uint EnableSsl(ref uint info) => SNIProxy.GetInstance().EnableSsl(Handle, info);
internal override uint EnableSsl(ref uint info)
{
SNIHandle handle = Handle;
try
{
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.EnableSsl | Info | Session Id {0}", handle?.ConnectionId);
return handle.EnableSsl(info);
}
catch (Exception e)
{
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.EnableSsl | Err | Session Id {0}, SNI Handshake failed with exception: {1}", handle?.ConnectionId, e?.Message);
return SNICommon.ReportSNIError(SNIProviders.SSL_PROV, SNICommon.HandshakeFailureError, e);
}
}

internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) => SNIProxy.GetInstance().SetConnectionBufferSize(Handle, unsignedPacketSize);
internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize)
{
Handle.SetBufferSize((int)unsignedPacketSize);
return TdsEnums.SNI_SUCCESS;
}

internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
{
Expand All @@ -274,8 +330,8 @@ internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint recei
_sspiClientContextStatus = new SspiClientContextStatus();
}

SNIProxy.GetInstance().GenSspiClientContext(_sspiClientContextStatus, receivedBuff, ref sendBuff, _sniSpnBuffer);
SqlClientEventSource.Log.TryTraceEvent("SNIProxy.GenerateSspiClientContext | Info | Session Id {0}", _sessionHandle?.ConnectionId);
SNIProxy.GenSspiClientContext(_sspiClientContextStatus, receivedBuff, ref sendBuff, _sniSpnBuffer);
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}", _sessionHandle?.ConnectionId);
sendLength = (uint)(sendBuff != null ? sendBuff.Length : 0);
return 0;
}
Expand Down

0 comments on commit 4de9744

Please sign in to comment.