Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exposes SSL Stream and adds more TLS settings #132

Merged
merged 1 commit into from
Jun 16, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions examples/Echo.Client/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Echo.Client
using System;
using System.Diagnostics.Tracing;
using System.Net;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using DotNetty.Codecs;
Expand All @@ -26,6 +27,13 @@ static async Task RunClientAsync()

var group = new MultithreadEventLoopGroup();

X509Certificate2 cert = null;
string targetHost = null;
if (EchoClientSettings.IsSsl)
{
cert = new X509Certificate2("dotnetty.com.pfx", "password");
targetHost = cert.GetNameInfo(X509NameType.DnsName, false);
}
try
{
var bootstrap = new Bootstrap();
Expand All @@ -37,11 +45,9 @@ static async Task RunClientAsync()
{
IChannelPipeline pipeline = channel.Pipeline;

if (EchoClientSettings.IsSsl)
if (cert != null)
{
var cert = new X509Certificate2("dotnetty.com.pfx", "password");
string targetHost = cert.GetNameInfo(X509NameType.DnsName, false);
pipeline.AddLast(TlsHandler.Client(targetHost, null, (sender, certificate, chain, errors) => true));
pipeline.AddLast(new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ClientTlsSettings(targetHost)));
}
pipeline.AddLast(new LengthFieldPrepender(2));
pipeline.AddLast(new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2));
Expand Down
11 changes: 8 additions & 3 deletions examples/Echo.Server/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ namespace Echo.Server
{
using System;
using System.Diagnostics.Tracing;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using DotNetty.Codecs;
Expand All @@ -26,6 +27,11 @@ static async Task RunServerAsync()

var bossGroup = new MultithreadEventLoopGroup(1);
var workerGroup = new MultithreadEventLoopGroup();
X509Certificate2 tlsCertificate = null;
if (EchoServerSettings.IsSsl)
{
tlsCertificate = new X509Certificate2("dotnetty.com.pfx", "password");
}
try
{
var bootstrap = new ServerBootstrap();
Expand All @@ -37,10 +43,9 @@ static async Task RunServerAsync()
.ChildHandler(new ActionChannelInitializer<ISocketChannel>(channel =>
{
IChannelPipeline pipeline = channel.Pipeline;

if (EchoServerSettings.IsSsl)
if (tlsCertificate != null)
{
pipeline.AddLast(TlsHandler.Server(new X509Certificate2("dotnetty.com.pfx", "password")));
pipeline.AddLast(TlsHandler.Server(tlsCertificate));
}
pipeline.AddLast(new LengthFieldPrepender(2));
pipeline.AddLast(new LengthFieldBasedFrameDecoder(ushort.MaxValue, 0, 2, 0, 2));
Expand Down
3 changes: 3 additions & 0 deletions src/DotNetty.Handlers/DotNetty.Handlers.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@
<Compile Include="Logging\LogLevel.cs" />
<Compile Include="Logging\LogLevelExtensions.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="Tls\ClientTlsSettings.cs" />
<Compile Include="Tls\NotSslRecordException.cs" />
<Compile Include="Tls\ServerTlsSettings.cs" />
<Compile Include="Tls\TlsHandshakeCompletionEvent.cs" />
<Compile Include="Tls\TlsHandler.cs" />
<Compile Include="Timeout\IdleState.cs" />
Expand All @@ -58,6 +60,7 @@
<Compile Include="Timeout\WriteTimeoutException.cs" />
<Compile Include="Timeout\ReadTimeoutHandler.cs" />
<Compile Include="Timeout\WriteTimeoutHandler.cs" />
<Compile Include="Tls\TlsSettings.cs" />
<Compile Include="Tls\TlsUtils.cs" />
</ItemGroup>
<ItemGroup>
Expand Down
44 changes: 44 additions & 0 deletions src/DotNetty.Handlers/Tls/ClientTlsSettings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace DotNetty.Handlers.Tls
{
using System.Collections.Generic;
using System.Linq;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;

public sealed class ClientTlsSettings : TlsSettings
{
IReadOnlyCollection<X509Certificate2> certificates;

public ClientTlsSettings(string targetHost)
: this(targetHost, new List<X509Certificate>())
{
}

public ClientTlsSettings(string targetHost, List<X509Certificate> certificates)
: this(false, certificates, targetHost)
{
}

public ClientTlsSettings(bool checkCertificateRevocation, List<X509Certificate> certificates, string targetHost)
: this(SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, checkCertificateRevocation, certificates, targetHost)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we remove TLS 1.0 from here by default as well already?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to do a breaking change. Let's do that in version 0.4.0

{
}

public ClientTlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation, List<X509Certificate> certificates, string targetHost)
:base(enabledProtocols, checkCertificateRevocation)
{
this.X509CertificateCollection = new X509CertificateCollection(certificates.ToArray());
this.TargetHost = targetHost;
this.Certificates = certificates.AsReadOnly();
}

internal X509CertificateCollection X509CertificateCollection { get; set; }

public IReadOnlyCollection<X509Certificate> Certificates { get; }

public string TargetHost { get; }
}
}
29 changes: 29 additions & 0 deletions src/DotNetty.Handlers/Tls/ServerTlsSettings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace DotNetty.Handlers.Tls
{
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;

public sealed class ServerTlsSettings : TlsSettings
{
public ServerTlsSettings(X509Certificate certificate)
: this(false, certificate)
{
}

public ServerTlsSettings(bool checkCertificateRevocation, X509Certificate certificate)
: this(SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, checkCertificateRevocation, certificate)
{
}

public ServerTlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation, X509Certificate certificate)
: base(enabledProtocols, checkCertificateRevocation)
{
this.Certificate = certificate;
}

public X509Certificate Certificate { get; }
}
}
71 changes: 33 additions & 38 deletions src/DotNetty.Handlers/Tls/TlsHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ namespace DotNetty.Handlers.Tls
using System.IO;
using System.Net.Security;
using System.Runtime.ExceptionServices;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -21,48 +20,47 @@ namespace DotNetty.Handlers.Tls

public sealed class TlsHandler : ByteToMessageDecoder
{
readonly TlsSettings settings;
const int FallbackReadBufferSize = 256;
const int UnencryptedWriteBatchSize = 14 * 1024;

static readonly Exception ChannelClosedException = new IOException("Channel is closed");
static readonly Action<Task, object> HandshakeCompletionCallback = new Action<Task, object>(HandleHandshakeCompleted);

readonly SslStream sslStream;
readonly MediationStream mediationStream;
readonly TaskCompletionSource closeFuture;

TlsHandlerState state;
int packetLength;
readonly MediationStream mediationStream;
volatile IChannelHandlerContext capturedContext;
BatchingPendingWriteQueue pendingUnencryptedWrites;
Task lastContextWriteTask;
readonly TaskCompletionSource closeFuture;
readonly bool isServer;
readonly X509Certificate2 certificate;
readonly string targetHost;
bool firedChannelRead;
IByteBuffer pendingSslStreamReadBuffer;
Task<int> pendingSslStreamReadFuture;

TlsHandler(bool isServer, X509Certificate2 certificate, string targetHost, RemoteCertificateValidationCallback certificateValidationCallback)
public TlsHandler(TlsSettings settings)
: this(stream => new SslStream(stream, true), settings)
{
Contract.Requires(!isServer || certificate != null);
Contract.Requires(isServer || !string.IsNullOrEmpty(targetHost));
}

this.closeFuture = new TaskCompletionSource();
public TlsHandler(Func<Stream, SslStream> sslStreamFactory, TlsSettings settings)
{
Contract.Requires(sslStreamFactory != null);
Contract.Requires(settings != null);

this.isServer = isServer;
this.certificate = certificate;
this.targetHost = targetHost;
this.settings = settings;
this.closeFuture = new TaskCompletionSource();
this.mediationStream = new MediationStream(this);
this.sslStream = new SslStream(this.mediationStream, true, certificateValidationCallback);
this.sslStream = sslStreamFactory(this.mediationStream);
}

public static TlsHandler Client(string targetHost) => new TlsHandler(false, null, targetHost, null);

public static TlsHandler Client(string targetHost, X509Certificate2 certificate) => new TlsHandler(false, certificate, targetHost, null);
public static TlsHandler Client(string targetHost) => new TlsHandler(new ClientTlsSettings(targetHost));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we leave these for common cases and internalize sslStreamFactory in such a case?


public static TlsHandler Client(string targetHost, X509Certificate2 certificate, RemoteCertificateValidationCallback certificateValidationCallback) => new TlsHandler(false, certificate, targetHost, certificateValidationCallback);

public static TlsHandler Server(X509Certificate2 certificate) => new TlsHandler(true, certificate, null, null);
public static TlsHandler Client(string targetHost, X509Certificate clientCertificate) => new TlsHandler(new ClientTlsSettings(targetHost, new List<X509Certificate>{ clientCertificate }));
public static TlsHandler Server(X509Certificate certificate) => new TlsHandler(new ServerTlsSettings(certificate));

public X509Certificate LocalCertificate => this.sslStream.LocalCertificate;

Expand All @@ -74,7 +72,7 @@ public override void ChannelActive(IChannelHandlerContext context)
{
base.ChannelActive(context);

if (!this.isServer)
if (this.settings is ServerTlsSettings)
{
this.EnsureAuthenticated();
}
Expand Down Expand Up @@ -161,7 +159,7 @@ public override void HandlerAdded(IChannelHandlerContext context)
base.HandlerAdded(context);
this.capturedContext = context;
this.pendingUnencryptedWrites = new BatchingPendingWriteQueue(context, UnencryptedWriteBatchSize);
if (context.Channel.Active && !this.isServer)
if (context.Channel.Active && this.settings is ClientTlsSettings)
{
// todo: support delayed initialization on an existing/active channel if in client mode
this.EnsureAuthenticated();
Expand Down Expand Up @@ -217,23 +215,23 @@ protected override void Decode(IChannelHandlerContext context, IByteBuffer input
break;
}

int packetLength = TlsUtils.GetEncryptedPacketLength(input, offset);
if (packetLength == -1)
int encryptedPacketLength = TlsUtils.GetEncryptedPacketLength(input, offset);
if (encryptedPacketLength == -1)
{
nonSslRecord = true;
break;
}

Contract.Assert(packetLength > 0);
Contract.Assert(encryptedPacketLength > 0);

if (packetLength > readableBytes)
if (encryptedPacketLength > readableBytes)
{
// wait until the whole packet can be read
this.packetLength = packetLength;
this.packetLength = encryptedPacketLength;
break;
}

int newTotalLength = totalLength + packetLength;
int newTotalLength = totalLength + encryptedPacketLength;
if (newTotalLength > TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH)
{
// Don't read too much.
Expand All @@ -245,8 +243,8 @@ protected override void Decode(IChannelHandlerContext context, IByteBuffer input

// We have a whole packet.
// Increment the offset to handle the next packet.
packetLengths.Add(packetLength);
offset += packetLength;
packetLengths.Add(encryptedPacketLength);
offset += encryptedPacketLength;
totalLength = newTotalLength;
}

Expand Down Expand Up @@ -482,19 +480,16 @@ bool EnsureAuthenticated()
if (!oldState.HasAny(TlsHandlerState.AuthenticationStarted))
{
this.state = oldState | TlsHandlerState.Authenticating;
if (this.isServer)
var serverSettings = settings as ServerTlsSettings;
if (serverSettings != null)
{
this.sslStream.AuthenticateAsServerAsync(this.certificate, false, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false) // todo: change to begin/end
this.sslStream.AuthenticateAsServerAsync(serverSettings.Certificate, false, serverSettings.EnabledProtocols, serverSettings.CheckCertificateRevocation)
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously);
}
else
{
var certificateCollection = new X509Certificate2Collection();
if (this.certificate != null)
{
certificateCollection.Add(this.certificate);
}
this.sslStream.AuthenticateAsClientAsync(this.targetHost, certificateCollection, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false) // todo: change to begin/end
var clientSettings = (ClientTlsSettings)settings;
this.sslStream.AuthenticateAsClientAsync(clientSettings.TargetHost, clientSettings.X509CertificateCollection, clientSettings.EnabledProtocols, clientSettings.CheckCertificateRevocation)
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously);
}
return false;
Expand Down
20 changes: 20 additions & 0 deletions src/DotNetty.Handlers/Tls/TlsSettings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace DotNetty.Handlers.Tls
{
using System.Security.Authentication;

public abstract class TlsSettings
{
protected TlsSettings(SslProtocols enabledProtocols, bool checkCertificateRevocation)
{
this.EnabledProtocols = enabledProtocols;
this.CheckCertificateRevocation = checkCertificateRevocation;
}

public SslProtocols EnabledProtocols { get; }

public bool CheckCertificateRevocation { get; }
}
}
4 changes: 3 additions & 1 deletion test/DotNetty.Handlers.Tests/TlsHandlerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ static async Task<Tuple<EmbeddedChannel, SslStream>> SetupStreamAndChannelAsync(
{
var tlsCertificate = new X509Certificate2("dotnetty.com.pfx", "password");
string targetHost = tlsCertificate.GetNameInfo(X509NameType.DnsName, false);
TlsHandler tlsHandler = isClient ? TlsHandler.Client(targetHost, null, (_1, _2, _3, _4) => true) : TlsHandler.Server(tlsCertificate);
TlsHandler tlsHandler = isClient ?
new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ClientTlsSettings(targetHost)) :
TlsHandler.Server(tlsCertificate);
//var ch = new EmbeddedChannel(new LoggingHandler("BEFORE"), tlsHandler, new LoggingHandler("AFTER"));
var ch = new EmbeddedChannel(tlsHandler);

Expand Down
Loading