Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lua: Change the TLS callback function type of ThreadLocalState to Upd… #11944

Merged
merged 10 commits into from
Jul 15, 2020
6 changes: 4 additions & 2 deletions source/extensions/filters/common/lua/lua.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ int ThreadLocalState::getGlobalRef(uint64_t slot) {
}

uint64_t ThreadLocalState::registerGlobal(const std::string& global) {
tls_slot_->runOnAllThreads([this, global]() {
LuaThreadLocal& tls = tls_slot_->getTyped<LuaThreadLocal>();
tls_slot_->runOnAllThreads([global](ThreadLocal::ThreadLocalObjectSharedPtr ptr) -> auto {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does spelling out -> auto (-> ThreadLocal::ThreadLocalObjectSharedPtr) here will be a problem for clang-tidy?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And let's rename ptr as previous?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the existing code, it is acceptable to spell out the return type or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, I will remove it to make the code more concise.

ASSERT(std::dynamic_pointer_cast<LuaThreadLocal>(ptr));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this assert here?

LuaThreadLocal& tls = *std::dynamic_pointer_cast<LuaThreadLocal>(ptr);
lua_getglobal(tls.state_.get(), global.c_str());
if (lua_isfunction(tls.state_.get(), -1)) {
tls.global_slots_.push_back(luaL_ref(tls.state_.get(), LUA_REGISTRYINDEX));
Expand All @@ -81,6 +82,7 @@ uint64_t ThreadLocalState::registerGlobal(const std::string& global) {
lua_pop(tls.state_.get(), 1);
tls.global_slots_.push_back(LUA_REFNIL);
}
return ptr;
});

return current_global_slot_++;
Expand Down
8 changes: 6 additions & 2 deletions source/extensions/filters/common/lua/lua.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,12 @@ class ThreadLocalState : Logger::Loggable<Logger::Id::lua> {
* all threaded workers.
*/
template <class T> void registerType() {
tls_slot_->runOnAllThreads(
[this]() { T::registerType(tls_slot_->getTyped<LuaThreadLocal>().state_.get()); });
tls_slot_->runOnAllThreads([](ThreadLocal::ThreadLocalObjectSharedPtr ptr) -> auto {
ASSERT(std::dynamic_pointer_cast<LuaThreadLocal>(ptr));
LuaThreadLocal& tls = *std::dynamic_pointer_cast<LuaThreadLocal>(ptr);
T::registerType(tls.state_.get());
return ptr;
});
}

/**
Expand Down
1 change: 1 addition & 0 deletions test/extensions/filters/common/lua/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ envoy_cc_test(
srcs = ["lua_test.cc"],
tags = ["skip_on_windows"],
deps = [
"//source/common/thread_local:thread_local_lib",
"//source/extensions/filters/common/lua:lua_lib",
"//test/mocks:common_lib",
"//test/mocks/thread_local:thread_local_mocks",
Expand Down
51 changes: 51 additions & 0 deletions test/extensions/filters/common/lua/lua_test.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <memory>

#include "common/thread_local/thread_local_impl.h"

#include "extensions/filters/common/lua/lua.h"

#include "test/mocks/common.h"
Expand Down Expand Up @@ -157,6 +159,55 @@ TEST_F(LuaTest, MarkDead) {
lua_gc(cr1->luaState(), LUA_GCCOLLECT, 0);
}

class ThreadSafeTest : public testing::Test {
public:
ThreadSafeTest()
: api_(Api::createApiForTest()), main_dispatcher_(api_->allocateDispatcher("main")),
worker_dispatcher_(api_->allocateDispatcher("worker")) {}

// Use real dispatchers to verify that callback functions can be executed correctly.
Api::ApiPtr api_;
Event::DispatcherPtr main_dispatcher_;
Event::DispatcherPtr worker_dispatcher_;
ThreadLocal::InstanceImpl tls_;

std::unique_ptr<ThreadLocalState> state_;
};

// Test whether ThreadLocalState can be safely released.
TEST_F(ThreadSafeTest, StateDestructedBeforeWorkerRun) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this test fail without the code change in this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will be a crash.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for verifying!

const std::string SCRIPT{R"EOF(
function HelloWorld()
print("Hello World!")
end
)EOF"};

tls_.registerThread(*main_dispatcher_, true);
EXPECT_EQ(main_dispatcher_.get(), &tls_.dispatcher());
tls_.registerThread(*worker_dispatcher_, false);

// Some callback functions waiting to be executed will be added to the dispatcher of the Worker
// thread. The callback functions in the main thread will be executed directly.
state_ = std::make_unique<ThreadLocalState>(SCRIPT, tls_);
state_->registerType<TestObject>();

main_dispatcher_->run(Event::Dispatcher::RunType::Block);

// Destroy state_.
state_.reset(nullptr);

// Start a new worker thread to execute the callback functions in the worker dispatcher.
Thread::ThreadPtr thread = Thread::threadFactoryForTest().createThread([this]() {
worker_dispatcher_->run(Event::Dispatcher::RunType::Block);
// Verify we have the expected dispatcher for the new thread thread.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the new worker thread?

EXPECT_EQ(worker_dispatcher_.get(), &tls_.dispatcher());
});
thread->join();

tls_.shutdownGlobalThreading();
tls_.shutdownThread();
}

} // namespace
} // namespace Lua
} // namespace Common
Expand Down