diff --git a/mooncake-common/common.cmake b/mooncake-common/common.cmake index c60c17353..02c65cc82 100644 --- a/mooncake-common/common.cmake +++ b/mooncake-common/common.cmake @@ -59,13 +59,13 @@ option(BUILD_UNIT_TESTS "Build uint tests" ON) option(USE_CUDA "option for enabling gpu features" OFF) option(USE_NVMEOF "option for using NVMe over Fabric" OFF) option(USE_TCP "option for using TCP transport" ON) -option(USE_ASCEND "option for using npu" OFF) +option(USE_ASCEND "option for using npu" ON) option(USE_MNNVL "option for using Multi-Node NVLink transport" OFF) option(USE_CXL "option for using CXL protocol" OFF) option(USE_ETCD "option for enable etcd as metadata server" OFF) option(USE_ETCD_LEGACY "option for enable etcd based on etcd-cpp-api-v3" OFF) option(USE_REDIS "option for enable redis as metadata server" OFF) -option(USE_HTTP "option for enable http as metadata server" ON) +option(USE_HTTP "option for enable http as metadata server" OFF) option(WITH_RUST_EXAMPLE "build the Rust interface and sample code for the transfer engine" OFF) option(WITH_METRICS "enable metrics and metrics reporting thread" ON) option(USE_3FS "option for using 3FS storage backend" OFF) @@ -115,7 +115,13 @@ if (USE_ASCEND) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DOPEN_BUILD_PROJECT ") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DOPEN_BUILD_PROJECT ") - file(GLOB ASCEND_TOOLKIT_ROOT "/usr/local/Ascend/ascend-toolkit/latest/*-linux") + if("$ENV{ASCEND_TOOLKIT_PATH}" STREQUAL "") + set(ASCEND_TOOLKIT_ROOT "/usr/local/Ascend/ascend-toolkit/latest/*-linux") + else() + set(ASCEND_TOOLKIT_ROOT $ENV{ASCEND_TOOLKIT_PATH}) + endif() + + file(GLOB ASCEND_TOOLKIT_ROOT "${ASCEND_TOOLKIT_ROOT}") set(ASCEND_LIB_DIR "${ASCEND_TOOLKIT_ROOT}/lib64") set(ASCEND_INCLUDE_DIR "${ASCEND_TOOLKIT_ROOT}/include") add_compile_definitions(USE_ASCEND) diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index e89dd44bd..710bdbee5 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -154,7 +154,7 @@ tl::expected DistributedObjectStore::setup_internal( } else { this->local_hostname = local_hostname; } - + LOG(ERROR) << "setup_internal local_hostname:" << this->local_hostname; void **args = (protocol == "rdma") ? rdma_args(rdma_devices) : nullptr; auto client_opt = mooncake::Client::Create(this->local_hostname, metadata_server, @@ -164,24 +164,6 @@ tl::expected DistributedObjectStore::setup_internal( return tl::unexpected(ErrorCode::INVALID_PARAMS); } client_ = *client_opt; - - // Local_buffer_size is allowed to be 0, but we only register memory when - // local_buffer_size > 0. Invoke ibv_reg_mr() with size=0 is UB, and may - // fail in some rdma implementations. - client_buffer_allocator_ = ClientBufferAllocator::create(local_buffer_size); - if (local_buffer_size > 0) { - auto result = client_->RegisterLocalMemory( - client_buffer_allocator_->getBase(), local_buffer_size, - kWildcardLocation, false, true); - if (!result.has_value()) { - LOG(ERROR) << "Failed to register local memory: " - << toString(result.error()); - return tl::unexpected(result.error()); - } - } else { - LOG(INFO) << "Local buffer size is 0, skip registering local memory"; - } - // If global_segment_size is 0, skip mount segment; // If global_segment_size is larger than max_mr_size, split to multiple // segments. @@ -1000,6 +982,21 @@ std::vector DistributedObjectStore::batch_put_from( return results; } +std::vector DistributedObjectStore::batch_put_from_ascend( + const std::string key, const std::vector &buffers, + const std::vector &sizes, const ReplicateConfig &config) { + auto internal_results = + batch_put_from_internal_ascend(key, buffers, sizes, config); + std::vector results; + results.reserve(internal_results.size()); + + for (const auto &result : internal_results) { + results.push_back(to_py_ret(result)); + } + + return results; +} + std::vector DistributedObjectStore::batch_get_into( const std::vector &keys, const std::vector &buffers, const std::vector &sizes) { @@ -1014,6 +1011,25 @@ std::vector DistributedObjectStore::batch_get_into( return results; } +std::vector DistributedObjectStore::batch_get_into_ascend( + const std::string key, const std::vector &buffers, + const std::vector &sizes) { + // auto start = std::chrono::high_resolution_clock::now(); + + auto internal_results = batch_get_into_internal_ascend(key, buffers, sizes); + std::vector results; + results.reserve(internal_results.size()); + + for (const auto &result : internal_results) { + results.push_back(to_py_ret(result)); + } + // auto stop = std::chrono::high_resolution_clock::now(); + // auto duration_call = + // std::chrono::duration_cast(stop - start); + // LOG(INFO) << "key: " << key << ", batch_get_into_ascend: " << duration_call.count() << "us"; + return results; +} + tl::expected DistributedObjectStore::put_from_internal( const std::string &key, void *buffer, size_t size, const ReplicateConfig &config) { @@ -1201,6 +1217,152 @@ DistributedObjectStore::batch_get_into_internal( return results; } +std::vector> +DistributedObjectStore::batch_get_into_internal_ascend( + const std::string key, const std::vector &buffers, + const std::vector &sizes) { + // LOG(INFO) << "GET KEY start: " << key; + // Validate preconditions + if (!client_) { + LOG(ERROR) << "Client is not initialized"; + return std::vector>( + 1, tl::unexpected(ErrorCode::INVALID_PARAMS)); + } + + if (buffers.size() != sizes.size()) { + LOG(ERROR) << "Input vector sizes mismatch: keys=" << 1 + << ", buffers=" << buffers.size() + << ", sizes=" << sizes.size(); + return std::vector>( + 1, tl::unexpected(ErrorCode::INVALID_PARAMS)); + } + + const size_t num_keys = 1; + std::vector> results; + results.reserve(num_keys); + + if (num_keys == 0) { + return results; + } + std::vector keys; + keys.reserve(1); + keys.emplace_back(key); + // Query metadata for all keys + const auto query_results = client_->BatchQuery(keys); + + // Process each key individually and prepare for batch transfer + struct ValidKeyInfo { + std::string key; + size_t original_index; + std::vector replica_list; + std::vector slices; + uint64_t total_size; + }; + + std::vector valid_operations; + valid_operations.reserve(num_keys); + + for (size_t i = 0; i < num_keys; ++i) { + const auto &key = keys[i]; + + // Handle query failures + if (!query_results[i]) { + const auto error = query_results[i].error(); + results.emplace_back(tl::unexpected(error)); + if (error != ErrorCode::OBJECT_NOT_FOUND) { + LOG(ERROR) << "Query failed for key '" << key + << "': " << toString(error); + } + continue; + } + + // Validate replica list + auto replica_list = query_results[i].value(); + if (replica_list.empty()) { + LOG(ERROR) << "Empty replica list for key: " << key; + results.emplace_back(tl::unexpected(ErrorCode::INVALID_REPLICA)); + continue; + } + + // Calculate required buffer size + const auto &replica = replica_list[0]; + uint64_t total_size = calculate_total_size(replica); + int total_key_size = 0; + for (size_t k = 0; k < sizes.size(); ++k) { + total_key_size += sizes[k]; + } + // LOG(INFO) << "KEY: '" << key + // << "': required=" << total_size + // << ", available=" << total_key_size; + // Validate buffer capacity + if (total_key_size < total_size) { + LOG(ERROR) << "Buffer too small for key '" << key + << "': required=" << total_size + << ", available=" << total_key_size; + results.emplace_back(tl::unexpected(ErrorCode::INVALID_PARAMS)); + continue; + } + std::vector key_slices; + // Create slices for this key's buffer + for (size_t j = 0; j < buffers.size(); ++j) { + uint64_t offset = 0; + if (replica.is_memory_replica() == false) { + key_slices.emplace_back(Slice{buffers[j], sizes[j]}); + } else { + key_slices.emplace_back(Slice{buffers[j], sizes[j]}); + } + } + + // Store operation info for batch processing + valid_operations.push_back({.key = key, + .original_index = i, + .replica_list = std::move(replica_list), + .slices = std::move(key_slices), + .total_size = total_size}); + + // Set success result (actual bytes transferred) + results.emplace_back(static_cast(total_size)); + } + + // Early return if no valid operations + if (valid_operations.empty()) { + return results; + } + + // Prepare batch transfer data structures + std::vector batch_keys; + std::vector> batch_replica_lists; + std::unordered_map> batch_slices; + + batch_keys.reserve(valid_operations.size()); + batch_replica_lists.reserve(valid_operations.size()); + + for (const auto &op : valid_operations) { + batch_keys.push_back(op.key); + batch_replica_lists.push_back(op.replica_list); + batch_slices[op.key] = op.slices; + } + + // Execute batch transfer + const auto batch_get_results = + client_->BatchGet(batch_keys, batch_replica_lists, batch_slices); + + // Process transfer results + for (size_t j = 0; j < batch_get_results.size(); ++j) { + const auto &op = valid_operations[j]; + + if (!batch_get_results[j]) { + const auto error = batch_get_results[j].error(); + LOG(ERROR) << "BatchGet failed for key '" << op.key + << "': " << toString(error); + results[op.original_index] = tl::unexpected(error); + } + } + // LOG(INFO) << "GET KEY end: " << key << "end"; + + return results; +} + std::vector> DistributedObjectStore::batch_put_from_internal( const std::vector &keys, const std::vector &buffers, @@ -1255,6 +1417,47 @@ DistributedObjectStore::batch_put_from_internal( return client_->BatchPut(keys, ordered_batched_slices, config); } +std::vector> +DistributedObjectStore::batch_put_from_internal_ascend( + const std::string key, const std::vector &buffers, + const std::vector &sizes, const ReplicateConfig &config) { + LOG(INFO) << "PUT KEY start: " << key; + if (!client_) { + LOG(ERROR) << "Client is not initialized"; + return std::vector>( + 1, tl::unexpected(ErrorCode::INVALID_PARAMS)); + } + + if (buffers.size() != sizes.size()) { + LOG(ERROR) << "Mismatched sizes for key, buffers, and sizes"; + return std::vector>( + 1, tl::unexpected(ErrorCode::INVALID_PARAMS)); + } + + // std::unordered_map> all_slices; + + // Create slices from user buffers + std::vector slices; + slices.reserve(buffers.size()); + for (size_t i = 0; i < buffers.size(); ++i) { + void *buffer = buffers[i]; + size_t size = sizes[i]; + slices.emplace_back(Slice{buffer, size}); + } + std::vector> ordered_batched_slices; + ordered_batched_slices.reserve(1); + ordered_batched_slices.emplace_back(slices); + + std::vector keys; + keys.reserve(1); + keys.emplace_back(key); + LOG(ERROR) << "batch put keys size:" << keys.size() << ", ordered_batched_slices size:" << ordered_batched_slices.size() + << ", slice size len:" << slices.size(); + + // Call client BatchPut and return the vector directly + return client_->BatchPut(keys, ordered_batched_slices, config); +} + std::vector> DistributedObjectStore::batchIsExist_internal( const std::vector &keys) { @@ -1778,7 +1981,45 @@ PYBIND11_MODULE(store, m) { }, py::arg("keys"), py::arg("values"), py::arg("config") = ReplicateConfig{}) - .def("get_hostname", &DistributedObjectStore::get_hostname); + .def("get_hostname", &DistributedObjectStore::get_hostname) + .def( + "batch_put_from_ascend", + [](DistributedObjectStore &self, + const std::string key, + const std::vector &buffer_ptrs, + const std::vector &sizes, + const ReplicateConfig &config = ReplicateConfig{}) { + std::vector buffers; + buffers.reserve(buffer_ptrs.size()); + for (uintptr_t ptr : buffer_ptrs) { + buffers.push_back(reinterpret_cast(ptr)); + } + py::gil_scoped_release release; + return self.batch_put_from_ascend(key, buffers, sizes, config); + }, + py::arg("keys"), py::arg("buffer_ptrs"), py::arg("sizes"), + py::arg("config") = ReplicateConfig{}, + "Put object data directly from pre-allocated buffers for " + "multiple " + "keys") + .def( + "batch_get_into_ascend", + [](DistributedObjectStore &self, + const std::string key, + const std::vector &buffer_ptrs, + const std::vector &sizes) { + std::vector buffers; + buffers.reserve(buffer_ptrs.size()); + for (uintptr_t ptr : buffer_ptrs) { + buffers.push_back(reinterpret_cast(ptr)); + } + py::gil_scoped_release release; + return self.batch_get_into_ascend(key, buffers, sizes); + }, + py::arg("keys"), py::arg("buffer_ptrs"), py::arg("sizes"), + "Get object data directly into pre-allocated buffers for " + "multiple " + "keys"); } } // namespace mooncake diff --git a/mooncake-integration/store/store_py.h b/mooncake-integration/store/store_py.h index f5df34e71..2f9a3d837 100644 --- a/mooncake-integration/store/store_py.h +++ b/mooncake-integration/store/store_py.h @@ -30,10 +30,13 @@ int64_t to_py_ret(const tl::expected &exp) noexcept { } if constexpr (std::is_void_v) { + // LOG(WARNING) << "is_void_v"; return 0; } else if constexpr (std::is_integral_v) { + // LOG(WARNING) << "is_integral bane"; return static_cast(exp.value()); } else { + // LOG(WARNING) << "Unsupported payload type in to_py_ret()"; static_assert(!sizeof(T), "Unsupported payload type in to_py_ret()"); } } @@ -137,6 +140,19 @@ class DistributedObjectStore { const std::vector &buffers, const std::vector &sizes); + /** + * @brief Get object data directly into pre-allocated buffers for multiple + * keys(batch version) on Ascend NPU + * @param keys Key of the objects to get + * @param buffers Vector of pointers to the pre-allocated buffers + * @param sizes Vector of sizes of the buffers + * @return Vector of integers, where each element is the number of bytes + * read on success, or a negative value on error + */ + std::vector batch_get_into_ascend(const std::string key, + const std::vector &buffers, + const std::vector &sizes); + /** * @brief Put object data directly from a pre-allocated buffer * @param key Key of the object to put @@ -188,6 +204,23 @@ class DistributedObjectStore { const std::vector &buffers, const std::vector &sizes, const ReplicateConfig &config = ReplicateConfig{}); + /** + * @brief Put object data directly from pre-allocated buffers for multiple + * keys (batch version) on Ascend NPU + * @param keys keys of the objects to put + * @param buffers Vector of pointers to the pre-allocated buffers + * @param sizes Vector of sizes of the buffers + * @param config Replication configuration + * @return Vector of integers, where each element is 0 on success, or a + * negative value on error + * @note The buffer addresses must be previously registered with + * register_buffer() for zero-copy operations + */ + std::vector batch_put_from_ascend( + const std::string keys, + const std::vector &buffers, const std::vector &sizes, + const ReplicateConfig &config = ReplicateConfig{}); + int put_parts(const std::string &key, std::vector> values, const ReplicateConfig &config = ReplicateConfig{}); @@ -293,6 +326,10 @@ class DistributedObjectStore { const std::vector &keys, const std::vector &buffers, const std::vector &sizes); + std::vector> batch_get_into_internal_ascend( + const std::string key, + const std::vector &buffers, const std::vector &sizes); + tl::expected put_from_internal( const std::string &key, void *buffer, size_t size, const ReplicateConfig &config = ReplicateConfig{}); @@ -302,6 +339,11 @@ class DistributedObjectStore { const std::vector &buffers, const std::vector &sizes, const ReplicateConfig &config = ReplicateConfig{}); + std::vector> batch_put_from_internal_ascend( + const std::string key, + const std::vector &buffers, const std::vector &sizes, + const ReplicateConfig &config = ReplicateConfig{}); + tl::expected put_parts_internal( const std::string &key, std::vector> values, const ReplicateConfig &config = ReplicateConfig{}); diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.cpp b/mooncake-integration/transfer_engine/transfer_engine_py.cpp index b99a24cfd..dcff5103c 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.cpp +++ b/mooncake-integration/transfer_engine/transfer_engine_py.cpp @@ -116,7 +116,8 @@ int TransferEnginePy::initializeExt(const char *local_hostname, auto device_name_safe = device_name ? std::string(device_name) : ""; auto device_filter = buildDeviceFilter(device_name_safe); - engine_ = std::make_unique(true, device_filter); + engine_ = std::make_shared(true, device_filter); + LOG(INFO) << "TransferEnginePy InitTransferEngine"; if (getenv("MC_LEGACY_RPC_PORT_BINDING")) { auto hostname_port = parseHostNameWithPort(local_hostname); int ret = @@ -129,6 +130,8 @@ int TransferEnginePy::initializeExt(const char *local_hostname, if (ret) return -1; } + g_transfer_engine = engine_; + LOG(INFO) << "TransferEnginePy InitTransferEngine end: " << g_transfer_engine; free_list_.resize(kSlabSizeKBTabLen); #ifndef USE_ASCEND doBuddyAllocate(kMaxClassId); @@ -291,6 +294,7 @@ int TransferEnginePy::transferSync(const char *target_hostname, entry.target_id = handle; entry.target_offset = peer_buffer_address; entry.advise_retry_cnt = retry; + entry.target_offset_type = 1; Status s = engine_->submitTransfer(batch_id, {entry}); if (!s.ok()) return -1; @@ -367,6 +371,7 @@ int TransferEnginePy::batchTransferSync( entry.target_id = handle; entry.target_offset = peer_buffer_addresses[i]; entry.advise_retry_cnt = 0; + entry.target_offset_type = 1; entries.push_back(entry); } @@ -454,6 +459,9 @@ batch_id_t TransferEnginePy::batchTransferAsync( entry.target_id = handle; entry.target_offset = peer_buffer_addresses[i]; entry.advise_retry_cnt = 0; +#ifndef USE_ASCEND + entry.target_offset_type = 1; +#endif entries.push_back(entry); } diff --git a/mooncake-store/include/client.h b/mooncake-store/include/client.h index 7cf8fb482..3f1f51213 100644 --- a/mooncake-store/include/client.h +++ b/mooncake-store/include/client.h @@ -256,7 +256,7 @@ class Client { const std::vector& ops); // Core components - TransferEngine transfer_engine_; + std::shared_ptr transfer_engine_; MasterClient master_client_; std::unique_ptr transfer_submitter_; diff --git a/mooncake-store/include/transfer_task.h b/mooncake-store/include/transfer_task.h index c88c8863e..51305df10 100644 --- a/mooncake-store/include/transfer_task.h +++ b/mooncake-store/include/transfer_task.h @@ -239,9 +239,9 @@ struct MemcpyOperation { void* dest; const void* src; size_t size; - - MemcpyOperation(void* d, const void* s, size_t sz) - : dest(d), src(s), size(sz) {} + bool is_write; + MemcpyOperation(void* d, const void* s, size_t sz, bool w) + : dest(d), src(s), size(sz), is_write(w) {} }; /** @@ -280,7 +280,7 @@ class MemcpyWorkerPool { void submitTask(MemcpyTask task); private: - void workerThread(); + void workerThread(int deviceLogicId); std::vector workers_; std::queue task_queue_; diff --git a/mooncake-store/src/client.cpp b/mooncake-store/src/client.cpp index 33a8d799c..0110dcf4c 100644 --- a/mooncake-store/src/client.cpp +++ b/mooncake-store/src/client.cpp @@ -190,34 +190,57 @@ ErrorCode Client::ConnectToMaster(const std::string& master_server_entry) { } } +std::vector buildDeviceFilter(const std::string &device_names) { + std::stringstream ss(device_names); + std::string item; + std::vector tokens; + while (getline(ss, item, ',')) { + tokens.push_back(item); + } + return tokens; +} + ErrorCode Client::InitTransferEngine(const std::string& local_hostname, const std::string& metadata_connstring, const std::string& protocol, void** protocol_args) { // get auto_discover and filters from env bool auto_discover = get_auto_discover(); - transfer_engine_.setAutoDiscover(auto_discover); - transfer_engine_.setWhitelistFilters( - get_auto_discover_filters(auto_discover)); - auto [hostname, port] = parseHostNameWithPort(local_hostname); - int rc = transfer_engine_.init(metadata_connstring, local_hostname, - hostname, port); - CHECK_EQ(rc, 0) << "Failed to initialize transfer engine"; + LOG(INFO) << "Pooling InitTransferEngine:" << g_transfer_engine; + if (g_transfer_engine) { + transfer_engine_= g_transfer_engine; + LOG(INFO) << "Pooling multiplexing transferEngine"; + } else { + auto [hostname, port] = parseHostNameWithPort(local_hostname); + std::string device_name_safe = ""; + auto device_filter = buildDeviceFilter(device_name_safe); + transfer_engine_ = std::make_shared(true, device_filter); + int rc = transfer_engine_->init(metadata_connstring, local_hostname, + hostname, port); + CHECK_EQ(rc, 0) << "Failed to initialize transfer engine"; + g_transfer_engine = transfer_engine_; + } + transfer_engine_->setAutoDiscover(auto_discover); + transfer_engine_->setWhitelistFilters( + get_auto_discover_filters(auto_discover)); Transport* transport = nullptr; if (protocol == "rdma") { LOG(INFO) << "transport_type=rdma"; - transport = transfer_engine_.installTransport("rdma", protocol_args); + transport = transfer_engine_->installTransport("rdma", protocol_args); } else if (protocol == "tcp") { LOG(INFO) << "transport_type=tcp"; try { - transport = transfer_engine_.installTransport("tcp", protocol_args); + transport = transfer_engine_->installTransport("tcp", protocol_args); } catch (std::exception& e) { LOG(ERROR) << "tcp_transport_install_failed error_message=\"" << e.what() << "\""; return ErrorCode::INTERNAL_ERROR; } + } else if (protocol == "ascend") { + LOG(INFO) << "transport protocol=" << protocol; + transport = transfer_engine_->installTransport("ascend", protocol_args); } else { LOG(ERROR) << "unsupported_protocol protocol=" << protocol; return ErrorCode::INVALID_PARAMS; @@ -228,9 +251,10 @@ ErrorCode Client::InitTransferEngine(const std::string& local_hostname, } // Initialize TransferSubmitter after transfer engine is ready + // transfer_submitter_ = std::make_unique( + // transfer_engine_, local_hostname, storage_backend_); transfer_submitter_ = std::make_unique( - transfer_engine_, local_hostname, storage_backend_); - + *transfer_engine_, local_hostname, storage_backend_); return ErrorCode::OK; } @@ -244,8 +268,16 @@ std::optional> Client::Create( ? std::getenv("MOONCAKE_STORAGE_ROOT_DIR") : ""; - auto client = std::shared_ptr( - new Client(local_hostname, metadata_connstring, storage_root_dir)); + std::string local_name = local_hostname; + if (g_transfer_engine) { + local_name = g_transfer_engine->local_server_name_; + LOG(INFO) << "Pooling multiplexing local_name:" << local_name; + g_separate_pool = false; + } else { + g_separate_pool = true; + } + LOG(INFO) << "master_server_entry:" << master_server_entry; + auto client = std::shared_ptr(new Client(local_name, metadata_connstring, storage_root_dir)); ErrorCode err = client->ConnectToMaster(master_server_entry); if (err != ErrorCode::OK) { @@ -265,13 +297,14 @@ std::optional> Client::Create( // Initialize storage backend client->PrepareStorageBackend(storage_root_dir, response.value()); } - - // Initialize transfer engine - err = client->InitTransferEngine(local_hostname, metadata_connstring, - protocol, protocol_args); - if (err != ErrorCode::OK) { - LOG(ERROR) << "Failed to initialize transfer engine"; - return std::nullopt; + if (protocol != "ascend_no_transport") { + // Initialize transfer engine + err = client->InitTransferEngine(local_name, metadata_connstring, + protocol, protocol_args); + if (err != ErrorCode::OK) { + LOG(ERROR) << "Failed to initialize transfer engine"; + return std::nullopt; + } } return client; @@ -470,8 +503,8 @@ std::vector> Client::BatchGet( continue; } - VLOG(1) << "Submitted transfer for key " << key - << " using strategy: " << static_cast(future->strategy()); + // LOG(INFO) << "Submitted transfer for key " << key + // << " using strategy: " << static_cast(future->strategy()); pending_transfers.emplace_back(i, key, std::move(*future)); } @@ -484,12 +517,12 @@ std::vector> Client::BatchGet( << " with error: " << static_cast(result); results[index] = tl::unexpected(result); } else { - VLOG(1) << "Transfer completed successfully for key: " << key; + // LOG(INFO) << "Transfer completed successfully for key: " << key; results[index] = {}; } } - VLOG(1) << "BatchGet completed for " << object_keys.size() << " keys"; + // LOG(INFO) << "BatchGet completed for " << object_keys.size() << " keys"; return results; } @@ -613,6 +646,7 @@ std::vector Client::CreatePutOperations( ops.reserve(keys.size()); for (size_t i = 0; i < keys.size(); ++i) { ops.emplace_back(keys[i], batched_slices[i]); + // LOG(ERROR) << "batched_slices size: " << batched_slices[i].size(); } return ops; } @@ -624,7 +658,7 @@ void Client::StartBatchPut(std::vector& ops, keys.reserve(ops.size()); slice_lengths.reserve(ops.size()); - + // LOG(ERROR) << "ops size: " << ops.size(); for (const auto& op : ops) { keys.emplace_back(op.key); @@ -633,9 +667,10 @@ void Client::StartBatchPut(std::vector& ops, for (const auto& slice : op.slices) { slice_sizes.emplace_back(slice.size); } + // LOG(ERROR) << "slicen size: " << slice_sizes.size(); slice_lengths.emplace_back(std::move(slice_sizes)); } - + // LOG(ERROR) << "slice_lengths size: " << slice_lengths.size(); auto start_responses = master_client_.BatchPutStart(keys, slice_lengths, config); @@ -978,9 +1013,10 @@ tl::expected Client::MountSegment(const void* buffer, return tl::unexpected(ErrorCode::INVALID_PARAMS); } } + std::string kWildcardLocation_pool = "cpu"; + int rc = transfer_engine_->registerLocalMemory( + (void*)buffer, size, kWildcardLocation_pool, true, true); - int rc = transfer_engine_.registerLocalMemory( - (void*)buffer, size, kWildcardLocation, true, true); if (rc != 0) { LOG(ERROR) << "register_local_memory_failed base=" << buffer << " size=" << size << ", error=" << rc; @@ -1029,7 +1065,7 @@ tl::expected Client::UnmountSegment(const void* buffer, return tl::unexpected(err); } - int rc = transfer_engine_.unregisterLocalMemory( + int rc = transfer_engine_->unregisterLocalMemory( reinterpret_cast(segment->second.base)); if (rc != 0) { LOG(ERROR) << "Failed to unregister transfer buffer with transfer " @@ -1053,7 +1089,7 @@ tl::expected Client::RegisterLocalMemory( if (!check_result) { return tl::unexpected(check_result.error()); } - if (this->transfer_engine_.registerLocalMemory( + if (this->transfer_engine_->registerLocalMemory( addr, length, location, remote_accessible, update_metadata) != 0) { return tl::unexpected(ErrorCode::INVALID_PARAMS); } @@ -1062,7 +1098,7 @@ tl::expected Client::RegisterLocalMemory( tl::expected Client::unregisterLocalMemory( void* addr, bool update_metadata) { - if (this->transfer_engine_.unregisterLocalMemory(addr, update_metadata) != + if (this->transfer_engine_->unregisterLocalMemory(addr, update_metadata) != 0) { return tl::unexpected(ErrorCode::INVALID_PARAMS); } diff --git a/mooncake-store/src/rpc_service.cpp b/mooncake-store/src/rpc_service.cpp index a65a8ded8..8c6e928ba 100644 --- a/mooncake-store/src/rpc_service.cpp +++ b/mooncake-store/src/rpc_service.cpp @@ -320,9 +320,18 @@ WrappedMasterService::BatchPutStart( results; results.reserve(keys.size()); + uint32_t all_slice_len = 0; + std::vector slice_len; for (size_t i = 0; i < keys.size(); ++i) { + slice_len.reserve(keys.size()); + all_slice_len = 0; + for (size_t j = 0; j < slice_lengths[i].size(); ++j) { + all_slice_len += slice_lengths[i][j]; + } + slice_len.emplace_back(all_slice_len); + // LOG(ERROR) << "master_server put start, len:" << slice_lengths[i].size(); results.emplace_back( - master_service_.PutStart(keys[i], slice_lengths[i], config)); + master_service_.PutStart(keys[i], slice_len, config)); } size_t failure_count = 0; diff --git a/mooncake-store/src/transfer_task.cpp b/mooncake-store/src/transfer_task.cpp index 4ce4dbbee..fda9863f2 100644 --- a/mooncake-store/src/transfer_task.cpp +++ b/mooncake-store/src/transfer_task.cpp @@ -6,6 +6,7 @@ #include #include "utils.h" +#include "acl/acl.h" namespace mooncake { @@ -14,7 +15,7 @@ namespace mooncake { // ============================================================================ // to fully utilize the available ssd bandwidth, we use a default of 10 worker // threads. -constexpr int kDefaultFilereadWorkers = 10; +constexpr int kDefaultFilereadWorkers = 1; FilereadWorkerPool::FilereadWorkerPool(std::shared_ptr& backend) : shutdown_(false) { @@ -125,11 +126,16 @@ constexpr int kDefaultMemcpyWorkers = 1; MemcpyWorkerPool::MemcpyWorkerPool() : shutdown_(false) { VLOG(1) << "Creating MemcpyWorkerPool with " << kDefaultMemcpyWorkers << " workers"; - + int deviceLogicId; + int ret = aclrtGetDevice(&deviceLogicId); + if (ret) { + LOG(ERROR) << "MemcpyWorkerPool: aclrtGetDevice failed, ret: " << ret; + return; + } // Start worker threads workers_.reserve(kDefaultMemcpyWorkers); for (int i = 0; i < kDefaultMemcpyWorkers; ++i) { - workers_.emplace_back(&MemcpyWorkerPool::workerThread, this); + workers_.emplace_back(&MemcpyWorkerPool::workerThread, this, deviceLogicId); } } @@ -165,8 +171,17 @@ void MemcpyWorkerPool::submitTask(MemcpyTask task) { queue_cv_.notify_one(); } -void MemcpyWorkerPool::workerThread() { +void MemcpyWorkerPool::workerThread(int deviceLogicId) { VLOG(2) << "MemcpyWorkerPool worker thread started"; + int ret = aclrtSetDevice(deviceLogicId); + if (ret) { + LOG(ERROR) << "HcclTransport: aclrtSetDevice failed ret:" << ret; + } + aclrtStream stream; + ret = aclrtCreateStream(&stream); + if (ret) { + LOG(ERROR) << "HcclTransport: aclrtCreateStream error, ret:" << ret; + } while (true) { MemcpyTask task({}, nullptr); @@ -192,7 +207,19 @@ void MemcpyWorkerPool::workerThread() { if (task.state) { try { for (const auto& op : task.operations) { - std::memcpy(op.dest, op.src, op.size); + if (op.is_write) { + ret = aclrtMemcpy(op.dest, op.size, op.src, op.size, ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to copy data from device to host, ret: " << ret << ", local:" << op.src << ", dest:" << op.dest << ", size:" << op.size; + } + LOG(INFO) << "pool write own ret:" << ret << "op.dest:" << op.dest << ", op.src:" << op.src << ", len:" << op.size; + } else { + ret = aclrtMemcpy(op.dest, op.size, op.src, op.size, ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to copy data from host to device, ret: " << ret << ", local:" << op.src << ", dest:" << op.dest << ", size:" << op.size; + } + LOG(INFO) << "pool read own ret:" << ret << "op.dest:" << op.dest << ", op.src:" << op.src << ", len:" << op.size; + } } VLOG(2) << "Memcpy task completed successfully with " @@ -441,14 +468,14 @@ std::optional TransferSubmitter::submitMemcpyOperation( // buffer) dest = slice.ptr; src = reinterpret_cast(handle.buffer_address_); + operations.emplace_back(dest, src, handle.size_, false); } else { // WRITE: from slice (local buffer) to handle (remote // buffer) dest = reinterpret_cast(handle.buffer_address_); src = slice.ptr; + operations.emplace_back(dest, src, handle.size_, true); } - - operations.emplace_back(dest, src, handle.size_); } // Submit memcpy operations to worker pool for async execution @@ -466,27 +493,33 @@ std::optional TransferSubmitter::submitTransferEngineOperation( std::vector& slices, Transport::TransferRequest::OpCode op_code) { // Create transfer requests std::vector requests; - requests.reserve(handles.size()); - - for (size_t i = 0; i < handles.size(); ++i) { - const auto& handle = handles[i]; + requests.reserve(slices.size()); + uint64_t offset = 0; + const auto& handle = handles[0]; + for (size_t i = 0; i < slices.size(); ++i) { const auto& slice = slices[i]; + int device_id = -1; + auto[host_name, port] = parseHostNameWithPortAscend(handle.segment_name_, &device_id); + std::string segment_id = host_name + ":" + std::to_string(port); Transport::SegmentHandle seg = - engine_.openSegment(handle.segment_name_); + engine_.openSegment(segment_id); if (seg == static_cast(ERR_INVALID_ARGUMENT)) { LOG(ERROR) << "Failed to open segment " << handle.segment_name_; return std::nullopt; } - + // LOG(INFO) << "Transfer Engine, parseHostNameWithPortAscend, Server:" << host_name << ", port" + // << port << ", device_id" << device_id << ", source:" << slice.ptr << ", target_id:" << seg + // << ", target_offset" << handle.buffer_address_ << ", length:" << slice.size; Transport::TransferRequest request; request.opcode = op_code; request.source = static_cast(slice.ptr); request.target_id = seg; - request.target_offset = handle.buffer_address_; - request.length = handle.size_; + request.target_offset = handle.buffer_address_ + offset; + request.length = slice.size; requests.emplace_back(request); + offset += slice.size; } // Allocate batch ID @@ -573,15 +606,13 @@ bool TransferSubmitter::validateTransferParams( << " slices_size=" << slices.size(); return false; } - - for (size_t i = 0; i < handles.size(); ++i) { - if (handles[i].size_ != slices[i].size) { - LOG(ERROR) << "Size of replica partition " << i << " (" - << handles[i].size_ - << ") does not match provided buffer (" << slices[i].size - << ")"; - return false; - } + uint64_t all_slice_len = 0; + for (size_t i = 0; i < slices.size(); ++i) { + all_slice_len += slices[i].size; + } + if (handles[0].size_ != all_slice_len) { + LOG(ERROR) << "handles len:" << handles[0].size_ << ", all_slice_len:" << all_slice_len; + return false; } return true; diff --git a/mooncake-store/tests/transfer_task_test.cpp b/mooncake-store/tests/transfer_task_test.cpp index caa7c7ad7..0b21a8a61 100644 --- a/mooncake-store/tests/transfer_task_test.cpp +++ b/mooncake-store/tests/transfer_task_test.cpp @@ -38,7 +38,7 @@ TEST_F(TransferTaskTest, MemcpyOperationBasic) { std::vector dest_data(data_size, 'B'); // Create memcpy operation - MemcpyOperation op(dest_data.data(), src_data.data(), data_size); + MemcpyOperation op(dest_data.data(), src_data.data(), data_size, false); // Verify operation parameters EXPECT_EQ(op.dest, dest_data.data()); @@ -81,7 +81,7 @@ TEST_F(TransferTaskTest, MemcpyWorkerPoolBasic) { // Create memcpy operations std::vector operations; - operations.emplace_back(dest_data.data(), src_data.data(), data_size); + operations.emplace_back(dest_data.data(), src_data.data(), data_size, false); // Create and submit task MemcpyTask task(std::move(operations), state); @@ -122,7 +122,7 @@ TEST_F(TransferTaskTest, MemcpyWorkerPoolMultipleOperations) { std::vector operations; for (size_t i = 0; i < num_ops; ++i) { operations.emplace_back(dest_buffers[i].data(), src_buffers[i].data(), - data_size); + data_size, false); } // Create and submit task diff --git a/mooncake-transfer-engine/example/CMakeLists.txt b/mooncake-transfer-engine/example/CMakeLists.txt index 123e60722..1533d5d17 100644 --- a/mooncake-transfer-engine/example/CMakeLists.txt +++ b/mooncake-transfer-engine/example/CMakeLists.txt @@ -1,8 +1,3 @@ -file(GLOB ASCEND_TOOLKIT_ROOT "/usr/local/Ascend/ascend-toolkit/latest/*-linux") -set(ASCEND_LIB_DIR "${ASCEND_TOOLKIT_ROOT}/lib64") - -link_directories(${ASCEND_LIB_DIR}) - add_executable(transfer_engine_bench transfer_engine_bench.cpp) target_link_libraries(transfer_engine_bench PUBLIC transfer_engine) @@ -16,9 +11,15 @@ add_executable(memory_pool memory_pool.cpp) target_link_libraries(memory_pool PUBLIC transfer_engine) if (USE_ASCEND) + file(GLOB ASCEND_TOOLKIT_ROOT "/usr/local/Ascend/ascend-toolkit/latest/*-linux") + set(ASCEND_LIB_DIR "${ASCEND_TOOLKIT_ROOT}/lib64") + link_directories(${ASCEND_LIB_DIR}) add_executable(transfer_engine_ascend_one_sided transfer_engine_ascend_one_sided.cpp) target_link_libraries(transfer_engine_ascend_one_sided PUBLIC transfer_engine) add_executable(transfer_engine_ascend_perf transfer_engine_ascend_perf.cpp) target_link_libraries(transfer_engine_ascend_perf PUBLIC transfer_engine) + + add_executable(transfer_engine_ascend_memcpy transfer_engine_ascend_memcpy.cpp) + target_link_libraries(transfer_engine_ascend_memcpy PUBLIC transfer_engine) endif() \ No newline at end of file diff --git a/mooncake-transfer-engine/example/transfer_engine_ascend_memcpy.cpp b/mooncake-transfer-engine/example/transfer_engine_ascend_memcpy.cpp new file mode 100644 index 000000000..4b30875c0 --- /dev/null +++ b/mooncake-transfer-engine/example/transfer_engine_ascend_memcpy.cpp @@ -0,0 +1,233 @@ +#include +#include +#include +#include +#include "common.h" +#include "transfer_engine.h" +#include "transport/transport.h" +#include "acl/acl.h" +#include "hccl.h" + +DEFINE_int32(batch_size, 64, "Batch size"); +DEFINE_uint64(block_size, 131072, "Block size for each transfer request"); +DEFINE_uint64(device_id, 0, "The device logic and phy ID of this machine"); +DEFINE_string(mode, "dtod", "device to device or device to host"); + +using namespace mooncake; + +std::string toLowerCase(const std::string& str) { + std::string lowerStr = str; + std::transform(lowerStr.begin(), lowerStr.end(), lowerStr.begin(), + [](unsigned char c) { return std::tolower(c); }); + return lowerStr; +} + +int aclDeviceToHost() { + void* devAddr; + void* hostAddr; + uint64_t totalSize = FLAGS_block_size * FLAGS_batch_size; + aclError ret = aclrtMalloc(&devAddr, totalSize, ACL_MEM_MALLOC_NORMAL_ONLY); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to allocate device memory, ret:" << ret; + return ret; + } + + ret = aclrtMallocHost(&hostAddr, totalSize); + if (ret != ACL_ERROR_NONE || hostAddr == nullptr) { + LOG(ERROR) << "Failed to allocate host memory, ret:" << ret; + return ret; + } + + for (size_t i = 0; i < totalSize; i += sizeof(uint32_t)) { + *(uint32_t*)((char *)hostAddr + i) = 0x12345678; + } + + ret = aclrtMemcpy(devAddr, totalSize, hostAddr, totalSize, ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to copy data from host to device, ret: " << ret; + return ret; + } + + constexpr std::uint64_t kBatch = 10000; + std::uint64_t cnt = 0; + std::chrono::microseconds total_us{0}; + + while (1) { + auto start = std::chrono::high_resolution_clock::now(); + + ret = aclrtMemcpy(devAddr, totalSize, hostAddr, totalSize, ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to copy data from host to device, ret: " << ret; + return ret; + } + + ret = aclrtMemcpy(hostAddr, totalSize, devAddr, totalSize, ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to copy data from device to host, ret: " << ret; + return ret; + } + + auto stop = std::chrono::high_resolution_clock::now(); + total_us += std::chrono::duration_cast(stop - start); + cnt++; + + if (cnt >= kBatch) { + double avg_us = static_cast(total_us.count()) / kBatch; + LOG(INFO) << ", total_size: " << totalSize * 2 / 1024 << "KB"; + LOG(INFO) << "avg_memcpy: " << avg_us << " us, avg_bw: " + << (static_cast(totalSize) / avg_us / 1e3) << " GB/s"; + cnt = 0; + total_us = std::chrono::microseconds{0}; + } + + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + + ret = aclrtFree(devAddr); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to aclrtFree, ret: " << ret; + return ret; + } + + ret = aclrtFreeHost(hostAddr); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to aclrtFreeHost, ret: " << ret; + return ret; + } + + return 0; +} + +int aclDeviceToDevice() { + void* hugeBlock; // 单个大块内存 + std::vector smallBlocks; // 多个小块内存 + uint64_t block_size = FLAGS_block_size; + uint64_t block_num = FLAGS_batch_size; + uint64_t hugeSize = block_size * block_num; + aclrtStream stream; + + aclError ret = aclrtCreateStream(&stream); + if (ret) { + LOG(ERROR) << "HcclTransport: aclrtCreateStream error, ret: " << ret; + } + + smallBlocks.resize(block_num); + ret = aclrtMalloc(&hugeBlock, hugeSize, ACL_MEM_MALLOC_NORMAL_ONLY); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to allocate device memory, ret:" << ret; + return ret; + } + + for (uint64_t i = 0; i < block_num; ++i) { + ret = aclrtMalloc(&smallBlocks[i], block_size, ACL_MEM_MALLOC_NORMAL_ONLY); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to allocate device memory, ret:" << ret; + return ret; + } + } + + void* tmpAddr = hugeBlock; + for (auto addr : smallBlocks) { + ret = aclrtMemcpyAsync(tmpAddr, block_size, addr, block_size, ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to copy data from device to device, ret: " << ret; + return ret; + } + tmpAddr = static_cast(tmpAddr) + block_size; + } + + constexpr std::uint64_t kBatch = 10000; + std::uint64_t cnt = 0; + std::chrono::microseconds total_us{0}; + + while (1) { + auto start = std::chrono::high_resolution_clock::now(); + tmpAddr = hugeBlock; + for (auto addr : smallBlocks) { + ret = aclrtMemcpyAsync(tmpAddr, block_size, addr, block_size, ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to copy data from device to device, ret: " << ret; + return ret; + } + tmpAddr = static_cast(tmpAddr) + block_size; + } + + ret = aclrtSynchronizeStream(stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to aclrtSynchronizeStream, ret: " << ret; + return ret; + } + + auto stop = std::chrono::high_resolution_clock::now(); + total_us += std::chrono::duration_cast(stop - start); + cnt++; + + if (cnt >= kBatch) { + double avg_us = static_cast(total_us.count()) / kBatch; + LOG(INFO) << " block_size: " << block_size / 1024 + << "KB, block_num: " << block_num + << ", total_size: " << hugeSize / 1024 << "KB"; + LOG(INFO) << "avg_memcpy: " << avg_us << " us, avg_bw: " + << (static_cast(hugeSize) / avg_us / 1e3) << " GB/s"; + cnt = 0; + total_us = std::chrono::microseconds{0}; + } + + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + + for (auto addr : smallBlocks) { + ret = aclrtFree(addr); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to aclrtFree, ret: " << ret; + return ret; + } + } + + ret = aclrtFree(hugeBlock); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to aclrtFree, ret: " << ret; + return ret; + } + + return 0; +} + + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + + int deviceLogicid = FLAGS_device_id; + const char *aclConfigPath = nullptr; + aclError ret = aclInit(aclConfigPath); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to initialize ACL, ret: " << ret; + return -1; + } + + ret = aclrtSetDevice(deviceLogicid); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to set device, ret: " << ret; + return -1; + } + + aclrtContext context = nullptr; + ret = aclrtCreateContext(&context, deviceLogicid); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to create context, ret: " << ret; + return ret; + } + + std::string mode = toLowerCase(FLAGS_mode); + + if (mode == "dtoh" || mode == "htod") { + LOG(INFO) << "Detected mode: " << mode; + return aclDeviceToHost(); + } else if (mode == "dtod") { + LOG(INFO) << "Detected mode: " << mode; + return aclDeviceToDevice(); + } + + LOG(ERROR) << "Unsupported mode: must be 'dtod' or 'dtoh' or 'htod'"; + exit(EXIT_FAILURE); +} \ No newline at end of file diff --git a/mooncake-transfer-engine/example/transfer_engine_ascend_one_sided.cpp b/mooncake-transfer-engine/example/transfer_engine_ascend_one_sided.cpp index 06bed24b3..d82eec282 100644 --- a/mooncake-transfer-engine/example/transfer_engine_ascend_one_sided.cpp +++ b/mooncake-transfer-engine/example/transfer_engine_ascend_one_sided.cpp @@ -60,7 +60,7 @@ using namespace mooncake; int g_deviceLogicId = 0; int g_devicePhyId = 0; -uint64_t g_TotalSize = 0; +uint64_t g_totalSize = 0; const static std::unordered_map RATE_UNIT_MP = { {"GB", 1000ull * 1000ull * 1000ull}, @@ -148,7 +148,7 @@ int initiator() { hostname_port.first.c_str(), hostname_port.second); void *devAddr = nullptr; - ret = allocateDevMem(devAddr, FLAGS_block_size * FLAGS_batch_size); + ret = allocateDevMem(devAddr, g_totalSize); if (ret) { LOG(ERROR) << "Failed to allocateDevMem, ret: " << ret; return ret; @@ -156,7 +156,7 @@ int initiator() { LOG(INFO) << "devAddr_initiator: " << devAddr; - ret = engine->registerLocalMemory(devAddr, g_TotalSize, + ret = engine->registerLocalMemory(devAddr, g_totalSize, "npu:" + std::to_string(g_devicePhyId)); if (ret) { LOG(ERROR) << "Failed to registerLocalMemory, ret: " << ret; @@ -164,7 +164,7 @@ int initiator() { } void *devAddr2 = nullptr; - ret = allocateDevMem(devAddr2, FLAGS_block_size * FLAGS_batch_size); + ret = allocateDevMem(devAddr2, g_totalSize); if (ret) { LOG(ERROR) << "Failed to allocateDevMem, ret: " << ret; return ret; @@ -172,7 +172,7 @@ int initiator() { LOG(INFO) << "devAddr_initiator2: " << devAddr2; - ret = engine->registerLocalMemory(devAddr2, g_TotalSize, + ret = engine->registerLocalMemory(devAddr2, g_totalSize, "npu:" + std::to_string(g_devicePhyId)); if (ret) { LOG(ERROR) << "Failed to registerLocalMemory, ret: " << ret; @@ -208,7 +208,7 @@ int initiator() { entry.source = (uint8_t *)(devAddr) + FLAGS_block_size * i; entry.target_id = segment_id; entry.target_offset = remote_base + FLAGS_block_size * i + - g_TotalSize * FLAGS_initiator_id; + g_totalSize * FLAGS_initiator_id; requests.emplace_back(entry); } @@ -248,7 +248,7 @@ int initiator() { entry.source = (uint8_t *)(devAddr2) + FLAGS_block_size * i; entry.target_id = segment_id; entry.target_offset = remote_base2 + FLAGS_block_size * i + - g_TotalSize * FLAGS_initiator_id; + g_totalSize * FLAGS_initiator_id; requests2.emplace_back(entry); } completed = false; @@ -278,8 +278,8 @@ int initiator() { (stop_tv.tv_usec - start_tv.tv_usec); LOG(INFO) << "Test completed: duration " << duration << "us, batch count " - << FLAGS_batch_size * FLAGS_block_size << ", throughput " - << calculateRate(FLAGS_batch_size * FLAGS_block_size, duration); + << g_totalSize << ", throughput " + << calculateRate(g_totalSize, duration); // When testing 1-to-2 transmission (1 initiator to 2 targets), fill in the // segment_id of the second receiver. If not filled, it defaults to "NA" and @@ -316,7 +316,7 @@ int initiator() { entry.source = (uint8_t *)(devAddr) + FLAGS_block_size * i; entry.target_id = segment_id_2; entry.target_offset = remote_base_desc_2 + FLAGS_block_size * i + - g_TotalSize * FLAGS_initiator_id; + g_totalSize * FLAGS_initiator_id; requests.emplace_back(entry); } @@ -370,7 +370,7 @@ int target() { hostname_port.first.c_str(), hostname_port.second); void *devAddr = nullptr; - ret = allocateDevMem(devAddr, FLAGS_block_size * FLAGS_batch_size); + ret = allocateDevMem(devAddr, g_totalSize); if (ret) { LOG(ERROR) << "Failed to allocateDevMem, ret: " << ret; return ret; @@ -379,7 +379,7 @@ int target() { LOG(INFO) << "devAddr_target: " << devAddr; ret = engine->registerLocalMemory(devAddr, - g_TotalSize * FLAGS_target_recv_count, + g_totalSize * FLAGS_target_recv_count, "npu:" + std::to_string(g_devicePhyId)); if (ret) { LOG(ERROR) << "Failed to registerLocalMemory, ret: " << ret; @@ -387,7 +387,7 @@ int target() { } void *devAddr2 = nullptr; - ret = allocateDevMem(devAddr2, FLAGS_block_size * FLAGS_batch_size); + ret = allocateDevMem(devAddr2, g_totalSize); if (ret) { LOG(ERROR) << "Failed to allocateDevMem, ret: " << ret; return ret; @@ -396,7 +396,7 @@ int target() { LOG(INFO) << "devAddr_target_2: " << devAddr2; ret = engine->registerLocalMemory(devAddr2, - g_TotalSize * FLAGS_target_recv_count, + g_totalSize * FLAGS_target_recv_count, "npu:" + std::to_string(g_devicePhyId)); if (ret) { LOG(ERROR) << "Failed to registerLocalMemory, ret: " << ret; @@ -414,7 +414,7 @@ int target() { int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, false); - g_TotalSize = (uint64_t)(FLAGS_batch_size * FLAGS_block_size); + g_totalSize = (uint64_t)(FLAGS_batch_size * FLAGS_block_size); if (FLAGS_device_id != 65536) { g_deviceLogicId = FLAGS_device_id; diff --git a/mooncake-transfer-engine/example/transfer_engine_ascend_perf.cpp b/mooncake-transfer-engine/example/transfer_engine_ascend_perf.cpp index 6a5637f2b..182ff170d 100644 --- a/mooncake-transfer-engine/example/transfer_engine_ascend_perf.cpp +++ b/mooncake-transfer-engine/example/transfer_engine_ascend_perf.cpp @@ -95,7 +95,7 @@ int allocateDevMem(void *&devAddr, size_t size) { void *host_addr = nullptr; ret = aclrtMallocHost(&host_addr, size); if (ret != ACL_ERROR_NONE || host_addr == nullptr) { - LOG(ERROR) << "Failed to allocate device memory, ret:" << ret; + LOG(ERROR) << "Failed to allocate host memory, ret:" << ret; return ret; } @@ -142,15 +142,15 @@ int initiator() { // Warm-up transmission void *tmp_devAddr = NULL; - ret = allocateDevMem(tmp_devAddr, FLAGS_block_size); + ret = allocateDevMem(tmp_devAddr, FLAGS_batch_size * FLAGS_block_size); if (ret) { LOG(ERROR) << "Failed to allocateDevMem, ret: " << ret; return ret; } LOG(INFO) << "tmp_devAddr_target: " << tmp_devAddr - << ", len: " << FLAGS_block_size; - ret = engine->registerLocalMemory(tmp_devAddr, FLAGS_block_size, + << ", len: " << FLAGS_batch_size * FLAGS_block_size; + ret = engine->registerLocalMemory(tmp_devAddr, FLAGS_batch_size * FLAGS_block_size, "npu:" + std::to_string(g_devicePhyId)); void *devAddr = NULL; @@ -200,7 +200,7 @@ int initiator() { std::vector tmp_requests; TransferRequest entry; entry.opcode = opcode; - entry.length = FLAGS_block_size; + entry.length = FLAGS_batch_size * FLAGS_block_size; entry.source = (uint8_t *)tmp_devAddr; entry.target_id = segment_id; entry.target_offset = remote_base; @@ -302,15 +302,15 @@ int target() { // Warm-up transmission void *tmp_devAddr = NULL; - ret = allocateDevMem(tmp_devAddr, FLAGS_block_size); + ret = allocateDevMem(tmp_devAddr, FLAGS_batch_size * FLAGS_block_size); if (ret) { LOG(ERROR) << "Failed to allocateDevMem, ret: " << ret; return ret; } LOG(INFO) << "tmp_devAddr_target: " << tmp_devAddr - << ", len: " << FLAGS_block_size; - ret = engine->registerLocalMemory(tmp_devAddr, FLAGS_block_size, + << ", len: " << FLAGS_batch_size * FLAGS_block_size; + ret = engine->registerLocalMemory(tmp_devAddr, FLAGS_batch_size * FLAGS_block_size, "npu:" + std::to_string(g_devicePhyId)); void *devAddr = NULL; diff --git a/mooncake-transfer-engine/include/multi_transport.h b/mooncake-transfer-engine/include/multi_transport.h index b5214b58c..3e0489008 100644 --- a/mooncake-transfer-engine/include/multi_transport.h +++ b/mooncake-transfer-engine/include/multi_transport.h @@ -53,13 +53,14 @@ class MultiTransport { std::vector listTransports(); + std::map> transport_map_; + private: Status selectTransport(const TransferRequest &entry, Transport *&transport); private: std::shared_ptr metadata_; std::string local_server_name_; - std::map> transport_map_; RWSpinlock batch_desc_lock_; std::unordered_map> batch_desc_set_; }; diff --git a/mooncake-transfer-engine/include/transfer_engine.h b/mooncake-transfer-engine/include/transfer_engine.h index 2a4ce9424..e098b16cb 100644 --- a/mooncake-transfer-engine/include/transfer_engine.h +++ b/mooncake-transfer-engine/include/transfer_engine.h @@ -212,6 +212,8 @@ class TransferEngine { return local_topology_; } + std::string local_server_name_; + private: struct MemoryRegion { void *addr; @@ -221,7 +223,6 @@ class TransferEngine { }; std::shared_ptr metadata_; - std::string local_server_name_; std::shared_ptr multi_transports_; std::shared_mutex mutex_; std::vector local_memory_regions_; @@ -251,6 +252,8 @@ class TransferEngine { void StopMetricsReportingThread(); #endif }; +extern __attribute__ ((visibility ("default"))) std::shared_ptr g_transfer_engine; +extern __attribute__ ((visibility ("default"))) bool g_separate_pool; } // namespace mooncake #endif diff --git a/mooncake-transfer-engine/include/transfer_metadata.h b/mooncake-transfer-engine/include/transfer_metadata.h index 70f15c8d4..4c0524e94 100644 --- a/mooncake-transfer-engine/include/transfer_metadata.h +++ b/mooncake-transfer-engine/include/transfer_metadata.h @@ -68,12 +68,16 @@ class TransferMetadata { uint64_t rankId = 0xFFFFFFFF; // rank id, user rank std::string hostIp; uint64_t hostPort; - uint64_t deviceLogicId; - uint64_t devicePhyId; uint64_t deviceType = 5; // default - std::string deviceIp; uint64_t devicePort; uint64_t pid; + uint32_t devPid; + int64_t sdid = 0xFFFFFFFF; + uint32_t deviceLogicId; + uint32_t devicePhyId; + uint32_t serverId = 0; + std::string deviceIp; + std::string vnicIp; }; using SegmentID = uint64_t; diff --git a/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_aggTransport_c.h b/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_aggTransport_c.h new file mode 100644 index 000000000..7ad20ec10 --- /dev/null +++ b/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_aggTransport_c.h @@ -0,0 +1,56 @@ +// Copyright 2025 Huawei Technologies Co., Ltd +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HCCL_AGGTRANSPORT_C_H +#define HCCL_AGGTRANSPORT_C_H + +#include "hccl_transport_mem_c.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +struct HugeBuffer { + MemBlock memBlock; + std::atomic freed{true}; + HugeBuffer(const MemBlock &mb, bool is_freed = true) + : memBlock(mb), freed{is_freed} {} +}; + +struct transferReq { + void *local_addr; + void *remote_addr; + uint64_t len; + int opcode; + int isMerge; + int mergeIdx; + std::string key_str; +}; + +/* Aggregated External Interface */ +extern int aggTransportMemTask(RankInfo *local_rank_info, + RankInfo *remote_rank_info, + std::vector &local_memPool, + std::vector &remote_memPool, + int opcode, aclrtStream stream, int mem_type); +extern int aggTransportMemTransfer(aclrtStream stream); +extern int aggTransportMemTarget(aclrtStream stream); +extern void aggRegLocalMem(uint64_t addr, uint64_t length, bool isAggBuffer); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // HCCL_AGGTRANSPORT_C_H diff --git a/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport.h b/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport.h index 94a09787f..7a729e4fc 100644 --- a/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport.h +++ b/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport.h @@ -30,9 +30,9 @@ #include "transfer_metadata.h" #include "transport/transport.h" #include "hccl_transport_mem_c.h" +#include "hccl_aggTransport_c.h" #define THREAD_NUM 1 -#define ASCEND_DEFAULT_HOST_PORT 10000 #define ASCEND_DEFAULT_DEVICE_PORT 16666 namespace mooncake { @@ -79,24 +79,41 @@ class HcclTransport : public Transport { private: int allocateLocalSegmentID(); - - int initPdThread(); - - void initiatorLoop(int deviceLogicId, int selfIdx); - - void acceptLoop(int deviceLogicId); - int getDevIdAndIpPortFromServerName(std::string &local_server_name, std::string &ip, int &ip_port, int &devicePhyId); - - int rankInfoParse(int devicePhyId, std::string hostIp); + int devInfoParse(std::string hostIp); + int prepareTransport(std::vector &slice_list); + + int startNonAggThreads(); + int nonAggTransport(std::vector &slice_list, aclrtStream stream); + void initiatorLoop(int deviceLogicId); // Thread logic for initiator + void targetAcceptLoop(int deviceLogicId); // Thread logic for target + + int startAggThreads(); + int aggTransport(std::vector &slice_list, aclrtStream stream); + void aggInitiatorLoop( + int deviceLogicId); // Thread logic for initiator aggregation/splitting + void aggInitiatorTransferLoop( + int deviceLogicId); // Thread logic for initiator data transfer + void aggTargetAcceptLoop( + int deviceLogicId); // Thread logic for target connection acceptance + void aggTargetLoop( + int deviceLogicId); // Thread logic for target aggregation/splitting private: + bool aggregateEnabled_; std::atomic_bool running_; - std::thread allInitiatorThreads_[THREAD_NUM]; - std::thread allAcceptThreads_[THREAD_NUM]; - std::queue> allReqQueues_[THREAD_NUM]; + + std::thread initiatorThread_; + std::thread targetAcceptThread_; + std::thread targetThread_; + + std::thread aggInitiatorThread_; + std::thread aggInitiatorTransferThread_; + std::thread aggTargetAcceptThread_; + std::thread aggTargetThread_; + std::queue> allReqQueues_; std::mutex initiator_mutex_; std::condition_variable initiator_cond_; RankInfo local_rank_info_; diff --git a/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport_mem_c.h b/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport_mem_c.h index dfe8a9f22..1b78a7c47 100644 --- a/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport_mem_c.h +++ b/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport_mem_c.h @@ -16,102 +16,37 @@ #ifndef HCCL_TRANSPORT_MEM_C_H #define HCCL_TRANSPORT_MEM_C_H -#include -#include #include -#include "acl/acl.h" -#include "adapter_hccp_common.h" -#include "dispatcher.h" -#include "dtype_common.h" -#include "externalinput_pub.h" -#include "hccl.h" -#include "hccl_types.h" -#include "hccl_check_buf_init.h" -#include "hccl_check_common.h" -#include "hccl_ip_address.h" -#include "hccl_network_pub.h" -#include "hccl_opbase_rootinfo_base.h" -#include "hccl_socket.h" -#include "mem_device_pub.h" -#include "notify_pool.h" -#include "p2p_mgmt_pub.h" -#include "sal_pub.h" -#include "transport_mem.h" -#include "transport_pub.h" -#include "hccl_mem.h" -#include "hccl_mem_defs.h" +#include "runtime/dev.h" +#include "hccl_transport_mem_internals.h" #ifdef __cplusplus extern "C" { #endif // __cplusplus -struct RankInfo { - uint64_t rankId = 0xFFFFFFFF; - uint64_t serverIdx; - struct in_addr hostIp; - uint64_t hostPort; - uint64_t deviceLogicId; - uint64_t devicePhyId; - DevType deviceType{DevType::DEV_TYPE_NOSOC}; - struct in_addr deviceIp; - uint64_t devicePort; - uint64_t pid; -}; - -struct RankControlInfo { - uint64_t deviceLogicId; - uint64_t devicePhyId; - struct in_addr hostIp; - struct in_addr deviceIp; - uint64_t pid; -}; - -struct MergeMem { - void *addr = nullptr; - uint64_t len = 0; - MergeMem(void *addr_, size_t len_) : addr(addr_), len(len_) {} -}; - -struct ConnectionInfo { - int tcp_socket; - std::shared_ptr hccl_ctrl_socket; - std::shared_ptr hccl_data_socket; - std::shared_ptr transport_mem; -}; - -// Retry mechanism for initialization function failure -#define RETRY_CALL(funcCall, errorMsg) \ - do { \ - int retryCount = 0; \ - int __ret = funcCall; \ - while (__ret && retryCount < 3) { \ - LOG(ERROR) << errorMsg << ", retrying... (" << ++retryCount \ - << "/3), ret :" << __ret; \ - __ret = funcCall; \ - } \ - if (__ret) { \ - LOG(ERROR) << errorMsg \ - << " failed after 3 retries, ret: " << __ret; \ - return __ret; \ - } \ - } while (0) - -extern int initTransportMem(RankInfo *local_rank_info); - -extern int transportMemTask(RankInfo *local_rank_info, - RankInfo *remote_rank_info, int op_code, - uint64_t offset, uint64_t req_len, void *local_mem, - aclrtStream stream); - -extern int transportMemAccept(RankInfo *local_rank_info); - -extern int regLocalRmaMem(void *addr, uint64_t length); - -extern bool printEnabled(); - +#define READ 0 +#define WRITE 1 +#define DDR 0 +#define HBM 1 +#define VECTOR_RESERVE_SIZE 200 +#define ASCEND_DEFAULT_HOST_PORT 10000 +#define BLOCK_AGGREGATION_THRESHOLD 0x40000 + +/* Public External Interface */ +extern bool enableAscendLogging(); +extern int initTransportMem(RankInfo *local_rank_info, bool aggregateEnabled); +extern int transportMemAccept(RankInfo *local_rank_info, bool aggregateEnabled); + +/* Non-Aggregated External Interface */ +extern void nonAggRegLocalMem(uint64_t addr, uint64_t length, bool is_pool); +extern int nonAggTransportMemTask(RankInfo *local_rank_info, + RankInfo *remote_rank_info, int op_code, + uint64_t offset, uint64_t req_len, + void *local_mem, int mem_type, + aclrtStream stream); extern int transportMemAddOpFence(RankInfo *remote_rank_info, aclrtStream stream); - +extern int transportMemTarget(aclrtStream stream); #ifdef __cplusplus } #endif // __cplusplus diff --git a/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport_mem_internals.h b/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport_mem_internals.h new file mode 100644 index 000000000..8375fb15c --- /dev/null +++ b/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport_mem_internals.h @@ -0,0 +1,174 @@ +// Copyright 2025 Huawei Technologies Co., Ltd +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HCCL_TRANSPORT_MEM_INTERNALS_H +#define HCCL_TRANSPORT_MEM_INTERNALS_H + +#include +#include +#include "acl/acl.h" +#include "adapter_hccp_common.h" +#include "dispatcher.h" +#include "dtype_common.h" +#include "externalinput_pub.h" +#include "hccl.h" +#include "hccl_types.h" +#include "hccl_check_buf_init.h" +#include "hccl_check_common.h" +#include "hccl_ip_address.h" +#include "hccl_network_pub.h" +#include "hccl_opbase_rootinfo_base.h" +#include "hccl_socket.h" +#include "mem_device_pub.h" +#include "notify_pool.h" +#include "p2p_mgmt_pub.h" +#include "sal_pub.h" +#include "transport_mem.h" +#include "transport_pub.h" +#include "hccl_mem.h" +#include "hccl_mem_defs.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#define MAX_EVENTS 32 +#define CONNECT_MAX 1000 +#define HCCL_RETRY_TIMES 3 +#define TOTAL_AGG_DEV_SIZE 0x4000000 +#define PER_HUGE_BUFFER_SIZE 0x2000000 +#define HUGE_BUFFER_NUM 2 +struct RankControlInfo { + char hostIp[128]; + char deviceIp[128]; + char vnicIp[128]; + uint32_t devPid; + int64_t sdid; + uint32_t deviceLogicId; + uint32_t devicePhyId; +}; + +struct RankInfo { + char hostIp[128]; + char deviceIp[128]; + char vnicIp[128]; + uint64_t rankId = 0xFFFFFFFF; + uint64_t hostPort = 0; + DevType deviceType{DevType::DEV_TYPE_NOSOC}; + uint64_t devicePort = 16666; + uint32_t devPid; + int64_t sdid = 0xFFFFFFFF; + uint32_t serverId = 0; + uint32_t deviceLogicId; + uint32_t devicePhyId; + + RankInfo() = default; + RankInfo(const RankControlInfo &controlInfo) + : devPid(controlInfo.devPid), + sdid(controlInfo.sdid), + deviceLogicId(controlInfo.deviceLogicId), + devicePhyId(controlInfo.devicePhyId) { + memcpy(hostIp, controlInfo.hostIp, sizeof(controlInfo.hostIp)); + memcpy(deviceIp, controlInfo.deviceIp, sizeof(controlInfo.deviceIp)); + memcpy(vnicIp, controlInfo.vnicIp, sizeof(controlInfo.vnicIp)); + } +}; + +struct MemBlock { + uint64_t addr = 0; + uint64_t len = 0; + int mem_type = 0; + MemBlock() : addr(0), len(0), mem_type(1) {} + MemBlock(uint64_t a, uint64_t l, int m) : addr(a), len(l), mem_type(m) {} +}; + +// Retry mechanism for initialization function failure +#define RETRY_CALL(funcCall, errorMsg) \ + do { \ + int retryCount = 0; \ + int __ret = funcCall; \ + while (__ret && retryCount < 3) { \ + LOG(ERROR) << errorMsg << ", retrying... (" << ++retryCount \ + << "/3)" << __ret; \ + __ret = funcCall; \ + } \ + if (__ret) { \ + LOG(ERROR) << errorMsg << " failed after 3 retries." << __ret; \ + return __ret; \ + } \ + } while (0) + +struct ConnectionInfo { + int tcp_socket; + std::shared_ptr hccl_ctrl_socket; + std::shared_ptr hccl_data_socket; + std::shared_ptr transport_mem; + int total_len; +}; + +struct SingleCopyInfo { + uint64_t host_addr; + uint64_t device_addr; + uint64_t len; + bool is_read; + bool is_copy; + uint64_t local_id; + uint64_t remote_id; + uint64_t offset; +}; + +extern std::unordered_map + g_target_key_to_connection_map; +extern std::unordered_map + g_target_key_to_accept_map; +extern std::vector g_localBuffer; +extern int g_epoll_fd_agg; +extern int g_epoll_fd_target; + +extern bool a3Enabled(); + +extern int initServerNetSocket(RankInfo *local_rank_info); + +extern int initControlSocket(RankInfo *local_rank_info, bool aggregateEnabled); + +extern int controlInfoSend(RankInfo *local_rank_info, + RankInfo *remote_rank_info); + +extern int createClientSocket(std::shared_ptr &hccl_socket, + RankInfo *local_rank_info, + RankInfo *remote_rank_info, bool is_cross_hccs, + std::string tag); + +extern int createTransportMem( + RankInfo *local_rank_info, RankInfo *remote_rank_info, std::string key_str, + bool is_cross_hccs, std::shared_ptr &transport_mem, + bool is_accept); + +extern int socketEpollWait(); + +extern int acceptFromTarget(); + +extern int acceptHcclSocket(std::shared_ptr &hccl_socket, + std::string baseTag_, + hccl::HcclIpAddress rempoteDevIp, + bool is_cross_hccs); + +extern void getDevIpAddresses(RankInfo *local_rank_info); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // HCCL_TRANSPORT_MEM_INTERNALS_H diff --git a/mooncake-transfer-engine/include/transport/transport.h b/mooncake-transfer-engine/include/transport/transport.h index c76d4ccb6..65d46b25d 100644 --- a/mooncake-transfer-engine/include/transport/transport.h +++ b/mooncake-transfer-engine/include/transport/transport.h @@ -60,6 +60,7 @@ class Transport { uint64_t target_offset; size_t length; int advise_retry_cnt = 0; + int target_offset_type = 0; }; enum TransferStatusEnum { @@ -119,6 +120,7 @@ class Transport { void *dest_addr; } cxl; struct { + int dest_addr_type; uint64_t dest_addr; } hccl; }; diff --git a/mooncake-transfer-engine/src/CMakeLists.txt b/mooncake-transfer-engine/src/CMakeLists.txt index 8fd118c84..a12c880bb 100644 --- a/mooncake-transfer-engine/src/CMakeLists.txt +++ b/mooncake-transfer-engine/src/CMakeLists.txt @@ -4,7 +4,7 @@ add_subdirectory(transport) SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) -add_library(transfer_engine ${ENGINE_SOURCES} $) +add_library(transfer_engine SHARED ${ENGINE_SOURCES} $) if (BUILD_SHARED_LIBS) install(TARGETS transfer_engine DESTINATION lib) endif() diff --git a/mooncake-transfer-engine/src/transfer_engine.cpp b/mooncake-transfer-engine/src/transfer_engine.cpp index fb28cd85d..271d3cc73 100644 --- a/mooncake-transfer-engine/src/transfer_engine.cpp +++ b/mooncake-transfer-engine/src/transfer_engine.cpp @@ -27,7 +27,8 @@ #include "transport/transport.h" namespace mooncake { - +__attribute__ ((visibility ("default"))) std::shared_ptr g_transfer_engine = nullptr; +__attribute__ ((visibility ("default"))) bool g_separate_pool = false; static bool setFilesLimit() { struct rlimit filesLimit; if (getrlimit(RLIMIT_NOFILE, &filesLimit) != 0) { @@ -107,10 +108,12 @@ int TransferEngine::init(const std::string &metadata_conn_string, if (metadata_conn_string == P2PHANDSHAKE) { rpc_binding_method = "P2P handshake"; - desc.rpc_port = findAvailableTcpPort(desc.sockfd); - if (desc.rpc_port == 0) { - LOG(ERROR) << "P2P: No valid port found for local TCP service."; - return -1; + if (!g_separate_pool) { + desc.rpc_port = findAvailableTcpPort(desc.sockfd); + if (desc.rpc_port == 0) { + LOG(ERROR) << "P2P: No valid port found for local TCP service."; + return -1; + } } #ifdef USE_ASCEND // The current version of Ascend Transport does not support IPv6, @@ -263,6 +266,7 @@ Transport *TransferEngine::installTransport(const std::string &proto, LOG(WARNING) << "Transport " << proto << " already installed"; return transport; } + LOG(WARNING) << "Transport not used"; if (args != nullptr && args[0] != nullptr) { const std::string nic_priority_matrix = static_cast(args[0]); diff --git a/mooncake-transfer-engine/src/transfer_metadata.cpp b/mooncake-transfer-engine/src/transfer_metadata.cpp index 52f5a228c..d149738f8 100644 --- a/mooncake-transfer-engine/src/transfer_metadata.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata.cpp @@ -191,15 +191,20 @@ int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc, rankInfoJSON["hostPort"] = static_cast(desc.rank_info.hostPort); rankInfoJSON["deviceLogicId"] = - static_cast(desc.rank_info.deviceLogicId); + static_cast(desc.rank_info.deviceLogicId); rankInfoJSON["devicePhyId"] = - static_cast(desc.rank_info.devicePhyId); + static_cast(desc.rank_info.devicePhyId); rankInfoJSON["deviceType"] = static_cast(desc.rank_info.deviceType); rankInfoJSON["deviceIp"] = desc.rank_info.deviceIp; rankInfoJSON["devicePort"] = static_cast(desc.rank_info.devicePort); - rankInfoJSON["pid"] = static_cast(desc.rank_info.pid); + rankInfoJSON["devPid"] = + static_cast(desc.rank_info.devPid); + rankInfoJSON["vnicIp"] = desc.rank_info.vnicIp; + rankInfoJSON["sdid"] = static_cast(desc.rank_info.sdid); + rankInfoJSON["serverId"] = + static_cast(desc.rank_info.serverId); segmentJSON["rank_info"] = rankInfoJSON; } else if (segmentJSON["protocol"] == "nvlink") { @@ -388,13 +393,15 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, desc->rank_info.rankId = rankInfoJSON["rankId"].asUInt64(); desc->rank_info.hostIp = rankInfoJSON["hostIp"].asString(); desc->rank_info.hostPort = rankInfoJSON["hostPort"].asUInt64(); - desc->rank_info.deviceLogicId = - rankInfoJSON["deviceLogicId"].asUInt64(); - desc->rank_info.devicePhyId = rankInfoJSON["devicePhyId"].asUInt64(); + desc->rank_info.deviceLogicId = rankInfoJSON["deviceLogicId"].asUInt(); + desc->rank_info.devicePhyId = rankInfoJSON["devicePhyId"].asUInt(); desc->rank_info.deviceType = rankInfoJSON["deviceType"].asUInt64(); desc->rank_info.deviceIp = rankInfoJSON["deviceIp"].asString(); desc->rank_info.devicePort = rankInfoJSON["devicePort"].asUInt64(); - desc->rank_info.pid = rankInfoJSON["pid"].asUInt64(); + desc->rank_info.devPid = rankInfoJSON["devPid"].asUInt(); + desc->rank_info.vnicIp = rankInfoJSON["vnicIp"].asString(); + desc->rank_info.sdid = rankInfoJSON["sdid"].asInt64(); + desc->rank_info.serverId = rankInfoJSON["serverId"].asUInt(); } else if (desc->protocol == "cxl") { desc->cxl_name = segmentJSON["cxl_name"].asString(); desc->cxl_base_addr = segmentJSON["cxl_base_addr"].asUInt64(); diff --git a/mooncake-transfer-engine/src/transport/ascend_transport/CMakeLists.txt b/mooncake-transfer-engine/src/transport/ascend_transport/CMakeLists.txt index 660b2d1cb..42efe9482 100644 --- a/mooncake-transfer-engine/src/transport/ascend_transport/CMakeLists.txt +++ b/mooncake-transfer-engine/src/transport/ascend_transport/CMakeLists.txt @@ -5,16 +5,11 @@ add_library(ascend_transport OBJECT ${ASCEND_SOURCES}) target_link_libraries(ascend_transport PRIVATE ascend_transport_mem) -set_target_properties(ascend_transport - PROPERTIES - PREFIX "" - ) +set_target_properties(ascend_transport PROPERTIES PREFIX "") target_compile_options(ascend_transport PRIVATE -O2 -Xlinker -export-dynamic ) -target_link_options(ascend_transport PRIVATE - -s - ) +target_link_options(ascend_transport PRIVATE -s) diff --git a/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_aggTransport_c.cpp b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_aggTransport_c.cpp new file mode 100644 index 000000000..e0512ce12 --- /dev/null +++ b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_aggTransport_c.cpp @@ -0,0 +1,898 @@ +// Copyright 2025 Huawei Technologies Co., Ltd +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include "transport/ascend_transport/hccl_transport/hccl_aggTransport_c.h" +#include "transport/ascend_transport/hccl_transport/hccl_transport_mem_internals.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +std::mutex g_transfer_mutex; +std::condition_variable g_transfer_cond; +std::queue> g_transferReqList; + +std::mutex g_split_mutex; +std::condition_variable g_split_cond; +std::queue g_splitList; +std::vector> g_localHugeBuffer; +std::vector g_localMemtoSend; + +static int sendMemInfo(int client_socket, const std::vector &memPool, + int opcode) { + static const uint64_t kHdrLen = sizeof(opcode) + sizeof(int) + sizeof(int); + const uint64_t kBodyLen = memPool.size() * sizeof(MemBlock); + const uint64_t kTotal = kHdrLen + kBodyLen; + + struct iovec iov[4]; + int mem_type = memPool[0].mem_type; + iov[0].iov_base = const_cast(static_cast(&mem_type)); + iov[0].iov_len = sizeof(int); + + const int mem_num = static_cast(memPool.size()); + iov[1].iov_base = const_cast(static_cast(&mem_num)); + iov[1].iov_len = sizeof(mem_num); + + iov[2].iov_base = const_cast(static_cast(&opcode)); + iov[2].iov_len = sizeof(opcode); + + iov[3].iov_base = + const_cast(static_cast(memPool.data())); + iov[3].iov_len = kBodyLen; + + uint64_t already_sent = 0; + while (already_sent < kTotal) { + struct msghdr msg{}; + struct iovec iov2[4]; + int iovcnt = 0; + uint64_t skip = already_sent; + + for (int i = 0; i < 4; ++i) { + if (skip >= iov[i].iov_len) { + skip -= iov[i].iov_len; + continue; + } + iov2[iovcnt].iov_base = static_cast(iov[i].iov_base) + skip; + iov2[iovcnt].iov_len = iov[i].iov_len - skip; + skip = 0; + ++iovcnt; + } + + msg.msg_iov = iov2; + msg.msg_iovlen = iovcnt; + + int ret = sendmsg(client_socket, &msg, MSG_NOSIGNAL); + if (ret < 0) { + if (errno == EINTR) continue; + LOG(ERROR) << "sendmsg failed: " << strerror(errno); + return -1; + } + already_sent += static_cast(ret); + } + + return 0; +} + +int aggTransportMemTransfer(aclrtStream stream) { + int ret = 0; + std::shared_ptr transport_mem{}; + std::unique_lock lock(g_transfer_mutex); + if (g_transferReqList.empty()) { + g_transfer_cond.wait(lock); + } + auto req = std::move(g_transferReqList.front()); + g_transferReqList.pop(); + lock.unlock(); + + hccl::TransportMem::RmaOpMem localMem; + localMem.addr = req->local_addr; + localMem.size = req->len; + hccl::TransportMem::RmaOpMem remoteMem; + remoteMem.size = req->len; + transport_mem = g_target_key_to_connection_map[req->key_str].transport_mem; + int opcode = req->opcode; + // LOG(INFO) << "aggTransportMemTransfer addr: " << localMem.addr + // << ", size:" << localMem.size << ", op_code:" << opcode; + + int client_socket = g_target_key_to_connection_map[req->key_str].tcp_socket; + if (req->isMerge == 0 || opcode == WRITE) { + remoteMem.addr = req->remote_addr; + } else if (opcode == READ) { + uint64_t remote_base; + // LOG(INFO) << "recv start"; + ret = recv(client_socket, &remote_base, sizeof(uint64_t), MSG_WAITALL); + if (ret <= 0) { + LOG(ERROR) << "Failed to receive remote_base, ret: " << ret + << ", errno: " << errno + << ", error: " << strerror(errno); + return -1; + } + // LOG(INFO) << "recv end" << remote_base; + remoteMem.addr = (void *)remote_base; + } + // LOG(INFO) << "wirte or read start opcode:" << opcode; + if (opcode == WRITE) { + ret = transport_mem->Write(remoteMem, localMem, stream); + if (ret) { + LOG(ERROR) << "Write failed, localMem: " << localMem.addr + << ", remoteMem: " << remoteMem.addr + << ", req_len: " << req->len << ", ret: " << ret; + return ret; + } + } else { + ret = transport_mem->Read(localMem, remoteMem, stream); + if (ret) { + LOG(ERROR) << "Read failed, localMem: " << localMem.addr + << ", remoteMem: " << remoteMem.addr + << ", req_len: " << req->len << ", ret: " << ret; + return ret; + } + } + + ret = transport_mem->AddOpFence(stream); + if (ret) { + LOG(ERROR) << "AddOpFence failed, localMem: " << localMem.addr + << ", remoteMem: " << remoteMem.addr + << ", req_len: " << req->len << ", ret: " << ret; + return ret; + } + + ret = aclrtSynchronizeStream(stream); + if (ret) { + LOG(ERROR) << "aclrtSynchronizeStream failed, localMem: " + << localMem.addr << ", remoteMem: " << remoteMem.addr + << ", req_len: " << req->len << ", ret: " << ret; + return ret; + } + // LOG(ERROR) << "write or read end:" << opcode; + + if (req->isMerge == 0) { + return 0; + } + int ready = 1; + if (opcode == WRITE) { + // LOG(INFO) << "write send start:" << opcode; + ret = send(client_socket, &ready, sizeof(int), 0); + if (ret < 0) { + LOG(ERROR) << "Failed to send READY to remote, ret: " << ret + << ", errno: " << errno + << ", error: " << strerror(errno); + return ret; + } + + // Check if the object at the specified index has been freed. + ret = g_localHugeBuffer[req->mergeIdx]->freed.load( + std::memory_order_acquire); + if (ret) { + LOG(ERROR) << "Error: Object at index " << req->mergeIdx + << " has been freed!"; + return ret; + } + + g_localHugeBuffer[req->mergeIdx]->freed.store( + true, std::memory_order_release); + } else { + // LOG(ERROR) << "read load start:" << opcode; + ret = send(client_socket, &ready, sizeof(int), 0); + if (ret < 0) { + LOG(ERROR) << "Failed to send ready read, ret: " << ret + << ", errno: " << errno + << ", error: " << strerror(errno); + return -1; + } + // Check if the object at the specified index has been freed. + while (!g_localHugeBuffer[req->mergeIdx]->freed.load( + std::memory_order_acquire)) { + std::this_thread::yield(); + } + // LOG(INFO) << "read load end:" << opcode; + + g_localHugeBuffer[req->mergeIdx]->freed.store( + false, std::memory_order_release); + std::unique_lock lock(g_split_mutex); + g_splitList.push(req->mergeIdx); + lock.unlock(); + g_split_cond.notify_one(); + } + + return 0; +} + +int aggTransportMemTask(RankInfo *local_rank_info, RankInfo *remote_rank_info, + std::vector &local_memPool, + std::vector &remote_memPool, int opcode, + aclrtStream stream, int mem_type) { + int ret = 0; + std::shared_ptr hccl_ctrl_socket; + std::shared_ptr hccl_data_socket; + std::shared_ptr transport_mem{}; + // Check if a connection has been established with the peer, and send local + // information to the peer + std::string key_str = std::string(remote_rank_info->hostIp) + '-' + + std::to_string(remote_rank_info->devicePhyId); + if (mem_type == DDR) { + std::string local_key = std::string(local_rank_info->hostIp) + '-' + + std::to_string(local_rank_info->devicePhyId); + // PUT OWN OBJECT / GET OWN OBJECT + if (local_key == key_str) { + uint64_t req_len = 0; + uint64_t mergeAddrWrite = g_localHugeBuffer[0]->memBlock.addr; + uint64_t mergeAddrRead = g_localHugeBuffer[0]->memBlock.addr; + for (uint32_t i = 0; i < local_memPool.size(); i++) { + if (opcode == WRITE) { + ret = aclrtMemcpyAsync( + (void *)mergeAddrWrite, local_memPool[i].len, + (void *)local_memPool[i].addr, local_memPool[i].len, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) + << "Failed to merge data from device to device, " + "ret: " + << ret << ", mergeAddrWrite: " << mergeAddrWrite + << ", localMem.addr: " << local_memPool[i].addr; + return ret; + } + mergeAddrWrite += local_memPool[i].len; + } + req_len += local_memPool[i].len; + } + + if (opcode == WRITE) { + ret = aclrtSynchronizeStream(stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) + << "Failed to aclrtSynchronizeStream, ret: " << ret; + return ret; + } + ret = aclrtMemcpy( + reinterpret_cast(remote_memPool[0].addr), req_len, + reinterpret_cast(mergeAddrWrite), req_len, + ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) + << "Failed to copy data from device to host, ret: " + << ret << ", local" << mergeAddrWrite + << ", dest:" << remote_memPool[0].addr + << ", len:" << req_len; + return ret; + } + // LOG(INFO) << "PUT: copy data from device to host, ret: " << + // ret + // << ", local" << mergeAddrWrite + // << ", dest:" << remote_memPool[0].addr + // << ", len:" << req_len; + } else { + ret = aclrtMemcpy( + reinterpret_cast(mergeAddrRead), req_len, + reinterpret_cast(remote_memPool[0].addr), req_len, + ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) + << "Failed to copy data from host to device, ret: " + << ret << ", local" << local_memPool[0].addr + << ", dest:" << mergeAddrRead << ", len:" << req_len; + return ret; + } + // LOG(INFO) << "GET: copy data from host to device, ret: " << + // ret + // << ", local" << local_memPool[0].addr + // << ", dest:" << mergeAddrRead << ", len:" << + // req_len; + for (uint32_t i = 0; i < local_memPool.size(); i++) { + ret = aclrtMemcpyAsync( + (void *)local_memPool[i].addr, local_memPool[i].len, + (void *)mergeAddrRead, local_memPool[i].len, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to copy data from device to " + "device, ret: " + << ret << ", local" << mergeAddrRead + << ", dest:" << local_memPool[i].addr + << ", len:" << local_memPool[i].len; + return ret; + } + // LOG(INFO) + // << "GET: copy data from device to device, ret: " << + // ret + // << ", local" << mergeAddrRead + // << ", dest:" << local_memPool[i].addr + // << ", len:" << local_memPool[i].len; + mergeAddrRead += local_memPool[i].len; + } + ret = aclrtSynchronizeStream(stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) + << "Failed to aclrtSynchronizeStream, ret: " << ret; + return ret; + } + } + return 0; + } + } + auto iter = g_target_key_to_connection_map.find(key_str); + if (iter == g_target_key_to_connection_map.end() || + g_target_key_to_connection_map[key_str].tcp_socket <= 0) { + ret = controlInfoSend(local_rank_info, remote_rank_info); + if (ret) { + LOG(ERROR) << "controlInfoSend failed, ret: " << ret; + return ret; + } + bool is_cross_hccs = false; + if (a3Enabled()) { + is_cross_hccs = false; + } else { + bool same_host = (strcmp(local_rank_info->hostIp, + remote_rank_info->hostIp) == 0); + // For A2 series, internal communication among 8 cards does not + // cross HCCS, such as communication among cards 0-7 + bool same_group = (local_rank_info->devicePhyId / 8) == + (remote_rank_info->devicePhyId / 8); + is_cross_hccs = !(same_host && same_group); + } + if (enableAscendLogging()) { + LOG(INFO) << "hccl transport is cross_hccs: " + << (is_cross_hccs ? "true (cross-hccs)" + : "false (same-hccs)"); + } + ret = createClientSocket(hccl_ctrl_socket, local_rank_info, + remote_rank_info, is_cross_hccs, "ctrl"); + if (ret) { + LOG(ERROR) << "createClientSocket hccl_ctrl_socket failed, ret: " + << ret; + return ret; + } + g_target_key_to_connection_map[key_str].hccl_ctrl_socket = + hccl_ctrl_socket; + ret = createClientSocket(hccl_data_socket, local_rank_info, + remote_rank_info, is_cross_hccs, "data"); + if (ret) { + LOG(ERROR) << "createClientSocket hccl_data_socket failed, ret: " + << ret; + return ret; + } + g_target_key_to_connection_map[key_str].hccl_data_socket = + hccl_data_socket; + ret = createTransportMem(local_rank_info, remote_rank_info, key_str, + is_cross_hccs, transport_mem, false); + if (ret) { + LOG(ERROR) << "createTransportMem failed, ret: " << ret; + return ret; + } + int ack = 0; + ret = recv(g_target_key_to_connection_map[key_str].tcp_socket, &ack, + sizeof(int), 0); + if (ret <= 0) { + LOG(ERROR) << "Failed to receive ack, ret: " << ret + << ", errno: " << errno + << ", error: " << strerror(errno); + return -1; + } + } else { + transport_mem = g_target_key_to_connection_map[key_str].transport_mem; + } + + int client_socket = g_target_key_to_connection_map[key_str].tcp_socket; + + // LOG(INFO) << "sendMemInfo start, opcode:" << opcode; + ret = sendMemInfo(client_socket, remote_memPool, opcode); + if (ret) { + LOG(ERROR) << "sendMemInfo failed, ret: " << ret; + return ret; + } + // LOG(INFO) << "sendMemInfo end, opcode:" << opcode; + + std::vector recvBuf(HUGE_BUFFER_NUM); + int recvIdx = 1; + if (opcode == WRITE) { + int total = HUGE_BUFFER_NUM * sizeof(uint64_t); + + struct iovec iov[1]; + iov[0].iov_base = recvBuf.data(); + iov[0].iov_len = total; + + struct msghdr msg = {}; + msg.msg_iov = iov; + msg.msg_iovlen = 1; + + ret = recvmsg(client_socket, &msg, MSG_WAITALL); + if (ret != total) { + LOG(ERROR) << "Failed to receive msg, ret: " << ret + << ", errno: " << errno + << ", error: " << strerror(errno); + return -1; + } + } + // LOG(INFO) << "recvmsg end, opcode:" << opcode; + + uint64_t send_index = 0; + uint64_t idx = 0; + while (idx < local_memPool.size()) { + uint64_t mergeLen = 0; + if (opcode == WRITE) { + while (!g_localHugeBuffer[send_index]->freed.load( + std::memory_order_acquire)) { + std::this_thread::yield(); + } + + g_localHugeBuffer[send_index]->freed.store( + false, std::memory_order_release); + } + + void *mergeAddr = (void *)g_localHugeBuffer[send_index]->memBlock.addr; + while (idx < local_memPool.size()) { + const MemBlock &localMem = local_memPool[idx]; + if (mergeLen + localMem.len > PER_HUGE_BUFFER_SIZE) { + break; + } + if (opcode == WRITE) { + ret = aclrtMemcpyAsync(mergeAddr, localMem.len, + (void *)localMem.addr, localMem.len, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) + << "Failed to merge data from device to device, ret: " + << ret << ", mergeAddr: " << mergeAddr + << ", localMem.addr: " << localMem.addr; + return ret; + } + // LOG(INFO) << "agg mergeAddr:" << mergeAddr + // << ", local:" << localMem.addr + // << ", len:" << localMem.len; + mergeAddr = static_cast(mergeAddr) + localMem.len; + } + + mergeLen += localMem.len; + idx++; + } + + auto req = std::make_shared(); + if (opcode == WRITE) { + ret = aclrtSynchronizeStream(stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to aclrtSynchronizeStream, ret: " << ret; + return ret; + } + req->remote_addr = (void *)recvBuf[recvIdx]; + recvIdx = (recvIdx + 2) % HUGE_BUFFER_NUM; + } + + req->local_addr = (void *)g_localHugeBuffer[send_index]->memBlock.addr; + req->len = mergeLen; + req->opcode = opcode; + req->isMerge = 1; + req->key_str = key_str; + req->mergeIdx = send_index; + // LOG(INFO) << "req local addr:" << req->local_addr + // << ", len:" << mergeLen << ", key:" << key_str + // << ", index:" << req->mergeIdx; + std::unique_lock lock(g_transfer_mutex); + g_transferReqList.push(req); + lock.unlock(); + g_transfer_cond.notify_one(); + mergeLen = 0; + send_index = (send_index + 2) % HUGE_BUFFER_NUM; + } + if (opcode == READ) { + idx = 0; + while (idx < local_memPool.size()) { + std::unique_lock lock(g_split_mutex); + if (g_splitList.empty()) { + g_split_cond.wait(lock); + } + int mergeIdx = std::move(g_splitList.front()); + + g_splitList.pop(); + void *mergeAddr = + (void *)g_localHugeBuffer[mergeIdx]->memBlock.addr; + lock.unlock(); + uint64_t mergeLen = 0; + while (idx < local_memPool.size()) { + const MemBlock &localMem = local_memPool[idx]; + if (mergeLen + localMem.len > PER_HUGE_BUFFER_SIZE) { + break; + } + ret = aclrtMemcpyAsync((void *)localMem.addr, localMem.len, + mergeAddr, localMem.len, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) + << "Failed to split data from device to device, ret: " + << ret; + return ret; + } + // LOG(INFO) << "mergeAddr to addr, localMem:" << localMem.addr + // << ", mergeAddr:" << mergeAddr + // << ", len:" << localMem.len; + + mergeAddr = static_cast(mergeAddr) + localMem.len; + mergeLen += localMem.len; + idx++; + } + ret = aclrtSynchronizeStream(stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to aclrtSynchronizeStream, ret: " << ret; + return ret; + } + + mergeLen = 0; + + ret = g_localHugeBuffer[mergeIdx]->freed.load( + std::memory_order_acquire); + if (ret) { + LOG(ERROR) << "Error: Object at index " << mergeIdx + << " has been freed!"; + return ret; + } + + g_localHugeBuffer[mergeIdx]->freed.store(true, + std::memory_order_release); + } + } + + int ready = 0; + if (opcode == WRITE) { + ret = recv(client_socket, &ready, sizeof(int), MSG_WAITALL); + if (ret <= 0) { + LOG(ERROR) << "Failed to receive ready write, ret: " << ret + << ", errno: " << errno + << ", error: " << strerror(errno); + return -1; + } + } + // else { + // ret = send(client_socket, &ready, sizeof(int), 0); + // if (ret < 0) { + // LOG(ERROR) << "Failed to send ready read, ret: " << ret + // << ", errno: " << errno + // << ", error: " << strerror(errno); + // return -1; + // } + // } + LOG(ERROR) << "slice ok"; + // auto duration_d1 = + // std::chrono::duration_cast(t2 - t1); auto + // duration_d2 = std::chrono::duration_cast(t3 - + // t2); auto duration_d3 = + // std::chrono::duration_cast(t4 - t3); auto + // duration_d4 = std::chrono::duration_cast(t5 - + // t4); + + // LOG(INFO) << "duration_d1: " << duration_d1.count() << "us" + // << ", duration_d2: " << duration_d2.count() << "us" + // << ", duration_d3: " << duration_d3.count() << "us" + // << ", duration_d4: " << duration_d4.count() << "us"; + + return 0; +} + +static int recvMemInfo(int client_socket, aclrtStream stream) { + int ret = 0; + int opcode, mem_num, recv_mem_type; + int a = + recv(client_socket, &recv_mem_type, sizeof(recv_mem_type), MSG_WAITALL); + if (a != sizeof(recv_mem_type)) { + LOG(ERROR) << "Failed to receive recv_mem_type type, ret: " << ret + << ", errno: " << errno << ", error: " << strerror(errno) + << ", a:" << a; + return -1; + } + if (recv(client_socket, &mem_num, sizeof(mem_num), MSG_WAITALL) != + sizeof(mem_num)) { + LOG(ERROR) << "Failed to receive mem_num, ret: " << ret + << ", errno: " << errno << ", error: " << strerror(errno); + return -1; + } + // LOG(INFO) << "recvMemInfo to receive mem_num:" << mem_num; + + uint64_t total_len = sizeof(int) + sizeof(mem_num) + sizeof(int) + + mem_num * sizeof(MemBlock); + + struct iovec iov[2]; + iov[0].iov_base = &opcode; + iov[0].iov_len = sizeof(int); + + std::vector receivedMemPool; + receivedMemPool.resize(mem_num); + iov[1].iov_base = receivedMemPool.data(); + iov[1].iov_len = mem_num * sizeof(MemBlock); + + struct msghdr msg{}; + msg.msg_iov = iov; + msg.msg_iovlen = 2; + + uint64_t already_received = 0; + while (already_received < + total_len - sizeof(recv_mem_type) - sizeof(mem_num)) { + ret = recvmsg(client_socket, &msg, 0); + if (ret <= 0) { + LOG(ERROR) << "Failed to receive msg, ret: " << ret + << ", errno: " << errno + << ", error: " << strerror(errno); + return ret; + } + already_received += ret; + + uint64_t skip = (uint64_t)ret; + for (int i = 0; i < 2; ++i) { + if (skip >= iov[i].iov_len) { + skip -= iov[i].iov_len; + iov[i].iov_len = 0; + } else { + iov[i].iov_base = static_cast(iov[i].iov_base) + skip; + iov[i].iov_len -= skip; + break; + } + } + } + + uint64_t recv_index = 1; + if (opcode == WRITE) { + struct iovec iov[1]; + iov[0].iov_base = static_cast(g_localMemtoSend.data()); + iov[0].iov_len = g_localMemtoSend.size() * sizeof(uint64_t); + + struct msghdr msg = {}; + msg.msg_iov = iov; + msg.msg_iovlen = 1; + + ret = sendmsg(client_socket, &msg, 0); + if (ret < 0) { + LOG(ERROR) << "Failed to send msg to remote, ret: " << ret + << ", errno: " << errno + << ", error: " << strerror(errno); + return ret; + } + } + + uint64_t idx = 0; + int ready = 0; + while (idx < receivedMemPool.size()) { + uint64_t mergeLen = 0; + void *mergeAddr = (void *)g_localHugeBuffer[recv_index]->memBlock.addr; + + while (!g_localHugeBuffer[recv_index]->freed.load( + std::memory_order_acquire)) { + std::this_thread::yield(); + } + g_localHugeBuffer[recv_index]->freed.store(false, + std::memory_order_release); + if (opcode == READ) { + while (idx < receivedMemPool.size()) { + auto &block = receivedMemPool[idx]; + if (mergeLen + block.len > PER_HUGE_BUFFER_SIZE) { + break; + } + if (recv_mem_type == HBM) { + ret = aclrtMemcpyAsync(mergeAddr, block.len, + (void *)block.addr, block.len, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to merge data from device to " + "device, ret: " + << ret; + return ret; + } + // LOG(INFO) << "POOLING RECV D2D mergeAddr:" << mergeAddr + // << ", block_addr:" << block.addr << ", len" + // << block.len; + mergeAddr = static_cast(mergeAddr) + block.len; + } + idx++; + mergeLen += block.len; + } + + if (recv_mem_type == HBM) { + ret = aclrtSynchronizeStream(stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) + << "Failed to aclrtSynchronizeStream, ret: " << ret; + return ret; + } + } else { + ret = aclrtMemcpy(mergeAddr, mergeLen, + (void *)receivedMemPool[0].addr, mergeLen, + ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) + << "Failed to merge data from device to device, ret: " + << ret; + return ret; + } + // LOG(INFO) << "POOLING RECV H2D mergeAddr:" << mergeAddr + // << ", block_addr:" << receivedMemPool[0].addr + // << ", len:" << mergeLen; + } + + uint64_t base_addr = g_localHugeBuffer[recv_index]->memBlock.addr; + ret = send(client_socket, &base_addr, sizeof(uint64_t), 0); + if (ret < 0) { + LOG(ERROR) << "Failed to send base_addr to remote, ret: " << ret + << ", errno: " << errno + << ", error: " << strerror(errno); + return ret; + } + LOG(INFO) << "RECV send ok addr:" << base_addr; + ret = recv(client_socket, &ready, sizeof(int), MSG_WAITALL); + if (ret <= 0) { + LOG(ERROR) << "Failed to receive ready write, ret: " << ret + << ", errno: " << errno + << ", error: " << strerror(errno); + return -1; + } + LOG(INFO) << "RECV ready addr:" << base_addr; + } else { + ret = recv(client_socket, &ready, sizeof(int), MSG_WAITALL); + if (ret <= 0) { + LOG(ERROR) << "Failed to receive ready, ret: " << ret + << ", errno: " << errno + << ", error: " << strerror(errno); + return ret; + } + + while (idx < receivedMemPool.size()) { + auto &block = receivedMemPool[idx]; + if (mergeLen + block.len > PER_HUGE_BUFFER_SIZE) { + break; + } + if (recv_mem_type == HBM) { + ret = aclrtMemcpyAsync((void *)block.addr, block.len, + mergeAddr, block.len, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to split data from device to " + "device, ret: " + << ret; + return ret; + } + // LOG(INFO) << "POOLING RECV D2D mergeAddr:" << mergeAddr + // << ", block_addr:" << block.addr << ", len" + // << block.len; + mergeAddr = static_cast(mergeAddr) + block.len; + } + idx++; + mergeLen += block.len; + } + if (recv_mem_type == HBM) { + ret = aclrtSynchronizeStream(stream); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) + << "Failed to aclrtSynchronizeStream, ret: " << ret; + return ret; + } + } else { + ret = + aclrtMemcpy((void *)receivedMemPool[0].addr, mergeLen, + mergeAddr, mergeLen, ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) + << "Failed to merge data from device to device, ret: " + << ret; + return ret; + } + // LOG(INFO) << "POOLING RECV D2H mergeAddr:" << mergeAddr + // << ", block_addr:" << receivedMemPool[0].addr + // << ", len" << mergeLen; + } + } + + ret = g_localHugeBuffer[recv_index]->freed.load( + std::memory_order_acquire); + if (ret) { + LOG(ERROR) << "Error: Object at index " << recv_index + << " has been freed!"; + return ret; + } + g_localHugeBuffer[recv_index]->freed.store(true, + std::memory_order_release); + recv_index = (recv_index + 2) % HUGE_BUFFER_NUM; + } + + if (opcode == WRITE) { + ret = send(client_socket, &ready, sizeof(int), 0); + if (ret < 0) { + LOG(ERROR) << "Failed to send ready to remote, ret: " << ret + << ", errno: " << errno + << ", error: " << strerror(errno); + return ret; + } + // LOG(INFO) << "RECV send end"; + } + + // auto duration_d1 = + // std::chrono::duration_cast(t2 - t1); auto + // duration_d2 = std::chrono::duration_cast(t3 - + // t2); auto duration_d3 = + // std::chrono::duration_cast(t4 - t3); auto + // duration_d4 = std::chrono::duration_cast(t5 - + // t4); + + // LOG(INFO) << "duration_d1: " << duration_d1.count() << "us" + // << ", duration_d2: " << duration_d2.count() << "us" + // << ", duration_d3: " << duration_d3.count() << "us" + // << ", duration_d4: " << duration_d4.count() << "us"; + + return 1; +} + +int aggTransportMemTarget(aclrtStream stream) { + int ret = 0; + struct epoll_event events[MAX_EVENTS]; + int nfds = epoll_wait(g_epoll_fd_agg, events, MAX_EVENTS, -1); + if (nfds == -1) { + if (errno == EINTR) { + return 0; + } else { + LOG(ERROR) << "epoll_wait failed: " << strerror(errno); + return -1; + } + } + + for (int i = 0; i < nfds; ++i) { + if (events[i].events & EPOLLIN) { + int fd = events[i].data.fd; + ret = recvMemInfo(fd, stream); + if (ret <= 0) { + if (ret == 0) { + LOG(ERROR) << "Peer closed the connection on fd: " << fd; + epoll_ctl(g_epoll_fd_agg, EPOLL_CTL_DEL, fd, NULL); + close(fd); + } else { + LOG(ERROR) << "Failed to recvMemInfo, ret: " << ret + << ", errno: " << errno; + if (errno != EAGAIN && errno != EWOULDBLOCK) { + epoll_ctl(g_epoll_fd_agg, EPOLL_CTL_DEL, fd, NULL); + close(fd); + } + } + } + } + } + + return 0; +} + +void aggRegLocalMem(uint64_t addr, uint64_t length, bool isAggBuffer) { + const uint64_t alignment = 1 << 21; + if (addr % alignment != 0) { + return; + } + + MemBlock memBlock; + memBlock.addr = addr; + memBlock.len = length; + + g_localBuffer.emplace_back(memBlock); + + if (isAggBuffer) { + MemBlock perHugeBuf; + perHugeBuf.addr = addr; + perHugeBuf.len = PER_HUGE_BUFFER_SIZE; + g_localMemtoSend.emplace_back(addr); + g_localHugeBuffer.emplace_back(new HugeBuffer(perHugeBuf, true)); + } + + return; +} + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_transport_mem_c.cpp b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_transport_mem_c.cpp index 05ae41db7..b6a1d18a7 100644 --- a/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_transport_mem_c.cpp +++ b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_transport_mem_c.cpp @@ -13,254 +13,87 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include -#include -#include -#include #include #include -#include -#include #include -#include -#include -#include -#include +#include +#include #include "mpi.h" #include "transport/ascend_transport/hccl_transport/hccl_transport_mem_c.h" +#include "transport/ascend_transport/hccl_transport/hccl_transport_mem_internals.h" + +#include +#include +#include +#include #ifdef __cplusplus extern "C" { #endif // __cplusplus -#define READ 0 -#define WRITE 1 -#define MAX_EVENTS 32 -#define CONNECT_MAX 1000 -#define RETRY_TIMES 3 -#define VECTOR_RESERVE_SIZE 200 +void *g_dev_addr; -HcclNetDevCtx vnicNetDevCtx_{nullptr}; -HcclNetDevCtx nicNetDevCtx_{nullptr}; -std::shared_ptr vnicServerSocket_{nullptr}; -std::shared_ptr nicServerSocket_{nullptr}; -std::unique_ptr notifyPool_; -HcclDispatcher dispatcher_{nullptr}; -std::unordered_map target_key_to_connection_map_; -std::vector g_localMergeMem; -int g_server_socket_ = 0; -int g_epoll_fd = 0; -struct epoll_event g_ev; -struct epoll_event g_events[MAX_EVENTS]; - -bool printEnabled() { +bool enableAscendLogging() { char *env = getenv("ASCEND_TRANSPORT_PRINT"); return env != nullptr && std::string(env) == "1"; } -uint16_t findAvailableTcpPort(int &sockfd, bool use_ipv6) { - static std::random_device rand_gen; - std::mt19937 gen(rand_gen()); - const int min_port = 15000; - const int max_port = 17000; - const int max_attempts = 500; - std::uniform_int_distribution<> rand_dist(min_port, max_port); - - for (int attempt = 0; attempt < max_attempts; ++attempt) { - int port = rand_dist(rand_gen); - if (use_ipv6) { - sockaddr_in6 bind_address; - memset(&bind_address, 0, sizeof(sockaddr_in6)); - bind_address.sin6_family = AF_INET6; - bind_address.sin6_port = htons(port); - bind_address.sin6_addr = IN6ADDR_ANY_INIT; - if (bind(sockfd, (sockaddr *)&bind_address, sizeof(sockaddr_in6)) < - 0) { - continue; - } - } else { - sockaddr_in bind_address; - memset(&bind_address, 0, sizeof(sockaddr_in)); - bind_address.sin_family = AF_INET; - bind_address.sin_port = htons(port); - bind_address.sin_addr.s_addr = INADDR_ANY; - if (bind(sockfd, (sockaddr *)&bind_address, sizeof(sockaddr_in)) < - 0) { - continue; - } - } - - return port; - } - return 0; -} - -static int initServerNetSocket(RankInfo *local_rank_info) { - RETRY_CALL(HcclNetInit(NICDeployment::NIC_DEPLOYMENT_DEVICE, - local_rank_info->devicePhyId, - local_rank_info->deviceLogicId, false), - "HcclNetInit failed"); - - // Use the physical network card of the device across HCCS - hccl::HcclIpAddress localIp(local_rank_info->deviceIp); - RETRY_CALL(HcclNetOpenDev(&nicNetDevCtx_, NicType::DEVICE_NIC_TYPE, - local_rank_info->devicePhyId, - local_rank_info->deviceLogicId, localIp), - "HcclNetOpenDev DEVICE_NIC_TYPE failed"); - - nicServerSocket_ = std::make_shared( - nicNetDevCtx_, local_rank_info->devicePort); - if (nicServerSocket_ == NULL) { - LOG(ERROR) << "make nicNetDevCtx_ failed"; - return -1; - } - - RETRY_CALL(nicServerSocket_->Init(), "nicServerSocket_ Init failed"); - RETRY_CALL(nicServerSocket_->Listen(), "nicServerSocket_ Listen failed"); - - // Use virtual network card within HCCS - hccl::HcclIpAddress localVnicIp(local_rank_info->devicePhyId); - RETRY_CALL( - hrtRaGetSingleSocketVnicIpInfo( - local_rank_info->devicePhyId, DeviceIdType::DEVICE_ID_TYPE_PHY_ID, - local_rank_info->devicePhyId, localVnicIp), - "hrtRaGetSingleSocketVnicIpInfo failed"); - - RETRY_CALL(HcclNetOpenDev(&vnicNetDevCtx_, NicType::VNIC_TYPE, - local_rank_info->devicePhyId, - local_rank_info->deviceLogicId, localVnicIp), - "HcclNetOpenDev vnicNetDevCtx_ failed"); - - // control plane connection, creat serversocket, listening client - vnicServerSocket_ = std::make_shared( - vnicNetDevCtx_, local_rank_info->devicePort); - if (vnicServerSocket_ == NULL) { - LOG(ERROR) << "vnicServerSocket_ make failed"; - return -1; - } - - RETRY_CALL(vnicServerSocket_->Init(), "vnicServerSocket_ Init failed"); - RETRY_CALL(vnicServerSocket_->Listen(), "vnicServerSocket_ Listen failed"); - - RETRY_CALL(HcclDispatcherInit(DispatcherType::DISPATCHER_NORMAL, - local_rank_info->devicePhyId, &dispatcher_), - "client HcclDispatcherInit failed"); - - notifyPool_.reset(new (std::nothrow) hccl::NotifyPool()); - if (notifyPool_ == nullptr) { - LOG(ERROR) << "reset notifyPool error"; - return -1; - } - - RETRY_CALL(notifyPool_->Init(local_rank_info->devicePhyId), - "Init notifyPool error"); - - return 0; -} - -// The out-of-band socket on the host side that ascend_transport depends on, -// used to convey control information such as deviceId and deviceIp -static int initControlSocket(RankInfo *local_rank_info) { +int initTransportMem(RankInfo *local_rank_info, bool aggregateEnabled) { int ret = 0; - g_server_socket_ = socket(AF_INET, SOCK_STREAM, 0); - if (g_server_socket_ < 0) { - LOG(ERROR) << "ascend transport out-of-band socket create failed"; - return g_server_socket_; + if (local_rank_info == NULL) { + LOG(ERROR) << "initTransportMem local_rank_info is NULL"; + return -1; } - int optval = 1; - ret = setsockopt(g_server_socket_, SOL_SOCKET, SO_REUSEADDR, &optval, - sizeof(optval)); - if (ret < 0) { - LOG(ERROR) << "set sock opt failed, ret: " << ret; - close(g_server_socket_); + ret = rtGetDevicePhyIdByIndex(local_rank_info->deviceLogicId, + &local_rank_info->devicePhyId); + if (ret) { + LOG(ERROR) << "HcclTransport: rtGetDevicePhyIdByIndex failed, ret: " + << ret; return ret; } - struct sockaddr_in bind_address; - memset(&bind_address, 0, sizeof(sockaddr_in)); - bind_address.sin_family = AF_INET; - bind_address.sin_addr.s_addr = INADDR_ANY; - bind_address.sin_port = htons(local_rank_info->hostPort); - - ret = bind(g_server_socket_, (struct sockaddr *)&bind_address, - sizeof(bind_address)); - if (ret < 0) { - LOG(INFO) << "bind failed on the default port, default port: " - << local_rank_info->hostPort << ", will find available port"; - uint16_t port = findAvailableTcpPort(g_server_socket_, false); - if (port == 0) { - LOG(ERROR) << "findAvailableTcpPort failed"; - close(g_server_socket_); - return -1; - } - local_rank_info->hostPort = (uint64_t)port; - } + local_rank_info->hostPort = + ASCEND_DEFAULT_HOST_PORT + local_rank_info->devicePhyId; - struct timeval timeout; - timeout.tv_sec = 120; - timeout.tv_usec = 0; - ret = setsockopt(g_server_socket_, SOL_SOCKET, SO_RCVTIMEO, - (const char *)&timeout, sizeof(timeout)); - if (ret < 0) { - LOG(ERROR) << "Set recv timeout failed, ret: " << ret; - close(g_server_socket_); - return ret; - } + getDevIpAddresses(local_rank_info); - ret = listen(g_server_socket_, CONNECT_MAX); - if (ret < 0) { - LOG(ERROR) << "Listen Failed, ret: " << ret; - close(g_server_socket_); - return ret; - } - LOG(INFO) << "initControlSocket successful, listen on hostPort: " - << local_rank_info->hostPort << "..." << " g_server_socket_" - << g_server_socket_; - g_epoll_fd = epoll_create1(0); - if (g_epoll_fd == -1) { - LOG(ERROR) << "epoll create Failed, ret: " << g_epoll_fd; - close(g_server_socket_); - return g_epoll_fd; - } - g_ev.events = EPOLLIN; - g_ev.data.fd = g_server_socket_; - ret = epoll_ctl(g_epoll_fd, EPOLL_CTL_ADD, g_server_socket_, &g_ev); - if (ret < 0) { - LOG(ERROR) << "epoll epoll_ctl Failed, ret: " << ret; - close(g_server_socket_); - return ret; - } - return 0; -} - -int initTransportMem(RankInfo *local_rank_info) { - int ret = 0; - if (local_rank_info == NULL) { - LOG(ERROR) << "initTransportMem local_rank_info is NULL"; - return -1; + if (a3Enabled()) { + const int infoTypeSdid = 26; + ret = rtGetDeviceInfo(local_rank_info->deviceLogicId, + RT_MODULE_TYPE_SYSTEM, infoTypeSdid, + &local_rank_info->sdid); + if (ret) { + LOG(ERROR) << "rtGetDeviceInfo failed, ret: " << ret; + return ret; + } + ret = rtGetServerIDBySDID(local_rank_info->sdid, + &local_rank_info->serverId); + if (ret) { + LOG(ERROR) << "rtGetServerIDBySDID failed, ret: " << ret; + return ret; + } } - uint32_t devPid; - ret = SalGetBareTgid(reinterpret_cast(&devPid)); + ret = SalGetBareTgid(&local_rank_info->devPid); if (ret) { LOG(ERROR) << "SalGetBareTgid failed: " << ret; return ret; } - local_rank_info->pid = (uint64_t)devPid; - LOG(INFO) << "initTransportMem local_rank_info rankId: " << local_rank_info->rankId - << ", serverIdx: " << local_rank_info->serverIdx + << ", serverId: " << local_rank_info->serverId << ", deviceLogicId: " << local_rank_info->deviceLogicId << ", devicePhyId: " << local_rank_info->devicePhyId - << ", deviceIp: " << inet_ntoa(local_rank_info->deviceIp) + << ", deviceIp: " << local_rank_info->deviceIp << ", devicePort: " << local_rank_info->devicePort - << ", hostIp: " << inet_ntoa(local_rank_info->hostIp) + << ", hostIp: " << local_rank_info->hostIp << ", hostPort: " << local_rank_info->hostPort - << ", device pid: " << local_rank_info->pid; + << ", device Pid: " << local_rank_info->devPid + << ", vnicIp: " << local_rank_info->vnicIp + << ", sdid: " << local_rank_info->sdid; // Initialize the virtual network card and socket for the data channel, // exchange RmaMem, and create the QP connection @@ -270,453 +103,158 @@ int initTransportMem(RankInfo *local_rank_info) { return ret; } - ret = initControlSocket(local_rank_info); + ret = initControlSocket(local_rank_info, aggregateEnabled); if (ret) { LOG(ERROR) << "initControlSocket failed, ret: " << ret; return ret; } - g_localMergeMem.reserve(VECTOR_RESERVE_SIZE); + g_localBuffer.reserve(VECTOR_RESERVE_SIZE); return 0; } -static int connectToTarget(std::string target_ip, int target_port) { - int client_socket; - struct sockaddr_in server_addr; - - client_socket = socket(AF_INET, SOCK_STREAM, 0); - if (client_socket < 0) { - LOG(ERROR) << "Socket creation failed"; - return client_socket; - } - - int optval = 1; - int ret = setsockopt(client_socket, SOL_SOCKET, SO_REUSEADDR, &optval, - sizeof(optval)); - if (ret < 0) { - LOG(ERROR) << "set sock opt failed, ret: " << ret; - close(client_socket); - return ret; - } - - memset(&server_addr, 0, sizeof(server_addr)); - server_addr.sin_family = AF_INET; - server_addr.sin_port = htons(target_port); - server_addr.sin_addr.s_addr = inet_addr(target_ip.c_str()); - - if (server_addr.sin_addr.s_addr == INADDR_NONE) { - LOG(ERROR) << "Invalid server IP address"; - close(client_socket); - return -1; - } - - int connected = 0; - - const char *tcp_timeout_str = std::getenv("Ascend_TCP_TIMEOUT"); - int ascend_tcp_timeout = tcp_timeout_str ? std::atoi(tcp_timeout_str) : 30; - int connect_retry_times = ascend_tcp_timeout * 100; - - for (int i = 0; i < connect_retry_times; ++i) { - if (connect(client_socket, (struct sockaddr *)&server_addr, - sizeof(server_addr)) == 0) { - LOG(INFO) << "Connect to host server " << target_ip << ":" - << ntohs(server_addr.sin_port) << " successful"; - connected = 1; - break; - } - - LOG(INFO) << "Connect attempt " << i << " failed: " << strerror(errno) - << ", retry once"; - - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } - - if (!connected) { - LOG(ERROR) << "Failed to connect to server after " - << connect_retry_times << " retries"; - close(client_socket); - return HCCL_E_TIMEOUT; - } - - return client_socket; -} - -int controlInfoSend(RankInfo *local_rank_info, RankInfo *remote_rank_info) { - int ret = 0; - std::string key_str = inet_ntoa(remote_rank_info->hostIp) + - std::to_string(remote_rank_info->devicePhyId); - LOG(INFO) << "transportMemTask local_rank_info rankId: " - << local_rank_info->rankId - << ", serverIdx: " << local_rank_info->serverIdx - << ", deviceLogicId: " << local_rank_info->deviceLogicId - << ", devicePhyId: " << local_rank_info->devicePhyId - << ", deviceIp: " << inet_ntoa(local_rank_info->deviceIp) - << ", devicePort: " << local_rank_info->devicePort - << ", hostIp: " << inet_ntoa(local_rank_info->hostIp) - << ", hostPort: " << local_rank_info->hostPort - << ", device pid: " << local_rank_info->pid; - - LOG(INFO) << "transportMemTask remote_rank_info rankId: " - << remote_rank_info->rankId - << ", serverIdx: " << remote_rank_info->serverIdx - << ", deviceLogicId: " << remote_rank_info->deviceLogicId - << ", devicePhyId: " << remote_rank_info->devicePhyId - << ", deviceIp: " << inet_ntoa(remote_rank_info->deviceIp) - << ", devicePort: " << remote_rank_info->devicePort - << ", hostIp: " << inet_ntoa(remote_rank_info->hostIp) - << ", hostPort: " << remote_rank_info->hostPort - << ", device pid: " << remote_rank_info->pid; - - // Encapsulate control information - RankControlInfo control_info; - control_info.deviceLogicId = local_rank_info->deviceLogicId; - control_info.devicePhyId = local_rank_info->devicePhyId; - control_info.hostIp = local_rank_info->hostIp; - control_info.deviceIp = local_rank_info->deviceIp; - control_info.pid = local_rank_info->pid; - // Self-built out-of-band, host socket for sending control plane - int client_socket = connectToTarget(inet_ntoa(remote_rank_info->hostIp), - remote_rank_info->hostPort); - if (client_socket < 0) { - LOG(ERROR) << "client connect failed"; - return client_socket; - } - ret = send(client_socket, &control_info, sizeof(RankControlInfo), 0); - if (ret < 0) { - LOG(ERROR) << "send control_info failed, ret: " << ret; - close(client_socket); - return ret; - } - target_key_to_connection_map_[key_str].tcp_socket = client_socket; - return 0; -} - -int createClientSocket(std::shared_ptr &hccl_socket, - RankInfo *local_rank_info, RankInfo *remote_rank_info, - bool is_cross_hccs, std::string tag) { - int ret = 0; - hccl::HcclIpAddress rempoteDevIp; - std::string key_str = inet_ntoa(remote_rank_info->hostIp) + +int transportMemAddOpFence(RankInfo *remote_rank_info, aclrtStream stream) { + std::string key_str = std::string(remote_rank_info->hostIp) + '-' + std::to_string(remote_rank_info->devicePhyId); - std::string baseTag_ = inet_ntoa(local_rank_info->hostIp) + - std::to_string(local_rank_info->devicePhyId) + - key_str + tag; - if (!is_cross_hccs) { - std::vector remoteDevPhyId; - remoteDevPhyId.push_back(remote_rank_info->devicePhyId); - ret = hccl::P2PMgmtPub::EnableP2P(remoteDevPhyId); - if (ret) { - LOG(ERROR) << "P2PMgmtPub EnableP2P failed, ret: " << ret; - return ret; - } - ret = hccl::P2PMgmtPub::WaitP2PEnabled(remoteDevPhyId); - if (ret) { - LOG(ERROR) << "P2PMgmtPub WaitP2PEnabled failed, ret: " << ret; - return ret; - } - rempoteDevIp = hccl::HcclIpAddress(remote_rank_info->devicePhyId); - ret = hrtRaGetSingleSocketVnicIpInfo( - local_rank_info->devicePhyId, DeviceIdType::DEVICE_ID_TYPE_PHY_ID, - remote_rank_info->devicePhyId, rempoteDevIp); - if (ret) { - LOG(ERROR) << "hrtRaGetSingleSocketVnicIpInfo, ret: " << ret; - return ret; - } - hccl_socket = std::make_shared( - baseTag_, vnicNetDevCtx_, rempoteDevIp, - remote_rank_info->devicePort, - hccl::HcclSocketRole::SOCKET_ROLE_CLIENT); - } else { - rempoteDevIp = hccl::HcclIpAddress(remote_rank_info->deviceIp); - hccl_socket = std::make_shared( - baseTag_, nicNetDevCtx_, rempoteDevIp, remote_rank_info->devicePort, - hccl::HcclSocketRole::SOCKET_ROLE_CLIENT); - } - - ret = hccl_socket->Init(); - if (ret) { - char deviceIp[64]; - inet_ntop(AF_INET, &rempoteDevIp, deviceIp, sizeof(deviceIp)); - LOG(ERROR) << "client hccl_socket init failed, target devicePhyId: " - << remote_rank_info->devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp - << ", remote port: " << remote_rank_info->devicePort - << ", ret: " << ret; - return ret; - } - ret = hccl_socket->Connect(); + int ret = g_target_key_to_connection_map[key_str].transport_mem->AddOpFence( + stream); if (ret) { - char deviceIp[64]; - inet_ntop(AF_INET, &rempoteDevIp, deviceIp, sizeof(deviceIp)); - LOG(ERROR) << "client hccl_socket Connect failed, target devicePhyId: " - << remote_rank_info->devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp - << ", remote port: " << remote_rank_info->devicePort - << ", ret: " << ret; + LOG(ERROR) << "transport_mem AddOpFence failed, ret: " << ret; return ret; } - LOG(INFO) << "hccl_socket begin to connect, local devicePhyId: " - << local_rank_info->devicePhyId - << ", target devicePhyId: " << remote_rank_info->devicePhyId; - - hccl::HcclSocketStatus status; - struct timespec start, end; - const char *hccl_socket_timeout_str = - std::getenv("Ascend_HCCL_SOCKET_TIMEOUT"); - int hccl_socket_timeout = - hccl_socket_timeout_str ? std::atoi(hccl_socket_timeout_str) : 30; - long long hccl_socket_timeout_ns = - static_cast(hccl_socket_timeout) * 1000000000LL; - clock_gettime(CLOCK_MONOTONIC, &start); - do { - status = hccl_socket->GetStatus(); - clock_gettime(CLOCK_MONOTONIC, &end); - long long elapsed_time = (end.tv_sec - start.tv_sec) * 1000000000LL + - (end.tv_nsec - start.tv_nsec); - if (elapsed_time > - hccl_socket_timeout_ns) { // Exceeds 20 seconds,TimeOut - LOG(ERROR) << "hccl_socket connect timeout, local devicePhyId: " - << local_rank_info->devicePhyId - << ", target devicePhyId: " - << remote_rank_info->devicePhyId; - return HCCL_E_TIMEOUT; - } - } while (status != hccl::HcclSocketStatus::SOCKET_OK); - - LOG(INFO) << "hccl_socket connect success, local devicePhyId: " - << local_rank_info->devicePhyId - << ", target devicePhyId: " << remote_rank_info->devicePhyId; return 0; } -int createTransportMem(RankInfo *local_rank_info, RankInfo *remote_rank_info, - std::shared_ptr &transport_mem) { +// // 打印单个BFloat16值的详细信息 +// void print_bfloat16_detail(uint16_t bf16_value) { +// uint16_t sign = (bf16_value >> 15) & 0x1; +// uint16_t exponent = (bf16_value >> 7) & 0xFF; +// uint16_t mantissa = bf16_value & 0x7F; + +// LOG(INFO) << "0x" << std::hex << std::setw(4) << std::setfill('0') << +// bf16_value +// << " (二进制: " << std::bitset<16>(bf16_value) << ")" << " [符号: " << +// sign << ", 指数: " +// << std::bitset<8>(exponent) << "(" << std::dec << exponent << ")" +// << ", 尾数: " << std::bitset<7>(mantissa) << "(" << mantissa << ")]"; +// } + +// // 将BFloat16转换为近似的float值(简单近似,不处理特殊情况) +// float bfloat16_to_float_approx(uint16_t bf16_value) { +// // 直接将BFloat16的位模式扩展到float(32位) +// uint32_t float_bits = static_cast(bf16_value) << 16; +// float result; +// std::memcpy(&result, &float_bits, sizeof(float)); +// return result; +// } + +// void print_bfloat16_memory(const void* data, size_t len) { +// // 确保长度是2的倍数(BFloat16是2字节) +// size_t num_elements = len / sizeof(uint16_t); +// if (len % sizeof(uint16_t) != 0) { +// std::cout << "警告: 长度不是2的倍数,可能会截断数据" << std::endl; +// } + +// const uint16_t* bf16_data = static_cast(data); + +// LOG(INFO) << "内存地址: " << data; +// LOG(INFO) << "数据长度: " << len << " 字节 (" << num_elements << " +// 个BFloat16值)"; LOG(INFO) << +// "=========================================="; + +// for (size_t i = 0; i < num_elements; ++i) { +// uint16_t value = bf16_data[i]; + +// std::cout << "[" << i << "] "; +// print_bfloat16_detail(value); + +// // 显示近似的float值 +// float approx_float = bfloat16_to_float_approx(value); +// std::cout << " ≈ " << approx_float << "f"; + +// std::cout << std::endl; +// } + +// // 如果有剩余字节,打印它们 +// if (len % sizeof(uint16_t) != 0) { +// std::cout << "剩余字节: "; +// const uint8_t* byte_data = static_cast(data); +// for (size_t i = num_elements * sizeof(uint16_t); i < len; ++i) { +// std::cout << "0x" << std::hex << std::setw(2) << +// std::setfill('0') +// << static_cast(byte_data[i]) << " "; +// } +// std::cout << std::dec << std::endl; +// } +// } + +int nonAggTransportMemTask(RankInfo *local_rank_info, + RankInfo *remote_rank_info, int op_code, + uint64_t offset, uint64_t req_len, void *local_mem, + int mem_type, aclrtStream stream) { + // Check if a connection has been established with the peer, and send local + // information to the peer int ret = 0; - bool same_host = - local_rank_info->hostIp.s_addr == remote_rank_info->hostIp.s_addr; - // For A2 series, internal communication among 8 cards does not cross HCCS, - // such as communication among cards 0-7 - bool same_group = (local_rank_info->devicePhyId / 8) == - (remote_rank_info->devicePhyId / 8); - bool is_cross_hccs = !(same_host && same_group); - std::string key_str = inet_ntoa(remote_rank_info->hostIp) + + std::string key_str = std::string(remote_rank_info->hostIp) + '-' + std::to_string(remote_rank_info->devicePhyId); - if (printEnabled()) { - LOG(INFO) << "hccl transport is cross_hccs: " - << (is_cross_hccs ? "true (cross-hccs)" - : "false (same-hccs)"); - } + uint64_t local_buffer = 0; std::shared_ptr hccl_ctrl_socket; std::shared_ptr hccl_data_socket; - ret = createClientSocket(hccl_ctrl_socket, local_rank_info, - remote_rank_info, is_cross_hccs, "ctrl"); - if (ret) { - LOG(ERROR) << "createClientSocket hccl_ctrl_socket failed, ret: " - << ret; - return ret; - } - target_key_to_connection_map_[key_str].hccl_ctrl_socket = hccl_ctrl_socket; - ret = createClientSocket(hccl_data_socket, local_rank_info, - remote_rank_info, is_cross_hccs, "data"); - if (ret) { - LOG(ERROR) << "createClientSocket hccl_data_socket failed, ret: " - << ret; - return ret; - } - target_key_to_connection_map_[key_str].hccl_data_socket = hccl_data_socket; - hccl::TransportMem::AttrInfo attrInfo; - attrInfo.localRankId = local_rank_info->deviceLogicId; - attrInfo.remoteRankId = remote_rank_info->deviceLogicId; - attrInfo.sdid = 0xFFFFFFFF; - attrInfo.serverId = local_rank_info->serverIdx; - attrInfo.trafficClass = 132; - attrInfo.serviceLevel = 4; - if (is_cross_hccs) { - transport_mem = hccl::TransportMem::Create( - hccl::TransportMem::TpType::ROCE, notifyPool_, nicNetDevCtx_, - dispatcher_, attrInfo); - } else { - transport_mem = hccl::TransportMem::Create( - hccl::TransportMem::TpType::IPC, notifyPool_, vnicNetDevCtx_, - dispatcher_, attrInfo); - } - ret = transport_mem->SetDataSocket(hccl_data_socket); - if (ret) { - char deviceIp[64]; - inet_ntop(AF_INET, &remote_rank_info->deviceIp, deviceIp, - sizeof(deviceIp)); - LOG(ERROR) << "transport_mem SetDataSocket failed, target devicePhyId: " - << remote_rank_info->devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp - << ", remote port: " << remote_rank_info->devicePort - << ", ret: " << ret; - return ret; - } - ret = transport_mem->SetSocket(hccl_ctrl_socket); - if (ret) { - char deviceIp[64]; - inet_ntop(AF_INET, &remote_rank_info->deviceIp, deviceIp, - sizeof(deviceIp)); - LOG(ERROR) << "transport_mem SetSocket failed, target devicePhyId: " - << remote_rank_info->devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp - << ", remote port: " << remote_rank_info->devicePort - << ", ret: " << ret; - return ret; - } - const char *transport_mem_timeout_str = - std::getenv("Ascend_TRANSPORT_MEM_TIMEOUT"); - int transport_mem_timeout = - transport_mem_timeout_str ? std::atoi(transport_mem_timeout_str) : 120; - ret = transport_mem->Connect(transport_mem_timeout); - if (ret) { - char deviceIp[64]; - inet_ntop(AF_INET, &remote_rank_info->deviceIp, deviceIp, - sizeof(deviceIp)); - LOG(ERROR) << "transport_mem Connect failed, target devicePhyId: " - << remote_rank_info->devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp - << ", remote port: " << remote_rank_info->devicePort - << ", ret: " << ret; - return ret; - } - LOG(INFO) << "transport_mem connect success"; - target_key_to_connection_map_[key_str].transport_mem = transport_mem; - size_t m_num = g_localMergeMem.size(); - std::vector rmaMemDescs(m_num); - - for (size_t i = 0; i < m_num; ++i) { - HcclMem mem; - HcclBuf buf; - mem.addr = g_localMergeMem[i].addr; - mem.size = g_localMergeMem[i].len; - mem.type = HCCL_MEM_TYPE_DEVICE; - if (!is_cross_hccs) { - ret = HcclMemReg(vnicNetDevCtx_, &mem, &buf); - } else { - ret = HcclMemReg(nicNetDevCtx_, &mem, &buf); - } - if (ret != 0 && ret != 20) { - LOG(ERROR) << "HcclMemReg failed, ret: " << ret - << " addr: " << g_localMergeMem[i].addr - << " len: " << g_localMergeMem[i].len; - return ret; - } - char *desc = nullptr; - uint64_t desc_len = 0; - ret = HcclMemExport(&buf, &desc, &desc_len); + std::shared_ptr transport_mem{}; + auto iter = g_target_key_to_connection_map.find(key_str); + if (iter == g_target_key_to_connection_map.end() || + g_target_key_to_connection_map[key_str].tcp_socket == 0) { + ret = controlInfoSend(local_rank_info, remote_rank_info); if (ret) { - LOG(ERROR) << "HcclMemExport failed, ret: " << ret - << ", addr: " << g_localMergeMem[i].addr - << ", len: " << g_localMergeMem[i].len; + LOG(ERROR) << "controlInfoSend failed, ret: " << ret; return ret; } - - rmaMemDescs[i].localRankId = local_rank_info->deviceLogicId; - rmaMemDescs[i].remoteRankId = remote_rank_info->deviceLogicId; - memset_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, 0, - hccl::TRANSPORT_EMD_ESC_SIZE); - if (memcpy_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, desc, - desc_len + 1) != EOK) { - LOG(ERROR) << "memcpy_s failed, ret: " << ret - << ", addr: " << g_localMergeMem[i].addr - << ", len: " << g_localMergeMem[i].len; - return -1; + bool is_cross_hccs = false; + if (a3Enabled()) { + is_cross_hccs = false; + } else { + bool same_host = (strcmp(local_rank_info->hostIp, + remote_rank_info->hostIp) == 0); + // For A2 series, internal communication among 8 cards does not + // cross HCCS, such as communication among cards 0-7 + bool same_group = (local_rank_info->devicePhyId / 8) == + (remote_rank_info->devicePhyId / 8); + is_cross_hccs = !(same_host && same_group); } - - // In the scenario within HCCS, it is necessary to call HcclMemGrant to - // authorize peer memory - if (!is_cross_hccs) { - HcclMemGrantInfo grant_info; - grant_info.remotePid = (int32_t)remote_rank_info->pid; - grant_info.remoteSdid = 0xFFFFFFFF; - ret = HcclMemGrant(&buf, &grant_info); - if (ret) { - LOG(ERROR) << "HcclMemGrant failed, ret: " << ret - << ", addr: " << g_localMergeMem[i].addr - << ", len: " << g_localMergeMem[i].len; - return ret; - } + if (enableAscendLogging()) { + LOG(INFO) << "hccl transport is cross_hccs: " + << (is_cross_hccs ? "true (cross-hccs)" + : "false (same-hccs)"); } - } - hccl::TransportMem::RmaMemDescs localRmaMemDescs; - localRmaMemDescs.array = rmaMemDescs.data(); - localRmaMemDescs.arrayLength = rmaMemDescs.size(); - uint32_t actualNumOfRemote = 0; - std::vector remoteRmaMemDescArray(m_num); - hccl::TransportMem::RmaMemDescs remoteRmaMemDescs; - remoteRmaMemDescs.array = remoteRmaMemDescArray.data(); - remoteRmaMemDescs.arrayLength = m_num; - ret = transport_mem->ExchangeMemDesc(localRmaMemDescs, remoteRmaMemDescs, - actualNumOfRemote); - if (ret) { - LOG(ERROR) << "transport_mem->ExchangeMemDesc failed, ret: " << ret - << ", local_rank: " << local_rank_info->devicePhyId - << ", remote_rank: " << remote_rank_info->devicePhyId; - return ret; - } - std::vector remoteRmaMemArray(m_num); - for (uint32_t i = 0; i < m_num; ++i) { - ret = transport_mem->EnableMemAccess(remoteRmaMemDescArray[i], - remoteRmaMemArray[i]); + ret = createClientSocket(hccl_ctrl_socket, local_rank_info, + remote_rank_info, is_cross_hccs, "ctrl"); if (ret) { - LOG(ERROR) << "transport_mem->EnableMemAccess failed, ret: " << ret - << ", i: " << i - << ", local_rank: " << local_rank_info->devicePhyId - << ", remote_rank: " << remote_rank_info->devicePhyId; + LOG(ERROR) << "createClientSocket hccl_ctrl_socket failed, ret: " + << ret; return ret; } - } - LOG(INFO) << "ExchangeMem and EnableMemAccess Success, local devicePhyId: " - << local_rank_info->devicePhyId - << ", target devicePhyId: " << remote_rank_info->devicePhyId; - return 0; -} - -int transportMemAddOpFence(RankInfo *remote_rank_info, aclrtStream stream) { - std::string key_str = inet_ntoa(remote_rank_info->hostIp) + - std::to_string(remote_rank_info->devicePhyId); - int ret = target_key_to_connection_map_[key_str].transport_mem->AddOpFence( - stream); - if (ret) { - LOG(ERROR) << "transport_mem AddOpFence failed, ret: " << ret; - return ret; - } - - return 0; -} - -int transportMemTask(RankInfo *local_rank_info, RankInfo *remote_rank_info, - int op_code, uint64_t offset, uint64_t req_len, - void *local_mem, aclrtStream stream) { - int ret = 0; - std::shared_ptr transport_mem{}; - // Check if a connection has been established with the peer, and send local - // information to the peer - std::string key_str = inet_ntoa(remote_rank_info->hostIp) + - std::to_string(remote_rank_info->devicePhyId); - auto iter = target_key_to_connection_map_.find(key_str); - if (iter == target_key_to_connection_map_.end()) { - ret = controlInfoSend(local_rank_info, remote_rank_info); + g_target_key_to_connection_map[key_str].hccl_ctrl_socket = + hccl_ctrl_socket; + ret = createClientSocket(hccl_data_socket, local_rank_info, + remote_rank_info, is_cross_hccs, "data"); if (ret) { - LOG(ERROR) << "controlInfoSend failed, ret: " << ret; + LOG(ERROR) << "createClientSocket hccl_data_socket failed, ret: " + << ret; return ret; } - ret = createTransportMem(local_rank_info, remote_rank_info, - transport_mem); + g_target_key_to_connection_map[key_str].hccl_data_socket = + hccl_data_socket; + + ret = createTransportMem(local_rank_info, remote_rank_info, key_str, + is_cross_hccs, transport_mem, false); if (ret) { LOG(ERROR) << "createTransportMem failed, ret: " << ret; return ret; } } else { - transport_mem = target_key_to_connection_map_[key_str].transport_mem; + transport_mem = g_target_key_to_connection_map[key_str].transport_mem; } hccl::TransportMem::RmaOpMem localMem; localMem.addr = local_mem; @@ -724,6 +262,7 @@ int transportMemTask(RankInfo *local_rank_info, RankInfo *remote_rank_info, hccl::TransportMem::RmaOpMem remoteMem; remoteMem.addr = (void *)offset; remoteMem.size = req_len; + if (op_code == WRITE) { ret = transport_mem->Write(remoteMem, localMem, stream); if (ret) { @@ -747,122 +286,90 @@ int transportMemTask(RankInfo *local_rank_info, RankInfo *remote_rank_info, return 0; } -static int acceptFromTarget() { - int client_socket; - struct sockaddr_in client_addr; - socklen_t client_len = sizeof(client_addr); - client_socket = - accept(g_server_socket_, (struct sockaddr *)&client_addr, &client_len); - if (client_socket < 0) { - LOG(ERROR) << "Accept failed"; - return client_socket; - } - - LOG(INFO) << "host client connected from " - << inet_ntoa(client_addr.sin_addr) << ":" - << ntohs(client_addr.sin_port); - return client_socket; -} - -int acceptSocket(std::shared_ptr &hccl_socket, - RankInfo *local_rank_info, RankControlInfo remote_control_info, - std::string baseTag_, hccl::HcclIpAddress rempoteDevIp, - bool is_cross_hccs) { - int ret = 0; - std::vector wlistInfoVec; - SocketWlistInfo wlistInfo = {}; - wlistInfo.connLimit = 1; - memcpy(&wlistInfo.tag[0], baseTag_.c_str(), baseTag_.size() + 1); - wlistInfo.remoteIp.addr = rempoteDevIp.GetBinaryAddress().addr; - wlistInfo.remoteIp.addr6 = rempoteDevIp.GetBinaryAddress().addr6; - wlistInfoVec.push_back(wlistInfo); - auto serverSocket = is_cross_hccs ? nicServerSocket_ : vnicServerSocket_; - ret = serverSocket->AddWhiteList(wlistInfoVec); - if (ret) { - LOG(ERROR) << "serverSocket AddWhiteList failed, ret: " << ret; +int transportMemAccept(RankInfo *local_rank_info, bool aggregateEnabled) { + // Self-built out-of-band, host socket for receiving control plane + int ret = socketEpollWait(); + if (ret == 0) { + return 0; + } else if (ret < 0) { + LOG(ERROR) << "socketEpollWait failed, ret: " << ret; return ret; } - // Before using the device-side network card for communication, it is - // necessary to add the client device address to the whitelist. - LOG(INFO) << "Add the client's Device IP address to the whitelist success."; - ret = serverSocket->Accept(baseTag_, hccl_socket); - if (ret) { - LOG(ERROR) << "serverSocket transportMemAccept ctrl socket failed ret: " - << ret; - return ret; + int recv_socket = acceptFromTarget(); + if (recv_socket < 0) { + return recv_socket; } - return 0; -} -int transportMemAccept(RankInfo *local_rank_info) { - // Self-built out-of-band, host socket for receiving control plane - int ret = 0; - int nfds = epoll_wait(g_epoll_fd, g_events, MAX_EVENTS, -1); - if (nfds == -1) { - return 0; - } - int client_socket = acceptFromTarget(); - if (client_socket < 0) { - return client_socket; - } + // int client_socket = acceptFromTarget(); + // if (client_socket < 0) { + // LOG(ERROR) << "acceptFromTarget failed, client_socket: " + // << client_socket; + // return client_socket; + // } RankControlInfo remote_control_info; - ret = recv(client_socket, &remote_control_info, sizeof(RankControlInfo), 0); + ret = recv(recv_socket, &remote_control_info, sizeof(RankControlInfo), 0); if (ret <= 0) { - if (ret < 0) { - LOG(ERROR) << "recv failed, ret: " << ret; - } else { - LOG(ERROR) << "Peer close the connection, ret: " << ret; - } - close(client_socket); + LOG(ERROR) << "Failed to receive remote_control_info, ret: " << ret + << ", errno: " << errno << ", error: " << strerror(errno); return -1; } LOG(INFO) << "Received remote_control_info, deviceLogicId: " << remote_control_info.deviceLogicId << ", devicePhyId: " << remote_control_info.devicePhyId - << ", hostIp: " << inet_ntoa(remote_control_info.hostIp) - << ", deviceIp: " << inet_ntoa(remote_control_info.deviceIp) - << ", device pid: " << remote_control_info.pid; + << ", hostIp: " << std::string(remote_control_info.hostIp) + << ", deviceIp: " << std::string(remote_control_info.deviceIp) + << ", device pid: " << remote_control_info.devPid; // Check if TransportMem has been established with the peer - std::string key_str = inet_ntoa(remote_control_info.hostIp) + + std::string key_str = std::string(remote_control_info.hostIp) + '-' + std::to_string(remote_control_info.devicePhyId); - auto iter = target_key_to_connection_map_.find(key_str); - if (iter != target_key_to_connection_map_.end()) { + auto iter = g_target_key_to_connection_map.find(key_str); + if (iter != g_target_key_to_connection_map.end()) { LOG(WARNING) << "A duplicate connection request from the same remote endpoint " "has been detected, the remote side may have restarted."; } - std::string baseTag_ = key_str + inet_ntoa(local_rank_info->hostIp) + - std::to_string(local_rank_info->devicePhyId); - hccl::HcclIpAddress rempoteDevIp; + std::string baseTag_ = key_str + std::string(local_rank_info->hostIp) + + '-' + std::to_string(local_rank_info->devicePhyId); + hccl::HcclIpAddress remoteDevIp; std::shared_ptr hccl_ctrl_socket; std::shared_ptr hccl_data_socket; - bool same_host = - local_rank_info->hostIp.s_addr == remote_control_info.hostIp.s_addr; - // For A2 series, internal communication among 8 cards does not cross HCCS, - // such as communication among cards 0-7 - bool same_group = (local_rank_info->devicePhyId / 8) == - (remote_control_info.devicePhyId / 8); - bool is_cross_hccs = !(same_host && same_group); - if (printEnabled()) { + bool is_cross_hccs = false; + if (a3Enabled()) { + is_cross_hccs = false; + } else { + bool same_host = + (strcmp(local_rank_info->hostIp, remote_control_info.hostIp) == 0); + // For A2 series, internal communication among 8 cards does not cross + // HCCS, such as communication among cards 0-7 + bool same_group = (local_rank_info->devicePhyId / 8) == + (remote_control_info.devicePhyId / 8); + is_cross_hccs = !(same_host && same_group); + } + if (enableAscendLogging()) { LOG(INFO) << "transport is cross_hccs: " << (is_cross_hccs ? "true (cross-hccs)" : "false (same-hccs)"); } if (!is_cross_hccs) { std::vector remoteDevPhyId; - rempoteDevIp = hccl::HcclIpAddress(remote_control_info.devicePhyId); - remoteDevPhyId.push_back(remote_control_info.devicePhyId); - ret = hrtRaGetSingleSocketVnicIpInfo( - local_rank_info->devicePhyId, DeviceIdType::DEVICE_ID_TYPE_PHY_ID, - remote_control_info.devicePhyId, rempoteDevIp); - if (ret) { - LOG(ERROR) << "hrtRaGetSingleSocketVnicIpInfo failed, ret: " << ret; - return ret; + if (a3Enabled()) { + remoteDevIp = hccl::HcclIpAddress(remote_control_info.vnicIp); + } else { + ret = hrtRaGetSingleSocketVnicIpInfo( + local_rank_info->devicePhyId, + DeviceIdType::DEVICE_ID_TYPE_PHY_ID, + remote_control_info.devicePhyId, remoteDevIp); + if (ret) { + LOG(ERROR) << "hrtRaGetSingleSocketVnicIpInfo failed, ret: " + << ret; + return ret; + } } + remoteDevPhyId.push_back(remote_control_info.devicePhyId); ret = hccl::P2PMgmtPub::EnableP2P(remoteDevPhyId); if (ret) { LOG(ERROR) << "P2PMgmtPub EnableP2P failed, ret: " << ret; @@ -873,194 +380,149 @@ int transportMemAccept(RankInfo *local_rank_info) { LOG(ERROR) << "P2PMgmtPub EnableP2P failed, ret: " << ret; return ret; } - ret = - acceptSocket(hccl_ctrl_socket, local_rank_info, remote_control_info, - baseTag_ + "ctrl", rempoteDevIp, is_cross_hccs); + ret = acceptHcclSocket(hccl_ctrl_socket, baseTag_ + "ctrl", remoteDevIp, + is_cross_hccs); if (ret) { - LOG(ERROR) << "acceptSocket ctrl failed, ret: " << ret; + LOG(ERROR) << "acceptHcclSocket ctrl failed, ret: " << ret; return ret; } - ret = - acceptSocket(hccl_data_socket, local_rank_info, remote_control_info, - baseTag_ + "data", rempoteDevIp, is_cross_hccs); + ret = acceptHcclSocket(hccl_data_socket, baseTag_ + "data", remoteDevIp, + is_cross_hccs); if (ret) { - LOG(ERROR) << "acceptSocket data failed, ret: " << ret; + LOG(ERROR) << "acceptHcclSocket data failed, ret: " << ret; return ret; } } else { - rempoteDevIp = hccl::HcclIpAddress(remote_control_info.deviceIp); - ret = - acceptSocket(hccl_ctrl_socket, local_rank_info, remote_control_info, - baseTag_ + "ctrl", rempoteDevIp, is_cross_hccs); + remoteDevIp = hccl::HcclIpAddress(remote_control_info.deviceIp); + ret = acceptHcclSocket(hccl_ctrl_socket, baseTag_ + "ctrl", remoteDevIp, + is_cross_hccs); if (ret) { - LOG(ERROR) << "acceptSocket ctrl failed, ret: " << ret; + LOG(ERROR) << "acceptHcclSocket ctrl failed, ret: " << ret; return ret; } - ret = - acceptSocket(hccl_data_socket, local_rank_info, remote_control_info, - baseTag_ + "data", rempoteDevIp, is_cross_hccs); + ret = acceptHcclSocket(hccl_data_socket, baseTag_ + "data", remoteDevIp, + is_cross_hccs); if (ret) { - LOG(ERROR) << "acceptSocket data failed, ret: " << ret; + LOG(ERROR) << "acceptHcclSocket data failed, ret: " << ret; return ret; } } - target_key_to_connection_map_[key_str].hccl_ctrl_socket = hccl_ctrl_socket; - target_key_to_connection_map_[key_str].hccl_data_socket = hccl_data_socket; - LOG(INFO) << "Creating transfer_mem on the accept side"; + LOG(INFO) << "accept hccl socket success."; + + g_target_key_to_accept_map[key_str].hccl_ctrl_socket = hccl_ctrl_socket; + g_target_key_to_accept_map[key_str].hccl_data_socket = hccl_data_socket; + + RankInfo remote_rank_info(remote_control_info); std::shared_ptr transport_mem{}; - hccl::TransportMem::AttrInfo attrInfo; - attrInfo.localRankId = local_rank_info->deviceLogicId; - attrInfo.remoteRankId = remote_control_info.deviceLogicId; - attrInfo.sdid = 0xFFFFFFFF; - attrInfo.serverId = local_rank_info->serverIdx; - attrInfo.trafficClass = 132; - attrInfo.serviceLevel = 4; - if (is_cross_hccs) { - transport_mem = hccl::TransportMem::Create( - hccl::TransportMem::TpType::ROCE, notifyPool_, nicNetDevCtx_, - dispatcher_, attrInfo); - } else { - transport_mem = hccl::TransportMem::Create( - hccl::TransportMem::TpType::IPC, notifyPool_, vnicNetDevCtx_, - dispatcher_, attrInfo); - } - ret = transport_mem->SetDataSocket(hccl_data_socket); + ret = createTransportMem(local_rank_info, &remote_rank_info, key_str, + is_cross_hccs, transport_mem, true); if (ret) { - char deviceIp[64]; - inet_ntop(AF_INET, &rempoteDevIp, deviceIp, sizeof(deviceIp)); - LOG(ERROR) << "transport_mem SetDataSocket failed, target devicePhyId: " - << remote_control_info.devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp << ", ret: " << ret; + LOG(ERROR) << "createTransportMem failed, ret: " << ret; return ret; } - ret = transport_mem->SetSocket(hccl_ctrl_socket); - if (ret) { - char deviceIp[64]; - inet_ntop(AF_INET, &rempoteDevIp, deviceIp, sizeof(deviceIp)); - LOG(ERROR) << "transport_mem SetSocket failed, target devicePhyId: " - << remote_control_info.devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp << ", ret: " << ret; - return ret; + g_target_key_to_accept_map[key_str].tcp_socket = recv_socket; + struct epoll_event event; + event.events = EPOLLIN; + + if (aggregateEnabled) { + event.data.fd = recv_socket; + if (epoll_ctl(g_epoll_fd_agg, EPOLL_CTL_ADD, recv_socket, &event) == + -1) { + LOG(ERROR) << "epoll_ctl: ADD"; + return -1; + } } - const char *transport_mem_timeout_str = - std::getenv("Ascend_TRANSPORT_MEM_TIMEOUT"); - int transport_mem_timeout = - transport_mem_timeout_str ? std::atoi(transport_mem_timeout_str) : 120; - ret = transport_mem->Connect(transport_mem_timeout); - if (ret) { - char deviceIp[64]; - inet_ntop(AF_INET, &rempoteDevIp, deviceIp, sizeof(deviceIp)); - LOG(ERROR) << "transport_mem Connect failed, target devicePhyId: " - << remote_control_info.devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp << ", ret: " << ret; - return ret; + int ack = 1; + ret = send(recv_socket, &ack, sizeof(int), 0); + if (ret < 0) { + LOG(ERROR) << "Failed to send ack, ret: " << ret << ", errno: " << errno + << ", error: " << strerror(errno); + return -1; } - target_key_to_connection_map_[key_str].transport_mem = transport_mem; - - size_t m_num = g_localMergeMem.size(); - std::vector rmaMemDescs(m_num); - for (size_t i = 0; i < m_num; ++i) { - HcclBuf buf; - HcclMem mem; - mem.addr = g_localMergeMem[i].addr; - mem.size = g_localMergeMem[i].len; - mem.type = HCCL_MEM_TYPE_DEVICE; - if (!is_cross_hccs) { - ret = HcclMemReg(vnicNetDevCtx_, &mem, &buf); - } else { - ret = HcclMemReg(nicNetDevCtx_, &mem, &buf); - } - if (ret != 0 && ret != 20) { - LOG(ERROR) << "HcclMemReg failed, ret: " << ret - << ", addr: " << g_localMergeMem[i].addr - << ", len: " << g_localMergeMem[i].len; - return ret; - } - char *desc = nullptr; - uint64_t desc_len = 0; - ret = HcclMemExport(&buf, &desc, &desc_len); - if (ret) { - LOG(ERROR) << "HcclMemExport failed, ret: " << ret - << ", addr: " << g_localMergeMem[i].addr - << ", len: " << g_localMergeMem[i].len; - return ret; - } + return 0; +} - rmaMemDescs[i].localRankId = local_rank_info->deviceLogicId; - rmaMemDescs[i].remoteRankId = remote_control_info.deviceLogicId; - memset_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, 0, - hccl::TRANSPORT_EMD_ESC_SIZE); - if (memcpy_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, desc, - desc_len + 1) != EOK) { - LOG(ERROR) << "memcpy_s failed, ret: " << ret - << ", addr: " << g_localMergeMem[i].addr - << ", len: " << g_localMergeMem[i].len; - return -1; +int recvMemInfo1(int client_socket, aclrtStream stream) { + int ret = 0; + SingleCopyInfo single_copy_info; + ret = recv(client_socket, &single_copy_info, sizeof(single_copy_info), 0); + if (ret <= 0) { + LOG(ERROR) << "failed to receive single_copy_info, ret: " << ret + << ", errno: " << errno << ", error: " << strerror(errno); + return -1; + } + // LOG(INFO) << "recv host addr:" << single_copy_info.host_addr + // << " , client_socket" << client_socket << " , local_id" + // << single_copy_info.remote_id << " , remote_id" + // << single_copy_info.local_id << ", offset" + // << single_copy_info.offset << ", len:" << single_copy_info.len; + uint64_t device_addr = + reinterpret_cast(g_dev_addr) + single_copy_info.offset; + if (single_copy_info.is_read) { + if (single_copy_info.is_copy) { + ret = aclrtMemcpy( + reinterpret_cast(device_addr), single_copy_info.len, + reinterpret_cast(single_copy_info.host_addr), + single_copy_info.len, ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to copy data from host to device, ret:" + << ret; + return ret; + } + // LOG(INFO) << "remote read ok:" << "device_addr:" << device_addr + // << ", host:" << single_copy_info.host_addr + // << ", len:" << single_copy_info.len; + // print_bfloat16_memory(reinterpret_cast(single_copy_info.host_addr), single_copy_info.len); } - - // In the scenario within HCCS, it is necessary to call HcclMemGrant to - // authorize peer memory - if (!is_cross_hccs) { - HcclMemGrantInfo grant_info; - grant_info.remotePid = (int32_t)remote_control_info.pid; - grant_info.remoteSdid = 0xFFFFFFFF; - ret = HcclMemGrant(&buf, &grant_info); - if (ret) { - LOG(ERROR) << "HcclMemGrant failed, ret: " << ret - << ", addr: " << g_localMergeMem[i].addr - << ", len: " << g_localMergeMem[i].len; + } else { + if (single_copy_info.is_copy) { + ret = aclrtMemcpy( + reinterpret_cast(single_copy_info.host_addr), + single_copy_info.len, reinterpret_cast(device_addr), + single_copy_info.len, ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to copy data from host to device, ret:" + << ret; return ret; } + // LOG(INFO) << "remote put ok:" << "device_addr:" << device_addr + // << ", host:" << single_copy_info.host_addr + // << ", len:" << single_copy_info.len; } } - hccl::TransportMem::RmaMemDescs localRmaMemDescs; - localRmaMemDescs.array = rmaMemDescs.data(); - localRmaMemDescs.arrayLength = rmaMemDescs.size(); - uint32_t actualNumOfRemote = 0; - std::vector remoteRmaMemDescArray(m_num); - hccl::TransportMem::RmaMemDescs remoteRmaMemDescs; - remoteRmaMemDescs.array = remoteRmaMemDescArray.data(); - remoteRmaMemDescs.arrayLength = m_num; - ret = transport_mem->ExchangeMemDesc(localRmaMemDescs, remoteRmaMemDescs, - actualNumOfRemote); - if (ret) { - LOG(ERROR) << "transport_mem->ExchangeMemDesc failed, ret: " << ret - << ", local_rank: " << local_rank_info->devicePhyId - << ", remote_rank: " << remote_control_info.devicePhyId; + SingleCopyInfo single_copy_infos; + single_copy_infos.device_addr = device_addr; + // LOG(INFO) << "send device addr:" << device_addr << " , client_socket" + // << client_socket << " , local_id" << single_copy_info.remote_id + // << " , remote_id" << single_copy_info.local_id; + ret = send(client_socket, &single_copy_infos, sizeof(SingleCopyInfo), 0); + if (ret < 0) { + LOG(ERROR) << "send to receive single_copy_info, ret: " << ret + << ", errno: " << errno << ", error: " << strerror(errno); + // close(client_socket); return ret; } - std::vector remoteRmaMemArray(m_num); - for (uint32_t i = 0; i < m_num; ++i) { - ret = transport_mem->EnableMemAccess(remoteRmaMemDescArray[i], - remoteRmaMemArray[i]); - if (ret) { - LOG(ERROR) << "transport_mem->EnableMemAccess failed, ret: " << ret - << ", i: " << i - << ", local_rank: " << local_rank_info->devicePhyId - << ", remote_rank: " << remote_control_info.devicePhyId; - return ret; - } - } - - LOG(INFO) << "ExchangeMem and EnableMemAccess Success, local devicePhyId: " - << local_rank_info->devicePhyId - << ", target devicePhyId: " << remote_control_info.devicePhyId; - return 0; + return client_socket; } -int regLocalRmaMem(void *addr, uint64_t length) { - g_localMergeMem.push_back(MergeMem{addr, length}); - return 0; +void nonAggRegLocalMem(uint64_t addr, uint64_t length, bool is_pool) { + if (is_pool) { + g_dev_addr = (void *)addr; + } + MemBlock memBlock; + memBlock.addr = addr; + memBlock.len = length; + LOG(INFO) << "addr:" << addr << ", len: " << length + << ", is_pool:" << is_pool; + g_localBuffer.emplace_back(memBlock); + return; } - #ifdef __cplusplus } #endif // __cplusplus \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_transport_mem_internals.cpp b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_transport_mem_internals.cpp new file mode 100644 index 000000000..97c1c12a1 --- /dev/null +++ b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_transport_mem_internals.cpp @@ -0,0 +1,906 @@ +// Copyright 2025 Huawei Technologies Co., Ltd +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "transport/ascend_transport/hccl_transport/hccl_transport_mem_internals.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +HcclNetDevCtx vnicNetDevCtx_{nullptr}; +HcclNetDevCtx nicNetDevCtx_{nullptr}; +std::shared_ptr vnicServerSocket_{nullptr}; +std::shared_ptr nicServerSocket_{nullptr}; +std::unique_ptr notifyPool_; +HcclDispatcher dispatcher_{nullptr}; + +std::unordered_map g_target_key_to_connection_map; +std::unordered_map g_target_key_to_accept_map; +std::vector g_localBuffer; + +int g_epoll_fd_agg = 0; +int g_server_socket = 0; +int g_epoll_fd = 0; +int g_epoll_fd_target = 0; +struct epoll_event g_ev; +struct epoll_event g_events[MAX_EVENTS]; +bool g_is_no_roce = false; +bool a3Enabled() { + char *env = getenv("ASCEND_A3_ENABLE"); + return env != nullptr && std::string(env) == "1"; +} + +static uint16_t findAvailableTcpPort(int &sockfd, bool use_ipv6) { + static std::random_device rand_gen; + std::mt19937 gen(rand_gen()); + const int min_port = 15000; + const int max_port = 17000; + const int max_attempts = 500; + std::uniform_int_distribution<> rand_dist(min_port, max_port); + + for (int attempt = 0; attempt < max_attempts; ++attempt) { + int port = rand_dist(rand_gen); + if (use_ipv6) { + sockaddr_in6 bind_address; + memset(&bind_address, 0, sizeof(sockaddr_in6)); + bind_address.sin6_family = AF_INET6; + bind_address.sin6_port = htons(port); + bind_address.sin6_addr = IN6ADDR_ANY_INIT; + if (bind(sockfd, (sockaddr *)&bind_address, sizeof(sockaddr_in6)) < + 0) { + continue; + } + } else { + sockaddr_in bind_address; + memset(&bind_address, 0, sizeof(sockaddr_in)); + bind_address.sin_family = AF_INET; + bind_address.sin_port = htons(port); + bind_address.sin_addr.s_addr = INADDR_ANY; + if (bind(sockfd, (sockaddr *)&bind_address, sizeof(sockaddr_in)) < + 0) { + continue; + } + } + + return port; + } + return 0; +} + +int initServerNetSocket(RankInfo *local_rank_info) { + RETRY_CALL(HcclNetInit(NICDeployment::NIC_DEPLOYMENT_DEVICE, + local_rank_info->devicePhyId, + local_rank_info->deviceLogicId, false), + "HcclNetInit failed"); + + // Use the physical network card of the device across HCCS + hccl::HcclIpAddress localIp(local_rank_info->deviceIp); + RETRY_CALL(HcclNetOpenDev(&nicNetDevCtx_, NicType::DEVICE_NIC_TYPE, + local_rank_info->devicePhyId, + local_rank_info->deviceLogicId, localIp), + "HcclNetOpenDev DEVICE_NIC_TYPE failed"); + + nicServerSocket_ = std::make_shared( + nicNetDevCtx_, local_rank_info->devicePort); + if (nicServerSocket_ == NULL) { + LOG(ERROR) << "make nicNetDevCtx_ failed"; + return -1; + } + + RETRY_CALL(nicServerSocket_->Init(), "nicServerSocket_ Init failed"); + // RETRY_CALL(nicServerSocket_->Listen(), "nicServerSocket_ Listen failed"); + int retryCount = 0; + int ret = nicServerSocket_->Listen(); + while(retryCount < 3) { + LOG(ERROR) << "nicServerSocket_ Listen failed" << ", retrying... (" << ++retryCount + << "/3)" << ret; + ret = nicServerSocket_->Listen(); + if (ret == 0) { + break; + } + } + if (ret != 0 && ret != 19) { + LOG(ERROR) << "nicServerSocket_ Listen failed 3 times, exit."; + return ret; + } + if (ret == 19) { + LOG(ERROR) << "Three monitoring attempts have failed, with a return value of 19" + << "It is temporarily assumed to be a non-ROCe issue"; + g_is_no_roce = true; + } + // Use virtual network card within HCCS + hccl::HcclIpAddress localVnicIp; + + if (a3Enabled()) { + localVnicIp = hccl::HcclIpAddress(local_rank_info->vnicIp); + } else { + localVnicIp = hccl::HcclIpAddress(local_rank_info->devicePhyId); + RETRY_CALL(hrtRaGetSingleSocketVnicIpInfo( + local_rank_info->devicePhyId, + DeviceIdType::DEVICE_ID_TYPE_PHY_ID, + local_rank_info->devicePhyId, localVnicIp), + "hrtRaGetSingleSocketVnicIpInfo failed"); + } + + RETRY_CALL(HcclNetOpenDev(&vnicNetDevCtx_, NicType::VNIC_TYPE, + local_rank_info->devicePhyId, + local_rank_info->deviceLogicId, localVnicIp), + "HcclNetOpenDev vnicNetDevCtx_ failed"); + + // control plane connection, creat serversocket, listening client + vnicServerSocket_ = std::make_shared( + vnicNetDevCtx_, local_rank_info->devicePort); + if (vnicServerSocket_ == NULL) { + LOG(ERROR) << "vnicServerSocket_ make failed"; + return -1; + } + + RETRY_CALL(vnicServerSocket_->Init(), "vnicServerSocket_ Init failed"); + RETRY_CALL(vnicServerSocket_->Listen(), "vnicServerSocket_ Listen failed"); + + RETRY_CALL(HcclDispatcherInit(DispatcherType::DISPATCHER_NORMAL, + local_rank_info->devicePhyId, &dispatcher_), + "client HcclDispatcherInit failed"); + + notifyPool_.reset(new (std::nothrow) hccl::NotifyPool()); + if (notifyPool_ == nullptr) { + LOG(ERROR) << "reset notifyPool error"; + return -1; + } + + RETRY_CALL(notifyPool_->Init(local_rank_info->devicePhyId), + "Init notifyPool error"); + + return 0; +} + +// The out-of-band socket on the host side that ascend_transport depends on, +// used to convey control information such as deviceId and deviceIp +int initControlSocket(RankInfo *local_rank_info, bool aggregateEnabled) { + g_target_key_to_connection_map.clear(); + g_target_key_to_accept_map.clear(); + int ret = 0; + g_server_socket = socket(AF_INET, SOCK_STREAM, 0); + if (g_server_socket < 0) { + LOG(ERROR) << "ascend transport out-of-band socket create failed"; + return g_server_socket; + } + + int optval = 1; + ret = setsockopt(g_server_socket, SOL_SOCKET, SO_REUSEADDR, &optval, + sizeof(optval)); + if (ret < 0) { + LOG(ERROR) << "set sock opt failed, ret: " << ret; + close(g_server_socket); + return ret; + } + + struct sockaddr_in bind_address; + memset(&bind_address, 0, sizeof(sockaddr_in)); + bind_address.sin_family = AF_INET; + bind_address.sin_addr.s_addr = INADDR_ANY; + bind_address.sin_port = htons(local_rank_info->hostPort); + + ret = bind(g_server_socket, (struct sockaddr *)&bind_address, + sizeof(bind_address)); + if (ret < 0) { + LOG(INFO) << "bind failed on the default port, default port: " + << local_rank_info->hostPort << ", will find available port"; + uint16_t port = findAvailableTcpPort(g_server_socket, false); + if (port == 0) { + LOG(ERROR) << "findAvailableTcpPort failed"; + close(g_server_socket); + return -1; + } + local_rank_info->hostPort = (uint64_t)port; + } + + ret = listen(g_server_socket, CONNECT_MAX); + if (ret < 0) { + LOG(ERROR) << "Listen Failed, ret: " << ret; + close(g_server_socket); + return ret; + } + LOG(INFO) << "initControlSocket successful, listening on host port: " + << local_rank_info->hostPort << "..." << " g_server_socket" + << g_server_socket; + g_epoll_fd = epoll_create1(0); + if (g_epoll_fd == -1) { + LOG(ERROR) << "epoll create Failed, ret: " << g_epoll_fd; + close(g_server_socket); + return g_epoll_fd; + } + + g_epoll_fd_target = epoll_create1(0); + if (g_epoll_fd_target == -1) { + LOG(ERROR) << "epoll fd target create Failed, ret: " + << g_epoll_fd_target; + close(g_server_socket); + return g_epoll_fd_target; + } + + if (aggregateEnabled) { + g_epoll_fd_agg = epoll_create1(0); + if (g_epoll_fd_agg == -1) { + LOG(ERROR) << "epoll fd target create Failed, ret: " << ret; + close(g_server_socket); + return g_epoll_fd; + } + LOG(INFO) << "g_epoll_fd_agg create end"; + } + + g_ev.events = EPOLLIN; + g_ev.data.fd = g_server_socket; + ret = epoll_ctl(g_epoll_fd, EPOLL_CTL_ADD, g_server_socket, &g_ev); + if (ret < 0) { + LOG(ERROR) << "epoll epoll_ctl Failed, ret: " << ret; + close(g_server_socket); + return ret; + } + return 0; +} + +void getDevIpAddresses(RankInfo *local_rank_info) { + int devicePhyId = local_rank_info->devicePhyId; + memset(local_rank_info->vnicIp, 0, sizeof(local_rank_info->vnicIp)); + memset(local_rank_info->deviceIp, 0, sizeof(local_rank_info->deviceIp)); + + if (a3Enabled()) { + bool gotVnicIp = false; + for (int i = 0; i < 10; i++) { + std::stringstream vnicCmd; + vnicCmd << "/usr/local/Ascend/driver/tools/hccn_tool -i " + << devicePhyId << " -vnic -g 2>&1"; + + LOG(INFO) << "Attempt " << (i + 1) + << " to get vnicIp with command: " << vnicCmd.str(); + + FILE *vnicPipe = popen(vnicCmd.str().c_str(), "r"); + if (vnicPipe) { + int fd = fileno(vnicPipe); + struct timeval timeout; + timeout.tv_sec = 2; + timeout.tv_usec = 0; + setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, + sizeof(timeout)); + + char buffer[1024] = {0}; + ssize_t bytesRead; + std::string output; + + while ((bytesRead = read(fd, buffer, sizeof(buffer) - 1)) > 0) { + buffer[bytesRead] = '\0'; + output += buffer; + memset(buffer, 0, sizeof(buffer)); + } + + pclose(vnicPipe); + + LOG(INFO) << "Attempt " << (i + 1) + << " vnic command full output: " << output; + + const char *prefix = "vnic ipaddr: "; + size_t pos = output.find(prefix); + if (pos != std::string::npos) { + const char *ipStart = output.c_str() + pos + strlen(prefix); + size_t maxLen = sizeof(local_rank_info->vnicIp) - 1; + size_t ipLen = 0; + while (ipLen < maxLen && ipStart[ipLen] != '\0' && + ipStart[ipLen] != '\n' && ipStart[ipLen] != ' ') { + local_rank_info->vnicIp[ipLen] = ipStart[ipLen]; + ipLen++; + } + local_rank_info->vnicIp[ipLen] = '\0'; + + if (ipLen >= 7) { + int dotCount = 0; + for (size_t j = 0; j < ipLen; j++) { + if (local_rank_info->vnicIp[j] == '.') { + dotCount++; + } + } + if (dotCount == 3) { + gotVnicIp = true; + break; + } + } + } + } else { + LOG(WARNING) << "Failed to open pipe for vnic command, attempt " + << (i + 1); + } + + if (gotVnicIp) break; + + int sleepTime = 10000 * (i + 1); + if (sleepTime > 500000) sleepTime = 500000; + LOG(INFO) << "Retrying vnicIp after " << sleepTime / 1000 << "ms"; + usleep(sleepTime); + } + + if (gotVnicIp) { + LOG(INFO) << "Successfully obtained vnicIp: " + << local_rank_info->vnicIp; + } else { + LOG(WARNING) + << "Failed to obtain valid vnicIp after multiple attempts"; + } + } + + bool gotIp = false; + for (int i = 0; i < 10; i++) { + std::stringstream ipCmd; + ipCmd << "/usr/local/Ascend/driver/tools/hccn_tool -i " << devicePhyId + << " -ip -g 2>&1"; + + LOG(INFO) << "Attempt " << (i + 1) + << " to get deviceIp with command: " << ipCmd.str(); + + FILE *ipPipe = popen(ipCmd.str().c_str(), "r"); + if (ipPipe) { + int fd = fileno(ipPipe); + struct timeval timeout; + timeout.tv_sec = 2; + timeout.tv_usec = 0; + setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)); + + char buffer[1024] = {0}; + ssize_t bytesRead; + std::string output; + + while ((bytesRead = read(fd, buffer, sizeof(buffer) - 1)) > 0) { + buffer[bytesRead] = '\0'; + output += buffer; + memset(buffer, 0, sizeof(buffer)); + } + + pclose(ipPipe); + + LOG(INFO) << "Attempt " << (i + 1) + << " device command full output: " << output; + + const char *prefix = "ipaddr:"; + size_t pos = output.find(prefix); + if (pos != std::string::npos) { + const char *ipStart = output.c_str() + pos + strlen(prefix); + while (*ipStart == ' ') ipStart++; + + size_t maxLen = sizeof(local_rank_info->deviceIp) - 1; + size_t ipLen = 0; + while (ipLen < maxLen && ipStart[ipLen] != '\0' && + ipStart[ipLen] != '\n' && ipStart[ipLen] != ' ') { + local_rank_info->deviceIp[ipLen] = ipStart[ipLen]; + ipLen++; + } + local_rank_info->deviceIp[ipLen] = '\0'; + + if (ipLen >= 7) { + int dotCount = 0; + for (size_t j = 0; j < ipLen; j++) { + if (local_rank_info->deviceIp[j] == '.') { + dotCount++; + } + } + if (dotCount == 3) { + gotIp = true; + break; + } + } + } + } else { + LOG(WARNING) << "Failed to open pipe for device command, attempt " + << (i + 1); + } + + if (gotIp) break; + + int sleepTime = 10000 * (i + 1); + if (sleepTime > 500000) sleepTime = 500000; + LOG(INFO) << "Retrying deviceIp after " << sleepTime / 1000 << "ms"; + usleep(sleepTime); + } + + if (gotIp) { + LOG(INFO) << "Successfully obtained deviceIp: " + << local_rank_info->deviceIp; + } else { + LOG(WARNING) << "Failed to obtain deviceIp after multiple attempts"; + } +} + +static int connectToTarget(std::string target_ip, int target_port) { + int client_socket; + struct sockaddr_in server_addr; + + client_socket = socket(AF_INET, SOCK_STREAM, 0); + if (client_socket < 0) { + LOG(ERROR) << "Socket creation failed"; + return client_socket; + } + + int optval = 1; + int ret = setsockopt(client_socket, SOL_SOCKET, SO_REUSEADDR, &optval, + sizeof(optval)); + if (ret < 0) { + LOG(ERROR) << "set sock opt failed, ret: " << ret; + close(client_socket); + return ret; + } + + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(target_port); + server_addr.sin_addr.s_addr = inet_addr(target_ip.c_str()); + + if (server_addr.sin_addr.s_addr == INADDR_NONE) { + LOG(ERROR) << "Invalid server IP address"; + close(client_socket); + return -1; + } + + int connected = 0; + + const char *tcp_timeout_str = std::getenv("Ascend_TCP_TIMEOUT"); + int ascend_tcp_timeout = tcp_timeout_str ? std::atoi(tcp_timeout_str) : 30; + int connect_retry_times = ascend_tcp_timeout * 100; + + for (int i = 0; i < connect_retry_times; ++i) { + if (connect(client_socket, (struct sockaddr *)&server_addr, + sizeof(server_addr)) == 0) { + LOG(INFO) << "Connect to host server " << target_ip << ":" + << ntohs(server_addr.sin_port) << " successful"; + connected = 1; + break; + } + + LOG(INFO) << "Connect attempt " << i << " failed: " << strerror(errno) + << ", retry once"; + + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + if (!connected) { + LOG(ERROR) << "Failed to connect to server after " + << connect_retry_times << " retries"; + close(client_socket); + return HCCL_E_TIMEOUT; + } + return client_socket; +} + +int controlInfoSend(RankInfo *local_rank_info, RankInfo *remote_rank_info) { + int ret = 0; + std::string key_str = std::string(remote_rank_info->hostIp) + '-' + + std::to_string(remote_rank_info->devicePhyId); + LOG(INFO) << "aggTransportMemTask local_rank_info rankId: " + << local_rank_info->rankId + << ", serverId: " << local_rank_info->serverId + << ", deviceLogicId: " << local_rank_info->deviceLogicId + << ", devicePhyId: " << local_rank_info->devicePhyId + << ", deviceIp: " << local_rank_info->deviceIp + << ", devicePort: " << local_rank_info->devicePort + << ", hostIp: " << local_rank_info->hostIp + << ", hostPort: " << local_rank_info->hostPort + << ", device pid: " << local_rank_info->devPid + << ", vnicIp: " << local_rank_info->vnicIp + << ", sdid: " << local_rank_info->sdid; + + LOG(INFO) << "aggTransportMemTask remote_rank_info rankId: " + << remote_rank_info->rankId + << ", serverId: " << remote_rank_info->serverId + << ", deviceLogicId: " << remote_rank_info->deviceLogicId + << ", devicePhyId: " << remote_rank_info->devicePhyId + << ", deviceIp: " << remote_rank_info->deviceIp + << ", devicePort: " << remote_rank_info->devicePort + << ", hostIp: " << remote_rank_info->hostIp + << ", hostPort: " << remote_rank_info->hostPort + << ", device pid: " << remote_rank_info->devPid + << ", vnicIp: " << remote_rank_info->vnicIp + << ", sdid: " << remote_rank_info->sdid; + + // Encapsulate control information + RankControlInfo control_info; + control_info.deviceLogicId = local_rank_info->deviceLogicId; + control_info.devicePhyId = local_rank_info->devicePhyId; + control_info.devPid = local_rank_info->devPid; + control_info.sdid = local_rank_info->sdid; + + strncpy(control_info.hostIp, local_rank_info->hostIp, 127); + control_info.hostIp[127] = '\0'; + strncpy(control_info.vnicIp, local_rank_info->vnicIp, 127); + control_info.vnicIp[127] = '\0'; + strncpy(control_info.deviceIp, local_rank_info->deviceIp, 127); + control_info.deviceIp[127] = '\0'; + // Self-built out-of-band, host socket for sending control plane + int client_socket = connectToTarget(std::string(remote_rank_info->hostIp), + remote_rank_info->hostPort); + if (client_socket < 0) { + LOG(ERROR) << "client connect failed"; + return client_socket; + } + + ret = send(client_socket, &control_info, sizeof(RankControlInfo), 0); + if (ret < 0) { + LOG(ERROR) << "send control_info failed, ret: " << ret + << ", errno: " << errno << ", error: " << strerror(errno); + return ret; + } + + g_target_key_to_connection_map[key_str].tcp_socket = client_socket; + LOG(INFO) << "target_key:" << key_str << ", tcp_socket:" << client_socket; + + return 0; +} + +int createClientSocket(std::shared_ptr &hccl_socket, + RankInfo *local_rank_info, RankInfo *remote_rank_info, + bool is_cross_hccs, std::string tag) { + int ret = 0; + hccl::HcclIpAddress remoteDevIp; + std::string key_str = std::string(remote_rank_info->hostIp) + '-' + + std::to_string(remote_rank_info->devicePhyId); + std::string baseTag_ = std::string(local_rank_info->hostIp) + '-' + + std::to_string(local_rank_info->devicePhyId) + + key_str + tag; + if (!is_cross_hccs) { + std::vector remoteDevPhyId; + remoteDevPhyId.emplace_back(remote_rank_info->devicePhyId); + ret = hccl::P2PMgmtPub::EnableP2P(remoteDevPhyId); + if (ret) { + LOG(ERROR) << "P2PMgmtPub EnableP2P failed, ret: " << ret; + return ret; + } + ret = hccl::P2PMgmtPub::WaitP2PEnabled(remoteDevPhyId); + if (ret) { + LOG(ERROR) << "P2PMgmtPub WaitP2PEnabled failed, ret: " << ret; + return ret; + } + if (a3Enabled()) { + remoteDevIp = hccl::HcclIpAddress(remote_rank_info->vnicIp); + } else { + remoteDevIp = hccl::HcclIpAddress(remote_rank_info->devicePhyId); + ret = hrtRaGetSingleSocketVnicIpInfo( + local_rank_info->devicePhyId, + DeviceIdType::DEVICE_ID_TYPE_PHY_ID, + remote_rank_info->devicePhyId, remoteDevIp); + if (ret) { + LOG(ERROR) << "hrtRaGetSingleSocketVnicIpInfo, ret: " << ret; + return ret; + } + } + hccl_socket = std::make_shared( + baseTag_, vnicNetDevCtx_, remoteDevIp, remote_rank_info->devicePort, + hccl::HcclSocketRole::SOCKET_ROLE_CLIENT); + } else { + remoteDevIp = hccl::HcclIpAddress(remote_rank_info->deviceIp); + hccl_socket = std::make_shared( + baseTag_, nicNetDevCtx_, remoteDevIp, remote_rank_info->devicePort, + hccl::HcclSocketRole::SOCKET_ROLE_CLIENT); + } + + ret = hccl_socket->Init(); + if (ret) { + char deviceIp[64]; + inet_ntop(AF_INET, &remoteDevIp, deviceIp, sizeof(deviceIp)); + LOG(ERROR) << "client hccl_socket init failed, target devicePhyId: " + << remote_rank_info->devicePhyId + << ", local devicePhyId: " << local_rank_info->devicePhyId + << ", remoteDevIp: " << deviceIp + << ", remote port: " << remote_rank_info->devicePort + << ", ret: " << ret; + return ret; + } + ret = hccl_socket->Connect(); + if (ret) { + char deviceIp[64]; + inet_ntop(AF_INET, &remoteDevIp, deviceIp, sizeof(deviceIp)); + LOG(ERROR) << "client hccl_socket Connect failed, target devicePhyId: " + << remote_rank_info->devicePhyId + << ", local devicePhyId: " << local_rank_info->devicePhyId + << ", remoteDevIp: " << deviceIp + << ", remote port: " << remote_rank_info->devicePort + << ", ret: " << ret; + return ret; + } + LOG(INFO) << "hccl_socket begin to connect, local devicePhyId: " + << local_rank_info->devicePhyId + << ", target devicePhyId: " << remote_rank_info->devicePhyId; + hccl::HcclSocketStatus status; + struct timespec start, end; + const char *hccl_socket_timeout_str = + std::getenv("Ascend_HCCL_SOCKET_TIMEOUT"); + int hccl_socket_timeout = + hccl_socket_timeout_str ? std::atoi(hccl_socket_timeout_str) : 30; + long long hccl_socket_timeout_ns = + static_cast(hccl_socket_timeout) * 1000000000LL; + clock_gettime(CLOCK_MONOTONIC, &start); + do { + status = hccl_socket->GetStatus(); + clock_gettime(CLOCK_MONOTONIC, &end); + long long elapsed_time = (end.tv_sec - start.tv_sec) * 1000000000LL + + (end.tv_nsec - start.tv_nsec); + if (elapsed_time > + hccl_socket_timeout_ns) { // Exceeds 20 seconds,TimeOut + LOG(ERROR) << "hccl_socket connect timeout, local devicePhyId: " + << local_rank_info->devicePhyId + << ", target devicePhyId: " + << remote_rank_info->devicePhyId; + return HCCL_E_TIMEOUT; + } + } while (status != hccl::HcclSocketStatus::SOCKET_OK); + LOG(INFO) << "hccl_socket connect success, local devicePhyId: " + << local_rank_info->devicePhyId + << ", target devicePhyId: " << remote_rank_info->devicePhyId; + + return 0; +} + +int createTransportMem(RankInfo *local_rank_info, RankInfo *remote_rank_info, + std::string key_str, bool is_cross_hccs, + std::shared_ptr &transport_mem, + bool is_accept) { + int ret = 0; + const char *aggregateEnv = std::getenv("ASCEND_AGGREGATE_ENABLE"); + bool is_agg = + (aggregateEnv != nullptr && std::string(aggregateEnv) == "1"); + hccl::TransportMem::AttrInfo attrInfo; + attrInfo.localRankId = local_rank_info->deviceLogicId; + attrInfo.remoteRankId = remote_rank_info->deviceLogicId; + attrInfo.sdid = local_rank_info->sdid; + attrInfo.serverId = local_rank_info->serverId; + attrInfo.trafficClass = 132; + attrInfo.serviceLevel = 4; + if (is_cross_hccs) { + transport_mem = hccl::TransportMem::Create( + hccl::TransportMem::TpType::ROCE, notifyPool_, nicNetDevCtx_, + dispatcher_, attrInfo); + } else { + transport_mem = hccl::TransportMem::Create( + hccl::TransportMem::TpType::IPC, notifyPool_, vnicNetDevCtx_, + dispatcher_, attrInfo); + } + if (is_accept) { + ret = transport_mem->SetDataSocket( + g_target_key_to_accept_map[key_str].hccl_data_socket); + } else { + ret = transport_mem->SetDataSocket( + g_target_key_to_connection_map[key_str].hccl_data_socket); + } + + if (ret) { + LOG(ERROR) << "client SetDataSocket failed, target devicePhyId: " + << remote_rank_info->devicePhyId + << ", local devicePhyId: " << local_rank_info->devicePhyId + << ", remoteDevIp: " << remote_rank_info->deviceIp + << ", ret: " << ret; + return ret; + } + if (is_accept) { + ret = transport_mem->SetSocket( + g_target_key_to_accept_map[key_str].hccl_ctrl_socket); + } else { + ret = transport_mem->SetSocket( + g_target_key_to_connection_map[key_str].hccl_ctrl_socket); + } + if (ret) { + LOG(ERROR) << "client SetSocket failed, target devicePhyId: " + << remote_rank_info->devicePhyId + << ", local devicePhyId: " << local_rank_info->devicePhyId + << ", remoteDevIp: " << remote_rank_info->deviceIp + << ", ret: " << ret; + return ret; + } + const char *transport_mem_timeout_str = + std::getenv("Ascend_TRANSPORT_MEM_TIMEOUT"); + int transport_mem_timeout = + transport_mem_timeout_str ? std::atoi(transport_mem_timeout_str) : 120; + ret = transport_mem->Connect(transport_mem_timeout); + if (ret) { + LOG(ERROR) << "client Connect failed, target devicePhyId: " + << remote_rank_info->devicePhyId + << ", local devicePhyId: " << local_rank_info->devicePhyId + << ", remoteDevIp: " << remote_rank_info->deviceIp + << ", ret: " << ret; + return ret; + } + LOG(INFO) << "transport_mem connect success"; + uint32_t m_num = g_localBuffer.size() - HUGE_BUFFER_NUM / 2; + if (!is_agg) { + uint32_t m_num = g_localBuffer.size(); + } + LOG(INFO) << "m_num: " << m_num; + + std::vector rmaMemDescs(m_num); + uint32_t idx = 0; + if (is_agg) { + if (is_accept) { + idx = 1; + } + } + + for (uint32_t i = 0; i < m_num; ++i) { + HcclMem mem; + HcclBuf buf; + mem.addr = (void *)g_localBuffer[idx].addr; + mem.size = g_localBuffer[idx].len; + mem.type = HCCL_MEM_TYPE_DEVICE; + if (!is_cross_hccs) { + ret = HcclMemReg(vnicNetDevCtx_, &mem, &buf); + } else { + ret = HcclMemReg(nicNetDevCtx_, &mem, &buf); + } + if (ret != 0 && ret != 20) { + LOG(ERROR) << "HcclMemReg failed, ret: " << ret + << " addr: " << mem.addr << " len: " << mem.size; + return ret; + } + char *desc = nullptr; + uint64_t desc_len = 0; + ret = HcclMemExport(&buf, &desc, &desc_len); + if (ret) { + LOG(ERROR) << "HcclMemExport failed, ret: " << ret + << ", addr: " << mem.addr << ", len: " << mem.size; + return ret; + } + + rmaMemDescs[i].localRankId = local_rank_info->deviceLogicId; + rmaMemDescs[i].remoteRankId = remote_rank_info->deviceLogicId; + memset_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, 0, + hccl::TRANSPORT_EMD_ESC_SIZE); + if (memcpy_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, desc, + desc_len + 1) != EOK) { + LOG(ERROR) << "memcpy_s failed, ret: " << ret + << ", addr: " << mem.addr << ", len: " << mem.size; + return -1; + } + + // In the scenario within HCCS, it is necessary to call HcclMemGrant to + // authorize peer memory + if (!is_cross_hccs) { + HcclMemGrantInfo grant_info; + grant_info.remotePid = (int32_t)remote_rank_info->devPid; + grant_info.remoteSdid = remote_rank_info->sdid; + ret = HcclMemGrant(&buf, &grant_info); + if (ret) { + LOG(ERROR) << "HcclMemGrant failed, ret: " << ret + << ", addr: " << mem.addr << ", len: " << mem.size; + return ret; + } + } + if (is_agg) { + if (idx < (HUGE_BUFFER_NUM - 1)) { + idx += 2; + } else { + idx++; + } + } else { + idx++; + } + } + hccl::TransportMem::RmaMemDescs localRmaMemDescs; + localRmaMemDescs.array = rmaMemDescs.data(); + localRmaMemDescs.arrayLength = rmaMemDescs.size(); + uint32_t actualNumOfRemote = 0; + std::vector remoteRmaMemDescArray(m_num); + hccl::TransportMem::RmaMemDescs remoteRmaMemDescs; + remoteRmaMemDescs.array = remoteRmaMemDescArray.data(); + remoteRmaMemDescs.arrayLength = m_num; + ret = transport_mem->ExchangeMemDesc(localRmaMemDescs, remoteRmaMemDescs, + actualNumOfRemote); + if (ret) { + LOG(ERROR) << "transport_mem->ExchangeMemDesc failed, ret: " << ret + << ", local_rank: " << local_rank_info->devicePhyId + << ", remote_rank: " << remote_rank_info->devicePhyId; + return ret; + } + std::vector remoteRmaMemArray(m_num); + for (uint32_t i = 0; i < m_num; ++i) { + ret = transport_mem->EnableMemAccess(remoteRmaMemDescArray[i], + remoteRmaMemArray[i]); + if (ret) { + LOG(ERROR) << "transport_mem->EnableMemAccess failed, ret: " << ret + << ", i: " << i + << ", local_rank: " << local_rank_info->devicePhyId + << ", remote_rank: " << remote_rank_info->devicePhyId; + return ret; + } + } + LOG(INFO) << "ExchangeMem and EnableMemAccess Success, local devicePhyId: " + << local_rank_info->devicePhyId + << ", target devicePhyId: " << remote_rank_info->devicePhyId; + if (is_accept) { + g_target_key_to_accept_map[key_str].transport_mem = transport_mem; + } else { + g_target_key_to_connection_map[key_str].transport_mem = transport_mem; + } + return 0; +} + +int socketEpollWait() { + int nfds = epoll_wait(g_epoll_fd, g_events, MAX_EVENTS, -1); + if (nfds == -1) { + if (errno == EINTR) { + return 0; + } else { + LOG(ERROR) << "epoll_wait failed: " << strerror(errno); + return -1; + } + } + + return nfds; +} + +int acceptFromTarget() { + int client_socket; + struct sockaddr_in client_addr; + socklen_t client_len = sizeof(client_addr); + client_socket = + accept(g_server_socket, (struct sockaddr *)&client_addr, &client_len); + if (client_socket < 0) { + LOG(ERROR) << "Accept failed"; + return client_socket; + } + + LOG(INFO) << "host client connected from " + << inet_ntoa(client_addr.sin_addr) << ":" + << ntohs(client_addr.sin_port); + return client_socket; +} + +int acceptHcclSocket(std::shared_ptr &hccl_socket, + std::string baseTag_, hccl::HcclIpAddress remoteDevIp, + bool is_cross_hccs) { + int ret = 0; + std::vector wlistInfoVec; + SocketWlistInfo wlistInfo = {}; + wlistInfo.connLimit = 1; + memcpy(&wlistInfo.tag[0], baseTag_.c_str(), baseTag_.size() + 1); + wlistInfo.remoteIp.addr = remoteDevIp.GetBinaryAddress().addr; + wlistInfo.remoteIp.addr6 = remoteDevIp.GetBinaryAddress().addr6; + wlistInfoVec.emplace_back(wlistInfo); + auto serverSocket = is_cross_hccs ? nicServerSocket_ : vnicServerSocket_; + if (g_is_no_roce) { + serverSocket = vnicServerSocket_; + } + ret = serverSocket->AddWhiteList(wlistInfoVec); + if (ret) { + LOG(ERROR) << "serverSocket AddWhiteList failed, ret: " << ret; + return ret; + } + // Before using the device-side network card for communication, it is + // necessary to add the client device address to the whitelist. + LOG(INFO) << "Add the client's Device IP address to the whitelist success."; + + ret = serverSocket->Accept(baseTag_, hccl_socket); + if (ret) { + LOG(ERROR) << "serverSocket transportMemAccept ctrl socket failed ret: " + << ret; + return ret; + } + return 0; +} + +#ifdef __cplusplus +} +#endif // __cplusplus \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/hccl_aggTransport.cpp b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/hccl_aggTransport.cpp new file mode 100644 index 000000000..e38782e33 --- /dev/null +++ b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/hccl_aggTransport.cpp @@ -0,0 +1,266 @@ +// Copyright 2025 Huawei Technologies Co., Ltd +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include "transport/ascend_transport/hccl_transport/hccl_transport.h" + +namespace mooncake { +int HcclTransport::aggTransport(std::vector &slice_list, + aclrtStream stream) { + auto start = std::chrono::high_resolution_clock::now(); + int ret = prepareTransport(slice_list); + if (ret) { + LOG(ERROR) << "HcclTransport: prepareTransport failed" << ret; + return ret; + } + // auto m1 = std::chrono::high_resolution_clock::now(); + std::vector localMemPool; + std::vector remoteMemPool; + localMemPool.reserve(slice_list.size()); + remoteMemPool.reserve(slice_list.size()); + + // auto m2 = std::chrono::high_resolution_clock::now(); + for (auto slice : slice_list) { + uint64_t len = slice->length; + uint64_t source_addr = reinterpret_cast(slice->source_addr); + uint64_t dest_addr = slice->hccl.dest_addr; + auto addr_type = slice->hccl.dest_addr_type; + + if (len > PER_HUGE_BUFFER_SIZE) { + uint64_t remaining = len; + uint64_t current_src = source_addr; + uint64_t current_dest = dest_addr; + + while (remaining > 0) { + uint64_t chunk_size = + std::min(remaining, (uint64_t)PER_HUGE_BUFFER_SIZE); + + localMemPool.emplace_back(current_src, chunk_size, addr_type); + remoteMemPool.emplace_back(current_dest, chunk_size, addr_type); + + current_src += chunk_size; + current_dest += chunk_size; + remaining -= chunk_size; + } + } else { + localMemPool.emplace_back(source_addr, len, addr_type); + remoteMemPool.emplace_back(dest_addr, len, addr_type); + } + } + // auto m3 = std::chrono::high_resolution_clock::now(); + + ret = aggTransportMemTask( + &local_rank_info_, &remote_rank_info_, localMemPool, remoteMemPool, + slice_list[0]->opcode, stream, slice_list[0]->hccl.dest_addr_type); + if (ret) { + LOG(ERROR) << "HcclTransport: aggTransportMemTask error, local " + "devicePhyId: " + << local_rank_info_.devicePhyId + << ", remote devicePhyId: " + << remote_rank_info_.devicePhyId << ", ret: " << ret; + for (auto slice : slice_list) { + slice->markFailed(); + slice->task->transferred_bytes = slice->length; + } + return ret; + } + // auto m4 = std::chrono::high_resolution_clock::now(); + + for (auto slice : slice_list) { + slice->markSuccess(); + slice->task->transferred_bytes = slice->length; + } + + auto stop = std::chrono::high_resolution_clock::now(); + if (enableAscendLogging()) { + pid_t pid = getpid(); + auto duration_call = + std::chrono::duration_cast(stop - + start); + LOG(INFO) << "pid: " << pid + << ", target hostIp: " << remote_rank_info_.hostIp + << ", local devicePhyId: " << local_rank_info_.devicePhyId + << ", target devicePhyId: " + << remote_rank_info_.devicePhyId + << ", batch transfersync spent: " << duration_call.count() + << "us"; + } else { + (void)start; + (void)stop; + } + + // auto duration_d1 = + // std::chrono::duration_cast(m1 - start); + // auto duration_d2 = + // std::chrono::duration_cast(m2 - m1); auto + // duration_d3 = + // std::chrono::duration_cast(m3 - m2); auto + // duration_d4 = + // std::chrono::duration_cast(m4 - m3); auto + // duration_d5 = + // std::chrono::duration_cast(stop - m4); + // LOG(INFO) << ", local devicePhyId: " << local_rank_info_.devicePhyId + // << ", target devicePhyId: " << remote_rank_info_.devicePhyId + // << ", batch duration_d1 spent: "<< duration_d1.count() << "us" + // << ", batch duration_d2 spent: "<< duration_d2.count() << "us" + // << ", batch duration_d3 spent: "<< duration_d3.count() << "us" + // << ", batch duration_d4 spent: "<< duration_d4.count() << "us" + // << ", batch duration_d5 spent: "<< duration_d5.count() << "us"; + return 0; +} + +void HcclTransport::aggInitiatorLoop(int deviceLogicId) { + aclrtStream stream; + int ret = aclrtSetDevice(deviceLogicId); + if (ret) { + LOG(ERROR) << "HcclTransport: aclrtSetDevice error, ret: " << ret; + } + + ret = aclrtCreateStream(&stream); + if (ret) { + LOG(ERROR) << "HcclTransport: aclrtCreateStream error, ret: " << ret; + } + + while (running_) { + std::unique_lock lock(initiator_mutex_); + initiator_cond_.wait( + lock, [this] { return !allReqQueues_.empty() || !running_; }); + auto slice_list = std::move(allReqQueues_.front()); + allReqQueues_.pop(); + lock.unlock(); + if (slice_list.empty()) { + LOG(ERROR) << "HcclTransport: empty transfer request batch"; + } + bool isAgg = true; + // if (slice_list[0]->hccl.dest_addr_type != 0) { + // size_t minLen = slice_list[0]->length; + // size_t maxLen = slice_list[0]->length; + // for (auto slice : slice_list) { + // minLen = std::min(minLen, slice->length); + // maxLen = std::max(maxLen, slice->length); + // if (maxLen > PER_HUGE_BUFFER_SIZE) { + // isAgg = false; + // break; + // } + // } + + // if (minLen > BLOCK_AGGREGATION_THRESHOLD) { + // isAgg = false; + // } + // } + + if (!isAgg) { + ret = nonAggTransport(slice_list, stream); + if (ret) { + LOG(ERROR) << "HcclTransport: nonAggTransport error, ret: " + << ret; + } + } else { + ret = aggTransport(slice_list, stream); + if (ret) { + LOG(ERROR) << "HcclTransport: aggTransport error, ret: " << ret; + } + } + } +} + +void HcclTransport::aggInitiatorTransferLoop(int deviceLogicId) { + aclrtStream stream; + int ret = aclrtSetDevice(deviceLogicId); + if (ret) { + LOG(ERROR) << "HcclTransport: aclrtSetDevice error, ret: " << ret; + } + + ret = aclrtCreateStream(&stream); + if (ret) { + LOG(ERROR) << "HcclTransport: aclrtCreateStream error, ret: " << ret; + } + + while (running_) { + ret = aggTransportMemTransfer(stream); + if (ret) { + LOG(ERROR) << "HcclTransport: aggTransportMemTransfer error"; + } + } +} + +void HcclTransport::aggTargetAcceptLoop(int deviceLogicId) { + int ret = aclrtSetDevice(deviceLogicId); + if (ret) { + LOG(ERROR) << "HcclTransport: aclrtSetDevice failed ret: " << ret; + } + + while (running_) { + ret = transportMemAccept(&local_rank_info_, aggregateEnabled_); + if (ret) { + LOG(ERROR) << "HcclTransport: transportMemAccept failed ret: " + << ret; + } + } +} + +// Target-side Aggregation/Splitting Processing Thread +void HcclTransport::aggTargetLoop(int deviceLogicId) { + aclrtStream stream; + int ret = aclrtSetDevice(deviceLogicId); + if (ret) { + LOG(ERROR) << "HcclTransport: aclrtSetDevice failed ret:" << ret; + } + + ret = aclrtCreateStream(&stream); + if (ret) { + LOG(ERROR) << "HcclTransport: aclrtCreateStream error, ret:" << ret; + } + + while (running_) { + ret = aggTransportMemTarget(stream); + if (ret) { + LOG(ERROR) << "HcclTransport: aggTransportMemTarget failed ret:" + << ret; + } + } +} + +int HcclTransport::startAggThreads() { + pid_t pid = getpid(); + int ret = 0; + int deviceLogicId; + ret = aclrtGetDevice(&deviceLogicId); + if (ret) { + LOG(ERROR) << "HcclTransport: aclrtGetDevice failed, ret: " << ret; + return ret; + } + + aggInitiatorThread_ = + std::thread(&HcclTransport::aggInitiatorLoop, this, deviceLogicId); + aggInitiatorTransferThread_ = std::thread( + &HcclTransport::aggInitiatorTransferLoop, this, deviceLogicId); + aggTargetAcceptThread_ = + std::thread(&HcclTransport::aggTargetAcceptLoop, this, deviceLogicId); + aggTargetThread_ = + std::thread(&HcclTransport::aggTargetLoop, this, deviceLogicId); + + LOG(INFO) << "HcclTransport: startAggThreads, pid: " << pid + << ", deviceLogicId: " << deviceLogicId; + return 0; +} + +} // namespace mooncake diff --git a/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/hccl_transport.cpp b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/hccl_transport.cpp index b1e1b09a6..178a16408 100644 --- a/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/hccl_transport.cpp +++ b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/hccl_transport.cpp @@ -24,23 +24,168 @@ #include "transport/ascend_transport/hccl_transport/hccl_transport.h" namespace mooncake { -HcclTransport::HcclTransport() : running_(-1) { - // TODO +HcclTransport::HcclTransport() : running_(false) { + const char *aggregateEnv = std::getenv("ASCEND_AGGREGATE_ENABLE"); + aggregateEnabled_ = + (aggregateEnv != nullptr && std::string(aggregateEnv) == "1"); } HcclTransport::~HcclTransport() { if (running_) { running_ = false; + initiator_cond_.notify_all(); + } + + if (aggregateEnabled_) { + if (aggInitiatorThread_.joinable()) { + aggInitiatorThread_.join(); + } + if (aggInitiatorTransferThread_.joinable()) { + aggInitiatorTransferThread_.join(); + } + if (aggTargetAcceptThread_.joinable()) { + aggTargetAcceptThread_.join(); + } + if (aggTargetThread_.joinable()) { + aggTargetThread_.join(); + } + } else { + if (initiatorThread_.joinable()) { + initiatorThread_.join(); + } + if (targetAcceptThread_.joinable()) { + targetAcceptThread_.join(); + } + if (targetThread_.joinable()) { + targetThread_.join(); + } + } +} + +int HcclTransport::prepareTransport(std::vector &slice_list) { + auto segment_desc = metadata_->getSegmentDescByID(slice_list[0]->target_id); + if (!segment_desc) { + LOG(ERROR) + << "Unable to get target segment ID, please recheck, segment ID: " + << slice_list[0]->target_id; + for (auto slice : slice_list) { + slice->markFailed(); + } + return -1; + } + + remote_rank_info_.rankId = segment_desc->rank_info.rankId; + remote_rank_info_.hostPort = segment_desc->rank_info.hostPort; + remote_rank_info_.deviceLogicId = segment_desc->rank_info.deviceLogicId; + remote_rank_info_.devicePhyId = segment_desc->rank_info.devicePhyId; + remote_rank_info_.devicePort = segment_desc->rank_info.devicePort; + remote_rank_info_.serverId = segment_desc->rank_info.serverId; + remote_rank_info_.devPid = segment_desc->rank_info.devPid; + remote_rank_info_.sdid = segment_desc->rank_info.sdid; + + strncpy(remote_rank_info_.hostIp, segment_desc->rank_info.hostIp.c_str(), + 127); + remote_rank_info_.hostIp[127] = '\0'; + + strncpy(remote_rank_info_.deviceIp, + segment_desc->rank_info.deviceIp.c_str(), 127); + remote_rank_info_.deviceIp[127] = '\0'; + + strncpy(remote_rank_info_.vnicIp, segment_desc->rank_info.vnicIp.c_str(), + 127); + remote_rank_info_.vnicIp[127] = '\0'; + + return 0; +} + +int HcclTransport::nonAggTransport(std::vector &slice_list, + aclrtStream stream) { + auto start = std::chrono::high_resolution_clock::now(); + int ret = prepareTransport(slice_list); + if (ret) { + LOG(ERROR) << "HcclTransport: prepareTransport failed" << ret; + return ret; + } + + for (auto slice : slice_list) { + LOG(ERROR) << "slice->hccl.dest_addr_type" + << slice->hccl.dest_addr_type; + ret = nonAggTransportMemTask(&local_rank_info_, &remote_rank_info_, + slice->opcode, slice->hccl.dest_addr, + slice->length, slice->source_addr, + slice->hccl.dest_addr_type, stream); + if (ret) { + LOG(ERROR) << "HcclTransport: nonAggTransportMemTask error, local " + "devicePhyId: " + << local_rank_info_.devicePhyId + << ", remote devicePhyId: " + << remote_rank_info_.devicePhyId + << ", source_addr: " << slice->source_addr + << ", dest_addr: " << slice->hccl.dest_addr + << ", dest_addr_type: " << slice->hccl.dest_addr_type + << ", ret: " << ret; + slice->markFailed(); + slice->status = Slice::SliceStatus::FAILED; + return ret; + } + } - for (size_t i = 0; i < THREAD_NUM; ++i) { - allInitiatorThreads_[i].join(); - allAcceptThreads_[i].join(); + auto taskDispatch = std::chrono::high_resolution_clock::now(); + ret = transportMemAddOpFence(&remote_rank_info_, stream); + if (ret) { + LOG(ERROR) << "transportMemAddOpFence failed, local devicePhyId: " + << local_rank_info_.devicePhyId + << ", remote devicePhyId: " << remote_rank_info_.devicePhyId + << ", ret: " << ret; + for (auto slice : slice_list) { + slice->markFailed(); } + return ret; } - metadata_->removeSegmentDesc(local_server_name_); + + ret = aclrtSynchronizeStream(stream); + if (ret) { + LOG(ERROR) << "aclrtSynchronizeStream failed, local devicePhyId: " + << local_rank_info_.devicePhyId + << ", remote devicePhyId: " << remote_rank_info_.devicePhyId + << ", ret: " << ret; + for (auto slice : slice_list) { + slice->markFailed(); + } + return ret; + } + + for (auto slice : slice_list) { + if (slice->status != Slice::SliceStatus::FAILED) { + slice->markSuccess(); + slice->task->transferred_bytes = slice->length; + } + } + + auto stop = std::chrono::high_resolution_clock::now(); + pid_t pid = getpid(); + if (enableAscendLogging()) { + auto duration_call = + std::chrono::duration_cast(taskDispatch - + start); + auto duration_sync = + std::chrono::duration_cast(stop - + taskDispatch); + LOG(INFO) << "pid: " << pid + << ", target hostIp: " << remote_rank_info_.hostIp + << ", local devicePhyId: " << local_rank_info_.devicePhyId + << ", target devicePhyId: " << remote_rank_info_.devicePhyId + << ", batch call spent: " << duration_call.count() << "us" + << ", batch sync spent: " << duration_sync.count() << "us"; + } else { + (void)start; + (void)taskDispatch; + (void)stop; + } + return 0; } -void HcclTransport::initiatorLoop(int deviceLogicId, int selfIdx) { +void HcclTransport::initiatorLoop(int deviceLogicId) { aclrtStream stream; int ret = aclrtSetDevice(deviceLogicId); if (ret) { @@ -52,134 +197,34 @@ void HcclTransport::initiatorLoop(int deviceLogicId, int selfIdx) { LOG(ERROR) << "HcclTransport: aclrtCreateStream error, ret: " << ret; } - while (1) { - auto waitlock = std::chrono::high_resolution_clock::now(); + while (running_) { std::unique_lock lock(initiator_mutex_); - if (allReqQueues_[selfIdx].empty()) { - initiator_cond_.wait(lock); + initiator_cond_.wait( + lock, [this] { return !allReqQueues_.empty() || !running_; }); + + if (!running_) { + break; } - auto start = std::chrono::high_resolution_clock::now(); - auto slice_list = std::move(allReqQueues_[selfIdx].front()); - allReqQueues_[selfIdx].pop(); + auto slice_list = std::move(allReqQueues_.front()); + allReqQueues_.pop(); lock.unlock(); if (slice_list.empty()) { LOG(ERROR) << "HcclTransport: empty transfer request batch"; } - auto segment_desc = - metadata_->getSegmentDescByID(slice_list[0]->target_id); - if (!segment_desc) { - LOG(ERROR) << "Unable to get target segment ID, please recheck, " - "segment ID: " - << slice_list[0]->target_id; - for (auto slice : slice_list) { - slice->markFailed(); - } - continue; - } - - remote_rank_info_.rankId = segment_desc->rank_info.rankId; - inet_pton(AF_INET, segment_desc->rank_info.hostIp.c_str(), - &remote_rank_info_.hostIp); - remote_rank_info_.hostPort = segment_desc->rank_info.hostPort; - remote_rank_info_.deviceLogicId = segment_desc->rank_info.deviceLogicId; - remote_rank_info_.devicePhyId = segment_desc->rank_info.devicePhyId; - inet_pton(AF_INET, segment_desc->rank_info.deviceIp.c_str(), - &remote_rank_info_.deviceIp); - remote_rank_info_.devicePort = segment_desc->rank_info.devicePort; - remote_rank_info_.serverIdx = 0; - remote_rank_info_.pid = segment_desc->rank_info.pid; - - for (auto slice : slice_list) { - ret = transportMemTask(&local_rank_info_, &remote_rank_info_, - slice->opcode, slice->hccl.dest_addr, - slice->length, slice->source_addr, stream); - if (ret) { - LOG(ERROR) << "HcclTransport: transportMemTask error, local " - "devicePhyId: " - << local_rank_info_.devicePhyId - << ", remote devicePhyId: " - << remote_rank_info_.devicePhyId - << ", source_addr: " << slice->source_addr - << ", dest_addr: " << slice->hccl.dest_addr - << ", ret: " << ret; - slice->markFailed(); - slice->status = Slice::SliceStatus::FAILED; - } - } - - auto mid = std::chrono::high_resolution_clock::now(); - ret = transportMemAddOpFence(&remote_rank_info_, stream); + ret = nonAggTransport(slice_list, stream); if (ret) { - LOG(ERROR) << "transportMemAddOpFence failed, local devicePhyId: " - << local_rank_info_.devicePhyId - << ", remote devicePhyId: " - << remote_rank_info_.devicePhyId << ", ret: " << ret; - for (auto slice : slice_list) { - slice->markFailed(); - } - } - auto addOpfence = std::chrono::high_resolution_clock::now(); - - ret = aclrtSynchronizeStream(stream); - if (ret) { - LOG(ERROR) << "aclrtSynchronizeStream failed, local devicePhyId: " - << local_rank_info_.devicePhyId - << ", remote devicePhyId: " - << remote_rank_info_.devicePhyId << ", ret: " << ret; - for (auto slice : slice_list) { - slice->markFailed(); - } - } - for (auto slice : slice_list) { - if (slice->status != Slice::SliceStatus::FAILED) { - slice->markSuccess(); - slice->task->transferred_bytes = slice->length; - } - } - auto stop = std::chrono::high_resolution_clock::now(); - if (printEnabled()) { - pid_t pid = getpid(); - auto duration_wait = - std::chrono::duration_cast(start - - waitlock); - auto duration_call = - std::chrono::duration_cast(mid - - start); - auto duration_addOpfence = - std::chrono::duration_cast( - addOpfence - mid); - auto duration_sync = - std::chrono::duration_cast( - stop - addOpfence); - LOG(INFO) << "pid: " << pid << ", target hostIp: " - << segment_desc->rank_info.hostIp.c_str() - << ", local devicePhyId: " << local_rank_info_.devicePhyId - << ", target devicePhyId: " - << remote_rank_info_.devicePhyId - << ", batch waitlock spent: " << duration_wait.count() - << "ms" - << ", batch call spent: " << duration_call.count() << "us" - << ", batch addOpfence spent: " - << duration_addOpfence.count() << "us" - << ", batch sync spent: " << duration_sync.count() - << "us"; - } else { - (void)waitlock; - (void)start; - (void)mid; - (void)addOpfence; - (void)stop; + LOG(ERROR) << "HcclTransport: nonAggTransport error, ret: " << ret; } } } -void HcclTransport::acceptLoop(int deviceLogicId) { +void HcclTransport::targetAcceptLoop(int deviceLogicId) { int ret = aclrtSetDevice(deviceLogicId); if (ret) { LOG(ERROR) << "HcclTransport: aclrtSetDevice failed ret: " << ret; } while (running_) { - ret = transportMemAccept(&local_rank_info_); + ret = transportMemAccept(&local_rank_info_, aggregateEnabled_); if (ret) { LOG(ERROR) << "HcclTransport: transportMemAccept failed ret: " << ret; @@ -187,7 +232,7 @@ void HcclTransport::acceptLoop(int deviceLogicId) { } } -int HcclTransport::initPdThread() { +int HcclTransport::startNonAggThreads() { pid_t pid = getpid(); int ret = 0; int deviceLogicId; @@ -197,17 +242,13 @@ int HcclTransport::initPdThread() { return ret; } - for (int i = 0; i < THREAD_NUM; ++i) { - allInitiatorThreads_[i] = - std::thread(&HcclTransport::initiatorLoop, this, deviceLogicId, i); - allAcceptThreads_[i] = - std::thread(&HcclTransport::acceptLoop, this, deviceLogicId); - } + initiatorThread_ = + std::thread(&HcclTransport::initiatorLoop, this, deviceLogicId); + targetAcceptThread_ = + std::thread(&HcclTransport::targetAcceptLoop, this, deviceLogicId); - LOG(INFO) << "HcclTransport: initPdThread, pid: " << pid << ";" << "init " - << THREAD_NUM - << " initiator threads and accept threads, deviceLogicId: " - << deviceLogicId; + LOG(INFO) << "HcclTransport: startNonAggThreads, pid: " << pid + << ", deviceLogicId: " << deviceLogicId; return 0; } @@ -243,110 +284,48 @@ int HcclTransport::getDevIdAndIpPortFromServerName(std::string &identifier, std::string npuStr = identifier.substr(secondColon + 1); if (npuStr.find("npu_") != 0) { LOG(ERROR) << "Invalid npu number format - should start with 'npu_'"; - return -1; + return 0; } try { npuId = std::stoi(npuStr.substr(4)); } catch (const std::exception &e) { LOG(ERROR) << "Invalid device_id ID"; - return -1; + return 0; } return 0; } -int HcclTransport::rankInfoParse(int devicePhyId, std::string hostIp) { - int ret = 0; +int HcclTransport::devInfoParse(std::string hostIp) { int deviceLogicId = 0; + int ret = aclrtGetDevice(&deviceLogicId); ret = aclrtGetDevice(&deviceLogicId); if (ret) { LOG(ERROR) << "HcclTransport: aclrtGetDevice failed, ret: " << ret; return ret; } - // Default configuration file path for HCCL - std::ifstream fin("/etc/hccn.conf"); - if (!fin) { - LOG(ERROR) << "can't open conf 文件:/etc/hccn.conf"; - return -1; - } + local_rank_info_.deviceLogicId = (uint32_t)deviceLogicId; + local_rank_info_.rankId = local_rank_info_.deviceLogicId; - std::string line; - while (std::getline(fin, line)) { - if (line.rfind("address_", 0) == 0) { - size_t equal_pos = line.find('='); - if (equal_pos != std::string::npos) { - std::string key = line.substr(8, equal_pos - 8); - key.erase(key.begin(), std::find_if(key.begin(), key.end(), - [](unsigned char c) { - return !std::isspace(c); - })); - if (key == std::to_string(devicePhyId)) { - std::string deviceIp = line.substr(equal_pos + 1); - deviceIp.erase( - deviceIp.begin(), - std::find_if( - deviceIp.begin(), deviceIp.end(), - [](unsigned char c) { return !std::isspace(c); })); - deviceIp.erase( - std::find_if( - deviceIp.rbegin(), deviceIp.rend(), - [](unsigned char c) { return !std::isspace(c); }) - .base(), - deviceIp.end()); - - if (inet_pton(AF_INET, hostIp.c_str(), - &local_rank_info_.hostIp) != 1) { - LOG(ERROR) << "HcclTransport: Invalid Host IP format: " - << hostIp; - return -1; - } - local_rank_info_.rankId = devicePhyId; - local_rank_info_.serverIdx = 0; - local_rank_info_.devicePhyId = devicePhyId; - local_rank_info_.hostPort = - ASCEND_DEFAULT_HOST_PORT + devicePhyId; - local_rank_info_.deviceLogicId = deviceLogicId; - local_rank_info_.devicePort = ASCEND_DEFAULT_DEVICE_PORT; - local_rank_info_.pid = 0; - if (inet_pton(AF_INET, deviceIp.c_str(), - &local_rank_info_.deviceIp) != 1) { - LOG(ERROR) - << "HcclTransport: Invalid Device IP format: " - << deviceIp; - return -1; - } - LOG(INFO) - << "rankInfoParse Success, hostIp: " << hostIp - << ", rankId: " << local_rank_info_.rankId - << ", serverIdx: " << local_rank_info_.serverIdx - << ", devicePhyId: " << local_rank_info_.devicePhyId - << ", hostPort: " << local_rank_info_.hostPort - << ", deviceLogicId: " << local_rank_info_.deviceLogicId - << ", devicePort: " << local_rank_info_.devicePort - << ", deviceIp: " << deviceIp - << ", device pid: " << local_rank_info_.pid; - // Exit after finishing rankInfoParse - return 0; - } - } - } - } - // Not Found - return -1; + strncpy(local_rank_info_.hostIp, hostIp.c_str(), 127); + local_rank_info_.hostIp[127] = '\0'; + + local_rank_info_.devicePort = ASCEND_DEFAULT_DEVICE_PORT; + + return 0; } int HcclTransport::install(std::string &local_server_name, std::shared_ptr meta, std::shared_ptr topo) { - int ret = 0; int port; std::string hostIp; - int devicePhyId; + int devicePhyId = 0; metadata_ = meta; - ret = getDevIdAndIpPortFromServerName(local_server_name, hostIp, port, - devicePhyId); + int ret = getDevIdAndIpPortFromServerName(local_server_name, hostIp, port, + devicePhyId); if (ret) { LOG(ERROR) << "HcclTransport: getDevIdAndIpPortFromServerName failed, ret: " @@ -360,13 +339,13 @@ int HcclTransport::install(std::string &local_server_name, << devicePhyId << ", local_server_name: " << local_server_name; // add to local_rank_info_ - ret = rankInfoParse(devicePhyId, hostIp); + ret = devInfoParse(hostIp); if (ret) { - LOG(ERROR) << "HcclTransport: rankInfoParse failed, ret: " << ret; + LOG(ERROR) << "HcclTransport: devInfoParse failed, ret: " << ret; return ret; } - ret = initTransportMem(&local_rank_info_); + ret = initTransportMem(&local_rank_info_, aggregateEnabled_); if (ret) { LOG(ERROR) << "HcclTransport: initTransportMem failed, ret: " << ret; return ret; @@ -387,10 +366,37 @@ int HcclTransport::install(std::string &local_server_name, return ret; } - ret = initPdThread(); - if (ret) { - LOG(ERROR) << "HcclTransport: initPdThread failed, ret: " << ret; - return ret; + running_ = true; + + if (aggregateEnabled_) { + ret = startAggThreads(); + if (ret) { + LOG(ERROR) << "HcclTransport: startAggThreads failed, ret: " << ret; + return ret; + } + + for (size_t i = 0; i < HUGE_BUFFER_NUM; i++) { + void *devAddr = nullptr; + ret = aclrtMalloc(&devAddr, PER_HUGE_BUFFER_SIZE, + ACL_MEM_MALLOC_NORMAL_ONLY); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "Failed to allocate device memory, ret:" << ret; + return ret; + } + const uint64_t alignment = 1 << 21; + if ((uint64_t)devAddr % alignment != 0) { + LOG(ERROR) << "The Merge malloc address is not 2M aligned."; + return -1; + } + aggRegLocalMem((uint64_t)devAddr, PER_HUGE_BUFFER_SIZE, true); + } + } else { + ret = startNonAggThreads(); + if (ret) { + LOG(ERROR) << "HcclTransport: startNonAggThreads failed, ret: " + << ret; + return ret; + } } return 0; @@ -420,6 +426,7 @@ Status HcclTransport::submitTransfer( slice->length = request.length; slice->opcode = request.opcode; slice->hccl.dest_addr = request.target_offset; + slice->hccl.dest_addr_type = request.target_offset_type; slice->task = &task; slice->target_id = request.target_id; slice->status = Slice::PENDING; @@ -428,10 +435,9 @@ Status HcclTransport::submitTransfer( std::unique_lock lock(initiator_mutex_); slice_list.push_back(slice); lock.unlock(); - initiator_cond_.notify_one(); } std::unique_lock lock(initiator_mutex_); - allReqQueues_[0].push(slice_list); + allReqQueues_.push(slice_list); lock.unlock(); initiator_cond_.notify_one(); @@ -453,6 +459,7 @@ Status HcclTransport::submitTransferTask( slice->length = request.length; slice->opcode = request.opcode; slice->hccl.dest_addr = request.target_offset; + slice->hccl.dest_addr_type = request.target_offset_type; slice->task = &task; slice->target_id = request.target_id; slice->status = Slice::PENDING; @@ -462,7 +469,7 @@ Status HcclTransport::submitTransferTask( slice_list.push_back(slice); } std::unique_lock lock(initiator_mutex_); - allReqQueues_[0].push(slice_list); + allReqQueues_.push(slice_list); lock.unlock(); initiator_cond_.notify_one(); @@ -495,6 +502,7 @@ Status HcclTransport::getTransferStatus(BatchID batch_id, size_t task_id, return Status::OK(); } +uint64_t g_transfer_dev_len = 0x4000000; int HcclTransport::registerLocalMemory(void *addr, size_t length, const std::string &location, bool remote_accessible, @@ -506,12 +514,31 @@ int HcclTransport::registerLocalMemory(void *addr, size_t length, buffer_desc.length = (uint64_t)length; int ret; - ret = regLocalRmaMem(addr, (uint64_t)length); - if (ret) { - LOG(ERROR) << "HcclTransport: reglocalRmaMem failed, ret: " << ret; - return ret; + if (location == "cpu") { + LOG(INFO) << "kv pool HcclTransport: registerLocalMemory: " << addr + << ", length:" << length; + // if (!aggregateEnabled_) { + // void *dev_addr = NULL; + // ret = aclrtMalloc(&dev_addr, g_transfer_dev_len, + // ACL_MEM_MALLOC_NORMAL_ONLY); + // if (ret != ACL_ERROR_NONE) { + // LOG(ERROR) << "failed to allocate device memory len:" + // << g_transfer_dev_len; + // return ret; + // } + // nonAggRegLocalMem((uint64_t)dev_addr, g_transfer_dev_len, true); + // } + } else { + if (aggregateEnabled_) { + // 强制开聚合,这里不注册内存 + // aggRegLocalMem(buffer_desc.addr, buffer_desc.length, false); + ret = 0; + } else { + LOG(INFO) << "aggregateEnabled_:FASLE, P2P HcclTransport: registerLocalMemory: " << buffer_desc.addr + << ", length:" << buffer_desc.length; + nonAggRegLocalMem(buffer_desc.addr, buffer_desc.length, false); + } } - ret = metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); if (ret) { LOG(ERROR) << "HcclTransport: addLocalMemoryBuffer failed, ret: " @@ -532,13 +559,16 @@ int HcclTransport::allocateLocalSegmentID() { desc->name = local_server_name_; desc->protocol = "ascend"; desc->rank_info.rankId = local_rank_info_.rankId; - desc->rank_info.hostIp = inet_ntoa(local_rank_info_.hostIp); + desc->rank_info.hostIp = local_rank_info_.hostIp; desc->rank_info.hostPort = local_rank_info_.hostPort; desc->rank_info.deviceLogicId = local_rank_info_.deviceLogicId; desc->rank_info.devicePhyId = local_rank_info_.devicePhyId; - desc->rank_info.deviceIp = inet_ntoa(local_rank_info_.deviceIp); + desc->rank_info.deviceIp = local_rank_info_.deviceIp; desc->rank_info.devicePort = local_rank_info_.devicePort; - desc->rank_info.pid = local_rank_info_.pid; + desc->rank_info.devPid = local_rank_info_.devPid; + desc->rank_info.sdid = local_rank_info_.sdid; + desc->rank_info.serverId = local_rank_info_.serverId; + desc->rank_info.vnicIp = local_rank_info_.vnicIp; metadata_->addLocalSegment(LOCAL_SEGMENT_ID, local_server_name_, std::move(desc)); diff --git a/scripts/ascend/perf/llmdatadist_bandwidth_test_cross_machine_demo.py b/scripts/ascend/perf/llmdatadist_bandwidth_test_cross_machine_demo.py index 6b73acbe2..7ce909573 100644 --- a/scripts/ascend/perf/llmdatadist_bandwidth_test_cross_machine_demo.py +++ b/scripts/ascend/perf/llmdatadist_bandwidth_test_cross_machine_demo.py @@ -69,7 +69,7 @@ def link(datadist, prefill_device_id, decode_device_id): "rank_id": "0" } ], - "server_id": "1" + "server_id": "0" }, { "device": [ @@ -79,7 +79,7 @@ def link(datadist, prefill_device_id, decode_device_id): "rank_id": "1" } ], - "server_id": "2" + "server_id": "1" } ] } diff --git a/scripts/ascend/pkg/ReadMe.txt b/scripts/ascend/pkg/ReadMe.txt deleted file mode 100644 index 424eaba87..000000000 --- a/scripts/ascend/pkg/ReadMe.txt +++ /dev/null @@ -1,18 +0,0 @@ -【替换命令】 -arm环境: -cp transport_mem.h /usr/local/Ascend/ascend-toolkit/latest/aarch64-linux/include/experiment/hccl/ -cp hccl_mem.h /usr/local/Ascend/ascend-toolkit/latest/aarch64-linux/include/experiment/hccl/ -cp hccl_mem_defs.h /usr/local/Ascend/ascend-toolkit/latest/aarch64-linux/include/experiment/hccl/ -cp transport_mem.h /usr/local/Ascend/ascend-toolkit/8.1.RC1/aarch64-linux/include/experiment/hccl/ -cp hccl_mem.h /usr/local/Ascend/ascend-toolkit/8.1.RC1/aarch64-linux/include/experiment/hccl/ -cp hccl_mem_defs.h /usr/local/Ascend/ascend-toolkit/8.1.RC1/aarch64-linux/include/experiment/hccl/ -cp arm/libhccl* /usr/local/Ascend/ascend-toolkit/8.1.RC1/aarch64-linux/lib64/ - -x86环境: -cp transport_mem.h /usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/include/experiment/hccl/ -cp hccl_mem.h /usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/include/experiment/hccl/ -cp hccl_mem_defs.h /usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/include/experiment/hccl/ -cp transport_mem.h /usr/local/Ascend/ascend-toolkit/8.1.RC1/x86_64-linux/include/experiment/hccl/ -cp hccl_mem.h /usr/local/Ascend/ascend-toolkit/8.1.RC1/x86_64-linux/include/experiment/hccl/ -cp hccl_mem_defs.h /usr/local/Ascend/ascend-toolkit/8.1.RC1/x86_64-linux/include/experiment/hccl/ -cp x86/libhccl* /usr/local/Ascend/ascend-toolkit/8.1.RC1/x86_64-linux/lib64/ diff --git a/scripts/ascend/pkg/hccl_mem.h b/scripts/ascend/pkg/hccl_mem.h deleted file mode 100644 index 0cd7e83c6..000000000 --- a/scripts/ascend/pkg/hccl_mem.h +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - * Description: HCCL内存管理接口,提供跨设备内存注册与访问能力 - */ - -#ifndef HCCL_MEM_H -#define HCCL_MEM_H - -#include "hccl_types.h" -#include "hccl_mem_defs.h" -#include - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -/* 网络设备句柄 */ -typedef void *HcclNetDev; - -/** - * @struct HcclBuf - * @brief 内存缓冲区描述结构体 - * @var addr - 虚拟地址指针 - * @var len - 内存长度(单位字节) - * @var handle - 内存管理句柄 - */ -typedef struct { - void *addr; - uint64_t len; - void *handle; -} HcclBuf; - -/** - * @brief 注册设备可访问内存 - * @param[in] netDev 待绑定的网络设备 - * @param[in] mem 要注册的原始内存 - * @param[out] buf 返回的缓冲区描述符 - * @return 执行状态码 HcclResult - */ -extern HcclResult HcclMemReg(HcclNetDev netDev, const HcclMem *mem, - HcclBuf *buf); - -/** - * @brief 注销已注册的内存区域 - * @param[in] buf 要注销的缓冲区描述符 - * @return 执行状态码 HcclResult - */ -extern HcclResult HcclMemDereg(const HcclBuf *buf); - -/** - * @brief 获取内存描述信息 - * @param[in] buf 已注册的缓冲区 - * @param[out] outDesc 返回描述信息指针(调用方不要释放) - * @param[out] outDescLen 返回描述信息长度 - * @return 执行状态码 HcclResult - */ -extern HcclResult HcclMemExport(HcclBuf *buf, char **outDesc, - uint64_t *outDescLen); - -/** - * @brief 通过描述信息重建内存缓冲区 - * @param[in] description 序列化的描述信息 - * @param[in] descLen 描述信息长度 - * @param[in] isRemote 是否远端访问标识 - * @param[out] outBuf 返回的缓冲区描述符 - * @return 执行状态码 HcclResult - */ -extern HcclResult HcclMemImport(const char *description, uint32_t descLen, - bool isRemote, HcclBuf *outBuf); - -/** - * @brief 关闭已打开的内存缓冲区 - * @param[in] buf 要关闭的缓冲区描述符 - * @return 执行状态码 HcclResult - */ -extern HcclResult HcclMemClose(HcclBuf *buf); - -/** - * @struct HcclMemGrantInfo - * @brief 内存授权信息结构体 - * @var remoteSdid - 目标设备的SuperPod ID - * @var remotePid - 目标进程的进程ID - */ -typedef struct { - uint32_t remoteSdid; - int32_t remotePid; -} HcclMemGrantInfo; - -/** - * @brief 授权本机内存给指定远端进程 - * @param[in] localBuf 本地缓冲区描述符 - * @param[in] remoteGrantInfo 远端授权目标信息 - * @return 执行状态码 HcclResult - */ -extern HcclResult HcclMemGrant(HcclBuf *localBuf, - const HcclMemGrantInfo *remoteGrantInfo); - -/** - * @brief 内存重映射接口 - * @param[in] netDev 目标网络设备 - * @param[in] memArray 内存段数组指针 - * @param[in] arraySize 内存段数组长度 - * @return 执行状态码 HcclResult - * @attention 需确保内存段已经在目标网络设备注册 - */ -extern HcclResult HcclMemRemap(HcclNetDev netDev, const HcclMem *memArray, - uint64_t arraySize); - -#ifdef __cplusplus -} -#endif // __cplusplus -#endif \ No newline at end of file diff --git a/scripts/ascend/pkg/hccl_mem_defs.h b/scripts/ascend/pkg/hccl_mem_defs.h deleted file mode 100644 index 4f52d881d..000000000 --- a/scripts/ascend/pkg/hccl_mem_defs.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - * Description: HCCL内存类型定义文件,声明内存数据类型 - */ - -#ifndef HCCL_MEM_DEFS_H -#define HCCL_MEM_DEFS_H - -#include - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -/** - * @enum HcclMemType - * @brief 内存类型枚举定义 - */ -typedef enum { - HCCL_MEM_TYPE_DEVICE, ///< 设备侧内存(如NPU等) - HCCL_MEM_TYPE_HOST, ///< 主机侧内存 - HCCL_MEM_TYPE_NUM ///< 内存类型数量 -} HcclMemType; - -/** - * @struct HcclMem - * @brief 内存段元数据描述结构体 - * @var type - 内存物理位置类型,参见HcclMemType - * @var addr - 内存虚拟地址 - * @var size - 内存区域字节数 - */ -typedef struct { - HcclMemType type; - void *addr; - uint64_t size; -} HcclMem; - -#ifdef __cplusplus -} -#endif // __cplusplus -#endif \ No newline at end of file diff --git a/scripts/ascend/pkg/transport_mem.h b/scripts/ascend/pkg/transport_mem.h deleted file mode 100644 index 51e10b245..000000000 --- a/scripts/ascend/pkg/transport_mem.h +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. - * Description: - * Author: l30050806 - * Create: 2024-12-27 - */ - -#ifndef TRANSPORT_MEM_H -#define TRANSPORT_MEM_H - -#include -#include -#include "dispatcher.h" -#include "notify_pool.h" -#include "hccl_socket.h" -#include "hccl_network_pub.h" -#include "hccl_common.h" - -namespace hccl { - -enum class RmaMemType : int { - DEVICE = 0, // device侧内存 - HOST = 1, // host侧内存 - TYPE_NUM -}; - -constexpr size_t TRANSPORT_EMD_ESC_SIZE = 512U - (sizeof(u32) * 2); - -class TransportMem { - public: - enum class TpType : int { IPC = 0, ROCE = 1, TYPE_NUM }; - - struct AttrInfo { - u32 localRankId; - u32 remoteRankId; - u32 sdid; // 本端所属超节点 - u32 serverId; // 本端所属server - }; - - struct RmaMemDesc { - u32 localRankId; - u32 remoteRankId; - char memDesc[TRANSPORT_EMD_ESC_SIZE]; - }; - - struct RmaMemDescs { - RmaMemDesc *array; - u32 arrayLength; - }; - - struct RmaOpMem { - void *addr; - u64 size; - }; - - struct RmaMem { - RmaMemType type; // segment的内存类型 - void *addr; // segment的虚拟地址 - u64 size; // segment的size - }; - - static std::shared_ptr Create( - TpType tpType, const std::unique_ptr ¬ifyPool, - const HcclNetDevCtx &netDevCtx, const HcclDispatcher &dispatcher, - AttrInfo &attrInfo); - - explicit TransportMem(const std::unique_ptr ¬ifyPool, - const HcclNetDevCtx &netDevCtx, - const HcclDispatcher &dispatcher, AttrInfo &attrInfo); - virtual ~TransportMem(); - virtual HcclResult ExchangeMemDesc(const RmaMemDescs &localMemDescs, - RmaMemDescs &remoteMemDescs, - u32 &actualNumOfRemote) = 0; - virtual HcclResult EnableMemAccess(const RmaMemDesc &remoteMemDesc, - RmaMem &remoteMem) = 0; - virtual HcclResult DisableMemAccess(const RmaMemDesc &remoteMemDesc) = 0; - virtual HcclResult SetDataSocket(const std::shared_ptr &socket); - - virtual HcclResult SetSocket(const std::shared_ptr &socket) = 0; - virtual HcclResult Connect(s32 timeoutSec) = 0; - virtual HcclResult Write(const RmaOpMem &remoteMem, - const RmaOpMem &localMem, - const rtStream_t &stream) = 0; - virtual HcclResult Read(const RmaOpMem &localMem, const RmaOpMem &remoteMem, - const rtStream_t &stream) = 0; - virtual HcclResult AddOpFence(const rtStream_t &stream) = 0; - - protected: - // 从 string 拷贝到 memDesc - HcclResult RmaMemDescCopyFromStr(RmaMemDesc &rmaMemDesc, - const std::string &memDescStr) const { - if (memcpy_s(rmaMemDesc.memDesc, TRANSPORT_EMD_ESC_SIZE, - memDescStr.c_str(), memDescStr.size() + 1) != EOK) { - return HCCL_E_INTERNAL; - } - return HCCL_SUCCESS; - } - - // 从 memDesc 转换为 string - std::string RmaMemDescCopyToStr(const RmaMemDesc &rmaMemDesc) const { - return std::string(rmaMemDesc.memDesc, TRANSPORT_EMD_ESC_SIZE); - } - - HcclResult DoExchangeMemDesc(const RmaMemDescs &localMemDescs, - RmaMemDescs &remoteMemDescs, - u32 &actualNumOfRemote); - HcclResult SendLocalMemDesc(const RmaMemDescs &localMemDescs); - HcclResult ReceiveRemoteMemDesc(RmaMemDescs &remoteMemDescs, - u32 &actualNumOfRemote); - - const std::unique_ptr ¬ifyPool_; - HcclNetDevCtx netDevCtx_{nullptr}; - HcclDispatcher dispatcher_{nullptr}; - - u32 localRankId_{0}; - u32 remoteRankId_{0}; - std::shared_ptr socket_{nullptr}; - - std::shared_ptr dataSocket_{nullptr}; -}; -} // namespace hccl -#endif \ No newline at end of file