From 39db97819c9c2f0f787f3ba11b7baeda6ea32cf0 Mon Sep 17 00:00:00 2001 From: timtay-microsoft Date: Fri, 26 May 2023 11:49:16 -0700 Subject: [PATCH] fix(iot-dev): Make correlation callback cleanup/get operations atomic We don't want cases where one thread is calling callbacks.get("some id") while another just removed that entry. #1718 --- .../iot/device/transport/IotHubTransport.java | 152 ++++++++++-------- 1 file changed, 85 insertions(+), 67 deletions(-) diff --git a/iothub/device/iot-device-client/src/main/java/com/microsoft/azure/sdk/iot/device/transport/IotHubTransport.java b/iothub/device/iot-device-client/src/main/java/com/microsoft/azure/sdk/iot/device/transport/IotHubTransport.java index 994954f9eb..d9b4ba7d41 100644 --- a/iothub/device/iot-device-client/src/main/java/com/microsoft/azure/sdk/iot/device/transport/IotHubTransport.java +++ b/iothub/device/iot-device-client/src/main/java/com/microsoft/azure/sdk/iot/device/transport/IotHubTransport.java @@ -126,6 +126,7 @@ public class IotHubTransport implements IotHubListener // A job that runs periodically to remove any stale correlation callbacks private Thread correlationCallbackCleanupThread = new Thread(() -> checkForOldMessages()); private static final int CORRELATION_CALLBACK_CLEANUP_PERIOD_MILLISECONDS = 60 * 60 * 1000; + private final Object correlationCallbackOperationLock = new Object(); /** * Constructor for an IotHubTransport object with default values @@ -268,20 +269,22 @@ public void onMessageSent(Message message, String deviceId, TransportException e if (!correlationId.isEmpty()) { - CorrelatingMessageCallback callback = correlationCallbacks.get(correlationId); - - if (callback != null) + synchronized (this.correlationCallbackOperationLock) { - Object context = correlationCallbackContexts.get(correlationId); - IotHubClientException clientException = null; - if (e != null) + CorrelatingMessageCallback callback = correlationCallbacks.get(correlationId); + + if (callback != null) { - clientException = e.toIotHubClientException(); + Object context = correlationCallbackContexts.get(correlationId); + IotHubClientException clientException = null; + if (e != null) + { + clientException = e.toIotHubClientException(); + } + callback.onRequestAcknowledged(packet.getMessage(), context, clientException); } - callback.onRequestAcknowledged(packet.getMessage(), context, clientException); } } - } catch (Exception ex) { @@ -321,31 +324,34 @@ else if (message != null) String correlationId = message.getCorrelationId(); if (!correlationId.isEmpty()) { - CorrelatingMessageCallback callback = correlationCallbacks.get(correlationId); - - if (callback != null) + synchronized (this.correlationCallbackOperationLock) { - Object context = correlationCallbackContexts.get(correlationId); - IotHubClientException clientException = null; - if (e != null) - { - // This case indicates that the transport layer failed to construct a valid message out of - // a message delivered by the service - clientException = e.toIotHubClientException(); - } - else + CorrelatingMessageCallback callback = correlationCallbacks.get(correlationId); + + if (callback != null) { - // This case indicates that the transport layer constructed a valid message out of a message - // delivered by the service, but that message may contain an unsuccessful status code in cases - // such as if an operation was rejected because it was badly formatted. - IotHubStatusCode statusCode = IotHubStatusCode.getIotHubStatusCode(Integer.parseInt(message.getStatus())); - if (!IotHubStatusCode.isSuccessful(statusCode)) + Object context = correlationCallbackContexts.get(correlationId); + IotHubClientException clientException = null; + if (e != null) { - clientException = new IotHubClientException(statusCode, "Received an unsuccessful operation error code from the service: " + statusCode); + // This case indicates that the transport layer failed to construct a valid message out of + // a message delivered by the service + clientException = e.toIotHubClientException(); + } + else + { + // This case indicates that the transport layer constructed a valid message out of a message + // delivered by the service, but that message may contain an unsuccessful status code in cases + // such as if an operation was rejected because it was badly formatted. + IotHubStatusCode statusCode = IotHubStatusCode.getIotHubStatusCode(Integer.parseInt(message.getStatus())); + if (!IotHubStatusCode.isSuccessful(statusCode)) + { + clientException = new IotHubClientException(statusCode, "Received an unsuccessful operation error code from the service: " + statusCode); + } } - } - callback.onResponseReceived(message, context, clientException); + callback.onResponseReceived(message, context, clientException); + } } } } @@ -767,12 +773,15 @@ public void sendMessages() if (!correlationId.isEmpty()) { - CorrelatingMessageCallback callback = correlationCallbacks.get(correlationId); - - if (callback != null) + synchronized (this.correlationCallbackOperationLock) { - Object context = correlationCallbackContexts.get(correlationId); - callback.onRequestSent(message, context); + CorrelatingMessageCallback callback = correlationCallbacks.get(correlationId); + + if (callback != null) + { + Object context = correlationCallbackContexts.get(correlationId); + callback.onRequestSent(message, context); + } } } } @@ -859,19 +868,22 @@ private void checkForOldMessages() List correlationIdsToRemove = new ArrayList<>(); - for (String correlationId : correlationCallbacks.keySet()) + synchronized (this.correlationCallbackOperationLock) { - if (System.currentTimeMillis() - correlationStartTimeMillis.get(correlationId) >= DEFAULT_CORRELATION_ID_LIVE_TIME) + for (String correlationId : correlationCallbacks.keySet()) { - correlationIdsToRemove.add(correlationId); - correlationCallbackContexts.remove(correlationId); - correlationStartTimeMillis.remove(correlationId); + if (System.currentTimeMillis() - correlationStartTimeMillis.get(correlationId) >= DEFAULT_CORRELATION_ID_LIVE_TIME) + { + correlationIdsToRemove.add(correlationId); + correlationCallbackContexts.remove(correlationId); + correlationStartTimeMillis.remove(correlationId); + } } - } - for (String correlationId : correlationIdsToRemove) - { - correlationCallbacks.remove(correlationId); + for (String correlationId : correlationIdsToRemove) + { + correlationCallbacks.remove(correlationId); + } } } } @@ -1216,22 +1228,22 @@ private void acknowledgeReceivedMessage(IotHubTransportMessage receivedMessage) String correlationId = receivedMessage.getCorrelationId(); if (!correlationId.isEmpty()) { - CorrelatingMessageCallback callback = correlationCallbacks.get(correlationId); - - if (callback != null) + synchronized (this.correlationCallbackOperationLock) { - Object context = correlationCallbackContexts.get(correlationId); - callback.onResponseAcknowledged(receivedMessage, context); - } + CorrelatingMessageCallback callback = correlationCallbacks.get(correlationId); - // We need to remove the CorrelatingMessageCallback with the current correlation ID from the map after the received C2D - // message has been acknowledged. Otherwise, the size of map will grow endlessly which results in OutOfMemory eventually. - new Thread(() -> - { - correlationCallbacks.remove(correlationId); //TODO wat + if (callback != null) + { + Object context = correlationCallbackContexts.get(correlationId); + callback.onResponseAcknowledged(receivedMessage, context); + } + + // We need to remove the CorrelatingMessageCallback with the current correlation ID from the map after the received C2D + // message has been acknowledged. Otherwise, the size of map will grow endlessly which results in OutOfMemory eventually. + correlationCallbacks.remove(correlationId); correlationCallbackContexts.remove(correlationId); correlationStartTimeMillis.remove(correlationId); - }).start(); + } } } catch (Exception ex) @@ -1268,12 +1280,15 @@ private void addReceivedMessagesOverHttpToReceivedQueue() throws TransportExcept String correlationId = transportMessage.getCorrelationId(); if (!correlationId.isEmpty()) { - CorrelatingMessageCallback callback = correlationCallbacks.get(correlationId); - - if (callback != null) + synchronized (this.correlationCallbackOperationLock) { - Object context = correlationCallbackContexts.get(correlationId); - callback.onResponseReceived(transportMessage, context, null); + CorrelatingMessageCallback callback = correlationCallbacks.get(correlationId); + + if (callback != null) + { + Object context = correlationCallbackContexts.get(correlationId); + callback.onResponseReceived(transportMessage, context, null); + } } } } @@ -1873,15 +1888,18 @@ private void addToWaitingQueue(IotHubTransportPacket packet) CorrelatingMessageCallback correlationCallback = message.getCorrelatingMessageCallback(); if (!correlationId.isEmpty() && correlationCallback != null) { - correlationCallbacks.put(correlationId, correlationCallback); - correlationStartTimeMillis.put(correlationId, System.currentTimeMillis()); - - Object correlationCallbackContext = message.getCorrelatingMessageCallbackContext(); - if (correlationCallbackContext != null) + synchronized (this.correlationCallbackOperationLock) { - correlationCallbackContexts.put(correlationId, correlationCallbackContext); + correlationCallbacks.put(correlationId, correlationCallback); + correlationStartTimeMillis.put(correlationId, System.currentTimeMillis()); + + Object correlationCallbackContext = message.getCorrelatingMessageCallbackContext(); + if (correlationCallbackContext != null) + { + correlationCallbackContexts.put(correlationId, correlationCallbackContext); + } + correlationCallback.onRequestQueued(message, correlationCallbackContext); } - correlationCallback.onRequestQueued(message, correlationCallbackContext); } } }