From 1cee677559d45dbc6993d4ec2335573c31486900 Mon Sep 17 00:00:00 2001 From: danielmarbach Date: Mon, 15 Jun 2020 22:11:49 +0200 Subject: [PATCH] Concurrent dispatch as opt-in --- .../client/api/ConnectionFactory.cs | 10 +++ .../client/impl/AsyncConsumerWorkService.cs | 74 +++++++++++++++++-- .../RabbitMQ.Client/client/impl/Connection.cs | 5 +- .../client/impl/ConsumerWorkService.cs | 73 ++++++++++++++++-- .../Unit/APIApproval.Approve.verified.txt | 1 + projects/Unit/TestAsyncConsumer.cs | 54 ++++++++++++++ projects/Unit/TestConsumer.cs | 65 ++++++++++++++++ 7 files changed, 270 insertions(+), 12 deletions(-) create mode 100644 projects/Unit/TestConsumer.cs diff --git a/projects/RabbitMQ.Client/client/api/ConnectionFactory.cs b/projects/RabbitMQ.Client/client/api/ConnectionFactory.cs index 15d5b07e1c..091c1f6474 100644 --- a/projects/RabbitMQ.Client/client/api/ConnectionFactory.cs +++ b/projects/RabbitMQ.Client/client/api/ConnectionFactory.cs @@ -175,6 +175,16 @@ public sealed class ConnectionFactory : ConnectionFactoryBase, IAsyncConnectionF /// public bool DispatchConsumersAsync { get; set; } = false; + /// + /// Set to a value greater than one to enable concurrent processing. For a concurrency greater than one + /// will be offloaded to the worker thread pool so it is important to choose the value for the concurrency wisely to avoid thread pool overloading. + /// can handle concurrency much more efficiently due to the non-blocking nature of the consumer. + /// Defaults to 1. + /// + /// For concurrency greater than one this removes the guarantee that consumers handle messages in the order they receive them. + /// In addition to that consumers need to be thread/concurrency safe. + public int ProcessingConcurrency { get; set; } = 1; + /// The host to connect to. public string HostName { get; set; } = "localhost"; diff --git a/projects/RabbitMQ.Client/client/impl/AsyncConsumerWorkService.cs b/projects/RabbitMQ.Client/client/impl/AsyncConsumerWorkService.cs index 5686aaf372..dfe1f45f28 100644 --- a/projects/RabbitMQ.Client/client/impl/AsyncConsumerWorkService.cs +++ b/projects/RabbitMQ.Client/client/impl/AsyncConsumerWorkService.cs @@ -9,7 +9,12 @@ namespace RabbitMQ.Client.Impl internal sealed class AsyncConsumerWorkService : ConsumerWorkService { private readonly ConcurrentDictionary _workPools = new ConcurrentDictionary(); - private readonly Func _startNewWorkPoolFunc = model => StartNewWorkPool(model); + private readonly Func _startNewWorkPoolFunc; + + public AsyncConsumerWorkService(int concurrency) : base(concurrency) + { + _startNewWorkPoolFunc = model => StartNewWorkPool(model); + } public void Schedule(ModelBase model, TWork work) where TWork : Work { @@ -22,9 +27,9 @@ public void Schedule(ModelBase model, TWork work) where TWork : Work workPool.Enqueue(work); } - private static WorkPool StartNewWorkPool(IModel model) + private WorkPool StartNewWorkPool(IModel model) { - var newWorkPool = new WorkPool(model as ModelBase); + var newWorkPool = new WorkPool(model as ModelBase, _concurrency); newWorkPool.Start(); return newWorkPool; } @@ -44,16 +49,29 @@ class WorkPool readonly Channel _channel; readonly ModelBase _model; private Task _worker; + private readonly int _concurrency; + private SemaphoreSlim _limiter; + private CancellationTokenSource _tokenSource; - public WorkPool(ModelBase model) + public WorkPool(ModelBase model, int concurrency) { + _concurrency = concurrency; _model = model; _channel = Channel.CreateUnbounded(new UnboundedChannelOptions { SingleReader = true, SingleWriter = false, AllowSynchronousContinuations = false }); } public void Start() { - _worker = Task.Run(Loop, CancellationToken.None); + if (_concurrency == 1) + { + _worker = Task.Run(Loop, CancellationToken.None); + } + else + { + _limiter = new SemaphoreSlim(_concurrency); + _tokenSource = new CancellationTokenSource(); + _worker = Task.Run(() => LoopWithConcurrency(_tokenSource.Token), CancellationToken.None); + } } public void Enqueue(Work work) @@ -83,9 +101,55 @@ async Task Loop() } } + async Task LoopWithConcurrency(CancellationToken cancellationToken) + { + try + { + while (await _channel.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + while (_channel.Reader.TryRead(out Work work)) + { + // Do a quick synchronous check before we resort to async/await with the state-machine overhead. + if(!_limiter.Wait(0)) + { + await _limiter.WaitAsync(cancellationToken).ConfigureAwait(false); + } + + _ = HandleConcurrent(work, _model, _limiter); + } + } + } + catch (OperationCanceledException) + { + // ignored + } + } + + static async Task HandleConcurrent(Work work, ModelBase model, SemaphoreSlim limiter) + { + try + { + Task task = work.Execute(model); + if (!task.IsCompleted) + { + await task.ConfigureAwait(false); + } + } + catch (Exception) + { + + } + finally + { + limiter.Release(); + } + } + public Task Stop() { _channel.Writer.Complete(); + _tokenSource?.Cancel(); + _limiter?.Dispose(); return _worker; } } diff --git a/projects/RabbitMQ.Client/client/impl/Connection.cs b/projects/RabbitMQ.Client/client/impl/Connection.cs index 36ce1a83e7..90662991d3 100644 --- a/projects/RabbitMQ.Client/client/impl/Connection.cs +++ b/projects/RabbitMQ.Client/client/impl/Connection.cs @@ -110,13 +110,14 @@ public Connection(IConnectionFactory factory, bool insist, IFrameHandler frameHa _factory = factory; _frameHandler = frameHandler; + int processingConcurrency = (factory as ConnectionFactory)?.ProcessingConcurrency ?? 1; if (factory is IAsyncConnectionFactory asyncConnectionFactory && asyncConnectionFactory.DispatchConsumersAsync) { - ConsumerWorkService = new AsyncConsumerWorkService(); + ConsumerWorkService = new AsyncConsumerWorkService(processingConcurrency); } else { - ConsumerWorkService = new ConsumerWorkService(); + ConsumerWorkService = new ConsumerWorkService(processingConcurrency); } _sessionManager = new SessionManager(this, 0); diff --git a/projects/RabbitMQ.Client/client/impl/ConsumerWorkService.cs b/projects/RabbitMQ.Client/client/impl/ConsumerWorkService.cs index 1d63069ba6..d8240a51d1 100644 --- a/projects/RabbitMQ.Client/client/impl/ConsumerWorkService.cs +++ b/projects/RabbitMQ.Client/client/impl/ConsumerWorkService.cs @@ -8,7 +8,15 @@ namespace RabbitMQ.Client.Impl internal class ConsumerWorkService { private readonly ConcurrentDictionary _workPools = new ConcurrentDictionary(); - private readonly Func _startNewWorkPoolFunc = model => StartNewWorkPool(model); + private readonly Func _startNewWorkPoolFunc; + protected readonly int _concurrency; + + public ConsumerWorkService(int concurrency) + { + _concurrency = concurrency; + + _startNewWorkPoolFunc = model => StartNewWorkPool(model); + } public void AddWork(IModel model, Action fn) { @@ -21,9 +29,9 @@ public void AddWork(IModel model, Action fn) workPool.Enqueue(fn); } - private static WorkPool StartNewWorkPool(IModel model) + private WorkPool StartNewWorkPool(IModel model) { - var newWorkPool = new WorkPool(); + var newWorkPool = new WorkPool(_concurrency); newWorkPool.Start(); return newWorkPool; } @@ -57,10 +65,13 @@ class WorkPool readonly CancellationTokenSource _tokenSource; readonly CancellationTokenRegistration _tokenRegistration; volatile TaskCompletionSource _syncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly int _concurrency; private Task _worker; + private SemaphoreSlim _limiter; - public WorkPool() + public WorkPool(int concurrency) { + _concurrency = concurrency; _actions = new ConcurrentQueue(); _tokenSource = new CancellationTokenSource(); _tokenRegistration = _tokenSource.Token.Register(() => _syncSource.TrySetCanceled()); @@ -68,7 +79,15 @@ public WorkPool() public void Start() { - _worker = Task.Run(Loop, CancellationToken.None); + if (_concurrency == 1) + { + _worker = Task.Run(Loop, CancellationToken.None); + } + else + { + _limiter = new SemaphoreSlim(_concurrency); + _worker = Task.Run(() => LoopWithConcurrency(_tokenSource.Token), CancellationToken.None); + } } public void Enqueue(Action action) @@ -105,10 +124,54 @@ async Task Loop() } } + async Task LoopWithConcurrency(CancellationToken cancellationToken) + { + while (_tokenSource.IsCancellationRequested == false) + { + try + { + await _syncSource.Task.ConfigureAwait(false); + _syncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + catch (TaskCanceledException) + { + // Swallowing the task cancellation exception for the semaphore in case we are stopping. + } + + while (_actions.TryDequeue(out Action action)) + { + // Do a quick synchronous check before we resort to async/await with the state-machine overhead. + if(!_limiter.Wait(0)) + { + await _limiter.WaitAsync(cancellationToken).ConfigureAwait(false); + } + + _ = OffloadToWorkerThreadPool(action, _limiter); + } + } + } + + static async Task OffloadToWorkerThreadPool(Action action, SemaphoreSlim limiter) + { + try + { + await Task.Run(() => action()); + } + catch (Exception) + { + // ignored + } + finally + { + limiter.Release(); + } + } + public Task Stop() { _tokenSource.Cancel(); _tokenRegistration.Dispose(); + _limiter?.Dispose(); return _worker; } } diff --git a/projects/Unit/APIApproval.Approve.verified.txt b/projects/Unit/APIApproval.Approve.verified.txt index 79af7ec8de..5959e02880 100644 --- a/projects/Unit/APIApproval.Approve.verified.txt +++ b/projects/Unit/APIApproval.Approve.verified.txt @@ -81,6 +81,7 @@ namespace RabbitMQ.Client public System.TimeSpan NetworkRecoveryInterval { get; set; } public string Password { get; set; } public int Port { get; set; } + public int ProcessingConcurrency { get; set; } public ushort RequestedChannelMax { get; set; } public System.TimeSpan RequestedConnectionTimeout { get; set; } public uint RequestedFrameMax { get; set; } diff --git a/projects/Unit/TestAsyncConsumer.cs b/projects/Unit/TestAsyncConsumer.cs index e1e4727a73..74009754aa 100644 --- a/projects/Unit/TestAsyncConsumer.cs +++ b/projects/Unit/TestAsyncConsumer.cs @@ -39,6 +39,7 @@ //--------------------------------------------------------------------------- using System; +using System.Text; using System.Threading; using System.Threading.Tasks; @@ -81,6 +82,59 @@ public void TestBasicRoundtrip() } } + [Test] + public async Task TestBasicRoundtripConcurrent() + { + var cf = new ConnectionFactory{ DispatchConsumersAsync = true, ProcessingConcurrency = 2 }; + using(IConnection c = cf.CreateConnection()) + using(IModel m = c.CreateModel()) + { + QueueDeclareOk q = m.QueueDeclare(); + IBasicProperties bp = m.CreateBasicProperties(); + const string publish1 = "async-hi-1"; + var body = Encoding.UTF8.GetBytes(publish1); + m.BasicPublish("", q.QueueName, bp, body); + const string publish2 = "async-hi-2"; + body = Encoding.UTF8.GetBytes(publish2); + m.BasicPublish("", q.QueueName, bp, body); + + var consumer = new AsyncEventingBasicConsumer(m); + + var publish1SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var publish2SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var maximumWaitTime = TimeSpan.FromSeconds(5); + var tokenSource = new CancellationTokenSource(maximumWaitTime); + tokenSource.Token.Register(() => + { + publish1SyncSource.TrySetResult(false); + publish2SyncSource.TrySetResult(false); + }); + + consumer.Received += async (o, a) => + { + switch (Encoding.UTF8.GetString(a.Body.ToArray())) + { + case publish1: + publish1SyncSource.TrySetResult(true); + await publish2SyncSource.Task; + break; + case publish2: + publish2SyncSource.TrySetResult(true); + await publish1SyncSource.Task; + break; + } + }; + + m.BasicConsume(q.QueueName, true, consumer); + // ensure we get a delivery + + await Task.WhenAll(publish1SyncSource.Task, publish2SyncSource.Task); + + Assert.IsTrue(publish1SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); + Assert.IsTrue(publish2SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); + } + } + [Test] public void TestBasicRoundtripNoWait() { diff --git a/projects/Unit/TestConsumer.cs b/projects/Unit/TestConsumer.cs new file mode 100644 index 0000000000..bba55c575e --- /dev/null +++ b/projects/Unit/TestConsumer.cs @@ -0,0 +1,65 @@ +using System; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using NUnit.Framework; +using RabbitMQ.Client.Events; + +namespace RabbitMQ.Client.Unit +{ + [TestFixture] + public class TestConsumer + { + [Test] + public async Task TestBasicRoundtripConcurrent() + { + var cf = new ConnectionFactory{ ProcessingConcurrency = 2 }; + using(IConnection c = cf.CreateConnection()) + using(IModel m = c.CreateModel()) + { + QueueDeclareOk q = m.QueueDeclare(); + IBasicProperties bp = m.CreateBasicProperties(); + const string publish1 = "sync-hi-1"; + var body = Encoding.UTF8.GetBytes(publish1); + m.BasicPublish("", q.QueueName, bp, body); + const string publish2 = "sync-hi-2"; + body = Encoding.UTF8.GetBytes(publish2); + m.BasicPublish("", q.QueueName, bp, body); + + var consumer = new EventingBasicConsumer(m); + + var publish1SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var publish2SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var maximumWaitTime = TimeSpan.FromSeconds(5); + var tokenSource = new CancellationTokenSource(maximumWaitTime); + tokenSource.Token.Register(() => + { + publish1SyncSource.TrySetResult(false); + publish2SyncSource.TrySetResult(false); + }); + + consumer.Received += (o, a) => + { + switch (Encoding.UTF8.GetString(a.Body.ToArray())) + { + case publish1: + publish1SyncSource.TrySetResult(true); + publish2SyncSource.Task.GetAwaiter().GetResult(); + break; + case publish2: + publish2SyncSource.TrySetResult(true); + publish1SyncSource.Task.GetAwaiter().GetResult(); + break; + } + }; + + m.BasicConsume(q.QueueName, true, consumer); + + await Task.WhenAll(publish1SyncSource.Task, publish2SyncSource.Task); + + Assert.IsTrue(publish1SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); + Assert.IsTrue(publish2SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); + } + } + } +}