diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Unix.cs b/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Unix.cs
index 5858b77b44..70c1a74377 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Unix.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Unix.cs
@@ -135,7 +135,10 @@ private static SecurityStatusPal EstablishSecurityContext(
}
catch (Exception ex)
{
- if (NetEventSource.IsEnabled) NetEventSource.Error(null, ex);
+ if (NetEventSource.IsEnabled)
+ {
+ NetEventSource.Error(null, ex);
+ }
return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, ex);
}
}
@@ -143,7 +146,7 @@ private static SecurityStatusPal EstablishSecurityContext(
internal static SecurityStatusPal InitializeSecurityContext(
SafeFreeCredentials credentialsHandle,
ref SafeDeleteContext securityContext,
- string spn,
+ string[] spns,
ContextFlagsPal requestedContextFlags,
SecurityBuffer[] inSecurityBufferArray,
SecurityBuffer outSecurityBuffer,
@@ -156,20 +159,33 @@ internal static SecurityStatusPal InitializeSecurityContext(
}
SafeFreeNegoCredentials negoCredentialsHandle = (SafeFreeNegoCredentials)credentialsHandle;
+ SecurityStatusPal status = default;
- if (negoCredentialsHandle.IsDefault && string.IsNullOrEmpty(spn))
+ foreach (string spn in spns)
{
- throw new PlatformNotSupportedException(Strings.net_nego_not_supported_empty_target_with_defaultcreds);
- }
+ if (negoCredentialsHandle.IsDefault && string.IsNullOrEmpty(spn))
+ {
+ throw new PlatformNotSupportedException(Strings.net_nego_not_supported_empty_target_with_defaultcreds);
+ }
- SecurityStatusPal status = EstablishSecurityContext(
- negoCredentialsHandle,
- ref securityContext,
- spn,
- requestedContextFlags,
- ((inSecurityBufferArray != null && inSecurityBufferArray.Length != 0) ? inSecurityBufferArray[0] : null),
- outSecurityBuffer,
- ref contextFlags);
+ status = EstablishSecurityContext(
+ negoCredentialsHandle,
+ ref securityContext,
+ spn,
+ requestedContextFlags,
+ ((inSecurityBufferArray != null && inSecurityBufferArray.Length != 0) ? inSecurityBufferArray[0] : null),
+ outSecurityBuffer,
+ ref contextFlags);
+
+ if (status.ErrorCode != SecurityStatusPalErrorCode.InternalError)
+ {
+ break; // Successful case, exit the loop with current SPN.
+ }
+ else
+ {
+ securityContext = null; // Reset security context to be generated again for next SPN.
+ }
+ }
// Confidentiality flag should not be set if not requested
if (status.ErrorCode == SecurityStatusPalErrorCode.CompleteNeeded)
@@ -180,7 +196,6 @@ internal static SecurityStatusPal InitializeSecurityContext(
throw new PlatformNotSupportedException(Strings.net_nego_protection_level_not_supported);
}
}
-
return status;
}
@@ -224,7 +239,7 @@ internal static SafeFreeCredentials AcquireCredentialsHandle(string package, boo
new SafeFreeNegoCredentials(false, string.Empty, string.Empty, string.Empty) :
new SafeFreeNegoCredentials(ntlmOnly, credential.UserName, credential.Password, credential.Domain);
}
- catch(Exception ex)
+ catch (Exception ex)
{
throw new Win32Exception(NTE_FAIL, ex.Message);
}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs b/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs
index 58bf635657..18a0b14cfe 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/Security/NegotiateStreamPal.Windows.cs
@@ -70,7 +70,7 @@ internal static string QueryContextAuthenticationPackage(SafeDeleteContext secur
internal static SecurityStatusPal InitializeSecurityContext(
SafeFreeCredentials credentialsHandle,
ref SafeDeleteContext securityContext,
- string spn,
+ string[] spn,
ContextFlagsPal requestedContextFlags,
SecurityBuffer[] inSecurityBufferArray,
SecurityBuffer outSecurityBuffer,
@@ -81,7 +81,7 @@ internal static SecurityStatusPal InitializeSecurityContext(
GlobalSSPI.SSPIAuth,
credentialsHandle,
ref securityContext,
- spn,
+ spn[0],
ContextFlagsAdapterPal.GetInteropFromContextFlagsPal(requestedContextFlags),
Interop.SspiCli.Endianness.SECURITY_NETWORK_DREP,
inSecurityBufferArray,
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs
index 55ee594a0b..bbe07eb693 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs
@@ -72,7 +72,7 @@ internal uint DisableSsl(SNIHandle handle)
/// Send buffer
/// Service Principal Name buffer
/// SNI error code
- internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[] serverName)
+ internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
{
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
@@ -104,12 +104,15 @@ internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStat
| ContextFlagsPal.Delegate
| ContextFlagsPal.MutualAuth;
- string serverSPN = System.Text.Encoding.UTF8.GetString(serverName);
-
+ string[] serverSPNs = new string[serverName.Length];
+ for (int i = 0; i < serverName.Length; i++)
+ {
+ serverSPNs[i] = System.Text.Encoding.UTF8.GetString(serverName[i]);
+ }
SecurityStatusPal statusCode = NegotiateStreamPal.InitializeSecurityContext(
credentialsHandle,
ref securityContext,
- serverSPN,
+ serverSPNs,
requestedContextFlags,
inSecurityBufferArray,
outSecurityBuffer,
@@ -253,7 +256,7 @@ internal uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
/// Used for DNS Cache
/// Used for DNS Cache
/// SNI handle
- internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
+ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
instanceName = new byte[1];
@@ -294,7 +297,7 @@ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniO
{
try
{
- spnBuffer = GetSqlServerSPN(details);
+ spnBuffer = GetSqlServerSPNs(details);
}
catch (Exception e)
{
@@ -305,7 +308,7 @@ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniO
return sniHandle;
}
- private static byte[] GetSqlServerSPN(DataSource dataSource)
+ private static byte[][] GetSqlServerSPNs(DataSource dataSource)
{
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
@@ -319,16 +322,11 @@ private static byte[] GetSqlServerSPN(DataSource dataSource)
{
postfix = dataSource.InstanceName;
}
- // For handling tcp: format
- else if (dataSource._connectionProtocol == DataSource.Protocol.TCP)
- {
- postfix = DefaultSqlServerPort.ToString();
- }
- return GetSqlServerSPN(hostName, postfix);
+ return GetSqlServerSPNs(hostName, postfix, dataSource._connectionProtocol);
}
- private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrInstanceName)
+ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
{
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
IPHostEntry hostEntry = null;
@@ -347,16 +345,22 @@ private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrIns
// If the DNS lookup failed, then resort to using the user provided hostname to construct the SPN.
fullyQualifiedDomainName = hostEntry?.HostName ?? hostNameOrAddress;
}
+
string serverSpn = SqlServerSpnHeader + "/" + fullyQualifiedDomainName;
+
if (!string.IsNullOrWhiteSpace(portOrInstanceName))
{
serverSpn += ":" + portOrInstanceName;
}
- else
+ else if (protocol == DataSource.Protocol.None || protocol == DataSource.Protocol.TCP) // Default is TCP
{
- serverSpn += $":{DefaultSqlServerPort}";
+ string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
+ // Set both SPNs with and without Port as Port is optional for default instance
+ return new byte[][] { Encoding.UTF8.GetBytes(serverSpn), Encoding.UTF8.GetBytes(serverSpnWithDefaultPort) };
}
- return Encoding.UTF8.GetBytes(serverSpn);
+ // else Named Pipes do not need to valid port
+
+ return new byte[][] { Encoding.UTF8.GetBytes(serverSpn) };
}
///
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
index 511ceb57c2..7abe7e9582 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
@@ -114,7 +114,7 @@ internal sealed partial class TdsParser
private bool _isDenali = false;
- private byte[] _sniSpnBuffer = null;
+ private byte[][] _sniSpnBuffer = null;
// SqlStatistics
private SqlStatistics _statistics = null;
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs
index 55c73b99b9..94f7f32ea9 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs
@@ -789,7 +789,7 @@ private void ResetCancelAndProcessAttention()
}
}
- internal abstract void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity = false);
+ internal abstract void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity = false);
internal abstract void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey, ref SQLDNSInfo pendingDNSInfo);
@@ -831,7 +831,7 @@ private void ResetCancelAndProcessAttention()
protected abstract void RemovePacketFromPendingList(PacketHandle pointer);
- internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer);
+ internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer);
internal bool Deactivate()
{
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs
index 48a6196f28..e016488603 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs
@@ -49,7 +49,7 @@ internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async)
protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
=> SNIProxy.GetInstance().PacketGetData(packet.ManagedPacket, _inBuff, ref dataSize);
- internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
+ internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
{
_sessionHandle = SNIProxy.GetInstance().CreateConnectionHandle(serverName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref spnBuffer, flushCache, async, parallel, isIntegratedSecurity, cachedFQDN, ref pendingDNSInfo);
if (_sessionHandle == null)
@@ -215,7 +215,7 @@ internal override uint EnableMars(ref uint info)
internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) => SNIProxy.GetInstance().SetConnectionBufferSize(Handle, unsignedPacketSize);
- internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer)
+ internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
{
if (_sspiClientContextStatus == null)
{
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs
index 2e638a0502..1b7deb2a0b 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs
@@ -80,12 +80,12 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache
if (string.IsNullOrEmpty(userProtocol))
{
-
+
result = SNINativeMethodWrapper.SniGetProviderNumber(Handle, ref providerNumber);
Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetProviderNumber");
_parser.isTcpProtocol = (providerNumber == SNINativeMethodWrapper.ProviderEnum.TCP_PROV);
}
- else if (userProtocol == TdsEnums.TCP)
+ else if (userProtocol == TdsEnums.TCP)
{
_parser.isTcpProtocol = true;
}
@@ -138,14 +138,14 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async)
return myInfo;
}
- internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
+ internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
{
// We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer
- spnBuffer = null;
+ spnBuffer = new byte[1][];
if (isIntegratedSecurity)
{
// now allocate proper length of buffer
- spnBuffer = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
+ spnBuffer[0] = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
}
SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async);
@@ -172,7 +172,7 @@ internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSni
SQLDNSInfo cachedDNSInfo;
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);
- _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel, cachedDNSInfo);
+ _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel, cachedDNSInfo);
}
protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
@@ -385,8 +385,8 @@ internal override uint EnableSsl(ref uint info)
internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize)
=> SNINativeMethodWrapper.SNISetInfo(Handle, SNINativeMethodWrapper.QTypes.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);
- internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer)
- => SNINativeMethodWrapper.SNISecGenClientContext(Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer);
+ internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
+ => SNINativeMethodWrapper.SNISecGenClientContext(Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer[0]);
internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
{
@@ -421,7 +421,7 @@ internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
protocolVersion = (int)SslProtocols.Ssl2;
#pragma warning restore CS0618 // Type or member is obsolete : SSL is depricated
}
- else if(nativeProtocol.HasFlag(NativeProtocols.SP_PROT_NONE))
+ else if (nativeProtocol.HasFlag(NativeProtocols.SP_PROT_NONE))
{
protocolVersion = (int)SslProtocols.None;
}