Skip to content

Commit

Permalink
[3/3][AOT][DeviceAPI] Wire up cpacked Device API context (apache#9501)
Browse files Browse the repository at this point in the history
* [AOT][DeviceAPI] Wire up cpacked Device API context

Adding the same functionality for the Device API to the cpacked calling convention. The MakePackedAPI pass now implicitly uses any variable named `kDeviceContextVar` as the `resource_handle` and this is then used in the `cpacked` calling convention which always expects some form of resource_handle to be passed.

* Document calling conventions
* Remove superfluous variable
  • Loading branch information
Mousius authored and mehrdadh committed Dec 1, 2021
1 parent 6024d6c commit 8f19164
Show file tree
Hide file tree
Showing 9 changed files with 433 additions and 124 deletions.
23 changes: 22 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,18 +348,39 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

GlobalVar global_var = call_lowered_props.lowered_func;
tir::Var empty_var("no_device_context", DataType::Handle());
bool has_c_device_api_context = device_contexts_.count(global_var) != 0;
bool use_cpacked_api = !use_unpacked_api_;

// The device context is passed to the operator in one of the following calling patterns:
// * Unpacked / direct function call with context:
// operator(arg0, arg1, device_context);
// * Unpacked / direct function call without context:
// operator(arg0, arg1);
// * Type-erased packed function call with context:
// operator(args, type_codes, int num_args, out_ret_value, out_ret_tcode,
// device_context_my_device)
// * Type-erased packed function call without context (we create an empty var for codegen):
// operator(args, type_codes, int num_args, out_ret_value, out_ret_tcode,
// no_device_context)
if (has_c_device_api_context) {
// call_extern calling convention with context
tir::Var context = device_contexts_.Get(global_var).value();
args.push_back(device_contexts_[global_var]);
args.push_back(context);

tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(tir::SeqStmt({
GenerateDeviceHook(context, "Open"),
func_call,
GenerateDeviceHook(context, "Close"),
}));
} else if (use_cpacked_api) {
// call_cpacked calling convention needs a blank context
args.push_back(tir::make_zero(DataType::Handle()));
tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(func_call);
} else {
// call_extern calling convention without context
tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(func_call);
}
Expand Down
17 changes: 12 additions & 5 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar
this->stream << "}\n";
}

void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_args) {
void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_args,
const std::string& resource_handle_name) {
this->PrintIndent();
std::string ret_val = GetUniqueName("ret_val");
std::string ret_type_code = GetUniqueName("ret_type_code");
Expand All @@ -266,7 +267,7 @@ void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_a
<< "(int*) stack_tcode"
<< ", " << num_args << ", "
<< "&" << ret_val << ", "
<< "&" << ret_type_code << ", NULL) != 0){\n";
<< "&" << ret_type_code << ", " << resource_handle_name << ") != 0){\n";

int func_call_scope = this->BeginScope();
this->PrintIndent();
Expand All @@ -276,7 +277,8 @@ void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_a
this->stream << "}\n";
}

CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op) {
CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op,
bool has_resource_handle) {
const StringImmNode* s = op->args[0].as<StringImmNode>();
ICHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name";
int64_t begin = op->args[3].as<IntImmNode>()->value;
Expand All @@ -296,6 +298,10 @@ CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op) {
declared_globals_[packed_func_name] = unique_name;
decl_stream << "static void* " << unique_name << " = NULL;\n";
}
if (has_resource_handle) {
std::string resource_handle_name = op->args[5].as<StringImmNode>()->value;
return {func_name, unique_name, num_args - 1, resource_handle_name};
}
return {func_name, unique_name, num_args};
}

Expand Down Expand Up @@ -327,8 +333,9 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT
this->PrintGetFuncFromBackend(function_info.func_name, function_info.func_name_packed);
this->PrintFuncCall(function_info.func_name_packed, function_info.num_args);
} else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) {
auto function_info = GetFunctionInfo(op);
this->PrintFuncCallC(function_info.func_name, function_info.num_args);
auto function_info = GetFunctionInfo(op, true);
this->PrintFuncCallC(function_info.func_name, function_info.num_args,
function_info.resource_handle_name);
} else if (op->op.same_as(builtin::tvm_throw_last_error())) {
this->PrintIndent();
this->stream << "return -1;\n";
Expand Down
7 changes: 5 additions & 2 deletions src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class CodeGenCHost : public CodeGenC {
std::string func_name_packed;
/* number of arguments required by the function */
int64_t num_args;
/* \brief name of resource_handle to pass */
std::string resource_handle_name;
};
std::string module_name_;
/* \brief mapping global packed func to the unique name */
Expand All @@ -82,10 +84,11 @@ class CodeGenCHost : public CodeGenC {
/*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_;

FunctionInfo GetFunctionInfo(const CallNode* op);
FunctionInfo GetFunctionInfo(const CallNode* op, bool has_resource_handle = false);
void PrintGetFuncFromBackend(const std::string& func_name, const std::string& packed_func_name);
void PrintFuncCall(const std::string& packed_func_name, int num_args);
void PrintFuncCallC(const std::string& packed_func_name, int num_args);
void PrintFuncCallC(const std::string& packed_func_name, int num_args,
const std::string& resource_handle_name);

/*!
* \brief Print ternary conditional operator implementing binary `op`
Expand Down
17 changes: 15 additions & 2 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,18 @@ class BuiltinLower : public StmtExprMutator {
size_t restore_array_stack = scope.run_array_stack;
size_t arg_stack_begin = scope.run_arg_stack;

scope.run_arg_stack += op->args.size();
size_t arg_count = op->args.size();

// cpacked expects a resource_handle parameter
if (!use_string_lookup) {
arg_count--;
}

scope.run_arg_stack += arg_count;
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
for (size_t i = 1; i < op->args.size(); ++i) {
for (size_t i = 1; i < arg_count; ++i) {
PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
PrimExpr arg = op->args[i];
DataType t = arg.dtype();
Expand Down Expand Up @@ -314,6 +321,12 @@ class BuiltinLower : public StmtExprMutator {
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1)};

// cpacked call resource_handle
if (!use_string_lookup) {
tir::Var resource_handle = Downcast<Var>(op->args[arg_count]);
packed_args.push_back(StringImm(resource_handle->name_hint));
}

auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered()
: builtin::tvm_call_cpacked_lowered();
return Call(op->dtype, builtin_call, packed_args);
Expand Down
27 changes: 21 additions & 6 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
namespace tvm {
namespace tir {

static constexpr const char* kDeviceContextVar = "device_api_context";

class ReturnRewriter : public StmtMutator {
public:
explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {}
Expand Down Expand Up @@ -161,15 +163,11 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
};
// ---------------------------
// start of logics
// add signiture for packed arguments.
// add signature for packed arguments.
if (pack_args) {
args.push_back(v_packed_args);
args.push_back(v_packed_arg_type_ids);
args.push_back(v_num_packed_args);
std::ostringstream os;

os << name_hint << ": num_args should be " << num_packed_args;
seq_init.emplace_back(MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
}

// Need to re-declare vars, in case some arguments also appears in the buffer.
Expand All @@ -180,6 +178,13 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
Var param = func_ptr->params[i];
Var v_arg = Var("arg" + std::to_string(i), param->dtype);

// Pluck the device API context out based on name
if (param->name_hint == kDeviceContextVar) {
num_packed_args--;
v_resource_handle = param;
continue;
}

auto it = func_ptr->buffer_map.find(param);
if (it != func_ptr->buffer_map.end()) {
buffer_def.emplace_back(v_arg, (*it).second);
Expand Down Expand Up @@ -262,7 +267,17 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
body = SeqStmt({set_device, body});
}
}
func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);

if (pack_args) {
std::ostringstream num_args_error;
num_args_error << name_hint << ": num_args should be " << num_packed_args;
std::vector<Stmt> arg_assert = {
MakeAssertEQ(v_num_packed_args, num_packed_args, num_args_error.str())};
func_ptr->body =
MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
} else {
func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
}
func_ptr->params = args;

Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
Expand Down
6 changes: 3 additions & 3 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def deserialize_command_stream(blob):
return cmms


def _create_test_runner(accel):
def create_test_runner(accel="ethos-u55-256"):
file_dir = os.path.dirname(os.path.abspath(__file__))
test_root = os.path.join(file_dir, "reference_system")
ethosu_macs = accel[accel.rfind("-") + 1 :]
Expand All @@ -215,7 +215,7 @@ def _create_test_runner(accel):


def build_source(module, inputs, outputs, accel="ethos-u55-256", output_tolerance=0):
test_runner = _create_test_runner(accel)
test_runner = create_test_runner(accel)
return compile_models(
models=AOTTestModel(
module=module,
Expand All @@ -239,7 +239,7 @@ def verify_source(
This method verifies the generated source from an NPU module by building it and running on an FVP.
"""
interface_api = "c"
test_runner = _create_test_runner(accel)
test_runner = create_test_runner(accel)
run_and_check(
models,
test_runner,
Expand Down
Loading

0 comments on commit 8f19164

Please sign in to comment.