diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs index 004a797cd114e4..8e6cb95ef77a61 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs @@ -928,7 +928,7 @@ internal void ProcessHandshakeSuccess() _headerSize = streamSizes.Header; _trailerSize = streamSizes.Trailer; - _maxDataSize = checked(streamSizes.MaximumMessage - (_headerSize + _trailerSize)); + _maxDataSize = streamSizes.MaximumMessage; Debug.Assert(_maxDataSize > 0, "_maxDataSize > 0"); SslStreamPal.QueryContextConnectionInfo(_securityContext!, ref _connectionInfo); @@ -942,18 +942,6 @@ internal void ProcessHandshakeSuccess() #endif } - /*++ - Encrypt - Encrypts our bytes before we send them over the wire - - PERF: make more efficient, this does an extra copy when the offset - is non-zero. - - Input: - buffer - bytes for sending - offset - - size - - output - Encrypted bytes - --*/ internal ProtocolToken Encrypt(ReadOnlyMemory buffer) { if (NetEventSource.Log.IsEnabled()) NetEventSource.DumpBuffer(this, buffer.Span); @@ -1337,7 +1325,7 @@ internal void EnsureAvailableSpace(int size) var oldPayload = Payload; - Payload = RentBuffer? ArrayPool.Shared.Rent(Size + size) : new byte[Size + size]; + Payload = RentBuffer ? ArrayPool.Shared.Rent(Size + size) : new byte[Size + size]; if (oldPayload != null) { oldPayload.AsSpan().CopyTo(Payload); diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs index 0b1744bd44bd15..a9fe7f0a4e741d 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Security.Authentication; using System.Security.Authentication.ExtendedProtection; @@ -49,7 +50,8 @@ public static Exception GetException(SecurityStatusPal status) private static byte[] InitSessionTokenBuffer() { - var schannelSessionToken = new Interop.SChannel.SCHANNEL_SESSION_TOKEN() { + var schannelSessionToken = new Interop.SChannel.SCHANNEL_SESSION_TOKEN() + { dwTokenType = Interop.SChannel.SCHANNEL_SESSION, dwFlags = Interop.SChannel.SSL_SESSION_DISABLE_RECONNECTS, }; @@ -61,7 +63,7 @@ public static void VerifyPackageInfo() SSPIWrapper.GetVerifyPackageInfo(GlobalSSPI.SSPISecureChannel, SecurityPackage, true); } - private static unsafe void SetAlpn(ref InputSecurityBuffers inputBuffers, List alpn, Span localBuffer) + private static void SetAlpn(ref InputSecurityBuffers inputBuffers, List alpn, Span localBuffer) { if (alpn.Count == 1 && alpn[0] == SslApplicationProtocol.Http11) { @@ -82,7 +84,7 @@ private static unsafe void SetAlpn(ref InputSecurityBuffers inputBuffers, List() + protocolLength; Span alpnBuffer = bufferLength <= localBuffer.Length ? localBuffer : new byte[bufferLength]; Interop.Sec_Application_Protocols.SetProtocols(alpnBuffer, alpn, protocolLength); @@ -99,7 +101,7 @@ public static SecurityStatusPal SelectApplicationProtocol( throw new PlatformNotSupportedException(nameof(SelectApplicationProtocol)); } - public static unsafe ProtocolToken AcceptSecurityContext( + public static ProtocolToken AcceptSecurityContext( ref SafeFreeCredentials? credentialsHandle, ref SafeDeleteSslContext? context, ReadOnlySpan inputBuffer, @@ -141,7 +143,7 @@ public static bool TryUpdateClintCertificate( return false; } - public static unsafe ProtocolToken InitializeSecurityContext( + public static ProtocolToken InitializeSecurityContext( ref SafeFreeCredentials? credentialsHandle, ref SafeDeleteSslContext? context, string? targetName, @@ -445,32 +447,32 @@ public static unsafe ProtocolToken EncryptMessage(SafeDeleteSslContext securityC input.Span.CopyTo(token.AvailableSpan.Slice(headerSize, input.Length)); const int NumSecBuffers = 4; // header + data + trailer + empty - Interop.SspiCli.SecBuffer* unmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[NumSecBuffers]; + Span unmanagedBuffers = stackalloc Interop.SspiCli.SecBuffer[NumSecBuffers]; Interop.SspiCli.SecBufferDesc sdcInOut = new Interop.SspiCli.SecBufferDesc(NumSecBuffers) { - pBuffers = unmanagedBuffer + pBuffers = Unsafe.AsPointer(ref MemoryMarshal.GetReference(unmanagedBuffers)) }; fixed (byte* outputPtr = token.Payload) { - Interop.SspiCli.SecBuffer* headerSecBuffer = &unmanagedBuffer[0]; - headerSecBuffer->BufferType = SecurityBufferType.SECBUFFER_STREAM_HEADER; - headerSecBuffer->pvBuffer = (IntPtr)outputPtr; - headerSecBuffer->cbBuffer = headerSize; + ref Interop.SspiCli.SecBuffer headerSecBuffer = ref unmanagedBuffers[0]; + headerSecBuffer.BufferType = SecurityBufferType.SECBUFFER_STREAM_HEADER; + headerSecBuffer.pvBuffer = (IntPtr)outputPtr; + headerSecBuffer.cbBuffer = headerSize; - Interop.SspiCli.SecBuffer* dataSecBuffer = &unmanagedBuffer[1]; - dataSecBuffer->BufferType = SecurityBufferType.SECBUFFER_DATA; - dataSecBuffer->pvBuffer = (IntPtr)(outputPtr + headerSize); - dataSecBuffer->cbBuffer = input.Length; + ref Interop.SspiCli.SecBuffer dataSecBuffer = ref unmanagedBuffers[1]; + dataSecBuffer.BufferType = SecurityBufferType.SECBUFFER_DATA; + dataSecBuffer.pvBuffer = (IntPtr)(outputPtr + headerSize); + dataSecBuffer.cbBuffer = input.Length; - Interop.SspiCli.SecBuffer* trailerSecBuffer = &unmanagedBuffer[2]; - trailerSecBuffer->BufferType = SecurityBufferType.SECBUFFER_STREAM_TRAILER; - trailerSecBuffer->pvBuffer = (IntPtr)(outputPtr + headerSize + input.Length); - trailerSecBuffer->cbBuffer = trailerSize; + ref Interop.SspiCli.SecBuffer trailerSecBuffer = ref unmanagedBuffers[2]; + trailerSecBuffer.BufferType = SecurityBufferType.SECBUFFER_STREAM_TRAILER; + trailerSecBuffer.pvBuffer = (IntPtr)(outputPtr + headerSize + input.Length); + trailerSecBuffer.cbBuffer = trailerSize; - Interop.SspiCli.SecBuffer* emptySecBuffer = &unmanagedBuffer[3]; - emptySecBuffer->BufferType = SecurityBufferType.SECBUFFER_EMPTY; - emptySecBuffer->cbBuffer = 0; - emptySecBuffer->pvBuffer = IntPtr.Zero; + ref Interop.SspiCli.SecBuffer emptySecBuffer = ref unmanagedBuffers[3]; + emptySecBuffer.BufferType = SecurityBufferType.SECBUFFER_EMPTY; + emptySecBuffer.cbBuffer = 0; + emptySecBuffer.pvBuffer = IntPtr.Zero; int errorCode = GlobalSSPI.SSPISecureChannel.EncryptMessage(securityContext, ref sdcInOut, 0); @@ -483,10 +485,10 @@ public static unsafe ProtocolToken EncryptMessage(SafeDeleteSslContext securityC return token; } - Debug.Assert(headerSecBuffer->cbBuffer >= 0 && dataSecBuffer->cbBuffer >= 0 && trailerSecBuffer->cbBuffer >= 0); - Debug.Assert(checked(headerSecBuffer->cbBuffer + dataSecBuffer->cbBuffer + trailerSecBuffer->cbBuffer) <= token.Payload!.Length); + Debug.Assert(headerSecBuffer.cbBuffer >= 0 && dataSecBuffer.cbBuffer >= 0 && trailerSecBuffer.cbBuffer >= 0); + Debug.Assert(checked(headerSecBuffer.cbBuffer + dataSecBuffer.cbBuffer + trailerSecBuffer.cbBuffer) <= token.Payload!.Length); - token.Size = checked(headerSecBuffer->cbBuffer + dataSecBuffer->cbBuffer + trailerSecBuffer->cbBuffer); + token.Size = checked(headerSecBuffer.cbBuffer + dataSecBuffer.cbBuffer + trailerSecBuffer.cbBuffer); token.Status = new SecurityStatusPal(SecurityStatusPalErrorCode.OK); } @@ -496,25 +498,26 @@ public static unsafe ProtocolToken EncryptMessage(SafeDeleteSslContext securityC public static unsafe SecurityStatusPal DecryptMessage(SafeDeleteSslContext? securityContext, Span buffer, out int offset, out int count) { const int NumSecBuffers = 4; // data + empty + empty + empty - fixed (byte* bufferPtr = buffer) + + Span unmanagedBuffers = stackalloc Interop.SspiCli.SecBuffer[NumSecBuffers]; + for (int i = 1; i < NumSecBuffers; i++) { - Interop.SspiCli.SecBuffer* unmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[NumSecBuffers]; - Interop.SspiCli.SecBuffer* dataBuffer = &unmanagedBuffer[0]; - dataBuffer->BufferType = SecurityBufferType.SECBUFFER_DATA; - dataBuffer->pvBuffer = (IntPtr)bufferPtr; - dataBuffer->cbBuffer = buffer.Length; + ref Interop.SspiCli.SecBuffer emptyBuffer = ref unmanagedBuffers[i]; + emptyBuffer.BufferType = SecurityBufferType.SECBUFFER_EMPTY; + emptyBuffer.pvBuffer = IntPtr.Zero; + emptyBuffer.cbBuffer = 0; + } - for (int i = 1; i < NumSecBuffers; i++) - { - Interop.SspiCli.SecBuffer* emptyBuffer = &unmanagedBuffer[i]; - emptyBuffer->BufferType = SecurityBufferType.SECBUFFER_EMPTY; - emptyBuffer->pvBuffer = IntPtr.Zero; - emptyBuffer->cbBuffer = 0; - } + fixed (byte* bufferPtr = buffer) + { + ref Interop.SspiCli.SecBuffer dataBuffer = ref unmanagedBuffers[0]; + dataBuffer.BufferType = SecurityBufferType.SECBUFFER_DATA; + dataBuffer.pvBuffer = (IntPtr)bufferPtr; + dataBuffer.cbBuffer = buffer.Length; Interop.SspiCli.SecBufferDesc sdcInOut = new Interop.SspiCli.SecBufferDesc(NumSecBuffers) { - pBuffers = unmanagedBuffer + pBuffers = Unsafe.AsPointer(ref MemoryMarshal.GetReference(unmanagedBuffers)) }; Interop.SECURITY_STATUS errorCode = (Interop.SECURITY_STATUS)GlobalSSPI.SSPISecureChannel.DecryptMessage(securityContext!, ref sdcInOut, out _); @@ -525,12 +528,12 @@ public static unsafe SecurityStatusPal DecryptMessage(SafeDeleteSslContext? secu for (int i = 0; i < NumSecBuffers; i++) { // Successfully decoded data and placed it at the following position in the buffer, - if ((errorCode == Interop.SECURITY_STATUS.OK && unmanagedBuffer[i].BufferType == SecurityBufferType.SECBUFFER_DATA) + if ((errorCode == Interop.SECURITY_STATUS.OK && unmanagedBuffers[i].BufferType == SecurityBufferType.SECBUFFER_DATA) // or we failed to decode the data, here is the encoded data. - || (errorCode != Interop.SECURITY_STATUS.OK && unmanagedBuffer[i].BufferType == SecurityBufferType.SECBUFFER_EXTRA)) + || (errorCode != Interop.SECURITY_STATUS.OK && unmanagedBuffers[i].BufferType == SecurityBufferType.SECBUFFER_EXTRA)) { - offset = (int)((byte*)unmanagedBuffer[i].pvBuffer - bufferPtr); - count = unmanagedBuffer[i].cbBuffer; + offset = (int)((byte*)unmanagedBuffers[i].pvBuffer - bufferPtr); + count = unmanagedBuffers[i].cbBuffer; // output is ignored on Windows. We always decrypt in place and we set outputOffset to indicate where the data start. Debug.Assert(offset >= 0 && count >= 0, $"Expected offset and count greater than 0, got {offset} and {count}");