diff --git a/onnxruntime/core/providers/iree/iree_ep_runtime.cc b/onnxruntime/core/providers/iree/iree_ep_runtime.cc index 086ef9962465a..7caee054d94dc 100644 --- a/onnxruntime/core/providers/iree/iree_ep_runtime.cc +++ b/onnxruntime/core/providers/iree/iree_ep_runtime.cc @@ -27,9 +27,6 @@ Instance::~Instance() { if (instance) { iree_runtime_instance_release(instance); } - if (device) { - iree_hal_device_release(device); - } } iree_status_t Instance::Initialize(std::string device_str) { @@ -234,22 +231,13 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, auto output_tensor = context.GetOutput(i, shape.data(), shape.size()); ORT_ENFORCE(output_tensor.IsTensor()); - iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv); // TODO: Synchronous mapping read, like everything in this function, is not a // great idea. It isn't supported on all device types and will need a scrub. - iree_string_view_t device_val = iree_hal_device_id(device); - auto device_str = std::string(device_val.data, device_val.size); - if (device_str == "hip") { - ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h( - iree_runtime_session_device(session), - ret_buffer, 0, output_tensor.GetTensorMutableRawData(), - iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, - iree_infinite_timeout()))); - return common::Status::OK(); - } - ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0, - output_tensor.GetTensorMutableRawData(), - iree_hal_buffer_view_byte_length(ret.bv)))); + ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h( + iree_runtime_session_device(session), + iree_hal_buffer_view_buffer(ret.bv), 0, output_tensor.GetTensorMutableRawData(), + iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, + iree_infinite_timeout()))); } return common::Status::OK(); diff --git a/onnxruntime/core/providers/iree/iree_execution_provider.cc b/onnxruntime/core/providers/iree/iree_execution_provider.cc index d504561707e60..a3e037eb04ac8 100644 --- a/onnxruntime/core/providers/iree/iree_execution_provider.cc +++ b/onnxruntime/core/providers/iree/iree_execution_provider.cc @@ -168,6 +168,11 @@ common::Status IREEExecutionProvider::Compile(const std::vectorInitialize())); + // Release hal device after session initialization. + if (rt_instance_->device) { + iree_hal_device_release(rt_instance_->device); + } + // Load the compiled module, releasing our ownership of the CompilerOutput. ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(vmfb_path, vmfb_output.Release(vmfb_path))));