Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stubbing out SSL support for MQTT clients #123

Open
wants to merge 13 commits into
base: dev
Choose a base branch
from
Open
21 changes: 21 additions & 0 deletions src/TurboMqtt/Client/ClientTlsOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// -----------------------------------------------------------------------
// <copyright file="ICertificateProvider.cs" company="Petabridge, LLC">
// Copyright (C) 2024 - 2024 Petabridge, LLC <https://petabridge.com>
// </copyright>
// -----------------------------------------------------------------------

using System.Net.Security;

namespace TurboMqtt.Client;

/// <summary>
/// Used to provide the necessary certificates and keys for establishing a secure connection with the MQTT broker.
/// </summary>
public sealed record ClientTlsOptions
{
public static readonly ClientTlsOptions Default = new();

public bool UseTls { get; init; } = false;

public SslClientAuthenticationOptions? SslOptions { get; init; }
}
3 changes: 3 additions & 0 deletions src/TurboMqtt/Client/IMqttClientFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ public MqttClientFactory(ActorSystem system)
public async Task<IMqttClient> CreateTcpClient(MqttClientConnectOptions options, MqttClientTcpOptions tcpOptions)
{
AssertMqtt311(options);
if (tcpOptions.TlsOptions is { UseTls: true, SslOptions: null })
throw new NullReferenceException("TlsOptions.SslOptions can not be null if TlsOptions.UseTls is true");

var transportManager = new TcpMqttTransportManager(tcpOptions, _mqttClientManager, options.ProtocolVersion);

// create the client
Expand Down
20 changes: 13 additions & 7 deletions src/TurboMqtt/Client/MqttClientTcpOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// </copyright>
// -----------------------------------------------------------------------

using System.Net;
using System.Net.Sockets;

namespace TurboMqtt.Client;
Expand All @@ -23,26 +22,33 @@ public MqttClientTcpOptions(string host, int port)
/// <summary>
/// Would love to just do IPV6, but that still meets resistance everywhere
/// </summary>
public AddressFamily AddressFamily { get; set; } = AddressFamily.Unspecified;
public AddressFamily AddressFamily { get; init; } = AddressFamily.Unspecified;

/// <summary>
/// Frames are limited to this size in bytes. A frame can contain multiple packets.
/// </summary>
public int MaxFrameSize { get; set; } = 128 * 1024; // 128kb
public int MaxFrameSize { get; init; } = 128 * 1024; // 128kb

public string Host { get; }
public string Host { get; init; }

public int Port { get; }
public int Port { get; init; }

/// <summary>
/// How long should we wait before attempting to reconnect the client?
/// </summary>
public TimeSpan ReconnectInterval { get; set; } = TimeSpan.FromSeconds(5);
public TimeSpan ReconnectInterval { get; init; } = TimeSpan.FromSeconds(5);

/// <summary>
/// Maximum number of times we should attempt to reconnect the client before giving up.
///
/// Resets back to 0 after a successful connection.
/// </summary>
public int MaxReconnectAttempts { get; set; } = 10;
public int MaxReconnectAttempts { get; init; } = 10;

/// <summary>
/// The <see cref="ClientTlsOptions"/> to use when connecting to the server.
///
/// Disabled by default.
/// </summary>
public ClientTlsOptions TlsOptions { get; init; } = ClientTlsOptions.Default;
}
98 changes: 60 additions & 38 deletions src/TurboMqtt/IO/Tcp/FakeMqttTcpServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
using System.Collections.Concurrent;
using System.IO.Pipelines;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using Akka.Event;
using TurboMqtt.Protocol;

namespace TurboMqtt.IO.Tcp;

internal sealed class MqttTcpServerOptions
public sealed record MqttTcpServerOptions
{
public MqttTcpServerOptions(string host, int port)
{
Expand All @@ -25,16 +26,18 @@ public MqttTcpServerOptions(string host, int port)
/// <summary>
/// Would love to just do IPV6, but that still meets resistance everywhere
/// </summary>
public AddressFamily AddressFamily { get; set; } = AddressFamily.Unspecified;
public AddressFamily AddressFamily { get; init; } = AddressFamily.Unspecified;

/// <summary>
/// Frames are limited to this size in bytes. A frame can contain multiple packets.
/// </summary>
public int MaxFrameSize { get; set; } = 128 * 1024; // 128kb
public int MaxFrameSize { get; init; } = 128 * 1024; // 128kb

public string Host { get; }
public string Host { get; init; }

public int Port { get; }
public int Port { get; init; }

public SslServerAuthenticationOptions? SslOptions { get; init; }
}

/// <summary>
Expand All @@ -47,7 +50,7 @@ internal sealed class FakeMqttTcpServer
private readonly CancellationTokenSource _shutdownTcs = new();
private readonly ILoggingAdapter _log;
private readonly ConcurrentDictionary<string, (CancellationTokenSource ct, Task shutdown)> _clientCts = new();
private readonly ConcurrentDictionary<string, Socket> _clientSockets = new();
private readonly ConcurrentDictionary<string, Stream> _clientSockets = new();
private readonly TimeSpan _heatBeatDelay;
private readonly IFakeServerHandleFactory _handleFactory;
private Socket? _bindSocket;
Expand Down Expand Up @@ -118,16 +121,16 @@ public bool TryKickClient(string clientId)

public bool TryDisconnectClientSocket(string clientId)
{
if (!_clientSockets.TryRemove(clientId, out var socket))
if (!_clientSockets.TryRemove(clientId, out var clientTcpStream))
return false;

if (!socket.Connected)
if (!clientTcpStream.CanRead)
return false;
socket.Disconnect(true);

clientTcpStream.Dispose();
return true;
}

public void Shutdown()
{
_log.Info("Shutting down server.");
Expand All @@ -154,7 +157,26 @@ private async Task BeginAcceptAsync()
while (!_shutdownTcs.IsCancellationRequested)
{
var socket = await _bindSocket!.AcceptAsync();
_ = ProcessClientAsync(socket);
Stream readingStream = new NetworkStream(socket, true);

// check for TLS
if (_options.SslOptions != null)
{
try
{
var sslStream = new SslStream(readingStream, false);
readingStream = sslStream;
await sslStream.AuthenticateAsServerAsync(_options.SslOptions, _shutdownTcs.Token);
_log.Info("Server authenticated successfully");
}
catch (Exception ex)
{
_log.Error(ex, "Exception during authentication");
throw;
}
}

_ = ProcessClientAsync(readingStream);
}
}

Expand All @@ -174,7 +196,7 @@ private static async Task ReadFromPipeAsync(PipeReader reader, IFakeServerHandle
// once we hand the message over to the end-user.
var newMemory = new Memory<byte>(new byte[buffer.Length]);
buffer.CopyTo(newMemory.Span);

handle.HandleBytes(newMemory);
}

Expand All @@ -193,7 +215,9 @@ private static async Task ReadFromPipeAsync(PipeReader reader, IFakeServerHandle
catch (Exception ex)
{
// junk exception that occurs during shutdown
handle.Log.Debug(ex, "Error advancing the reader with buffer size [{0}] with read result of [Completed={1}, Cancelled={2}]", buffer.Length, result.IsCompleted, result.IsCanceled);
handle.Log.Debug(ex,
"Error advancing the reader with buffer size [{0}] with read result of [Completed={1}, Cancelled={2}]",
buffer.Length, result.IsCompleted, result.IsCanceled);
return;
}
}
Expand All @@ -204,10 +228,10 @@ private static async Task ReadFromPipeAsync(PipeReader reader, IFakeServerHandle
}
}
}

private async Task ProcessClientAsync(Socket socket)
private async Task ProcessClientAsync(Stream stream)
{
using (socket)
await using (stream)
{
var closed = false;
var pipe = new Pipe(new PipeOptions(
Expand All @@ -217,18 +241,18 @@ private async Task ProcessClientAsync(Socket socket)
var clientShutdownCts = new CancellationTokenSource();
var linkedCts =
CancellationTokenSource.CreateLinkedTokenSource(clientShutdownCts.Token, _shutdownTcs.Token);

var handle = _handleFactory.CreateServerHandle(PushMessage, ClosingAction, _log, _version, _heatBeatDelay);

_ = handle.WhenClientIdAssigned.ContinueWith(t =>
{
if (t.IsCompletedSuccessfully)
{
_clientCts.TryAdd(t.Result, (clientShutdownCts, handle.WhenTerminated));
_clientSockets.TryAdd(t.Result, socket);
_clientSockets.TryAdd(t.Result, stream);
}
}, clientShutdownCts.Token);


_ = ReadFromPipeAsync(pipe.Reader, handle, linkedCts.Token);

Expand All @@ -239,7 +263,7 @@ private async Task ProcessClientAsync(Socket socket)
try
{
var memory = pipe.Writer.GetMemory(_options.MaxFrameSize / 4);
var bytesRead = await socket.ReceiveAsync(memory, SocketFlags.None, linkedCts.Token);
var bytesRead = await stream.ReadAsync(memory, linkedCts.Token);
if (bytesRead == 0)
{
_log.Info("Client {0} disconnected from server.",
Expand Down Expand Up @@ -287,7 +311,7 @@ private async Task ProcessClientAsync(Socket socket)

// ensure we've cleaned up all resources
await handle.WhenTerminated;

await pipe.Writer.CompleteAsync();
await pipe.Reader.CompleteAsync();

Expand All @@ -297,21 +321,10 @@ bool PushMessage((IMemoryOwner<byte> buffer, int estimatedSize) msg)
{
try
{
if (socket.Connected && linkedCts.Token is { IsCancellationRequested: false })
if (stream.CanWrite && linkedCts.Token is { IsCancellationRequested: false })
{
var sent = socket.Send(msg.buffer.Memory.Span.Slice(0, msg.estimatedSize));
while (sent < msg.estimatedSize)
{
if (sent == 0) return false; // we are shutting down

var remaining = msg.buffer.Memory.Slice(sent);
var sent2 = socket.Send(remaining.Span);
if (sent2 == remaining.Length)
sent += sent2;
else
return false;
}

var task = stream.WriteAsync(msg.buffer.Memory.Slice(0, msg.estimatedSize), linkedCts.Token);
task.GetAwaiter().GetResult();
return true;
}

Expand All @@ -333,7 +346,16 @@ async Task ClosingAction()
closed = true;
// ReSharper disable once AccessToModifiedClosure
await clientShutdownCts.CancelAsync();
if (socket.Connected) socket.Close();
try
{
stream?.Close();
// ReSharper disable once MethodHasAsyncOverload
stream?.Dispose();
}
catch
{
// suppress exceptions during stream disposal
}
}
}
}
Expand Down
Loading
Loading