Skip to content

Commit b0fe40c

Browse files
committed
DEVICE/API: Simplify createGpuXferReq to single-step API
Signed-off-by: Michal Shalev <mshalev@nvidia.com>
1 parent f7254f9 commit b0fe40c

File tree

6 files changed

+184
-55
lines changed

6 files changed

+184
-55
lines changed

src/api/cpp/nixl.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,22 @@ class nixlAgent {
322322
/**
323323
* @brief Create a GPU transfer request from a transfer request.
324324
*
325-
* @param req_hndl [in] Transfer request obtained from makeXferReq/createXferReq
326-
* @param gpu_req_hndl [out] GPU transfer request handle
327-
* @return nixl_status_t Error code if call was not successful
325+
*
326+
* @param local_descs [in] Local descriptor list (empty for signal-only case)
327+
* @param remote_descs [in] Remote descriptor list
328+
* @param remote_agent [in] Remote agent name for accessing the remote data
329+
* @param gpu_req_hndl [out] GPU transfer request handle
330+
* @param req_hndl [out] Transfer request handle
331+
* @param extra_params [in] Optional extra parameters
332+
* @return nixl_status_t Error code if call was not successful
328333
*/
329334
nixl_status_t
330-
createGpuXferReq(const nixlXferReqH &req_hndl, nixlGpuXferReqH &gpu_req_hndl) const;
335+
createGpuXferReq(const nixl_xfer_dlist_t &local_descs,
336+
const nixl_xfer_dlist_t &remote_descs,
337+
const std::string &remote_agent,
338+
nixlGpuXferReqH &gpu_req_hndl,
339+
nixlXferReqH *&req_hndl,
340+
const nixl_opt_args_t *extra_params = nullptr) const;
331341

332342
/**
333343
* @brief Release transfer request from GPU memory

src/core/nixl_agent.cpp

Lines changed: 129 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,25 +1217,143 @@ nixlAgent::releaseXferReq(nixlXferReqH *req_hndl) const {
12171217
}
12181218

12191219
nixl_status_t
1220-
nixlAgent::createGpuXferReq(const nixlXferReqH &req_hndl, nixlGpuXferReqH &gpu_req_hndl) const {
1221-
if (!req_hndl.engine) {
1222-
NIXL_ERROR_FUNC << "Invalid request handle[" << &req_hndl << "]: engine is null";
1223-
return NIXL_ERR_INVALID_PARAM;
1220+
nixlAgent::createGpuXferReq(const nixl_xfer_dlist_t &local_descs,
1221+
const nixl_xfer_dlist_t &remote_descs,
1222+
const std::string &remote_agent,
1223+
nixlGpuXferReqH &gpu_req_hndl,
1224+
nixlXferReqH *&req_hndl,
1225+
const nixl_opt_args_t *extra_params) const {
1226+
nixl_status_t ret1, ret2;
1227+
nixl_opt_b_args_t opt_args;
1228+
1229+
std::unique_ptr<backend_set_t> backend_set = std::make_unique<backend_set_t>();
1230+
1231+
req_hndl = nullptr;
1232+
1233+
NIXL_SHARED_LOCK_GUARD(data->lock);
1234+
1235+
if (data->remoteSections.count(remote_agent) == 0)
1236+
{
1237+
NIXL_ERROR_FUNC << "metadata for remote agent '" << remote_agent << "' not found";
1238+
data->addErrorTelemetry(NIXL_ERR_NOT_FOUND);
1239+
return NIXL_ERR_NOT_FOUND;
12241240
}
12251241

1226-
if (!req_hndl.backendHandle) {
1227-
NIXL_ERROR_FUNC << "Invalid request handle[" << &req_hndl << "]: backendHandle is null";
1242+
size_t total_bytes = 0;
1243+
if (local_descs.descCount() != remote_descs.descCount()) {
1244+
NIXL_ERROR_FUNC << "different descriptor list sizes (local=" << local_descs.descCount()
1245+
<< ", remote=" << remote_descs.descCount() << ")";
12281246
return NIXL_ERR_INVALID_PARAM;
12291247
}
1248+
for (int i = 0; i < local_descs.descCount(); ++i) {
1249+
if (local_descs[i].len != remote_descs[i].len) {
1250+
NIXL_ERROR_FUNC << "length mismatch at index " << i;
1251+
return NIXL_ERR_INVALID_PARAM;
1252+
}
1253+
total_bytes += local_descs[i].len;
1254+
}
12301255

1231-
NIXL_SHARED_LOCK_GUARD(data->lock);
1232-
const auto status = req_hndl.engine->createGpuXferReq(
1233-
*req_hndl.backendHandle, *req_hndl.initiatorDescs, *req_hndl.targetDescs, gpu_req_hndl);
1256+
if (!extra_params || extra_params->backends.size() == 0) {
1257+
// Finding backends that support the corresponding memories
1258+
// locally and remotely, and find the common ones.
1259+
backend_set_t* local_set =
1260+
data->memorySection->queryBackends(local_descs.getType());
1261+
backend_set_t* remote_set =
1262+
data->remoteSections[remote_agent]->queryBackends(
1263+
remote_descs.getType());
1264+
if (!local_set || !remote_set) {
1265+
NIXL_ERROR_FUNC << "no backends found for local or remote for their "
1266+
"corresponding memory type";
1267+
return NIXL_ERR_NOT_FOUND;
1268+
}
1269+
1270+
for (auto & elm : *local_set)
1271+
if (remote_set->count(elm) != 0)
1272+
backend_set->insert(elm);
1273+
1274+
if (backend_set->empty()) {
1275+
NIXL_ERROR_FUNC << "no potential backend found to be able to do the transfer";
1276+
return NIXL_ERR_NOT_FOUND;
1277+
}
1278+
} else {
1279+
for (auto & elm : extra_params->backends)
1280+
backend_set->insert(elm->engine);
1281+
}
1282+
1283+
std::unique_ptr<nixlXferReqH> handle = std::make_unique<nixlXferReqH>();
1284+
handle->initiatorDescs = new nixl_meta_dlist_t(local_descs.getType());
1285+
1286+
handle->targetDescs = new nixl_meta_dlist_t(remote_descs.getType());
1287+
1288+
for (auto & backend : *backend_set) {
1289+
ret1 = data->memorySection->populate(
1290+
local_descs, backend, *handle->initiatorDescs);
1291+
ret2 = data->remoteSections[remote_agent]->populate(
1292+
remote_descs, backend, *handle->targetDescs);
1293+
1294+
if ((ret1 == NIXL_SUCCESS) && (ret2 == NIXL_SUCCESS)) {
1295+
NIXL_INFO << "Selected backend: " << backend->getType();
1296+
handle->engine = backend;
1297+
break;
1298+
}
1299+
}
1300+
1301+
if (!handle->engine) {
1302+
NIXL_ERROR_FUNC << "no specified or potential backend had the required "
1303+
"registrations to be able to do the transfer";
1304+
data->addErrorTelemetry(NIXL_ERR_NOT_FOUND);
1305+
return NIXL_ERR_NOT_FOUND;
1306+
}
1307+
1308+
if (extra_params) {
1309+
if (extra_params->hasNotif) {
1310+
opt_args.notifMsg = extra_params->notifMsg;
1311+
opt_args.hasNotif = true;
1312+
}
1313+
1314+
if (extra_params->customParam.length() > 0)
1315+
opt_args.customParam = extra_params->customParam;
1316+
}
1317+
1318+
if (opt_args.hasNotif && (!handle->engine->supportsNotif())) {
1319+
NIXL_ERROR_FUNC << "the selected backend '" << handle->engine->getType()
1320+
<< "' does not support notifications";
1321+
data->addErrorTelemetry(NIXL_ERR_BACKEND);
1322+
return NIXL_ERR_BACKEND;
1323+
}
1324+
1325+
handle->remoteAgent = remote_agent;
1326+
handle->status = NIXL_ERR_NOT_POSTED;
1327+
handle->notifMsg = opt_args.notifMsg;
1328+
handle->hasNotif = opt_args.hasNotif;
1329+
1330+
if (data->telemetryEnabled) {
1331+
handle->telemetry.totalBytes = total_bytes;
1332+
handle->telemetry.descCount = handle->initiatorDescs->descCount();
1333+
}
1334+
1335+
ret1 = handle->engine->prepXfer (handle->backendOp,
1336+
*handle->initiatorDescs,
1337+
*handle->targetDescs,
1338+
handle->remoteAgent,
1339+
handle->backendHandle,
1340+
&opt_args);
1341+
if (ret1 != NIXL_SUCCESS) {
1342+
NIXL_ERROR_FUNC << "backend '" << handle->engine->getType()
1343+
<< "' failed to prepare the transfer request with status " << ret1;
1344+
data->addErrorTelemetry(ret1);
1345+
return ret1;
1346+
}
1347+
1348+
req_hndl = handle.release();
1349+
1350+
const auto status = req_hndl->engine->createGpuXferReq(
1351+
*req_hndl->backendHandle, *req_hndl->initiatorDescs, *req_hndl->targetDescs, gpu_req_hndl);
12341352
if (status == NIXL_SUCCESS) {
1235-
data->gpuReqToEngine.emplace(gpu_req_hndl, req_hndl.engine);
1353+
data->gpuReqToEngine.emplace(gpu_req_hndl, req_hndl->engine);
12361354
}
12371355

1238-
return status;
1356+
return NIXL_SUCCESS;
12391357
}
12401358

12411359
void

src/plugins/ucx/ucx_backend.cpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,16 +1633,6 @@ nixlUcxEngine::createGpuXferReq(const nixlBackendReqH &req_hndl,
16331633
nixlGpuXferReqH &gpu_req_hndl) const {
16341634
auto intHandle = static_cast<const nixlUcxBackendH *>(&req_hndl);
16351635

1636-
if (local_descs.descCount() != remote_descs.descCount()) {
1637-
NIXL_ERROR << "Mismatch between local and remote descriptor counts";
1638-
return NIXL_ERR_INVALID_PARAM;
1639-
}
1640-
1641-
if (local_descs.descCount() == 0) {
1642-
NIXL_ERROR << "Empty descriptor lists";
1643-
return NIXL_ERR_INVALID_PARAM;
1644-
}
1645-
16461636
auto remoteMd = static_cast<nixlUcxPublicMetadata *>(remote_descs[0].metadataP);
16471637
if (!remoteMd || !remoteMd->conn) {
16481638
NIXL_ERROR << "No connection found in remote metadata";
@@ -1655,9 +1645,11 @@ nixlUcxEngine::createGpuXferReq(const nixlBackendReqH &req_hndl,
16551645
std::vector<nixlUcxMem> local_mems;
16561646
std::vector<const nixl::ucx::rkey *> remote_rkeys;
16571647
std::vector<uint64_t> remote_addrs;
1648+
std::vector<size_t> remote_lengths;
16581649
local_mems.reserve(local_descs.descCount());
16591650
remote_rkeys.reserve(remote_descs.descCount());
16601651
remote_addrs.reserve(remote_descs.descCount());
1652+
remote_lengths.reserve(remote_descs.descCount());
16611653

16621654
for (size_t i = 0; i < static_cast<size_t>(local_descs.descCount()); i++) {
16631655
auto localMd = static_cast<nixlUcxPrivateMetadata *>(local_descs[i].metadataP);
@@ -1666,10 +1658,11 @@ nixlUcxEngine::createGpuXferReq(const nixlBackendReqH &req_hndl,
16661658
local_mems.push_back(localMd->mem);
16671659
remote_rkeys.push_back(&remoteMdDesc->getRkey(workerId));
16681660
remote_addrs.push_back(static_cast<uint64_t>(remote_descs[i].addr));
1661+
remote_lengths.push_back(remote_descs[i].len);
16691662
}
16701663

16711664
try {
1672-
gpu_req_hndl = nixl::ucx::createGpuXferReq(*ep, local_mems, remote_rkeys, remote_addrs);
1665+
gpu_req_hndl = nixl::ucx::createGpuXferReq(*ep, local_mems, remote_rkeys, remote_addrs, remote_lengths);
16731666
return NIXL_SUCCESS;
16741667
}
16751668
catch (const std::exception &e) {

src/utils/ucx/gpu_xfer_req_h.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,34 +35,43 @@ nixlGpuXferReqH
3535
createGpuXferReq(const nixlUcxEp &ep,
3636
const std::vector<nixlUcxMem> &local_mems,
3737
const std::vector<const nixl::ucx::rkey *> &remote_rkeys,
38-
const std::vector<uint64_t> &remote_addrs) {
38+
const std::vector<uint64_t> &remote_addrs,
39+
const std::vector<size_t> &remote_lengths) {
3940
nixl_status_t status = ep.checkTxState();
4041
if (status != NIXL_SUCCESS) {
4142
throw std::runtime_error("Endpoint not in valid state for creating memory list");
4243
}
4344

44-
if (local_mems.empty() || remote_rkeys.empty() || remote_addrs.empty()) {
45-
throw std::invalid_argument("Empty memory, rkey, or address lists provided");
46-
}
47-
48-
if (local_mems.size() != remote_rkeys.size() || local_mems.size() != remote_addrs.size()) {
49-
throw std::invalid_argument(
50-
"Local memory, remote rkey, and remote address lists must have same size");
51-
}
5245

5346
std::vector<ucp_device_mem_list_elem_t> ucp_elements;
5447
ucp_elements.reserve(local_mems.size());
5548

5649
for (size_t i = 0; i < local_mems.size(); i++) {
5750
ucp_device_mem_list_elem_t ucp_elem;
58-
ucp_elem.field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_MEMH |
59-
UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY | UCP_DEVICE_MEM_LIST_ELEM_FIELD_LOCAL_ADDR |
60-
UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR | UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH;
61-
ucp_elem.memh = local_mems[i].getMemh();
51+
bool has_local_mem = local_mems[i].getMemh() != nullptr;
52+
53+
if (has_local_mem) {
54+
/* Data element with local memory */
55+
ucp_elem.field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_MEMH |
56+
UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY |
57+
UCP_DEVICE_MEM_LIST_ELEM_FIELD_LOCAL_ADDR |
58+
UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR |
59+
UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH;
60+
ucp_elem.memh = local_mems[i].getMemh();
61+
ucp_elem.local_addr = local_mems[i].getBase();
62+
ucp_elem.length = local_mems[i].getSize();
63+
} else {
64+
/* Signal element without local memory */
65+
ucp_elem.field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY |
66+
UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR |
67+
UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH;
68+
ucp_elem.memh = nullptr;
69+
ucp_elem.local_addr = nullptr;
70+
ucp_elem.length = remote_lengths[i];
71+
}
72+
6273
ucp_elem.rkey = remote_rkeys[i]->get();
63-
ucp_elem.local_addr = local_mems[i].getBase();
6474
ucp_elem.remote_addr = remote_addrs[i];
65-
ucp_elem.length = local_mems[i].getSize();
6675
ucp_elements.push_back(ucp_elem);
6776
}
6877

src/utils/ucx/gpu_xfer_req_h.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ nixlGpuXferReqH
3232
createGpuXferReq(const nixlUcxEp &ep,
3333
const std::vector<nixlUcxMem> &local_mems,
3434
const std::vector<const nixl::ucx::rkey *> &remote_rkeys,
35-
const std::vector<uint64_t> &remote_addrs);
35+
const std::vector<uint64_t> &remote_addrs,
36+
const std::vector<size_t> &remote_lengths);
3637

3738
void
3839
releaseGpuXferReq(nixlGpuXferReqH gpu_req) noexcept;

test/gtest/device_api/single_write_test.cu

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ TEST_P(SingleWriteTest, BasicSingleWriteTest) {
385385
constexpr size_t count = 1;
386386
nixl_mem_t mem_type = VRAM_SEG;
387387
size_t num_threads = 32;
388-
const size_t num_iters = 10000;
388+
const size_t num_iters = 10;
389389
constexpr unsigned index = 0;
390390
const bool is_no_delay = true;
391391

@@ -405,22 +405,21 @@ TEST_P(SingleWriteTest, BasicSingleWriteTest) {
405405
extra_params.notifMsg = NOTIF_MSG;
406406

407407
nixlXferReqH *xfer_req = nullptr;
408+
nixlGpuXferReqH gpu_req_hndl;
409+
408410
nixl_status_t status = getAgent(SENDER_AGENT)
409-
.createXferReq(NIXL_WRITE,
411+
.createGpuXferReq(
410412
makeDescList<nixlBasicDesc>(src_buffers, mem_type),
411413
makeDescList<nixlBasicDesc>(dst_buffers, mem_type),
412414
getAgentName(RECEIVER_AGENT),
415+
gpu_req_hndl,
413416
xfer_req,
414417
&extra_params);
415418

416419
ASSERT_EQ(status, NIXL_SUCCESS)
417-
<< "Failed to create xfer request " << nixlEnumStrings::statusStr(status);
420+
<< "Failed to create GPU xfer request " << nixlEnumStrings::statusStr(status);
418421
EXPECT_NE(xfer_req, nullptr);
419422

420-
nixlGpuXferReqH gpu_req_hndl;
421-
status = getAgent(SENDER_AGENT).createGpuXferReq(*xfer_req, gpu_req_hndl);
422-
ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to create GPU xfer request";
423-
424423
ASSERT_NE(gpu_req_hndl, nullptr) << "GPU request handle is null after createGpuXferReq";
425424

426425
size_t src_offset = 0;
@@ -485,7 +484,7 @@ TEST_P(SingleWriteTest, VariableSizeTest) {
485484
constexpr size_t count = 1;
486485
nixl_mem_t mem_type = VRAM_SEG;
487486
size_t num_threads = 32;
488-
const size_t num_iters = 50000;
487+
const size_t num_iters = 10;
489488
constexpr unsigned index = 0;
490489
const bool is_no_delay = true;
491490

@@ -507,19 +506,18 @@ TEST_P(SingleWriteTest, VariableSizeTest) {
507506
extra_params.notifMsg = NOTIF_MSG;
508507

509508
nixlXferReqH *xfer_req = nullptr;
509+
nixlGpuXferReqH gpu_req_hndl;
510+
510511
nixl_status_t status =
511512
getAgent(SENDER_AGENT)
512-
.createXferReq(NIXL_WRITE,
513+
.createGpuXferReq(
513514
makeDescList<nixlBasicDesc>(src_buffers, mem_type),
514515
makeDescList<nixlBasicDesc>(dst_buffers, mem_type),
515516
getAgentName(RECEIVER_AGENT),
517+
gpu_req_hndl,
516518
xfer_req,
517519
&extra_params);
518520

519-
ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to create xfer request for size " << test_size;
520-
521-
nixlGpuXferReqH gpu_req_hndl;
522-
status = getAgent(SENDER_AGENT).createGpuXferReq(*xfer_req, gpu_req_hndl);
523521
ASSERT_EQ(status, NIXL_SUCCESS)
524522
<< "Failed to create GPU xfer request for size " << test_size;
525523

0 commit comments

Comments
 (0)