From 440751fb0ad22ebcd5a5e1f1397a994825361468 Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Tue, 23 Nov 2021 07:47:25 +0000 Subject: [PATCH] [3/3][AOT][DeviceAPI] Wire up cpacked Device API context (#9501) * [AOT][DeviceAPI] Wire up cpacked Device API context Adding the same functionality for the Device API to the cpacked calling convention. The MakePackedAPI pass now implicitly uses any variable named `kDeviceContextVar` as the `resource_handle` and this is then used in the `cpacked` calling convention which always expects some form of resource_handle to be passed. * Document calling conventions * Remove superfluous variable --- src/relay/backend/aot_executor_codegen.cc | 23 +- src/target/source/codegen_c_host.cc | 17 +- src/target/source/codegen_c_host.h | 7 +- src/tir/transforms/lower_tvm_builtin.cc | 17 +- src/tir/transforms/make_packed_api.cc | 27 +- tests/python/contrib/test_ethosu/infra.py | 6 +- tests/python/relay/aot/test_c_device_api.py | 255 ++++++++++++++++++ tests/python/relay/aot/test_crt_aot.py | 104 ------- .../test_tir_transform_make_packed_api.py | 101 ++++++- 9 files changed, 433 insertions(+), 124 deletions(-) create mode 100644 tests/python/relay/aot/test_c_device_api.py diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index a1bd026958ca..22a6542c8b9c 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -348,10 +348,25 @@ class AOTExecutorCodegen : public MixedModeVisitor { } GlobalVar global_var = call_lowered_props.lowered_func; + tir::Var empty_var("no_device_context", DataType::Handle()); bool has_c_device_api_context = device_contexts_.count(global_var) != 0; + bool use_cpacked_api = !use_unpacked_api_; + + // The device context is passed to the operator in one of the following calling patterns: + // * Unpacked / direct function call with context: + // operator(arg0, arg1, device_context); + // * Unpacked / direct function call without context: + // operator(arg0, arg1); + // * Type-erased packed function call with context: + // operator(args, type_codes, int num_args, out_ret_value, out_ret_tcode, + // device_context_my_device) + // * Type-erased packed function call without context (we create an empty var for codegen): + // operator(args, type_codes, int num_args, out_ret_value, out_ret_tcode, + // no_device_context) if (has_c_device_api_context) { + // call_extern calling convention with context tir::Var context = device_contexts_.Get(global_var).value(); - args.push_back(device_contexts_[global_var]); + args.push_back(context); tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); create_func_call_stmts.push_back(tir::SeqStmt({ @@ -359,7 +374,13 @@ class AOTExecutorCodegen : public MixedModeVisitor { func_call, GenerateDeviceHook(context, "Close"), })); + } else if (use_cpacked_api) { + // call_cpacked calling convention needs a blank context + args.push_back(tir::make_zero(DataType::Handle())); + tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); + create_func_call_stmts.push_back(func_call); } else { + // call_extern calling convention without context tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); create_func_call_stmts.push_back(func_call); } diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 37d54571859e..0715f07316d0 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -251,7 +251,8 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar this->stream << "}\n"; } -void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_args) { +void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_args, + const std::string& resource_handle_name) { this->PrintIndent(); std::string ret_val = GetUniqueName("ret_val"); std::string ret_type_code = GetUniqueName("ret_type_code"); @@ -266,7 +267,7 @@ void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_a << "(int*) stack_tcode" << ", " << num_args << ", " << "&" << ret_val << ", " - << "&" << ret_type_code << ", NULL) != 0){\n"; + << "&" << ret_type_code << ", " << resource_handle_name << ") != 0){\n"; int func_call_scope = this->BeginScope(); this->PrintIndent(); @@ -276,7 +277,8 @@ void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_a this->stream << "}\n"; } -CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op) { +CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op, + bool has_resource_handle) { const StringImmNode* s = op->args[0].as(); ICHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name"; int64_t begin = op->args[3].as()->value; @@ -296,6 +298,10 @@ CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op) { declared_globals_[packed_func_name] = unique_name; decl_stream << "static void* " << unique_name << " = NULL;\n"; } + if (has_resource_handle) { + std::string resource_handle_name = op->args[5].as()->value; + return {func_name, unique_name, num_args - 1, resource_handle_name}; + } return {func_name, unique_name, num_args}; } @@ -327,8 +333,9 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT this->PrintGetFuncFromBackend(function_info.func_name, function_info.func_name_packed); this->PrintFuncCall(function_info.func_name_packed, function_info.num_args); } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { - auto function_info = GetFunctionInfo(op); - this->PrintFuncCallC(function_info.func_name, function_info.num_args); + auto function_info = GetFunctionInfo(op, true); + this->PrintFuncCallC(function_info.func_name, function_info.num_args, + function_info.resource_handle_name); } else if (op->op.same_as(builtin::tvm_throw_last_error())) { this->PrintIndent(); this->stream << "return -1;\n"; diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index d72e2b37ee8a..d11fa5dd255f 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -73,6 +73,8 @@ class CodeGenCHost : public CodeGenC { std::string func_name_packed; /* number of arguments required by the function */ int64_t num_args; + /* \brief name of resource_handle to pass */ + std::string resource_handle_name; }; std::string module_name_; /* \brief mapping global packed func to the unique name */ @@ -82,10 +84,11 @@ class CodeGenCHost : public CodeGenC { /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; - FunctionInfo GetFunctionInfo(const CallNode* op); + FunctionInfo GetFunctionInfo(const CallNode* op, bool has_resource_handle = false); void PrintGetFuncFromBackend(const std::string& func_name, const std::string& packed_func_name); void PrintFuncCall(const std::string& packed_func_name, int num_args); - void PrintFuncCallC(const std::string& packed_func_name, int num_args); + void PrintFuncCallC(const std::string& packed_func_name, int num_args, + const std::string& resource_handle_name); /*! * \brief Print ternary conditional operator implementing binary `op` diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 062d67eef165..3343e1062e57 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -280,11 +280,18 @@ class BuiltinLower : public StmtExprMutator { size_t restore_array_stack = scope.run_array_stack; size_t arg_stack_begin = scope.run_arg_stack; - scope.run_arg_stack += op->args.size(); + size_t arg_count = op->args.size(); + + // cpacked expects a resource_handle parameter + if (!use_string_lookup) { + arg_count--; + } + + scope.run_arg_stack += arg_count; // Specially handle the buffer packed intrinsic PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - for (size_t i = 1; i < op->args.size(); ++i) { + for (size_t i = 1; i < arg_count; ++i) { PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1); PrimExpr arg = op->args[i]; DataType t = arg.dtype(); @@ -314,6 +321,12 @@ class BuiltinLower : public StmtExprMutator { ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1)}; + // cpacked call resource_handle + if (!use_string_lookup) { + tir::Var resource_handle = Downcast(op->args[arg_count]); + packed_args.push_back(StringImm(resource_handle->name_hint)); + } + auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered() : builtin::tvm_call_cpacked_lowered(); return Call(op->dtype, builtin_call, packed_args); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 393ce6c286b4..d7e1beff03d3 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -40,6 +40,8 @@ namespace tvm { namespace tir { +static constexpr const char* kDeviceContextVar = "device_api_context"; + class ReturnRewriter : public StmtMutator { public: explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {} @@ -161,15 +163,11 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { }; // --------------------------- // start of logics - // add signiture for packed arguments. + // add signature for packed arguments. if (pack_args) { args.push_back(v_packed_args); args.push_back(v_packed_arg_type_ids); args.push_back(v_num_packed_args); - std::ostringstream os; - - os << name_hint << ": num_args should be " << num_packed_args; - seq_init.emplace_back(MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); } // Need to re-declare vars, in case some arguments also appears in the buffer. @@ -180,6 +178,13 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { Var param = func_ptr->params[i]; Var v_arg = Var("arg" + std::to_string(i), param->dtype); + // Pluck the device API context out based on name + if (param->name_hint == kDeviceContextVar) { + num_packed_args--; + v_resource_handle = param; + continue; + } + auto it = func_ptr->buffer_map.find(param); if (it != func_ptr->buffer_map.end()) { buffer_def.emplace_back(v_arg, (*it).second); @@ -262,7 +267,17 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { body = SeqStmt({set_device, body}); } } - func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); + + if (pack_args) { + std::ostringstream num_args_error; + num_args_error << name_hint << ": num_args should be " << num_packed_args; + std::vector arg_assert = { + MakeAssertEQ(v_num_packed_args, num_packed_args, num_args_error.str())}; + func_ptr->body = + MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); + } else { + func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); + } func_ptr->params = args; Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 1c0b78cebf92..38bd88c10e48 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -189,7 +189,7 @@ def deserialize_command_stream(blob): return cmms -def _create_test_runner(accel): +def create_test_runner(accel="ethos-u55-256"): file_dir = os.path.dirname(os.path.abspath(__file__)) test_root = os.path.join(file_dir, "reference_system") ethosu_macs = accel[accel.rfind("-") + 1 :] @@ -215,7 +215,7 @@ def _create_test_runner(accel): def build_source(module, inputs, outputs, accel="ethos-u55-256", output_tolerance=0): - test_runner = _create_test_runner(accel) + test_runner = create_test_runner(accel) return compile_models( models=AOTTestModel( module=module, @@ -239,7 +239,7 @@ def verify_source( This method verifies the generated source from an NPU module by building it and running on an FVP. """ interface_api = "c" - test_runner = _create_test_runner(accel) + test_runner = create_test_runner(accel) run_and_check( models, test_runner, diff --git a/tests/python/relay/aot/test_c_device_api.py b/tests/python/relay/aot/test_c_device_api.py new file mode 100644 index 000000000000..3de4fecf5544 --- /dev/null +++ b/tests/python/relay/aot/test_c_device_api.py @@ -0,0 +1,255 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +from collections import OrderedDict + +import numpy as np +import pytest + +from tvm import relay +from tvm.ir.module import IRModule +from aot_test_utils import ( + AOT_DEFAULT_RUNNER, + AOTTestModel, + generate_ref_data, + compile_models, +) + + +@pytest.fixture +def device_api_main_func(): + # Ideally we should have a sample Target registered here + # but we're going to re-use this for now + pytest.importorskip("ethosu.vela") + import tensorflow as tf + import tflite.Model + + from tests.python.contrib.test_ethosu.infra import create_test_runner, generate_ref_data_tflite + from tvm.relay.op.contrib.ethosu import partition_for_ethosu + + tf.config.run_functions_eagerly(True) + + class Model(tf.Module): + @tf.function + def tf_function(self, x): + return tf.nn.max_pool(x, [1, 2], [1, 2], "SAME") + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(1, 3, 4, 3) + yield [data.astype(np.float32)] + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec([1, 3, 4, 3], dtype=tf.float32) + ) + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + + tflite_graph = converter.convert() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"x": [1, 3, 4, 3]}, + dtype_dict={"x": "int8"}, + ) + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = generate_ref_data_tflite(tflite_graph) + + def compile_to_main_func(interface_api="c", use_unpacked_api=True): + test_runner = create_test_runner() + compiled_models = compile_models( + models=AOTTestModel( + module=mod, + inputs=input_data, + outputs=output_data, + ), + interface_api=interface_api, + use_unpacked_api=use_unpacked_api, + workspace_byte_alignment=16, + pass_config=test_runner.pass_config, + ) + main_ir_module = list(compiled_models[0].executor_factory.lowered_ir_mods.values())[0] + main_func = main_ir_module["run_model"] + return main_func + + return compile_to_main_func + + +@pytest.fixture +def non_device_api_main_func(): + x = relay.var("x", shape=(10, 10)) + y = relay.var("y", shape=(1, 10)) + func = relay.Function([x, y], relay.multiply(x, y)) + x_data = np.random.rand(10, 10).astype("float32") + y_data = np.random.rand(1, 10).astype("float32") + + inputs = OrderedDict([("x", x_data), ("y", y_data)]) + output_list = generate_ref_data(func, inputs) + + def compile_to_main_func(interface_api="c", use_unpacked_api=True): + test_runner = AOT_DEFAULT_RUNNER + compiled_models = compile_models( + models=AOTTestModel( + module=IRModule.from_expr(func), + inputs=inputs, + outputs=output_list, + ), + interface_api=interface_api, + use_unpacked_api=use_unpacked_api, + workspace_byte_alignment=16, + pass_config=test_runner.pass_config, + ) + main_ir_module = list(compiled_models[0].executor_factory.lowered_ir_mods.values())[0] + main_func = main_ir_module["run_model"] + return main_func + + return compile_to_main_func + + +def test_device_api_hooks_unpacked_api(device_api_main_func): + """Check for Device API hooks with unpacked internal calls""" + main_func = device_api_main_func(interface_api="c", use_unpacked_api=True) + + # Activate Device + assert ( + str(main_func.body[0][0].value) + == "@tir.call_extern(" + + '"TVMDeviceEthosUActivate",' + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + # Open Device + assert ( + str(main_func.body[1].body.body[0][0][0].value) + == "@tir.call_extern(" + + '"TVMDeviceEthosUOpen",' + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + # Device Call + assert ( + str(main_func.body[1].body.body[0][0][1].value) + == "@tir.call_extern(" + + '"tvmgen_default_ethos_u_main_0",' + + " input: handle, output: handle," + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + # Close Device + assert ( + str(main_func.body[1].body.body[0][0][2].value) + == "@tir.call_extern(" + + '"TVMDeviceEthosUClose",' + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + # Deactivate Device + assert ( + str(main_func.body[2][0].value) + == "@tir.call_extern(" + + '"TVMDeviceEthosUDeactivate",' + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + + +def test_device_api_hooks_packed_api(device_api_main_func): + """Check for Device API hooks with packed internal calls""" + main_func = device_api_main_func(interface_api="packed", use_unpacked_api=False) + + # Activate Device + assert ( + str(main_func.body[0][0].value) + == "@tir.call_extern(" + + '"TVMDeviceEthosUActivate",' + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + # Open Device + assert ( + str(main_func.body[1].body.body[0][0][0].value) + == "@tir.call_extern(" + + '"TVMDeviceEthosUOpen",' + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + # Device Call + assert ( + str(main_func.body[1].body.body[0][0][1][0].value) + == "@tir.tvm_call_cpacked(" + + '"tvmgen_default_ethos_u_main_0",' + + " input: handle, output: handle," + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + # Close Device + assert ( + str(main_func.body[1].body.body[0][0][2].value) + == "@tir.call_extern(" + + '"TVMDeviceEthosUClose",' + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + # Deactivate Device + assert ( + str(main_func.body[2][0].value) + == "@tir.call_extern(" + + '"TVMDeviceEthosUDeactivate",' + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + + +def test_without_device_api_unpacked_api(non_device_api_main_func): + """Test a graph without the Device API with the unpacked internal calls""" + + main_func = non_device_api_main_func(interface_api="c", use_unpacked_api=True) + + assert ( + str(main_func.body[1].body.body[0][0].value) + == "@tir.call_extern(" + + '"tvmgen_default_fused_multiply",' + + " input: handle, input_1: handle, output: handle," + + " dtype=int32)" + ) + + +def test_without_device_api_packed_api(non_device_api_main_func): + """Test a graph without the Device API with the packed internal calls""" + + main_func = non_device_api_main_func(interface_api="packed", use_unpacked_api=False) + + assert ( + str(main_func.body[1].body.body[0][0]) + == 'let tvm_value_0 = tir.tvm_stack_alloca("array", 1)\n' + + "tir.tvm_struct_set(tvm_value_0, 0, 1, tir.reinterpret((uint64)0))\n" + + 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", input, input, output, tvm_value_0)\n' + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 605d061918bd..3c7e836a7639 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -17,7 +17,6 @@ from collections import OrderedDict import sys -import re import numpy as np import pytest @@ -734,108 +733,5 @@ def @main(%data: Tensor[(1, 4, 4, 4), float32], %weight: Tensor[(4, 4, 3, 3), fl assert source.count("TVMBackendAllocWorkspace") == 3 -def test_device_api_hooks(): - """Check for Device API hooks""" - - # Ideally we should have a sample Target registered here - # but we're going to re-use this for now - pytest.importorskip("ethosu.vela") - import tensorflow as tf - import tflite.Model - - from tests.python.contrib.test_ethosu import infra - from tvm.relay.op.contrib.ethosu import partition_for_ethosu - - def create_tflite_graph(): - tf.config.run_functions_eagerly(True) - - class Model(tf.Module): - @tf.function - def tf_function(self, x): - return tf.nn.max_pool(x, [1, 2], [1, 2], "SAME") - - def representative_dataset(): - for _ in range(100): - data = np.random.rand(*tuple([1, 3, 4, 3])) - yield [data.astype(np.float32)] - - model = Model() - concrete_func = model.tf_function.get_concrete_function( - tf.TensorSpec([1, 3, 4, 3], dtype=tf.float32) - ) - - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = representative_dataset - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.int8 - converter.inference_output_type = tf.int8 - tflite_model = converter.convert() - return tflite_model - - tflite_graph = create_tflite_graph() - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) - - relay_module, params = relay.frontend.from_tflite( - tflite_model, - shape_dict={"x": [1, 3, 4, 3]}, - dtype_dict={"x": "int8"}, - ) - mod = partition_for_ethosu(relay_module, params) - - # Generate reference data - input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) - - compiled_models = infra.build_source( - mod, - input_data, - output_data, - ) - main_ir_module = list(compiled_models[0].executor_factory.lowered_ir_mods.values())[0] - main_func = main_ir_module["run_model"] - - # Activate Device - assert ( - str(main_func.body[0][0].value) - == "@tir.call_extern(" - + '"TVMDeviceEthosUActivate",' - + " device_context_ethos_u: handle," - + " dtype=int32)" - ) - # Open Device - assert ( - str(main_func.body[1].body.body[0][0][0].value) - == "@tir.call_extern(" - + '"TVMDeviceEthosUOpen",' - + " device_context_ethos_u: handle," - + " dtype=int32)" - ) - # Device Call - assert ( - str(main_func.body[1].body.body[0][0][1].value) - == "@tir.call_extern(" - + '"tvmgen_default_ethos_u_main_0",' - + " input: handle, output: handle," - + " device_context_ethos_u: handle," - + " dtype=int32)" - ) - # Close Device - assert ( - str(main_func.body[1].body.body[0][0][2].value) - == "@tir.call_extern(" - + '"TVMDeviceEthosUClose",' - + " device_context_ethos_u: handle," - + " dtype=int32)" - ) - # Deactivate Device - assert ( - str(main_func.body[2][0].value) - == "@tir.call_extern(" - + '"TVMDeviceEthosUDeactivate",' - + " device_context_ethos_u: handle," - + " dtype=int32)" - ) - - if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 1ab6bdaad90a..ca7f2315e51c 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import numpy import tvm from tvm import te @@ -45,5 +44,105 @@ def test_makeapi(): assert len(f.params) == 8 +def _find_assignment(stmt, var_name): + while not isinstance(stmt, tvm.tir.LetStmt): + stmt = stmt.body + + if stmt.var.name != var_name: + return _find_assignment(stmt.body, var_name) + + return stmt + + +def _find_next(stmt, type): + while not isinstance(stmt, type): + stmt = stmt.body + return stmt + + +def test_variable_passed_from_args(): + ib = tvm.tir.ir_builder.create() + + input_buffer = tvm.tir.decl_buffer(name="input_buffer", shape=[1]) + not_device_context = tvm.tir.Var("not_device_context", dtype="handle") + + ib.emit( + tvm.tir.call_extern("float32", "some_external_call", input_buffer.data, not_device_context), + ) + stmt = ib.get() + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, not_device_context], stmt)) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) + func = tvm.tir.transform.MakePackedAPI()(mod)["main"] + + num_args = func.params[2] + + # num_args assertion + assert func.body.condition.a == num_args + assert func.body.condition.b == 2 + + # Arguments unpacking + assignment = _find_assignment(func.body, "arg0") + assert str(assignment.value) == "@tir.tvm_struct_get(args: handle, 0, 12, dtype=handle)" + + assignment = _find_assignment(func.body, "arg1") + assert str(assignment.value) == "@tir.tvm_struct_get(args: handle, 1, 12, dtype=handle)" + + assignment = _find_assignment(func.body, "input_buffer") + assert str(assignment.value) == "@tir.tvm_struct_get(arg0: handle, 0, 1, dtype=handle)" + unpacked_input_buffer = assignment.var + + assignment = _find_assignment(func.body, "not_device_context") + assert str(assignment.value) == "arg1: handle" + unpacked_not_device_context = assignment.var + + seq_stmt = _find_next(assignment, tvm.tir.SeqStmt) + call = _find_next(seq_stmt[1], tvm.tir.Evaluate) + call_extern = call.value + + assert call_extern.args[1] == unpacked_input_buffer + assert call_extern.args[2] == unpacked_not_device_context + + +def test_device_api_context_implicit_resource_handle(): + ib = tvm.tir.ir_builder.create() + + input_buffer = tvm.tir.decl_buffer(name="input_buffer", shape=[1]) + device_context = tvm.tir.Var("device_api_context", dtype="handle") + + ib.emit( + tvm.tir.call_extern("float32", "some_external_call", input_buffer.data, device_context), + ) + stmt = ib.get() + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, device_context], stmt)) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) + func = tvm.tir.transform.MakePackedAPI()(mod)["main"] + + num_args = func.params[2] + device_context_in_resource_handle = func.params[5] + + # num_args assertion + assert func.body.condition.a == num_args + assert func.body.condition.b == 1 + + # Arguments unpacking + assignment = _find_assignment(func.body, "arg0") + assert str(assignment.value) == "@tir.tvm_struct_get(args: handle, 0, 12, dtype=handle)" + + assignment = _find_assignment(func.body, "input_buffer") + assert str(assignment.value) == "@tir.tvm_struct_get(arg0: handle, 0, 1, dtype=handle)" + unpacked_input_buffer = assignment.var + + seq_stmt = _find_next(assignment, tvm.tir.SeqStmt) + call = _find_next(seq_stmt[1], tvm.tir.Evaluate) + call_extern = call.value + + assert call_extern.args[1] == unpacked_input_buffer + assert call_extern.args[2] == device_context_in_resource_handle + + if __name__ == "__main__": test_makeapi()