From cabe3e0ea84d015f71714753844776694cfb8aa2 Mon Sep 17 00:00:00 2001 From: Luan Santos Date: Fri, 13 Oct 2023 10:22:56 -0700 Subject: [PATCH] refactor: dispatcher logic for clarity --- src/game/scheduling/dispatcher.cpp | 107 ++++++++++++++++------------- src/game/scheduling/dispatcher.hpp | 32 +++++---- src/game/scheduling/task.hpp | 48 ++++++------- 3 files changed, 102 insertions(+), 85 deletions(-) diff --git a/src/game/scheduling/dispatcher.cpp b/src/game/scheduling/dispatcher.cpp index c7434b0609f..27ef34e760d 100644 --- a/src/game/scheduling/dispatcher.cpp +++ b/src/game/scheduling/dispatcher.cpp @@ -15,7 +15,7 @@ #include "utils/tools.hpp" constexpr static auto ASYNC_TIME_OUT = std::chrono::seconds(15); -constexpr static auto SLEEP_TIME_MS = 15; +static std::mutex dummyMutex; // This is only used for signaling the condition variable and not as an actual lock. Dispatcher &Dispatcher::getInstance() { return inject(); @@ -25,27 +25,22 @@ void Dispatcher::init() { updateClock(); threadPool.addLoad([this] { - std::unique_lock asyncLock(mutex); + std::unique_lock asyncLock(dummyMutex); while (!threadPool.getIoContext().stopped()) { updateClock(); - // Execute all asynchronous events separately by context - for (uint_fast8_t i = 0; i < static_cast(AsyncEventContext::Last); ++i) { - executeAsyncEvents(i, asyncLock); + for (uint_fast8_t i = 0; i < static_cast(TaskGroup::Last); ++i) { + executeEvents(i, asyncLock); } - // Merge all events that were created by async events - mergeEvents(); - - executeEvents(); executeScheduledEvents(); - - // Merge all events that were created by events and scheduled events mergeEvents(); - auto waitDuration = timeUntilNextScheduledTask(); - cv.wait_for(asyncLock, waitDuration); + if (!hasPendingTasks) { + auto waitDuration = timeUntilNextScheduledTask(); + signalSchedule.wait_for(asyncLock, waitDuration); + } } }); } @@ -53,23 +48,37 @@ void Dispatcher::init() { void Dispatcher::addEvent(std::function &&f, std::string_view context, uint32_t expiresAfterMs) { auto &thread = threads[getThreadId()]; std::scoped_lock lock(thread->mutex); - thread->tasks.emplace_back(expiresAfterMs, std::move(f), context); - cv.notify_one(); + bool notify = !hasPendingTasks; + thread->tasks[static_cast(TaskGroup::Serial)].emplace_back(expiresAfterMs, std::move(f), context); + if (notify && !hasPendingTasks) { + hasPendingTasks = true; + signalSchedule.notify_one(); + } } -void Dispatcher::addEvent_async(std::function &&f, AsyncEventContext context) { +void Dispatcher::addEvent_async(std::function &&f, TaskGroup group) { auto &thread = threads[getThreadId()]; std::scoped_lock lock(thread->mutex); - thread->asyncTasks[static_cast(context)].emplace_back(0, std::move(f), "Dispatcher::addEvent_async"); - cv.notify_one(); + bool notify = !hasPendingTasks; + thread->tasks[static_cast(group)].emplace_back(0, std::move(f), "Dispatcher::addEvent_async"); + if (notify && !hasPendingTasks) { + hasPendingTasks = true; + signalSchedule.notify_one(); + } } uint64_t Dispatcher::scheduleEvent(const std::shared_ptr &task) { auto &thread = threads[getThreadId()]; std::scoped_lock lock(thread->mutex); thread->scheduledTasks.emplace_back(task); - cv.notify_one(); - return scheduledTasksRef.emplace(task->generateId(), task).first->first; + bool notify = !hasPendingTasks; + signalSchedule.notify_one(); + auto eventId = scheduledTasksRef.emplace(task->generateId(), task).first->first; + if (notify && !hasPendingTasks) { + hasPendingTasks = true; + signalSchedule.notify_one(); + } + return eventId; } uint64_t Dispatcher::scheduleEvent(uint32_t delay, std::function &&f, std::string_view context, bool cycle) { @@ -87,41 +96,45 @@ void Dispatcher::stopEvent(uint64_t eventId) { scheduledTasksRef.erase(it); } -void Dispatcher::executeEvents() { - for (const auto &task : eventTasks) { +void Dispatcher::executeSerialEvents(std::vector &tasks) { + for (const auto &task : tasks) { if (task.execute()) { ++dispatcherCycle; } } - eventTasks.clear(); + tasks.clear(); } -void Dispatcher::executeAsyncEvents(const uint8_t contextId, std::unique_lock &asyncLock) { - auto &asyncTasks = asyncEventTasks[contextId]; - if (asyncTasks.empty()) { - return; - } - +void Dispatcher::executeParallelEvents(std::vector &tasks, const uint8_t groupId, std::unique_lock &asyncLock) { std::atomic_uint_fast64_t executedTasks = 0; - // Execute Async Task - for (const auto &task : asyncTasks) { - threadPool.addLoad([this, &task, &executedTasks, totalTaskSize = asyncTasks.size()] { + for (const auto &task : tasks) { + threadPool.addLoad([this, &task, &executedTasks, totalTaskSize = tasks.size()] { task.execute(); if (executedTasks.fetch_add(1) == totalTaskSize) { - asyncTasks_cv.notify_one(); + signalAsync.notify_one(); } }); } - // Wait for all the tasks in the current context to be executed. - if (asyncTasks_cv.wait_for(asyncLock, ASYNC_TIME_OUT) == std::cv_status::timeout) { - g_logger().warn("A timeout occurred when executing the async dispatch in the context({}).", contextId); + if (signalAsync.wait_for(asyncLock, ASYNC_TIME_OUT) == std::cv_status::timeout) { + g_logger().warn("A timeout occurred when executing the async dispatch in the context({}).", groupId); } + tasks.clear(); +} - // Clear all async tasks - asyncTasks.clear(); +void Dispatcher::executeEvents(const uint8_t groupId, std::unique_lock &asyncLock) { + auto &tasks = m_tasks[groupId]; + if (tasks.empty()) { + return; + } + + if (groupId == static_cast(TaskGroup::Serial)) { + executeSerialEvents(tasks); + } else { + executeParallelEvents(tasks, groupId, asyncLock); + } } void Dispatcher::executeScheduledEvents() { @@ -147,15 +160,9 @@ void Dispatcher::mergeEvents() { for (auto &thread : threads) { std::scoped_lock lock(thread->mutex); if (!thread->tasks.empty()) { - eventTasks.insert(eventTasks.end(), make_move_iterator(thread->tasks.begin()), make_move_iterator(thread->tasks.end())); - thread->tasks.clear(); - } - - for (uint_fast8_t i = 0; i < static_cast(AsyncEventContext::Last); ++i) { - auto &context = thread->asyncTasks[i]; - if (!context.empty()) { - asyncEventTasks[i].insert(asyncEventTasks[i].end(), make_move_iterator(context.begin()), make_move_iterator(context.end())); - context.clear(); + for (uint_fast8_t i = 0; i < static_cast(TaskGroup::Last); ++i) { + m_tasks[i].insert(m_tasks[i].end(), make_move_iterator(thread->tasks[i].begin()), make_move_iterator(thread->tasks[i].end())); + thread->tasks[i].clear(); } } @@ -167,6 +174,12 @@ void Dispatcher::mergeEvents() { thread->scheduledTasks.clear(); } } + for (uint_fast8_t i = 0; i < static_cast(TaskGroup::Last); ++i) { + if (!m_tasks[i].empty()) { + hasPendingTasks = true; + break; + } + } } std::chrono::nanoseconds Dispatcher::timeUntilNextScheduledTask() { diff --git a/src/game/scheduling/dispatcher.hpp b/src/game/scheduling/dispatcher.hpp index 0a4bc745b6f..c7eb9aeddd8 100644 --- a/src/game/scheduling/dispatcher.hpp +++ b/src/game/scheduling/dispatcher.hpp @@ -15,8 +15,9 @@ static constexpr uint16_t DISPATCHER_TASK_EXPIRATION = 2000; static constexpr uint16_t SCHEDULER_MINTICKS = 50; -enum class AsyncEventContext : uint8_t { - First, +enum class TaskGroup : uint8_t { + Serial, + GenericParallel, Last }; @@ -43,11 +44,11 @@ class Dispatcher { void init(); void shutdown() { - asyncTasks_cv.notify_all(); + signalAsync.notify_all(); } void addEvent(std::function &&f, std::string_view context, uint32_t expiresAfterMs = 0); - void addEvent_async(std::function &&f, AsyncEventContext context = AsyncEventContext::First); + void addEvent_async(std::function &&f, TaskGroup group = TaskGroup::Serial); uint64_t scheduleEvent(const std::shared_ptr &task); uint64_t scheduleEvent(uint32_t delay, std::function &&f, std::string_view context) { @@ -84,36 +85,37 @@ class Dispatcher { uint64_t scheduleEvent(uint32_t delay, std::function &&f, std::string_view context, bool cycle); inline void mergeEvents(); - inline void executeEvents(); - inline void executeAsyncEvents(const uint8_t contextId, std::unique_lock &asyncLock); + inline void executeEvents(const uint8_t groupId, std::unique_lock &asyncLock); inline void executeScheduledEvents(); + + inline void executeSerialEvents(std::vector &tasks); + inline void executeParallelEvents(std::vector &tasks, const uint8_t groupId, std::unique_lock &asyncLock); inline std::chrono::nanoseconds timeUntilNextScheduledTask(); uint_fast64_t dispatcherCycle = 0; ThreadPool &threadPool; - std::mutex mutex; - std::condition_variable asyncTasks_cv; - std::condition_variable cv; - bool hasPendingTasks = false; + std::condition_variable signalAsync; + std::condition_variable signalSchedule; + std::atomic_bool hasPendingTasks = false; // Thread Events struct ThreadTask { ThreadTask() { - tasks.reserve(2000); + for (auto &task : tasks) { + task.reserve(2000); + } scheduledTasks.reserve(2000); } - std::vector tasks; - std::array, static_cast(AsyncEventContext::Last)> asyncTasks; + std::array, static_cast(TaskGroup::Last)> tasks; std::vector> scheduledTasks; std::mutex mutex; }; std::vector> threads; // Main Events - std::vector eventTasks; - std::array, static_cast(AsyncEventContext::Last)> asyncEventTasks; + std::array, static_cast(TaskGroup::Last)> m_tasks; std::priority_queue, std::deque>, Task::Compare> scheduledTasks; phmap::parallel_flat_hash_map_m> scheduledTasksRef; }; diff --git a/src/game/scheduling/task.hpp b/src/game/scheduling/task.hpp index 0b850808947..cf964bd9314 100644 --- a/src/game/scheduling/task.hpp +++ b/src/game/scheduling/task.hpp @@ -94,29 +94,31 @@ class Task { static std::atomic_uint_fast64_t LAST_EVENT_ID; bool hasTraceableContext() const { - const static auto tasksContext = phmap::flat_hash_set({ "Creature::checkCreatureWalk", - "Decay::checkDecay", - "Dispatcher::addEvent_async", - "Game::checkCreatureAttack", - "Game::checkCreatures", - "Game::checkImbuements", - "Game::checkLight", - "Game::createFiendishMonsters", - "Game::createInfluencedMonsters", - "Game::updateCreatureWalk", - "Game::updateForgeableMonsters", - "GlobalEvents::think", - "LuaEnvironment::executeTimerEvent", - "Modules::executeOnRecvbyte", - "OutputMessagePool::sendAll", - "ProtocolGame::addGameTask", - "ProtocolGame::parsePacketFromDispatcher", - "Raids::checkRaids", - "SpawnMonster::checkSpawnMonster", - "SpawnMonster::scheduleSpawn", - "SpawnNpc::checkSpawnNpc", - "Webhook::run", - "sendRecvMessageCallback" }); + const static auto tasksContext = phmap::flat_hash_set({ + "Creature::checkCreatureWalk", + "Decay::checkDecay", + "Dispatcher::addEvent_async", + "Game::checkCreatureAttack", + "Game::checkCreatures", + "Game::checkImbuements", + "Game::checkLight", + "Game::createFiendishMonsters", + "Game::createInfluencedMonsters", + "Game::updateCreatureWalk", + "Game::updateForgeableMonsters", + "GlobalEvents::think", + "LuaEnvironment::executeTimerEvent", + "Modules::executeOnRecvbyte", + "OutputMessagePool::sendAll", + "ProtocolGame::addGameTask", + "ProtocolGame::parsePacketFromDispatcher", + "Raids::checkRaids", + "SpawnMonster::checkSpawnMonster", + "SpawnMonster::scheduleSpawn", + "SpawnNpc::checkSpawnNpc", + "Webhook::run", + "sendRecvMessageCallback", + }); return tasksContext.contains(context); }