From 3d51765d7be63b299c0cb095dd3d03b1d00b03bf Mon Sep 17 00:00:00 2001 From: John Plevyak Date: Thu, 10 Oct 2019 13:04:14 -0700 Subject: [PATCH] Merge pull request #111 from jplevyak/release-1.3-use-after-free Apply fix for use-after-free in Envoy ThreadLocal Slot. Signed-off-by: John Plevyak --- include/envoy/thread_local/thread_local.h | 20 ++++ source/common/common/non_copyable.h | 13 +- source/common/config/config_provider_impl.h | 8 +- source/common/router/rds_impl.cc | 8 +- .../common/thread_local/thread_local_impl.cc | 113 ++++++++++++++++-- .../common/thread_local/thread_local_impl.h | 44 ++++++- .../thread_local/thread_local_impl_test.cc | 86 ++++++++++++- test/mocks/thread_local/mocks.h | 10 ++ 8 files changed, 281 insertions(+), 21 deletions(-) diff --git a/include/envoy/thread_local/thread_local.h b/include/envoy/thread_local/thread_local.h index c6eb0b54bb4a..b464c2bb59d4 100644 --- a/include/envoy/thread_local/thread_local.h +++ b/include/envoy/thread_local/thread_local.h @@ -27,6 +27,15 @@ typedef std::shared_ptr ThreadLocalObjectSharedPtr; class Slot { public: virtual ~Slot() {} + /** + * Returns if there is thread local data for this thread. + * + * This should return true for Envoy worker threads and false for threads which do not have thread + * local storage allocated. + * + * @return true if registerThread has been called for this thread, false otherwise. + */ + virtual bool currentThreadRegistered() PURE; /** * @return ThreadLocalObjectSharedPtr a thread local object stored in the slot. @@ -64,6 +73,17 @@ class Slot { */ typedef std::function InitializeCb; virtual void set(InitializeCb cb) PURE; + + /** + * UpdateCb takes the current stored data, and returns an updated/new version data. + * TLS will run the callback and replace the stored data with the returned value *in each thread*. + * + * NOTE: The update callback is not supposed to capture the Slot, or its owner. As the owner may + * be destructed in main thread before the update_cb gets called in a worker thread. + **/ + using UpdateCb = std::function; + virtual void runOnAllThreads(const UpdateCb& update_cb) PURE; + virtual void runOnAllThreads(const UpdateCb& update_cb, Event::PostCb complete_cb) PURE; }; typedef std::unique_ptr SlotPtr; diff --git a/source/common/common/non_copyable.h b/source/common/common/non_copyable.h index 7c394c41de18..34fb35cad00a 100644 --- a/source/common/common/non_copyable.h +++ b/source/common/common/non_copyable.h @@ -2,14 +2,19 @@ namespace Envoy { /** - * Mixin class that makes derived classes not copyable. Like boost::noncopyable without boost. + * Mixin class that makes derived classes not copyable and not moveable. Like boost::noncopyable + * without boost. */ class NonCopyable { protected: NonCopyable() {} -private: - NonCopyable(const NonCopyable&); - NonCopyable& operator=(const NonCopyable&); + // Non-moveable. + NonCopyable(NonCopyable&&) noexcept = delete; + NonCopyable& operator=(NonCopyable&&) noexcept = delete; + + // Non-copyable. + NonCopyable(const NonCopyable&) = delete; + NonCopyable& operator=(const NonCopyable&) = delete; }; } // namespace Envoy diff --git a/source/common/config/config_provider_impl.h b/source/common/config/config_provider_impl.h index ddcd1232caa1..85827eeabcab 100644 --- a/source/common/config/config_provider_impl.h +++ b/source/common/config/config_provider_impl.h @@ -369,8 +369,12 @@ class MutableConfigProviderBase : public MutableConfigProviderCommonBase { if (getConfig() == config) { return; } - tls_->runOnAllThreads( - [this, config]() -> void { tls_->getTyped().config_ = config; }); + tls_->runOnAllThreads([config](ThreadLocal::ThreadLocalObjectSharedPtr previous) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + auto prev_thread_local_config = std::dynamic_pointer_cast(previous); + prev_thread_local_config->config_ = config; + return previous; + }); } protected: diff --git a/source/common/router/rds_impl.cc b/source/common/router/rds_impl.cc index 62a0fd0786e3..5d4a63e9567b 100644 --- a/source/common/router/rds_impl.cc +++ b/source/common/router/rds_impl.cc @@ -167,8 +167,12 @@ Router::ConfigConstSharedPtr RdsRouteConfigProviderImpl::config() { void RdsRouteConfigProviderImpl::onConfigUpdate() { ConfigConstSharedPtr new_config( new ConfigImpl(config_update_info_->routeConfiguration(), factory_context_, false)); - tls_->runOnAllThreads( - [this, new_config]() -> void { tls_->getTyped().config_ = new_config; }); + tls_->runOnAllThreads([new_config](ThreadLocal::ThreadLocalObjectSharedPtr previous) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + auto prev_config = std::dynamic_pointer_cast(previous); + prev_config->config_ = new_config; + return previous; + }); } RouteConfigProviderManagerImpl::RouteConfigProviderManagerImpl(Server::Admin& admin) { diff --git a/source/common/thread_local/thread_local_impl.cc b/source/common/thread_local/thread_local_impl.cc index 14bcf251bd5c..25615925f7d7 100644 --- a/source/common/thread_local/thread_local_impl.cc +++ b/source/common/thread_local/thread_local_impl.cc @@ -1,5 +1,6 @@ #include "common/thread_local/thread_local_impl.h" +#include #include #include #include @@ -24,17 +25,30 @@ SlotPtr InstanceImpl::allocateSlot() { ASSERT(std::this_thread::get_id() == main_thread_id_); ASSERT(!shutdown_); - for (uint64_t i = 0; i < slots_.size(); i++) { - if (slots_[i] == nullptr) { - std::unique_ptr slot(new SlotImpl(*this, i)); - slots_[i] = slot.get(); - return slot; - } + if (free_slot_indexes_.empty()) { + std::unique_ptr slot(new SlotImpl(*this, slots_.size())); + auto wrapper = std::make_unique(*this, std::move(slot)); + slots_.push_back(wrapper->slot_.get()); + return wrapper; } + const uint32_t idx = free_slot_indexes_.front(); + free_slot_indexes_.pop_front(); + ASSERT(idx < slots_.size()); + std::unique_ptr slot(new SlotImpl(*this, idx)); + slots_[idx] = slot.get(); + return std::make_unique(*this, std::move(slot)); +} + +bool InstanceImpl::SlotImpl::currentThreadRegistered() { + return thread_local_data_.data_.size() > index_; +} - std::unique_ptr slot(new SlotImpl(*this, slots_.size())); - slots_.push_back(slot.get()); - return slot; +void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb) { + parent_.runOnAllThreads([this, cb]() { setThreadLocal(index_, cb(get())); }); +} + +void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) { + parent_.runOnAllThreads([this, cb]() { setThreadLocal(index_, cb(get())); }, complete_cb); } ThreadLocalObjectSharedPtr InstanceImpl::SlotImpl::get() { @@ -42,6 +56,51 @@ ThreadLocalObjectSharedPtr InstanceImpl::SlotImpl::get() { return thread_local_data_.data_[index_]; } +InstanceImpl::Bookkeeper::Bookkeeper(InstanceImpl& parent, std::unique_ptr&& slot) + : parent_(parent), slot_(std::move(slot)), + ref_count_(/*not used.*/ nullptr, + [slot = slot_.get(), &parent = this->parent_](uint32_t* /* not used */) { + // On destruction, post a cleanup callback on main thread, this could happen on + // any thread. + parent.scheduleCleanup(slot); + }) {} + +ThreadLocalObjectSharedPtr InstanceImpl::Bookkeeper::get() { return slot_->get(); } + +void InstanceImpl::Bookkeeper::runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) { + slot_->runOnAllThreads( + [cb, ref_count = this->ref_count_](ThreadLocalObjectSharedPtr previous) { + return cb(std::move(previous)); + }, + complete_cb); +} + +void InstanceImpl::Bookkeeper::runOnAllThreads(const UpdateCb& cb) { + slot_->runOnAllThreads([cb, ref_count = this->ref_count_](ThreadLocalObjectSharedPtr previous) { + return cb(std::move(previous)); + }); +} + +bool InstanceImpl::Bookkeeper::currentThreadRegistered() { + return slot_->currentThreadRegistered(); +} + +void InstanceImpl::Bookkeeper::runOnAllThreads(Event::PostCb cb) { + // Use ref_count_ to bookkeep how many on-the-fly callback are out there. + slot_->runOnAllThreads([cb, ref_count = this->ref_count_]() { cb(); }); +} + +void InstanceImpl::Bookkeeper::runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) { + // Use ref_count_ to bookkeep how many on-the-fly callback are out there. + slot_->runOnAllThreads([cb, main_callback, ref_count = this->ref_count_]() { cb(); }, + main_callback); +} + +void InstanceImpl::Bookkeeper::set(InitializeCb cb) { + slot_->set([cb, ref_count = this->ref_count_](Event::Dispatcher& dispatcher) + -> ThreadLocalObjectSharedPtr { return cb(dispatcher); }); +} + void InstanceImpl::registerThread(Event::Dispatcher& dispatcher, bool main_thread) { ASSERT(std::this_thread::get_id() == main_thread_id_); ASSERT(!shutdown_); @@ -56,6 +115,38 @@ void InstanceImpl::registerThread(Event::Dispatcher& dispatcher, bool main_threa } } +// Puts the slot into a deferred delete container, the slot will be destructed when its out-going +// callback reference count goes to 0. +void InstanceImpl::recycle(std::unique_ptr&& slot) { + ASSERT(std::this_thread::get_id() == main_thread_id_); + ASSERT(slot != nullptr); + auto* slot_addr = slot.get(); + deferred_deletes_.insert({slot_addr, std::move(slot)}); +} + +// Called by the Bookkeeper ref_count destructor, the SlotImpl in the deferred deletes map can be +// destructed now. +void InstanceImpl::scheduleCleanup(SlotImpl* slot) { + if (shutdown_) { + // If server is shutting down, do nothing here. + // The destruction of Bookkeeper has already transferred the SlotImpl to the deferred_deletes_ + // queue. No matter if this method is called from a Worker thread, the SlotImpl will be + // destructed on main thread when InstanceImpl destructs. + return; + } + if (std::this_thread::get_id() == main_thread_id_) { + // If called from main thread, save a callback. + ASSERT(deferred_deletes_.contains(slot)); + deferred_deletes_.erase(slot); + return; + } + main_thread_dispatcher_->post([slot, this]() { + ASSERT(deferred_deletes_.contains(slot)); + // The slot is guaranteed to be put into the deferred_deletes_ map by Bookkeeper destructor. + deferred_deletes_.erase(slot); + }); +} + void InstanceImpl::removeSlot(SlotImpl& slot) { ASSERT(std::this_thread::get_id() == main_thread_id_); @@ -69,6 +160,10 @@ void InstanceImpl::removeSlot(SlotImpl& slot) { const uint64_t index = slot.index_; slots_[index] = nullptr; + ASSERT(std::find(free_slot_indexes_.begin(), free_slot_indexes_.end(), index) == + free_slot_indexes_.end(), + fmt::format("slot index {} already in free slot set!", index)); + free_slot_indexes_.push_back(index); runOnAllThreads([index]() -> void { // This runs on each thread and clears the slot, making it available for a new allocations. // This is safe even if a new allocation comes in, because everything happens with post() and diff --git a/source/common/thread_local/thread_local_impl.h b/source/common/thread_local/thread_local_impl.h index 820cd1504a95..5ae02be1f0f3 100644 --- a/source/common/thread_local/thread_local_impl.h +++ b/source/common/thread_local/thread_local_impl.h @@ -8,6 +8,9 @@ #include "envoy/thread_local/thread_local.h" #include "common/common/logger.h" +#include "common/common/non_copyable.h" + +#include "absl/container/flat_hash_map.h" namespace Envoy { namespace ThreadLocal { @@ -15,7 +18,7 @@ namespace ThreadLocal { /** * Implementation of ThreadLocal that relies on static thread_local objects. */ -class InstanceImpl : Logger::Loggable, public Instance { +class InstanceImpl : Logger::Loggable, public NonCopyable, public Instance { public: InstanceImpl() : main_thread_id_(std::this_thread::get_id()) {} ~InstanceImpl(); @@ -34,6 +37,9 @@ class InstanceImpl : Logger::Loggable, public Instance { // ThreadLocal::Slot ThreadLocalObjectSharedPtr get() override; + bool currentThreadRegistered() override; + void runOnAllThreads(const UpdateCb& cb) override; + void runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) override; void runOnAllThreads(Event::PostCb cb) override { parent_.runOnAllThreads(cb); } void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) override { parent_.runOnAllThreads(cb, main_callback); @@ -44,22 +50,58 @@ class InstanceImpl : Logger::Loggable, public Instance { const uint64_t index_; }; + // A Wrapper of SlotImpl which on destruction returns the SlotImpl to the deferred delete queue + // (detaches it). + struct Bookkeeper : public Slot { + Bookkeeper(InstanceImpl& parent, std::unique_ptr&& slot); + ~Bookkeeper() override { parent_.recycle(std::move(slot_)); } + + // ThreadLocal::Slot + ThreadLocalObjectSharedPtr get() override; + void runOnAllThreads(const UpdateCb& cb) override; + void runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) override; + bool currentThreadRegistered() override; + void runOnAllThreads(Event::PostCb cb) override; + void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) override; + void set(InitializeCb cb) override; + + InstanceImpl& parent_; + std::unique_ptr slot_; + std::shared_ptr ref_count_; + }; + struct ThreadLocalData { Event::Dispatcher* dispatcher_{}; std::vector data_; }; + void recycle(std::unique_ptr&& slot); + // Cleanup the deferred deletes queue. + void scheduleCleanup(SlotImpl* slot); + void removeSlot(SlotImpl& slot); void runOnAllThreads(Event::PostCb cb); void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback); static void setThreadLocal(uint32_t index, ThreadLocalObjectSharedPtr object); static thread_local ThreadLocalData thread_local_data_; + + // A indexed container for Slots that has to be deferred to delete due to out-going callbacks + // pointing to the Slot. To let the ref_count_ deleter find the SlotImpl by address, the container + // is defined as a map of SlotImpl address to the unique_ptr. + absl::flat_hash_map> deferred_deletes_; + std::vector slots_; + // A list of index of freed slots. + std::list free_slot_indexes_; + std::list> registered_threads_; std::thread::id main_thread_id_; Event::Dispatcher* main_thread_dispatcher_{}; std::atomic shutdown_{}; + + // Test only. + friend class ThreadLocalInstanceImplTest; }; } // namespace ThreadLocal diff --git a/test/common/thread_local/thread_local_impl_test.cc b/test/common/thread_local/thread_local_impl_test.cc index 49ba114a2daf..9195425c0084 100644 --- a/test/common/thread_local/thread_local_impl_test.cc +++ b/test/common/thread_local/thread_local_impl_test.cc @@ -14,7 +14,6 @@ using testing::ReturnPointee; namespace Envoy { namespace ThreadLocal { -namespace { class TestThreadLocalObject : public ThreadLocalObject { public: @@ -46,8 +45,10 @@ class ThreadLocalInstanceImplTest : public testing::Test { object.reset(); return object_ref; } - + int deferredDeletesMapSize() { return tls_.deferred_deletes_.size(); } + int freeSlotIndexesListSize() { return tls_.free_slot_indexes_.size(); } InstanceImpl tls_; + Event::MockDispatcher main_dispatcher_; Event::MockDispatcher thread_dispatcher_; }; @@ -59,15 +60,20 @@ TEST_F(ThreadLocalInstanceImplTest, All) { EXPECT_CALL(thread_dispatcher_, post(_)); SlotPtr slot1 = tls_.allocateSlot(); slot1.reset(); + EXPECT_EQ(deferredDeletesMapSize(), 0); + EXPECT_EQ(freeSlotIndexesListSize(), 1); // Create a new slot which should take the place of the old slot. ReturnPointee() is used to // avoid "leaks" when using InSequence and shared_ptr. SlotPtr slot2 = tls_.allocateSlot(); TestThreadLocalObject& object_ref2 = setObject(*slot2); + EXPECT_EQ(freeSlotIndexesListSize(), 0); EXPECT_CALL(thread_dispatcher_, post(_)); EXPECT_CALL(object_ref2, onDestroy()); + EXPECT_EQ(freeSlotIndexesListSize(), 0); slot2.reset(); + EXPECT_EQ(freeSlotIndexesListSize(), 1); // Make two new slots, shutdown global threading, and delete them. We should not see any // cross-thread posts at this point. We should also see destruction in reverse order. @@ -79,12 +85,87 @@ TEST_F(ThreadLocalInstanceImplTest, All) { tls_.shutdownGlobalThreading(); slot3.reset(); slot4.reset(); + EXPECT_EQ(freeSlotIndexesListSize(), 0); + EXPECT_EQ(deferredDeletesMapSize(), 2); EXPECT_CALL(object_ref4, onDestroy()); EXPECT_CALL(object_ref3, onDestroy()); tls_.shutdownThread(); } +TEST_F(ThreadLocalInstanceImplTest, DeferredRecycle) { + InSequence s; + + // Free a slot without ever calling set. + EXPECT_CALL(thread_dispatcher_, post(_)); + SlotPtr slot1 = tls_.allocateSlot(); + slot1.reset(); + // Slot destructed directly, as there is no out-going callbacks. + EXPECT_EQ(deferredDeletesMapSize(), 0); + EXPECT_EQ(freeSlotIndexesListSize(), 1); + + // Allocate a slot and set value, hold the posted callback and the slot will only be returned + // after the held callback is destructed. + { + SlotPtr slot2 = tls_.allocateSlot(); + EXPECT_EQ(freeSlotIndexesListSize(), 0); + { + Event::PostCb holder; + EXPECT_CALL(thread_dispatcher_, post(_)).WillOnce(Invoke([&](Event::PostCb cb) { + // Holds the posted callback. + holder = cb; + })); + slot2->set( + [](Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectSharedPtr { return nullptr; }); + slot2.reset(); + // Not released yet, as holder has a copy of the ref_count_. + EXPECT_EQ(freeSlotIndexesListSize(), 0); + EXPECT_EQ(deferredDeletesMapSize(), 1); + // This post is called when the holder dies. + EXPECT_CALL(thread_dispatcher_, post(_)); + } + // Slot is deleted now that there holder destructs. + EXPECT_EQ(deferredDeletesMapSize(), 0); + EXPECT_EQ(freeSlotIndexesListSize(), 1); + } + + tls_.shutdownGlobalThreading(); +} + +// Test that the config passed into the update callback is the previous version stored in the slot. +TEST_F(ThreadLocalInstanceImplTest, UpdateCallback) { + InSequence s; + + SlotPtr slot = tls_.allocateSlot(); + + auto newer_version = std::make_shared(); + bool update_called = false; + + TestThreadLocalObject& object_ref = setObject(*slot); + auto update_cb = [&object_ref, &update_called, + newer_version](ThreadLocalObjectSharedPtr obj) -> ThreadLocalObjectSharedPtr { + // The unit test setup have two dispatchers registered, but only one thread, this lambda will be + // called twice in the same thread. + if (!update_called) { + EXPECT_EQ(obj.get(), &object_ref); + update_called = true; + } else { + EXPECT_EQ(obj.get(), newer_version.get()); + } + + return newer_version; + }; + EXPECT_CALL(thread_dispatcher_, post(_)); + EXPECT_CALL(object_ref, onDestroy()); + EXPECT_CALL(*newer_version, onDestroy()); + slot->runOnAllThreads(update_cb); + + EXPECT_EQ(newer_version.get(), &slot->getTyped()); + + tls_.shutdownGlobalThreading(); + tls_.shutdownThread(); +} + // TODO(ramaraochavali): Run this test with real threads. The current issue in the unit // testing environment is, the post to main_dispatcher is not working as expected. @@ -145,6 +226,5 @@ TEST(ThreadLocalInstanceImplDispatcherTest, Dispatcher) { tls.shutdownThread(); } -} // namespace } // namespace ThreadLocal } // namespace Envoy diff --git a/test/mocks/thread_local/mocks.h b/test/mocks/thread_local/mocks.h index 24393722887a..4febddfda351 100644 --- a/test/mocks/thread_local/mocks.h +++ b/test/mocks/thread_local/mocks.h @@ -58,10 +58,19 @@ class MockInstance : public Instance { // ThreadLocal::Slot ThreadLocalObjectSharedPtr get() override { return parent_.data_[index_]; } + bool currentThreadRegistered() override { return parent_.registered_; } void runOnAllThreads(Event::PostCb cb) override { parent_.runOnAllThreads(cb); } void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) override { parent_.runOnAllThreads(cb, main_callback); } + void runOnAllThreads(const UpdateCb& cb) override { + parent_.runOnAllThreads([cb, this]() { parent_.data_[index_] = cb(parent_.data_[index_]); }); + } + void runOnAllThreads(const UpdateCb& cb, Event::PostCb main_callback) override { + parent_.runOnAllThreads([cb, this]() { parent_.data_[index_] = cb(parent_.data_[index_]); }, + main_callback); + } + void set(InitializeCb cb) override { parent_.data_[index_] = cb(parent_.dispatcher_); } MockInstance& parent_; @@ -72,6 +81,7 @@ class MockInstance : public Instance { testing::NiceMock dispatcher_; std::vector data_; bool shutdown_{}; + bool registered_{true}; }; } // namespace ThreadLocal