Skip to content

Commit

Permalink
support one ggml_tensor rdma
Browse files Browse the repository at this point in the history
  • Loading branch information
MjieYu committed Mar 27, 2024
1 parent 54cbd72 commit 9366e0c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 242 deletions.
93 changes: 61 additions & 32 deletions rdma-example/src/rdma_client_LLM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ static struct rdma_buffer_attr_vec server_metadata_attrs;
static struct rdma_buffer_attr_vec client_metadata_attrs;
std::vector<struct ibv_mr *> client_src_mrs;
std::vector<struct ibv_mr *> client_dst_mrs;
int total = 1;
std::vector<struct ibv_sge> client_send_sges;
std::vector<struct ibv_recv_wr> server_recv_wrs;

int total = 2;

void printf_value(struct ggml_tensor * tensor )
{
Expand Down Expand Up @@ -72,6 +75,21 @@ static int check_src_dst_LLM()
printf_value(tensor_dst);
return memcmp((void*) tensor_src->data, (void*) tensor_dst->data, ggml_nbytes(tensor_dst));
}
static int check_src_dst_LLM_vec()
{
printf("%s\n",__func__);
for(int i=0;i<total;++i)
{
printf_value(tensor_srcs[i]);
printf_value(tensor_dsts[i]);
if(memcmp((void*) tensor_srcs[i]->data, (void*) tensor_dsts[i]->data, ggml_nbytes(tensor_dsts[i]))!=0)
{
return -1;
}
}
return 0;

}
/* This function prepares client side connection resources for an RDMA connection */
static int client_prepare_connection(struct sockaddr_in *s_addr)
{
Expand Down Expand Up @@ -529,7 +547,7 @@ static int client_xchange_metadata_with_server_LLM_vec()
return ret;
}
debug("Server sent us its buffer location and credentials, showing \n");
show_rdma_buffer_attr(&server_metadata_attr);
show_rdma_buffer_attrs(&server_metadata_attrs);
return 0;
}

Expand Down Expand Up @@ -709,10 +727,11 @@ static int client_remote_memory_ops_LLM()
}
static int client_remote_memory_ops_LLM_vec()
{
struct ibv_wc wc;
struct ibv_wc wc[total];
int ret = -1;
printf_value(tensor_srcs[0]);
for(int i=0;i<total;++i)
{ printf("rdma_buffer_register\n");
{
client_dst_mrs[i] = rdma_buffer_register(pd, //使用rdma_buffer_register()函数将dst缓冲区注册到RDMA设备上,并指定访问权限为本地写、远程写和远程读。
tensor_dsts[i]->data,
ggml_nbytes(tensor_srcs[i]),
Expand All @@ -724,6 +743,7 @@ static int client_remote_memory_ops_LLM_vec()
return -ENOMEM;
}
}
printf_value(tensor_srcs[0]);

/* Step 1: is to copy the local buffer into the remote buffer. We will
* reuse the previous variables. */
Expand All @@ -740,8 +760,8 @@ static int client_remote_memory_ops_LLM_vec()
client_send_wr.opcode = IBV_WR_RDMA_WRITE;
client_send_wr.send_flags = IBV_SEND_SIGNALED;
/* we have to tell server side info for RDMA */ //设置远程RDMA操作的相关信息,包括远程键(rkey)和远程地址。
client_send_wr.wr.rdma.rkey = server_metadata_attr.stag.remote_stag;
client_send_wr.wr.rdma.remote_addr = server_metadata_attr.address;
client_send_wr.wr.rdma.rkey = server_metadata_attrs.stags[i].remote_stag;
client_send_wr.wr.rdma.remote_addr = server_metadata_attrs.address[i];
/* Now we post it */
ret = ibv_post_send(client_qp, //调用ibv_post_send()函数将发送请求发送到RDMA队列中。
&client_send_wr,
Expand All @@ -751,6 +771,8 @@ static int client_remote_memory_ops_LLM_vec()
-errno);
return -errno;
}
sleep(5);

}
/* now we link to the send work request */ //初始化client_send_wr结构体,并设置相关参数,如SGE列表、SGE数量、操作码(IBV_WR_RDMA_WRITE)和发送标志(IBV_SEND_SIGNALED)。

Expand All @@ -760,15 +782,16 @@ static int client_remote_memory_ops_LLM_vec()

/* at this point we are expecting 1 work completion for the write */
printf("process_work_completion_events\n");
ret = process_work_completion_events(io_completion_channel, //函数等待并处理工作完成事件。
&wc, 1);
if(ret != 1) {
rdma_error("We failed to get 1 work completions , ret = %d \n",
ret);
return ret;
}
// ret = process_work_completion_events(io_completion_channel, //函数等待并处理工作完成事件。
// wc, total);
// if(ret != total) {
// rdma_error("We failed to get 1 work completions , ret = %d \n",
// ret);
// return ret;
// }

debug("Client side WRITE is complete \n");
for(int i=0;i<1;++i)
for(int i=0;i<total;++i)
{
client_send_sge.addr = (uint64_t) client_dst_mrs[i]->addr;
client_send_sge.length = (uint32_t) client_dst_mrs[i]->length;
Expand All @@ -780,8 +803,8 @@ static int client_remote_memory_ops_LLM_vec()
client_send_wr.opcode = IBV_WR_RDMA_READ;
client_send_wr.send_flags = IBV_SEND_SIGNALED;
/* we have to tell server side info for RDMA */ // 设置远程RDMA操作的相关信息,包括远程键和远程地址。
client_send_wr.wr.rdma.rkey = server_metadata_attr.stag.remote_stag;
client_send_wr.wr.rdma.remote_addr = server_metadata_attr.address;
client_send_wr.wr.rdma.rkey = server_metadata_attrs.stags[0].remote_stag;
client_send_wr.wr.rdma.remote_addr = server_metadata_attrs.address[0];
/* Now we post it */
ret = ibv_post_send(client_qp, //函数将发送请求发送到RDMA队列中。
&client_send_wr,
Expand All @@ -792,16 +815,21 @@ static int client_remote_memory_ops_LLM_vec()
return -errno;
}
/* Now we prepare a READ using same variables but for destination */ //将目标缓冲区的地址、长度和本地键赋值给client_send_sge结构体,表示接收的数据

sleep(5);
}


/* at this point we are expecting 1 work completion for the write */
ret = process_work_completion_events(io_completion_channel,
&wc, 1);
if(ret != 1) {
rdma_error("We failed to get 1 work completions , ret = %d \n",
ret);
return ret;
}

// ret = process_work_completion_events(io_completion_channel,
// wc, total);


// if(ret != total) {
// printf("We failed to get 1 work completions , ret = %d \n",
// ret);
// return ret;
// }
debug("Client side READ is complete \n");
return 0;
}
Expand Down Expand Up @@ -962,10 +990,10 @@ int main(int argc, char **argv) {
tensor_srcs.resize(total);
client_src_mrs.resize(total);
client_dst_mrs.resize(total);
client_metadata_attrs.address.resize(total);
client_metadata_attrs.length.resize(total);
client_metadata_attrs.stags.resize(total);
client_metadata_attrs.size = total;
// client_metadata_attrs.address.resize(total);
// client_metadata_attrs.length.resize(total);
// client_metadata_attrs.stags.resize(total);
// client_metadata_attrs.size = total;
for(int i=0;i<total;++i)
{
tensor_srcs[i] = ggml_new_tensor_1d(ctx,GGML_TYPE_F32,4);
Expand All @@ -989,7 +1017,7 @@ int main(int argc, char **argv) {
rdma_error("Failed to setup client connection , ret = %d \n", ret);
return ret;
}
ret = client_pre_post_recv_buffer();
ret = client_pre_post_recv_buffer_LLM_vec();
if (ret) {
rdma_error("Failed to setup client connection , ret = %d \n", ret);
return ret;
Expand All @@ -999,17 +1027,18 @@ int main(int argc, char **argv) {
rdma_error("Failed to setup client connection , ret = %d \n", ret);
return ret;
}
ret = client_xchange_metadata_with_server();
ret = client_xchange_metadata_with_server_LLM_vec();
if (ret) {
rdma_error("Failed to setup client connection , ret = %d \n", ret);
return ret;
}
ret = client_remote_memory_ops();
ret = client_remote_memory_ops_LLM_vec();
if (ret) {
rdma_error("Failed to finish remote memory ops, ret = %d \n", ret);
return ret;
}
if (check_src_dst()) {

if (check_src_dst_LLM_vec()) {
rdma_error("src and dst buffers do not match \n");
} else {
printf("...\nSUCCESS, source and destination buffers match \n");
Expand Down
Loading

0 comments on commit 9366e0c

Please sign in to comment.