diff --git a/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs b/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs index d2f7b516556d76..219809bd315bc7 100644 --- a/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/libraries/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs @@ -142,7 +142,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba } else { - _socket = Connect(serverName, port, ts); + _socket = Connect(serverName, port, ts, isInfiniteTimeOut); } if (_socket == null || !_socket.Connected) @@ -177,7 +177,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba _status = TdsEnums.SNI_SUCCESS; } - private static Socket Connect(string serverName, int port, TimeSpan timeout) + private static Socket Connect(string serverName, int port, TimeSpan timeout, bool isInfiniteTimeout) { IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName); IPAddress serverIPv4 = null; @@ -196,8 +196,8 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout) ipAddresses = new IPAddress[] { serverIPv4, serverIPv6 }; Socket[] sockets = new Socket[2]; - CancellationTokenSource cts = new CancellationTokenSource(); - cts.CancelAfter(timeout); + CancellationTokenSource cts = null; + void Cancel() { for (int i = 0; i < sockets.Length; ++i) @@ -213,35 +213,47 @@ void Cancel() catch { } } } - cts.Token.Register(Cancel); + + if (!isInfiniteTimeout) + { + cts = new CancellationTokenSource(timeout); + cts.Token.Register(Cancel); + } Socket availableSocket = null; - for (int i = 0; i < sockets.Length; ++i) + try { - try + for (int i = 0; i < sockets.Length; ++i) { - if (ipAddresses[i] != null) + try { - sockets[i] = new Socket(ipAddresses[i].AddressFamily, SocketType.Stream, ProtocolType.Tcp); - // enable keep-alive on socket - SNITcpHandle.SetKeepAliveValues(ref sockets[i]); - sockets[i].Connect(ipAddresses[i], port); - if (sockets[i] != null) // sockets[i] can be null if cancel callback is executed during connect() + if (ipAddresses[i] != null) { - if (sockets[i].Connected) - { - availableSocket = sockets[i]; - break; - } - else + sockets[i] = new Socket(ipAddresses[i].AddressFamily, SocketType.Stream, ProtocolType.Tcp); + // enable keep-alive on socket + SNITcpHandle.SetKeepAliveValues(ref sockets[i]); + sockets[i].Connect(ipAddresses[i], port); + if (sockets[i] != null) // sockets[i] can be null if cancel callback is executed during connect() { - sockets[i].Dispose(); - sockets[i] = null; + if (sockets[i].Connected) + { + availableSocket = sockets[i]; + break; + } + else + { + sockets[i].Dispose(); + sockets[i] = null; + } } } } + catch { } } - catch { } + } + finally + { + cts?.Dispose(); } return availableSocket;