Skip to content

Commit

Permalink
Add TLS end to end unit test (#1)
Browse files Browse the repository at this point in the history
* Add TLS end to end unit test

* Make TLS spec inherit TCP spec
  • Loading branch information
Arkatufus authored May 31, 2024
1 parent 9167ca2 commit 3b53dee
Show file tree
Hide file tree
Showing 18 changed files with 451 additions and 113 deletions.
3 changes: 3 additions & 0 deletions src/TurboMqtt/Client/IMqttClientFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ public MqttClientFactory(ActorSystem system)
public async Task<IMqttClient> CreateTcpClient(MqttClientConnectOptions options, MqttClientTcpOptions tcpOptions)
{
AssertMqtt311(options);
if (tcpOptions.TlsOptions is { UseTls: true, SslOptions: null })
throw new NullReferenceException("TlsOptions.SslOptions can not be null if TlsOptions.UseTls is true");

var transportManager = new TcpMqttTransportManager(tcpOptions, _mqttClientManager, options.ProtocolVersion);

// create the client
Expand Down
27 changes: 18 additions & 9 deletions src/TurboMqtt/IO/Tcp/FakeMqttTcpServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

namespace TurboMqtt.IO.Tcp;

internal sealed class MqttTcpServerOptions
public sealed record MqttTcpServerOptions
{
public MqttTcpServerOptions(string host, int port)
{
Expand All @@ -26,18 +26,18 @@ public MqttTcpServerOptions(string host, int port)
/// <summary>
/// Would love to just do IPV6, but that still meets resistance everywhere
/// </summary>
public AddressFamily AddressFamily { get; set; } = AddressFamily.Unspecified;
public AddressFamily AddressFamily { get; init; } = AddressFamily.Unspecified;

/// <summary>
/// Frames are limited to this size in bytes. A frame can contain multiple packets.
/// </summary>
public int MaxFrameSize { get; set; } = 128 * 1024; // 128kb
public int MaxFrameSize { get; init; } = 128 * 1024; // 128kb

public string Host { get; }
public string Host { get; init; }

public int Port { get; }
public int Port { get; init; }

public SslServerAuthenticationOptions? SslOptions { get; set; }
public SslServerAuthenticationOptions? SslOptions { get; init; }
}

/// <summary>
Expand Down Expand Up @@ -162,9 +162,18 @@ private async Task BeginAcceptAsync()
// check for TLS
if (_options.SslOptions != null)
{
var sslStream = new SslStream(readingStream, false);
await sslStream.AuthenticateAsServerAsync(_options.SslOptions, _shutdownTcs.Token);
readingStream = sslStream;
try
{
var sslStream = new SslStream(readingStream, false);
readingStream = sslStream;
await sslStream.AuthenticateAsServerAsync(_options.SslOptions, _shutdownTcs.Token);
_log.Info("Server authenticated successfully");
}
catch (Exception ex)
{
_log.Error(ex, "Exception during authentication");
throw;
}
}

_ = ProcessClientAsync(readingStream);
Expand Down
9 changes: 9 additions & 0 deletions src/TurboMqtt/IO/Tcp/TcpTransportActor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Channels;
using Akka.Actor;
using Akka.Event;
Expand Down Expand Up @@ -226,6 +227,14 @@ private async Task DoConnectAsync(IPAddress[] addresses, int port, IActorRef des
await _tcpClient.ConnectAsync(addresses, port, ct).ConfigureAwait(false);
connectResult = new ConnectResult(ConnectionStatus.Connected, "Connected.");
_tcpStream = new NetworkStream(_tcpClient, true);

// Check for TLS
if (TcpOptions.TlsOptions.UseTls)
{
var sslStream = new SslStream(_tcpStream, false);
_tcpStream = sslStream;
await sslStream.AuthenticateAsClientAsync(TcpOptions.TlsOptions.SslOptions!, ct);
}
}
catch (Exception ex)
{
Expand Down
24 changes: 12 additions & 12 deletions tests/TurboMqtt.Tests/End2End/MQTT_311/TcpMqtt311End2EndSpecs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

namespace TurboMqtt.Tests.End2End;

[CollectionDefinition(nameof(TcpEnd2EndCollection))]
public sealed class TcpEnd2EndCollection;

[Collection(nameof(TcpEnd2EndCollection))]
public class TcpMqtt311End2EndSpecs : TransportSpecBase
{
public static readonly Config DebugLogging = """
Expand All @@ -26,7 +30,7 @@ public TcpMqtt311End2EndSpecs(ITestOutputHelper output) : base(output: output, c
{
var logger = new BusLogging(Sys.EventStream, "FakeMqttTcpServer", typeof(FakeMqttTcpServer),
Sys.Settings.LogFormatter);
_server = new FakeMqttTcpServer(new MqttTcpServerOptions("localhost", 21883), MqttProtocolVersion.V3_1_1,
_server = new FakeMqttTcpServer(DefaultTcpServerOptions, MqttProtocolVersion.V3_1_1,
logger, TimeSpan.Zero, new DefaultFakeServerHandleFactory());
_server.Bind();
}
Expand All @@ -39,7 +43,8 @@ public override async Task<IMqttClient> CreateClient()
return client;
}

public MqttClientTcpOptions DefaultTcpOptions => new("localhost", 21883);
protected virtual MqttClientTcpOptions DefaultTcpOptions => new("localhost", 21883);
protected virtual MqttTcpServerOptions DefaultTcpServerOptions => new("localhost", 21883);

protected override void AfterAll()
{
Expand Down Expand Up @@ -128,7 +133,7 @@ public async Task ShouldReconnectSuccessfullyIfReconnectFlowFailed()
_server.Shutdown();

var server = new FakeMqttTcpServer(
options: new MqttTcpServerOptions("localhost", 21883),
options: DefaultTcpServerOptions,
version: MqttProtocolVersion.V3_1_1,
log: Log,
heartbeatDelay: TimeSpan.Zero,
Expand Down Expand Up @@ -259,10 +264,7 @@ public async Task ShouldTerminateClientAfterMultipleFailedConnectionAttempts()
[Fact]
public async Task ShouldFailToConnectToNonExistentServer()
{
var updatedTcpOptions = new MqttClientTcpOptions("localhost", 21884)
{
MaxReconnectAttempts = 0
};
var updatedTcpOptions = DefaultTcpOptions with { Port = 21884, MaxReconnectAttempts = 0 };
var client = await ClientFactory.CreateTcpClient(DefaultConnectOptions, updatedTcpOptions);

// we are going to do this, intentionally, without a CTS here - this operation MUST FAIL if we are unable to connect
Expand All @@ -275,10 +277,7 @@ public async Task ShouldFailToConnectToNonExistentServer()
[Fact]
public async Task ShouldSuccessFullyConnectWhenBrokerAvailableAfterFailedConnectionAttempt()
{
var updatedTcpOptions = new MqttClientTcpOptions("localhost", 21889)
{
MaxReconnectAttempts = 0
};
var updatedTcpOptions = DefaultTcpOptions with { Port = 21889, MaxReconnectAttempts = 0 };
var client = await ClientFactory.CreateTcpClient(DefaultConnectOptions, updatedTcpOptions);

// we are going to do this, intentionally, without a CTS here - this operation MUST FAIL if we are unable to connect
Expand All @@ -287,8 +286,9 @@ public async Task ShouldSuccessFullyConnectWhenBrokerAvailableAfterFailedConnect

client.IsConnected.Should().BeFalse();

var updatedServerOptions = DefaultTcpServerOptions with { Port = 21889 };
// start up a new server
var newServer = new FakeMqttTcpServer(new MqttTcpServerOptions("localhost", 21889), MqttProtocolVersion.V3_1_1,
var newServer = new FakeMqttTcpServer(updatedServerOptions, MqttProtocolVersion.V3_1_1,
Sys.Log, TimeSpan.Zero, new DefaultFakeServerHandleFactory());
try
{
Expand Down
137 changes: 137 additions & 0 deletions tests/TurboMqtt.Tests/End2End/MQTT_311/TlsMqtt311End2EndSpecs.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// -----------------------------------------------------------------------
// <copyright file="TlsMqtt311End2EndSpecs.cs" company="Petabridge, LLC">
// Copyright (C) 2024 - 2024 Petabridge, LLC <https://petabridge.com>
// </copyright>
// -----------------------------------------------------------------------

using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using Akka.Event;
using TurboMqtt.Client;
using TurboMqtt.IO.Tcp;
using Xunit.Abstractions;

namespace TurboMqtt.Tests.End2End;

[Collection(nameof(TcpEnd2EndCollection))]
public class TlsMqtt311End2EndSpecs : TcpMqtt311End2EndSpecs
{
// This is a workaround for this issue:
// https://github.com/dotnet/runtime/issues/23749
private static readonly X509Certificate2 RootCert = new (
X509Certificate2.CreateFromEncryptedPemFile("./certs/root_cert.pem", "password")
.Export(X509ContentType.Pkcs12));

private static readonly X509ChainPolicy RootChainPolicy = new()
{
CustomTrustStore = { RootCert },
TrustMode = X509ChainTrustMode.CustomRootTrust,
RevocationMode = X509RevocationMode.NoCheck
};

private static readonly X509Chain RootChain = new ()
{
ChainPolicy = RootChainPolicy
};

public TlsMqtt311End2EndSpecs(ITestOutputHelper output) : base(output)
{
}

// This is a workaround for this issue:
// https://github.com/dotnet/runtime/issues/23749
private static readonly X509Certificate2 ServerCert = new X509Certificate2(
X509Certificate2.CreateFromEncryptedPemFile("./certs/server_cert.pem", "password")
.Export(X509ContentType.Pkcs12));

protected override MqttTcpServerOptions DefaultTcpServerOptions => new ("localhost", 21883)
{
SslOptions = new SslServerAuthenticationOptions
{
ServerCertificate = ServerCert,
ClientCertificateRequired = false,
RemoteCertificateValidationCallback = ValidateClientCertificate
}
};

private bool ValidateClientCertificate(
object sender,
X509Certificate? certificate,
X509Chain? chain,
SslPolicyErrors sslPolicyErrors)
{
if (sslPolicyErrors == SslPolicyErrors.None)
return true;

// Return true if client certificate is not required
if (sslPolicyErrors == SslPolicyErrors.RemoteCertificateNotAvailable)
return true;

// Validate client certificate with a custom chain
if (certificate is not null)
{
var isValid = RootChain.Build(new X509Certificate2(certificate));
if (!isValid)
{
foreach (var status in RootChain.ChainStatus)
{
Log.Error("[Server] Chain error: {0}", status.StatusInformation);
}
}

return isValid;
}

// Refuse everything else
Log.Error("[Server] Certificate error: {0}", sslPolicyErrors);
return false;
}

private bool ValidateServerCertificate(
object sender,
X509Certificate? certificate,
X509Chain? chain,
SslPolicyErrors errors)
{
if (errors == SslPolicyErrors.None)
return true;

// Missing cert or the destination hostname wasn't valid for the cert.
if ((errors & ~SslPolicyErrors.RemoteCertificateChainErrors) != 0)
return false;

// Validate client certificate with a custom chain
if (certificate is not null)
{
chain ??= RootChain;
var isValid = chain.Build(new X509Certificate2(certificate));
if (!isValid)
{
foreach (var status in chain.ChainStatus)
{
Log.Error("[Client] Chain error: [{0}] {1}", status.Status, status.StatusInformation);
}
}

return isValid;
}

// Refuse everything else
Log.Error("[Client] Certificate error: {0}", errors);
return false;
}

protected override MqttClientTcpOptions DefaultTcpOptions => new("localhost", 21883)
{
TlsOptions = new ClientTlsOptions
{
UseTls = true,
SslOptions = new SslClientAuthenticationOptions
{
TargetHost = "localhost",
CertificateChainPolicy = RootChainPolicy,
RemoteCertificateValidationCallback = ValidateServerCertificate
}
}
};
}
14 changes: 14 additions & 0 deletions tests/TurboMqtt.Tests/TurboMqtt.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,18 @@
<ProjectReference Include="..\..\src\TurboMqtt\TurboMqtt.csproj" />
</ItemGroup>

<ItemGroup>
<None Include="..\certs\root_cert.pem">
<Link>certs\root_cert.pem</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>

<ItemGroup>
<None Include="..\certs\server_cert.pem">
<Link>certs\server_cert.pem</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>

</Project>
30 changes: 30 additions & 0 deletions tests/certs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Generating Test Certificates

* Use WSL to access OpenSSL (the keys in this folder was generated using Ubuntu)
* All private key files uses the password "password"

## 1. Generate Root CA

```
openssl req -x509 -new -nodes -key root_private_key.pem -sha256 -days 3650 -out root_cert.pem -config root_cert_config.cnf
```

* This will generate the **"root_cert.pem"** that is valid for 10 years.
* Open the **"root_private_key.pem"** and copy-paste its content to the end of the **"root_cert.pem"** file.

## 2. Generate CSR

```
openssl req -new -key server_private_key.pem -out server.csr -config server_cert_config.cnf
```

* This generates the **"server.csr"** file.

## 3. Generate The Server Certificate

```
openssl x509 -req -in server.csr -CA root_cert.pem -CAkey root_private_key.pem -CAcreateserial -out server_cert.pem -days 365 -sha256 -extfile v3_ext.cnf
```

* This generates the **"server_cert.pem"** file.
* Open the **"server_private_key.pem"** and copy-paste its content to the end of the **"server_cert.pem"** file.
22 changes: 0 additions & 22 deletions tests/certs/certificate.pem

This file was deleted.

18 changes: 0 additions & 18 deletions tests/certs/csr.pem

This file was deleted.

Loading

0 comments on commit 3b53dee

Please sign in to comment.