Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DNS Caching Phase 1 #594

Merged
merged 10 commits into from
Jun 16, 2020
14 changes: 14 additions & 0 deletions doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,20 @@ GO
]]></format>
</remarks>
</RetrieveStatistics>
<RetrieveInternalInfo>
<summary>Returns a name value pair collection of internal properties at the point in time the method is called.</summary>
<returns>Returns a reference of type <see cref="T:System.Collections.Generic.IDictionary" /> of (string, object) items.</returns>
<remarks>
<format type="text/markdown"><![CDATA[

## Remarks
When this method is called, the values retrieved are those at the current point in time. If you continue using the connection, the values are incorrect. You need to re-execute the method to obtain the most current values.

Supported internal properties: `SQLDNSCachingSupportedState` and `SQLDNSCachingSupportedStateBeforeRedirect`.
karinazhou marked this conversation as resolved.
Show resolved Hide resolved

]]></format>
</remarks>
</RetrieveInternalInfo>
<ServerVersion>
<summary>Gets a string that contains the version of the instance of SQL Server to which the client is connected.</summary>
<value>The version of the instance of SQL Server.</value>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,18 @@ public SqlConnection(string connectionString, Microsoft.Data.SqlClient.SqlCreden
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/ClientConnectionId/*'/>
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
public System.Guid ClientConnectionId { get { throw null; } }

///
/// for internal test only
///
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
internal string SQLDNSCachingSupportedState { get { throw null; } }
///
/// for internal test only
///
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
internal string SQLDNSCachingSupportedStateBeforeRedirect { get { throw null; } }

object System.ICloneable.Clone() { throw null; }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/ConnectionString/*'/>
[System.ComponentModel.DefaultValueAttribute("")]
Expand Down Expand Up @@ -599,6 +611,9 @@ public void Open(SqlConnectionOverrides overrides) { }
public void ResetStatistics() { }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RetrieveStatistics/*'/>
public System.Collections.IDictionary RetrieveStatistics() { throw null; }

/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RetrieveInternalInfo/*'/>
public System.Collections.Generic.IDictionary<string, object> RetrieveInternalInfo() { throw null; }
}
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnectionOverrides.xml' path='docs/members[@name="SqlConnectionOverrides"]/SqlConnectionOverrides/*' />
public enum SqlConnectionOverrides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.Data.SqlClient;
using System;
using System.Runtime.InteropServices;
using System.Text;

namespace Microsoft.Data.SqlClient
{
Expand All @@ -20,6 +21,8 @@ internal static partial class SNINativeMethodWrapper
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
internal delegate void SqlAsyncCallbackDelegate(IntPtr m_ConsKey, IntPtr pPacket, uint dwError);

internal const int SniIP6AddrStringBufferLength = 48; // from SNI layer

internal static int SniMaxComposedSpnLength
{
get
Expand Down Expand Up @@ -162,6 +165,20 @@ private unsafe struct SNI_CLIENT_CONSUMER_INFO
public TransparentNetworkResolutionMode transparentNetworkResolution;
public int totalTimeout;
public bool isAzureSqlServerEndpoint;
public SNI_DNSCache_Info DNSCacheInfo;
}

[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
internal struct SNI_DNSCache_Info
{
[MarshalAs(UnmanagedType.LPWStr)]
public string wszCachedFQDN;
[MarshalAs(UnmanagedType.LPWStr)]
public string wszCachedTcpIPv4;
[MarshalAs(UnmanagedType.LPWStr)]
public string wszCachedTcpIPv6;
[MarshalAs(UnmanagedType.LPWStr)]
public string wszCachedTcpPort;
}

[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
Expand Down Expand Up @@ -236,6 +253,17 @@ internal struct SNI_Error
[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out Guid pbQInfo);

// kz start
karinazhou marked this conversation as resolved.
Show resolved Hide resolved
[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out ushort portNum);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Unicode)]
private static extern uint SNIGetPeerAddrStrWrapper([In] SNIHandle pConn, int bufferSize, StringBuilder addrBuffer, out uint addrLen);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out ProviderEnum provNum);
// kz end

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern uint SNIInitialize([In] IntPtr pmo);

Expand All @@ -248,7 +276,8 @@ private static extern uint SNIOpenWrapper(
[MarshalAs(UnmanagedType.LPWStr)] string szConnect,
[In] SNIHandle pConn,
out IntPtr ppConn,
[MarshalAs(UnmanagedType.Bool)] bool fSync);
[MarshalAs(UnmanagedType.Bool)] bool fSync,
[In] ref SNI_DNSCache_Info pDNSCachedInfo);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern IntPtr SNIPacketAllocateWrapper([In] SafeHandle pConn, IOType IOType);
Expand Down Expand Up @@ -283,22 +312,55 @@ internal static uint SniGetConnectionId(SNIHandle pConn, ref Guid connId)
{
return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_CONNID, out connId);
}

// kz start
internal static uint SniGetProviderNumber(SNIHandle pConn, ref ProviderEnum provNum)
{
return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_PROVIDERNUM, out provNum);
}

internal static uint SniGetConnectionPort(SNIHandle pConn, ref ushort portNum)
{
return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_PEERPORT, out portNum);
}

internal static uint SniGetConnectionIPString(SNIHandle pConn, ref string connIPStr)
{
UInt32 ret;
uint connIPLen = 0;

int bufferSize = SniIP6AddrStringBufferLength;
StringBuilder addrBuffer = new StringBuilder(bufferSize);

ret = SNIGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out connIPLen);

connIPStr = addrBuffer.ToString(0, Convert.ToInt32(connIPLen));

return ret;
}
// kz end

internal static uint SNIInitialize()
{
return SNIInitialize(IntPtr.Zero);
}

internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync)
internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync, SQLDNSInfo cachedDNSInfo)
{
// initialize consumer info for MARS
Sni_Consumer_Info native_consumerInfo = new Sni_Consumer_Info();
MarshalConsumerInfo(consumerInfo, ref native_consumerInfo);

return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync);
SNI_DNSCache_Info native_cachedDNSInfo = new SNI_DNSCache_Info();
native_cachedDNSInfo.wszCachedFQDN = cachedDNSInfo?.FQDN;
native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4;
native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port;

return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ref native_cachedDNSInfo);
}

internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache, bool fSync, int timeout, bool fParallel)
internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache, bool fSync, int timeout, bool fParallel, SQLDNSInfo cachedDNSInfo)
{
fixed (byte* pin_instanceName = &instanceName[0])
{
Expand All @@ -321,6 +383,11 @@ internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string cons
clientConsumerInfo.totalTimeout = SniOpenTimeOut;
clientConsumerInfo.isAzureSqlServerEndpoint = ADP.IsAzureSqlServerEndpoint(constring);

clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port;

if (spnBuffer != null)
{
fixed (byte* pin_spnBuffer = &spnBuffer[0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@
<Compile Include="..\..\src\Microsoft\Data\SqlTypes\SqlTypeWorkarounds.cs">
<Link>Microsoft\Data\SqlTypes\SqlTypeWorkarounds.cs</Link>
</Compile>
<Compile Include="..\..\src\Microsoft\Data\SqlClient\SQLDNSCache.cs">
<Link>Microsoft\Data\SqlClient\SQLDNSCache.cs</Link>
</Compile>
</ItemGroup>
<ItemGroup Condition="'$(TargetGroup)' == 'netstandard' OR '$(TargetGroup)' == 'netcoreapp' OR '$(IsUAPAssembly)' == 'true'">
<Compile Include="Microsoft.Data.SqlClient.TypeForwards.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,10 @@ public uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
/// <param name="async">Asynchronous connection</param>
/// <param name="parallel">Attempt parallel connects</param>
/// <param name="isIntegratedSecurity"></param>
/// <param name="cachedFQDN">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNI handle</returns>
public SNIHandle CreateConnectionHandle(object callbackObject, string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity)
public SNIHandle CreateConnectionHandle(object callbackObject, string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
instanceName = new byte[1];

Expand All @@ -291,7 +293,7 @@ public SNIHandle CreateConnectionHandle(object callbackObject, string fullServer
case DataSource.Protocol.Admin:
case DataSource.Protocol.None: // default to using tcp if no protocol is provided
case DataSource.Protocol.TCP:
sniHandle = CreateTcpHandle(details, timerExpire, callbackObject, parallel);
sniHandle = CreateTcpHandle(details, timerExpire, callbackObject, parallel, cachedFQDN, ref pendingDNSInfo);
break;
case DataSource.Protocol.NP:
sniHandle = CreateNpHandle(details, timerExpire, callbackObject, parallel);
Expand Down Expand Up @@ -373,8 +375,10 @@ private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrIns
/// <param name="timerExpire">Timer expiration</param>
/// <param name="callbackObject">Asynchronous I/O callback object</param>
/// <param name="parallel">Should MultiSubnetFailover be used</param>
/// <param name="cachedFQDN">Key for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNITCPHandle</returns>
private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, object callbackObject, bool parallel)
private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, object callbackObject, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
// TCP Format:
// tcp:<host name>\<instance name>
Expand Down Expand Up @@ -412,7 +416,7 @@ private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, objec
port = isAdminConnection ? DefaultSqlServerDacPort : DefaultSqlServerPort;
}

return new SNITCPHandle(hostName, port, timerExpire, callbackObject, parallel);
return new SNITCPHandle(hostName, port, timerExpire, callbackObject, parallel, cachedFQDN, ref pendingDNSInfo);
}


Expand Down
Loading