Skip to content

Commit

Permalink
Exposes SSL Stream and adds more TLS settings (#132)
Browse files Browse the repository at this point in the history
Motivation:

Some important SSL Stream settings are hidden in the TlsHandler class

Modifications:
SSLStream is provided by user now via factory method;
TLS settings extended

Results:
More advanced scenarios, like X509 client authentication, are possible to do now
  • Loading branch information
Mikhail Tuhckov authored and nayato committed Jun 16, 2016
1 parent 8ab8c8b commit fb18eaf
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 51 deletions.
14 changes: 10 additions & 4 deletions examples/Echo.Client/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Echo.Client
using System;
using System.Diagnostics.Tracing;
using System.Net;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using DotNetty.Codecs;
Expand All @@ -26,6 +27,13 @@ static async Task RunClientAsync()

var group = new MultithreadEventLoopGroup();

X509Certificate2 cert = null;
string targetHost = null;
if (EchoClientSettings.IsSsl)
{
cert = new X509Certificate2("dotnetty.com.pfx", "password");
targetHost = cert.GetNameInfo(X509NameType.DnsName, false);
}
try
{
var bootstrap = new Bootstrap();
Expand All @@ -37,11 +45,9 @@ static async Task RunClientAsync()
{
IChannelPipeline pipeline = channel.Pipeline;

if (EchoClientSettings.IsSsl)
if (cert != null)
{
var cert = new X509Certificate2("dotnetty.com.pfx", "password");
string targetHost = cert.GetNameInfo(X509NameType.DnsName, false);
pipeline.AddLast(TlsHandler.Client(targetHost, null, (sender, certificate, chain, errors) => true));
pipeline.AddLast(new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ClientTlsSettings(targetHost)));
}
pipeline.AddLast(new LengthFieldPrepender(2));
pipeline.AddLast(new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2));
Expand Down
11 changes: 8 additions & 3 deletions examples/Echo.Server/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ namespace Echo.Server
{
using System;
using System.Diagnostics.Tracing;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using DotNetty.Codecs;
Expand All @@ -26,6 +27,11 @@ static async Task RunServerAsync()

var bossGroup = new MultithreadEventLoopGroup(1);
var workerGroup = new MultithreadEventLoopGroup();
X509Certificate2 tlsCertificate = null;
if (EchoServerSettings.IsSsl)
{
tlsCertificate = new X509Certificate2("dotnetty.com.pfx", "password");
}
try
{
var bootstrap = new ServerBootstrap();
Expand All @@ -37,10 +43,9 @@ static async Task RunServerAsync()
.ChildHandler(new ActionChannelInitializer<ISocketChannel>(channel =>
{
IChannelPipeline pipeline = channel.Pipeline;

if (EchoServerSettings.IsSsl)
if (tlsCertificate != null)
{
pipeline.AddLast(TlsHandler.Server(new X509Certificate2("dotnetty.com.pfx", "password")));
pipeline.AddLast(TlsHandler.Server(tlsCertificate));
}
pipeline.AddLast(new LengthFieldPrepender(2));
pipeline.AddLast(new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2));
Expand Down
3 changes: 3 additions & 0 deletions src/DotNetty.Handlers/DotNetty.Handlers.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@
<Compile Include="Logging\LogLevel.cs" />
<Compile Include="Logging\LogLevelExtensions.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="Tls\ClientTlsSettings.cs" />
<Compile Include="Tls\NotSslRecordException.cs" />
<Compile Include="Tls\ServerTlsSettings.cs" />
<Compile Include="Tls\TlsHandshakeCompletionEvent.cs" />
<Compile Include="Tls\TlsHandler.cs" />
<Compile Include="Timeout\IdleState.cs" />
Expand All @@ -58,6 +60,7 @@
<Compile Include="Timeout\WriteTimeoutException.cs" />
<Compile Include="Timeout\ReadTimeoutHandler.cs" />
<Compile Include="Timeout\WriteTimeoutHandler.cs" />
<Compile Include="Tls\TlsSettings.cs" />
<Compile Include="Tls\TlsUtils.cs" />
</ItemGroup>
<ItemGroup>
Expand Down
44 changes: 44 additions & 0 deletions src/DotNetty.Handlers/Tls/ClientTlsSettings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace DotNetty.Handlers.Tls
{
using System.Collections.Generic;
using System.Linq;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;

public sealed class ClientTlsSettings : TlsSettings
{
IReadOnlyCollection<X509Certificate2> certificates;

public ClientTlsSettings(string targetHost)
: this(targetHost, new List<X509Certificate>())
{
}

public ClientTlsSettings(string targetHost, List<X509Certificate> certificates)
: this(false, certificates, targetHost)
{
}

public ClientTlsSettings(bool checkCertificateRevocation, List<X509Certificate> certificates, string targetHost)
: this(SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, checkCertificateRevocation, certificates, targetHost)
{
}

public ClientTlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation, List<X509Certificate> certificates, string targetHost)
:base(enabledProtocols, checkCertificateRevocation)
{
this.X509CertificateCollection = new X509CertificateCollection(certificates.ToArray());
this.TargetHost = targetHost;
this.Certificates = certificates.AsReadOnly();
}

internal X509CertificateCollection X509CertificateCollection { get; set; }

public IReadOnlyCollection<X509Certificate> Certificates { get; }

public string TargetHost { get; }
}
}
29 changes: 29 additions & 0 deletions src/DotNetty.Handlers/Tls/ServerTlsSettings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace DotNetty.Handlers.Tls
{
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;

public sealed class ServerTlsSettings : TlsSettings
{
public ServerTlsSettings(X509Certificate certificate)
: this(false, certificate)
{
}

public ServerTlsSettings(bool checkCertificateRevocation, X509Certificate certificate)
: this(SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, checkCertificateRevocation, certificate)
{
}

public ServerTlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation, X509Certificate certificate)
: base(enabledProtocols, checkCertificateRevocation)
{
this.Certificate = certificate;
}

public X509Certificate Certificate { get; }
}
}
71 changes: 33 additions & 38 deletions src/DotNetty.Handlers/Tls/TlsHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ namespace DotNetty.Handlers.Tls
using System.IO;
using System.Net.Security;
using System.Runtime.ExceptionServices;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -21,48 +20,47 @@ namespace DotNetty.Handlers.Tls

public sealed class TlsHandler : ByteToMessageDecoder
{
readonly TlsSettings settings;
const int FallbackReadBufferSize = 256;
const int UnencryptedWriteBatchSize = 14 * 1024;

static readonly Exception ChannelClosedException = new IOException("Channel is closed");
static readonly Action<Task, object> HandshakeCompletionCallback = new Action<Task, object>(HandleHandshakeCompleted);

readonly SslStream sslStream;
readonly MediationStream mediationStream;
readonly TaskCompletionSource closeFuture;

TlsHandlerState state;
int packetLength;
readonly MediationStream mediationStream;
volatile IChannelHandlerContext capturedContext;
BatchingPendingWriteQueue pendingUnencryptedWrites;
Task lastContextWriteTask;
readonly TaskCompletionSource closeFuture;
readonly bool isServer;
readonly X509Certificate2 certificate;
readonly string targetHost;
bool firedChannelRead;
IByteBuffer pendingSslStreamReadBuffer;
Task<int> pendingSslStreamReadFuture;

TlsHandler(bool isServer, X509Certificate2 certificate, string targetHost, RemoteCertificateValidationCallback certificateValidationCallback)
public TlsHandler(TlsSettings settings)
: this(stream => new SslStream(stream, true), settings)
{
Contract.Requires(!isServer || certificate != null);
Contract.Requires(isServer || !string.IsNullOrEmpty(targetHost));
}

this.closeFuture = new TaskCompletionSource();
public TlsHandler(Func<Stream, SslStream> sslStreamFactory, TlsSettings settings)
{
Contract.Requires(sslStreamFactory != null);
Contract.Requires(settings != null);

this.isServer = isServer;
this.certificate = certificate;
this.targetHost = targetHost;
this.settings = settings;
this.closeFuture = new TaskCompletionSource();
this.mediationStream = new MediationStream(this);
this.sslStream = new SslStream(this.mediationStream, true, certificateValidationCallback);
this.sslStream = sslStreamFactory(this.mediationStream);
}

public static TlsHandler Client(string targetHost) => new TlsHandler(false, null, targetHost, null);

public static TlsHandler Client(string targetHost, X509Certificate2 certificate) => new TlsHandler(false, certificate, targetHost, null);
public static TlsHandler Client(string targetHost) => new TlsHandler(new ClientTlsSettings(targetHost));

public static TlsHandler Client(string targetHost, X509Certificate2 certificate, RemoteCertificateValidationCallback certificateValidationCallback) => new TlsHandler(false, certificate, targetHost, certificateValidationCallback);

public static TlsHandler Server(X509Certificate2 certificate) => new TlsHandler(true, certificate, null, null);
public static TlsHandler Client(string targetHost, X509Certificate clientCertificate) => new TlsHandler(new ClientTlsSettings(targetHost, new List<X509Certificate>{ clientCertificate }));
public static TlsHandler Server(X509Certificate certificate) => new TlsHandler(new ServerTlsSettings(certificate));

public X509Certificate LocalCertificate => this.sslStream.LocalCertificate;

Expand All @@ -74,7 +72,7 @@ public override void ChannelActive(IChannelHandlerContext context)
{
base.ChannelActive(context);

if (!this.isServer)
if (this.settings is ServerTlsSettings)
{
this.EnsureAuthenticated();
}
Expand Down Expand Up @@ -161,7 +159,7 @@ public override void HandlerAdded(IChannelHandlerContext context)
base.HandlerAdded(context);
this.capturedContext = context;
this.pendingUnencryptedWrites = new BatchingPendingWriteQueue(context, UnencryptedWriteBatchSize);
if (context.Channel.Active && !this.isServer)
if (context.Channel.Active && this.settings is ClientTlsSettings)
{
// todo: support delayed initialization on an existing/active channel if in client mode
this.EnsureAuthenticated();
Expand Down Expand Up @@ -217,23 +215,23 @@ protected override void Decode(IChannelHandlerContext context, IByteBuffer input
break;
}

int packetLength = TlsUtils.GetEncryptedPacketLength(input, offset);
if (packetLength == -1)
int encryptedPacketLength = TlsUtils.GetEncryptedPacketLength(input, offset);
if (encryptedPacketLength == -1)
{
nonSslRecord = true;
break;
}

Contract.Assert(packetLength > 0);
Contract.Assert(encryptedPacketLength > 0);

if (packetLength > readableBytes)
if (encryptedPacketLength > readableBytes)
{
// wait until the whole packet can be read
this.packetLength = packetLength;
this.packetLength = encryptedPacketLength;
break;
}

int newTotalLength = totalLength + packetLength;
int newTotalLength = totalLength + encryptedPacketLength;
if (newTotalLength > TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH)
{
// Don't read too much.
Expand All @@ -245,8 +243,8 @@ protected override void Decode(IChannelHandlerContext context, IByteBuffer input

// We have a whole packet.
// Increment the offset to handle the next packet.
packetLengths.Add(packetLength);
offset += packetLength;
packetLengths.Add(encryptedPacketLength);
offset += encryptedPacketLength;
totalLength = newTotalLength;
}

Expand Down Expand Up @@ -482,19 +480,16 @@ bool EnsureAuthenticated()
if (!oldState.HasAny(TlsHandlerState.AuthenticationStarted))
{
this.state = oldState | TlsHandlerState.Authenticating;
if (this.isServer)
var serverSettings = settings as ServerTlsSettings;
if (serverSettings != null)
{
this.sslStream.AuthenticateAsServerAsync(this.certificate, false, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false) // todo: change to begin/end
this.sslStream.AuthenticateAsServerAsync(serverSettings.Certificate, false, serverSettings.EnabledProtocols, serverSettings.CheckCertificateRevocation)
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously);
}
else
{
var certificateCollection = new X509Certificate2Collection();
if (this.certificate != null)
{
certificateCollection.Add(this.certificate);
}
this.sslStream.AuthenticateAsClientAsync(this.targetHost, certificateCollection, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false) // todo: change to begin/end
var clientSettings = (ClientTlsSettings)settings;
this.sslStream.AuthenticateAsClientAsync(clientSettings.TargetHost, clientSettings.X509CertificateCollection, clientSettings.EnabledProtocols, clientSettings.CheckCertificateRevocation)
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously);
}
return false;
Expand Down
20 changes: 20 additions & 0 deletions src/DotNetty.Handlers/Tls/TlsSettings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace DotNetty.Handlers.Tls
{
using System.Security.Authentication;

public abstract class TlsSettings
{
protected TlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation)
{
this.EnabledProtocols = enabledProtocols;
this.CheckCertificateRevocation = checkCertificateRevocation;
}

public SslProtocols EnabledProtocols { get; }

public bool CheckCertificateRevocation { get; }
}
}
4 changes: 3 additions & 1 deletion test/DotNetty.Handlers.Tests/TlsHandlerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ static async Task<Tuple<EmbeddedChannel, SslStream>> SetupStreamAndChannelAsync(
{
var tlsCertificate = new X509Certificate2("dotnetty.com.pfx", "password");
string targetHost = tlsCertificate.GetNameInfo(X509NameType.DnsName, false);
TlsHandler tlsHandler = isClient ? TlsHandler.Client(targetHost, null, (_1, _2, _3, _4) => true) : TlsHandler.Server(tlsCertificate);
TlsHandler tlsHandler = isClient ?
new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ClientTlsSettings(targetHost)) :
TlsHandler.Server(tlsCertificate);
//var ch = new EmbeddedChannel(new LoggingHandler("BEFORE"), tlsHandler, new LoggingHandler("AFTER"));
var ch = new EmbeddedChannel(tlsHandler);

Expand Down
Loading

0 comments on commit fb18eaf

Please sign in to comment.