diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs index 6980eb09f2..815f4e13d7 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Net; using System.Net.Security; using System.Security.Cryptography.X509Certificates; @@ -194,6 +195,15 @@ internal static bool ValidateSslServerCertificate(string targetServerName, X509C } } + internal static IPAddress[] GetDnsIpAddresses(string serverName) + { + using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) + { + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Getting DNS host entries for serverName {0}.", args0: serverName); + return Dns.GetHostAddresses(serverName); + } + } + /// /// Sets last error encountered for SNI /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index 73f2c7e6e7..a9c773ac08 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -279,14 +279,7 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i Socket availableSocket = null; Task connectTask; - Task serverAddrTask = Dns.GetHostAddressesAsync(hostName); - bool complete = serverAddrTask.Wait(ts); - - // DNS timed out - don't block - if (!complete) - return null; - - IPAddress[] serverAddresses = serverAddrTask.Result; + IPAddress[] serverAddresses = SNICommon.GetDnsIpAddresses(hostName); if (serverAddresses.Length > MaxParallelIpAddresses) { @@ -338,14 +331,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference)); - Task serverAddrTask = Dns.GetHostAddressesAsync(serverName); - bool complete = serverAddrTask.Wait(timeout); - - // DNS timed out - don't block - if (!complete) - return null; - - IPAddress[] ipAddresses = serverAddrTask.Result; + IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(serverName); string IPv4String = null; string IPv6String = null; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index afd0dde34d..22c7dc77ae 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -164,7 +164,16 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re Debug.Assert(port >= 0 && port <= 65535, "Invalid port"); Debug.Assert(requestPacket != null && requestPacket.Length > 0, "requestPacket should not be null or 0-length array"); - bool isIpAddress = IPAddress.TryParse(browserHostname, out IPAddress address); + if (IPAddress.TryParse(browserHostname, out IPAddress address)) + { + SsrpResult response = SendUDPRequest(new IPAddress[] { address }, port, requestPacket, allIPsInParallel); + if (response != null && response.ResponsePacket != null) + return response.ResponsePacket; + else if (response != null && response.Error != null) + throw response.Error; + else + return null; + } TimeSpan ts = default; // In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count @@ -175,27 +184,7 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts; } - IPAddress[] ipAddresses = null; - if (!isIpAddress) - { - Task serverAddrTask = Dns.GetHostAddressesAsync(browserHostname); - bool taskComplete; - try - { - taskComplete = serverAddrTask.Wait(ts); - } - catch (AggregateException ae) - { - throw ae.InnerException; - } - - // If DNS took too long, need to return instead of blocking - if (!taskComplete) - return null; - - ipAddresses = serverAddrTask.Result; - } - + IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(browserHostname); Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve"); switch (ipPreference) @@ -272,7 +261,7 @@ private static SsrpResult SendUDPRequest(IPAddress[] ipAddresses, int port, byte for (int i = 0; i < ipAddresses.Length; i++) { IPEndPoint endPoint = new IPEndPoint(ipAddresses[i], port); - tasks.Add(Task.Factory.StartNew(() => SendUDPRequest(endPoint, requestPacket))); + tasks.Add(Task.Factory.StartNew(() => SendUDPRequest(endPoint, requestPacket), cts.Token)); } List> completedTasks = new(); diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/ConnectivityTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/ConnectivityTest.cs index 41a52f48ac..6d1d09334e 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/ConnectivityTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/ConnectivityTest.cs @@ -85,6 +85,21 @@ public static void EnvironmentHostNameSPIDTest() Assert.True(false, "No non-empty hostname found for the application"); } + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + public static async void ConnectionTimeoutInfiniteTest() + { + // Exercise the special-case infinite connect timeout code path + SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString) + { + ConnectTimeout = 0 // Infinite + }; + + using SqlConnection conn = new(builder.ConnectionString); + CancellationTokenSource cts = new(30000); + // Will throw TaskCanceledException and fail the test in the event of a hang + await conn.OpenAsync(cts.Token); + } + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] public static void ConnectionTimeoutTestWithThread() { diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs index ceba949d55..3b264419bb 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Net; using System.Net.Sockets; using System.Text; using System.Threading.Tasks; @@ -12,7 +13,7 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests { public static class InstanceNameTest { - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.AreConnStringsSetup))] public static void ConnectToSQLWithInstanceNameTest() { SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString); @@ -20,9 +21,9 @@ public static void ConnectToSQLWithInstanceNameTest() bool proceed = true; string dataSourceStr = builder.DataSource.Replace("tcp:", ""); string[] serverNamePartsByBackSlash = dataSourceStr.Split('\\'); + string hostname = serverNamePartsByBackSlash[0]; if (!dataSourceStr.Contains(",") && serverNamePartsByBackSlash.Length == 2) { - string hostname = serverNamePartsByBackSlash[0]; proceed = !string.IsNullOrWhiteSpace(hostname) && IsBrowserAlive(hostname); } @@ -31,6 +32,14 @@ public static void ConnectToSQLWithInstanceNameTest() using SqlConnection connection = new(builder.ConnectionString); connection.Open(); connection.Close(); + + // Exercise the IP address-specific code in SSRP + IPAddress[] addresses = Dns.GetHostAddresses(hostname); + builder.DataSource = builder.DataSource.Replace(hostname, addresses[0].ToString()); + builder.TrustServerCertificate = true; + using SqlConnection connection2 = new(builder.ConnectionString); + connection2.Open(); + connection2.Close(); } }