Skip to content

Commit

Permalink
For backward compatibility, the legacy message consumer calls should …
Browse files Browse the repository at this point in the history
…ignore 'connected' events.
  • Loading branch information
fpagliughi committed Jul 9, 2024
1 parent 1c08a3e commit c892569
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 27 deletions.
41 changes: 29 additions & 12 deletions include/mqtt/async_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -782,13 +782,21 @@ class async_client : public virtual iasync_client
const_message_ptr* msg, const std::chrono::duration<Rep, Period>& relTime
) {
event_type evt;
if (!que_->try_get_for(&evt, relTime))
return false;

if (const auto* pval = std::get_if<const_message_ptr>(&evt))
*msg = std::move(*pval);
else
*msg = const_message_ptr{};
while (true) {
if (!que_->try_get_for(&evt, relTime))
return false;

if (const auto* pval = std::get_if<const_message_ptr>(&evt)) {
*msg = std::move(*pval);
break;
}

if (!std::holds_alternative<connected_event>(evt)) {
*msg = const_message_ptr{};
break;
}
}
return true;
}
/**
Expand Down Expand Up @@ -817,13 +825,22 @@ class async_client : public virtual iasync_client
const_message_ptr* msg, const std::chrono::time_point<Clock, Duration>& absTime
) {
event_type evt;
if (!que_->try_get_until(&evt, absTime))
return false;

if (const auto* pval = std::get_if<const_message_ptr>(&evt))
*msg = std::move(*pval);
else
*msg = const_message_ptr{};
while (true) {
if (!que_->try_get_until(&evt, absTime))
return false;

if (const auto* pval = std::get_if<const_message_ptr>(&evt)) {
*msg = std::move(*pval);
break;
}

if (!std::holds_alternative<connected_event>(evt)) {
*msg = const_message_ptr{};
break;
}
}

return true;
}
/**
Expand Down
35 changes: 25 additions & 10 deletions src/async_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,22 +876,37 @@ void async_client::stop_consuming()

const_message_ptr async_client::consume_message()
{
auto evt = que_->get();
if (const auto* pval = std::get_if<const_message_ptr>(&evt))
return *pval;
return const_message_ptr{};
// For backward compatibility we ignore the 'connected' events,
// whereas disconnected/lost return an empty pointer.
while (true) {
auto evt = que_->get();

if (const auto* pval = std::get_if<const_message_ptr>(&evt))
return *pval;

if (!std::holds_alternative<connected_event>(evt))
return const_message_ptr{};
}
}

bool async_client::try_consume_message(const_message_ptr* msg)
{
event_type evt;
if (!que_->try_get(&evt))
return false;

if (const auto* pval = std::get_if<const_message_ptr>(&evt))
*msg = std::move(*pval);
else
*msg = const_message_ptr{};
while (true) {
if (!que_->try_get(&evt))
return false;

if (const auto* pval = std::get_if<const_message_ptr>(&evt)) {
*msg = std::move(*pval);
break;
}

if (!std::holds_alternative<connected_event>(evt)) {
*msg = const_message_ptr{};
break;
}
}
return true;
}

Expand Down
48 changes: 43 additions & 5 deletions test/unit/test_thread_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,34 @@ TEST_CASE("que put/get", "[thread_queue]")
REQUIRE(que.get() == 3);
}

TEST_CASE("que tryget", "[thread_queue]")
{
thread_queue<int> que;
int n;

// try_get's should fail on empty queue
REQUIRE(!que.try_get(&n));
REQUIRE(!que.try_get_for(&n, 5ms));

auto timeout = steady_clock::now() + 15ms;
REQUIRE(!que.try_get_until(&n, timeout));

que.put(1);
que.put(2);
REQUIRE(que.try_get(&n));
REQUIRE(n == 1);

que.put(3);
REQUIRE(que.try_get(&n));
REQUIRE(n == 2);
REQUIRE(que.try_get(&n));
REQUIRE(n == 3);

// Empty now. Try should fail and leave 'n' unchanged
REQUIRE(!que.try_get(&n));
REQUIRE(n == 3);
}

TEST_CASE("que mt put/get", "[thread_queue]")
{
thread_queue<string> que;
Expand All @@ -54,9 +82,13 @@ TEST_CASE("que mt put/get", "[thread_queue]")

auto producer = [&que, &N]() {
string s;
for (size_t i = 0; i < 512; ++i) s.push_back('a' + i % 26);
for (size_t i = 0; i < 512; ++i) {
s.push_back('a' + i % 26);
}

for (size_t i = 0; i < N; ++i) que.put(s);
for (size_t i = 0; i < N; ++i) {
que.put(s);
}
};

auto consumer = [&que, &N]() {
Expand All @@ -71,11 +103,17 @@ TEST_CASE("que mt put/get", "[thread_queue]")
std::vector<std::thread> producers;
std::vector<std::future<bool>> consumers;

for (size_t i = 0; i < N_THR; ++i) producers.push_back(std::thread(producer));
for (size_t i = 0; i < N_THR; ++i) {
producers.push_back(std::thread(producer));
}

for (size_t i = 0; i < N_THR; ++i) consumers.push_back(std::async(consumer));
for (size_t i = 0; i < N_THR; ++i) {
consumers.push_back(std::async(consumer));
}

for (size_t i = 0; i < N_THR; ++i) producers[i].join();
for (size_t i = 0; i < N_THR; ++i) {
producers[i].join();
}

for (size_t i = 0; i < N_THR; ++i) {
REQUIRE(consumers[i].get());
Expand Down

0 comments on commit c892569

Please sign in to comment.