diff --git a/projects/RabbitMQ.Client/client/api/ConnectionFactory.cs b/projects/RabbitMQ.Client/client/api/ConnectionFactory.cs
index 15d5b07e1c..091c1f6474 100644
--- a/projects/RabbitMQ.Client/client/api/ConnectionFactory.cs
+++ b/projects/RabbitMQ.Client/client/api/ConnectionFactory.cs
@@ -175,6 +175,16 @@ public sealed class ConnectionFactory : ConnectionFactoryBase, IAsyncConnectionF
///
public bool DispatchConsumersAsync { get; set; } = false;
+ ///
+ /// Set to a value greater than one to enable concurrent processing. For a concurrency greater than one
+ /// will be offloaded to the worker thread pool so it is important to choose the value for the concurrency wisely to avoid thread pool overloading.
+ /// can handle concurrency much more efficiently due to the non-blocking nature of the consumer.
+ /// Defaults to 1.
+ ///
+ /// For concurrency greater than one this removes the guarantee that consumers handle messages in the order they receive them.
+ /// In addition to that consumers need to be thread/concurrency safe.
+ public int ProcessingConcurrency { get; set; } = 1;
+
/// The host to connect to.
public string HostName { get; set; } = "localhost";
diff --git a/projects/RabbitMQ.Client/client/impl/AsyncConsumerWorkService.cs b/projects/RabbitMQ.Client/client/impl/AsyncConsumerWorkService.cs
index 5686aaf372..dfe1f45f28 100644
--- a/projects/RabbitMQ.Client/client/impl/AsyncConsumerWorkService.cs
+++ b/projects/RabbitMQ.Client/client/impl/AsyncConsumerWorkService.cs
@@ -9,7 +9,12 @@ namespace RabbitMQ.Client.Impl
internal sealed class AsyncConsumerWorkService : ConsumerWorkService
{
private readonly ConcurrentDictionary _workPools = new ConcurrentDictionary();
- private readonly Func _startNewWorkPoolFunc = model => StartNewWorkPool(model);
+ private readonly Func _startNewWorkPoolFunc;
+
+ public AsyncConsumerWorkService(int concurrency) : base(concurrency)
+ {
+ _startNewWorkPoolFunc = model => StartNewWorkPool(model);
+ }
public void Schedule(ModelBase model, TWork work) where TWork : Work
{
@@ -22,9 +27,9 @@ public void Schedule(ModelBase model, TWork work) where TWork : Work
workPool.Enqueue(work);
}
- private static WorkPool StartNewWorkPool(IModel model)
+ private WorkPool StartNewWorkPool(IModel model)
{
- var newWorkPool = new WorkPool(model as ModelBase);
+ var newWorkPool = new WorkPool(model as ModelBase, _concurrency);
newWorkPool.Start();
return newWorkPool;
}
@@ -44,16 +49,29 @@ class WorkPool
readonly Channel _channel;
readonly ModelBase _model;
private Task _worker;
+ private readonly int _concurrency;
+ private SemaphoreSlim _limiter;
+ private CancellationTokenSource _tokenSource;
- public WorkPool(ModelBase model)
+ public WorkPool(ModelBase model, int concurrency)
{
+ _concurrency = concurrency;
_model = model;
_channel = Channel.CreateUnbounded(new UnboundedChannelOptions { SingleReader = true, SingleWriter = false, AllowSynchronousContinuations = false });
}
public void Start()
{
- _worker = Task.Run(Loop, CancellationToken.None);
+ if (_concurrency == 1)
+ {
+ _worker = Task.Run(Loop, CancellationToken.None);
+ }
+ else
+ {
+ _limiter = new SemaphoreSlim(_concurrency);
+ _tokenSource = new CancellationTokenSource();
+ _worker = Task.Run(() => LoopWithConcurrency(_tokenSource.Token), CancellationToken.None);
+ }
}
public void Enqueue(Work work)
@@ -83,9 +101,55 @@ async Task Loop()
}
}
+ async Task LoopWithConcurrency(CancellationToken cancellationToken)
+ {
+ try
+ {
+ while (await _channel.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
+ {
+ while (_channel.Reader.TryRead(out Work work))
+ {
+ // Do a quick synchronous check before we resort to async/await with the state-machine overhead.
+ if(!_limiter.Wait(0))
+ {
+ await _limiter.WaitAsync(cancellationToken).ConfigureAwait(false);
+ }
+
+ _ = HandleConcurrent(work, _model, _limiter);
+ }
+ }
+ }
+ catch (OperationCanceledException)
+ {
+ // ignored
+ }
+ }
+
+ static async Task HandleConcurrent(Work work, ModelBase model, SemaphoreSlim limiter)
+ {
+ try
+ {
+ Task task = work.Execute(model);
+ if (!task.IsCompleted)
+ {
+ await task.ConfigureAwait(false);
+ }
+ }
+ catch (Exception)
+ {
+
+ }
+ finally
+ {
+ limiter.Release();
+ }
+ }
+
public Task Stop()
{
_channel.Writer.Complete();
+ _tokenSource?.Cancel();
+ _limiter?.Dispose();
return _worker;
}
}
diff --git a/projects/RabbitMQ.Client/client/impl/Connection.cs b/projects/RabbitMQ.Client/client/impl/Connection.cs
index 36ce1a83e7..90662991d3 100644
--- a/projects/RabbitMQ.Client/client/impl/Connection.cs
+++ b/projects/RabbitMQ.Client/client/impl/Connection.cs
@@ -110,13 +110,14 @@ public Connection(IConnectionFactory factory, bool insist, IFrameHandler frameHa
_factory = factory;
_frameHandler = frameHandler;
+ int processingConcurrency = (factory as ConnectionFactory)?.ProcessingConcurrency ?? 1;
if (factory is IAsyncConnectionFactory asyncConnectionFactory && asyncConnectionFactory.DispatchConsumersAsync)
{
- ConsumerWorkService = new AsyncConsumerWorkService();
+ ConsumerWorkService = new AsyncConsumerWorkService(processingConcurrency);
}
else
{
- ConsumerWorkService = new ConsumerWorkService();
+ ConsumerWorkService = new ConsumerWorkService(processingConcurrency);
}
_sessionManager = new SessionManager(this, 0);
diff --git a/projects/RabbitMQ.Client/client/impl/ConsumerWorkService.cs b/projects/RabbitMQ.Client/client/impl/ConsumerWorkService.cs
index 1d63069ba6..d8240a51d1 100644
--- a/projects/RabbitMQ.Client/client/impl/ConsumerWorkService.cs
+++ b/projects/RabbitMQ.Client/client/impl/ConsumerWorkService.cs
@@ -8,7 +8,15 @@ namespace RabbitMQ.Client.Impl
internal class ConsumerWorkService
{
private readonly ConcurrentDictionary _workPools = new ConcurrentDictionary();
- private readonly Func _startNewWorkPoolFunc = model => StartNewWorkPool(model);
+ private readonly Func _startNewWorkPoolFunc;
+ protected readonly int _concurrency;
+
+ public ConsumerWorkService(int concurrency)
+ {
+ _concurrency = concurrency;
+
+ _startNewWorkPoolFunc = model => StartNewWorkPool(model);
+ }
public void AddWork(IModel model, Action fn)
{
@@ -21,9 +29,9 @@ public void AddWork(IModel model, Action fn)
workPool.Enqueue(fn);
}
- private static WorkPool StartNewWorkPool(IModel model)
+ private WorkPool StartNewWorkPool(IModel model)
{
- var newWorkPool = new WorkPool();
+ var newWorkPool = new WorkPool(_concurrency);
newWorkPool.Start();
return newWorkPool;
}
@@ -57,10 +65,13 @@ class WorkPool
readonly CancellationTokenSource _tokenSource;
readonly CancellationTokenRegistration _tokenRegistration;
volatile TaskCompletionSource _syncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ private readonly int _concurrency;
private Task _worker;
+ private SemaphoreSlim _limiter;
- public WorkPool()
+ public WorkPool(int concurrency)
{
+ _concurrency = concurrency;
_actions = new ConcurrentQueue();
_tokenSource = new CancellationTokenSource();
_tokenRegistration = _tokenSource.Token.Register(() => _syncSource.TrySetCanceled());
@@ -68,7 +79,15 @@ public WorkPool()
public void Start()
{
- _worker = Task.Run(Loop, CancellationToken.None);
+ if (_concurrency == 1)
+ {
+ _worker = Task.Run(Loop, CancellationToken.None);
+ }
+ else
+ {
+ _limiter = new SemaphoreSlim(_concurrency);
+ _worker = Task.Run(() => LoopWithConcurrency(_tokenSource.Token), CancellationToken.None);
+ }
}
public void Enqueue(Action action)
@@ -105,10 +124,54 @@ async Task Loop()
}
}
+ async Task LoopWithConcurrency(CancellationToken cancellationToken)
+ {
+ while (_tokenSource.IsCancellationRequested == false)
+ {
+ try
+ {
+ await _syncSource.Task.ConfigureAwait(false);
+ _syncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ }
+ catch (TaskCanceledException)
+ {
+ // Swallowing the task cancellation exception for the semaphore in case we are stopping.
+ }
+
+ while (_actions.TryDequeue(out Action action))
+ {
+ // Do a quick synchronous check before we resort to async/await with the state-machine overhead.
+ if(!_limiter.Wait(0))
+ {
+ await _limiter.WaitAsync(cancellationToken).ConfigureAwait(false);
+ }
+
+ _ = OffloadToWorkerThreadPool(action, _limiter);
+ }
+ }
+ }
+
+ static async Task OffloadToWorkerThreadPool(Action action, SemaphoreSlim limiter)
+ {
+ try
+ {
+ await Task.Run(() => action());
+ }
+ catch (Exception)
+ {
+ // ignored
+ }
+ finally
+ {
+ limiter.Release();
+ }
+ }
+
public Task Stop()
{
_tokenSource.Cancel();
_tokenRegistration.Dispose();
+ _limiter?.Dispose();
return _worker;
}
}
diff --git a/projects/Unit/APIApproval.Approve.verified.txt b/projects/Unit/APIApproval.Approve.verified.txt
index 79af7ec8de..5959e02880 100644
--- a/projects/Unit/APIApproval.Approve.verified.txt
+++ b/projects/Unit/APIApproval.Approve.verified.txt
@@ -81,6 +81,7 @@ namespace RabbitMQ.Client
public System.TimeSpan NetworkRecoveryInterval { get; set; }
public string Password { get; set; }
public int Port { get; set; }
+ public int ProcessingConcurrency { get; set; }
public ushort RequestedChannelMax { get; set; }
public System.TimeSpan RequestedConnectionTimeout { get; set; }
public uint RequestedFrameMax { get; set; }
diff --git a/projects/Unit/TestAsyncConsumer.cs b/projects/Unit/TestAsyncConsumer.cs
index e1e4727a73..74009754aa 100644
--- a/projects/Unit/TestAsyncConsumer.cs
+++ b/projects/Unit/TestAsyncConsumer.cs
@@ -39,6 +39,7 @@
//---------------------------------------------------------------------------
using System;
+using System.Text;
using System.Threading;
using System.Threading.Tasks;
@@ -81,6 +82,59 @@ public void TestBasicRoundtrip()
}
}
+ [Test]
+ public async Task TestBasicRoundtripConcurrent()
+ {
+ var cf = new ConnectionFactory{ DispatchConsumersAsync = true, ProcessingConcurrency = 2 };
+ using(IConnection c = cf.CreateConnection())
+ using(IModel m = c.CreateModel())
+ {
+ QueueDeclareOk q = m.QueueDeclare();
+ IBasicProperties bp = m.CreateBasicProperties();
+ const string publish1 = "async-hi-1";
+ var body = Encoding.UTF8.GetBytes(publish1);
+ m.BasicPublish("", q.QueueName, bp, body);
+ const string publish2 = "async-hi-2";
+ body = Encoding.UTF8.GetBytes(publish2);
+ m.BasicPublish("", q.QueueName, bp, body);
+
+ var consumer = new AsyncEventingBasicConsumer(m);
+
+ var publish1SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var publish2SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var maximumWaitTime = TimeSpan.FromSeconds(5);
+ var tokenSource = new CancellationTokenSource(maximumWaitTime);
+ tokenSource.Token.Register(() =>
+ {
+ publish1SyncSource.TrySetResult(false);
+ publish2SyncSource.TrySetResult(false);
+ });
+
+ consumer.Received += async (o, a) =>
+ {
+ switch (Encoding.UTF8.GetString(a.Body.ToArray()))
+ {
+ case publish1:
+ publish1SyncSource.TrySetResult(true);
+ await publish2SyncSource.Task;
+ break;
+ case publish2:
+ publish2SyncSource.TrySetResult(true);
+ await publish1SyncSource.Task;
+ break;
+ }
+ };
+
+ m.BasicConsume(q.QueueName, true, consumer);
+ // ensure we get a delivery
+
+ await Task.WhenAll(publish1SyncSource.Task, publish2SyncSource.Task);
+
+ Assert.IsTrue(publish1SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}");
+ Assert.IsTrue(publish2SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}");
+ }
+ }
+
[Test]
public void TestBasicRoundtripNoWait()
{
diff --git a/projects/Unit/TestConsumer.cs b/projects/Unit/TestConsumer.cs
new file mode 100644
index 0000000000..bba55c575e
--- /dev/null
+++ b/projects/Unit/TestConsumer.cs
@@ -0,0 +1,65 @@
+using System;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using NUnit.Framework;
+using RabbitMQ.Client.Events;
+
+namespace RabbitMQ.Client.Unit
+{
+ [TestFixture]
+ public class TestConsumer
+ {
+ [Test]
+ public async Task TestBasicRoundtripConcurrent()
+ {
+ var cf = new ConnectionFactory{ ProcessingConcurrency = 2 };
+ using(IConnection c = cf.CreateConnection())
+ using(IModel m = c.CreateModel())
+ {
+ QueueDeclareOk q = m.QueueDeclare();
+ IBasicProperties bp = m.CreateBasicProperties();
+ const string publish1 = "sync-hi-1";
+ var body = Encoding.UTF8.GetBytes(publish1);
+ m.BasicPublish("", q.QueueName, bp, body);
+ const string publish2 = "sync-hi-2";
+ body = Encoding.UTF8.GetBytes(publish2);
+ m.BasicPublish("", q.QueueName, bp, body);
+
+ var consumer = new EventingBasicConsumer(m);
+
+ var publish1SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var publish2SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var maximumWaitTime = TimeSpan.FromSeconds(5);
+ var tokenSource = new CancellationTokenSource(maximumWaitTime);
+ tokenSource.Token.Register(() =>
+ {
+ publish1SyncSource.TrySetResult(false);
+ publish2SyncSource.TrySetResult(false);
+ });
+
+ consumer.Received += (o, a) =>
+ {
+ switch (Encoding.UTF8.GetString(a.Body.ToArray()))
+ {
+ case publish1:
+ publish1SyncSource.TrySetResult(true);
+ publish2SyncSource.Task.GetAwaiter().GetResult();
+ break;
+ case publish2:
+ publish2SyncSource.TrySetResult(true);
+ publish1SyncSource.Task.GetAwaiter().GetResult();
+ break;
+ }
+ };
+
+ m.BasicConsume(q.QueueName, true, consumer);
+
+ await Task.WhenAll(publish1SyncSource.Task, publish2SyncSource.Task);
+
+ Assert.IsTrue(publish1SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}");
+ Assert.IsTrue(publish2SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}");
+ }
+ }
+ }
+}