Skip to content

Commit

Permalink
[IREE][EP] Add support for rocm backend
Browse files Browse the repository at this point in the history
This commit adds support for rocm backend in iree-ep.

Signed-Off-by: Gaurav Shukla<gaurav.shukla@amd.com>
(cherry picked from commit 00b1696)
  • Loading branch information
Shukla-Gaurav committed Sep 11, 2024
1 parent 08acace commit 146425f
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 33 deletions.
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/iree/compiler/jit_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,16 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer
return common::Status::OK();
}

common::Status CompilerInvocation::CompileAndOutputVMFB(iree_compiler_output_t* output) {
common::Status CompilerInvocation::CompileAndOutputVMFB(iree_compiler_output_t* output, std::string save_to) {
// Main compilation.
if (!ireeCompilerInvocationPipeline(inv, IREE_COMPILER_PIPELINE_STD)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "IREE compilation error.", ConsumeDiagnostics());
}

// Attach the compiled output to a file.
save_to.append("compiled_model.vmfb");
ireeCompilerOutputOpenFile(save_to.c_str(), &output);

// Output.
if (auto* err = ireeCompilerInvocationOutputVMBytecode(inv, output)) {
return ErrorToStatus(err, "Failure emitting VM bytecode: ");
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/iree/compiler/jit_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct CompilerOutput {
// Releases ownership of the output, returning a callback that can be used to
// destroy it at a later date.
std::function<void()> Release() {
iree_compiler_output_t* local_output = output;
iree_compiler_output_t* local_output = this->output;
this->output = nullptr;
return [local_output]() {
if (local_output) {
Expand Down Expand Up @@ -84,7 +84,7 @@ struct CompilerInvocation {
common::Status ImportSubgraph(const onnxruntime::GraphViewer& graph_view, const std::string& func_name);

// Compile and output a VMFB.
common::Status CompileAndOutputVMFB(iree_compiler_output_t* output);
common::Status CompileAndOutputVMFB(iree_compiler_output_t* output, std::string save_to);

// If there are any diagnostics, clears them and returns a loggable string.
std::string ConsumeDiagnostics();
Expand Down
39 changes: 21 additions & 18 deletions onnxruntime/core/providers/iree/iree_ep_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,7 @@ common::Status HandleFailingIREEStatus(iree_status_t iree_status) {
return common::Status::OK();
}

std::string buffer;
iree_host_size_t actual_len;
buffer.resize(1024);
if (!iree_status_format(iree_status, buffer.size(), buffer.data(),
&actual_len)) {
buffer.resize(actual_len);
if (!iree_status_format(iree_status, buffer.size(), buffer.data(),
&actual_len)) {
actual_len = 0;
}
}
buffer.resize(actual_len);
std::string buffer = iree::Status::ToString(iree_status);

return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "IREE Runtime Error: ", std::move(buffer));
}
Expand All @@ -43,13 +32,13 @@ Instance::~Instance() {
}
}

iree_status_t Instance::Initialize() {
iree_status_t Instance::Initialize(std::string device_str) {
IREE_RETURN_IF_ERROR(iree_runtime_instance_create(
&options, iree_allocator_system(), &instance));

// TODO: Need real device selection.
IREE_RETURN_IF_ERROR(iree_runtime_instance_try_create_default_device(
instance, iree_make_cstring_view("local-task"), &device));
instance, iree_make_cstring_view(device_str.c_str()), &device));

return iree_ok_status();
}
Expand All @@ -74,11 +63,15 @@ iree_status_t Session::Initialize() {
&session);
}

iree_status_t Session::AppendBytecodeModule(void* contents, uint64_t size, std::function<void()> dispose_callback) {
iree_status_t Session::AppendBytecodeModule(std::string file_loc, std::function<void()> dispose_callback) {
dispose_callbacks.push_back(std::move(dispose_callback));
return iree_runtime_session_append_bytecode_module_from_memory(
session, iree_make_const_byte_span(contents, size),
iree_allocator_null());
// TODO: load from memory instead of file.

Check warning on line 68 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:68: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// return iree_runtime_session_append_bytecode_module_from_memory(
// session, iree_make_const_byte_span(contents, size),
// iree_allocator_null());
file_loc.append("compiled_model.vmfb");
return iree_runtime_session_append_bytecode_module_from_file(
session, file_loc.c_str());
}

namespace {
Expand Down Expand Up @@ -245,6 +238,16 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
// TODO: Synchronous mapping read, like everything in this function, is not a
// great idea. It isn't supported on all device types and will need a scrub.
iree_string_view_t device_val = iree_hal_device_id(device);
auto device_str = std::string(device_val.data, device_val.size);

Check warning on line 242 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:242: Add #include <string> for string [build/include_what_you_use] [4]
if (device_str == "hip") {
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h(
iree_runtime_session_device(session),
ret_buffer, 0, output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout())));
return common::Status::OK();
}
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0,
output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv))));
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/iree/iree_ep_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct Instance {

// Initializes the instance.
// TODO: We should probably pass the options in here and use it to set up.
iree_status_t Initialize();
iree_status_t Initialize(std::string device_str);

// Instance globals.
iree_runtime_instance_options_t options;
Expand All @@ -48,7 +48,7 @@ struct Session {
// Append a user-compiled bytecode module buffer to the session, along with a dispose callback.
// The dispose callback will be invoked when Session is destroyed regardless of success/failure
// of this call.
iree_status_t AppendBytecodeModule(void* contents, uint64_t size, std::function<void()> dispose_callback);
iree_status_t AppendBytecodeModule(std::string file_loc, std::function<void()> dispose_callback);

Check warning on line 51 in onnxruntime/core/providers/iree/iree_ep_runtime.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.h:51: Add #include <string> for string [build/include_what_you_use] [4]

// Calls the entrypoint. This returns an ORT Status and normalizes any IREE statuses to that
// because that can arise from ORT interactions.
Expand Down
39 changes: 29 additions & 10 deletions onnxruntime/core/providers/iree/iree_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ IREEExecutionProvider::~IREEExecutionProvider() {
}

common::Status IREEExecutionProvider::Initialize() {
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_instance_->Initialize()));
if (info_.find("device") == info_.end())
info_["device"] = "local-task";
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider runtime device set as " << info_["device"];
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_instance_->Initialize(info_["device"])));
return common::Status::OK();
}

Expand Down Expand Up @@ -102,15 +105,25 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
// TODO: The target needs to be synchronized with the runtime based on EP options.
// TODO: We should just be adding the target to the module instead of specifying via
// flags.
std::string device_flag = "--iree-hal-target-backends=";
std::string device_flag = "--iree-hal-target-device=";
if (info_.find("hal_target_device") == info_.end()) {
// In case device info is absent, set `llvm-cpu` as default hal-target-backend.
// In case device info is absent, set `llvm-cpu` as default hal-target-device.
device_flag.append("llvm-cpu");
} else {
device_flag.append(info_["hal_target_device"]);
}
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting device flag as " << device_flag;
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << device_flag;
ORT_RETURN_IF_ERROR(compiler.SetFlag(device_flag.c_str()));

// Set all the compile-time flags.
// TODO: Use ireeCompilerSessionSetFlags API to set all the flags at once.

Check warning on line 119 in onnxruntime/core/providers/iree/iree_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/iree/iree_execution_provider.cc:119: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// TODO: support more than one extra flags by parsing the input string.

Check warning on line 120 in onnxruntime/core/providers/iree/iree_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/iree/iree_execution_provider.cc:120: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
if (info_.find("compile_time_flags") != info_.end()) {
std::string extra_flag = info_["compile_time_flags"];
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << extra_flag;
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag.c_str()));
}

ORT_RETURN_IF_ERROR(compiler.Initialize());
std::string module_name = "ort";
iree_ep_jit::CompilerInvocation inv(compiler, module_name.c_str());
Expand All @@ -137,20 +150,26 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
if (auto* err = ireeCompilerOutputOpenMembuffer(&vmfb_output.output)) {
return iree_ep_jit::ErrorToStatus(err, "Failure opening compiler output buffer: ");
}
ORT_RETURN_IF_ERROR(inv.CompileAndOutputVMFB(vmfb_output.output));

// This will save the compiled module to current working directory.
if (info_.find("save_to") == info_.end())
info_["save_to"] = "";
ORT_RETURN_IF_ERROR(inv.CompileAndOutputVMFB(vmfb_output.output, info_["save_to"]));
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compiled vmfb saved at this location " << info_["save_to"];

// Map raw memory.
void* vmfb_contents;
uint64_t vmfb_size;
ORT_RETURN_IF_ERROR(vmfb_output.MapMemory(&vmfb_contents, &vmfb_size));
// void* vmfb_contents = nullptr;
// uint64_t vmfb_size = 0;
// TODO: Map memory instead of storing the compiled module as a file

Check warning on line 163 in onnxruntime/core/providers/iree/iree_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/iree/iree_execution_provider.cc:163: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// ORT_RETURN_IF_ERROR(vmfb_output.MapMemory(&vmfb_contents, &vmfb_size));

// Create a new runtime session.
auto rt_session = std::make_shared<iree_ep_rt::Session>(rt_instance_);
// In case device info is absent, set `local-task` as default device.
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->Initialize()));

// Load the compiled module, releasing our ownership of the CompilerOutput.
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(
vmfb_contents, vmfb_size, vmfb_output.Release())));
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(info_["save_to"], vmfb_output.Release())));

Check warning on line 172 in onnxruntime/core/providers/iree/iree_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/iree/iree_execution_provider.cc:172: Lines should be <= 120 characters long [whitespace/line_length] [2]

for (auto& entrypoint_name : entrypoint_names) {
node_compute_funcs.push_back(CreateNodeComputeFunc(entrypoint_name, rt_session));
Expand Down

0 comments on commit 146425f

Please sign in to comment.