Skip to content

Commit

Permalink
[IREE-EP] Integrate iree async module in the IREE-EP
Browse files Browse the repository at this point in the history
Signed-Off-by: Gaurav Shukla <gaurav.shukla@amd.com>
  • Loading branch information
Shukla-Gaurav committed Nov 11, 2024
1 parent 7b2046f commit 16f08f8
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 79 deletions.
215 changes: 136 additions & 79 deletions onnxruntime/core/providers/iree/iree_ep_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/providers/iree/iree_ep_runtime.h"

#include "core/session/onnxruntime_cxx_api.h"
#include <iostream>

Check warning on line 7 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: iree_ep_runtime.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:7: Found C++ system header after other header. Should be: iree_ep_runtime.h, c system, c++ system, other. [build/include_order] [4]

namespace onnxruntime::iree_ep_rt {

Expand Down Expand Up @@ -57,10 +58,18 @@ Session::~Session() {
}

iree_status_t Session::Initialize() {
return iree_runtime_session_create_with_device(
iree_status_t res_status = iree_runtime_session_create_with_device(
instance->instance, &session_options, instance->device,
iree_runtime_instance_host_allocator(instance->instance),
&session);
iree_vm_module_t* custom_module = NULL;
iree_allocator_t host_allocator = iree_allocator_system();
IREE_CHECK_OK(iree_custom_module_async_create(
iree_runtime_instance_vm_instance(instance->instance), instance->device,
host_allocator, &custom_module));
IREE_CHECK_OK(iree_runtime_session_append_module(session, custom_module));
iree_vm_module_release(custom_module);
return res_status;
}

iree_status_t Session::AppendBytecodeModule(fs::path vmfb_path, std::function<void()> dispose_callback) {
Expand Down Expand Up @@ -147,6 +156,13 @@ iree_hal_element_type_t ConvertOrtElementType(ONNXTensorElementDataType et) {
common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, OrtKernelContext* ort_context_c) {
// TODO: This is far from the most efficient way to make a call. Synchronous and copying. We can do
// better but this gets points for simplicity and lets us bootstrap the tests.
iree_vm_list_t* inputs = NULL;
iree_allocator_t host_allocator = iree_allocator_system();
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1,
host_allocator, &inputs));
iree_vm_list_t* outputs = NULL;
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1,
host_allocator, &outputs));
Ort::KernelContext context(ort_context_c);
SynchronousCall call(session);
ORT_RETURN_IF_ERROR(HandleIREEStatus(call.InitializeByName(entrypoint_name)));
Expand All @@ -161,59 +177,93 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,

// Process inputs. We could be smarter about this in a lot of ways, including carrying
// more state from compilation so we are doing less munging here.
for (size_t i = 0; i < context.GetInputCount(); ++i) {
auto input_tensor = context.GetInput(i);
ORT_ENFORCE(input_tensor.IsTensor());

// The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that
// is useful for anything.
auto ort_device_type = input_tensor.GetTensorMemoryInfo().GetDeviceType();
ORT_ENFORCE(ort_device_type == OrtMemoryInfoDeviceType_CPU);

const auto& tensor_type = input_tensor.GetTensorTypeAndShapeInfo();
auto element_type = ConvertOrtElementType(tensor_type.GetElementType());
ORT_ENFORCE(element_type != IREE_HAL_ELEMENT_TYPE_NONE, "Unsupported element type ",
static_cast<int>(tensor_type.GetElementType()));
ORT_ENFORCE(iree_hal_element_is_byte_aligned(element_type));
size_t element_size_bytes = iree_hal_element_dense_byte_count(element_type);

// Yes, that's right, returned as an std::vector by value :(
// And of a different type than we expect.
std::vector<int64_t> shape = tensor_type.GetShape();
dims.resize(shape.size());
std::copy(shape.begin(), shape.end(), dims.begin());

// No convenient way to get the byte size of the raw data.
size_t element_count = tensor_type.GetElementCount();
const void* raw_data = input_tensor.GetTensorRawData();

HalBufferView arg;
iree_hal_buffer_params_t buffer_params;
memset(&buffer_params, 0, sizeof(buffer_params));
buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL;
buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_view_allocate_buffer_copy(
device, device_allocator,
// Shape rank and dimensions:
dims.size(), dims.data(),
// Element type:
element_type,
// Encoding type:
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
buffer_params,
// The actual heap buffer to wrap or clone and its allocator:
iree_make_const_byte_span(raw_data, element_count * element_size_bytes),
// Buffer view + storage are returned and owned by the caller:
&arg.bv)));

// Add it to the call.
iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv);
ORT_RETURN_IF_ERROR(HandleIREEStatus(status));
}

std::cout << "input count: " << context.GetInputCount() << "\n";
// for (size_t i = 0; i < context.GetInputCount(); ++i) {
auto input_tensor = context.GetInput(0);
ORT_ENFORCE(input_tensor.IsTensor());

// The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that
// is useful for anything.
auto ort_device_type = input_tensor.GetTensorMemoryInfo().GetDeviceType();
ORT_ENFORCE(ort_device_type == OrtMemoryInfoDeviceType_CPU);

const auto& tensor_type = input_tensor.GetTensorTypeAndShapeInfo();
auto element_type = ConvertOrtElementType(tensor_type.GetElementType());
ORT_ENFORCE(element_type != IREE_HAL_ELEMENT_TYPE_NONE, "Unsupported element type ",
static_cast<int>(tensor_type.GetElementType()));
ORT_ENFORCE(iree_hal_element_is_byte_aligned(element_type));
size_t element_size_bytes = iree_hal_element_dense_byte_count(element_type);

// Yes, that's right, returned as an std::vector by value :(
// And of a different type than we expect.
std::vector<int64_t> shape = tensor_type.GetShape();

Check warning on line 200 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:200: Add #include <vector> for vector<> [build/include_what_you_use] [4]
dims.resize(shape.size());
std::copy(shape.begin(), shape.end(), dims.begin());

// No convenient way to get the byte size of the raw data.
size_t element_count = tensor_type.GetElementCount();
const void* raw_data = input_tensor.GetTensorRawData();

HalBufferView arg;
iree_hal_buffer_params_t buffer_params;
memset(&buffer_params, 0, sizeof(buffer_params));
buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL;
buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_view_allocate_buffer_copy(
device, device_allocator,
// Shape rank and dimensions:
dims.size(), dims.data(),
// Element type:
element_type,
// Encoding type:
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
buffer_params,
// The actual heap buffer to wrap or clone and its allocator:
iree_make_const_byte_span(raw_data, element_count * element_size_bytes),
// Buffer view + storage are returned and owned by the caller:
&arg.bv)));

iree_vm_ref_t input_view_ref = iree_hal_buffer_view_move_ref(arg.bv);
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &input_view_ref));

iree_hal_semaphore_t* semaphore = NULL;
IREE_CHECK_OK(iree_hal_semaphore_create(
device, 0ull, IREE_HAL_SEMAPHORE_FLAG_NONE, &semaphore));
iree_hal_fence_t* fence_t1 = NULL;
IREE_CHECK_OK(
iree_hal_fence_create_at(semaphore, 1ull, host_allocator, &fence_t1));
iree_hal_fence_t* fence_t2 = NULL;
IREE_CHECK_OK(
iree_hal_fence_create_at(semaphore, 2ull, host_allocator, &fence_t2));
iree_hal_semaphore_release(semaphore);
std::cout << "\n semaphore released";
iree_vm_ref_t fence_t1_ref = iree_hal_fence_retain_ref(fence_t1);
std::cout << "\n semaphore released1";
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &fence_t1_ref));
std::cout << "\n semaphore released2";
iree_vm_ref_t fence_t2_ref = iree_hal_fence_retain_ref(fence_t2);
std::cout << "\n semaphore released3";
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &fence_t2_ref));
std::cout << "\n semaphore released4";
IREE_CHECK_OK(iree_hal_fence_signal(fence_t1));
std::cout << "\n T=1 reached";
// Add it to the call.
iree_string_view_t entry_point = iree_make_cstring_view(entrypoint_name);
IREE_CHECK_OK(
iree_runtime_session_call_by_name(session, entry_point, inputs, outputs));
// We could go do other things now while the async work progresses. Here we
// just immediately wait.
IREE_CHECK_OK(iree_hal_fence_wait(fence_t2, iree_infinite_timeout()));
std::cout << "\n T=2 reached";
// iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv);
// ORT_RETURN_IF_ERROR(HandleIREEStatus(status));
// }
// Read back the tensor<?xi32> result:

// Invoke.
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, /*flags=*/0)));
// ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, [>flags=<]0)));

// Marshal the outputs.
// TODO: Accessing the ORT output requires the shape and then we could get zero copy
Expand All @@ -222,37 +272,44 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
// convention, which allows passing in slabs of result buffers. Further, that would
// run the host-side computation (which would compute output metadata) inline.
// For static cases, we could also side-load the shape from the compile time.
std::vector<int64_t> shape;
for (size_t i = 0; i < context.GetOutputCount(); ++i) {
HalBufferView ret;
ORT_RETURN_IF_ERROR(HandleIREEStatus(
iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv)));
size_t ret_rank = iree_hal_buffer_view_shape_rank(ret.bv);
const iree_hal_dim_t* ret_dims = iree_hal_buffer_view_shape_dims(ret.bv);
shape.resize(ret_rank);
std::copy(ret_dims, ret_dims + ret_rank, shape.begin());
auto output_tensor = context.GetOutput(i, shape.data(), shape.size());
ORT_ENFORCE(output_tensor.IsTensor());

iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
// TODO: Synchronous mapping read, like everything in this function, is not a
// great idea. It isn't supported on all device types and will need a scrub.
iree_string_view_t device_val = iree_hal_device_id(device);
auto device_str = std::string(device_val.data, device_val.size);
if (device_str == "hip") {
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h(
iree_runtime_session_device(session),
ret_buffer, 0, output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout())));
return common::Status::OK();
}
// std::vector<int64_t> shape;
std::cout << "output count: " << context.GetOutputCount() << "\n";
// for (size_t i = 0; i < context.GetOutputCount(); ++i) {
HalBufferView ret;
ret.bv = iree_vm_list_get_buffer_view_assign(outputs, 0);
// ORT_RETURN_IF_ERROR(HandleIREEStatus(
// iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv)));
size_t ret_rank = iree_hal_buffer_view_shape_rank(ret.bv);
const iree_hal_dim_t* ret_dims = iree_hal_buffer_view_shape_dims(ret.bv);
shape.clear();
shape.resize(ret_rank);
std::copy(ret_dims, ret_dims + ret_rank, shape.begin());

Check warning on line 286 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for copy [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:286: Add #include <algorithm> for copy [build/include_what_you_use] [4]
auto output_tensor = context.GetOutput(0, shape.data(), shape.size());
ORT_ENFORCE(output_tensor.IsTensor());

iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
// TODO: Synchronous mapping read, like everything in this function, is not a

Check warning on line 291 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:291: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// great idea. It isn't supported on all device types and will need a scrub.
iree_string_view_t device_val = iree_hal_device_id(device);
auto device_str = std::string(device_val.data, device_val.size);

Check warning on line 294 in onnxruntime/core/providers/iree/iree_ep_runtime.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.cc:294: Add #include <string> for string [build/include_what_you_use] [4]
if (device_str == "hip") {
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h(
iree_runtime_session_device(session),
ret_buffer, 0, output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout())));
return common::Status::OK();
}
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0,
output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv))));
}
// }

return common::Status::OK();
iree_vm_list_release(inputs);
iree_vm_list_release(outputs);
iree_hal_fence_release(fence_t1);
iree_hal_fence_release(fence_t2);
return common::Status::OK();
}

} // namespace onnxruntime::iree_ep_rt
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/iree/iree_ep_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

#include "core/common/common.h"
#include "core/session/onnxruntime_c_api.h"
#include "iree/modules/hal/types.h"
#include "iree/runtime/api.h"

#include "module.h"

Check warning on line 11 in onnxruntime/core/providers/iree/iree_ep_runtime.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/iree/iree_ep_runtime.h:11: Include the directory when naming header files [build/include_subdir] [4]

#include <filesystem>

namespace fs = std::filesystem;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/iree/iree_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << extra_flag;
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag.c_str()));
}
std::string extra_flag_2 = "--iree-execution-model=async-external";
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag_2.c_str()));

ORT_RETURN_IF_ERROR(compiler.Initialize());
std::string module_name = "ort";
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
#include "core/providers/dml/dml_session_options_config_keys.h"
#endif

#ifdef USE_IREE
#include "core/providers/iree/iree_provider_factory.h"
#endif

#ifdef _WIN32
#define strdup _strdup
#endif
Expand Down

0 comments on commit 16f08f8

Please sign in to comment.