Skip to content

Commit acc7d8e

Browse files
Load md fixes (#969)
* libfabric: Fix creation of localMD When creating MD for local operations populate the remote selected endpoints array. Remove dedicated code for local operations on the same agent process to use general flow now that localMD is created correctly Signed-off-by: Amit Radzi <amitrad@amazon.com> * libfabric: Refactor load MD functions Create a helper function for the common part of loadLocalMD and loadRemoteMD Signed-off-by: Amit Radzi <amitrad@amazon.com> --------- Signed-off-by: Amit Radzi <amitrad@amazon.com> Co-authored-by: Adit Ranadive <aranadive@nvidia.com>
1 parent a43fc9e commit acc7d8e

File tree

2 files changed

+29
-58
lines changed

2 files changed

+29
-58
lines changed

src/plugins/libfabric/libfabric_backend.cpp

Lines changed: 24 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -825,24 +825,32 @@ nixlLibfabricEngine::getPublicData(const nixlBackendMD *meta, std::string &str)
825825
}
826826

827827
nixl_status_t
828-
nixlLibfabricEngine::loadLocalMD(nixlBackendMD *input, nixlBackendMD *&output) {
829-
nixlLibfabricPrivateMetadata *input_md = static_cast<nixlLibfabricPrivateMetadata *>(input);
828+
nixlLibfabricEngine::loadMetadataHelper(const std::vector<uint64_t> &rail_keys,
829+
void *buffer,
830+
std::shared_ptr<nixlLibfabricConnection> conn,
831+
nixlBackendMD *&output) {
830832
auto pub_md = std::make_unique<nixlLibfabricPublicMetadata>();
831-
// Store all rail keys instead of just the first one
832-
pub_md->rail_remote_key_list_.reserve(input_md->rail_key_list_.size());
833-
for (size_t rail_id = 0; rail_id < input_md->rail_key_list_.size(); ++rail_id) {
834-
pub_md->rail_remote_key_list_.push_back(input_md->rail_key_list_[rail_id]);
835-
NIXL_DEBUG << "Added rail " << rail_id << " key: " << input_md->rail_key_list_[rail_id];
836-
}
837833

838-
pub_md->remote_buf_addr_ = reinterpret_cast<uint64_t>(input_md->buffer_);
839-
pub_md->conn_ = connections_[localAgent];
834+
pub_md->rail_remote_key_list_ = std::move(rail_keys);
835+
pub_md->derive_remote_selected_endpoints();
836+
pub_md->remote_buf_addr_ = reinterpret_cast<uint64_t>(buffer);
837+
pub_md->conn_ = conn;
840838

841839
output = pub_md.release();
842-
NIXL_DEBUG << "Loading Local MD with " << input_md->rail_key_list_.size() << " rail keys";
840+
NIXL_DEBUG << "Metadata loaded with"
841+
<< " Remote addr: " << (void *)pub_md->remote_buf_addr_ << " Remote keys for "
842+
<< pub_md->rail_remote_key_list_.size() << " rails"
843+
<< " Remote fi_addr: " << pub_md->conn_->rail_remote_addr_list_[0][0];
843844
return NIXL_SUCCESS;
844845
}
845846

847+
nixl_status_t
848+
nixlLibfabricEngine::loadLocalMD(nixlBackendMD *input, nixlBackendMD *&output) {
849+
nixlLibfabricPrivateMetadata *input_md = static_cast<nixlLibfabricPrivateMetadata *>(input);
850+
return loadMetadataHelper(
851+
input_md->rail_key_list_, input_md->buffer_, connections_[localAgent], output);
852+
}
853+
846854
nixl_status_t
847855
nixlLibfabricEngine::loadRemoteMD(const nixlBlobDesc &input,
848856
const nixl_mem_t &nixl_mem,
@@ -869,19 +877,8 @@ nixlLibfabricEngine::loadRemoteMD(const nixlBlobDesc &input,
869877
return status;
870878
}
871879

872-
// Engine handles connection management and metadata object creation
873-
auto pub_md = std::make_unique<nixlLibfabricPublicMetadata>();
874-
pub_md->conn_ = conn_it->second;
875-
pub_md->rail_remote_key_list_ = std::move(remote_keys);
876-
pub_md->derive_remote_selected_endpoints();
877-
pub_md->remote_buf_addr_ = remote_addr;
878-
NIXL_DEBUG << "Remote metadata loaded with"
879-
<< " Remote addr: " << (void *)pub_md->remote_buf_addr_ << " Remote keys for "
880-
<< pub_md->rail_remote_key_list_.size() << " rails"
881-
<< " Remote fi_addr: " << pub_md->conn_->rail_remote_addr_list_[0][0];
882-
883-
output = pub_md.release();
884-
return NIXL_SUCCESS;
880+
return loadMetadataHelper(
881+
remote_keys, reinterpret_cast<void *>(remote_addr), conn_it->second, output);
885882
}
886883

887884
nixl_status_t
@@ -1041,36 +1038,6 @@ nixlLibfabricEngine::postXfer(const nixl_xfer_op_t &operation,
10411038
NIXL_DEBUG << "DEBUG: remote_agent='" << remote_agent << "' localAgent='" << localAgent
10421039
<< "'";
10431040

1044-
// Check for same-agent (local) transfer - handle with direct memcpy
1045-
if (remote_agent == localAgent) {
1046-
NIXL_DEBUG << "Same-agent transfer detected from localAgent= " << localAgent
1047-
<< "to remote_agent " << remote_agent << "for descriptor " << desc_idx
1048-
<< ", using memcpy fallback for " << transfer_size << " bytes";
1049-
1050-
// For same-agent transfers, we need to copy directly between the descriptor addresses
1051-
// The remote[desc_idx].addr should be the target address for the transfer
1052-
void *remote_addr = reinterpret_cast<void *>(remote[desc_idx].addr);
1053-
1054-
NIXL_DEBUG << "About to perform memcpy: local_addr=" << transfer_addr
1055-
<< " remote_addr=" << remote_addr << " size=" << transfer_size;
1056-
1057-
if (op_type == nixlLibfabricReq::WRITE) {
1058-
// Write: copy from local_addr to remote_addr
1059-
std::memcpy(remote_addr, transfer_addr, transfer_size);
1060-
NIXL_DEBUG << "Same-agent memcpy write completed: " << transfer_addr << " -> "
1061-
<< remote_addr << " (" << transfer_size << " bytes)";
1062-
} else {
1063-
// Read: copy from remote_addr to local_addr
1064-
std::memcpy(transfer_addr, remote_addr, transfer_size);
1065-
NIXL_DEBUG << "Same-agent memcpy read completed: " << remote_addr << " -> "
1066-
<< transfer_addr << " (" << transfer_size << " bytes)";
1067-
}
1068-
1069-
NIXL_DEBUG << "Successfully processed same-agent descriptor " << desc_idx
1070-
<< " using memcpy fallback";
1071-
continue; // Skip the rail manager transfer for this descriptor
1072-
}
1073-
10741041
// Prepare and submit transfer for remote agents
10751042
// Use descriptor's specific target address
10761043
uint64_t remote_target_addr = remote[desc_idx].addr;
@@ -1104,8 +1071,9 @@ nixlLibfabricEngine::postXfer(const nixl_xfer_op_t &operation,
11041071

11051072
NIXL_DEBUG << "Processing complete: submitted "
11061073
<< backend_handle->binary_notif.expected_completions << " requests from "
1107-
<< desc_count << " descriptors" << " with "
1108-
<< backend_handle->binary_notif.expected_completions << " total XFER_IDs";
1074+
<< desc_count << " descriptors"
1075+
<< " with " << backend_handle->binary_notif.expected_completions
1076+
<< " total XFER_IDs";
11091077

11101078
// For same-agent transfers, we need to set the total to 0 since we bypassed all rail operations
11111079
if (remote_agent == localAgent) {
@@ -1578,7 +1546,6 @@ nixlLibfabricEngine::addReceivedXferId(uint16_t xfer_id) {
15781546
checkPendingNotifications();
15791547
}
15801548

1581-
15821549
/****************************************
15831550
* Notification Queuing Helper Methods
15841551
*****************************************/

src/plugins/libfabric/libfabric_backend.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,11 @@ class nixlLibfabricEngine : public nixlBackendEngine {
286286
processConnectionRequest(uint16_t agent_idx,
287287
const std::string &serialized_data,
288288
nixlLibfabricRail *rail);
289-
289+
nixl_status_t
290+
loadMetadataHelper(const std::vector<uint64_t> &rail_keys,
291+
void *buffer,
292+
std::shared_ptr<nixlLibfabricConnection> conn,
293+
nixlBackendMD *&output);
290294

291295
#ifdef HAVE_CUDA
292296
// CUDA context management methods

0 commit comments

Comments
 (0)