Skip to content

Commit

Permalink
Merge pull request #65 from doujiang24/cmake-add-test
Browse files Browse the repository at this point in the history
[TransferEngine] test: cmake enable testing.
  • Loading branch information
alogfans authored Jan 7, 2025
2 parents 282bfa8 + 607fd4f commit 295d094
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 67 deletions.
20 changes: 17 additions & 3 deletions .github/workflows/build-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,26 @@ jobs:
run: |
sudo apt update -y
sudo bash -x dependencies.sh
mkdir build
mkdir build
cd build
cmake ..
cmake .. -DUSE_HTTP=ON
shell: bash
- name: make
run: |
cd build
make -j
shell: bash
shell: bash
- name: start-metadata-server
run: |
cd mooncake-transfer-engine/example/http-metadata-server
export PATH=$PATH:/usr/local/go/bin
go mod tidy && go build -o http-metadata-server .
./http-metadata-server --addr=:8090 &
shell: bash
- name: test
run: |
cd build
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib
ldconfig -v || echo "always continue"
MC_METADATA_SERVER=http://127.0.0.1:8090/metadata make test -j ARGS="-V"
shell: bash
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
cmake_minimum_required(VERSION 3.16)
project(mooncake CXX C)

enable_testing()

set(CMAKE_C_STANDARD 99)
set(CMAKE_CXX_STANDARD 17)

Expand Down
4 changes: 4 additions & 0 deletions dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,8 @@ cd build
cmake ..
make -j$(nproc) && sudo make install

echo "*** Download and installing [golang-1.22] ***"
wget https://go.dev/dl/go1.22.linux-amd64.tar.gz
sudo tar -C /usr/local -xzf go1.22.linux-amd64.tar.gz

echo "*** Dependencies Installed! ***"
2 changes: 1 addition & 1 deletion mooncake-transfer-engine/include/multi_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class MultiTransport {
private:
std::shared_ptr<TransferMetadata> metadata_;
std::string local_server_name_;
std::map<std::string, Transport *> transport_map_;
std::map<std::string, std::shared_ptr<Transport>> transport_map_;
RWSpinlock batch_desc_lock_;
std::unordered_map<BatchID, std::shared_ptr<BatchDesc>> batch_desc_set_;
};
Expand Down
11 changes: 6 additions & 5 deletions mooncake-transfer-engine/src/multi_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,13 @@ Transport *MultiTransport::installTransport(const std::string &proto,
return nullptr;
}

transport_map_[proto] = transport;
transport_map_[proto] = std::shared_ptr<Transport>(transport);
return transport;
}

Transport *MultiTransport::selectTransport(const TransferRequest &entry) {
if (entry.target_id == LOCAL_SEGMENT_ID && transport_map_.count("local"))
return transport_map_["local"];
return transport_map_["local"].get();
auto target_segment_desc = metadata_->getSegmentDescByID(entry.target_id);
if (!target_segment_desc) {
LOG(ERROR) << "MultiTransport: Incorrect target segment id "
Expand All @@ -165,17 +165,18 @@ Transport *MultiTransport::selectTransport(const TransferRequest &entry) {
LOG(ERROR) << "MultiTransport: Transport " << proto << " not installed";
return nullptr;
}
return transport_map_[proto];
return transport_map_[proto].get();
}

Transport *MultiTransport::getTransport(const std::string &proto) {
if (!transport_map_.count(proto)) return nullptr;
return transport_map_[proto];
return transport_map_[proto].get();
}

std::vector<Transport *> MultiTransport::listTransports() {
std::vector<Transport *> transport_list;
for (auto &entry : transport_map_) transport_list.push_back(entry.second);
for (auto &entry : transport_map_)
transport_list.push_back(entry.second.get());
return transport_list;
}

Expand Down
38 changes: 24 additions & 14 deletions mooncake-transfer-engine/src/transfer_metadata_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,18 @@ static inline const std::string getNetworkAddress(struct sockaddr *addr) {
}

struct SocketHandShakePlugin : public HandShakePlugin {
SocketHandShakePlugin() : listener_running_(false) {}
SocketHandShakePlugin() : listener_running_(false), listen_fd_(-1) {}

void closeListen() {
if (listen_fd_ >= 0) {
LOG(INFO) << "SocketHandShakePlugin: closing listen socket";
close(listen_fd_);
listen_fd_ = -1;
}
}

virtual ~SocketHandShakePlugin() {
closeListen();
if (listener_running_) {
listener_running_ = false;
listener_.join();
Expand All @@ -383,54 +392,54 @@ struct SocketHandShakePlugin : public HandShakePlugin {
virtual int startDaemon(OnReceiveCallBack on_recv_callback,
uint16_t listen_port) {
sockaddr_in bind_address;
int on = 1, listen_fd = -1;
int on = 1;
memset(&bind_address, 0, sizeof(sockaddr_in));
bind_address.sin_family = AF_INET;
bind_address.sin_port = htons(listen_port);
bind_address.sin_addr.s_addr = INADDR_ANY;

listen_fd = socket(AF_INET, SOCK_STREAM, 0);
if (listen_fd < 0) {
listen_fd_ = socket(AF_INET, SOCK_STREAM, 0);
if (listen_fd_ < 0) {
PLOG(ERROR) << "SocketHandShakePlugin: socket()";
return ERR_SOCKET;
}

struct timeval timeout;
timeout.tv_sec = 1;
timeout.tv_usec = 0;
if (setsockopt(listen_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout,
if (setsockopt(listen_fd_, SOL_SOCKET, SO_RCVTIMEO, &timeout,
sizeof(timeout))) {
PLOG(ERROR) << "SocketHandShakePlugin: setsockopt(SO_RCVTIMEO)";
close(listen_fd);
closeListen();
return ERR_SOCKET;
}

if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on))) {
if (setsockopt(listen_fd_, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on))) {
PLOG(ERROR) << "SocketHandShakePlugin: setsockopt(SO_REUSEADDR)";
close(listen_fd);
closeListen();
return ERR_SOCKET;
}

if (bind(listen_fd, (sockaddr *)&bind_address, sizeof(sockaddr_in)) <
if (bind(listen_fd_, (sockaddr *)&bind_address, sizeof(sockaddr_in)) <
0) {
PLOG(ERROR) << "SocketHandShakePlugin: bind (port " << listen_port
<< ")";
close(listen_fd);
closeListen();
return ERR_SOCKET;
}

if (listen(listen_fd, 5)) {
if (listen(listen_fd_, 5)) {
PLOG(ERROR) << "SocketHandShakePlugin: listen()";
close(listen_fd);
closeListen();
return ERR_SOCKET;
}

listener_running_ = true;
listener_ = std::thread([this, listen_fd, on_recv_callback]() {
listener_ = std::thread([this, on_recv_callback]() {
while (listener_running_) {
sockaddr_in addr;
socklen_t addr_len = sizeof(sockaddr_in);
int conn_fd = accept(listen_fd, (sockaddr *)&addr, &addr_len);
int conn_fd = accept(listen_fd_, (sockaddr *)&addr, &addr_len);
if (conn_fd < 0) {
if (errno != EWOULDBLOCK)
PLOG(ERROR) << "SocketHandShakePlugin: accept()";
Expand Down Expand Up @@ -584,6 +593,7 @@ struct SocketHandShakePlugin : public HandShakePlugin {

std::atomic<bool> listener_running_;
std::thread listener_;
int listen_fd_;
};

std::shared_ptr<HandShakePlugin> HandShakePlugin::Create(
Expand Down
12 changes: 9 additions & 3 deletions mooncake-transfer-engine/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
add_executable(rdma_transport_test rdma_transport_test.cpp)
target_link_libraries(rdma_transport_test PUBLIC transfer_engine)
# add_test(NAME rdma_transport_test COMMAND rdma_transport_test)

add_executable(transport_uint_test transport_uint_test.cpp)
target_link_libraries(transport_uint_test PUBLIC transfer_engine gtest gtest_main )
target_link_libraries(transport_uint_test PUBLIC transfer_engine gtest gtest_main )
add_test(NAME transport_uint_test COMMAND transport_uint_test)

add_executable(rdma_transport_test2 rdma_transport_test2.cpp)
target_link_libraries(rdma_transport_test2 PUBLIC transfer_engine gtest gtest_main )
target_link_libraries(rdma_transport_test2 PUBLIC transfer_engine gtest gtest_main )
# add_test(NAME rdma_transport_test2 COMMAND rdma_transport_test2)

add_executable(tcp_transport_test tcp_transport_test.cpp)
target_link_libraries(tcp_transport_test PUBLIC transfer_engine gtest gtest_main )
target_link_libraries(tcp_transport_test PUBLIC transfer_engine gtest gtest_main )
add_test(NAME tcp_transport_test COMMAND tcp_transport_test)

add_executable(transfer_metadata_test transfer_metadata_test.cpp)
target_link_libraries(transfer_metadata_test PUBLIC transfer_engine gtest gtest_main)
add_test(NAME transfer_metadata_test COMMAND transfer_metadata_test)

add_executable(topology_test topology_test.cpp)
target_link_libraries(topology_test PUBLIC transfer_engine gtest gtest_main)
add_test(NAME topology_test COMMAND topology_test)
49 changes: 33 additions & 16 deletions mooncake-transfer-engine/tests/tcp_transport_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ DEFINE_int32(gpu_id, 0, "GPU ID to use");

using namespace mooncake;

//// etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls
/// http://10.0.0.1:2379 / ./tcp_transport_test
namespace mooncake {

class TCPTransportTest : public ::testing::Test {
Expand All @@ -58,12 +56,29 @@ class TCPTransportTest : public ::testing::Test {
void SetUp() override {
google::InitGoogleLogging("TCPTransportTest");
FLAGS_logtostderr = 1;

const char *env = std::getenv("MC_METADATA_SERVER");
if (env)
metadata_server = env;
else
metadata_server = metadata_server;
LOG(INFO) << "metadata_server: " << metadata_server;

env = std::getenv("MC_LOCAL_SERVER_NAME");
if (env)
local_server_name = env;
else
local_server_name = "127.0.0.2:12345";
LOG(INFO) << "local_server_name: " << local_server_name;
}

void TearDown() override {
// 清理 glog
google::ShutdownGoogleLogging();
}

std::string metadata_server;
std::string local_server_name;
};

static void *allocateMemoryPool(size_t size, int socket_id,
Expand All @@ -73,9 +88,10 @@ static void *allocateMemoryPool(size_t size, int socket_id,

TEST_F(TCPTransportTest, GetTcpTest) {
auto engine = std::make_unique<TransferEngine>();
auto hostname_port = parseHostNameWithPort("127.0.0.2:12345");
engine->init("127.0.0.1:2379", "127.0.0.2:12345",
hostname_port.first.c_str(), hostname_port.second);
auto hostname_port = parseHostNameWithPort(local_server_name);
auto rc = engine->init(metadata_server, local_server_name,
hostname_port.first.c_str(), hostname_port.second);
LOG_ASSERT(rc == 0);
Transport *xport = nullptr;
xport = engine->installTransport("tcp", nullptr);
LOG_ASSERT(xport != nullptr);
Expand All @@ -86,22 +102,23 @@ TEST_F(TCPTransportTest, Writetest) {
void *addr = nullptr;
const size_t ram_buffer_size = 1ull << 30;
auto engine = std::make_unique<TransferEngine>();
auto hostname_port = parseHostNameWithPort("127.0.0.2:12345");
engine->init("127.0.0.1:2379", "127.0.0.2:12345",
hostname_port.first.c_str(), hostname_port.second);
auto hostname_port = parseHostNameWithPort(local_server_name);
auto rc = engine->init(metadata_server, local_server_name,
hostname_port.first.c_str(), hostname_port.second);
LOG_ASSERT(rc == 0);
Transport *xport = nullptr;
xport = engine->installTransport("tcp", nullptr);
LOG_ASSERT(xport != nullptr);

addr = allocateMemoryPool(ram_buffer_size, 0, false);
int rc = engine->registerLocalMemory(addr, ram_buffer_size, "cpu:0");
rc = engine->registerLocalMemory(addr, ram_buffer_size, "cpu:0");
LOG_ASSERT(!rc);

for (size_t offset = 0; offset < kDataLength; ++offset)
*((char *)(addr) + offset) = 'a' + lrand48() % 26;
auto batch_id = engine->allocateBatchID(1);
int ret = 0;
auto segment_id = engine->openSegment("127.0.0.2:12345");
auto segment_id = engine->openSegment(local_server_name);
TransferRequest entry;
auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id);
uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr;
Expand Down Expand Up @@ -129,8 +146,8 @@ TEST_F(TCPTransportTest, WriteAndReadtest) {
void *addr = nullptr;
const size_t ram_buffer_size = 1ull << 30;
auto engine = std::make_unique<TransferEngine>();
auto hostname_port = parseHostNameWithPort("127.0.0.2:12345");
engine->init("127.0.0.1:2379", "127.0.0.2:12345",
auto hostname_port = parseHostNameWithPort(local_server_name);
engine->init(metadata_server, local_server_name,
hostname_port.first.c_str(), hostname_port.second);
Transport *xport = nullptr;
xport = engine->installTransport("tcp", nullptr);
Expand All @@ -142,7 +159,7 @@ TEST_F(TCPTransportTest, WriteAndReadtest) {
for (size_t offset = 0; offset < kDataLength; ++offset)
*((char *)(addr) + offset) = 'a' + lrand48() % 26;

auto segment_id = engine->openSegment("127.0.0.2:12345");
auto segment_id = engine->openSegment(local_server_name);
auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id);
uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr;
{
Expand Down Expand Up @@ -200,8 +217,8 @@ TEST_F(TCPTransportTest, WriteAndRead2test) {
void *addr = nullptr;
const size_t ram_buffer_size = 1ull << 30;
auto engine = std::make_unique<TransferEngine>();
auto hostname_port = parseHostNameWithPort("127.0.0.2:12345");
engine->init("127.0.0.1:2379", "127.0.0.2:12345",
auto hostname_port = parseHostNameWithPort(local_server_name);
engine->init(metadata_server, local_server_name,
hostname_port.first.c_str(), hostname_port.second);
Transport *xport = nullptr;
xport = engine->installTransport("tcp", nullptr);
Expand All @@ -213,7 +230,7 @@ TEST_F(TCPTransportTest, WriteAndRead2test) {
for (size_t offset = 0; offset < kDataLength; ++offset)
*((char *)(addr) + offset) = 'a' + lrand48() % 26;

auto segment_id = engine->openSegment("127.0.0.2:12345");
auto segment_id = engine->openSegment(local_server_name);
auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id);
uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr;

Expand Down
Loading

0 comments on commit 295d094

Please sign in to comment.