Skip to content

Commit

Permalink
(Fix: DeviceClient): Fix the concurrency issue in MQTT stack (#2234)
Browse files Browse the repository at this point in the history
  • Loading branch information
azabbasi authored Dec 9, 2021
1 parent 2cf35f3 commit 4618ef3
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 43 deletions.
2 changes: 1 addition & 1 deletion iothub/device/src/Transport/Mqtt/MqttTransportHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,6 @@ private Func<IPAddress[], int, Task<IChannel>> CreateChannelFactory(IotHubConnec
.Handler(new ActionChannelInitializer<ISocketChannel>(ch =>
{
var tlsHandler = new TlsHandler(streamFactory, clientTlsSettings);

ch.Pipeline.AddLast(
tlsHandler,
MqttEncoder.Instance,
Expand Down Expand Up @@ -1308,6 +1307,7 @@ private Func<IPAddress[], int, Task<IChannel>> CreateWebSocketChannelFactory(Iot
await websocket.ConnectAsync(websocketUri, cts.Token).ConfigureAwait(false);

var clientWebSocketChannel = new ClientWebSocketChannel(null, websocket);

clientWebSocketChannel
.Option(ChannelOption.Allocator, UnpooledByteBufferAllocator.Default)
.Option(ChannelOption.AutoRead, false)
Expand Down
34 changes: 20 additions & 14 deletions iothub/device/src/Transport/Mqtt/OrderedTwoPhaseWorkQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Collections.Concurrent;
using System.Linq;
using System.Threading.Tasks;
using DotNetty.Transport.Channels;
Expand All @@ -27,7 +27,7 @@ public IncompleteWorkItem(TWorkId id, TWork workItem)

private readonly Func<TWork, TWorkId> _getWorkId;
private readonly Func<IChannelHandlerContext, TWork, Task> _completeWorkAsync;
private readonly Queue<IncompleteWorkItem> _incompleteQueue = new Queue<IncompleteWorkItem>();
private readonly ConcurrentQueue<IncompleteWorkItem> _incompleteQueue = new ConcurrentQueue<IncompleteWorkItem>();

public OrderedTwoPhaseWorkQueue(
Func<IChannelHandlerContext, TWork, Task> workerAsync,
Expand All @@ -46,16 +46,22 @@ public Task CompleteWorkAsync(IChannelHandlerContext context, TWorkId workId)
throw new IotHubException("Nothing to complete.", isTransient: false);
}

IncompleteWorkItem incompleteWorkItem = _incompleteQueue.Peek();
if (incompleteWorkItem.Id.Equals(workId))
if (_incompleteQueue.TryDequeue(out IncompleteWorkItem incompleteWorkItem))
{
_incompleteQueue.Dequeue();
return _completeWorkAsync(context, incompleteWorkItem.WorkItem);
}
if (incompleteWorkItem.Id.Equals(workId))
{
return _completeWorkAsync(context, incompleteWorkItem.WorkItem);
}

throw new IotHubException(
$"Work must be complete in the same order as it was started. Expected work id: '{incompleteWorkItem.Id}', actual work id: '{workId}'",
isTransient: false);
throw new IotHubException(
$"Work must be complete in the same order as it was started. Expected work id: '{incompleteWorkItem.Id}', actual work id: '{workId}'",
isTransient: false);
}
#if NET451
return TaskHelpers.CompletedTask;
#else
return Task.CompletedTask;
#endif
}

protected override async Task DoWorkAsync(IChannelHandlerContext context, TWork work)
Expand All @@ -77,17 +83,17 @@ public override void Abort(Exception exception)
if (stateBefore != State
&& State == States.Aborted)
{
while (_incompleteQueue.Any())
while (_incompleteQueue.TryDequeue(out IncompleteWorkItem workItem))
{
var workItem = _incompleteQueue.Dequeue().WorkItem as ICancellable;
var cancellableWorkItem = workItem.WorkItem as ICancellable;

if (exception == null)
{
workItem?.Cancel();
cancellableWorkItem?.Cancel();
}
else
{
workItem?.Abort(exception);
cancellableWorkItem?.Abort(exception);
}
}
}
Expand Down
13 changes: 6 additions & 7 deletions iothub/device/src/Transport/Mqtt/SimpleWorkQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Collections.Concurrent;
using System.Linq;
using System.Threading.Tasks;
using DotNetty.Common.Utilities;
Expand All @@ -25,14 +25,14 @@ namespace Microsoft.Azure.Devices.Client.Transport.Mqtt
internal class SimpleWorkQueue<TWork>
{
private readonly Func<IChannelHandlerContext, TWork, Task> _workerAsync;
private readonly Queue<TWork> _backlogQueue;
private readonly ConcurrentQueue<TWork> _backlogQueue;
private readonly TaskCompletionSource _completionSource;

public SimpleWorkQueue(Func<IChannelHandlerContext, TWork, Task> workerAsync)
{
_workerAsync = workerAsync;
_completionSource = new TaskCompletionSource();
_backlogQueue = new Queue<TWork>();
_backlogQueue = new ConcurrentQueue<TWork>();
}

protected States State { get; set; }
Expand Down Expand Up @@ -111,9 +111,8 @@ public virtual void Abort(Exception exception)
case States.FinalProcessing:
State = States.Aborted;

while (_backlogQueue.Any())
while (_backlogQueue.TryDequeue(out TWork workItem))
{
TWork workItem = _backlogQueue.Dequeue();
ReferenceCountUtil.Release(workItem);

var cancellableWorkItem = workItem as ICancellable;
Expand Down Expand Up @@ -146,9 +145,9 @@ private async void StartWorkQueueProcessingAsync(IChannelHandlerContext context)
try
{
while (_backlogQueue.Any()
&& State != States.Aborted)
&& State != States.Aborted
&& _backlogQueue.TryDequeue(out TWork workItem))
{
TWork workItem = _backlogQueue.Dequeue();
await DoWorkAsync(context, workItem).ConfigureAwait(false);
}

Expand Down
6 changes: 3 additions & 3 deletions provisioning/device/src/ProvisioningTransportException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ public ProvisioningTransportException(string message, Exception innerException,
public ProvisioningTransportException(string message, Exception innerException, bool isTransient, ProvisioningErrorDetails errorDetails)
: base(message, innerException)
{
this.IsTransient = isTransient;
this.ErrorDetails = errorDetails;
IsTransient = isTransient;
ErrorDetails = errorDetails;
}

/// <summary>
Expand All @@ -125,7 +125,7 @@ protected ProvisioningTransportException(SerializationInfo info, StreamingContex
public override void GetObjectData(SerializationInfo info, StreamingContext context)
{
base.GetObjectData(info, context);
info.AddValue(IsTransientValueSerializationStoreName, this.IsTransient);
info.AddValue(IsTransientValueSerializationStoreName, IsTransient);
}
}
}
27 changes: 9 additions & 18 deletions provisioning/transport/amqp/src/ProvisioningErrorDetailsAmqp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace Microsoft.Azure.Devices.Provisioning.Client.Transport
{
[SuppressMessage("Microsoft.Performance", "CA1812", Justification = "Is instantiated by json convertor")]
[SuppressMessage("Microsoft.Performance", "CA1812", Justification = "Is instantiated by json converter")]
internal class ProvisioningErrorDetailsAmqp : ProvisioningErrorDetails
{
/// <summary>
Expand All @@ -20,11 +20,9 @@ internal class ProvisioningErrorDetailsAmqp : ProvisioningErrorDetails

public static TimeSpan? GetRetryAfterFromApplicationProperties(AmqpMessage amqpResponse, TimeSpan defaultInterval)
{
object retryAfter;
if (amqpResponse.ApplicationProperties != null && amqpResponse.ApplicationProperties.Map.TryGetValue(RetryAfterKey, out retryAfter))
if (amqpResponse.ApplicationProperties != null && amqpResponse.ApplicationProperties.Map.TryGetValue(RetryAfterKey, out object retryAfter))
{
int secondsToWait;
if (int.TryParse(retryAfter.ToString(), out secondsToWait))
if (int.TryParse(retryAfter.ToString(), out int secondsToWait))
{
var serviceRecommendedDelay = TimeSpan.FromSeconds(secondsToWait);

Expand All @@ -46,23 +44,16 @@ internal class ProvisioningErrorDetailsAmqp : ProvisioningErrorDetails
{
if (rejected.Error != null && rejected.Error.Info != null)
{
object retryAfter;
if (rejected.Error.Info.TryGetValue(RetryAfterKey, out retryAfter))
if (rejected.Error.Info.TryGetValue(RetryAfterKey, out object retryAfter))
{
int secondsToWait = 0;
if (int.TryParse(retryAfter.ToString(), out secondsToWait))
if (int.TryParse(retryAfter.ToString(), out int secondsToWait))
{
if (secondsToWait < defaultInterval.Seconds)
{
return defaultInterval;
}
else
{
return TimeSpan.FromSeconds(secondsToWait);
}
return secondsToWait < defaultInterval.Seconds
? defaultInterval
: TimeSpan.FromSeconds(secondsToWait);
}
}

}

return null;
Expand Down

0 comments on commit 4618ef3

Please sign in to comment.