diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 2233b7973ee93..d25645123896d 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -72,6 +72,10 @@ struct OrtDevice { DeviceId device_id; }; +inline bool operator==(const OrtDevice& left, const OrtDevice& other) { + return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type(); +} + struct OrtAllocatorInfo { // use string for name, so we could have customized allocator in execution provider. const char* name; diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 5d1e35ecafe69..5046b6e7b5fd6 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -347,6 +347,10 @@ class PlannerImpl { Status ComputeUseCounts() { // Note: for every ml-value, its definition must appear before all its uses in a topological sort of a valid model + std::unordered_set graph_inputs; + for (auto& graph_input : graph_viewer_.GetInputsIncludingInitializers()) { + graph_inputs.insert(graph_input->Name()); + } for (auto graph_input : graph_viewer_.GetInputs()) { OrtValueIndex index = Index(graph_input->Name()); @@ -371,15 +375,7 @@ class PlannerImpl { for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) { auto pnode = graph_viewer_.GetNode(step.node_index); if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index); - for (auto node_input : pnode->InputDefs()) { - if (node_input->Exists()) - UseCount(node_input->Name())++; - } - for (auto node_input : pnode->ImplicitInputDefs()) { - if (node_input->Exists()) - UseCount(node_input->Name())++; - } // Identify where each output of this node should be allocated. // This is determined by the opkernel bound to the node. const KernelCreateInfo* kernel_create_info = nullptr; @@ -394,31 +390,45 @@ class PlannerImpl { if (!pnode->Name().empty()) errormsg << " (node " << pnode->Name() << ")"; return Status(ONNXRUNTIME, FAIL, errormsg.str()); } - auto exec_provider = execution_providers_.Get(*pnode); if (exec_provider == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the execution provider ", pnode->GetExecutionProviderType()); } - auto& default_allocator_info = exec_provider->GetAllocator(0, OrtMemTypeDefault)->Info(); + // increment UseCount and add location information if applicable for the provided input def + auto process_input = [&graph_inputs, &exec_provider, &p_kernelDef, this](const NodeArg& input, size_t arg_idx) { + const auto& name = input.Name(); + UseCount(name)++; + + // If it's a graph input or outer scope node arg, set its plan. + // NOTE: Copy nodes should have already been added if a graph input is fed as input + // to nodes assigned to different providers. + if (graph_inputs.find(name) != graph_inputs.cend() || + std::find_if(outer_scope_node_args_.cbegin(), outer_scope_node_args_.cend(), + [&name](const NodeArg* value) { + return value && value->Name() == name; + }) != outer_scope_node_args_.cend()) { + OrtValueIndex index = Index(name); + plan_.SetLocation(static_cast(index), + exec_provider->GetAllocator(0, p_kernelDef->InputMemoryType(arg_idx))->Info()); + } + + return Status::OK(); + }; + + ORT_RETURN_IF_ERROR(Node::ForEachWithIndex(pnode->InputDefs(), process_input)); + ORT_RETURN_IF_ERROR(Node::ForEachWithIndex(pnode->ImplicitInputDefs(), process_input)); + auto outputs = pnode->OutputDefs(); auto num_outputs = outputs.size(); - for (size_t i = 0; i < num_outputs; ++i) { auto* node_output = outputs[i]; if (!node_output->Exists()) continue; OrtValueIndex index = Index(node_output->Name()); ProcessDef(index, node_output); ++UseCount(index); - if (strcmp(default_allocator_info.name, CPU) != 0) { - // By default, outputs of this node are allocated on the default device allocator, - // except for outputs marked for allocation in MemoryType: - auto memory_type = p_kernelDef->OutputMemoryType(i); - plan_.SetLocation(static_cast(index), memory_type == OrtMemTypeDefault - ? default_allocator_info - : exec_provider->GetAllocator(0, memory_type)->Info()); - } + plan_.SetLocation(static_cast(index), exec_provider->GetAllocator(0, p_kernelDef->OutputMemoryType(i))->Info()); } // if sync is needed, mark allocation plan as create_fence_if_async=true // note that the input arg may come from an execution provider (i.e. CPU) that does not support async, diff --git a/onnxruntime/core/framework/feeds_fetches_manager.h b/onnxruntime/core/framework/feeds_fetches_manager.h index 000eaa504176f..d646c82ab23d4 100644 --- a/onnxruntime/core/framework/feeds_fetches_manager.h +++ b/onnxruntime/core/framework/feeds_fetches_manager.h @@ -48,9 +48,8 @@ struct FeedsFetchesInfo { class FeedsFetchesManager { public: struct MLValueCopyInfo { - int allocation_device_id = 0; + OrtDevice target_device; const IExecutionProvider* allocation_provider = nullptr; - const IExecutionProvider* copy_provider = nullptr; }; static Status Create(const std::vector& feed_names, const std::vector& output_names, diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index df506ddb6af35..fe53971656932 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -175,10 +175,6 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f //prepare the func kernel KernelDefBuilder builder; BuildFusedKernelDef(builder, *node); - if (node->GetExecutionProviderType() == onnxruntime::kNGraphExecutionProvider || node->GetExecutionProviderType() == onnxruntime::kNnapiExecutionProvider) { - builder.SetDefaultInputsMemoryType(OrtMemTypeCPUInput); - builder.SetDefaultOutputMemoryType(OrtMemTypeCPUOutput); - } ORT_RETURN_IF_ERROR(fused_kernel_registry->Register( builder, static_cast([](const OpKernelInfo& info) -> OpKernel* { return new FunctionKernel(info); }))); } diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 9387977dc3d80..92a6f107e5058 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -141,17 +141,18 @@ class SessionState { * \param p_node0 Nullable * \param kci0 Nullable */ - NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0) + NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0, const OrtDevice& device0) : index(index0), p_node(p_node0), - kci(kci0) { - } + kci(kci0), + device(&device0) {} size_t index; // Nullable const onnxruntime::Node* p_node = nullptr; // Nullable const KernelCreateInfo* kci = nullptr; + const OrtDevice* device = nullptr; }; using NameNodeInfoMapType = std::unordered_map>; diff --git a/onnxruntime/core/framework/session_state_initializer.cc b/onnxruntime/core/framework/session_state_initializer.cc index 3f4777d8608d0..54dc8e4aa0d62 100644 --- a/onnxruntime/core/framework/session_state_initializer.cc +++ b/onnxruntime/core/framework/session_state_initializer.cc @@ -351,6 +351,8 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph if (implicit_inputs && implicit_inputs->empty()) { implicit_inputs = nullptr; } + const auto* exec_plan = session_state.GetExecutionPlan(); + const auto& name_to_id = session_state.GetOrtValueNameIdxMap(); for (auto& node : graph.Nodes()) { // note that KernelCreateInfo may not exist for custom kernel @@ -365,7 +367,11 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph return Status::OK(); } - SessionState::NodeInfo node_info(index, &node, kci); + int arg_index; + ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg.Name(), arg_index)); + const auto& device = exec_plan->GetLocation(arg_index).device; + + SessionState::NodeInfo node_info(index, &node, kci, device); if (IsArgNameInInputsOutputs(arg.Name(), graph_inputs)) { ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(arg.Name(), node_info)); @@ -397,8 +403,13 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph // copy to/from CPU to go through the control flow nodes where possible/applicable. // the processing for the subgraph where the implicit input is consumed will do the real check on whether any // copy to a different device is required - SessionState::NodeInfo node_info(std::numeric_limits::max(), &node, kci); for (const auto& input_def : node_implicit_inputs) { + int arg_index; + //Question: the implicit input may not be found in this session state name to id map, but in parent session state name to id map. + //@Scott + ORT_RETURN_IF_ERROR(name_to_id.GetIdx(input_def->Name(), arg_index)); + auto& device = exec_plan->GetLocation(arg_index).device; + SessionState::NodeInfo node_info(std::numeric_limits::max(), &node, kci, device); ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(input_def->Name(), node_info)); } } @@ -413,7 +424,6 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph auto& input_map = session_state.GetInputNodeInfoMap(); auto end_map = input_map.cend(); - SessionState::NodeInfo empty_node_info(std::numeric_limits::max(), nullptr, nullptr); for (const auto& graph_input : graph_inputs) { const auto& name = graph_input->Name(); @@ -422,6 +432,10 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph // utils::CopyOneInputAcrossDevices will use the input OrtValue as is given we don't believe it's used anywhere. LOGS(session_state.Logger(), INFO) << (graph.IsSubgraph() ? "Subgraph" : "Graph") << " input with name " << name << " is not used by any node."; + int arg_index; + ORT_RETURN_IF_ERROR(name_to_id.GetIdx(name, arg_index)); + auto& device = exec_plan->GetLocation(arg_index).device; + SessionState::NodeInfo empty_node_info(std::numeric_limits::max(), nullptr, nullptr, device); ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(name, empty_node_info)); } } diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 83e2650109e57..5fc78a4f99715 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -23,9 +23,18 @@ AllocatorPtr GetAllocator(const SessionState& session_state, const OrtAllocatorI return session_state.GetExecutionProviders().GetAllocator(allocator_info); } -common::Status AllocateHelper(const IExecutionProvider& execution_provider, int device_id, const Tensor& fetched_tensor, +bool ProviderIsCpuBased(const std::string& provider_type) { + return provider_type == onnxruntime::kCpuExecutionProvider || + provider_type == onnxruntime::kMklDnnExecutionProvider || + provider_type == onnxruntime::kNGraphExecutionProvider || + provider_type == onnxruntime::kNupharExecutionProvider || + provider_type == onnxruntime::kOpenVINOExecutionProvider || + provider_type == onnxruntime::kNnapiExecutionProvider; +} + +common::Status AllocateHelper(const IExecutionProvider& execution_provider, const OrtDevice& device, const Tensor& fetched_tensor, OrtValue& output_mlvalue) { - auto allocator = execution_provider.GetAllocator(device_id, OrtMemTypeDefault); + auto allocator = execution_provider.GetAllocator(device.Id(), OrtMemTypeDefault); if (!allocator) { return Status(common::ONNXRUNTIME, common::FAIL, "invalid allocator"); } @@ -62,20 +71,20 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr, const FeedsFetchesManager::MLValueCopyInfo& copy_info, const OrtValue& source_mlvalue, OrtValue& target_mlvalue) { - if (copy_info.copy_provider == nullptr) { + if (copy_info.allocation_provider == nullptr){ target_mlvalue = source_mlvalue; - } else { - auto& source_tensor = source_mlvalue.Get(); + return Status::OK(); + } - if (!target_mlvalue.IsAllocated()) { - ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.allocation_device_id, - source_tensor, target_mlvalue)); - } + auto& source_tensor = source_mlvalue.Get(); + if (!target_mlvalue.IsAllocated()) { + ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.target_device, + source_tensor, target_mlvalue)); + } - Tensor* p_output_tensor = target_mlvalue.GetMutable(); + Tensor* p_output_tensor = target_mlvalue.GetMutable(); - ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor)); - } + ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor)); return Status::OK(); } @@ -86,8 +95,6 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons FeedsFetchesManager::MLValueCopyInfo& copy_info) { needed_copy = false; - //TODO: make it configurable - const int target_device_id = 0; std::vector node_info_vec; ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec)); @@ -111,51 +118,23 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons break; } - auto& required_provider_type = GetNodeInputProviderType(node_info); - auto& input_tensor = orig_mlvalue.Get(); - auto& input_tensor_loc = input_tensor.Location(); - - auto* p_input_provider = exec_providers.Get(input_tensor_loc); - if (!p_input_provider) { - p_input_provider = exec_providers.Get(onnxruntime::kCpuExecutionProvider); - ORT_ENFORCE(p_input_provider); - } - - //no copy for nGraph - if (required_provider_type == onnxruntime::kNGraphExecutionProvider) { - new_mlvalue = orig_mlvalue; - break; - } - - auto input_provider_type = p_input_provider->Type(); - if (input_provider_type == required_provider_type && input_tensor_loc.mem_type == OrtMemTypeDefault) { - new_mlvalue = orig_mlvalue; - break; - } - - // If a node requires input on cpu and input tensor is allocated with pinned memory allocator, don't do copy - if (required_provider_type == onnxruntime::kCpuExecutionProvider && - input_tensor_loc.mem_type == OrtMemTypeCPU) { + auto& required_device = *node_info.device; + auto& input_tensor_device = orig_mlvalue.Get().Location().device; + if (required_device == input_tensor_device) { + // No copy needed for same device. new_mlvalue = orig_mlvalue; break; } + auto& required_provider_type = GetNodeInputProviderType(node_info); auto* required_provider = exec_providers.Get(required_provider_type); - ORT_ENFORCE(required_provider); - - auto* p_copy_provider = (required_provider_type != onnxruntime::kCpuExecutionProvider) - ? required_provider - : p_input_provider; - - copy_info.allocation_device_id = target_device_id; + copy_info.target_device = required_device; copy_info.allocation_provider = required_provider; - copy_info.copy_provider = p_copy_provider; ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue)); needed_copy = true; - // } loop of node_info_vec } while (false); return Status::OK(); @@ -344,43 +323,26 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state continue; } - auto& fetched_tensor = fetched_mlvalue.Get(); - auto& fetched_tensor_location = fetched_tensor.Location(); - auto* p_fetched_provider = execution_providers.Get(fetched_tensor_location); - if (!p_fetched_provider) { - p_fetched_provider = cpu_execution_provider; - } - - auto fetched_provider_type = p_fetched_provider->Type(); - auto& output_mlvalue = user_fetches[idx]; - const IExecutionProvider* p_output_provider = nullptr; - + auto target_device = OrtDevice(); + auto& output_mlvalue = user_fetches[idx]; if (output_mlvalue.IsAllocated()) { Tensor* p_output_tensor = output_mlvalue.GetMutable(); + target_device = p_output_tensor->Location().device; p_output_provider = execution_providers.Get(p_output_tensor->Location()); } + auto fetch_result_device = fetched_mlvalue.Get().Location().device; + if (target_device == fetch_result_device) { + user_fetches[idx] = fetched_mlvalue; + continue; + } if (!p_output_provider) { p_output_provider = cpu_execution_provider; } - auto output_provider_type = p_output_provider->Type(); - - if (fetched_provider_type == output_provider_type || - (p_output_provider == cpu_execution_provider && fetched_tensor_location.mem_type == OrtMemTypeCPUOutput)) { - user_fetches[idx] = fetched_mlvalue; - continue; - } - needed_copy = true; - - auto* p_copy_provider = (fetched_provider_type != onnxruntime::kCpuExecutionProvider) - ? p_fetched_provider - : p_output_provider; - - const int device_id = 0; // TODO: As per comment in the copy input code, make this configurable. - FeedsFetchesManager::MLValueCopyInfo copy_info{device_id, p_output_provider, p_copy_provider}; + FeedsFetchesManager::MLValueCopyInfo copy_info{target_device, p_output_provider}; ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, fetched_mlvalue, output_mlvalue)); if (copiers) { @@ -410,11 +372,7 @@ static common::Status CachedCopyOutputsAcrossDevices( static DeviceCopyCheck CheckExecutionProviders(const ExecutionProviders& execution_providers) { for (const auto& execution_provider : execution_providers) { - if (execution_provider->Type() != onnxruntime::kCpuExecutionProvider && - execution_provider->Type() != onnxruntime::kMklDnnExecutionProvider && - execution_provider->Type() != onnxruntime::kNGraphExecutionProvider && - execution_provider->Type() != onnxruntime::kNupharExecutionProvider && - execution_provider->Type() != onnxruntime::kOpenVINOExecutionProvider) { + if (!ProviderIsCpuBased(execution_provider->Type())) { return DeviceCopyCheck::Unknown; } } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 6a44cc27eb5c4..c4de7e676c1bf 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -72,6 +72,11 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in DeviceAllocatorRegistrationInfo pinned_allocator_info( {OrtMemTypeCPUOutput, [](int) { return std::make_unique(0, CUDA_PINNED); }, std::numeric_limits::max()}); InsertAllocator(CreateAllocator(pinned_allocator_info, device_id_)); + + // TODO: this is actually used for the cuda kernels which explicitly ask for inputs from CPU. + // This will be refactored/removed when allocator and execution provider are decoupled. + DeviceAllocatorRegistrationInfo cpu_allocator_info({OrtMemTypeCPUInput, [](int) { return std::make_unique(std::make_unique("CUDA_CPU", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUInput)); }, std::numeric_limits::max()}); + InsertAllocator(CreateAllocator(cpu_allocator_info)); } CUDAExecutionProvider::~CUDAExecutionProvider() { diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index 679d90a953806..f4faef586cd22 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -448,7 +448,6 @@ void OpTester::Run(ExpectResult expect_result, std::unordered_map feeds; std::vector output_names; FillFeedsAndOutputNames(feeds, output_names); - // Run the model SessionOptions so; so.session_logid = op_;