Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Release 2.1] Fix | Fixes Kerberos auth when SPN does not contain port #935

Merged
merged 1 commit into from
Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,18 @@ private static SecurityStatusPal EstablishSecurityContext(
}
catch (Exception ex)
{
if (NetEventSource.IsEnabled) NetEventSource.Error(null, ex);
if (NetEventSource.IsEnabled)
{
NetEventSource.Error(null, ex);
}
return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, ex);
}
}

internal static SecurityStatusPal InitializeSecurityContext(
SafeFreeCredentials credentialsHandle,
ref SafeDeleteContext securityContext,
string spn,
string[] spns,
ContextFlagsPal requestedContextFlags,
SecurityBuffer[] inSecurityBufferArray,
SecurityBuffer outSecurityBuffer,
Expand All @@ -156,20 +159,33 @@ internal static SecurityStatusPal InitializeSecurityContext(
}

SafeFreeNegoCredentials negoCredentialsHandle = (SafeFreeNegoCredentials)credentialsHandle;
SecurityStatusPal status = default;

if (negoCredentialsHandle.IsDefault && string.IsNullOrEmpty(spn))
foreach (string spn in spns)
{
throw new PlatformNotSupportedException(Strings.net_nego_not_supported_empty_target_with_defaultcreds);
}
if (negoCredentialsHandle.IsDefault && string.IsNullOrEmpty(spn))
{
throw new PlatformNotSupportedException(Strings.net_nego_not_supported_empty_target_with_defaultcreds);
}

SecurityStatusPal status = EstablishSecurityContext(
negoCredentialsHandle,
ref securityContext,
spn,
requestedContextFlags,
((inSecurityBufferArray != null && inSecurityBufferArray.Length != 0) ? inSecurityBufferArray[0] : null),
outSecurityBuffer,
ref contextFlags);
status = EstablishSecurityContext(
negoCredentialsHandle,
ref securityContext,
spn,
requestedContextFlags,
((inSecurityBufferArray != null && inSecurityBufferArray.Length != 0) ? inSecurityBufferArray[0] : null),
outSecurityBuffer,
ref contextFlags);

if (status.ErrorCode != SecurityStatusPalErrorCode.InternalError)
{
break; // Successful case, exit the loop with current SPN.
}
else
{
securityContext = null; // Reset security context to be generated again for next SPN.
}
}

// Confidentiality flag should not be set if not requested
if (status.ErrorCode == SecurityStatusPalErrorCode.CompleteNeeded)
Expand All @@ -180,7 +196,6 @@ internal static SecurityStatusPal InitializeSecurityContext(
throw new PlatformNotSupportedException(Strings.net_nego_protection_level_not_supported);
}
}

return status;
}

Expand Down Expand Up @@ -224,7 +239,7 @@ internal static SafeFreeCredentials AcquireCredentialsHandle(string package, boo
new SafeFreeNegoCredentials(false, string.Empty, string.Empty, string.Empty) :
new SafeFreeNegoCredentials(ntlmOnly, credential.UserName, credential.Password, credential.Domain);
}
catch(Exception ex)
catch (Exception ex)
{
throw new Win32Exception(NTE_FAIL, ex.Message);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ internal static string QueryContextAuthenticationPackage(SafeDeleteContext secur
internal static SecurityStatusPal InitializeSecurityContext(
SafeFreeCredentials credentialsHandle,
ref SafeDeleteContext securityContext,
string spn,
string[] spn,
ContextFlagsPal requestedContextFlags,
SecurityBuffer[] inSecurityBufferArray,
SecurityBuffer outSecurityBuffer,
Expand All @@ -81,7 +81,7 @@ internal static SecurityStatusPal InitializeSecurityContext(
GlobalSSPI.SSPIAuth,
credentialsHandle,
ref securityContext,
spn,
spn[0],
ContextFlagsAdapterPal.GetInteropFromContextFlagsPal(requestedContextFlags),
Interop.SspiCli.Endianness.SECURITY_NETWORK_DREP,
inSecurityBufferArray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ internal uint DisableSsl(SNIHandle handle)
/// <param name="sendBuff">Send buffer</param>
/// <param name="serverName">Service Principal Name buffer</param>
/// <returns>SNI error code</returns>
internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[] serverName)
internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
{
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
Expand Down Expand Up @@ -104,12 +104,15 @@ internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStat
| ContextFlagsPal.Delegate
| ContextFlagsPal.MutualAuth;

string serverSPN = System.Text.Encoding.UTF8.GetString(serverName);

string[] serverSPNs = new string[serverName.Length];
for (int i = 0; i < serverName.Length; i++)
{
serverSPNs[i] = System.Text.Encoding.UTF8.GetString(serverName[i]);
}
SecurityStatusPal statusCode = NegotiateStreamPal.InitializeSecurityContext(
credentialsHandle,
ref securityContext,
serverSPN,
serverSPNs,
requestedContextFlags,
inSecurityBufferArray,
outSecurityBuffer,
Expand Down Expand Up @@ -253,7 +256,7 @@ internal uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
/// <param name="cachedFQDN">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNI handle</returns>
internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
instanceName = new byte[1];

Expand Down Expand Up @@ -294,7 +297,7 @@ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniO
{
try
{
spnBuffer = GetSqlServerSPN(details);
spnBuffer = GetSqlServerSPNs(details);
}
catch (Exception e)
{
Expand All @@ -305,7 +308,7 @@ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniO
return sniHandle;
}

private static byte[] GetSqlServerSPN(DataSource dataSource)
private static byte[][] GetSqlServerSPNs(DataSource dataSource)
{
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));

Expand All @@ -319,16 +322,11 @@ private static byte[] GetSqlServerSPN(DataSource dataSource)
{
postfix = dataSource.InstanceName;
}
// For handling tcp:<hostname> format
else if (dataSource._connectionProtocol == DataSource.Protocol.TCP)
{
postfix = DefaultSqlServerPort.ToString();
}

return GetSqlServerSPN(hostName, postfix);
return GetSqlServerSPNs(hostName, postfix, dataSource._connectionProtocol);
}

private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrInstanceName)
private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
{
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
IPHostEntry hostEntry = null;
Expand All @@ -347,16 +345,22 @@ private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrIns
// If the DNS lookup failed, then resort to using the user provided hostname to construct the SPN.
fullyQualifiedDomainName = hostEntry?.HostName ?? hostNameOrAddress;
}

string serverSpn = SqlServerSpnHeader + "/" + fullyQualifiedDomainName;

if (!string.IsNullOrWhiteSpace(portOrInstanceName))
{
serverSpn += ":" + portOrInstanceName;
}
else
else if (protocol == DataSource.Protocol.None || protocol == DataSource.Protocol.TCP) // Default is TCP
{
serverSpn += $":{DefaultSqlServerPort}";
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
// Set both SPNs with and without Port as Port is optional for default instance
return new byte[][] { Encoding.UTF8.GetBytes(serverSpn), Encoding.UTF8.GetBytes(serverSpnWithDefaultPort) };
}
return Encoding.UTF8.GetBytes(serverSpn);
// else Named Pipes do not need to valid port

return new byte[][] { Encoding.UTF8.GetBytes(serverSpn) };
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ internal sealed partial class TdsParser

private bool _isDenali = false;

private byte[] _sniSpnBuffer = null;
private byte[][] _sniSpnBuffer = null;

// SqlStatistics
private SqlStatistics _statistics = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ private void ResetCancelAndProcessAttention()
}
}

internal abstract void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity = false);
internal abstract void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity = false);

internal abstract void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey, ref SQLDNSInfo pendingDNSInfo);

Expand Down Expand Up @@ -831,7 +831,7 @@ private void ResetCancelAndProcessAttention()

protected abstract void RemovePacketFromPendingList(PacketHandle pointer);

internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer);
internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer);

internal bool Deactivate()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async)
protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
=> SNIProxy.GetInstance().PacketGetData(packet.ManagedPacket, _inBuff, ref dataSize);

internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
{
_sessionHandle = SNIProxy.GetInstance().CreateConnectionHandle(serverName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref spnBuffer, flushCache, async, parallel, isIntegratedSecurity, cachedFQDN, ref pendingDNSInfo);
if (_sessionHandle == null)
Expand Down Expand Up @@ -215,7 +215,7 @@ internal override uint EnableMars(ref uint info)

internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) => SNIProxy.GetInstance().SetConnectionBufferSize(Handle, unsignedPacketSize);

internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer)
internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
{
if (_sspiClientContextStatus == null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache

if (string.IsNullOrEmpty(userProtocol))
{

result = SNINativeMethodWrapper.SniGetProviderNumber(Handle, ref providerNumber);
Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetProviderNumber");
_parser.isTcpProtocol = (providerNumber == SNINativeMethodWrapper.ProviderEnum.TCP_PROV);
}
else if (userProtocol == TdsEnums.TCP)
else if (userProtocol == TdsEnums.TCP)
{
_parser.isTcpProtocol = true;
}
Expand Down Expand Up @@ -138,14 +138,14 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async)
return myInfo;
}

internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
{
// We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer
spnBuffer = null;
spnBuffer = new byte[1][];
if (isIntegratedSecurity)
{
// now allocate proper length of buffer
spnBuffer = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
spnBuffer[0] = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
}

SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async);
Expand All @@ -172,7 +172,7 @@ internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSni
SQLDNSInfo cachedDNSInfo;
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);

_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel, cachedDNSInfo);
_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel, cachedDNSInfo);
}

protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
Expand Down Expand Up @@ -385,8 +385,8 @@ internal override uint EnableSsl(ref uint info)
internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize)
=> SNINativeMethodWrapper.SNISetInfo(Handle, SNINativeMethodWrapper.QTypes.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);

internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer)
=> SNINativeMethodWrapper.SNISecGenClientContext(Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer);
internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
=> SNINativeMethodWrapper.SNISecGenClientContext(Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer[0]);

internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
{
Expand Down Expand Up @@ -421,7 +421,7 @@ internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
protocolVersion = (int)SslProtocols.Ssl2;
#pragma warning restore CS0618 // Type or member is obsolete : SSL is depricated
}
else if(nativeProtocol.HasFlag(NativeProtocols.SP_PROT_NONE))
else if (nativeProtocol.HasFlag(NativeProtocols.SP_PROT_NONE))
{
protocolVersion = (int)SslProtocols.None;
}
Expand Down