From 0a54e59380436ac6f2c476025e18af10331aca59 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian Date: Tue, 6 Jun 2023 14:00:06 -0400 Subject: [PATCH] Fix error handling for GPU tensors (#249) * Fix error handling for GPU tensors * Fix GPU buffer handling * Review edit * Fix for dynamically batched responses with GPU tensor * Review edits * Fix unused i variable for GPU=OFF * Review comments * Review edit --- CMakeLists.txt | 2 + src/gpu_buffers.cc | 88 ++++++++++++++++++++++++++++++ src/gpu_buffers.h | 67 +++++++++++++++++++++++ src/infer_request.cc | 11 +++- src/infer_response.cc | 54 +++++++----------- src/infer_response.h | 13 +++-- src/pb_stub.cc | 60 +++++++++++--------- src/pb_utils.h | 10 +--- src/python_be.cc | 121 +++++++++++++---------------------------- src/response_sender.cc | 26 +++++---- 10 files changed, 280 insertions(+), 172 deletions(-) create mode 100644 src/gpu_buffers.cc create mode 100644 src/gpu_buffers.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 213b1927..3659c7bd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -163,6 +163,8 @@ set( src/metric.cc src/metric_family.h src/metric_family.cc + src/gpu_buffers.cc + src/gpu_buffers.h ) set( diff --git a/src/gpu_buffers.cc b/src/gpu_buffers.cc new file mode 100644 index 00000000..6b370ea1 --- /dev/null +++ b/src/gpu_buffers.cc @@ -0,0 +1,88 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "gpu_buffers.h" +#include "pb_string.h" + +namespace triton { namespace backend { namespace python { +GPUBuffersHelper::GPUBuffersHelper() +{ + completed_ = false; +} + +void +GPUBuffersHelper::AddBuffer(const bi::managed_external_buffer::handle_t& handle) +{ + if (completed_) { + throw PythonBackendException( + "It is not possible to add buffers after 'Complete' has been called on " + "a GPUBuffersHelper."); + } + + buffers_.emplace_back(handle); +} + +void +GPUBuffersHelper::SetError( + std::unique_ptr& shm_pool, const std::string& error) +{ + error_shm_ = PbString::Create(shm_pool, error); +} + +void +GPUBuffersHelper::Complete(std::unique_ptr& shm_pool) +{ + if (completed_) { + throw PythonBackendException( + "Complete has already been called. Complete should only be called " + "once."); + } + gpu_buffers_shm_ = shm_pool->Construct(); + if (!error_shm_) { + buffers_handle_shm_ = + shm_pool->Construct( + buffers_.size()); + gpu_buffers_shm_.data_->buffer_count = buffers_.size(); + gpu_buffers_shm_.data_->success = true; + gpu_buffers_shm_.data_->buffers = buffers_handle_shm_.handle_; + for (size_t i = 0; i < buffers_.size(); ++i) { + buffers_handle_shm_.data_.get()[i] = buffers_[i]; + } + } else { + gpu_buffers_shm_.data_->success = false; + gpu_buffers_shm_.data_->error = error_shm_->ShmHandle(); + } + completed_ = true; +} + + +bi::managed_external_buffer::handle_t +GPUBuffersHelper::ShmHandle() +{ + return gpu_buffers_shm_.handle_; +} + +}}} // namespace triton::backend::python diff --git a/src/gpu_buffers.h b/src/gpu_buffers.h new file mode 100644 index 00000000..fd683ba7 --- /dev/null +++ b/src/gpu_buffers.h @@ -0,0 +1,67 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include "pb_string.h" +#include "pb_utils.h" +#include "scoped_defer.h" + +namespace triton { namespace backend { namespace python { + +/// \param success indicating whether the process of fetching the GPU buffers +/// was successful. +/// \param error if success is equal to false, the error object will be set. +/// \param buffers list of buffers elements. +/// \param buffer_count the number of buffers. +struct GPUBuffersShm { + bool success; + bi::managed_external_buffer::handle_t error; + bi::managed_external_buffer::handle_t buffers; + uint32_t buffer_count; +}; + +/// Helper class to facilitate transfer of metadata associated +/// the GPU buffers in shared memory. +class GPUBuffersHelper { + public: + GPUBuffersHelper(); + void AddBuffer(const bi::managed_external_buffer::handle_t& handle); + void Complete(std::unique_ptr& shm_pool); + void SetError( + std::unique_ptr& shm_pool, const std::string& error); + bi::managed_external_buffer::handle_t ShmHandle(); + + private: + AllocatedSharedMemory gpu_buffers_shm_; + std::vector buffers_; + AllocatedSharedMemory + buffers_handle_shm_; + std::unique_ptr error_shm_; + bool completed_; +}; + +}}}; // namespace triton::backend::python diff --git a/src/infer_request.cc b/src/infer_request.cc index 2a9799db..3ecde9e8 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -28,6 +28,7 @@ #include +#include "gpu_buffers.h" #include "pb_utils.h" #include "scoped_defer.h" #ifdef TRITON_PB_STUB @@ -481,11 +482,19 @@ InferRequest::Exec(const bool is_decoupled) // Additional round trip required for asking the stub process // to fill in the GPU tensor buffers if (has_gpu_tensor) { + AllocatedSharedMemory gpu_buffers_shm = + shm_pool->Load( + request_batch_shm_ptr->gpu_buffers_handle); AllocatedSharedMemory gpu_buffers_handle = shm_pool->Load( - request_batch_shm_ptr->gpu_buffers_handle); + gpu_buffers_shm.data_->buffers); try { + if (!gpu_buffers_shm.data_->success) { + std::unique_ptr error = PbString::LoadFromSharedMemory( + shm_pool, gpu_buffers_shm.data_->error); + throw PythonBackendException(error->String()); + } #ifdef TRITON_ENABLE_GPU size_t i = 0; for (auto& input_tensor : this->Inputs()) { diff --git a/src/infer_response.cc b/src/infer_response.cc index 4defd74b..668a03d1 100644 --- a/src/infer_response.cc +++ b/src/infer_response.cc @@ -201,64 +201,50 @@ InferResponse::IsLastResponse() } #ifndef TRITON_PB_STUB -std::shared_ptr +void InferResponse::Send( - TRITONBACKEND_ResponseFactory* response_factory, void* cuda_stream, + TRITONBACKEND_Response* response, void* cuda_stream, bool& requires_deferred_callback, const uint32_t flags, std::unique_ptr& shm_pool, + GPUBuffersHelper& gpu_buffer_helper, std::vector, void*>>& output_buffers, - const std::set& requested_output_names, - TRITONBACKEND_Response* response) + const std::set& requested_output_names) { std::shared_ptr response_error = WrapTritonErrorInSharedPtr(nullptr); std::unique_ptr response_error_handling; requires_deferred_callback = false; - // Should only destruct the response factory whenever a response factory is - // being created. - bool destruct_response_factor = (response == nullptr); - - if (response == nullptr) { - SET_ERROR_AND_RETURN( - response_error, - TRITONBACKEND_ResponseNewFromFactory(&response, response_factory)); - } - // This lambda expression will be called when this function exits, if the // inference response doesn't have any GPU tensors. Otherwise, it will be // called when the object is destructed or DeferredSendCallback is called. - response_error_handling = std::make_unique( - [response, response_error, flags, response_factory, - destruct_response_factor] { + response_error_handling = + std::make_unique([response, response_error, flags] { if (response != nullptr) { LOG_IF_ERROR( TRITONBACKEND_ResponseSend(response, flags, *response_error), "failed to send the response."); - if (flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL && - destruct_response_factor) { - std::unique_ptr< - TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> - response_factory_ptr( - reinterpret_cast( - response_factory)); - } } }); // Moves the response sending callback so that it is not called until the stub // process fills in the GPU buffers. - ScopedDefer deferred_task( - [this, &requires_deferred_callback, &response_error_handling] { - if (requires_deferred_callback) { - deferred_send_callback_ = std::move(response_error_handling); - } - }); + ScopedDefer deferred_task([this, &requires_deferred_callback, + &response_error_handling, &gpu_buffer_helper, + response_error, &shm_pool] { + if (*response_error != nullptr) { + gpu_buffer_helper.SetError( + shm_pool, TRITONSERVER_ErrorMessage(*response_error)); + } + if (requires_deferred_callback) { + deferred_send_callback_ = std::move(response_error_handling); + } + }); if (HasError()) { *response_error = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INTERNAL, Error()->Message().c_str()); - return nullptr; + return; } bool cuda_copy = false; @@ -322,6 +308,7 @@ InferResponse::Send( output_tensor->ByteSize(), reinterpret_cast(buffer), true /* copy_gpu */)); } + gpu_buffer_helper.AddBuffer(output_buffer->ShmHandle()); output_buffers.push_back({std::move(output_buffer), buffer}); #endif } @@ -336,6 +323,7 @@ InferResponse::Send( shm_pool, actual_memory_type, actual_memory_type_id, output_tensor->ByteSize(), nullptr /* data ptr */)); + gpu_buffer_helper.AddBuffer(output_buffer->ShmHandle()); output_buffers.push_back({std::move(output_buffer), buffer}); } @@ -357,8 +345,6 @@ InferResponse::Send( cudaStreamSynchronize(reinterpret_cast(cuda_stream)); } #endif // TRITON_ENABLE_GPU - - return response_error; } #endif diff --git a/src/infer_response.h b/src/infer_response.h index 9197df4e..330354a1 100644 --- a/src/infer_response.h +++ b/src/infer_response.h @@ -27,6 +27,7 @@ #pragma once #include +#include "gpu_buffers.h" #include "pb_error.h" #include "pb_tensor.h" #include "pb_utils.h" @@ -49,7 +50,7 @@ struct ResponseShm { TRITONSERVER_Error* raasnie_err__ = (X); \ if (raasnie_err__ != nullptr) { \ *E = raasnie_err__; \ - return E; \ + return; \ } \ } while (false) @@ -62,7 +63,7 @@ struct ResponseShm { TRITONSERVER_Error* rarie_err__ = TRITONSERVER_ErrorNew( \ TRITONSERVER_ERROR_INTERNAL, pb_exception.what()); \ *E = rarie_err__; \ - return E; \ + return; \ } \ } while (false) @@ -96,13 +97,13 @@ class InferResponse { /// response needs to be done in two step. The boolean /// 'requires_deferred_callback' indicates whether DeferredSendCallback method /// should be called or not. - std::shared_ptr Send( - TRITONBACKEND_ResponseFactory* response_factory, void* cuda_stream, + void Send( + TRITONBACKEND_Response* response, void* cuda_stream, bool& requires_deferred_callback, const uint32_t flags, std::unique_ptr& shm_pool, + GPUBuffersHelper& gpu_buffer_helper, std::vector, void*>>& output_buffers, - const std::set& requested_output_names = {}, - TRITONBACKEND_Response* response = nullptr); + const std::set& requested_output_names = {}); void DeferredSendCallback(); #endif diff --git a/src/pb_stub.cc b/src/pb_stub.cc index 234cfd9f..22ecd7e9 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -356,9 +356,10 @@ Stub::RunCommand() LoadGPUBuffers(ipc_message); } catch (const PythonBackendException& pb_exception) { - LOG_INFO << "An error occurred while trying to load GPU buffers in the " - "Python backend stub: " - << pb_exception.what() << std::endl; + LOG_ERROR + << "An error occurred while trying to load GPU buffers in the " + "Python backend stub: " + << pb_exception.what() << std::endl; } break; @@ -539,43 +540,48 @@ Stub::ProcessResponse(InferResponse* response) void Stub::LoadGPUBuffers(std::unique_ptr& ipc_message) { - AllocatedSharedMemory gpu_buffers_handle = - shm_pool_->Load(ipc_message->Args()); + ScopedDefer load_gpu_buffer_response([this] { + // LoadGPUBuffers must let the parent process know when loading the + // buffers have been finished. + parent_message_queue_->Push(DUMMY_MESSAGE); + gpu_tensors_.clear(); + }); - uint64_t* gpu_buffer_count = - reinterpret_cast(gpu_buffers_handle.data_.get()); - bi::managed_external_buffer::handle_t* gpu_buffers_handle_shm = - reinterpret_cast( - gpu_buffers_handle.data_.get() + sizeof(uint64_t)); - - if (gpu_tensors_.size() != *gpu_buffer_count) { - LOG_INFO - << (std::string( - "GPU buffers size does not match the provided buffers: ") + - std::to_string(gpu_tensors_.size()) + - " != " + std::to_string(*gpu_buffer_count)); - return; + AllocatedSharedMemory gpu_buffers_handle = + shm_pool_->Load(ipc_message->Args()); + + if (!gpu_buffers_handle.data_->success) { + std::unique_ptr error = PbString::LoadFromSharedMemory( + shm_pool_, gpu_buffers_handle.data_->error); + throw PythonBackendException( + "Failed to load GPU buffers: " + error->String()); } - std::vector> dst_buffers; + uint64_t gpu_buffer_count = gpu_buffers_handle.data_->buffer_count; + AllocatedSharedMemory + gpu_buffers_handle_shm = + shm_pool_->Load( + gpu_buffers_handle.data_->buffers); + if (gpu_tensors_.size() != gpu_buffer_count) { + throw PythonBackendException( + std::string("GPU buffers size does not match the provided buffers: ") + + std::to_string(gpu_tensors_.size()) + + " != " + std::to_string(gpu_buffer_count)); + } + + std::vector> dst_buffers; for (size_t i = 0; i < gpu_tensors_.size(); i++) { std::unique_ptr dst_buffer = PbMemory::LoadFromSharedMemory( - shm_pool_, gpu_buffers_handle_shm[i], true /* open_cuda_handle */); + shm_pool_, gpu_buffers_handle_shm.data_.get()[i], + true /* open_cuda_handle */); dst_buffers.emplace_back(std::move(dst_buffer)); } - ScopedDefer load_gpu_buffer_response([this] { - // Push a dummy message to signal the thread to terminate. - parent_message_queue_->Push(DUMMY_MESSAGE); - }); - for (size_t i = 0; i < gpu_tensors_.size(); i++) { std::shared_ptr& src_buffer = gpu_tensors_[i]; PbMemory::CopyBuffer(dst_buffers[i], src_buffer->Memory()); } - - gpu_tensors_.clear(); } py::list diff --git a/src/pb_utils.h b/src/pb_utils.h index 36d7e3c7..c05e8411 100644 --- a/src/pb_utils.h +++ b/src/pb_utils.h @@ -212,23 +212,17 @@ struct ResponseSenderBase { struct ResponseSendMessage : ResponseSenderBase { bi::managed_external_buffer::handle_t response; - // GPU Buffers handle + // A shm handle to a GPUBuffersShm object. bi::managed_external_buffer::handle_t gpu_buffers_handle; - // GPU buffers count - uint32_t gpu_buffers_count; - uint32_t flags; }; struct RequestBatch { uint32_t batch_size; - // GPU Buffers handle + // A shm handle to a GPUBuffersShm object. bi::managed_external_buffer::handle_t gpu_buffers_handle; - - // GPU buffers count - uint32_t gpu_buffers_count; }; #ifdef TRITON_ENABLE_GPU diff --git a/src/python_be.cc b/src/python_be.cc index 5c815485..87375348 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -25,6 +25,7 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "python_be.h" +#include "gpu_buffers.h" #include "infer_payload.h" #include "pb_log.h" @@ -623,10 +624,9 @@ ModelInstanceState::ExecuteBLSRequest( is_response_batch_set = true; bool has_gpu_tensor = false; + GPUBuffersHelper gpu_buffer_helper; PythonBackendException pb_exception(std::string{}); - - uint32_t gpu_buffers_count = 0; if (request_batch_shm_ptr->batch_size == 1) { std::shared_ptr infer_request; bi::managed_external_buffer::handle_t* request_handle = @@ -643,7 +643,6 @@ ModelInstanceState::ExecuteBLSRequest( for (auto& input_tensor : infer_request->Inputs()) { if (!input_tensor->IsCPU()) { #ifdef TRITON_ENABLE_GPU - gpu_buffers_count++; BackendMemory* backend_memory; std::unique_ptr lbackend_memory; has_gpu_tensor = true; @@ -661,38 +660,25 @@ ModelInstanceState::ExecuteBLSRequest( lbackend_memory.reset(backend_memory); input_tensor->SetMemory(std::move(PbMemory::Create( Stub()->ShmPool(), std::move(lbackend_memory)))); + gpu_buffer_helper.AddBuffer( + input_tensor->Memory()->ShmHandle()); #endif // TRITON_ENABLE_GPU } } } catch (const PythonBackendException& exception) { + gpu_buffer_helper.SetError(Stub()->ShmPool(), exception.what()); pb_exception = exception; } - AllocatedSharedMemory gpu_handles; // Wait for the extra round trip to complete. The stub process will fill // in the data for the GPU tensors. If there is an error, the extra round // trip must be still completed, otherwise the stub process will always be // waiting for a message from the parent process. if (has_gpu_tensor) { - try { - gpu_handles = Stub() - ->ShmPool() - ->Construct( - gpu_buffers_count); - request_batch_shm_ptr->gpu_buffers_count = gpu_buffers_count; - request_batch_shm_ptr->gpu_buffers_handle = gpu_handles.handle_; - size_t i = 0; - for (auto& input_tensor : infer_request->Inputs()) { - if (!input_tensor->IsCPU()) { - gpu_handles.data_.get()[i] = input_tensor->Memory()->ShmHandle(); - ++i; - } - } - } - catch (const PythonBackendException& exception) { - pb_exception = exception; - } + gpu_buffer_helper.Complete(Stub()->ShmPool()); + request_batch_shm_ptr->gpu_buffers_handle = + gpu_buffer_helper.ShmHandle(); bi::scoped_lock lock{ *(ipc_message->ResponseMutex())}; @@ -700,7 +686,7 @@ ModelInstanceState::ExecuteBLSRequest( ipc_message->ResponseCondition()->wait(lock); } - if (pb_exception.what() != nullptr) { + if (pb_exception.what() == std::string{""}) { auto callback = std::bind( &ModelInstanceState::SendBLSDecoupledResponse, this, std::placeholders::_1); @@ -1071,32 +1057,31 @@ ModelInstanceState::ResponseSendDecoupled( false /* open cuda ipc handle */); bool requires_deferred_callback = false; + TRITONBACKEND_Response* response; + SetErrorForResponseSendMessage( + send_message_payload, + WrapTritonErrorInSharedPtr( + TRITONBACKEND_ResponseNewFromFactory(&response, response_factory)), + error_message); + std::vector, void*>> gpu_output_buffers; - std::shared_ptr error = infer_response->Send( - response_factory, CudaStream(), requires_deferred_callback, - send_message_payload->flags, Stub()->ShmPool(), gpu_output_buffers); - SetErrorForResponseSendMessage(send_message_payload, error, error_message); + std::unique_ptr< + TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> + response_factory_ptr; + GPUBuffersHelper gpu_buffer_helper; + if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { + response_factory_ptr.reset( + reinterpret_cast(response_factory)); + } + infer_response->Send( + response, CudaStream(), requires_deferred_callback, + send_message_payload->flags, Stub()->ShmPool(), gpu_buffer_helper, + gpu_output_buffers); if (requires_deferred_callback) { - AllocatedSharedMemory gpu_buffers_handle = - Stub()->ShmPool()->Construct( - sizeof(uint64_t) + - gpu_output_buffers.size() * - sizeof(bi::managed_external_buffer::handle_t)); - uint64_t* gpu_buffer_count = - reinterpret_cast(gpu_buffers_handle.data_.get()); - *gpu_buffer_count = gpu_output_buffers.size(); - bi::managed_external_buffer::handle_t* gpu_buffers_handle_shm = - reinterpret_cast( - gpu_buffers_handle.data_.get() + sizeof(uint64_t)); - send_message_payload->gpu_buffers_handle = gpu_buffers_handle.handle_; - - size_t index = 0; - for (auto& output_buffer_pair : gpu_output_buffers) { - std::unique_ptr& pb_memory = output_buffer_pair.first; - gpu_buffers_handle_shm[index] = pb_memory->ShmHandle(); - ++index; - } + gpu_buffer_helper.Complete(Stub()->ShmPool()); + send_message_payload->gpu_buffers_handle = + gpu_buffer_helper.ShmHandle(); // Additional round trip so that the stub can fill the GPU output buffers. { @@ -1109,7 +1094,6 @@ ModelInstanceState::ResponseSendDecoupled( } } - index = 0; bool cuda_copy = false; for (auto& output_buffer_pair : gpu_output_buffers) { auto& pb_memory = output_buffer_pair.first; @@ -1125,8 +1109,6 @@ ModelInstanceState::ResponseSendDecoupled( CudaStream(), &cuda_used); cuda_copy |= cuda_used; } - gpu_buffers_handle_shm[index] = pb_memory->ShmHandle(); - ++index; #ifdef TRITON_ENABLE_GPU if (cuda_copy) { cudaStreamSynchronize(stream_); @@ -1407,6 +1389,7 @@ ModelInstanceState::ProcessRequests( std::vector> shm_responses; std::vector, void*>>> gpu_output_buffers(request_count); + GPUBuffersHelper gpu_buffer_helper; for (uint32_t r = 0; r < request_count; ++r) { NVTX_RANGE(nvtx_, "LoadingResponse " + Name()); @@ -1466,17 +1449,13 @@ ModelInstanceState::ProcessRequests( gpu_output_buffers[r] = std::vector, void*>>{}; - std::shared_ptr error = infer_response->Send( - nullptr, CudaStream(), require_deferred_callback, + infer_response->Send( + response, CudaStream(), require_deferred_callback, TRITONSERVER_RESPONSE_COMPLETE_FINAL, Stub()->ShmPool(), - gpu_output_buffers[r], requested_output_names, response); - GUARDED_RESPOND_IF_ERROR(responses, r, *error); + gpu_buffer_helper, gpu_output_buffers[r], requested_output_names); requires_deferred_callback[r] = require_deferred_callback; - // Error object will be deleted by the GUARDED_RESPOND macro - *error = nullptr; - error.reset(); if (requires_deferred_callback[r]) { has_gpu_output = true; } @@ -1488,39 +1467,15 @@ ModelInstanceState::ProcessRequests( // If the output tensor is in GPU, there will be a second round trip // required for filling the GPU buffers provided by the main process. if (has_gpu_output) { - size_t total_gpu_buffers_count = 0; - for (auto& gpu_output_buffer : gpu_output_buffers) { - total_gpu_buffers_count += gpu_output_buffer.size(); - } - AllocatedSharedMemory gpu_buffers_handle = - Stub()->ShmPool()->Construct( - sizeof(uint64_t) + - total_gpu_buffers_count * - sizeof(bi::managed_external_buffer::handle_t)); - uint64_t* gpu_buffer_count = - reinterpret_cast(gpu_buffers_handle.data_.get()); - *gpu_buffer_count = total_gpu_buffers_count; - bi::managed_external_buffer::handle_t* gpu_buffers_handle_shm = - reinterpret_cast( - gpu_buffers_handle.data_.get() + sizeof(uint64_t)); - - size_t index = 0; - for (auto& gpu_output_buffer : gpu_output_buffers) { - for (auto& buffer_memory_pair : gpu_output_buffer) { - gpu_buffers_handle_shm[index] = buffer_memory_pair.first->ShmHandle(); - ++index; - } - } - ipc_message->Command() = PYTHONSTUB_CommandType::PYTHONSTUB_LoadGPUBuffers; - ipc_message->Args() = gpu_buffers_handle.handle_; + gpu_buffer_helper.Complete(Stub()->ShmPool()); + ipc_message->Args() = gpu_buffer_helper.ShmHandle(); SendMessageAndReceiveResponse( ipc_message->ShmHandle(), response_message, restart, responses, requests, 0); bool cuda_copy = false; - index = 0; uint32_t response_index = 0; for (auto& gpu_output_buffer : gpu_output_buffers) { for (auto& buffer_memory_pair : gpu_output_buffer) { @@ -1538,8 +1493,6 @@ ModelInstanceState::ProcessRequests( CudaStream(), &cuda_used)); cuda_copy |= cuda_used; } - gpu_buffers_handle_shm[index] = pb_memory->ShmHandle(); - ++index; } response_index++; #ifdef TRITON_ENABLE_GPU diff --git a/src/response_sender.cc b/src/response_sender.cc index e8394df9..31e1be5b 100644 --- a/src/response_sender.cc +++ b/src/response_sender.cc @@ -130,20 +130,21 @@ ResponseSender::Send( } if (has_gpu_output) { - AllocatedSharedMemory gpu_buffers_handle = - shm_pool_->Load(send_message_payload->gpu_buffers_handle); - - bi::managed_external_buffer::handle_t* gpu_buffers_handle_shm = - reinterpret_cast( - gpu_buffers_handle.data_.get() + sizeof(uint64_t)); - uint64_t* gpu_buffer_count = - reinterpret_cast(gpu_buffers_handle.data_.get()); - if (gpu_tensors.size() != *gpu_buffer_count) { - LOG_INFO + AllocatedSharedMemory gpu_buffers_handle = + shm_pool_->Load( + send_message_payload->gpu_buffers_handle); + + AllocatedSharedMemory + gpu_buffers_handle_shm = + shm_pool_->Load( + gpu_buffers_handle.data_->buffers); + uint64_t gpu_buffer_count = gpu_buffers_handle.data_->buffer_count; + if (gpu_tensors.size() != gpu_buffer_count) { + LOG_ERROR << (std::string( "GPU buffers size does not match the provided buffers: ") + std::to_string(gpu_tensors.size()) + - " != " + std::to_string(*gpu_buffer_count)); + " != " + std::to_string(gpu_buffer_count)); return; } @@ -151,7 +152,8 @@ ResponseSender::Send( for (size_t i = 0; i < gpu_tensors.size(); i++) { std::unique_ptr dst_buffer = PbMemory::LoadFromSharedMemory( - shm_pool_, gpu_buffers_handle_shm[i], true /* open_cuda_handle */); + shm_pool_, gpu_buffers_handle_shm.data_.get()[i], + true /* open_cuda_handle */); dst_buffers.emplace_back(std::move(dst_buffer)); std::shared_ptr& src_buffer = gpu_tensors[i]; PbMemory::CopyBuffer(dst_buffers[i], src_buffer->Memory());