diff --git a/include/aws/mqtt/private/client_impl.h b/include/aws/mqtt/private/client_impl.h index 68efb114..18e6fc21 100644 --- a/include/aws/mqtt/private/client_impl.h +++ b/include/aws/mqtt/private/client_impl.h @@ -308,6 +308,11 @@ struct aws_mqtt_client_connection_311_impl { */ uint16_t packet_id; + /** + * The last request complete time + */ + uint64_t last_request_write_timestamp; + } synced_data; struct { diff --git a/source/client.c b/source/client.c index 8968a8cc..43e7a541 100644 --- a/source/client.c +++ b/source/client.c @@ -93,6 +93,15 @@ void mqtt_connection_set_state( connection->synced_data.state = state; } +/* Set the socket write timestamp to current clock time */ +void s_mqtt_connection_sync_socket_write_time(struct aws_mqtt_client_connection_311_impl *connection) { + ASSERT_SYNCED_DATA_LOCK_HELD(connection); + if (connection->slot != NULL && connection->slot->channel != NULL) { + aws_channel_current_clock_time( + connection->slot->channel, &connection->synced_data.last_request_write_timestamp); + } +} + struct request_timeout_wrapper; /* used for timeout task */ @@ -1841,6 +1850,8 @@ static enum aws_mqtt_client_request_state s_subscribe_send(uint16_t packet_id, b */ if (aws_channel_slot_send_message(task_arg->connection->slot, message, AWS_CHANNEL_DIR_WRITE)) { aws_mem_release(message->allocator, message); + } else { + s_mqtt_connection_sync_socket_write_time(task_arg->connection); } if (!task_arg->tree_updated) { @@ -2295,6 +2306,8 @@ static enum aws_mqtt_client_request_state s_resubscribe_send( /* This is not necessarily a fatal error; if the send fails, it'll just retry. Still need to clean up though. */ if (aws_channel_slot_send_message(task_arg->connection->slot, message, AWS_CHANNEL_DIR_WRITE)) { aws_mem_release(message->allocator, message); + } else { + s_mqtt_connection_sync_socket_write_time(task_arg->connection); } return AWS_MQTT_CLIENT_REQUEST_ONGOING; @@ -2535,6 +2548,8 @@ static enum aws_mqtt_client_request_state s_unsubscribe_send( if (aws_channel_slot_send_message(task_arg->connection->slot, message, AWS_CHANNEL_DIR_WRITE)) { goto handle_error; + } else { + s_mqtt_connection_sync_socket_write_time(task_arg->connection); } /* TODO: timing should start from the message written into the socket, which is aws_io_message->on_completion @@ -2810,6 +2825,8 @@ static enum aws_mqtt_client_request_state s_publish_send(uint16_t packet_id, boo /* If it's QoS 0, telling user that the message haven't been sent, else, the message will be resent once the * connection is back */ return is_qos_0 ? AWS_MQTT_CLIENT_REQUEST_ERROR : AWS_MQTT_CLIENT_REQUEST_ONGOING; + } else { + s_mqtt_connection_sync_socket_write_time(task_arg->connection); } /* If there's still payload left, get a new message and start again. */ diff --git a/source/client_channel_handler.c b/source/client_channel_handler.c index 6f5502d6..77baab2e 100644 --- a/source/client_channel_handler.c +++ b/source/client_channel_handler.c @@ -32,6 +32,16 @@ static void s_update_next_ping_time(struct aws_mqtt_client_connection_311_impl * } } +/* push off next ping time on ack received. The function must be called in critical section. */ +static void s_pushoff_next_ping_time(struct aws_mqtt_client_connection_311_impl *connection) { + ASSERT_SYNCED_DATA_LOCK_HELD(connection); + uint64_t last_socket_write_time = connection->synced_data.last_request_write_timestamp; + aws_add_u64_checked(last_socket_write_time, connection->keep_alive_time_ns, &last_socket_write_time); + if (last_socket_write_time > connection->next_ping_time) { + connection->next_ping_time = last_socket_write_time; + } +} + /******************************************************************************* * Packet State Machine ******************************************************************************/ @@ -426,7 +436,6 @@ static int s_packet_handler_unsuback(struct aws_byte_cursor message_cursor, void AWS_LS_MQTT_CLIENT, "id=%p: received ack for message id %" PRIu16, (void *)connection, ack.packet_identifier); mqtt_request_complete(connection, AWS_ERROR_SUCCESS, ack.packet_identifier); - return AWS_OP_SUCCESS; } @@ -528,7 +537,6 @@ static int s_packet_handler_pubcomp(struct aws_byte_cursor message_cursor, void AWS_LS_MQTT_CLIENT, "id=%p: received ack for message id %" PRIu16, (void *)connection, ack.packet_identifier); mqtt_request_complete(connection, AWS_ERROR_SUCCESS, ack.packet_identifier); - return AWS_OP_SUCCESS; } @@ -813,9 +821,6 @@ static void s_request_outgoing_task(struct aws_channel_task *task, void *arg, en aws_mqtt_connection_statistics_change_operation_statistic_state( request->connection, request, AWS_MQTT_OSS_NONE); - /* Since a request has complete, update the next ping time */ - s_update_next_ping_time(connection); - aws_hash_table_remove( &connection->synced_data.outstanding_requests_table, &request->packet_id, NULL, NULL); aws_memory_pool_release(&connection->synced_data.requests_pool, request); @@ -837,9 +842,6 @@ static void s_request_outgoing_task(struct aws_channel_task *task, void *arg, en aws_mqtt_connection_statistics_change_operation_statistic_state( request->connection, request, AWS_MQTT_OSS_INCOMPLETE | AWS_MQTT_OSS_UNACKED); - /* Since a request has complete, update the next ping time */ - s_update_next_ping_time(connection); - mqtt_connection_unlock_synced_data(connection); } /* END CRITICAL SECTION */ @@ -1003,6 +1005,7 @@ void mqtt_request_complete(struct aws_mqtt_client_connection_311_impl *connectio { /* BEGIN CRITICAL SECTION */ mqtt_connection_lock_synced_data(connection); + s_pushoff_next_ping_time(connection); struct aws_hash_element *elem = NULL; aws_hash_table_find(&connection->synced_data.outstanding_requests_table, &packet_id, &elem); if (elem != NULL) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 407bd07f..7af8c4eb 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -75,6 +75,7 @@ add_test_case(mqtt_connection_unsub_timeout) add_test_case(mqtt_connection_publish_QoS1_timeout_connection_lost_reset_time) add_test_case(mqtt_connection_ping_norm) add_test_case(mqtt_connection_ping_no) +add_test_case(mqtt_connection_ping_noack) add_test_case(mqtt_connection_ping_basic_scenario) add_test_case(mqtt_connection_ping_double_scenario) add_test_case(mqtt_connection_close_callback_simple) diff --git a/tests/v3/connection_state_test.c b/tests/v3/connection_state_test.c index 854b0b59..8b2aada4 100644 --- a/tests/v3/connection_state_test.c +++ b/tests/v3/connection_state_test.c @@ -3366,8 +3366,8 @@ AWS_TEST_CASE_FIXTURE( &test_data) /** - * Makes a CONNECT, with 1 second keep alive ping interval, send a publish roughly every second, and then ensure NO - * pings were sent + * Makes a CONNECT, with 1 second keep alive ping interval, send a different operation every second, and then ensure NO + * pings were sent. (The ping time will be push off on ack ) */ static int s_test_mqtt_connection_ping_no_fn(struct aws_allocator *allocator, void *ctx) { (void)allocator; @@ -3387,10 +3387,102 @@ static int s_test_mqtt_connection_ping_no_fn(struct aws_allocator *allocator, vo ASSERT_SUCCESS(aws_mqtt_client_connection_connect(state_test_data->mqtt_connection, &connection_options)); s_wait_for_connection_to_complete(state_test_data); + struct aws_byte_cursor pub_topic = aws_byte_cursor_from_c_str("/test/topic"); + struct aws_byte_cursor payload_1 = aws_byte_cursor_from_c_str("Test Message 1"); + + /* Publish */ + uint16_t packet_id = aws_mqtt_client_connection_publish( + state_test_data->mqtt_connection, + &pub_topic, + AWS_MQTT_QOS_AT_LEAST_ONCE, + false, + &payload_1, + s_on_op_complete, + state_test_data); + ASSERT_TRUE(packet_id > 0); + + /* Wait 0.8 seconds */ + aws_thread_current_sleep(800000000); + + /* Subscribe */ + packet_id = aws_mqtt_client_connection_subscribe( + state_test_data->mqtt_connection, + &pub_topic, + AWS_MQTT_QOS_AT_LEAST_ONCE, + s_on_publish_received, + state_test_data, + NULL, + s_on_suback, + state_test_data); + ASSERT_TRUE(packet_id > 0); + + /* Wait 0.8 seconds */ + aws_thread_current_sleep(800000000); + + /* Resub */ + uint16_t resub_packet_id = + aws_mqtt_resubscribe_existing_topics(state_test_data->mqtt_connection, s_on_multi_suback, state_test_data); + ASSERT_TRUE(resub_packet_id > 0); + /* Wait 0.8 seconds */ + aws_thread_current_sleep(800000000); + + /*Unsub*/ + packet_id = aws_mqtt_client_connection_unsubscribe( + state_test_data->mqtt_connection, &pub_topic, s_on_op_complete, state_test_data); + ASSERT_TRUE(packet_id > 0); + + /* Wait 0.8 seconds */ + aws_thread_current_sleep(800000000); + + /* Ensure the server got 0 PING packets */ + ASSERT_INT_EQUALS(0, mqtt_mock_server_get_ping_count(state_test_data->mock_server)); + + ASSERT_SUCCESS( + aws_mqtt_client_connection_disconnect(state_test_data->mqtt_connection, s_on_disconnect_fn, state_test_data)); + s_wait_for_disconnect_to_complete(state_test_data); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE_FIXTURE( + mqtt_connection_ping_no, + s_setup_mqtt_server_fn, + s_test_mqtt_connection_ping_no_fn, + s_clean_up_mqtt_server_fn, + &test_data) + +/** + * Makes a CONNECT, with 1 second keep alive ping interval, disable the server auto ack so that we never received ack + * back. Send a qos1 publish roughly every second for 4 seconds. As we never received the ACK, we should send a total of + * 4 ping + */ +static int s_test_mqtt_connection_ping_noack_fn(struct aws_allocator *allocator, void *ctx) { + (void)allocator; + struct mqtt_connection_state_test *state_test_data = ctx; + + struct aws_mqtt_connection_options connection_options = { + .user_data = state_test_data, + .clean_session = true, + .client_id = aws_byte_cursor_from_c_str("client1234"), + .host_name = aws_byte_cursor_from_c_str(state_test_data->endpoint.address), + .socket_options = &state_test_data->socket_options, + .on_connection_complete = s_on_connection_complete_fn, + .keep_alive_time_secs = 1, + .ping_timeout_ms = 100, + }; + + ASSERT_SUCCESS(aws_mqtt_client_connection_connect(state_test_data->mqtt_connection, &connection_options)); + s_wait_for_connection_to_complete(state_test_data); + + /* Disable the auto ACK packets sent by the server, to pretend to be a bad network */ + mqtt_mock_server_disable_auto_ack(state_test_data->mock_server); + + struct aws_byte_cursor pub_topic = aws_byte_cursor_from_c_str("/test/topic"); + struct aws_byte_cursor payload_1 = aws_byte_cursor_from_c_str("Test Message 1"); + for (int i = 0; i < 4; i++) { - struct aws_byte_cursor pub_topic = aws_byte_cursor_from_c_str("/test/topic"); - struct aws_byte_cursor payload_1 = aws_byte_cursor_from_c_str("Test Message 1"); - uint16_t packet_id_1 = aws_mqtt_client_connection_publish( + /* Publish qos1*/ + uint16_t packet_id = aws_mqtt_client_connection_publish( state_test_data->mqtt_connection, &pub_topic, AWS_MQTT_QOS_AT_LEAST_ONCE, @@ -3398,14 +3490,17 @@ static int s_test_mqtt_connection_ping_no_fn(struct aws_allocator *allocator, vo &payload_1, s_on_op_complete, state_test_data); - ASSERT_TRUE(packet_id_1 > 0); + ASSERT_TRUE(packet_id > 0); /* Wait 0.8 seconds */ aws_thread_current_sleep(800000000); } - /* Ensure the server got 0 PING packets */ - ASSERT_INT_EQUALS(0, mqtt_mock_server_get_ping_count(state_test_data->mock_server)); + /* + * We would like to wait for a total of ~4.5 seconds to account for slight drift/jitter. + * We have been waiting for 0.8*4=3.2 sec. Wait for another 1 sec here. + */ + aws_thread_current_sleep(1000000000); ASSERT_SUCCESS( aws_mqtt_client_connection_disconnect(state_test_data->mqtt_connection, s_on_disconnect_fn, state_test_data)); @@ -3415,9 +3510,9 @@ static int s_test_mqtt_connection_ping_no_fn(struct aws_allocator *allocator, vo } AWS_TEST_CASE_FIXTURE( - mqtt_connection_ping_no, + mqtt_connection_ping_noack, s_setup_mqtt_server_fn, - s_test_mqtt_connection_ping_no_fn, + s_test_mqtt_connection_ping_noack_fn, s_clean_up_mqtt_server_fn, &test_data) diff --git a/tests/v3/mqtt_mock_server_handler.c b/tests/v3/mqtt_mock_server_handler.c index 13f7171f..51bb448f 100644 --- a/tests/v3/mqtt_mock_server_handler.c +++ b/tests/v3/mqtt_mock_server_handler.c @@ -187,7 +187,9 @@ static int s_mqtt_mock_server_handler_process_packet( bool auto_ack = server->synced.auto_ack; aws_mutex_unlock(&server->synced.lock); - if (auto_ack) { + uint8_t qos = (publish_packet.fixed_header.flags >> 1) & 0x3; + // Do not send puback if qos0 + if (auto_ack && qos != 0) { struct aws_io_message *puback_msg = aws_channel_acquire_message_from_pool(server->slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, 256); struct aws_mqtt_packet_ack puback;