Skip to content

Commit

Permalink
[TVM EP] support of TVM Virtual Machine (microsoft#10341)
Browse files Browse the repository at this point in the history
* add executor option (vm or graph) and support virtual machine methods

* nullptr check for compile and run methods (see also PR#10211 from microsoft:onnxruntime)

* get output shapes for VM

* remove run_with_benchmark. remove run methods from python api, get it from native side

* get outputs method for VM was implemented

* support multiple input for VM

* update python logging and exception

* small fix

* update tvm with patch for VM API

* update nhwc transformations for TVM EP

* add data alignment check and support set_input_zero_copy for GE in TVM EP

* fix logger name

* return back to apache/tvm with VM fixes instead of local dev branch

* hide customized tvm logger while issue is not resolved. fix tvm warning related to target_host

* flake8 fix

Co-authored-by: Valery Chernov <valery.chernov@deelvin.com>
(cherry picked from commit 62cc981)
  • Loading branch information
vvchernov authored and Peter Salas committed Nov 7, 2022
1 parent 0009d33 commit 13fab6f
Show file tree
Hide file tree
Showing 12 changed files with 250 additions and 128 deletions.
2 changes: 1 addition & 1 deletion cgmanifests/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "36b48a5707321adba8a70e14da443566a9391e5a",
"commitHash": "d62a364ba783afef92623ee531043ee8dbd43566",
"repositoryUrl": "https://github.com/apache/tvm.git"
},
"comments": "needed for TVM EP"
Expand Down
7 changes: 4 additions & 3 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1414,14 +1414,15 @@ if (onnxruntime_USE_TVM)
set(USE_CUDA ON CACHE BOOL "Only defined for TVM" FORCE)
endif()

add_compile_definitions(TVM_LOG_CUSTOMIZE=1)
add_library(tvm_custom_logger STATIC ${ONNXRUNTIME_ROOT}/core/providers/tvm/custom_logging.cc)
# TODO(vvchernov): customized tvm logger is hidden due to the issue on TVM side (https://github.com/apache/tvm/issues/10139)
# add_compile_definitions(TVM_LOG_CUSTOMIZE=1)
# add_library(tvm_custom_logger STATIC ${ONNXRUNTIME_ROOT}/core/providers/tvm/custom_logging.cc)

set(USE_OPENMP gnu CACHE STRING "Only defined for TVM")
add_subdirectory(${tvm_SOURCE_DIR} ${tvm_BINARY_DIR} EXCLUDE_FROM_ALL)

set_target_properties(tvm PROPERTIES FOLDER ${tvm_SOURCE_DIR})
target_link_libraries(tvm PUBLIC tvm_custom_logger)
# target_link_libraries(tvm PUBLIC tvm_custom_logger)

set(TVM_INCLUDES ${tvm_SOURCE_DIR}/include
${tvm_SOURCE_DIR}/3rdparty/dmlc-core/include
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/tvm.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ if (onnxruntime_USE_TVM)
FetchContent_Declare(
tvm
GIT_REPOSITORY https://github.com/apache/tvm.git
GIT_TAG 36b48a5707321adba8a70e14da443566a9391e5a
GIT_TAG d62a364ba783afef92623ee531043ee8dbd43566
)

FetchContent_GetProperties(tvm)
Expand Down
97 changes: 67 additions & 30 deletions onnxruntime/core/providers/tvm/tvm_api.cc
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/common.h"

#include "tvm_api.h"

#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/target/codegen.h>

#include "core/common/common.h"

#include "tvm_api.h"

namespace onnxruntime {
namespace tvm {
Expand All @@ -16,16 +16,17 @@ using TvmIntArray = ::tvm::Array<::tvm::Integer>;
using TvmPackedFunc = ::tvm::PackedFunc;

TvmModule TVMCompile(const std::string& onnx_txt,
const std::string& model_path,
const std::string& target,
const std::string& target_host,
int opt_level,
int opset,
bool freeze_params,
const std::vector<std::vector<int64_t>>& input_shapes,
bool nhwc,
const std::string& tuning_logfile,
const std::string& tuning_type)
const std::string& model_path,
const std::string& executor,
const std::string& target,
const std::string& target_host,
int opt_level,
int opset,
bool freeze_params,
const std::vector<std::vector<int64_t>>& input_shapes,
bool nhwc,
const std::string& tuning_logfile,
const std::string& tuning_type)
{
::tvm::Array<TvmIntArray> shapes;
for (size_t i = 0; i < input_shapes.size(); ++i)
Expand All @@ -43,6 +44,7 @@ TvmModule TVMCompile(const std::string& onnx_txt,
TvmModule mod = (*compile)(
TVMByteArray{onnx_txt.data(), onnx_txt.size()},
model_path,
executor,
target,
target_host,
opt_level,
Expand All @@ -52,20 +54,56 @@ TvmModule TVMCompile(const std::string& onnx_txt,
nhwc,
tuning_logfile,
tuning_type);
ORT_ENFORCE(mod.get() != nullptr, "Compiled TVM Module is nullptr!");
return mod;
}

void TVMSetInputs(TvmModule& mod,
std::vector<size_t>& inds,
std::vector<DLTensor>& inputs)
{
// TODO(vvchernov): set_input_zero_copy is more preferable but it does not satisfy alignment conditions.
//tvm::PackedFunc set_input = mod.GetFunction("set_input_zero_copy", false);

TvmPackedFunc set_input = mod.GetFunction("set_input", false);
for (auto& i : inds)
TvmPackedFunc set_input_zero_copy = mod.GetFunction("set_input_zero_copy", false);
for (size_t i = 0; i < inds.size(); ++i)
{
set_input(i, &inputs[i]);
if (reinterpret_cast<size_t>(inputs[i].data) % ::tvm::runtime::kAllocAlignment == 0) {
set_input_zero_copy(inds[i], &inputs[i]);
} else {
set_input(inds[i], &inputs[i]);
}
}
}

void TVM_VM_SetInputs(TvmModule& mod,
std::vector<size_t>& inds,
std::vector<DLTensor>& inputs)
{
TvmPackedFunc set_input = mod.GetFunction("set_one_input", false);
for (size_t i = 0; i < inds.size(); ++i)
{
set_input("main", inds[i], &inputs[i]);
}
}

void TVMGetOutputs(TvmModule& mod,
std::vector<DLTensor>& outputs)
{
TvmPackedFunc get_output = mod.GetFunction("get_output", false);
for (size_t i = 0; i < outputs.size(); ++i)
{
get_output(i, &outputs[i]);
}
}

void TVM_VM_GetOutputs(TvmModule& mod,
std::vector<DLTensor>& outputs)
{
TvmPackedFunc get_output = mod.GetFunction("get_output", false);
for (size_t i = 0; i < outputs.size(); ++i)
{
// TODO(vvchernov): think about improvement of memory management
::tvm::runtime::NDArray output_array = get_output(i);
output_array.CopyTo(&outputs[i]);
}
}

Expand All @@ -87,19 +125,18 @@ void TVMGetOutputShapes(TvmModule& mod,
}
}

void TVMRun(TvmModule& mod,
std::vector<DLTensor>& outputs,
[[maybe_unused]] ::tvm::runtime::TVMRetValue *ret)
void TVMRun(TvmModule& mod)
{
const TvmPackedFunc* run = ::tvm::runtime::Registry::Get("tvm_run");
ORT_ENFORCE(run != nullptr, "Unable to retrieve 'tvm_run'.");
(*run)(mod);
TvmPackedFunc run = mod.GetFunction("run", false);
ORT_ENFORCE(run != nullptr, "Unable to retrieve graph executor run.");
run();
}

TvmPackedFunc get_output = mod.GetFunction("get_output", false);
for (size_t i = 0; i < outputs.size(); ++i)
{
get_output(i, &outputs[i]);
}
void TVM_VM_Run(TvmModule& mod)
{
TvmPackedFunc run = mod.GetFunction("invoke", false);
ORT_ENFORCE(run != nullptr, "Unable to retrieve virtual machine invoke.");
run("main");
}

} // namespace tvm
Expand Down
31 changes: 20 additions & 11 deletions onnxruntime/core/providers/tvm/tvm_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,35 @@
#ifndef TVM_API_H
#define TVM_API_H

#include <vector>
#include <string>

#include "tvm_common.h"
#include "tvm_defaults.h"

namespace onnxruntime {
namespace tvm {
TvmModule TVMCompile(const std::string& onnx_txt,
const std::string& model_path,
const std::string& target,
const std::string& target_host,
int opt_level,
int opset,
bool freeze_params,
const std::vector<std::vector<int64_t>>& input_shapes,
bool nhwc = false,
const std::string& tuning_logfile = "",
const std::string& tuning_type = "AutoTVM");
const std::string& model_path,
const std::string& executor,
const std::string& target,
const std::string& target_host,
int opt_level,
int opset,
bool freeze_params,
const std::vector<std::vector<int64_t>>& input_shapes,
bool nhwc = false,
const std::string& tuning_logfile = "",
const std::string& tuning_type = std::string(onnxruntime::tvm::default_tuning_type));
void TVMSetInputs(TvmModule& mod, std::vector<size_t>& inds, std::vector<DLTensor>& inputs);
void TVM_VM_SetInputs(TvmModule& mod, std::vector<size_t>& inds, std::vector<DLTensor>& inputs);
void TVMGetOutputs(TvmModule& mod, std::vector<DLTensor>& outputs);
void TVM_VM_GetOutputs(TvmModule& mod, std::vector<DLTensor>& outputs);
void TVMGetOutputShapes(TvmModule& mod,
size_t num_outputs,
std::vector<std::vector<int64_t>>& output_shapes);
void TVMRun(TvmModule& mod, std::vector<DLTensor>& outputs, ::tvm::runtime::TVMRetValue *ret);
void TVMRun(TvmModule& mod);
void TVM_VM_Run(TvmModule& mod);
} // namespace tvm
} // namespace onnxruntime

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/tvm/tvm_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <dlpack/dlpack.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/vm/vm.h>

using TvmModule = tvm::runtime::Module;

Expand Down
29 changes: 29 additions & 0 deletions onnxruntime/core/providers/tvm/tvm_defaults.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifndef TVM_DEFAULTS_H
#define TVM_DEFAULTS_H

namespace onnxruntime {
namespace tvm {

constexpr const char* default_executor_type = "vm";
constexpr const char* vm_executor_type = "vm";
constexpr const char* graph_executor_type = "graph";

constexpr const char* default_target_str = "cpu";
constexpr const char* llvm_target_str = "llvm";

constexpr const char* cpu_target_str = "cpu";
constexpr const char* gpu_target_str = "gpu";

constexpr const char* default_tuning_type = "AutoTVM";
constexpr const char* autotvm_tuning_type = "AutoTVM";
constexpr const char* ansor_tuning_type = "Ansor";

constexpr const unsigned int default_opt_level = 3;

} // namespace tvm
} // namespace onnxruntime

#endif // TVM_DEFAULTS_H
Loading

0 comments on commit 13fab6f

Please sign in to comment.