1515#include < taskr/taskr.hpp>
1616
1717#define _PROMPT_THREAD_COUNT 16
18+ #define _REQUESTS_PER_THREAD_COUNT 32
1819
1920int main (int argc, char *argv[])
2021{
@@ -176,6 +177,7 @@ int main(int argc, char *argv[])
176177 std::vector<std::unique_ptr<std::thread>> promptThreads;
177178 std::default_random_engine promptTimeRandomEngine;
178179 std::uniform_real_distribution<double > promptTimeRandomDistribution (0.0 , 1.0 );
180+ std::atomic<size_t > finishedPromptThreads = 0 ;
179181 if (isRoot)
180182 {
181183 for (size_t i = 0 ; i < _PROMPT_THREAD_COUNT; i++)
@@ -188,20 +190,23 @@ int main(int argc, char *argv[])
188190 auto session = hllm.createSession ();
189191
190192 // Send a test message
191- size_t currentPrompt = 0 ;
192- for (size_t iterations = 0 ; true ; iterations++)
193+ for (size_t promptCount = 0 ; promptCount < _REQUESTS_PER_THREAD_COUNT; promptCount++)
193194 {
194- const auto prompt = session->createPrompt (std::string (" Hello, World! " ) + std::to_string (currentPrompt++ ));
195+ const auto prompt = session->createPrompt (std::string (" Hello, World! " ) + std::to_string (promptCount ));
195196 session->pushPrompt (prompt);
196197 // printf("[User] Sent prompt: %s\n", prompt->getPrompt().c_str());
197198 while (prompt->hasResponse () == false );
198199 const auto promptId = prompt->getPromptId ();
199200 printf (" [User %04lu] Got response: '%s' for prompt %lu/%lu: '%s'\n " , i, prompt->getResponse ().c_str (), promptId.first , promptId.second , prompt->getPrompt ().c_str ());
200201 usleep (100000.0 * promptTimeRandomDistribution (promptTimeRandomEngine));
201202 }
202-
203- // Violently exit when done with the test
204- exit (0 );
203+
204+ // Increase counter for finished prompt threads
205+ const auto finishedThreads = finishedPromptThreads.fetch_add (1 ) + 1 ;
206+ printf (" Finished Threads: %lu\n " , finishedThreads);
207+
208+ // If ths was the last thread, then ask hllm to shutdown
209+ if (finishedThreads == _PROMPT_THREAD_COUNT) hllm.requestTermination ();
205210 }));
206211 }
207212
@@ -212,5 +217,5 @@ int main(int argc, char *argv[])
212217 if (isRoot) for (auto & thread : promptThreads) thread->join ();
213218
214219 // Finalize Instance Manager
215- instanceManager->finalize ();
220+ // instanceManager->finalize();
216221}
0 commit comments