diff --git a/examples/Echo.Client/Program.cs b/examples/Echo.Client/Program.cs index deb3e12b5..5217ef592 100644 --- a/examples/Echo.Client/Program.cs +++ b/examples/Echo.Client/Program.cs @@ -6,6 +6,7 @@ namespace Echo.Client using System; using System.Diagnostics.Tracing; using System.Net; + using System.Net.Security; using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; using DotNetty.Codecs; @@ -26,6 +27,13 @@ static async Task RunClientAsync() var group = new MultithreadEventLoopGroup(); + X509Certificate2 cert = null; + string targetHost = null; + if (EchoClientSettings.IsSsl) + { + cert = new X509Certificate2("dotnetty.com.pfx", "password"); + targetHost = cert.GetNameInfo(X509NameType.DnsName, false); + } try { var bootstrap = new Bootstrap(); @@ -37,11 +45,9 @@ static async Task RunClientAsync() { IChannelPipeline pipeline = channel.Pipeline; - if (EchoClientSettings.IsSsl) + if (cert != null) { - var cert = new X509Certificate2("dotnetty.com.pfx", "password"); - string targetHost = cert.GetNameInfo(X509NameType.DnsName, false); - pipeline.AddLast(TlsHandler.Client(targetHost, null, (sender, certificate, chain, errors) => true)); + pipeline.AddLast(new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ClientTlsSettings(targetHost))); } pipeline.AddLast(new LengthFieldPrepender(2)); pipeline.AddLast(new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2)); diff --git a/examples/Echo.Server/Program.cs b/examples/Echo.Server/Program.cs index fe61259fd..38840d52e 100644 --- a/examples/Echo.Server/Program.cs +++ b/examples/Echo.Server/Program.cs @@ -5,6 +5,7 @@ namespace Echo.Server { using System; using System.Diagnostics.Tracing; + using System.Net.Security; using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; using DotNetty.Codecs; @@ -26,6 +27,11 @@ static async Task RunServerAsync() var bossGroup = new MultithreadEventLoopGroup(1); var workerGroup = new MultithreadEventLoopGroup(); + X509Certificate2 tlsCertificate = null; + if (EchoServerSettings.IsSsl) + { + tlsCertificate = new X509Certificate2("dotnetty.com.pfx", "password"); + } try { var bootstrap = new ServerBootstrap(); @@ -37,10 +43,9 @@ static async Task RunServerAsync() .ChildHandler(new ActionChannelInitializer(channel => { IChannelPipeline pipeline = channel.Pipeline; - - if (EchoServerSettings.IsSsl) + if (tlsCertificate != null) { - pipeline.AddLast(TlsHandler.Server(new X509Certificate2("dotnetty.com.pfx", "password"))); + pipeline.AddLast(TlsHandler.Server(tlsCertificate)); } pipeline.AddLast(new LengthFieldPrepender(2)); pipeline.AddLast(new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2)); diff --git a/src/DotNetty.Handlers/DotNetty.Handlers.csproj b/src/DotNetty.Handlers/DotNetty.Handlers.csproj index 51d18f586..355fc112b 100644 --- a/src/DotNetty.Handlers/DotNetty.Handlers.csproj +++ b/src/DotNetty.Handlers/DotNetty.Handlers.csproj @@ -47,7 +47,9 @@ + + @@ -58,6 +60,7 @@ + diff --git a/src/DotNetty.Handlers/Tls/ClientTlsSettings.cs b/src/DotNetty.Handlers/Tls/ClientTlsSettings.cs new file mode 100644 index 000000000..80076302b --- /dev/null +++ b/src/DotNetty.Handlers/Tls/ClientTlsSettings.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Handlers.Tls +{ + using System.Collections.Generic; + using System.Linq; + using System.Security.Authentication; + using System.Security.Cryptography.X509Certificates; + + public sealed class ClientTlsSettings : TlsSettings + { + IReadOnlyCollection certificates; + + public ClientTlsSettings(string targetHost) + : this(targetHost, new List()) + { + } + + public ClientTlsSettings(string targetHost, List certificates) + : this(false, certificates, targetHost) + { + } + + public ClientTlsSettings(bool checkCertificateRevocation, List certificates, string targetHost) + : this(SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, checkCertificateRevocation, certificates, targetHost) + { + } + + public ClientTlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation, List certificates, string targetHost) + :base(enabledProtocols, checkCertificateRevocation) + { + this.X509CertificateCollection = new X509CertificateCollection(certificates.ToArray()); + this.TargetHost = targetHost; + this.Certificates = certificates.AsReadOnly(); + } + + internal X509CertificateCollection X509CertificateCollection { get; set; } + + public IReadOnlyCollection Certificates { get; } + + public string TargetHost { get; } + } +} \ No newline at end of file diff --git a/src/DotNetty.Handlers/Tls/ServerTlsSettings.cs b/src/DotNetty.Handlers/Tls/ServerTlsSettings.cs new file mode 100644 index 000000000..d30cef432 --- /dev/null +++ b/src/DotNetty.Handlers/Tls/ServerTlsSettings.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Handlers.Tls +{ + using System.Security.Authentication; + using System.Security.Cryptography.X509Certificates; + + public sealed class ServerTlsSettings : TlsSettings + { + public ServerTlsSettings(X509Certificate certificate) + : this(false, certificate) + { + } + + public ServerTlsSettings(bool checkCertificateRevocation, X509Certificate certificate) + : this(SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, checkCertificateRevocation, certificate) + { + } + + public ServerTlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation, X509Certificate certificate) + : base(enabledProtocols, checkCertificateRevocation) + { + this.Certificate = certificate; + } + + public X509Certificate Certificate { get; } + } +} \ No newline at end of file diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.cs b/src/DotNetty.Handlers/Tls/TlsHandler.cs index 1850504ac..87c07ccd2 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.cs @@ -9,7 +9,6 @@ namespace DotNetty.Handlers.Tls using System.IO; using System.Net.Security; using System.Runtime.ExceptionServices; - using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; @@ -21,6 +20,7 @@ namespace DotNetty.Handlers.Tls public sealed class TlsHandler : ByteToMessageDecoder { + readonly TlsSettings settings; const int FallbackReadBufferSize = 256; const int UnencryptedWriteBatchSize = 14 * 1024; @@ -28,41 +28,39 @@ public sealed class TlsHandler : ByteToMessageDecoder static readonly Action HandshakeCompletionCallback = new Action(HandleHandshakeCompleted); readonly SslStream sslStream; + readonly MediationStream mediationStream; + readonly TaskCompletionSource closeFuture; + TlsHandlerState state; int packetLength; - readonly MediationStream mediationStream; volatile IChannelHandlerContext capturedContext; BatchingPendingWriteQueue pendingUnencryptedWrites; Task lastContextWriteTask; - readonly TaskCompletionSource closeFuture; - readonly bool isServer; - readonly X509Certificate2 certificate; - readonly string targetHost; bool firedChannelRead; IByteBuffer pendingSslStreamReadBuffer; Task pendingSslStreamReadFuture; - TlsHandler(bool isServer, X509Certificate2 certificate, string targetHost, RemoteCertificateValidationCallback certificateValidationCallback) + public TlsHandler(TlsSettings settings) + : this(stream => new SslStream(stream, true), settings) { - Contract.Requires(!isServer || certificate != null); - Contract.Requires(isServer || !string.IsNullOrEmpty(targetHost)); + } - this.closeFuture = new TaskCompletionSource(); + public TlsHandler(Func sslStreamFactory, TlsSettings settings) + { + Contract.Requires(sslStreamFactory != null); + Contract.Requires(settings != null); - this.isServer = isServer; - this.certificate = certificate; - this.targetHost = targetHost; + this.settings = settings; + this.closeFuture = new TaskCompletionSource(); this.mediationStream = new MediationStream(this); - this.sslStream = new SslStream(this.mediationStream, true, certificateValidationCallback); + this.sslStream = sslStreamFactory(this.mediationStream); } - public static TlsHandler Client(string targetHost) => new TlsHandler(false, null, targetHost, null); - - public static TlsHandler Client(string targetHost, X509Certificate2 certificate) => new TlsHandler(false, certificate, targetHost, null); + public static TlsHandler Client(string targetHost) => new TlsHandler(new ClientTlsSettings(targetHost)); - public static TlsHandler Client(string targetHost, X509Certificate2 certificate, RemoteCertificateValidationCallback certificateValidationCallback) => new TlsHandler(false, certificate, targetHost, certificateValidationCallback); - - public static TlsHandler Server(X509Certificate2 certificate) => new TlsHandler(true, certificate, null, null); + public static TlsHandler Client(string targetHost, X509Certificate clientCertificate) => new TlsHandler(new ClientTlsSettings(targetHost, new List{ clientCertificate })); + + public static TlsHandler Server(X509Certificate certificate) => new TlsHandler(new ServerTlsSettings(certificate)); public X509Certificate LocalCertificate => this.sslStream.LocalCertificate; @@ -74,7 +72,7 @@ public override void ChannelActive(IChannelHandlerContext context) { base.ChannelActive(context); - if (!this.isServer) + if (this.settings is ServerTlsSettings) { this.EnsureAuthenticated(); } @@ -161,7 +159,7 @@ public override void HandlerAdded(IChannelHandlerContext context) base.HandlerAdded(context); this.capturedContext = context; this.pendingUnencryptedWrites = new BatchingPendingWriteQueue(context, UnencryptedWriteBatchSize); - if (context.Channel.Active && !this.isServer) + if (context.Channel.Active && this.settings is ClientTlsSettings) { // todo: support delayed initialization on an existing/active channel if in client mode this.EnsureAuthenticated(); @@ -217,23 +215,23 @@ protected override void Decode(IChannelHandlerContext context, IByteBuffer input break; } - int packetLength = TlsUtils.GetEncryptedPacketLength(input, offset); - if (packetLength == -1) + int encryptedPacketLength = TlsUtils.GetEncryptedPacketLength(input, offset); + if (encryptedPacketLength == -1) { nonSslRecord = true; break; } - Contract.Assert(packetLength > 0); + Contract.Assert(encryptedPacketLength > 0); - if (packetLength > readableBytes) + if (encryptedPacketLength > readableBytes) { // wait until the whole packet can be read - this.packetLength = packetLength; + this.packetLength = encryptedPacketLength; break; } - int newTotalLength = totalLength + packetLength; + int newTotalLength = totalLength + encryptedPacketLength; if (newTotalLength > TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH) { // Don't read too much. @@ -245,8 +243,8 @@ protected override void Decode(IChannelHandlerContext context, IByteBuffer input // We have a whole packet. // Increment the offset to handle the next packet. - packetLengths.Add(packetLength); - offset += packetLength; + packetLengths.Add(encryptedPacketLength); + offset += encryptedPacketLength; totalLength = newTotalLength; } @@ -482,19 +480,16 @@ bool EnsureAuthenticated() if (!oldState.HasAny(TlsHandlerState.AuthenticationStarted)) { this.state = oldState | TlsHandlerState.Authenticating; - if (this.isServer) + var serverSettings = settings as ServerTlsSettings; + if (serverSettings != null) { - this.sslStream.AuthenticateAsServerAsync(this.certificate, false, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false) // todo: change to begin/end + this.sslStream.AuthenticateAsServerAsync(serverSettings.Certificate, false, serverSettings.EnabledProtocols, serverSettings.CheckCertificateRevocation) .ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously); } else { - var certificateCollection = new X509Certificate2Collection(); - if (this.certificate != null) - { - certificateCollection.Add(this.certificate); - } - this.sslStream.AuthenticateAsClientAsync(this.targetHost, certificateCollection, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false) // todo: change to begin/end + var clientSettings = (ClientTlsSettings)settings; + this.sslStream.AuthenticateAsClientAsync(clientSettings.TargetHost, clientSettings.X509CertificateCollection, clientSettings.EnabledProtocols, clientSettings.CheckCertificateRevocation) .ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously); } return false; diff --git a/src/DotNetty.Handlers/Tls/TlsSettings.cs b/src/DotNetty.Handlers/Tls/TlsSettings.cs new file mode 100644 index 000000000..42846ccd1 --- /dev/null +++ b/src/DotNetty.Handlers/Tls/TlsSettings.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace DotNetty.Handlers.Tls +{ + using System.Security.Authentication; + + public abstract class TlsSettings + { + protected TlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation) + { + this.EnabledProtocols = enabledProtocols; + this.CheckCertificateRevocation = checkCertificateRevocation; + } + + public SslProtocols EnabledProtocols { get; } + + public bool CheckCertificateRevocation { get; } + } +} \ No newline at end of file diff --git a/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs b/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs index bf7d9f616..1779f6e04 100644 --- a/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs +++ b/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs @@ -154,7 +154,9 @@ static async Task> SetupStreamAndChannelAsync( { var tlsCertificate = new X509Certificate2("dotnetty.com.pfx", "password"); string targetHost = tlsCertificate.GetNameInfo(X509NameType.DnsName, false); - TlsHandler tlsHandler = isClient ? TlsHandler.Client(targetHost, null, (_1, _2, _3, _4) => true) : TlsHandler.Server(tlsCertificate); + TlsHandler tlsHandler = isClient ? + new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ClientTlsSettings(targetHost)) : + TlsHandler.Server(tlsCertificate); //var ch = new EmbeddedChannel(new LoggingHandler("BEFORE"), tlsHandler, new LoggingHandler("AFTER")); var ch = new EmbeddedChannel(tlsHandler); diff --git a/test/DotNetty.Tests.End2End/End2EndTests.cs b/test/DotNetty.Tests.End2End/End2EndTests.cs index 87ffbc10b..0a97f9c50 100644 --- a/test/DotNetty.Tests.End2End/End2EndTests.cs +++ b/test/DotNetty.Tests.End2End/End2EndTests.cs @@ -7,6 +7,7 @@ namespace DotNetty.Tests.End2End using System.Collections.Generic; using System.Linq; using System.Net; + using System.Net.Security; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading.Tasks; @@ -53,7 +54,7 @@ public async void EchoServerAndClient() { ch.Pipeline.AddLast("server logger", new LoggingHandler("SERVER")); ch.Pipeline.AddLast("server tls", TlsHandler.Server(tlsCertificate)); - ch.Pipeline.AddLast("server logger2", new LoggingHandler("SER***")); + ch.Pipeline.AddLast("server logger2", new LoggingHandler("SER***")); ch.Pipeline.AddLast("server prepender", new LengthFieldPrepender(2)); ch.Pipeline.AddLast("server decoder", new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2)); ch.Pipeline.AddLast(new EchoChannelHandler()); @@ -67,8 +68,9 @@ public async void EchoServerAndClient() .Handler(new ActionChannelInitializer(ch => { string targetHost = tlsCertificate.GetNameInfo(X509NameType.DnsName, false); + var clientTlsSettings = new ClientTlsSettings(targetHost); ch.Pipeline.AddLast("client logger", new LoggingHandler("CLIENT")); - ch.Pipeline.AddLast("client tls", TlsHandler.Client(targetHost, null, (sender, certificate, chain, errors) => true)); + ch.Pipeline.AddLast("client tls", new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), clientTlsSettings)); ch.Pipeline.AddLast("client logger2", new LoggingHandler("CLI***")); ch.Pipeline.AddLast("client prepender", new LengthFieldPrepender(2)); ch.Pipeline.AddLast("client decoder", new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2)); @@ -109,7 +111,7 @@ public async void MqttServerAndClient() Func closeServerFunc = await this.StartServerAsync(true, ch => { ch.Pipeline.AddLast("server logger", new LoggingHandler("SERVER")); - ch.Pipeline.AddLast("client tls", TlsHandler.Server(tlsCertificate)); + ch.Pipeline.AddLast("server tls", TlsHandler.Server(tlsCertificate)); ch.Pipeline.AddLast("server logger2", new LoggingHandler("SER***")); ch.Pipeline.AddLast( MqttEncoder.Instance, @@ -124,9 +126,11 @@ public async void MqttServerAndClient() .Option(ChannelOption.TcpNodelay, true) .Handler(new ActionChannelInitializer(ch => { - ch.Pipeline.AddLast("client logger", new LoggingHandler("CLIENT")); string targetHost = tlsCertificate.GetNameInfo(X509NameType.DnsName, false); - ch.Pipeline.AddLast("client tls", TlsHandler.Client(targetHost, null, (sender, certificate, chain, errors) => true)); + var clientTlsSettings = new ClientTlsSettings(targetHost); + + ch.Pipeline.AddLast("client logger", new LoggingHandler("CLIENT")); + ch.Pipeline.AddLast("client tls", new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), clientTlsSettings)); ch.Pipeline.AddLast("client logger2", new LoggingHandler("CLI***")); ch.Pipeline.AddLast( MqttEncoder.Instance,