From 67f9f7539517b4baaa37b7893ef12e7405a5e25d Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Wed, 13 Nov 2024 17:38:11 -0600 Subject: [PATCH 01/12] Cleanup member variables of SniNativeWrapper --- .../Interop/Windows/Sni/SniNativeWrapper.cs | 129 +++++++++--------- 1 file changed, 66 insertions(+), 63 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index cd28e2f162..2a5c663537 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -18,24 +18,27 @@ namespace Microsoft.Data.SqlClient { internal static class SniNativeWrapper { -#if NETFRAMEWORK - private static readonly ISniNativeMethods NativeMethods = RuntimeInformation.ProcessArchitecture switch + #region Member Variables + + private const int SniIpv6AddrStringBufferLength = 48; + private const int SniOpenTimeOut = -1; + + #if NETFRAMEWORK + private static readonly ISniNativeMethods s_nativeMethods = RuntimeInformation.ProcessArchitecture switch { Architecture.Arm64 => new SniNativeMethodsArm64(), Architecture.X64 => new SniNativeMethodsX64(), Architecture.X86 => new SniNativeMethodsX86(), _ => new SniNativeMethodsNotSupported(RuntimeInformation.ProcessArchitecture) }; -#else - private static readonly ISniNativeMethods NativeMethods = new SniNativeMethods(); -#endif - + #else + private static readonly SniNativeMethods s_nativeMethods = new SniNativeMethods(); + #endif + private static int s_sniMaxComposedSpnLength = -1; - - private const int SniOpenTimeOut = -1; // infinite - - internal const int SniIP6AddrStringBufferLength = 48; // from SNI layer - + + #endregion + internal static int SniMaxComposedSpnLength { get @@ -48,7 +51,7 @@ internal static int SniMaxComposedSpnLength } } -#if NETFRAMEWORK + #if NETFRAMEWORK static AppDomain GetDefaultAppDomainInternal() { return AppDomain.CurrentDomain; @@ -86,86 +89,86 @@ internal unsafe static void SetData(Byte[] data) SqlDependencyProcessDispatcherStorage.NativeSetData(pin_dispatcher, data.Length); } } -#endif + #endif #region DLL Imports internal static uint SNIAddProvider(SNIHandle pConn, Provider ProvNum, [In] ref uint pInfo) => - NativeMethods.SniAddProvider(pConn, ProvNum, ref pInfo); + s_nativeMethods.SniAddProvider(pConn, ProvNum, ref pInfo); internal static uint SNIAddProvider(SNIHandle pConn, Provider ProvNum, [In] ref AuthProviderInfo pInfo) => - NativeMethods.SniAddProvider(pConn, ProvNum, ref pInfo); + s_nativeMethods.SniAddProvider(pConn, ProvNum, ref pInfo); internal static uint SNICheckConnection([In] SNIHandle pConn) => - NativeMethods.SniCheckConnection(pConn); + s_nativeMethods.SniCheckConnection(pConn); internal static uint SNIClose(IntPtr pConn) => - NativeMethods.SniClose(pConn); + s_nativeMethods.SniClose(pConn); internal static void SNIGetLastError(out SniError pErrorStruct) => - NativeMethods.SniGetLastError(out pErrorStruct); + s_nativeMethods.SniGetLastError(out pErrorStruct); internal static void SNIPacketRelease(IntPtr pPacket) => - NativeMethods.SniPacketRelease(pPacket); + s_nativeMethods.SniPacketRelease(pPacket); internal static void SNIPacketReset([In] SNIHandle pConn, IoType IOType, SNIPacket pPacket, ConsumerNumber ConsNum) => - NativeMethods.SniPacketReset(pConn, IOType, pPacket, ConsNum); + s_nativeMethods.SniPacketReset(pConn, IOType, pPacket, ConsNum); internal static uint SNIQueryInfo(QueryType QType, ref uint pbQInfo) => - NativeMethods.SniQueryInfo(QType, ref pbQInfo); + s_nativeMethods.SniQueryInfo(QType, ref pbQInfo); internal static uint SNIQueryInfo(QueryType QType, ref IntPtr pbQInfo) => - NativeMethods.SniQueryInfo(QType, ref pbQInfo); + s_nativeMethods.SniQueryInfo(QType, ref pbQInfo); internal static uint SNIReadAsync(SNIHandle pConn, ref IntPtr ppNewPacket) => - NativeMethods.SniReadAsync(pConn, ref ppNewPacket); + s_nativeMethods.SniReadAsync(pConn, ref ppNewPacket); internal static uint SNIReadSyncOverAsync(SNIHandle pConn, ref IntPtr ppNewPacket, int timeout) => - NativeMethods.SniReadSyncOverAsync(pConn, ref ppNewPacket, timeout); + s_nativeMethods.SniReadSyncOverAsync(pConn, ref ppNewPacket, timeout); internal static uint SNIRemoveProvider(SNIHandle pConn, Provider ProvNum) => - NativeMethods.SniRemoveProvider(pConn, ProvNum); + s_nativeMethods.SniRemoveProvider(pConn, ProvNum); internal static uint SNISecInitPackage(ref uint pcbMaxToken) => - NativeMethods.SniSecInitPackage(ref pcbMaxToken); + s_nativeMethods.SniSecInitPackage(ref pcbMaxToken); internal static uint SNISetInfo(SNIHandle pConn, QueryType QType, [In] ref uint pbQInfo) => - NativeMethods.SniSetInfo(pConn, QType, ref pbQInfo); + s_nativeMethods.SniSetInfo(pConn, QType, ref pbQInfo); internal static uint SNITerminate() => - NativeMethods.SniTerminate(); + s_nativeMethods.SniTerminate(); internal static uint SNIWaitForSSLHandshakeToComplete([In] SNIHandle pConn, int dwMilliseconds, out uint pProtocolVersion) => - NativeMethods.SniWaitForSslHandshakeToComplete(pConn, dwMilliseconds, out pProtocolVersion); + s_nativeMethods.SniWaitForSslHandshakeToComplete(pConn, dwMilliseconds, out pProtocolVersion); internal static uint UnmanagedIsTokenRestricted([In] IntPtr token, [MarshalAs(UnmanagedType.Bool)] out bool isRestricted) => - NativeMethods.SniIsTokenRestricted(token, out isRestricted); + s_nativeMethods.SniIsTokenRestricted(token, out isRestricted); private static uint GetSniMaxComposedSpnLength() => - NativeMethods.SniGetMaxComposedSpnLength(); + s_nativeMethods.SniGetMaxComposedSpnLength(); private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out Guid pbQInfo) => - NativeMethods.SniGetInfoWrapper(pConn, QType, out pbQInfo); + s_nativeMethods.SniGetInfoWrapper(pConn, QType, out pbQInfo); -#if NETFRAMEWORK + #if NETFRAMEWORK private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, [MarshalAs(UnmanagedType.Bool)] out bool pbQInfo) => - NativeMethods.SniGetInfoWrapper(pConn, QType, out pbQInfo); -#endif + s_nativeMethods.SniGetInfoWrapper(pConn, QType, out pbQInfo); + #endif private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out ushort portNum) => - NativeMethods.SniGetInfoWrapper(pConn, QType, out portNum); + s_nativeMethods.SniGetInfoWrapper(pConn, QType, out portNum); private static uint SNIGetPeerAddrStrWrapper([In] SNIHandle pConn, int bufferSize, StringBuilder addrBuffer, out uint addrLen) => - NativeMethods.SniGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out addrLen); + s_nativeMethods.SniGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out addrLen); private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out Provider provNum) => - NativeMethods.SniGetInfoWrapper(pConn, QType, out provNum); + s_nativeMethods.SniGetInfoWrapper(pConn, QType, out provNum); private static uint SNIInitialize([In] IntPtr pmo) => - NativeMethods.SniInitialize(pmo); + s_nativeMethods.SniInitialize(pmo); private static uint SNIOpenSyncExWrapper(ref SniClientConsumerInfo pClientConsumerInfo, out IntPtr ppConn) => - NativeMethods.SniOpenSyncExWrapper(ref pClientConsumerInfo, out ppConn); + s_nativeMethods.SniOpenSyncExWrapper(ref pClientConsumerInfo, out ppConn); private static uint SNIOpenWrapper( [In] ref SniConsumerInfo pConsumerInfo, @@ -175,7 +178,7 @@ private static uint SNIOpenWrapper( [MarshalAs(UnmanagedType.Bool)] bool fSync, SqlConnectionIPAddressPreference ipPreference, [In] ref SniDnsCacheInfo pDNSCachedInfo) => - NativeMethods.SniOpenWrapper( + s_nativeMethods.SniOpenWrapper( ref pConsumerInfo, szConnect, pConn, @@ -185,14 +188,14 @@ private static uint SNIOpenWrapper( ref pDNSCachedInfo); private static IntPtr SNIPacketAllocateWrapper([In] SafeHandle pConn, IoType IOType) => - NativeMethods.SniPacketAllocateWrapper(pConn, IOType); + s_nativeMethods.SniPacketAllocateWrapper(pConn, IOType); private static uint SNIPacketGetDataWrapper([In] IntPtr packet, [In, Out] byte[] readBuffer, uint readBufferLength, out uint dataSize) => - NativeMethods.SniPacketGetDataWrapper(packet, readBuffer, readBufferLength, out dataSize); + s_nativeMethods.SniPacketGetDataWrapper(packet, readBuffer, readBufferLength, out dataSize); private static unsafe void SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf, uint cbBuf) => - NativeMethods.SniPacketSetData(pPacket, pbBuf, cbBuf); - + s_nativeMethods.SniPacketSetData(pPacket, pbBuf, cbBuf); + private static unsafe uint SNISecGenClientContextWrapper( [In] SNIHandle pConn, [In, Out] ReadOnlySpan pIn, @@ -210,7 +213,7 @@ private static unsafe uint SNISecGenClientContextWrapper( fixed (byte* pOutPtr = pOut) fixed (byte* pServerInfo = serverInfo) { - return NativeMethods.SniSecGenClientContextWrapper( + return s_nativeMethods.SniSecGenClientContextWrapper( pConn, pInPtr, (uint)pIn.Length, @@ -225,23 +228,23 @@ private static unsafe uint SNISecGenClientContextWrapper( } private static uint SNIWriteAsyncWrapper(SNIHandle pConn, [In] SNIPacket pPacket) => - NativeMethods.SniWriteAsyncWrapper(pConn, pPacket); + s_nativeMethods.SniWriteAsyncWrapper(pConn, pPacket); private static uint SNIWriteSyncOverAsync(SNIHandle pConn, [In] SNIPacket pPacket) => - NativeMethods.SniWriteSyncOverAsync(pConn, pPacket); + s_nativeMethods.SniWriteSyncOverAsync(pConn, pPacket); internal static IntPtr SNIServerEnumOpen() => - NativeMethods.SniServerEnumOpen(); + s_nativeMethods.SniServerEnumOpen(); internal static void SNIServerEnumClose([In] IntPtr packet) => - NativeMethods.SniServerEnumClose(packet); + s_nativeMethods.SniServerEnumClose(packet); internal static int SNIServerEnumRead( [In] IntPtr packet, [In][MarshalAs(UnmanagedType.LPArray)] char[] readBuffer, [In] int bufferLength, [MarshalAs(UnmanagedType.Bool)] out bool more) => - NativeMethods.SniServerEnumRead(packet, readBuffer, bufferLength, out more); + s_nativeMethods.SniServerEnumRead(packet, readBuffer, bufferLength, out more); #endregion @@ -265,7 +268,7 @@ internal static uint SniGetConnectionIPString(SNIHandle pConn, ref string connIP UInt32 ret; uint connIPLen = 0; - int bufferSize = SniIP6AddrStringBufferLength; + int bufferSize = SniIpv6AddrStringBufferLength; StringBuilder addrBuffer = new StringBuilder(bufferSize); ret = SNIGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out connIPLen); @@ -305,12 +308,12 @@ internal static unsafe uint SNIOpenSyncEx( bool fSync, int timeout, bool fParallel, - -#if NETFRAMEWORK + + #if NETFRAMEWORK Int32 transparentNetworkResolutionStateNo, Int32 totalTimeout, -#endif - + #endif + SqlConnectionIPAddressPreference ipPreference, SQLDNSInfo cachedDNSInfo, string hostNameInCertificate) @@ -332,7 +335,7 @@ internal static unsafe uint SNIOpenSyncEx( clientConsumerInfo.timeout = timeout; clientConsumerInfo.fParallel = fParallel; -#if NETFRAMEWORK + #if NETFRAMEWORK switch (transparentNetworkResolutionStateNo) { case (0): @@ -346,10 +349,10 @@ internal static unsafe uint SNIOpenSyncEx( break; }; clientConsumerInfo.totalTimeout = totalTimeout; -#else + #else clientConsumerInfo.transparentNetworkResolution = TransparentNetworkResolutionMode.DisabledMode; clientConsumerInfo.totalTimeout = SniOpenTimeOut; -#endif + #endif clientConsumerInfo.isAzureSqlServerEndpoint = ADP.IsAzureSqlServerEndpoint(constring); @@ -422,7 +425,7 @@ internal static unsafe uint SNIOpenSyncEx( } } -#if NETFRAMEWORK + #if NETFRAMEWORK [ResourceExposure(ResourceScope.None)] [ResourceConsumption(ResourceScope.Machine, ResourceScope.Machine)] internal static uint SNIAddProvider(SNIHandle pConn, @@ -445,7 +448,7 @@ internal static uint SNIAddProvider(SNIHandle pConn, return ret; } -#endif + #endif internal static void SNIPacketAllocate(SafeHandle pConn, IoType IOType, ref IntPtr pPacket) { @@ -465,7 +468,7 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int } } -#if NETFRAMEWORK + #if NETFRAMEWORK //[ResourceExposure(ResourceScope::None)] // // Notes on SecureString: Writing out security sensitive information to managed buffer should be avoid as these can be moved @@ -589,7 +592,7 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w } } } -#endif + #endif internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan inBuff, Span outBuff, ref uint sendLength, string serverUserName) { From 87ddbb4c69a2d409419c99de457f7ad5b7b2610b Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Wed, 13 Nov 2024 17:54:09 -0600 Subject: [PATCH 02/12] Sort public methods (as per target names) --- .../Interop/Windows/Sni/SniNativeWrapper.cs | 445 +++++++++--------- 1 file changed, 224 insertions(+), 221 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index 2a5c663537..c29cfc3be0 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -51,218 +51,50 @@ internal static int SniMaxComposedSpnLength } } + #region Public Methods + + internal static uint SNIAddProvider(SNIHandle pConn, Provider ProvNum, [In] ref AuthProviderInfo pInfo) => + s_nativeMethods.SniAddProvider(pConn, ProvNum, ref pInfo); + #if NETFRAMEWORK - static AppDomain GetDefaultAppDomainInternal() + [ResourceExposure(ResourceScope.None)] + [ResourceConsumption(ResourceScope.Machine, ResourceScope.Machine)] + internal static uint SNIAddProvider(SNIHandle pConn, + Provider providerEnum, + AuthProviderInfo authInfo) { - return AppDomain.CurrentDomain; - } + UInt32 ret; + uint ERROR_SUCCESS = 0; - internal static _AppDomain GetDefaultAppDomain() - { - return GetDefaultAppDomainInternal(); - } + Debug.Assert(authInfo.clientCertificateCallback == null, "CTAIP support has been removed"); - [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider - [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] - internal unsafe static byte[] GetData() - { - int size; - IntPtr ptr = (IntPtr)(SqlDependencyProcessDispatcherStorage.NativeGetData(out size)); - byte[] result = null; + ret = SNIAddProvider(pConn, providerEnum, ref authInfo); - if (ptr != IntPtr.Zero) + if (ret == ERROR_SUCCESS) { - result = new byte[size]; - Marshal.Copy(ptr, result, 0, size); + // added a provider, need to requery for sync over async support + ret = SNIGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_SUPPORTS_SYNC_OVER_ASYNC, out bool _); + Debug.Assert(ret == ERROR_SUCCESS, "SNIGetInfo cannot fail with this QType"); } - return result; - } - - [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider - [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] - internal unsafe static void SetData(Byte[] data) - { - //cli::pin_ptr pin_dispatcher = &data[0]; - fixed (byte* pin_dispatcher = &data[0]) - { - SqlDependencyProcessDispatcherStorage.NativeSetData(pin_dispatcher, data.Length); - } + return ret; } #endif - - #region DLL Imports - + internal static uint SNIAddProvider(SNIHandle pConn, Provider ProvNum, [In] ref uint pInfo) => s_nativeMethods.SniAddProvider(pConn, ProvNum, ref pInfo); - - internal static uint SNIAddProvider(SNIHandle pConn, Provider ProvNum, [In] ref AuthProviderInfo pInfo) => - s_nativeMethods.SniAddProvider(pConn, ProvNum, ref pInfo); - + internal static uint SNICheckConnection([In] SNIHandle pConn) => s_nativeMethods.SniCheckConnection(pConn); - + internal static uint SNIClose(IntPtr pConn) => s_nativeMethods.SniClose(pConn); - - internal static void SNIGetLastError(out SniError pErrorStruct) => - s_nativeMethods.SniGetLastError(out pErrorStruct); - - internal static void SNIPacketRelease(IntPtr pPacket) => - s_nativeMethods.SniPacketRelease(pPacket); - - internal static void SNIPacketReset([In] SNIHandle pConn, IoType IOType, SNIPacket pPacket, ConsumerNumber ConsNum) => - s_nativeMethods.SniPacketReset(pConn, IOType, pPacket, ConsNum); - - internal static uint SNIQueryInfo(QueryType QType, ref uint pbQInfo) => - s_nativeMethods.SniQueryInfo(QType, ref pbQInfo); - - internal static uint SNIQueryInfo(QueryType QType, ref IntPtr pbQInfo) => - s_nativeMethods.SniQueryInfo(QType, ref pbQInfo); - - internal static uint SNIReadAsync(SNIHandle pConn, ref IntPtr ppNewPacket) => - s_nativeMethods.SniReadAsync(pConn, ref ppNewPacket); - - internal static uint SNIReadSyncOverAsync(SNIHandle pConn, ref IntPtr ppNewPacket, int timeout) => - s_nativeMethods.SniReadSyncOverAsync(pConn, ref ppNewPacket, timeout); - - internal static uint SNIRemoveProvider(SNIHandle pConn, Provider ProvNum) => - s_nativeMethods.SniRemoveProvider(pConn, ProvNum); - - internal static uint SNISecInitPackage(ref uint pcbMaxToken) => - s_nativeMethods.SniSecInitPackage(ref pcbMaxToken); - - internal static uint SNISetInfo(SNIHandle pConn, QueryType QType, [In] ref uint pbQInfo) => - s_nativeMethods.SniSetInfo(pConn, QType, ref pbQInfo); - - internal static uint SNITerminate() => - s_nativeMethods.SniTerminate(); - - internal static uint SNIWaitForSSLHandshakeToComplete([In] SNIHandle pConn, int dwMilliseconds, out uint pProtocolVersion) => - s_nativeMethods.SniWaitForSslHandshakeToComplete(pConn, dwMilliseconds, out pProtocolVersion); - - internal static uint UnmanagedIsTokenRestricted([In] IntPtr token, [MarshalAs(UnmanagedType.Bool)] out bool isRestricted) => - s_nativeMethods.SniIsTokenRestricted(token, out isRestricted); - - private static uint GetSniMaxComposedSpnLength() => - s_nativeMethods.SniGetMaxComposedSpnLength(); - - private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out Guid pbQInfo) => - s_nativeMethods.SniGetInfoWrapper(pConn, QType, out pbQInfo); - - #if NETFRAMEWORK - private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, [MarshalAs(UnmanagedType.Bool)] out bool pbQInfo) => - s_nativeMethods.SniGetInfoWrapper(pConn, QType, out pbQInfo); - #endif - - private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out ushort portNum) => - s_nativeMethods.SniGetInfoWrapper(pConn, QType, out portNum); - - private static uint SNIGetPeerAddrStrWrapper([In] SNIHandle pConn, int bufferSize, StringBuilder addrBuffer, out uint addrLen) => - s_nativeMethods.SniGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out addrLen); - - private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out Provider provNum) => - s_nativeMethods.SniGetInfoWrapper(pConn, QType, out provNum); - - private static uint SNIInitialize([In] IntPtr pmo) => - s_nativeMethods.SniInitialize(pmo); - - private static uint SNIOpenSyncExWrapper(ref SniClientConsumerInfo pClientConsumerInfo, out IntPtr ppConn) => - s_nativeMethods.SniOpenSyncExWrapper(ref pClientConsumerInfo, out ppConn); - - private static uint SNIOpenWrapper( - [In] ref SniConsumerInfo pConsumerInfo, - [MarshalAs(UnmanagedType.LPWStr)] string szConnect, - [In] SNIHandle pConn, - out IntPtr ppConn, - [MarshalAs(UnmanagedType.Bool)] bool fSync, - SqlConnectionIPAddressPreference ipPreference, - [In] ref SniDnsCacheInfo pDNSCachedInfo) => - s_nativeMethods.SniOpenWrapper( - ref pConsumerInfo, - szConnect, - pConn, - out ppConn, - fSync, - ipPreference, - ref pDNSCachedInfo); - - private static IntPtr SNIPacketAllocateWrapper([In] SafeHandle pConn, IoType IOType) => - s_nativeMethods.SniPacketAllocateWrapper(pConn, IOType); - - private static uint SNIPacketGetDataWrapper([In] IntPtr packet, [In, Out] byte[] readBuffer, uint readBufferLength, out uint dataSize) => - s_nativeMethods.SniPacketGetDataWrapper(packet, readBuffer, readBufferLength, out dataSize); - - private static unsafe void SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf, uint cbBuf) => - s_nativeMethods.SniPacketSetData(pPacket, pbBuf, cbBuf); - private static unsafe uint SNISecGenClientContextWrapper( - [In] SNIHandle pConn, - [In, Out] ReadOnlySpan pIn, - [In, Out] Span pOut, - [In] ref uint pcbOut, - [MarshalAsAttribute(UnmanagedType.Bool)] - out bool pfDone, - ReadOnlySpan serverInfo, - [MarshalAsAttribute(UnmanagedType.LPWStr)] - string pwszUserName, - [MarshalAsAttribute(UnmanagedType.LPWStr)] - string pwszPassword) - { - fixed (byte* pInPtr = pIn) - fixed (byte* pOutPtr = pOut) - fixed (byte* pServerInfo = serverInfo) - { - return s_nativeMethods.SniSecGenClientContextWrapper( - pConn, - pInPtr, - (uint)pIn.Length, - pOutPtr, - ref pcbOut, - out pfDone, - pServerInfo, - (uint)serverInfo.Length, - pwszUserName, - pwszPassword); - } - } - - private static uint SNIWriteAsyncWrapper(SNIHandle pConn, [In] SNIPacket pPacket) => - s_nativeMethods.SniWriteAsyncWrapper(pConn, pPacket); - - private static uint SNIWriteSyncOverAsync(SNIHandle pConn, [In] SNIPacket pPacket) => - s_nativeMethods.SniWriteSyncOverAsync(pConn, pPacket); - - internal static IntPtr SNIServerEnumOpen() => - s_nativeMethods.SniServerEnumOpen(); - - internal static void SNIServerEnumClose([In] IntPtr packet) => - s_nativeMethods.SniServerEnumClose(packet); - - internal static int SNIServerEnumRead( - [In] IntPtr packet, - [In][MarshalAs(UnmanagedType.LPArray)] char[] readBuffer, - [In] int bufferLength, - [MarshalAs(UnmanagedType.Bool)] out bool more) => - s_nativeMethods.SniServerEnumRead(packet, readBuffer, bufferLength, out more); - - #endregion - internal static uint SniGetConnectionId(SNIHandle pConn, ref Guid connId) { return SNIGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_CONNID, out connId); } - - internal static uint SniGetProviderNumber(SNIHandle pConn, ref Provider provNum) - { - return SNIGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PROVIDERNUM, out provNum); - } - - internal static uint SniGetConnectionPort(SNIHandle pConn, ref ushort portNum) - { - return SNIGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PEERPORT, out portNum); - } - + internal static uint SniGetConnectionIPString(SNIHandle pConn, ref string connIPStr) { UInt32 ret; @@ -277,12 +109,28 @@ internal static uint SniGetConnectionIPString(SNIHandle pConn, ref string connIP return ret; } - + + internal static uint SniGetConnectionPort(SNIHandle pConn, ref ushort portNum) + { + return SNIGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PEERPORT, out portNum); + } + + internal static void SNIGetLastError(out SniError pErrorStruct) => + s_nativeMethods.SniGetLastError(out pErrorStruct); + + internal static uint SniGetProviderNumber(SNIHandle pConn, ref Provider provNum) + { + return SNIGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PROVIDERNUM, out provNum); + } + internal static uint SNIInitialize() { return SNIInitialize(IntPtr.Zero); } - + + internal static uint UnmanagedIsTokenRestricted([In] IntPtr token, [MarshalAs(UnmanagedType.Bool)] out bool isRestricted) => + s_nativeMethods.SniIsTokenRestricted(token, out isRestricted); + internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync, SqlConnectionIPAddressPreference ipPreference, SQLDNSInfo cachedDNSInfo) { // initialize consumer info for MARS @@ -425,31 +273,6 @@ internal static unsafe uint SNIOpenSyncEx( } } - #if NETFRAMEWORK - [ResourceExposure(ResourceScope.None)] - [ResourceConsumption(ResourceScope.Machine, ResourceScope.Machine)] - internal static uint SNIAddProvider(SNIHandle pConn, - Provider providerEnum, - AuthProviderInfo authInfo) - { - UInt32 ret; - uint ERROR_SUCCESS = 0; - - Debug.Assert(authInfo.clientCertificateCallback == null, "CTAIP support has been removed"); - - ret = SNIAddProvider(pConn, providerEnum, ref authInfo); - - if (ret == ERROR_SUCCESS) - { - // added a provider, need to requery for sync over async support - ret = SNIGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_SUPPORTS_SYNC_OVER_ASYNC, out bool _); - Debug.Assert(ret == ERROR_SUCCESS, "SNIGetInfo cannot fail with this QType"); - } - - return ret; - } - #endif - internal static void SNIPacketAllocate(SafeHandle pConn, IoType IOType, ref IntPtr pPacket) { pPacket = SNIPacketAllocateWrapper(pConn, IOType); @@ -459,7 +282,10 @@ internal static unsafe uint SNIPacketGetData(IntPtr packet, byte[] readBuffer, r { return SNIPacketGetDataWrapper(packet, readBuffer, (uint)readBuffer.Length, out dataSize); } - + + internal static void SNIPacketRelease(IntPtr pPacket) => + s_nativeMethods.SniPacketRelease(pPacket); + internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int length) { fixed (byte* pin_data = &data[0]) @@ -593,7 +419,22 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w } } #endif - + + internal static void SNIPacketReset([In] SNIHandle pConn, IoType IOType, SNIPacket pPacket, ConsumerNumber ConsNum) => + s_nativeMethods.SniPacketReset(pConn, IOType, pPacket, ConsNum); + + internal static uint SNIQueryInfo(QueryType QType, ref uint pbQInfo) => + s_nativeMethods.SniQueryInfo(QType, ref pbQInfo); + + internal static uint SNIQueryInfo(QueryType QType, ref IntPtr pbQInfo) => + s_nativeMethods.SniQueryInfo(QType, ref pbQInfo); + + internal static uint SNIReadAsync(SNIHandle pConn, ref IntPtr ppNewPacket) => + s_nativeMethods.SniReadAsync(pConn, ref ppNewPacket); + + internal static uint SNIReadSyncOverAsync(SNIHandle pConn, ref IntPtr ppNewPacket, int timeout) => + s_nativeMethods.SniReadSyncOverAsync(pConn, ref ppNewPacket, timeout); + internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan inBuff, Span outBuff, ref uint sendLength, string serverUserName) { var serverWriter = SqlObjectPools.BufferWriter.Rent(); @@ -617,7 +458,32 @@ internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, SqlObjectPools.BufferWriter.Return(serverWriter); } } - + + internal static uint SNISecInitPackage(ref uint pcbMaxToken) => + s_nativeMethods.SniSecInitPackage(ref pcbMaxToken); + + internal static void SNIServerEnumClose([In] IntPtr packet) => + s_nativeMethods.SniServerEnumClose(packet); + + internal static IntPtr SNIServerEnumOpen() => + s_nativeMethods.SniServerEnumOpen(); + + internal static int SNIServerEnumRead( + [In] IntPtr packet, + [In] [MarshalAs(UnmanagedType.LPArray)] char[] readBuffer, + [In] int bufferLength, + [MarshalAs(UnmanagedType.Bool)] out bool more) => + s_nativeMethods.SniServerEnumRead(packet, readBuffer, bufferLength, out more); + + internal static uint SNISetInfo(SNIHandle pConn, QueryType QType, [In] ref uint pbQInfo) => + s_nativeMethods.SniSetInfo(pConn, QType, ref pbQInfo); + + internal static uint SNITerminate() => + s_nativeMethods.SniTerminate(); + + internal static uint SNIWaitForSSLHandshakeToComplete([In] SNIHandle pConn, int dwMilliseconds, out uint pProtocolVersion) => + s_nativeMethods.SniWaitForSslHandshakeToComplete(pConn, dwMilliseconds, out pProtocolVersion); + internal static uint SNIWritePacket(SNIHandle pConn, SNIPacket packet, bool sync) { if (sync) @@ -629,6 +495,143 @@ internal static uint SNIWritePacket(SNIHandle pConn, SNIPacket packet, bool sync return SNIWriteAsyncWrapper(pConn, packet); } } + + #endregion + + + + #if NETFRAMEWORK + static AppDomain GetDefaultAppDomainInternal() + { + return AppDomain.CurrentDomain; + } + + internal static _AppDomain GetDefaultAppDomain() + { + return GetDefaultAppDomainInternal(); + } + + [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider + [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] + internal unsafe static byte[] GetData() + { + int size; + IntPtr ptr = (IntPtr)(SqlDependencyProcessDispatcherStorage.NativeGetData(out size)); + byte[] result = null; + + if (ptr != IntPtr.Zero) + { + result = new byte[size]; + Marshal.Copy(ptr, result, 0, size); + } + + return result; + } + + [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider + [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] + internal unsafe static void SetData(Byte[] data) + { + //cli::pin_ptr pin_dispatcher = &data[0]; + fixed (byte* pin_dispatcher = &data[0]) + { + SqlDependencyProcessDispatcherStorage.NativeSetData(pin_dispatcher, data.Length); + } + } + #endif + + #region DLL Imports + + private static uint GetSniMaxComposedSpnLength() => + s_nativeMethods.SniGetMaxComposedSpnLength(); + + private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out Guid pbQInfo) => + s_nativeMethods.SniGetInfoWrapper(pConn, QType, out pbQInfo); + + #if NETFRAMEWORK + private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, [MarshalAs(UnmanagedType.Bool)] out bool pbQInfo) => + s_nativeMethods.SniGetInfoWrapper(pConn, QType, out pbQInfo); + #endif + + private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out ushort portNum) => + s_nativeMethods.SniGetInfoWrapper(pConn, QType, out portNum); + + private static uint SNIGetPeerAddrStrWrapper([In] SNIHandle pConn, int bufferSize, StringBuilder addrBuffer, out uint addrLen) => + s_nativeMethods.SniGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out addrLen); + + private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out Provider provNum) => + s_nativeMethods.SniGetInfoWrapper(pConn, QType, out provNum); + + private static uint SNIInitialize([In] IntPtr pmo) => + s_nativeMethods.SniInitialize(pmo); + + private static uint SNIOpenSyncExWrapper(ref SniClientConsumerInfo pClientConsumerInfo, out IntPtr ppConn) => + s_nativeMethods.SniOpenSyncExWrapper(ref pClientConsumerInfo, out ppConn); + + private static uint SNIOpenWrapper( + [In] ref SniConsumerInfo pConsumerInfo, + [MarshalAs(UnmanagedType.LPWStr)] string szConnect, + [In] SNIHandle pConn, + out IntPtr ppConn, + [MarshalAs(UnmanagedType.Bool)] bool fSync, + SqlConnectionIPAddressPreference ipPreference, + [In] ref SniDnsCacheInfo pDNSCachedInfo) => + s_nativeMethods.SniOpenWrapper( + ref pConsumerInfo, + szConnect, + pConn, + out ppConn, + fSync, + ipPreference, + ref pDNSCachedInfo); + + private static IntPtr SNIPacketAllocateWrapper([In] SafeHandle pConn, IoType IOType) => + s_nativeMethods.SniPacketAllocateWrapper(pConn, IOType); + + private static uint SNIPacketGetDataWrapper([In] IntPtr packet, [In, Out] byte[] readBuffer, uint readBufferLength, out uint dataSize) => + s_nativeMethods.SniPacketGetDataWrapper(packet, readBuffer, readBufferLength, out dataSize); + + private static unsafe void SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf, uint cbBuf) => + s_nativeMethods.SniPacketSetData(pPacket, pbBuf, cbBuf); + + private static unsafe uint SNISecGenClientContextWrapper( + [In] SNIHandle pConn, + [In, Out] ReadOnlySpan pIn, + [In, Out] Span pOut, + [In] ref uint pcbOut, + [MarshalAsAttribute(UnmanagedType.Bool)] + out bool pfDone, + ReadOnlySpan serverInfo, + [MarshalAsAttribute(UnmanagedType.LPWStr)] + string pwszUserName, + [MarshalAsAttribute(UnmanagedType.LPWStr)] + string pwszPassword) + { + fixed (byte* pInPtr = pIn) + fixed (byte* pOutPtr = pOut) + fixed (byte* pServerInfo = serverInfo) + { + return s_nativeMethods.SniSecGenClientContextWrapper( + pConn, + pInPtr, + (uint)pIn.Length, + pOutPtr, + ref pcbOut, + out pfDone, + pServerInfo, + (uint)serverInfo.Length, + pwszUserName, + pwszPassword); + } + } + + private static uint SNIWriteAsyncWrapper(SNIHandle pConn, [In] SNIPacket pPacket) => + s_nativeMethods.SniWriteAsyncWrapper(pConn, pPacket); + + private static uint SNIWriteSyncOverAsync(SNIHandle pConn, [In] SNIPacket pPacket) => + s_nativeMethods.SniWriteSyncOverAsync(pConn, pPacket); + + #endregion private static void MarshalConsumerInfo(ConsumerInfo consumerInfo, ref SniConsumerInfo native_consumerInfo) { From f991395cea5c4a3807beb9dcbf8644ef43b50588 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Wed, 13 Nov 2024 17:58:15 -0600 Subject: [PATCH 03/12] Sort private methods (most of these will be removed in next commit) --- .../Interop/Windows/Sni/SniNativeWrapper.cs | 119 +++++++++--------- 1 file changed, 59 insertions(+), 60 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index c29cfc3be0..0e44db1b79 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -497,77 +497,47 @@ internal static uint SNIWritePacket(SNIHandle pConn, SNIPacket packet, bool sync } #endregion - - - - #if NETFRAMEWORK - static AppDomain GetDefaultAppDomainInternal() - { - return AppDomain.CurrentDomain; - } - - internal static _AppDomain GetDefaultAppDomain() - { - return GetDefaultAppDomainInternal(); - } - - [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider - [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] - internal unsafe static byte[] GetData() - { - int size; - IntPtr ptr = (IntPtr)(SqlDependencyProcessDispatcherStorage.NativeGetData(out size)); - byte[] result = null; - if (ptr != IntPtr.Zero) - { - result = new byte[size]; - Marshal.Copy(ptr, result, 0, size); - } - - return result; - } - - [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider - [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] - internal unsafe static void SetData(Byte[] data) - { - //cli::pin_ptr pin_dispatcher = &data[0]; - fixed (byte* pin_dispatcher = &data[0]) - { - SqlDependencyProcessDispatcherStorage.NativeSetData(pin_dispatcher, data.Length); - } - } - #endif - - #region DLL Imports + #region Private Methods private static uint GetSniMaxComposedSpnLength() => s_nativeMethods.SniGetMaxComposedSpnLength(); + private static void MarshalConsumerInfo(ConsumerInfo consumerInfo, ref SniConsumerInfo native_consumerInfo) + { + native_consumerInfo.DefaultUserDataLength = consumerInfo.defaultBufferSize; + native_consumerInfo.fnReadComp = consumerInfo.readDelegate != null + ? Marshal.GetFunctionPointerForDelegate(consumerInfo.readDelegate) + : IntPtr.Zero; + native_consumerInfo.fnWriteComp = consumerInfo.writeDelegate != null + ? Marshal.GetFunctionPointerForDelegate(consumerInfo.writeDelegate) + : IntPtr.Zero; + native_consumerInfo.ConsumerKey = consumerInfo.key; + } + private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out Guid pbQInfo) => s_nativeMethods.SniGetInfoWrapper(pConn, QType, out pbQInfo); - + #if NETFRAMEWORK private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, [MarshalAs(UnmanagedType.Bool)] out bool pbQInfo) => s_nativeMethods.SniGetInfoWrapper(pConn, QType, out pbQInfo); #endif + private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out Provider provNum) => + s_nativeMethods.SniGetInfoWrapper(pConn, QType, out provNum); + private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out ushort portNum) => s_nativeMethods.SniGetInfoWrapper(pConn, QType, out portNum); - + private static uint SNIGetPeerAddrStrWrapper([In] SNIHandle pConn, int bufferSize, StringBuilder addrBuffer, out uint addrLen) => s_nativeMethods.SniGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out addrLen); - - private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out Provider provNum) => - s_nativeMethods.SniGetInfoWrapper(pConn, QType, out provNum); - + private static uint SNIInitialize([In] IntPtr pmo) => s_nativeMethods.SniInitialize(pmo); - + private static uint SNIOpenSyncExWrapper(ref SniClientConsumerInfo pClientConsumerInfo, out IntPtr ppConn) => s_nativeMethods.SniOpenSyncExWrapper(ref pClientConsumerInfo, out ppConn); - + private static uint SNIOpenWrapper( [In] ref SniConsumerInfo pConsumerInfo, [MarshalAs(UnmanagedType.LPWStr)] string szConnect, @@ -632,18 +602,47 @@ private static uint SNIWriteSyncOverAsync(SNIHandle pConn, [In] SNIPacket pPacke s_nativeMethods.SniWriteSyncOverAsync(pConn, pPacket); #endregion + + + #if NETFRAMEWORK + static AppDomain GetDefaultAppDomainInternal() + { + return AppDomain.CurrentDomain; + } - private static void MarshalConsumerInfo(ConsumerInfo consumerInfo, ref SniConsumerInfo native_consumerInfo) + internal static _AppDomain GetDefaultAppDomain() { - native_consumerInfo.DefaultUserDataLength = consumerInfo.defaultBufferSize; - native_consumerInfo.fnReadComp = consumerInfo.readDelegate != null - ? Marshal.GetFunctionPointerForDelegate(consumerInfo.readDelegate) - : IntPtr.Zero; - native_consumerInfo.fnWriteComp = consumerInfo.writeDelegate != null - ? Marshal.GetFunctionPointerForDelegate(consumerInfo.writeDelegate) - : IntPtr.Zero; - native_consumerInfo.ConsumerKey = consumerInfo.key; + return GetDefaultAppDomainInternal(); } + + [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider + [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] + internal unsafe static byte[] GetData() + { + int size; + IntPtr ptr = (IntPtr)(SqlDependencyProcessDispatcherStorage.NativeGetData(out size)); + byte[] result = null; + + if (ptr != IntPtr.Zero) + { + result = new byte[size]; + Marshal.Copy(ptr, result, 0, size); + } + + return result; + } + + [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider + [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] + internal unsafe static void SetData(Byte[] data) + { + //cli::pin_ptr pin_dispatcher = &data[0]; + fixed (byte* pin_dispatcher = &data[0]) + { + SqlDependencyProcessDispatcherStorage.NativeSetData(pin_dispatcher, data.Length); + } + } + #endif } } From 9252f4652bac99b97c8fab0a36d4f2527e05ef22 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Wed, 13 Nov 2024 18:17:56 -0600 Subject: [PATCH 04/12] Remove unneeded private methods --- .../Interop/Windows/Sni/SniNativeWrapper.cs | 187 +++++------------- 1 file changed, 49 insertions(+), 138 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index 0e44db1b79..0e7016a43f 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -45,7 +45,7 @@ internal static int SniMaxComposedSpnLength { if (s_sniMaxComposedSpnLength == -1) { - s_sniMaxComposedSpnLength = checked((int)GetSniMaxComposedSpnLength()); + s_sniMaxComposedSpnLength = checked((int)s_nativeMethods.SniGetMaxComposedSpnLength()); } return s_sniMaxComposedSpnLength; } @@ -73,7 +73,7 @@ internal static uint SNIAddProvider(SNIHandle pConn, if (ret == ERROR_SUCCESS) { // added a provider, need to requery for sync over async support - ret = SNIGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_SUPPORTS_SYNC_OVER_ASYNC, out bool _); + ret = s_nativeMethods.SniGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_SUPPORTS_SYNC_OVER_ASYNC, out bool _); Debug.Assert(ret == ERROR_SUCCESS, "SNIGetInfo cannot fail with this QType"); } @@ -90,10 +90,8 @@ internal static uint SNICheckConnection([In] SNIHandle pConn) => internal static uint SNIClose(IntPtr pConn) => s_nativeMethods.SniClose(pConn); - internal static uint SniGetConnectionId(SNIHandle pConn, ref Guid connId) - { - return SNIGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_CONNID, out connId); - } + internal static uint SniGetConnectionId(SNIHandle pConn, ref Guid connId) => + s_nativeMethods.SniGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_CONNID, out connId); internal static uint SniGetConnectionIPString(SNIHandle pConn, ref string connIPStr) { @@ -103,7 +101,7 @@ internal static uint SniGetConnectionIPString(SNIHandle pConn, ref string connIP int bufferSize = SniIpv6AddrStringBufferLength; StringBuilder addrBuffer = new StringBuilder(bufferSize); - ret = SNIGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out connIPLen); + ret = s_nativeMethods.SniGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out connIPLen); connIPStr = addrBuffer.ToString(0, Convert.ToInt32(connIPLen)); @@ -112,7 +110,7 @@ internal static uint SniGetConnectionIPString(SNIHandle pConn, ref string connIP internal static uint SniGetConnectionPort(SNIHandle pConn, ref ushort portNum) { - return SNIGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PEERPORT, out portNum); + return s_nativeMethods.SniGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PEERPORT, out portNum); } internal static void SNIGetLastError(out SniError pErrorStruct) => @@ -120,13 +118,11 @@ internal static void SNIGetLastError(out SniError pErrorStruct) => internal static uint SniGetProviderNumber(SNIHandle pConn, ref Provider provNum) { - return SNIGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PROVIDERNUM, out provNum); - } - - internal static uint SNIInitialize() - { - return SNIInitialize(IntPtr.Zero); + return s_nativeMethods.SniGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PROVIDERNUM, out provNum); } + + internal static uint SNIInitialize() => + s_nativeMethods.SniInitialize(IntPtr.Zero); internal static uint UnmanagedIsTokenRestricted([In] IntPtr token, [MarshalAs(UnmanagedType.Bool)] out bool isRestricted) => s_nativeMethods.SniIsTokenRestricted(token, out isRestricted); @@ -143,7 +139,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port; - return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ipPreference, ref native_cachedDNSInfo); + return s_nativeMethods.SniOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ipPreference, ref native_cachedDNSInfo); } internal static unsafe uint SNIOpenSyncEx( @@ -212,9 +208,9 @@ internal static unsafe uint SNIOpenSyncEx( if (spn != null) { - // An empty string implies we need to find the SPN so we supply a buffer for the max size if (spn.Length == 0) { + // An empty string implies we need to find the SPN so we supply a buffer for the max size var array = ArrayPool.Shared.Rent(SniMaxComposedSpnLength); array.AsSpan().Clear(); @@ -225,8 +221,7 @@ internal static unsafe uint SNIOpenSyncEx( clientConsumerInfo.szSPN = pin_spnBuffer; clientConsumerInfo.cchSPN = (uint)SniMaxComposedSpnLength; - var result = SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); - + var result = s_nativeMethods.SniOpenSyncExWrapper(ref clientConsumerInfo, out pConn); if (result == 0) { spn = Encoding.Unicode.GetString(array).TrimEnd('\0'); @@ -240,23 +235,23 @@ internal static unsafe uint SNIOpenSyncEx( ArrayPool.Shared.Return(array); } } - - // We have a value of the SPN, so we marshal that and send it to the native layer else { + // We have a value of the SPN, so we marshal that and send it to the native layer var writer = SqlObjectPools.BufferWriter.Rent(); try { // Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code. Encoding.Unicode.GetBytes(spn, writer); - Trace.Assert(writer.WrittenCount <= SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size."); + Trace.Assert(writer.WrittenCount <= SniMaxComposedSpnLength, + "Length of the provided SPN exceeded the buffer size."); fixed (byte* pin_spnBuffer = writer.WrittenSpan) { clientConsumerInfo.szSPN = pin_spnBuffer; clientConsumerInfo.cchSPN = (uint)writer.WrittenCount; - return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + return s_nativeMethods.SniOpenSyncExWrapper(ref clientConsumerInfo, out pConn); } } finally @@ -268,20 +263,16 @@ internal static unsafe uint SNIOpenSyncEx( else { // else leave szSPN null (SQL Auth) - return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + return s_nativeMethods.SniOpenSyncExWrapper(ref clientConsumerInfo, out pConn); } } } - internal static void SNIPacketAllocate(SafeHandle pConn, IoType IOType, ref IntPtr pPacket) - { - pPacket = SNIPacketAllocateWrapper(pConn, IOType); - } - - internal static unsafe uint SNIPacketGetData(IntPtr packet, byte[] readBuffer, ref uint dataSize) - { - return SNIPacketGetDataWrapper(packet, readBuffer, (uint)readBuffer.Length, out dataSize); - } + internal static void SNIPacketAllocate(SafeHandle pConn, IoType IOType, ref IntPtr pPacket) => + pPacket = s_nativeMethods.SniPacketAllocateWrapper(pConn, IOType); + + internal static unsafe uint SNIPacketGetData(IntPtr packet, byte[] readBuffer, ref uint dataSize) => + s_nativeMethods.SniPacketGetDataWrapper(packet, readBuffer, (uint)readBuffer.Length, out dataSize); internal static void SNIPacketRelease(IntPtr pPacket) => s_nativeMethods.SniPacketRelease(pPacket); @@ -290,7 +281,7 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int { fixed (byte* pin_data = &data[0]) { - SNIPacketSetData(packet, pin_data, (uint)length); + s_nativeMethods.SniPacketSetData(packet, pin_data, (uint)length); } } @@ -394,10 +385,7 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w packet.DangerousAddRef(ref mustRelease); Debug.Assert(mustRelease, "AddRef Failed!"); - fixed (byte* pin_data = &data[0]) - { - SNIPacketSetData(packet, pin_data, (uint)length); - } + SNIPacketSetData(packet, data, length); } } finally @@ -435,7 +423,12 @@ internal static uint SNIReadAsync(SNIHandle pConn, ref IntPtr ppNewPacket) => internal static uint SNIReadSyncOverAsync(SNIHandle pConn, ref IntPtr ppNewPacket, int timeout) => s_nativeMethods.SniReadSyncOverAsync(pConn, ref ppNewPacket, timeout); - internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan inBuff, Span outBuff, ref uint sendLength, string serverUserName) + internal static unsafe uint SNISecGenClientContext( + SNIHandle pConnectionObject, + ReadOnlySpan inBuff, + Span outBuff, + ref uint sendLength, + string serverUserName) { var serverWriter = SqlObjectPools.BufferWriter.Rent(); @@ -443,15 +436,22 @@ internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, { Encoding.Unicode.GetBytes(serverUserName, serverWriter); - return SNISecGenClientContextWrapper( - pConnectionObject, - inBuff, - outBuff, - ref sendLength, - out _, - serverWriter.WrittenSpan, - null, - null); + fixed (byte* pInBuff = inBuff) + fixed (byte* pOutBuff = outBuff) + fixed (byte* pServerInfo = serverWriter.WrittenSpan) + { + return s_nativeMethods.SniSecGenClientContextWrapper( + pConn: pConnectionObject, + pIn: pInBuff, + cbIn: (uint)inBuff.Length, + pOut: pOutBuff, + pcbOut: ref sendLength, + pfDone: out _, + szServerInfo: pServerInfo, + cbServerInfo: (uint)serverWriter.WrittenSpan.Length, + pwszUserName: null, + pwszPassword: null); + } } finally { @@ -488,11 +488,11 @@ internal static uint SNIWritePacket(SNIHandle pConn, SNIPacket packet, bool sync { if (sync) { - return SNIWriteSyncOverAsync(pConn, packet); + return s_nativeMethods.SniWriteSyncOverAsync(pConn, packet); } else { - return SNIWriteAsyncWrapper(pConn, packet); + return s_nativeMethods.SniWriteAsyncWrapper(pConn, packet); } } @@ -500,9 +500,6 @@ internal static uint SNIWritePacket(SNIHandle pConn, SNIPacket packet, bool sync #region Private Methods - private static uint GetSniMaxComposedSpnLength() => - s_nativeMethods.SniGetMaxComposedSpnLength(); - private static void MarshalConsumerInfo(ConsumerInfo consumerInfo, ref SniConsumerInfo native_consumerInfo) { native_consumerInfo.DefaultUserDataLength = consumerInfo.defaultBufferSize; @@ -515,92 +512,6 @@ private static void MarshalConsumerInfo(ConsumerInfo consumerInfo, ref SniConsum native_consumerInfo.ConsumerKey = consumerInfo.key; } - private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out Guid pbQInfo) => - s_nativeMethods.SniGetInfoWrapper(pConn, QType, out pbQInfo); - - #if NETFRAMEWORK - private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, [MarshalAs(UnmanagedType.Bool)] out bool pbQInfo) => - s_nativeMethods.SniGetInfoWrapper(pConn, QType, out pbQInfo); - #endif - - private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out Provider provNum) => - s_nativeMethods.SniGetInfoWrapper(pConn, QType, out provNum); - - private static uint SNIGetInfoWrapper([In] SNIHandle pConn, QueryType QType, out ushort portNum) => - s_nativeMethods.SniGetInfoWrapper(pConn, QType, out portNum); - - private static uint SNIGetPeerAddrStrWrapper([In] SNIHandle pConn, int bufferSize, StringBuilder addrBuffer, out uint addrLen) => - s_nativeMethods.SniGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out addrLen); - - private static uint SNIInitialize([In] IntPtr pmo) => - s_nativeMethods.SniInitialize(pmo); - - private static uint SNIOpenSyncExWrapper(ref SniClientConsumerInfo pClientConsumerInfo, out IntPtr ppConn) => - s_nativeMethods.SniOpenSyncExWrapper(ref pClientConsumerInfo, out ppConn); - - private static uint SNIOpenWrapper( - [In] ref SniConsumerInfo pConsumerInfo, - [MarshalAs(UnmanagedType.LPWStr)] string szConnect, - [In] SNIHandle pConn, - out IntPtr ppConn, - [MarshalAs(UnmanagedType.Bool)] bool fSync, - SqlConnectionIPAddressPreference ipPreference, - [In] ref SniDnsCacheInfo pDNSCachedInfo) => - s_nativeMethods.SniOpenWrapper( - ref pConsumerInfo, - szConnect, - pConn, - out ppConn, - fSync, - ipPreference, - ref pDNSCachedInfo); - - private static IntPtr SNIPacketAllocateWrapper([In] SafeHandle pConn, IoType IOType) => - s_nativeMethods.SniPacketAllocateWrapper(pConn, IOType); - - private static uint SNIPacketGetDataWrapper([In] IntPtr packet, [In, Out] byte[] readBuffer, uint readBufferLength, out uint dataSize) => - s_nativeMethods.SniPacketGetDataWrapper(packet, readBuffer, readBufferLength, out dataSize); - - private static unsafe void SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf, uint cbBuf) => - s_nativeMethods.SniPacketSetData(pPacket, pbBuf, cbBuf); - - private static unsafe uint SNISecGenClientContextWrapper( - [In] SNIHandle pConn, - [In, Out] ReadOnlySpan pIn, - [In, Out] Span pOut, - [In] ref uint pcbOut, - [MarshalAsAttribute(UnmanagedType.Bool)] - out bool pfDone, - ReadOnlySpan serverInfo, - [MarshalAsAttribute(UnmanagedType.LPWStr)] - string pwszUserName, - [MarshalAsAttribute(UnmanagedType.LPWStr)] - string pwszPassword) - { - fixed (byte* pInPtr = pIn) - fixed (byte* pOutPtr = pOut) - fixed (byte* pServerInfo = serverInfo) - { - return s_nativeMethods.SniSecGenClientContextWrapper( - pConn, - pInPtr, - (uint)pIn.Length, - pOutPtr, - ref pcbOut, - out pfDone, - pServerInfo, - (uint)serverInfo.Length, - pwszUserName, - pwszPassword); - } - } - - private static uint SNIWriteAsyncWrapper(SNIHandle pConn, [In] SNIPacket pPacket) => - s_nativeMethods.SniWriteAsyncWrapper(pConn, pPacket); - - private static uint SNIWriteSyncOverAsync(SNIHandle pConn, [In] SNIPacket pPacket) => - s_nativeMethods.SniWriteSyncOverAsync(pConn, pPacket); - #endregion From 51c2342e5615916bdfe94db6cd7e288996228518 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Wed, 13 Nov 2024 18:23:59 -0600 Subject: [PATCH 05/12] Rename methods as per naming rules --- .../Data/SqlClient/TdsParser.Windows.cs | 2 +- .../SqlClient/TdsParserStateObjectNative.cs | 28 ++++---- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 16 ++--- .../SqlClient/TdsParserStateObject.netfx.cs | 20 +++--- .../Interop/Windows/Sni/SniNativeWrapper.cs | 69 ++++++++++--------- .../SqlDataSourceEnumeratorNativeHelper.cs | 16 ++--- .../SSPI/NativeSSPIContextProvider.cs | 4 +- .../SqlClient/TdsParserSafeHandles.Windows.cs | 20 +++--- 8 files changed, 89 insertions(+), 86 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.Windows.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.Windows.cs index 7c10f0aa4f..de4bcb1338 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.Windows.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.Windows.cs @@ -79,7 +79,7 @@ private SNIErrorDetails GetSniErrorDetails() } else { - SniNativeWrapper.SNIGetLastError(out SniError sniError); + SniNativeWrapper.SniGetLastError(out SniError sniError); details.sniErrorNumber = sniError.sniError; details.errorMessage = sniError.errorMessage; details.nativeError = sniError.nativeError; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index 49192f21e4..043b6f3a5e 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -184,7 +184,7 @@ internal override void CreatePhysicalSNIHandle( protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) { Debug.Assert(packet.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); - return SniNativeWrapper.SNIPacketGetData(packet.NativePointer, _inBuff, ref dataSize); + return SniNativeWrapper.SniPacketGetData(packet.NativePointer, _inBuff, ref dataSize); } protected override bool CheckPacket(PacketHandle packet, TaskCompletionSource source) @@ -264,7 +264,7 @@ internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint throw ADP.ClosedConnectionError(); } IntPtr readPacketPtr = IntPtr.Zero; - error = SniNativeWrapper.SNIReadSyncOverAsync(handle, ref readPacketPtr, GetTimeoutRemaining()); + error = SniNativeWrapper.SniReadSyncOverAsync(handle, ref readPacketPtr, GetTimeoutRemaining()); return PacketHandle.FromNativePointer(readPacketPtr); } @@ -281,20 +281,20 @@ internal override bool IsPacketEmpty(PacketHandle readPacket) internal override void ReleasePacket(PacketHandle syncReadPacket) { Debug.Assert(syncReadPacket.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); - SniNativeWrapper.SNIPacketRelease(syncReadPacket.NativePointer); + SniNativeWrapper.SniPacketRelease(syncReadPacket.NativePointer); } internal override uint CheckConnection() { SNIHandle handle = Handle; - return handle == null ? TdsEnums.SNI_SUCCESS : SniNativeWrapper.SNICheckConnection(handle); + return handle == null ? TdsEnums.SNI_SUCCESS : SniNativeWrapper.SniCheckConnection(handle); } internal override PacketHandle ReadAsync(SessionHandle handle, out uint error) { Debug.Assert(handle.Type == SessionHandle.NativeHandleType, "unexpected handle type when requiring NativePointer"); IntPtr readPacketPtr = IntPtr.Zero; - error = SniNativeWrapper.SNIReadAsync(handle.NativeHandle, ref readPacketPtr); + error = SniNativeWrapper.SniReadAsync(handle.NativeHandle, ref readPacketPtr); return PacketHandle.FromNativePointer(readPacketPtr); } @@ -310,7 +310,7 @@ internal override PacketHandle CreateAndSetAttentionPacket() internal override uint WritePacket(PacketHandle packet, bool sync) { Debug.Assert(packet.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket"); - return SniNativeWrapper.SNIWritePacket(Handle, packet.NativePacket, sync); + return SniNativeWrapper.SniWritePacket(Handle, packet.NativePacket, sync); } internal override PacketHandle AddPacketToPendingList(PacketHandle packetToAdd) @@ -343,7 +343,7 @@ internal override PacketHandle GetResetWritePacket(int dataSize) { if (_sniPacket != null) { - SniNativeWrapper.SNIPacketReset(Handle, IoType.WRITE, _sniPacket, ConsumerNumber.SNI_Consumer_SNI); + SniNativeWrapper.SniPacketReset(Handle, IoType.WRITE, _sniPacket, ConsumerNumber.SNI_Consumer_SNI); } else { @@ -372,17 +372,17 @@ internal override void ClearAllWritePackets() internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed) { Debug.Assert(packet.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket"); - SniNativeWrapper.SNIPacketSetData(packet.NativePacket, buffer, bytesUsed); + SniNativeWrapper.SniPacketSetData(packet.NativePacket, buffer, bytesUsed); } internal override uint SniGetConnectionId(ref Guid clientConnectionId) => SniNativeWrapper.SniGetConnectionId(Handle, ref clientConnectionId); internal override uint DisableSsl() - => SniNativeWrapper.SNIRemoveProvider(Handle, Provider.SSL_PROV); + => SniNativeWrapper.SniRemoveProvider(Handle, Provider.SSL_PROV); internal override uint EnableMars(ref uint info) - => SniNativeWrapper.SNIAddProvider(Handle, Provider.SMUX_PROV, ref info); + => SniNativeWrapper.SniAddProvider(Handle, Provider.SMUX_PROV, ref info); internal override uint EnableSsl(ref uint info, bool tlsFirst, string serverCertificateFilename) { @@ -392,15 +392,15 @@ internal override uint EnableSsl(ref uint info, bool tlsFirst, string serverCert authInfo.serverCertFileName = serverCertificateFilename; // Add SSL (Encryption) SNI provider. - return SniNativeWrapper.SNIAddProvider(Handle, Provider.SSL_PROV, ref authInfo); + return SniNativeWrapper.SniAddProvider(Handle, Provider.SSL_PROV, ref authInfo); } internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) - => SniNativeWrapper.SNISetInfo(Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize); + => SniNativeWrapper.SniSetInfo(Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize); internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion) { - uint returnValue = SniNativeWrapper.SNIWaitForSSLHandshakeToComplete(Handle, GetTimeoutRemaining(), out uint nativeProtocolVersion); + uint returnValue = SniNativeWrapper.SniWaitForSslHandshakeToComplete(Handle, GetTimeoutRemaining(), out uint nativeProtocolVersion); var nativeProtocol = (NativeProtocols)nativeProtocolVersion; #pragma warning disable CA5398 // Avoid hardcoded SslProtocols values @@ -469,7 +469,7 @@ public SNIPacket Take(SNIHandle sniHandle) { // Success - reset the packet packet = _packets.Pop(); - SniNativeWrapper.SNIPacketReset(sniHandle, IoType.WRITE, packet, ConsumerNumber.SNI_Consumer_SNI); + SniNativeWrapper.SniPacketReset(sniHandle, IoType.WRITE, packet, ConsumerNumber.SNI_Consumer_SNI); } else { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 06ffdcd629..2b41165983 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -699,7 +699,7 @@ internal void RemoveEncryption() uint error = 0; // Remove SSL (Encryption) SNI provider since we only wanted to encrypt login. - error = SniNativeWrapper.SNIRemoveProvider(_physicalStateObj.Handle, Provider.SSL_PROV); + error = SniNativeWrapper.SniRemoveProvider(_physicalStateObj.Handle, Provider.SSL_PROV); if (error != TdsEnums.SNI_SUCCESS) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -726,7 +726,7 @@ internal void EnableMars() uint info = 0; // Add SMUX (MARS) SNI provider. - error = SniNativeWrapper.SNIAddProvider(_pMarsPhysicalConObj.Handle, Provider.SMUX_PROV, ref info); + error = SniNativeWrapper.SniAddProvider(_pMarsPhysicalConObj.Handle, Provider.SMUX_PROV, ref info); if (error != TdsEnums.SNI_SUCCESS) { @@ -747,12 +747,12 @@ internal void EnableMars() { _pMarsPhysicalConObj.IncrementPendingCallbacks(); - error = SniNativeWrapper.SNIReadAsync(_pMarsPhysicalConObj.Handle, ref temp); + error = SniNativeWrapper.SniReadAsync(_pMarsPhysicalConObj.Handle, ref temp); if (temp != IntPtr.Zero) { // Be sure to release packet, otherwise it will be leaked by native. - SniNativeWrapper.SNIPacketRelease(temp); + SniNativeWrapper.SniPacketRelease(temp); } } Debug.Assert(IntPtr.Zero == temp, "unexpected syncReadPacket without corresponding SNIPacketRelease"); @@ -1025,7 +1025,7 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ Debug.Assert((_encryptionOption & EncryptionOptions.CLIENT_CERT) == 0, "Client certificate authentication support has been removed"); - error = SniNativeWrapper.SNIAddProvider(_physicalStateObj.Handle, Provider.SSL_PROV, authInfo); + error = SniNativeWrapper.SniAddProvider(_physicalStateObj.Handle, Provider.SSL_PROV, authInfo); if (error != TdsEnums.SNI_SUCCESS) { @@ -1037,7 +1037,7 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ // wait for SSL handshake to complete, so that the SSL context is fully negotiated before we try to use its // Channel Bindings as part of the Windows Authentication context build (SSL handshake must complete // before calling SNISecGenClientContext). - error = SniNativeWrapper.SNIWaitForSSLHandshakeToComplete(_physicalStateObj.Handle, _physicalStateObj.GetTimeoutRemaining(), out uint protocolVersion); + error = SniNativeWrapper.SniWaitForSslHandshakeToComplete(_physicalStateObj.Handle, _physicalStateObj.GetTimeoutRemaining(), out uint protocolVersion); if (error != TdsEnums.SNI_SUCCESS) { @@ -1591,7 +1591,7 @@ internal SqlError ProcessSNIError(TdsParserStateObject stateObj) Debug.Assert(SniContext.Undefined != stateObj.DebugOnlyCopyOfSniContext || ((_fMARS) && ((_state == TdsParserState.Closed) || (_state == TdsParserState.Broken))), "SniContext must not be None"); #endif SniError sniError = new SniError(); - SniNativeWrapper.SNIGetLastError(out sniError); + SniNativeWrapper.SniGetLastError(out sniError); if (sniError.sniError != 0) { @@ -2906,7 +2906,7 @@ private TdsOperationStatus TryProcessEnvChange(int tokenLength, TdsParserStateOb // Update SNI ConsumerInfo value to be resulting packet size uint unsignedPacketSize = (uint)packetSize; - uint bufferSizeResult = SniNativeWrapper.SNISetInfo(_physicalStateObj.Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize); + uint bufferSizeResult = SniNativeWrapper.SniSetInfo(_physicalStateObj.Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize); Debug.Assert(bufferSizeResult == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SNISetInfo"); } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs index 024d698822..461223c237 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs @@ -270,20 +270,20 @@ internal PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint error) { SNIHandle handle = Handle ?? throw ADP.ClosedConnectionError(); PacketHandle readPacket = default; - error = SniNativeWrapper.SNIReadSyncOverAsync(handle, ref readPacket, timeoutRemaining); + error = SniNativeWrapper.SniReadSyncOverAsync(handle, ref readPacket, timeoutRemaining); return readPacket; } internal PacketHandle ReadAsync(SessionHandle handle, out uint error) { PacketHandle readPacket = default; - error = SniNativeWrapper.SNIReadAsync(handle.NativeHandle, ref readPacket); + error = SniNativeWrapper.SniReadAsync(handle.NativeHandle, ref readPacket); return readPacket; } - internal uint CheckConnection() => SniNativeWrapper.SNICheckConnection(Handle); + internal uint CheckConnection() => SniNativeWrapper.SniCheckConnection(Handle); - internal void ReleasePacket(PacketHandle syncReadPacket) => SniNativeWrapper.SNIPacketRelease(syncReadPacket); + internal void ReleasePacket(PacketHandle syncReadPacket) => SniNativeWrapper.SniPacketRelease(syncReadPacket); [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] internal int DecrementPendingCallbacks(bool release) @@ -401,7 +401,7 @@ internal bool ValidateSNIConnection() SNIHandle handle = Handle; if (handle != null) { - error = SniNativeWrapper.SNICheckConnection(handle); + error = SniNativeWrapper.SniCheckConnection(handle); } } finally @@ -518,7 +518,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) private uint GetSniPacket(PacketHandle packet, ref uint dataSize) { - return SniNativeWrapper.SNIPacketGetData(packet, _inBuff, ref dataSize); + return SniNativeWrapper.SniPacketGetData(packet, _inBuff, ref dataSize); } private void ChangeNetworkPacketTimeout(int dueTime, int period) @@ -1007,7 +1007,7 @@ private Task SNIWritePacket(SNIHandle handle, SNIPacket packet, out uint sniErro } finally { - sniError = SniNativeWrapper.SNIWritePacket(handle, packet, sync); + sniError = SniNativeWrapper.SniWritePacket(handle, packet, sync); } if (sniError == TdsEnums.SNI_SUCCESS_IO_PENDING) @@ -1119,7 +1119,7 @@ internal void SendAttention(bool mustTakeWriteLock = false, bool asyncClose = fa SNIPacket attnPacket = new SNIPacket(Handle); _sniAsyncAttnPacket = attnPacket; - SniNativeWrapper.SNIPacketSetData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN, null, null); + SniNativeWrapper.SniPacketSetData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN, null, null); RuntimeHelpers.PrepareConstrainedRegions(); try @@ -1183,7 +1183,7 @@ private Task WriteSni(bool canAccumulate) { // Prepare packet, and write to packet. SNIPacket packet = GetResetWritePacket(); - SniNativeWrapper.SNIPacketSetData(packet, _outBuff, _outBytesUsed, _securePasswords, _securePasswordOffsetsInBuffer); + SniNativeWrapper.SniPacketSetData(packet, _outBuff, _outBytesUsed, _securePasswords, _securePasswordOffsetsInBuffer); Debug.Assert(Parser.Connection._parserLock.ThreadMayHaveLock(), "Thread is writing without taking the connection lock"); Task task = SNIWritePacket(Handle, packet, out _, canAccumulate, callerHasConnectionLock: true); @@ -1238,7 +1238,7 @@ internal SNIPacket GetResetWritePacket() { if (_sniPacket != null) { - SniNativeWrapper.SNIPacketReset(Handle, IoType.WRITE, _sniPacket, ConsumerNumber.SNI_Consumer_SNI); + SniNativeWrapper.SniPacketReset(Handle, IoType.WRITE, _sniPacket, ConsumerNumber.SNI_Consumer_SNI); } else { diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index 0e7016a43f..d91b441baa 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -53,13 +53,13 @@ internal static int SniMaxComposedSpnLength #region Public Methods - internal static uint SNIAddProvider(SNIHandle pConn, Provider ProvNum, [In] ref AuthProviderInfo pInfo) => + internal static uint SniAddProvider(SNIHandle pConn, Provider ProvNum, [In] ref AuthProviderInfo pInfo) => s_nativeMethods.SniAddProvider(pConn, ProvNum, ref pInfo); #if NETFRAMEWORK [ResourceExposure(ResourceScope.None)] [ResourceConsumption(ResourceScope.Machine, ResourceScope.Machine)] - internal static uint SNIAddProvider(SNIHandle pConn, + internal static uint SniAddProvider(SNIHandle pConn, Provider providerEnum, AuthProviderInfo authInfo) { @@ -68,7 +68,7 @@ internal static uint SNIAddProvider(SNIHandle pConn, Debug.Assert(authInfo.clientCertificateCallback == null, "CTAIP support has been removed"); - ret = SNIAddProvider(pConn, providerEnum, ref authInfo); + ret = SniAddProvider(pConn, providerEnum, ref authInfo); if (ret == ERROR_SUCCESS) { @@ -81,13 +81,13 @@ internal static uint SNIAddProvider(SNIHandle pConn, } #endif - internal static uint SNIAddProvider(SNIHandle pConn, Provider ProvNum, [In] ref uint pInfo) => + internal static uint SniAddProvider(SNIHandle pConn, Provider ProvNum, [In] ref uint pInfo) => s_nativeMethods.SniAddProvider(pConn, ProvNum, ref pInfo); - internal static uint SNICheckConnection([In] SNIHandle pConn) => + internal static uint SniCheckConnection([In] SNIHandle pConn) => s_nativeMethods.SniCheckConnection(pConn); - internal static uint SNIClose(IntPtr pConn) => + internal static uint SniClose(IntPtr pConn) => s_nativeMethods.SniClose(pConn); internal static uint SniGetConnectionId(SNIHandle pConn, ref Guid connId) => @@ -113,7 +113,7 @@ internal static uint SniGetConnectionPort(SNIHandle pConn, ref ushort portNum) return s_nativeMethods.SniGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PEERPORT, out portNum); } - internal static void SNIGetLastError(out SniError pErrorStruct) => + internal static void SniGetLastError(out SniError pErrorStruct) => s_nativeMethods.SniGetLastError(out pErrorStruct); internal static uint SniGetProviderNumber(SNIHandle pConn, ref Provider provNum) @@ -121,13 +121,13 @@ internal static uint SniGetProviderNumber(SNIHandle pConn, ref Provider provNum) return s_nativeMethods.SniGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PROVIDERNUM, out provNum); } - internal static uint SNIInitialize() => + internal static uint SniInitialize() => s_nativeMethods.SniInitialize(IntPtr.Zero); - internal static uint UnmanagedIsTokenRestricted([In] IntPtr token, [MarshalAs(UnmanagedType.Bool)] out bool isRestricted) => + internal static uint SniIsTokenRestricted([In] IntPtr token, [MarshalAs(UnmanagedType.Bool)] out bool isRestricted) => s_nativeMethods.SniIsTokenRestricted(token, out isRestricted); - internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync, SqlConnectionIPAddressPreference ipPreference, SQLDNSInfo cachedDNSInfo) + internal static unsafe uint SniOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync, SqlConnectionIPAddressPreference ipPreference, SQLDNSInfo cachedDNSInfo) { // initialize consumer info for MARS SniConsumerInfo native_consumerInfo = new SniConsumerInfo(); @@ -142,7 +142,7 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan return s_nativeMethods.SniOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ipPreference, ref native_cachedDNSInfo); } - internal static unsafe uint SNIOpenSyncEx( + internal static unsafe uint SniOpenSyncEx( ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, @@ -268,23 +268,23 @@ internal static unsafe uint SNIOpenSyncEx( } } - internal static void SNIPacketAllocate(SafeHandle pConn, IoType IOType, ref IntPtr pPacket) => + internal static void SniPacketAllocate(SafeHandle pConn, IoType IOType, ref IntPtr pPacket) => pPacket = s_nativeMethods.SniPacketAllocateWrapper(pConn, IOType); - internal static unsafe uint SNIPacketGetData(IntPtr packet, byte[] readBuffer, ref uint dataSize) => + internal static unsafe uint SniPacketGetData(IntPtr packet, byte[] readBuffer, ref uint dataSize) => s_nativeMethods.SniPacketGetDataWrapper(packet, readBuffer, (uint)readBuffer.Length, out dataSize); - internal static void SNIPacketRelease(IntPtr pPacket) => + internal static void SniPacketRelease(IntPtr pPacket) => s_nativeMethods.SniPacketRelease(pPacket); - internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int length) + internal static unsafe void SniPacketSetData(SNIPacket packet, byte[] data, int length) { fixed (byte* pin_data = &data[0]) { s_nativeMethods.SniPacketSetData(packet, pin_data, (uint)length); } } - + #if NETFRAMEWORK //[ResourceExposure(ResourceScope::None)] // @@ -297,7 +297,7 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int // to loose encryption algorithm is changed it should be done in both in this method as well as TdsParserStaticMethods.EncryptPassword. // Up to current release, it is also guaranteed that both password and new change password will fit into a single login packet whose size is fixed to 4096 // So, there is no splitting logic is needed. - internal static void SNIPacketSetData(SNIPacket packet, + internal static void SniPacketSetData(SNIPacket packet, Byte[] data, Int32 length, SecureString[] passwords, // pointer to the passwords which need to be written out to SNI Packet @@ -385,7 +385,7 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w packet.DangerousAddRef(ref mustRelease); Debug.Assert(mustRelease, "AddRef Failed!"); - SNIPacketSetData(packet, data, length); + SniPacketSetData(packet, data, length); } } finally @@ -408,22 +408,25 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w } #endif - internal static void SNIPacketReset([In] SNIHandle pConn, IoType IOType, SNIPacket pPacket, ConsumerNumber ConsNum) => + internal static void SniPacketReset([In] SNIHandle pConn, IoType IOType, SNIPacket pPacket, ConsumerNumber ConsNum) => s_nativeMethods.SniPacketReset(pConn, IOType, pPacket, ConsNum); - internal static uint SNIQueryInfo(QueryType QType, ref uint pbQInfo) => + internal static uint SniQueryInfo(QueryType QType, ref uint pbQInfo) => s_nativeMethods.SniQueryInfo(QType, ref pbQInfo); - internal static uint SNIQueryInfo(QueryType QType, ref IntPtr pbQInfo) => + internal static uint SniQueryInfo(QueryType QType, ref IntPtr pbQInfo) => s_nativeMethods.SniQueryInfo(QType, ref pbQInfo); - internal static uint SNIReadAsync(SNIHandle pConn, ref IntPtr ppNewPacket) => + internal static uint SniReadAsync(SNIHandle pConn, ref IntPtr ppNewPacket) => s_nativeMethods.SniReadAsync(pConn, ref ppNewPacket); - internal static uint SNIReadSyncOverAsync(SNIHandle pConn, ref IntPtr ppNewPacket, int timeout) => + internal static uint SniReadSyncOverAsync(SNIHandle pConn, ref IntPtr ppNewPacket, int timeout) => s_nativeMethods.SniReadSyncOverAsync(pConn, ref ppNewPacket, timeout); - internal static unsafe uint SNISecGenClientContext( + internal static uint SniRemoveProvider(SNIHandle pConn, Provider ProvNum) => + s_nativeMethods.SniRemoveProvider(pConn, ProvNum); + + internal static unsafe uint SniSecGenClientContext( SNIHandle pConnectionObject, ReadOnlySpan inBuff, Span outBuff, @@ -459,32 +462,32 @@ internal static unsafe uint SNISecGenClientContext( } } - internal static uint SNISecInitPackage(ref uint pcbMaxToken) => + internal static uint SniSecInitPackage(ref uint pcbMaxToken) => s_nativeMethods.SniSecInitPackage(ref pcbMaxToken); - internal static void SNIServerEnumClose([In] IntPtr packet) => + internal static void SniServerEnumClose([In] IntPtr packet) => s_nativeMethods.SniServerEnumClose(packet); - internal static IntPtr SNIServerEnumOpen() => + internal static IntPtr SniServerEnumOpen() => s_nativeMethods.SniServerEnumOpen(); - internal static int SNIServerEnumRead( + internal static int SniServerEnumRead( [In] IntPtr packet, [In] [MarshalAs(UnmanagedType.LPArray)] char[] readBuffer, [In] int bufferLength, [MarshalAs(UnmanagedType.Bool)] out bool more) => s_nativeMethods.SniServerEnumRead(packet, readBuffer, bufferLength, out more); - internal static uint SNISetInfo(SNIHandle pConn, QueryType QType, [In] ref uint pbQInfo) => + internal static uint SniSetInfo(SNIHandle pConn, QueryType QType, [In] ref uint pbQInfo) => s_nativeMethods.SniSetInfo(pConn, QType, ref pbQInfo); - internal static uint SNITerminate() => + internal static uint SniTerminate() => s_nativeMethods.SniTerminate(); - internal static uint SNIWaitForSSLHandshakeToComplete([In] SNIHandle pConn, int dwMilliseconds, out uint pProtocolVersion) => + internal static uint SniWaitForSslHandshakeToComplete([In] SNIHandle pConn, int dwMilliseconds, out uint pProtocolVersion) => s_nativeMethods.SniWaitForSslHandshakeToComplete(pConn, dwMilliseconds, out pProtocolVersion); - internal static uint SNIWritePacket(SNIHandle pConn, SNIPacket packet, bool sync) + internal static uint SniWritePacket(SNIHandle pConn, SNIPacket packet, bool sync) { if (sync) { @@ -564,7 +567,7 @@ internal static class Win32NativeMethods internal static bool IsTokenRestrictedWrapper(IntPtr token) { bool isRestricted; - uint result = SniNativeWrapper.UnmanagedIsTokenRestricted(token, out isRestricted); + uint result = SniNativeWrapper.SniIsTokenRestricted(token, out isRestricted); if (result != 0) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Sql/SqlDataSourceEnumeratorNativeHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Sql/SqlDataSourceEnumeratorNativeHelper.cs index b99e91414a..138e671dc9 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Sql/SqlDataSourceEnumeratorNativeHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Sql/SqlDataSourceEnumeratorNativeHelper.cs @@ -50,23 +50,23 @@ internal static DataTable GetDataSources() { } finally { - handle = SniNativeWrapper.SNIServerEnumOpen(); + handle = SniNativeWrapper.SniServerEnumOpen(); SqlClientEventSource.Log.TryTraceEvent(" {2} returned handle = {3}.", nameof(SqlDataSourceEnumeratorNativeHelper), nameof(GetDataSources), - nameof(SniNativeWrapper.SNIServerEnumOpen), handle); + nameof(SniNativeWrapper.SniServerEnumOpen), handle); } if (handle != ADP.s_ptrZero) { while (more && !TdsParserStaticMethods.TimeoutHasExpired(s_timeoutTime)) { - readLength = SniNativeWrapper.SNIServerEnumRead(handle, buffer, bufferSize, out more); + readLength = SniNativeWrapper.SniServerEnumRead(handle, buffer, bufferSize, out more); SqlClientEventSource.Log.TryTraceEvent(" {2} returned 'readlength':{3}, and 'more':{4} with 'bufferSize' of {5}", nameof(SqlDataSourceEnumeratorNativeHelper), nameof(GetDataSources), - nameof(SniNativeWrapper.SNIServerEnumRead), + nameof(SniNativeWrapper.SniServerEnumRead), readLength, more, bufferSize); if (readLength > bufferSize) { @@ -84,21 +84,21 @@ internal static DataTable GetDataSources() { if (handle != ADP.s_ptrZero) { - SniNativeWrapper.SNIServerEnumClose(handle); + SniNativeWrapper.SniServerEnumClose(handle); SqlClientEventSource.Log.TryTraceEvent(" {2} called.", nameof(SqlDataSourceEnumeratorNativeHelper), nameof(GetDataSources), - nameof(SniNativeWrapper.SNIServerEnumClose)); + nameof(SniNativeWrapper.SniServerEnumClose)); } } if (failure) { - Debug.Assert(false, $"{nameof(GetDataSources)}:{nameof(SniNativeWrapper.SNIServerEnumRead)} returned bad length"); + Debug.Assert(false, $"{nameof(GetDataSources)}:{nameof(SniNativeWrapper.SniServerEnumRead)} returned bad length"); SqlClientEventSource.Log.TryTraceEvent(" {2} returned bad length, requested buffer {3}, received {4}", nameof(SqlDataSourceEnumeratorNativeHelper), nameof(GetDataSources), - nameof(SniNativeWrapper.SNIServerEnumRead), + nameof(SniNativeWrapper.SniServerEnumRead), bufferSize, readLength); throw ADP.ArgumentOutOfRange(StringsHelper.GetString(Strings.ADP_ParameterValueOutOfRange, readLength), nameof(readLength)); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs index 0a2fa8aeb7..621ec5b4cc 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs @@ -34,7 +34,7 @@ private void LoadSSPILibrary() // use local for ref param to defer setting s_maxSSPILength until we know the call succeeded. uint maxLength = 0; - if (0 != SniNativeWrapper.SNISecInitPackage(ref maxLength)) + if (0 != SniNativeWrapper.SniSecInitPackage(ref maxLength)) SSPIError(SQLMessage.SSPIInitializeError(), TdsEnums.INIT_SSPI_PACKAGE); s_maxSSPILength = maxLength; @@ -62,7 +62,7 @@ protected override void GenerateSspiClientContext(ReadOnlySpan incomingBlo var sendLength = s_maxSSPILength; var outBuff = outgoingBlobWriter.GetSpan((int)sendLength); - if (0 != SniNativeWrapper.SNISecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, serverSpns[0])) + if (0 != SniNativeWrapper.SniSecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, serverSpns[0])) { throw new InvalidOperationException(SQLMessage.SSPIGenerateError()); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs index 2e8d96857b..f2f9143ed7 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs @@ -36,7 +36,7 @@ private SNILoadHandle() : base(IntPtr.Zero, true) { } finally { - _sniStatus = SniNativeWrapper.SNIInitialize(); + _sniStatus = SniNativeWrapper.SniInitialize(); base.handle = (IntPtr)1; // Initialize to non-zero dummy variable. } } @@ -57,7 +57,7 @@ public bool ClientOSEncryptionSupport { uint value = 0; // Query OS to find out whether encryption is supported. - SniNativeWrapper.SNIQueryInfo(QueryType.SNI_QUERY_CLIENT_ENCRYPT_POSSIBLE, ref value); + SniNativeWrapper.SniQueryInfo(QueryType.SNI_QUERY_CLIENT_ENCRYPT_POSSIBLE, ref value); _clientOSEncryptionSupport = value != 0; } catch (Exception e) @@ -79,7 +79,7 @@ override protected bool ReleaseHandle() if (TdsEnums.SNI_SUCCESS == _sniStatus) { LocalDbApi.ReleaseDllHandles(); - SniNativeWrapper.SNITerminate(); + SniNativeWrapper.SniTerminate(); } base.handle = IntPtr.Zero; } @@ -186,7 +186,7 @@ internal SNIHandle( #if NETFRAMEWORK int transparentNetworkResolutionStateNo = (int)transparentNetworkResolutionState; - _status = SniNativeWrapper.SNIOpenSyncEx( + _status = SniNativeWrapper.SniOpenSyncEx( myInfo, serverName, ref base.handle, @@ -202,7 +202,7 @@ internal SNIHandle( cachedDNSInfo, hostNameInCertificate); #else - _status = SniNativeWrapper.SNIOpenSyncEx( + _status = SniNativeWrapper.SniOpenSyncEx( myInfo, serverName, ref base.handle, @@ -226,7 +226,7 @@ internal SNIHandle(ConsumerInfo myInfo, SNIHandle parent, SqlConnectionIPAddress { } finally { - _status = SniNativeWrapper.SNIOpenMarsSession(myInfo, parent, ref base.handle, parent._fSync, ipPreference, cachedDNSInfo); + _status = SniNativeWrapper.SniOpenMarsSession(myInfo, parent, ref base.handle, parent._fSync, ipPreference, cachedDNSInfo); } } @@ -245,7 +245,7 @@ override protected bool ReleaseHandle() base.handle = IntPtr.Zero; if (IntPtr.Zero != ptr) { - if (0 != SniNativeWrapper.SNIClose(ptr)) + if (0 != SniNativeWrapper.SniClose(ptr)) { return false; // SNIClose should never fail. } @@ -266,7 +266,7 @@ internal sealed class SNIPacket : SafeHandle { internal SNIPacket(SafeHandle sniHandle) : base(IntPtr.Zero, true) { - SniNativeWrapper.SNIPacketAllocate(sniHandle, IoType.WRITE, ref base.handle); + SniNativeWrapper.SniPacketAllocate(sniHandle, IoType.WRITE, ref base.handle); if (IntPtr.Zero == base.handle) { throw SQL.SNIPacketAllocationFailure(); @@ -288,7 +288,7 @@ override protected bool ReleaseHandle() base.handle = IntPtr.Zero; if (IntPtr.Zero != ptr) { - SniNativeWrapper.SNIPacketRelease(ptr); + SniNativeWrapper.SniPacketRelease(ptr); } return true; } @@ -312,7 +312,7 @@ public SNIPacket Take(SNIHandle sniHandle) { // Success - reset the packet packet = _packets.Pop(); - SniNativeWrapper.SNIPacketReset(sniHandle, IoType.WRITE, packet, ConsumerNumber.SNI_Consumer_SNI); + SniNativeWrapper.SniPacketReset(sniHandle, IoType.WRITE, packet, ConsumerNumber.SNI_Consumer_SNI); } else { From 992e82e0a18faa6c0afe284ed7c785befc4c5428 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 14 Nov 2024 12:43:52 -0600 Subject: [PATCH 06/12] Cleanup suggestions and follow naming conventions --- .../SqlClient/TdsParserStateObjectNative.cs | 2 +- .../Data/SqlClient/TdsParser.netfx.cs | 2 +- .../Interop/Windows/Sni/SniNativeWrapper.cs | 325 +++++++++--------- .../src/Interop/Windows/SystemErrors.cs | 4 +- 4 files changed, 166 insertions(+), 167 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index 043b6f3a5e..390ca39cd2 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -99,7 +99,7 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache result = SniNativeWrapper.SniGetConnectionPort(Handle, ref portFromSNI); Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionPort"); - result = SniNativeWrapper.SniGetConnectionIPString(Handle, ref IPStringFromSNI); + result = SniNativeWrapper.SniGetConnectionIpString(Handle, ref IPStringFromSNI); Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionIPString"); pendingDNSInfo = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI.ToString()); diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.netfx.cs index 7150aadfd2..690d0351ec 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.netfx.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.netfx.cs @@ -75,7 +75,7 @@ internal void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey) Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionPort"); - result = SniNativeWrapper.SniGetConnectionIPString(_physicalStateObj.Handle, ref IPStringFromSNI); + result = SniNativeWrapper.SniGetConnectionIpString(_physicalStateObj.Handle, ref IPStringFromSNI); Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionIPString"); _connHandler.pendingSQLDNSObject = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI.ToString()); diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index d91b441baa..2a1a87ae98 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -7,13 +7,19 @@ using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; -using System.Runtime.Versioning; -using System.Security; using System.Text; using Interop.Windows.Sni; using Microsoft.Data.Common; using Microsoft.Data.SqlClient; +#if NETFRAMEWORK +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.Versioning; +using System.Security; +using Interop.Windows; +#endif + namespace Microsoft.Data.SqlClient { internal static class SniNativeWrapper @@ -21,7 +27,10 @@ internal static class SniNativeWrapper #region Member Variables private const int SniIpv6AddrStringBufferLength = 48; + + #if NET private const int SniOpenTimeOut = -1; + #endif #if NETFRAMEWORK private static readonly ISniNativeMethods s_nativeMethods = RuntimeInformation.ProcessArchitecture switch @@ -53,8 +62,8 @@ internal static int SniMaxComposedSpnLength #region Public Methods - internal static uint SniAddProvider(SNIHandle pConn, Provider ProvNum, [In] ref AuthProviderInfo pInfo) => - s_nativeMethods.SniAddProvider(pConn, ProvNum, ref pInfo); + internal static uint SniAddProvider(SNIHandle pConn, Provider provNum, ref AuthProviderInfo pInfo) => + s_nativeMethods.SniAddProvider(pConn, provNum, ref pInfo); #if NETFRAMEWORK [ResourceExposure(ResourceScope.None)] @@ -63,28 +72,24 @@ internal static uint SniAddProvider(SNIHandle pConn, Provider providerEnum, AuthProviderInfo authInfo) { - UInt32 ret; - uint ERROR_SUCCESS = 0; - Debug.Assert(authInfo.clientCertificateCallback == null, "CTAIP support has been removed"); - ret = SniAddProvider(pConn, providerEnum, ref authInfo); - - if (ret == ERROR_SUCCESS) + uint ret = SniAddProvider(pConn, providerEnum, ref authInfo); + if (ret == SystemErrors.ERROR_SUCCESS) { // added a provider, need to requery for sync over async support ret = s_nativeMethods.SniGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_SUPPORTS_SYNC_OVER_ASYNC, out bool _); - Debug.Assert(ret == ERROR_SUCCESS, "SNIGetInfo cannot fail with this QType"); + Debug.Assert(ret == SystemErrors.ERROR_SUCCESS, "SNIGetInfo cannot fail with this QType"); } return ret; } #endif - internal static uint SniAddProvider(SNIHandle pConn, Provider ProvNum, [In] ref uint pInfo) => - s_nativeMethods.SniAddProvider(pConn, ProvNum, ref pInfo); + internal static uint SniAddProvider(SNIHandle pConn, Provider provNum, ref uint pInfo) => + s_nativeMethods.SniAddProvider(pConn, provNum, ref pInfo); - internal static uint SniCheckConnection([In] SNIHandle pConn) => + internal static uint SniCheckConnection(SNIHandle pConn) => s_nativeMethods.SniCheckConnection(pConn); internal static uint SniClose(IntPtr pConn) => @@ -93,58 +98,69 @@ internal static uint SniClose(IntPtr pConn) => internal static uint SniGetConnectionId(SNIHandle pConn, ref Guid connId) => s_nativeMethods.SniGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_CONNID, out connId); - internal static uint SniGetConnectionIPString(SNIHandle pConn, ref string connIPStr) + internal static uint SniGetConnectionIpString(SNIHandle pConn, ref string connIpStr) { - UInt32 ret; - uint connIPLen = 0; + StringBuilder addrBuffer = new StringBuilder(SniIpv6AddrStringBufferLength); - int bufferSize = SniIpv6AddrStringBufferLength; - StringBuilder addrBuffer = new StringBuilder(bufferSize); + uint ret = s_nativeMethods.SniGetPeerAddrStrWrapper( + pConn, + SniIpv6AddrStringBufferLength, + addrBuffer, + out uint connIpLen); - ret = s_nativeMethods.SniGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out connIPLen); - - connIPStr = addrBuffer.ToString(0, Convert.ToInt32(connIPLen)); + connIpStr = addrBuffer.ToString(0, Convert.ToInt32(connIpLen)); return ret; } - internal static uint SniGetConnectionPort(SNIHandle pConn, ref ushort portNum) - { - return s_nativeMethods.SniGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PEERPORT, out portNum); - } + internal static uint SniGetConnectionPort(SNIHandle pConn, ref ushort portNum) => + s_nativeMethods.SniGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PEERPORT, out portNum); internal static void SniGetLastError(out SniError pErrorStruct) => s_nativeMethods.SniGetLastError(out pErrorStruct); - internal static uint SniGetProviderNumber(SNIHandle pConn, ref Provider provNum) - { - return s_nativeMethods.SniGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PROVIDERNUM, out provNum); - } + internal static uint SniGetProviderNumber(SNIHandle pConn, ref Provider provNum) => + s_nativeMethods.SniGetInfoWrapper(pConn, QueryType.SNI_QUERY_CONN_PROVIDERNUM, out provNum); internal static uint SniInitialize() => s_nativeMethods.SniInitialize(IntPtr.Zero); - internal static uint SniIsTokenRestricted([In] IntPtr token, [MarshalAs(UnmanagedType.Bool)] out bool isRestricted) => + internal static uint SniIsTokenRestricted(IntPtr token, out bool isRestricted) => s_nativeMethods.SniIsTokenRestricted(token, out isRestricted); - internal static unsafe uint SniOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync, SqlConnectionIPAddressPreference ipPreference, SQLDNSInfo cachedDNSInfo) + internal static uint SniOpenMarsSession( + ConsumerInfo consumerInfo, + SNIHandle parent, + ref IntPtr pConn, + bool fSync, + SqlConnectionIPAddressPreference ipPreference, + SQLDNSInfo cachedDnsInfo) { // initialize consumer info for MARS - SniConsumerInfo native_consumerInfo = new SniConsumerInfo(); - MarshalConsumerInfo(consumerInfo, ref native_consumerInfo); - - SniDnsCacheInfo native_cachedDNSInfo = new SniDnsCacheInfo(); - native_cachedDNSInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; - native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; - native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + SniConsumerInfo nativeConsumerInfo = new SniConsumerInfo(); + MarshalConsumerInfo(consumerInfo, ref nativeConsumerInfo); - return s_nativeMethods.SniOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ipPreference, ref native_cachedDNSInfo); + SniDnsCacheInfo nativeCachedDnsInfo = new SniDnsCacheInfo() + { + wszCachedFQDN = cachedDnsInfo?.FQDN, + wszCachedTcpIPv4 = cachedDnsInfo?.AddrIPv4, + wszCachedTcpIPv6 = cachedDnsInfo?.AddrIPv6, + wszCachedTcpPort = cachedDnsInfo?.Port, + }; + + return s_nativeMethods.SniOpenWrapper( + pConsumerInfo: ref nativeConsumerInfo, + connect: "session:", + pConn: parent, + ppConn: out pConn, + fSync, + ipPreference, + pDnsCacheInfo: ref nativeCachedDnsInfo); } internal static unsafe uint SniOpenSyncEx( ConsumerInfo consumerInfo, - string constring, + string connString, ref IntPtr pConn, ref string spn, byte[] instanceName, @@ -154,25 +170,25 @@ internal static unsafe uint SniOpenSyncEx( bool fParallel, #if NETFRAMEWORK - Int32 transparentNetworkResolutionStateNo, - Int32 totalTimeout, + int transparentNetworkResolutionStateNo, + int totalTimeout, #endif SqlConnectionIPAddressPreference ipPreference, - SQLDNSInfo cachedDNSInfo, + SQLDNSInfo cachedDnsInfo, string hostNameInCertificate) { - fixed (byte* pin_instanceName = &instanceName[0]) + fixed (byte* pInstanceName = instanceName) { SniClientConsumerInfo clientConsumerInfo = new SniClientConsumerInfo(); // initialize client ConsumerInfo part first MarshalConsumerInfo(consumerInfo, ref clientConsumerInfo.ConsumerInfo); - clientConsumerInfo.wszConnectionString = constring; + clientConsumerInfo.wszConnectionString = connString; clientConsumerInfo.HostNameInCertificate = hostNameInCertificate; clientConsumerInfo.networkLibrary = Prefix.UNKNOWN_PREFIX; - clientConsumerInfo.szInstanceName = pin_instanceName; + clientConsumerInfo.szInstanceName = pInstanceName; clientConsumerInfo.cchInstanceName = (uint)instanceName.Length; clientConsumerInfo.fOverrideLastConnectCache = fOverrideCache; clientConsumerInfo.fSynchronousConnection = fSync; @@ -182,13 +198,13 @@ internal static unsafe uint SniOpenSyncEx( #if NETFRAMEWORK switch (transparentNetworkResolutionStateNo) { - case (0): + case 0: clientConsumerInfo.transparentNetworkResolution = TransparentNetworkResolutionMode.DisabledMode; break; - case (1): + case 1: clientConsumerInfo.transparentNetworkResolution = TransparentNetworkResolutionMode.SequentialMode; break; - case (2): + case 2: clientConsumerInfo.transparentNetworkResolution = TransparentNetworkResolutionMode.ParallelMode; break; }; @@ -198,15 +214,15 @@ internal static unsafe uint SniOpenSyncEx( clientConsumerInfo.totalTimeout = SniOpenTimeOut; #endif - clientConsumerInfo.isAzureSqlServerEndpoint = ADP.IsAzureSqlServerEndpoint(constring); + clientConsumerInfo.isAzureSqlServerEndpoint = ADP.IsAzureSqlServerEndpoint(connString); clientConsumerInfo.ipAddressPreference = ipPreference; - clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; - clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDnsInfo?.FQDN; + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDnsInfo?.AddrIPv4; + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDnsInfo?.AddrIPv6; + clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDnsInfo?.Port; - if (spn != null) + if (spn is not null) { if (spn.Length == 0) { @@ -260,18 +276,16 @@ internal static unsafe uint SniOpenSyncEx( } } } - else - { - // else leave szSPN null (SQL Auth) - return s_nativeMethods.SniOpenSyncExWrapper(ref clientConsumerInfo, out pConn); - } + + // Otherwise leave szSPN null (SQL Auth) + return s_nativeMethods.SniOpenSyncExWrapper(ref clientConsumerInfo, out pConn); } } - internal static void SniPacketAllocate(SafeHandle pConn, IoType IOType, ref IntPtr pPacket) => - pPacket = s_nativeMethods.SniPacketAllocateWrapper(pConn, IOType); + internal static void SniPacketAllocate(SafeHandle pConn, IoType ioType, ref IntPtr pPacket) => + pPacket = s_nativeMethods.SniPacketAllocateWrapper(pConn, ioType); - internal static unsafe uint SniPacketGetData(IntPtr packet, byte[] readBuffer, ref uint dataSize) => + internal static uint SniPacketGetData(IntPtr packet, byte[] readBuffer, ref uint dataSize) => s_nativeMethods.SniPacketGetDataWrapper(packet, readBuffer, (uint)readBuffer.Length, out dataSize); internal static void SniPacketRelease(IntPtr pPacket) => @@ -279,46 +293,47 @@ internal static void SniPacketRelease(IntPtr pPacket) => internal static unsafe void SniPacketSetData(SNIPacket packet, byte[] data, int length) { - fixed (byte* pin_data = &data[0]) + fixed (byte* pData = data) { - s_nativeMethods.SniPacketSetData(packet, pin_data, (uint)length); + s_nativeMethods.SniPacketSetData(packet, pData, (uint)length); } } #if NETFRAMEWORK - //[ResourceExposure(ResourceScope::None)] - // - // Notes on SecureString: Writing out security sensitive information to managed buffer should be avoid as these can be moved - // around by GC. There are two set of information which falls into this category: passwords and new changed password which - // are passed in as SecureString by a user. Writing out clear passwords information is delayed until this layer to ensure that - // the information is written out to buffer which is pinned in this method already. This also ensures that processing a clear password - // is done right before it is written out to SNI_Packet where gets encrypted properly. - // TdsParserStaticMethods.EncryptPassword operation is also done here to minimize the time the clear password is held in memory. Any changes - // to loose encryption algorithm is changed it should be done in both in this method as well as TdsParserStaticMethods.EncryptPassword. - // Up to current release, it is also guaranteed that both password and new change password will fit into a single login packet whose size is fixed to 4096 - // So, there is no splitting logic is needed. - internal static void SniPacketSetData(SNIPacket packet, - Byte[] data, - Int32 length, - SecureString[] passwords, // pointer to the passwords which need to be written out to SNI Packet - Int32[] passwordOffsets // Offset into data buffer where the password to be written out to - ) + // Notes on SecureString: Writing out security sensitive information to managed buffer + // should be avoided as these can be moved around by GC. There are two set of + // information which falls into this category: passwords and new changed password which + // are passed in as SecureString by a user. Writing out clear passwords information is + // delayed until this layer to ensure that the information is written out to buffer + // which is pinned in this method already. This also ensures that processing a clear + // password is done right before it is written out to SNI_Packet where gets encrypted + // properly. TdsParserStaticMethods.EncryptPassword operation is also done here to + // minimize the time the clear password is held in memory. Any time loose encryption + // algorithms are changed it should be done in both in this method and + // TdsParserStaticMethods.EncryptPassword. + // Up to current release, it is also guaranteed that both password and new change + // password will fit into a single login packet whose size is fixed to 4096 So, no + // splitting logic is needed. + internal static void SniPacketSetData( + SNIPacket packet, + byte[] data, + int length, + SecureString[] passwords, // pointer to the passwords which need to be written out to SNI Packet + int[] passwordOffsets) // Offset into data buffer where the password to be written out to { - Debug.Assert(passwords == null || (passwordOffsets != null && passwords.Length == passwordOffsets.Length), "The number of passwords does not match the number of password offsets"); + Debug.Assert(passwords is null || (passwordOffsets is not null && passwords.Length == passwordOffsets.Length), "The number of passwords does not match the number of password offsets"); bool mustRelease = false; bool mustClearBuffer = false; IntPtr clearPassword = IntPtr.Zero; - // provides a guaranteed finally block – without this it isn’t guaranteed – non interruptable by fatal exceptions + // provides a guaranteed finally block – without this it isn’t guaranteed – non- + // interruptible by fatal exceptions RuntimeHelpers.PrepareConstrainedRegions(); try { unsafe { - - fixed (byte* pin_data = &data[0]) - { } if (passwords != null) { // Process SecureString @@ -327,48 +342,40 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w // SecureString is used if (passwords[i] != null) { - // provides a guaranteed finally block – without this it isn’t guaranteed – non interruptable by fatal exceptions + // provides a guaranteed finally block – without this it isn’t + // guaranteed – non-interruptible by fatal exceptions RuntimeHelpers.PrepareConstrainedRegions(); try { - // ========================================================================== - // Get the clear text of secure string without converting it to String type - // ========================================================================== + // ============================================================ + // Get the clear text of secure string without converting it + // to string type + // ============================================================ clearPassword = Marshal.SecureStringToCoTaskMemUnicode(passwords[i]); - // ========================================================================================================================== - // Losely encrypt the clear text - The encryption algorithm should exactly match the TdsParserStaticMethods.EncryptPassword - // ========================================================================================================================== + // ============================================================ + // Loosely encrypt the clear text - The encryption algorithm + // should exactly match the TdsParserStaticMethods.EncryptPassword + // ============================================================ + char* pwChar = (char*)clearPassword.ToPointer(); + byte* pByte = (byte*)clearPassword.ToPointer(); - unsafe + int passwordsLength = passwords[i].Length; + for (int j = 0; j < passwordsLength; ++j) { - - char* pwChar = (char*)clearPassword.ToPointer(); - byte* pByte = (byte*)(clearPassword.ToPointer()); - - - - - int s; - byte bLo; - byte bHi; - int passwordsLength = passwords[i].Length; - for (int j = 0; j < passwordsLength; ++j) - { - s = (int)*pwChar; - bLo = (byte)(s & 0xff); - bHi = (byte)((s >> 8) & 0xff); - *(pByte++) = (Byte)((((bLo & 0x0f) << 4) | (bLo >> 4)) ^ 0xa5); - *(pByte++) = (Byte)((((bHi & 0x0f) << 4) | (bHi >> 4)) ^ 0xa5); - ++pwChar; - } - - // =========================================================== - // Write out the losely encrypted passwords to data buffer - // =========================================================== - mustClearBuffer = true; - Marshal.Copy(clearPassword, data, passwordOffsets[i], passwordsLength * 2); + int s = *pwChar; + byte bLo = (byte)(s & 0xff); + byte bHi = (byte)((s >> 8) & 0xff); + *(pByte++) = (byte)((((bLo & 0x0f) << 4) | (bLo >> 4)) ^ 0xa5); + *(pByte++) = (byte)((((bHi & 0x0f) << 4) | (bHi >> 4)) ^ 0xa5); + ++pwChar; } + + // ============================================================ + // Write out the loosely encrypted passwords to data buffer + // ============================================================ + mustClearBuffer = true; + Marshal.Copy(clearPassword, data, passwordOffsets[i], passwordsLength * 2); } finally { @@ -408,14 +415,14 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w } #endif - internal static void SniPacketReset([In] SNIHandle pConn, IoType IOType, SNIPacket pPacket, ConsumerNumber ConsNum) => - s_nativeMethods.SniPacketReset(pConn, IOType, pPacket, ConsNum); + internal static void SniPacketReset(SNIHandle pConn, IoType ioType, SNIPacket pPacket, ConsumerNumber consNum) => + s_nativeMethods.SniPacketReset(pConn, ioType, pPacket, consNum); - internal static uint SniQueryInfo(QueryType QType, ref uint pbQInfo) => - s_nativeMethods.SniQueryInfo(QType, ref pbQInfo); + internal static uint SniQueryInfo(QueryType qType, ref uint pbQInfo) => + s_nativeMethods.SniQueryInfo(qType, ref pbQInfo); - internal static uint SniQueryInfo(QueryType QType, ref IntPtr pbQInfo) => - s_nativeMethods.SniQueryInfo(QType, ref pbQInfo); + internal static uint SniQueryInfo(QueryType qType, ref IntPtr pbQInfo) => + s_nativeMethods.SniQueryInfo(qType, ref pbQInfo); internal static uint SniReadAsync(SNIHandle pConn, ref IntPtr ppNewPacket) => s_nativeMethods.SniReadAsync(pConn, ref ppNewPacket); @@ -423,8 +430,8 @@ internal static uint SniReadAsync(SNIHandle pConn, ref IntPtr ppNewPacket) => internal static uint SniReadSyncOverAsync(SNIHandle pConn, ref IntPtr ppNewPacket, int timeout) => s_nativeMethods.SniReadSyncOverAsync(pConn, ref ppNewPacket, timeout); - internal static uint SniRemoveProvider(SNIHandle pConn, Provider ProvNum) => - s_nativeMethods.SniRemoveProvider(pConn, ProvNum); + internal static uint SniRemoveProvider(SNIHandle pConn, Provider provNum) => + s_nativeMethods.SniRemoveProvider(pConn, provNum); internal static unsafe uint SniSecGenClientContext( SNIHandle pConnectionObject, @@ -465,54 +472,46 @@ internal static unsafe uint SniSecGenClientContext( internal static uint SniSecInitPackage(ref uint pcbMaxToken) => s_nativeMethods.SniSecInitPackage(ref pcbMaxToken); - internal static void SniServerEnumClose([In] IntPtr packet) => + internal static void SniServerEnumClose(IntPtr packet) => s_nativeMethods.SniServerEnumClose(packet); internal static IntPtr SniServerEnumOpen() => s_nativeMethods.SniServerEnumOpen(); - internal static int SniServerEnumRead( - [In] IntPtr packet, - [In] [MarshalAs(UnmanagedType.LPArray)] char[] readBuffer, - [In] int bufferLength, - [MarshalAs(UnmanagedType.Bool)] out bool more) => + internal static int SniServerEnumRead(IntPtr packet, char[] readBuffer, int bufferLength, out bool more) => s_nativeMethods.SniServerEnumRead(packet, readBuffer, bufferLength, out more); - internal static uint SniSetInfo(SNIHandle pConn, QueryType QType, [In] ref uint pbQInfo) => - s_nativeMethods.SniSetInfo(pConn, QType, ref pbQInfo); + internal static uint SniSetInfo(SNIHandle pConn, QueryType qType, ref uint pbQInfo) => + s_nativeMethods.SniSetInfo(pConn, qType, ref pbQInfo); internal static uint SniTerminate() => s_nativeMethods.SniTerminate(); - internal static uint SniWaitForSslHandshakeToComplete([In] SNIHandle pConn, int dwMilliseconds, out uint pProtocolVersion) => + internal static uint SniWaitForSslHandshakeToComplete( + SNIHandle pConn, + int dwMilliseconds, + out uint pProtocolVersion) => s_nativeMethods.SniWaitForSslHandshakeToComplete(pConn, dwMilliseconds, out pProtocolVersion); - - internal static uint SniWritePacket(SNIHandle pConn, SNIPacket packet, bool sync) - { - if (sync) - { - return s_nativeMethods.SniWriteSyncOverAsync(pConn, packet); - } - else - { - return s_nativeMethods.SniWriteAsyncWrapper(pConn, packet); - } - } + + internal static uint SniWritePacket(SNIHandle pConn, SNIPacket packet, bool sync) => + sync + ? s_nativeMethods.SniWriteSyncOverAsync(pConn, packet) + : s_nativeMethods.SniWriteAsyncWrapper(pConn, packet); #endregion #region Private Methods - private static void MarshalConsumerInfo(ConsumerInfo consumerInfo, ref SniConsumerInfo native_consumerInfo) + private static void MarshalConsumerInfo(ConsumerInfo consumerInfo, ref SniConsumerInfo nativeConsumerInfo) { - native_consumerInfo.DefaultUserDataLength = consumerInfo.defaultBufferSize; - native_consumerInfo.fnReadComp = consumerInfo.readDelegate != null + nativeConsumerInfo.DefaultUserDataLength = consumerInfo.defaultBufferSize; + nativeConsumerInfo.fnReadComp = consumerInfo.readDelegate is not null ? Marshal.GetFunctionPointerForDelegate(consumerInfo.readDelegate) : IntPtr.Zero; - native_consumerInfo.fnWriteComp = consumerInfo.writeDelegate != null + nativeConsumerInfo.fnWriteComp = consumerInfo.writeDelegate is not null ? Marshal.GetFunctionPointerForDelegate(consumerInfo.writeDelegate) : IntPtr.Zero; - native_consumerInfo.ConsumerKey = consumerInfo.key; + nativeConsumerInfo.ConsumerKey = consumerInfo.key; } #endregion @@ -531,10 +530,9 @@ internal static _AppDomain GetDefaultAppDomain() [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] - internal unsafe static byte[] GetData() + internal static unsafe byte[] GetData() { - int size; - IntPtr ptr = (IntPtr)(SqlDependencyProcessDispatcherStorage.NativeGetData(out size)); + IntPtr ptr = (IntPtr)SqlDependencyProcessDispatcherStorage.NativeGetData(out int size); byte[] result = null; if (ptr != IntPtr.Zero) @@ -548,12 +546,11 @@ internal unsafe static byte[] GetData() [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] - internal unsafe static void SetData(Byte[] data) + internal static unsafe void SetData(byte[] data) { - //cli::pin_ptr pin_dispatcher = &data[0]; - fixed (byte* pin_dispatcher = &data[0]) + fixed (byte* pDispatcher = data) { - SqlDependencyProcessDispatcherStorage.NativeSetData(pin_dispatcher, data.Length); + SqlDependencyProcessDispatcherStorage.NativeSetData(pDispatcher, data.Length); } } #endif diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/SystemErrors.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/SystemErrors.cs index 6a2edef310..40819ef4d9 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/SystemErrors.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/SystemErrors.cs @@ -5,8 +5,10 @@ namespace Interop.Windows { // https://msdn.microsoft.com/en-us/library/windows/desktop/ms681382.aspx - internal partial class SystemErrors + internal class SystemErrors { + internal const int ERROR_SUCCESS = 0x00; + internal const int ERROR_FILE_NOT_FOUND = 0x2; internal const int ERROR_INVALID_HANDLE = 0x6; internal const int ERROR_SHARING_VIOLATION = 0x20; From 635f6f9bac7a722223afcd9accd29a8f78b485d6 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 14 Nov 2024 13:21:44 -0600 Subject: [PATCH 07/12] Remove AppDomain method from SniNativeWrapper --- .../Interop/Windows/Sni/SniNativeWrapper.cs | 11 ----- .../Microsoft/Data/SqlClient/SqlDependency.cs | 45 ++++++++----------- 2 files changed, 18 insertions(+), 38 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index 2a1a87ae98..bf118dd0d6 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -516,18 +516,7 @@ private static void MarshalConsumerInfo(ConsumerInfo consumerInfo, ref SniConsum #endregion - #if NETFRAMEWORK - static AppDomain GetDefaultAppDomainInternal() - { - return AppDomain.CurrentDomain; - } - - internal static _AppDomain GetDefaultAppDomain() - { - return GetDefaultAppDomainInternal(); - } - [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] internal static unsafe byte[] GetData() diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDependency.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDependency.cs index d2fdd6ef0e..787cbedac8 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDependency.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDependency.cs @@ -473,47 +473,38 @@ private static void ObtainProcessDispatcher() #if DEBUG // Possibly expensive, limit to debug. SqlClientEventSource.Log.TryNotificationTraceEvent(" AppDomain.CurrentDomain.FriendlyName: {0}", AppDomain.CurrentDomain.FriendlyName); - #endif // DEBUG - _AppDomain masterDomain = SniNativeWrapper.GetDefaultAppDomain(); - - if (masterDomain != null) + + _AppDomain masterDomain = AppDomain.CurrentDomain; + + ObjectHandle handle = CreateProcessDispatcher(masterDomain); + if (handle != null) { - ObjectHandle handle = CreateProcessDispatcher(masterDomain); + SqlDependencyProcessDispatcher dependency = (SqlDependencyProcessDispatcher)handle.Unwrap(); - if (handle != null) + if (dependency != null) { - SqlDependencyProcessDispatcher dependency = (SqlDependencyProcessDispatcher)handle.Unwrap(); - - if (dependency != null) - { - s_processDispatcher = SqlDependencyProcessDispatcher.SingletonProcessDispatcher; // Set to static instance. + s_processDispatcher = SqlDependencyProcessDispatcher.SingletonProcessDispatcher; // Set to static instance. - // Serialize and set in native. - using (MemoryStream stream = new()) - { - SqlClientObjRef objRef = new(s_processDispatcher); - DataContractSerializer serializer = new(objRef.GetType()); - GetSerializedObject(objRef, serializer, stream); - SniNativeWrapper.SetData(stream.ToArray()); // Native will be forced to synchronize and not overwrite. - } - } - else + // Serialize and set in native. + using (MemoryStream stream = new()) { - SqlClientEventSource.Log.TryNotificationTraceEvent(" ERROR - ObjectHandle.Unwrap returned null!"); - throw ADP.InternalError(ADP.InternalErrorCode.SqlDependencyObtainProcessDispatcherFailureObjectHandle); + SqlClientObjRef objRef = new(s_processDispatcher); + DataContractSerializer serializer = new(objRef.GetType()); + GetSerializedObject(objRef, serializer, stream); + SniNativeWrapper.SetData(stream.ToArray()); // Native will be forced to synchronize and not overwrite. } } else { - SqlClientEventSource.Log.TryNotificationTraceEvent(" ERROR - AppDomain.CreateInstance returned null!"); - throw ADP.InternalError(ADP.InternalErrorCode.SqlDependencyProcessDispatcherFailureCreateInstance); + SqlClientEventSource.Log.TryNotificationTraceEvent(" ERROR - ObjectHandle.Unwrap returned null!"); + throw ADP.InternalError(ADP.InternalErrorCode.SqlDependencyObtainProcessDispatcherFailureObjectHandle); } } else { - SqlClientEventSource.Log.TryNotificationTraceEvent(" ERROR - unable to obtain default AppDomain!"); - throw ADP.InternalError(ADP.InternalErrorCode.SqlDependencyProcessDispatcherFailureAppDomain); + SqlClientEventSource.Log.TryNotificationTraceEvent(" ERROR - AppDomain.CreateInstance returned null!"); + throw ADP.InternalError(ADP.InternalErrorCode.SqlDependencyProcessDispatcherFailureCreateInstance); } } else From 476aad33b8a00b08afe0c382b3dab988c91985ef Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 14 Nov 2024 13:30:12 -0600 Subject: [PATCH 08/12] Remove Win32NativeMethods --- .../Interop/Windows/Sni/SniNativeWrapper.cs | 34 ++++++------------- .../DbConnectionPoolIdentity.Windows.cs | 7 ++-- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index bf118dd0d6..e4305f982d 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -10,7 +10,6 @@ using System.Text; using Interop.Windows.Sni; using Microsoft.Data.Common; -using Microsoft.Data.SqlClient; #if NETFRAMEWORK using System.Diagnostics; @@ -124,9 +123,17 @@ internal static uint SniGetProviderNumber(SNIHandle pConn, ref Provider provNum) internal static uint SniInitialize() => s_nativeMethods.SniInitialize(IntPtr.Zero); - - internal static uint SniIsTokenRestricted(IntPtr token, out bool isRestricted) => - s_nativeMethods.SniIsTokenRestricted(token, out isRestricted); + + internal static uint SniIsTokenRestricted(IntPtr token, out bool isRestricted) + { + uint result = s_nativeMethods.SniIsTokenRestricted(token, out isRestricted); + if (result != 0) + { + Marshal.ThrowExceptionForHR(unchecked((int)result)); + } + + return result; + } internal static uint SniOpenMarsSession( ConsumerInfo consumerInfo, @@ -545,22 +552,3 @@ internal static unsafe void SetData(byte[] data) #endif } } - -namespace Microsoft.Data -{ - internal static class Win32NativeMethods - { - internal static bool IsTokenRestrictedWrapper(IntPtr token) - { - bool isRestricted; - uint result = SniNativeWrapper.SniIsTokenRestricted(token, out isRestricted); - - if (result != 0) - { - Marshal.ThrowExceptionForHR(unchecked((int)result)); - } - - return isRestricted; - } - } -} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/DbConnectionPoolIdentity.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/DbConnectionPoolIdentity.Windows.cs index 175d4c8595..99783ff3c7 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/DbConnectionPoolIdentity.Windows.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/DbConnectionPoolIdentity.Windows.cs @@ -43,10 +43,13 @@ private static DbConnectionPoolIdentity GetCurrentNative() string sidString = user.Value; // Win32NativeMethods.IsTokenRestricted will raise exception if the native call fails - bool isRestricted = Win32NativeMethods.IsTokenRestrictedWrapper(token); + SniNativeWrapper.SniIsTokenRestricted(token, out bool isRestricted); var lastIdentity = s_lastIdentity; - if ((lastIdentity != null) && (lastIdentity._sidString == sidString) && (lastIdentity._isRestricted == isRestricted) && (lastIdentity._isNetwork == isNetwork)) + if (lastIdentity != null && + lastIdentity._sidString == sidString && + lastIdentity._isRestricted == isRestricted && + lastIdentity._isNetwork == isNetwork) { current = lastIdentity; } From d924304a16277c8c2276e9d20e882e138b38a9a1 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 14 Nov 2024 13:48:41 -0600 Subject: [PATCH 09/12] Remove sqldependency process storage methods from SniNativeWrapper --- .../SqlClient/TdsParserStateObject.netcore.cs | 4 +- .../SqlClient/TdsParserStateObjectManaged.cs | 2 +- .../SqlClient/TdsParserStateObjectNative.cs | 2 +- .../Interop/Windows/Sni/SniNativeWrapper.cs | 28 -------- ...ependencyProcessDispatcherStorage.netfx.cs | 66 +++++++++++-------- .../SqlClient/LocalDb/LocalDbApi.Windows.cs | 4 +- .../Microsoft/Data/SqlClient/SqlDependency.cs | 6 +- .../TdsParserStateObject.Multiplexer.cs | 12 ++-- .../TdsParserStateObject.TestHarness.cs | 12 ++-- 9 files changed, 61 insertions(+), 75 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs index b2763a786f..0b564d1674 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs @@ -229,7 +229,7 @@ internal abstract void CreatePhysicalSNIHandle( internal abstract void ReleasePacket(PacketHandle syncReadPacket); - protected abstract uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize); + protected abstract uint SniPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize); internal abstract PacketHandle GetResetWritePacket(int dataSize); @@ -401,7 +401,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) private uint GetSniPacket(PacketHandle packet, ref uint dataSize) { - return SNIPacketGetData(packet, _inBuff, ref dataSize); + return SniPacketGetData(packet, _inBuff, ref dataSize); } private void ChangeNetworkPacketTimeout(int dueTime, int period) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index 66606e26b8..3a709d03c9 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -69,7 +69,7 @@ internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async) /// Destination byte array where data packets are copied to /// Length of data packets /// SNI error status - protected override uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize) + protected override uint SniPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize) { int dataSizeInt = 0; packet.ManagedPacket.GetData(inBuff, ref dataSizeInt); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index 390ca39cd2..929056b306 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -181,7 +181,7 @@ internal override void CreatePhysicalSNIHandle( spns = new[] { serverSPN.TrimEnd() }; } - protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) + protected override uint SniPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) { Debug.Assert(packet.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); return SniNativeWrapper.SniPacketGetData(packet.NativePointer, _inBuff, ref dataSize); diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index e4305f982d..42abf2c052 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -522,33 +522,5 @@ private static void MarshalConsumerInfo(ConsumerInfo consumerInfo, ref SniConsum } #endregion - - #if NETFRAMEWORK - [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider - [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] - internal static unsafe byte[] GetData() - { - IntPtr ptr = (IntPtr)SqlDependencyProcessDispatcherStorage.NativeGetData(out int size); - byte[] result = null; - - if (ptr != IntPtr.Zero) - { - result = new byte[size]; - Marshal.Copy(ptr, result, 0, size); - } - - return result; - } - - [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider - [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] - internal static unsafe void SetData(byte[] data) - { - fixed (byte* pDispatcher = data) - { - SqlDependencyProcessDispatcherStorage.NativeSetData(pDispatcher, data.Length); - } - } - #endif } } diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SqlDependencyProcessDispatcherStorage.netfx.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SqlDependencyProcessDispatcherStorage.netfx.cs index c4693a32b4..bbeda47473 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SqlDependencyProcessDispatcherStorage.netfx.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SqlDependencyProcessDispatcherStorage.netfx.cs @@ -4,52 +4,64 @@ #if NETFRAMEWORK +using System; using System.Diagnostics; using System.Runtime.InteropServices; +using System.Runtime.Versioning; using System.Threading; namespace Interop.Windows.Sni { internal unsafe class SqlDependencyProcessDispatcherStorage { - static void* data; + private static void* s_data; + private static int s_size; + private static volatile int s_lock; // Int used for a spin-lock. - static int size; - static volatile int thelock; // Int used for a spin-lock. - - public static void* NativeGetData(out int passedSize) + [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider + [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] + public static byte[] NativeGetData() { - passedSize = size; - return data; + IntPtr ptr = (IntPtr)s_data; + + byte[] result = null; + if (ptr != IntPtr.Zero) + { + result = new byte[s_size]; + Marshal.Copy(ptr, result, 0, s_size); + } + + return result; } - internal static bool NativeSetData(void* passedData, int passedSize) + [ResourceExposure(ResourceScope.Process)] // SxS: there is no way to set scope = Instance, using Process which is wider + [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] + internal static void NativeSetData(byte[] data) { - bool success = false; - - while (0 != Interlocked.CompareExchange(ref thelock, 1, 0)) - { // Spin until we have the lock. - Thread.Sleep(50); // Sleep with short-timeout to prevent starvation. - } - Trace.Assert(1 == thelock); // Now that we have the lock, lock should be equal to 1. - - if (data == null) + fixed (byte* pDispatcher = data) { - data = Marshal.AllocHGlobal(passedSize).ToPointer(); + while (Interlocked.CompareExchange(ref s_lock, 1, 0) != 0) + { + // Spin until we have the lock. + Thread.Sleep(50); // Sleep with short-timeout to prevent starvation. + } + Trace.Assert(s_lock == 1); // Now that we have the lock, lock should be equal to 1. - Trace.Assert(data != null); + if (s_data == null) + { + s_data = Marshal.AllocHGlobal(data.Length).ToPointer(); - System.Buffer.MemoryCopy(passedData, data, passedSize, passedSize); + Trace.Assert(s_data != null); - Trace.Assert(0 == size); // Size should still be zero at this point. - size = passedSize; - success = true; - } + Buffer.MemoryCopy(pDispatcher, s_data, data.Length, data.Length); - int result = Interlocked.CompareExchange(ref thelock, 0, 1); - Trace.Assert(1 == result); // The release of the lock should have been successful. + Trace.Assert(0 == s_size); // Size should still be zero at this point. + s_size = data.Length; + } - return success; + int result = Interlocked.CompareExchange(ref s_lock, 0, 1); + Trace.Assert(1 == result); // The release of the lock should have been successful. + } } } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalDb/LocalDbApi.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalDb/LocalDbApi.Windows.cs index ba7b3cea2f..c160af38d3 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalDb/LocalDbApi.Windows.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalDb/LocalDbApi.Windows.cs @@ -150,7 +150,7 @@ private static IntPtr UserInstanceDllHandle { if (s_userInstanceDllHandle == IntPtr.Zero) { - SniNativeWrapper.SNIQueryInfo(QueryType.SNI_QUERY_LOCALDB_HMODULE, ref s_userInstanceDllHandle); + SniNativeWrapper.SniQueryInfo(QueryType.SNI_QUERY_LOCALDB_HMODULE, ref s_userInstanceDllHandle); if (s_userInstanceDllHandle != IntPtr.Zero) { #if NETFRAMEWORK @@ -161,7 +161,7 @@ private static IntPtr UserInstanceDllHandle } else { - SniNativeWrapper.SNIGetLastError(out SniError sniError); + SniNativeWrapper.SniGetLastError(out SniError sniError); throw CreateLocalDbException( errorMessage: StringsHelper.GetString("LocalDB_FailedGetDLLHandle"), sniError: sniError.sniError); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDependency.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDependency.cs index 787cbedac8..65d9c0cd92 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDependency.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDependency.cs @@ -465,7 +465,7 @@ public void AddCommandDependency(SqlCommand command) [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)] private static void ObtainProcessDispatcher() { - byte[] nativeStorage = SniNativeWrapper.GetData(); + byte[] nativeStorage = SqlDependencyProcessDispatcherStorage.NativeGetData(); if (nativeStorage == null) { @@ -492,7 +492,9 @@ private static void ObtainProcessDispatcher() SqlClientObjRef objRef = new(s_processDispatcher); DataContractSerializer serializer = new(objRef.GetType()); GetSerializedObject(objRef, serializer, stream); - SniNativeWrapper.SetData(stream.ToArray()); // Native will be forced to synchronize and not overwrite. + + // Native will be forced to synchronize and not overwrite. + SqlDependencyProcessDispatcherStorage.NativeSetData(stream.ToArray()); } } else diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs index 2bb72e9bf2..19d3d5add2 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs @@ -513,12 +513,12 @@ public void ProcessSniPacketCompat(PacketHandle packet, uint error) else { uint dataSize = 0; - - uint getDataError = -#if NETFRAMEWORK - SniNativeWrapper. -#endif - SNIPacketGetData(packet, _inBuff, ref dataSize); + + #if NETFRAMEWORK + uint getDataError = SniNativeWrapper.SniPacketGetData(packet, _inBuff, ref dataSize); + #else + uint getDataError = SniPacketGetData(packet, _inBuff, ref dataSize); + #endif if (getDataError == TdsEnums.SNI_SUCCESS) { diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs index c512c3385b..914942e2ff 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs @@ -55,7 +55,7 @@ internal enum TdsParserState private uint GetSniPacket(PacketHandle packet, ref uint dataSize) { - return SNIPacketGetData(packet, _inBuff, ref dataSize); + return SniPacketGetData(packet, _inBuff, ref dataSize); } private class StringsHelper @@ -71,7 +71,7 @@ internal class Strings public class Parser { - internal object ProcessSNIError(TdsParserStateObject tdsParserStateObject) => "ProcessSNIError"; + internal object ProcessSniError(TdsParserStateObject tdsParserStateObject) => "ProcessSNIError"; public TdsParserState State = TdsParserState.OpenLoggedIn; } @@ -118,7 +118,7 @@ public TdsParserStateObject(List input, int packetSize, bool isAsync } } [DebuggerStepThrough] - private uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize) + private uint SniPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize) { Span target = inBuff.AsSpan(0, _packetSize); Span source = Current.Array.AsSpan(Current.Start, Current.Length); @@ -161,7 +161,7 @@ public static bool UseCompatibilityProcessSni } } -#if NETFRAMEWORK + #if NETFRAMEWORK private SniNativeWrapperImpl _native; internal SniNativeWrapperImpl SniNativeWrapper { @@ -180,9 +180,9 @@ internal class SniNativeWrapperImpl private readonly TdsParserStateObject _parent; internal SniNativeWrapperImpl(TdsParserStateObject parent) => _parent = parent; - internal uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize) => _parent.SNIPacketGetData(packet, inBuff, ref dataSize); + internal uint SniPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize) => _parent.SniPacketGetData(packet, inBuff, ref dataSize); } -#endif + #endif } internal static class TdsEnums From c6987bb0acd16fa63042e5a5a2ab145664c54b8a Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Tue, 4 Mar 2025 16:02:48 -0600 Subject: [PATCH 10/12] Fix extraneous using statements --- .../src/Interop/Windows/Sni/SniNativeWrapper.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index 42abf2c052..ae754d9421 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -5,14 +5,12 @@ using System; using System.Buffers; using System.Diagnostics; -using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; using Interop.Windows.Sni; using Microsoft.Data.Common; #if NETFRAMEWORK -using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.Versioning; using System.Security; From f4594e9d32662d4f5a84d024c36e36d47e8c43a2 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 6 Mar 2025 18:35:19 -0600 Subject: [PATCH 11/12] Revert a name change in the test harness --- .../tests/FunctionalTests/TdsParserStateObject.TestHarness.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs index 914942e2ff..eded4d1986 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs @@ -71,7 +71,7 @@ internal class Strings public class Parser { - internal object ProcessSniError(TdsParserStateObject tdsParserStateObject) => "ProcessSNIError"; + internal object ProcessSNIError(TdsParserStateObject tdsParserStateObject) => "ProcessSNIError"; public TdsParserState State = TdsParserState.OpenLoggedIn; } From ac9be8403871aac221cf53cbd67cbc8a9b5327ed Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 6 Mar 2025 18:42:25 -0600 Subject: [PATCH 12/12] Removing no-longer-used error code --- .../src/Microsoft/Data/Common/AdapterUtil.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs index 97e6a2039a..3f23a186f8 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs @@ -945,7 +945,7 @@ internal enum InternalErrorCode SqlDependencyObtainProcessDispatcherFailureObjectHandle = 50, SqlDependencyProcessDispatcherFailureCreateInstance = 51, - SqlDependencyProcessDispatcherFailureAppDomain = 52, + SqlDependencyCommandHashIsNotAssociatedWithNotification = 53, UnknownTransactionFailure = 60,