Skip to content

Commit

Permalink
Add DNS Caching Phase 1 (#594)
Browse files Browse the repository at this point in the history
This PR adds support for the upcoming DNS Caching feature on the server-side. There are two phases for this feature and this is the first phase support.

When using TCP protocol, the SqlClient driver will record the DNS information such as IP address and port number in an in-memory cache if the server sends back the corresponding TDS token for this caching. The DNS information will stay in the cache during the application lifetime unless the server tells the driver to clean the cache.

In phase 1, the IP and port number are from the DNS resolution during pre-login handshake. If the DNS server fails to resolve the server name, the driver will connect with its cached IP and port if they existed. Otherwise, the connection will fail.

Add support for both .NET Framework and .NET Core.
This change requires the corresponding change in native SNI.
The caching behavior is only for TCP connections.
Need to fail the DNS Server to test this behavior.
  • Loading branch information
karinazhou authored Jun 16, 2020
1 parent 033541f commit ead806e
Show file tree
Hide file tree
Showing 33 changed files with 1,200 additions and 73 deletions.
17 changes: 17 additions & 0 deletions doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,23 @@ GO
]]></format>
</remarks>
</RetrieveStatistics>
<RetrieveInternalInfo>
<summary>Returns a name value pair collection of internal properties at the point in time the method is called.</summary>
<returns>Returns a reference of type <see cref="T:System.Collections.Generic.IDictionary" /> of (string, object) items.</returns>
<remarks>
<format type="text/markdown"><![CDATA[
## Remarks
When this method is called, the values retrieved are those at the current point in time. If you continue using the connection, the values are incorrect. You need to re-execute the method to obtain the most current values.
|Supported internal properties|Type|Information provided|Return value|
|-----------------------------|---------|----------------------------|------------|
|`SQLDNSCachingSupportedState`|string|To indicate the IsSupported flag sent by the server for DNS Caching|"true", "false", "innerConnection is null!"|
|`SQLDNSCachingSupportedStateBeforeRedirect`|string|To indicate the IsSupported flag sent by the server for DNS Caching before redirection.|"true", "false", "innerConnection is null!"|
]]></format>
</remarks>
</RetrieveInternalInfo>
<ServerVersion>
<summary>Gets a string that contains the version of the instance of SQL Server to which the client is connected.</summary>
<value>The version of the instance of SQL Server.</value>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,18 @@ public SqlConnection(string connectionString, Microsoft.Data.SqlClient.SqlCreden
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/ClientConnectionId/*'/>
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
public System.Guid ClientConnectionId { get { throw null; } }

///
/// for internal test only
///
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
internal string SQLDNSCachingSupportedState { get { throw null; } }
///
/// for internal test only
///
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
internal string SQLDNSCachingSupportedStateBeforeRedirect { get { throw null; } }

object System.ICloneable.Clone() { throw null; }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/ConnectionString/*'/>
[System.ComponentModel.DefaultValueAttribute("")]
Expand Down Expand Up @@ -639,6 +651,9 @@ public void Open(SqlConnectionOverrides overrides) { }
public void ResetStatistics() { }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RetrieveStatistics/*'/>
public System.Collections.IDictionary RetrieveStatistics() { throw null; }

/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RetrieveInternalInfo/*'/>
public System.Collections.Generic.IDictionary<string, object> RetrieveInternalInfo() { throw null; }
}
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnectionOverrides.xml' path='docs/members[@name="SqlConnectionOverrides"]/SqlConnectionOverrides/*' />
public enum SqlConnectionOverrides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.Data.SqlClient;
using System;
using System.Runtime.InteropServices;
using System.Text;

namespace Microsoft.Data.SqlClient
{
Expand All @@ -20,6 +21,8 @@ internal static partial class SNINativeMethodWrapper
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
internal delegate void SqlAsyncCallbackDelegate(IntPtr m_ConsKey, IntPtr pPacket, uint dwError);

internal const int SniIP6AddrStringBufferLength = 48; // from SNI layer

internal static int SniMaxComposedSpnLength
{
get
Expand Down Expand Up @@ -162,6 +165,20 @@ private unsafe struct SNI_CLIENT_CONSUMER_INFO
public TransparentNetworkResolutionMode transparentNetworkResolution;
public int totalTimeout;
public bool isAzureSqlServerEndpoint;
public SNI_DNSCache_Info DNSCacheInfo;
}

[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
internal struct SNI_DNSCache_Info
{
[MarshalAs(UnmanagedType.LPWStr)]
public string wszCachedFQDN;
[MarshalAs(UnmanagedType.LPWStr)]
public string wszCachedTcpIPv4;
[MarshalAs(UnmanagedType.LPWStr)]
public string wszCachedTcpIPv6;
[MarshalAs(UnmanagedType.LPWStr)]
public string wszCachedTcpPort;
}

[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
Expand Down Expand Up @@ -236,6 +253,15 @@ internal struct SNI_Error
[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out Guid pbQInfo);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out ushort portNum);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Unicode)]
private static extern uint SNIGetPeerAddrStrWrapper([In] SNIHandle pConn, int bufferSize, StringBuilder addrBuffer, out uint addrLen);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out ProviderEnum provNum);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern uint SNIInitialize([In] IntPtr pmo);

Expand All @@ -248,7 +274,8 @@ private static extern uint SNIOpenWrapper(
[MarshalAs(UnmanagedType.LPWStr)] string szConnect,
[In] SNIHandle pConn,
out IntPtr ppConn,
[MarshalAs(UnmanagedType.Bool)] bool fSync);
[MarshalAs(UnmanagedType.Bool)] bool fSync,
[In] ref SNI_DNSCache_Info pDNSCachedInfo);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern IntPtr SNIPacketAllocateWrapper([In] SafeHandle pConn, IOType IOType);
Expand Down Expand Up @@ -283,22 +310,53 @@ internal static uint SniGetConnectionId(SNIHandle pConn, ref Guid connId)
{
return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_CONNID, out connId);
}

internal static uint SniGetProviderNumber(SNIHandle pConn, ref ProviderEnum provNum)
{
return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_PROVIDERNUM, out provNum);
}

internal static uint SniGetConnectionPort(SNIHandle pConn, ref ushort portNum)
{
return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_PEERPORT, out portNum);
}

internal static uint SniGetConnectionIPString(SNIHandle pConn, ref string connIPStr)
{
UInt32 ret;
uint connIPLen = 0;

int bufferSize = SniIP6AddrStringBufferLength;
StringBuilder addrBuffer = new StringBuilder(bufferSize);

ret = SNIGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out connIPLen);

connIPStr = addrBuffer.ToString(0, Convert.ToInt32(connIPLen));

return ret;
}

internal static uint SNIInitialize()
{
return SNIInitialize(IntPtr.Zero);
}

internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync)
internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync, SQLDNSInfo cachedDNSInfo)
{
// initialize consumer info for MARS
Sni_Consumer_Info native_consumerInfo = new Sni_Consumer_Info();
MarshalConsumerInfo(consumerInfo, ref native_consumerInfo);

return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync);
SNI_DNSCache_Info native_cachedDNSInfo = new SNI_DNSCache_Info();
native_cachedDNSInfo.wszCachedFQDN = cachedDNSInfo?.FQDN;
native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4;
native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port;

return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ref native_cachedDNSInfo);
}

internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache, bool fSync, int timeout, bool fParallel)
internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache, bool fSync, int timeout, bool fParallel, SQLDNSInfo cachedDNSInfo)
{
fixed (byte* pin_instanceName = &instanceName[0])
{
Expand All @@ -321,6 +379,11 @@ internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string cons
clientConsumerInfo.totalTimeout = SniOpenTimeOut;
clientConsumerInfo.isAzureSqlServerEndpoint = ADP.IsAzureSqlServerEndpoint(constring);

clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port;

if (spnBuffer != null)
{
fixed (byte* pin_spnBuffer = &spnBuffer[0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@
<Compile Include="..\..\src\Microsoft\Data\SqlTypes\SqlTypeWorkarounds.cs">
<Link>Microsoft\Data\SqlTypes\SqlTypeWorkarounds.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\SQLFallbackDNSCache.cs">
<Link>Microsoft\Data\SqlClient\SQLFallbackDNSCache.cs</Link>
</Compile>
</ItemGroup>
<ItemGroup Condition="'$(TargetGroup)' == 'netstandard' OR '$(TargetGroup)' == 'netcoreapp' OR '$(IsUAPAssembly)' == 'true'">
<Compile Include="Microsoft.Data.SqlClient.TypeForwards.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,10 @@ public uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
/// <param name="async">Asynchronous connection</param>
/// <param name="parallel">Attempt parallel connects</param>
/// <param name="isIntegratedSecurity"></param>
/// <param name="cachedFQDN">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNI handle</returns>
public SNIHandle CreateConnectionHandle(object callbackObject, string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity)
public SNIHandle CreateConnectionHandle(object callbackObject, string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
instanceName = new byte[1];

Expand All @@ -291,7 +293,7 @@ public SNIHandle CreateConnectionHandle(object callbackObject, string fullServer
case DataSource.Protocol.Admin:
case DataSource.Protocol.None: // default to using tcp if no protocol is provided
case DataSource.Protocol.TCP:
sniHandle = CreateTcpHandle(details, timerExpire, callbackObject, parallel);
sniHandle = CreateTcpHandle(details, timerExpire, callbackObject, parallel, cachedFQDN, ref pendingDNSInfo);
break;
case DataSource.Protocol.NP:
sniHandle = CreateNpHandle(details, timerExpire, callbackObject, parallel);
Expand Down Expand Up @@ -373,8 +375,10 @@ private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrIns
/// <param name="timerExpire">Timer expiration</param>
/// <param name="callbackObject">Asynchronous I/O callback object</param>
/// <param name="parallel">Should MultiSubnetFailover be used</param>
/// <param name="cachedFQDN">Key for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNITCPHandle</returns>
private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, object callbackObject, bool parallel)
private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, object callbackObject, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
// TCP Format:
// tcp:<host name>\<instance name>
Expand Down Expand Up @@ -412,7 +416,7 @@ private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, objec
port = isAdminConnection ? DefaultSqlServerDacPort : DefaultSqlServerPort;
}

return new SNITCPHandle(hostName, port, timerExpire, callbackObject, parallel);
return new SNITCPHandle(hostName, port, timerExpire, callbackObject, parallel, cachedFQDN, ref pendingDNSInfo);
}


Expand Down
Loading

0 comments on commit ead806e

Please sign in to comment.