Skip to content

Commit

Permalink
[AOT][DeviceAPI] Wire up cpacked Device API context
Browse files Browse the repository at this point in the history
Adding the same functionality for the Device API to the cpacked calling convention. The MakePackedAPI pass now implicitly uses any variable named `kDeviceContextVar` as the `resource_handle` and this is then used in the `cpacked` calling convention which always expects some form of resource_handle to be passed.
  • Loading branch information
Mousius committed Nov 16, 2021
1 parent 948641c commit aff40bb
Show file tree
Hide file tree
Showing 9 changed files with 425 additions and 124 deletions.
14 changes: 13 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,18 +347,30 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

GlobalVar global_var = call_lowered_props.lowered_func;
tir::Var empty_var("no_device_context", DataType::Handle());
bool has_c_device_api_context = device_contexts_.count(global_var) != 0;
bool use_cpacked_api = !use_unpacked_api_;

if (has_c_device_api_context) {
// call_extern calling convention with context
tir::Var context = device_contexts_.Get(global_var).value();
args.push_back(device_contexts_[global_var]);
args.push_back(context);

tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(tir::SeqStmt({
GenerateDeviceHook(context, "Open"),
func_call,
GenerateDeviceHook(context, "Close"),
}));
} else if (use_cpacked_api) {
// call_cpacked calling convention needs a blank context
args.push_back(empty_var);

tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
tir::LetStmt set_zero(empty_var, tir::make_zero(DataType::Handle()), func_call);
create_func_call_stmts.push_back(set_zero);
} else {
// call_extern calling convention without context
tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
create_func_call_stmts.push_back(func_call);
}
Expand Down
17 changes: 12 additions & 5 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar
this->stream << "}\n";
}

void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_args) {
void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_args,
const std::string& resource_handle_name) {
this->PrintIndent();
std::string ret_val = GetUniqueName("ret_val");
std::string ret_type_code = GetUniqueName("ret_type_code");
Expand All @@ -265,7 +266,7 @@ void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_a
<< "(int*) stack_tcode"
<< ", " << num_args << ", "
<< "&" << ret_val << ", "
<< "&" << ret_type_code << ", NULL) != 0){\n";
<< "&" << ret_type_code << ", " << resource_handle_name << ") != 0){\n";

int func_call_scope = this->BeginScope();
this->PrintIndent();
Expand All @@ -275,7 +276,8 @@ void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_a
this->stream << "}\n";
}

CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op) {
CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op,
bool has_resource_handle) {
const StringImmNode* s = op->args[0].as<StringImmNode>();
ICHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name";
int64_t begin = op->args[3].as<IntImmNode>()->value;
Expand All @@ -295,6 +297,10 @@ CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op) {
declared_globals_[packed_func_name] = unique_name;
decl_stream << "static void* " << unique_name << " = NULL;\n";
}
if (has_resource_handle) {
std::string resource_handle_name = op->args[5].as<StringImmNode>()->value;
return {func_name, unique_name, num_args - 1, resource_handle_name};
}
return {func_name, unique_name, num_args};
}

Expand Down Expand Up @@ -326,8 +332,9 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT
this->PrintGetFuncFromBackend(function_info.func_name, function_info.func_name_packed);
this->PrintFuncCall(function_info.func_name_packed, function_info.num_args);
} else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) {
auto function_info = GetFunctionInfo(op);
this->PrintFuncCallC(function_info.func_name, function_info.num_args);
auto function_info = GetFunctionInfo(op, true);
this->PrintFuncCallC(function_info.func_name, function_info.num_args,
function_info.resource_handle_name);
} else if (op->op.same_as(builtin::tvm_throw_last_error())) {
this->PrintIndent();
this->stream << "return -1;\n";
Expand Down
7 changes: 5 additions & 2 deletions src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class CodeGenCHost : public CodeGenC {
std::string func_name_packed;
/* number of arguments required by the function */
int64_t num_args;
/* \brief name of resource_handle to pass */
std::string resource_handle_name;
};
std::string module_name_;
/* \brief mapping global packed func to the unique name */
Expand All @@ -82,10 +84,11 @@ class CodeGenCHost : public CodeGenC {
/*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_;

FunctionInfo GetFunctionInfo(const CallNode* op);
FunctionInfo GetFunctionInfo(const CallNode* op, bool has_resource_handle = false);
void PrintGetFuncFromBackend(const std::string& func_name, const std::string& packed_func_name);
void PrintFuncCall(const std::string& packed_func_name, int num_args);
void PrintFuncCallC(const std::string& packed_func_name, int num_args);
void PrintFuncCallC(const std::string& packed_func_name, int num_args,
const std::string& resource_handle_name);

/*!
* \brief Print ternary conditional operator implementing binary `op`
Expand Down
17 changes: 15 additions & 2 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,18 @@ class BuiltinLower : public StmtExprMutator {
size_t restore_array_stack = scope.run_array_stack;
size_t arg_stack_begin = scope.run_arg_stack;

scope.run_arg_stack += op->args.size();
size_t arg_count = op->args.size();

// cpacked expects a resource_handle parameter
if (!use_string_lookup) {
arg_count--;
}

scope.run_arg_stack += arg_count;
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
for (size_t i = 1; i < op->args.size(); ++i) {
for (size_t i = 1; i < arg_count; ++i) {
PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
PrimExpr arg = op->args[i];
DataType t = arg.dtype();
Expand Down Expand Up @@ -314,6 +321,12 @@ class BuiltinLower : public StmtExprMutator {
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1)};

// cpacked call resource_handle
if (!use_string_lookup) {
tir::Var resource_handle = Downcast<Var>(op->args[arg_count]);
packed_args.push_back(StringImm(resource_handle->name_hint));
}

auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered()
: builtin::tvm_call_cpacked_lowered();
return Call(op->dtype, builtin_call, packed_args);
Expand Down
27 changes: 21 additions & 6 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
namespace tvm {
namespace tir {

static constexpr const char* kDeviceContextVar = "device_api_context";

class ReturnRewriter : public StmtMutator {
public:
explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {}
Expand Down Expand Up @@ -161,15 +163,11 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
};
// ---------------------------
// start of logics
// add signiture for packed arguments.
// add signature for packed arguments.
if (pack_args) {
args.push_back(v_packed_args);
args.push_back(v_packed_arg_type_ids);
args.push_back(v_num_packed_args);
std::ostringstream os;

os << name_hint << ": num_args should be " << num_packed_args;
seq_init.emplace_back(MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
}

// Need to re-declare vars, in case some arguments also appears in the buffer.
Expand All @@ -180,6 +178,13 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
Var param = func_ptr->params[i];
Var v_arg = Var("arg" + std::to_string(i), param->dtype);

// Pluck the device API context out based on name
if (param->name_hint == kDeviceContextVar) {
num_packed_args--;
v_resource_handle = param;
continue;
}

auto it = func_ptr->buffer_map.find(param);
if (it != func_ptr->buffer_map.end()) {
buffer_def.emplace_back(v_arg, (*it).second);
Expand Down Expand Up @@ -262,7 +267,17 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
body = SeqStmt({set_device, body});
}
}
func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);

if (pack_args) {
std::ostringstream num_args_error;
num_args_error << name_hint << ": num_args should be " << num_packed_args;
std::vector<Stmt> arg_assert = {
MakeAssertEQ(v_num_packed_args, num_packed_args, num_args_error.str())};
func_ptr->body =
MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
} else {
func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
}
func_ptr->params = args;

Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
Expand Down
6 changes: 3 additions & 3 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def deserialize_command_stream(blob):
return cmms


def _create_test_runner(accel):
def create_test_runner(accel="ethos-u55-256"):
file_dir = os.path.dirname(os.path.abspath(__file__))
test_root = os.path.join(file_dir, "reference_system")
ethosu_macs = accel[accel.rfind("-") + 1 :]
Expand All @@ -215,7 +215,7 @@ def _create_test_runner(accel):


def build_source(module, inputs, outputs, accel="ethos-u55-256", output_tolerance=0):
test_runner = _create_test_runner(accel)
test_runner = create_test_runner(accel)
return compile_models(
models=AOTTestModel(
module=module,
Expand All @@ -239,7 +239,7 @@ def verify_source(
This method verifies the generated source from an NPU module by building it and running on an FVP.
"""
interface_api = "c"
test_runner = _create_test_runner(accel)
test_runner = create_test_runner(accel)
run_and_check(
models,
test_runner,
Expand Down
Loading

0 comments on commit aff40bb

Please sign in to comment.