From 00b1696a4145fcd01e863c83341baa7253a9d56e Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Wed, 11 Sep 2024 20:20:16 +0530 Subject: [PATCH] [IREE][EP] Add support for rocm backend This commit adds support for rocm backend in iree-ep. Signed-Off-by: Gaurav Shukla --- .../providers/iree/compiler/jit_compiler.cc | 6 ++- .../providers/iree/compiler/jit_compiler.h | 4 +- .../core/providers/iree/iree_ep_runtime.cc | 39 ++++++++++--------- .../core/providers/iree/iree_ep_runtime.h | 4 +- .../providers/iree/iree_execution_provider.cc | 39 ++++++++++++++----- 5 files changed, 59 insertions(+), 33 deletions(-) diff --git a/onnxruntime/core/providers/iree/compiler/jit_compiler.cc b/onnxruntime/core/providers/iree/compiler/jit_compiler.cc index b2627753887c4..06b68c22f95f5 100644 --- a/onnxruntime/core/providers/iree/compiler/jit_compiler.cc +++ b/onnxruntime/core/providers/iree/compiler/jit_compiler.cc @@ -222,12 +222,16 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer return common::Status::OK(); } -common::Status CompilerInvocation::CompileAndOutputVMFB(iree_compiler_output_t* output) { +common::Status CompilerInvocation::CompileAndOutputVMFB(iree_compiler_output_t* output, std::string save_to) { // Main compilation. if (!ireeCompilerInvocationPipeline(inv, IREE_COMPILER_PIPELINE_STD)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "IREE compilation error.", ConsumeDiagnostics()); } + // Attach the compiled output to a file. + save_to.append("compiled_model.vmfb"); + ireeCompilerOutputOpenFile(save_to.c_str(), &output); + // Output. if (auto* err = ireeCompilerInvocationOutputVMBytecode(inv, output)) { return ErrorToStatus(err, "Failure emitting VM bytecode: "); diff --git a/onnxruntime/core/providers/iree/compiler/jit_compiler.h b/onnxruntime/core/providers/iree/compiler/jit_compiler.h index 5b974d971f5d6..8021148d1394e 100644 --- a/onnxruntime/core/providers/iree/compiler/jit_compiler.h +++ b/onnxruntime/core/providers/iree/compiler/jit_compiler.h @@ -45,7 +45,7 @@ struct CompilerOutput { // Releases ownership of the output, returning a callback that can be used to // destroy it at a later date. std::function Release() { - iree_compiler_output_t* local_output = output; + iree_compiler_output_t* local_output = this->output; this->output = nullptr; return [local_output]() { if (local_output) { @@ -84,7 +84,7 @@ struct CompilerInvocation { common::Status ImportSubgraph(const onnxruntime::GraphViewer& graph_view, const std::string& func_name); // Compile and output a VMFB. - common::Status CompileAndOutputVMFB(iree_compiler_output_t* output); + common::Status CompileAndOutputVMFB(iree_compiler_output_t* output, std::string save_to); // If there are any diagnostics, clears them and returns a loggable string. std::string ConsumeDiagnostics(); diff --git a/onnxruntime/core/providers/iree/iree_ep_runtime.cc b/onnxruntime/core/providers/iree/iree_ep_runtime.cc index ed258ebb5de07..fe2e4da79cc65 100644 --- a/onnxruntime/core/providers/iree/iree_ep_runtime.cc +++ b/onnxruntime/core/providers/iree/iree_ep_runtime.cc @@ -13,18 +13,7 @@ common::Status HandleFailingIREEStatus(iree_status_t iree_status) { return common::Status::OK(); } - std::string buffer; - iree_host_size_t actual_len; - buffer.resize(1024); - if (!iree_status_format(iree_status, buffer.size(), buffer.data(), - &actual_len)) { - buffer.resize(actual_len); - if (!iree_status_format(iree_status, buffer.size(), buffer.data(), - &actual_len)) { - actual_len = 0; - } - } - buffer.resize(actual_len); + std::string buffer = iree::Status::ToString(iree_status); return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "IREE Runtime Error: ", std::move(buffer)); } @@ -43,13 +32,13 @@ Instance::~Instance() { } } -iree_status_t Instance::Initialize() { +iree_status_t Instance::Initialize(std::string device_str) { IREE_RETURN_IF_ERROR(iree_runtime_instance_create( &options, iree_allocator_system(), &instance)); // TODO: Need real device selection. IREE_RETURN_IF_ERROR(iree_runtime_instance_try_create_default_device( - instance, iree_make_cstring_view("local-task"), &device)); + instance, iree_make_cstring_view(device_str.c_str()), &device)); return iree_ok_status(); } @@ -74,11 +63,15 @@ iree_status_t Session::Initialize() { &session); } -iree_status_t Session::AppendBytecodeModule(void* contents, uint64_t size, std::function dispose_callback) { +iree_status_t Session::AppendBytecodeModule(std::string file_loc, std::function dispose_callback) { dispose_callbacks.push_back(std::move(dispose_callback)); - return iree_runtime_session_append_bytecode_module_from_memory( - session, iree_make_const_byte_span(contents, size), - iree_allocator_null()); + // TODO: load from memory instead of file. + // return iree_runtime_session_append_bytecode_module_from_memory( + // session, iree_make_const_byte_span(contents, size), + // iree_allocator_null()); + file_loc.append("compiled_model.vmfb"); + return iree_runtime_session_append_bytecode_module_from_file( + session, file_loc.c_str()); } namespace { @@ -245,6 +238,16 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv); // TODO: Synchronous mapping read, like everything in this function, is not a // great idea. It isn't supported on all device types and will need a scrub. + iree_string_view_t device_val = iree_hal_device_id(device); + auto device_str = std::string(device_val.data, device_val.size); + if (device_str == "hip") { + ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h( + iree_runtime_session_device(session), + ret_buffer, 0, output_tensor.GetTensorMutableRawData(), + iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, + iree_infinite_timeout()))); + return common::Status::OK(); + } ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0, output_tensor.GetTensorMutableRawData(), iree_hal_buffer_view_byte_length(ret.bv)))); diff --git a/onnxruntime/core/providers/iree/iree_ep_runtime.h b/onnxruntime/core/providers/iree/iree_ep_runtime.h index 8adad33dfe8d5..28bcc3fa561f7 100644 --- a/onnxruntime/core/providers/iree/iree_ep_runtime.h +++ b/onnxruntime/core/providers/iree/iree_ep_runtime.h @@ -27,7 +27,7 @@ struct Instance { // Initializes the instance. // TODO: We should probably pass the options in here and use it to set up. - iree_status_t Initialize(); + iree_status_t Initialize(std::string device_str); // Instance globals. iree_runtime_instance_options_t options; @@ -48,7 +48,7 @@ struct Session { // Append a user-compiled bytecode module buffer to the session, along with a dispose callback. // The dispose callback will be invoked when Session is destroyed regardless of success/failure // of this call. - iree_status_t AppendBytecodeModule(void* contents, uint64_t size, std::function dispose_callback); + iree_status_t AppendBytecodeModule(std::string file_loc, std::function dispose_callback); // Calls the entrypoint. This returns an ORT Status and normalizes any IREE statuses to that // because that can arise from ORT interactions. diff --git a/onnxruntime/core/providers/iree/iree_execution_provider.cc b/onnxruntime/core/providers/iree/iree_execution_provider.cc index 70094efa23788..4e83f7647a613 100644 --- a/onnxruntime/core/providers/iree/iree_execution_provider.cc +++ b/onnxruntime/core/providers/iree/iree_execution_provider.cc @@ -33,7 +33,10 @@ IREEExecutionProvider::~IREEExecutionProvider() { } common::Status IREEExecutionProvider::Initialize() { - ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_instance_->Initialize())); + if (info_.find("device") == info_.end()) + info_["device"] = "local-task"; + LOGS(*GetLogger(), INFO) << "IREEExecutionProvider runtime device set as " << info_["device"]; + ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_instance_->Initialize(info_["device"]))); return common::Status::OK(); } @@ -102,15 +105,25 @@ common::Status IREEExecutionProvider::Compile(const std::vector(rt_instance_); + // In case device info is absent, set `local-task` as default device. ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->Initialize())); // Load the compiled module, releasing our ownership of the CompilerOutput. - ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule( - vmfb_contents, vmfb_size, vmfb_output.Release()))); + ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(info_["save_to"], vmfb_output.Release()))); for (auto& entrypoint_name : entrypoint_names) { node_compute_funcs.push_back(CreateNodeComputeFunc(entrypoint_name, rt_session));