-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bugfix: Fix interrupted network read operation #120
Changes from 15 commits
5ec1c2f
4205109
0704c72
adce766
f5aa1b3
76166d0
1d73bb7
0c5a9e3
34adc86
2a3cc9a
70b181a
68fcb16
4ba8c53
b710790
f8586dc
c59d39b
af1ca02
86798bf
e1de841
9a792bc
ebcfb92
fe0a071
32f0291
dd91637
a4bf796
bf6f94d
092e7ff
f0a1808
fcaf4bb
b923ea5
b1eba61
25e74d6
b4b0c78
95abb93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,17 +71,25 @@ static uint32_t calculateElapsedTime( uint32_t later, | |
static MQTTPubAckType_t getAckFromPacketType( uint8_t packetType ); | ||
|
||
/** | ||
* @brief Receive bytes into the network buffer, with a timeout. | ||
* @brief Receive bytes into the network buffer. | ||
* | ||
* @param[in] pContext Initialized MQTT Context. | ||
* @param[in] bytesToRecv Number of bytes to receive. | ||
* @param[in] timeoutMs Time remaining to receive the packet. | ||
* | ||
* @note This operation calls the transport receive function | ||
* repeatedly to read bytes from the network until either: | ||
* 1. The requested number of bytes @a bytesToRecv are read. | ||
* OR | ||
* 2. No data is received from the network for MQTT_RECV_POLLING_TIMEOUT_MS duration. | ||
* | ||
* OR | ||
* 3. There is an error in reading from the network. | ||
* | ||
* | ||
* @return Number of bytes received, or negative number on network error. | ||
*/ | ||
static int32_t recvExact( const MQTTContext_t * pContext, | ||
size_t bytesToRecv, | ||
uint32_t timeoutMs ); | ||
size_t bytesToRecv ); | ||
|
||
/** | ||
* @brief Discard a packet from the transport interface. | ||
|
@@ -683,13 +691,12 @@ static MQTTPubAckType_t getAckFromPacketType( uint8_t packetType ) | |
/*-----------------------------------------------------------*/ | ||
|
||
static int32_t recvExact( const MQTTContext_t * pContext, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. static int32_t recvExact( const MQTTContext_t * pContext,
size_t bytesToRecv )
{
uint8_t * pIndex = NULL;
size_t bytesRemaining = bytesToRecv;
int32_t totalBytesRecvd = 0, bytesRecvd;
uint32_t timestampBeforeRecvMs = 0U, timeSpentInRecvMs = 0U;
TransportRecv_t recvFunc = NULL;
MQTTGetCurrentTimeFunc_t getTimeStampMs = NULL;
bool receiveError = false;
assert( pContext != NULL );
assert( bytesToRecv <= pContext->networkBuffer.size );
assert( pContext->getTime != NULL );
assert( pContext->transportInterface.recv != NULL );
assert( pContext->networkBuffer.pBuffer != NULL );
pIndex = pContext->networkBuffer.pBuffer;
recvFunc = pContext->transportInterface.recv;
getTimeStampMs = pContext->getTime;
while( ( bytesRemaining > 0U ) && ( receiveError == false ) )
{
timestampBeforeRecvMs = getTimeStampMs();
bytesRecvd = recvFunc( pContext->transportInterface.pNetworkContext,
pIndex,
bytesRemaining );
timeSpentInRecvMs = calculateElapsedTime( getTimeStampMs(), timestampBeforeRecvMs );
if( bytesRecvd < 0 )
{
LogError( ( "Network error while receiving packet: ReturnCode=%ld.",
( long int ) bytesRecvd ) );
totalBytesRecvd = bytesRecvd;
receiveError = true;
}
else if( bytesRecvd > 0 )
{
/* It is a bug in the application's transport receive implementation
* if more bytes than expected are received. To avoid a possible
* overflow in converting bytesRemaining from unsigned to signed,
* this assert must exist after the check for bytesRecvd being
* negative. */
assert( ( size_t ) bytesRecvd <= bytesRemaining );
bytesRemaining -= ( size_t ) bytesRecvd;
totalBytesRecvd += ( int32_t ) bytesRecvd;
pIndex += bytesRecvd;
LogDebug( ( "BytesReceived=%ld, BytesRemaining=%lu, "
"TotalBytesReceived=%ld.",
( long int ) bytesRecvd,
( unsigned long ) bytesRemaining,
( long int ) totalBytesRecvd ) );
}
else
{
/* No bytes were read from the network. */
}
/* If there is more data to be received and nothing was received
* for MQTT_RECV_POLLING_TIMEOUT_MS, treat this as error. */
if( ( bytesRemaining > 0U ) &&
( timeSpentInRecvMs >= MQTT_RECV_POLLING_TIMEOUT_MS ) &&
( bytesRecvd == 0 ) )
{
LogError( ( "Time expired while receiving packet." ) );
receiveError = true;
}
}
return totalBytesRecvd;
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will only calculate the time spent for each call of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In your code suggestion, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. In that case, changing the variable names would make it a bit more readable: entryTimeMs --> lastDataRecvdTimestampMs
noDataRecvdTimeMs --> timeSinceLastDataWasRecvd |
||
size_t bytesToRecv, | ||
uint32_t timeoutMs ) | ||
size_t bytesToRecv ) | ||
{ | ||
uint8_t * pIndex = NULL; | ||
size_t bytesRemaining = bytesToRecv; | ||
int32_t totalBytesRecvd = 0, bytesRecvd; | ||
uint32_t entryTimeMs = 0U, elapsedTimeMs = 0U; | ||
uint32_t lastDataRecvTimeMs = 0U, timeSinceLastDataWasRecvdMs = 0U; | ||
TransportRecv_t recvFunc = NULL; | ||
MQTTGetCurrentTimeFunc_t getTimeStampMs = NULL; | ||
bool receiveError = false; | ||
|
@@ -704,7 +711,8 @@ static int32_t recvExact( const MQTTContext_t * pContext, | |
recvFunc = pContext->transportInterface.recv; | ||
getTimeStampMs = pContext->getTime; | ||
|
||
entryTimeMs = getTimeStampMs(); | ||
/* Part of the MQTT packet has been read before calling this function. */ | ||
lastDataRecvTimeMs = getTimeStampMs(); | ||
|
||
while( ( bytesRemaining > 0U ) && ( receiveError == false ) ) | ||
{ | ||
|
@@ -719,8 +727,11 @@ static int32_t recvExact( const MQTTContext_t * pContext, | |
totalBytesRecvd = bytesRecvd; | ||
receiveError = true; | ||
} | ||
else | ||
else if( bytesRecvd > 0 ) | ||
{ | ||
/* Reset the starting time as we have received some data from the network. */ | ||
lastDataRecvTimeMs = getTimeStampMs(); | ||
|
||
/* It is a bug in the application's transport receive implementation | ||
* if more bytes than expected are received. To avoid a possible | ||
* overflow in converting bytesRemaining from unsigned to signed, | ||
|
@@ -737,10 +748,14 @@ static int32_t recvExact( const MQTTContext_t * pContext, | |
( unsigned long ) bytesRemaining, | ||
( long int ) totalBytesRecvd ) ); | ||
} | ||
else | ||
{ | ||
/* No bytes were read from the network. */ | ||
timeSinceLastDataWasRecvdMs = calculateElapsedTime( getTimeStampMs(), lastDataRecvTimeMs ); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think mapping a return value of 0 to mean no data to receive is an assumption that needs to be documented more clearly. In POSIX sockets and OpenSSL, for example, a return value of 0 means the server has closed the connection or that a close-notify has been sent. No data to receive will actually return a negative error code. In our transport implementations, that error code is mapped appropriately to 0, but this is not going to be immediately obvious to other engineers who try to make their own transport implementations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I have updated that in the transport interface API doc. |
||
} | ||
|
||
elapsedTimeMs = calculateElapsedTime( getTimeStampMs(), entryTimeMs ); | ||
|
||
if( ( bytesRemaining > 0U ) && ( elapsedTimeMs >= timeoutMs ) ) | ||
if( ( bytesRemaining > 0U ) && | ||
RichardBarry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
( timeSinceLastDataWasRecvdMs >= MQTT_RECV_POLLING_TIMEOUT_MS ) ) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I don't really like this name. Not saying it's not accurate, it's just long and I don't think the extra words are very meaningful. For example, something like I don't think the variable needs to be this explicit if it's clear from reading the rest of the code. Feel free to disagree. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I like the recommendation. Have made the name concise. |
||
{ | ||
LogError( ( "Time expired while receiving packet." ) ); | ||
receiveError = true; | ||
|
@@ -760,7 +775,6 @@ static MQTTStatus_t discardPacket( const MQTTContext_t * pContext, | |
int32_t bytesReceived = 0; | ||
size_t bytesToReceive = 0U; | ||
uint32_t totalBytesReceived = 0U, entryTimeMs = 0U, elapsedTimeMs = 0U; | ||
uint32_t remainingTimeMs = timeoutMs; | ||
MQTTGetCurrentTimeFunc_t getTimeStampMs = NULL; | ||
bool receiveError = false; | ||
|
||
|
@@ -779,7 +793,7 @@ static MQTTStatus_t discardPacket( const MQTTContext_t * pContext, | |
bytesToReceive = remainingLength - totalBytesReceived; | ||
} | ||
|
||
bytesReceived = recvExact( pContext, bytesToReceive, remainingTimeMs ); | ||
bytesReceived = recvExact( pContext, bytesToReceive ); | ||
|
||
if( bytesReceived != ( int32_t ) bytesToReceive ) | ||
{ | ||
|
@@ -795,12 +809,8 @@ static MQTTStatus_t discardPacket( const MQTTContext_t * pContext, | |
|
||
elapsedTimeMs = calculateElapsedTime( getTimeStampMs(), entryTimeMs ); | ||
|
||
/* Update remaining time and check for timeout. */ | ||
if( elapsedTimeMs < timeoutMs ) | ||
{ | ||
remainingTimeMs = timeoutMs - elapsedTimeMs; | ||
} | ||
else | ||
/* Check for timeout. */ | ||
if( elapsedTimeMs >= timeoutMs ) | ||
{ | ||
LogError( ( "Time expired while discarding packet." ) ); | ||
receiveError = true; | ||
|
@@ -846,7 +856,7 @@ static MQTTStatus_t receivePacket( const MQTTContext_t * pContext, | |
else | ||
{ | ||
bytesToReceive = incomingPacket.remainingLength; | ||
bytesReceived = recvExact( pContext, bytesToReceive, remainingTimeMs ); | ||
bytesReceived = recvExact( pContext, bytesToReceive ); | ||
|
||
if( bytesReceived == ( int32_t ) bytesToReceive ) | ||
{ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -132,6 +132,7 @@ typedef struct ProcessLoopReturns | |
MQTTStatus_t processLoopStatus; /**< @brief Return value of the process loop. */ | ||
bool incomingPublish; /**< @brief Whether the incoming packet is a publish. */ | ||
MQTTPublishInfo_t * pPubInfo; /**< @brief Publish information to be returned by the deserializer. */ | ||
uint32_t timeoutMs; /**< @brief The timeout value to call MQTT_ProcessLoop API with. */ | ||
} ProcessLoopReturns_t; | ||
|
||
/** | ||
|
@@ -365,6 +366,20 @@ static int32_t transportRecvOneByte( NetworkContext_t * pNetworkContext, | |
return 1; | ||
} | ||
|
||
/** | ||
* @brief Mocked transport returning zero bytes to simulate reception | ||
* of no data over network. | ||
*/ | ||
static int32_t transportRecvNoData( NetworkContext_t * pNetworkContext, | ||
void * pBuffer, | ||
size_t bytesToRead ) | ||
{ | ||
( void ) pNetworkContext; | ||
( void ) pBuffer; | ||
( void ) bytesToRead; | ||
return 0; | ||
} | ||
|
||
/** | ||
* @brief Initialize the transport interface with the mocked functions for | ||
* send and receive. | ||
|
@@ -405,6 +420,7 @@ static void resetProcessLoopParams( ProcessLoopReturns_t * pExpectParams ) | |
pExpectParams->processLoopStatus = MQTTSuccess; | ||
pExpectParams->incomingPublish = false; | ||
pExpectParams->pPubInfo = NULL; | ||
pExpectParams->timeoutMs = MQTT_NO_TIMEOUT_MS; | ||
} | ||
|
||
/** | ||
|
@@ -548,7 +564,7 @@ static void expectProcessLoopCalls( MQTTContext_t * const pContext, | |
} | ||
|
||
/* Expect the above calls when running MQTT_ProcessLoop. */ | ||
mqttStatus = MQTT_ProcessLoop( pContext, MQTT_NO_TIMEOUT_MS ); | ||
mqttStatus = MQTT_ProcessLoop( pContext, pExpectParams->timeoutMs ); | ||
TEST_ASSERT_EQUAL( processLoopStatus, mqttStatus ); | ||
|
||
/* Any final assertions to end the test. */ | ||
|
@@ -853,7 +869,7 @@ void test_MQTT_Connect_partial_receive() | |
|
||
setupTransportInterface( &transport ); | ||
setupNetworkBuffer( &networkBuffer ); | ||
transport.recv = transportRecvOneByte; | ||
transport.recv = transportRecvNoData; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to write a test to replicate the scenario which caught this issue:
This test should fail without this change and pass with this change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can achieve that with the existing mock recv function that returns only byte at a time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added the test |
||
|
||
memset( &mqttContext, 0x0, sizeof( mqttContext ) ); | ||
MQTT_Init( &mqttContext, &transport, getTime, eventCallback, &networkBuffer ); | ||
|
@@ -864,14 +880,16 @@ void test_MQTT_Connect_partial_receive() | |
incomingPacket.type = MQTT_PACKET_TYPE_CONNACK; | ||
incomingPacket.remainingLength = 2; | ||
|
||
/* Not enough time to receive entire packet, for branch coverage. This is due | ||
* to the fact the mocked receive function reads only one byte at a time. */ | ||
timeout = 1; | ||
/* Timeout in receiving entire packet, for branch coverage. This is due to the fact that the mocked | ||
* receive function always returns 0 bytes read. */ | ||
MQTT_GetIncomingPacketTypeAndLength_ExpectAnyArgsAndReturn( MQTTSuccess ); | ||
MQTT_GetIncomingPacketTypeAndLength_ReturnThruPtr_pIncomingPacket( &incomingPacket ); | ||
status = MQTT_Connect( &mqttContext, &connectInfo, NULL, timeout, &sessionPresent ); | ||
TEST_ASSERT_EQUAL_INT( MQTTRecvFailed, status ); | ||
|
||
/* Update to use mock receive function that receives one byte at a time for the | ||
* rest of the test. */ | ||
mqttContext.transportInterface.recv = transportRecvOneByte; | ||
timeout = 10; | ||
|
||
/* Not enough space for packet, discard it. */ | ||
|
@@ -1573,6 +1591,53 @@ void test_MQTT_ProcessLoop_handleIncomingPublish_Error_Paths( void ) | |
TEST_ASSERT_FALSE( isEventCallbackInvoked ); | ||
} | ||
|
||
/** | ||
* @brief This test checks that the ProcessLoop API function is able to | ||
* support receiving an entire incoming MQTT packet over the network when | ||
* the transport recv function only reads less than requested bytes at a | ||
* time, and the timeout passed to the API is "0ms". | ||
*/ | ||
void test_MQTT_ProcessLoop_Zero_Duration_And_Partial_Network_Read( void ) | ||
{ | ||
MQTTStatus_t mqttStatus; | ||
MQTTContext_t context; | ||
TransportInterface_t transport; | ||
MQTTFixedBuffer_t networkBuffer; | ||
ProcessLoopReturns_t expectParams = { 0 }; | ||
|
||
setupNetworkBuffer( &networkBuffer ); | ||
|
||
transport.send = transportSendSuccess; | ||
|
||
/* Set the transport recv function for the test to the mock function that represents | ||
* partial read of data from network (i.e. less than requested number of bytes) | ||
* at a time. */ | ||
transport.recv = transportRecvOneByte; | ||
|
||
/* Initialize the context. */ | ||
mqttStatus = MQTT_Init( &context, &transport, getTime, eventCallback, &networkBuffer ); | ||
TEST_ASSERT_EQUAL( MQTTSuccess, mqttStatus ); | ||
|
||
/* Set flag required for configuring behavior of expectProcessLoopCalls() | ||
* helper function. */ | ||
modifyIncomingPacketStatus = MQTTSuccess; | ||
|
||
/* Test the ProcessLoop() call with zero duration timeout to verify that it | ||
* will be able to support reading the packet over network over multiple calls to | ||
* the transport receive function. */ | ||
expectParams.timeoutMs = MQTT_NO_TIMEOUT_MS; | ||
|
||
/* Test with an incoming PUBLISH packet whose payload is read only one byte | ||
* per call to the transport recv function. */ | ||
currentPacketType = MQTT_PACKET_TYPE_PUBLISH; | ||
/* Set expected return values during the loop. */ | ||
resetProcessLoopParams( &expectParams ); | ||
expectParams.stateAfterDeserialize = MQTTPubAckSend; | ||
expectParams.stateAfterSerialize = MQTTPublishDone; | ||
expectParams.incomingPublish = true; | ||
expectProcessLoopCalls( &context, &expectParams ); | ||
} | ||
|
||
/** | ||
* @brief This test case covers all calls to the private method, | ||
* handleIncomingAck(...), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to the removal of this parameter, some more timeouts could probably be removed due to similar reasoning. I don't think it's critical though and definitely shouldn't be done in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree that the timeout passed to
discardPacket
and the timeout logic in the function can also be removed (as hygiene). However, I decided to keep the scope of this PR limited to just fixing therecvExact
issue.