Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TLS end to end unit test #1

Merged
merged 2 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/TurboMqtt/Client/IMqttClientFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ public MqttClientFactory(ActorSystem system)
public async Task<IMqttClient> CreateTcpClient(MqttClientConnectOptions options, MqttClientTcpOptions tcpOptions)
{
AssertMqtt311(options);
if (tcpOptions.TlsOptions is { UseTls: true, SslOptions: null })
throw new NullReferenceException("TlsOptions.SslOptions can not be null if TlsOptions.UseTls is true");
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


var transportManager = new TcpMqttTransportManager(tcpOptions, _mqttClientManager, options.ProtocolVersion);

// create the client
Expand Down
27 changes: 18 additions & 9 deletions src/TurboMqtt/IO/Tcp/FakeMqttTcpServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

namespace TurboMqtt.IO.Tcp;

internal sealed class MqttTcpServerOptions
public sealed record MqttTcpServerOptions
{
public MqttTcpServerOptions(string host, int port)
{
Expand All @@ -26,18 +26,18 @@ public MqttTcpServerOptions(string host, int port)
/// <summary>
/// Would love to just do IPV6, but that still meets resistance everywhere
/// </summary>
public AddressFamily AddressFamily { get; set; } = AddressFamily.Unspecified;
public AddressFamily AddressFamily { get; init; } = AddressFamily.Unspecified;

/// <summary>
/// Frames are limited to this size in bytes. A frame can contain multiple packets.
/// </summary>
public int MaxFrameSize { get; set; } = 128 * 1024; // 128kb
public int MaxFrameSize { get; init; } = 128 * 1024; // 128kb

public string Host { get; }
public string Host { get; init; }

public int Port { get; }
public int Port { get; init; }

public SslServerAuthenticationOptions? SslOptions { get; set; }
public SslServerAuthenticationOptions? SslOptions { get; init; }
}

/// <summary>
Expand Down Expand Up @@ -162,9 +162,18 @@ private async Task BeginAcceptAsync()
// check for TLS
if (_options.SslOptions != null)
{
var sslStream = new SslStream(readingStream, false);
await sslStream.AuthenticateAsServerAsync(_options.SslOptions, _shutdownTcs.Token);
readingStream = sslStream;
try
{
var sslStream = new SslStream(readingStream, false);
readingStream = sslStream;
await sslStream.AuthenticateAsServerAsync(_options.SslOptions, _shutdownTcs.Token);
_log.Info("Server authenticated successfully");
}
catch (Exception ex)
{
_log.Error(ex, "Exception during authentication");
throw;
}
}

_ = ProcessClientAsync(readingStream);
Expand Down
9 changes: 9 additions & 0 deletions src/TurboMqtt/IO/Tcp/TcpTransportActor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Channels;
using Akka.Actor;
using Akka.Event;
Expand Down Expand Up @@ -226,6 +227,14 @@ private async Task DoConnectAsync(IPAddress[] addresses, int port, IActorRef des
await _tcpClient.ConnectAsync(addresses, port, ct).ConfigureAwait(false);
connectResult = new ConnectResult(ConnectionStatus.Connected, "Connected.");
_tcpStream = new NetworkStream(_tcpClient, true);

// Check for TLS
if (TcpOptions.TlsOptions.UseTls)
{
var sslStream = new SslStream(_tcpStream, false);
_tcpStream = sslStream;
await sslStream.AuthenticateAsClientAsync(TcpOptions.TlsOptions.SslOptions!, ct);
}
}
catch (Exception ex)
{
Expand Down
24 changes: 12 additions & 12 deletions tests/TurboMqtt.Tests/End2End/MQTT_311/TcpMqtt311End2EndSpecs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

namespace TurboMqtt.Tests.End2End;

[CollectionDefinition(nameof(TcpEnd2EndCollection))]
public sealed class TcpEnd2EndCollection;

[Collection(nameof(TcpEnd2EndCollection))]
public class TcpMqtt311End2EndSpecs : TransportSpecBase
{
public static readonly Config DebugLogging = """
Expand All @@ -26,7 +30,7 @@ public TcpMqtt311End2EndSpecs(ITestOutputHelper output) : base(output: output, c
{
var logger = new BusLogging(Sys.EventStream, "FakeMqttTcpServer", typeof(FakeMqttTcpServer),
Sys.Settings.LogFormatter);
_server = new FakeMqttTcpServer(new MqttTcpServerOptions("localhost", 21883), MqttProtocolVersion.V3_1_1,
_server = new FakeMqttTcpServer(DefaultTcpServerOptions, MqttProtocolVersion.V3_1_1,
logger, TimeSpan.Zero, new DefaultFakeServerHandleFactory());
_server.Bind();
}
Expand All @@ -39,7 +43,8 @@ public override async Task<IMqttClient> CreateClient()
return client;
}

public MqttClientTcpOptions DefaultTcpOptions => new("localhost", 21883);
protected virtual MqttClientTcpOptions DefaultTcpOptions => new("localhost", 21883);
protected virtual MqttTcpServerOptions DefaultTcpServerOptions => new("localhost", 21883);

protected override void AfterAll()
{
Expand Down Expand Up @@ -128,7 +133,7 @@ public async Task ShouldReconnectSuccessfullyIfReconnectFlowFailed()
_server.Shutdown();

var server = new FakeMqttTcpServer(
options: new MqttTcpServerOptions("localhost", 21883),
options: DefaultTcpServerOptions,
version: MqttProtocolVersion.V3_1_1,
log: Log,
heartbeatDelay: TimeSpan.Zero,
Expand Down Expand Up @@ -259,10 +264,7 @@ public async Task ShouldTerminateClientAfterMultipleFailedConnectionAttempts()
[Fact]
public async Task ShouldFailToConnectToNonExistentServer()
{
var updatedTcpOptions = new MqttClientTcpOptions("localhost", 21884)
{
MaxReconnectAttempts = 0
};
var updatedTcpOptions = DefaultTcpOptions with { Port = 21884, MaxReconnectAttempts = 0 };
var client = await ClientFactory.CreateTcpClient(DefaultConnectOptions, updatedTcpOptions);

// we are going to do this, intentionally, without a CTS here - this operation MUST FAIL if we are unable to connect
Expand All @@ -275,10 +277,7 @@ public async Task ShouldFailToConnectToNonExistentServer()
[Fact]
public async Task ShouldSuccessFullyConnectWhenBrokerAvailableAfterFailedConnectionAttempt()
{
var updatedTcpOptions = new MqttClientTcpOptions("localhost", 21889)
{
MaxReconnectAttempts = 0
};
var updatedTcpOptions = DefaultTcpOptions with { Port = 21889, MaxReconnectAttempts = 0 };
var client = await ClientFactory.CreateTcpClient(DefaultConnectOptions, updatedTcpOptions);

// we are going to do this, intentionally, without a CTS here - this operation MUST FAIL if we are unable to connect
Expand All @@ -287,8 +286,9 @@ public async Task ShouldSuccessFullyConnectWhenBrokerAvailableAfterFailedConnect

client.IsConnected.Should().BeFalse();

var updatedServerOptions = DefaultTcpServerOptions with { Port = 21889 };
// start up a new server
var newServer = new FakeMqttTcpServer(new MqttTcpServerOptions("localhost", 21889), MqttProtocolVersion.V3_1_1,
var newServer = new FakeMqttTcpServer(updatedServerOptions, MqttProtocolVersion.V3_1_1,
Sys.Log, TimeSpan.Zero, new DefaultFakeServerHandleFactory());
try
{
Expand Down
137 changes: 137 additions & 0 deletions tests/TurboMqtt.Tests/End2End/MQTT_311/TlsMqtt311End2EndSpecs.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// -----------------------------------------------------------------------
// <copyright file="TlsMqtt311End2EndSpecs.cs" company="Petabridge, LLC">
// Copyright (C) 2024 - 2024 Petabridge, LLC <https://petabridge.com>
// </copyright>
// -----------------------------------------------------------------------

using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using Akka.Event;
using TurboMqtt.Client;
using TurboMqtt.IO.Tcp;
using Xunit.Abstractions;

namespace TurboMqtt.Tests.End2End;

[Collection(nameof(TcpEnd2EndCollection))]
public class TlsMqtt311End2EndSpecs : TcpMqtt311End2EndSpecs
{
// This is a workaround for this issue:
// https://github.com/dotnet/runtime/issues/23749
private static readonly X509Certificate2 RootCert = new (
X509Certificate2.CreateFromEncryptedPemFile("./certs/root_cert.pem", "password")
.Export(X509ContentType.Pkcs12));

private static readonly X509ChainPolicy RootChainPolicy = new()
{
CustomTrustStore = { RootCert },
TrustMode = X509ChainTrustMode.CustomRootTrust,
RevocationMode = X509RevocationMode.NoCheck
};

private static readonly X509Chain RootChain = new ()
{
ChainPolicy = RootChainPolicy
};

public TlsMqtt311End2EndSpecs(ITestOutputHelper output) : base(output)
{
}

// This is a workaround for this issue:
// https://github.com/dotnet/runtime/issues/23749
private static readonly X509Certificate2 ServerCert = new X509Certificate2(
X509Certificate2.CreateFromEncryptedPemFile("./certs/server_cert.pem", "password")
.Export(X509ContentType.Pkcs12));

protected override MqttTcpServerOptions DefaultTcpServerOptions => new ("localhost", 21883)
{
SslOptions = new SslServerAuthenticationOptions
{
ServerCertificate = ServerCert,
ClientCertificateRequired = false,
RemoteCertificateValidationCallback = ValidateClientCertificate
}
};

private bool ValidateClientCertificate(
object sender,
X509Certificate? certificate,
X509Chain? chain,
SslPolicyErrors sslPolicyErrors)
{
if (sslPolicyErrors == SslPolicyErrors.None)
return true;

// Return true if client certificate is not required
if (sslPolicyErrors == SslPolicyErrors.RemoteCertificateNotAvailable)
return true;

// Validate client certificate with a custom chain
if (certificate is not null)
{
var isValid = RootChain.Build(new X509Certificate2(certificate));
if (!isValid)
{
foreach (var status in RootChain.ChainStatus)
{
Log.Error("[Server] Chain error: {0}", status.StatusInformation);
}
}

return isValid;
}

// Refuse everything else
Log.Error("[Server] Certificate error: {0}", sslPolicyErrors);
return false;
}

private bool ValidateServerCertificate(
object sender,
X509Certificate? certificate,
X509Chain? chain,
SslPolicyErrors errors)
{
if (errors == SslPolicyErrors.None)
return true;

// Missing cert or the destination hostname wasn't valid for the cert.
if ((errors & ~SslPolicyErrors.RemoteCertificateChainErrors) != 0)
return false;

// Validate client certificate with a custom chain
if (certificate is not null)
{
chain ??= RootChain;
var isValid = chain.Build(new X509Certificate2(certificate));
if (!isValid)
{
foreach (var status in chain.ChainStatus)
{
Log.Error("[Client] Chain error: [{0}] {1}", status.Status, status.StatusInformation);
}
}

return isValid;
}

// Refuse everything else
Log.Error("[Client] Certificate error: {0}", errors);
return false;
}

protected override MqttClientTcpOptions DefaultTcpOptions => new("localhost", 21883)
{
TlsOptions = new ClientTlsOptions
{
UseTls = true,
SslOptions = new SslClientAuthenticationOptions
{
TargetHost = "localhost",
CertificateChainPolicy = RootChainPolicy,
RemoteCertificateValidationCallback = ValidateServerCertificate
}
}
};
}
14 changes: 14 additions & 0 deletions tests/TurboMqtt.Tests/TurboMqtt.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,18 @@
<ProjectReference Include="..\..\src\TurboMqtt\TurboMqtt.csproj" />
</ItemGroup>

<ItemGroup>
<None Include="..\certs\root_cert.pem">
<Link>certs\root_cert.pem</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>

<ItemGroup>
<None Include="..\certs\server_cert.pem">
<Link>certs\server_cert.pem</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>

</Project>
30 changes: 30 additions & 0 deletions tests/certs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Generating Test Certificates

* Use WSL to access OpenSSL (the keys in this folder was generated using Ubuntu)
* All private key files uses the password "password"

## 1. Generate Root CA

```
openssl req -x509 -new -nodes -key root_private_key.pem -sha256 -days 3650 -out root_cert.pem -config root_cert_config.cnf
```

* This will generate the **"root_cert.pem"** that is valid for 10 years.
* Open the **"root_private_key.pem"** and copy-paste its content to the end of the **"root_cert.pem"** file.

## 2. Generate CSR

```
openssl req -new -key server_private_key.pem -out server.csr -config server_cert_config.cnf
```

* This generates the **"server.csr"** file.

## 3. Generate The Server Certificate

```
openssl x509 -req -in server.csr -CA root_cert.pem -CAkey root_private_key.pem -CAcreateserial -out server_cert.pem -days 365 -sha256 -extfile v3_ext.cnf
```

* This generates the **"server_cert.pem"** file.
* Open the **"server_private_key.pem"** and copy-paste its content to the end of the **"server_cert.pem"** file.
22 changes: 0 additions & 22 deletions tests/certs/certificate.pem

This file was deleted.

18 changes: 0 additions & 18 deletions tests/certs/csr.pem

This file was deleted.

Loading