Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3/3][AOT][DeviceAPI] Wire up cpacked Device API context #9501

Merged
merged 3 commits into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,18 +348,39 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

GlobalVar global_var = call_lowered_props.lowered_func;
tir::Var empty_var("no_device_context", DataType::Handle());
bool has_c_device_api_context = device_contexts_.count(global_var) != 0;
bool use_cpacked_api = !use_unpacked_api_;

// The device context is passed to the operator in one of the following calling patterns:
// * Unpacked / direct function call with context:
// operator(arg0, arg1, device_context);
// * Unpacked / direct function call without context:
// operator(arg0, arg1);
// * Type-erased packed function call with context:
// operator(args, type_codes, int num_args, out_ret_value, out_ret_tcode,
// device_context_my_device)
// * Type-erased packed function call without context (we create an empty var for codegen):
// operator(args, type_codes, int num_args, out_ret_value, out_ret_tcode,
// no_device_context)
if (has_c_device_api_context) {
// call_extern calling convention with context
tir::Var context = device_contexts_.Get(global_var).value();
args.push_back(device_contexts_[global_var]);
args.push_back(context);

tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(tir::SeqStmt({
GenerateDeviceHook(context, "Open"),
func_call,
GenerateDeviceHook(context, "Close"),
}));
} else if (use_cpacked_api) {
// call_cpacked calling convention needs a blank context
args.push_back(tir::make_zero(DataType::Handle()));
tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(func_call);
} else {
// call_extern calling convention without context
tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(func_call);
}
Expand Down
17 changes: 12 additions & 5 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar
this->stream << "}\n";
}

void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_args) {
void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_args,
const std::string& resource_handle_name) {
this->PrintIndent();
std::string ret_val = GetUniqueName("ret_val");
std::string ret_type_code = GetUniqueName("ret_type_code");
Expand All @@ -266,7 +267,7 @@ void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_a
<< "(int*) stack_tcode"
<< ", " << num_args << ", "
<< "&" << ret_val << ", "
<< "&" << ret_type_code << ", NULL) != 0){\n";
<< "&" << ret_type_code << ", " << resource_handle_name << ") != 0){\n";

int func_call_scope = this->BeginScope();
this->PrintIndent();
Expand All @@ -276,7 +277,8 @@ void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_a
this->stream << "}\n";
}

CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op) {
CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op,
bool has_resource_handle) {
const StringImmNode* s = op->args[0].as<StringImmNode>();
ICHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name";
int64_t begin = op->args[3].as<IntImmNode>()->value;
Expand All @@ -296,6 +298,10 @@ CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op) {
declared_globals_[packed_func_name] = unique_name;
decl_stream << "static void* " << unique_name << " = NULL;\n";
}
if (has_resource_handle) {
std::string resource_handle_name = op->args[5].as<StringImmNode>()->value;
return {func_name, unique_name, num_args - 1, resource_handle_name};
}
return {func_name, unique_name, num_args};
}

Expand Down Expand Up @@ -327,8 +333,9 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT
this->PrintGetFuncFromBackend(function_info.func_name, function_info.func_name_packed);
this->PrintFuncCall(function_info.func_name_packed, function_info.num_args);
} else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) {
auto function_info = GetFunctionInfo(op);
this->PrintFuncCallC(function_info.func_name, function_info.num_args);
auto function_info = GetFunctionInfo(op, true);
this->PrintFuncCallC(function_info.func_name, function_info.num_args,
function_info.resource_handle_name);
} else if (op->op.same_as(builtin::tvm_throw_last_error())) {
this->PrintIndent();
this->stream << "return -1;\n";
Expand Down
7 changes: 5 additions & 2 deletions src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class CodeGenCHost : public CodeGenC {
std::string func_name_packed;
/* number of arguments required by the function */
int64_t num_args;
/* \brief name of resource_handle to pass */
std::string resource_handle_name;
};
std::string module_name_;
/* \brief mapping global packed func to the unique name */
Expand All @@ -82,10 +84,11 @@ class CodeGenCHost : public CodeGenC {
/*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_;

FunctionInfo GetFunctionInfo(const CallNode* op);
FunctionInfo GetFunctionInfo(const CallNode* op, bool has_resource_handle = false);
void PrintGetFuncFromBackend(const std::string& func_name, const std::string& packed_func_name);
void PrintFuncCall(const std::string& packed_func_name, int num_args);
void PrintFuncCallC(const std::string& packed_func_name, int num_args);
void PrintFuncCallC(const std::string& packed_func_name, int num_args,
const std::string& resource_handle_name);

/*!
* \brief Print ternary conditional operator implementing binary `op`
Expand Down
17 changes: 15 additions & 2 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,18 @@ class BuiltinLower : public StmtExprMutator {
size_t restore_array_stack = scope.run_array_stack;
size_t arg_stack_begin = scope.run_arg_stack;

scope.run_arg_stack += op->args.size();
size_t arg_count = op->args.size();

// cpacked expects a resource_handle parameter
if (!use_string_lookup) {
arg_count--;
}

scope.run_arg_stack += arg_count;
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
for (size_t i = 1; i < op->args.size(); ++i) {
for (size_t i = 1; i < arg_count; ++i) {
PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
PrimExpr arg = op->args[i];
DataType t = arg.dtype();
Expand Down Expand Up @@ -314,6 +321,12 @@ class BuiltinLower : public StmtExprMutator {
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1)};

// cpacked call resource_handle
if (!use_string_lookup) {
tir::Var resource_handle = Downcast<Var>(op->args[arg_count]);
packed_args.push_back(StringImm(resource_handle->name_hint));
}

auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered()
: builtin::tvm_call_cpacked_lowered();
return Call(op->dtype, builtin_call, packed_args);
Expand Down
27 changes: 21 additions & 6 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
namespace tvm {
namespace tir {

static constexpr const char* kDeviceContextVar = "device_api_context";

class ReturnRewriter : public StmtMutator {
public:
explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {}
Expand Down Expand Up @@ -161,15 +163,11 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
};
// ---------------------------
// start of logics
// add signiture for packed arguments.
// add signature for packed arguments.
if (pack_args) {
args.push_back(v_packed_args);
args.push_back(v_packed_arg_type_ids);
args.push_back(v_num_packed_args);
std::ostringstream os;

os << name_hint << ": num_args should be " << num_packed_args;
seq_init.emplace_back(MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
}

// Need to re-declare vars, in case some arguments also appears in the buffer.
Expand All @@ -180,6 +178,13 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
Var param = func_ptr->params[i];
Var v_arg = Var("arg" + std::to_string(i), param->dtype);

// Pluck the device API context out based on name
if (param->name_hint == kDeviceContextVar) {
num_packed_args--;
v_resource_handle = param;
continue;
}

auto it = func_ptr->buffer_map.find(param);
if (it != func_ptr->buffer_map.end()) {
buffer_def.emplace_back(v_arg, (*it).second);
Expand Down Expand Up @@ -262,7 +267,17 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
body = SeqStmt({set_device, body});
}
}
func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);

if (pack_args) {
std::ostringstream num_args_error;
num_args_error << name_hint << ": num_args should be " << num_packed_args;
std::vector<Stmt> arg_assert = {
MakeAssertEQ(v_num_packed_args, num_packed_args, num_args_error.str())};
func_ptr->body =
MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
} else {
func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
}
func_ptr->params = args;

Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
Expand Down
6 changes: 3 additions & 3 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def deserialize_command_stream(blob):
return cmms


def _create_test_runner(accel):
def create_test_runner(accel="ethos-u55-256"):
file_dir = os.path.dirname(os.path.abspath(__file__))
test_root = os.path.join(file_dir, "reference_system")
ethosu_macs = accel[accel.rfind("-") + 1 :]
Expand All @@ -215,7 +215,7 @@ def _create_test_runner(accel):


def build_source(module, inputs, outputs, accel="ethos-u55-256", output_tolerance=0):
test_runner = _create_test_runner(accel)
test_runner = create_test_runner(accel)
return compile_models(
models=AOTTestModel(
module=module,
Expand All @@ -239,7 +239,7 @@ def verify_source(
This method verifies the generated source from an NPU module by building it and running on an FVP.
"""
interface_api = "c"
test_runner = _create_test_runner(accel)
test_runner = create_test_runner(accel)
run_and_check(
models,
test_runner,
Expand Down
Loading