From 9cc46201e45c23f36a4972283e7da2a3e39f511a Mon Sep 17 00:00:00 2001 From: Rainy-Memory <2630737606@qq.com> Date: Sun, 12 Feb 2023 21:14:13 +0800 Subject: [PATCH] [MetaSchedule] Introduce Async Pipeline in MultiLevelTiling This PR introduces async pipeline in the current TVM's MultiLevelTiling Rules. This PR is blocking on apache/tvm#13966 since some conv2d workload will use `tir.if_then_else` to pad the input to the correct size, and this PR uses async copy in such copy statement. 1. Add a subrule in `src/meta_schedule/schedule_rule/multi_level_tiling.h/.cc` that annotate async copy for mlt. In CUDA Core, this PR has a perf boost of around 1T GFLOP/s in most Conv2d test cases and 1T ~ 2T in most GEMM test cases. All generated codes, scripts, and traces are available at https://github.com/Rainy-Memory/tvm-async-rule-benchmark. Currently tested on commit `afbfb7aa7e43732cb716f8e443df696110be6afc` in conv2d NHWC workload, with a RTX 3080 GPU. Workload: Conv2d NHWC |Shape|Mainline TVM|Mainline TVM with Async| |-|-|-| |N=1_H=224_W=224_C=3_K=64_R=7_S=7_STR=2_PAD=3_DIL=1|13838.05219|14687.89452| |N=1_H=56_W=56_C=64_K=64_R=1_S=1_STR=1_PAD=0_DIL=1|5398.305085|5613.892553| |N=1_H=56_W=56_C=64_K=64_R=3_S=3_STR=1_PAD=1_DIL=1|11652.96825|13157.88249| |N=1_H=56_W=56_C=64_K=256_R=1_S=1_STR=1_PAD=0_DIL=1|10638.8309|11674.68499| |N=1_H=56_W=56_C=256_K=64_R=1_S=1_STR=1_PAD=0_DIL=1|8692.32829|9469.264089| |N=1_H=56_W=56_C=256_K=128_R=1_S=1_STR=2_PAD=0_DIL=1|4685.767442|5698.19634| |N=1_H=28_W=28_C=128_K=128_R=3_S=3_STR=1_PAD=1_DIL=1|9872.787087|10404.60405| |N=1_H=28_W=28_C=128_K=512_R=1_S=1_STR=1_PAD=0_DIL=1|9974.281496|10073.31657| |N=1_H=28_W=28_C=512_K=128_R=1_S=1_STR=1_PAD=0_DIL=1|7075.866932|8564.572712| |N=1_H=28_W=28_C=512_K=256_R=1_S=1_STR=2_PAD=0_DIL=1|3648.330914|4021.923142| |N=1_H=14_W=14_C=256_K=256_R=3_S=3_STR=1_PAD=1_DIL=1|8192.954618|9160.182054| |N=1_H=14_W=14_C=256_K=1024_R=1_S=1_STR=1_PAD=0_DIL=1|8008.870153|9362.825279| |N=1_H=14_W=14_C=1024_K=256_R=1_S=1_STR=1_PAD=0_DIL=1|5210.062241|6051.208379| |N=1_H=14_W=14_C=1024_K=512_R=1_S=1_STR=2_PAD=0_DIL=1|2550.787202|3587.902938| |N=1_H=7_W=7_C=512_K=512_R=3_S=3_STR=1_PAD=1_DIL=1|4350.626084|5432.788068| |N=1_H=7_W=7_C=512_K=2048_R=1_S=1_STR=1_PAD=0_DIL=1|6672.068026|7663.725217| |N=1_H=7_W=7_C=2048_K=512_R=1_S=1_STR=1_PAD=0_DIL=1|3142.564263|4297.988014| Workload: GEMM NN |Shape|Mainline TVM|Mainline TVM with Async| |-|-|-| |M=512_N=256_K=640|8678.46|10607.37| |M=512_N=384_K=256|8109.13|10290.72| |M=512_N=512_K=512|11419.83|14000.86| |M=512_N=3072_K=768|19709.39|18351.61| |M=512_N=768_K=3072|12844.59|13730.88| |M=896_N=896_K=896|16149.91|16131.39| |M=1024_N=1024_K=1024|18842.11|19662.8| |M=1152_N=1152_K=1152|15386.79|16736.1| |M=1536_N=1536_K=1536|18522.67|18872.06| |M=2048_N=2048_K=2048|19515.42|18874.85| |M=3072_N=3072_K=3072|19233.9|19291.42| |M=4096_N=4096_K=4096|17122.17|19259.01| --- .../template_project/microtvm_api_server.py | 10 +- .../example_project/{model.c => platform.c} | 25 +- .../example_project/{model.h => platform.h} | 4 - .../src/example_project/project.ino | 4 +- .../{model_support.c => platform.c} | 12 +- .../template_project/CMakeLists.txt.template | 2 +- .../template_project/microtvm_api_server.py | 67 ++- .../src/aot_standalone_demo/main.c | 193 +++---- .../src/aot_standalone_demo/platform.c | 126 +++++ .../src/aot_standalone_demo/zephyr_uart.c | 87 --- .../src/aot_standalone_demo/zephyr_uart.h | 50 -- .../src/host_driven/fvp/semihost.c | 2 +- .../template_project/src/host_driven/main.c | 154 +----- .../src/host_driven/platform.c | 155 ++++++ .../src/mlperftiny/platform.cc | 68 +++ .../src/mlperftiny/submitter_implemented.cc | 174 +++++- .../src/mlperftiny/tvmruntime.cc | 164 ------ .../src/mlperftiny/tvmruntime.h | 62 --- .../src/mlperftiny/zephyr_uart.cc | 89 --- .../src/mlperftiny/zephyr_uart.h | 51 -- .../minimal_cross_isa_jenkinsfile.groovy | 4 +- .../minimal_cross_isa_jenkinsfile.groovy.j2 | 2 +- cmake/modules/CRT.cmake | 7 +- cmake/modules/Zephyr.cmake | 2 - .../work_with_microtvm/micro_mlperftiny.py | 2 +- .../work_with_microtvm/micro_pytorch.py | 2 +- include/tvm/relay/attrs/transform.h | 5 +- include/tvm/runtime/crt/platform.h | 9 + include/tvm/runtime/ndarray.h | 8 - include/tvm/script/ir_builder/tir/ir.h | 5 +- include/tvm/target/virtual_device.h | 18 +- python/tvm/meta_schedule/runner/rpc_runner.py | 6 +- python/tvm/micro/model_library_format.py | 14 +- python/tvm/relay/frontend/onnx.py | 53 +- python/tvm/relay/op/transform.py | 8 +- python/tvm/script/ir_builder/tir/__init__.py | 2 +- python/tvm/script/ir_builder/tir/ir.py | 37 +- python/tvm/script/parser/core/entry.py | 2 + python/tvm/script/parser/ir/parser.py | 11 + python/tvm/script/parser/tir/entry.py | 17 +- python/tvm/script/parser/tir/parser.py | 2 +- python/tvm/testing/aot.py | 4 +- python/tvm/tir/tensor_intrin/cuda.py | 40 +- python/tvm/topi/cuda/scatter.py | 16 +- python/tvm/topi/scatter.py | 28 +- python/tvm/utils/roofline/cuda.py | 2 +- python/tvm/utils/roofline/x86.py | 2 +- src/meta_schedule/database/memory_database.cc | 2 +- .../schedule_rule/multi_level_tiling.cc | 55 ++ .../schedule_rule/multi_level_tiling.h | 4 + src/relay/analysis/graph_partitioner.cc | 334 ++++++++++++ src/relay/analysis/graph_partitioner.h | 269 +++++++++ src/relay/transforms/fuse_ops.cc | 516 +----------------- src/runtime/crt/host/Makefile.template | 4 +- src/runtime/crt/host/main.cc | 78 +-- src/runtime/crt/host/microtvm_api_server.py | 31 +- src/runtime/crt/host/platform.cc | 126 +++++ src/runtime/crt/platform-template.c | 80 +++ src/runtime/hexagon/hexagon_buffer.cc | 44 +- src/script/ir_builder/tir/ir.cc | 41 +- src/script/ir_builder/tir/utils.h | 1 + src/script/printer/ir/ir.cc | 54 +- src/script/printer/ir/misc.cc | 3 + src/script/printer/tir/expr.cc | 31 +- src/script/printer/tir/ir.cc | 6 +- src/script/printer/utils.h | 14 +- src/target/source/codegen_cuda.cc | 8 +- src/target/source/ptx.cc | 31 ++ src/target/source/ptx.h | 16 + src/target/virtual_device.cc | 12 +- src/tir/ir/stmt.cc | 6 + src/tir/transforms/inject_ptx_async_copy.cc | 156 +++--- tests/micro/arduino/test_arduino_workflow.py | 14 +- tests/micro/arduino/testdata/project.ino | 5 +- tests/micro/zephyr/test_zephyr.py | 36 ++ .../zephyr/test_zephyr_aot_exec_standalone.py | 4 +- tests/micro/zephyr/utils.py | 8 +- .../test_copy_compute_reordering.py | 76 +-- .../test_ethosu/test_merge_constants.py | 40 +- tests/python/frontend/onnx/test_forward.py | 2 - tests/python/integration/test_lower.py | 12 +- .../aot/test_aot_create_executor_metadata.py | 2 +- .../relay/aot/test_pass_aot_lower_main.py | 4 +- tests/python/relay/test_op_level3.py | 19 +- tests/python/topi/python/test_topi_scatter.py | 18 +- .../unittest/test_aot_legalize_packed_call.py | 16 +- .../unittest/test_arith_domain_touched.py | 4 +- .../unittest/test_auto_scheduler_feature.py | 6 +- .../unittest/test_cp_async_in_if_then_else.py | 238 ++++++++ .../unittest/test_meta_schedule_database.py | 13 +- ..._meta_schedule_postproc_verify_gpu_code.py | 12 +- .../unittest/test_meta_schedule_runner.py | 41 +- .../test_meta_schedule_trace_apply.py | 40 +- .../test_micro_model_library_format.py | 25 + .../unittest/test_te_create_primfunc.py | 16 +- .../python/unittest/test_tir_analysis_oob.py | 2 +- tests/python/unittest/test_tir_constructor.py | 3 +- tests/python/unittest/test_tir_intrin.py | 10 +- .../unittest/test_tir_lower_match_buffer.py | 26 +- tests/python/unittest/test_tir_renew_defs.py | 6 +- .../unittest/test_tir_schedule_rfactor.py | 2 +- .../unittest/test_tir_schedule_tensorize.py | 24 +- tests/python/unittest/test_tir_specialize.py | 18 +- .../test_tir_transform_common_subexpr_elim.py | 4 +- .../test_tir_transform_hoist_expression.py | 4 +- ..._plan_update_buffer_allocation_location.py | 2 +- .../test_tir_transform_storage_flatten.py | 2 +- .../test_tir_transform_storage_rewrite.py | 4 +- ...orm_convert_pool_allocations_to_offsets.py | 36 +- .../unittest/test_tvmscript_error_report.py | 4 +- .../unittest/test_tvmscript_ir_builder_tir.py | 32 +- .../unittest/test_tvmscript_parser_tir.py | 4 +- .../unittest/test_tvmscript_printer_tir.py | 54 +- .../unittest/test_tvmscript_roundtrip.py | 70 +-- .../unittest/test_tvmscript_syntax_sugar.py | 6 +- .../task_config_build_minimal_cross_isa.sh | 1 + 116 files changed, 2754 insertions(+), 1940 deletions(-) rename apps/microtvm/arduino/template_project/src/example_project/{model.c => platform.c} (80%) rename apps/microtvm/arduino/template_project/src/example_project/{model.h => platform.h} (94%) rename apps/microtvm/arduino/template_project/src/host_driven/{model_support.c => platform.c} (85%) create mode 100644 apps/microtvm/zephyr/template_project/src/aot_standalone_demo/platform.c delete mode 100644 apps/microtvm/zephyr/template_project/src/aot_standalone_demo/zephyr_uart.c delete mode 100644 apps/microtvm/zephyr/template_project/src/aot_standalone_demo/zephyr_uart.h create mode 100644 apps/microtvm/zephyr/template_project/src/host_driven/platform.c create mode 100644 apps/microtvm/zephyr/template_project/src/mlperftiny/platform.cc delete mode 100644 apps/microtvm/zephyr/template_project/src/mlperftiny/tvmruntime.cc delete mode 100644 apps/microtvm/zephyr/template_project/src/mlperftiny/tvmruntime.h delete mode 100644 apps/microtvm/zephyr/template_project/src/mlperftiny/zephyr_uart.cc delete mode 100644 apps/microtvm/zephyr/template_project/src/mlperftiny/zephyr_uart.h create mode 100644 src/relay/analysis/graph_partitioner.cc create mode 100644 src/relay/analysis/graph_partitioner.h create mode 100644 src/runtime/crt/host/platform.cc create mode 100644 src/runtime/crt/platform-template.c create mode 100644 tests/python/unittest/test_cp_async_in_if_then_else.py diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index 05c17ee194b2b..f121238a678d2 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -197,8 +197,8 @@ def _disassemble_mlf(self, mlf_tar_path, source_dir): metadata = json.load(f) return metadata - def _template_model_header(self, source_dir, metadata): - with open(source_dir / "model.h", "r") as f: + def _template_model(self, source_dir, metadata): + with open(source_dir / "platform.c", "r") as f: model_h_template = Template(f.read()) all_module_names = [] @@ -218,7 +218,7 @@ def _template_model_header(self, source_dir, metadata): "workspace_size_bytes": workspace_size_bytes, } - with open(source_dir / "model.h", "w") as f: + with open(source_dir / "platform.c", "w") as f: f.write(model_h_template.substitute(template_values)) # Arduino ONLY recognizes .ino, .ccp, .c, .h @@ -415,9 +415,9 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec metadata = self._disassemble_mlf(model_library_format_path, source_dir) shutil.copy2(model_library_format_path, project_dir / MODEL_LIBRARY_FORMAT_RELPATH) - # For AOT, template model.h with metadata to minimize space usage + # For AOT, template platform.c with metadata to minimize space usage if project_type == "example_project": - self._template_model_header(source_dir, metadata) + self._template_model(source_dir, metadata) self._change_cpp_file_extensions(source_dir) diff --git a/apps/microtvm/arduino/template_project/src/example_project/model.c b/apps/microtvm/arduino/template_project/src/example_project/platform.c similarity index 80% rename from apps/microtvm/arduino/template_project/src/example_project/model.c rename to apps/microtvm/arduino/template_project/src/example_project/platform.c index 46f43752ef2a4..973b8aa18cc43 100644 --- a/apps/microtvm/arduino/template_project/src/example_project/model.c +++ b/apps/microtvm/arduino/template_project/src/example_project/platform.c @@ -17,17 +17,22 @@ * under the License. */ -#include "model.h" +/*! + * \brief Implementation of TVMPlatform functions in tvm/runtime/crt/platform.h + */ #include "Arduino.h" #include "standalone_crt/include/dlpack/dlpack.h" #include "standalone_crt/include/tvm/runtime/crt/stack_allocator.h" +#define TVM_WORKSPACE_SIZE_BYTES $workspace_size_bytes + // AOT memory array, stack allocator wants it aligned -static uint8_t g_aot_memory[WORKSPACE_SIZE] +static uint8_t g_aot_memory[TVM_WORKSPACE_SIZE_BYTES] __attribute__((aligned(TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES))); tvm_workspace_t app_workspace; +// Called when an internal error occurs and execution cannot continue. // Blink code for debugging purposes void TVMPlatformAbort(tvm_crt_error_t error) { TVMLogf("TVMPlatformAbort: 0x%08x\n", error); @@ -45,19 +50,23 @@ void TVMPlatformAbort(tvm_crt_error_t error) { } } -void TVMLogf(const char* msg, ...) {} - +// Allocate memory for use by TVM. tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { return StackMemoryManager_Allocate(&app_workspace, num_bytes, out_ptr); } +// Free memory used by TVM. tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { return StackMemoryManager_Free(&app_workspace, ptr); } +// Internal logging API call implementation. +void TVMLogf(const char* msg, ...) {} + unsigned long g_utvm_start_time_micros; int g_utvm_timer_running = 0; +// Start a device timer. tvm_crt_error_t TVMPlatformTimerStart() { if (g_utvm_timer_running) { return kTvmErrorPlatformTimerBadState; @@ -67,6 +76,7 @@ tvm_crt_error_t TVMPlatformTimerStart() { return kTvmErrorNoError; } +// Stop the running device timer and get the elapsed time (in microseconds). tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { if (!g_utvm_timer_running) { return kTvmErrorPlatformTimerBadState; @@ -77,6 +87,7 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { return kTvmErrorNoError; } +// Fill a buffer with random data. tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { for (size_t i = 0; i < num_bytes; i++) { buffer[i] = rand(); @@ -84,7 +95,11 @@ tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { return kTvmErrorNoError; } -void TVMInitialize() { StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE); } +// Initialize TVM inference. +tvm_crt_error_t TVMPlatformInitialize() { + StackMemoryManager_Init(&app_workspace, g_aot_memory, sizeof(g_aot_memory)); + return kTvmErrorNoError; +} void TVMExecute(void* input_data, void* output_data) { int ret_val = tvmgen_default___tvm_main__(input_data, output_data); diff --git a/apps/microtvm/arduino/template_project/src/example_project/model.h b/apps/microtvm/arduino/template_project/src/example_project/platform.h similarity index 94% rename from apps/microtvm/arduino/template_project/src/example_project/model.h rename to apps/microtvm/arduino/template_project/src/example_project/platform.h index 7381c97e9b3fa..d6f10e13e96e8 100644 --- a/apps/microtvm/arduino/template_project/src/example_project/model.h +++ b/apps/microtvm/arduino/template_project/src/example_project/platform.h @@ -17,14 +17,10 @@ * under the License. */ -#define WORKSPACE_SIZE $workspace_size_bytes - #ifdef __cplusplus extern "C" { #endif -void TVMInitialize(); - /* TODO template this function signature with the input and output * data types and sizes. For example: * diff --git a/apps/microtvm/arduino/template_project/src/example_project/project.ino b/apps/microtvm/arduino/template_project/src/example_project/project.ino index 5f5683161e0af..666396b407ae5 100644 --- a/apps/microtvm/arduino/template_project/src/example_project/project.ino +++ b/apps/microtvm/arduino/template_project/src/example_project/project.ino @@ -17,10 +17,10 @@ * under the License. */ -#include "src/model.h" +#include "src/standalone_crt/include/tvm/runtime/crt/platform.h" void setup() { - TVMInitialize(); + TVMPlatformInitialize(); // If desired, initialize the RNG with random noise // randomSeed(analogRead(0)); } diff --git a/apps/microtvm/arduino/template_project/src/host_driven/model_support.c b/apps/microtvm/arduino/template_project/src/host_driven/platform.c similarity index 85% rename from apps/microtvm/arduino/template_project/src/host_driven/model_support.c rename to apps/microtvm/arduino/template_project/src/host_driven/platform.c index bcc9a109cace5..0a276134d4190 100644 --- a/apps/microtvm/arduino/template_project/src/host_driven/model_support.c +++ b/apps/microtvm/arduino/template_project/src/host_driven/platform.c @@ -17,22 +17,28 @@ * under the License. */ +/*! + * \brief Implementation of TVMPlatform functions in tvm/runtime/crt/platform.h + */ + #include "standalone_crt/include/dlpack/dlpack.h" #include "standalone_crt/include/tvm/runtime/crt/error_codes.h" #include "stdarg.h" -// Blink code for debugging purposes +// Called when an internal error occurs and execution cannot continue. void TVMPlatformAbort(tvm_crt_error_t error) { TVMLogf("TVMPlatformAbort: 0x%08x\n", error); for (;;) ; } +// Called by the microTVM RPC server to implement TVMLogf. size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, va_list args) { return vsnprintf(out_buf, out_buf_size_bytes, fmt, args); } +// Allocate memory for use by TVM. tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { if (num_bytes == 0) { num_bytes = sizeof(int); @@ -41,6 +47,7 @@ tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** return (*out_ptr == NULL) ? kTvmErrorPlatformNoMemory : kTvmErrorNoError; } +// Free memory used by TVM. tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { free(ptr); return kTvmErrorNoError; @@ -49,6 +56,7 @@ tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { unsigned long g_utvm_start_time_micros; int g_utvm_timer_running = 0; +// Start a device timer. tvm_crt_error_t TVMPlatformTimerStart() { if (g_utvm_timer_running) { return kTvmErrorPlatformTimerBadState; @@ -58,6 +66,7 @@ tvm_crt_error_t TVMPlatformTimerStart() { return kTvmErrorNoError; } +// Stop the running device timer and get the elapsed time (in microseconds). tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { if (!g_utvm_timer_running) { return kTvmErrorPlatformTimerBadState; @@ -68,6 +77,7 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { return kTvmErrorNoError; } +// Fill a buffer with random data. tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { for (size_t i = 0; i < num_bytes; i++) { buffer[i] = rand(); diff --git a/apps/microtvm/zephyr/template_project/CMakeLists.txt.template b/apps/microtvm/zephyr/template_project/CMakeLists.txt.template index 1aff9ece6bfa1..cec29e7248822 100644 --- a/apps/microtvm/zephyr/template_project/CMakeLists.txt.template +++ b/apps/microtvm/zephyr/template_project/CMakeLists.txt.template @@ -83,4 +83,4 @@ endif() file(GLOB_RECURSE app_srcs src/**.c src/**.cc) target_sources(app PRIVATE ${app_srcs} ${cmsis_lib_srcs}) -target_include_directories(app PRIVATE crt_config ${CMAKE_SOURCE_DIR}/include crt/include ${cmsis_includes}) +target_include_directories(app PRIVATE crt_config include ${CMAKE_SOURCE_DIR}/include crt/include ${cmsis_includes}) diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py index 17dd991229b8d..227b0389445a7 100644 --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -210,14 +210,14 @@ def _get_board_mem_size_bytes(zephyr_base: str, board: str): return None -DEFAULT_HEAP_SIZE_BYTES = 216 * 1024 +DEFAULT_WORKSPACE_SIZE_BYTES = 216 * 1024 def _get_recommended_heap_size_bytes(board: str): prop = BOARD_PROPERTIES[board] if "recommended_heap_size_bytes" in prop: return prop["recommended_heap_size_bytes"] - return DEFAULT_HEAP_SIZE_BYTES + return DEFAULT_WORKSPACE_SIZE_BYTES def generic_find_serial_port(serial_number: str = None): @@ -358,11 +358,11 @@ def _get_nrf_device_args(serial_number: str = None) -> list: help="Run on the FVP emulator instead of hardware.", ), server.ProjectOption( - "heap_size_bytes", + "workspace_size_bytes", optional=["generate_project"], type="int", default=None, - help="Sets the value for HEAP_SIZE_BYTES passed to K_HEAP_DEFINE() to service TVM memory allocation requests.", + help="Sets the value for TVM_WORKSPACE_SIZE_BYTES passed to K_HEAP_DEFINE() to service TVM memory allocation requests.", ), ] @@ -403,7 +403,13 @@ def server_info_query(self, tvm_version): } def _create_prj_conf( - self, project_dir: pathlib.Path, board: str, project_type: str, config_main_stack_size + self, + project_dir: pathlib.Path, + board: str, + project_type: str, + config_main_stack_size: int, + config_led: bool, + use_fvp: bool, ): with open(project_dir / "prj.conf", "w") as f: f.write( @@ -413,6 +419,13 @@ def _create_prj_conf( "CONFIG_UART_INTERRUPT_DRIVEN=y\n" "\n" ) + if ( + config_led + and not self._is_qemu(board, use_fvp) + and not self._is_fvp(board, use_fvp) + ): + f.write("# For debugging.\n" "CONFIG_LED=y\n" "\n") + f.write("# For TVMPlatformAbort().\n" "CONFIG_REBOOT=y\n" "\n") if project_type == "host_driven": @@ -522,6 +535,18 @@ def _generate_cmake_args( return cmake_args + def _copy_src_and_header_files(self, src_dir: pathlib.Path, dst_dir: pathlib.Path): + """Copy content of src_dir from template project to dst_dir in separate + source and header sub-directories. + """ + for file in os.listdir(src_dir): + file = src_dir / file + if file.is_file(): + if file.suffix in [".cc", ".c"]: + shutil.copy2(file, dst_dir / "src") + elif file.suffix in [".h"]: + shutil.copy2(file, dst_dir / "include" / "tvm") + def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): zephyr_board = options["board"] project_type = options["project_type"] @@ -533,7 +558,7 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec verbose = options.get("verbose") recommended_heap_size = _get_recommended_heap_size_bytes(zephyr_board) - heap_size_bytes = options.get("heap_size_bytes") or recommended_heap_size + workspace_size_bytes = options.get("workspace_size_bytes") or recommended_heap_size board_mem_size = _get_board_mem_size_bytes(zephyr_base, zephyr_board) compile_definitions = options.get("compile_definitions") @@ -602,7 +627,7 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec else: shutil.copy2(src_path, dst_path) - # Populate Makefile. + # Populate CMakeLists. with open(project_dir / CMAKELIST_FILENAME, "w") as cmake_f: with open(API_SERVER_DIR / f"{CMAKELIST_FILENAME}.template", "r") as cmake_template_f: for line in cmake_template_f: @@ -629,10 +654,10 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec if board_mem_size is not None: assert ( - heap_size_bytes < board_mem_size - ), f"Heap size {heap_size_bytes} is larger than memory size {board_mem_size} on this board." + workspace_size_bytes < board_mem_size + ), f"Workspace size {workspace_size_bytes} is larger than memory size {board_mem_size} on this board." cmake_f.write( - f"target_compile_definitions(app PUBLIC -DHEAP_SIZE_BYTES={heap_size_bytes})\n" + f"target_compile_definitions(app PUBLIC -DTVM_WORKSPACE_SIZE_BYTES={workspace_size_bytes})\n" ) if compile_definitions: @@ -649,7 +674,9 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec if self._is_fvp(zephyr_board, use_fvp): cmake_f.write(f"target_compile_definitions(app PUBLIC -DFVP=1)\n") - self._create_prj_conf(project_dir, zephyr_board, project_type, config_main_stack_size) + self._create_prj_conf( + project_dir, zephyr_board, project_type, config_main_stack_size, verbose, use_fvp + ) # Populate crt-config.h crt_config_dir = project_dir / "crt_config" @@ -658,13 +685,19 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec API_SERVER_DIR / "crt_config" / "crt_config.h", crt_config_dir / "crt_config.h" ) - # Populate src/ + # Populate `src` and `include` src_dir = project_dir / "src" - if project_type != "host_driven" or self._is_fvp(zephyr_board, use_fvp): - shutil.copytree(API_SERVER_DIR / "src" / project_type, src_dir) - else: - src_dir.mkdir() - shutil.copy2(API_SERVER_DIR / "src" / project_type / "main.c", src_dir) + src_dir.mkdir() + include_dir = project_dir / "include" / "tvm" + include_dir.mkdir(parents=True) + src_project_type_dir = API_SERVER_DIR / "src" / project_type + self._copy_src_and_header_files(src_project_type_dir, project_dir) + + if self._is_fvp(zephyr_board, use_fvp): + self._copy_src_and_header_files(src_project_type_dir / "fvp", project_dir) + + if project_type == "mlperftiny": + shutil.copytree(src_project_type_dir / "api", src_dir / "api") # Populate extra_files if extra_files_tar: diff --git a/apps/microtvm/zephyr/template_project/src/aot_standalone_demo/main.c b/apps/microtvm/zephyr/template_project/src/aot_standalone_demo/main.c index 9ba521ae171d2..fff8f5787597f 100644 --- a/apps/microtvm/zephyr/template_project/src/aot_standalone_demo/main.c +++ b/apps/microtvm/zephyr/template_project/src/aot_standalone_demo/main.c @@ -25,23 +25,18 @@ #include #include #include +#include #include -#include +#include -#include "input_data.h" -#include "output_data.h" +#include "tvm/input_data.h" +#include "tvm/output_data.h" #include "tvmgen_default.h" -#include "zephyr_uart.h" #ifdef CONFIG_ARCH_POSIX #include "posix_board_if.h" #endif -// WORKSPACE_SIZE defined in Project API Makefile - -static uint8_t g_aot_memory[WORKSPACE_SIZE]; -tvm_workspace_t app_workspace; - // Transport Commands. // Commands on host end with `\n` // Commands on microTVM device end with `%` @@ -53,126 +48,84 @@ const unsigned char CMD_INFER[] = "infer"; #define CMD_SIZE 80u #define CMD_TERMINATOR '%' -size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, - va_list args) { - return vsnprintk(out_buf, out_buf_size_bytes, fmt, args); -} +static uint8_t main_rx_buf[128]; +static uint8_t g_cmd_buf[128]; +static size_t g_cmd_buf_ind; -void TVMLogf(const char* msg, ...) { - char buffer[256]; - int size; - va_list args; - va_start(args, msg); - size = vsprintf(buffer, msg, args); - va_end(args); - TVMPlatformWriteSerial(buffer, (uint32_t)size); -} +static const struct device* g_microtvm_uart; +#define RING_BUF_SIZE_BYTES (TVM_CRT_MAX_PACKET_SIZE_BYTES + 100) -void TVMPlatformAbort(tvm_crt_error_t error) { - TVMLogf("TVMPlatformAbort: %08x\n", error); - sys_reboot(SYS_REBOOT_COLD); - for (;;) - ; -} +// Ring buffer used to store data read from the UART on rx interrupt. +RING_BUF_DECLARE(uart_rx_rbuf, RING_BUF_SIZE_BYTES); -tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { - return StackMemoryManager_Allocate(&app_workspace, num_bytes, out_ptr); +uint32_t UartTxWrite(const char* data, uint32_t size) { + for (uint32_t i = 0; i < size; i++) { + uart_poll_out(g_microtvm_uart, data[i]); + } + return size; } -tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { - return StackMemoryManager_Free(&app_workspace, ptr); +uint32_t UartRxRead(uint8_t* data, uint32_t data_size_bytes) { + unsigned int key = irq_lock(); + uint32_t bytes_read = ring_buf_get(&uart_rx_rbuf, data, data_size_bytes); + irq_unlock(key); + return bytes_read; } -void timer_expiry_function(struct k_timer* timer_id) { return; } - -#define MILLIS_TIL_EXPIRY 200 -#define TIME_TIL_EXPIRY (K_MSEC(MILLIS_TIL_EXPIRY)) -struct k_timer g_microtvm_timer; -uint32_t g_microtvm_start_time; -int g_microtvm_timer_running = 0; - -// Called to start system timer. -tvm_crt_error_t TVMPlatformTimerStart() { - if (g_microtvm_timer_running) { - TVMLogf("timer already running"); - return kTvmErrorPlatformTimerBadState; - } - - k_timer_start(&g_microtvm_timer, TIME_TIL_EXPIRY, TIME_TIL_EXPIRY); - g_microtvm_start_time = k_cycle_get_32(); - g_microtvm_timer_running = 1; - return kTvmErrorNoError; +// Initialize UART +void UartInit() { + // Claim console device. + g_microtvm_uart = DEVICE_DT_GET(DT_CHOSEN(zephyr_console)); + const struct uart_config config = {.baudrate = 115200, + .parity = UART_CFG_PARITY_NONE, + .stop_bits = UART_CFG_STOP_BITS_1, + .data_bits = UART_CFG_DATA_BITS_8, + .flow_ctrl = UART_CFG_FLOW_CTRL_NONE}; + uart_configure(g_microtvm_uart, &config); + uart_rx_init(&uart_rx_rbuf, g_microtvm_uart); } -// Called to stop system timer. -tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { - if (!g_microtvm_timer_running) { - TVMLogf("timer not running"); - return kTvmErrorSystemErrorMask | 2; - } - - uint32_t stop_time = k_cycle_get_32(); - - // compute how long the work took - uint32_t cycles_spent = stop_time - g_microtvm_start_time; - if (stop_time < g_microtvm_start_time) { - // we rolled over *at least* once, so correct the rollover it was *only* - // once, because we might still use this result - cycles_spent = ~((uint32_t)0) - (g_microtvm_start_time - stop_time); - } - - uint32_t ns_spent = (uint32_t)k_cyc_to_ns_floor64(cycles_spent); - double hw_clock_res_us = ns_spent / 1000.0; - - // need to grab time remaining *before* stopping. when stopped, this function - // always returns 0. - int32_t time_remaining_ms = k_timer_remaining_get(&g_microtvm_timer); - k_timer_stop(&g_microtvm_timer); - // check *after* stopping to prevent extra expiries on the happy path - if (time_remaining_ms < 0) { - return kTvmErrorSystemErrorMask | 3; - } - uint32_t num_expiries = k_timer_status_get(&g_microtvm_timer); - uint32_t timer_res_ms = ((num_expiries * MILLIS_TIL_EXPIRY) + time_remaining_ms); - double approx_num_cycles = - (double)k_ticks_to_cyc_floor32(1) * (double)k_ms_to_ticks_ceil32(timer_res_ms); - // if we approach the limits of the HW clock datatype (uint32_t), use the - // coarse-grained timer result instead - if (approx_num_cycles > (0.5 * (~((uint32_t)0)))) { - *elapsed_time_seconds = timer_res_ms / 1000.0; - } else { - *elapsed_time_seconds = hw_clock_res_us / 1e6; +static uint8_t uart_data[8]; +// UART interrupt callback. +void uart_irq_cb(const struct device* dev, void* user_data) { + while (uart_irq_update(dev) && uart_irq_is_pending(dev)) { + struct ring_buf* rbuf = (struct ring_buf*)user_data; + if (uart_irq_rx_ready(dev) != 0) { + for (;;) { + // Read a small chunk of data from the UART. + int bytes_read = uart_fifo_read(dev, uart_data, sizeof(uart_data)); + if (bytes_read < 0) { + TVMPlatformAbort((tvm_crt_error_t)(0xbeef1)); + } else if (bytes_read == 0) { + break; + } + // Write it into the ring buffer. + int bytes_written = ring_buf_put(rbuf, uart_data, bytes_read); + if (bytes_read != bytes_written) { + TVMPlatformAbort((tvm_crt_error_t)(0xbeef2)); + } + } + } } - - g_microtvm_timer_running = 0; - return kTvmErrorNoError; } -void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, int dtype_code_hint, - int dtype_bits_hint) { - tvm_crt_error_t err = kTvmErrorNoError; - void* ptr = 0; - DLDevice dev = {device_type, device_id}; - assert(nbytes > 0); - err = TVMPlatformMemoryAllocate(nbytes, dev, &ptr); - CHECK_EQ(err, kTvmErrorNoError, - "TVMBackendAllocWorkspace(%d, %d, %" PRIu64 ", %d, %d) -> %" PRId32, device_type, - device_id, nbytes, dtype_code_hint, dtype_bits_hint, err); - return ptr; +// Used to initialize the UART receiver. +void uart_rx_init(struct ring_buf* rbuf, const struct device* dev) { + uart_irq_callback_user_data_set(dev, uart_irq_cb, (void*)rbuf); + uart_irq_rx_enable(dev); } -int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { - tvm_crt_error_t err = kTvmErrorNoError; - DLDevice dev = {device_type, device_id}; - err = TVMPlatformMemoryFree(ptr, dev); - return err; +void TVMLogf(const char* msg, ...) { + char buffer[256]; + int size; + va_list args; + va_start(args, msg); + size = vsprintf(buffer, msg, args); + va_end(args); + UartTxWrite(buffer, (uint32_t)size); } -static uint8_t main_rx_buf[128]; -static uint8_t g_cmd_buf[128]; -static size_t g_cmd_buf_ind; - -void TVMInfer() { +void Infer() { struct tvmgen_default_inputs inputs = { .input_1 = input_data, }; @@ -180,8 +133,6 @@ void TVMInfer() { .Identity = output_data, }; - StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE); - double elapsed_time = 0; TVMPlatformTimerStart(); int ret_val = tvmgen_default_run(&inputs, &outputs); @@ -206,11 +157,11 @@ void TVMInfer() { // Execute functions based on received command void command_ready(char* command) { if (strncmp(command, CMD_INIT, CMD_SIZE) == 0) { - TVMPlatformWriteSerial(CMD_WAKEUP, sizeof(CMD_WAKEUP)); + UartTxWrite(CMD_WAKEUP, sizeof(CMD_WAKEUP)); } else if (strncmp(command, CMD_INFER, CMD_SIZE) == 0) { - TVMInfer(); + Infer(); } else { - TVMPlatformWriteSerial(CMD_READY, sizeof(CMD_READY)); + UartTxWrite(CMD_READY, sizeof(CMD_READY)); } } @@ -229,13 +180,13 @@ void serial_callback(char* message, int len_bytes) { } void main(void) { + TVMPlatformInitialize(); + UartInit(); g_cmd_buf_ind = 0; memset((char*)g_cmd_buf, 0, sizeof(g_cmd_buf)); - TVMPlatformUARTInit(); - k_timer_init(&g_microtvm_timer, NULL, NULL); while (true) { - int bytes_read = TVMPlatformUartRxRead(main_rx_buf, sizeof(main_rx_buf)); + int bytes_read = UartRxRead(main_rx_buf, sizeof(main_rx_buf)); if (bytes_read > 0) { serial_callback(main_rx_buf, bytes_read); } diff --git a/apps/microtvm/zephyr/template_project/src/aot_standalone_demo/platform.c b/apps/microtvm/zephyr/template_project/src/aot_standalone_demo/platform.c new file mode 100644 index 0000000000000..c66dad5711552 --- /dev/null +++ b/apps/microtvm/zephyr/template_project/src/aot_standalone_demo/platform.c @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Implementation of TVMPlatform functions in tvm/runtime/crt/platform.h + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "crt_config.h" +#include "dlpack/dlpack.h" +#include "tvmgen_default.h" + +// TVM_WORKSPACE_SIZE_BYTES defined in Project API Makefile +static uint8_t g_aot_memory[TVM_WORKSPACE_SIZE_BYTES]; +tvm_workspace_t app_workspace; + +#define MILLIS_TIL_EXPIRY 200 +#define TIME_TIL_EXPIRY (K_MSEC(MILLIS_TIL_EXPIRY)) +struct k_timer g_microtvm_timer; +uint32_t g_microtvm_start_time; +int g_microtvm_timer_running = 0; + +size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, + va_list args) { + return vsnprintk(out_buf, out_buf_size_bytes, fmt, args); +} + +void TVMPlatformAbort(tvm_crt_error_t error) { + TVMLogf("TVMPlatformAbort: %08x\n", error); + sys_reboot(SYS_REBOOT_COLD); + for (;;) + ; +} + +tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { + return StackMemoryManager_Allocate(&app_workspace, num_bytes, out_ptr); +} + +tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { + return StackMemoryManager_Free(&app_workspace, ptr); +} + +tvm_crt_error_t TVMPlatformInitialize() { + k_timer_init(&g_microtvm_timer, NULL, NULL); + StackMemoryManager_Init(&app_workspace, g_aot_memory, sizeof(g_aot_memory)); + return kTvmErrorNoError; +} + +tvm_crt_error_t TVMPlatformTimerStart() { + if (g_microtvm_timer_running) { + TVMLogf("timer already running"); + return kTvmErrorPlatformTimerBadState; + } + + k_timer_start(&g_microtvm_timer, TIME_TIL_EXPIRY, TIME_TIL_EXPIRY); + g_microtvm_start_time = k_cycle_get_32(); + g_microtvm_timer_running = 1; + return kTvmErrorNoError; +} + +tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { + if (!g_microtvm_timer_running) { + TVMLogf("timer not running"); + return kTvmErrorSystemErrorMask | 2; + } + + uint32_t stop_time = k_cycle_get_32(); + + // compute how long the work took + uint32_t cycles_spent = stop_time - g_microtvm_start_time; + if (stop_time < g_microtvm_start_time) { + // we rolled over *at least* once, so correct the rollover it was *only* + // once, because we might still use this result + cycles_spent = ~((uint32_t)0) - (g_microtvm_start_time - stop_time); + } + + uint32_t ns_spent = (uint32_t)k_cyc_to_ns_floor64(cycles_spent); + double hw_clock_res_us = ns_spent / 1000.0; + + // need to grab time remaining *before* stopping. when stopped, this function + // always returns 0. + int32_t time_remaining_ms = k_timer_remaining_get(&g_microtvm_timer); + k_timer_stop(&g_microtvm_timer); + // check *after* stopping to prevent extra expiries on the happy path + if (time_remaining_ms < 0) { + return kTvmErrorSystemErrorMask | 3; + } + uint32_t num_expiries = k_timer_status_get(&g_microtvm_timer); + uint32_t timer_res_ms = ((num_expiries * MILLIS_TIL_EXPIRY) + time_remaining_ms); + double approx_num_cycles = + (double)k_ticks_to_cyc_floor32(1) * (double)k_ms_to_ticks_ceil32(timer_res_ms); + // if we approach the limits of the HW clock datatype (uint32_t), use the + // coarse-grained timer result instead + if (approx_num_cycles > (0.5 * (~((uint32_t)0)))) { + *elapsed_time_seconds = timer_res_ms / 1000.0; + } else { + *elapsed_time_seconds = hw_clock_res_us / 1e6; + } + + g_microtvm_timer_running = 0; + return kTvmErrorNoError; +} diff --git a/apps/microtvm/zephyr/template_project/src/aot_standalone_demo/zephyr_uart.c b/apps/microtvm/zephyr/template_project/src/aot_standalone_demo/zephyr_uart.c deleted file mode 100644 index 8d5f912081660..0000000000000 --- a/apps/microtvm/zephyr/template_project/src/aot_standalone_demo/zephyr_uart.c +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include "zephyr_uart.h" - -#include -#include - -#include "crt_config.h" - -static const struct device* g_microtvm_uart; -#define RING_BUF_SIZE_BYTES (TVM_CRT_MAX_PACKET_SIZE_BYTES + 100) - -// Ring buffer used to store data read from the UART on rx interrupt. -RING_BUF_DECLARE(uart_rx_rbuf, RING_BUF_SIZE_BYTES); - -static uint8_t uart_data[8]; -// UART interrupt callback. -void uart_irq_cb(const struct device* dev, void* user_data) { - while (uart_irq_update(dev) && uart_irq_is_pending(dev)) { - struct ring_buf* rbuf = (struct ring_buf*)user_data; - if (uart_irq_rx_ready(dev) != 0) { - for (;;) { - // Read a small chunk of data from the UART. - int bytes_read = uart_fifo_read(dev, uart_data, sizeof(uart_data)); - if (bytes_read < 0) { - TVMPlatformAbort((tvm_crt_error_t)(0xbeef1)); - } else if (bytes_read == 0) { - break; - } - // Write it into the ring buffer. - int bytes_written = ring_buf_put(rbuf, uart_data, bytes_read); - if (bytes_read != bytes_written) { - TVMPlatformAbort((tvm_crt_error_t)(0xbeef2)); - } - } - } - } -} - -// Used to initialize the UART receiver. -void uart_rx_init(struct ring_buf* rbuf, const struct device* dev) { - uart_irq_callback_user_data_set(dev, uart_irq_cb, (void*)rbuf); - uart_irq_rx_enable(dev); -} - -uint32_t TVMPlatformUartRxRead(uint8_t* data, uint32_t data_size_bytes) { - unsigned int key = irq_lock(); - uint32_t bytes_read = ring_buf_get(&uart_rx_rbuf, data, data_size_bytes); - irq_unlock(key); - return bytes_read; -} - -uint32_t TVMPlatformWriteSerial(const char* data, uint32_t size) { - for (uint32_t i = 0; i < size; i++) { - uart_poll_out(g_microtvm_uart, data[i]); - } - return size; -} - -// Initialize UART -void TVMPlatformUARTInit() { - // Claim console device. - g_microtvm_uart = DEVICE_DT_GET(DT_CHOSEN(zephyr_console)); - const struct uart_config config = {.baudrate = 115200, - .parity = UART_CFG_PARITY_NONE, - .stop_bits = UART_CFG_STOP_BITS_1, - .data_bits = UART_CFG_DATA_BITS_8, - .flow_ctrl = UART_CFG_FLOW_CTRL_NONE}; - uart_configure(g_microtvm_uart, &config); - uart_rx_init(&uart_rx_rbuf, g_microtvm_uart); -} diff --git a/apps/microtvm/zephyr/template_project/src/aot_standalone_demo/zephyr_uart.h b/apps/microtvm/zephyr/template_project/src/aot_standalone_demo/zephyr_uart.h deleted file mode 100644 index 771cb490d0d63..0000000000000 --- a/apps/microtvm/zephyr/template_project/src/aot_standalone_demo/zephyr_uart.h +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_APPS_MICROTVM_ZEPHYR_AOT_STANDALONE_DEMO_ZEPHYR_UART_H_ -#define TVM_APPS_MICROTVM_ZEPHYR_AOT_STANDALONE_DEMO_ZEPHYR_UART_H_ - -#include - -// Used to read data from the UART. - -/*! - * \brief Read Uart Rx buffer. - * \param data Pointer to read data. - * \param data_size_bytes Read request size in bytes. - * - * \return Number of data read in bytes. - */ -uint32_t TVMPlatformUartRxRead(uint8_t* data, uint32_t data_size_bytes); - -/*! - * \brief Write data in serial. - * \param data Pointer to data to write. - * \param size Size of data in bytes. - * - * \return Number of write in bytes. - */ -uint32_t TVMPlatformWriteSerial(const char* data, uint32_t size); - -/*! - * \brief Initialize Uart. - */ -void TVMPlatformUARTInit(); - -#endif /* TVM_APPS_MICROTVM_ZEPHYR_AOT_STANDALONE_DEMO_ZEPHYR_UART_H_ */ diff --git a/apps/microtvm/zephyr/template_project/src/host_driven/fvp/semihost.c b/apps/microtvm/zephyr/template_project/src/host_driven/fvp/semihost.c index 2e03df096307d..f51aa47c9f71e 100644 --- a/apps/microtvm/zephyr/template_project/src/host_driven/fvp/semihost.c +++ b/apps/microtvm/zephyr/template_project/src/host_driven/fvp/semihost.c @@ -22,7 +22,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "semihost.h" +#include "tvm/semihost.h" int32_t stdout_fd; int32_t stdin_fd; diff --git a/apps/microtvm/zephyr/template_project/src/host_driven/main.c b/apps/microtvm/zephyr/template_project/src/host_driven/main.c index c01daab6d0d6b..1c63474817de6 100644 --- a/apps/microtvm/zephyr/template_project/src/host_driven/main.c +++ b/apps/microtvm/zephyr/template_project/src/host_driven/main.c @@ -29,7 +29,6 @@ * this logic into your own application. */ #include -#include #include #include #include @@ -37,15 +36,7 @@ #include #include #include -#include -#include -#include #include -#include - -#ifdef FVP -#include "fvp/semihost.h" -#endif #ifdef CONFIG_ARCH_POSIX #include "posix_board_if.h" @@ -53,15 +44,11 @@ #include "crt_config.h" -static const struct device* tvm_uart; +#ifdef FVP +#include "tvm/semihost.h" +#endif -#ifdef CONFIG_LED -#define LED0_NODE DT_ALIAS(led0) -#define LED0 DT_GPIO_LABEL(LED0_NODE, gpios) -#define LED0_PIN DT_GPIO_PIN(LED0_NODE, gpios) -#define LED0_FLAGS DT_GPIO_FLAGS(LED0_NODE, gpios) -static const struct device* led0_pin; -#endif // CONFIG_LED +static const struct device* tvm_uart; static size_t g_num_bytes_requested = 0; static size_t g_num_bytes_written = 0; @@ -69,20 +56,11 @@ static size_t g_num_bytes_in_rx_buffer = 0; // Called by TVM to write serial data to the UART. ssize_t uart_write(void* unused_context, const uint8_t* data, size_t size) { -#ifdef CONFIG_LED - gpio_pin_set(led0_pin, LED0_PIN, 1); -#endif g_num_bytes_requested += size; - for (size_t i = 0; i < size; i++) { uart_poll_out(tvm_uart, data[i]); g_num_bytes_written++; } - -#ifdef CONFIG_LED - gpio_pin_set(led0_pin, LED0_PIN, 0); -#endif - return size; } @@ -94,105 +72,6 @@ ssize_t serial_write(void* unused_context, const uint8_t* data, size_t size) { #endif } -// This is invoked by Zephyr from an exception handler, which will be invoked -// if the device crashes. Here, we turn on the LED and spin. -void k_sys_fatal_error_handler(unsigned int reason, const z_arch_esf_t* esf) { -#ifdef CONFIG_LED - gpio_pin_set(led0_pin, LED0_PIN, 1); -#endif - for (;;) - ; -} - -// Called by TVM when a message needs to be formatted. -size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, - va_list args) { - return vsnprintk(out_buf, out_buf_size_bytes, fmt, args); -} - -// Called by TVM when an internal invariant is violated, and execution cannot continue. -void TVMPlatformAbort(tvm_crt_error_t error) { - TVMLogf("TVMError: 0x%x", error); - sys_reboot(SYS_REBOOT_COLD); -#ifdef CONFIG_LED - gpio_pin_set(led0_pin, LED0_PIN, 1); -#endif - for (;;) - ; -} - -// Called by TVM to generate random data. -tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { - uint32_t random; // one unit of random data. - - // Fill parts of `buffer` which are as large as `random`. - size_t num_full_blocks = num_bytes / sizeof(random); - for (int i = 0; i < num_full_blocks; ++i) { - random = sys_rand32_get(); - memcpy(&buffer[i * sizeof(random)], &random, sizeof(random)); - } - - // Fill any leftover tail which is smaller than `random`. - size_t num_tail_bytes = num_bytes % sizeof(random); - if (num_tail_bytes > 0) { - random = sys_rand32_get(); - memcpy(&buffer[num_bytes - num_tail_bytes], &random, num_tail_bytes); - } - return kTvmErrorNoError; -} - -// Heap for use by TVMPlatformMemoryAllocate. -K_HEAP_DEFINE(tvm_heap, HEAP_SIZE_BYTES); - -// Called by TVM to allocate memory. -tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { - *out_ptr = k_heap_alloc(&tvm_heap, num_bytes, K_NO_WAIT); - return (*out_ptr == NULL) ? kTvmErrorPlatformNoMemory : kTvmErrorNoError; -} - -// Called by TVM to deallocate memory. -tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { - k_heap_free(&tvm_heap, ptr); - return kTvmErrorNoError; -} - -volatile timing_t g_microtvm_start_time, g_microtvm_end_time; -int g_microtvm_timer_running = 0; - -// Called to start system timer. -tvm_crt_error_t TVMPlatformTimerStart() { - if (g_microtvm_timer_running) { - TVMLogf("timer already running"); - return kTvmErrorPlatformTimerBadState; - } - -#ifdef CONFIG_LED - gpio_pin_set(led0_pin, LED0_PIN, 1); -#endif - g_microtvm_start_time = timing_counter_get(); - g_microtvm_timer_running = 1; - return kTvmErrorNoError; -} - -// Called to stop system timer. -tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { - if (!g_microtvm_timer_running) { - TVMLogf("timer not running"); - return kTvmErrorSystemErrorMask | 2; - } - -#ifdef CONFIG_LED - gpio_pin_set(led0_pin, LED0_PIN, 0); -#endif - - g_microtvm_end_time = timing_counter_get(); - uint64_t cycles = timing_cycles_get(&g_microtvm_start_time, &g_microtvm_end_time); - uint64_t ns_spent = timing_cycles_to_ns(cycles); - *elapsed_time_seconds = ns_spent / (double)1e9; - g_microtvm_timer_running = 0; - return kTvmErrorNoError; -} - // Ring buffer used to store data read from the UART on rx interrupt. // This ring buffer size is only required for testing with QEMU and not for physical hardware. #define RING_BUF_SIZE_BYTES (TVM_CRT_MAX_PACKET_SIZE_BYTES + 100) @@ -223,8 +102,6 @@ void uart_irq_cb(const struct device* dev, void* user_data) { if (err != 0) { TVMPlatformAbort((tvm_crt_error_t)0xbeef2); } - // CHECK_EQ(bytes_read, bytes_written, "bytes_read: %d; bytes_written: %d", bytes_read, - // bytes_written); } } } @@ -238,29 +115,12 @@ void uart_rx_init(struct ring_buf* rbuf, const struct device* dev) { // The main function of this application. extern void __stdout_hook_install(int (*hook)(int)); void main(void) { -#ifdef CONFIG_LED - int ret; - led0_pin = device_get_binding(LED0); - if (led0_pin == NULL) { - for (;;) - ; - } - ret = gpio_pin_configure(led0_pin, LED0_PIN, GPIO_OUTPUT_ACTIVE | LED0_FLAGS); - if (ret < 0) { - TVMPlatformAbort((tvm_crt_error_t)0xbeef4); - } - gpio_pin_set(led0_pin, LED0_PIN, 1); -#endif + TVMPlatformInitialize(); // Claim console device. tvm_uart = DEVICE_DT_GET(DT_CHOSEN(zephyr_console)); uart_rx_init(&uart_rx_rbuf, tvm_uart); - // Initialize system timing. We could stop and start it every time, but we'll - // be using it enough we should just keep it enabled. - timing_init(); - timing_start(); - #ifdef FVP init_semihosting(); // send some dummy log to speed up the initialization @@ -274,10 +134,6 @@ void main(void) { microtvm_rpc_server_t server = MicroTVMRpcServerInit(serial_write, NULL); TVMLogf("microTVM Zephyr runtime - running"); -#ifdef CONFIG_LED - gpio_pin_set(led0_pin, LED0_PIN, 0); -#endif - // The main application loop. We continuously read commands from the UART // and dispatch them to MicroTVMRpcServerLoop(). while (true) { diff --git a/apps/microtvm/zephyr/template_project/src/host_driven/platform.c b/apps/microtvm/zephyr/template_project/src/host_driven/platform.c new file mode 100644 index 0000000000000..8aa9abf235c73 --- /dev/null +++ b/apps/microtvm/zephyr/template_project/src/host_driven/platform.c @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Implementation of TVMPlatform functions in tvm/runtime/crt/platform.h + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +K_HEAP_DEFINE(tvm_heap, TVM_WORKSPACE_SIZE_BYTES); + +volatile timing_t g_microtvm_start_time, g_microtvm_end_time; +int g_microtvm_timer_running = 0; + +#ifdef CONFIG_LED +#define LED0_NODE DT_ALIAS(led0) +static const struct gpio_dt_spec led0 = GPIO_DT_SPEC_GET(LED0_NODE, gpios); +#endif // CONFIG_LED + +// This is invoked by Zephyr from an exception handler, which will be invoked +// if the device crashes. Here, we turn on the LED and spin. +void k_sys_fatal_error_handler(unsigned int reason, const z_arch_esf_t* esf) { +#ifdef CONFIG_LED + gpio_pin_set_dt(&led0, 1); +#endif + for (;;) + ; +} + +void TVMPlatformAbort(tvm_crt_error_t error) { + TVMLogf("TVMError: 0x%x", error); + sys_reboot(SYS_REBOOT_COLD); +#ifdef CONFIG_LED + gpio_pin_set_dt(&led0, 1); +#endif + for (;;) + ; +} + +size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, + va_list args) { + return vsnprintk(out_buf, out_buf_size_bytes, fmt, args); +} + +tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { + *out_ptr = k_heap_alloc(&tvm_heap, num_bytes, K_NO_WAIT); + return (*out_ptr == NULL) ? kTvmErrorPlatformNoMemory : kTvmErrorNoError; +} + +tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { + k_heap_free(&tvm_heap, ptr); + return kTvmErrorNoError; +} + +// Called to start system timer. +tvm_crt_error_t TVMPlatformTimerStart() { + if (g_microtvm_timer_running) { + TVMLogf("timer already running"); + return kTvmErrorPlatformTimerBadState; + } + +#ifdef CONFIG_LED + gpio_pin_set_dt(&led0, 1); +#endif + g_microtvm_start_time = timing_counter_get(); + g_microtvm_timer_running = 1; + return kTvmErrorNoError; +} + +// Called to stop system timer. +tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { + if (!g_microtvm_timer_running) { + TVMLogf("timer not running"); + return kTvmErrorSystemErrorMask | 2; + } + +#ifdef CONFIG_LED + gpio_pin_set_dt(&led0, 0); +#endif + + g_microtvm_end_time = timing_counter_get(); + uint64_t cycles = timing_cycles_get(&g_microtvm_start_time, &g_microtvm_end_time); + uint64_t ns_spent = timing_cycles_to_ns(cycles); + *elapsed_time_seconds = ns_spent / (double)1e9; + g_microtvm_timer_running = 0; + return kTvmErrorNoError; +} + +tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { + uint32_t random; // one unit of random data. + + // Fill parts of `buffer` which are as large as `random`. + size_t num_full_blocks = num_bytes / sizeof(random); + for (int i = 0; i < num_full_blocks; ++i) { + random = sys_rand32_get(); + memcpy(&buffer[i * sizeof(random)], &random, sizeof(random)); + } + + // Fill any leftover tail which is smaller than `random`. + size_t num_tail_bytes = num_bytes % sizeof(random); + if (num_tail_bytes > 0) { + random = sys_rand32_get(); + memcpy(&buffer[num_bytes - num_tail_bytes], &random, num_tail_bytes); + } + return kTvmErrorNoError; +} + +tvm_crt_error_t TVMPlatformInitialize() { +#ifdef CONFIG_LED + if (!device_is_ready(led0.port)) { + for (;;) + ; + } + int ret = gpio_pin_configure_dt(&led0, GPIO_OUTPUT_ACTIVE); + if (ret < 0) { + TVMPlatformAbort((tvm_crt_error_t)0xbeef4); + } + gpio_pin_set_dt(&led0, 0); +#endif + + // Initialize system timing. We could stop and start it every time, but we'll + // be using it enough we should just keep it enabled. + timing_init(); + timing_start(); + + return kTvmErrorNoError; +} diff --git a/apps/microtvm/zephyr/template_project/src/mlperftiny/platform.cc b/apps/microtvm/zephyr/template_project/src/mlperftiny/platform.cc new file mode 100644 index 0000000000000..9dc4516271df7 --- /dev/null +++ b/apps/microtvm/zephyr/template_project/src/mlperftiny/platform.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Implementation of TVMPlatform functions in tvm/runtime/crt/platform.h + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "crt_config.h" + +// TVM_WORKSPACE_SIZE_BYTES is defined in python +static uint8_t g_aot_memory[TVM_WORKSPACE_SIZE_BYTES]; +tvm_workspace_t app_workspace; + +size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, + va_list args) { + return vsnprintk(out_buf, out_buf_size_bytes, fmt, args); +} + +void TVMPlatformAbort(tvm_crt_error_t error) { + TVMLogf("TVMPlatformAbort: %08x\n", error); + sys_reboot(SYS_REBOOT_COLD); + for (;;) + ; +} + +tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { + return StackMemoryManager_Allocate(&app_workspace, num_bytes, out_ptr); +} + +tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { + return StackMemoryManager_Free(&app_workspace, ptr); +} + +tvm_crt_error_t TVMPlatformInitialize() { + StackMemoryManager_Init(&app_workspace, g_aot_memory, sizeof(g_aot_memory)); + return kTvmErrorNoError; +} diff --git a/apps/microtvm/zephyr/template_project/src/mlperftiny/submitter_implemented.cc b/apps/microtvm/zephyr/template_project/src/mlperftiny/submitter_implemented.cc index 91ae0c025c6e4..72b679c6408a6 100644 --- a/apps/microtvm/zephyr/template_project/src/mlperftiny/submitter_implemented.cc +++ b/apps/microtvm/zephyr/template_project/src/mlperftiny/submitter_implemented.cc @@ -19,20 +19,46 @@ #include "api/submitter_implemented.h" +#include +#include #include #include #include +#include #include +#include #include "api/internally_implemented.h" -#include "tvmruntime.h" -#include "zephyr_uart.h" +#include "crt_config.h" +#include "tvm/output_data.h" +#include "tvmgen_default.h" + +// ############################################################### +// Model +// ############################################################### +#define MODEL_KWS 1 +#define MODEL_VWW 2 +#define MODEL_AD 3 +#define MODEL_IC 4 static void* g_input_data; -#if TARGET_MODEL == 3 // AD +#if TARGET_MODEL == MODEL_AD static uint8_t __aligned(4) g_input_data_buffer_aligned[MAX_DB_INPUT_SIZE]; #endif +// OUT_QUANT_SCALE and OUT_QUANT_ZERO are set in python. +#if TARGET_MODEL == MODEL_AD +float* g_output_data = output_data; +#else +int8_t* g_output_data = output_data; +float g_quant_scale = OUT_QUANT_SCALE; +int8_t g_quant_zero = OUT_QUANT_ZERO; +#endif +size_t g_output_data_len = output_data_len; + +// ############################################################### +// GPIO +// ############################################################### #if EE_CFG_ENERGY_MODE == 1 && NRF_BOARD != 1 // use GPIO PC6 which is on connector CN7 pin 1 on the nucleo_l4r5zi static const char* g_gpio_device_name = "GPIOC"; @@ -40,23 +66,141 @@ static const struct device* g_gpio_dev; static const gpio_pin_t g_gpio_pin = 6; #endif +// ############################################################### +// UART +// ############################################################### +#define TVM_UART_DEFAULT_BAUDRATE 115200 +static const struct device* g_microtvm_uart; + +void UartInit(uint32_t baudrate = TVM_UART_DEFAULT_BAUDRATE) { + // Claim console device. + g_microtvm_uart = DEVICE_DT_GET(DT_CHOSEN(zephyr_console)); + const struct uart_config config = {.baudrate = baudrate, + .parity = UART_CFG_PARITY_NONE, + .stop_bits = UART_CFG_STOP_BITS_1, + .data_bits = UART_CFG_DATA_BITS_8, + .flow_ctrl = UART_CFG_FLOW_CTRL_NONE}; + uart_configure(g_microtvm_uart, &config); +} + +char UartRxRead() { + unsigned char c; + int ret = -1; + while (ret != 0) { + ret = uart_poll_in(g_microtvm_uart, &c); + } + return (char)c; +} + +uint32_t UartTxWrite(const char* data, uint32_t size) { + for (uint32_t i = 0; i < size; i++) { + uart_poll_out(g_microtvm_uart, data[i]); + } + return size; +} + +// ############################################################### +// TVM +// ############################################################### +#ifdef __cplusplus +extern "C" { +#endif +// TODO(mehrdadh): remove and reuse the CRT +// implementation in src/runtime/crt/common/crt_backend_api.c +void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, int dtype_code_hint, + int dtype_bits_hint) { + tvm_crt_error_t err = kTvmErrorNoError; + void* ptr = 0; + DLDevice dev = {(DLDeviceType)device_type, device_id}; + assert(nbytes > 0); + err = TVMPlatformMemoryAllocate(nbytes, dev, &ptr); + CHECK_EQ(err, kTvmErrorNoError, + "TVMBackendAllocWorkspace(%d, %d, %" PRIu64 ", %d, %d) -> %" PRId32, device_type, + device_id, nbytes, dtype_code_hint, dtype_bits_hint, err); + return ptr; +} + +// TODO(mehrdadh): remove and reuse the CRT +// implementation in src/runtime/crt/common/crt_backend_api.c +int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { + tvm_crt_error_t err = kTvmErrorNoError; + DLDevice dev = {(DLDeviceType)device_type, device_id}; + err = TVMPlatformMemoryFree(ptr, dev); + CHECK_EQ(err, kTvmErrorNoError, "TVMBackendFreeWorkspace(%d, %d)", device_type, device_id); + return err; +} + +void TVMLogf(const char* msg, ...) { + char buffer[128]; + int size; + va_list args; + va_start(args, msg); + size = TVMPlatformFormatMessage(buffer, 128, msg, args); + va_end(args); + UartTxWrite(buffer, (size_t)size); +} + +#ifdef __cplusplus +} // extern "C" +#endif + +void Infer(void* input_ptr) { + struct tvmgen_default_inputs inputs = { +#if TARGET_MODEL == MODEL_KWS + .input_1 = input_ptr, +#elif TARGET_MODEL == MODEL_IC + .input_1_int8 = input_ptr, +#elif TARGET_MODEL == MODEL_VWW + .input_1_int8 = input_ptr, +#elif TARGET_MODEL == MODEL_AD + .input_1 = input_ptr, +#elif +#error Wrong model. +#endif + }; + + struct tvmgen_default_outputs outputs = { +#if TARGET_MODEL == MODEL_KWS +#if COMPILE_WITH_CMSISNN + .Identity = output_data, +#else + .output = output_data, +#endif +#elif TARGET_MODEL == MODEL_IC + .Identity_int8 = output_data, +#elif TARGET_MODEL == MODEL_VWW + .Identity_int8 = output_data, +#elif TARGET_MODEL == MODEL_AD + .Identity = output_data, +#endif + }; + + int ret_val = tvmgen_default_run(&inputs, &outputs); + if (ret_val != 0) { + th_printf("Error: %d\n", ret_val); + } +} + +// ############################################################### +// MLPerftiny APIs +// ############################################################### // Implement this method to prepare for inference and preprocess inputs. // Modified from source void th_load_tensor() { -#if TARGET_MODEL == 1 // KWS +#if TARGET_MODEL == MODEL_KWS g_input_data = static_cast(ee_get_buffer_pointer()); -#elif TARGET_MODEL == 2 // VWW +#elif TARGET_MODEL == MODEL_VWW // Converting uint8 to int8 int8_t* temp_int = reinterpret_cast(ee_get_buffer_pointer()); for (size_t i = 0; i < MAX_DB_INPUT_SIZE; i++) { temp_int[i] -= 128; } g_input_data = static_cast(temp_int); -#elif TARGET_MODEL == 3 // AD +#elif TARGET_MODEL == MODEL_AD uint8_t* buffer = ee_get_buffer_pointer(); memcpy(g_input_data_buffer_aligned, buffer, sizeof(g_input_data_buffer_aligned)); g_input_data = g_input_data_buffer_aligned; -#elif TARGET_MODEL == 4 // IC +#elif TARGET_MODEL == MODEL_IC uint8_t* temp_uint = reinterpret_cast(ee_get_buffer_pointer()); int8_t* temp_int = reinterpret_cast(ee_get_buffer_pointer()); for (size_t i = 0; i < MAX_DB_INPUT_SIZE; i++) { @@ -71,7 +215,7 @@ void th_load_tensor() { #endif } -#if TARGET_MODEL == 3 // model AD +#if TARGET_MODEL == MODEL_AD // calculate |output - input| for AD model static float calculate_result() { size_t feature_size = g_output_data_len; @@ -95,7 +239,7 @@ void th_results() { * The results need to be printed back in exactly this format; if easier * to just modify this loop than copy to results[] above, do that. */ -#if TARGET_MODEL == 3 // model AD +#if TARGET_MODEL == MODEL_AD th_printf("m-results-[%0.3f]\r\n", calculate_result()); #else size_t kCategoryCount = g_output_data_len; @@ -114,11 +258,11 @@ void th_results() { // Implement this method with the logic to perform one inference cycle. // Modified from source -void th_infer() { TVMInfer(g_input_data); } +void th_infer() { Infer(g_input_data); } /// \brief optional API. // Modified from source -void th_final_initialize(void) { TVMRuntimeInit(); } +void th_final_initialize(void) { TVMPlatformInitialize(); } void th_pre() {} void th_post() {} @@ -156,18 +300,18 @@ void th_printf(const char* p_fmt, ...) { va_start(args, p_fmt); size = TVMPlatformFormatMessage(buffer, 128, p_fmt, args); va_end(args); - TVMPlatformWriteSerial(buffer, (size_t)size); + UartTxWrite(buffer, (size_t)size); } // Modified from source -char th_getchar() { return TVMPlatformUartRxRead(); } +char th_getchar() { return UartRxRead(); } // Modified from source void th_serialport_initialize(void) { #if EE_CFG_ENERGY_MODE == 1 && NRF_BOARD != 1 - TVMPlatformUARTInit(9600); + UartInit(9600); #else - TVMPlatformUARTInit(); + UartInit(); #endif } diff --git a/apps/microtvm/zephyr/template_project/src/mlperftiny/tvmruntime.cc b/apps/microtvm/zephyr/template_project/src/mlperftiny/tvmruntime.cc deleted file mode 100644 index 3fb7ccf8eb39f..0000000000000 --- a/apps/microtvm/zephyr/template_project/src/mlperftiny/tvmruntime.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "tvmruntime.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "output_data.h" -#include "tvmgen_default.h" -#include "zephyr_uart.h" - -#ifdef CONFIG_ARCH_POSIX -#include "posix_board_if.h" -#endif - -// OUT_QUANT_SCALE and OUT_QUANT_ZERO are set in python. -#if TARGET_MODEL == 3 -float* g_output_data = output_data; -#else -int8_t* g_output_data = output_data; -float g_quant_scale = OUT_QUANT_SCALE; -int8_t g_quant_zero = OUT_QUANT_ZERO; -#endif -size_t g_output_data_len = output_data_len; - -// WORKSPACE_SIZE is defined in python -static uint8_t g_aot_memory[WORKSPACE_SIZE]; -tvm_workspace_t app_workspace; - -size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, - va_list args) { - return vsnprintk(out_buf, out_buf_size_bytes, fmt, args); -} - -void TVMLogf(const char* msg, ...) { - char buffer[128]; - int size; - va_list args; - va_start(args, msg); - size = TVMPlatformFormatMessage(buffer, 128, msg, args); - va_end(args); - TVMPlatformWriteSerial(buffer, (size_t)size); -} - -void __attribute__((noreturn)) TVMPlatformAbort(tvm_crt_error_t error) { - TVMLogf("TVMPlatformAbort: %08x\n", error); - sys_reboot(SYS_REBOOT_COLD); - for (;;) - ; -} - -tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { - return StackMemoryManager_Allocate(&app_workspace, num_bytes, out_ptr); -} - -tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { - return StackMemoryManager_Free(&app_workspace, ptr); -} - -void timer_expiry_function(struct k_timer* timer_id) { return; } - -#ifdef __cplusplus -extern "C" { -#endif -void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, int dtype_code_hint, - int dtype_bits_hint) { - tvm_crt_error_t err = kTvmErrorNoError; - void* ptr = 0; - DLDevice dev = {(DLDeviceType)device_type, device_id}; - assert(nbytes > 0); - err = TVMPlatformMemoryAllocate(nbytes, dev, &ptr); - CHECK_EQ(err, kTvmErrorNoError, - "TVMBackendAllocWorkspace(%d, %d, %" PRIu64 ", %d, %d) -> %" PRId32, device_type, - device_id, nbytes, dtype_code_hint, dtype_bits_hint, err); - return ptr; -} - -int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { - tvm_crt_error_t err = kTvmErrorNoError; - DLDevice dev = {(DLDeviceType)device_type, device_id}; - err = TVMPlatformMemoryFree(ptr, dev); - CHECK_EQ(err, kTvmErrorNoError, "TVMBackendFreeWorkspace(%d, %d)", device_type, device_id); - return err; -} - -#ifdef __cplusplus -} // extern "C" -#endif - -void TVMRuntimeInit() { StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE); } - -void TVMInfer(void* input_ptr) { - struct tvmgen_default_inputs inputs = { -#if TARGET_MODEL == MODEL_KWS - .input_1 = input_ptr, -#elif TARGET_MODEL == MODEL_IC - .input_1_int8 = input_ptr, -#elif TARGET_MODEL == MODEL_VWW - .input_1_int8 = input_ptr, -#elif TARGET_MODEL == MODEL_AD - .input_1 = input_ptr, -#elif -#error Wrong model. -#endif - }; - - struct tvmgen_default_outputs outputs = { -#if TARGET_MODEL == MODEL_KWS -#if COMPILE_WITH_CMSISNN - .Identity = output_data, -#else - .output = output_data, -#endif -#elif TARGET_MODEL == MODEL_IC - .Identity_int8 = output_data, -#elif TARGET_MODEL == MODEL_VWW - .Identity_int8 = output_data, -#elif TARGET_MODEL == MODEL_AD - .Identity = output_data, -#endif - }; - - int ret_val = tvmgen_default_run(&inputs, &outputs); - if (ret_val != 0) { - TVMLogf("Error: %d\n", ret_val); - } -} - -int8_t QuantizeFloatToInt8(float value, float scale, int zero_point) { - int32_t result = round(value / scale) + zero_point; - if (result < INT8_MIN) { - result = INT8_MIN; - } - if (result > INT8_MAX) { - result = INT8_MAX; - } - return (int8_t)(result); -} diff --git a/apps/microtvm/zephyr/template_project/src/mlperftiny/tvmruntime.h b/apps/microtvm/zephyr/template_project/src/mlperftiny/tvmruntime.h deleted file mode 100644 index 940d64634d592..0000000000000 --- a/apps/microtvm/zephyr/template_project/src/mlperftiny/tvmruntime.h +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef APPS_MICROTVM_ZEPHYR_TEMPLATE_PROJECT_SRC_MLPERFTINY_TVMRUNTIME_H_ -#define APPS_MICROTVM_ZEPHYR_TEMPLATE_PROJECT_SRC_MLPERFTINY_TVMRUNTIME_H_ - -#include -#include -#include - -#define MODEL_KWS 1 -#define MODEL_VWW 2 -#define MODEL_AD 3 -#define MODEL_IC 4 - -extern const unsigned char g_wakeup_sequence[]; -extern size_t g_output_data_len; - -#if TARGET_MODEL == 3 -extern float* g_output_data; -#else -extern int8_t* g_output_data; -#endif - -extern float g_quant_scale; -extern int8_t g_quant_zero; - -/*! - * \brief Initialize TVM runtime. - */ -void TVMRuntimeInit(); - -/*! - * \brief Run TVM inference. - */ -void TVMInfer(void* input_ptr); - -/*! - * \brief Quantize float to int8. - * \param value Input data in float. - * \param scale Quantization scale factor. - * \param zero_point Quantization zero point. - */ -int8_t QuantizeFloatToInt8(float value, float scale, int zero_point); - -#endif /* APPS_MICROTVM_ZEPHYR_TEMPLATE_PROJECT_SRC_MLPERFTINY_TVMRUNTIME_H_ */ diff --git a/apps/microtvm/zephyr/template_project/src/mlperftiny/zephyr_uart.cc b/apps/microtvm/zephyr/template_project/src/mlperftiny/zephyr_uart.cc deleted file mode 100644 index 0922c32133634..0000000000000 --- a/apps/microtvm/zephyr/template_project/src/mlperftiny/zephyr_uart.cc +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "zephyr_uart.h" - -#include -#include -#include - -#include "crt_config.h" - -static const struct device* g_microtvm_uart; - -static uint8_t uart_data[8]; - -// UART interrupt callback. -void uart_irq_cb(const struct device* dev, void* user_data) { - while (uart_irq_update(dev) && uart_irq_is_pending(dev)) { - struct ring_buf* rbuf = (struct ring_buf*)user_data; - if (uart_irq_rx_ready(dev) != 0) { - for (;;) { - // Read a small chunk of data from the UART. - int bytes_read = uart_fifo_read(dev, uart_data, sizeof(uart_data)); - if (bytes_read < 0) { - TVMPlatformAbort((tvm_crt_error_t)(0xbeef1)); - } else if (bytes_read == 0) { - break; - } - // Write it into the ring buffer. - int bytes_written = ring_buf_put(rbuf, uart_data, bytes_read); - if (bytes_read != bytes_written) { - TVMPlatformAbort((tvm_crt_error_t)(0xbeef2)); - } - } - } - } -} - -// Initialize the UART receiver. -void uart_rx_init(struct ring_buf* rbuf, const struct device* dev) { - uart_irq_callback_user_data_set(dev, uart_irq_cb, (void*)rbuf); - uart_irq_rx_enable(dev); -} - -// UART read. -char TVMPlatformUartRxRead() { - unsigned char c; - int ret = -1; - while (ret != 0) { - ret = uart_poll_in(g_microtvm_uart, &c); - } - return (char)c; -} - -// UART write. -uint32_t TVMPlatformWriteSerial(const char* data, uint32_t size) { - for (uint32_t i = 0; i < size; i++) { - uart_poll_out(g_microtvm_uart, data[i]); - } - return size; -} - -// Initialize UART. -void TVMPlatformUARTInit(uint32_t baudrate /* = TVM_UART_DEFAULT_BAUDRATE */) { - // Claim console device. - g_microtvm_uart = DEVICE_DT_GET(DT_CHOSEN(zephyr_console)); - const struct uart_config config = {.baudrate = baudrate, - .parity = UART_CFG_PARITY_NONE, - .stop_bits = UART_CFG_STOP_BITS_1, - .data_bits = UART_CFG_DATA_BITS_8, - .flow_ctrl = UART_CFG_FLOW_CTRL_NONE}; - uart_configure(g_microtvm_uart, &config); -} diff --git a/apps/microtvm/zephyr/template_project/src/mlperftiny/zephyr_uart.h b/apps/microtvm/zephyr/template_project/src/mlperftiny/zephyr_uart.h deleted file mode 100644 index f10cf02622246..0000000000000 --- a/apps/microtvm/zephyr/template_project/src/mlperftiny/zephyr_uart.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef APPS_MICROTVM_ZEPHYR_TEMPLATE_PROJECT_SRC_MLPERFTINY_ZEPHYR_UART_H_ -#define APPS_MICROTVM_ZEPHYR_TEMPLATE_PROJECT_SRC_MLPERFTINY_ZEPHYR_UART_H_ - -#include - -#define TVM_UART_DEFAULT_BAUDRATE 115200 - -/*! - * \brief Read Uart Rx buffer. - * \param data Pointer to read data. - * \param data_size_bytes Read request size in bytes. - * - * \return Number of data read in bytes. - */ -char TVMPlatformUartRxRead(); - -/*! - * \brief Write data in serial. - * \param data Pointer to data to write. - * \param size Size of data in bytes. - * - * \return Number of write in bytes. - */ -uint32_t TVMPlatformWriteSerial(const char* data, uint32_t size); - -/*! - * \brief Initialize Uart. - * \param baudrate Desired UART baudrate. - */ -void TVMPlatformUARTInit(uint32_t baudrate = TVM_UART_DEFAULT_BAUDRATE); - -#endif /* APPS_MICROTVM_ZEPHYR_TEMPLATE_PROJECT_SRC_MLPERFTINY_ZEPHYR_UART_H_ */ diff --git a/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy b/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy index 479a4d2f8d630..4c748e3f20d78 100644 --- a/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy +++ b/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2023-02-02T20:12:16.769381 +// Generated at 2023-02-07T23:01:16.071376 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -554,7 +554,7 @@ def build() { ) cmake_build(ci_minimal, 'build', '-j2') sh( - script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/cpu-minimal-cross-isa --items build/libtvm.so build/libtvm_runtime.so build/config.cmake build/libtvm_allvisible.so", + script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/cpu-minimal-cross-isa --items build/libtvm.so build/libtvm_runtime.so build/config.cmake build/libtvm_allvisible.so build/standalone_crt build/build.ninja build/microtvm_template_projects", label: 'Upload artifacts to S3', ) } diff --git a/ci/jenkins/templates/minimal_cross_isa_jenkinsfile.groovy.j2 b/ci/jenkins/templates/minimal_cross_isa_jenkinsfile.groovy.j2 index f418b2a08ec4b..dce5ead041ac2 100644 --- a/ci/jenkins/templates/minimal_cross_isa_jenkinsfile.groovy.j2 +++ b/ci/jenkins/templates/minimal_cross_isa_jenkinsfile.groovy.j2 @@ -29,7 +29,7 @@ label: 'Create CPU minimal cmake config', ) cmake_build(ci_minimal, 'build', '-j2') - {{ m.upload_artifacts(tag='cpu-minimal-cross-isa', filenames=tvm_lib + tvm_allvisible) }} + {{ m.upload_artifacts(tag='cpu-minimal-cross-isa', filenames=tvm_lib + tvm_allvisible + standalone_crt + microtvm_template_projects) }} {% endcall %} diff --git a/cmake/modules/CRT.cmake b/cmake/modules/CRT.cmake index 518a613dc1022..3ddbb5298f847 100644 --- a/cmake/modules/CRT.cmake +++ b/cmake/modules/CRT.cmake @@ -22,8 +22,9 @@ if(USE_MICRO) CRT_TEMPLATE_FILE_COPY_JOBS "src/runtime/crt/host microtvm_api_server.py -> crt" "src/runtime/crt/host Makefile.template -> crt" - "src/runtime/crt crt_config-template.h -> crt" "src/runtime/crt/host main.cc -> crt/src" + "src/runtime/crt/host platform.cc -> crt/src" + "src/runtime/crt crt_config-template.h -> crt/crt_config" ) foreach(job_spec IN LISTS CRT_TEMPLATE_FILE_COPY_JOBS) @@ -66,6 +67,10 @@ if(USE_MICRO) endforeach() endforeach() + # Add template files for Model Library Format + configure_file("src/runtime/crt/crt_config-template.h" "${MICROTVM_TEMPLATE_PROJECTS}/crt/templates/crt_config.h.template" COPYONLY) + configure_file("src/runtime/crt/platform-template.c" "${MICROTVM_TEMPLATE_PROJECTS}/crt/templates/platform.c.template" COPYONLY) + add_custom_target(crt DEPENDS ${crt_template_deps}) endfunction() diff --git a/cmake/modules/Zephyr.cmake b/cmake/modules/Zephyr.cmake index a13aef33195f1..38551f1dd44d8 100644 --- a/cmake/modules/Zephyr.cmake +++ b/cmake/modules/Zephyr.cmake @@ -26,11 +26,9 @@ if(USE_MICRO) "apps/microtvm/zephyr/template_project boards.json -> zephyr" "apps/microtvm/zephyr/template_project CMakeLists.txt.template -> zephyr" "apps/microtvm/zephyr/template_project/src/aot_standalone_demo *.c -> zephyr/src/aot_standalone_demo" - "apps/microtvm/zephyr/template_project/src/aot_standalone_demo *.h -> zephyr/src/aot_standalone_demo" "apps/microtvm/zephyr/template_project/src/host_driven *.c -> zephyr/src/host_driven" "apps/microtvm/zephyr/template_project/src/host_driven *.h -> zephyr/src/host_driven" "apps/microtvm/zephyr/template_project/src/mlperftiny *.cc -> zephyr/src/mlperftiny" - "apps/microtvm/zephyr/template_project/src/mlperftiny *.h -> zephyr/src/mlperftiny" "3rdparty/mlperftiny/api * -> zephyr/src/mlperftiny/api" "apps/microtvm/zephyr/template_project/fvp-hack * -> zephyr/fvp-hack" "apps/microtvm/zephyr/template_project/qemu-hack * -> zephyr/qemu-hack" diff --git a/gallery/how_to/work_with_microtvm/micro_mlperftiny.py b/gallery/how_to/work_with_microtvm/micro_mlperftiny.py index e8c6a253ad2bc..bb7abddddce61 100644 --- a/gallery/how_to/work_with_microtvm/micro_mlperftiny.py +++ b/gallery/how_to/work_with_microtvm/micro_mlperftiny.py @@ -226,7 +226,7 @@ shape=output_shape, dtype=output_dtype, ), - "include", + "include/tvm", tf, ) diff --git a/gallery/how_to/work_with_microtvm/micro_pytorch.py b/gallery/how_to/work_with_microtvm/micro_pytorch.py index a7f5f10280476..a0f4ebddee485 100644 --- a/gallery/how_to/work_with_microtvm/micro_pytorch.py +++ b/gallery/how_to/work_with_microtvm/micro_pytorch.py @@ -131,7 +131,7 @@ # template_project_path = pathlib.Path(tvm.micro.get_microtvm_template_projects("crt")) -project_options = {"verbose": False, "memory_size_bytes": 6 * 1024 * 1024} +project_options = {"verbose": False, "workspace_size_bytes": 6 * 1024 * 1024} temp_dir = tvm.contrib.utils.tempdir() / "project" project = tvm.micro.generate_project( diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 274a421e57195..82403d7c40ee2 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -168,8 +168,9 @@ struct ScatterNDAttrs : public tvm::AttrsNode { String mode; TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") { - TVM_ATTR_FIELD(mode).describe( - "Accumulation mode of the scatter, either \"update\" or \"add\"."); + TVM_ATTR_FIELD(mode).set_default("update").describe( + "Accumulation mode of the ScatterND, " + "either \"update\", \"add\", \"mul\", \"min\" or \"max\"."); } }; diff --git a/include/tvm/runtime/crt/platform.h b/include/tvm/runtime/crt/platform.h index 1bc610e6cc538..85121fd0f520e 100644 --- a/include/tvm/runtime/crt/platform.h +++ b/include/tvm/runtime/crt/platform.h @@ -139,6 +139,15 @@ tvm_crt_error_t TVMPlatformAfterMeasurement(); */ tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes); +/*! \brief Initialize TVM inference. + * + * Placeholder function for TVM inference initializations on a specific platform. + * A common use of this function is setting up workspace memory for TVM inference. + * + * \return kTvmErrorNoError if successful. + */ +tvm_crt_error_t TVMPlatformInitialize(); + #ifdef __cplusplus } // extern "C" #endif diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index b6a4cfe453c14..119d0f7fd3395 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -42,14 +42,6 @@ namespace tvm { // alias DLDevice using Device = DLDevice; -// A 'null' device type, does not correspond to any DLDeviceType enum. -// TODO(mbs): This is to help us as we transition away from representing the 'homogenous' case -// as a singleton target map indexed by the invalid DLDeviceType '0'. -constexpr DLDeviceType kNullDeviceType = static_cast(0); - -// An 'invalid' device type, does not correspond to any DLDeviceType enum. -constexpr DLDeviceType kInvalidDeviceType = static_cast(-1); - namespace runtime { /*! diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 5cba879205807..d5cc1de5c675d 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -415,12 +415,12 @@ void Prefetch(Buffer buffer, Array bounds); void Evaluate(PrimExpr value); /*! - * \brief The pointer declaration function. + * \brief Create a TIR var that represents a pointer * \param dtype The data type of the pointer. * \param storage_scope The storage scope of the pointer. * \return The pointer. */ -PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global"); +Var Handle(runtime::DataType dtype = runtime::DataType::Void(), String storage_scope = "global"); #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ inline PrimExpr FuncName(Optional expr = NullOpt) { \ @@ -455,7 +455,6 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle()); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); #undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index c26ae5befe66a..9d8c91403309a 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -44,6 +44,16 @@ namespace tvm { */ using MemoryScope = String; +// NOTE: cannot use enum as they are out of bound of the original enum +// and results in an undefined behavior +// A 'null' device type, does not correspond to any DLDeviceType enum. +// TODO(mbs): This is to help us as we transition away from representing the 'homogenous' case +// as a singleton target map indexed by the invalid DLDeviceType '0'. +constexpr int kNullDeviceType = 0; + +// An 'invalid' device type, does not correspond to any DLDeviceType enum. +constexpr int kInvalidDeviceType = -1; + /*! * \brief Describes at compile time the constraints on where data is to be stored at runtime * down to the (virtual) device and memory scope level, and how to compile code to compute that @@ -229,7 +239,7 @@ class VirtualDeviceNode : public AttrsNode { * Physical Devices" above. */ Device ToDevice() const { - ICHECK(device_type() != kInvalidDeviceType); + ICHECK(device_type_int != kInvalidDeviceType); ICHECK(virtual_device_id != -1); Device device; device.device_type = device_type(); @@ -262,7 +272,7 @@ class VirtualDevice : public ObjectRef { public: /*! * \brief Construct a virtual device. - * \param device_type The device type for the virtual device, or \p kInvalidDeviceType if + * \param device_type_int The device type for the virtual device, or \p kInvalidDeviceType if * unconstrained. If \p target is defined then must match its \p target->GetTargetDeviceType(). * \param virtual_device_id The device id for the virtual device, or -1 if unconstrained. * \param target The target describing how to compile for the virtual device, or null if @@ -271,7 +281,7 @@ class VirtualDevice : public ObjectRef { * unconstrained. * \return The virtual device. */ - explicit VirtualDevice(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1, + explicit VirtualDevice(int device_type_int = kInvalidDeviceType, int virtual_device_id = -1, Target target = {}, MemoryScope memory_scope = {}); /*! \brief Returns the unique fully unconstrained \p VirtualDevice. */ @@ -349,7 +359,7 @@ class VirtualDevice : public ObjectRef { class VirtualDeviceCache { public: /*! \brief Returns the unique \p VirtualDevice representing given fields. */ - VirtualDevice Make(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1, + VirtualDevice Make(int device_type = kInvalidDeviceType, int virtual_device_id = -1, Target target = {}, MemoryScope memory_scope = {}); /*! diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py index 9bdf715756cc5..b249be7ded749 100644 --- a/python/tvm/meta_schedule/runner/rpc_runner.py +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -27,7 +27,6 @@ from ..logging import get_logger from ..profiler import Profiler from ..utils import ( - cpu_count, derived_object, get_global_func_on_rpc_session, get_global_func_with_default_on_worker, @@ -270,7 +269,7 @@ def __init__( f_cleanup: Union[T_CLEANUP, str, None] The function name to cleanup the session or the function itself. max_workers: Optional[int] = None - The maximum number of connections. Defaults to number of logical CPU cores. + The maximum number of connections. Defaults to 1. initializer: Optional[Callable[[], None]] The initializer function. """ @@ -285,11 +284,10 @@ def __init__( self.f_run_evaluator = f_run_evaluator self.f_cleanup = f_cleanup if max_workers is None: - max_workers = cpu_count(logical=True) + max_workers = 1 logger.info("RPCRunner: max_workers = %d", max_workers) self.pool = PopenPoolExecutor( max_workers=max_workers, - timeout=rpc_config.session_timeout_sec, initializer=initializer, ) self._sanity_check() diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index fc32fe34d6c91..b16877b256456 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -26,7 +26,7 @@ import typing import tvm -from tvm.micro import get_standalone_crt_dir +from tvm.micro import get_standalone_crt_dir, get_microtvm_template_projects from .._ffi import get_global_func from ..contrib import utils @@ -39,6 +39,7 @@ # This should be kept identical to runtime::symbol::tvm_module_main MAIN_FUNC_NAME_STR = "__tvm_main__" STANDALONE_CRT_URL = "./runtime" +CRT_TEMPLATE_FILES_URL = "./templates" METADATA_FILE = "metadata.json" @@ -373,7 +374,18 @@ def reset(tarinfo): for mod in modules: is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule) if is_aot and str(mod.runtime) == "crt": + crt_template_path = pathlib.Path(get_microtvm_template_projects("crt")) tar_f.add(get_standalone_crt_dir(), arcname=STANDALONE_CRT_URL) + + # Add template files from CRT template project + for file in [ + "templates/crt_config.h.template", + "templates/platform.c.template", + ]: + tar_f.add( + crt_template_path / pathlib.Path(file), + arcname=f"{CRT_TEMPLATE_FILES_URL}/{pathlib.Path(file).name}", + ) break diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 8de5e0e08bd8c..aebc6daa5ebe6 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2856,12 +2856,63 @@ def _impl_v1(cls, inputs, attr, params): class ScatterND(OnnxOpConverter): """Operator converter for ScatterND.""" + @classmethod + def _inputs_check(cls, inputs): + assert ( + len(inputs) == 3 + ), "ScatterND takes 3 inputs (data, indices, updates), {} given".format(len(inputs)) + assert infer_type(inputs[1]).checked_type.dtype == "int64" + + data_rank = len(infer_shape(inputs[0])) + assert data_rank > 0, "Data rank higher than 0 is expected" + indices_rank = len(infer_shape(inputs[1])) + assert indices_rank > 0, "Indices rank higher than 0 is expected" + updates_rank = len(infer_shape(inputs[2])) + assert ( + updates_rank == data_rank + indices_rank - infer_shape(inputs[1])[-1] - 1 + ), "Updates rank should be equal to data_rank + indices_rank - indices_shape[-1] - 1" + + @classmethod + def _reduction_check(cls, attr, red_valids=None): + reduction = attr.get("reduction", None) + if reduction is None: + reduction = b"update" + reduction = reduction.decode("utf-8") + if red_valids is None: + red_valids = ["update"] + assert reduction in red_valids, "Only {} reductions are supported, but {} is gotten".format( + red_valids, reduction + ) + + return reduction + @classmethod def _impl_v11(cls, inputs, attr, params): + cls._inputs_check(inputs) + indices_dim = len(infer_shape(inputs[1])) + axes = list(range(indices_dim)) + return _op.scatter_nd(inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2]) + + @classmethod + def _impl_v16(cls, inputs, attr, params): + cls._inputs_check(inputs) + reduction = cls._reduction_check(attr, ["update", "add", "mul"]) + + indices_dim = len(infer_shape(inputs[1])) + axes = list(range(indices_dim)) + return _op.scatter_nd( + inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], reduction + ) + + @classmethod + def _impl_v18(cls, inputs, attr, params): + cls._inputs_check(inputs) + reduction = cls._reduction_check(attr, ["update", "add", "mul", "min", "max"]) + indices_dim = len(infer_shape(inputs[1])) axes = list(range(indices_dim)) return _op.scatter_nd( - inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], "update" + inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], reduction ) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c7234f3403955..782797dadb830 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -420,7 +420,13 @@ def scatter_nd(data, indices, updates, mode="update"): The values to update. mode : string, optional - The accumulation mode for scatter. "update" or "add" + The accumulation mode for scatter. "update", "add", "mul", "min" or "max" + If update, the update values will replace the input data + If add, the update values will be added to the input data + If mul, the update values will be multiply to the input data + If min, there is choice of minimal between the update values and the input data + If max, there is choice of maximal between the update values and the input data + It is "update" by default Returns ------- diff --git a/python/tvm/script/ir_builder/tir/__init__.py b/python/tvm/script/ir_builder/tir/__init__.py index 563ac56f7b109..db2fc6aca0958 100644 --- a/python/tvm/script/ir_builder/tir/__init__.py +++ b/python/tvm/script/ir_builder/tir/__init__.py @@ -17,4 +17,4 @@ """Package tvm.script.ir_builder.tir""" from .ir import * # pylint: disable=wildcard-import,redefined-builtin from .ir import boolean as bool # pylint: disable=redefined-builtin -from .ir import buffer_decl as Buffer +from .ir import buffer as Buffer diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index fdb27df2a9d1c..5f4e9d4f2cf0b 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -86,7 +86,7 @@ # pylint: enable=unused-import -def buffer_decl( +def buffer( shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], dtype: str = "float32", data: Var = None, @@ -138,7 +138,7 @@ def buffer_decl( The declared buffer. """ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape - return _ffi_api.BufferDecl( # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Buffer( # type: ignore[attr-defined] # pylint: disable=no-member shape, dtype, "", @@ -153,6 +153,11 @@ def buffer_decl( ) +@deprecated("T.buffer_decl(...)", "T.Buffer(...)") +def buffer_decl(*args, **kwargs): + return buffer(*args, **kwargs) + + def prim_func() -> frame.PrimFuncFrame: """The primitive function statement. @@ -1177,7 +1182,11 @@ def env_thread(thread_tag: str) -> IterVar: return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member -def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, slice]]) -> None: +def buffer_store( + buffer: Buffer, # pylint: disable=redefined-outer-name + value: PrimExpr, + indices: List[Union[PrimExpr, slice]], +) -> None: """Buffer store node. Parameters @@ -1211,7 +1220,10 @@ def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr, ) -def prefetch(buffer: Buffer, bounds: List[Range]) -> None: +def prefetch( + buffer: Buffer, # pylint: disable=redefined-outer-name + bounds: List[Range], +) -> None: """The prefetch hint for a buffer. Parameters @@ -1358,20 +1370,23 @@ def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr: return _ffi_api.Boolean(expr) # type: ignore[attr-defined] # pylint: disable=no-member -def handle(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type handle or cast expression to type handle. +def handle(dtype: str = "void", storage_scope: str = "global") -> Var: + """Create a TIR var that represents a pointer. Parameters ---------- - expr: PrimExpr - The expression to be cast. + dtype: str + The data type of the pointer. + + storage_scope: str + The storage scope of the pointer. Returns ------- res : PrimExpr The new tir.Var with type handle or casted expression with type handle. """ - return _ffi_api.Handle(expr) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Handle(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member def void(expr: Optional[PrimExpr] = None) -> PrimExpr: @@ -1390,6 +1405,7 @@ def void(expr: Optional[PrimExpr] = None) -> PrimExpr: return _ffi_api.Void(expr) # type: ignore[attr-defined] # pylint: disable=no-member +@deprecated("T.var", "T.{dtype}") def var(dtype: str, name: str = "") -> Var: """Construct a new tir.Var. @@ -1428,7 +1444,7 @@ def ptr(dtype: str, storage_scope: str = "global") -> Var: return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member -@deprecated("T.buffer_var", "T.Ptr") +@deprecated("T.buffer_var", "T.handle") def buffer_var(dtype: str, storage_scope: str = "global") -> Var: """The pointer declaration function. @@ -1811,6 +1827,7 @@ def wrapped(*args, **kwargs): "float16x64", "float32x64", "float64x64", + "buffer", "buffer_decl", "prim_func", "arg", diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index bf6a118672df6..9e6c100c954d8 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -41,10 +41,12 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) The parsed TVMScript program. """ if extra_vars is None: + import tvm # pylint: disable=import-outside-toplevel from tvm.script.parser import ir # pylint: disable=import-outside-toplevel from tvm.script.parser import tir # pylint: disable=import-outside-toplevel extra_vars = { + "tvm": tvm, "I": ir, "ir": ir, "T": tir, diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index 9532e7e32c005..e0268412d284e 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -64,3 +64,14 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None: node : doc.ClassDef The doc AST expression node. """ + + +@dispatch.register(token="default", type_name="Assign") +def visit_assign(self: Parser, node: doc.Assign) -> None: + if len(node.targets) != 1: + self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") + lhs = node.targets[0] + rhs = self.eval_expr(node.value) + self.eval_assign( + target=lhs, source=rhs, bind_value=lambda _a, _b, _c, value: value, allow_shadowing=True + ) diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index bacf92c14287e..411a7f8f3c83d 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -21,7 +21,7 @@ from tvm.ir.base import deprecated from tvm.tir import Buffer, PrimFunc -from ...ir_builder.tir import buffer_decl, ptr +from ...ir_builder.tir import buffer, ptr from .._core import parse, utils @@ -49,9 +49,7 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]: class BufferProxy: - """Buffer proxy class for constructing tir buffer. - Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer(). - """ + """Buffer proxy class for constructing tir buffer.""" def __call__( self, @@ -66,7 +64,7 @@ def __call__( buffer_type="", axis_separators=None, ) -> Buffer: - return buffer_decl( + return buffer( shape, dtype=dtype, data=data, @@ -79,7 +77,7 @@ def __call__( axis_separators=axis_separators, ) - @deprecated("T.Buffer(...)", "T.Buffer(...)") + @deprecated("T.Buffer[...]", "T.Buffer(...)") def __getitem__(self, keys) -> Buffer: if not isinstance(keys, tuple): return self(keys) @@ -89,16 +87,15 @@ def __getitem__(self, keys) -> Buffer: class PtrProxy: - """Ptr proxy class for constructing tir pointer. - Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr(). - """ + """Ptr proxy class for constructing tir pointer.""" + @deprecated("T.Ptr(...)", "T.handle(...)") def __call__(self, dtype, storage_scope="global"): if callable(dtype): dtype = dtype().dtype return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore - @deprecated("T.Ptr(...)", "T.Ptr(...)") + @deprecated("T.Ptr[...]", "T.handle(...)") def __getitem__(self, keys): if not isinstance(keys, tuple): return self(keys) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 0e74114ba29cf..fbef1a969179f 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -143,7 +143,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - IRBuilder.name(var_name, value) return value elif isinstance(value, PrimExpr): - var = T.var(value.dtype) + var = Var("", value.dtype) IRBuilder.name(var_name, var) frame = T.let(var, value) frame.add_callback(partial(frame.__exit__, None, None, None)) diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index 5ddbdcabacc9a..a975eb95bcf01 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -725,8 +725,8 @@ def run_and_check_body(base_path): os.mkdir(include_path) crt_root = tvm.micro.get_microtvm_template_projects("crt") shutil.copy2( - os.path.join(crt_root, "crt_config-template.h"), - os.path.join(include_path, "crt_config.h"), + pathlib.Path(crt_root) / "crt_config" / "crt_config-template.h", + pathlib.Path(include_path) / "crt_config.h", ) workspace_bytes = 0 diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 0703811ea79f5..6483b99454a36 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -146,8 +146,8 @@ def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None: @T.prim_func def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: - s0 = T.var("int32") - s1 = T.var("int32") + s0 = T.int32() + s1 = T.int32() shared = T.match_buffer( shared_handle, shmem_shape, @@ -385,8 +385,8 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None: @T.prim_func def mma_store_impl(a: T.handle, c: T.handle) -> None: - s0 = T.var("int32") - s1 = T.var("int32") + s0 = T.int32() + s1 = T.int32() C_warp = T.match_buffer( a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 @@ -530,10 +530,10 @@ def wmma_load_desc(a: T.handle, c: T.handle) -> None: @T.prim_func def wmma_load_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - d1 = T.var("int32") - d0 = T.var("int32") + s1 = T.int32() + s0 = T.int32() + d1 = T.int32() + d0 = T.int32() A = T.match_buffer( a, (m_dim, n_dim), @@ -593,8 +593,8 @@ def wmma_fill_desc(c: T.handle) -> None: @T.prim_func def wmma_fill_impl(c: T.handle) -> None: - d1 = T.var("int32") - d0 = T.var("int32") + d1 = T.int32() + d0 = T.int32() C = T.match_buffer( c, (m_dim, n_dim), @@ -643,10 +643,10 @@ def wmma_store_desc(a: T.handle, c: T.handle) -> None: @T.prim_func def wmma_store_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - d1 = T.var("int32") - d0 = T.var("int32") + s1 = T.int32() + s0 = T.int32() + d1 = T.int32() + d0 = T.int32() A = T.match_buffer( a, (m_dim, n_dim), @@ -726,12 +726,12 @@ def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: @T.prim_func def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: - a1 = T.var("int32") - a0 = T.var("int32") - b1 = T.var("int32") - b0 = T.var("int32") - c1 = T.var("int32") - c0 = T.var("int32") + a1 = T.int32() + a0 = T.int32() + b1 = T.int32() + b0 = T.int32() + c1 = T.int32() + c0 = T.int32() A = T.match_buffer( a, diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index fa7545cd323a4..1bdd531566230 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument """Scatter operator """ import tvm -from tvm import te, autotvm +from tvm import te, tir, autotvm from ..scatter import _verify_scatter_nd_inputs from ..generic import schedule_extern from .nms import atomic_add @@ -871,8 +871,20 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): out[index] = updates[i * fused_updates_dimension + j] elif mode == "add": out[index] += updates[i * fused_updates_dimension + j] + elif mode == "mul": + out[index] *= updates[i * fused_updates_dimension + j] + elif mode == "min": + out[index] = tir.min( + out[index], updates[i * fused_updates_dimension + j] + ) + elif mode == "max": + out[index] = tir.max( + out[index], updates[i * fused_updates_dimension + j] + ) else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) + raise NotImplementedError( + "scatter_nd mode not in [update, add, mul, min, max]:", mode + ) return ib.get() diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index e0578aab41b9b..45629c005f799 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -16,11 +16,11 @@ # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Scatter operator""" -from ..te import extern, hybrid -from ..tir import decl_buffer, expr, ir_builder +from tvm import te, tir # hide redefinition of min and max +from tvm.tir import expr -@hybrid.script +@te.hybrid.script def _scatter_1d(data, indices, updates): out = output_tensor(data.shape, data.dtype) for i in range(data.shape[0]): @@ -30,7 +30,7 @@ def _scatter_1d(data, indices, updates): return out -@hybrid.script +@te.hybrid.script def _scatter_2d(data, indices, updates, axis): out = output_tensor(data.shape, data.dtype) for i in range(data.shape[0]): @@ -52,7 +52,7 @@ def _scatter_2d(data, indices, updates, axis): return out -@hybrid.script +@te.hybrid.script def _scatter_3d(data, indices, updates, axis): out = output_tensor(data.shape, data.dtype) for i in range(data.shape[0]): @@ -96,7 +96,7 @@ def _scatter_3d(data, indices, updates, axis): return out -@hybrid.script +@te.hybrid.script def _scatter_4d(data, indices, updates, axis): out = output_tensor(data.shape, data.dtype) for i in range(data.shape[0]): @@ -269,7 +269,7 @@ def scatter_nd(data, indices, updates, mode): def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # pylint: disable=invalid-name - ib = ir_builder.create() + ib = tir.ir_builder.create() data = ib.buffer_ptr(data_ptr) indices = ib.buffer_ptr(indices_ptr) @@ -308,13 +308,21 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): out[index] = updates[i * fused_updates_dimension + j] elif mode == "add": out[index] += updates[i * fused_updates_dimension + j] + elif mode == "mul": + out[index] *= updates[i * fused_updates_dimension + j] + elif mode == "min": + out[index] = tir.min(out[index], updates[i * fused_updates_dimension + j]) + elif mode == "max": + out[index] = tir.max(out[index], updates[i * fused_updates_dimension + j]) else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) + raise NotImplementedError( + "scatter_nd mode not in [update, add, mul, min, max]:", mode + ) return ib.get() - out_buf = decl_buffer(data.shape, data.dtype, "out_buf") - return extern( + out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") + return te.extern( [data.shape], [data, indices, updates], lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), diff --git a/python/tvm/utils/roofline/cuda.py b/python/tvm/utils/roofline/cuda.py index 5d80c808801b6..b83a902b7fda4 100644 --- a/python/tvm/utils/roofline/cuda.py +++ b/python/tvm/utils/roofline/cuda.py @@ -299,7 +299,7 @@ def estimate_peak_flops( @T.prim_func def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: T.int32) -> None: # pylint: disable=invalid-name, missing-function-docstring - N = T.var("int32") + N = T.int32() A = T.match_buffer(a, [blocks, N, 4, warp_size], "float32") B = T.match_buffer(b, [blocks, 4, warp_size], "float32") for i in T.thread_binding(blocks, "blockIdx.x"): diff --git a/python/tvm/utils/roofline/x86.py b/python/tvm/utils/roofline/x86.py index 37a666d2527ad..5d2dd27e523b3 100644 --- a/python/tvm/utils/roofline/x86.py +++ b/python/tvm/utils/roofline/x86.py @@ -216,7 +216,7 @@ def estimate_peak_fma_flops( @T.prim_func def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width: T.int32) -> None: # pylint: disable=invalid-name, missing-function-docstring - N = T.var("int32") + N = T.int32() A = T.match_buffer(a, [threads, N, 4, vec_width], "float32") B = T.match_buffer(b, [threads, 4, vec_width], "float32") # Parallelism is necessary to hit all cores/nodes diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index 8cbde46f83b71..533a86acacfd3 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -85,7 +85,7 @@ class MemoryDatabaseNode : public DatabaseNode { } std::stable_sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs()); if (results.size() > static_cast(top_k)) { - return {results.begin(), results.end() + top_k}; + return {results.begin(), results.begin() + top_k}; } else { if (results.size() < static_cast(top_k)) { LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not " diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 324eedafb98a1..bda088bc5f1e8 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -87,6 +87,21 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target"; } } + if (Optional opt_sm = context->target.value()->GetAttr("arch")) { + std::string sm = opt_sm.value(); + if (support::StartsWith(sm, "sm_")) { + sm = sm.substr(3); + try { + // only sm_80 or higher supports async memcopy + if (std::stoi(sm) >= 80) { + this->stages.insert(this->stages.end(), {4, 5}); + } + } catch (const std::invalid_argument& e) { + LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm + << ". Details: " << e.what(); + } + } + } logger = context->logger; } @@ -115,6 +130,9 @@ std::vector MultiLevelTilingNode::ApplySubRules(std::vector states states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); }); states = SubRule(std::move(states), [&](State state) { return AddReadReuse(std::move(state)); }); + states = SubRule(std::move(states), [&](State state) { + return AddAsyncPipeline(std::move(state)); + }); return states; } @@ -280,6 +298,43 @@ std::vector MultiLevelTilingNode::AddReadReuse(State state) const { return results; } +std::vector MultiLevelTilingNode::AddAsyncPipeline(State state) const { + // For arch that does not support async pipeline, this->stages will be an empty vector + if (r_indices_.size() < 1 || this->stages.empty()) { + return {state}; + } + // Current only support default config used by ScheduleRule::DefaultCUDA + // @see src/meta_schedule/schedule_rule/schedule_rule.cc + // check the reduce loop contains exactly 3 for loops + // therefore it matches the notation array size in the following code + tir::StmtSRef r_loop_sref = state->sch->GetSRef(state->tiles[r_indices_[0]].back()); + const tir::ForNode* r_for_loop = TVM_SREF_TO_FOR(r_loop_sref); + Array seq = Downcast(r_for_loop->body)->seq; + if (seq.size() != 3) { + return {state}; + } + for (auto& stmt : seq) { + if (!stmt.as()) { + return {state}; + } + } + + LoopRV r_loop_fused = state->sch->Fuse(state->tiles[r_indices_[0]]); + std::vector ret; + ret.push_back(state); + for (int stage : this->stages) { + State new_state = state->Copy(); + new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_stage, + Array{0, 0, stage - 2}); + new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_order, + Array{0, 1, 2}); + new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_async_stages, + Array{0}); + ret.push_back(std::move(new_state)); + } + return ret; +} + void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, const tir::BlockRV& block) const { // Filter out invalid vector lanes according to the data type. diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index d8725a3060b1e..ff38756ff06be 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -148,6 +148,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode { std::vector TileLoopNest(State state) const; // SubRule 3. add read cache std::vector AddReadReuse(State state) const; + // SubRule 4. add async pipeline + std::vector AddAsyncPipeline(State state) const; // Do nothing; Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final; @@ -192,6 +194,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode { int thread_warp_size_; /*! \brief The maximum number of threads to be used size of a thread warp */ int max_threads_per_block_; + /*! \brief All available async pipeline stages. */ + std::vector stages; /*! \brief The logging function */ PackedFunc logger; /*! \brief The function to overwrite the default condition for applying MultiLevelTiling. */ diff --git a/src/relay/analysis/graph_partitioner.cc b/src/relay/analysis/graph_partitioner.cc new file mode 100644 index 0000000000000..861fd58d9e5c8 --- /dev/null +++ b/src/relay/analysis/graph_partitioner.cc @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./graph_partitioner.h" + +#include + +namespace tvm { +namespace relay { + +DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { + DominatorTree tree; + tree.nodes.resize(graph.post_dfs_order.size(), nullptr); + // reverse topo order + for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { + size_t index = i - 1; + tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); + } + return tree; +} + +DominatorTree::Node* DominatorTree::LeastCommonAncestor(Node* lhs, Node* rhs, + OpPatternKind* edge_pattern) { + while (lhs != rhs) { + if (lhs == nullptr) return nullptr; + if (rhs == nullptr) return nullptr; + if (lhs->depth < rhs->depth) { + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); + rhs = rhs->parent; + } else if (rhs->depth < lhs->depth) { + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); + lhs = lhs->parent; + } else { + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); + lhs = lhs->parent; + rhs = rhs->parent; + } + } + return lhs; +} + +DominatorTree::Node* DominatorTree::LeastCommonAncestor( + const LinkedList& input_nodes, OpPatternKind* edge_pattern) { + auto link = input_nodes.head; + if (link == nullptr) { + return nullptr; + } + auto get_node = [&](const IndexedForwardGraph::Edge& edge) { + size_t oindex = edge.node->index; + ICHECK_LT(oindex, nodes.size()); + Node* onode = nodes[oindex]; + ICHECK(onode != nullptr); + return onode; + }; + Node* parent = get_node(link->value); + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); + link = link->next; + for (; link != nullptr; link = link->next) { + parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern); + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); + } + return parent; +} + +DominatorTree::Node* DominatorTree::GetNode(support::Arena* arena, + IndexedForwardGraph::Node* gnode) { + Node* tnode = arena->make(); + tnode->gnode = gnode; + if (gnode->extern_ref) { + tnode->depth = 1; + tnode->parent = nullptr; + tnode->pattern = kOpaque; + } else { + // find the LCAs of all outputs. + OpPatternKind pattern = kElemWise; + Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); + tnode->depth = parent ? parent->depth + 1 : 1; + tnode->parent = parent; + tnode->pattern = pattern; + } + return tnode; +} + +std::vector GraphPartitioner::Partition( + const IndexedForwardGraph& graph) { + this->InitGroups(graph); + if (opt_level_ == 0) return std::move(groups_); + // get post dominator tree + auto post_dom_tree = DominatorTree::PostDom(arena_, graph); + // run fusion algorithm. + for (int phase = 0; phase < 3; ++phase) { + this->RunFuse(graph, post_dom_tree, phase); + } + return std::move(groups_); +} + +GraphPartitioner::Group* GraphPartitioner::Group::FindRoot() { + // fast path + if (this->parent == nullptr) return this; + // slow path with path compression. + Group* root = this; + while (root->parent != nullptr) { + root = root->parent; + } + for (Group* p = this; p != root;) { + Group* parent = p->parent; + p->parent = root; + p = parent; + } + return root; +} + +template +bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, + F fcond) { + if (visited_.count(src)) return true; + visited_.insert(src); + Group* gnode = groups_[src->index]; + ICHECK(gnode != nullptr); + gnode = gnode->FindRoot(); + if (!fcond(gnode->pattern, src == sink)) return false; + if (src == sink) return true; + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + if (!CheckPath_(link->value.node, sink, fcond)) return false; + } + return true; +} + +template +bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, + F fcond) { + ICHECK(!src->extern_ref); + visited_.clear(); + ICHECK(src != sink); + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + if (!CheckPath_(link->value.node, sink, fcond)) return false; + } + return true; +} + +OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { + if (lhs > relay::kBroadcast && rhs > relay::kBroadcast) { + LOG(FATAL) << "Cannot merge two complex group together"; + } + if (lhs > rhs) return lhs; + return rhs; +} + +void GraphPartitioner::MergeFromTo(Group* child, Group* parent) { + child = child->FindRoot(); + parent = parent->FindRoot(); + if (child == parent) return; + // update the number of nodes of the parent group + parent->num_nodes += child->num_nodes; + child->parent = parent; + // update anchor ref and pattern + if (child->anchor_ref != nullptr) { + ICHECK(parent->anchor_ref == nullptr); + parent->anchor_ref = child->anchor_ref; + parent->pattern = CombinePattern(child->pattern, parent->pattern); + } +} + +void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, + Group* target) { + if (src == sink) return; + if (visited_.count(src)) return; + visited_.insert(src); + Group* gnode = groups_[src->index]; + ICHECK(gnode != nullptr); + // merge the current group to the parent if possible. + MergeFromTo(gnode, target); + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + CommitFuse_(link->value.node, sink, target); + } +} + +void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { + Group* target = groups_[sink->index]; + visited_.clear(); + ICHECK(src != sink); + CommitFuse_(src, sink, target); +} + +size_t GraphPartitioner::CountNodesUptoSink_(IndexedForwardGraph::Node* src, + IndexedForwardGraph::Node* sink) { + if (src == sink || visited_.count(src)) return 0; + visited_.insert(src); + Group* gnode = groups_[src->index]; + ICHECK(gnode != nullptr); + auto sum = gnode->num_nodes; + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + sum += CountNodesUptoSink_(link->value.node, sink); + } + return sum; +} + +size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, + IndexedForwardGraph::Node* dom_parent) { + Group* target = groups_[dom_parent->index]; + visited_.clear(); + ICHECK(child != dom_parent); + return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); +} + +void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) { + groups_.resize(graph.post_dfs_order.size()); + for (size_t nid = 0; nid < groups_.size(); ++nid) { + const auto* graph_node = graph.post_dfs_order[nid]; + auto* group_node = arena_->make(); + group_node->pattern = graph_node->pattern; + group_node->root_ref = graph_node->ref; + // set anchor ref if necessary. + if (group_node->pattern == relay::kOutEWiseFusable) { + group_node->anchor_ref = graph_node->ref; + } + groups_[nid] = group_node; + } +} + +void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // + const DominatorTree& post_dom_tree, // + int phase) { + for (size_t nid = 0; nid < groups_.size(); ++nid) { + // the group of current node has been specified already. + auto* graph_node = graph.post_dfs_order[nid]; + auto* dom_node = post_dom_tree.nodes[nid]; + Group* group_node = groups_[nid]; + ICHECK(group_node != nullptr); + // no actions for opaque nodes + if (group_node->pattern == kOpaque) continue; + // no actions needed if the current node have no dominator + if (dom_node->parent == nullptr) continue; + ICHECK(!graph_node->extern_ref); + size_t dom_parent_gindex = dom_node->parent->gnode->index; + + // refuse the fusion if too many ops are going to be fused together + if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) + continue; + + if (phase == 2) { + // Fuse injective ops into intermediate tuples, if any + if (group_node->pattern > relay::kInjective) continue; + Group* dom_parent_group = groups_[dom_parent_gindex]; + Group* dom_root_group = dom_parent_group->FindRoot(); + // If dom node group has a tuple as its root, we do not fuse tuple fields into it + if (dom_root_group->pattern == relay::kTuple) continue; + if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) { + // Now we know the tuple has been fused into subsequent injective ops + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; + // dom_root_group can also be tuple, as in inception layers + // CheckPath is needed to avoid fusing two intermediate tuples + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + continue; + } + + // Skip if current node is already fused to the parent. + if (groups_[dom_parent_gindex] != nullptr && + group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { + continue; + } + // Do not fuse into tuple for now + if (groups_[dom_parent_gindex]->pattern == kTuple) continue; + // Try to fuse current node to its post-dominator. + if (group_node->pattern == kOutEWiseFusable) { + if (phase != 0) continue; + // Path for OutEWiseFusable: conv2d + // Check if the dominator relation is elemwise. + if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { + ICHECK(dom_node->parent->gnode != nullptr); + // The fuse can be executed if all the intermediate ops are still broadcast. + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + } else if (group_node->pattern <= kBroadcast) { + // Pre-condition: can only be fused to parent which is injective or reduction. + if (dom_node->parent != nullptr && + (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { + // Check if all the intermediate ops are still broadcast. + // The final terminal node can already be fused to a OutEWiseFusable group. + auto fcond = [](OpPatternKind kind, bool is_sink) { + if (!is_sink) { + // Elemwise, broadcast, and injective ops on the parallel branches + // are allowed be fused to the elemwise/broadcast anchor. + return kind <= kInjective; + } else { + return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || + kind == kOutEWiseFusable); + } + }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) { + // defer injective fusion to second phase. + // so conv2d always finishes fusing. + if (phase != 1) continue; + // Check if all path are injective. + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } else { + // do nothing. + ICHECK(group_node->pattern == kCommReduce); + } + } +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/analysis/graph_partitioner.h b/src/relay/analysis/graph_partitioner.h new file mode 100644 index 0000000000000..9433aafa119d4 --- /dev/null +++ b/src/relay/analysis/graph_partitioner.h @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/graph_partitioner.h + * \brief The helper function for op fusion. + */ + +#ifndef TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_ +#define TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_ + +#include + +#include +#include +#include + +#include "../../support/arena.h" + +namespace tvm { +namespace relay { + +using support::LinkedList; +using support::LinkNode; + +/*! + * \brief Indexed data flow graph in forward direction. + * This is a temporary data structure used for operator fusion analysis. + * + * This data structure only captures the dataflow fragment and + * could ignore blocks like let by simply ordering each dataflow block + * and mark the output node as extern_ref; + */ +class IndexedForwardGraph { + public: + struct Node; + /*! + * The forward edge in the dataflow graph. + */ + struct Edge { + /*! \brief The corresponding node */ + Node* node{nullptr}; + /*! \brief The respective pattern of this op */ + OpPatternKind pattern{kOpaque}; + }; + /*! \brief A node in the graph. */ + struct Node { + /*! \brief weak reference to the corresponding edge. */ + const tvm::Object* ref{nullptr}; + /*! \brief The index of the node in topological order. */ + size_t index{0}; + /*! \brief Whether this node is referenced by external source */ + bool extern_ref{false}; + /*! \brief The general pattern in the node */ + OpPatternKind pattern{kOpaque}; + /*! \brief The outputs of the node. */ + LinkedList outputs; + }; + /*! \brief The node map that maps node to graph */ + std::unordered_map node_map; + /*! \brief All the nodes in post DFS order */ + std::vector post_dfs_order; + + /*! \brief Dump the graph into string. */ + void DebugDump() { + std::ostringstream os; + for (size_t i = 0; i < post_dfs_order.size(); ++i) { + Node* node = post_dfs_order[i]; + os << "node[" << i << "], " << GetRef(node->ref) << " outputs=["; + for (auto* link = node->outputs.head; link != nullptr; link = link->next) { + os << link->value.node->index << ", "; + } + os << "]\n"; + } + LOG(INFO) << os.str(); + } +}; + +/*! + * \brief Dominator tree that represent domination or + * post domination relation of the node. + */ +class DominatorTree { + public: + /*! + * \brief A node in the dominator tree. + */ + struct Node { + /*! \brief The node in the tree */ + IndexedForwardGraph::Node* gnode{nullptr}; + /*! \brief parent of the tree */ + Node* parent{nullptr}; + /*! \brief current depth*/ + int depth{0}; + /*! \brief aggregated pattern to parent */ + OpPatternKind pattern{kOpaque}; + }; + // index -> node. + std::vector nodes; + /*! + * \brief compute a post dominator relation for a given dataflow graph. + * \param arena The arena used for node allocation. + * \param graph The graph to be analyzed. + * \return The dominator tree of the graph. + * \note This algorithm makes use of the fact that graph is DAG, + * and runs a single pass algorithm via LCA (Least Common Ancestor) + */ + static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph); + + private: + // Combine pattern together. + inline static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { + if (lhs > rhs) return lhs; + return rhs; + } + /*! + * \brief Find the least common ancestor of the two nodes. + * \param lhs The left node. + * \param rhs The right node. + * \param edge_pattern + * The combined edge pattern across all the parents. + * \return The least common ancestor of the two. + */ + static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern); + /*! + * \brief Find the least common ancestor of a list of nodes. + * \param nodes the nodes. + * \param edge_pattern + * The combined edge pattern across all the parents. + * \return The least common ancestor of all nodes. + */ + Node* LeastCommonAncestor(const LinkedList& input_nodes, + OpPatternKind* edge_pattern); + + /*! + * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node. + * \param arena The Arena. + * \param gnode An IndexedForwardGraph Node. + * \return The DominatorTree Node. + */ + Node* GetNode(support::Arena* arena, IndexedForwardGraph::Node* gnode); +}; + +/*! + * \brief A partition of the graph marked by union find data structure. + */ +class GraphPartitioner { + public: + explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth) + : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {} + /*! + * \brief Group as a union find data structure. + */ + struct Group { + /*! \brief The parent in the union find data structure. */ + Group* parent{nullptr}; + /*! \brief The pattern of the group */ + OpPatternKind pattern; + /*! \brief reference to the root node. */ + const tvm::Object* root_ref{nullptr}; + /*! + * \brief Reference to the anchor node, + * this field is not nullptr only if pattern is kOutEWiseFusable. + */ + const tvm::Object* anchor_ref{nullptr}; + /*! + * \brief The number of nodes belonging to this group + */ + uint32_t num_nodes{1}; + + /*! \brief Optional attributes to annotate the grouped function. */ + runtime::Map attrs; + /*! + * \brief Find the group root, perform path compression + * \return The root type node. + */ + Group* FindRoot(); + }; + /*! + * \brief Partition a graph. + * \return group assignments of each node. + */ + std::vector Partition(const IndexedForwardGraph& graph); + + private: + /*! \brief The internal arena for temporary space. */ + support::Arena* arena_; + /*! \brief optimization level for fuse operation. */ + int opt_level_; + /*! \brief The maximum number of operations in one fused function */ + size_t max_fuse_depth_; + /*! \brief The internal groups. */ + std::vector groups_; + /*! \brief internal field used for deduplication */ + std::unordered_set visited_; + // Internal implementation of CheckPath + template + bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond); + + /*! + * \brief Check all the node and edge pattern + * between src and sink satisfies fcond. + * + * src is not checked. + * + * \param src The source node. + * \param sink The termination node. + * \param fcond The condition to be checked. + * \tparam F the condition function, with signature + * \note sink must be a post-dominator of src. + */ + template + bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond); + + /*! + * \brief Merge the child group to the parent. + * \param child The child group. + * \param parent The parent group. + */ + void MergeFromTo(Group* child, Group* parent); + + // Internal implementation of CommitFuse + void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target); + + /*! + * \brief Commit fusion operation. + * \param src The source node. + * \param sink The termination node. + * \note sink must be a post-dominator of src. + */ + void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); + + size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); + + // Count the number of nodes in a fused subgraph if child is additionally fused. + // dom_parent is already known to be a part of the subgraph. + // For a diamond structure, there can be multiple paths connecting child and dom_parent. + // All intermediate nodes between child and dom_parent are taken into account. + // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot() + // is important for correct calculation. + size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, + IndexedForwardGraph::Node* dom_parent); + + // Initialize the groups. + void InitGroups(const IndexedForwardGraph& graph); + + // execute the fusion algorithm. + void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase); +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_ diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index afa60f1bb4e54..1fb857cb1cb31 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -32,6 +32,7 @@ #include #include "../../support/arena.h" +#include "../analysis/graph_partitioner.h" #include "../op/annotation/annotation.h" #include "./pass_utils.h" #include "./pattern_utils.h" @@ -88,72 +89,16 @@ static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion"); TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.link_params", Bool); -/*! - * \brief Indexed data flow graph in forward direction. - * This is a temporary data structure used for operator fusion analysis. - * - * This data structure only captures the dataflow fragment and - * could ignore blocks like let by simply ordering each dataflow block - * and mark the output node as extern_ref; - */ -class IndexedForwardGraph { +// Creator of post dominator tree of the dataflow +class IndexedForwardGraphCreator : private ExprVisitor { public: - struct Node; - /*! - * The forward edge in the dataflow graph. - */ - struct Edge { - /*! \brief The corresponding node */ - Node* node{nullptr}; - /*! \brief The respective pattern of this op */ - OpPatternKind pattern{kOpaque}; - }; - /*! \brief A node in the graph. */ - struct Node { - /*! \brief weak reference to the corresponding edge. */ - const tvm::Object* ref{nullptr}; - /*! \brief The index of the node in topological order. */ - size_t index{0}; - /*! \brief Whether this node is referenced by external source */ - bool extern_ref{false}; - /*! \brief The general pattern in the node */ - OpPatternKind pattern{kOpaque}; - /*! \brief The outputs of the node. */ - LinkedList outputs; - }; - /*! \brief The node map that maps node to graph */ - std::unordered_map node_map; - /*! \brief All the nodes in post DFS order */ - std::vector post_dfs_order; - - /*! \brief Dump the graph into string. */ - void DebugDump() { - std::ostringstream os; - for (size_t i = 0; i < post_dfs_order.size(); ++i) { - Node* node = post_dfs_order[i]; - os << "node[" << i << "], " << GetRef(node->ref) << " outputs=["; - for (auto* link = node->outputs.head; link != nullptr; link = link->next) { - os << link->value.node->index << ", "; - } - os << "]\n"; - } - LOG(INFO) << os.str(); + static IndexedForwardGraph Create(support::Arena* arena, const Expr& body) { + IndexedForwardGraphCreator creator(arena); + return creator.Prepare(body); } - /*! - * \brief create a indexed forward graph. - * \param arena The arena used for data allocation. - * \param body The body of the expression to create a graph. - */ - static IndexedForwardGraph Create(support::Arena* arena, const Expr& body); private: - class Creator; -}; - -// Creator of post dominator tree of the dataflow -class IndexedForwardGraph::Creator : private ExprVisitor { - public: - explicit Creator(support::Arena* arena) : arena_(arena) {} + explicit IndexedForwardGraphCreator(support::Arena* arena) : arena_(arena) {} IndexedForwardGraph Prepare(const Expr& body) { this->Update(body, nullptr, kOpaque); @@ -213,7 +158,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const ConstantNode* op) final { this->AddNode(op); - Node* node = graph_.node_map.at(op); + IndexedForwardGraph::Node* node = graph_.node_map.at(op); DataType dtype = DataType(op->data->dtype); // This rule must be consistent with code generator. bool is_simple_const = @@ -230,7 +175,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const CallNode* call) final { ICHECK(graph_.node_map.count(call)); - Node* node = graph_.node_map.at(call); + IndexedForwardGraph::Node* node = graph_.node_map.at(call); static auto fpattern = Op::GetAttrMap("TOpPattern"); // Now we set the pattern of this call. // @@ -274,7 +219,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const TupleNode* op) final { ICHECK(graph_.node_map.count(op)); - Node* tuple_node = graph_.node_map.at(op); + IndexedForwardGraph::Node* tuple_node = graph_.node_map.at(op); tuple_node->pattern = kTuple; for (const Expr& field : op->fields) { if (field->checked_type().as()) { @@ -306,7 +251,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->Update(op->tuple, nullptr, kOpaque); } else { ICHECK(graph_.node_map.count(op)); - Node* node = graph_.node_map.at(op); + IndexedForwardGraph::Node* node = graph_.node_map.at(op); node->pattern = kInjective; this->Update(op->tuple, node, kInjective); } @@ -372,443 +317,6 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } }; -IndexedForwardGraph IndexedForwardGraph::Create(support::Arena* arena, const Expr& body) { - return Creator(arena).Prepare(body); -} - -/*! - * \brief Dominator tree that represent domination or - * post domination relation of the node. - */ -class DominatorTree { - public: - /*! - * \brief A node in the dominator tree. - */ - struct Node { - /*! \brief The node in the tree */ - IndexedForwardGraph::Node* gnode{nullptr}; - /*! \brief parent of the tree */ - Node* parent{nullptr}; - /*! \brief current depth*/ - int depth{0}; - /*! \brief aggregated pattern to parent */ - OpPatternKind pattern{kOpaque}; - }; - // index -> node. - std::vector nodes; - /*! - * \brief compute a post dominator relation for a given dataflow graph. - * \param arena The arena used for node allocation. - * \param graph The graph to be analyzed. - * \return The dominator tree of the graph. - * \note This algorithm makes use of the fact that graph is DAG, - * and runs a single pass algorithm via LCA (Least Common Ancestor) - */ - static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph); - - private: - // Combine pattern together. - static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { - if (lhs > rhs) return lhs; - return rhs; - } - /*! - * \brief Find the least common ancestor of the two nodes. - * \param lhs The left node. - * \param rhs The right node. - * \param edge_pattern - * The combined edge pattern across all the parents. - * \return The least common ancestor of the two. - */ - static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern) { - while (lhs != rhs) { - if (lhs == nullptr) return nullptr; - if (rhs == nullptr) return nullptr; - if (lhs->depth < rhs->depth) { - edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); - rhs = rhs->parent; - } else if (rhs->depth < lhs->depth) { - edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); - lhs = lhs->parent; - } else { - edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); - edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); - lhs = lhs->parent; - rhs = rhs->parent; - } - } - return lhs; - } - /*! - * \brief Find the least common ancestor of a list of nodes. - * \param nodes the nodes. - * \param edge_pattern - * The combined edge pattern across all the parents. - * \return The least common ancestor of all nodes. - */ - Node* LeastCommonAncestor(const LinkedList& input_nodes, - OpPatternKind* edge_pattern) { - auto link = input_nodes.head; - if (link == nullptr) { - return nullptr; - } - auto get_node = [&](const IndexedForwardGraph::Edge& edge) { - size_t oindex = edge.node->index; - ICHECK_LT(oindex, nodes.size()); - Node* onode = nodes[oindex]; - ICHECK(onode != nullptr); - return onode; - }; - Node* parent = get_node(link->value); - *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); - link = link->next; - for (; link != nullptr; link = link->next) { - parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern); - *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); - } - return parent; - } - /*! - * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node. - * \param arena The Arena. - * \param gnode An IndexedForwardGraph Node. - * \return The DominatorTree Node. - */ - Node* GetNode(support::Arena* arena, IndexedForwardGraph::Node* gnode) { - Node* tnode = arena->make(); - tnode->gnode = gnode; - if (gnode->extern_ref) { - tnode->depth = 1; - tnode->parent = nullptr; - tnode->pattern = kOpaque; - } else { - // find the LCAs of all outputs. - OpPatternKind pattern = kElemWise; - Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); - tnode->depth = parent ? parent->depth + 1 : 1; - tnode->parent = parent; - tnode->pattern = pattern; - } - return tnode; - } -}; - -DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { - DominatorTree tree; - tree.nodes.resize(graph.post_dfs_order.size(), nullptr); - // reverse topo order - for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { - size_t index = i - 1; - tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); - } - return tree; -} - -/*! - * \brief A partition of the graph marked by union find data structure. - */ -class GraphPartitioner { - public: - explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth) - : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {} - /*! - * \brief Group as a union find data structure. - */ - struct Group { - /*! \brief The parent in the union find data structure. */ - Group* parent{nullptr}; - /*! \brief The pattern of the group */ - OpPatternKind pattern; - /*! \brief reference to the root node. */ - const tvm::Object* root_ref{nullptr}; - /*! - * \brief Reference to the anchor node, - * this field is not nullptr only if pattern is kOutEWiseFusable. - */ - const tvm::Object* anchor_ref{nullptr}; - /*! - * \brief Find the group root, perform path compression - * \return The root type node. - */ - Group* FindRoot() { - // fast path - if (this->parent == nullptr) return this; - // slow path with path compression. - Group* root = this; - while (root->parent != nullptr) { - root = root->parent; - } - for (Group* p = this; p != root;) { - Group* parent = p->parent; - p->parent = root; - p = parent; - } - return root; - } - - /*! - * \brief The number of nodes belonging to this group - */ - uint32_t num_nodes{1}; - }; - /*! - * \brief Partition a graph. - * \return group assignments of each node. - */ - std::vector Partition(const IndexedForwardGraph& graph); - - private: - /*! \brief The internal arena for temporary space. */ - support::Arena* arena_; - /*! \brief optimization level for fuse operation. */ - int opt_level_; - /*! \brief The maximum number of operations in one fused function */ - size_t max_fuse_depth_; - /*! \brief The internal groups. */ - std::vector groups_; - /*! \brief internal field used for deduplication */ - std::unordered_set visited_; - // Internal implelementation of CheckPath - template - bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { - if (visited_.count(src)) return true; - visited_.insert(src); - Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); - gnode = gnode->FindRoot(); - if (!fcond(gnode->pattern, src == sink)) return false; - if (src == sink) return true; - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - if (!CheckPath_(link->value.node, sink, fcond)) return false; - } - return true; - } - /*! - * \brief Check all the node and edge pattern - * between src and sink satisfies fcond. - * - * src is not checked. - * - * \param src The source node. - * \param sink The termination node. - * \param fcond The condition to be checked. - * \tparam F the condition function, with signature - * \note sink must be a post-dominator of src. - */ - template - bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { - ICHECK(!src->extern_ref); - visited_.clear(); - ICHECK(src != sink); - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - if (!CheckPath_(link->value.node, sink, fcond)) return false; - } - return true; - } - // Combine two patterns together. - static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { - if (lhs > kBroadcast && rhs > kBroadcast) { - LOG(FATAL) << "Cannot merge two complex group together"; - } - if (lhs > rhs) return lhs; - return rhs; - } - /*! - * \brief Merge the child group to the parent. - * \param child The child group. - * \param parent The parent group. - */ - void MergeFromTo(Group* child, Group* parent) { - child = child->FindRoot(); - parent = parent->FindRoot(); - if (child == parent) return; - // update the number of nodes of the parent group - parent->num_nodes += child->num_nodes; - child->parent = parent; - // update anchor ref and pattern - if (child->anchor_ref != nullptr) { - ICHECK(parent->anchor_ref == nullptr); - parent->anchor_ref = child->anchor_ref; - parent->pattern = CombinePattern(child->pattern, parent->pattern); - } - } - // Internal implelementation of CommitFuse - void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target) { - if (src == sink) return; - if (visited_.count(src)) return; - visited_.insert(src); - Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); - // merge the current group to the parent if possible. - MergeFromTo(gnode, target); - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - CommitFuse_(link->value.node, sink, target); - } - } - /*! - * \brief Commit fusion operation. - * \param src The source node. - * \param sink The termination node. - * \note sink must be a post-dominator of src. - */ - void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { - Group* target = groups_[sink->index]; - visited_.clear(); - ICHECK(src != sink); - CommitFuse_(src, sink, target); - } - - size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { - if (src == sink || visited_.count(src)) return 0; - visited_.insert(src); - Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); - auto sum = gnode->num_nodes; - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - sum += CountNodesUptoSink_(link->value.node, sink); - } - return sum; - } - - // Count the number of nodes in a fused subgraph if child is additionaly fused. - // dom_parent is already known to be a part of the subgraph. - // For a diamond structure, there can be multiple paths connecting child and dom_parent. - // All intermediate nodes between child and dom_parent are taken into account. - // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot() - // is important for correct calculation. - size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, - IndexedForwardGraph::Node* dom_parent) { - Group* target = groups_[dom_parent->index]; - visited_.clear(); - ICHECK(child != dom_parent); - return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); - } - - // Initialize the groups. - void InitGroups(const IndexedForwardGraph& graph) { - groups_.resize(graph.post_dfs_order.size()); - for (size_t nid = 0; nid < groups_.size(); ++nid) { - const auto* graph_node = graph.post_dfs_order[nid]; - auto* group_node = arena_->make(); - group_node->pattern = graph_node->pattern; - group_node->root_ref = graph_node->ref; - // set anchor ref if necessary. - if (group_node->pattern == kOutEWiseFusable) { - group_node->anchor_ref = graph_node->ref; - } - groups_[nid] = group_node; - } - } - - // execute the fusion algorithm. - void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase) { - for (size_t nid = 0; nid < groups_.size(); ++nid) { - // the group of current node has been specified already. - auto* graph_node = graph.post_dfs_order[nid]; - auto* dom_node = post_dom_tree.nodes[nid]; - Group* group_node = groups_[nid]; - ICHECK(group_node != nullptr); - // no actions for opaque nodes - if (group_node->pattern == kOpaque) continue; - // no actions needed if the current node have no dominator - if (dom_node->parent == nullptr) continue; - ICHECK(!graph_node->extern_ref); - size_t dom_parent_gindex = dom_node->parent->gnode->index; - - // refuse the fusion if too many ops are going to be fused together - if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) - continue; - - if (phase == 2) { - // Fuse injective ops into intermediate tuples, if any - if (group_node->pattern > kInjective) continue; - Group* dom_parent_group = groups_[dom_parent_gindex]; - Group* dom_root_group = dom_parent_group->FindRoot(); - // If dom node group has a tuple as its root, we do not fuse tuple fields into it - if (dom_root_group->pattern == kTuple) continue; - if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) { - // Now we know the tuple has been fused into subsequent injective ops - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; - // dom_root_group can also be tuple, as in inception layers - // CheckPath is needed to avoid fusing two intermediate tuples - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } - continue; - } - - // Skip if current node is already fused to the parent. - if (groups_[dom_parent_gindex] != nullptr && - group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { - continue; - } - // Do not fuse into tuple for now - if (groups_[dom_parent_gindex]->pattern == kTuple) continue; - // Try to fuse current node to its post-dominator. - if (group_node->pattern == kOutEWiseFusable) { - if (phase != 0) continue; - // Path for OutEWiseFusable: conv2d - // Check if the dominator relation is elemwise. - if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { - ICHECK(dom_node->parent->gnode != nullptr); - // The fuse can be executed if all the intermediate ops are still broadcast. - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } - } else if (group_node->pattern <= kBroadcast) { - // Pre-condition: can only be fused to parent which is injective or reduction. - if (dom_node->parent != nullptr && - (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { - // Check if all the intermediate ops are still broadcast. - // The final terminal node can already be fused to a OutEWiseFusable group. - auto fcond = [](OpPatternKind kind, bool is_sink) { - if (!is_sink) { - // Elemwise, broadcast, and injective ops on the parallel branches - // are allowed be fused to the elemwise/broadcast anchor. - return kind <= kInjective; - } else { - return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || - kind == kOutEWiseFusable); - } - }; - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } - } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) { - // defer injective fusion to second phase. - // so conv2d always finishes fusing. - if (phase != 1) continue; - // Check if all path are injective. - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } else { - // do nothing. - ICHECK(group_node->pattern == kCommReduce); - } - } - } -}; - -std::vector GraphPartitioner::Partition( - const IndexedForwardGraph& graph) { - this->InitGroups(graph); - if (opt_level_ == 0) return std::move(groups_); - // get post dominator tree - auto post_dom_tree = DominatorTree::PostDom(arena_, graph); - // run fusion algorithm. - for (int phase = 0; phase < 3; ++phase) { - this->RunFuse(graph, post_dom_tree, phase); - } - return std::move(groups_); -} - class FuseMutator : private MixedModeMutator { public: FuseMutator(int fuse_opt_level, size_t max_fuse_depth, bool link_params) @@ -825,7 +333,7 @@ class FuseMutator : private MixedModeMutator { // Run the transform Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) { // setup the group map. - auto graph = IndexedForwardGraph::Create(&arena_, body); + auto graph = IndexedForwardGraphCreator::Create(&arena_, body); auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph); for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { ICHECK(graph.post_dfs_order[nid]->ref != nullptr); diff --git a/src/runtime/crt/host/Makefile.template b/src/runtime/crt/host/Makefile.template index 2caf7ba0bc231..526b17deb73f3 100644 --- a/src/runtime/crt/host/Makefile.template +++ b/src/runtime/crt/host/Makefile.template @@ -16,9 +16,9 @@ # under the License. INCLUDES ?= -isystem crt/include -Icrt_config -MEMORY_SIZE_BYTES := +TVM_WORKSPACE_SIZE_BYTES := CFLAGS ?= -Werror -Wall -CXXFLAGS ?= -Werror -Wall -std=c++11 -DTVM_HOST_USE_GRAPH_EXECUTOR_MODULE -DMEMORY_SIZE_BYTES=$(MEMORY_SIZE_BYTES) +CXXFLAGS ?= -Werror -Wall -std=c++11 -DTVM_HOST_USE_GRAPH_EXECUTOR_MODULE -DTVM_WORKSPACE_SIZE_BYTES=$(TVM_WORKSPACE_SIZE_BYTES) LDFLAGS ?= -Werror -Wall # Codegen produces spurious lines like: int32_t arg2_code = ((int32_t*)arg_type_ids)[(2)]; diff --git a/src/runtime/crt/host/main.cc b/src/runtime/crt/host/main.cc index e9f6813f9b3cd..0607d4b287191 100644 --- a/src/runtime/crt/host/main.cc +++ b/src/runtime/crt/host/main.cc @@ -22,14 +22,12 @@ * \brief main entry point for host subprocess-based CRT */ #include -#include #include +#include #include #include -#include #include -#include #include #include "crt_config.h" @@ -38,10 +36,6 @@ #include #endif -#include - -using namespace std::chrono; - extern "C" { ssize_t MicroTVMWriteFunc(void* context, const uint8_t* data, size_t num_bytes) { @@ -50,70 +44,8 @@ ssize_t MicroTVMWriteFunc(void* context, const uint8_t* data, size_t num_bytes) fsync(STDOUT_FILENO); return to_return; } - -size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, - va_list args) { - return vsnprintf(out_buf, out_buf_size_bytes, fmt, args); -} - -void TVMPlatformAbort(tvm_crt_error_t error_code) { - std::cerr << "TVMPlatformAbort: " << error_code << std::endl; - throw "Aborted"; } -MemoryManagerInterface* memory_manager; - -tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { - return memory_manager->Allocate(memory_manager, num_bytes, dev, out_ptr); -} - -tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { - return memory_manager->Free(memory_manager, ptr, dev); -} - -steady_clock::time_point g_microtvm_start_time; -int g_microtvm_timer_running = 0; - -tvm_crt_error_t TVMPlatformTimerStart() { - if (g_microtvm_timer_running) { - std::cerr << "timer already running" << std::endl; - return kTvmErrorPlatformTimerBadState; - } - g_microtvm_start_time = std::chrono::steady_clock::now(); - g_microtvm_timer_running = 1; - return kTvmErrorNoError; -} - -tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { - if (!g_microtvm_timer_running) { - std::cerr << "timer not running" << std::endl; - return kTvmErrorPlatformTimerBadState; - } - auto microtvm_stop_time = std::chrono::steady_clock::now(); - std::chrono::microseconds time_span = std::chrono::duration_cast( - microtvm_stop_time - g_microtvm_start_time); - *elapsed_time_seconds = static_cast(time_span.count()) / 1e6; - g_microtvm_timer_running = 0; - return kTvmErrorNoError; -} - -static_assert(RAND_MAX >= (1 << 8), "RAND_MAX is smaller than acceptable"); -unsigned int random_seed = 0; -tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { - if (random_seed == 0) { - random_seed = (unsigned int)time(NULL); - } - for (size_t i = 0; i < num_bytes; ++i) { - int random = rand_r(&random_seed); - buffer[i] = (uint8_t)random; - } - - return kTvmErrorNoError; -} -} - -uint8_t memory[MEMORY_SIZE_BYTES]; - static char** g_argv = NULL; int testonly_reset_server(TVMValue* args, int* type_codes, int num_args, TVMValue* out_ret_value, @@ -125,13 +57,7 @@ int testonly_reset_server(TVMValue* args, int* type_codes, int num_args, TVMValu int main(int argc, char** argv) { g_argv = argv; - int status = - PageMemoryManagerCreate(&memory_manager, memory, sizeof(memory), 8 /* page_size_log2 */); - if (status != 0) { - fprintf(stderr, "error initiailizing memory manager\n"); - return 2; - } - + TVMPlatformInitialize(); microtvm_rpc_server_t rpc_server = MicroTVMRpcServerInit(&MicroTVMWriteFunc, nullptr); #ifdef TVM_HOST_USE_GRAPH_EXECUTOR_MODULE diff --git a/src/runtime/crt/host/microtvm_api_server.py b/src/runtime/crt/host/microtvm_api_server.py index e5b82f96b0ff8..57b7506b879fc 100644 --- a/src/runtime/crt/host/microtvm_api_server.py +++ b/src/runtime/crt/host/microtvm_api_server.py @@ -38,7 +38,7 @@ IS_TEMPLATE = not os.path.exists(os.path.join(PROJECT_DIR, MODEL_LIBRARY_FORMAT_RELPATH)) # Used this size to pass most CRT tests in TVM. -MEMORY_SIZE_BYTES = 2 * 1024 * 1024 +WORKSPACE_SIZE_BYTES = 2 * 1024 * 1024 MAKEFILE_FILENAME = "Makefile" @@ -67,11 +67,11 @@ def server_info_query(self, tvm_version): help="Run make with verbose output", ), server.ProjectOption( - "memory_size_bytes", + "workspace_size_bytes", optional=["generate_project"], type="int", - default=MEMORY_SIZE_BYTES, - help="Sets the value of MEMORY_SIZE_BYTES.", + default=WORKSPACE_SIZE_BYTES, + help="Sets the value of TVM_WORKSPACE_SIZE_BYTES.", ), ], ) @@ -90,7 +90,7 @@ def _populate_makefile( ): """Generate Makefile from template.""" flags = { - "MEMORY_SIZE_BYTES": str(memory_size), + "TVM_WORKSPACE_SIZE_BYTES": str(memory_size), } regex = re.compile(r"([A-Z_]+) := (<[A-Z_]+>)") @@ -106,6 +106,7 @@ def _populate_makefile( def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): # Make project directory. project_dir.mkdir(parents=True) + current_dir = pathlib.Path(__file__).parent.absolute() # Copy ourselves to the generated project. TVM may perform further build steps on the generated project # by launching the copy. @@ -135,25 +136,29 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec # Populate Makefile self._populate_makefile( - pathlib.Path(__file__).parent / f"{MAKEFILE_FILENAME}.template", + current_dir / f"{MAKEFILE_FILENAME}.template", project_dir / MAKEFILE_FILENAME, - options.get("memory_size_bytes", MEMORY_SIZE_BYTES), + options.get("workspace_size_bytes", WORKSPACE_SIZE_BYTES), ) # Populate crt-config.h crt_config_dir = project_dir / "crt_config" crt_config_dir.mkdir() shutil.copy2( - os.path.join(os.path.dirname(__file__), "crt_config-template.h"), - os.path.join(crt_config_dir, "crt_config.h"), + current_dir / "crt_config" / "crt_config-template.h", + crt_config_dir / "crt_config.h", ) # Populate src/ - src_dir = os.path.join(project_dir, "src") - os.mkdir(src_dir) + src_dir = project_dir / "src" + src_dir.mkdir() shutil.copy2( - os.path.join(os.path.dirname(__file__), "src", "main.cc"), - os.path.join(src_dir, "main.cc"), + current_dir / "src" / "main.cc", + src_dir / "main.cc", + ) + shutil.copy2( + current_dir / "src" / "platform.cc", + src_dir / "platform.cc", ) def build(self, options): diff --git a/src/runtime/crt/host/platform.cc b/src/runtime/crt/host/platform.cc new file mode 100644 index 0000000000000..f5af08a9be884 --- /dev/null +++ b/src/runtime/crt/host/platform.cc @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Implementation of TVMPlatform functions in tvm/runtime/crt/platform.h + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace std::chrono; + +extern "C" { + +uint8_t memory[TVM_WORKSPACE_SIZE_BYTES]; +MemoryManagerInterface* memory_manager; + +steady_clock::time_point g_microtvm_start_time; +int g_microtvm_timer_running = 0; + +// Called when an internal error occurs and execution cannot continue. +void TVMPlatformAbort(tvm_crt_error_t error_code) { + std::cerr << "TVMPlatformAbort: " << error_code << std::endl; + throw "Aborted"; +} + +// Called by the microTVM RPC server to implement TVMLogf. +size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, + va_list args) { + return vsprintf(out_buf, fmt, args); +} + +// Allocate memory for use by TVM. +tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { + return memory_manager->Allocate(memory_manager, num_bytes, dev, out_ptr); +} + +// Free memory used by TVM. +tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { + return memory_manager->Free(memory_manager, ptr, dev); +} + +// Start a device timer. +tvm_crt_error_t TVMPlatformTimerStart() { + if (g_microtvm_timer_running) { + std::cerr << "timer already running" << std::endl; + return kTvmErrorPlatformTimerBadState; + } + g_microtvm_start_time = std::chrono::steady_clock::now(); + g_microtvm_timer_running = 1; + return kTvmErrorNoError; +} + +// Stop the running device timer and get the elapsed time (in microseconds). +tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { + if (!g_microtvm_timer_running) { + std::cerr << "timer not running" << std::endl; + return kTvmErrorPlatformTimerBadState; + } + auto microtvm_stop_time = std::chrono::steady_clock::now(); + std::chrono::microseconds time_span = std::chrono::duration_cast( + microtvm_stop_time - g_microtvm_start_time); + *elapsed_time_seconds = static_cast(time_span.count()) / 1e6; + g_microtvm_timer_running = 0; + return kTvmErrorNoError; +} + +// Platform-specific before measurement call. +tvm_crt_error_t TVMPlatformBeforeMeasurement() { return kTvmErrorNoError; } + +// Platform-specific after measurement call. +tvm_crt_error_t TVMPlatformAfterMeasurement() { return kTvmErrorNoError; } + +static_assert(RAND_MAX >= (1 << 8), "RAND_MAX is smaller than acceptable"); +unsigned int random_seed = 0; +// Fill a buffer with random data. +tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { + if (random_seed == 0) { + random_seed = (unsigned int)time(NULL); + } + for (size_t i = 0; i < num_bytes; ++i) { + int random = rand_r(&random_seed); + buffer[i] = (uint8_t)random; + } + return kTvmErrorNoError; +} + +// Initialize TVM inference. +tvm_crt_error_t TVMPlatformInitialize() { + int status = + PageMemoryManagerCreate(&memory_manager, memory, sizeof(memory), 8 /* page_size_log2 */); + if (status != 0) { + fprintf(stderr, "error initiailizing memory manager\n"); + return kTvmErrorPlatformMemoryManagerInitialized; + } + return kTvmErrorNoError; +} + +} // extern C diff --git a/src/runtime/crt/platform-template.c b/src/runtime/crt/platform-template.c new file mode 100644 index 0000000000000..b93fd1459be61 --- /dev/null +++ b/src/runtime/crt/platform-template.c @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Implementation of TVMPlatform functions in tvm/runtime/crt/platform.h + */ + +#include +#include +#include +#include +#include +#include +#include + +uint8_t memory[TVM_WORKSPACE_SIZE_BYTES]; +MemoryManagerInterface* memory_manager; + +// Called when an internal error occurs and execution cannot continue. +void TVMPlatformAbort(tvm_crt_error_t error_code) { exit(1); } + +// Called by the microTVM RPC server to implement TVMLogf. +size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, + va_list args) { + return vsprintf(out_buf, fmt, args); +} + +// Allocate memory for use by TVM. +tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { + return memory_manager->Allocate(memory_manager, num_bytes, dev, out_ptr); +} + +// Free memory used by TVM. +tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { + return memory_manager->Free(memory_manager, ptr, dev); +} + +// Start a device timer. +tvm_crt_error_t TVMPlatformTimerStart() { return kTvmErrorNoError; } + +// Stop the running device timer and get the elapsed time (in microseconds). +tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { return kTvmErrorNoError; } + +// Platform-specific before measurement call. +tvm_crt_error_t TVMPlatformBeforeMeasurement() { return kTvmErrorNoError; } + +// Platform-specific after measurement call. +tvm_crt_error_t TVMPlatformAfterMeasurement() { return kTvmErrorNoError; } + +// Fill a buffer with random data. +tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { + return kTvmErrorNoError; +} + +// Initialize TVM inference. +tvm_crt_error_t TVMPlatformInitialize() { + int status = + PageMemoryManagerCreate(&memory_manager, memory, sizeof(memory), 8 /* page_size_log2 */); + if (status != 0) { + fprintf(stderr, "error initiailizing memory manager\n"); + return kTvmErrorPlatformMemoryManagerInitialized; + } + return kTvmErrorNoError; +} diff --git a/src/runtime/hexagon/hexagon_buffer.cc b/src/runtime/hexagon/hexagon_buffer.cc index 3a3444faf4da5..48afa5770afd1 100644 --- a/src/runtime/hexagon/hexagon_buffer.cc +++ b/src/runtime/hexagon/hexagon_buffer.cc @@ -50,6 +50,16 @@ struct DDRAllocation : public Allocation { DDRAllocation(size_t nbytes, size_t alignment) : Allocation(nbytes, alignment) { int ret = posix_memalign(&data_, alignment, nbytes); CHECK_EQ(ret, 0); + + // The heap used by malloc on Hexagon is always mapped as cacheable. The heap manager may not + // perform cache invalidation on a prior memory free. So, a subsequent memory allocation request + // to the heap manager may allocate memory that resides in part or in full in the cache. Hence, + // we must invalidate the allocation from the cache to ensure that DMA with cache bypass enabled + // will function properly. DMA with cache bypass enabled assumes that HexagonBuffer objects are + // not cached unless explicitly modified by the primfunc. We must invalidate after malloc to + // uphold this assumption. + qurt_mem_cache_clean(reinterpret_cast(data_), nbytes, QURT_MEM_CACHE_INVALIDATE, + QURT_MEM_DCACHE); } ~DDRAllocation() { free(data_); } }; @@ -224,7 +234,8 @@ std::vector MemoryCopy::MergeAdjacent(std::vector micro_ } void hexagon_buffer_copy_across_regions(const BufferSet& dest, const BufferSet& src, - size_t bytes_to_copy) { + size_t bytes_to_copy, bool src_is_hexbuff, + bool dest_is_hexbuff) { // First, determine all copies that do not cross boundaries in // either source or destination region. auto micro_copies = BufferSet::MemoryCopies(dest, src, bytes_to_copy); @@ -235,19 +246,21 @@ void hexagon_buffer_copy_across_regions(const BufferSet& dest, const BufferSet& // Finally, do the memory copies. for (const auto& copy : macro_copies) { - // clean Hexagon cache before / after memcpy to ensure clean cache state to enable usage of DMA - // bypass mode for increased DMA bandwidth + // if src is a HexagonBuffer, invalidate it before the memcpy + if (src_is_hexbuff) { + qurt_mem_cache_clean(reinterpret_cast(copy.src), copy.num_bytes, + QURT_MEM_CACHE_INVALIDATE, QURT_MEM_DCACHE); + } + // TODO(HWE): Switch to ION Buffer to avoid need for memcpy and potentially lighten or alleviate // the burden of cache invalidation in this code - qurt_mem_cache_clean(reinterpret_cast(copy.dest), copy.num_bytes, - QURT_MEM_CACHE_INVALIDATE, QURT_MEM_DCACHE); - qurt_mem_cache_clean(reinterpret_cast(copy.src), copy.num_bytes, - QURT_MEM_CACHE_INVALIDATE, QURT_MEM_DCACHE); memcpy(copy.dest, copy.src, copy.num_bytes); - qurt_mem_cache_clean(reinterpret_cast(copy.dest), copy.num_bytes, - QURT_MEM_CACHE_INVALIDATE, QURT_MEM_DCACHE); - qurt_mem_cache_clean(reinterpret_cast(copy.src), copy.num_bytes, - QURT_MEM_CACHE_INVALIDATE, QURT_MEM_DCACHE); + + // if dest is a HexagonBuffer, flush it after the memcpy + if (dest_is_hexbuff) { + qurt_mem_cache_clean(reinterpret_cast(copy.dest), copy.num_bytes, + QURT_MEM_CACHE_FLUSH, QURT_MEM_DCACHE); + } } } @@ -255,21 +268,24 @@ void HexagonBuffer::CopyTo(void* data, size_t nbytes) const { BufferSet src(allocations_.data(), allocations_.size(), nbytes_per_allocation_); BufferSet dest(&data, 1, nbytes); - hexagon_buffer_copy_across_regions(dest, src, nbytes); + hexagon_buffer_copy_across_regions(dest, src, nbytes, true /* src_is_hexbuff */, + false /* dest_is_hexbuff */); } void HexagonBuffer::CopyFrom(void* data, size_t nbytes) { BufferSet src(&data, 1, nbytes); BufferSet dest(allocations_.data(), allocations_.size(), nbytes_per_allocation_); - hexagon_buffer_copy_across_regions(dest, src, nbytes); + hexagon_buffer_copy_across_regions(dest, src, nbytes, false /* src_is_hexbuff */, + true /* dest_is_hexbuff */); } void HexagonBuffer::CopyFrom(const HexagonBuffer& other, size_t nbytes) { BufferSet src(other.allocations_.data(), other.allocations_.size(), other.nbytes_per_allocation_); BufferSet dest(allocations_.data(), allocations_.size(), nbytes_per_allocation_); - hexagon_buffer_copy_across_regions(dest, src, nbytes); + hexagon_buffer_copy_across_regions(dest, src, nbytes, true /* src_is_hexbuff */, + true /* dest_is_hexbuff */); } } // namespace hexagon diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 822e8e4683776..30102b6877223 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -492,6 +492,34 @@ Var EnvThread(String thread_tag) { } void BufferStore(Buffer buffer, PrimExpr value, Array indices) { + runtime::DataType buffer_dtype = buffer->dtype; + int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; + runtime::DataType lhs_dtype = buffer_dtype.with_lanes(buffer_dtype.lanes() * index_lanes); + runtime::DataType rhs_dtype = value->dtype; + if (lhs_dtype != rhs_dtype) { + if (lhs_dtype.lanes() != rhs_dtype.lanes()) { + LOG(FATAL) << "TypeError: Incompatible types in BufferStore" + << ": LHS is `" << lhs_dtype << "`, RHS is `" << rhs_dtype + << "`, indexing lanes: " << index_lanes; + } + if (lhs_dtype.code() != rhs_dtype.code()) { + if ( + // Case 1. lhs is handle, and rhs needs to be casted to handle. + (lhs_dtype.code() == runtime::DataType::kHandle) || + // Case 2. rhs is handle, and it needs to be casted to non-handle. + (rhs_dtype.code() == runtime::DataType::kHandle) || + // Case 3. rhs is float or bfloat, and casting to non-float can lose precision. + ((lhs_dtype.code() == runtime::DataType::kInt || + lhs_dtype.code() == runtime::DataType::kUInt) && + (rhs_dtype.code() == runtime::DataType::kFloat || + rhs_dtype.code() == runtime::DataType::kBFloat))) { + LOG(WARNING) << "Casting in BufferStore may lose precision" + << ": LHS is `" << lhs_dtype << "`, RHS is `" << rhs_dtype + << "`, indexing lanes: " << index_lanes; + } + } + value = tvm::cast(lhs_dtype, value); + } AddToParent(tvm::tir::BufferStore(buffer, value, indices)); } @@ -517,6 +545,16 @@ PrimExpr Ptr(runtime::DataType dtype, String storage_scope) { return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope)); } +Var Handle(runtime::DataType dtype, String storage_scope) { + Type type_annotation{nullptr}; + if (dtype.is_void() && storage_scope == "global") { + type_annotation = PrimType(runtime::DataType::Handle()); + } else { + type_annotation = PointerType(PrimType(dtype), storage_scope); + } + return tvm::tir::Var("", type_annotation); +} + using tvm::script::ir_builder::details::Namer; TVM_STATIC_IR_FUNCTOR(Namer, vtable) @@ -555,8 +593,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) Namer::Name(var->var, name); }); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferDecl").set_body_typed(BufferDecl); - +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Buffer").set_body_typed(BufferDecl); TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg") .set_body_typed([](String name, ObjectRef obj) -> ObjectRef { diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 733c975fad7e3..485757063867b 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -21,6 +21,7 @@ #include #include +#include #include namespace tvm { diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 190669aa7a6cb..e6f4a1eaee2c2 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -26,27 +26,47 @@ namespace printer { TVM_REGISTER_NODE_TYPE(IRFrameNode); +struct SortableFunction { + int priority; + GlobalVar gv; + BaseFunc func; + + explicit SortableFunction(const std::pair& obj) + : priority(0), gv(obj.first), func(obj.second) { + if (gv->name_hint == "main") { + priority = 1000; + } else if (obj.second->GetTypeKey() == "tir.PrimFunc") { + priority = 1; + } else if (obj.second->GetTypeKey() == "relax.expr.ExternFunc") { + priority = 2; + } else if (obj.second->GetTypeKey() == "relax.expr.Function") { + priority = 3; + } else { + LOG(FATAL) << "TypeError: TVMScript cannot print functions of type: " + << obj.second->GetTypeKey(); + } + } + + bool operator<(const SortableFunction& other) const { + if (this->priority != other.priority) { + return this->priority < other.priority; + } + return this->gv->name_hint < other.gv->name_hint; + } +}; + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](IRModule mod, ObjectPath p, IRDocsifier d) -> Doc { - std::vector> functions{mod->functions.begin(), - mod->functions.end()}; - // print "main" first - std::sort(functions.begin(), functions.end(), [](const auto& lhs, const auto& rhs) { - String lhs_name = lhs.first->name_hint; - String rhs_name = rhs.first->name_hint; - if (lhs_name == "main") { - lhs_name = ""; - } - if (rhs_name == "main") { - rhs_name = ""; - } - return lhs_name < rhs_name; - }); + std::vector functions; + for (const auto& kv : mod->functions) { + functions.push_back(SortableFunction(kv)); + } + std::sort(functions.begin(), functions.end()); With f(d); (*f)->AddDispatchToken(d, "ir"); - for (const auto& kv : functions) { - GlobalVar gv = kv.first; - BaseFunc func = kv.second; + for (const auto& entry : functions) { + const GlobalVar& gv = entry.gv; + const BaseFunc& func = entry.func; d->cfg->binding_names.push_back(gv->name_hint); Doc doc = d->AsDoc(func, p->Attr("functions")->MapValue(gv)); d->cfg->binding_names.pop_back(); diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc index cb78dc3ff5c33..ef68b89b5bf45 100644 --- a/src/script/printer/ir/misc.cc +++ b/src/script/printer/ir/misc.cc @@ -24,6 +24,9 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](String s, ObjectPath p, IRDocsifier d) -> Doc { + if (HasMultipleLines(s)) { + return d->AddMetadata(s); + } return LiteralDoc::Str(s, p); }); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index cc37f46e6036c..d860eeb2a7da4 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -29,14 +29,29 @@ Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d) if (Optional opt_f = FindLowestVarDef(var, d)) { ExprDoc lhs = DefineVar(var, opt_f.value(), d); Type type = var->type_annotation; + ObjectPath type_p = var_p->Attr("type_annotation"); + ExprDoc rhs{nullptr}; if (const auto* ptr_type = type.as()) { - ICHECK(ptr_type->element_type->IsInstance()); - ExprDoc rhs = d->AsDoc(type, var_p->Attr("type_annotation")); - opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + const auto* prim_type = ptr_type->element_type.as(); + ICHECK(prim_type); + ExprDoc element_type = + LiteralDoc::DataType(prim_type->dtype, type_p->Attr("element_type")->Attr("dtype")); + rhs = TIR(d, "handle"); + rhs->source_paths.push_back(var_p->Attr("dtype")); + if (ptr_type->storage_scope == "") { + rhs = rhs->Call({element_type}); + } else { + rhs = rhs->Call({element_type, + LiteralDoc::Str(ptr_type->storage_scope, // + type_p->Attr("storage_scope"))}); + } } else { - ExprDoc rhs = TIR(d, "var")->Call({LiteralDoc::DataType(var->dtype, var_p->Attr("dtype"))}); - opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + rhs = TIR(d, DType2Str(var->dtype)); + rhs->source_paths.push_back(var_p->Attr("dtype")); + rhs = rhs->Call({}); } + rhs->source_paths.push_back(type_p); + opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } else { LOG(WARNING) << "Didn't find variable definition for: " << var->name_hint; } @@ -79,7 +94,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::StringImm s, ObjectPath p, IRDocsifier d) -> Doc { - return d->AsDoc(s->value, p->Attr("value")); + if (HasMultipleLines(s->value)) { + return d->AddMetadata(s); + } else { + return d->AsDoc(s->value, p->Attr("value")); + } }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index ce10ff6816d7d..78e50a5eb5daa 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -73,10 +73,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) element_type = d->AsDoc(ty->element_type, ty_p->Attr("element_type")); } if (ty->storage_scope == "") { - return TIR(d, "Ptr")->Call({element_type}); + return TIR(d, "handle")->Call({element_type}); } else { - return TIR(d, "Ptr")->Call( - {element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))}); + return TIR(d, "handle") + ->Call({element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))}); } }); diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index ade19b345215a..90300518b75b2 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -27,6 +27,8 @@ #include #include +#include "../../support/str_escape.h" + namespace tvm { namespace script { namespace printer { @@ -76,9 +78,10 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra std::ostringstream os; if (!d->metadata.empty()) { if (d->cfg->show_meta) { - os << "metadata = tvm.ir.load_json(" - << SaveJSON(Map(d->metadata.begin(), d->metadata.end())) << ")" - << "\n"; + os << "metadata = tvm.ir.load_json(\"" + << support::StrEscape( + SaveJSON(Map(d->metadata.begin(), d->metadata.end()))) + << "\")\n"; } else { f->stmts.push_back( CommentDoc("Metadata omitted. Use show_meta=True in script() method to show it.")); @@ -130,6 +133,11 @@ inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) { return doc; } +/*! \brief Check if a string has multiple lines. */ +inline bool HasMultipleLines(const std::string& str) { + return str.find_first_of('\n') != std::string::npos; +} + inline Optional GetBindingName(const IRDocsifier& d) { return d->cfg->binding_names.empty() ? Optional(NullOpt) : d->cfg->binding_names.back(); } diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index c891ec5a28cf0..9bf0109cace15 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -914,7 +914,13 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string src = this->PrintExpr(op->args[2]); std::string src_offset = this->PrintExpr(op->args[3]); std::string size = this->PrintExpr(op->args[4]); - this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); + // use size of argument list to indicate whether or not to use predicated cp.async + if (op->args.size() == 5) { + this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); + } else { + this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, size, + this->PrintExpr(op->args[5])); + } } else if (op->op.same_as(builtin::ptx_commit_group())) { this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n"; } else if (op->op.same_as(builtin::ptx_wait_group())) { diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index 886242efe08cf..b5299b4e4b2a3 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -659,5 +659,36 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, return asm_code; } +std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, + const std::string& bytes, + const std::string& predicate_value) { + std::string predicated_asm_code = R"( + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)({smem_addr})) + ); + int src_bytes = {pred_guard} ? {bytes} : 0; + __asm__ __volatile__( + "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;" + :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes) + ); + } +)"; + Replacer replacer; + replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); + replacer.register_rule("{bytes}", bytes); + replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); + replacer.register_rule("{pred_guard}", predicate_value); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + } // namespace codegen } // namespace tvm diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h index c811a1b9c1d6b..1e49b57c1790a 100644 --- a/src/target/source/ptx.h +++ b/src/target/source/ptx.h @@ -92,6 +92,22 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, const std::string& global_ptr, const std::string& global_elem_offset, const std::string& bytes); +/*! + * \brief Print predicated ptx cp.async assembly string given parameters. + * \param shared_ptr: The pointer to the destination shared memory. + * \param shared_elem_offset: The offset into the shared memory. + * \param global_ptr: The pointer to the global memory. + * \param global_elem_offset: The offset into the global memory. + * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. + * \param predicate_value: The value of predicate `@p`. + */ +std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, + const std::string& bytes, + const std::string& predicate_value); + } // namespace codegen } // namespace tvm diff --git a/src/target/virtual_device.cc b/src/target/virtual_device.cc index 39bb11ff157b9..3842776a6fd4a 100644 --- a/src/target/virtual_device.cc +++ b/src/target/virtual_device.cc @@ -66,13 +66,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); -VirtualDevice::VirtualDevice(DLDeviceType device_type, int virtual_device_id, Target target, +VirtualDevice::VirtualDevice(int device_type_int, int virtual_device_id, Target target, MemoryScope memory_scope) { - ICHECK(!target.defined() || device_type == target->GetTargetDeviceType()) + ICHECK(!target.defined() || device_type_int == target->GetTargetDeviceType()) << "target " << target->ToDebugString() << " has device type " - << target->GetTargetDeviceType() << " but virtual device has device type " << device_type; + << target->GetTargetDeviceType() << " but virtual device has device type " << device_type_int; auto node = make_object(); - node->device_type_int = device_type; + node->device_type_int = device_type_int; node->virtual_device_id = virtual_device_id; node->target = std::move(target); node->memory_scope = std::move(memory_scope); @@ -166,8 +166,8 @@ VirtualDevice VirtualDevice::Default(const VirtualDevice& lhs, const VirtualDevi defaulted_memory_scope); } -VirtualDevice VirtualDeviceCache::Make(DLDeviceType device_type, int virtual_device_id, - Target target, MemoryScope memory_scope) { +VirtualDevice VirtualDeviceCache::Make(int device_type, int virtual_device_id, Target target, + MemoryScope memory_scope) { VirtualDevice prototype(device_type, virtual_device_id, std::move(target), std::move(memory_scope)); if (prototype->IsFullyUnconstrained()) { diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 355a3b16b8553..1652786cb510e 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -507,6 +507,12 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, << "Cannot store value with " << value.dtype().lanes() << ", expected value with " << index_lanes * buffer_lanes << " (" << index_lanes << " index lanes * " << buffer_lanes << " buffer element lanes)"; + if (buffer->dtype.with_lanes(buffer_lanes * index_lanes) != value.dtype()) { + LOG(FATAL) << "TypeError: dtype mismatch on BufferStore: " // + << "buffer's dtype is `" << buffer->dtype // + << "`, the lanes of indexing are: `" << index_lanes // + << "`, but RHS's dtype is `" << value.dtype() << "`"; + } ObjectPtr node = make_object(); node->buffer = std::move(buffer); diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 8ee0d054e56dc..2e3c906e89c17 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -47,73 +47,105 @@ class PTXAsyncCopyInjector : public StmtMutator { return StmtMutator::VisitStmt_(attr); } + Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store, bool predicated = false, + PrimExpr predicate_value = PrimExpr()) { + if (load->buffer.scope() == "global") { + ICHECK(load->indices.size() == 1 && store->indices.size() == 1); + ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); + + const int indices_lanes = load->indices[0]->dtype.lanes(); + const int bytes = indices_lanes * load->buffer->dtype.bytes(); + + if (bytes == 4 || bytes == 8 || bytes == 16) { + auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation); + auto src_elem_type = GetPointerType(load->buffer->data->type_annotation); + ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) + << "Both store and load buffer should have a pointer type annotation."; + + int index_factor = 1; + if (dst_elem_type.value() != src_elem_type.value()) { + // The only case where src and dst have different dtypes is when the dst shared memory + // is a byte buffer generated by merging dynamic shared memory. + ICHECK(store->buffer.scope() == "shared.dyn"); + ICHECK(dst_elem_type.value() == DataType::UInt(8)); + // BufferStore/Load have the "pointer reinterpret" semantics according to their + // "value" dtype. Their "indices" are supposed to be applied after such pointer cast, + // for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value; + // To replace BufferStore/Load with cp.async, we need to multiply the store index by + // the byte size of the "value" dtype, to get the correct offset into the byte buffer. + index_factor = src_elem_type->bytes(); + } + + if (indices_lanes == 1) { + auto src_offset = load->indices[0]; + auto dst_offset = store->indices[0]; + Array args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)}; + // use arguments size to indicate whether or not to use predicated cp.async + if (predicated) { + args.push_back(predicate_value); + } + return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args)); + } + + // Predicated load don't support vectorized indexing. + if (!predicated) { + // Only some vectorized indexing patterns are supported for now. + auto src_offset = [=]() -> PrimExpr { + if (load->indices[0]->IsInstance()) { + return load->indices[0].as()->base; + } + return PrimExpr(); + }(); + + auto dst_offset = [=]() -> PrimExpr { + if (store->indices[0].as()) { + return store->indices[0].as()->base; + } else if (store->indices[0].as()) { + // The case where the dst buffer is a byte buffer generated by merging dynamic + // shared memory. + // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)] + auto* add = store->indices[0].as(); + if (!add->a->IsInstance()) return PrimExpr(); + if (!add->b->IsInstance()) return PrimExpr(); + return tir::Add(add->a.as()->base, add->b.as()->value); + } + return PrimExpr(); + }(); + + if (src_offset.defined() && dst_offset.defined()) { + return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)})); + } + } + } + } + return StmtMutator::VisitStmt_(store); + } + Stmt VisitStmt_(const BufferStoreNode* store) { if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn")) { if (auto* load = store->value.as()) { - if (load->buffer.scope() == "global") { - ICHECK(load->indices.size() == 1 && store->indices.size() == 1); - ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); - - const int indices_lanes = load->indices[0]->dtype.lanes(); - const int bytes = indices_lanes * load->buffer->dtype.bytes(); - - if (bytes == 4 || bytes == 8 || bytes == 16) { - auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation); - auto src_elem_type = GetPointerType(load->buffer->data->type_annotation); - ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) - << "Both store and load buffer should have a pointer type annotation."; - - int index_factor = 1; - if (dst_elem_type.value() != src_elem_type.value()) { - // The only case where src and dst have different dtypes is when the dst shared memory - // is a byte buffer generated by merging dynamic shared memory. - ICHECK(store->buffer.scope() == "shared.dyn"); - ICHECK(dst_elem_type.value() == DataType::UInt(8)); - // BufferStore/Load have the "pointer reinterpret" semantics according to their - // "value" dtype. Their "indices" are supposed to be applied after such pointer cast, - // for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value; - // To replace BufferStore/Load with cp.async, we need to multiply the store index by - // the byte size of the "value" dtype, to get the correct offset into the byte buffer. - index_factor = src_elem_type->bytes(); + return InjectPTX(load, store); + } else if (auto* call = store->value.as()) { + // tir.if_then_else is a call to tir::builtin::if_then_else() + if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) { + if (auto* load = call->args[1].as()) { + // Only default value of 0 is supported since 0 is the default value used by cp.async + // ptx. @see section 9.7.8.22.3. of + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations + bool else_value_is_zero = false; + if (auto* b = call->args[2].as()) { + if (auto* f = b->value.as()) { + else_value_is_zero = f->value == 0.0f; + } } - - if (indices_lanes == 1) { - auto src_offset = load->indices[0]; - auto dst_offset = store->indices[0]; - return Evaluate( - Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), - {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), - load->buffer->data, src_offset, PrimExpr(bytes)})); + if (auto* f = call->args[2].as()) { + else_value_is_zero = f->value == 0.0f; } - - // Only some vectorized indexing patterns are supported for now. - auto src_offset = [=]() -> PrimExpr { - if (load->indices[0]->IsInstance()) { - return load->indices[0].as()->base; - } - return PrimExpr(); - }(); - - auto dst_offset = [=]() -> PrimExpr { - if (store->indices[0].as()) { - return store->indices[0].as()->base; - } else if (store->indices[0].as()) { - // The case where the dst buffer is a byte buffer generated by merging dynamic - // shared memory. - // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)] - auto* add = store->indices[0].as(); - if (!add->a->IsInstance()) return PrimExpr(); - if (!add->b->IsInstance()) return PrimExpr(); - return tir::Add(add->a.as()->base, add->b.as()->value); - } - return PrimExpr(); - }(); - - if (src_offset.defined() && dst_offset.defined()) { - return Evaluate( - Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), - {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), - load->buffer->data, src_offset, PrimExpr(bytes)})); + if (else_value_is_zero) { + return InjectPTX(load, store, true, call->args[0]); } } } diff --git a/tests/micro/arduino/test_arduino_workflow.py b/tests/micro/arduino/test_arduino_workflow.py index 42874ad6c3499..73cdd9b85d281 100644 --- a/tests/micro/arduino/test_arduino_workflow.py +++ b/tests/micro/arduino/test_arduino_workflow.py @@ -71,7 +71,7 @@ def test_project_folder_structure(project_dir, project): source_dir = project_dir / "src" assert _get_directory_elements(source_dir) == set( - ["model", "standalone_crt", "model.c", "model.h"] + ["model", "standalone_crt", "platform.c", "platform.h"] ) @@ -82,15 +82,15 @@ def test_project_model_integrity(project_dir, project): ) -def test_model_header_templating(project_dir, project): - # Ensure model.h was templated with correct WORKSPACE_SIZE - with (project_dir / "src" / "model.h").open() as f: - model_h = f.read() - workspace_size_defs = re.findall(r"\#define WORKSPACE_SIZE ([0-9]*)", model_h) +def test_model_platform_templating(project_dir, project): + # Ensure platform.c was templated with correct TVM_WORKSPACE_SIZE_BYTES + with (project_dir / "src" / "platform.c").open() as f: + platform_c = f.read() + workspace_size_defs = re.findall(r"\#define TVM_WORKSPACE_SIZE_BYTES ([0-9]*)", platform_c) assert workspace_size_defs assert len(workspace_size_defs) == 1 - # Make sure the WORKSPACE_SIZE we define is a reasonable size. We don't want + # Make sure the TVM_WORKSPACE_SIZE_BYTES we define is a reasonable size. We don't want # to set an exact value, as this test shouldn't break if an improvement to # TVM causes the amount of memory needed to decrease. workspace_size = int(workspace_size_defs[0]) diff --git a/tests/micro/arduino/testdata/project.ino b/tests/micro/arduino/testdata/project.ino index ebd1c5e0e6506..d7ef155b33f66 100644 --- a/tests/micro/arduino/testdata/project.ino +++ b/tests/micro/arduino/testdata/project.ino @@ -17,11 +17,12 @@ * under the License. */ -#include "src/model.h" +#include "src/platform.h" #include "src/data/yes.c" #include "src/data/no.c" #include "src/data/unknown.c" #include "src/data/silence.c" +#include "src/standalone_crt/include/tvm/runtime/crt/platform.h" void performInference(int8_t input_data[1960], char *data_name) { int8_t output_data[4]; @@ -41,7 +42,7 @@ void performInference(int8_t input_data[1960], char *data_name) { } void setup() { - TVMInitialize(); + TVMPlatformInitialize(); Serial.begin(115200); } diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index 59c4cab88147a..79e4f46e0f932 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -608,5 +608,41 @@ def test_schedule_build_with_cmsis_dependency(workspace_dir, board, microtvm_deb assert "CMSIS-NN/Include" in cmake_content +@tvm.testing.requires_micro +def test_debugging_enabled(workspace_dir): + """Test debugging enabled for LED. `verbose=True` in project option enables + debugging. For this test a physical board(nucleo_l4r5zi) is used instead of + QEMU since LED config is not available on QEMU. + """ + board = "nucleo_l4r5zi" + project_options = { + "project_type": "host_driven", + "board": board, + "verbose": True, + } + shape = (10,) + dtype = "int8" + x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) + xx = relay.multiply(x, x) + z = relay.add(xx, relay.const(np.ones(shape=shape, dtype=dtype))) + func = relay.Function([x], z) + ir_mod = tvm.IRModule.from_expr(func) + + runtime = Runtime("crt", {"system-lib": True}) + executor = Executor("aot") + target = tvm.micro.testing.get_target("zephyr", board) + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build(ir_mod, target=target, runtime=runtime, executor=executor) + + project = tvm.micro.generate_project( + str(utils.TEMPLATE_PROJECT_DIR), + mod, + workspace_dir / "project", + project_options, + ) + project.build() + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/micro/zephyr/test_zephyr_aot_exec_standalone.py b/tests/micro/zephyr/test_zephyr_aot_exec_standalone.py index 16c1f9e30814a..8c6bc272f0f0b 100644 --- a/tests/micro/zephyr/test_zephyr_aot_exec_standalone.py +++ b/tests/micro/zephyr/test_zephyr_aot_exec_standalone.py @@ -63,7 +63,9 @@ def test_tflite(workspace_dir, board, microtvm_debug, serial_number): "aot", {"unpacked-api": True, "interface-api": "c", "workspace-byte-alignment": 4} ) runtime = Runtime("crt") - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + with tvm.transform.PassContext( + opt_level=3, config={"tir.disable_vectorize": True, "tir.usmp.enable": True} + ): lowered = relay.build(relay_mod, target, params=params, runtime=runtime, executor=executor) sample_url = "https://github.com/tlc-pack/web-data/raw/967fc387dadb272c5a7f8c3461d34c060100dbf1/testdata/microTVM/data/keyword_spotting_int8_6.pyc.npy" diff --git a/tests/micro/zephyr/utils.py b/tests/micro/zephyr/utils.py index 26f9d6a10e2d0..fdd873c8e8c35 100644 --- a/tests/micro/zephyr/utils.py +++ b/tests/micro/zephyr/utils.py @@ -153,8 +153,8 @@ def generate_project( with tempfile.NamedTemporaryFile() as tar_temp_file: with tarfile.open(tar_temp_file.name, "w:gz") as tf: with tempfile.TemporaryDirectory() as tar_temp_dir: - model_files_path = os.path.join(tar_temp_dir, "include") - os.mkdir(model_files_path) + model_files_path = pathlib.Path(tar_temp_dir) / "include" + model_files_path.mkdir(parents=True) if load_cmsis: loadCMSIS(model_files_path) tf.add( @@ -174,9 +174,9 @@ def generate_project( ) tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir)) - create_header_file("input_data", sample, "include", tf) + create_header_file("input_data", sample, "include/tvm", tf) create_header_file( - "output_data", np.zeros(shape=output_shape, dtype=output_type), "include", tf + "output_data", np.zeros(shape=output_shape, dtype=output_type), "include/tvm", tf ) project, project_dir = build_project( diff --git a/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py b/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py index 99bd273115a7e..1a00e01b60310 100644 --- a/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py +++ b/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py @@ -476,16 +476,16 @@ class ModuleBefore: def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_encoded_3: T.Buffer(112, "uint8"), ethosu_write: T.Buffer(43672, "int8")) -> None: # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main", "from_legacy_te_schedule": True}) - ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32") - ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32") - ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32") - ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32") - nn = T.var("int32") - nn_1 = T.var("int32") - nn_2 = T.var("int32") - nn_3 = T.var("int32") - nn_4 = T.var("int32") - nn_5 = T.var("int32") + ax0_ax1_fused_ax2_fused_ax3_fused = T.int32() + ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32() + ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32() + ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.int32() + nn = T.int32() + nn_1 = T.int32() + nn_2 = T.int32() + nn_3 = T.int32() + nn_4 = T.int32() + nn_5 = T.int32() # body placeholder_d_global = T.decl_buffer([208], "uint8") placeholder_d_global_1 = T.decl_buffer([112], "uint8") @@ -524,16 +524,16 @@ class ModuleAfter: def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_encoded_3: T.Buffer(112, "uint8"), ethosu_write: T.Buffer(43672, "int8")) -> None: # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main", "from_legacy_te_schedule": True}) - ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32") - ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32") - ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32") - ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32") - nn = T.var("int32") - nn_1 = T.var("int32") - nn_2 = T.var("int32") - nn_3 = T.var("int32") - nn_4 = T.var("int32") - nn_5 = T.var("int32") + ax0_ax1_fused_ax2_fused_ax3_fused = T.int32() + ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32() + ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32() + ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.int32() + nn = T.int32() + nn_1 = T.int32() + nn_2 = T.int32() + nn_3 = T.int32() + nn_4 = T.int32() + nn_5 = T.int32() # body placeholder_d_global = T.decl_buffer([208], "uint8") placeholder_d_global_1 = T.decl_buffer([112], "uint8") @@ -579,15 +579,15 @@ class ModuleBefore: def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_1: T.Buffer(256, "int8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_2: T.Buffer(256, "int8"), placeholder_3: T.Buffer(256, "int8"), ethosu_write: T.Buffer(46200, "int8")) -> None: # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main", "from_legacy_te_schedule": True}) - ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32") - ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32") - ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32") - nn = T.var("int32") - nn_1 = T.var("int32") - nn_2 = T.var("int32") - nn_3 = T.var("int32") - nn_4 = T.var("int32") - nn_5 = T.var("int32") + ax0_ax1_fused_ax2_fused_ax3_fused = T.int32() + ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32() + ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32() + nn = T.int32() + nn_1 = T.int32() + nn_2 = T.int32() + nn_3 = T.int32() + nn_4 = T.int32() + nn_5 = T.int32() # body placeholder_d_d_global = T.decl_buffer([208], "uint8") placeholder_d_d_global_1 = T.decl_buffer([112], "uint8") @@ -629,15 +629,15 @@ class ModuleAfter: def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded: T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"), placeholder_1: T.Buffer(256, "int8"), placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_2: T.Buffer(256, "int8"), placeholder_3: T.Buffer(256, "int8"), ethosu_write: T.Buffer(46200, "int8")) -> None: # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main", "from_legacy_te_schedule": True}) - ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32") - ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32") - ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32") - nn = T.var("int32") - nn_1 = T.var("int32") - nn_2 = T.var("int32") - nn_3 = T.var("int32") - nn_4 = T.var("int32") - nn_5 = T.var("int32") + ax0_ax1_fused_ax2_fused_ax3_fused = T.int32() + ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32() + ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32() + nn = T.int32() + nn_1 = T.int32() + nn_2 = T.int32() + nn_3 = T.int32() + nn_4 = T.int32() + nn_5 = T.int32() # body placeholder_d_d_global = T.decl_buffer([208], "uint8") placeholder_d_d_global_1 = T.decl_buffer([112], "uint8") diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py b/tests/python/contrib/test_ethosu/test_merge_constants.py index 909f9fe67365e..624bef00c7f8b 100644 --- a/tests/python/contrib/test_ethosu/test_merge_constants.py +++ b/tests/python/contrib/test_ethosu/test_merge_constants.py @@ -650,18 +650,18 @@ class InputModule: def main(buffer2: T.Buffer((128,), "uint8"), buffer3: T.Buffer((32,), "uint8"), buffer4: T.Buffer((112,), "uint8"), buffer5: T.Buffer((32,), "uint8"), buffer6: T.Buffer((112,), "uint8"), buffer7: T.Buffer((32,), "uint8"), buffer8: T.Buffer((112,), "uint8"), buffer9: T.Buffer((32,), "uint8")) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - v1a = T.var("int32") - v1b = T.var("int32") - v1c = T.var("int32") - v2a = T.var("int32") - v2b = T.var("int32") - v2c = T.var("int32") - v3a = T.var("int32") - v3b = T.var("int32") - v3c = T.var("int32") - v4a = T.var("int32") - v4b = T.var("int32") - v4c = T.var("int32") + v1a = T.int32() + v1b = T.int32() + v1c = T.int32() + v2a = T.int32() + v2b = T.int32() + v2c = T.int32() + v3a = T.int32() + v3b = T.int32() + v3c = T.int32() + v4a = T.int32() + v4b = T.int32() + v4c = T.int32() buffer1 = T.Buffer([8192], "int8") buffer10 = T.Buffer([2048], "int8") # body @@ -713,14 +713,14 @@ class ReferenceModule: def main(buffer2: T.Buffer((160,), "uint8"), buffer4: T.Buffer((144,), "uint8"), buffer6: T.Buffer((144,), "uint8"), buffer8: T.Buffer((144,), "uint8")) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - v1a = T.var("int32") - v1c = T.var("int32") - v2a = T.var("int32") - v2c = T.var("int32") - v3a = T.var("int32") - v3c = T.var("int32") - v4a = T.var("int32") - v4c = T.var("int32") + v1a = T.int32() + v1c = T.int32() + v2a = T.int32() + v2c = T.int32() + v3a = T.int32() + v3c = T.int32() + v4a = T.int32() + v4c = T.int32() buffer1 = T.Buffer([8192], "int8") buffer10 = T.Buffer([2048], "int8") # body diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0a032843267aa..470a67e86c930 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5398,8 +5398,6 @@ def verify_eyelike(indata, dynamic=False): "test_reduce_sum_negative_axes_keepdims_random", "test_roialign_aligned_true", "test_scatter_elements_with_duplicate_indices", - "test_scatternd_add", - "test_scatternd_multiply", "test_sequence_insert_at_back", "test_sequence_insert_at_front", "test_sequence_map_add_1_sequence_1_tensor", diff --git a/tests/python/integration/test_lower.py b/tests/python/integration/test_lower.py index 1ccdde8b13374..965ab80bebb28 100644 --- a/tests/python/integration/test_lower.py +++ b/tests/python/integration/test_lower.py @@ -136,8 +136,8 @@ def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, handle_c: T.handle) axis_vk * 16 : axis_vk * 16 + 16, ] ) - stride0 = T.var("int32") - stride1 = T.var("int32") + stride0 = T.int32() + stride1 = T.int32() match_buffer_a0 = T.match_buffer( shared_a[ new_axis_vi * 16 : new_axis_vi * 16 + 16, @@ -198,8 +198,8 @@ def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, handle_c: T.handle) axis_vk * 16 : axis_vk * 16 + 16, ] ) - stride0 = T.var("int32") - stride1 = T.var("int32") + stride0 = T.int32() + stride1 = T.int32() match_buffer_b0 = T.match_buffer( shared_b[ new_axis_vj * 16 : new_axis_vj * 16 + 16, @@ -335,8 +335,8 @@ def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, handle_c: T.handle) new_axis_vj * 16 : new_axis_vj * 16 + 16, ] ) - stride0 = T.var("int32") - stride1 = T.var("int32") + stride0 = T.int32() + stride1 = T.int32() wmma_c2 = T.match_buffer( wmma_c[ new_axis_vi * 16 : new_axis_vi * 16 + 16, diff --git a/tests/python/relay/aot/test_aot_create_executor_metadata.py b/tests/python/relay/aot/test_aot_create_executor_metadata.py index 1bc79fe2a607a..804738a7866a0 100644 --- a/tests/python/relay/aot/test_aot_create_executor_metadata.py +++ b/tests/python/relay/aot/test_aot_create_executor_metadata.py @@ -53,7 +53,7 @@ def test_create_executor_metadata_single_func(): class Module: @T.prim_func def __tvm_main__( - a: T.handle, output: T.handle, workspace: T.Ptr(T.uint8), constants: T.Ptr(T.uint8) + a: T.handle, output: T.handle, workspace: T.handle("uint8"), constants: T.handle("uint8") ) -> None: # function attr dict T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind": "llvm", "tag": "", "keys": ["cpu"]}), "input_vars": [a], "output_vars": [output], "devices": ["test_device"]}) diff --git a/tests/python/relay/aot/test_pass_aot_lower_main.py b/tests/python/relay/aot/test_pass_aot_lower_main.py index f2455e97a051b..bc58812cd67ca 100644 --- a/tests/python/relay/aot/test_pass_aot_lower_main.py +++ b/tests/python/relay/aot/test_pass_aot_lower_main.py @@ -178,13 +178,13 @@ def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { def func(a: T.handle, output: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output], "devices": []}) - tmp_read = T.Ptr("uint8", "") + tmp_read = T.handle("uint8", "") # buffer definition tmp_read_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_read) a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) # body - tmp_write: T.Ptr(T.uint8) = output_buffer.data + tmp_write: T.handle("uint8") = output_buffer.data tmp_write_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_write) for i in T.serial(140): tmp_write_1[i] = T.let(tmp_read, a_buffer.data, tmp_read_1[i]) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index efd37f2ecd22e..225210f4d617a 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1983,6 +1983,7 @@ def verify_scatter_nd_with_stack( ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol, atol=atol) + # TODO(vcchernov): check frameworks' int type requirements. ONNX expects int64 only for indice_dtype in ["uint8", "uint16", "uint32"]: data = np.zeros((2, 2)).astype("int64") indices = np.array([[1, 1, 0], [0, 1, 0]]).astype(indice_dtype) @@ -2009,7 +2010,7 @@ def verify_scatter_nd_with_stack( verify_scatter_nd(data, indices, updates, out, mode="add") verify_scatter_nd_with_stack(data, indices, updates, out) - for mode in ["add", "update"]: + for mode in ["update", "add", "mul", "min", "max"]: indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype( indice_dtype ) @@ -2019,10 +2020,20 @@ def verify_scatter_nd_with_stack( out = data.copy() for i in range(indices.shape[1]): for j in range(updates.shape[1]): - if mode == "add": - out[indices[0, i], indices[1, i], j] += updates[i, j] - elif mode == "update": + if mode == "update": out[indices[0, i], indices[1, i], j] = updates[i, j] + elif mode == "add": + out[indices[0, i], indices[1, i], j] += updates[i, j] + elif mode == "mul": + out[indices[0, i], indices[1, i], j] *= updates[i, j] + elif mode == "min": + out[indices[0, i], indices[1, i], j] = min( + out[indices[0, i], indices[1, i], j], updates[i, j] + ) + elif mode == "max": + out[indices[0, i], indices[1, i], j] = max( + out[indices[0, i], indices[1, i], j], updates[i, j] + ) verify_scatter_nd(data, indices, updates, out, mode) verify_scatter_nd_with_stack(data, indices, updates, out, mode) diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py index 025e44889d63c..ccc34837a05aa 100644 --- a/tests/python/topi/python/test_topi_scatter.py +++ b/tests/python/topi/python/test_topi_scatter.py @@ -61,7 +61,7 @@ def check_scatter_nd(data, indices, updates, out, mode="add"): out[0, :] += updates[2, :] check_scatter_nd(data, indices, updates, out) - for mode in ["add", "update"]: + for mode in ["update", "add", "mul", "min", "max"]: updates = np.ones((5, 3)).astype("float64") indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype( "int64" @@ -71,10 +71,20 @@ def check_scatter_nd(data, indices, updates, out, mode="add"): out = data.copy() for i in range(indices.shape[1]): for j in range(updates.shape[1]): - if mode == "add": - out[indices[0, i], indices[1, i], j] += updates[i, j] - elif mode == "update": + if mode == "update": out[indices[0, i], indices[1, i], j] = updates[i, j] + elif mode == "add": + out[indices[0, i], indices[1, i], j] += updates[i, j] + elif mode == "mul": + out[indices[0, i], indices[1, i], j] *= updates[i, j] + elif mode == "min": + out[indices[0, i], indices[1, i], j] = min( + out[indices[0, i], indices[1, i], j], updates[i, j] + ) + elif mode == "max": + out[indices[0, i], indices[1, i], j] = max( + out[indices[0, i], indices[1, i], j], updates[i, j] + ) check_scatter_nd(data, indices, updates, out, mode) diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py index ad970d52c0824..6f66f3a43283b 100644 --- a/tests/python/unittest/test_aot_legalize_packed_call.py +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -35,10 +35,10 @@ def tvm_test_cpacked( @T.prim_func def tir_packed_call() -> None: - A = T.var("handle") - B = T.var("handle") - C = T.var("handle") - device_context = T.var("handle") + A = T.handle() + B = T.handle() + C = T.handle() + device_context = T.handle() # body T.evaluate( T.tvm_call_cpacked( @@ -65,10 +65,10 @@ def tvm_test_cpacked( @T.prim_func def tir_packed_call() -> None: - A = T.var("handle") - B = T.var("handle") - C = T.var("handle") - device_context = T.var("handle") + A = T.handle() + B = T.handle() + C = T.handle() + device_context = T.handle() # body T.evaluate( diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index 9f7eee096362e..e19991b3b83a0 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -21,7 +21,7 @@ @T.prim_func def scalar_func(a: T.handle, b: T.handle): - m = T.var("int32") + m = T.int32() n = 100 A = T.match_buffer(a, (n, m)) B = T.match_buffer(b, (n, m)) @@ -73,7 +73,7 @@ def test_domain_touched_vector(): @T.prim_func def func(a: T.handle, b: T.handle): - n = T.var("int32") + n = T.int32() A = T.match_buffer(a, (n * m,)) B = T.match_buffer(b, (n * m,)) diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py index ddd86347c2ec9..c8edebfd3b875 100644 --- a/tests/python/unittest/test_auto_scheduler_feature.py +++ b/tests/python/unittest/test_auto_scheduler_feature.py @@ -209,9 +209,9 @@ def tir_matmul( ) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - A_flat = T.buffer_decl([16384], dtype="float32", data=A.data) - B_flat = T.buffer_decl([16384], dtype="float32", data=B.data) - C_flat = T.buffer_decl([16384], dtype="float32", data=C.data) + A_flat = T.Buffer([16384], dtype="float32", data=A.data) + B_flat = T.Buffer([16384], dtype="float32", data=B.data) + C_flat = T.Buffer([16384], dtype="float32", data=C.data) # body for x, y in T.grid(128, 128): C_flat[x * 128 + y] = T.float32(0) diff --git a/tests/python/unittest/test_cp_async_in_if_then_else.py b/tests/python/unittest/test_cp_async_in_if_then_else.py new file mode 100644 index 0000000000000..08de5ba34da10 --- /dev/null +++ b/tests/python/unittest/test_cp_async_in_if_then_else.py @@ -0,0 +1,238 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""test the correctness of inject async memory copy from an if_then_else load""" +import tvm +import numpy as np + +from tvm.script import tir as T +import tvm.testing + +expected_cuda_script = r""" +#ifdef _WIN32 + using uint = unsigned int; + using uchar = unsigned char; + using ushort = unsigned short; + using int64_t = long long; + using uint64_t = unsigned long long; +#else + #define uint unsigned int + #define uchar unsigned char + #define ushort unsigned short + #define int64_t long long + #define uint64_t unsigned long long +#endif +extern "C" __global__ void __launch_bounds__(16) main_kernel0(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { + __shared__ float A_shared[64]; + __shared__ float B_shared[64]; + A_shared[((int)threadIdx.x)] = 0.000000e+00f; + B_shared[((int)threadIdx.x)] = 0.000000e+00f; +__asm__ __volatile__("cp.async.commit_group;"); + + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(A_shared + (((int)threadIdx.x) + 16))) + ); + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)(A + (((int)threadIdx.x) * 14))), "n"(4) + ); + } + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(B_shared + (((int)threadIdx.x) + 16))) + ); + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)(B + (((int)threadIdx.x) * 14))), "n"(4) + ); + } +__asm__ __volatile__("cp.async.commit_group;"); + + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(A_shared + (((int)threadIdx.x) + 32))) + ); + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)(A + ((((int)threadIdx.x) * 14) + 1))), "n"(4) + ); + } + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(B_shared + (((int)threadIdx.x) + 32))) + ); + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)(B + ((((int)threadIdx.x) * 14) + 1))), "n"(4) + ); + } +__asm__ __volatile__("cp.async.commit_group;"); + + for (int i = 0; i < 13; ++i) { + bool cse_var_1 = (i < 12); + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) + ); + int src_bytes = cse_var_1 ? 4 : 0; + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2, %3;" + :: "r"(addr), "l"((void*)(A + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(src_bytes) + ); + } +__asm__ __volatile__("cp.async.commit_group;"); + +__asm__ __volatile__("cp.async.wait_group 5;"); + + __syncthreads(); + C[((((int)threadIdx.x) * 16) + i)] = (A_shared[(((i & 3) * 16) + ((int)threadIdx.x))] + B_shared[(((i & 3) * 16) + ((int)threadIdx.x))]); + __syncthreads(); + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) + ); + int src_bytes = cse_var_1 ? 4 : 0; + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2, %3;" + :: "r"(addr), "l"((void*)(B + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(src_bytes) + ); + } +__asm__ __volatile__("cp.async.commit_group;"); + + } +__asm__ __volatile__("cp.async.wait_group 2;"); + + __syncthreads(); + C[((((int)threadIdx.x) * 16) + 13)] = (A_shared[(((int)threadIdx.x) + 16)] + B_shared[(((int)threadIdx.x) + 16)]); +__asm__ __volatile__("cp.async.wait_group 1;"); + + __syncthreads(); + C[((((int)threadIdx.x) * 16) + 14)] = (A_shared[(((int)threadIdx.x) + 32)] + B_shared[(((int)threadIdx.x) + 32)]); +__asm__ __volatile__("cp.async.wait_group 0;"); + + __syncthreads(); + C[((((int)threadIdx.x) * 16) + 15)] = (A_shared[(((int)threadIdx.x) + 48)] + B_shared[(((int)threadIdx.x) + 48)]); +} + +""" + + +generated_code = "" +support_async = True + + +@tvm.register_func +def tvm_callback_cuda_postproc(code): + global generated_code + global support_async + generated_code = code + # return a dummy code so that device < sm80 could build correctly + if not support_async: + ret = "" + for line in code.split("\n"): + ret += line + "\n" + if line.startswith('extern "C" __global__'): + break + ret += "}" + return ret + return code + + +@tvm.testing.requires_cuda +def test_cp_async_in_if_then_else(): + global support_async + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # At least sm80 is required + support_async = False + + @T.prim_func + def simple_compute( + A: T.Buffer((16, 14), "float32"), + B: T.Buffer((16, 14), "float32"), + C: T.Buffer((16, 16), "float32"), + ): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 16, + annotations={ + "software_pipeline_stage": [0, 0, 3], + "software_pipeline_order": [0, 2, 1], + "software_pipeline_async_stages": [0], + }, + ): + with T.block("compute"): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + A_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + B_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(A_shared[tx, 0]) + A_shared[tx, 0] = T.if_then_else( + 1 <= i and i < 15, A[tx, i - 1], T.float32(0), dtype="float32" + ) + with T.block(): + T.reads(B[tx, i]) + T.writes(B_shared[tx, 0]) + B_shared[tx, 0] = T.if_then_else( + 1 <= i and i < 15, B[tx, i - 1], T.float32(0), dtype="float32" + ) + with T.block(): + T.reads(A_shared[tx, 0], B_shared[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = A_shared[tx, 0] + B_shared[tx, 0] + + mod = tvm.IRModule.from_expr(simple_compute) + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + tvm.build(mod, target="cuda") + + assert generated_code == expected_cuda_script + + if not support_async: + # avoid return dummy code to other tests + support_async = True + + +if __name__ == "__main__": + test_cp_async_in_if_then_else() diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py index 806ea2d1827bb..11fbeb811ea78 100644 --- a/tests/python/unittest/test_meta_schedule_database.py +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -18,19 +18,19 @@ """Test Meta Schedule Database""" import os.path as osp import tempfile -import pytest -from typing import Callable, Optional, List +from typing import Callable, List, Optional +import pytest import tvm import tvm.testing -from tvm.target import Target from tvm import meta_schedule as ms -from tvm.meta_schedule.database import TuningRecord, Workload -from tvm import tir +from tvm import relay, tir from tvm.ir.module import IRModule +from tvm.meta_schedule.database import TuningRecord, Workload from tvm.script import tir as T +from tvm.target import Target from tvm.tir import Schedule -from tvm import relay + # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off @@ -556,6 +556,7 @@ def call_get_top_k(run_secs_list, database, k): "k,expected", [ (0, []), + (1, [[0.0, 2.0]]), (4, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]), (5, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]), ], diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py index 59de0b0c570a9..0facc9b961e91 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -399,12 +399,12 @@ def GMMCUDATensorCore( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - s0 = T.var("int32") - s0_1 = T.var("int32") - s0_2 = T.var("int32") - s1 = T.var("int32") - s1_1 = T.var("int32") - s1_2 = T.var("int32") + s0 = T.int32() + s0_1 = T.int32() + s0_2 = T.int32() + s1 = T.int32() + s1_1 = T.int32() + s1_2 = T.int32() # body # with T.block("root") Z_wmma_accumulator = T.alloc_buffer([1024, 1024], dtype="float32", scope="wmma.accumulator") diff --git a/tests/python/unittest/test_meta_schedule_runner.py b/tests/python/unittest/test_meta_schedule_runner.py index a79498304b2f5..59c7d5441efa3 100644 --- a/tests/python/unittest/test_meta_schedule_runner.py +++ b/tests/python/unittest/test_meta_schedule_runner.py @@ -131,6 +131,23 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s C[vi] = A[vi] + B[vi] +# A huge matmul that must cause timeout in the timeout test below. +@tvm.script.ir_module +class MatmulHugeModule: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, (4096, 4096), "float32") + B = T.match_buffer(b, (4096, 4096), "float32") + C = T.match_buffer(c, (4096, 4096), "float32") + for i, j, k in T.grid(4096, 4096, 4096): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @@ -372,22 +389,20 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: def test_meta_schedule_rpc_runner_time_out(): - """Test meta schedule RPC Runner time out""" + """Test meta schedule RPC Runner time out by using a super large workload""" - def initializer(): - @register_func("meta_schedule.runner.test_time_out") - def timeout_session_creator( # pylint: disable=unused-variable - rpc_config: RPCConfig, # pylint: disable=unused-argument - ) -> RPCSession: - time.sleep(2) + builder = LocalBuilder() + builder_inputs = [BuilderInput(MatmulHugeModule, Target("llvm"))] + builder_results = builder.build(builder_inputs) + builder_results[0].artifact_path runner_input = RunnerInput( - "test", + builder_results[0].artifact_path, "llvm", [ - TensorInfo("float32", (MATMUL_N, MATMUL_N)), - TensorInfo("float32", (MATMUL_N, MATMUL_N)), - TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (4096, 4096)), + TensorInfo("float32", (4096, 4096)), + TensorInfo("float32", (4096, 4096)), ], ) @@ -408,15 +423,13 @@ def timeout_session_creator( # pylint: disable=unused-variable runner = RPCRunner( rpc_config, evaluator_config, - initializer=initializer, - f_create_session="meta_schedule.runner.test_time_out", ) # Run the module (runner_future,) = runner.run([runner_input]) runner_result = runner_future.result() assert runner_result.error_msg is not None and runner_result.error_msg.startswith( - "RPCRunner: Timeout, killed after" + "RPCRunner: An exception occurred" ) assert runner_result.run_secs is None diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py index ae65cc1a815bd..d09f2a226cba1 100644 --- a/tests/python/unittest/test_meta_schedule_trace_apply.py +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -637,26 +637,26 @@ class Conv2dInt8_tensorcore_scheduled: def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((1, 1, 1, 256), "int64"), p5: T.Buffer((1, 1, 1, 256), "int64"), p6: T.Buffer((1, 1, 1, 256), "int64"), p7: T.Buffer((), "int32"), p8: T.Buffer(1, "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "uint8")) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_s0 = T.var("int32") - A_s0_1 = T.var("int32") - A_s0_2 = T.var("int32") - A_s0_3 = T.var("int32") - A_s1 = T.var("int32") - A_s1_1 = T.var("int32") - A_s1_2 = T.var("int32") - A_s1_3 = T.var("int32") - B_s0 = T.var("int32") - B_s1 = T.var("int32") - C_s0 = T.var("int32") - C_s0_1 = T.var("int32") - C_s0_2 = T.var("int32") - C_s0_3 = T.var("int32") - C_s0_4 = T.var("int32") - C_s1 = T.var("int32") - C_s1_1 = T.var("int32") - C_s1_2 = T.var("int32") - C_s1_3 = T.var("int32") - C_s1_4 = T.var("int32") + A_s0 = T.int32() + A_s0_1 = T.int32() + A_s0_2 = T.int32() + A_s0_3 = T.int32() + A_s1 = T.int32() + A_s1_1 = T.int32() + A_s1_2 = T.int32() + A_s1_3 = T.int32() + B_s0 = T.int32() + B_s1 = T.int32() + C_s0 = T.int32() + C_s0_1 = T.int32() + C_s0_2 = T.int32() + C_s0_3 = T.int32() + C_s0_4 = T.int32() + C_s1 = T.int32() + C_s1_1 = T.int32() + C_s1_2 = T.int32() + C_s1_3 = T.int32() + C_s1_4 = T.int32() # body # with T.block("root") conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256], dtype="int32", scope="shared") diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 734404fb34501..6f79723de456d 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -715,5 +715,30 @@ def test_output_names_many(): } +@tvm.testing.requires_micro +def test_template_files(): + """Check template files in generated model library format.""" + mod = get_conv2d_relay_module() + + executor = Executor("aot", {"unpacked-api": True, "interface-api": "c"}) + runtime = Runtime("crt") + target = tvm.target.target.micro("host") + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + factory = tvm.relay.build(mod, target, runtime=runtime, executor=executor, mod_name="mod") + + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir / "lib.tar" + micro.export_model_library_format(factory, mlf_tar_path) + + tf = tarfile.open(mlf_tar_path) + extract_dir = temp_dir / "extract" + os.mkdir(extract_dir) + tf.extractall(extract_dir) + + assert (extract_dir / "templates" / "crt_config.h.template").is_file() + assert (extract_dir / "templates" / "platform.c.template").is_file() + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 0b6f87b833a32..2598d620bac8f 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -199,8 +199,8 @@ def te_multi_output(): @T.prim_func def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - m = T.var("int32") - n = T.var("int32") + m = T.int32() + n = T.int32() A0 = T.match_buffer(a0, (m, n)) A1 = T.match_buffer(a1, (m, n)) B0 = T.match_buffer(b0, (m, n)) @@ -491,8 +491,8 @@ def tir_argmax_idx_val( var_idx: T.handle, var_val: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - m = T.var("int32") - n = T.var("int32") + m = T.int32() + n = T.int32() idx = T.match_buffer(var_idx, [m, n], dtype="int32") val = T.match_buffer(var_val, [m, n], dtype="float32") argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="int32") @@ -538,8 +538,8 @@ def tir_argmax_val_idx( var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - m = T.var("int32") - n = T.var("int32") + m = T.int32() + n = T.int32() val = T.match_buffer(var_val, [m, n], dtype="float32") idx = T.match_buffer(var_idx, [m, n], dtype="int32") argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32") @@ -711,8 +711,8 @@ def tir_resize2d_symbolic( var_resize: T.handle, ): T.func_attr({"global_symbol": "main", "tir.noalias": True}) - oh = T.var("int64") - ow = T.var("int64") + oh = T.int64() + ow = T.int64() resize = T.match_buffer(var_resize, [T.int64(2), T.int64(3), oh, ow], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), oh, ow): with T.block("resize"): diff --git a/tests/python/unittest/test_tir_analysis_oob.py b/tests/python/unittest/test_tir_analysis_oob.py index 83c0294176243..7c8ceed36e107 100644 --- a/tests/python/unittest/test_tir_analysis_oob.py +++ b/tests/python/unittest/test_tir_analysis_oob.py @@ -44,7 +44,7 @@ def bad_store_loop(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32" @T.prim_func def unknown_bounds(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32")): - N = T.var("int32") + N = T.int32() for i in range(3): B[0, N] = A[1, i] diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 6f591efc2d2d4..2df644d7e1989 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -16,7 +16,6 @@ # under the License. import pytest - import tvm from tvm import te @@ -153,7 +152,7 @@ def test_stmt_constructor(): buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("uint1"))) buffer = tvm.tir.decl_buffer([16], "uint1", data=buffer_var) - x = tvm.tir.BufferStore(buffer, 1, [10]) + x = tvm.tir.BufferStore(buffer, tvm.tir.IntImm("bool", 1), [10]) assert isinstance(x, tvm.tir.BufferStore) assert x.buffer == buffer assert x.buffer.data == buffer_var diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index f887f8877a221..1ee709191c41d 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -193,11 +193,11 @@ class Module: def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "test_fma", "tir.noalias": True}) - n = T.var("int32") - stride = T.var("int32") - stride_1 = T.var("int32") - stride_2 = T.var("int32") - stride_3 = T.var("int32") + n = T.int32() + stride = T.int32() + stride_1 = T.int32() + stride_2 = T.int32() + stride_3 = T.int32() A_1 = T.match_buffer( A, [n], diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 535e0bb3294f1..5bea77ffe35ae 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -93,8 +93,8 @@ def opaque_access(a: T.handle, b: T.handle) -> None: ) for i, j, k in T.grid(64, 2, 8): with T.block(): - Bs_0 = T.var("int32") - Bs_1 = T.var("int32") + Bs_0 = T.int32() + Bs_1 = T.int32() T.reads([]) T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) sub_B = T.match_buffer( @@ -154,8 +154,8 @@ def high_dim_opaque_access(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64)) for i, j, k in T.grid(16, 2, 4): with T.block(): - As_0 = T.var("int32") - As_1 = T.var("int32") + As_0 = T.int32() + As_1 = T.int32() T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) sub_A = T.match_buffer( @@ -200,8 +200,8 @@ def high_dim_opaque_access_with_source_strides(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1]) for i, j, k in T.grid(16, 2, 4): with T.block(): - As_0 = T.var("int32") - As_1 = T.var("int32") + As_0 = T.int32() + As_1 = T.int32() T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) sub_A = T.match_buffer( @@ -254,8 +254,8 @@ def recursive_match(a: T.handle, b: T.handle) -> None: B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], ] ) - As_0 = T.var("int32") - As_1 = T.var("int32") + As_0 = T.int32() + As_1 = T.int32() sub_A = T.match_buffer( A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], (16, 16), @@ -276,8 +276,8 @@ def recursive_match(a: T.handle, b: T.handle) -> None: sub_B[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], ] ) - Ass_0 = T.var("int32") - Ass_1 = T.var("int32") + Ass_0 = T.int32() + Ass_1 = T.int32() sub_sub_A = T.match_buffer( sub_A[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], (4, 4), @@ -355,8 +355,8 @@ def symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) -> None: with T.block(): T.reads([]) T.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) - Bs_0 = T.var("int32") - Bs_1 = T.var("int32") + Bs_0 = T.int32() + Bs_1 = T.int32() sub_A = T.match_buffer(A[i * m : i * m + m, 0:m], (m, m), offset_factor=1) sub_B = T.match_buffer( B[i * n : i * n + 2, 0 : m * 4], (2, m * 4), strides=[Bs_0, Bs_1], offset_factor=1 @@ -470,7 +470,7 @@ def fail_buffer_bind(a: T.handle) -> None: A = T.match_buffer(a, (8, 8)) for i, j in T.grid(8, 2): with T.block(): - stride = T.var("int32") + stride = T.int32() sub_A = T.match_buffer( A[i, j * 4 : j * 4 + 4], (1, 4), strides=[stride, stride], offset_factor=1 ) diff --git a/tests/python/unittest/test_tir_renew_defs.py b/tests/python/unittest/test_tir_renew_defs.py index e14cd5a89832c..e01f5ecb12ead 100644 --- a/tests/python/unittest/test_tir_renew_defs.py +++ b/tests/python/unittest/test_tir_renew_defs.py @@ -88,8 +88,8 @@ def test_match_buffer(): # A and B should be remapped def func_match_buffer(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")): with T.block("root"): - s = T.var("int32") - e = T.var("int32") + s = T.int32() + e = T.int32() # A0 should be remapped A0 = T.match_buffer( A[0:128, 0:128], @@ -157,7 +157,7 @@ def _get_buffer_store_buffer(f): def test_symbolic_func(): @T.prim_func def symbolic_func(a: T.handle, b: T.handle, n: T.int32): - m = T.var("int32") + m = T.int32() A = T.match_buffer(a, (n, m)) B = T.match_buffer(b, (n, m * 2)) for i, j in T.grid(n, m): diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index 766cc3f8671c0..199e822e84b88 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -954,7 +954,7 @@ def argmax_split_body_bufferstore_value_unbound_var( argmax_v0: T.Buffer((128,), "int32"), argmax_v1: T.Buffer((128,), "float32"), ) -> None: - v_unbound = T.var("int32") + v_unbound = T.int32() for i0, i1_0, i1_1 in T.grid(128, 4, 32): with T.block("argmax"): i = T.axis.spatial(128, i0) diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index 143cf87d04e14..fcb4bacbba7b7 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -195,9 +195,9 @@ def tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: ] ) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A_elem_offset = T.var("int32") - B_elem_offset = T.var("int32") - C_elem_offset = T.var("int32") + A_elem_offset = T.int32() + B_elem_offset = T.int32() + C_elem_offset = T.int32() A_sub = T.match_buffer( A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], [16, 16], @@ -267,9 +267,9 @@ def tensorized_batch_matmul_mma( B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], ) T.writes(C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A_elem_offset = T.var("int32") - B_elem_offset = T.var("int32") - C_elem_offset = T.var("int32") + A_elem_offset = T.int32() + B_elem_offset = T.int32() + C_elem_offset = T.int32() A_sub = T.match_buffer( A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), @@ -429,9 +429,9 @@ def annotated_tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: ] ) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A_elem_offset = T.var("int32") - B_elem_offset = T.var("int32") - C_elem_offset = T.var("int32") + A_elem_offset = T.int32() + B_elem_offset = T.int32() + C_elem_offset = T.int32() A_sub = T.match_buffer( A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], [16, 16], @@ -745,9 +745,9 @@ def tensorized_matmul_int64_shape( ] ) T.writes(C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)]) - A_elem_offset = T.var("int64") - B_elem_offset = T.var("int64") - C_elem_offset = T.var("int64") + A_elem_offset = T.int64() + B_elem_offset = T.int64() + C_elem_offset = T.int64() A_sub = T.match_buffer( A[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)], [T.int64(16), T.int64(16)], diff --git a/tests/python/unittest/test_tir_specialize.py b/tests/python/unittest/test_tir_specialize.py index 72666a89ebcb0..ebae827ef5add 100644 --- a/tests/python/unittest/test_tir_specialize.py +++ b/tests/python/unittest/test_tir_specialize.py @@ -22,7 +22,7 @@ @T.prim_func def matmul(a: T.handle, b: T.handle, c: T.handle, n: T.int32) -> None: - m = T.var("int32") + m = T.int32() A = T.match_buffer(a, [m, n]) B = T.match_buffer(b, [m, n]) C = T.match_buffer(c, [m, m]) @@ -51,7 +51,7 @@ def matmul_128(a: T.handle, b: T.handle, c: T.handle) -> None: @T.prim_func def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None: - m = T.var("int32") + m = T.int32() A = T.match_buffer(a, [m, 128]) B = T.match_buffer(b, [m, 128]) C = T.match_buffer(c, [m, m]) @@ -66,8 +66,8 @@ def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None: @T.prim_func def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None: - x = T.var("int32") - m = T.var("int32") + x = T.int32() + m = T.int32() A = T.match_buffer(a, [m, x * 8]) B = T.match_buffer(b, [m, x * 8]) C = T.match_buffer(c, [m, m]) @@ -82,8 +82,8 @@ def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None: @T.prim_func def element_wise(a: T.handle, c: T.handle) -> None: - m = T.var("int32") - n = T.var("int32") + m = T.int32() + n = T.int32() A = T.match_buffer(a, (m, n), "float32") C = T.match_buffer(c, (m, n), "float32") @@ -119,7 +119,7 @@ def element_wise_128_64(a: T.handle, c: T.handle) -> None: @T.prim_func def element_wise_128_n(a: T.handle, c: T.handle) -> None: - n = T.var("int32") + n = T.int32() A = T.match_buffer(a, (128, n), "float32") C = T.match_buffer(c, (128, n), "float32") B = T.alloc_buffer((128, n), "float32") @@ -170,7 +170,7 @@ def mem_copy_m_n_p_n(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int3 @T.prim_func def param_in_arith_exprs(a: T.handle, b: T.handle) -> None: - n = T.var("int32") + n = T.int32() A = T.match_buffer(a, [n // 8, 8], "int32") B = T.match_buffer(b, [n], "int32") for i in range(n - 1): @@ -181,7 +181,7 @@ def param_in_arith_exprs(a: T.handle, b: T.handle) -> None: @T.prim_func def param_in_arith_exprs_n_16(a: T.handle, b: T.handle) -> None: - n = T.var("int32") + n = T.int32() A = T.match_buffer(a, [2, 8], "int32") B = T.match_buffer(b, [16], "int32") for i in range(15): diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index 113d9f047478f..1755a66ec9fb2 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -359,7 +359,7 @@ def func_distributivity_expected( i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: B = T.Buffer((50,), "int32") - cse_var_1 = T.var("int32") + cse_var_1 = T.int32() with T.let(cse_var_1, x * y + x * z): B[i1] = cse_var_1 B[i2] = cse_var_1 @@ -377,7 +377,7 @@ def func_associativity_expected( i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: B = T.Buffer((50,), "int32") - cse_var_1 = T.var("int32") + cse_var_1 = T.int32() with T.let(cse_var_1, (x + y) + z): B[i1] = cse_var_1 B[i2] = cse_var_1 diff --git a/tests/python/unittest/test_tir_transform_hoist_expression.py b/tests/python/unittest/test_tir_transform_hoist_expression.py index 77862f64d6291..ca37915597a5c 100644 --- a/tests/python/unittest/test_tir_transform_hoist_expression.py +++ b/tests/python/unittest/test_tir_transform_hoist_expression.py @@ -447,7 +447,7 @@ class TestHoistLetExpr(BaseBeforeAfter): @T.prim_func def before(A: T.Buffer((4, 4), "float32")): for i, j in T.grid(4, 4): - x = T.var("float32") + x = T.float32() A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32")) @T.prim_func @@ -466,7 +466,7 @@ class TestSuppressHoistLetExpr(BaseBeforeAfter): @T.prim_func def before(A: T.Buffer((4, 4), "float32")): for i, j in T.grid(4, 4): - x = T.var("float32") + x = T.float32() A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32")) expected = before diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index 05d71de5bca69..758a395da6d77 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -424,7 +424,7 @@ def test_buffer_conditional_lowering(): """ @T.prim_func - def before(A: T.Ptr("float32")): + def before(A: T.handle("float32")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in range(1): A_1 = T.Buffer((1,), data=A) diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 29623b498f432..39009164e708f 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -139,7 +139,7 @@ def main(): T.func_attr({"from_legacy_te_schedule": True}) # If a pointer defined using a LetStmt, - A_data: T.Ptr("int32") = T.call_extern("dummy_extern_function", dtype="handle") + A_data: T.handle("int32") = T.call_extern("dummy_extern_function", dtype="handle") # and a buffer is backed by that pointer, A = T.decl_buffer([1], dtype="float32", data=A_data) diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 4766022121df2..c46754fb17422 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -689,12 +689,12 @@ class TestLetBufferRewrite(BaseCompare): """ def before() -> None: - A_data: T.Ptr("int32") = T.call_extern("dummy_func", dtype="handle") + A_data: T.handle("int32") = T.call_extern("dummy_func", dtype="handle") A = T.Buffer([8], "int32", data=A_data) A[0:8] = T.broadcast(42, 8) def expected() -> None: - A_data: T.Ptr("int32x8") = T.call_extern("dummy_func", dtype="handle") + A_data: T.handle("int32x8") = T.call_extern("dummy_func", dtype="handle") A = T.Buffer([1], "int32x8", data=A_data) A[0] = T.broadcast(42, 8) diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 5bbedd3492598..58f37f04967d9 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -144,20 +144,20 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: @tvm.script.ir_module class LinearStructurePlanned: @T.prim_func - def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr("uint8"), slow_memory_1_var: T.Ptr("uint8"), output: T.handle) -> None: + def __tvm_main__(input: T.handle, fast_memory_0_var: T.handle("uint8"), slow_memory_1_var: T.handle("uint8"), output: T.handle) -> None: fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_9_let: T.Ptr("int8") = T.address_of(slow_memory_1_buffer_var[1117472], dtype="handle") - sid_8_let: T.Ptr("int8") = T.address_of(slow_memory_1_buffer_var[0], dtype="handle") + sid_9_let: T.handle("int8") = T.address_of(slow_memory_1_buffer_var[1117472], dtype="handle") + sid_8_let: T.handle("int8") = T.address_of(slow_memory_1_buffer_var[0], dtype="handle") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8_let, output, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) @T.prim_func - def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr("uint8"), slow_memory_7_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.handle("uint8"), slow_memory_7_var: T.handle("uint8")) -> None: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) @@ -174,7 +174,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T_cast_7[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3] = T.cast(tensor_2_let[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3], "int16") @T.prim_func - def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.Ptr("uint8"), slow_memory_3_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.handle("uint8"), slow_memory_3_var: T.handle("uint8")) -> None: placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") @@ -185,7 +185,7 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.Ptr("uint8"), slow_memory_5_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.handle("uint8"), slow_memory_5_var: T.handle("uint8")) -> None: placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16") placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32") @@ -380,7 +380,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @tvm.script.ir_module class ResnetStructurePlanned: @T.prim_func - def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle("uint8")) -> None: placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") @@ -390,7 +390,7 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle("uint8")) -> None: placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") @@ -414,7 +414,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T_cast_7[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3_let[ax3_inner_4] + placeholder_26[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8") @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.handle("uint8")) -> None: placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") @@ -437,7 +437,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s T_add_1[ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3] = T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2_let[ax3_inner_3] + placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136 @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle("uint8")) -> None: placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") @@ -459,7 +459,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_let[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func - def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.Ptr("uint8")) -> None: + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle("uint8")) -> None: placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") @@ -481,15 +481,15 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1_let[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") @T.prim_func - def __tvm_main__(input: T.handle, global_workspace_0_var: T.Ptr("uint8"), output: T.handle) -> None: + def __tvm_main__(input: T.handle, global_workspace_0_var: T.handle("uint8"), output: T.handle) -> None: global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) - sid_2_let: T.Ptr("int8") = T.address_of(global_workspace_0_buffer_var[5760000], dtype="handle") - sid_6_let: T.Ptr("int8") = T.address_of(global_workspace_0_buffer_var[0], dtype="handle") - sid_7_let: T.Ptr("int8") = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") - sid_8_let: T.Ptr("int8") = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") + sid_2_let: T.handle("int8") = T.address_of(global_workspace_0_buffer_var[5760000], dtype="handle") + sid_6_let: T.handle("int8") = T.address_of(global_workspace_0_buffer_var[0], dtype="handle") + sid_7_let: T.handle("int8") = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") + sid_8_let: T.handle("int8") = T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32")) @@ -557,7 +557,7 @@ def __tvm_main__(input: T.handle, output: T.handle) -> None: @tvm.script.ir_module class TensorIntrinStructurePlanned: @T.prim_func - def tensor_intrin_primfunc(global_workspace_1_var: T.Ptr("uint8")) -> None: + def tensor_intrin_primfunc(global_workspace_1_var: T.handle("uint8")) -> None: global_workspace_1_buffer_var = T.match_buffer( global_workspace_1_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16 ) @@ -576,7 +576,7 @@ def tensor_intrin_primfunc(global_workspace_1_var: T.Ptr("uint8")) -> None: @T.prim_func def __tvm_main__( - input: T.handle, global_workspace_1_var: T.Ptr("uint8"), output: T.handle + input: T.handle, global_workspace_1_var: T.handle("uint8"), output: T.handle ) -> None: global_workspace_1_buffer_var = T.match_buffer( global_workspace_1_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16 diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index d2f275ac3d5f8..2713669bd3c30 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -511,7 +511,7 @@ def test_report_error_root_block(): def test_load_var(): def load_var_multiple() -> None: - d = T.var("float32") + d = T.float32() d[2] = d[2, 1] # error cannot provide two indices to load check_error(load_var_multiple, 3) @@ -519,7 +519,7 @@ def load_var_multiple() -> None: def test_store_var(): def store_var_multiple() -> None: - d = T.var("float32") + d = T.float32() d[2, 1] = d[1] # error cannot provide two indices to store check_error(store_var_multiple, 3) diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 85d2e808b3d83..889f0c9eda33b 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -52,7 +52,7 @@ def test_ir_builder_tir_primfunc_complete(): with IRBuilder() as ib: with T.prim_func(): T.arg("a", T.handle()) - T.arg("b", T.var("int64")) + T.arg("b", T.int64()) T.arg("c", T.Buffer((128, 128), "float32")) d = T.arg("d", T.handle()) e = T.arg("e", T.Buffer((1024,), "int8")) @@ -119,12 +119,12 @@ def test_ir_builder_tir_block_base(): def test_ir_builder_tir_block_complete(): with IRBuilder() as ib: - a = T.var("int64", "a") + a = T.int64() b = T.Buffer((128, 128), "float32") c = T.Buffer((128, 128), "float32") - d = T.var("int32", "d") + d = T.int32() e = T.Buffer((128, 128), "float32") - f = T.var("int32", "f") + f = T.int32() with T.block("block"): T.where(a > 1) T.reads(b[0:16, 0:16]) @@ -169,10 +169,10 @@ def test_ir_builder_tir_block_complete(): def test_ir_builder_tir_axis(): with IRBuilder() as ib: - a = T.var("int32", "a") - b = T.var("int32", "b") - c = T.var("int32", "c") - d = T.var("int32", "d") + a = T.int32() + b = T.int32() + c = T.int32() + d = T.int32() with T.block("block"): T.axis.spatial(8, a) T.axis.reduce(16, b) @@ -269,15 +269,13 @@ def test_ir_builder_tir_for(): def test_ir_builder_tir_assert(): with IRBuilder() as ib: - with T.Assert(T.var("int32", name="a") == 0, message="a is 0"): + with T.Assert(T.int32() == 0, message="a is 0"): T.evaluate(0) # the assert generated by IRBuilder assert_actual = ib.get() # the expected assert statement - assert_expected = tir.AssertStmt( - T.var("int32", name="a") == 0, tir.StringImm("a is 0"), tir.Evaluate(0) - ) + assert_expected = tir.AssertStmt(T.int32() == 0, tir.StringImm("a is 0"), tir.Evaluate(0)) # Check if the generated ir is expected assert_structural_equal(assert_actual, assert_expected, map_free_vars=True) @@ -285,13 +283,13 @@ def test_ir_builder_tir_assert(): def test_ir_builder_tir_let(): with IRBuilder() as ib: - with T.let(T.var("int32", name="a"), tir.IntImm("int32", 2)): + with T.let(T.int32(), tir.IntImm("int32", 2)): T.evaluate(0) # the let binding generated by IRBuilder let_actual = ib.get() # the expected Let statement - let_expected = tir.LetStmt(T.var("int32", name="a"), tir.IntImm("int32", 2), tir.Evaluate(0)) + let_expected = tir.LetStmt(T.int32(), tir.IntImm("int32", 2), tir.Evaluate(0)) # Check if the generated ir is expected assert_structural_equal(let_actual, let_expected, map_free_vars=True) @@ -381,7 +379,7 @@ def test_ir_builder_tir_allocate_const(): def test_ir_builder_tir_while(): with IRBuilder() as ib: - with T.While(T.var("int32", "x") > 0): + with T.While(T.int32() > 0): T.evaluate(0) # the while generated by IRBuilder @@ -396,7 +394,7 @@ def test_ir_builder_tir_while(): def test_ir_builder_tir_if_then_else(): with IRBuilder() as ib: - with T.If(T.var("int32", "c") < 12): + with T.If(T.int32() < 12): with T.Then(): T.evaluate(T.int32(0)) with T.Else(): @@ -418,7 +416,7 @@ def test_ir_builder_tir_if_then_else(): def test_ir_builder_tir_buffer_store(): buffer_a = T.Buffer((10, 10), "float32") - i = T.var("int32", "x") + i = T.int32() with IRBuilder() as ib: T.buffer_store(buffer_a, 0.1, [0, i]) diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index e96ae4da8c2e4..20be6d1498086 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -40,7 +40,7 @@ def test_tir_buffer_proxy(): def test_tir_ptr_proxy(): - ptr_0 = T.Ptr("int32", "global") + ptr_0 = T.handle("int32", "global") assert ( isinstance(ptr_0, tir.Var) and ptr_0.dtype == "handle" @@ -49,7 +49,7 @@ def test_tir_ptr_proxy(): and ptr_0.type_annotation.storage_scope == "global" ) - ptr_1 = T.Ptr("float32", "shared") + ptr_1 = T.handle("float32", "shared") assert ( isinstance(ptr_1, tir.Var) and ptr_1.dtype == "handle" diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 6f96b3a3dd31a..13aaacb3b7584 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -117,9 +117,9 @@ def test_block_realize(): _assert_print( obj, """ -i = T.var("int32") -j = T.var("int32") -k = T.var("int32") +i = T.int32() +j = T.int32() +k = T.int32() with T.block("block"): vi = T.axis.spatial(128, i) vj = T.axis.spatial(64, j) @@ -248,13 +248,13 @@ def test_for(): def test_let_stmt(): with IRBuilder() as ib: - with T.let(T.var("float32"), T.float32(10)): + with T.let(T.float32(), T.float32(10)): T.evaluate(0) obj = ib.get() _assert_print( obj, """ -v = T.var("float32") +v = T.float32() with T.let(v, T.float32(10)): T.evaluate(0) """, @@ -291,14 +291,14 @@ def test_assert_stmt(): def test_while(): with IRBuilder() as ib: - x = T.var("int32") + x = T.int32() with T.While(x < 10): T.evaluate(0) obj = ib.get() _assert_print( obj, """ -v = T.var("int32") +v = T.int32() while v < 10: T.evaluate(0) """, @@ -410,7 +410,7 @@ def test_seq_stmt(): def test_if_then_else(): with IRBuilder() as ib: - with T.If(T.var("int32") == 1): + with T.If(T.int32() == 1): with T.Then(): T.evaluate(0) @@ -418,7 +418,7 @@ def test_if_then_else(): _assert_print( obj, """ -v = T.var("int32") +v = T.int32() if v == 1: T.evaluate(0) """, @@ -458,7 +458,7 @@ def test_var(): _assert_print( a, """ -a = T.var("float32") +a = T.float32() a""", ) @@ -468,7 +468,7 @@ def test_size_var(): _assert_print( a, """ -a = T.var("float32") +a = T.float32() a""", ) @@ -478,7 +478,7 @@ def test_iter_var(): _assert_print( a, """ -a = T.var("int32") +a = T.int32() T.iter_var(a, T.Range(0, 8), "DataPar", "") """, ) @@ -494,7 +494,7 @@ def test_cast(): _assert_print( obj, """ -a = T.var("float32") +a = T.float32() T.Cast("float64", a) """, ) @@ -521,15 +521,15 @@ def test_binary_arith(): obj = op(a, b) if sign.isalpha(): expected = """ -a = T.var("float32") -b = T.var("float32") +a = T.float32() +b = T.float32() T.{}(a, b)""".format( sign ) else: expected = """ -a = T.var("float32") -b = T.var("float32") +a = T.float32() +b = T.float32() a {} b""".format( sign ) @@ -537,28 +537,28 @@ def test_binary_arith(): def test_logical(): - a = T.var("bool", "a") - b = T.var("bool", "b") + a = tir.Var("a", "bool") + b = tir.Var("b", "bool") _assert_print( tir.And(a, b), """ -a = T.var("bool") -b = T.var("bool") +a = T.bool() +b = T.bool() a and b """, ) _assert_print( tir.Or(a, b), """ -a = T.var("bool") -b = T.var("bool") +a = T.bool() +b = T.bool() a or b """, ) _assert_print( tir.Not(a), """ -a = T.var("bool") +a = T.bool() not a """, ) @@ -579,7 +579,7 @@ def test_ramp(): _assert_print( obj, """ -a = T.var("int32") +a = T.int32() T.Ramp(a, 1, 32) """, ) @@ -601,7 +601,7 @@ def test_let_expr(): _assert_print( obj, """ -x = T.var("int32") +x = T.int32() T.let(x, 1, x + 1) """, ) @@ -674,7 +674,7 @@ def test_prim_type(): def test_pointer_type(): obj = ir.PointerType(ir.PrimType("int32"), "global") - _assert_print(obj, 'T.Ptr("int32", "global")') + _assert_print(obj, 'T.handle("int32", "global")') def test_tuple_type(): diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 05a3270d158b5..48a59994690b8 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -204,30 +204,30 @@ def mmult( arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") arg2_code: T.int32 = buf_type_ids[2] - A_data: T.Ptr("int32") = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + A_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 1, dtype="handle") T.attr(A_data, "storage_alignment", 128) A = T.Buffer([1024 * 1024], dtype="int32", data=A_data) - buf0_shape_data: T.Ptr("int32") = T.tvm_struct_get(arg0, 0, 2, dtype="handle") + buf0_shape_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 2, dtype="handle") buf0_shape = T.Buffer([2], dtype="int32", data=buf0_shape_data) - buf0_strides_data: T.Ptr("int32") = T.tvm_struct_get(arg0, 0, 3, dtype="handle") + buf0_strides_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 3, dtype="handle") buf0_strides = T.Buffer([2], dtype="int32", data=buf0_strides_data) dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") - B_data: T.Ptr("int32") = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + B_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 1, dtype="handle") T.attr(B_data, "storage_alignment", 128) B = T.Buffer([1024 * 1024], dtype="int32", data=B_data) - buf1_shape_data: T.Ptr("int32") = T.tvm_struct_get(arg1, 0, 2, dtype="handle") + buf1_shape_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 2, dtype="handle") buf1_shape = T.Buffer([2], dtype="int32", data=buf1_shape_data) - buf1_strides_data: T.Ptr("int32") = T.tvm_struct_get(arg1, 0, 3, dtype="handle") + buf1_strides_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 3, dtype="handle") buf1_strides = T.Buffer([2], dtype="int32", data=buf1_strides_data) - C_data: T.Ptr("int32") = T.tvm_struct_get(arg2, 0, 1, dtype="handle") + C_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 1, dtype="handle") T.attr(C_data, "storage_alignment", 128) C = T.Buffer([1024 * 1024], dtype="int32", data=C_data) - buf2_shape_data: T.Ptr("int32") = T.tvm_struct_get(arg2, 0, 2, dtype="handle") + buf2_shape_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 2, dtype="handle") buf2_shape = T.Buffer([2], dtype="int32", data=buf2_shape_data) - buf2_strides_data: T.Ptr("int32") = T.tvm_struct_get(arg2, 0, 3, dtype="handle") + buf2_strides_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 3, dtype="handle") buf2_strides = T.Buffer([2], dtype="int32", data=buf2_strides_data) assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( @@ -2238,7 +2238,7 @@ def opt_conv_tensorcore_mod_host( } ) # body - stack_tcode_data: T.Ptr("int32") = T.tvm_stack_alloca("arg_tcode", 10, dtype="handle") + stack_tcode_data: T.handle("int32") = T.tvm_stack_alloca("arg_tcode", 10, dtype="handle") stack_tcode = T.Buffer([9], "int32", data=stack_tcode_data) stack_value: T.handle = T.tvm_stack_alloca("arg_value", 10, dtype="handle") assert num_args == 3, "default_function: num_args should be 3" @@ -2251,25 +2251,25 @@ def opt_conv_tensorcore_mod_host( A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle") T.attr(A, "storage_alignment", 128) - arg0_shape_data: T.Ptr("int64") = T.tvm_struct_get(arg0, 0, 2, dtype="handle") + arg0_shape_data: T.handle("int64") = T.tvm_struct_get(arg0, 0, 2, dtype="handle") arg0_shape = T.Buffer([6], "int64", data=arg0_shape_data) - arg0_strides_data: T.Ptr("int64") = T.tvm_struct_get(arg0, 0, 3, dtype="handle") + arg0_strides_data: T.handle("int64") = T.tvm_struct_get(arg0, 0, 3, dtype="handle") arg0_strides = T.Buffer([6], "int64", data=arg0_strides_data) dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") W: T.handle = T.tvm_struct_get(arg1, 0, 1, dtype="handle") T.attr(W, "storage_alignment", 128) - arg1_shape_data: T.Ptr("int64") = T.tvm_struct_get(arg1, 0, 2, dtype="handle") + arg1_shape_data: T.handle("int64") = T.tvm_struct_get(arg1, 0, 2, dtype="handle") arg1_shape = T.Buffer([6], "int64", data=arg1_shape_data) - arg1_strides_data: T.Ptr("int64") = T.tvm_struct_get(arg1, 0, 3, dtype="handle") + arg1_strides_data: T.handle("int64") = T.tvm_struct_get(arg1, 0, 3, dtype="handle") arg1_strides = T.Buffer([6], "int64", data=arg1_strides_data) Conv: T.handle = T.tvm_struct_get(arg2, 0, 1, dtype="handle") T.attr(Conv, "storage_alignment", 128) - arg2_shape_data: T.Ptr("int64") = T.tvm_struct_get(arg2, 0, 2, dtype="handle") + arg2_shape_data: T.handle("int64") = T.tvm_struct_get(arg2, 0, 2, dtype="handle") arg2_shape = T.Buffer([6], "int64", data=arg2_shape_data) - arg2_strides_data: T.Ptr("int64") = T.tvm_struct_get(arg2, 0, 3, dtype="handle") + arg2_strides_data: T.handle("int64") = T.tvm_struct_get(arg2, 0, 3, dtype="handle") arg2_strides = T.Buffer([6], "int64", data=arg2_strides_data) assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( @@ -2904,10 +2904,10 @@ def constant_folding(a: T.handle) -> None: def simplify_bracket(): @T.prim_func def simplify_bracket() -> None: - a = T.var("int32") - b = T.var("int32") - c = T.var("int32") - d = T.var("int32") + a = T.int32() + b = T.int32() + c = T.int32() + d = T.int32() T.evaluate(a + b * (c + d)) return simplify_bracket @@ -3039,8 +3039,8 @@ def multiple_commreducer() -> None: def func_div_mod(): @T.prim_func def func_div_mod(): - a = T.var("int32") - b = T.var("int32") + a = T.int32() + b = T.int32() T.evaluate(a // b) T.evaluate(a % b) T.evaluate(T.truncmod(a, b)) @@ -3145,7 +3145,7 @@ def func(A: T.Buffer(1, "int32")): def func_T_ptr_let_statement(): @T.prim_func def func_T_ptr_let_statement( - args: T.handle, arg_type_ids_handle: T.Ptr("int32"), num_args: T.int32 + args: T.handle, arg_type_ids_handle: T.handle("int32"), num_args: T.int32 ) -> None: # The T.Ptr declaration in the parameter list should parse # correctly, and should be usable as the data pointer in a buffer. @@ -3157,14 +3157,14 @@ def func_T_ptr_let_statement( # Functions that return a "handle" can be assigned to a T.Ptr # variable. A variable annotated with T.Ptr still has dtype of # T.handle, but has type annotation as a pointer type. - A_data: T.Ptr("float32") = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + A_data: T.handle("float32") = T.tvm_struct_get(arg0, 0, 1, dtype="handle") # The buffer declaration has a data pointer defined earlier in # this function. It should only be defined after the data pointer # has been defined, and should not be hoisted into the header of # the function as other buffer_decl statements can be. A = T.Buffer([1024], dtype="float32", data=A_data) - B_data: T.Ptr("float32") = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + B_data: T.handle("float32") = T.tvm_struct_get(arg1, 0, 1, dtype="handle") B = T.Buffer([1024], dtype="float32", data=B_data) B[0] = A[0] @@ -3266,13 +3266,13 @@ def string_annotation_of_special_chars(): def pointer_type(): @T.prim_func - def func_with_ptr_type_annotations(x: T.Ptr("int32"), y: T.Ptr("int32", "shared")): + def func_with_ptr_type_annotations(x: T.handle("int32"), y: T.handle("int32", "shared")): xx_data = T.allocate([16], "int32", "global") xx = T.Buffer(shape=[16], dtype="int32", scope="global", data=xx_data) yy_data = T.allocate([16], "int32", "shared") yy = T.Buffer(shape=[16], dtype="int32", scope="shared", data=yy_data) - a: T.Ptr("int32") = T.address_of(xx[0], dtype="handle") - b: T.Ptr("int32", "shared") = T.address_of(yy[0], dtype="handle") + a: T.handle("int32") = T.address_of(xx[0], dtype="handle") + b: T.handle("int32", "shared") = T.address_of(yy[0], dtype="handle") T.evaluate(T.call_extern("copy", a, b, dtype="")) return func_with_ptr_type_annotations @@ -3316,7 +3316,7 @@ def buffer_ramp_access(a: T.handle, b: T.handle, c: T.handle) -> None: def let_expression(): @T.prim_func def func(): - x = T.var("int32") + x = T.int32() T.evaluate(T.let(x, 1, x + 1)) return func @@ -3324,7 +3324,7 @@ def func(): def void_ptr(): @T.prim_func - def func(out_ret_value: T.Ptr("void")): + def func(out_ret_value: T.handle("void")): T.evaluate(out_ret_value) return func @@ -3542,8 +3542,8 @@ def func(): def let_stmt_var(): @T.prim_func def func(): - x = T.var("int32") - y = T.var("int32") + x = T.int32() + y = T.int32() with T.let(x, 0): with T.let(y, 0): T.evaluate(0) @@ -3555,8 +3555,8 @@ def func(): def let_stmt_value(): @T.prim_func def func(): - x = T.var("int32") - y = T.var("int32") + x = T.int32() + y = T.int32() with T.let(x, y): with T.let(y, 0): T.evaluate(0) @@ -3630,7 +3630,7 @@ def func(): def test_roundtrip(ir_generator): original = ir_generator() - after_roundtrip = tvm.script.from_source(original.script()) + after_roundtrip = tvm.script.from_source(original.script(show_meta=True)) tvm.ir.assert_structural_equal(original, after_roundtrip, True) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index a840722bea8ce..e4ba1f7950ab7 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -155,9 +155,9 @@ def func_with_sugar(A: T.Buffer(16, "float32")): # dynamic shape gemm @T.prim_func def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle): - N = T.var("int32") - M = T.var("int32") - K = T.var("int32") + N = T.int32() + M = T.int32() + K = T.int32() A = T.match_buffer(a, (N, K), "float32") B = T.match_buffer(b, (K, M), "float32") C = T.match_buffer(c, (N, M), "float32") diff --git a/tests/scripts/task_config_build_minimal_cross_isa.sh b/tests/scripts/task_config_build_minimal_cross_isa.sh index ac556d48ed2c8..1c08cb285d211 100755 --- a/tests/scripts/task_config_build_minimal_cross_isa.sh +++ b/tests/scripts/task_config_build_minimal_cross_isa.sh @@ -24,6 +24,7 @@ cd "$BUILD_DIR" cp ../cmake/config.cmake . echo set\(USE_SORT ON\) >> config.cmake +echo set\(USE_MICRO ON\) >> config.cmake echo set\(USE_RELAY_DEBUG ON\) >> config.cmake echo set\(CMAKE_BUILD_TYPE=Debug\) >> config.cmake echo set\(CMAKE_CXX_FLAGS \"-Werror -Wp,-D_GLIBCXX_ASSERTIONS\"\) >> config.cmake