diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Unix.cs b/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Unix.cs index 5858b77b44..70c1a74377 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Unix.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Unix.cs @@ -135,7 +135,10 @@ private static SecurityStatusPal EstablishSecurityContext( } catch (Exception ex) { - if (NetEventSource.IsEnabled) NetEventSource.Error(null, ex); + if (NetEventSource.IsEnabled) + { + NetEventSource.Error(null, ex); + } return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, ex); } } @@ -143,7 +146,7 @@ private static SecurityStatusPal EstablishSecurityContext( internal static SecurityStatusPal InitializeSecurityContext( SafeFreeCredentials credentialsHandle, ref SafeDeleteContext securityContext, - string spn, + string[] spns, ContextFlagsPal requestedContextFlags, SecurityBuffer[] inSecurityBufferArray, SecurityBuffer outSecurityBuffer, @@ -156,20 +159,33 @@ internal static SecurityStatusPal InitializeSecurityContext( } SafeFreeNegoCredentials negoCredentialsHandle = (SafeFreeNegoCredentials)credentialsHandle; + SecurityStatusPal status = default; - if (negoCredentialsHandle.IsDefault && string.IsNullOrEmpty(spn)) + foreach (string spn in spns) { - throw new PlatformNotSupportedException(Strings.net_nego_not_supported_empty_target_with_defaultcreds); - } + if (negoCredentialsHandle.IsDefault && string.IsNullOrEmpty(spn)) + { + throw new PlatformNotSupportedException(Strings.net_nego_not_supported_empty_target_with_defaultcreds); + } - SecurityStatusPal status = EstablishSecurityContext( - negoCredentialsHandle, - ref securityContext, - spn, - requestedContextFlags, - ((inSecurityBufferArray != null && inSecurityBufferArray.Length != 0) ? inSecurityBufferArray[0] : null), - outSecurityBuffer, - ref contextFlags); + status = EstablishSecurityContext( + negoCredentialsHandle, + ref securityContext, + spn, + requestedContextFlags, + ((inSecurityBufferArray != null && inSecurityBufferArray.Length != 0) ? inSecurityBufferArray[0] : null), + outSecurityBuffer, + ref contextFlags); + + if (status.ErrorCode != SecurityStatusPalErrorCode.InternalError) + { + break; // Successful case, exit the loop with current SPN. + } + else + { + securityContext = null; // Reset security context to be generated again for next SPN. + } + } // Confidentiality flag should not be set if not requested if (status.ErrorCode == SecurityStatusPalErrorCode.CompleteNeeded) @@ -180,7 +196,6 @@ internal static SecurityStatusPal InitializeSecurityContext( throw new PlatformNotSupportedException(Strings.net_nego_protection_level_not_supported); } } - return status; } @@ -224,7 +239,7 @@ internal static SafeFreeCredentials AcquireCredentialsHandle(string package, boo new SafeFreeNegoCredentials(false, string.Empty, string.Empty, string.Empty) : new SafeFreeNegoCredentials(ntlmOnly, credential.UserName, credential.Password, credential.Domain); } - catch(Exception ex) + catch (Exception ex) { throw new Win32Exception(NTE_FAIL, ex.Message); } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs b/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs index 58bf635657..18a0b14cfe 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs @@ -70,7 +70,7 @@ internal static string QueryContextAuthenticationPackage(SafeDeleteContext secur internal static SecurityStatusPal InitializeSecurityContext( SafeFreeCredentials credentialsHandle, ref SafeDeleteContext securityContext, - string spn, + string[] spn, ContextFlagsPal requestedContextFlags, SecurityBuffer[] inSecurityBufferArray, SecurityBuffer outSecurityBuffer, @@ -81,7 +81,7 @@ internal static SecurityStatusPal InitializeSecurityContext( GlobalSSPI.SSPIAuth, credentialsHandle, ref securityContext, - spn, + spn[0], ContextFlagsAdapterPal.GetInteropFromContextFlagsPal(requestedContextFlags), Interop.SspiCli.Endianness.SECURITY_NETWORK_DREP, inSecurityBufferArray, diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index 55ee594a0b..bbe07eb693 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -72,7 +72,7 @@ internal uint DisableSsl(SNIHandle handle) /// Send buffer /// Service Principal Name buffer /// SNI error code - internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[] serverName) + internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName) { SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext; ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags; @@ -104,12 +104,15 @@ internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStat | ContextFlagsPal.Delegate | ContextFlagsPal.MutualAuth; - string serverSPN = System.Text.Encoding.UTF8.GetString(serverName); - + string[] serverSPNs = new string[serverName.Length]; + for (int i = 0; i < serverName.Length; i++) + { + serverSPNs[i] = System.Text.Encoding.UTF8.GetString(serverName[i]); + } SecurityStatusPal statusCode = NegotiateStreamPal.InitializeSecurityContext( credentialsHandle, ref securityContext, - serverSPN, + serverSPNs, requestedContextFlags, inSecurityBufferArray, outSecurityBuffer, @@ -253,7 +256,7 @@ internal uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync) /// Used for DNS Cache /// Used for DNS Cache /// SNI handle - internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) + internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { instanceName = new byte[1]; @@ -294,7 +297,7 @@ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniO { try { - spnBuffer = GetSqlServerSPN(details); + spnBuffer = GetSqlServerSPNs(details); } catch (Exception e) { @@ -305,7 +308,7 @@ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniO return sniHandle; } - private static byte[] GetSqlServerSPN(DataSource dataSource) + private static byte[][] GetSqlServerSPNs(DataSource dataSource) { Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName)); @@ -319,16 +322,11 @@ private static byte[] GetSqlServerSPN(DataSource dataSource) { postfix = dataSource.InstanceName; } - // For handling tcp: format - else if (dataSource._connectionProtocol == DataSource.Protocol.TCP) - { - postfix = DefaultSqlServerPort.ToString(); - } - return GetSqlServerSPN(hostName, postfix); + return GetSqlServerSPNs(hostName, postfix, dataSource._connectionProtocol); } - private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrInstanceName) + private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol) { Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress)); IPHostEntry hostEntry = null; @@ -347,16 +345,22 @@ private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrIns // If the DNS lookup failed, then resort to using the user provided hostname to construct the SPN. fullyQualifiedDomainName = hostEntry?.HostName ?? hostNameOrAddress; } + string serverSpn = SqlServerSpnHeader + "/" + fullyQualifiedDomainName; + if (!string.IsNullOrWhiteSpace(portOrInstanceName)) { serverSpn += ":" + portOrInstanceName; } - else + else if (protocol == DataSource.Protocol.None || protocol == DataSource.Protocol.TCP) // Default is TCP { - serverSpn += $":{DefaultSqlServerPort}"; + string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}"; + // Set both SPNs with and without Port as Port is optional for default instance + return new byte[][] { Encoding.UTF8.GetBytes(serverSpn), Encoding.UTF8.GetBytes(serverSpnWithDefaultPort) }; } - return Encoding.UTF8.GetBytes(serverSpn); + // else Named Pipes do not need to valid port + + return new byte[][] { Encoding.UTF8.GetBytes(serverSpn) }; } /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 511ceb57c2..7abe7e9582 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -114,7 +114,7 @@ internal sealed partial class TdsParser private bool _isDenali = false; - private byte[] _sniSpnBuffer = null; + private byte[][] _sniSpnBuffer = null; // SqlStatistics private SqlStatistics _statistics = null; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 55c73b99b9..94f7f32ea9 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -789,7 +789,7 @@ private void ResetCancelAndProcessAttention() } } - internal abstract void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity = false); + internal abstract void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity = false); internal abstract void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey, ref SQLDNSInfo pendingDNSInfo); @@ -831,7 +831,7 @@ private void ResetCancelAndProcessAttention() protected abstract void RemovePacketFromPendingList(PacketHandle pointer); - internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer); + internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer); internal bool Deactivate() { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index 48a6196f28..e016488603 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -49,7 +49,7 @@ internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async) protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) => SNIProxy.GetInstance().PacketGetData(packet.ManagedPacket, _inBuff, ref dataSize); - internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity) + internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity) { _sessionHandle = SNIProxy.GetInstance().CreateConnectionHandle(serverName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref spnBuffer, flushCache, async, parallel, isIntegratedSecurity, cachedFQDN, ref pendingDNSInfo); if (_sessionHandle == null) @@ -215,7 +215,7 @@ internal override uint EnableMars(ref uint info) internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) => SNIProxy.GetInstance().SetConnectionBufferSize(Handle, unsignedPacketSize); - internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer) + internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer) { if (_sspiClientContextStatus == null) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index 2e638a0502..1b7deb2a0b 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -80,12 +80,12 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache if (string.IsNullOrEmpty(userProtocol)) { - + result = SNINativeMethodWrapper.SniGetProviderNumber(Handle, ref providerNumber); Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetProviderNumber"); _parser.isTcpProtocol = (providerNumber == SNINativeMethodWrapper.ProviderEnum.TCP_PROV); } - else if (userProtocol == TdsEnums.TCP) + else if (userProtocol == TdsEnums.TCP) { _parser.isTcpProtocol = true; } @@ -138,14 +138,14 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async) return myInfo; } - internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity) + internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity) { // We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer - spnBuffer = null; + spnBuffer = new byte[1][]; if (isIntegratedSecurity) { // now allocate proper length of buffer - spnBuffer = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength]; + spnBuffer[0] = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength]; } SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async); @@ -172,7 +172,7 @@ internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSni SQLDNSInfo cachedDNSInfo; bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo); - _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel, cachedDNSInfo); + _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel, cachedDNSInfo); } protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) @@ -385,8 +385,8 @@ internal override uint EnableSsl(ref uint info) internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) => SNINativeMethodWrapper.SNISetInfo(Handle, SNINativeMethodWrapper.QTypes.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize); - internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer) - => SNINativeMethodWrapper.SNISecGenClientContext(Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer); + internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer) + => SNINativeMethodWrapper.SNISecGenClientContext(Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer[0]); internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion) { @@ -421,7 +421,7 @@ internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion) protocolVersion = (int)SslProtocols.Ssl2; #pragma warning restore CS0618 // Type or member is obsolete : SSL is depricated } - else if(nativeProtocol.HasFlag(NativeProtocols.SP_PROT_NONE)) + else if (nativeProtocol.HasFlag(NativeProtocols.SP_PROT_NONE)) { protocolVersion = (int)SslProtocols.None; }