From 6661251489c1fb8d38c9b0b2792e60fe7631a86b Mon Sep 17 00:00:00 2001 From: Bright Chen Date: Tue, 23 Jan 2024 13:22:08 +0800 Subject: [PATCH] Support A Multiple Producer, Single Consumer Queue (#2492) --- src/butil/containers/mpsc_queue.h | 189 ++++++++++++++++++++++++++++++ src/butil/thread_key.h | 6 +- test/BUILD.bazel | 1 + test/CMakeLists.txt | 1 + test/Makefile | 1 + test/mpsc_queue_unittest.cc | 124 ++++++++++++++++++++ 6 files changed, 319 insertions(+), 3 deletions(-) create mode 100644 src/butil/containers/mpsc_queue.h create mode 100644 test/mpsc_queue_unittest.cc diff --git a/src/butil/containers/mpsc_queue.h b/src/butil/containers/mpsc_queue.h new file mode 100644 index 0000000000..4c7072da97 --- /dev/null +++ b/src/butil/containers/mpsc_queue.h @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// A Multiple Producer, Single Consumer Queue. +// It allows multiple threads to enqueue, and allows one thread +// (and only one thread) to dequeue. + +#ifndef BUTIL_MPSC_QUEUE_H +#define BUTIL_MPSC_QUEUE_H + +#include "butil/object_pool.h" +#include "butil/type_traits.h" + +namespace butil { + +template +struct BAIDU_CACHELINE_ALIGNMENT MPSCQueueNode { + static MPSCQueueNode* const UNCONNECTED; + + MPSCQueueNode* next{NULL}; + char data_mem[sizeof(T)]{}; + +}; + +template +MPSCQueueNode* const MPSCQueueNode::UNCONNECTED = (MPSCQueueNode*)(intptr_t)-1; + +// Default allocator for MPSCQueueNode. +template +class DefaultAllocator { +public: + void* Alloc() { return malloc(sizeof(MPSCQueueNode)); } + void Free(void* p) { free(p); } +}; + +// Allocator using ObjectPool for MPSCQueueNode. +template +class ObjectPoolAllocator { +public: + void* Alloc() { return get_object>(); } + void Free(void* p) { return_object(p); } +}; + + +template > +class MPSCQueue { +public: + MPSCQueue() + : _head(NULL) + , _cur_enqueue_node(NULL) + , _cur_dequeue_node(NULL) {} + + ~MPSCQueue(); + + // Enqueue data to the queue. + void Enqueue(typename add_const_reference::type data); + void Enqueue(T&& data); + + // Dequeue data from the queue. + bool Dequeue(T& data); + +private: + // Reverse the list until old_head. + void ReverseList(MPSCQueueNode* old_head); + + void EnqueueImpl(MPSCQueueNode* node); + bool DequeueImpl(T* data); + + Alloc _alloc; + atomic*> _head; + atomic*> _cur_enqueue_node; + MPSCQueueNode* _cur_dequeue_node; +}; + +template +MPSCQueue::~MPSCQueue() { + while (DequeueImpl(NULL)); +} + +template +void MPSCQueue::Enqueue(typename add_const_reference::type data) { + auto node = (MPSCQueueNode*)_alloc.Alloc(); + node->next = MPSCQueueNode::UNCONNECTED; + new ((void*)&node->data_mem) T(data); + EnqueueImpl(node); +} + +template +void MPSCQueue::Enqueue(T&& data) { + auto node = (MPSCQueueNode*)_alloc.Alloc(); + node->next = MPSCQueueNode::UNCONNECTED; + new ((void*)&node->data_mem) T(std::forward(data)); + EnqueueImpl(node); +} + +template +void MPSCQueue::EnqueueImpl(MPSCQueueNode* node) { + MPSCQueueNode* prev = _head.exchange(node, memory_order_release); + if (prev) { + node->next = prev; + return; + } + node->next = NULL; + _cur_enqueue_node.store(node, memory_order_relaxed); +} + +template +bool MPSCQueue::Dequeue(T& data) { + return DequeueImpl(&data); +} + +template +bool MPSCQueue::DequeueImpl(T* data) { + MPSCQueueNode* node; + if (_cur_dequeue_node) { + node = _cur_dequeue_node; + } else { + node = _cur_enqueue_node.load(memory_order_relaxed); + } + if (!node) { + return false; + } + + _cur_enqueue_node.store(NULL, memory_order_relaxed); + if (data) { + auto mem = (T* const)node->data_mem; + *data = std::move(*mem); + } + MPSCQueueNode* old_node = node; + if (!node->next) { + ReverseList(node); + } + _cur_dequeue_node = node->next; + return_object(old_node); + + return true; +} + +template +void MPSCQueue::ReverseList(MPSCQueueNode* old_head) { + // Try to set _write_head to NULL to mark that it is done. + MPSCQueueNode* new_head = old_head; + MPSCQueueNode* desired = NULL; + if (_head.compare_exchange_strong( + new_head, desired, memory_order_acquire)) { + // No one added new requests. + return; + } + CHECK_NE(new_head, old_head); + // Above acquire fence pairs release fence of exchange in Enqueue() to make + // sure that we see all fields of requests set. + + // Someone added new requests. + // Reverse the list until old_head. + MPSCQueueNode* tail = NULL; + MPSCQueueNode* p = new_head; + do { + while (p->next == MPSCQueueNode::UNCONNECTED) { + // TODO(gejun): elaborate this + sched_yield(); + } + MPSCQueueNode* const saved_next = p->next; + p->next = tail; + tail = p; + p = saved_next; + CHECK(p); + } while (p != old_head); + + // Link old list with new list. + old_head->next = tail; +} + +} + +#endif // BUTIL_MPSC_QUEUE_H diff --git a/src/butil/thread_key.h b/src/butil/thread_key.h index 48f02f7d02..f8d8f0e47c 100644 --- a/src/butil/thread_key.h +++ b/src/butil/thread_key.h @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -#ifndef BRPC_THREAD_KEY_H -#define BRPC_THREAD_KEY_H +#ifndef BUTIL_THREAD_KEY_H +#define BUTIL_THREAD_KEY_H #include #include @@ -199,4 +199,4 @@ void ThreadLocal::reset(T* ptr) { } -#endif //BRPC_THREAD_KEY_H +#endif // BUTIL_THREAD_KEY_H diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 8c57c10a1e..3bf7cd4597 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -53,6 +53,7 @@ TEST_BUTIL_SOURCES = [ "mru_cache_unittest.cc", "small_map_unittest.cc", "stack_container_unittest.cc", + "mpsc_queue_unittest.cc", "cpu_unittest.cc", "crash_logging_unittest.cc", "leak_tracker_unittest.cc", diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5aaf7d3ab4..a62273182d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -88,6 +88,7 @@ SET(TEST_BUTIL_SOURCES ${PROJECT_SOURCE_DIR}/test/mru_cache_unittest.cc ${PROJECT_SOURCE_DIR}/test/small_map_unittest.cc ${PROJECT_SOURCE_DIR}/test/stack_container_unittest.cc + ${PROJECT_SOURCE_DIR}/test/mpsc_queue_unittest.cc ${PROJECT_SOURCE_DIR}/test/cpu_unittest.cc ${PROJECT_SOURCE_DIR}/test/crash_logging_unittest.cc ${PROJECT_SOURCE_DIR}/test/leak_tracker_unittest.cc diff --git a/test/Makefile b/test/Makefile index 6e0dbc9752..82efc38103 100644 --- a/test/Makefile +++ b/test/Makefile @@ -58,6 +58,7 @@ TEST_BUTIL_SOURCES = \ mru_cache_unittest.cc \ small_map_unittest.cc \ stack_container_unittest.cc \ + mpsc_queue_unittest.cc \ cpu_unittest.cc \ crash_logging_unittest.cc \ leak_tracker_unittest.cc \ diff --git a/test/mpsc_queue_unittest.cc b/test/mpsc_queue_unittest.cc new file mode 100644 index 0000000000..da67cfb7fd --- /dev/null +++ b/test/mpsc_queue_unittest.cc @@ -0,0 +1,124 @@ +#include +#include +#include "butil/containers/mpsc_queue.h" + +namespace { + +const uint MAX_COUNT = 10000000; + +void Consume(butil::MPSCQueue& q, bool allow_empty) { + uint i = 0; + uint empty_count = 0; + while (true) { + uint d; + if (!q.Dequeue(d)) { + ASSERT_TRUE(allow_empty); + ASSERT_LT(empty_count++, (const uint)10000); + ::usleep(10 * 1000); + continue; + } + ASSERT_EQ(i++, d); + if (i == MAX_COUNT) { + break; + } + } +} + +void* ProduceThread(void* arg) { + auto q = (butil::MPSCQueue*)arg; + for (uint i = 0; i < MAX_COUNT; ++i) { + q->Enqueue(i); + } + return NULL; +} + +void* ConsumeThread1(void* arg) { + auto q = (butil::MPSCQueue*)arg; + Consume(*q, true); + return NULL; +} + +TEST(MPSCQueueTest, spsc_single_thread) { + butil::MPSCQueue q; + for (uint i = 0; i < MAX_COUNT; ++i) { + q.Enqueue(i); + } + Consume(q, false); +} + +TEST(MPSCQueueTest, spsc_multi_thread) { + butil::MPSCQueue q; + pthread_t produce_tid; + ASSERT_EQ(0, pthread_create(&produce_tid, NULL, ProduceThread, &q)); + pthread_t consume_tid; + ASSERT_EQ(0, pthread_create(&consume_tid, NULL, ConsumeThread1, &q)); + + pthread_join(produce_tid, NULL); + pthread_join(consume_tid, NULL); + +} + +butil::atomic g_index(0); +void* MultiProduceThread(void* arg) { + auto q = (butil::MPSCQueue*)arg; + while (true) { + uint i = g_index.fetch_add(1, butil::memory_order_relaxed); + if (i >= MAX_COUNT) { + break; + } + q->Enqueue(i); + } + return NULL; +} + +butil::Mutex g_mutex; +bool g_counts[MAX_COUNT]; +void Consume2(butil::MPSCQueue& q) { + uint empty_count = 0; + uint count = 0; + while (true) { + uint d; + if (!q.Dequeue(d)) { + ASSERT_LT(empty_count++, (const uint)10000); + ::usleep(1 * 1000); + continue; + } + ASSERT_LT(d, MAX_COUNT); + { + BAIDU_SCOPED_LOCK(g_mutex); + ASSERT_FALSE(g_counts[d]); + g_counts[d] = true; + } + if (++count >= MAX_COUNT) { + break; + } + } +} + +void* ConsumeThread2(void* arg) { + auto q = (butil::MPSCQueue*)arg; + Consume2(*q); + return NULL; +} + +TEST(MPSCQueueTest, mpsc_multi_thread) { + butil::MPSCQueue q; + + int thread_num = 8; + pthread_t threads[thread_num]; + for (int i = 0; i < thread_num; ++i) { + ASSERT_EQ(0, pthread_create(&threads[i], NULL, MultiProduceThread, &q)); + } + + pthread_t consume_tid; + ASSERT_EQ(0, pthread_create(&consume_tid, NULL, ConsumeThread2, &q)); + + for (int i = 0; i < thread_num; ++i) { + pthread_join(threads[i], NULL); + } + pthread_join(consume_tid, NULL); + +} + + +} \ No newline at end of file