Skip to content

Commit

Permalink
fix: Handle read calls that successfully return fewers chars than req…
Browse files Browse the repository at this point in the history
…uest.

(valid). Also reduce number of system calls in the common case.

Bug: 377484615
Change-Id: Ie2da709e10c78d44e8fdb9404a2cbacafff10590
GitOrigin-RevId: 3cae94bfae4d33b29cc9029691b9e9c408b426e7
  • Loading branch information
Privacy Sandbox Team authored and copybara-github committed Nov 5, 2024
1 parent caa2701 commit 5566e1d
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/roma/byob/container/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ cc_binary(
srcs = ["run_workers.cc"],
deps = [
"//src/core/common/uuid",
"//src/roma/byob/dispatcher:dispatcher_cc_proto",
"//src/roma/byob/dispatcher",
"//src/util/status_macro:status_macros",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
6 changes: 4 additions & 2 deletions src/roma/byob/container/run_workers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "absl/types/span.h"
#include "google/protobuf/util/delimited_message_util.h"
#include "src/core/common/uuid/uuid.h"
#include "src/roma/byob/dispatcher/dispatcher.h"
#include "src/roma/byob/dispatcher/dispatcher.pb.h"
#include "src/util/status_macro/status_macros.h"

Expand All @@ -59,6 +60,7 @@ namespace {
using ::google::protobuf::io::FileInputStream;
using ::google::protobuf::util::ParseDelimitedFromZeroCopyStream;
using ::privacy_sandbox::server_common::byob::DispatcherRequest;
using ::privacy_sandbox::server_common::byob::kNumTokenBytes;

bool ConnectToPath(int fd, std::string_view socket_name) {
::sockaddr_un sa = {
Expand Down Expand Up @@ -209,9 +211,9 @@ constexpr uint32_t MaxIntDecimalLength() {
int WorkerImpl(void* arg) {
const WorkerImplArg& worker_impl_arg = *static_cast<WorkerImplArg*>(arg);
PCHECK(::write(worker_impl_arg.rpc_fd, worker_impl_arg.code_token.data(),
36) == 36);
kNumTokenBytes) == kNumTokenBytes);
PCHECK(::write(worker_impl_arg.rpc_fd, worker_impl_arg.execution_token.data(),
36) == 36);
kNumTokenBytes) == kNumTokenBytes);

// Add one to decimal length because `snprintf` adds a null terminator.
char connection_fd[MaxIntDecimalLength() + 1];
Expand Down
2 changes: 1 addition & 1 deletion src/roma/byob/dispatcher/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ cc_library(
name = "dispatcher",
srcs = ["dispatcher.cc"],
hdrs = ["dispatcher.h"],
visibility = ["//src/roma/byob/interface:__subpackages__"],
visibility = ["//src/roma/byob:__subpackages__"],
deps = [
":dispatcher_cc_proto",
"//src/core/common/uuid",
Expand Down
34 changes: 28 additions & 6 deletions src/roma/byob/dispatcher/dispatcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ using ::google::protobuf::util::ParseDelimitedFromZeroCopyStream;
using ::google::protobuf::util::SerializeDelimitedToFileDescriptor;
using ::google::scp::core::common::Uuid;
using ::google::scp::roma::proto::FunctionBindingIoProto;

absl::StatusOr<std::string> Read(int fd, int size) {
std::string buffer(size, '\0');
size_t read_bytes = 0;
while (read_bytes < size) {
const ssize_t n = ::read(fd, &buffer[read_bytes], size - read_bytes);
if (n == -1) {
return absl::ErrnoToStatus(errno, "Failed to read data from fd.");
} else if (n == 0) {
return absl::UnavailableError("Unexpected EOF.");
}
read_bytes += n;
}
return buffer;
}

} // namespace

Dispatcher::~Dispatcher() {
Expand Down Expand Up @@ -163,19 +179,25 @@ void Dispatcher::AcceptorImpl() {
if (fd == -1) {
break;
}
char code_token[37] = {};
char execution_token[37] = {};
(void)::read(fd, code_token, 36);
(void)::read(fd, execution_token, 36);
// Read code token and exectution token, both are 36 bytes.
// First is code token, second is execution token.
auto data = Read(fd, kNumTokenBytes * 2);
if (!data.ok()) {
LOG(ERROR) << "Read failure closing socket: " << data.status();
::close(fd);
continue;
}
std::string execution_token = data->substr(kNumTokenBytes);
data->resize(kNumTokenBytes);
absl::MutexLock lock(&mu_);
const auto it = code_token_to_fds_and_tokens_.find(code_token);
const auto it = code_token_to_fds_and_tokens_.find(*data);
if (it == code_token_to_fds_and_tokens_.end()) {
LOG(INFO) << "Unrecognized code token.";
continue;
}
it->second.push(FdAndToken{
.fd = fd,
.token = execution_token,
.token = std::move(execution_token),
});
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/roma/byob/dispatcher/dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
#include "src/util/execution_token.h"

namespace privacy_sandbox::server_common::byob {
inline constexpr size_t kNumTokenBytes = 36;

class Dispatcher {
public:
~Dispatcher();
Expand Down
76 changes: 42 additions & 34 deletions src/roma/byob/dispatcher/dispatcher_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ TEST(DispatcherTest, LoadGoesToWorker) {
ASSERT_TRUE(ParseDelimitedFromZeroCopyStream(&request, &input, nullptr));
}
ASSERT_TRUE(request.has_load_binary());
ASSERT_EQ(request.load_binary().code_token().size(), 36);
ASSERT_EQ(request.load_binary().code_token().size(), kNumTokenBytes);
EXPECT_EQ(request.load_binary().num_workers(), 7);
EXPECT_EQ(::close(fd), 0);
});
Expand All @@ -204,7 +204,7 @@ TEST(DispatcherTest, LoadAndDeleteGoToWorker) {
ASSERT_TRUE(
ParseDelimitedFromZeroCopyStream(&load_request, &input, nullptr));
ASSERT_TRUE(load_request.has_load_binary());
ASSERT_EQ(load_request.load_binary().code_token().size(), 36);
ASSERT_EQ(load_request.load_binary().code_token().size(), kNumTokenBytes);
EXPECT_EQ(load_request.load_binary().num_workers(), 3);
DispatcherRequest delete_request;
ASSERT_TRUE(
Expand Down Expand Up @@ -251,19 +251,20 @@ TEST(DispatcherTest, LoadAndExecute) {
ASSERT_TRUE(ParseDelimitedFromZeroCopyStream(&request, &input, nullptr));
}
ASSERT_TRUE(request.has_load_binary());
ASSERT_EQ(request.load_binary().code_token().size(), 36);
ASSERT_EQ(request.load_binary().code_token().size(), kNumTokenBytes);
EXPECT_EQ(request.load_binary().num_workers(), 1);

// Process execution request.
const int connection_fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
ASSERT_NE(connection_fd, -1);
ConnectToPath(connection_fd, "abcd.sock");
EXPECT_EQ(
::write(connection_fd, request.load_binary().code_token().c_str(), 36),
36);
EXPECT_EQ(::write(connection_fd, request.load_binary().code_token().c_str(),
kNumTokenBytes),
kNumTokenBytes);
{
const std::string execution_token(36, 'a');
EXPECT_EQ(::write(connection_fd, execution_token.c_str(), 36), 36);
const std::string execution_token(kNumTokenBytes, 'a');
EXPECT_EQ(::write(connection_fd, execution_token.c_str(), kNumTokenBytes),
kNumTokenBytes);
}
{
// Read UDF input.
Expand Down Expand Up @@ -329,19 +330,20 @@ TEST(DispatcherTest, LoadAndCloseBeforeExecute) {
ASSERT_TRUE(ParseDelimitedFromZeroCopyStream(&request, &input, nullptr));
}
ASSERT_TRUE(request.has_load_binary());
ASSERT_EQ(request.load_binary().code_token().size(), 36);
ASSERT_EQ(request.load_binary().code_token().size(), kNumTokenBytes);
EXPECT_EQ(request.load_binary().num_workers(), 1);

// Process execution request.
const int connection_fd = ::socket(AF_UNIX, SOCK_STREAM, /*protocol=*/0);
ASSERT_NE(connection_fd, -1);
ConnectToPath(connection_fd, "abcd.sock");
EXPECT_EQ(::write(connection_fd, request.load_binary().code_token().c_str(),
/*count=*/36),
36);
/*count=*/kNumTokenBytes),
kNumTokenBytes);
{
const std::string execution_token(36, 'a');
EXPECT_EQ(::write(connection_fd, execution_token.c_str(), 36), 36);
const std::string execution_token(kNumTokenBytes, 'a');
EXPECT_EQ(::write(connection_fd, execution_token.c_str(), kNumTokenBytes),
kNumTokenBytes);
}
EXPECT_EQ(::close(connection_fd), 0);
EXPECT_EQ(::close(fd), 0);
Expand Down Expand Up @@ -385,19 +387,20 @@ TEST(DispatcherTest, LoadAndExecuteWithCallbacks) {
ASSERT_TRUE(ParseDelimitedFromZeroCopyStream(&request, &input, nullptr));
}
ASSERT_TRUE(request.has_load_binary());
ASSERT_EQ(request.load_binary().code_token().size(), 36);
ASSERT_EQ(request.load_binary().code_token().size(), kNumTokenBytes);
EXPECT_EQ(request.load_binary().num_workers(), 1);

// Process execution request.
const int connection_fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
ASSERT_NE(connection_fd, -1);
ConnectToPath(connection_fd, "abcd.sock");
EXPECT_EQ(
::write(connection_fd, request.load_binary().code_token().c_str(), 36),
36);
EXPECT_EQ(::write(connection_fd, request.load_binary().code_token().c_str(),
kNumTokenBytes),
kNumTokenBytes);
{
const std::string execution_token(36, 'a');
EXPECT_EQ(::write(connection_fd, execution_token.c_str(), 36), 36);
const std::string execution_token(kNumTokenBytes, 'a');
EXPECT_EQ(::write(connection_fd, execution_token.c_str(), kNumTokenBytes),
kNumTokenBytes);
}

// Read UDF input.
Expand Down Expand Up @@ -492,19 +495,20 @@ TEST(DispatcherTest, LoadAndExecuteWithCallbacksWithoutReadingResponse) {
FileInputStream input(fd);
ASSERT_TRUE(ParseDelimitedFromZeroCopyStream(&request, &input, nullptr));
}
ASSERT_EQ(request.load_binary().code_token().size(), 36);
ASSERT_EQ(request.load_binary().code_token().size(), kNumTokenBytes);
EXPECT_EQ(request.load_binary().num_workers(), 1);

// Process execution request.
const int connection_fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
ASSERT_NE(connection_fd, -1);
ConnectToPath(connection_fd, "abcd.sock");
EXPECT_EQ(::write(connection_fd, request.load_binary().code_token().c_str(),
/*count=*/36),
36);
/*count=*/kNumTokenBytes),
kNumTokenBytes);
{
const std::string execution_token(36, 'a');
EXPECT_EQ(::write(connection_fd, execution_token.c_str(), 36), 36);
const std::string execution_token(kNumTokenBytes, 'a');
EXPECT_EQ(::write(connection_fd, execution_token.c_str(), kNumTokenBytes),
kNumTokenBytes);
}
{
// Read UDF input.
Expand Down Expand Up @@ -580,7 +584,7 @@ TEST(DispatcherTest, LoadAndExecuteWithCallbacksAndMetadata) {
DispatcherRequest request;
ASSERT_TRUE(ParseDelimitedFromZeroCopyStream(&request, &input, nullptr));
ASSERT_TRUE(request.has_load_binary());
ASSERT_EQ(request.load_binary().code_token().size(), 36);
ASSERT_EQ(request.load_binary().code_token().size(), kNumTokenBytes);
code_token =
std::move(*request.mutable_load_binary()->mutable_code_token());
EXPECT_EQ(request.load_binary().num_workers(), 1);
Expand All @@ -591,10 +595,13 @@ TEST(DispatcherTest, LoadAndExecuteWithCallbacksAndMetadata) {
const int connection_fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
ASSERT_NE(connection_fd, -1);
ConnectToPath(connection_fd, "abcd.sock", /*unlink_path=*/false);
EXPECT_EQ(::write(connection_fd, code_token.c_str(), 36), 36);
EXPECT_EQ(::write(connection_fd, code_token.c_str(), kNumTokenBytes),
kNumTokenBytes);
{
const std::string execution_token(36, 'a');
EXPECT_EQ(::write(connection_fd, execution_token.c_str(), 36), 36);
const std::string execution_token(kNumTokenBytes, 'a');
EXPECT_EQ(
::write(connection_fd, execution_token.c_str(), kNumTokenBytes),
kNumTokenBytes);
}

// Read UDF input.
Expand Down Expand Up @@ -678,18 +685,19 @@ TEST(DispatcherTest, LoadAndExecuteThenCancel) {
DispatcherRequest request;
ASSERT_TRUE(ParseDelimitedFromZeroCopyStream(&request, &input, nullptr));
ASSERT_TRUE(request.has_load_binary());
ASSERT_EQ(request.load_binary().code_token().size(), 36);
ASSERT_EQ(request.load_binary().code_token().size(), kNumTokenBytes);
EXPECT_EQ(request.load_binary().num_workers(), 1);

// Process execution request.
const int connection_fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
ASSERT_NE(connection_fd, -1);
ConnectToPath(connection_fd, "abcd.sock", /*unlink_path=*/false);
EXPECT_EQ(
::write(connection_fd, request.load_binary().code_token().c_str(), 36),
36);
const std::string execution_token(36, 'a');
EXPECT_EQ(::write(connection_fd, execution_token.c_str(), 36), 36);
EXPECT_EQ(::write(connection_fd, request.load_binary().code_token().c_str(),
kNumTokenBytes),
kNumTokenBytes);
const std::string execution_token(kNumTokenBytes, 'a');
EXPECT_EQ(::write(connection_fd, execution_token.c_str(), kNumTokenBytes),
kNumTokenBytes);
{
// Read UDF input.
google::protobuf::Any any;
Expand Down

0 comments on commit 5566e1d

Please sign in to comment.