Skip to content

Commit

Permalink
feat: Scale acceptor threads with num_workers
Browse files Browse the repository at this point in the history
Bug: N/A
Change-Id: Ia75bcb4797b183b118d297b9e0c8c7c823563566
GitOrigin-RevId: 5768dabf5d81ef63af65dfcb2227623c7a850410
  • Loading branch information
Privacy Sandbox Team authored and copybara-github committed Dec 14, 2024
1 parent 15d3fa3 commit 3f61745
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 32 deletions.
73 changes: 44 additions & 29 deletions src/roma/byob/dispatcher/dispatcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ Dispatcher::~Dispatcher() {
+[](int* i) { return *i == 0; }, &executor_threads_in_flight_));
}
::shutdown(listen_fd_, SHUT_RDWR);
if (acceptor_.has_value()) {
acceptor_->join();
{
absl::MutexLock lock(&mu_);
mu_.Await(absl::Condition(
+[](int* i) { return *i == 0; }, &acceptor_threads_in_flight_));
}
for (auto& [_, fds_and_tokens] : code_token_to_fds_and_tokens_) {
while (!fds_and_tokens.empty()) {
Expand Down Expand Up @@ -136,7 +138,6 @@ absl::Status Dispatcher::Init(std::filesystem::path control_socket_name,
// Ignore SIGPIPE. Otherwise, host process will crash when UDFs close sockets
// before or while Roma writes requests or callback responses.
::signal(SIGPIPE, SIG_IGN);
acceptor_.emplace(&Dispatcher::AcceptorImpl, this);
return absl::OkStatus();
}

Expand Down Expand Up @@ -183,6 +184,13 @@ absl::StatusOr<std::string> Dispatcher::LoadBinary(
code_token_to_fds_and_tokens_.erase(code_token);
return privacy_sandbox::server_common::ToAbslStatus(status);
}
{
absl::MutexLock lock(&mu_);
acceptor_threads_in_flight_ += num_workers;
}
for (int i = 0; i < num_workers; ++i) {
std::thread(&Dispatcher::AcceptorImpl, this, code_token).detach();
}
return code_token;
}

Expand Down Expand Up @@ -211,6 +219,13 @@ absl::StatusOr<std::string> Dispatcher::LoadBinaryForLogging(
code_token_to_fds_and_tokens_.erase(code_token);
return privacy_sandbox::server_common::ToAbslStatus(status);
}
{
absl::MutexLock lock(&mu_);
acceptor_threads_in_flight_ += num_workers;
}
for (int i = 0; i < num_workers; ++i) {
std::thread(&Dispatcher::AcceptorImpl, this, code_token).detach();
}
return code_token;
}

Expand Down Expand Up @@ -241,20 +256,28 @@ void Dispatcher::Cancel(google::scp::roma::ExecutionToken execution_token) {
stub_->Cancel(&context, request, &response);
}

void Dispatcher::AcceptorImpl() {
absl::Mutex thread_count_mu;
int thread_count = 0; // Guarded by `thread_count_mu`.
auto read_tokens_and_push_to_queue = [this, &thread_count_mu,
&thread_count](int fd) {
// `parent_code_token` identifies the code_token generated by the `Load` call
// that detached this thread. The loop breaks when the code_token is deleted,
// ensuring the number of acceptors scales back. Note that an acceptor will
// accept connections without regard to UDF code_token.
void Dispatcher::AcceptorImpl(std::string parent_code_token) {
while (true) {
const int fd = ::accept(listen_fd_, nullptr, nullptr);
if (fd == -1) {
break;
}

// Read code token and execution token, 36 bytes each.
// First is code token, second is execution token.
if (absl::StatusOr<std::string> data = Read(fd, kNumTokenBytes * 2);
!data.ok()) {
absl::StatusOr<std::string> data = Read(fd, kNumTokenBytes * 2);
if (!data.ok()) {
LOG(ERROR) << "Read failure: " << data.status();
::close(fd);
} else {
std::string execution_token = data->substr(kNumTokenBytes);
data->resize(kNumTokenBytes);
continue;
}
std::string execution_token = data->substr(kNumTokenBytes);
data->resize(kNumTokenBytes);
{
absl::MutexLock lock(&mu_);
if (const auto it = code_token_to_fds_and_tokens_.find(*data);
it != code_token_to_fds_and_tokens_.end()) {
Expand All @@ -266,24 +289,16 @@ void Dispatcher::AcceptorImpl() {
LOG(ERROR) << "Unrecognized code token.";
::close(fd);
}

// Break loop if the code_token generated by the `Load` call that detached
// this thread was deleted.
if (!code_token_to_fds_and_tokens_.contains(parent_code_token)) {
break;
}
}
absl::MutexLock lock(&thread_count_mu);
--thread_count;
};
while (true) {
const int fd = ::accept(listen_fd_, nullptr, nullptr);
if (fd == -1) {
break;
}
{
absl::MutexLock lock(&thread_count_mu);
++thread_count;
}
std::thread(read_tokens_and_push_to_queue, fd).detach();
}
absl::MutexLock lock(&thread_count_mu);
thread_count_mu.Await(
absl::Condition(+[](int* i) { return *i == 0; }, &thread_count));
absl::MutexLock lock(&mu_);
--acceptor_threads_in_flight_;
}

void Dispatcher::ExecutorImpl(const int fd,
Expand Down
5 changes: 2 additions & 3 deletions src/roma/byob/dispatcher/dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <filesystem>
#include <fstream>
#include <memory>
#include <optional>
#include <queue>
#include <string>
#include <string_view>
Expand Down Expand Up @@ -119,16 +118,16 @@ class Dispatcher {

// Accepts connections from newly created UDF instances, reads code tokens,
// and pushes file descriptors to the queue.
void AcceptorImpl();
void AcceptorImpl(std::string parent_code_token) ABSL_LOCKS_EXCLUDED(mu_);
void ExecutorImpl(int fd, const google::protobuf::Message& request,
absl::AnyInvocable<void(int) &&> handler)
ABSL_LOCKS_EXCLUDED(mu_);

int listen_fd_;
std::filesystem::path log_dir_;
std::unique_ptr<WorkerRunnerService::Stub> stub_;
std::optional<std::thread> acceptor_;
absl::Mutex mu_;
int acceptor_threads_in_flight_ ABSL_GUARDED_BY(mu_) = 0;
int executor_threads_in_flight_ ABSL_GUARDED_BY(mu_) = 0;
absl::flat_hash_map<std::string, std::queue<FdAndToken>>
code_token_to_fds_and_tokens_ ABSL_GUARDED_BY(mu_);
Expand Down

0 comments on commit 3f61745

Please sign in to comment.