Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ internal static int GetAlpnProtocolListSerializedLength(List<SslApplicationProto
}

protocolSize += protocol.Protocol.Length + 1;

if (protocolSize > ushort.MaxValue)
{
throw new ArgumentException(SR.net_ssl_app_protocols_invalid, nameof(applicationProtocols));
}
}

return protocolSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ internal struct Sec_Application_Protocols
{
public uint ProtocolListsSize;
public ApplicationProtocolNegotiationExt ProtocolExtensionType;
public short ProtocolListSize;
public ushort ProtocolListSize;

public static int GetProtocolLength(List<SslApplicationProtocol> applicationProtocols)
{
Expand All @@ -30,7 +30,7 @@ public static int GetProtocolLength(List<SslApplicationProtocol> applicationProt

protocolListSize += protocolLength + 1;

if (protocolListSize > short.MaxValue)
if (protocolListSize > ushort.MaxValue)
{
throw new ArgumentException(SR.net_ssl_app_protocols_invalid, nameof(applicationProtocols));
}
Expand All @@ -49,7 +49,7 @@ public static unsafe byte[] ToByteArray(List<SslApplicationProtocol> application
protocols.ProtocolListsSize = (uint)(protocolListConstSize + protocolListSize);

protocols.ProtocolExtensionType = ApplicationProtocolNegotiationExt.ALPN;
protocols.ProtocolListSize = (short)protocolListSize;
protocols.ProtocolListSize = (ushort)protocolListSize;

byte[] buffer = new byte[sizeof(Sec_Application_Protocols) + protocolListSize];
int index = 0;
Expand All @@ -73,7 +73,7 @@ public static unsafe void SetProtocols(Span<byte> buffer, List<SslApplicationPro
Span<Sec_Application_Protocols> alpn = MemoryMarshal.Cast<byte, Sec_Application_Protocols>(buffer);
alpn[0].ProtocolListsSize = (uint)(sizeof(Sec_Application_Protocols) - sizeof(uint) + protocolLength);
alpn[0].ProtocolExtensionType = ApplicationProtocolNegotiationExt.ALPN;
alpn[0].ProtocolListSize = (short)protocolLength;
alpn[0].ProtocolListSize = (ushort)protocolLength;

Span<byte> data = buffer.Slice(sizeof(Sec_Application_Protocols));
for (int i = 0; i < applicationProtocols.Count; i++)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Collections.Generic;
using System.Net.Security;
using System.Runtime.InteropServices;
using System.Security.Authentication;
Expand Down Expand Up @@ -306,6 +307,7 @@ private unsafe void InitializeSslContext(
if (authOptions.ApplicationProtocols != null && authOptions.ApplicationProtocols.Count != 0
&& Interop.AndroidCrypto.SSLSupportsApplicationProtocolsConfiguration())
{
ValidateAlpnProtocolListSize(authOptions.ApplicationProtocols);
// Set application protocols if the platform supports it. Otherwise, we will silently ignore the option.
Interop.AndroidCrypto.SSLStreamSetApplicationProtocols(handle, authOptions.ApplicationProtocols);
}
Expand All @@ -320,5 +322,18 @@ private unsafe void InitializeSslContext(
Interop.AndroidCrypto.SSLStreamSetTargetHost(handle, authOptions.TargetHost);
}
}

private static void ValidateAlpnProtocolListSize(List<SslApplicationProtocol> applicationProtocols)
{
int protocolListSize = 0;
foreach (SslApplicationProtocol protocol in applicationProtocols)
{
protocolListSize += protocol.Protocol.Length + 1;
if (protocolListSize > ushort.MaxValue)
{
throw new ArgumentException(SR.net_ssl_app_protocols_invalid, nameof(applicationProtocols));
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -442,10 +442,17 @@ static int GetAlpnProtocolListSerializedLength(List<SslApplicationProtocol>? app
}

int protocolSize = 0;
int wireSize = 0;

foreach (SslApplicationProtocol protocol in applicationProtocols)
{
protocolSize += protocol.Protocol.Length + 2;

wireSize += protocol.Protocol.Length + 1;
if (wireSize > ushort.MaxValue)
{
throw new ArgumentException(SR.net_ssl_app_protocols_invalid, nameof(applicationProtocols));
}
}

return protocolSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public SafeDeleteSslContext(SslAuthenticationOptions sslAuthenticationOptions)

if (sslAuthenticationOptions.ApplicationProtocols != null && sslAuthenticationOptions.ApplicationProtocols.Count != 0)
{
ValidateAlpnProtocolListSize(sslAuthenticationOptions.ApplicationProtocols);

if (sslAuthenticationOptions.IsClient)
{
// On macOS coreTls supports only client side.
Expand Down Expand Up @@ -397,5 +399,18 @@ internal static void SetCertificate(SafeSslHandle sslContext, SslStreamCertifica

Interop.AppleCrypto.SslSetCertificate(sslContext, ptrs);
}

private static void ValidateAlpnProtocolListSize(List<SslApplicationProtocol> applicationProtocols)
{
int protocolListSize = 0;
foreach (SslApplicationProtocol protocol in applicationProtocols)
{
protocolListSize += protocol.Protocol.Length + 1;
if (protocolListSize > ushort.MaxValue)
{
throw new ArgumentException(SR.net_ssl_app_protocols_invalid, nameof(applicationProtocols));
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ private static ProtocolToken HandshakeInternal(
sslContext.ReadPendingWrites(ref token);
return token;
}
catch (Exception exc)
catch (Exception exc) when (exc is not ArgumentException)
{
token.Status = new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, exc);
return token;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ private static ProtocolToken HandshakeInternal(ref SafeDeleteSslContext? context

token.Status = new SecurityStatusPal(errorCode);
}
catch (Exception exc)
catch (Exception exc) when (exc is not ArgumentException)
{
token.Status = new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, exc);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,53 @@ public static IEnumerable<object[]> Alpn_TestData()
yield return new object[] { proto, null, null, default(SslApplicationProtocol) };
}
}

[ConditionalFact(nameof(BackendSupportsAlpn))]
public async Task SslStream_StreamToStream_AlpnListTotalSizeExceedsLimit_Throws()
{
// Each protocol is 255 bytes, serialized with a 1-byte length prefix = 256 bytes each.
// Per RFC 7301, TLS wire format limits ProtocolNameList to 2^16-1 (65,535) bytes.
// All platforms enforce this via managed validation before calling native APIs.
// 256 * 256 = 65,536 > 65,535
const int protocolCount = 256;
List<SslApplicationProtocol> oversizedProtocols = new List<SslApplicationProtocol>();
for (int i = 0; i < protocolCount; i++)
{
byte[] proto = new byte[255];
proto.AsSpan().Fill((byte)'a');
proto[0] = (byte)((i >> 8) + 1);
proto[1] = (byte)(i & 0xFF);
oversizedProtocols.Add(new SslApplicationProtocol(proto));
}

using X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate();

SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions
{
ApplicationProtocols = oversizedProtocols,
RemoteCertificateValidationCallback = delegate { return true; },
TargetHost = Guid.NewGuid().ToString("N"),
};

SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions
{
ApplicationProtocols = new List<SslApplicationProtocol> { SslApplicationProtocol.Http2 },
ServerCertificateContext = SslStreamCertificateContext.Create(certificate, null)
};

(Stream clientStream, Stream serverStream) = TestHelper.GetConnectedStreams();
using (clientStream)
using (serverStream)
using (var client = new SslStream(clientStream, false))
using (var server = new SslStream(serverStream, false))
{
Task serverTask = server.AuthenticateAsServerAsync(TestAuthenticateAsync, serverOptions);
await Assert.ThrowsAsync<ArgumentException>(() => client.AuthenticateAsClientAsync(TestAuthenticateAsync, clientOptions));
server.Dispose();

await Assert.ThrowsAnyAsync<Exception>(() => serverTask.WaitAsync(TestConfiguration.PassingTestTimeout));
}
}
}

public sealed class SslStreamAlpnTest_Async : SslStreamAlpnTestBase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,51 @@ public void Constructor_ByteArray_Copies()
Assert.NotSame(expected, arraySegment.Array);
}

[Theory]
[InlineData(0, true)]
[InlineData(1, false)]
[InlineData(254, false)]
[InlineData(255, false)]
[InlineData(256, true)]
[InlineData(512, true)]
public void Constructor_ProtocolSizeBoundary_ThrowsForInvalidSize(int size, bool shouldThrow)
{
byte[] protocol = new byte[size];
protocol.AsSpan().Fill((byte)'a');

if (shouldThrow)
{
AssertExtensions.Throws<ArgumentException>("protocol", () => new SslApplicationProtocol(protocol));
}
else
{
SslApplicationProtocol alpn = new SslApplicationProtocol(protocol);
Assert.Equal(size, alpn.Protocol.Length);
}
}

[Theory]
[InlineData(0, true)]
[InlineData(1, false)]
[InlineData(254, false)]
[InlineData(255, false)]
[InlineData(256, true)]
[InlineData(512, true)]
public void Constructor_StringSizeBoundary_ThrowsForInvalidSize(int size, bool shouldThrow)
{
string protocol = new string('a', size);

if (shouldThrow)
{
AssertExtensions.Throws<ArgumentException>("protocol", () => new SslApplicationProtocol(protocol));
}
else
{
SslApplicationProtocol alpn = new SslApplicationProtocol(protocol);
Assert.Equal(size, alpn.Protocol.Length);
}
}

[Theory]
[MemberData(nameof(Protocol_Equality_TestData))]
public void Equality_Tests_Succeeds(SslApplicationProtocol left, SslApplicationProtocol right)
Expand Down
Loading