diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 5d0e94533bf4..58ee78787176 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,27 +1,158 @@ -# Github owner file -# List of code reviewers for TVM modules +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. -# Global reviewers -* @dmlc/tvm-committers +# Github code owners file +# This file is used as a convenient tool to map +# committers' areas of expertise and faciliate the review process. +# +# This may not be the non-comprehensive list and is meant to be +# updated over time. -# LLVM backends -src/codegen/llvm/* @aatluri +# Per ASF policy, committer have global write permission. +# We normally recommend committers to shepherd code in their area of expertise. +* @apache/tvm-committers -# ROCM runtime -src/runtime/rocm/* @aatluri +# Order is important; the last matching pattern takes the most precedence. +# The sub modules should be ordered first by depth. +# Making sure we append new sub-module rules after exisiting modules rules. -# SGX support -src/runtime/sgx/* @nhynes -apps/sgx/* @nhynes +############################## +# Top-level Fallbacks +############################## +include/** @tqchen @jroesch @yzhliu @icemelon9 @junrushao1994 @comaniac @zhiics +src/** @tqchen @jroesch @yzhliu @icemelon9 @junrushao1994 @comaniac @zhiics +apps/** @tqchen @jroesch @yzhliu @icemelon9 @junrushao1994 @comaniac @zhiics +python/** @tqchen @jroesch @yzhliu @icemelon9 @junrushao1994 @comaniac @zhiics + +# Thirdparty license audit +3rdparty/** @tqchen @jroesch +licenses/** @tqchen @jroesch # JVM language -jvm/* @yzhliu +jvm/** @yzhliu + +# Golang +golang/** @srkreddy1238 + +# WASM +web/** @tqchen @jroesch + +# Docker +docker/** @areusch @leandron @jroesch + +# Conda +conda/** @tqchen @junrushao1994 @comaniac + +# CMake +cmake/** @jroesch @tqchen @areusch @junrushao1994 @comaniac + +# rust bindings +rust/** @jroesch @nhynes @nhynes + +# vta +vta/** @tmoreau89 @vegaluisjose + +# docs +docs/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon9 +tutorials/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon9 + +# tests +tests/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon9 + +############################## +# Specific modules +############################## + +# automation related +src/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 +include/tvm/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 +python/tvm/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 + +python/tvm/autotvm/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 + +# node system and reflection +src/node/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac +include/tvm/node/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac + +# ir: Common IR +src/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac +include/tvm/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac +python/tvm/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac + +# tir +src/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were +include/tvm/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were +python/tvm/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were + +# te +src/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were +include/tvm/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were +python/tvm/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were + +# target +src/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi +include/tvm/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi +python/tvm/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi + +# arith: Arithmetic module and simplifiers +src/arith/** @tqchen @junrushao1994 @vinx13 +include/tvm/arith/** @tqchen @junrushao1994 @vinx13 +python/tvm/arith/** @tqchen @junrushao1994 @vinx13 + +# parser +src/parser/** @jroesch @slyubomirsky + +# runtime +src/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 +include/tvm/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 +python/tvm/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 + +# runtime/micro +src/runtime/micro/** @areusch @liangfu @tmoreau89 +src/runtime/crt/** @areusch @liangfu @tmoreau89 +include/tvm/runtime/crt/** @areusch @liangfu @tmoreau89 +include/tvm/runtime/micro/** @areusch @liangfu @tmoreau89 +python/tvm/micro/** @areusch @liangfu @tmoreau89 + +# relay +src/relay/** @jroesch @slyubomirsky @icemelon9 @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 +include/tvm/relay/** @jroesch @slyubomirsky @icemelon9 @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 +python/tvm/relay/** @jroesch @slyubomirsky @icemelon9 @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 + + +# relay/qnn +src/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang +inlcude/tvm/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang +python/tvm/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang + +# relay/backend/contrib: BYOC +src/relay/backend/contrib/** @zhiics @trevor-m @comaniac @mbaret + +# relay/frontends +python/tvm/relay/frontend/** @jwfromm @mbrookhart @srkreddy1238 @siju-samuel @Huyuwei @hlu1 @kazum @PariksheetPinjari909 -# WebGL backends -src/runtime/opengl/* @phisiart -src/codegen/*opengl* @phisiart +# topi: Operator definitions +src/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 +include/tvm/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 +python/tvm/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 -# TOPI -topi/python/topi/* @Laurawly @Huyuwei +# tvm/driver/ +python/tvm/driver/** @leandron @jwfromm @tqchen @jroesch +# tvm/driver/tvmc +python/tvm/driver/tvmc/** @leandron @jwfromm diff --git a/.gitignore b/.gitignore index e24999ef0d5c..7141aaeb192f 100644 --- a/.gitignore +++ b/.gitignore @@ -63,7 +63,7 @@ instance/ # Sphinx documentation docs/_build/ -docs/gen_modules +docs/_staging/ # PyBuilder /target/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000000..3a2c07de458a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Pre-commit hook +# See documentation at: https://pre-commit.com/ +# +# Pre-commit hook to run the sanity checks from Jenkins locally. +# +# Requirements: +# - How to configure: +# - $ pip install pre-commit +# - $ pre-commit install +# - How to prevent running it: +# - git options: --no-verify or -n +# - $ git commit -n -m "YOUR COMMIT MESSAGE" +# - How to run it as standalone +# - $ pre-commit run +# + +default_language_version: + python: python3.8 +fail_fast: True +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-added-large-files + - id: check-merge-conflict + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: local + hooks: + - id: run-black + name: Running Black... + entry: docker/lint.sh python_format + language: system + always_run: true + pass_filenames: false + - id: run-file-checks + name: Checking File Types.... + entry: docker/lint.sh file_type + language: system + always_run: true + pass_filenames: false + - id: run-headers-check + name: Checking ASF License Headers ... + entry: docker/lint.sh asf + language: system + always_run: true + pass_filenames: false + - id: run-headers-check + name: Linting the C++ code ... + entry: docker/lint.sh cpplint + language: system + always_run: true + pass_filenames: false + - id: run-clang-format + name: Checking Clang format ... + entry: docker/lint.sh clang_format + language: system + always_run: true + pass_filenames: false + - id: run-mypy + name: Type Checking with MyPY ... + entry: docker/lint.sh mypy + language: system + always_run: true + pass_filenames: false diff --git a/CMakeLists.txt b/CMakeLists.txt index c56a929e276d..127ba50b3720 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,7 @@ tvm_option(USE_ETHOSN "Build with Arm Ethos-N" OFF) tvm_option(INDEX_DEFAULT_I64 "Defaults the index datatype to int64" ON) tvm_option(USE_LIBBACKTRACE "Build libbacktrace to supply linenumbers on stack traces" AUTO) tvm_option(BUILD_STATIC_RUNTIME "Build static version of libtvm_runtime" OFF) +tvm_option(USE_PAPI "Use Performance Application Programming Interface (PAPI) to read performance counters" OFF) # 3rdparty libraries tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include") @@ -341,11 +342,11 @@ if(USE_GRAPH_RUNTIME AND NOT DEFINED USE_GRAPH_EXECUTOR) endif(USE_GRAPH_RUNTIME AND NOT DEFINED USE_GRAPH_EXECUTOR) # NOTE(areusch): USE_GRAPH_RUNTIME_DEBUG will be deleted in a future release -if(USE_GRAPH_RUNTIME_DEBUG AND NOT DEFINED USE_GRAPH_EXECUTOR_DEBUG) - message(WARNING "USE_GRAPH_RUNTIME_DEBUG renamed to USE_GRAPH_EXECUTOR_DEBUG. Please update your config.cmake") - set(USE_GRAPH_EXECUTOR_DEBUG ${USE_GRAPH_RUNTIME_DEBUG}) +if(USE_GRAPH_RUNTIME_DEBUG AND NOT DEFINED USE_PROFILER) + message(WARNING "USE_GRAPH_RUNTIME_DEBUG renamed to USE_PROFILER. Please update your config.cmake") + set(USE_PROFILER ${USE_GRAPH_RUNTIME_DEBUG}) unset(USE_GRAPH_RUNTIME_DEBUG CACHE) -endif(USE_GRAPH_RUNTIME_DEBUG AND NOT DEFINED USE_GRAPH_EXECUTOR_DEBUG) +endif(USE_GRAPH_RUNTIME_DEBUG AND NOT DEFINED USE_PROFILER) if(USE_GRAPH_EXECUTOR) message(STATUS "Build with Graph Executor support...") @@ -356,10 +357,12 @@ endif(USE_GRAPH_EXECUTOR) # convert old options for profiler if(USE_GRAPH_EXECUTOR_DEBUG) + message(WARNING "USE_GRAPH_EXECUTOR renamed to USE_PROFILER. Please update your config.cmake") unset(USE_GRAPH_EXECUTOR_DEBUG CACHE) set(USE_PROFILER ON) endif() if(USE_VM_PROFILER) + message(WARNING "USE_VM_PROFILER renamed to USE_PROFILER. Please update your config.cmake") unset(USE_VM_PROFILER CACHE) set(USE_PROFILER ON) endif() @@ -376,6 +379,15 @@ if(USE_PROFILER) list(APPEND RUNTIME_SRCS ${RUNTIME_VM_PROFILER_SRCS}) endif(USE_PROFILER) +# Enable ctest if gtest is available +find_path(GTEST_INCLUDE_DIR gtest/gtest.h) +find_library(GTEST_LIB gtest "$ENV{GTEST_LIB}") +if(GTEST_INCLUDE_DIR AND GTEST_LIB) + enable_testing() + include(CTest) + include(GoogleTest) +endif() + # Module rules include(cmake/modules/VTA.cmake) include(cmake/modules/StandaloneCrt.cmake) @@ -422,12 +434,15 @@ else() set(CMAKE_CUDA_STANDARD 14) endif() -add_lib_info(${CMAKE_CURRENT_LIST_DIR}/src/support/libinfo.cc) +set(LIBINFO_FILE ${CMAKE_CURRENT_LIST_DIR}/src/support/libinfo.cc) +add_lib_info(${LIBINFO_FILE}) +list(REMOVE_ITEM COMPILER_SRCS ${LIBINFO_FILE}) add_library(tvm_objs OBJECT ${COMPILER_SRCS}) add_library(tvm_runtime_objs OBJECT ${RUNTIME_SRCS}) +add_library(tvm_libinfo_objs OBJECT ${LIBINFO_FILE}) -add_library(tvm SHARED $ $) +add_library(tvm SHARED $ $ $) set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}") set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}") if(BUILD_STATIC_RUNTIME) @@ -443,14 +458,18 @@ else() set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}") endif() set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}") + target_compile_definitions(tvm_objs PUBLIC DMLC_USE_LOGGING_LIBRARY=) target_compile_definitions(tvm_runtime_objs PUBLIC DMLC_USE_LOGGING_LIBRARY=) +target_compile_definitions(tvm_libinfo_objs PUBLIC DMLC_USE_LOGGING_LIBRARY=) target_compile_definitions(tvm PUBLIC DMLC_USE_LOGGING_LIBRARY=) target_compile_definitions(tvm_runtime PUBLIC DMLC_USE_LOGGING_LIBRARY=) # logging option for libbacktrace include(cmake/modules/Logging.cmake) +include(cmake/modules/contrib/PAPI.cmake) + if(USE_MICRO) # NOTE: cmake doesn't track dependencies at the file level across subdirectories. For the # Unix Makefiles generator, need to add these explicit target-level dependency) @@ -472,19 +491,24 @@ if(USE_RELAY_DEBUG) target_compile_definitions(tvm_objs PRIVATE "TVM_LOG_DEBUG") target_compile_definitions(tvm_runtime_objs PRIVATE "USE_RELAY_DEBUG") target_compile_definitions(tvm_runtime_objs PRIVATE "TVM_LOG_DEBUG") + target_compile_definitions(tvm_libinfo_objs PRIVATE "USE_RELAY_DEBUG") + target_compile_definitions(tvm_libinfo_objs PRIVATE "TVM_LOG_DEBUG") else() target_compile_definitions(tvm_objs PRIVATE "NDEBUG") target_compile_definitions(tvm_runtime_objs PRIVATE "NDEBUG") + target_compile_definitions(tvm_libinfo_objs PRIVATE "NDEBUG") endif(USE_RELAY_DEBUG) if(USE_FALLBACK_STL_MAP) message(STATUS "Building with STL Map...") target_compile_definitions(tvm_objs PRIVATE "USE_FALLBACK_STL_MAP=1") target_compile_definitions(tvm_runtime_objs PRIVATE "USE_FALLBACK_STL_MAP=1") + target_compile_definitions(tvm_libinfo_objs PRIVATE "USE_FALLBACK_STL_MAP=1") else() message(STATUS "Building with TVM Map...") target_compile_definitions(tvm_objs PRIVATE "USE_FALLBACK_STL_MAP=0") target_compile_definitions(tvm_runtime_objs PRIVATE "USE_FALLBACK_STL_MAP=0") + target_compile_definitions(tvm_libinfo_objs PRIVATE "USE_FALLBACK_STL_MAP=0") endif(USE_FALLBACK_STL_MAP) if(BUILD_FOR_HEXAGON) @@ -515,17 +539,23 @@ target_include_directories( target_include_directories( tvm_objs PUBLIC "topi/include") +target_include_directories( + tvm_libinfo_objs + PUBLIC "topi/include") set(CRC16_INCLUDE_PATH "3rdparty/libcrc/include") target_include_directorieS( tvm_objs PRIVATE "${CRC16_INCLUDE_PATH}") +target_include_directorieS( + tvm_libinfo_objs + PRIVATE "${CRC16_INCLUDE_PATH}") target_include_directorieS( tvm_runtime_objs PRIVATE "${CRC16_INCLUDE_PATH}") set(TVM_TEST_LIBRARY_NAME tvm) if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - add_library(tvm_allvisible SHARED $ $) + add_library(tvm_allvisible SHARED $ $ $) target_include_directories(tvm_allvisible PUBLIC "$") target_link_libraries(tvm_allvisible PRIVATE "$") set(TVM_TEST_LIBRARY_NAME tvm_allvisible) @@ -539,26 +569,16 @@ if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") target_compile_definitions(tvm_allvisible PUBLIC DMLC_USE_LOGGING_LIBRARY=) endif() -# Tests -set(TEST_EXECS "") -file(GLOB TEST_SRCS tests/cpp/*.cc) -find_path(GTEST_INCLUDE_DIR gtest/gtest.h) -find_library(GTEST_LIB gtest "$ENV{GTEST_LIB}") - # Create the `cpptest` target if we can find GTest. If not, we create dummy # targets that give the user an informative error message. if(GTEST_INCLUDE_DIR AND GTEST_LIB) - foreach(__srcpath ${TEST_SRCS}) - get_filename_component(__srcname ${__srcpath} NAME) - string(REPLACE ".cc" "" __execname ${__srcname}) - add_executable(${__execname} ${__srcpath}) - list(APPEND TEST_EXECS ${__execname}) - target_include_directories(${__execname} SYSTEM PUBLIC ${GTEST_INCLUDE_DIR}) - target_link_libraries(${__execname} PRIVATE ${TVM_TEST_LIBRARY_NAME} ${GTEST_LIB} pthread dl) - set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1) - set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1) - endforeach() - add_custom_target(cpptest DEPENDS ${TEST_EXECS}) + file(GLOB TEST_SRCS tests/cpp/*.cc) + add_executable(cpptest ${TEST_SRCS}) + target_include_directories(cpptest SYSTEM PUBLIC ${GTEST_INCLUDE_DIR}) + target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} ${GTEST_LIB} gtest_main pthread dl) + set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_ALL 1) + set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1) + gtest_discover_tests(cpptest) elseif(NOT GTEST_INCLUDE_DIR) add_custom_target(cpptest COMMAND echo "Missing Google Test headers in include path" @@ -603,6 +623,7 @@ endif(INSTALL_DEV) # More target definitions if(MSVC) target_compile_definitions(tvm_objs PRIVATE -DTVM_EXPORTS) + target_compile_definitions(tvm_libinfo_objs PRIVATE -DTVM_EXPORTS) target_compile_definitions(tvm_runtime_objs PRIVATE -DTVM_EXPORTS) endif() @@ -619,6 +640,7 @@ if(TVM_IS_DEBUG_BUILD) if(FILE_PREFIX_MAP_SUPPORTED) target_compile_options(tvm PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) target_compile_options(tvm_objs PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) + target_compile_options(tvm_libinfo_objs PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) target_compile_options(tvm_runtime PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) target_compile_options(tvm_runtime_objs PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) endif() @@ -635,3 +657,32 @@ if(APPLE AND TVM_IS_DEBUG_BUILD) VERBATIM ) endif() + +#Caches the build. +#Note that ccache-3.x doesn't support nvcc well, so CUDA kernels may never hit the cache and still +#need to be re-compiled every time. Using ccache 4.0+ can resolve this issue. + +if(USE_CCACHE) # True for AUTO, ON, /path/to/ccache + if("${USE_CCACHE}" STREQUAL "AUTO") # Auto mode + find_program(CCACHE_FOUND ccache) + if(CCACHE_FOUND) + message(STATUS "Found the path to ccache, enabling ccache") + set(PATH_TO_CCACHE ccache) + else() + message(STATUS "Didn't find the path to CCACHE, disabling ccache") + endif(CCACHE_FOUND) + elseif("${USE_CCACHE}" MATCHES ${IS_TRUE_PATTERN}) + find_program(CCACHE_FOUND ccache) + if(CCACHE_FOUND) + message(STATUS "Found the path to ccache, enabling ccache") + set(PATH_TO_CCACHE ccache) + else() + message(FATAL_ERROR "Cannot find ccache. Set USE_CCACHE mode to AUTO or OFF to build without ccache. USE_CCACHE=" "${USE_CCACHE") + endif(CCACHE_FOUND) + else() # /path/to/ccache + set(PATH_TO_CCACHE USE_CCACHE) + message(STATUS "Setting ccache path to " "${PATH_TO_CCACHE}") + endif() + # Set the flag for ccache + set(CXX_COMPILER_LAUNCHER PATH_TO_CCACHE) +endif(USE_CCACHE) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index b26d38574c6f..8398bdd5e0a2 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -22,17 +22,6 @@ contribute to, and influence the direction of the project. We actively invite co See the [community structure document](https://tvm.apache.org/docs/contribute/community.html) for the explanation of community structure and contribution guidelines. -## Mentors - -TVM is now part of the Apache Incubator. -We are fortunate to have the following mentors. - -- Markus Weimer @markusweimer -- Sebastian Schelter @sscdotopen -- Byung-Gon Chun @bgchun -- Henry Saputra @hsaputra -- Timothy Chen @tnachen -- Furkan KAMACI @kamaci ## Committers @@ -66,9 +55,9 @@ We do encourage everyone to work anything they are interested in. - [Andrew Reusch](https://github.com/areusch): @areusch - runtime, microTVM - [Jared Roesch](https://github.com/jroesch) (PMC): @jroesch - relay - [Siju Samuel](https://github.com/siju-samuel): @siju-samuel - frontends -- [Junru Shao](https://github.com/junrushao1994) @junrushao1994 - relay, compiler +- [Junru Shao](https://github.com/junrushao1994) (PMC): @junrushao1994 - relay, compiler - [Haichen Shen](https://github.com/icemelon9) (PMC): @icemelon9 - relay, topi -- [Siva](https://github.com/srkreddy1238): @srkreddy1238 - frontends, golang +- [Siva Rama Krishna Reddy](https://github.com/srkreddy1238): @srkreddy1238 - frontends, golang - [Zhixun Tan](https://github.com/phisiart): @phisiart - opengl, web - [Andrew Tulloch](https://github.com/ajtulloch): @ajtulloch - topi, compiler, runtime - [Luis Vega](https://github.com/vegaluisjose): @vegaluisjose - vta, chisel @@ -77,7 +66,7 @@ We do encourage everyone to work anything they are interested in. - [Jian Weng](https://github.com/were): @were: - hybrid script - [Zhao Wu](https://github.com/FrozenGene): @FrozenGene - runtime, topi, frontends - [Eddie Yan](https://github.com/eqy) (PMC): @eqy - runtime, autotvm, rpc, topi -- [Hao Yu](https://github.com/comaniac): @comaniac - relay, byoc, auto_scheduler +- [Hao Yu](https://github.com/comaniac): @comaniac (PMC) - relay, byoc, auto_scheduler - [Lianmin Zheng](https://github.com/merrymercy) (PMC): @merrymercy - autotvm, auto_scheduler, topi, relay ## Reviewers @@ -115,6 +104,7 @@ We do encourage everyone to work anything they are interested in. - [Xin Liu](https://github.com/Meteorix): @Meteorix - [Yizhi Liu](https://github.com/yzhliu) : @yzhliu - [Hao Lu](https://github.com/hlu1): @hlu1 +- [Eric Lunderberg](https://github.com/Lunderberg): @Lunderberg - [Steven Lyubomirsky](https://github.com/slyubomirsky): @slyubomirsky - [Masahiro Masuda](https://github.com/masahi): @masahi - [Sergey Mironov](https://github.com/grwlf): @grwlf @@ -123,18 +113,21 @@ We do encourage everyone to work anything they are interested in. - [Trevor Morris](https://github.com/trevor-m): @trevor-m - [Tatsuya Nishiyama](https://github.com/nishi-t): @nishi-t - [Leandro Nunes](https://github.com/leandron): @leandron +- [Lily Orth-Smith](https://github.com/electriclilies): @electriclilies - [Wei Pan](https://github.com/wpan11nv): @wpan11nv - [Krzysztof Parzyszek](https://github.com/kparzysz-quic): @kparzysz-quic - [Pariksheet Pinjari](https://github.com/PariksheetPinjari909): @PariksheetPinjari909 - [Josh Pollock](https://github.com/joshpoll): @joshpoll - [Andrew Reusch](https://github.com/areusch): @areusch - [Jared Roesch](https://github.com/jroesch): @jroesch +- [Gustavo Romero](https://github.com/gromero): @gromero - [Giuseppe Rossini](https://github.com/giuseros): @giuseros - [Siju Samuel](https://github.com/siju-samuel): @siju-samuel - [Junru Shao](https://github.com/junrushao1994): @junrushao1994 - [Haichen Shen](https://github.com/icemelon9): @icemelon9 - [Xingjian Shi](https://github.com/sxjscience): @sxjscience -- [Siva](https://github.com/srkreddy1238): @srkreddy1238 +- [Christopher Sidebottom](https://github.com/mousius): @mousius +- [Siva Rama Krishna Reddy](https://github.com/srkreddy1238): @srkreddy1238 - [Dmitriy Smirnov](https://github.com/d-smirnov): @d-smirnov - [Jon Soifer](https://github.com/soiferj): @soiferj - [Zhixun Tan](https://github.com/phisiart): @phisiart @@ -146,6 +139,7 @@ We do encourage everyone to work anything they are interested in. - [Leyuan Wang](https://github.com/Laurawly): @Laurawly - [Alex Weaver](https://github.com/alex-weaver): @alex-weaver - [Logan Weber](https://github.com/weberlo): @weberlo +- [Matt Welsh](https://github.com/mdw-octoml): @mdw-octoml - [Jian Weng](https://github.com/were): @were - [Yong Wu](https://github.com/yongwww): @yongwww - [Zhao Wu](https://github.com/FrozenGene): @FrozenGene @@ -157,3 +151,14 @@ We do encourage everyone to work anything they are interested in. ## List of Contributors - [Full List of Contributors](https://github.com/apache/tvm/graphs/contributors) + +## Mentors + +TVM is now a top-level Apache project. During our Incubator phase, we were fortunate to have the following mentors. + +- Markus Weimer @markusweimer +- Sebastian Schelter @sscdotopen +- Byung-Gon Chun @bgchun +- Henry Saputra @hsaputra +- Timothy Chen @tnachen +- Furkan KAMACI @kamaci diff --git a/Jenkinsfile b/Jenkinsfile old mode 100644 new mode 100755 index f26b148085fb..fa1629205080 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -44,24 +44,41 @@ // // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> -ci_lint = "tlcpack/ci-lint:v0.66" -ci_gpu = "tlcpack/ci-gpu:v0.75" -ci_cpu = "tlcpack/ci-cpu:v0.74" +ci_lint = "tlcpack/ci-lint:v0.67" +ci_gpu = "tlcpack/ci-gpu:v0.77" +ci_cpu = "tlcpack/ci-cpu:v0.77" ci_wasm = "tlcpack/ci-wasm:v0.71" ci_i386 = "tlcpack/ci-i386:v0.73" -ci_qemu = "tlcpack/ci-qemu:v0.05" -ci_arm = "tlcpack/ci-arm:v0.05" +ci_qemu = "tlcpack/ci-qemu:v0.08" +ci_arm = "tlcpack/ci-arm:v0.06" // <--- End of regex-scanned config. +// Parameters to allow overriding (in Jenkins UI), the images +// to be used by a given build. When provided, they take precedence +// over default values above. +properties([ + parameters([ + string(name: 'ci_lint_param', defaultValue: ""), + string(name: 'ci_cpu_param', defaultValue: ""), + string(name: 'ci_gpu_param', defaultValue: ""), + string(name: 'ci_wasm_param', defaultValue: ""), + string(name: 'ci_i386_param', defaultValue: ""), + string(name: 'ci_qemu_param', defaultValue: ""), + string(name: 'ci_arm_param', defaultValue: "") + ]) +]) + // tvm libraries tvm_runtime = "build/libtvm_runtime.so, build/config.cmake" tvm_lib = "build/libtvm.so, " + tvm_runtime // LLVM upstream lib tvm_multilib = "build/libtvm.so, " + - "build/libvta_tsim.so, " + "build/libvta_fsim.so, " + tvm_runtime +tvm_multilib_tsim = "build/libvta_tsim.so, " + + tvm_multilib + // command to start a docker container docker_run = 'docker/bash.sh' // timeout in minutes @@ -107,6 +124,30 @@ def cancel_previous_build() { cancel_previous_build() +stage('Prepare') { + node('CPU') { + // When something is provided in ci_*_param, use it, otherwise default with ci_* + ci_lint = params.ci_lint_param ?: ci_lint + ci_cpu = params.ci_cpu_param ?: ci_cpu + ci_gpu = params.ci_gpu_param ?: ci_gpu + ci_wasm = params.ci_wasm_param ?: ci_wasm + ci_i386 = params.ci_i386_param ?: ci_i386 + ci_qemu = params.ci_qemu_param ?: ci_qemu + ci_arm = params.ci_arm_param ?: ci_arm + + sh """ + echo "Docker images being used in this build:" + echo " ci_lint = ${ci_lint}" + echo " ci_cpu = ${ci_cpu}" + echo " ci_gpu = ${ci_gpu}" + echo " ci_wasm = ${ci_wasm}" + echo " ci_i386 = ${ci_i386}" + echo " ci_qemu = ${ci_qemu}" + echo " ci_arm = ${ci_arm}" + """ + } +} + stage("Sanity Check") { timeout(time: max_time, unit: 'MINUTES') { node('CPU') { @@ -179,7 +220,7 @@ stage('Build') { init_git() sh "${docker_run} ${ci_cpu} ./tests/scripts/task_config_build_cpu.sh" make(ci_cpu, 'build', '-j2') - pack_lib('cpu', tvm_multilib) + pack_lib('cpu', tvm_multilib_tsim) timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${ci_cpu} ./tests/scripts/task_ci_setup.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_unittest.sh" @@ -188,7 +229,7 @@ stage('Build') { sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_tsim.sh" // sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh" // TODO(@jroesch): need to resolve CI issue will turn back on in follow up patch - // sh "${docker_run} ${ci_cpu} ./tests/scripts/task_rust.sh" + sh "${docker_run} ${ci_cpu} ./tests/scripts/task_rust.sh" junit "build/pytest-results/*.xml" } } @@ -213,7 +254,7 @@ stage('Build') { init_git() sh "${docker_run} ${ci_i386} ./tests/scripts/task_config_build_i386.sh" make(ci_i386, 'build', '-j2') - pack_lib('i386', tvm_multilib) + pack_lib('i386', tvm_multilib_tsim) } } }, @@ -282,6 +323,7 @@ stage('Unit Test') { timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${ci_arm} ./tests/scripts/task_ci_setup.sh" sh "${docker_run} ${ci_arm} ./tests/scripts/task_python_unittest.sh" + sh "${docker_run} ${ci_arm} ./tests/scripts/task_python_arm_compute_library.sh" junit "build/pytest-results/*.xml" // sh "${docker_run} ${ci_arm} ./tests/scripts/task_python_integration.sh" } diff --git a/Makefile b/Makefile index 58cb1b69c5b8..8ebbea610f57 100644 --- a/Makefile +++ b/Makefile @@ -81,7 +81,7 @@ FORCE: CMAKE_TARGETS = all runtime vta cpptest crttest define GEN_CMAKE_RULE -%/$(CMAKE_TARGET): %/CMakeCache.txt +%/$(CMAKE_TARGET): %/CMakeCache.txt FORCE @$$(MAKE) -C $$(@D) $(CMAKE_TARGET) endef $(foreach CMAKE_TARGET,$(CMAKE_TARGETS),$(eval $(GEN_CMAKE_RULE))) diff --git a/README.md b/README.md index eec5bfd5797d..09ceb7ab1d07 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,13 @@ TVM works with deep learning frameworks to provide end to end compilation to dif License ------- -© Contributors Licensed under an [Apache-2.0](LICENSE) license. +TVM is licensed under the [Apache-2.0](LICENSE) license. + +Getting Started +--------------- +Check out the [TVM Documentation](https://tvm.apache.org/docs/) site for installation instructions, tutorials, examples, and more. +The [Getting Started with TVM](https://tvm.apache.org/docs/tutorials/get_started/introduction.html) tutorial is a great +place to start. Contribute to TVM ----------------- diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 1331e1a65ca8..e9dd4faba23f 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -34,6 +34,7 @@ #define TVM_LOG_CUSTOMIZE 1 #include "../src/runtime/c_runtime_api.cc" +#include "../src/runtime/container.cc" #include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/dso_library.cc" #include "../src/runtime/file_utils.cc" @@ -67,7 +68,14 @@ #endif #ifdef TVM_VULKAN_RUNTIME -#include "../src/runtime/vulkan/vulkan.cc" +#include "../src/runtime/vulkan/vulkan_buffer.cc" +#include "../src/runtime/vulkan/vulkan_common.cc" +#include "../src/runtime/vulkan/vulkan_device.cc" +#include "../src/runtime/vulkan/vulkan_device_api.cc" +#include "../src/runtime/vulkan/vulkan_instance.cc" +#include "../src/runtime/vulkan/vulkan_module.cc" +#include "../src/runtime/vulkan/vulkan_stream.cc" +#include "../src/runtime/vulkan/vulkan_wrapped_func.cc" #endif #ifdef USE_SORT diff --git a/apps/bundle_deploy/crt_config/crt_config.h b/apps/bundle_deploy/crt_config/crt_config.h index 11086c0e9a15..58f923512d2e 100644 --- a/apps/bundle_deploy/crt_config/crt_config.h +++ b/apps/bundle_deploy/crt_config/crt_config.h @@ -35,9 +35,9 @@ /*! Maximum supported arguments in generated functions */ #define TVM_CRT_MAX_ARGS 10 /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ -#define TVM_CRT_STRLEN_DLTYPE 10 +#define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_STRLEN_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index e897a975de28..be172cea4e53 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -47,7 +47,7 @@ namespace { std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) { std::string untar_cmd; untar_cmd.reserve(512); -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) untar_cmd += "tar -C "; untar_cmd += output_dir; untar_cmd += " -zxf "; @@ -224,7 +224,7 @@ std::vector ListDir(const std::string& dirname) { return vec; } -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) /*! * \brief LinuxShared Creates a linux shared library * \param output The output file name @@ -280,7 +280,7 @@ void WindowsShared(const std::string& output, const std::vector& fi * \param files The files for building */ void CreateShared(const std::string& output, const std::vector& files) { -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) LinuxShared(output, files); #elif defined(_WIN32) WindowsShared(output, files); @@ -290,7 +290,8 @@ void CreateShared(const std::string& output, const std::vector& fil } std::string BuildSharedLibrary(std::string file) { - if (support::EndsWith(file, ".so") || support::EndsWith(file, ".dll")) { + if (support::EndsWith(file, ".so") || support::EndsWith(file, ".dll") || + support::EndsWith(file, ".dylib")) { return file; } diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 5dc84105388b..f5706d315f4a 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -140,7 +140,7 @@ class RPCServer { * \brief ListenLoopProc The listen process. */ void ListenLoopProc() { - TrackerClient tracker(tracker_addr_, key_, custom_addr_); + TrackerClient tracker(tracker_addr_, key_, custom_addr_, port_); while (true) { support::TCPSocket conn; support::SockAddr addr("0.0.0.0", 0); diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h index 1497ab3251be..58824e6aa9af 100644 --- a/apps/cpp_rpc/rpc_tracker_client.h +++ b/apps/cpp_rpc/rpc_tracker_client.h @@ -49,12 +49,17 @@ class TrackerClient { * \brief Constructor. */ TrackerClient(const std::string& tracker_addr, const std::string& key, - const std::string& custom_addr) + const std::string& custom_addr, int port) : tracker_addr_(tracker_addr), key_(key), custom_addr_(custom_addr), + port_(port), gen_(std::random_device{}()), - dis_(0.0, 1.0) {} + dis_(0.0, 1.0) { + if (custom_addr_.empty()) { + custom_addr_ = "null"; + } + } /*! * \brief Destructor. */ @@ -80,7 +85,7 @@ class TrackerClient { std::ostringstream ss; ss << "[" << static_cast(TrackerCode::kUpdateInfo) << ", {\"key\": \"server:" << key_ - << "\"}]"; + << "\", \"addr\": [" << custom_addr_ << ", \"" << port_ << "\"]}]"; tracker_sock_.SendBytes(ss.str()); // Receive status and validate @@ -105,9 +110,6 @@ class TrackerClient { void ReportResourceAndGetKey(int port, std::string* matchkey) { if (!tracker_sock_.IsClosed()) { *matchkey = RandomKey(key_ + ":", old_keyset_); - if (custom_addr_.empty()) { - custom_addr_ = "null"; - } std::ostringstream ss; ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" << port @@ -230,6 +232,7 @@ class TrackerClient { std::string tracker_addr_; std::string key_; std::string custom_addr_; + int port_; support::TCPSocket tracker_sock_; std::set old_keyset_; std::mt19937 gen_; diff --git a/apps/microtvm/arduino/example_project/project.ino b/apps/microtvm/arduino/example_project/project.ino new file mode 100644 index 000000000000..5f5683161e0a --- /dev/null +++ b/apps/microtvm/arduino/example_project/project.ino @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "src/model.h" + +void setup() { + TVMInitialize(); + // If desired, initialize the RNG with random noise + // randomSeed(analogRead(0)); +} + +void loop() { + //TVMExecute(input_data, output_data); +} diff --git a/apps/microtvm/arduino/example_project/src/model.c b/apps/microtvm/arduino/example_project/src/model.c new file mode 100644 index 000000000000..9e7c47f75160 --- /dev/null +++ b/apps/microtvm/arduino/example_project/src/model.c @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "model.h" + +#include "Arduino.h" +#include "standalone_crt/include/tvm/runtime/crt/stack_allocator.h" + +// AOT memory array +static uint8_t g_aot_memory[WORKSPACE_SIZE]; +extern tvm_model_t tvmgen_default_network; +tvm_workspace_t app_workspace; + +// Blink code for debugging purposes +void TVMPlatformAbort(tvm_crt_error_t error) { + TVMLogf("TVMPlatformAbort: 0x%08x\n", error); + for (;;) { +#ifdef LED_BUILTIN + digitalWrite(LED_BUILTIN, HIGH); + delay(250); + digitalWrite(LED_BUILTIN, LOW); + delay(250); + digitalWrite(LED_BUILTIN, HIGH); + delay(250); + digitalWrite(LED_BUILTIN, LOW); + delay(750); +#endif + } +} + +void TVMLogf(const char* msg, ...) {} + +tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { + return StackMemoryManager_Allocate(&app_workspace, num_bytes, out_ptr); +} + +tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { + return StackMemoryManager_Free(&app_workspace, ptr); +} + +unsigned long g_utvm_start_time_micros; +int g_utvm_timer_running = 0; + +tvm_crt_error_t TVMPlatformTimerStart() { + if (g_utvm_timer_running) { + return kTvmErrorPlatformTimerBadState; + } + g_utvm_timer_running = 1; + g_utvm_start_time_micros = micros(); + return kTvmErrorNoError; +} + +tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { + if (!g_utvm_timer_running) { + return kTvmErrorPlatformTimerBadState; + } + g_utvm_timer_running = 0; + unsigned long g_utvm_stop_time = micros() - g_utvm_start_time_micros; + *elapsed_time_seconds = ((double)g_utvm_stop_time) / 1e6; + return kTvmErrorNoError; +} + +tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { + for (size_t i = 0; i < num_bytes; i++) { + buffer[i] = rand(); + } + return kTvmErrorNoError; +} + +void TVMInitialize() { StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE); } + +void TVMExecute(void* input_data, void* output_data) { + int ret_val = tvmgen_default_run_model(input_data, output_data); + if (ret_val != 0) { + TVMPlatformAbort(kTvmErrorPlatformCheckFailure); + } +} diff --git a/apps/microtvm/arduino/example_project/src/model.h b/apps/microtvm/arduino/example_project/src/model.h new file mode 100644 index 000000000000..7381c97e9b3f --- /dev/null +++ b/apps/microtvm/arduino/example_project/src/model.h @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#define WORKSPACE_SIZE $workspace_size_bytes + +#ifdef __cplusplus +extern "C" { +#endif + +void TVMInitialize(); + +/* TODO template this function signature with the input and output + * data types and sizes. For example: + * + * void TVMExecute(uint8_t input_data[9216], uint8_t output_data[3]); + * + * Note this can only be done once MLF has JSON metadata describing + * inputs and outputs. + */ +void TVMExecute(void* input_data, void* output_data); + +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/apps/microtvm/zephyr/aot_demo/crt/crt_config.h b/apps/microtvm/arduino/example_project/src/standalone_crt/crt_config/crt_config.h similarity index 74% rename from apps/microtvm/zephyr/aot_demo/crt/crt_config.h rename to apps/microtvm/arduino/example_project/src/standalone_crt/crt_config/crt_config.h index 9ee315aa1763..cf73103aff8b 100644 --- a/apps/microtvm/zephyr/aot_demo/crt/crt_config.h +++ b/apps/microtvm/arduino/example_project/src/standalone_crt/crt_config/crt_config.h @@ -18,45 +18,38 @@ */ /*! - * \file tvm/runtime/crt_config.h.template - * \brief Template for CRT configuration, to be modified on each target. + * \brief CRT configuration for the host-linked CRT. */ -#ifndef TVM_RUNTIME_CRT_CONFIG_H_ -#define TVM_RUNTIME_CRT_CONFIG_H_ - -#include +#ifndef TVM_RUNTIME_MICRO_CRT_CONFIG_H_ +#define TVM_RUNTIME_MICRO_CRT_CONFIG_H_ /*! Log level of the CRT runtime */ #define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG +/*! Support low-level debugging in MISRA-C runtime */ +#define TVM_CRT_DEBUG 0 + /*! Maximum supported dimension in NDArray */ #define TVM_CRT_MAX_NDIM 6 - /*! Maximum supported arguments in generated functions */ #define TVM_CRT_MAX_ARGS 10 - -/*! Size of the global function registry, in bytes. */ -#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200 +/*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ +#define TVM_CRT_MAX_STRLEN_DLTYPE 10 +/*! Maximum supported string length in function names */ +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 -/*! Maximum packet size, in bytes, including the length header. */ -#define TVM_CRT_MAX_PACKET_SIZE_BYTES (1 * 1024) - -/*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ -#define TVM_CRT_MAX_STRLEN_DLTYPE 10 +/*! Size of the global function registry, in bytes. */ +#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512 -/*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +/*! Maximum packet size, in bytes, including the length header. */ +#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8 * 1024 /*! \brief Maximum length of a PackedFunc function name. */ #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 -/*! \brief Log2 of the page size (bytes) for a virtual memory page. */ -#define TVM_CRT_PAGE_BITS 10 // 1 kB - -/*! \brief Number of pages on device. */ -#define TVM_CRT_MAX_PAGES 300 +// #define TVM_CRT_FRAMER_ENABLE_LOGS -#endif // TVM_RUNTIME_CRT_CONFIG_H_ +#endif // TVM_RUNTIME_MICRO_CRT_CONFIG_H_ diff --git a/apps/microtvm/arduino/host_driven/project.ino b/apps/microtvm/arduino/host_driven/project.ino new file mode 100644 index 000000000000..d394059e1bf5 --- /dev/null +++ b/apps/microtvm/arduino/host_driven/project.ino @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "src/standalone_crt/include/tvm/runtime/crt/microtvm_rpc_server.h" +#include "src/standalone_crt/include/tvm/runtime/crt/logging.h" +microtvm_rpc_server_t server; + +// Called by TVM to write serial data to the UART. +ssize_t write_serial(void* unused_context, const uint8_t* data, size_t size) { + Serial.write(data, size); + return size; +} + +void setup() { + server = MicroTVMRpcServerInit(write_serial, NULL); + TVMLogf("microTVM Arduino runtime - running"); + Serial.begin(115200); + + // If desired, initialize the RNG with random noise + // randomSeed(analogRead(0)); +} + +void loop() { + // Read at most 128 bytes at a time to prevent stack blowup + int to_read = min(Serial.available(), 128); + + uint8_t data[to_read]; + size_t bytes_remaining = Serial.readBytes((char*) data, to_read); + uint8_t* arr_ptr = data; + while (bytes_remaining > 0) { + // Pass the received bytes to the RPC server. + tvm_crt_error_t err = MicroTVMRpcServerLoop(server, &arr_ptr, &bytes_remaining); + if (err != kTvmErrorNoError && err != kTvmErrorFramingShortPacket) { + TVMPlatformAbort(err); + } + } +} diff --git a/apps/microtvm/arduino/host_driven/src/model_support.c b/apps/microtvm/arduino/host_driven/src/model_support.c new file mode 100644 index 000000000000..dfcb031136c5 --- /dev/null +++ b/apps/microtvm/arduino/host_driven/src/model_support.c @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "stdarg.h" + +// Blink code for debugging purposes +void TVMPlatformAbort(tvm_crt_error_t error) { + TVMLogf("TVMPlatformAbort: 0x%08x\n", error); + for (;;) + ; +} + +size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, + va_list args) { + return vsnprintf(out_buf, out_buf_size_bytes, fmt, args); +} + +tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { + if (num_bytes == 0) { + num_bytes = sizeof(int); + } + *out_ptr = malloc(num_bytes); + return (*out_ptr == NULL) ? kTvmErrorPlatformNoMemory : kTvmErrorNoError; +} + +tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { + free(ptr); + return kTvmErrorNoError; +} + +unsigned long g_utvm_start_time_micros; +int g_utvm_timer_running = 0; + +tvm_crt_error_t TVMPlatformTimerStart() { + if (g_utvm_timer_running) { + return kTvmErrorPlatformTimerBadState; + } + g_utvm_timer_running = 1; + g_utvm_start_time_micros = micros(); + return kTvmErrorNoError; +} + +tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { + if (!g_utvm_timer_running) { + return kTvmErrorPlatformTimerBadState; + } + g_utvm_timer_running = 0; + unsigned long g_utvm_stop_time = micros() - g_utvm_start_time_micros; + *elapsed_time_seconds = ((double)g_utvm_stop_time) / 1e6; + return kTvmErrorNoError; +} + +tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { + for (size_t i = 0; i < num_bytes; i++) { + buffer[i] = rand(); + } + return kTvmErrorNoError; +} diff --git a/apps/microtvm/arduino/host_driven/src/standalone_crt/crt_config/crt_config.h b/apps/microtvm/arduino/host_driven/src/standalone_crt/crt_config/crt_config.h new file mode 100644 index 000000000000..cf73103aff8b --- /dev/null +++ b/apps/microtvm/arduino/host_driven/src/standalone_crt/crt_config/crt_config.h @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief CRT configuration for the host-linked CRT. + */ +#ifndef TVM_RUNTIME_MICRO_CRT_CONFIG_H_ +#define TVM_RUNTIME_MICRO_CRT_CONFIG_H_ + +/*! Log level of the CRT runtime */ +#define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG + +/*! Support low-level debugging in MISRA-C runtime */ +#define TVM_CRT_DEBUG 0 + +/*! Maximum supported dimension in NDArray */ +#define TVM_CRT_MAX_NDIM 6 +/*! Maximum supported arguments in generated functions */ +#define TVM_CRT_MAX_ARGS 10 +/*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ +#define TVM_CRT_MAX_STRLEN_DLTYPE 10 +/*! Maximum supported string length in function names */ +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 + +/*! Maximum number of registered modules. */ +#define TVM_CRT_MAX_REGISTERED_MODULES 2 + +/*! Size of the global function registry, in bytes. */ +#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512 + +/*! Maximum packet size, in bytes, including the length header. */ +#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8 * 1024 + +/*! \brief Maximum length of a PackedFunc function name. */ +#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 + +// #define TVM_CRT_FRAMER_ENABLE_LOGS + +#endif // TVM_RUNTIME_MICRO_CRT_CONFIG_H_ diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py new file mode 100644 index 000000000000..91beaf558249 --- /dev/null +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -0,0 +1,486 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import collections +import functools +import json +import logging +import os +import os.path +import pathlib +import re +import shlex +import shutil +import subprocess +import sys +import tarfile +import tempfile +import time +from string import Template + +import serial +import serial.tools.list_ports +from tvm.micro.project_api import server + +MODEL_LIBRARY_FORMAT_RELPATH = pathlib.Path("src") / "model" / "model.tar" +API_SERVER_DIR = pathlib.Path(os.path.dirname(__file__) or os.path.getcwd()) +BUILD_DIR = API_SERVER_DIR / "build" +MODEL_LIBRARY_FORMAT_PATH = API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH + +IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists() + + +class BoardAutodetectFailed(Exception): + """Raised when no attached hardware is found matching the requested board""" + + +# Data structure to hold the information microtvm_api_server.py needs +# to communicate with each of these boards. Currently just holds the +# components of each board's FQBN, but might be extended in the future +# to include the SRAM, PSRAM, flash, etc. on each board. +BOARD_PROPERTIES = { + "due": { + "package": "arduino", + "architecture": "sam", + "board": "arduino_due_x_dbg", + }, + # Due to the way the Feather S2 bootloader works, compilation + # behaves fine but uploads cannot be done automatically + "feathers2": { + "package": "esp32", + "architecture": "esp32", + "board": "feathers2", + }, + # Spresense only works as of its v2.3.0 sdk + "spresense": { + "package": "SPRESENSE", + "architecture": "spresense", + "board": "spresense", + }, + "nano33ble": { + "package": "arduino", + "architecture": "mbed_nano", + "board": "nano33ble", + }, + "pybadge": { + "package": "adafruit", + "architecture": "samd", + "board": "adafruit_pybadge_m4", + }, + # The Teensy boards are listed here for completeness, but they + # won't work until https://github.com/arduino/arduino-cli/issues/700 + # is finished + "teensy40": { + "package": "teensy", + "architecture": "avr", + "board": "teensy40", + }, + "teensy41": { + "package": "teensy", + "architecture": "avr", + "board": "teensy41", + }, + "wioterminal": { + "package": "Seeeduino", + "architecture": "samd", + "board": "seeed_wio_terminal", + }, +} + +PROJECT_TYPES = ["example_project", "host_driven"] + +PROJECT_OPTIONS = [ + server.ProjectOption( + "arduino_board", + choices=list(BOARD_PROPERTIES), + help="Name of the Arduino board to build for", + ), + server.ProjectOption("arduino_cli_cmd", help="Path to the arduino-cli tool."), + server.ProjectOption("port", help="Port to use for connecting to hardware"), + server.ProjectOption( + "project_type", + help="Type of project to generate.", + choices=tuple(PROJECT_TYPES), + ), + server.ProjectOption( + "verbose", help="True to pass --verbose flag to arduino-cli compile and upload" + ), +] + + +class Handler(server.ProjectAPIHandler): + def __init__(self): + super(Handler, self).__init__() + self._proc = None + self._port = None + self._serial = None + + def server_info_query(self, tvm_version): + return server.ServerInfo( + platform_name="arduino", + is_template=IS_TEMPLATE, + model_library_format_path=MODEL_LIBRARY_FORMAT_PATH, + project_options=PROJECT_OPTIONS, + ) + + def _copy_project_files(self, api_server_dir, project_dir, project_type): + """Copies the files for project_type into project_dir. + + Notes + ----- + template_dir is NOT a project type, and that directory is never copied + in this function. template_dir only holds this file and its unit tests, + so this file is copied separately in generate_project. + + """ + project_types_folder = api_server_dir.parents[0] + for item in (project_types_folder / project_type / "src").iterdir(): + dest = project_dir / "src" / item.name + if item.is_dir(): + shutil.copytree(item, dest) + else: + shutil.copy2(item, dest) + + # Arduino requires the .ino file have the same filename as its containing folder + shutil.copy2( + project_types_folder / project_type / "project.ino", + project_dir / f"{project_dir.stem}.ino", + ) + + CRT_COPY_ITEMS = ("include", "src") + + def _copy_standalone_crt(self, source_dir, standalone_crt_dir): + output_crt_dir = source_dir / "standalone_crt" + for item in self.CRT_COPY_ITEMS: + src_path = os.path.join(standalone_crt_dir, item) + dst_path = output_crt_dir / item + if os.path.isdir(src_path): + shutil.copytree(src_path, dst_path) + else: + shutil.copy2(src_path, dst_path) + + # Example project is the "minimum viable project", + # and doesn't need a fancy RPC server + EXAMPLE_PROJECT_UNUSED_COMPONENTS = [ + "include/dmlc", + "src/support", + "src/runtime/minrpc", + "src/runtime/crt/graph_executor", + "src/runtime/crt/microtvm_rpc_common", + "src/runtime/crt/microtvm_rpc_server", + "src/runtime/crt/tab", + ] + + def _remove_unused_components(self, source_dir, project_type): + unused_components = [] + if project_type == "example_project": + unused_components = self.EXAMPLE_PROJECT_UNUSED_COMPONENTS + + for component in unused_components: + shutil.rmtree(source_dir / "standalone_crt" / component) + + def _disassemble_mlf(self, mlf_tar_path, source_dir): + with tempfile.TemporaryDirectory() as mlf_unpacking_dir_str: + mlf_unpacking_dir = pathlib.Path(mlf_unpacking_dir_str) + with tarfile.open(mlf_tar_path, "r:") as tar: + tar.extractall(mlf_unpacking_dir) + + model_dir = source_dir / "model" + model_dir.mkdir() + + # Copy C files from model. The filesnames and quantity + # depend on the target string, so we just copy all c files + source_dir = mlf_unpacking_dir / "codegen" / "host" / "src" + for file in source_dir.rglob(f"*.c"): + shutil.copy(file, model_dir) + + # Return metadata.json for use in templating + with open(os.path.join(mlf_unpacking_dir, "metadata.json")) as f: + metadata = json.load(f) + return metadata + + def _template_model_header(self, source_dir, metadata): + with open(source_dir / "model.h", "r") as f: + model_h_template = Template(f.read()) + + assert ( + metadata["style"] == "full-model" + ), "when generating AOT, expect only full-model Model Library Format" + + template_values = { + "workspace_size_bytes": metadata["memory"]["functions"]["main"][0][ + "workspace_size_bytes" + ], + } + + with open(source_dir / "model.h", "w") as f: + f.write(model_h_template.substitute(template_values)) + + # Arduino ONLY recognizes .ino, .ccp, .c, .h + + CPP_FILE_EXTENSION_SYNONYMS = ("cc", "cxx") + + def _change_cpp_file_extensions(self, source_dir): + for ext in self.CPP_FILE_EXTENSION_SYNONYMS: + for filename in source_dir.rglob(f"*.{ext}"): + filename.rename(filename.with_suffix(".cpp")) + + for filename in source_dir.rglob(f"*.inc"): + filename.rename(filename.with_suffix(".h")) + + def _convert_includes(self, project_dir, source_dir): + """Changes all #include statements in project_dir to be relevant to their + containing file's location. + + Arduino only supports includes relative to a file's location, so this + function finds each time we #include a file and changes the path to + be relative to the file location. Does not do this for standard C + libraries. Also changes angle brackets syntax to double quotes syntax. + + See Also + ----- + https://www.arduino.cc/reference/en/language/structure/further-syntax/include/ + + """ + for ext in ("c", "h", "cpp"): + for filename in source_dir.rglob(f"*.{ext}"): + with filename.open() as file: + lines = file.readlines() + + for i in range(len(lines)): + # Check if line has an include + result = re.search(r"#include\s*[<\"]([^>]*)[>\"]", lines[i]) + if not result: + continue + new_include = self._find_modified_include_path( + project_dir, filename, result.groups()[0] + ) + + lines[i] = f'#include "{new_include}"\n' + + with filename.open("w") as file: + file.writelines(lines) + + # Most of the files we used to be able to point to directly are under "src/standalone_crt/include/". + # Howver, crt_config.h lives under "src/standalone_crt/crt_config/", and more exceptions might + # be added in the future. + POSSIBLE_BASE_PATHS = ["src/standalone_crt/include/", "src/standalone_crt/crt_config/"] + + def _find_modified_include_path(self, project_dir, file_path, include_path): + """Takes a single #include path, and returns the location it should point to. + + Examples + -------- + >>> _find_modified_include_path( + ... "/path/to/project/dir" + ... "/path/to/project/dir/src/standalone_crt/src/runtime/crt/common/ndarray.c" + ... "tvm/runtime/crt/platform.h" + ... ) + "../../../../../../src/standalone_crt/include/tvm/runtime/crt/platform.h" + + """ + if include_path.endswith(".inc"): + include_path = re.sub(r"\.[a-z]+$", ".h", include_path) + + # Change includes referencing .cc and .cxx files to point to the renamed .cpp file + if include_path.endswith(self.CPP_FILE_EXTENSION_SYNONYMS): + include_path = re.sub(r"\.[a-z]+$", ".cpp", include_path) + + # If the include already works, don't modify it + if (file_path.parents[0] / include_path).exists(): + return include_path + + relative_path = file_path.relative_to(project_dir) + up_dirs_path = "../" * str(relative_path).count("/") + + for base_path in self.POSSIBLE_BASE_PATHS: + full_potential_path = project_dir / base_path / include_path + if full_potential_path.exists(): + return up_dirs_path + base_path + include_path + + # If we can't find the file, just leave it untouched + # It's probably a standard C/C++ header + return include_path + + def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): + # Reference key directories with pathlib + project_dir = pathlib.Path(project_dir) + project_dir.mkdir() + source_dir = project_dir / "src" + source_dir.mkdir() + + # Copies files from the template folder to project_dir + shutil.copy2(API_SERVER_DIR / "microtvm_api_server.py", project_dir) + self._copy_project_files(API_SERVER_DIR, project_dir, options["project_type"]) + + # Copy standalone_crt into src folder + self._copy_standalone_crt(source_dir, standalone_crt_dir) + self._remove_unused_components(source_dir, options["project_type"]) + + # Unpack the MLF and copy the relevant files + metadata = self._disassemble_mlf(model_library_format_path, source_dir) + shutil.copy2(model_library_format_path, source_dir / "model") + + # For AOT, template model.h with metadata to minimize space usage + if options["project_type"] == "example_project": + self._template_model_header(source_dir, metadata) + + self._change_cpp_file_extensions(source_dir) + + # Recursively change includes + self._convert_includes(project_dir, source_dir) + + def _get_fqbn(self, options): + o = BOARD_PROPERTIES[options["arduino_board"]] + return f"{o['package']}:{o['architecture']}:{o['board']}" + + def build(self, options): + BUILD_DIR.mkdir() + + compile_cmd = [ + options["arduino_cli_cmd"], + "compile", + "./project/", + "--fqbn", + self._get_fqbn(options), + "--build-path", + BUILD_DIR.resolve(), + ] + + if options.get("verbose"): + compile_cmd.append("--verbose") + + # Specify project to compile + subprocess.run(compile_cmd) + + BOARD_LIST_HEADERS = ("Port", "Type", "Board Name", "FQBN", "Core") + + def _parse_boards_tabular_str(self, tabular_str): + """Parses the tabular output from `arduino-cli board list` into a 2D array + + Examples + -------- + >>> list(_parse_boards_tabular_str(bytes( + ... "Port Type Board Name FQBN Core \n" + ... "/dev/ttyS4 Serial Port Unknown \n" + ... "/dev/ttyUSB0 Serial Port (USB) Spresense SPRESENSE:spresense:spresense SPRESENSE:spresense\n" + ... "\n", + ... "utf-8"))) + [['/dev/ttys4', 'Serial Port', 'Unknown', '', ''], ['/dev/ttyUSB0', 'Serial Port (USB)', + 'Spresense', 'SPRESENSE:spresense:spresense', 'SPRESENSE:spresense']] + + """ + + str_rows = tabular_str.split("\n")[:-2] + header = str_rows[0] + indices = [header.index(h) for h in self.BOARD_LIST_HEADERS] + [len(header)] + + for str_row in str_rows[1:]: + parsed_row = [] + for cell_index in range(len(self.BOARD_LIST_HEADERS)): + start = indices[cell_index] + end = indices[cell_index + 1] + str_cell = str_row[start:end] + + # Remove trailing whitespace used for padding + parsed_row.append(str_cell.rstrip()) + yield parsed_row + + def _auto_detect_port(self, options): + list_cmd = [options["arduino_cli_cmd"], "board", "list"] + list_cmd_output = subprocess.run(list_cmd, stdout=subprocess.PIPE).stdout.decode("utf-8") + + desired_fqbn = self._get_fqbn(options) + for line in self._parse_boards_tabular_str(list_cmd_output): + if line[3] == desired_fqbn: + return line[0] + + # If no compatible boards, raise an error + raise BoardAutodetectFailed() + + def _get_arduino_port(self, options): + if not self._port: + if "port" in options and options["port"]: + self._port = options["port"] + else: + self._port = self._auto_detect_port(options) + + return self._port + + def flash(self, options): + port = self._get_arduino_port(options) + + upload_cmd = [ + options["arduino_cli_cmd"], + "upload", + "./project", + "--fqbn", + self._get_fqbn(options), + "--input-dir", + BUILD_DIR.resolve(), + "--port", + port, + ] + + if options.get("verbose"): + upload_cmd.append("--verbose") + + subprocess.run(upload_cmd) + + def open_transport(self, options): + # Zephyr example doesn't throw an error in this case + if self._serial is not None: + return + + port = self._get_arduino_port(options) + + # It takes a moment for the Arduino code to finish initializing + # and start communicating over serial + for attempts in range(10): + if any(serial.tools.list_ports.grep(port)): + break + time.sleep(0.5) + + self._serial = serial.Serial(port, baudrate=115200, timeout=5) + + return server.TransportTimeouts( + session_start_retry_timeout_sec=2.0, + session_start_timeout_sec=5.0, + session_established_timeout_sec=5.0, + ) + + def close_transport(self): + if self._serial is None: + return + self._serial.close() + self._serial = None + + def read_transport(self, n, timeout_sec): + # It's hard to set timeout_sec, so we just throw it away + # TODO fix this + if self._serial is None: + raise server.TransportClosedError() + return self._serial.read(n) + + def write_transport(self, data, timeout_sec): + if self._serial is None: + raise server.TransportClosedError() + return self._serial.write(data) + + +if __name__ == "__main__": + server.main(Handler()) diff --git a/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py b/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py new file mode 100644 index 000000000000..00969a5a892b --- /dev/null +++ b/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import subprocess +import sys +from pathlib import Path +from unittest import mock + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent)) +import microtvm_api_server + +sys.path.pop(0) + + +class TestGenerateProject: + DEFAULT_OPTIONS = {"arduino_cli_cmd": "arduino-cli", "arduino_board": "nano33ble"} + + def _set_pathlib_path_exists(self, value): + with mock.patch.object(Path, "exists") as mock_exists: + mock_exists.return_value = value + + @mock.patch("pathlib.Path") + def test_find_modified_include_path(self, mock_pathlib_path): + handler = microtvm_api_server.Handler() + + project_dir = mock_pathlib_path("/dummy/project") + file_path = ( + project_dir + / "src" + / "standalone_crt" + / "src" + / "runtime" + / "crt" + / "graph_executor" + / "load_json.c" + ) + + # Should return C standard libs unmodified + clib_output = handler._find_modified_include_path(project_dir, file_path, "math.h") + assert clib_output == "math.h" + + # If import already works, should return unmodified + valid_arduino_import = "../../../../include/tvm/runtime/crt/platform.h" + self._set_pathlib_path_exists(True) + valid_output = handler._find_modified_include_path( + project_dir, file_path, valid_arduino_import + ) + assert valid_output == valid_arduino_import + + BOARD_CONNECTED_OUTPUT = bytes( + "Port Type Board Name FQBN Core \n" + "/dev/ttyACM0 Serial Port (USB) Arduino Nano 33 BLE arduino:mbed_nano:nano33ble arduino:mbed_nano\n" + "/dev/ttyACM1 Serial Port (USB) Arduino Nano 33 arduino:mbed_nano:nano33 arduino:mbed_nano\n" + "/dev/ttyS4 Serial Port Unknown \n" + "\n", + "utf-8", + ) + BOARD_DISCONNECTED_OUTPUT = bytes( + "Port Type Board Name FQBN Core\n" + "/dev/ttyS4 Serial Port Unknown \n" + "\n", + "utf-8", + ) + + @mock.patch("subprocess.run") + def test_auto_detect_port(self, mock_subprocess_run): + process_mock = mock.Mock() + handler = microtvm_api_server.Handler() + + # Test it returns the correct port when a board is connected + mock_subprocess_run.return_value.stdout = self.BOARD_CONNECTED_OUTPUT + assert handler._auto_detect_port(self.DEFAULT_OPTIONS) == "/dev/ttyACM0" + + # Test it raises an exception when no board is connected + mock_subprocess_run.return_value.stdout = self.BOARD_DISCONNECTED_OUTPUT + with pytest.raises(microtvm_api_server.BoardAutodetectFailed): + handler._auto_detect_port(self.DEFAULT_OPTIONS) + + # Test that the FQBN needs to match EXACTLY + handler._get_fqbn = mock.MagicMock(return_value="arduino:mbed_nano:nano33") + mock_subprocess_run.return_value.stdout = self.BOARD_CONNECTED_OUTPUT + assert ( + handler._auto_detect_port({**self.DEFAULT_OPTIONS, "arduino_board": "nano33"}) + == "/dev/ttyACM1" + ) + + @mock.patch("subprocess.run") + def test_flash(self, mock_subprocess_run): + handler = microtvm_api_server.Handler() + handler._port = "/dev/ttyACM0" + + # Test no exception thrown when command works + handler.flash(self.DEFAULT_OPTIONS) + mock_subprocess_run.assert_called_once() + + # Test exception raised when `arduino-cli upload` returns error code + mock_subprocess_run.side_effect = subprocess.CalledProcessError(2, []) + with pytest.raises(subprocess.CalledProcessError): + handler.flash(self.DEFAULT_OPTIONS) diff --git a/apps/microtvm/reference-vm/README.md b/apps/microtvm/reference-vm/README.md index 7ff75c75b4f9..9303c0a64ece 100644 --- a/apps/microtvm/reference-vm/README.md +++ b/apps/microtvm/reference-vm/README.md @@ -15,53 +15,100 @@ -# microTVM Reference Virtual Machines +# microTVM Reference Virtual Machines (RVM) -This directory contains Vagrant specifications that create reference Virtual Machines for use with -microTVM. These machines help microTVM users collaborate by providing a stable reference test -environment. +This directory contains Vagrant specifications that create Reference Virtual +Machines (RVM) for use with microTVM. These machines help microTVM users +collaborate by providing a stable reference environment to build and test +microTVM. -For more information on how to use them, see the microTVM Reference Virtual Machines tutorial. +For more information on how to use them, see the +[microTVM Reference VM tutorial](../../../tutorials/micro/micro_reference_vm.py). -## Reference VM Developer Information +## microTVM Developer Information -Each RTOS or platform that integrates with microTVM can check-in a Reference VM in this directory to -help the community collaborate. You should use the tools provided here to ensure a uniform release -process across all platforms. Typically, releases need to be created by TVM committers. +Each RTOS or platform (like Zephyr, Ardunio, etc) that integrates with microTVM +can check-in a Reference VM in this directory to help the community collaborate. +You should use the tools provided here to ensure a uniform release process +across all platforms. Typically, releases need to be created by TVM committers. -Generally speaking, it's expected that any integrated platform with a regression test checked-in to -the tvm repository should also define a reference VM. If you want to integrate a new platform, -please raise a discussion on [the forum](https://discuss.tvm.ai). +Generally speaking, it's expected that any integrated platform with a regression +test checked-in to the tvm repository should also define a reference VM. If you +want to integrate a new platform, please raise a discussion on +[the forum](https://discuss.tvm.ai). -### Organization -Reference VMs are organized as follows: +## Reference VMs Organization + +Reference VMs are organized in this directory as follows: + +``` +. ++-- base-box-tool.py - Reference VM build, test, and release tool. ++-- PLATFORM/ - One or more dirs related to the supported platform(s), + like zephyr/ and arduino/. The dir names are the same to + be passed as arguments to base-box-tool.py as PLATFORM. + +-- Vagrantfile - Vagrantfile that end-users will invoke. Should be based + | off a base box which contains dependencies other than the + | TVM python dependencies. + +-- base-box/ - Top-level directory which defines the base box. + +-- Vagrantfile.packer-template - 'packer' template Vagrantfile which + | will be used to build the base box. + +-- test-config.json - JSON file explaining how to perform + release tests to base-box-tool.py. +``` -* `base-box-tool.py` - Reference VM build, test, and release tool -* `/` -** `Vagrantfile` Vagrantfile that end-users will inovke. Should be based off a base box - which contains dependencies other than the TVM python dependencies. -** `base-box` - Top-level directory which defines the base box. -*** `Vagrantfile.packer-template` - Packer template Vagrantfile which will be used to build the - base box. -*** `test-config.json` - JSON file explaining how to perform release tests to `base-box-tool.py` ## Creating Releases -1. Build the base box for the given platform: `$ ./base-box-tool.py [--provider=] build ` -2. Run release tests for each platform: - 1. Connect any needed hardware to the VM host machine. - 2. Run tests: `$ ./base-box-tool.py [--provider=] test [--microtvm-platform=] [--test-device-serial=]`. This - command does the following for each provider: - 1. Copies all files inside `./` except `.vagrant` and `base-box` to - `./release-test`. This is done to avoid reusing any VM the developer may have started. - 2. Executes `$ vagrant up [--provider=]`. - 3. Finds an attached USB device matching the VID and PID specified in `test-config.json`, - and if `--test-device-serial` was given, that serial number (as reported to USB). Creates - a rule to autoconnect this device to the VM, and also attaches it to the VM> - 4. SSHs to the VM, `cd` to the TVM root directory, and runs `test_cmd` from - `test-config.json`. Nonzero status means failure. -3. If release tests fail, fix them and restart from step 1. -4. If release tests pass: `$ ./base-box-tool.py [--provider=] release <--release-version=> <--platform-version=> `. Be sure you've logged - in to Vagrant Cloud using the `vagrant` tool. +1. **Build** the base box for a given platform: +```bash +$ ./base-box-tool.py [--provider=PROVIDER] build PLATFORM +``` + +For example: +```bash +$ ./base-box-tool.py --provider virtualbox build zephyr +``` + +2. **Run** release tests for each platform: + + A. Connect any needed hardware to the VM host machine; + + B. Run tests: + ```bash + $ ./base-box-tool.py [--provider=PROVIDER] test --microtvm-platform=MICROTVM_PLATFORM [--test-device-serial=SERIAL] PLATFORM + ``` + where MICROTVM_PLATFORM is one of the options listed in the + PLATFORM/base-box/test-config.json file. + + For example: + ```base + $ ./base-box-tool.py --provider virtualbox test --microtvm-platform=stm32f746xx_disco zephyr + ``` + + This command does the following for the specified provider: + + * Copies all files inside `PLATFORM/` dir except `.vagrant` and `base-box` to + `release-test/`. This is done to avoid reusing any VM the developer may have + started; + + * Executes `$ vagrant up [--provider=PROVIDER]`; + + * Finds an attached USB device matching the VID and PID specified in + `test-config.json`, and if `--test-device-serial` was given, that serial + number (as reported to USB). Creates a rule to autoconnect this device to the + VM, and also attaches it to the VM; + + * SSHs to the VM, `cd` to the TVM root directory, and runs `test_cmd` from + `test-config.json`. Nonzero status means failure. + +3. If release tests _fail_, fix them and restart from step 1. + +4. If release tests pass, **release** the box: +```bash +$ ./base-box-tool.py [--provider=PROVIDER] release --release-version=RELEASE_VER --platform-version=PLATFORM_VER PLATFORM +``` + For that step be sure you've logged in to Vagrant Cloud using the `vagrant` + tool. diff --git a/apps/microtvm/reference-vm/base-box-tool.py b/apps/microtvm/reference-vm/base-box-tool.py index c22eff4cdbad..be9c5173de73 100755 --- a/apps/microtvm/reference-vm/base-box-tool.py +++ b/apps/microtvm/reference-vm/base-box-tool.py @@ -43,7 +43,8 @@ # List of microTVM platforms for testing. ALL_MICROTVM_PLATFORMS = ( - "stm32f746xx", + "stm32f746xx_nucleo", + "stm32f746xx_disco", "nrf5340dk", "mps2_an521", ) @@ -177,10 +178,7 @@ def attach_vmware(uuid, vid_hex=None, pid_hex=None, serial=None): # Extra scripts required to execute on provisioning # in zephyr/base-box/base_box_provision.sh -EXTRA_SCRIPTS = ( - "docker/install/ubuntu_init_zephyr_project.sh", - "docker/install/ubuntu_install_qemu.sh", -) +EXTRA_SCRIPTS = ("docker/install/ubuntu_init_zephyr_project.sh",) def generate_packer_config(file_path, providers): diff --git a/apps/microtvm/reference-vm/zephyr/Vagrantfile b/apps/microtvm/reference-vm/zephyr/Vagrantfile index bd0094fcec66..28339dc0a6d9 100644 --- a/apps/microtvm/reference-vm/zephyr/Vagrantfile +++ b/apps/microtvm/reference-vm/zephyr/Vagrantfile @@ -46,7 +46,12 @@ Vagrant.configure("2") do |config| end end - config.vm.provision "shell", path: "provision_setup.sh", env: {"TVM_HOME": dirs_to_mount[0]}, privileged: false + config.vm.provision "shell", + path: "provision_setup.sh", + env: {"TVM_HOME": dirs_to_mount[0], + "TVM_CI_NUM_CORES": num_cores + }, + privileged: false # Enable USB Controller on VirtualBox vm_name = "microtvm-#{Time.now.tv_sec}" diff --git a/apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh b/apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh index 69e6171d06dd..0631e89f3bb3 100644 --- a/apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh +++ b/apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh @@ -30,8 +30,5 @@ cd ~ # Using most recent commit that passes all the tests. ~/ubuntu_init_zephyr_project.sh ~/zephyr v2.5-branch --commit dabf23758417fd041fec2a2a821d8f526afac29d -# Build QEMU -sudo ~/ubuntu_install_qemu.sh --target-list arm-softmmu - # Cleanup rm -f *.sh diff --git a/apps/microtvm/reference-vm/zephyr/base-box/test-config.json b/apps/microtvm/reference-vm/zephyr/base-box/test-config.json index 48b6915a10f4..f3f2633d9468 100644 --- a/apps/microtvm/reference-vm/zephyr/base-box/test-config.json +++ b/apps/microtvm/reference-vm/zephyr/base-box/test-config.json @@ -1,5 +1,9 @@ { - "stm32f746xx": { + "stm32f746xx_nucleo": { + "vid_hex": "0483", + "pid_hex": "374b" + }, + "stm32f746xx_disco": { "vid_hex": "0483", "pid_hex": "374b" }, diff --git a/apps/microtvm/reference-vm/zephyr/provision_setup.sh b/apps/microtvm/reference-vm/zephyr/provision_setup.sh index f95c7e24f5aa..fcefc1176821 100644 --- a/apps/microtvm/reference-vm/zephyr/provision_setup.sh +++ b/apps/microtvm/reference-vm/zephyr/provision_setup.sh @@ -47,3 +47,4 @@ poetry run pip3 install -r ${ZEPHYR_BASE}/scripts/requirements.txt echo "export TVM_LIBRARY_PATH=\"$TVM_HOME\"/build-microtvm" >>~/.profile echo "VENV_PATH=\$((cd \"$TVM_HOME\"/apps/microtvm/reference-vm/zephyr && poetry env list --full-path) | sed -E 's/^(.*)[[:space:]]\(Activated\)\$/\1/g')" >>~/.profile echo "source \$VENV_PATH/bin/activate" >>~/.profile +echo "export PATH=\"\${PATH}:\${HOME}/zephyr-sdk/sysroots/x86_64-pokysdk-linux/usr/bin\"" >>~/.profile diff --git a/apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh b/apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh index 1cebcf7166af..a4c659438d4d 100755 --- a/apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh +++ b/apps/microtvm/reference-vm/zephyr/rebuild-tvm.sh @@ -36,7 +36,7 @@ fi cp cmake/config.cmake "${BUILD_DIR}" cd "${BUILD_DIR}" sed -i 's/USE_MICRO OFF/USE_MICRO ON/' config.cmake -sed -i 's/USE_GRAPH_EXECUTOR_DEBUG OFF/USE_GRAPH_EXECUTOR_DEBUG ON/' config.cmake +sed -i 's/USE_PROFILER OFF/USE_PROFILER ON/' config.cmake sed -i 's/USE_LLVM OFF/USE_LLVM ON/' config.cmake cmake .. rm -rf standalone_crt host_standalone_crt # remove stale generated files diff --git a/apps/microtvm/zephyr/aot_demo/CMakeLists.txt b/apps/microtvm/zephyr/aot_demo/CMakeLists.txt deleted file mode 100644 index d7ec2a25db14..000000000000 --- a/apps/microtvm/zephyr/aot_demo/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -cmake_minimum_required(VERSION 3.13.1) - -set(ENV{QEMU_BIN_PATH} "${CMAKE_SOURCE_DIR}/qemu-hack") - -set(QEMU_PIPE "\${QEMU_PIPE}") # QEMU_PIPE is set by the calling TVM instance. - -find_package(Zephyr HINTS $ENV{ZEPHYR_BASE}) -project(microtvm_zephyr_runtime) - -set(CMAKE_VERBOSE_MAKEFILE ON) - -target_sources(app PRIVATE src/zephyr_uart.c) -target_sources(app PRIVATE src/main.c) - -foreach(tvm_lib ${TVM_LIBS}) - string(LENGTH ${tvm_lib} tvm_lib_length) - math(EXPR tvm_lib_cut "${tvm_lib_length} - 2") - string(SUBSTRING ${tvm_lib} 3 ${tvm_lib_cut} tvm_lib_name) - add_library(${tvm_lib_name} STATIC IMPORTED) - set_target_properties(${tvm_lib_name} PROPERTIES - IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/${tvm_lib}) - target_link_libraries(app PRIVATE ${tvm_lib_name}) -endforeach(tvm_lib ${TVM_LIBS}) - -target_include_directories(app PRIVATE ${TVM_INCLUDE_DIRS}) diff --git a/apps/microtvm/zephyr/aot_demo/prj.conf b/apps/microtvm/zephyr/aot_demo/prj.conf deleted file mode 100644 index 5f4d7a0689dc..000000000000 --- a/apps/microtvm/zephyr/aot_demo/prj.conf +++ /dev/null @@ -1,35 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# The settings in this file are generic for all boards, and are merged -# with the settings in the file boards/.conf by the Zephyr build -# process. - -# For UART implementation in main(). -CONFIG_RING_BUFFER=y -CONFIG_UART_CONSOLE=n -CONFIG_UART_INTERRUPT_DRIVEN=y - -# For RPC server C++ bindings. -CONFIG_CPLUSPLUS=y -CONFIG_NEWLIB_LIBC=y - -# For models with floating point. -CONFIG_FPU=y - -# For TVMPlatformAbort(). -CONFIG_REBOOT=y diff --git a/apps/microtvm/zephyr/aot_demo/qemu-hack b/apps/microtvm/zephyr/aot_demo/qemu-hack deleted file mode 120000 index b4810f2aab6e..000000000000 --- a/apps/microtvm/zephyr/aot_demo/qemu-hack +++ /dev/null @@ -1 +0,0 @@ -../qemu-hack \ No newline at end of file diff --git a/apps/microtvm/zephyr/host_driven/CMakeLists.txt b/apps/microtvm/zephyr/host_driven/CMakeLists.txt deleted file mode 100644 index f04a792086cb..000000000000 --- a/apps/microtvm/zephyr/host_driven/CMakeLists.txt +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -cmake_minimum_required(VERSION 3.13.1) - -set(ENV{QEMU_BIN_PATH} "${CMAKE_SOURCE_DIR}/qemu-hack") - -set(QEMU_PIPE "\${QEMU_PIPE}") # QEMU_PIPE is set by the calling TVM instance. - -find_package(Zephyr HINTS $ENV{ZEPHYR_BASE}) -project(microtvm_zephyr_runtime) - -set(CMAKE_VERBOSE_MAKEFILE ON) - -target_sources(app PRIVATE src/main.c) - -foreach(tvm_lib ${TVM_LIBS}) - string(LENGTH ${tvm_lib} tvm_lib_length) - math(EXPR tvm_lib_cut "${tvm_lib_length} - 2") - string(SUBSTRING ${tvm_lib} 3 ${tvm_lib_cut} tvm_lib_name) - add_library(${tvm_lib_name} STATIC IMPORTED) - set_target_properties(${tvm_lib_name} PROPERTIES - IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/${tvm_lib}) - target_link_libraries(app PRIVATE ${tvm_lib_name}) -endforeach(tvm_lib ${TVM_LIBS}) - -target_include_directories(app PRIVATE ${TVM_INCLUDE_DIRS}) diff --git a/apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf b/apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf deleted file mode 100644 index 149a69ea3b5b..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf +++ /dev/null @@ -1,31 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# This file is specific to the nRF5340 DK board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# Required for Cortex-M33 devices. -CONFIG_MAIN_STACK_SIZE=1536 - -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y - -# For debugging. -CONFIG_LED=y diff --git a/apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf b/apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf deleted file mode 100644 index eba023294894..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf +++ /dev/null @@ -1,30 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# This file is specific to the STM32F746 Nucleo board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# For operations that stack allocates a large float array. -CONFIG_MAIN_STACK_SIZE=1536 - -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y - -# For debugging. -CONFIG_LED=y diff --git a/apps/microtvm/zephyr/host_driven/boards/nucleo_l4r5zi.conf b/apps/microtvm/zephyr/host_driven/boards/nucleo_l4r5zi.conf deleted file mode 100644 index b87206019026..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/nucleo_l4r5zi.conf +++ /dev/null @@ -1,31 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# This file is specific to the STM32L4R5ZI Nucleo board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# For operations that stack allocates a large float array. -CONFIG_MAIN_STACK_SIZE=1536 - -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y - -# For debugging. -CONFIG_LED=y diff --git a/apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf b/apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf deleted file mode 100644 index a8a055bcc748..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file is specific to the QEMU-emulated RISCV64 microTVM board. - -# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random. -CONFIG_TEST_RANDOM_GENERATOR=y -CONFIG_TIMER_RANDOM_GENERATOR=y - -# Default 512, for operations with large floating point data. -CONFIG_MAIN_STACK_SIZE=2048 diff --git a/apps/microtvm/zephyr/host_driven/boards/qemu_x86.conf b/apps/microtvm/zephyr/host_driven/boards/qemu_x86.conf deleted file mode 100644 index f314f59a597a..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/qemu_x86.conf +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file is specific to the QEMU-emulated microTVM board. - -# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random. -CONFIG_TEST_RANDOM_GENERATOR=y -CONFIG_TIMER_RANDOM_GENERATOR=y - -# Default stack size is 1k, this is required for debug mode. -CONFIG_MAIN_STACK_SIZE=1536 diff --git a/apps/microtvm/zephyr/host_driven/prj.conf b/apps/microtvm/zephyr/host_driven/prj.conf deleted file mode 100644 index 5f4d7a0689dc..000000000000 --- a/apps/microtvm/zephyr/host_driven/prj.conf +++ /dev/null @@ -1,35 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# The settings in this file are generic for all boards, and are merged -# with the settings in the file boards/.conf by the Zephyr build -# process. - -# For UART implementation in main(). -CONFIG_RING_BUFFER=y -CONFIG_UART_CONSOLE=n -CONFIG_UART_INTERRUPT_DRIVEN=y - -# For RPC server C++ bindings. -CONFIG_CPLUSPLUS=y -CONFIG_NEWLIB_LIBC=y - -# For models with floating point. -CONFIG_FPU=y - -# For TVMPlatformAbort(). -CONFIG_REBOOT=y diff --git a/apps/microtvm/zephyr/host_driven/qemu-hack b/apps/microtvm/zephyr/host_driven/qemu-hack deleted file mode 120000 index b4810f2aab6e..000000000000 --- a/apps/microtvm/zephyr/host_driven/qemu-hack +++ /dev/null @@ -1 +0,0 @@ -../qemu-hack \ No newline at end of file diff --git a/apps/microtvm/zephyr/qemu-hack/qemu-system-arm b/apps/microtvm/zephyr/qemu-hack/qemu-system-arm deleted file mode 120000 index 58fc8296c31f..000000000000 --- a/apps/microtvm/zephyr/qemu-hack/qemu-system-arm +++ /dev/null @@ -1 +0,0 @@ -./qemu-system-i386 \ No newline at end of file diff --git a/apps/microtvm/zephyr/template_project/CMakeLists.txt.template b/apps/microtvm/zephyr/template_project/CMakeLists.txt.template new file mode 100644 index 000000000000..17e9d75c76e8 --- /dev/null +++ b/apps/microtvm/zephyr/template_project/CMakeLists.txt.template @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# SPDX-License-Identifier: Apache-2.0 + +cmake_minimum_required(VERSION 3.13.1) + +set(ENV{QEMU_BIN_PATH} "${CMAKE_SOURCE_DIR}/qemu-hack") + +set(QEMU_PIPE "\${QEMU_PIPE}") # QEMU_PIPE is set by the calling TVM instance. + +find_package(Zephyr HINTS $ENV{ZEPHYR_BASE}) +project(microtvm_autogenerated_project) + +set(CRT_LIBS ) +set(CRT_LIB_BASE crt/src/runtime/crt) +foreach(crt_lib_name ${CRT_LIBS}) + zephyr_library_named(${crt_lib_name}) + file(GLOB_RECURSE crt_lib_srcs ${CRT_LIB_BASE}/${crt_lib_name}/*.c ${CRT_LIB_BASE}/${crt_lib_name}/*.cc) + target_sources(${crt_lib_name} PRIVATE ${crt_lib_srcs}) + zephyr_library_include_directories(${crt_lib_name} PRIVATE crt_config crt/include) + target_link_libraries(app PRIVATE ${crt_lib_name}) +endforeach(crt_lib_name ${CRT_LIBS}) + +# define a library for the model sources. +zephyr_library_named(tvm_model) +file(GLOB_RECURSE tvm_model_srcs model/codegen/host/src/*.c model/codegen/host/lib/*.o) +target_sources(tvm_model PRIVATE ${tvm_model_srcs}) +target_include_directories(tvm_model PRIVATE ${CMAKE_SOURCE_DIR}/include crt_config crt/include) +target_compile_options(tvm_model PRIVATE -Wno-unused-variable) # TVM-generated code tends to include lots of these. +target_link_libraries(app PRIVATE tvm_model) + +file(GLOB_RECURSE app_srcs src/**.c) +target_sources(app PRIVATE ${app_srcs}) +target_include_directories(app PRIVATE crt_config ${CMAKE_SOURCE_DIR}/include crt/include) diff --git a/apps/microtvm/zephyr/host_driven/README.md b/apps/microtvm/zephyr/template_project/README.md similarity index 100% rename from apps/microtvm/zephyr/host_driven/README.md rename to apps/microtvm/zephyr/template_project/README.md diff --git a/apps/microtvm/zephyr/host_driven/crt/crt_config.h b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h similarity index 97% rename from apps/microtvm/zephyr/host_driven/crt/crt_config.h rename to apps/microtvm/zephyr/template_project/crt_config/crt_config.h index 658b97e267ba..f8fc7514a28d 100644 --- a/apps/microtvm/zephyr/host_driven/crt/crt_config.h +++ b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h @@ -42,7 +42,7 @@ #define TVM_CRT_MAX_REGISTERED_MODULES 2 /*! Maximum packet size, in bytes, including the length header. */ -#define TVM_CRT_MAX_PACKET_SIZE_BYTES (4 * 1024) +#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8192 /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py new file mode 100644 index 000000000000..f267648a83f9 --- /dev/null +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -0,0 +1,716 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import atexit +import collections +import collections.abc +import enum +import fcntl +import logging +import os +import os.path +import pathlib +import queue +import re +import select +import shlex +import shutil +import subprocess +import sys +import tarfile +import tempfile +import threading +import time + +import serial +import serial.tools.list_ports +import yaml + +from tvm.micro.project_api import server + + +_LOG = logging.getLogger(__name__) + + +API_SERVER_DIR = pathlib.Path(os.path.dirname(__file__) or os.path.getcwd()) + + +BUILD_DIR = API_SERVER_DIR / "build" + + +MODEL_LIBRARY_FORMAT_RELPATH = "model.tar" + + +IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists() + + +def check_call(cmd_args, *args, **kwargs): + cwd_str = "" if "cwd" not in kwargs else f" (in cwd: {kwargs['cwd']})" + _LOG.info("run%s: %s", cwd_str, " ".join(shlex.quote(a) for a in cmd_args)) + return subprocess.check_call(cmd_args, *args, **kwargs) + + +CACHE_ENTRY_RE = re.compile(r"(?P[^:]+):(?P[^=]+)=(?P.*)") + + +CMAKE_BOOL_MAP = dict( + [(k, True) for k in ("1", "ON", "YES", "TRUE", "Y")] + + [(k, False) for k in ("0", "OFF", "NO", "FALSE", "N", "IGNORE", "NOTFOUND", "")] +) + + +class CMakeCache(collections.abc.Mapping): + def __init__(self, path): + self._path = path + self._dict = None + + def __iter__(self): + return iter(self._dict) + + def __getitem__(self, key): + if self._dict is None: + self._dict = self._read_cmake_cache() + + return self._dict[key] + + def __len__(self): + return len(self._dict) + + def _read_cmake_cache(self): + """Read a CMakeCache.txt-like file and return a dictionary of values.""" + entries = collections.OrderedDict() + with open(self._path, encoding="utf-8") as f: + for line in f: + m = CACHE_ENTRY_RE.match(line.rstrip("\n")) + if not m: + continue + + if m.group("type") == "BOOL": + value = CMAKE_BOOL_MAP[m.group("value").upper()] + else: + value = m.group("value") + + entries[m.group("name")] = value + + return entries + + +CMAKE_CACHE = CMakeCache(BUILD_DIR / "CMakeCache.txt") + + +class BoardError(Exception): + """Raised when an attached board cannot be opened (i.e. missing /dev nodes, etc).""" + + +class BoardAutodetectFailed(Exception): + """Raised when no attached hardware is found matching the board= given to ZephyrCompiler.""" + + +def _get_flash_runner(): + flash_runner = CMAKE_CACHE.get("ZEPHYR_BOARD_FLASH_RUNNER") + if flash_runner is not None: + return flash_runner + + with open(CMAKE_CACHE["ZEPHYR_RUNNERS_YAML"]) as f: + doc = yaml.load(f, Loader=yaml.FullLoader) + return doc["flash-runner"] + + +def _get_device_args(options): + flash_runner = _get_flash_runner() + + if flash_runner == "nrfjprog": + return _get_nrf_device_args(options) + + if flash_runner == "openocd": + return _get_openocd_device_args(options) + + raise BoardError( + f"Don't know how to find serial terminal for board {CMAKE_CACHE['BOARD']} with flash " + f"runner {flash_runner}" + ) + + +# kwargs passed to usb.core.find to find attached boards for the openocd flash runner. +BOARD_USB_FIND_KW = { + "nucleo_l4r5zi": {"idVendor": 0x0483, "idProduct": 0x374B}, + "nucleo_f746zg": {"idVendor": 0x0483, "idProduct": 0x374B}, + "stm32f746g_disco": {"idVendor": 0x0483, "idProduct": 0x374B}, +} + + +def openocd_serial(options): + """Find the serial port to use for a board with OpenOCD flash strategy.""" + if "openocd_serial" in options: + return options["openocd_serial"] + + import usb # pylint: disable=import-outside-toplevel + + find_kw = BOARD_USB_FIND_KW[CMAKE_CACHE["BOARD"]] + boards = usb.core.find(find_all=True, **find_kw) + serials = [] + for b in boards: + serials.append(b.serial_number) + + if len(serials) == 0: + raise BoardAutodetectFailed(f"No attached USB devices matching: {find_kw!r}") + serials.sort() + + autodetected_openocd_serial = serials[0] + _LOG.debug("zephyr openocd driver: autodetected serial %s", serials[0]) + + return autodetected_openocd_serial + + +def _get_openocd_device_args(options): + return ["--serial", openocd_serial(options)] + + +def _get_nrf_device_args(options): + nrfjprog_args = ["nrfjprog", "--ids"] + nrfjprog_ids = subprocess.check_output(nrfjprog_args, encoding="utf-8") + if not nrfjprog_ids.strip("\n"): + raise BoardAutodetectFailed(f'No attached boards recognized by {" ".join(nrfjprog_args)}') + + boards = nrfjprog_ids.split("\n")[:-1] + if len(boards) > 1: + if options["nrfjprog_snr"] is None: + raise BoardError( + "Multiple boards connected; specify one with nrfjprog_snr=: " f'{", ".join(boards)}' + ) + + if str(options["nrfjprog_snr"]) not in boards: + raise BoardError( + f"nrfjprog_snr ({options['nrfjprog_snr']}) not found in {nrfjprog_args}: {boards}" + ) + + return ["--snr", options["nrfjprog_snr"]] + + if not boards: + return [] + + return ["--snr", boards[0]] + + +PROJECT_TYPES = [] +if IS_TEMPLATE: + for d in (API_SERVER_DIR / "src").iterdir(): + if d.is_dir(): + PROJECT_TYPES.append(d.name) + + +PROJECT_OPTIONS = [ + server.ProjectOption( + "extra_files", + help="If given, during generate_project, uncompress the tarball at this path into the project dir", + ), + server.ProjectOption( + "gdbserver_port", help=("If given, port number to use when running the local gdbserver") + ), + server.ProjectOption( + "nrfjprog_snr", + help=( + "When used with nRF targets, serial # of the " "attached board to use, from nrfjprog" + ), + ), + server.ProjectOption( + "openocd_serial", + help=("When used with OpenOCD targets, serial # of the " "attached board to use"), + ), + server.ProjectOption( + "project_type", + help="Type of project to generate.", + choices=tuple(PROJECT_TYPES), + ), + server.ProjectOption("verbose", help="Run build with verbose output"), + server.ProjectOption( + "west_cmd", + help=( + "Path to the west tool. If given, supersedes both the zephyr_base " + "option and ZEPHYR_BASE environment variable." + ), + ), + server.ProjectOption("zephyr_base", help="Path to the zephyr base directory."), + server.ProjectOption("zephyr_board", help="Name of the Zephyr board to build for"), +] + + +class Handler(server.ProjectAPIHandler): + def __init__(self): + super(Handler, self).__init__() + self._proc = None + + def server_info_query(self, tvm_version): + return server.ServerInfo( + platform_name="zephyr", + is_template=IS_TEMPLATE, + model_library_format_path="" + if IS_TEMPLATE + else (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH), + project_options=PROJECT_OPTIONS, + ) + + # These files and directories will be recursively copied into generated projects from the CRT. + CRT_COPY_ITEMS = ("include", "Makefile", "src") + + # Maps extra line added to prj.conf to a tuple or list of zephyr_board for which it is needed. + EXTRA_PRJ_CONF_DIRECTIVES = { + "CONFIG_TIMER_RANDOM_GENERATOR=y": ( + "qemu_x86", + "qemu_riscv32", + "qemu_cortex_r5", + "qemu_riscv64", + ), + "CONFIG_ENTROPY_GENERATOR=y": ( + "mps2_an521", + "nrf5340dk_nrf5340_cpuapp", + "nucleo_f746zg", + "nucleo_l4r5zi", + "stm32f746g_disco", + ), + } + + def _create_prj_conf(self, project_dir, options): + with open(project_dir / "prj.conf", "w") as f: + f.write( + "# For UART used from main().\n" + "CONFIG_RING_BUFFER=y\n" + "CONFIG_UART_CONSOLE=n\n" + "CONFIG_UART_INTERRUPT_DRIVEN=y\n" + "\n" + ) + f.write("# For TVMPlatformAbort().\n" "CONFIG_REBOOT=y\n" "\n") + + if options["project_type"] == "host_driven": + f.write("# For RPC server C++ bindings.\n" "CONFIG_CPLUSPLUS=y\n" "\n") + + f.write("# For math routines\n" "CONFIG_NEWLIB_LIBC=y\n" "\n") + + if self._has_fpu(options["zephyr_board"]): + f.write("# For models with floating point.\n" "CONFIG_FPU=y\n" "\n") + + main_stack_size = None + if self._is_qemu(options) and options["project_type"] == "host_driven": + main_stack_size = 1536 + + # Set main stack size, if needed. + if main_stack_size is not None: + f.write(f"CONFIG_MAIN_STACK_SIZE={main_stack_size}\n") + + f.write("# For random number generation.\n" "CONFIG_TEST_RANDOM_GENERATOR=y\n") + + f.write("\n# Extra prj.conf directives\n") + for line, board_list in self.EXTRA_PRJ_CONF_DIRECTIVES.items(): + if options["zephyr_board"] in board_list: + f.write(f"{line}\n") + + f.write("\n") + + API_SERVER_CRT_LIBS_TOKEN = "" + + CRT_LIBS_BY_PROJECT_TYPE = { + "host_driven": "microtvm_rpc_server microtvm_rpc_common common", + "aot_demo": "memory microtvm_rpc_common common", + } + + def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): + project_dir = pathlib.Path(project_dir) + # Make project directory. + project_dir.mkdir() + + # Copy ourselves to the generated project. TVM may perform further build steps on the generated project + # by launching the copy. + shutil.copy2(__file__, project_dir / os.path.basename(__file__)) + + # Place Model Library Format tarball in the special location, which this script uses to decide + # whether it's being invoked in a template or generated project. + project_model_library_format_tar_path = project_dir / MODEL_LIBRARY_FORMAT_RELPATH + shutil.copy2(model_library_format_path, project_model_library_format_tar_path) + + # Extract Model Library Format tarball.into /model. + extract_path = os.path.splitext(project_model_library_format_tar_path)[0] + with tarfile.TarFile(project_model_library_format_tar_path) as tf: + os.makedirs(extract_path) + tf.extractall(path=extract_path) + + if self._is_qemu(options): + shutil.copytree(API_SERVER_DIR / "qemu-hack", project_dir / "qemu-hack") + + # Populate CRT. + crt_path = project_dir / "crt" + crt_path.mkdir() + for item in self.CRT_COPY_ITEMS: + src_path = os.path.join(standalone_crt_dir, item) + dst_path = crt_path / item + if os.path.isdir(src_path): + shutil.copytree(src_path, dst_path) + else: + shutil.copy2(src_path, dst_path) + + # Populate Makefile. + with open(API_SERVER_DIR / "CMakeLists.txt.template", "r") as cmake_template_f: + with open(project_dir / "CMakeLists.txt", "w") as cmake_f: + for line in cmake_template_f: + if self.API_SERVER_CRT_LIBS_TOKEN in line: + crt_libs = self.CRT_LIBS_BY_PROJECT_TYPE[options["project_type"]] + line = line.replace("", crt_libs) + + cmake_f.write(line) + + self._create_prj_conf(project_dir, options) + + # Populate crt-config.h + crt_config_dir = project_dir / "crt_config" + crt_config_dir.mkdir() + shutil.copy2( + API_SERVER_DIR / "crt_config" / "crt_config.h", crt_config_dir / "crt_config.h" + ) + + # Populate src/ + src_dir = project_dir / "src" + shutil.copytree(API_SERVER_DIR / "src" / options["project_type"], src_dir) + + # Populate extra_files + if options.get("extra_files_tar"): + with tarfile.open(options["extra_files_tar"], mode="r:*") as tf: + tf.extractall(project_dir) + + def build(self, options): + BUILD_DIR.mkdir() + + cmake_args = ["cmake", ".."] + if options.get("verbose"): + cmake_args.append("-DCMAKE_VERBOSE_MAKEFILE:BOOL=TRUE") + + if options.get("zephyr_base"): + cmake_args.append(f"-DZEPHYR_BASE:STRING={options['zephyr_base']}") + + cmake_args.append(f"-DBOARD:STRING={options['zephyr_board']}") + + check_call(cmake_args, cwd=BUILD_DIR) + + args = ["make", "-j2"] + if options.get("verbose"): + args.append("VERBOSE=1") + check_call(args, cwd=BUILD_DIR) + + # A list of all zephyr_board values which are known to launch using QEMU. Many platforms which + # launch through QEMU by default include "qemu" in their name. However, not all do. This list + # includes those tested platforms which do not include qemu. + _KNOWN_QEMU_ZEPHYR_BOARDS = ("mps2_an521",) + + @classmethod + def _is_qemu(cls, options): + return ( + "qemu" in options["zephyr_board"] + or options["zephyr_board"] in cls._KNOWN_QEMU_ZEPHYR_BOARDS + ) + + _KNOWN_FPU_ZEPHYR_BOARDS = ( + "nucleo_f746zg", + "nucleo_l4r5zi", + "nrf5340dk_nrf5340_cpuapp", + "qemu_cortex_r5", + "qemu_riscv32", + "qemu_riscv64", + "qemu_x86", + "stm32f746g_disco", + ) + + @classmethod + def _has_fpu(cls, zephyr_board): + return zephyr_board in cls._KNOWN_FPU_ZEPHYR_BOARDS + + def flash(self, options): + if self._is_qemu(options): + return # NOTE: qemu requires no flash step--it is launched from open_transport. + + zephyr_board = options["zephyr_board"] + + # The nRF5340DK requires an additional `nrfjprog --recover` before each flash cycle. + # This is because readback protection is enabled by default when this device is flashed. + # Otherwise, flashing may fail with an error such as the following: + # ERROR: The operation attempted is unavailable due to readback protection in + # ERROR: your device. Please use --recover to unlock the device. + if zephyr_board.startswith("nrf5340dk") and _get_flash_runner() == "nrfjprog": + recover_args = ["nrfjprog", "--recover"] + recover_args.extend(_get_nrf_device_args(options)) + check_call(recover_args, cwd=API_SERVER_DIR / "build") + + check_call(["make", "flash"], cwd=API_SERVER_DIR / "build") + + def open_transport(self, options): + if self._is_qemu(options): + transport = ZephyrQemuTransport(options) + else: + transport = ZephyrSerialTransport(options) + + to_return = transport.open() + self._transport = transport + atexit.register(lambda: self.close_transport()) + return to_return + + def close_transport(self): + if self._transport is not None: + self._transport.close() + self._transport = None + + def read_transport(self, n, timeout_sec): + if self._transport is None: + raise server.TransportClosedError() + + return self._transport.read(n, timeout_sec) + + def write_transport(self, data, timeout_sec): + if self._transport is None: + raise server.TransportClosedError() + + return self._transport.write(data, timeout_sec) + + +def _set_nonblock(fd): + flag = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flag | os.O_NONBLOCK) + new_flag = fcntl.fcntl(fd, fcntl.F_GETFL) + assert (new_flag & os.O_NONBLOCK) != 0, "Cannot set file descriptor {fd} to non-blocking" + + +class ZephyrSerialTransport: + @classmethod + def _lookup_baud_rate(cls, options): + zephyr_base = options.get("zephyr_base", os.environ["ZEPHYR_BASE"]) + sys.path.insert(0, os.path.join(zephyr_base, "scripts", "dts")) + try: + import dtlib # pylint: disable=import-outside-toplevel + finally: + sys.path.pop(0) + + dt_inst = dtlib.DT(BUILD_DIR / "zephyr" / "zephyr.dts") + uart_baud = ( + dt_inst.get_node("/chosen") + .props["zephyr,console"] + .to_path() + .props["current-speed"] + .to_num() + ) + _LOG.debug("zephyr transport: found UART baudrate from devicetree: %d", uart_baud) + + return uart_baud + + @classmethod + def _find_nrf_serial_port(cls, options): + com_ports = subprocess.check_output( + ["nrfjprog", "--com"] + _get_device_args(options), encoding="utf-8" + ) + ports_by_vcom = {} + for line in com_ports.split("\n")[:-1]: + parts = line.split() + ports_by_vcom[parts[2]] = parts[1] + + return ports_by_vcom["VCOM2"] + + @classmethod + def _find_openocd_serial_port(cls, options): + serial_number = openocd_serial(options) + ports = [p for p in serial.tools.list_ports.grep(serial_number)] + if len(ports) != 1: + raise Exception( + f"_find_openocd_serial_port: expected 1 port to match {serial_number}, " + f"found: {ports!r}" + ) + + return ports[0].device + + @classmethod + def _find_serial_port(cls, options): + flash_runner = _get_flash_runner() + + if flash_runner == "nrfjprog": + return cls._find_nrf_serial_port(options) + + if flash_runner == "openocd": + return cls._find_openocd_serial_port(options) + + raise FlashRunnerNotSupported( + f"Don't know how to deduce serial port for flash runner {flash_runner}" + ) + + def __init__(self, options): + self._options = options + self._port = None + + def open(self): + port_path = self._find_serial_port(self._options) + self._port = serial.Serial(port_path, baudrate=self._lookup_baud_rate(self._options)) + return server.TransportTimeouts( + session_start_retry_timeout_sec=2.0, + session_start_timeout_sec=5.0, + session_established_timeout_sec=5.0, + ) + + def close(self): + self._port.close() + self._port = None + + def read(self, n, timeout_sec): + self._port.timeout = timeout_sec + to_return = self._port.read(n) + if not to_return: + raise server.IoTimeoutError() + + return to_return + + def write(self, data, timeout_sec): + self._port.write_timeout = timeout_sec + bytes_written = 0 + while bytes_written < len(data): + n = self._port.write(data) + data = data[n:] + bytes_written += n + + +class ZephyrQemuMakeResult(enum.Enum): + QEMU_STARTED = "qemu_started" + MAKE_FAILED = "make_failed" + EOF = "eof" + + +class ZephyrQemuTransport: + """The user-facing Zephyr QEMU transport class.""" + + def __init__(self, options): + self.options = options + self.proc = None + self.pipe_dir = None + self.read_fd = None + self.write_fd = None + self._queue = queue.Queue() + + def open(self): + self.pipe_dir = pathlib.Path(tempfile.mkdtemp()) + self.pipe = self.pipe_dir / "fifo" + self.write_pipe = self.pipe_dir / "fifo.in" + self.read_pipe = self.pipe_dir / "fifo.out" + os.mkfifo(self.write_pipe) + os.mkfifo(self.read_pipe) + + if "gdbserver_port" in self.options: + if "env" in self.kwargs: + self.kwargs["env"] = copy.copy(self.kwargs["env"]) + else: + self.kwargs["env"] = os.environ.copy() + + self.kwargs["env"]["TVM_QEMU_GDBSERVER_PORT"] = str(self.options["gdbserver_port"]) + + self.proc = subprocess.Popen( + ["make", "run", f"QEMU_PIPE={self.pipe}"], + cwd=BUILD_DIR, + stdout=subprocess.PIPE, + ) + self._wait_for_qemu() + + # NOTE: although each pipe is unidirectional, open both as RDWR to work around a select + # limitation on linux. Without this, non-blocking I/O can't use timeouts because named + # FIFO are always considered ready to read when no one has opened them for writing. + self.read_fd = os.open(self.read_pipe, os.O_RDWR | os.O_NONBLOCK) + self.write_fd = os.open(self.write_pipe, os.O_RDWR | os.O_NONBLOCK) + _set_nonblock(self.read_fd) + _set_nonblock(self.write_fd) + + return server.TransportTimeouts( + session_start_retry_timeout_sec=2.0, + session_start_timeout_sec=10.0, + session_established_timeout_sec=10.0, + ) + + def close(self): + did_write = False + if self.write_fd is not None: + try: + server.write_with_timeout( + self.write_fd, b"\x01x", 1.0 + ) # Use a short timeout since we will kill the process + did_write = True + except server.IoTimeoutError: + pass + os.close(self.write_fd) + self.write_fd = None + + if self.proc: + if not did_write: + self.proc.terminate() + try: + self.proc.wait(5.0) + except subprocess.TimeoutExpired: + self.proc.kill() + + if self.read_fd: + os.close(self.read_fd) + self.read_fd = None + + if self.pipe_dir is not None: + shutil.rmtree(self.pipe_dir) + self.pipe_dir = None + + def read(self, n, timeout_sec): + return server.read_with_timeout(self.read_fd, n, timeout_sec) + + def write(self, data, timeout_sec): + to_write = bytearray() + escape_pos = [] + for i, b in enumerate(data): + if b == 0x01: + to_write.append(b) + escape_pos.append(i) + to_write.append(b) + + while to_write: + num_written = server.write_with_timeout(self.write_fd, to_write, timeout_sec) + to_write = to_write[num_written:] + + def _qemu_check_stdout(self): + for line in self.proc.stdout: + line = str(line) + _LOG.info("%s", line) + if "[QEMU] CPU" in line: + self._queue.put(ZephyrQemuMakeResult.QEMU_STARTED) + else: + line = re.sub("[^a-zA-Z0-9 \n]", "", line) + pattern = r"recipe for target (\w*) failed" + if re.search(pattern, line, re.IGNORECASE): + self._queue.put(ZephyrQemuMakeResult.MAKE_FAILED) + self._queue.put(ZephyrQemuMakeResult.EOF) + + def _wait_for_qemu(self): + threading.Thread(target=self._qemu_check_stdout, daemon=True).start() + while True: + try: + item = self._queue.get(timeout=120) + except Exception: + raise TimeoutError("QEMU setup timeout.") + + if item == ZephyrQemuMakeResult.QEMU_STARTED: + break + + if item in [ZephyrQemuMakeResult.MAKE_FAILED, ZephyrQemuMakeResult.EOF]: + raise RuntimeError("QEMU setup failed.") + + raise ValueError(f"{item} not expected.") + + +if __name__ == "__main__": + server.main(Handler()) diff --git a/apps/microtvm/zephyr/qemu-hack/qemu-system-riscv32 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm similarity index 100% rename from apps/microtvm/zephyr/qemu-hack/qemu-system-riscv32 rename to apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm diff --git a/apps/microtvm/zephyr/qemu-hack/qemu-system-i386 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-i386 similarity index 91% rename from apps/microtvm/zephyr/qemu-hack/qemu-system-i386 rename to apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-i386 index a30605204d31..6871efbc8b6f 100755 --- a/apps/microtvm/zephyr/qemu-hack/qemu-system-i386 +++ b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-i386 @@ -31,8 +31,8 @@ while [ "$#" -gt 0 ]; do done # For debugging -if [ "${TVM_QEMU_DEBUG}" != "" ]; then - ARGS=( "${ARGS[@]}" -s -S ) +if [ "${TVM_QEMU_GDBSERVER_PORT}" != "" ]; then + ARGS=( "${ARGS[@]}" -gdb "tcp::${TVM_QEMU_GDBSERVER_PORT}" -S ) fi "${ARGS[@]}" diff --git a/apps/microtvm/zephyr/qemu-hack/qemu-system-riscv64 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 similarity index 100% rename from apps/microtvm/zephyr/qemu-hack/qemu-system-riscv64 rename to apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 new file mode 120000 index 000000000000..ebbc8ad5ad9d --- /dev/null +++ b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 @@ -0,0 +1 @@ +qemu-system-i386 \ No newline at end of file diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 new file mode 120000 index 000000000000..ebbc8ad5ad9d --- /dev/null +++ b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 @@ -0,0 +1 @@ +qemu-system-i386 \ No newline at end of file diff --git a/apps/microtvm/zephyr/aot_demo/src/main.c b/apps/microtvm/zephyr/template_project/src/aot_demo/main.c similarity index 94% rename from apps/microtvm/zephyr/aot_demo/src/main.c rename to apps/microtvm/zephyr/template_project/src/aot_demo/main.c index 43cc7b33987b..a96e3b4d0a4e 100644 --- a/apps/microtvm/zephyr/aot_demo/src/main.c +++ b/apps/microtvm/zephyr/template_project/src/aot_demo/main.c @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -32,6 +31,7 @@ #include "input_data.h" #include "output_data.h" +#include "tvmgen_default.h" #include "zephyr_uart.h" #ifdef CONFIG_ARCH_POSIX @@ -41,12 +41,10 @@ #define WORKSPACE_SIZE (270 * 1024) static uint8_t g_aot_memory[WORKSPACE_SIZE]; -extern tvm_model_t tvmgen_default_network; tvm_workspace_t app_workspace; // Wakeup sequence used to wake up QEMU on the host. -const unsigned char g_wakeup_sequence[12] = {0xfe, 0xff, 0xfd, 0x03, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x02, 0x66, 0x77}; +const unsigned char g_wakeup_sequence[] = "#wakeup\n"; const char g_start_cmd[] = "start\n"; size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, @@ -194,18 +192,18 @@ void main(void) { } TVMLogf("Zephyr AOT Runtime\n"); - void* inputs[1] = { - input_data, + struct tvmgen_default_inputs inputs = { + .input_1 = input_data, }; - void* outputs[1] = { - output_data, + struct tvmgen_default_outputs outputs = { + .output = output_data, }; StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE); double elapsed_time = 0; TVMPlatformTimerStart(); - int ret_val = tvm_runtime_run(&tvmgen_default_network, inputs, outputs); + int ret_val = tvmgen_default_run(&inputs, &outputs); TVMPlatformTimerStop(&elapsed_time); if (ret_val != 0) { diff --git a/apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c b/apps/microtvm/zephyr/template_project/src/aot_demo/zephyr_uart.c similarity index 100% rename from apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c rename to apps/microtvm/zephyr/template_project/src/aot_demo/zephyr_uart.c diff --git a/apps/microtvm/zephyr/aot_demo/include/zephyr_uart.h b/apps/microtvm/zephyr/template_project/src/aot_demo/zephyr_uart.h similarity index 100% rename from apps/microtvm/zephyr/aot_demo/include/zephyr_uart.h rename to apps/microtvm/zephyr/template_project/src/aot_demo/zephyr_uart.h diff --git a/apps/microtvm/zephyr/host_driven/src/main.c b/apps/microtvm/zephyr/template_project/src/host_driven/main.c similarity index 99% rename from apps/microtvm/zephyr/host_driven/src/main.c rename to apps/microtvm/zephyr/template_project/src/host_driven/main.c index 5b93d647eb00..43064e804193 100644 --- a/apps/microtvm/zephyr/host_driven/src/main.c +++ b/apps/microtvm/zephyr/template_project/src/host_driven/main.c @@ -100,7 +100,7 @@ size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const // Called by TVM when an internal invariant is violated, and execution cannot continue. void TVMPlatformAbort(tvm_crt_error_t error) { - TVMLogf("TVMError: %x", error); + TVMLogf("TVMError: 0x%x", error); sys_reboot(SYS_REBOOT_COLD); #ifdef CONFIG_LED gpio_pin_set(led0_pin, LED0_PIN, 1); diff --git a/apps/wasm-standalone/README.md b/apps/wasm-standalone/README.md index e40d218634aa..b8a977f6ae50 100644 --- a/apps/wasm-standalone/README.md +++ b/apps/wasm-standalone/README.md @@ -116,16 +116,10 @@ This project should be considered **experimental** at the very early stage, all - Build DL library in the WebAssembly format. - - Download model + - Compile the model ``` - cd wasm-graph/tools && wget https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v1/resnet50v1.onnx - ``` - - - Compile - - ``` - LLVM_AR=llvm-ar-10 python ./build_graph_lib.py -O3 ./resnet50v1.onnx + cd wasm-graph/tools && LLVM_AR=llvm-ar-10 python ./build_graph_lib.py -O3 ``` ### Build wasm-graph package @@ -170,9 +164,14 @@ $ wget -O synset.csv https://raw.githubusercontent.com/kazum/tvm-wasm/master/syn $ ./target/debug/test_graph_resnet50 -g ./wasm_graph_resnet50.wasm -i ./cat.png -l ./synset.csv original image dimensions: (256, 256) resized image dimensions: (224, 224) -input image belongs to the class `tabby, tabby cat` +input image belongs to the class `tiger cat` ``` +Note: this example also works without WASI support. Please modify `wasm-graph/.cargo/config` to change the target to +`wasm32-unknown-unknown` and uncomment the raw wasm engine in `wasm-runtime/src/graph.rs` to run in pure wasm32. SIMD +may not be supported without WASI support. You may also need to delete ` -mattr=+simd128` in the +[build script](apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py). + ## Future Work ### More networks support diff --git a/apps/wasm-standalone/wasm-graph/src/lib.rs b/apps/wasm-standalone/wasm-graph/src/lib.rs index 2b4187849edc..92a3d5c2f3b0 100644 --- a/apps/wasm-standalone/wasm-graph/src/lib.rs +++ b/apps/wasm-standalone/wasm-graph/src/lib.rs @@ -48,6 +48,7 @@ lazy_static! { "/lib/graph.json" ))) .unwrap(); + let params_bytes = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/lib/graph.params")); let params = tvm_graph_rt::load_param_dict(params_bytes) @@ -57,6 +58,7 @@ lazy_static! { .collect::>>(); let mut exec = GraphExecutor::new(graph, &*SYSLIB).unwrap(); + exec.load_params(params); Mutex::new(exec) @@ -68,14 +70,14 @@ pub extern "C" fn run(wasm_addr: i32, in_size: i32) -> i32 { let in_tensor = unsafe { utils::load_input(wasm_addr, in_size as usize) }; let input: TVMTensor = in_tensor.as_dltensor().into(); - GRAPH_EXECUTOR.lock().unwrap().set_input("data", input); - GRAPH_EXECUTOR.lock().unwrap().run(); - let output = GRAPH_EXECUTOR - .lock() - .unwrap() - .get_output(0) - .unwrap() - .as_dltensor(false); + // since this executor is not multi-threaded, we can acquire lock once + let mut executor = GRAPH_EXECUTOR.lock().unwrap(); + + executor.set_input("data", input); + + executor.run(); + + let output = executor.get_output(0).unwrap().as_dltensor(false); let out_tensor: Tensor = output.into(); let out_size = unsafe { utils::store_output(wasm_addr, out_tensor) }; diff --git a/apps/wasm-standalone/wasm-graph/src/types.rs b/apps/wasm-standalone/wasm-graph/src/types.rs index a3761a758cff..f08b7be84990 100644 --- a/apps/wasm-standalone/wasm-graph/src/types.rs +++ b/apps/wasm-standalone/wasm-graph/src/types.rs @@ -24,7 +24,7 @@ use std::{ }; pub use tvm_sys::ffi::DLTensor; use tvm_sys::ffi::{ - DLDevice, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDeviceType_kDLCPU, + DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDevice, DLDeviceType_kDLCPU, }; #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] diff --git a/apps/wasm-standalone/wasm-graph/src/utils.rs b/apps/wasm-standalone/wasm-graph/src/utils.rs index fd4a71745f4f..92d386e3062a 100644 --- a/apps/wasm-standalone/wasm-graph/src/utils.rs +++ b/apps/wasm-standalone/wasm-graph/src/utils.rs @@ -24,13 +24,20 @@ use std::ptr; pub unsafe fn load_input(in_addr: i32, in_size: usize) -> Tensor { let in_addr = in_addr as *mut u8; - let mut data_vec = Vec::new(); - for i in 0..in_size { - data_vec.push(ptr::read(in_addr.offset(i as isize))); - } - let input: Tensor = serde_json::from_slice(&data_vec).unwrap(); + println!("DEBUG: in_addr {:?}, in_size {:?}", in_addr, in_size); + + let data_vec = unsafe { std::slice::from_raw_parts(in_addr, in_size) }; - input + let input = serde_json::from_slice(&data_vec); + match input { + Ok(result) => { + println!("DEBUG: SER SUCCEED!!! and Ok"); + result + } + Err(e) => { + panic!("DEBUG: SER SUCCEED!!! but Err, {:?}", &e); + } + } } pub unsafe fn store_output(out_addr: i32, output: Tensor) -> usize { diff --git a/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py b/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py old mode 100644 new mode 100755 index 3d8a349b8744..b1cdb199a871 --- a/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py +++ b/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -"""Builds a simple graph for testing.""" +"""Builds a simple resnet50 graph for testing.""" import argparse import os import subprocess @@ -25,47 +25,78 @@ import onnx import tvm from tvm import relay, runtime +from tvm.contrib.download import download_testdata +from tvm.contrib import graph_executor +from PIL import Image +import numpy as np +import tvm.relay as relay -def _get_mod_and_params(model_file): - onnx_model = onnx.load(model_file) - shape_dict = {} - for input in onnx_model.graph.input: - shape_dict[input.name] = [dim.dim_value for dim in input.type.tensor_type.shape.dim] +# This example uses resnet50-v2-7 model +model_url = "".join( + [ + "https://github.com/onnx/models/raw/", + "master/vision/classification/resnet/model/", + "resnet50-v2-7.onnx", + ] +) - return relay.frontend.from_onnx(onnx_model, shape_dict) - -def build_graph_lib(model_file, opt_level): +def build_graph_lib(opt_level): """Compiles the pre-trained model with TVM""" out_dir = os.path.join(sys.path[0], "../lib") if not os.path.exists(out_dir): os.makedirs(out_dir) - # Compile the relay mod - mod, params = _get_mod_and_params(model_file) + # Follow the tutorial to download and compile the model + model_path = download_testdata(model_url, "resnet50-v2-7.onnx", module="onnx") + onnx_model = onnx.load(model_path) + + img_url = "https://s3.amazonaws.com/model-server/inputs/kitten.jpg" + img_path = download_testdata(img_url, "imagenet_cat.png", module="data") + + # Resize it to 224x224 + resized_image = Image.open(img_path).resize((224, 224)) + img_data = np.asarray(resized_image).astype("float32") + + # Our input image is in HWC layout while ONNX expects CHW input, so convert the array + img_data = np.transpose(img_data, (2, 0, 1)) + + # Normalize according to the ImageNet input specification + imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) + imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev + + # Add the batch dimension, as we are expecting 4-dimensional input: NCHW. + img_data = np.expand_dims(norm_img_data, axis=0) + + input_name = "data" + shape_dict = {input_name: img_data.shape} + + mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) target = "llvm -mtriple=wasm32-unknown-unknown -mattr=+simd128 --system-lib" + with tvm.transform.PassContext(opt_level=opt_level): - graph_json, lib, params = relay.build(mod, target=target, params=params) + factory = relay.build(mod, target=target, params=params) # Save the model artifacts to obj_file obj_file = os.path.join(out_dir, "graph.o") - lib.save(obj_file) + factory.get_lib().save(obj_file) + # Run llvm-ar to archive obj_file into lib_file lib_file = os.path.join(out_dir, "libgraph_wasm32.a") cmds = [os.environ.get("LLVM_AR", "llvm-ar-10"), "rcs", lib_file, obj_file] subprocess.run(cmds) + # Save the json and params with open(os.path.join(out_dir, "graph.json"), "w") as f_graph: - f_graph.write(graph_json) - + f_graph.write(factory.get_graph_json()) with open(os.path.join(out_dir, "graph.params"), "wb") as f_params: - f_params.write(runtime.save_param_dict(params)) + f_params.write(runtime.save_param_dict(factory.get_params())) if __name__ == "__main__": parser = argparse.ArgumentParser(description="ONNX model build example") - parser.add_argument("model_file", type=str, help="the path of onnx model file") parser.add_argument( "-O", "--opt-level", @@ -75,4 +106,4 @@ def build_graph_lib(model_file, opt_level): ) args = parser.parse_args() - build_graph_lib(args.model_file, args.opt_level) + build_graph_lib(args.opt_level) diff --git a/apps/wasm-standalone/wasm-runtime/Cargo.toml b/apps/wasm-standalone/wasm-runtime/Cargo.toml index 99f6db54431f..d3f860170d4e 100644 --- a/apps/wasm-standalone/wasm-runtime/Cargo.toml +++ b/apps/wasm-standalone/wasm-runtime/Cargo.toml @@ -26,8 +26,8 @@ license = "Apache-2.0" keywords = ["wasm", "machine learning", "wasmtime"] [dependencies] -wasmtime = "0.16.0" -wasmtime-wasi = "0.16.0" +wasmtime = "0.28.0" +wasmtime-wasi = "0.28.0" anyhow = "1.0.31" serde = "1.0.53" serde_json = "1.0.53" diff --git a/apps/wasm-standalone/wasm-runtime/src/graph.rs b/apps/wasm-standalone/wasm-runtime/src/graph.rs index e7c39cbb0687..bfa1c2f19c56 100644 --- a/apps/wasm-standalone/wasm-runtime/src/graph.rs +++ b/apps/wasm-standalone/wasm-runtime/src/graph.rs @@ -19,7 +19,7 @@ use anyhow::Result; use wasmtime::*; -use wasmtime_wasi::{Wasi, WasiCtx}; +use wasmtime_wasi::{WasiCtx, WasiCtxBuilder}; use super::Tensor; @@ -27,6 +27,9 @@ pub struct GraphExecutor { pub(crate) wasm_addr: i32, pub(crate) input_size: i32, pub(crate) output_size: i32, + pub(crate) store: Option>, + // None-WASI version: + // pub(crate) store: Option>, pub(crate) instance: Option, } @@ -37,25 +40,44 @@ impl GraphExecutor { wasm_addr: 0, input_size: 0, output_size: 0, + store: None, instance: None, } } pub fn instantiate(&mut self, wasm_graph_file: String) -> Result<()> { - let engine = Engine::new(Config::new().wasm_simd(true)); - let store = Store::new(&engine); + // It seems WASI in this example is not necessary + // None WASI version: works with no SIMD + // let engine = Engine::new(Config::new().wasm_simd(true)).unwrap(); + // let mut store = Store::new(&engine, ()); + // let module = Module::from_file(store.engine(), &wasm_graph_file)?; + + // let instance = Instance::new(&mut store, &module, &[])?; + + // self.instance = Some(instance); + // self.store = Some(store); + + // Ok(()) + + // WASI version: + let engine = Engine::new(Config::new().wasm_simd(true)).unwrap(); // First set up our linker which is going to be linking modules together. We // want our linker to have wasi available, so we set that up here as well. - let mut linker = Linker::new(&store); + let mut linker = Linker::new(&engine); + wasmtime_wasi::add_to_linker(&mut linker, |s| s)?; // Create an instance of `Wasi` which contains a `WasiCtx`. Note that // `WasiCtx` provides a number of ways to configure what the target program // will have access to. - let wasi = Wasi::new(&store, WasiCtx::new(std::env::args())?); - wasi.add_to_linker(&mut linker)?; + let wasi = WasiCtxBuilder::new() + .inherit_stdio() + .inherit_args()? + .build(); + let mut store = Store::new(&engine, wasi); - let module = Module::from_file(&store, &wasm_graph_file)?; - self.instance = Some(linker.instantiate(&module)?); + let module = Module::from_file(&engine, &wasm_graph_file)?; + self.instance = Some(linker.instantiate(&mut store, &module)?); + self.store = Some(store); Ok(()) } @@ -65,26 +87,24 @@ impl GraphExecutor { .instance .as_ref() .unwrap() - .get_memory("memory") + .get_memory(self.store.as_mut().unwrap(), "memory") .ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?; // Specify the wasm address to access the wasm memory. - let wasm_addr = memory.data_size(); + let wasm_addr = memory.data_size(self.store.as_mut().unwrap()); + // Serialize the data into a JSON string. let in_data = serde_json::to_vec(&input_data)?; let in_size = in_data.len(); + // Grow up memory size according to in_size to avoid memory leak. - memory.grow((in_size >> 16) as u32 + 1)?; + memory.grow(self.store.as_mut().unwrap(), (in_size >> 16) as u32 + 1)?; - // Insert the input data into wasm memory. - for i in 0..in_size { - unsafe { - memory.data_unchecked_mut()[wasm_addr + i] = *in_data.get(i).unwrap(); - } - } + memory.write(self.store.as_mut().unwrap(), wasm_addr, &in_data)?; self.wasm_addr = wasm_addr as i32; self.input_size = in_size as i32; + Ok(()) } @@ -94,11 +114,12 @@ impl GraphExecutor { .instance .as_ref() .unwrap() - .get_func("run") - .ok_or_else(|| anyhow::format_err!("failed to find `run` function export!"))? - .get2::()?; + .get_func(self.store.as_mut().unwrap(), "run") + .ok_or_else(|| anyhow::format_err!("failed to find `run` function export!"))?; - let out_size = run(self.wasm_addr, self.input_size)?; + let params = [Val::I32(self.wasm_addr), Val::I32(self.input_size)]; + let out_size = run.call(self.store.as_mut().unwrap(), ¶ms[..])?; + let out_size = (*out_size)[0].unwrap_i32(); if out_size == 0 { panic!("graph run failed!"); } @@ -107,18 +128,22 @@ impl GraphExecutor { Ok(()) } - pub fn get_output(&self) -> Result { + pub fn get_output(&mut self) -> Result { let memory = self .instance .as_ref() .unwrap() - .get_memory("memory") + .get_memory(self.store.as_mut().unwrap(), "memory") .ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?; - let out_data = unsafe { - &memory.data_unchecked()[self.wasm_addr as usize..][..self.output_size as usize] - }; - let out_vec: Tensor = serde_json::from_slice(out_data).unwrap(); + let mut out_data = vec![0 as u8; self.output_size as _]; + memory.read( + self.store.as_mut().unwrap(), + self.wasm_addr as _, + &mut out_data, + )?; + + let out_vec: Tensor = serde_json::from_slice(&out_data).unwrap(); Ok(out_vec) } } diff --git a/cmake/config.cmake b/cmake/config.cmake index ae257d435155..e55f1197d90e 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -273,6 +273,13 @@ set(USE_FALLBACK_STL_MAP OFF) set(USE_HEXAGON_DEVICE OFF) set(USE_HEXAGON_SDK /path/to/sdk) +# Hexagon architecture to target when compiling TVM itself (not the target for +# compiling _by_ TVM). This applies to components like the TVM runtime, but is +# also used to select correct include/library paths from the Hexagon SDK when +# building offloading runtime for Android. +# Valid values are v60, v62, v65, v66, v68. +set(USE_HEXAGON_ARCH "v66") + # Whether to use ONNX codegen set(USE_TARGET_ONNX OFF) @@ -299,3 +306,23 @@ set(USE_LIBBACKTRACE AUTO) # not be included in the final executable. This would make the corresponding # runtime functions to be unavailable to the program. set(BUILD_STATIC_RUNTIME OFF) + + +# Caches the build so that building is faster when switching between branches. +# If you switch branches, build and then encounter a linking error, you may +# need to regenerate the build tree through "make .." (the cache will +# still provide significant speedups). +# Possible values: +# - AUTO: search for path to ccache, disable if not found. +# - ON: enable ccache by searching for the path to ccache, report an error if not found +# - OFF: disable ccache +# - /path/to/ccache: use specific path to ccache +set(USE_CCACHE AUTO) + +# Whether to enable PAPI support in profiling. PAPI provides access to hardware +# counters while profiling. +# Possible values: +# - ON: enable PAPI support. Will search PKG_CONFIG_PATH for a papi.pc +# - OFF: disable PAPI support. +# - /path/to/folder/containing/: Path to folder containing papi.pc. +set(USE_PAPI OFF) diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index 80df76b04645..eb3ad1f5ae4a 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -16,12 +16,12 @@ # under the License. include(ExternalProject) +include(cmake/modules/HexagonSDK.cmake) set(PICK_SIM "sim") set(PICK_HW "target") set(PICK_NONE "OFF") -set(FOUND_HEXAGON_SDK_ROOT FALSE) set(FOUND_HEXAGON_TOOLCHAIN FALSE) function(find_hexagon_toolchain) @@ -47,41 +47,10 @@ function(find_hexagon_toolchain) endif() endfunction() -function(find_hexagon_sdk_root) - if(FOUND_HEXAGON_SDK_ROOT) - return() - endif() - message(STATUS "Checking Hexagon SDK root: ${USE_HEXAGON_SDK}") - file(GLOB_RECURSE HEXAGON_AEESTDDEF "${USE_HEXAGON_SDK}/*/AEEStdDef.h") - if(HEXAGON_AEESTDDEF) - # The path is ${HEXAGON_SDK_ROOT}/incs/stddef/AEEStdDef.h. - get_filename_component(HEXAGON_TMP0 "${HEXAGON_AEESTDDEF}" DIRECTORY) - get_filename_component(HEXAGON_TMP1 "${HEXAGON_TMP0}" DIRECTORY) - get_filename_component(HEXAGON_TMP2 "${HEXAGON_TMP1}" DIRECTORY) - set(HEXAGON_SDK_ROOT "${HEXAGON_TMP2}" CACHE PATH - "Root directory of Hexagon SDK") - set(FOUND_HEXAGON_SDK_ROOT TRUE) - else(HEXAGON_AEESTDDEF) - message(SEND_ERROR "Cannot validate Hexagon SDK in ${USE_HEXAGON_SDK}") - endif() -endfunction() - if(BUILD_FOR_HEXAGON) - find_hexagon_sdk_root() - if(HEXAGON_SDK_ROOT MATCHES "3.5.1") - message(SEND_ERROR "Hexagon SDK 3.5.1 is not supported") - elseif(HEXAGON_SDK_ROOT MATCHES "3\.[0-9]+\.[0-9]+") - include_directories( - SYSTEM "${USE_HEXAGON_SDK}/libs/common/qurt/ADSPv62MP/include/posix" - SYSTEM "${USE_HEXAGON_SDK}/libs/common/qurt/ADSPv62MP/include/qurt") - else() - include_directories( - SYSTEM "${HEXAGON_SDK_ROOT}/rtos/qurt/computev65/include/posix" - SYSTEM "${HEXAGON_SDK_ROOT}/rtos/qurt/computev65/include/qurt") - endif() - include_directories( - SYSTEM "${HEXAGON_SDK_ROOT}/incs" - SYSTEM "${HEXAGON_SDK_ROOT}/incs/stddef") + find_hexagon_sdk_root("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}") + # Add SDK and QuRT includes when building for Hexagon. + include_directories(SYSTEM ${HEXAGON_SDK_INCLUDES} ${HEXAGON_QURT_INCLUDES}) endif() if(USE_HEXAGON_DEVICE STREQUAL "OFF") @@ -113,29 +82,19 @@ if(USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}") CMAKE_ARGS "-DCMAKE_C_COMPILER=${HEXAGON_TOOLCHAIN}/bin/hexagon-clang" "-DCMAKE_CXX_COMPILER=${HEXAGON_TOOLCHAIN}/bin/hexagon-clang++" + "-DHEXAGON_ARCH=${USE_HEXAGON_ARCH}" INSTALL_COMMAND "true" ) elseif(USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") - find_hexagon_sdk_root() + find_hexagon_sdk_root("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}") find_hexagon_toolchain() - message(STATUS "Hexagon SDK: ${HEXAGON_SDK_ROOT}") - if(HEXAGON_SDK_ROOT MATCHES "3.5.1") - message(SEND_ERROR "Hexagon SDK 3.5.1 is not supported") - elseif(HEXAGON_SDK_ROOT MATCHES "3\.[0-9]+\.[0-9]+") - set(RPCMEM_DIR "libs/common/rpcmem") - set(REMOTE_DIR "libs/common/remote/ship/android_Release_aarch64") - else() - set(RPCMEM_DIR "ipc/fastrpc/rpcmem") - set(REMOTE_DIR "ipc/fastrpc/remote/ship/android_aarch64") - endif() file(GLOB RUNTIME_HEXAGON_DEVICE_SRCS src/runtime/hexagon/target/*.cc) - include_directories(SYSTEM "${HEXAGON_SDK_ROOT}/incs/stddef") - include_directories(SYSTEM "${HEXAGON_SDK_ROOT}/${RPCMEM_DIR}/inc") - include_directories( - SYSTEM "${HEXAGON_SDK_ROOT}/incs") - include_directories( - SYSTEM "${HEXAGON_SDK_ROOT}/${REMOTE_DIR}") - include_directories(SYSTEM "${HEXAGON_TOOLCHAIN}/include/iss") + + include_directories(SYSTEM + ${HEXAGON_SDK_INCLUDES} + ${HEXAGON_RPCMEM_ROOT}/inc + ${HEXAGON_REMOTE_ROOT} + ) list(APPEND TVM_RUNTIME_LINKER_LIBS "dl") if(BUILD_FOR_ANDROID) # Hexagon runtime uses __android_log_print, which is in liblog. diff --git a/cmake/modules/HexagonSDK.cmake b/cmake/modules/HexagonSDK.cmake new file mode 100644 index 000000000000..9541f5be821c --- /dev/null +++ b/cmake/modules/HexagonSDK.cmake @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set(FOUND_HEXAGON_SDK_ROOT FALSE) + +macro(set_parent var) + set(${var} ${ARGN} PARENT_SCOPE) +endmacro() + +function(find_hexagon_sdk_root HEXAGON_SDK_PATH HEXAGON_ARCH) + if(FOUND_HEXAGON_SDK_ROOT) + return() + endif() + if(${ARGC} LESS "2") + message(SEND_ERROR "Must provide Hexagon SDK path and Hexagon arch") + endif() + + # Initial verification of the Hexagon SDK. + message(STATUS "Checking Hexagon SDK root: ${HEXAGON_SDK_PATH}") + file(GLOB_RECURSE VERSION_HEADERS "${HEXAGON_SDK_PATH}/*/version.h") + if(VERSION_HEADERS) + foreach(HEADER IN LISTS VERSION_HEADERS) + if(HEADER MATCHES "incs/version.h$") + set(SDK_VERSION_HEADER "${HEADER}") + break() + endif() + endforeach() + # The path is ${HEXAGON_SDK_ROOT}/incs/version.h. + get_filename_component(TMP0 "${SDK_VERSION_HEADER}" DIRECTORY) + get_filename_component(TMP1 "${TMP0}" DIRECTORY) + set(HEXAGON_SDK_ROOT "${TMP1}" CACHE PATH "Root directory of Hexagon SDK") + else() + message(SEND_ERROR "Cannot validate Hexagon SDK in ${HEXAGON_SDK_PATH}") + endif() + + execute_process( + COMMAND grep "#define[ \t]*VERSION_STRING" "${SDK_VERSION_HEADER}" + OUTPUT_VARIABLE SDK_VERSION_DEFINE) + string( + REGEX REPLACE ".*VERSION_STRING.* ([0-9\\.]+) .*" "\\1" + SDK_VERSION_STRING "${SDK_VERSION_DEFINE}") + + if (SDK_VERSION_STRING MATCHES "3.5.1") + message(SEND_ERROR "Hexagon SDK 3.5.1 is not supported") + endif() + + # Set the Hexagon arch directory component. + set(HEXARCH_DIR_v60 "ADSPv60MP") + set(HEXARCH_DIR_v62 "ADSPv62MP") + set(HEXARCH_DIR_v65 "computev65") + set(HEXARCH_DIR_v66 "computev66") + set(HEXARCH_DIR_v68 "computev68") + set(HEXARCH_DIR_STR "HEXARCH_DIR_${HEXAGON_ARCH}") + set(HEXARCH_DIR "${${HEXARCH_DIR_STR}}") + + if(NOT HEXARCH_DIR) + message(SEND_ERROR + "Please set HEXAGON_ARCH to one of v60, v62, v65, v66, v68") + endif() + + # Set parent variables: + # - HEXAGON_SDK_VERSION + # - HEXAGON_SDK_INCLUDES + # - HEXAGON_QURT_INCLUDES + # - HEXAGON_RPCMEM_ROOT + # - HEXAGON_REMOTE_ROOT + # - HEXAGON_QAIC_EXE + set_parent(HEXAGON_SDK_VERSION "${SDK_VERSION_STRING}") + + if(SDK_VERSION_STRING MATCHES "^3\.[0-9]+\.[0-9]+") + # SDK 3.x.y + if(HEXAGON_ARCH MATCHES "v6[7-9]|v[7-9][0-9]") + message(SEND_ERROR + "Hexagon SDK ${SDK_VERSION_STRING} does not support ${HEXAGON_ARCH}") + endif() + set_parent(HEXAGON_SDK_INCLUDES + "${HEXAGON_SDK_ROOT}/incs" + "${HEXAGON_SDK_ROOT}/incs/a1std" + "${HEXAGON_SDK_ROOT}/incs/qlist" + "${HEXAGON_SDK_ROOT}/incs/stddef") + set_parent(HEXAGON_QURT_INCLUDES + "${HEXAGON_SDK_ROOT}/libs/common/qurt/${HEXARCH_DIR}/include/posix" + "${HEXAGON_SDK_ROOT}/libs/common/qurt/${HEXARCH_DIR}/include/qurt") + set_parent(HEXAGON_RPCMEM_ROOT "${HEXAGON_SDK_ROOT}/libs/common/rpcmem") + set_parent(HEXAGON_REMOTE_ROOT + "${HEXAGON_SDK_ROOT}/libs/common/remote/ship/android_Release_aarch64") + set_parent(HEXAGON_QAIC_EXE "${HEXAGON_SDK_ROOT}/tools/qaic/bin/qaic") + else() + # SDK 4.x.y.z + if(HEXAGON_ARCH MATCHES "v6[02]") + message(SEND_ERROR + "Hexagon SDK ${SDK_VERSION_STRING} does not support ${HEXAGON_ARCH}") + endif() + set_parent(HEXAGON_SDK_INCLUDES + "${HEXAGON_SDK_ROOT}/incs" + "${HEXAGON_SDK_ROOT}/incs/stddef") + set_parent(HEXAGON_QURT_INCLUDES + "${HEXAGON_SDK_ROOT}/rtos/qurt/${HEXARCH_DIR}/include/posix" + "${HEXAGON_SDK_ROOT}/rtos/qurt/${HEXARCH_DIR}/include/qurt") + set_parent(HEXAGON_RPCMEM_ROOT "${HEXAGON_SDK_ROOT}/ipc/fastrpc/rpcmem") + set_parent(HEXAGON_REMOTE_ROOT # libadsprpc.so + "${HEXAGON_SDK_ROOT}/ipc/fastrpc/remote/ship/android_aarch64") + set_parent(HEXAGON_QAIC_EXE + "${HEXAGON_SDK_ROOT}/ipc/fastrpc/qaic/Ubuntu16/qaic") + endif() + + set(FOUND_HEXAGON_SDK_ROOT TRUE) +endfunction() + diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 8d43879e4df4..163a56dbd1d4 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -51,7 +51,7 @@ function(add_lib_info src_file) TVM_INFO_CUDA_VERSION="${TVM_INFO_CUDA_VERSION}" TVM_INFO_USE_STACKVM_RUNTIME="${USE_STACKVM_RUNTIME}" TVM_INFO_USE_GRAPH_EXECUTOR="${USE_GRAPH_EXECUTOR}" - TVM_INFO_USE_GRAPH_EXECUTOR_DEBUG="${USE_GRAPH_EXECUTOR_DEBUG}" + TVM_INFO_USE_PROFILER="${USE_PROFILER}" TVM_INFO_USE_OPENMP="${USE_OPENMP}" TVM_INFO_USE_RELAY_DEBUG="${USE_RELAY_DEBUG}" TVM_INFO_USE_RTTI="${USE_RTTI}" diff --git a/cmake/modules/StandaloneCrt.cmake b/cmake/modules/StandaloneCrt.cmake index 09f2ccc95d85..5ed5f5ead088 100644 --- a/cmake/modules/StandaloneCrt.cmake +++ b/cmake/modules/StandaloneCrt.cmake @@ -44,10 +44,10 @@ if(USE_MICRO) "src/runtime/crt/include *.h -> include" "src/runtime/crt/common *.c -> src/runtime/crt/common" "src/runtime/crt/graph_executor *.c -> src/runtime/crt/graph_executor" - "src/runtime/crt/aot_executor *.c -> src/runtime/crt/aot_executor" "src/runtime/crt/graph_executor_module *.c -> src/runtime/crt/graph_executor_module" - "src/runtime/crt/host crt_config.h -> template/host" "src/runtime/crt/host *.cc -> template/host" + "src/runtime/crt/host *.py -> template/host" + "src/runtime/crt/host Makefile -> template/host" "src/runtime/crt/memory *.c -> src/runtime/crt/memory" "src/runtime/crt/microtvm_rpc_common *.cc -> src/runtime/crt/microtvm_rpc_common" "src/runtime/crt/microtvm_rpc_server *.cc -> src/runtime/crt/microtvm_rpc_server" @@ -98,13 +98,13 @@ if(USE_MICRO) set(make_quiet ) endif(${VERBOSE}) - list(APPEND crt_libraries memory graph_executor aot_executor microtvm_rpc_server microtvm_rpc_common common) # NOTE: listed in link order. + list(APPEND crt_libraries memory graph_executor microtvm_rpc_server microtvm_rpc_common common) # NOTE: listed in link order. foreach(crt_lib_name IN LISTS crt_libraries) list(APPEND crt_library_paths "host_standalone_crt/lib${crt_lib_name}.a") endforeach() set(make_common_args - "CRT_CONFIG=template/host/crt_config.h" + "CRT_CONFIG=${CMAKE_SOURCE_DIR}/src/runtime/micro/crt_config.h" "BUILD_DIR=${host_build_dir_abspath}" "EXTRA_CFLAGS=-fPIC" "EXTRA_CXXFLAGS=-fPIC" @@ -132,26 +132,16 @@ if(USE_MICRO) PUBLIC_HEADER "${crt_headers}") endforeach() - # Standalone CRT tests - file(GLOB TEST_SRCS ${CMAKE_SOURCE_DIR}/tests/crt/*_test.cc) - find_path(GTEST_INCLUDE_DIR gtest/gtest.h) - find_library(GTEST_LIB gtest "$ENV{GTEST_LIB}") - # Create the `crttest` target if we can find GTest. If not, we create dummy # targets that give the user an informative error message. if(GTEST_INCLUDE_DIR AND GTEST_LIB) - foreach(__srcpath ${TEST_SRCS}) - get_filename_component(__srcname ${__srcpath} NAME) - string(REPLACE ".cc" "" __execname ${__srcname}) - add_executable(${__execname} ${__srcpath}) - list(APPEND TEST_EXECS ${__execname}) - target_include_directories(${__execname} PUBLIC ${GTEST_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include ${CMAKE_SOURCE_DIR}/src/runtime/crt/host) - target_compile_options(${__execname} PRIVATE -pthread) - target_link_libraries(${__execname} ${cmake_crt_libraries} ${GTEST_LIB} pthread) - set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1) - set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1) - endforeach() - add_custom_target(crttest DEPENDS ${TEST_EXECS}) + file(GLOB TEST_SRCS ${CMAKE_SOURCE_DIR}/tests/crt/*_test.cc) + add_executable(crttest ${TEST_SRCS}) + target_include_directories(crttest SYSTEM PUBLIC ${GTEST_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include ${CMAKE_SOURCE_DIR}/src/runtime/micro) + target_link_libraries(crttest PRIVATE ${cmake_crt_libraries} ${GTEST_LIB} gtest_main pthread dl) + set_target_properties(crttest PROPERTIES EXCLUDE_FROM_ALL 1) + set_target_properties(crttest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1) + gtest_discover_tests(crttest) elseif(NOT GTEST_INCLUDE_DIR) add_custom_target(crttest COMMAND echo "Missing Google Test headers in include path" diff --git a/cmake/modules/VTA.cmake b/cmake/modules/VTA.cmake index e520e62711f3..1f9d08b50a10 100644 --- a/cmake/modules/VTA.cmake +++ b/cmake/modules/VTA.cmake @@ -73,6 +73,17 @@ elseif(PYTHON) # Cycle accurate simulator driver build if(USE_VTA_TSIM) + if(DEFINED ENV{VERILATOR_INC_DIR}) + set(VERILATOR_INC_DIR $ENV{VERILATOR_INC_DIR}) + elseif (EXISTS /usr/local/share/verilator/include) + set(VERILATOR_INC_DIR /usr/local/share/verilator/include) + elseif (EXISTS /usr/share/verilator/include) + set(VERILATOR_INC_DIR /usr/share/verilator/include) + else() + message(STATUS "Verilator not found in /usr/local/share/verilator/include") + message(STATUS "Verilator not found in /usr/share/verilator/include") + message(FATAL_ERROR "Cannot find Verilator, VERILATOR_INC_DIR is not defined") + endif() # Add tsim driver sources file(GLOB TSIM_RUNTIME_SRCS ${VTA_HW_PATH}/src/*.cc) file(GLOB TSIM_RUNTIME_SRCS vta/runtime/*.cc) @@ -81,7 +92,7 @@ elseif(PYTHON) list(APPEND TSIM_RUNTIME_SRCS ${VTA_HW_PATH}/src/vmem/virtual_memory.cc) # Target lib: vta_tsim add_library(vta_tsim SHARED ${TSIM_RUNTIME_SRCS}) - target_include_directories(vta_tsim SYSTEM PUBLIC ${VTA_HW_PATH}/include) + target_include_directories(vta_tsim SYSTEM PUBLIC ${VTA_HW_PATH}/include ${VERILATOR_INC_DIR} ${VERILATOR_INC_DIR}/vltstd) target_compile_definitions(vta_tsim PUBLIC DMLC_USE_LOGGING_LIBRARY=) foreach(__def ${VTA_DEFINITIONS}) string(SUBSTRING ${__def} 3 -1 __strip_def) diff --git a/apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf b/cmake/modules/contrib/PAPI.cmake similarity index 61% rename from apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf rename to cmake/modules/contrib/PAPI.cmake index 5f3c4a4bed36..0e03cb8750bd 100644 --- a/apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf +++ b/cmake/modules/contrib/PAPI.cmake @@ -15,11 +15,14 @@ # specific language governing permissions and limitations # under the License. -# This file is specific to the QEMU-emulated microTVM board. +if(USE_PAPI) + find_package(PkgConfig REQUIRED) -# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random. -CONFIG_TEST_RANDOM_GENERATOR=y -CONFIG_TIMER_RANDOM_GENERATOR=y - -# Default stack size is 1k, this is required for debug mode. -CONFIG_MAIN_STACK_SIZE=2000 + set(ENV{PKG_CONFIG_PATH} "${USE_PAPI}:$ENV{PKG_CONFIG_PATH}") + pkg_check_modules(PAPI REQUIRED IMPORTED_TARGET papi>=6.0) + message(STATUS "Using PAPI library ${PAPI_LINK_LIBRARIES}") + target_link_libraries(tvm_runtime_objs PRIVATE PkgConfig::PAPI) + target_link_libraries(tvm PRIVATE PkgConfig::PAPI) + target_link_libraries(tvm_runtime PRIVATE PkgConfig::PAPI) + target_sources(tvm_runtime_objs PRIVATE src/runtime/contrib/papi/papi.cc) +endif() diff --git a/conda/recipe/bld.bat b/conda/recipe/bld.bat index e877b8fda1e1..9a90fb13d4c4 100644 --- a/conda/recipe/bld.bat +++ b/conda/recipe/bld.bat @@ -28,7 +28,7 @@ cmake ^ -DUSE_CPP_RPC=ON ^ -DUSE_SORT=ON ^ -DUSE_RANDOM=ON ^ - -DUSE_GRAPH_EXECUTOR_DEBUG=ON ^ + -DUSE_PROFILER=ON ^ -DINSTALL_DEV=ON ^ %SRC_DIR% diff --git a/conda/recipe/build.sh b/conda/recipe/build.sh index a94b9df72440..242d6a28b3d3 100755 --- a/conda/recipe/build.sh +++ b/conda/recipe/build.sh @@ -49,7 +49,7 @@ cmake -DCMAKE_INSTALL_PREFIX="${PREFIX}" \ -DUSE_CPP_RPC=OFF \ -DUSE_SORT=ON \ -DUSE_RANDOM=ON \ - -DUSE_GRAPH_EXECUTOR_DEBUG=ON \ + -DUSE_PROFILER=ON \ -DUSE_LLVM=ON \ -DINSTALL_DEV=ON \ -DUSE_LIBBACKTRACE=AUTO \ diff --git a/conftest.py b/conftest.py index f591fe970de8..28859fd4a17b 100644 --- a/conftest.py +++ b/conftest.py @@ -14,36 +14,5 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest -from pytest import ExitCode -import tvm -import tvm.testing - - -def pytest_configure(config): - print("enabled targets:", "; ".join(map(lambda x: x[0], tvm.testing.enabled_targets()))) - print("pytest marker:", config.option.markexpr) - - -@pytest.fixture -def dev(target): - return tvm.device(target) - - -def pytest_generate_tests(metafunc): - tvm.testing._auto_parametrize_target(metafunc) - tvm.testing._parametrize_correlated_parameters(metafunc) - - -def pytest_collection_modifyitems(config, items): - tvm.testing._count_num_fixture_uses(items) - tvm.testing._remove_global_fixture_definitions(items) - - -def pytest_sessionfinish(session, exitstatus): - # Don't exit with an error if we select a subset of tests that doesn't - # include anything - if session.config.option.markexpr != "": - if exitstatus == ExitCode.NO_TESTS_COLLECTED: - session.exitstatus = ExitCode.OK +pytest_plugins = ["tvm.testing.plugin"] diff --git a/docker/Dockerfile.ci_arm b/docker/Dockerfile.ci_arm index 9479d7194d3b..974998b9d6fe 100644 --- a/docker/Dockerfile.ci_arm +++ b/docker/Dockerfile.ci_arm @@ -32,6 +32,9 @@ RUN bash /install/ubuntu_install_llvm.sh COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh +# Globally disable pip cache +RUN pip config set global.cache-dir false + COPY install/ubuntu_install_cmake_source.sh /install/ubuntu_install_cmake_source.sh RUN bash /install/ubuntu_install_cmake_source.sh diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 65afa6931d9c..a1c997364238 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -27,6 +27,9 @@ RUN bash /install/ubuntu_install_core.sh COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh +# Globally disable pip cache +RUN pip config set global.cache-dir false + COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh RUN bash /install/ubuntu_install_python_package.sh @@ -48,6 +51,7 @@ COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh RUN bash /install/ubuntu_install_rust.sh ENV RUSTUP_HOME /opt/rust ENV CARGO_HOME /opt/rust +ENV PATH $PATH:$CARGO_HOME/bin # AutoTVM deps COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh @@ -56,13 +60,12 @@ RUN bash /install/ubuntu_install_redis.sh # Golang environment COPY install/ubuntu_install_golang.sh /install/ubuntu_install_golang.sh RUN bash /install/ubuntu_install_golang.sh +ENV PATH $PATH:/usr/lib/go-1.10/bin # NNPACK deps COPY install/ubuntu_install_nnpack.sh /install/ubuntu_install_nnpack.sh RUN bash /install/ubuntu_install_nnpack.sh -ENV PATH $PATH:$CARGO_HOME/bin:/usr/lib/go-1.10/bin - # ANTLR deps COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh RUN bash /install/ubuntu_install_java.sh @@ -109,3 +112,17 @@ RUN bash /install/ubuntu_install_androidsdk.sh ENV ANDROID_HOME=/opt/android-sdk-linux/ ENV ANDROID_NDK_HOME=/opt/android-sdk-linux/ndk/21.3.6528147/ +# Arm(R) Ethos(TM)-U NPU driver +COPY install/ubuntu_install_ethosu_driver_stack.sh /install/ubuntu_install_ethosu_driver_stack.sh +RUN bash /install/ubuntu_install_ethosu_driver_stack.sh + +# Install Vela compiler +COPY install/ubuntu_install_vela.sh /install/ubuntu_install_vela.sh +RUN bash /install/ubuntu_install_vela.sh + +# Update PATH +ENV PATH /opt/arm/gcc-arm-none-eabi/bin:$PATH + +# PaddlePaddle deps +COPY install/ubuntu_install_paddle.sh /install/ubuntu_install_paddle.sh +RUN bash /install/ubuntu_install_paddle.sh diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 09c6425da6fb..3d0704bb27f9 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -28,6 +28,9 @@ RUN bash /install/ubuntu_install_core.sh COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh +# Globally disable pip cache +RUN pip config set global.cache-dir false + COPY install/ubuntu1804_install_llvm.sh /install/ubuntu1804_install_llvm.sh RUN bash /install/ubuntu1804_install_llvm.sh @@ -80,9 +83,13 @@ RUN bash /install/ubuntu_install_caffe2.sh COPY install/ubuntu_install_dgl.sh /install/ubuntu_install_dgl.sh RUN bash /install/ubuntu_install_dgl.sh +ENV NVIDIA_DRIVER_CAPABILITIES compute,graphics,utility COPY install/ubuntu_install_vulkan.sh /install/ubuntu_install_vulkan.sh RUN bash /install/ubuntu_install_vulkan.sh +COPY install/ubuntu_install_paddle.sh /install/ubuntu_install_paddle.sh +RUN bash /install/ubuntu_install_paddle.sh + # Rust env (build early; takes a while) COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh RUN bash /install/ubuntu_install_rust.sh diff --git a/docker/Dockerfile.ci_i386 b/docker/Dockerfile.ci_i386 index 2383f4675e37..564731c12d36 100644 --- a/docker/Dockerfile.ci_i386 +++ b/docker/Dockerfile.ci_i386 @@ -31,6 +31,9 @@ RUN bash /install/ubuntu_install_llvm.sh COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh RUN bash /install/ubuntu_install_python.sh +# Globally disable pip cache +RUN pip config set global.cache-dir false + COPY install/ubuntu_install_cmake_source.sh /install/ubuntu_install_cmake_source.sh RUN bash /install/ubuntu_install_cmake_source.sh diff --git a/docker/Dockerfile.ci_lint b/docker/Dockerfile.ci_lint index 2adb793a3517..20bcfe6de903 100644 --- a/docker/Dockerfile.ci_lint +++ b/docker/Dockerfile.ci_lint @@ -27,9 +27,19 @@ RUN apt-get update && apt-get install -y wget git sudo make COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh -RUN apt-get update && apt-get install -y doxygen graphviz +# Globally disable pip cache +RUN pip config set global.cache-dir false -RUN pip3 install cpplint pylint==2.4.4 mypy==0.902 black==20.8b1 +RUN apt-get update && apt-get install -y doxygen graphviz curl + +RUN pip3 install cpplint pylint==2.4.4 mypy==0.902 black==20.8b1 flake8==3.9.2 + +# Rust env (build early; takes a while) +COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh +RUN bash /install/ubuntu_install_rust.sh +ENV RUSTUP_HOME /opt/rust +ENV CARGO_HOME /opt/rust +ENV PATH $PATH:$CARGO_HOME/bin # java deps for rat COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh diff --git a/docker/Dockerfile.ci_qemu b/docker/Dockerfile.ci_qemu index 72189bd79afa..b907ba7b08a9 100644 --- a/docker/Dockerfile.ci_qemu +++ b/docker/Dockerfile.ci_qemu @@ -26,7 +26,10 @@ RUN bash /install/ubuntu_install_core.sh COPY install/ubuntu1804_install_python_venv.sh /install/ubuntu1804_install_python_venv.sh RUN bash /install/ubuntu1804_install_python_venv.sh -ENV PATH=/opt/tvm-venv/bin:$PATH +ENV PATH=/opt/tvm-venv/bin:/opt/zephyr-sdk/sysroots/x86_64-pokysdk-linux/usr/bin:$PATH + +# Globally disable pip cache +RUN pip config set global.cache-dir false COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh RUN bash /install/ubuntu_install_python_package.sh @@ -56,16 +59,21 @@ RUN bash /install/ubuntu_install_tensorflow.sh COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh RUN bash /install/ubuntu_install_tflite.sh -# QEMU deps -COPY install/ubuntu_install_qemu.sh /install/ubuntu_install_qemu.sh -RUN bash /install/ubuntu_install_qemu.sh - # Zephyr SDK deps COPY install/ubuntu_install_zephyr.sh /install/ubuntu_install_zephyr.sh COPY install/ubuntu_init_zephyr_project.sh /install/ubuntu_init_zephyr_project.sh RUN bash /install/ubuntu_install_zephyr.sh ENV ZEPHYR_BASE=/opt/zephyrproject/zephyr +# Arduino deps +# NOTE: override Arduino directories so packages are installed in a +# CI-accessible location. +ENV ARDUINO_DIRECTORIES_DATA=/arduino15-data +ENV ARDUINO_DIRECTORIES_DOWNLOADS=/arduino15-downloads +ENV ARDUINO_DIRECTORIES_USER=/arduino15-user +COPY install/ubuntu_install_arduino.sh /install/ubuntu_install_arduino.sh +RUN bash /install/ubuntu_install_arduino.sh + # Install ONNX COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh RUN bash /install/ubuntu_install_onnx.sh diff --git a/docker/Dockerfile.ci_wasm b/docker/Dockerfile.ci_wasm index 85f942d57ca3..03a1ded5572f 100644 --- a/docker/Dockerfile.ci_wasm +++ b/docker/Dockerfile.ci_wasm @@ -24,6 +24,9 @@ RUN bash /install/ubuntu_install_core.sh COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh +# Globally disable pip cache +RUN pip config set global.cache-dir false + COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh RUN bash /install/ubuntu_install_python_package.sh diff --git a/docker/README.md b/docker/README.md index 5e470350749d..a05079d30881 100644 --- a/docker/README.md +++ b/docker/README.md @@ -33,7 +33,7 @@ interactive bash session with a given image_name. The script does the following things: -- Mount current directory to /workspace and set it as home +- Mount current directory to the same location in the docker container, and set it as home - Switch user to be the same user that calls the bash.sh - Use the host-side network @@ -102,9 +102,10 @@ The command ``./docker/build.sh image_name COMMANDS`` is almost equivelant to ``./docker/bash.sh image_name COMMANDS`` but in the case of ``bash.sh`` a build attempt is not done. -The build command will map the tvm root to /workspace/ inside the container -with the same user as the user invoking the docker command. -Here are some common use examples to perform CI tasks. +The build command will map the tvm root to the corresponding location +inside the container with the same user as the user invoking the +docker command. Here are some common use examples to perform CI +tasks. - lint the python codes diff --git a/docker/bash.sh b/docker/bash.sh index 80f4a9577be1..2a05abf4f2bc 100755 --- a/docker/bash.sh +++ b/docker/bash.sh @@ -18,9 +18,12 @@ # under the License. # -# Start a bash, mount /workspace to be current directory. +# Start a bash, mount REPO_MOUNT_POINT to be current directory. # -# Usage: bash.sh [-i] [--net=host] [--mount path] +# Usage: docker/bash.sh [-i|--interactive] [--net=host] [-t|--tty] +# [--mount MOUNT_DIR] [--repo-mount-point REPO_MOUNT_POINT] +# [--dry-run] +# [--] [COMMAND] # # Usage: docker/bash.sh # Starts an interactive session @@ -30,152 +33,396 @@ # With -i, execute interactively. # -set -e +set -euo pipefail -source "$(dirname $0)/dev_common.sh" || exit 2 -interactive=0 -if [ "$1" == "-i" ]; then - interactive=1 - shift +function show_usage() { + cat < [--] [COMMAND] + +-h, --help + + Display this help message. + +-i, --interactive + + Start the docker session in interactive mode. + +-t, --tty + + Start the docker session with a pseudo terminal (tty). + +--net=host + + Expose servers run into the container to the host, passing the + "--net=host" argument through to docker. On MacOS, this is + instead passed as "-p 8888:8888" since the host networking driver + isn't supported. + +--mount MOUNT_DIR + + Expose MOUNT_DIR as an additional mount point inside the docker + container. The mount point inside the container is the same as + the folder location outside the container. This option can be + specified multiple times. + +--repo-mount-point REPO_MOUNT_POINT + + The directory inside the docker container at which the TVM + repository should be mounted, and is used as the workspace inside + the docker container. + + If unspecified, the mount location depends on the environment. If + running inside Jenkins, the mount location will be /workspace. + Otherwise, the mount location of the repository will be the same + as the external location of the repository, to maintain + compatibility with git-worktree. + +--dry-run + + Print the docker command to be run, but do not execute it. + +DOCKER_IMAGE_NAME + + The name of the docker container to be run. This can be an + explicit name of a docker image (e.g. "tlcpack/ci-gpu:v0.76") or + can be a shortcut as defined in the TVM Jenkinsfile + (e.g. "ci_gpu"). + +COMMAND + + The command to be run inside the docker container. If this is set + to "bash", both the --interactive and --net=host flags are set. + If no command is specified, defaults to "bash". If the command + contains dash-prefixed arguments, the command should be preceded + by -- to indicate arguments that are not intended for bash.sh. + +EOF +} + + +################################# +### Start of argument parsing ### +################################# + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -P)" +REPO_DIR="$(dirname "${SCRIPT_DIR}")" + +DRY_RUN=false +INTERACTIVE=false +TTY=false +USE_NET_HOST=false +DOCKER_IMAGE_NAME= +COMMAND=bash +MOUNT_DIRS=( ) + +# TODO(Lunderberg): Remove this if statement and always set to +# "${REPO_DIR}". The consistent directory for Jenkins is currently +# necessary to allow cmake build commands to run in CI after the build +# steps. +if [[ -n "${JENKINS_HOME:-}" ]]; then + REPO_MOUNT_POINT=/workspace +else + REPO_MOUNT_POINT="${REPO_DIR}" fi -CI_DOCKER_EXTRA_PARAMS=( ) -if [[ "$1" == "--net=host" ]]; then - CI_DOCKER_EXTRA_PARAMS+=('--net=host') - shift 1 + +function parse_error() { + echo "$@" >&2 + show_usage >&2 + exit 1 +} + +# Handle joined flags, such as interpreting -ih as -i -h. Either rewrites +# the current argument if it is a joined argument, or shifts all arguments +# otherwise. Should be called as "eval $break_joined_flag" where joined +# flags are possible. Can't use a function definition, because it needs +# to overwrite the parent scope's behavior. +break_joined_flag='if (( ${#1} == 2 )); then shift; else set -- -"${1#-i}" "${@:2}"; fi' + + +while (( $# )); do + case "$1" in + -h|--help) + show_usage + exit 0 + ;; + + -i*|--interactive) + INTERACTIVE=true + eval $break_joined_flag + ;; + + -t*|--tty) + TTY=true + eval $break_joined_flag + ;; + + --net=host) + USE_NET_HOST=true + shift + ;; + + --mount) + if [[ -n "$2" ]]; then + MOUNT_DIRS+=("$2") + shift 2 + else + parse_error 'ERROR: --mount requires a non-empty argument' + fi + ;; + + --mount=?*) + MOUNT_DIRS+=("${1#*=}") + shift + ;; + + --dry-run) + DRY_RUN=true + shift + ;; + + --repo-mount-point) + if [[ -n "$2" ]]; then + REPO_MOUNT_POINT="$2" + shift 2 + else + parse_error 'ERROR: --repo-mount-point requires a non-empty argument' + fi + ;; + + --repo-mount-point=?*) + REPO_MOUNT_POINT="${1#*=}" + shift + ;; + + --) + shift + COMMAND=( "$@" ) + break + ;; + + -*|--*) + echo "Error: Unknown flag: $1" >&2 + echo " If this flag is intended to be passed to the" >&2 + echo " docker command, please add -- before the docker" >&2 + echo " command (e.g. docker/bash.sh ci_gpu -- build -j2)" >&2 + show_usage >&2 + exit 1 + ;; + + *) + # First positional argument is the image name, all + # remaining below to the COMMAND. + if [[ -z "${DOCKER_IMAGE_NAME}" ]]; then + DOCKER_IMAGE_NAME=$1 + shift + else + COMMAND=( "$@" ) + break + fi + ;; + esac +done + +if [[ -z "${DOCKER_IMAGE_NAME}" ]]; then + echo "Error: Missing DOCKER_IMAGE_NAME" >&2 + show_usage >&2 fi -# Mount external directory to the docker -CI_DOCKER_MOUNT_CMD=( ) -if [ "$1" == "--mount" ]; then - shift 1 - CI_DOCKER_MOUNT_CMD=( -v "$1:$1" ) - shift 1 +if [[ ${COMMAND[@]+"${COMMAND[@]}"} = bash ]]; then + INTERACTIVE=true + USE_NET_HOST=true fi -if [ "$#" -lt 1 ]; then - echo "Usage: docker/bash.sh [-i] [--net=host] [COMMAND]" - exit -1 + + +############################### +### End of argument parsing ### +############################### + +source "$(dirname $0)/dev_common.sh" || exit 2 + +DOCKER_FLAGS=( ) +DOCKER_ENV=( ) +DOCKER_MOUNT=( ) +DOCKER_DEVICES=( ) + + +# If the user gave a shortcut defined in the Jenkinsfile, use it. +EXPANDED_SHORTCUT=$(lookup_image_spec "${DOCKER_IMAGE_NAME}") +if [ -n "${EXPANDED_SHORTCUT}" ]; then + DOCKER_IMAGE_NAME="${EXPANDED_SHORTCUT}" fi -DOCKER_IMAGE_NAME=$(lookup_image_spec "$1") -if [ -z "${DOCKER_IMAGE_NAME}" ]; then - DOCKER_IMAGE_NAME=("$1") +# Set up working directories + +DOCKER_FLAGS+=( --workdir "${REPO_MOUNT_POINT}" ) +DOCKER_MOUNT+=( --volume "${REPO_DIR}":"${REPO_MOUNT_POINT}" + --volume "${SCRIPT_DIR}":/docker + ) + +# Set up CI-specific environment variables +DOCKER_ENV+=( --env CI_BUILD_HOME="${REPO_MOUNT_POINT}" + --env CI_BUILD_USER="$(id -u -n)" + --env CI_BUILD_UID="$(id -u)" + --env CI_BUILD_GROUP="$(id -g -n)" + --env CI_BUILD_GID="$(id -g)" + --env CI_PYTEST_ADD_OPTIONS="${CI_PYTEST_ADD_OPTIONS:-}" + --env CI_IMAGE_NAME="${DOCKER_IMAGE_NAME}" + ) + + +# Pass tvm test data folder through to the docker container, to avoid +# repeated downloads. Check if we have permissions to write to the +# directory first, since the CI may not. +TEST_DATA_PATH="${TVM_DATA_ROOT_PATH:-${HOME}/.tvm_test_data}" +if [[ -d "${TEST_DATA_PATH}" && -w "${TEST_DATA_PATH}" ]]; then + DOCKER_MOUNT+=( --volume "${TEST_DATA_PATH}":"${REPO_MOUNT_POINT}"/.tvm_test_data ) fi -if [ "$#" -eq 1 ]; then - COMMAND="bash" - interactive=1 + +# Remove the container once it finishes running (--rm) and share the +# PID namespace (--pid=host). The process inside does not have pid 1 +# and SIGKILL is propagated to the process inside, allowing jenkins to +# kill it if needed. +DOCKER_FLAGS+=( --rm --pid=host) + +# Expose services running in container to the host. +if $USE_NET_HOST; then if [[ $(uname) == "Darwin" ]]; then # Docker's host networking driver isn't supported on macOS. # Use default bridge network and expose port for jupyter notebook. - CI_DOCKER_EXTRA_PARAMS+=( "${CI_DOCKER_EXTRA_PARAMS[@]}" "-p 8888:8888" ) + DOCKER_FLAGS+=( -p 8888:8888 ) else - CI_DOCKER_EXTRA_PARAMS+=( "${CI_DOCKER_EXTRA_PARAMS[@]}" "--net=host" ) + DOCKER_FLAGS+=(--net=host) fi -else - shift 1 - COMMAND=("$@") fi -if [ $interactive -eq 1 ]; then - CI_DOCKER_EXTRA_PARAMS=( "${CI_DOCKER_EXTRA_PARAMS[@]}" -it ) +# Set up interactive sessions +if ${INTERACTIVE}; then + DOCKER_FLAGS+=( --interactive ) fi -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -WORKSPACE="$(pwd)" - -# Use nvidia-docker if the container is GPU. -if [[ ! -z $CUDA_VISIBLE_DEVICES ]]; then - CUDA_ENV="-e CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" -else - CUDA_ENV="" +if ${TTY}; then + DOCKER_FLAGS+=( --tty ) fi +# Expose external directories to the docker container +for MOUNT_DIR in ${MOUNT_DIRS[@]+"${MOUNT_DIRS[@]}"}; do + DOCKER_MOUNT+=( --volume "${MOUNT_DIR}:${MOUNT_DIR}" ) +done + +# Use nvidia-docker for GPU container. If nvidia-docker is not +# available, fall back to using "--gpus all" flag, requires docker +# version 19.03 or higher. if [[ "${DOCKER_IMAGE_NAME}" == *"gpu"* || "${DOCKER_IMAGE_NAME}" == *"cuda"* ]]; then - if ! type "nvidia-docker" 1> /dev/null 2> /dev/null - then - DOCKER_BINARY="docker" - CUDA_ENV=" --gpus all "${CUDA_ENV} + if type nvidia-docker 1> /dev/null 2> /dev/null; then + DOCKER_BINARY=nvidia-docker else - DOCKER_BINARY="nvidia-docker" + DOCKER_BINARY=docker + DOCKER_FLAGS+=( --gpus all ) fi + + # nvidia-docker treats Vulkan as a graphics API, so we need to + # request passthrough of graphics APIs. This could also be set in + # the Dockerfile. + DOCKER_ENV+=( --env NVIDIA_DRIVER_CAPABILITIES=compute,graphics,utility ) + + # But as of nvidia-docker version 2.6.0-1, we still need to pass + # through the nvidia icd files ourselves. + ICD_SEARCH_LOCATIONS=( + # https://github.com/KhronosGroup/Vulkan-Loader/blob/master/loader/LoaderAndLayerInterface.md#icd-discovery-on-linux + /usr/local/etc/vulkan/icd.d + /usr/local/share/vulkan/icd.d + /etc/vulkan/icd.d + /usr/share/vulkan/icd.d + # https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md#icd-installation + /etc/glvnd/egl_vendor.d + /usr/share/glvnd/egl_vendor.d + ) + for filename in $(find "${ICD_SEARCH_LOCATIONS[@]}" -name "*nvidia*.json" 2> /dev/null); do + DOCKER_MOUNT+=( --volume "${filename}":"${filename}":ro ) + done + else - DOCKER_BINARY="docker" + DOCKER_BINARY=docker +fi + + + +# Pass any restrictions of allowed CUDA devices from the host to the +# docker container. +if [[ -n ${CUDA_VISIBLE_DEVICES:-} ]]; then + DOCKER_ENV+=( --env CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES}" ) fi + + +# Set TVM import path inside the docker image if [[ "${DOCKER_IMAGE_NAME}" == *"ci"* ]]; then - CI_ADDON_ENV="-e PYTHONPATH=/workspace/python" -else - CI_ADDON_ENV="" + DOCKER_ENV+=( --env PYTHONPATH="${REPO_MOUNT_POINT}"/python ) fi -DOCKER_ENVS="" -DOCKER_DEVICES="" -WORKSPACE_VOLUMES="" -# If the Vitis-AI docker image is selected, expose the Xilinx FPGA devices and required volumes containing e.g. DSA's and overlays + + +# If the Vitis-AI docker image is selected, expose the Xilinx FPGA +# devices and required volumes containing e.g. DSA's and overlays if [[ "${DOCKER_IMAGE_NAME}" == *"demo_vitis_ai"* && -d "/dev/shm" && -d "/opt/xilinx/dsa" && -d "/opt/xilinx/overlaybins" ]]; then - WORKSPACE_VOLUMES="-v /dev/shm:/dev/shm -v /opt/xilinx/dsa:/opt/xilinx/dsa -v /opt/xilinx/overlaybins:/opt/xilinx/overlaybins" + DOCKER_MOUNT+=( --volume /dev/shm:/dev/shm + --volume /opt/xilinx/dsa:/opt/xilinx/dsa + --volume /opt/xilinx/overlaybins:/opt/xilinx/overlaybins + ) + XCLMGMT_DRIVER="$(find /dev -name xclmgmt\*)" - DOCKER_DEVICES="" - for i in ${XCLMGMT_DRIVER} ; - do - DOCKER_DEVICES+="--device=$i " + for DRIVER in "${XCLMGMT_DRIVER}"; do + DOCKER_DEVICES+=( --device="${DRIVER}" ) done RENDER_DRIVER="$(find /dev/dri -name renderD\*)" - for i in ${RENDER_DRIVER} ; - do - DOCKER_DEVICES+="--device=$i " + for DRIVER in "${RENDER_DRIVER}"; do + DOCKER_DEVICES+=( --device="${DRIVER}" ) done fi # Add ROCm devices and set ROCM_ENABLED=1 which is used in the with_the_same_user script # to add the user to the video group if [[ "${DOCKER_IMAGE_NAME}" == *"rocm"* && -d "/dev/dri" ]]; then - DOCKER_DEVICES+="--device=/dev/kfd --device=/dev/dri " - DOCKER_ENVS+="-e ROCM_ENABLED=1 " + DOCKER_DEVICES+=( --device=/dev/kfd --device=/dev/dri ) + DOCKER_ENV+=( --env ROCM_ENABLED=1 ) +fi + +# When running from a git worktree, also mount the original git dir. +if [ -f "${REPO_DIR}/.git" ]; then + git_dir=$(cd ${REPO_DIR} && git rev-parse --git-common-dir) + if [ "${git_dir}" != "${REPO_DIR}/.git" ]; then + DOCKER_MOUNT+=( --volume "${git_dir}:${git_dir}" ) + fi fi # Print arguments. -echo "WORKSPACE: ${WORKSPACE}" +echo "REPO_DIR: ${REPO_DIR}" echo "DOCKER CONTAINER NAME: ${DOCKER_IMAGE_NAME}" echo "" -echo "Running '${COMMAND[@]}' inside ${DOCKER_IMAGE_NAME}..." +echo Running \'${COMMAND[@]+"${COMMAND[@]}"}\' inside ${DOCKER_IMAGE_NAME}... -# When running from a git worktree, also mount the original git dir. -EXTRA_MOUNTS=( ) -if [ -f "${WORKSPACE}/.git" ]; then - git_dir=$(cd ${WORKSPACE} && git rev-parse --git-common-dir) - if [ "${git_dir}" != "${WORKSPACE}/.git" ]; then - EXTRA_MOUNTS=( "${EXTRA_MOUNTS[@]}" -v "${git_dir}:${git_dir}" ) - fi -fi -# By default we cleanup - remove the container once it finish running (--rm) -# and share the PID namespace (--pid=host) so the process inside does not have -# pid 1 and SIGKILL is propagated to the process inside (jenkins can kill it). -${DOCKER_BINARY} run --rm --pid=host\ - ${DOCKER_DEVICES}\ - ${WORKSPACE_VOLUMES}\ - -v ${WORKSPACE}:/workspace \ - -v ${SCRIPT_DIR}:/docker \ - "${CI_DOCKER_MOUNT_CMD[@]}" \ - "${EXTRA_MOUNTS[@]}" \ - -w /workspace \ - -e "CI_BUILD_HOME=/workspace" \ - -e "CI_BUILD_USER=$(id -u -n)" \ - -e "CI_BUILD_UID=$(id -u)" \ - -e "CI_BUILD_GROUP=$(id -g -n)" \ - -e "CI_BUILD_GID=$(id -g)" \ - -e "CI_PYTEST_ADD_OPTIONS=$CI_PYTEST_ADD_OPTIONS" \ - -e "CI_IMAGE_NAME=${DOCKER_IMAGE_NAME}" \ - ${DOCKER_ENVS} \ - ${CI_ADDON_ENV} \ - ${CUDA_ENV} \ - "${CI_DOCKER_EXTRA_PARAMS[@]}" \ - ${DOCKER_IMAGE_NAME} \ - bash --login /docker/with_the_same_user \ - "${COMMAND[@]}" +DOCKER_CMD=(${DOCKER_BINARY} run + ${DOCKER_FLAGS[@]+"${DOCKER_FLAGS[@]}"} + ${DOCKER_ENV[@]+"${DOCKER_ENV[@]}"} + ${DOCKER_MOUNT[@]+"${DOCKER_MOUNT[@]}"} + ${DOCKER_DEVICES[@]+"${DOCKER_DEVICES[@]}"} + "${DOCKER_IMAGE_NAME}" + bash --login /docker/with_the_same_user + ${COMMAND[@]+"${COMMAND[@]}"} + ) + +if ${DRY_RUN}; then + echo ${DOCKER_CMD[@]+"${DOCKER_CMD[@]}"} +else + ${DOCKER_CMD[@]+"${DOCKER_CMD[@]}"} +fi diff --git a/docker/build.sh b/docker/build.sh index e654a6253317..3b58bcc52a75 100755 --- a/docker/build.sh +++ b/docker/build.sh @@ -22,7 +22,8 @@ # # Usage: build.sh [--tag ] # [--dockerfile ] [-it] -# [--net=host] [--cache-from ] [] +# [--net=host] [--cache-from ] +# [--context-path ] [] # # CONTAINER_TYPE: Type of the docker container used the run the build, # e.g. "ci_cpu", "ci_gpu" @@ -37,6 +38,9 @@ # IMAGE_NAME: An image to be as a source for cached layers when building the # Docker image requested. # +# CONTEXT_PATH: Path to be used for relative path resolution when building +# the Docker images. +# # COMMAND (optional): Command to be executed in the docker container # SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -47,7 +51,6 @@ shift 1 # Dockerfile to be used in docker build DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.${CONTAINER_TYPE}" -DOCKER_CONTEXT_PATH="${SCRIPT_DIR}" if [[ "$1" == "--tag" ]]; then DOCKER_IMAGE_TAG="$2" @@ -57,9 +60,7 @@ fi if [[ "$1" == "--dockerfile" ]]; then DOCKERFILE_PATH="$2" - DOCKER_CONTEXT_PATH=$(dirname "${DOCKERFILE_PATH}") echo "Using custom Dockerfile path: ${DOCKERFILE_PATH}" - echo "Using custom docker build context path: ${DOCKER_CONTEXT_PATH}" shift 2 fi @@ -85,6 +86,15 @@ if [[ "$1" == "--cache-from" ]]; then shift 1 fi +if [[ "$1" == "--context-path" ]]; then + DOCKER_CONTEXT_PATH="$2" + echo "Using custom context path: ${DOCKER_CONTEXT_PATH}" + shift 2 +else + DOCKER_CONTEXT_PATH=$(dirname "${DOCKERFILE_PATH}") + echo "Using default context path: ${DOCKER_CONTEXT_PATH}" +fi + if [[ ! -f "${DOCKERFILE_PATH}" ]]; then echo "Invalid Dockerfile path: \"${DOCKERFILE_PATH}\"" exit 1 diff --git a/docker/dev_common.sh b/docker/dev_common.sh index 68b9f8d28760..6823f754ae16 100644 --- a/docker/dev_common.sh +++ b/docker/dev_common.sh @@ -55,7 +55,6 @@ function lookup_image_spec() { fi } - function run_docker() { image_name="$1" # Name of the Jenkinsfile var to find shift @@ -66,5 +65,5 @@ function run_docker() { exit 2 fi - "${GIT_TOPLEVEL}/docker/bash.sh" -i "${image_spec}" "$@" + "${GIT_TOPLEVEL}/docker/bash.sh" "${image_spec}" "$@" } diff --git a/apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf b/docker/install/ubuntu_install_arduino.sh similarity index 53% rename from apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf rename to docker/install/ubuntu_install_arduino.sh index 3733568ed02f..d5c4303f211b 100644 --- a/apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf +++ b/docker/install/ubuntu_install_arduino.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,15 +16,21 @@ # specific language governing permissions and limitations # under the License. -# This file is specific to the QEMU-emulated RISCV32 microTVM board. +set -e +set -u +set -o pipefail -# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random. -CONFIG_TEST_RANDOM_GENERATOR=y -CONFIG_TIMER_RANDOM_GENERATOR=y +export DEBIAN_FRONTEND=noninteractive +apt-get install -y ca-certificates -# Default is 512, raised here for operations with large floating point data. -CONFIG_MAIN_STACK_SIZE=2048 +# Install arduino-cli latest version +wget -O - https://raw.githubusercontent.com/arduino/arduino-cli/master/install.sh | sh -s -# For floating point operations. It has exception on floating point operations -# without this flag. -CONFIG_FPU_SHARING=y +# Install supported cores from those URLS +arduino-cli core install arduino:mbed_nano +arduino-cli core install arduino:sam + +# ARDUINO_DIRECTORIES_USER wouldn't normally be created until we +# install a package, which would casue chmod to fail +mkdir -p "${ARDUINO_DIRECTORIES_DATA}" "${ARDUINO_DIRECTORIES_USER}" "${ARDUINO_DIRECTORIES_DOWNLOADS}" +chmod -R o+rw "${ARDUINO_DIRECTORIES_DATA}" "${ARDUINO_DIRECTORIES_USER}" "${ARDUINO_DIRECTORIES_DOWNLOADS}" diff --git a/docker/install/ubuntu_install_darknet.sh b/docker/install/ubuntu_install_darknet.sh index 37adf4a30270..8020899f8bf1 100755 --- a/docker/install/ubuntu_install_darknet.sh +++ b/docker/install/ubuntu_install_darknet.sh @@ -23,4 +23,7 @@ set -o pipefail #install the necessary dependancies, cffi, opencv wget -q 'https://github.com/siju-samuel/darknet/blob/master/lib/libdarknet.so?raw=true' -O libdarknet.so debian_version=`cat /etc/debian_version` -pip3 install opencv-python cffi + +pip3 install \ + cffi \ + opencv-python diff --git a/docker/install/ubuntu_install_ethosu_driver_stack.sh b/docker/install/ubuntu_install_ethosu_driver_stack.sh new file mode 100755 index 000000000000..35b2b4c74b7b --- /dev/null +++ b/docker/install/ubuntu_install_ethosu_driver_stack.sh @@ -0,0 +1,94 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u +set -o pipefail + +fvp_dir="/opt/arm/FVP_Corstone_SSE-300_Ethos-U55" +cmake_dir="/opt/arm/cmake" +ethosu_dir="/opt/arm/ethosu" +ethosu_driver_ver="21.05" +cmsis_ver="5.7.0" + +mkdir -p /opt/arm + +tmpdir=$(mktemp -d) + +cleanup() +{ + rm -rf "$tmpdir" +} + +trap cleanup 0 + +# Ubuntu 18.04 dependencies +apt-get update + +apt-get install -y \ + bsdmainutils \ + build-essential \ + cpp \ + git \ + linux-headers-generic \ + make \ + python-dev \ + python3 \ + ssh \ + wget \ + xxd + +# Download the FVP +mkdir -p "$fvp_dir" +cd "$tmpdir" +curl -sL https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/MPS3/FVP_Corstone_SSE-300_Ethos-U55_11.14_24.tgz | tar -xz +./FVP_Corstone_SSE-300_Ethos-U55.sh --i-agree-to-the-contained-eula --no-interactive -d "$fvp_dir" +rm -rf FVP_Corstone_SSE-300_Ethos-U55.sh license_terms + +# Setup cmake 3.19.5 +mkdir -p "${cmake_dir}" +cd "$tmpdir" +curl -sL -o cmake-3.19.5-Linux-x86_64.sh https://github.com/Kitware/CMake/releases/download/v3.19.5/cmake-3.19.5-Linux-x86_64.sh +chmod +x cmake-3.19.5-Linux-x86_64.sh +./cmake-3.19.5-Linux-x86_64.sh --prefix="${cmake_dir}" --skip-license +rm cmake-3.19.5-Linux-x86_64.sh +export PATH="${cmake_dir}/bin:${PATH}" + +# Install the GCC toolchain +mkdir -p /opt/arm/gcc-arm-none-eabi/ +gcc_arm_url='https://developer.arm.com/-/media/Files/downloads/gnu-rm/10-2020q4/gcc-arm-none-eabi-10-2020-q4-major-x86_64-linux.tar.bz2?revision=ca0cbf9c-9de2-491c-ac48-898b5bbc0443&la=en&hash=68760A8AE66026BCF99F05AC017A6A50C6FD832A' +curl --retry 64 -sSL ${gcc_arm_url} | tar -C /opt/arm/gcc-arm-none-eabi --strip-components=1 -jx +export PATH="/opt/arm/gcc-arm-none-eabi/bin:${PATH}" + +# Clone Arm(R) Ethos(TM)-U NPU driver stack +mkdir "${ethosu_dir}" +cd "${ethosu_dir}" +git clone "https://review.mlplatform.org/ml/ethos-u/ethos-u-core-driver" core_driver +cd core_driver +git checkout tags/${ethosu_driver_ver} + +cd "${ethosu_dir}" +git clone "https://review.mlplatform.org/ml/ethos-u/ethos-u-core-platform" core_platform +cd core_platform +git checkout tags/${ethosu_driver_ver} + +# Clone CMSIS +cd "${ethosu_dir}" +git clone "https://github.com/ARM-software/CMSIS_5.git" cmsis +cd cmsis +git checkout -f tags/${cmsis_ver} diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh index 8f462284c2ba..ef0bf1b012c6 100755 --- a/docker/install/ubuntu_install_onnx.sh +++ b/docker/install/ubuntu_install_onnx.sh @@ -22,11 +22,14 @@ set -o pipefail # We need to fix the onnx version because changing versions tends to break tests # TODO(mbrookhart): periodically update -pip3 install onnx==1.8.1 -pip3 install onnxruntime==1.7.0 +pip3 install \ + onnx==1.8.1 \ + onnxruntime==1.7.0 # torch depends on a number of other packages, but unhelpfully, does # not expose that in the wheel!!! pip3 install future -pip3 install torch==1.7.0 torchvision==0.8.1 +pip3 install \ + torch==1.7.0 \ + torchvision==0.8.1 diff --git a/python/tvm/micro/contrib/__init__.py b/docker/install/ubuntu_install_paddle.sh similarity index 91% rename from python/tvm/micro/contrib/__init__.py rename to docker/install/ubuntu_install_paddle.sh index 13a83393a912..267d59105c06 100644 --- a/python/tvm/micro/contrib/__init__.py +++ b/docker/install/ubuntu_install_paddle.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,3 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +set -e +set -u +set -o pipefail + +pip install paddlepaddle==2.1.2 diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index 7989a49a4826..88d68408381c 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -21,4 +21,21 @@ set -u set -o pipefail # install libraries for python package on ubuntu -pip3 install six numpy pytest cython decorator scipy tornado pytest pytest-xdist pytest-profiling mypy orderedset attrs requests Pillow packaging cloudpickle synr +pip3 install \ + attrs \ + cloudpickle \ + cython \ + decorator \ + mypy \ + numpy \ + orderedset \ + packaging \ + Pillow \ + pytest \ + pytest-profiling \ + pytest-xdist \ + requests \ + scipy \ + synr==0.3.0 \ + six \ + tornado diff --git a/docker/install/ubuntu_install_redis.sh b/docker/install/ubuntu_install_redis.sh index 0eb46eb8edec..d2600d828d49 100755 --- a/docker/install/ubuntu_install_redis.sh +++ b/docker/install/ubuntu_install_redis.sh @@ -21,4 +21,6 @@ set -u set -o pipefail apt-get update && apt-get install -y redis-server -pip3 install "xgboost>=1.1.0" psutil +pip3 install \ + psutil \ + "xgboost>=1.1.0" diff --git a/docker/install/ubuntu_install_sphinx.sh b/docker/install/ubuntu_install_sphinx.sh index 8a7ce1d3f798..12208bbe6643 100755 --- a/docker/install/ubuntu_install_sphinx.sh +++ b/docker/install/ubuntu_install_sphinx.sh @@ -21,4 +21,13 @@ set -u set -o pipefail # NOTE: install docutils < 0.17 to work around https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 -pip3 install sphinx sphinx-gallery==0.4.0 autodocsumm sphinx_rtd_theme sphinx_autodoc_annotation matplotlib Image "commonmark>=0.7.3" "docutils>=0.11,<0.17" +pip3 install \ + autodocsumm \ + "commonmark>=0.7.3" \ + "docutils>=0.11,<0.17" \ + Image \ + matplotlib \ + sphinx \ + sphinx_autodoc_annotation \ + sphinx-gallery==0.4.0 \ + sphinx_rtd_theme diff --git a/docker/install/ubuntu_install_tensorflow.sh b/docker/install/ubuntu_install_tensorflow.sh index 0e14c724ae3e..8a51fbbbb178 100755 --- a/docker/install/ubuntu_install_tensorflow.sh +++ b/docker/install/ubuntu_install_tensorflow.sh @@ -20,4 +20,7 @@ set -e set -u set -o pipefail -pip3 install tensorflow==2.4.2 keras==2.4.3 "h5py<3.0" +pip3 install \ + "h5py<3.0" \ + keras==2.4.3 \ + tensorflow==2.4.2 diff --git a/apps/microtvm/zephyr/aot_demo/boards/nucleo_l4r5zi.conf b/docker/install/ubuntu_install_vela.sh similarity index 68% rename from apps/microtvm/zephyr/aot_demo/boards/nucleo_l4r5zi.conf rename to docker/install/ubuntu_install_vela.sh index 52a6753c733b..e75a99d9d563 100644 --- a/apps/microtvm/zephyr/aot_demo/boards/nucleo_l4r5zi.conf +++ b/docker/install/ubuntu_install_vela.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,18 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -# This file is specific to the STM32L4R5ZI Nucleo board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# For AOT runtime which requires lots of function call. -CONFIG_MAIN_STACK_SIZE=3000 -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y +set -e +set -u +set -o pipefail -# For debugging. -CONFIG_LED=y +pip3 install -U setuptools +# In a refactor between v2.1.1 and v3.0.0, find_block_configs was removed from Vela. +# Since this is still required for the TVM port, it will be reinstated in Vela in a future release. +# Until then, it needs to be pinned to v2.1.1. +pip3 install ethos-u-vela==2.1.1 diff --git a/docker/lint.sh b/docker/lint.sh index d15ce71b7a98..e709bfb08445 100755 --- a/docker/lint.sh +++ b/docker/lint.sh @@ -20,7 +20,7 @@ source "$(dirname $0)/dev_common.sh" SCRIPT_NAME="$0" -DEFAULT_STEPS=( file_type asf cpplint clang_format pylint python_format jnilint cppdocs ) +DEFAULT_STEPS=( file_type asf cpplint clang_format pylint python_format jnilint cppdocs mypy ) inplace_fix=0 @@ -67,6 +67,9 @@ function run_lint_step() { cppdocs) cmd=( tests/lint/cppdocs.sh ) ;; + mypy) + cmd=( tests/scripts/task_mypy.sh ) + ;; *) echo "error: don't know how to run lint step: $1" >&2 echo "usage: ${SCRIPT_NAME} [-i] " >&2 diff --git a/docs/Makefile b/docs/Makefile index ca4f6e9a08f0..e04e324a0f80 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -23,15 +23,16 @@ SPHINXOPTS = SPHINXBUILD = python3 -m sphinx PAPER = BUILDDIR = _build +STAGINGDIR = _staging # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . +ALLSPHINXOPTS = -d $(PWD)/$(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . # the i18n builder cannot share the environment and doctrees with the others I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . -.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext +.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext staging help: @echo "Please use \`make ' where is one of" @@ -61,44 +62,83 @@ help: @echo " coverage to run coverage check of the documentation (if enabled)" clean: - rm -rf $(BUILDDIR)/* + rm -rf $(BUILDDIR) + rm -rf $(STAGINGDIR) + + # TODO(Lunderberg): Remove these lines once the CI steps have + # propagated. + + # Remove folders that have since been relocated into + # $(STAGINGDIR). This allows `task_sphinx_precheck.sh` to + # run, even if a commit that predates $(STAGINGDIR) was + # previously run on that node. rm -rf gen_modules + rm -rf user_tutorials rm -rf tutorials rm -rf vta/tutorials -html: - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html +staging: + # Prepare the staging directory. Sphinx gallery automatically + # writes new .rst files into the current directory. This can + # cause issues when switching branches. By sequestering the + # auto-generated files into the staging directory, they can be + # removed without knowing the exact directory. + + mkdir -p $(STAGINGDIR) + + # Remove any symlinks that currently exist + find $(STAGINGDIR) -type l -exec rm {} \; + + # Reproduce the directory structure + find . \ + -path ./$(BUILDDIR) -prune -o -path ./$(STAGINGDIR) -prune -o \ + -name "*.rst" \ + -printf "$(STAGINGDIR)/%h\n" \ + | sort | uniq | xargs mkdir -p + + # Symlink all .rst files into the staging directory + find . \ + -path ./$(BUILDDIR) -prune -o -path ./$(STAGINGDIR) -prune -o \ + -name "*.rst" \ + -exec ln -s $(PWD)/{} $(STAGINGDIR)/{} \; + + ln -s $(PWD)/conf.py $(STAGINGDIR)/conf.py + ln -s $(PWD)/_static $(STAGINGDIR)/_static + + +html: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." -dirhtml: - $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml +dirhtml: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/dirhtml @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." -singlehtml: - $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml +singlehtml: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/singlehtml @echo @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." -pickle: - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle +pickle: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/pickle @echo @echo "Build finished; now you can process the pickle files." -json: - $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json +json: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/json @echo @echo "Build finished; now you can process the JSON files." -htmlhelp: - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp +htmlhelp: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/htmlhelp @echo @echo "Build finished; now you can run HTML Help Workshop with the" \ ".hhp project file in $(BUILDDIR)/htmlhelp." -qthelp: - $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp +qthelp: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/qthelp @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" @@ -106,16 +146,16 @@ qthelp: @echo "To view the help file:" @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/rabit.qhc" -applehelp: - $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp +applehelp: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/applehelp @echo @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." @echo "N.B. You won't be able to view it unless you put it in" \ "~/Library/Documentation/Help or install it in your application" \ "bundle." -devhelp: - $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp +devhelp: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" @@ -123,85 +163,85 @@ devhelp: @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/rabit" @echo "# devhelp" -epub: - $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub +epub: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/epub @echo @echo "Build finished. The epub file is in $(BUILDDIR)/epub." -latex: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex +latex: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/latex @echo @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Run \`make' in that directory to run these through (pdf)latex" \ "(use \`make latexpdf' here to do that automatically)." -latexpdf: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex +latexpdf: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." -latexpdfja: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex +latexpdfja: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/latex @echo "Running LaTeX files through platex and dvipdfmx..." $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." -text: - $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text +text: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/text @echo @echo "Build finished. The text files are in $(BUILDDIR)/text." -man: - $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man +man: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/man @echo @echo "Build finished. The manual pages are in $(BUILDDIR)/man." -texinfo: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo +texinfo: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/texinfo @echo @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." @echo "Run \`make' in that directory to run these through makeinfo" \ "(use \`make info' here to do that automatically)." -info: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo +info: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/texinfo @echo "Running Texinfo files through makeinfo..." make -C $(BUILDDIR)/texinfo info @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." -gettext: - $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale +gettext: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(PWD)/$(BUILDDIR)/locale @echo @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." -changes: - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes +changes: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/changes @echo @echo "The overview file is in $(BUILDDIR)/changes." -linkcheck: - $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck +linkcheck: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/linkcheck @echo @echo "Link check complete; look for any errors in the above output " \ "or in $(BUILDDIR)/linkcheck/output.txt." -doctest: - $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest +doctest: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." -coverage: - $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage +coverage: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/coverage @echo "Testing of coverage in the sources finished, look at the " \ "results in $(BUILDDIR)/coverage/python.txt." -xml: - $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml +xml: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/xml @echo @echo "Build finished. The XML files are in $(BUILDDIR)/xml." -pseudoxml: - $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml +pseudoxml: staging + cd $(STAGINGDIR) && $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(PWD)/$(BUILDDIR)/pseudoxml @echo @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." diff --git a/docs/conf.py b/docs/conf.py index 3706c6201ba3..eaa17abef5de 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,18 +29,31 @@ # # All configuration values have a default; values that are commented out # serve to show the default. -import sys +import gc +import importlib.util import inspect -import os, subprocess +import os +from pathlib import Path import shlex +import subprocess +import sys + import sphinx_gallery + # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) -sys.path.insert(0, os.path.join(curr_path, "../python/")) -sys.path.insert(0, os.path.join(curr_path, "../vta/python")) +curr_path = Path(__file__).expanduser().absolute().parent +if curr_path.name == "_staging": + # Can't use curr_path.parent, because sphinx_gallery requires a relative path. + tvm_path = Path(os.pardir, os.pardir) +else: + tvm_path = Path(os.pardir) + + +sys.path.insert(0, str(tvm_path / "python")) +sys.path.insert(0, str(tvm_path / "vta" / "python")) # -- General configuration ------------------------------------------------ @@ -55,7 +68,7 @@ def git_describe_version(original_version): """Get git describe version.""" - ver_py = os.path.join(curr_path, "..", "version.py") + ver_py = tvm_path.joinpath("version.py") libver = {"__file__": ver_py} exec(compile(open(ver_py, "rb").read(), ver_py, "exec"), libver, libver) _, gd_version = libver["git_describe_version"]() @@ -117,7 +130,7 @@ def git_describe_version(original_version): # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ["_build"] +exclude_patterns = ["_build", "_staging"] # The reST default role (used for this markup: `text`) to use for all # documents. @@ -190,30 +203,31 @@ def git_describe_version(original_version): intersphinx_mapping = { "python": ("https://docs.python.org/{.major}".format(sys.version_info), None), "numpy": ("https://numpy.org/doc/stable", None), - "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), + "scipy": ("https://docs.scipy.org/doc/scipy/", None), "matplotlib": ("https://matplotlib.org/", None), } from sphinx_gallery.sorting import ExplicitOrder -examples_dirs = ["../tutorials/", "../vta/tutorials/"] +examples_dirs = [tvm_path.joinpath("tutorials"), tvm_path.joinpath("vta", "tutorials")] gallery_dirs = ["tutorials", "vta/tutorials"] subsection_order = ExplicitOrder( - [ - "../tutorials/get_started", - "../tutorials/frontend", - "../tutorials/language", - "../tutorials/optimize", - "../tutorials/autotvm", - "../tutorials/auto_scheduler", - "../tutorials/dev", - "../tutorials/topi", - "../tutorials/deployment", - "../tutorials/micro", - "../vta/tutorials/frontend", - "../vta/tutorials/optimize", - "../vta/tutorials/autotvm", + str(p) + for p in [ + tvm_path / "tutorials" / "get_started", + tvm_path / "tutorials" / "frontend", + tvm_path / "tutorials" / "language", + tvm_path / "tutorials" / "optimize", + tvm_path / "tutorials" / "autotvm", + tvm_path / "tutorials" / "auto_scheduler", + tvm_path / "tutorials" / "dev", + tvm_path / "tutorials" / "topi", + tvm_path / "tutorials" / "deployment", + tvm_path / "tutorials" / "micro", + tvm_path / "vta" / "tutorials" / "frontend", + tvm_path / "vta" / "tutorials" / "optimize", + tvm_path / "vta" / "tutorials" / "autotvm", ] ) @@ -300,6 +314,14 @@ def __call__(self, filename): return filename +# When running the tutorials on GPUs we are dependent on the Python garbage collector +# collecting TVM packed function closures for any device memory to also be released. This +# is not a good setup for machines with lots of CPU ram but constrained GPU ram, so force +# a gc after each example. +def force_gc(gallery_conf, fname): + gc.collect() + + sphinx_gallery_conf = { "backreferences_dir": "gen_modules/backreferences", "doc_module": ("tvm", "numpy"), @@ -317,6 +339,7 @@ def __call__(self, filename): "download_all_examples": False, "min_reported_time": 60, "expected_failing_examples": [], + "reset_modules": ("matplotlib", "seaborn", force_gc), } autodoc_default_options = { diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 60ab5e5ae9d2..efc8b32832c0 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -183,7 +183,7 @@ The first time you invoke the compiled module with ``fadd(a, b, c)``, ``GetFunct auto it = fmap_.find(name); const FunctionInfo& info = it->second; CUDAWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncVoidAddr(f, info.arg_types); } @@ -204,7 +204,7 @@ The ``PackedFunc``'s overloaded ``operator()`` will be called, which in turn cal fcache_[device_id] = m_->GetFunc(device_id, func_name_); } CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); CUresult result = cuLaunchKernel( fcache_[device_id], wl.grid_dim(0), diff --git a/docs/dev/how_to.rst b/docs/dev/how_to.rst index ff078fce911b..dcf8df2331eb 100644 --- a/docs/dev/how_to.rst +++ b/docs/dev/how_to.rst @@ -29,3 +29,4 @@ various areas of the TVM stack. relay_add_pass relay_bring_your_own_codegen codebase_walkthrough + pytest_target_parametrization diff --git a/docs/dev/index.rst b/docs/dev/index.rst index b4fb37d790f4..76d50f496e75 100644 --- a/docs/dev/index.rst +++ b/docs/dev/index.rst @@ -423,3 +423,4 @@ microTVM :maxdepth: 1 microtvm_design + model_library_format diff --git a/docs/dev/inferbound.rst b/docs/dev/inferbound.rst index 010d0d42d37e..28e034dc44cb 100644 --- a/docs/dev/inferbound.rst +++ b/docs/dev/inferbound.rst @@ -447,13 +447,11 @@ Here is the IR after ScheduleOps (note that loops with extent 1 have been preser :: - // attr [compute(D, 0x2c070b0)] realize_scope = "" realize D([0, 4], [0, 5], [0, 16]) { produce D { for (di, 0, 4) { for (dj, 0, 5) { for (dk, 0, 16) { - // attr [compute(C, 0x2c29990)] realize_scope = "" realize C([dj, 1], [dk, 1]) { produce C { for (i, 0, 1) { diff --git a/docs/dev/model_library_format.rst b/docs/dev/model_library_format.rst new file mode 100644 index 000000000000..fec90de4bcea --- /dev/null +++ b/docs/dev/model_library_format.rst @@ -0,0 +1,169 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Model Library Format +==================== + +About Model Library Format +-------------------------- + +TVM traditionally exports generated libraries as Dynamic Shared Objects (e.g. DLLs (Windows) or .so +(linux)). Inferences can be performed using those libraries by loading them into an executable using +``libtvm_runtime.so``. This process is very dependent on services provided by traditional OS. + +For deployment to unconventional platforms (e.g. those lacking traditional OS), TVM provides another +output format, Model Library Format. Initially, the microTVM project is the primary use case for this +format. Should it become useful in other use cases (and in particular, should it become possible to +export BYOC artifacts in Model Library Format), it could be used as a general-purpose TVM export +format. Model Library Format is a tarball containing a file for each piece of the TVM compiler +output. + +What can be Exported? +--------------------- + +At the time of writing, export is limited to full models built with ``tvm.relay.build``. + +Directory Layout +---------------- + +Model Library Format is contained within a tarball. All paths are relative to the root of the +tarball: + +- ``/`` - Root of the tarball + + - ``codegen`` - Root directory for all generated device code + + - (see `codegen`_ section) + + - ``executor-config/`` - Configuration for the executor which drives model inference + + - ``graph/`` - Root directory containing configuration for the GraphExecutor + + - ``graph.json`` - GraphExecutor JSON configuration + + - ``metadata.json`` - Machine-parseable metadata for this model + + - ``parameters/`` - Root directory where simplified parameters are placed + + - ``.params`` - Parameters for the model tvm.relay._save_params format + + - ``src/`` - Root directory for all source code consumed by TVM + + - ``relay.txt`` - Relay source code for the generated model + +Description of Sub-directories +------------------------------ + +.. _subdir_codegen: + +``codegen`` +^^^^^^^^^^^ + +All TVM-generated code is placed in this directory. At the time of writing, there is 1 file per +Module in the generated Module tree, though this restriction may change in the future. Files in +this directory should have filenames of the form ``/(lib|src)/.``. + +These components are described below: + + * ```` - Identifies the TVM target on which the code should run. Currently, only ``host`` + is supported. + * ```` - A unique slug identifying this file. Currently ``lib``, with ``>`` an + auto-incrementing integer. + * ```` - Suffix identifying the filename format. Currently ``c`` or ``o``. + +An example directory tree for a CPU-only model is shown below: + +- ``codegen/`` - Codegen directory + + - ``host/`` - Generated code for ``target_host`` + + - ``lib/`` - Generated binary object files + + - ``lib0.o`` - LLVM module (if ``llvm`` target is used) + - ``lib1.o`` - LLVM CRT Metadata Module (if ``llvm`` target is used) + + - ``src/`` - Generated C source + + - ``lib0.c`` - C module (if ``c`` target is used) + - ``lib1.c`` - C CRT Metadata module (if ``c`` target is used) + +``executor-config`` +^^^^^^^^^^^^^^^^^^^ + +Contains machine-parsable configuration for executors which can drive model inference. Currently, +only the GraphExecutor produces configuration for this directory, in ``graph/graph.json``. This +file should be read in and the resulting string supplied to the ``GraphExecutor()`` constructor for +parsing. + +``parameters`` +^^^^^^^^^^^^^^ + +Contains machine-parseable parameters. A variety of formats may be provided, but at present, only +the format produced by ``tvm.relay._save_params`` is supplied. When building with +``tvm.relay.build``, the ``name`` parameter is considered to be the model name. A single file is +created in this directory ``.json``. + +``src`` +^^^^^^^ + +Contains source code parsed by TVM. Currently, just the Relay source code is created in +``src/relay.txt``. + +Metadata +-------- + +Machine-parseable metadata is placed in a file ``metadata.json`` at the root of the tarball. +Metadata is a dictionary with these keys: + +- ``export_datetime``: Timestamp when this Model Library Format was generated, in + `strftime `_ + format ``"%Y-%M-%d %H:%M:%SZ",``. +- ``memory``: A summary of the memory usage of each generated function. Documented in + `Memory Usage Summary`_. +- ``model_name``: The name of this model (e.g. the ``name`` parameter supplied to + ``tvm.relay.build``). +- ``executors``: A list of executors supported by this model. Currently, this list is always + ``["graph"]``. +- ``target``: A dictionary mapping ``device_type`` (the underlying integer, as a string) to the + sub-target which describes that relay backend used for that ``device_type``. +- ``version``: A numeric version number that identifies the format used in this Model Library + Format. This number is incremented when the metadata structure or on-disk structure changes. + This document reflects version ``5``. + +Memory Usage Summary +^^^^^^^^^^^^^^^^^^^^ + +A dictionary with these sub-keys: + + - ``"main"``: ``list[MainFunctionWorkspaceUsage]``. A list summarizing memory usage for each + workspace used by the main function and all sub-functions invoked. + - ``"operator_functions"``: ``map[string, list[FunctionWorkspaceUsage]]``. Maps operator function + name to a list summarizing memory usage for each workpace used by the function. + +A ``MainFunctionWorkspaceUsage`` is a dict with these keys: + +- ``"device"``: ``int``. The ``device_type`` associated with this workspace. +- ``"workspace_size_bytes"``: ``int``. Number of bytes needed in this workspace by this function + and all sub-functions invoked. +- ``"constants_size_bytes"``: ``int``. Size of the constants used by the main function. +- ``"io_size_bytes"``: ``int``. Sum of the sizes of the buffers used from this workspace by this + function and sub-functions. + +A ``FunctionWorkspaceUsage`` is a dict with these keys: + +- ``"device"``: ``int``. The ``device_type`` associated with this workspace. +- ``"workspace_size_bytes"``: ``int``. Number of bytes needed in this workspace by this function. diff --git a/docs/dev/pass_infra.rst b/docs/dev/pass_infra.rst index 9fc24d87ef0d..9e76251cc85c 100644 --- a/docs/dev/pass_infra.rst +++ b/docs/dev/pass_infra.rst @@ -325,7 +325,7 @@ favorably use Python APIs to create a specific pass object. .. code:: c++ Pass CreateFunctionPass( - const runtime::TypedPackedFunc& pass_func, + const runtime::TypedPackedFunc& pass_func, int opt_level, String name, Array required); @@ -337,7 +337,7 @@ favorably use Python APIs to create a specific pass object. Array required); Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, + const runtime::TypedPackedFunc& pass_func, int opt_level, String name, Array required); diff --git a/docs/dev/pytest_target_parametrization.rst b/docs/dev/pytest_target_parametrization.rst new file mode 100644 index 000000000000..6dfcaf3633be --- /dev/null +++ b/docs/dev/pytest_target_parametrization.rst @@ -0,0 +1,283 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Python Target Parametrization +============================= + +Summary +------- + +For any supported runtime, TVM should should produce numerically +correct results. Therefore, when writing unit tests that validate +the numeric output, these unit tests should be run on all supported +runtimes. Since this is a very common use case, TVM has helper +functions to parametrize unit tests such that they will run on all +targets that are enabled and have a compatible device. + +A single python function in the test suite can expand to several +parametrized unit tests, each of which tests a single target device. +In order for a test to be run, all of the following must be true. + +- The test exists in a file or directory that has been passed to + `pytest`. + +- The pytest marks applied to the function, either explicitly or + through target parametrization, must be compatible with the + expression passed to pytest's `-m` argument. + +- For parametrized tests using the `target` fixture, the target must + appear in the environment variable `TVM_TEST_TARGETS`. + +- For parametrized tests using the `target` fixture, the build + configuration in `config.cmake` must enable the corresponding + runtime. + +Unit-Test File Contents +----------------------- + +.. _pytest-marks: https://docs.pytest.org/en/6.2.x/mark.html + +The recommended method to run a test on multiple targets is by +parametrizing the test. This can be done explicitly for a fixed list +of targets by decorating with +``@tvm.testing.parametrize_targets('target_1', 'target_2', ...)``, and +accepting ``target`` or ``dev`` as function arguments. The function +will be run once for each target listed, and the success/failure of +each target is reported separately. If a target cannot be run because +it is disabled in the `config.cmake`, or because no appropriate +hardware is present, then that target will be reported as skipped. + +.. code-block:: python + + # Explicit listing of targets to use. + @tvm.testing.parametrize_target('llvm', 'cuda') + def test_function(target, dev): + # Test code goes here + +For tests that should run correctly on all targets, the decorator can +be omitted. Any test that accepts a ``target`` or ``dev`` argument +will automatically be parametrized over all targets specified in +``TVM_TEST_TARGETS``. The parametrization provides the same +pass/fail/skipped report for each target, while allowing the test +suite to be easily extended to cover additional targets. + +.. code-block:: python + + # Implicitly parametrized to run on all targets + # in environment variable TVM_TEST_TARGETS + def test_function(target, dev): + # Test code goes here + +The ``@tvm.testing.parametrize_targets`` can also be used as a bare +decorator to explicitly draw attention to the parametrization, but has +no additional effect. + +.. code-block:: python + + # Explicitly parametrized to run on all targets + # in environment variable TVM_TEST_TARGETS + @tvm.testing.parametrize_targets + def test_function(target, dev): + # Test code goes here + + +Specific targets can be excluded or marked as expected to fail using +the ``@tvm.testing.exclude_targets`` or +``@tvm.testing.known_failing_targets`` decorators. For more +information on their intended use cases, please see their docstrings. + +In some cases it may be necessary to parametrize across multiple +parameters. For instance, there may be target-specific +implementations that should be tested, where some targets have more +than one implementation. These can be done by explicitly +parametrizing over tuples of arguments, such as shown below. In these +cases, only the explicitly listed targets will run, but they will +still have the appropriate ``@tvm.testing.requires_RUNTIME`` mark +applied to them. + +.. code-block:: python + + @pytest.mark.parametrize('target,impl', [ + ('llvm', cpu_implementation), + ('cuda', gpu_implementation_small_batch), + ('cuda', gpu_implementation_large_batch), + ]) + def test_function(target, dev, impl): + # Test code goes here + + +The parametrization functionality is implemented +on top of pytest marks. Each test function can +be decorated with `pytest marks `_ +to include metadata. The most frequently applied +marks are as follows. + +- ``@pytest.mark.gpu`` - Tags a function as using GPU + capabilities. This has no effect on its own, but can be paired with + command-line arguments ``-m gpu`` or ``-m 'not gpu'`` to restrict + which tests pytest will executed. This should not be called on its + own, but is part of other marks used in unit-tests. + +- ``@tvm.testing.uses_gpu`` - Applies ``@pytest.mark.gpu``. This + should be used to mark a unit tests that may use the GPU, if one is + present. This decorator is only needed for tests that explicitly + loop over ``tvm.testing.enabled_targets()``, but that is no longer + the preferred style of writing unit tests (see below). When using + ``tvm.testing.parametrize_targets()``, this decorator is implicit + for GPU targets, and does not need to be explicitly applied. + +- ``@tvm.testing.requires_gpu`` - Applies ``@tvm.testing.uses_gpu``, + and additionally marks that the test should be skipped + (``@pytest.mark.skipif``) entirely if no GPU is present. + +- ``@tvfm.testing.requires_RUNTIME`` - Several decorators + (e.g. ``@tvm.testing.requires_cuda``), each of which skips a test if + the specified runtime cannot be used. A runtime cannot be used if it + is disabled in the ``config.cmake``, or if a compatible device is + not present. For runtimes that use the GPU, this includes + ``@tvm.testing.requires_gpu``. + +When using parametrized targets, each test run is decorated with the +``@tvm.testing.requires_RUNTIME`` that corresponds to the target +being used. As a result, if a target is disabled in ``config.cmake`` +or does not have appropriate hardware to run, it will be explicitly +listed as skipped. + +There also exists a ``tvm.testing.enabled_targets()`` that returns +all targets that are enabled and runnable on the current machine, +based on the environment variable ``TVM_TEST_TARGETS``, the build +configuration, and the physical hardware present. Most current tests +explictly loop over the targets returned from ``enabled_targets()``, +but it should not be used for new tests. The pytest output for this +style silently skips runtimes that are disabled in ``config.cmake``, +or do not have a device on which they can run. In addition, the test +halts on the first target to fail, which is ambiguous as to whether +the error occurs on a particular target, or on every target. + +.. code-block:: python + + # Old style, do not use. + def test_function(): + for target,dev in tvm.testing.enabled_targets(): + # Test code goes here + + + +Running locally +--------------- + +To run the python unit-tests locally, use the command ``pytest`` in +the ``${TVM_HOME}`` directory. + +- Environment variables + - ``TVM_TEST_TARGETS`` should be a semicolon-separated list of + targets to run. If unset, will default to the targets defined in + ``tvm.testing.DEFAULT_TEST_TARGETS``. + + Note: If ``TVM_TEST_TARGETS`` does not contain any targets that + are both enabled, and have an accessible device of that type, + then the tests will fall back to running on the ``llvm`` target + only. + + - ``TVM_LIBRARY_PATH`` should be a path to the ``libtvm.so`` + library. This can be used, for example, to run tests using a + debug build. If unset, will search for ``libtvm.so`` relative to + the TVM source directory. + +- Command-line arguments + + - Passing a path to a folder or file will run only the unit tests + in that folder or file. This can be useful, for example, to + avoid running tests located in ``tests/python/frontend`` on a + system without a specific frontend installed. + + - The ``-m`` argument only runs unit tests that are tagged with a + specific pytest marker. The most frequent usage is to use ``m + gpu`` to run only tests that are marked with + ``@pytest.mark.gpu`` and use a GPU to run. It can also be used + to run only tests that do not use a GPU, by passing ``m 'not + gpu'``. + + Note: This filtering takes place after the selection of targets + based on the ``TVM_TEST_TARGETS`` environment variable. Even if + ``-m gpu`` is specified, if ``TVM_TEST_TARGETS`` does not + contain GPU targets, no GPU tests will be run. + +Running in local docker container +--------------------------------- + +.. _tlcpack: https://hub.docker.com/u/tlcpack + +The ``docker/bash.sh`` script can be used to run unit tests inside the +same docker image as is used by the CI. The first argument should +specify which docker image to run (e.g. ``docker/bash.sh ci_gpu``). +Allowed image names are defined at the top of the Jenkinsfile located +in the TVM source directory, and map to images at `tlcpack`_. + +If no additional arguments are given, the docker image will be loaded +with an interactive bash session. If a script is passed as an +optional argument (e.g. ``docker/bash.sh ci_gpu tests/scripts/task_python_unittest.sh``), then that script will be +executed inside the docker image. + +Note: The docker images contain all system dependencies, but do not +include the ``build/config.cmake`` configuration file for those +systems. The TVM source directory is used as the home directory of +the docker image, and so this will default to using the same +config/build directories as the local config. One solution is to +maintain separate ``build_local`` and ``build_docker`` directories, +and make a symlink from ``build`` to the appropriate folder when +entering/exiting docker. + +Running in CI +------------- + +Everything in the CI starts from the task definitions present in the +Jenkinsfile. This includes defining which docker image gets used, +what the compile-time configuration is, and which tests are included +in which stages. + +- Docker images + + Each task of the Jenkinsfile (e.g. 'BUILD: CPU') makes calls to + ``docker/bash.sh``. The argument following the call to + docker/bash.sh defines the docker image in CI, just as it does + locally. + +- Compile-time configuration + + The docker image does not have the ``config.cmake`` file built into + it, so this is the first step in each of the ``BUILD`` tasks. This + is done using the ``tests/scripts/task_config_build_*.sh`` scripts. + Which script is used depends on the build being tested, and is + specified in the Jenkinsfile. + + Each ``BUILD`` task concludes by packing a library for use in later + tests. + +- Which tests run + + The ``Unit Test`` and ``Integration Test`` stages of the Jenkinsfile + determine how ``pytest`` is called. Each task starts by unpacking a + compiled library that was previous compiled in the ``BUILD`` stage, + then runs a test script + (e.g. ``tests/script/task_python_unittest.sh``). These scripts set + the files/folders and command-line options that are passed to + ``pytest``. + + Several of these scripts include the ``-m gpu`` option, which + restricts the tests to only run tests that include the + ``@pytest.mark.gpu`` mark. diff --git a/docs/index.rst b/docs/index.rst index a7ae68c87b01..491c42712e9a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -78,6 +78,7 @@ For Developers :caption: MISC vta/index + profiling/index Index diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index c8e1ab7fc045..5c3a2544c578 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -63,7 +63,8 @@ The minimal building requirements for the ``TVM`` libraries are: - CMake 3.5 or higher - We highly recommend to build with LLVM to enable all the features. - If you want to use CUDA, CUDA toolkit version >= 8.0 is required. If you are upgrading from an older version, make sure you purge the older version and reboot after installation. - - On macOS, you may want to install `Homebrew ` to easily install and manage dependencies. + - On macOS, you may want to install `Homebrew `_ to easily install and manage dependencies. + - Python is also required. Avoid using Python 3.9.X+ which is not `supported `_. 3.7.X+ and 3.8.X+ should be well supported however. To install the these minimal pre-requisites on Ubuntu/Debian like linux operating systems, execute (in a terminal): @@ -73,6 +74,15 @@ linux operating systems, execute (in a terminal): sudo apt-get update sudo apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev +Use Homebrew to install the required dependencies for macOS running either the Intel or M1 processors. You must follow the post-installation steps specified by +Homebrew to ensure the dependencies are correctly installed and configured: + +.. code:: bash + + brew install gcc git cmake + brew install llvm + brew install python@3.8 + We use cmake to build the library. The configuration of TVM can be modified by editing `config.cmake` and/or by passing cmake flags to the command line: @@ -293,6 +303,21 @@ like ``virtualenv``. pip3 install --user tornado psutil xgboost cloudpickle +Note on M1 macs, you may have trouble installing xgboost / scipy. scipy and xgboost requires some additional dependencies to be installed, +including openblas and its dependencies. Use the following commands to install scipy and xgboost with the required dependencies and +configuration. A workaround for this is to do the following commands: + + .. code:: bash + + brew install openblas gfortran + + pip install pybind11 cython pythran   + + export OPENBLAS=/opt/homebrew/opt/openblas/lib/ + + pip install scipy --no-use-pep517 + + pip install xgboost Install Contrib Libraries ------------------------- @@ -316,7 +341,7 @@ tests in TVM. The easiest way to install GTest is from source. cd googletest mkdir build cd build - cmake -DMAKE_SHARED_LIBS=ON .. + cmake -DBUILD_SHARED_LIBS=ON .. make sudo make install diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index b74c58921d3f..68e77ecfa43e 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -406,7 +406,7 @@ Either match the first pattern or the second pattern. Domination ********** -Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node betwen the child and the pattern matches the path pattern. +Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node between the child and the pattern matches the path pattern. Function Pattern **************** diff --git a/docs/profiling/index.rst b/docs/profiling/index.rst new file mode 100644 index 000000000000..9443fef25ea6 --- /dev/null +++ b/docs/profiling/index.rst @@ -0,0 +1,24 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Profiling Deep Learning Models +==================================== + +.. toctree:: + :maxdepth: 1 + + papi diff --git a/docs/profiling/papi.rst b/docs/profiling/papi.rst new file mode 100644 index 000000000000..b7c23b2c0c73 --- /dev/null +++ b/docs/profiling/papi.rst @@ -0,0 +1,114 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Getting Started With PAPI +========================= + +The Performance Application Programming Interface (PAPI) is a library that +provides performance counters on a variety of platforms. Performance counters +provide accurate low-level information about processors behavior during a given +execution run. This information can contain simple metrics like total cycle +count, cache misses, and instructions executed as well as more high level +information like total FLOPS and warp occupancy. PAPI makes these metrics +available while profiling. + +Installing PAPI +--------------- + +PAPI can either be installed using your package manager (``apt-get install libpapi-dev`` +on Ubuntu), or from source here: +https://bitbucket.org/icl/papi/src/master/. + + +Building TVM With PAPI +---------------------- + +To include PAPI in your build of TVM, set the following line in you ``config.cmake``: + +.. code:: + + set(USE_PAPI ON) + +If PAPI is installed in a non-standard place, you can specify where it is like so: + +.. code:: + + set(USE_PAPI path/to/papi.pc) + + +Using PAPI While Profiling +-------------------------- + +If TVM has been built with PAPI (see above), then you can pass a +:py:class:`tvm.runtime.profiling.PAPIMetricCollector` to +:py:meth:`tvm.runtime.GraphModule.profile` to collect performance metrics. Here +is an example: + +.. code:: python + + target = "llvm" + dev = tvm.cpu() + mod, params = mlp.get_workload(1) + + exe = relay.vm.compile(mod, target, params=params) + vm = profiler_vm.VirtualMachineProfiler(exe, dev) + + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32"), device=dev) + report = vm.profile( + [data], + func_name="main", + collectors=[tvm.runtime.profiling.PAPIMetricCollector()], + ) + print(report) + +.. code:: + + Name perf::CACHE-MISSES perf::CYCLES perf::STALLED-CYCLES-BACKEND perf::INSTRUCTIONS perf::STALLED-CYCLES-FRONTEND + fused_nn_dense_nn_bias_add_nn_relu 2,494 1,570,698 85,608 675,564 39,583 + fused_nn_dense_nn_bias_add_nn_relu_1 1,149 655,101 13,278 202,297 21,380 + fused_nn_dense_nn_bias_add 288 600,184 8,321 163,446 19,513 + fused_nn_batch_flatten 301 587,049 4,636 158,636 18,565 + fused_nn_softmax 154 575,143 8,018 160,738 18,995 + ---------- + Sum 4,386 3,988,175 119,861 1,360,681 118,036 + Total 10,644 8,327,360 179,310 2,660,569 270,044 + +You can also change which metrics are collected: + +.. code:: python + + report = vm.profile( + [data], + func_name="main", + collectors=[tvm.runtime.profiling.PAPIMetricCollector({dev: ["PAPI_FP_OPS"])], + ) + +.. code:: + + Name PAPI_FP_OPS + fused_nn_dense_nn_bias_add_nn_relu 200,832 + fused_nn_dense_nn_bias_add_nn_relu_1 16,448 + fused_nn_dense_nn_bias_add 1,548 + fused_nn_softmax 160 + fused_nn_batch_flatten 0 + ---------- + Sum 218,988 + Total 218,988 + +You can find a list of available metrics by running the ``papi_avail`` and +``papi_native_avail`` commands. diff --git a/golang/sample/complex.go b/golang/sample/complex.go index 911d0a7a28c1..91821c978e96 100644 --- a/golang/sample/complex.go +++ b/golang/sample/complex.go @@ -88,7 +88,7 @@ func main() { // Array allocation attributes tshapeIn := []int64{1, 224, 224, 3} - tshapeOut := []int64{1, 1000} + tshapeOut := []int64{1, 1001} // Allocate input Array inX, err := gotvm.Empty(tshapeIn, "float32", gotvm.CPU(0)) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index d671339fb66b..6c72cbeafdd4 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -282,6 +282,18 @@ class IterSumExpr : public IterMapExpr { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, arith::Analyzer* analyzer); +/*! + * \brief Use IterVarMap detector to rewrite and simplify the indices + * + * \param indices The indices to detect pattern for. + * \param input_iters Map from variable to iterator's range. + * \param input_pred The predicate constraints on the input iterators + * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. + * + * \return The indices after rewrite + */ +Array IterMapSimplify(const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, bool require_bijective); /*! * \brief Apply the inverse of the affine transformation to the outputs. diff --git a/include/tvm/ir/affine_type.h b/include/tvm/ir/affine_type.h new file mode 100644 index 000000000000..afbe1f343bb8 --- /dev/null +++ b/include/tvm/ir/affine_type.h @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/affine_type.h + * \brief Quantized Tensor Types. + */ +#ifndef TVM_IR_AFFINE_TYPE_H_ +#define TVM_IR_AFFINE_TYPE_H_ + +#include +#include + +namespace tvm { + +/*! + * \brief AffineType representation + * \sa AffineType + */ +class AffineTypeNode : public Object { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + static constexpr const char* _type_key = "AffineType"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(AffineTypeNode, Object); +}; + +/*! + * \brief Managed reference to AffineTypeNode. + * \sa AffineTypeNode + */ +class AffineType : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(AffineType, ObjectRef, AffineTypeNode); +}; + +/*! + * \brief TensorAffineType representation + * \sa TensorAffineType + * + * This Type represents a quantized integer tensor that can be converted + * back to real space via the x_real = scale * (x_quant - zero_point) + */ +class TensorAffineTypeNode : public AffineTypeNode { + public: + /*! \brief The scale of this type */ + RelayExpr scale; + /*! \brief The zero point of this type */ + RelayExpr zero_point; + /*! \brief The data type of this type */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("scale", &scale); + v->Visit("zero_point", &zero_point); + v->Visit("dtype", &dtype); + } + + bool SEqualReduce(const TensorAffineTypeNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(scale, other->scale) && equal(zero_point, other->zero_point) && + equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(scale); + hash_reduce(zero_point); + hash_reduce(dtype); + } + + static constexpr const char* _type_key = "TensorAffineType"; + TVM_DECLARE_BASE_OBJECT_INFO(TensorAffineTypeNode, AffineTypeNode); +}; + +/*! + * \brief Managed reference to AffineTypes. + * \sa AffineTypeNode + */ +class TensorAffineType : public AffineType { + public: + TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorAffineType, AffineType, TensorAffineTypeNode); +}; + +/*! + * \brief TupleAffineType representation + * \sa TupleAffineType + */ +class TupleAffineTypeNode : public AffineTypeNode { + public: + /*! \brief The types of this tuple*/ + Array types; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("types", &types); } + + bool SEqualReduce(const TupleAffineTypeNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(types, other->types); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(types); + } + + static constexpr const char* _type_key = "TupleAffineType"; + TVM_DECLARE_BASE_OBJECT_INFO(TupleAffineTypeNode, AffineTypeNode); +}; + +/*! + * \brief Managed reference to TupleAffineTypes. + * \sa TupleAffineType + */ +class TupleAffineType : public AffineType { + public: + TVM_DLL TupleAffineType(Array types); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleAffineType, AffineType, TupleAffineTypeNode); +}; + +} // namespace tvm +#endif // TVM_IR_AFFINE_TYPE_H_ diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index da7bc12619bd..fa1861051e2f 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -214,6 +214,7 @@ class DictAttrsNode : public BaseAttrsNode { void VisitNonDefaultAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; Array ListFieldInfo() const final; + // type info static constexpr const char* _type_key = "DictAttrs"; TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); @@ -232,6 +233,72 @@ class DictAttrs : public Attrs { */ TVM_DLL explicit DictAttrs(Map dict); + // Utils for accessing attributes + // This needs to be on DictAttrs, not DictAttrsNode because we return the default + // value if DictAttrsNode is not defined. + /*! + * \brief Get a function attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const BaseFunc& f) { + * auto value = f->attrs.GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + static_assert(std::is_base_of::value, + "Can only call GetAttr with ObjectRef types."); + if (!defined()) return default_value; + const DictAttrsNode* node = this->as(); + + auto it = node->dict.find(attr_key); + if (it != node->dict.end()) { + return Downcast>((*it).second); + } else { + return default_value; + } + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + /*! + * \brief Check whether the function has an non-zero integer attr. + * + * This function can be used to check whether an optional + * attribute mark(e.g. inline) exists. + * + * \param attr_key The key to the attribute. + * \return The check result. + * + * \code + * + * void HasNonzeroAttrExample(const BaseFunc& f) { + * if (f->HasNonzeroAttr(attr::kInline)) { + * // inline the function. + * } + * } + * + * \endcode + */ + bool HasNonzeroAttr(const std::string& attr_key) const { + return GetAttr(attr_key, 0) != 0; + } + TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; @@ -249,6 +316,47 @@ inline TAttrs AttrsWithDefaultValues() { return TAttrs(n); } +/*! + * \brief Copy the function or module, but overrides + * the attribute value key with the value. + * + * \param input The thing to annotate (BaseFunc or IRModule) + * \param attr_key The attribute key. + * \param attr_value The value attribute value. + * + * \tparam TFunc The corresponding function or module type. + * + * \returns The new function or module with updated attributes. + * + * \note This function performs copy on write optimization for func and module. + * If we move a uniquely referenced func or module into WithAttr, + * then no additional copy will be performed. + * + * This is also why we make it as a function instead of a member function + * and why we pass by value in the first argument. + * + * \code + * + * // Recommended way to trigger copy on write + * func = WithAttr(std::move(func), "key1", value1); + * func = WithAttr(std::move(func), "key2", value2); + * + * \endcode + */ +template +inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) { + using TNode = typename TFunc::ContainerType; + static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); + TNode* node = input.CopyOnWrite(); + if (node->attrs.defined()) { + node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); + } else { + Map dict = {{attr_key, attr_value}}; + node->attrs = DictAttrs(dict); + } + return input; +} + // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index c1a012f05318..13b984d9cb35 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -43,7 +43,7 @@ namespace tvm { */ enum class CallingConv : int { /*! - * \brief Default calling convetion. + * \brief Default calling convention. * * - Uses the native calling convention of the target. * - Implementation: specified by the native target. @@ -102,21 +102,14 @@ class BaseFuncNode : public RelayExprNode { Optional GetAttr( const std::string& attr_key, Optional default_value = Optional(nullptr)) const { - static_assert(std::is_base_of::value, - "Can only call GetAttr with ObjectRef types."); - if (!attrs.defined()) return default_value; - auto it = attrs->dict.find(attr_key); - if (it != attrs->dict.end()) { - return Downcast>((*it).second); - } else { - return default_value; - } + return attrs.GetAttr(attr_key, default_value); } // variant that uses TObjectRef to enable implicit conversion to default value. template Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { return GetAttr(attr_key, Optional(default_value)); } + /*! * \brief Check whether the function has an non-zero integer attr. * @@ -136,9 +129,7 @@ class BaseFuncNode : public RelayExprNode { * * \endcode */ - bool HasNonzeroAttr(const std::string& attr_key) const { - return GetAttr(attr_key, 0) != 0; - } + bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); } static constexpr const char* _type_key = "BaseFunc"; static constexpr const uint32_t _type_child_slots = 2; @@ -154,48 +145,6 @@ class BaseFunc : public RelayExpr { TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); }; -/*! - * \brief Create a new function that copies func, but overrides - * the attribute value key with the value. - * - * \param func The input function. - * \param attr_key The attribute key. - * \param attr_value The value attribute value. - * - * \tparam TFunc The corresponding function type. - * - * \returns The new function with updated attributes. - * - * \note This function performs copy on write optimization for func. - * If we move a uniquely referenced func into WithAttr, - * then no additional copy will be performed. - * - * This is also why we make it as a function instead of a member function - * and why we pass by value in the first argument. - * - * \code - * - * // Recommended way to trigger copy on write - * func = WithAttr(std::move(func), "key1", value1); - * func = WithAttr(std::move(func), "key2", value2); - * - * \endcode - */ -template ::value>::type> -inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) { - using TNode = typename TFunc::ContainerType; - static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); - TNode* node = func.CopyOnWrite(); - if (node->attrs.defined()) { - node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); - } else { - Map dict = {{attr_key, attr_value}}; - node->attrs = DictAttrs(dict); - } - return func; -} - /*! * \brief Generic attribute names that can be attached to any function. * diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 638f132e3179..fefb08f878ef 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -36,6 +36,7 @@ #include #include #include +#include #include namespace tvm { @@ -58,6 +59,60 @@ class IRModuleNode : public Object { Map type_definitions; /*! \brief The source map for the module. */ parser::SourceMap source_map; + /* \brief Additional attributes storing meta-data about the module. */ + DictAttrs attrs; + + /*! + * \brief Get a module attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const IRModule& mod) { + * auto value = f->GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + return attrs.GetAttr(attr_key, default_value); + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + + /*! + * \brief Check whether the module has an non-zero integer attr. + * + * This function can be used to check whether an optional + * attribute mark(e.g. inline) exists. + * + * \param attr_key The key to the attribute. + * \return The check result. + * + * \code + * + * void HasNonzeroAttrExample(const IRModule& mod) { + * if (mod->HasNonzeroAttr(attr::kInline)) { + * // inline the function. + * } + * } + * + * \endcode + */ + bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); } IRModuleNode() : source_map() {} @@ -253,6 +308,14 @@ class IRModuleNode : public Object { /*! \brief Helper function for registering a typedef's constructors */ void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type); + /*! + * \brief Returns a version of \p name which is unique amongst all function definitions in module. + * + * \param name The original name. + * \return Updated name which is unique. + */ + String GetUniqueName(const String& name); + /*! \brief A map from string names to global variables that * ensures global uniqueness. */ @@ -307,16 +370,38 @@ class IRModule : public ObjectRef { } /*! - * \brief Construct a module from a standalone expression. + * \brief Constructs a module from a standalone expression \p expr. * - * Allows one to optionally pass a global function map and - * map of type definitions as well. + * If \p expr is a function it will be bound directly. Otherwise a function over the free + * variables of \p expr (possibly none) with \p expr as body is created and bound. + * + * The function is bound to, in preference order: + * - The "global_symbol" attribute of \p expr, if it is a function with that attribute. + * - 'main' + * - A unique name derived from 'main' if 'main' is already bound in \p global_funcs. + * + * Additional global functions and type definitions may be included in the result module. + * + * See also \p FromExpr. * * \param expr The expression to set as the main function to the module. - * \param global_funcs The global function map. - * \param type_definitions Map of global type definitions + * \param global_funcs The global function map. Default empty. + * \param type_definitions The global type definition map. Default empty. + * \param import_set Set of external modules already imported. Default empty. + * + * \returns A module with \p expr set as the main function, and the global var to which + * \p expr was bound (typcially 'main'). * - * \returns A module with expr set as the main function. + * TODO(mbs): Does import_set and the bound global var need to be exposed via ffi? + */ + static std::pair FromExprInContext( + const RelayExpr& expr, const Map& global_funcs = {}, + const Map& type_definitions = {}, + std::unordered_set import_set = {}); + + /*! + * \brief As for \p FromExprInContext, but assuming \p expr is bound to 'main' and no + * imports. */ TVM_DLL static IRModule FromExpr(const RelayExpr& expr, const Map& global_funcs = {}, diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index c772650809fa..5f2c4de3152a 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -164,10 +164,17 @@ class PointerTypeNode : public TypeNode { } bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const { - return equal(element_type, other->element_type); + // Make "global" equal to "" + String lhs_scope = storage_scope.empty() ? "global" : storage_scope; + String rhs_scope = other->storage_scope.empty() ? "global" : other->storage_scope; + return equal(element_type, other->element_type) && equal(lhs_scope, rhs_scope); } - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(element_type); } + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(element_type); + // Make "global" equal to "" + hash_reduce(storage_scope.empty() ? "global" : storage_scope); + } static constexpr const char* _type_key = "PointerType"; TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode); diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index d842c33cce03..18e8db0ace22 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -43,7 +43,7 @@ using runtime::ObjectPtr; using runtime::ObjectRef; /*! - * \brief Visitor class for to get the attributesof a AST/IR node. + * \brief Visitor class to get the attributes of an AST/IR node. * The content is going to be called for each field. * * Each objects that wants reflection will need to implement @@ -75,7 +75,7 @@ class AttrVisitor { /*! * \brief Virtual function table to support IR/AST node reflection. * - * Functions are stored in columar manner. + * Functions are stored in columnar manner. * Each column is a vector indexed by Object's type_index. */ class ReflectionVTable { @@ -205,7 +205,7 @@ class ReflectionVTable::Registry { /*! * \brief Set fcreate function. * \param f The creator function. - * \return rference to self. + * \return Reference to self. */ Registry& set_creator(FCreate f) { // NOLINT(*) ICHECK_LT(type_index_, parent_->fcreate_.size()); @@ -215,7 +215,7 @@ class ReflectionVTable::Registry { /*! * \brief Set bytes repr function. * \param f The ReprBytes function. - * \return rference to self. + * \return Reference to self. */ Registry& set_repr_bytes(FReprBytes f) { // NOLINT(*) ICHECK_LT(type_index_, parent_->frepr_bytes_.size()); @@ -374,7 +374,7 @@ inline ReflectionVTable::Registry ReflectionVTable::Register() { fsequal_reduce_.resize(tindex + 1, nullptr); fshash_reduce_.resize(tindex + 1, nullptr); } - // functor that implemnts the redirection. + // functor that implements the redirection. fvisit_attrs_[tindex] = ::tvm::detail::SelectVisitAttrs::VisitAttrs; fsequal_reduce_[tindex] = ::tvm::detail::SelectSEqualReduce::SEqualReduce; diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 3c7574562676..d28044c3845d 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1003,16 +1003,45 @@ struct DenseAttrs : public tvm::AttrsNode { } }; -/*! \brief Attributes for batch matmul operator */ +/*! \brief Attributes for dense_pack operator */ +struct DensePackAttrs : public tvm::AttrsNode { + IndexExpr units; + DataType out_dtype; + tvm::String weight_layout; + + TVM_DECLARE_ATTRS(DensePackAttrs, "relay.attrs.DensePackAttrs") { + TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); + + // use 0 bits to indicate none. + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + TVM_ATTR_FIELD(weight_layout) + .set_default("NK") + .describe("Dimension ordering of weight. Packed layouts, such as NK8n, are possible."); + } +}; + +/*! \brief Attributes for batch matmul operator. */ struct BatchMatmulAttrs : public tvm::AttrsNode { - tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite DataType out_dtype; + bool transpose_a; + bool transpose_b; + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") { // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); + + TVM_ATTR_FIELD(transpose_a) + .set_default(false) + .describe("Whether the first input tensor is in transposed format."); + + TVM_ATTR_FIELD(transpose_b) + .set_default(false) + .describe("Whether the second input tensor is in transposed format."); } }; @@ -1037,12 +1066,16 @@ struct SparseTransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes for sparse_dense operator */ struct SparseConv2DAttrs : public tvm::AttrsNode { std::string layout; + Array kernel_size; TVM_DECLARE_ATTRS(SparseConv2DAttrs, "relay.attrs.SparseConv2DAttrs") { TVM_ATTR_FIELD(layout).set_default("NHWC").describe( "Dimension ordering of input data. Can be 'NCHW', 'NHWC'" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively."); + TVM_ATTR_FIELD(kernel_size) + .set_default(Array{1, 1}) + .describe("Kernel size for SparseConv2D, 1x1 or 3x3. "); } }; diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 12e4e3f45fef..10e461645c8b 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -47,10 +47,13 @@ class DFPatternCallbackNode : public Object { PackedFunc function; /*! \brief Require InferType to be run before the callback */ bool require_type; + /*! \brief Run the callback only once */ + bool rewrite_once; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pattern", &pattern); v->Visit("require_type", &require_type); + v->Visit("rewrite_once", &rewrite_once); } static constexpr const char* _type_key = "DFPatternCallbackNode"; @@ -63,7 +66,8 @@ class DFPatternCallbackNode : public Object { */ class DFPatternCallback : public ObjectRef { public: - TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type); + TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type, + bool rewrite_once = false); TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode); }; diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 95eaad0b2797..9170bc53ea02 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -126,7 +126,7 @@ namespace attr { /*! \brief Mark the function as a primitive function. */ constexpr const char* kPrimitive = "Primitive"; /*! - * \brief Indicate the compiler that should be used for builing this function. + * \brief Indicate the compiler that should be used for building this function. * When this is unset or set to "default", the default compilation pipeline will be used. */ constexpr const char* kCompiler = "Compiler"; @@ -144,7 +144,6 @@ constexpr const char* kComposite = "Composite"; constexpr const char* kInline = "Inline"; /*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; - /*! \brief Mark the function as only composed of reshape operations. */ constexpr const char* kReshapeOnly = "relay.reshape_only"; } // namespace attr diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 93a56cede77b..eed6d0ffc1e4 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -40,31 +40,11 @@ #include #include +#include + namespace tvm { namespace relay { -/*! - *\brief Create a Interpreter function that can - * evaluate an expression and produce a value. - * - * The resulting value can be passed to Python, making it easy to use - * for testing and debugging. - * - * The interpreter interprets the program fragments not supported by the - * TVM runtime, although the interpreter is naively implemented it uses - * TVM operators for evaluating all operators. - * - * Our intent is that this will never be the most efficient implementation of - * Relay's semantics, but a readable and clear one. - * - * \param mod The function module. - * \param device The primary device that the interepreter runs on. - * \param target Compiler target flag to compile the functions on the context. - * \return A function that takes in an expression and returns a value. - */ -runtime::TypedPackedFunc CreateInterpreter(IRModule mod, Device device, - Target target); - /*! \brief The container type of Closures used by the interpreter. */ class InterpreterClosureObj : public runtime::ClosureObj { public: @@ -164,6 +144,52 @@ class ConstructorValue : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj); }; +/*! + * \brief Returns a packed function over Relay expressions which will evaluate \p expr + * applied to those arguments, where \p expr is w.r.t. the definitions in \p mod. + * + * This function is intended to support the Python 'debug' executor. + * + * The given \p expr should have function type. The given \p mod may be empty or + * undefined if \p expr is self-contained. Relay arguments passed to the result + * packed function must be constants, references, or constructors/tuples over such. + * As much work as possible is done while constructing the result packed function, and + * that function may be reasonably efficiently applied multiple times without redoing + * unnecessary work. + * + * Primitives are lowered and compiled to packed functions for execution on \p device + * with properties given by \p target. All other Relay constructs are interpreted. + * + * The interpreter is intended to be a 'reference' implementation of the Relay semantics + * for testing and interactive use. It is not intended to be particularly efficient. + * + * \param mod A module containing definitions which can be referenced from + * \p expr. May be empty or undefined. + * \param expr An expression of function type to evaluate. May reference definitions from \p mod. + * \param device The device on which all primitives will be executed. + * \param target The compiler target flag for compiling primitives. + * \return A packed function that takes an array of Relay expressions and returns the + * result of applying \p expr to those arguments. + */ +TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, Device device, + Target target); + +/*! + * \brief Evaluates \p expr and returns its result. + * + * This function is intended to support TVM constant evaluation. + * + * \param expr An expression to evaluate. + * \param type_definitions Global type definitions which \p expr may references. + * \param import_set Already imported external modules. + * \param device The device on which all primitives will be executed. + * \param target The compiler target flag for compiling primitives. + * @return The object representing the result. + */ +ObjectRef Eval(Expr expr, Map type_definitions, + std::unordered_set import_set, Device device, Target target); + } // namespace relay } // namespace tvm + #endif // TVM_RELAY_INTERPRETER_H_ diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 17d1ba2a5132..8454b04443a1 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -520,6 +520,14 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); */ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); +/*! + * \brief Convert type index to type key. + * \param tindex The type index. + * \param out_type_key The output type key. + * \return 0 when success, nonzero when failure happens + */ +TVM_DLL int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); + /*! * \brief Increase the reference count of an object. * diff --git a/include/tvm/runtime/contrib/papi.h b/include/tvm/runtime/contrib/papi.h new file mode 100644 index 000000000000..ff2d75c483eb --- /dev/null +++ b/include/tvm/runtime/contrib/papi.h @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \brief Performance counters for profiling via the PAPI library. + */ +#ifndef TVM_RUNTIME_CONTRIB_PAPI_H_ +#define TVM_RUNTIME_CONTRIB_PAPI_H_ + +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace profiling { + +/*! \brief Construct a metric collector that collects data from hardware + * performance counters using the Performance Application Programming Interface + * (PAPI). + * + * \param metrics A mapping from a device type to the metrics that should be + * collected on that device. You can find the names of available metrics by + * running `papi_native_avail`. + */ +TVM_DLL MetricCollector CreatePAPIMetricCollector(Map> metrics); +} // namespace profiling +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_PAPI_H_ diff --git a/include/tvm/runtime/crt/module.h b/include/tvm/runtime/crt/module.h index 7b124c4faa3a..bcbae32fd96e 100644 --- a/include/tvm/runtime/crt/module.h +++ b/include/tvm/runtime/crt/module.h @@ -35,7 +35,7 @@ extern "C" { * \brief Module container of TVM. */ typedef struct TVMModule { - /*! \brief The function registry associated with this mdoule. */ + /*! \brief The function registry associated with this module. */ const TVMFuncRegistry* registry; } TVMModule; diff --git a/include/tvm/runtime/crt/packed_func.h b/include/tvm/runtime/crt/packed_func.h index 0c39fe1a65b8..83d961baf203 100644 --- a/include/tvm/runtime/crt/packed_func.h +++ b/include/tvm/runtime/crt/packed_func.h @@ -62,11 +62,11 @@ void TVMPackedFunc_SetArgs(TVMPackedFunc* pf, const TVMArgs* args); inline TVMModuleHandle TVMArgs_AsModuleHandle(const TVMArgs* args, size_t index) { if (index >= args->values_count) { - TVMPlatformAbort(-1); + TVMPlatformAbort((tvm_crt_error_t)-1); } if (args->tcodes[index] != kTVMModuleHandle) { - TVMPlatformAbort(-1); + TVMPlatformAbort((tvm_crt_error_t)-1); } return args->values[index].v_handle; diff --git a/include/tvm/runtime/crt/rpc_common/framing.h b/include/tvm/runtime/crt/rpc_common/framing.h index 32a0f56dab11..33f37a0af03f 100644 --- a/include/tvm/runtime/crt/rpc_common/framing.h +++ b/include/tvm/runtime/crt/rpc_common/framing.h @@ -134,7 +134,7 @@ class Unframer { /*! \brief number of bytes in buffer that are currently valid. */ size_t num_buffer_bytes_valid_; - /*! \brief number of payload bytes left to write before the CRC begins. */ + /*! \brief number of payload bytes left to receive before the CRC begins. */ size_t num_payload_bytes_remaining_; /*! \brief Running CRC value. */ diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 58b9ff1932cc..c3d83bf2993f 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -85,6 +85,15 @@ class TVM_DLL DeviceAPI { * \sa DeviceAttrKind */ virtual void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) = 0; + + /*! + * \brief Query the device for specified properties. + * + * This is used to expand "-from_device=N" in the target string to + * all properties that can be determined from that device. + */ + virtual void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) {} + /*! * \brief Allocate a data space on device. * \param dev The device device to perform operation. @@ -291,7 +300,14 @@ inline Device RemoveRPCSessionMask(Device dev) { return dev; } -inline std::ostream& operator<<(std::ostream& os, DLDevice dev); +inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*) + if (tvm::runtime::IsRPCSessionDevice(dev)) { + os << "remote[" << tvm::runtime::GetRPCSessionIndex(dev) << "]-"; + dev = tvm::runtime::RemoveRPCSessionMask(dev); + } + os << tvm::runtime::DeviceName(static_cast(dev.device_type)) << "(" << dev.device_id << ")"; + return os; +} /*! * \brief Add a RPC session mask to a Device. @@ -308,14 +324,6 @@ inline Device AddRPCSessionMask(Device dev, int session_table_index) { return dev; } -inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*) - if (IsRPCSessionDevice(dev)) { - os << "remote[" << GetRPCSessionIndex(dev) << "]-"; - dev = RemoveRPCSessionMask(dev); - } - os << runtime::DeviceName(static_cast(dev.device_type)) << "(" << dev.device_id << ")"; - return os; -} } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 9dd7423c6679..71be8d218d2d 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -230,8 +230,10 @@ constexpr const char* tvm_module_main = "__tvm_main__"; constexpr const char* tvm_param_prefix = "__tvm_param__"; /*! \brief A PackedFunc that looks up linked parameters by storage_id. */ constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param"; -/*! \brief The main AOT executor function */ +/*! \brief The main AOT executor function generated from TIR */ constexpr const char* tvm_run_func_suffix = "run_model"; +/*! \brief Model entrypoint generated as an interface to the AOT function outside of TIR */ +constexpr const char* tvm_entrypoint_suffix = "run"; } // namespace symbol // implementations of inline functions. diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 5b44e020f4e4..eea195f64a6d 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -37,6 +37,7 @@ #include namespace tvm { + namespace runtime { /*! \brief Base class for all implementations. @@ -150,6 +151,26 @@ class Timer : public ObjectRef { Timer DefaultTimer(Device dev); namespace profiling { +/*! \brief Wrapper for `Device` because `Device` is not passable across the + * PackedFunc interface. + */ +struct DeviceWrapperNode : public Object { + /*! The device */ + Device device; + + /*! Constructor */ + explicit DeviceWrapperNode(Device device) : device(device) {} + + static constexpr const char* _type_key = "runtime.profiling.DeviceWrapper"; + TVM_DECLARE_BASE_OBJECT_INFO(DeviceWrapperNode, Object); +}; + +/*! \brief Wrapper for `Device`. */ +class DeviceWrapper : public ObjectRef { + public: + explicit DeviceWrapper(Device dev) { data_ = make_object(dev); } + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DeviceWrapper, ObjectRef, DeviceWrapperNode); +}; /*! \brief Data collected from a profiling run. Includes per-call metrics and per-device metrics. */ @@ -184,6 +205,39 @@ class ReportNode : public Object { * `aggregate` is true. */ String AsTable(bool sort = true, bool aggregate = true) const; + /*! \brief Convert this report to JSON. + * + * Output JSON will be of this format: + * \code + * { + * "calls": [ + * { + * "Duration (us)": { + * "microseconds": 12.3 + * }, + * "Name": "fused_dense", + * "Count": { + * "count": 1 + * }, + * "Percent": { + * "percent": 10.3 + * } + * } + * ], + * "device_metrics": { + * "cpu": { + * "Duration (us)": { + * "microseconds": 334.2 + * }, + * "Percent": { + * "percent": 100 + * } + * } + * } + * } + * \endcode + */ + String AsJSON() const; static constexpr const char* _type_key = "runtime.profiling.Report"; TVM_DECLARE_FINAL_OBJECT_INFO(ReportNode, Object); @@ -200,6 +254,57 @@ class Report : public ObjectRef { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Report, ObjectRef, ReportNode); }; +/*! \brief Interface for user defined profiling metric collection. + * + * Users can register their own collector by registering a packed function with + * the name "runtime.profiling.metrics.my_collector_name" where + * "my_collector_name" is the name of their collector. This function should + * take an Array of Device as input which contains the devices the collector + * will be run on. + * + * `MetricCollectorNode`s will be called in the following fashion. + * \code + * MetricCollector mc; + * for (auto op : model) { + * auto o = mc.Start(); + * op(); + * auto metrics = mc.Stop(o); // metrics are added the profiling report + * } + * \endcode + */ +class MetricCollectorNode : public Object { + public: + /*! \brief Initialization call. Called before profiling has started. Any + * expensive precomputation should happen here. + * \param devs The list of devices this collector will be run on. + */ + virtual void Init(Array devs) = 0; + /*! \brief Start colling metrics for a function call. + * \param dev The device the call will be run on. + * \returns An object used to maintain state of the metric collection. This + * object will be passed to the corresponding `Stop` call. If the device is + * not supported, this function will return a nullptr ObjectRef. + */ + virtual ObjectRef Start(Device dev) = 0; + /*! \brief Stop collecting metrics. + * \param obj The object created by the corresponding `Start` call. + * \returns A set of metric names and the associated values. Values must be + * one of DurationNode, PercentNode, CountNode, or StringObj. + */ + virtual Map Stop(ObjectRef obj) = 0; + + virtual ~MetricCollectorNode() {} + + static constexpr const char* _type_key = "runtime.profiling.MetricCollector"; + TVM_DECLARE_BASE_OBJECT_INFO(MetricCollectorNode, Object); +}; + +/*! \brief Wrapper for `MetricCollectorNode`. */ +class MetricCollector : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetricCollector, ObjectRef, MetricCollectorNode); +}; + /*! Information about a single function or operator call. */ struct CallFrame { /*! Device on which the call was made */ @@ -210,6 +315,10 @@ struct CallFrame { Timer timer; /*! Extra performance metrics */ std::unordered_map extra_metrics; + /*! User defined metric collectors. Each pair is the MetricCollector and its + * associated data (returned from MetricCollector.Start). + */ + std::vector> extra_collectors; }; /*! Runtime profiler for function and/or operator calls. Used in the graph @@ -217,9 +326,10 @@ struct CallFrame { * * Example usage: * \code{.cpp} - * Profiler prof; * Device cpu, gpu; - * prof.Start({cpu, gpu}); + * Profiler prof({cpu, gpu}); + * my_gpu_kernel(); // do a warmup iteration + * prof.Start(); * prof.StartCall("my_gpu_kernel", gpu); * my_gpu_kernel(); * prof.StopCall(); @@ -232,13 +342,24 @@ struct CallFrame { */ class Profiler { public: - /*! \brief Start the profiler. + /*! Constructor. + * + * The profiler should be constructed before you do any warmup iterations. + * + * \note + * Calling this constructor will reset the TVM threadpool. It is necessary in + * order to install thread handlers required by certain collectors. + * * \param devs The list of devices the profiler will be running on. Should * include all devices used by profiled operators. + * \param metric_collectors Additional `MetricCollector`s to use with this profiler. + */ + explicit Profiler(std::vector devs, std::vector metric_collectors); + /*! \brief Start the profiler. * * This function should only be called once per object. */ - void Start(const std::vector& devs); + void Start(); /*! \brief Stop the profiler. * * This function should only be called once per object after start has been called. @@ -270,12 +391,14 @@ class Profiler { /*! \brief Check if the profiler is currently running. * \returns Whether or not the profiler is running. */ - bool IsRunning() const { return !global_timers_.empty(); } + bool IsRunning() const { return is_running_; } private: - std::vector> global_timers_; + std::vector devs_; + bool is_running_{false}; std::vector calls_; std::stack in_flight_; + std::vector collectors_; }; /* \brief A duration in time. */ diff --git a/include/tvm/runtime/threading_backend.h b/include/tvm/runtime/threading_backend.h index 95a64049fd45..43636ddbdb1f 100644 --- a/include/tvm/runtime/threading_backend.h +++ b/include/tvm/runtime/threading_backend.h @@ -94,6 +94,14 @@ void Yield(); */ int MaxConcurrency(); +/*! + * \brief Reset the threads in the pool. All current threads are destroyed and + * new ones are created. + * + * Note that this does nothing when openmp is used. + */ +void ResetThreadPool(); + } // namespace threading } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 58c6ee037fb5..2fdfec9452af 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -258,6 +258,16 @@ class VirtualMachine : public runtime::ModuleNode { */ void InvokeGlobal(const VMFunction& func, const std::vector& args); + /*! + * \brief Set inputs to a function. + * \param name The function name + * \param args args[offset:] are arguments to the + * function. If the arguments are not of the correct device for the function, + * they will be copied to the device. + * \param offset Starting offset of the arguments in `args`. + */ + void SetInput(std::string name, TVMArgs args, int offset); + protected: /*! \brief The virtual machine's packed function table. */ std::vector packed_funcs_; diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h new file mode 100644 index 000000000000..6b733d074f6a --- /dev/null +++ b/include/tvm/support/random_engine.h @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file random_engine.h + * \brief Random number generator. It provides a generic interface consistent with + * `std::uniform_random_bit_generator` + */ + +#ifndef TVM_SUPPORT_RANDOM_ENGINE_H_ +#define TVM_SUPPORT_RANDOM_ENGINE_H_ + +#include + +#include // for uint64_t + +namespace tvm { +namespace support { + +/*! + * \brief This linear congruential engine is a drop-in replacement for std::minstd_rand. It strictly + * corresponds to std::minstd_rand and is designed to be platform-independent. + * \note Our linear congruential engine is a complete implementation of + * std::uniform_random_bit_generator so it can be used as generator for any STL random number + * distribution. However, parts of std::linear_congruential_engine's member functions are not + * included for simplification. For full member functions of std::minstd_rand, please check out the + * following link: https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine + */ + +class LinearCongruentialEngine { + public: + /*! + * \brief The result type is defined as uint64_t here to avoid overflow. + * \note The type name is not in Google style because it is used in STL's distribution inferface. + */ + using result_type = uint64_t; + using TRandState = int64_t; + + /*! \brief The multiplier */ + static constexpr TRandState multiplier = 48271; + + /*! \brief The increment */ + static constexpr TRandState increment = 0; + + /*! \brief The modulus */ + static constexpr TRandState modulus = 2147483647; + + /*! + * \brief The minimum possible value of random state here. + * \note The function name is uncapilized because it is used in STL's distribution inferface. + */ + static constexpr result_type min() { return 0; } + + /*! + * \brief The maximum possible value of random state here. + * \note The function name is uncapilized because it is used in STL's distribution inferface. + */ + static constexpr result_type max() { return modulus - 1; } + + /*! + * \brief Operator to move the random state to the next and return the new random state. According + * to definition of linear congruential engine, the new random state value is computed as + * new_random_state = (current_random_state * multiplier + increment) % modulus. + * \return The next current random state value in the type of result_type. + * \note In order for better efficiency, the implementation here has a few assumptions: + * 1. The multiplication and addition won't overflow. + * 2. The given random state pointer `rand_state_ptr` is not nullptr. + * 3. The given random state `*(rand_state_ptr)` is in the range of [0, modulus - 1]. + */ + result_type operator()() { + (*rand_state_ptr_) = ((*rand_state_ptr_) * multiplier + increment) % modulus; + return *rand_state_ptr_; + } + + /*! + * \brief Change the start random state of RNG with the seed of a new random state value. + * \param rand_state The random state given in result_type. + */ + void Seed(TRandState rand_state = 1) { + rand_state %= modulus; // Make sure the seed is within the range of modulus. + if (rand_state == 0) + rand_state = 1; // Avoid getting all 0 given the current parameter set. + else if (rand_state < 0) + rand_state += modulus; // Make sure the rand state is non-negative. + ICHECK(rand_state_ptr_ != nullptr); // Make sure the pointer is not null. + *rand_state_ptr_ = rand_state; // Change pointed random state to given random state value. + } + + /*! + * \brief Construct a random number generator with a random state pointer. + * \param rand_state_ptr The random state pointer given in result_type*. + * \note The random state is not checked for whether it's nullptr and whether it's in the range of + * [0, modulus-1]. We assume the given random state is valid or the Seed function would be + * called right after the constructor before any usage. + */ + explicit LinearCongruentialEngine(TRandState* rand_state_ptr) { + rand_state_ptr_ = rand_state_ptr; + } + + private: + TRandState* rand_state_ptr_; +}; + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_RANDOM_ENGINE_H_ diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index e7da2dd413a0..8a2bbcbd0121 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -377,7 +377,8 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_name() { .add_attr_option("device") \ .add_attr_option("model") \ .add_attr_option>("libs") \ - .add_attr_option("host") + .add_attr_option("host") \ + .add_attr_option("from_device") } // namespace tvm diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 27e48999a7d1..13f39317dbe4 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -125,11 +125,12 @@ class TVM_DLL OperationNode : public Object { * \param stage the op's stage. * \param realize_map The realization domain map of the operators. * \param body The body that is going to get + * \param storage_scope The storage scope associated with this realization * \return A realization statement that wraps body. */ virtual Stmt BuildRealize(const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const = 0; + const std::unordered_map& realize_map, const Stmt& body, + String storage_scope = "") const = 0; /*! * \brief Build the statement that provide the output tensors. * \param stage The schedule stage of the op. @@ -168,7 +169,7 @@ class PlaceholderOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -212,7 +213,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; virtual size_t num_schedulable_dims() const = 0; static constexpr const char* _type_key = "BaseComputeOp"; @@ -370,7 +371,7 @@ class ScanOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -433,7 +434,7 @@ class ExternOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -498,7 +499,7 @@ class HybridOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 63d6fa375c83..dce9736adec7 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -96,22 +96,20 @@ TVM_DLL Array UndefinedVars(const PrimExpr& expr); TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr); /*! - * \brief Whether e expression used any var in variable set. - * \param expr The expression to be checked. - * \param vset_contains The check function to see if var is in the vset. - * \return Whether e uses vset. + * \brief Whether the given Stmt uses any var in the given variable set. + * \param stmt The Stmt to be checked. + * \param vset_contains The check function to see if a var is in the variable set. + * \return Whether `stmt` uses any var in the given variable set. */ -TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function vset_contains); +TVM_DLL bool UsesVar(const Stmt& stmt, std::function vset_contains); /*! - * \brief Whether e expression used var. - * \param expr The expression to be checked. - * \param var The variable. - * \return Whether e uses v. + * \brief Whether the given PrimExpr uses any var in the given variable set. + * \param expr The PrimExpr to be checked. + * \param vset_contains The check function to see if var is in the variable set. + * \return Whether `expr` uses any var in the given variable set. */ -inline bool ExprUseVar(const PrimExpr& expr, const Var& var) { - return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; }); -} +TVM_DLL bool UsesVar(const PrimExpr& expr, std::function vset_contains); /*! * \brief Verifies whether the IR stmt or Expr is in SSA form. diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 017f4f7052b1..28d202cb50a9 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -67,8 +67,6 @@ class BufferNode : public Object { // Meta data /*! \brief optional name of the buffer */ String name; - /*! \brief storage scope of the buffer, if other than global */ - String scope; /*! \brief Alignment requirement of data pointer in bytes. */ int data_alignment; /*! @@ -93,7 +91,6 @@ class BufferNode : public Object { v->Visit("strides", &strides); v->Visit("elem_offset", &elem_offset); v->Visit("name", &name); - v->Visit("scope", &scope); v->Visit("data_alignment", &data_alignment); v->Visit("offset_factor", &offset_factor); v->Visit("buffer_type", &buffer_type); @@ -105,7 +102,7 @@ class BufferNode : public Object { // in its semantics, skip name as name is not important. return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) && - equal.DefEqual(elem_offset, other->elem_offset) && equal(scope, other->scope) && + equal.DefEqual(elem_offset, other->elem_offset) && equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type); } @@ -115,7 +112,6 @@ class BufferNode : public Object { hash_reduce.DefHash(shape); hash_reduce.DefHash(strides); hash_reduce.DefHash(elem_offset); - hash_reduce(scope); hash_reduce(data_alignment); hash_reduce(buffer_type); } @@ -141,8 +137,8 @@ class Buffer : public ObjectRef { // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. TVM_DLL Buffer(Var ptr, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, String name, String scope, int data_alignment, - int offset_factor, BufferType buffer_type, Span span = Span()); + PrimExpr elem_offset, String name, int data_alignment, int offset_factor, + BufferType buffer_type, Span span = Span()); /*! * \brief Return a new buffer that is equivalent with current one @@ -182,6 +178,11 @@ class Buffer : public ObjectRef { */ TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; + /*! + * \brief Return the storage scope associated with this buffer. + */ + TVM_DLL String scope() const; + TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode); }; @@ -191,12 +192,13 @@ class Buffer : public ObjectRef { * \param shape The shape of the buffer, * \param dtype The content data type. * \param name The name of the buffer + * \param storage_scope The storage scope associated with this buffer * \param span The location of this object in the source code. * \return The created buffer. * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", Span span = Span()); + String name = "buffer", String storage_scope = "", Span span = Span()); /*! * \brief Base node for data producers. diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 61280d33f1df..86857a33cdf4 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -600,6 +600,20 @@ TVM_DLL const Op& vectorcombine(); * \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA */ TVM_DLL const Op& atomic_add(); +/*! + * \brief Create a texture 2d memory allocation + */ +TVM_DLL const Op& texture2d_alloca(); + +/*! + * \brief Store to texture 2d memory + */ +TVM_DLL const Op& texture2d_store(); + +/*! + * \brief Load from texture 2d memory + */ +TVM_DLL const Op& texture2d_load(); /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 40d66a2d8357..8ea48dd592d5 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1126,6 +1126,9 @@ class AnyNode : public PrimExprNode { /*! \brief Convert to var. */ Var ToVar() const { return Var("any_dim", DataType::Int(32)); } + /*! \brief Convert to SizeVar. */ + SizeVar ToSizeVar() const { return SizeVar("any_dim", DataType::Int(32)); } + static constexpr const char* _type_key = "tir.Any"; TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode); }; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 25ed2f9ae8d1..55f4fc62649c 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -240,10 +240,12 @@ namespace attr { * * Call(f, * [arg1, arg2, ..., arg_n, - * work_size_1, work_size_2, ... work_size_m]) + * work_size_1, work_size_2, ... work_size_m, dyn_shmem_size]) * * Here n = len(arg), m = len(work_size) = len(device_thread_axis). * + * When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted. + * * The list of device_thread_axis indicates how can be bind the * work_size arguments to the corresponding threads. * @@ -251,6 +253,13 @@ namespace attr { */ constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis"; +/*! + * \brief Whether or not use dynamic shared memory. + * + * Type: Integer + */ +constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory"; + /*! * \brief Whether to set noalias rule on the function arguments. * diff --git a/include/tvm/tir/schedule/block_scope.h b/include/tvm/tir/schedule/block_scope.h index fb08583b7771..be3e79a18331 100644 --- a/include/tvm/tir/schedule/block_scope.h +++ b/include/tvm/tir/schedule/block_scope.h @@ -262,7 +262,7 @@ class BlockScope : public ObjectRef { * \param child_block_srefs The srefs to the leaf blocks * \note We assume the leaf blocks are given in pre-DFS order */ - TVM_DLL BlockScope(const Array& child_block_srefs); + TVM_DLL explicit BlockScope(const Array& child_block_srefs); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode); }; diff --git a/include/tvm/tir/schedule/instruction.h b/include/tvm/tir/schedule/instruction.h new file mode 100644 index 000000000000..5a9e687dc8c7 --- /dev/null +++ b/include/tvm/tir/schedule/instruction.h @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_INSTRUCTION_H_ +#define TVM_TIR_SCHEDULE_INSTRUCTION_H_ + +#include + +#include + +namespace tvm { + +// Forward declaration +template +class AttrRegistry; + +namespace tir { + +// Forward declaration +class Schedule; + +/*! + * \brief Type of the functor that applies the instruction to a TensorIR schedule + * \param sch The schedule to be applied on + * \param inputs The input random variables + * \param attrs Instruction attributes + * \param decision Decisions made on the instruction + * \return The functor returns an array of output random variables + */ +using FInstructionApply = runtime::TypedPackedFunc( + Schedule sch, const Array& inputs, const Array& attrs, + const Optional& decision)>; + +/*! + * \brief Type of the functor that converts the instruction to a statement in python syntax + * \param inputs Names of the input random variables + * \param attrs Instruction attributes + * \param decisions Decisions made on the instruction + * \param outputs Names of the output random variables + * \return A string representing the python api call + */ +using FInstructionAsPython = runtime::TypedPackedFunc& inputs, const Array& attrs, + const Optional& decision, const Array& outputs)>; + +/*! + * \brief Type of the functor that serialize its attributes to JSON + * \param attrs The attributes to be serialized + * \return An array, serialized attributes + * \note This functor is nullable + */ +using FInstructionAttrsAsJSON = runtime::TypedPackedFunc attrs)>; + +/*! + * \brief Type of the functor that deserialize its attributes from JSON + * \param json_attrs The attributes to be serialized + * \return An array, deserialized attributes + * \note This functor is nullable + */ +using FInstructionAttrsFromJSON = runtime::TypedPackedFunc(ObjectRef json_attrs)>; + +/*! + * \brief Kind of an instruction, e.g. Split, Reorder, etc. + * Besides the name, every kind of instruction has its own properties, including: + * 1) A boolean indicating if the instruction is pure, i.e. change nothing in the schedule state + * 2) A functor that applies the instruction to a TensorIR schedule + * 3) A functor that converts the instruction to a statement in python syntax + * 4) A functor that serialize its attributes to JSON + * 5) A functor that deserialize its attributes from JSON + * + * Unlike `tvm::OpNode`, `InstructionKindNode` doesn't support unstructured properties, + * mainly because there is no such usecase yet to add any other property. + */ +class InstructionKindNode : public runtime::Object { + public: + /*! \brief The name of a kind of instructions */ + String name; + /*! + * \brief Indicates if the instruction is pure, i.e. removing it alone doesn't mutate the schedule + * state. For example, the instruction `GetBlock` is pure because it changes + * nothing, while `ComputeInline` is not because removing it leads to a different resulting + * schedule. + */ + bool is_pure{false}; + /*! \brief A functor that applies the instruction to a TensorIR schedule */ + FInstructionApply f_apply_to_schedule{nullptr}; + /*! \brief A functor that converts the instruction to a statement in python syntax */ + FInstructionAsPython f_as_python{nullptr}; + /*! + * \brief A functor that serialize its attributes to JSON + * \note If the functor is null, it means no conversion is needed + */ + FInstructionAttrsAsJSON f_attrs_as_json{nullptr}; + /*! + * \brief A functor that deserialize its attributes from JSON + * \note If the functor is null, it means no conversion is needed + */ + FInstructionAttrsFromJSON f_attrs_from_json{nullptr}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("_is_pure", &is_pure); + // not visited: f_apply_to_schedule + // not visited: f_as_python + // not visited: f_attrs_as_json + // not visited: f_attrs_from_json + } + + static constexpr const char* _type_key = "tir.InstructionKind"; + TVM_DECLARE_FINAL_OBJECT_INFO(InstructionKindNode, runtime::Object); +}; + +/*! + * \brief Managed reference to InstructionKindNode + * \sa InstructionKindNode + */ +class InstructionKind : public runtime::ObjectRef { + public: + /*! + * \brief Retrieve an InstructionKind using its name + * \param name The registered name of the InstructionKind + * \return The InstructionKind retrieved + */ + static InstructionKind Get(const String& name); + TVM_DEFINE_OBJECT_REF_METHODS(InstructionKind, runtime::ObjectRef, InstructionKindNode); +}; + +/*! \brief Schedule instructions each corresponds to a schedule primitive */ +class InstructionNode : public runtime::Object { + public: + /*! \brief The kind of the instruction */ + InstructionKind kind; + /*! + * \brief The input random variables of the instruction, and the type of each element can be one + * of the following: + * - BlockRV + * - LoopRV + * - ExprRV + * - FloatImm + * - IntImm + * - String + * - null pointer + */ + Array inputs; + /*! + * \brief The attributes of the instruction. Similar to attributes of an operator, + * attributes of an instruction are arbitrary constant metadata required by the instructions. + * For example, the name of the block to be retrieved in `GetBlock`. + */ + Array attrs; + /*! \brief The output random variables of the instruction, and the type of each element can be one + * of the following: + * - BlockRV + * - LoopRV + * - ExprRV, atomic variables only, won't be constants or composite PrimExpr + */ + Array outputs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("kind", &kind); + v->Visit("inputs", &inputs); + v->Visit("attrs", &attrs); + v->Visit("outputs", &outputs); + } + + static constexpr const char* _type_key = "tir.Instruction"; + TVM_DECLARE_FINAL_OBJECT_INFO(InstructionNode, runtime::Object); +}; + +/*! + * \brief Managed reference to InstructionNode + * \sa InstructionNode + */ +class Instruction : public runtime::ObjectRef { + public: + /*! + * \brief Constructor + * \param kind The kind of the instruction + * \param inputs The input random variables of the instruction + * \param attrs The attributes of the instruction + * \param outputs The output random variables of the instruction + */ + explicit Instruction(InstructionKind kind, Array inputs, Array attrs, + Array outputs); + + TVM_DEFINE_OBJECT_REF_METHODS(Instruction, runtime::ObjectRef, InstructionNode); +}; + +/*! + * \brief A helper macro to register InstructionKind, only used in `TVM_REGISTER_INST_KIND` + * \note This macro is not user-facing. + * \sa TVM_REGISTER_INST_KIND + */ +#define TVM_INST_KIND_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::tir::InstructionKindRegEntry& __make_##InstructionKind + +/*! + * \brief Register an InstructionKind + * \param InstructionKindName The name of the InstructionKind + * + * Example: + * + * \code + * + * TVM_REGISTER_INST_KIND("ComputeInline") + * .set_is_pure(false) + * .set_apply_to_schedule(ApplyToSchedule) + * .set_attrs_as_json(AttrsAsJSON) + * .set_attrs_from_json(AttrsFromJSON) + * .set_as_python(AsPython); + * + * \endcode + */ +#define TVM_REGISTER_INST_KIND(InstructionKindName) \ + TVM_STR_CONCAT(TVM_INST_KIND_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::tir::InstructionKindRegEntry::RegisterOrGet(InstructionKindName).set_name() + +/*! \brief An entry in the registry of InstructionKind */ +class InstructionKindRegEntry { + public: + static InstructionKindRegEntry& RegisterOrGet(const String& name); + + InstructionKindRegEntry& set_name() { + get_mutable()->name = this->name; + return *this; + } + + InstructionKindRegEntry& set_is_pure(bool is_pure) { + get_mutable()->is_pure = is_pure; + return *this; + } + + InstructionKindRegEntry& set_apply_to_schedule(FInstructionApply f_apply_to_schedule) { + get_mutable()->f_apply_to_schedule = std::move(f_apply_to_schedule); + return *this; + } + + InstructionKindRegEntry& set_as_python(FInstructionAsPython f_as_python) { + get_mutable()->f_as_python = std::move(f_as_python); + return *this; + } + + InstructionKindRegEntry& set_attrs_as_json(FInstructionAttrsAsJSON f_attrs_as_json) { + get_mutable()->f_attrs_as_json = std::move(f_attrs_as_json); + return *this; + } + + InstructionKindRegEntry& set_attrs_from_json(FInstructionAttrsFromJSON f_attrs_from_json) { + get_mutable()->f_attrs_from_json = std::move(f_attrs_from_json); + return *this; + } + + private: + /*! \brief Private constructor, used only by AttrRegistry */ + explicit InstructionKindRegEntry(uint32_t reg_index); + /*! \brief Get the mutable reference to the internal InstructionKind */ + InstructionKindNode* get_mutable() const { + return const_cast(inst_kind_.get()); + } + + /*! \brief The name of the registry entry */ + String name; + /*! \brief The instruction kind */ + InstructionKind inst_kind_; + template + friend class ::tvm::AttrRegistry; + friend class InstructionKind; +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_INSTRUCTION_H_ diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9a09d0ad211f..79fed09c3e36 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -19,7 +19,9 @@ #ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_ #define TVM_TIR_SCHEDULE_SCHEDULE_H_ +#include #include +#include namespace tvm { namespace tir { @@ -95,19 +97,21 @@ class ScheduleNode : public runtime::Object { virtual ~ScheduleNode() = default; static constexpr const char* _type_key = "tir.Schedule"; - TVM_DECLARE_BASE_OBJECT_INFO(ScheduleNode, runtime::Object); + TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, runtime::Object); public: /*! \brief Get the IRModule associated with this schedule. */ virtual IRModule mod() const { return state()->mod; } /*! \return The internal state of scheduling */ virtual ScheduleState state() const = 0; + /*! \return The internally maintained trace of scheduling program execution */ + virtual Optional trace() const = 0; /*! * \brief Returns a copy of the schedule, including both its state and its symbol table, * guaranteeing that * 1) SRef tree is completely reconstructed; * 2) The IRModule being scheduled is not modified; - * 3) All the random variables are valid in the copy, pointing to the correpsonding sref + * 3) All the random variables are valid in the copy, pointing to the corresponding sref * reconstructed */ virtual Schedule Copy() const = 0; @@ -115,9 +119,9 @@ class ScheduleNode : public runtime::Object { * \brief Seed the randomness * \param seed The new random seed, -1 if use device random, otherwise non-negative */ - virtual void Seed(int64_t seed = -1) { - LOG(FATAL) << "ValueError: The schedule cannot be seeded because no randomness is allowed"; - } + virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) = 0; + /*! \brief Fork the random state */ + virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0; public: /******** Lookup/Remove random variables ********/ @@ -180,7 +184,18 @@ class ScheduleNode : public runtime::Object { virtual void RemoveRV(const ExprRV& expr_rv) = 0; public: - /******** Block/Loop relation ********/ + /******** Schedule: Sampling ********/ + /*! + * \brief Sample an integer given the probability distribution + * \param candidates The candidates + * \param probs The probability distribution of the candidates + * \param decision The sampling decision + * \return The random variable sampled from candidates + */ + virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, + Optional decision = NullOpt) = 0; + + /******** Schedule: Get blocks & loops ********/ /*! * \brief Retrieve a block in a specific function with its name * \param name The name of the block to be retrieved @@ -195,8 +210,79 @@ class ScheduleNode : public runtime::Object { * \return A list of loops above the given block in its scope, from outer to inner */ virtual Array GetLoops(const BlockRV& block_rv) = 0; - /******** Schedule: loops manipulation ********/ - /******** Schedule: compute location ********/ + /******** Schedule: Transform loops ********/ + /*! + * \brief Fuse a list of consecutive loops into one. It requires: + * 1) The loops can't have annotations or thread bindings. + * 2) The (i+1)-th loop must be the only child of the i-th loop. + * 3) All loops must start with 0. + * \param loop_rvs The loops to be fused + * \return The new loop after fusion + */ + virtual LoopRV Fuse(const Array& loop_rvs) = 0; + /*! + * \brief Split a loop into a list of consecutive loops. It requires: + * 1) The loop can't have annotation or thread binding. + * 2) The loop must start with 0. + * \param loop_rv The loop to be split + * \param factors The tiling factors, and at most one of which is -1, which means that + * factor is inferred. + * \return The new loops after split + */ + virtual Array Split(const LoopRV& loop_rv, const Array>& factors) = 0; + /*! + * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. + * It requires: + * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , + * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between + * l_1 and l_n (which also indicates they are under the same scope). + * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. + * 3) For every block under the loop nests, its block binding must be affine, and the block + * variables must be either data parallel or reduction. + * 4) No duplicated loops are allowed in the arguments. + * \param ordered_loop_rvs The loops in the new order + */ + virtual void Reorder(const Array& ordered_loop_rvs) = 0; + /******** Schedule: Manipulate ForKind ********/ + /*! + * \brief Parallelize the input loop. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' + * bindings + * \param loop_rv The loop to be parallelized + */ + virtual void Parallel(const LoopRV& loop_rv) = 0; + /*! + * \brief Vectorize the input loop. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' + * bindings + * \param loop_rv The loop to be vectorized + */ + virtual void Vectorize(const LoopRV& loop_rv) = 0; + /*! + * \brief Bind the input loop to the given thread axis. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can only + * be contained in data-parallel block iter and reduction block iters' bindings. Otherwise the + * loop can only be contained in data-parallel block iters' bindings + * \param loop_rv The loop to be bound to the thread axis + * \param thread_axis The thread axis to be bound to the loop + */ + virtual void Bind(const LoopRV& loop_rv, const String& thread_axis) = 0; + /*! + * \brief Unroll the input loop. It requires nothing + * \param loop_rv The loop to be unrolled + */ + virtual void Unroll(const LoopRV& loop_rv) = 0; + /******** Schedule: Insert cache stages ********/ + /******** Schedule: Compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: * 1) The block is a complete non-root block, which only produces one buffer @@ -220,10 +306,45 @@ class ScheduleNode : public runtime::Object { * \param block The block to be inlined to its producer */ virtual void ReverseComputeInline(const BlockRV& block) = 0; - /******** Schedule: loop binding/annotation ********/ - /******** Schedule: cache read/write ********/ - /******** Schedule: reduction ********/ - /******** Schedule: blockize & tensorize ********/ + /******** Schedule: Reduction ********/ + /*! + * \brief Factorize an associative reduction block by the specified loop. + * \details An associative reduction cannot be parallelized directly, + * because it leads to potential race condition during accumulation. + * Alternatively, the reduction could be factorized on a loop with the following steps: + * - Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent + * - Step 2: compute the chunks separately and write the result into `n` intermediate buffers; + * - Step 3: accumulate the `n` separate buffer into the result buffer. + * Note that the Step 2 above introduces opportunities for parallelization. + * RFactor is a schedule primitive that implements the transformation described above. + * \param loop_rv The loop outside block we want to do rfactor + * \param factor_axis The position where the new dimension is placed in the new introduced rfactor + * buffer. Suppose the original reduction block writes to buffer `B` with + * ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1, + * ndim(B)]`, and the negative index will be normalized to a non-negative one + * \return The rfactor block + */ + virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0; + /******** Schedule: Block annotation ********/ + /*! + * \brief Set alignment requirement for specific dimension such that + * stride[axis] == k * factor + offset for some k. This is useful to set memory layout for + * more friendly memory access pattern. For example, we can set alignment to be factor=2, + * offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared + * memory. + * \param block_rv The producer block of the buffer + * \param buffer_index The index of the buffer in block's write region + * \param axis The dimension to be specified for alignment + * \param factor The factor multiple of alignment + * \param offset The required offset factor + */ + virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, + int offset) = 0; + /******** Schedule: Blockize & Tensorize ********/ + /******** Schedule: Annotation ********/ + /******** Schedule: Misc ********/ + /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ + virtual void EnterPostproc() = 0; }; /*! @@ -246,7 +367,8 @@ class Schedule : public runtime::ObjectRef { /*! * \brief Construct a concrete TensorIR schedule from an IRModule * \param mod The IRModule to be scheduled - * \param debug_mode Do extra correctness checking after the class creation + * \param seed The seed value for schedule's random state + * \param debug_mask Do extra correctness checking after the class creation * and each time after calling the Replace method. * \param error_render_level The level of error rendering * \return The concrete schedule created @@ -255,8 +377,23 @@ class Schedule : public runtime::ObjectRef { * 1) VerifySRefTree * 2) VerifyCachedFlags */ - TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode, - ScheduleErrorRenderLevel error_render_level); + TVM_DLL static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, ScheduleErrorRenderLevel error_render_level); + /*! + * \brief Construct a traced concrete TensorIR schedule from an IRModule + * \param mod The IRModule to be scheduled + * \param seed The seed value for schedule's random state + * \param debug_mask Do extra correctness checking after the class creation + * and each time after calling the Replace method. + * \param error_render_level The level of error rendering + * \return The concrete schedule created + * \sa ScheduleDebugMask + * \note The checks performed include: + * 1) VerifySRefTree + * 2) VerifyCachedFlags + */ + TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, ScheduleErrorRenderLevel error_render_level); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); }; diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 83ac7150543f..7cd1b00c15ef 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -80,7 +80,7 @@ enum ScheduleDebugMask : uint32_t { * 2) The sref tree of schedulable statements (indicated by the srefs) * 3) The dependency information of each block scope (block_info) * 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref) - * 5) A debug flag, if set, extra checking is enabled (debug_mode) + * 5) A debug flag, if set, extra checking is enabled (debug_mask) */ class ScheduleStateNode : public Object { public: @@ -99,13 +99,13 @@ class ScheduleStateNode : public Object { * and each time after calling the Replace method. * \sa ScheduleDebugMask */ - int debug_mode; + int debug_mask; void VisitAttrs(AttrVisitor* v) { v->Visit("mod", &mod); // `block_info` is not visited // `stmt2ref` is not visited - v->Visit("debug_mode", &debug_mode); + v->Visit("debug_mask", &debug_mask); } /*! * \brief Replace the part of the AST, as being pointed to by `src_sref`, @@ -129,7 +129,7 @@ class ScheduleStateNode : public Object { TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt, const Map& block_sref_reuse); /*! - * \brief Trigger the verification according to the `debug_mode` bitmask. + * \brief Trigger the verification according to the `debug_mask` bitmask. * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree. * 2) If the bitmask `kVerifyCachedFlags` is on, verify the correctness of `affine_binding`, * `region_cover` and `stage_pipeline` @@ -186,18 +186,10 @@ class ScheduleState : public ObjectRef { /*! * \brief Construct a schedule state from an IRModule * \param mod The IRModule to be scheduled - * \param debug_mode Do extra correctness checking after the class creation + * \param debug_mask Do extra correctness checking after the class creation * and each time after calling the Replace method. */ - TVM_DLL explicit ScheduleState(IRModule mod, int debug_mode = 0); - /*! - * \brief Construct a schedule state from a PrimFunc - * \param func The PrimFunc to be scheduled. A new IRModule will be created with - * this specific PrimFunc as "main" function in the module to be scheduled - * \param debug_mode Do extra correctness checking after the class creation - * and each time after calling the Replace method. - */ - TVM_DLL explicit ScheduleState(PrimFunc func, int debug_mode = 0); + TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0); /*! \return The mutable pointer to the ScheduleStateNode */ ScheduleStateNode* get() const { return static_cast(data_.get()); } diff --git a/include/tvm/tir/schedule/trace.h b/include/tvm/tir/schedule/trace.h new file mode 100644 index 000000000000..b6b3b57226c8 --- /dev/null +++ b/include/tvm/tir/schedule/trace.h @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_TRACE_H_ +#define TVM_TIR_SCHEDULE_TRACE_H_ + +#include + +namespace tvm { +namespace tir { + +// Forward declaration +class Trace; + +/*! + * \brief A callback that allows users to mutate decisions on the fly + * when applying instructions. The signature of the callback is: + * \param inst The instruction + * \param inputs The input random variables + * \param attrs The attributes + * \param decision The original decision + * \return A new decision + */ +using FTraceDecisionProvider = runtime::TypedPackedFunc& inputs, const Array& attrs, + const Optional& decision)>; + +/*! + * \brief An execution trace of a scheduling program + * + * A trace has two parts: + * 1) The instructions invoked so far in the program execution + * 2) The random decisions made upon those instructions, if any + * + * A trace can be serialized to: + * 1) Roundtrippable JSON format: can be saved to file and loaded back + * 2) Python syntax: allows users to copy-paste the trace to reproduce the scheduling process + * + * A trace can be applied to a TensorIR schedule by re-applying all its instructions possibly with + * their decisions accordingly. Re-sampling is invoked if a sampling instruction doesn't have its + * corresponding decision; Otherwise the existing decision will be reused accordingly. + */ +class TraceNode : public runtime::Object { + public: + /*! \brief The instructions invoked so far in the program execution */ + Array insts; + /*! \brief The random decisions made upon those instructions */ + Map decisions; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("insts", &insts); + v->Visit("decisions", &decisions); + } + + static constexpr const char* _type_key = "tir.Trace"; + TVM_DECLARE_FINAL_OBJECT_INFO(TraceNode, runtime::Object); + + public: + /*! + * \brief Retrieve the decision made on a specific instruction + * \param inst The instruction whose decision is to be retrieved + * \return The corresponding decision; NullOpt if there is no decision made on the instruction + */ + Optional GetDecision(const Instruction& inst) const; + /*! + * \brief Append a new instruction to the trace + * \param inst The new instruction to be appended + */ + void Append(Instruction inst); + /*! + * \brief Append a new instruction with a random decision to the trace + * \param inst The new instruction to be appended + * \param decision The random decision made on this instruction + * The type of `decision` depends on the instruction, e.g. + * the decision of `SamplePerfectTile` has type `Array` + */ + void Append(Instruction inst, ObjectRef decision); + /*! + * \brief Remove the last instruction, along with the decision made on that instruction, if any + * \return The instruction removed; NullOpt if the trace is empty + */ + Optional Pop(); + /*! + * \brief Apply the trace to a TensorIR schedule + * \param sch The schedule to be applied onto + * \param remove_postproc If postprocessing instructions are removed + * \param decision_provider A callback that allows users to mutate decisions on the fly + * when applying instructions. + * \sa FTraceDecisionProvider + */ + void ApplyToSchedule(Schedule sch, bool remove_postproc, + FTraceDecisionProvider decision_provider = nullptr) const; + /*! + * \brief Serialize the trace as a JSON-style object + * \param remove_postproc If postprocessing instructions are removed + * \return The JSON-style object + */ + ObjectRef AsJSON(bool remove_postproc) const; + /*! + * \brief Serialize the trace as a sequence of python statements + * \param remove_postproc If postprocessing instructions are removed + * \return A sequence of python statements + */ + Array AsPython(bool remove_postproc) const; + /*! + * \brief Create a new trace with an instruction whose decision is changed, + * assuming this instruction exists in the resulting trace + * \param inst The instruction whose decision is to be changed + * \param decision The decision to be changed to + * \param remove_postproc If postprocessing instructions are removed + * \return The new trace with the decision changed + */ + Trace WithDecision(Instruction inst, ObjectRef decision, bool remove_postproc) const; + /*! + * \brief Simplify the trace with dead-code elimination + * \param remove_postproc If postprocessing instructions are removed + * \return A simplified trace + */ + Trace Simplified(bool remove_postproc) const; +}; + +/*! + * \brief Managed reference to TraceNode + * \sa TraceNode + */ +class Trace : public runtime::ObjectRef { + public: + /*! \brief Default constructor. Creating an empty trace. */ + Trace(); + /*! + * \brief Constructor. Creating a trace from existing instructions and their decisions + * \param insts The instructions used + * \param decisions The decisions made in sampling + */ + explicit Trace(Array insts, Map decisions); + /*! + * \brief Apply a JSON-serialized trace to a TensorIR schedule + * \param json The JSON-serialized trace + * \param sch The TensorIR schedule + */ + static void ApplyJSONToSchedule(ObjectRef json, Schedule sch); + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, runtime::ObjectRef, TraceNode); +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_TRACE_H_ diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index cc10c218c8ff..0da8e55be023 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -464,18 +464,22 @@ class ProducerRealizeNode : public StmtNode { PrimExpr condition; /*! \brief The body of realization. */ Stmt body; + /*! \brief The storage scope associated with this realization. */ + String storage_scope; void VisitAttrs(AttrVisitor* v) { v->Visit("producer", &producer); v->Visit("bounds", &bounds); v->Visit("condition", &condition); v->Visit("body", &body); + v->Visit("storage_scope", &storage_scope); v->Visit("span", &span); } bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const { return equal(producer, other->producer) && equal(bounds, other->bounds) && - equal(condition, other->condition) && equal(body, other->body); + equal(condition, other->condition) && equal(body, other->body) && + equal(storage_scope, other->storage_scope); } void SHashReduce(SHashReducer hash_reduce) const { @@ -483,6 +487,7 @@ class ProducerRealizeNode : public StmtNode { hash_reduce(bounds); hash_reduce(condition); hash_reduce(body); + hash_reduce(storage_scope); } static constexpr const char* _type_key = "tir.ProducerRealize"; @@ -496,7 +501,7 @@ class ProducerRealizeNode : public StmtNode { class ProducerRealize : public Stmt { public: TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - Span span = Span()); + String storage_scope = "", Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode); }; @@ -860,6 +865,7 @@ class For : public Stmt { Map annotations = Map(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode); }; /*! @@ -1235,8 +1241,6 @@ constexpr const char* extern_scope = "extern_scope"; * This can hint some code generator to create a new function for compute. */ constexpr const char* compute_scope = "compute_scope"; -/*! \brief Mark storage scope of buffers */ -constexpr const char* storage_scope = "storage_scope"; /*! \brief Mark storage alignement requirement of buffers */ constexpr const char* storage_alignment = "storage_alignment"; /*! \brief Mark storage scope of realization */ @@ -1356,6 +1360,24 @@ TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span()); // overload printing of for type. TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind); +// inline implementations +inline const char* ForKind2String(ForKind t) { + switch (t) { + case ForKind::kSerial: + return "serial"; + case ForKind::kParallel: + return "parallel"; + case ForKind::kVectorized: + return "vectorized"; + case ForKind::kUnrolled: + return "unroll"; + case ForKind::kThreadBinding: + return "thread_binding"; + } + LOG(FATAL) << "Unknown ForKind" << t; + return "Unknown"; +} + } // namespace tir } // namespace tvm #endif // TVM_TIR_STMT_H_ diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 5ee847e2f010..b5998874f7e3 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -423,6 +423,12 @@ TVM_DLL Pass CompactBufferAllocation(); */ TVM_DLL Pass LegalizePackedCalls(); +/*! + * \brief Remove match buffers inside the block. Also, it will validate the binding. + * \return The pass. + */ +TVM_DLL Pass LowerMatchBuffer(); + /*! * \brief Flatten the multi-dimensional BufferLoad and BufferStore * to single dimensional Load/Store. Also remove Block to @@ -431,6 +437,32 @@ TVM_DLL Pass LegalizePackedCalls(); */ TVM_DLL Pass FlattenBuffer(); +/* + * \brief Flatten the multi-dimensional read/write + * to two dimensional texture Load/Store and realize + * texture buffer allocations. + * + * \return The Pass + */ +TVM_DLL Pass TextureFlatten(); + +/*! + * \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and + * "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., + * "threadIdx.x") use different IterVars and variables in their AttrStmts. After the + * unification, we use a consolidated IterVar and a variable for them. + * \return The pass. + * \note `vthread` is a legacy behavior that will be deprecated, though thread bindings of `vthread` + * are still also unified in this pass. Please use `vthread.x`, `vthread.y` and `vthread.z` + * instead. + */ +TVM_DLL Pass UnifyThreadBinding(); + +/*! + * A pass to merge multiple TIR-level dynamic shared memory allocations into one + */ +TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/include/tvm/topi/detail/extern.h b/include/tvm/topi/detail/extern.h index caca1e85e520..2561f8d1ca27 100644 --- a/include/tvm/topi/detail/extern.h +++ b/include/tvm/topi/detail/extern.h @@ -48,7 +48,7 @@ using namespace tvm::te; inline Buffer DeclExternBuffer(Array shape, DataType dtype, std::string name) { auto data = var(name, DataType::Handle()); auto elem_offset = PrimExpr(); - return Buffer(data, dtype, shape, Array(), elem_offset, name, "", -1, 0, kDefault); + return Buffer(data, dtype, shape, Array(), elem_offset, name, -1, 0, kDefault); } /*! diff --git a/include/tvm/topi/detail/ravel_unravel.h b/include/tvm/topi/detail/ravel_unravel.h index dd7bcac09a04..e91d6afb666a 100644 --- a/include/tvm/topi/detail/ravel_unravel.h +++ b/include/tvm/topi/detail/ravel_unravel.h @@ -44,7 +44,9 @@ using namespace tvm::te; */ inline PrimExpr RavelIndex(Array indices, Array shape) { ICHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size"; - ICHECK_GT(indices.size(), 0) << "indices must not be empty"; + if (indices.size() == 0U) { + return 0; + } PrimExpr idx; for (size_t i = 0; i < indices.size(); ++i) { if (i == 0) { diff --git a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java index a7a03d52740e..737fdef24ae8 100644 --- a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java +++ b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java @@ -147,7 +147,7 @@ public NDArray debugGetOutput(String node, NDArray out) { if (fdebugGetOutput != null) { fdebugGetOutput.pushArg(node).pushArg(out).invoke(); } else { - throw new RuntimeException("Please compile runtime with USE_GRAPH_EXECUTOR_DEBUG = 0"); + throw new RuntimeException("Please compile runtime with USE_PROFILER = ON"); } return out; } @@ -162,7 +162,7 @@ public NDArray debugGetOutput(int node, NDArray out) { if (fdebugGetOutput != null) { fdebugGetOutput.pushArg(node).pushArg(out).invoke(); } else { - throw new RuntimeException("Please compile runtime with USE_GRAPH_EXECUTOR_DEBUG = 0"); + throw new RuntimeException("Please compile runtime with USE_PROFILER = ON"); } return out; } diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/ConnectTrackerServerProcessor.java b/jvm/core/src/main/java/org/apache/tvm/rpc/ConnectTrackerServerProcessor.java index 9811ae19afd8..10df50237628 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/ConnectTrackerServerProcessor.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/ConnectTrackerServerProcessor.java @@ -243,8 +243,8 @@ private boolean needRefreshKey() throws IOException { // handcrafted JSON private String generateCinfo(String key) { - String cinfo = "{\"key\" : " + "\"server:" + key + "\", \"addr\": [\"" - + trackerHost + "\", \"" + trackerPort + "\"]}"; + String cinfo = "{\"key\" : " + "\"server:" + key + "\", \"addr\": [null, \"" + + serverPort + "\"]}"; return "[" + RPC.TrackerCode.UPDATE_INFO + ", " + cinfo + "]"; } diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 675f8fe9b5a0..000000000000 --- a/pytest.ini +++ /dev/null @@ -1,26 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -[pytest] -markers = - gpu: mark a test as requiring a gpu - tensorcore: mark a test as requiring a tensorcore - cuda: mark a test as requiring cuda - opencl: mark a test as requiring opencl - rocm: mark a test as requiring rocm - vulkan: mark a test as requiring vulkan - metal: mark a test as requiring metal - llvm: mark a test as requiring llvm diff --git a/python/gen_requirements.py b/python/gen_requirements.py index dc338a3fcd3b..781db2cb872a 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -"""TVM Python requriements.txt generator. +"""TVM Python requirements.txt generator. This script generates a set of requirements.txt files (stored in `./requirements`) that describe TVM's Python dependencies. @@ -75,6 +75,16 @@ ], ), ), + # Provide support for Arm(R) Ethos(TM)-U NPU. + ( + "ethosu", + ( + "Requirements for using Arm(R) Ethos(TM)-U NPU", + [ + "ethos-u-vela", + ], + ), + ), # Relay frontends. ( "importer-caffe2", @@ -205,6 +215,7 @@ "docutils", "<0.17", ), # Work around https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 + ("ethos-u-vela", "==2.1.1"), ("future", None), ("image", None), ("matplotlib", None), @@ -220,7 +231,7 @@ ("sphinx_autodoc_annotation", None), ("sphinx_gallery", None), ("sphinx_rtd_theme", None), - ("synr", ">=0.2.1"), # Requires bugfix commit ee0b12a61c08f01604475f36ff37d4cb110bdc27 + ("synr", "==0.3.0"), ("tensorflow", None), ("tensorflow-estimator", None), ("tflite", None), diff --git a/python/setup.py b/python/setup.py index b47e5b14f6a7..dd13a12d8903 100644 --- a/python/setup.py +++ b/python/setup.py @@ -41,7 +41,7 @@ def get_lib_path(): """Get library path, name and version""" # We can not import `libinfo.py` in setup.py directly since __init__.py - # Will be invoked which introduces dependences + # Will be invoked which introduces dependencies libinfo_py = os.path.join(CURRENT_DIR, "./tvm/_ffi/libinfo.py") libinfo = {"__file__": libinfo_py} exec(compile(open(libinfo_py, "rb").read(), libinfo_py, "exec"), libinfo, libinfo) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index f7a5f39c829a..c212d143f987 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -222,7 +222,7 @@ def rewrite_layout_from_state(self, state): def workload_key(self): """Return the workload key of this compute DAG. - The workload key is a JSON string from a tuple of (hash-key, tensor shapes...) + The workload key is a JSON string from a tuple of (hash of DAG, tensor shapes...) Returns ------- @@ -230,12 +230,19 @@ def workload_key(self): The workload key of this compute DAG """ str_dag = _ffi_api.ComputeDAGPrintDAG(self, True) - str_dag = str_dag.encode(encoding="utf-8") - hash_key = hashlib.md5(str_dag).hexdigest() + hash_func = tvm._ffi.get_global_func( + "auto_scheduler.compute_dag.hash_func", allow_missing=True + ) + + if hash_func is None: + str_dag = str_dag.encode("utf-8") + hash_key = hashlib.md5(str_dag).hexdigest() + else: + hash_key = hash_func(str_dag) io_shapes = [] for tensor in self.tensors: - io_shapes += get_const_tuple(tensor.shape) + io_shapes.append(get_const_tuple(tensor.shape)) return json.dumps([hash_key] + io_shapes) def __str__(self): diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py index c843dcfccdf0..cc1e76b9faa8 100644 --- a/python/tvm/auto_scheduler/dispatcher.py +++ b/python/tvm/auto_scheduler/dispatcher.py @@ -53,7 +53,7 @@ def __init__(self): def query(self, target, workload_key, has_complex_op, dag, func_name): """ Query the context to get the specific config for a workload. - If cannot find the result inside this context, this function will query it + If this function cannot find the result inside this context, it will query the result from the upper contexts. Parameters diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 8d762602bfd1..b746dbf96f43 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -44,6 +44,7 @@ from tvm.ir import transform from tvm.autotvm.measure.measure_methods import set_cuda_target_arch from tvm.contrib import tar, ndk +from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor, StatusKind from tvm.target import Target @@ -374,7 +375,7 @@ class LocalRunner(ProgramRunner): i.e., When the run time of one `repeat` falls below this time, the `number` parameter will be automatically increased. cooldown_interval : float = 0.0 - The cool down interval between two measurements. + The cool down interval between two measurements in seconds. enable_cpu_cache_flush: bool = False Whether to flush cache on CPU between repeated measurements. Flushing cache can make the measured latency of one operator closer to @@ -445,7 +446,7 @@ class RPCRunner(ProgramRunner): i.e., When the run time of one `repeat` falls below this time, the `number` parameter will be automatically increased. cooldown_interval : float = 0.0 - The cool down interval between two measurements. + The cool down interval between two measurements in seconds. enable_cpu_cache_flush: bool = False Whether to flush cache on CPU between repeated measurements. Flushing cache can make the measured latency of one operator closer to @@ -524,7 +525,7 @@ class LocalRPCMeasureContext: i.e., When the run time of one `repeat` falls below this time, the `number` parameter will be automatically increased. cooldown_interval : float = 0.0 - The cool down interval between two measurements. + The cool down interval between two measurements in seconds. enable_cpu_cache_flush: bool = False Whether to flush cache on CPU between repeated measurements. Flushing cache can make the measured latency of one operator closer to @@ -599,7 +600,7 @@ class MeasureErrorNo(object): UNKNOWN_ERROR = 8 # Unknown error -def _timed_func(inp_serialized, build_func, verbose): +def _local_build_worker(inp_serialized, build_func, verbose): tic = time.time() inp = MeasureInput.deserialize(inp_serialized) task = inp.task @@ -658,23 +659,13 @@ def local_build_worker(args): res : BuildResult The build result of this Builder thread. """ - inp, build_func, timeout, verbose = args + inp, build_func, verbose = args assert build_func == BuildFunc.name, ( "BuildFunc.name: " + BuildFunc.name + ", but args is: " + build_func ) build_func = BuildFunc.build_func - res = call_func_with_timeout(timeout, _timed_func, args=(inp, build_func, verbose)) - if isinstance(res, TimeoutError): - if verbose >= 1: - print(".T", end="", flush=True) # Build timeout - res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout - elif isinstance(res, Exception): - if verbose >= 1: - print(".E", end="", flush=True) # Build error - res = None, [], MeasureErrorNo.COMPILE_HOST, str(res), timeout - - return res + return _local_build_worker(inp, build_func, verbose) @tvm._ffi.register_func("auto_scheduler.local_builder.build") @@ -701,27 +692,35 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo res : List[BuildResult] The build results of these MeasureInputs. """ - # This pool is not doing computationally intensive work, so we can use threads - pool = multiprocessing.pool.ThreadPool(n_parallel) - tuple_res = pool.map( + executor = PopenPoolExecutor(n_parallel, timeout) + tuple_res = executor.map_with_error_catching( local_build_worker, [ ( i.serialize(), build_func, - timeout, verbose, ) for i in inputs ], ) - pool.terminate() - pool.join() - del pool results = [] for res in tuple_res: - results.append(BuildResult(*res)) + if res.status == StatusKind.COMPLETE: + results.append(BuildResult(*res.value)) + elif res.status == StatusKind.TIMEOUT: + if verbose >= 1: + print(".T", end="", flush=True) # Build timeout + results.append(BuildResult(None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout)) + elif res.status == StatusKind.EXCEPTION: + if verbose >= 1: + print(".E", end="", flush=True) # Build error + results.append( + BuildResult(None, [], MeasureErrorNo.COMPILE_HOST, repr(res.value), timeout) + ) + else: + raise ValueError("Result status is not expected. Unreachable branch") return results @@ -817,9 +816,58 @@ def prepare_input_map(args): return tensor_input_map +def prepare_runner_args(inp, build_res): + """This function prepares the pre-defined arguments in `TASK_INPUT_BUFFER_TABLE` for local/rpc + runner in main process + + Parameters + ---------- + inp : MeasureInput + Measure input to be measured. + + build_res : BuildResult + Build result to be measured. + + Returns + ------- + List[Optional[numpy.ndarray]] : + List of arguments for running the program. If the argument does not have a pre-defined input + buffer, None is added to the list as a placeholder. + + """ + # pylint: disable=import-outside-toplevel + from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency + + task_input_names = inp.task.task_input_names + tensor_input_map = prepare_input_map(build_res.args) + if not task_input_names: + tensor_input_map = {} + args = [] + task_inputs_count = 0 + for arg in build_res.args: + if arg in tensor_input_map: + tensor_name = tensor_input_map[arg] + if tensor_name in task_input_names: + task_input_buffer = get_task_input_buffer(inp.task.workload_key, tensor_name) + # convert tvm.NDArray to picklable numpy.ndarray + args.append(task_input_buffer.numpy()) + task_inputs_count += 1 + else: + raise ValueError( + "%s not found in task_inputs, " % (tensor_name) + + "should provide with `SearchTask(..., task_inputs={...})`" + ) + else: + args.append(None) + if task_inputs_count != len(task_input_names): + raise RuntimeError("task_inputs not fully matched, check if there's any unexpected error") + return args + + def _timed_eval_func( inp_serialized, build_res, + args, number, repeat, min_repeat_ms, @@ -827,11 +875,7 @@ def _timed_eval_func( enable_cpu_cache_flush, verbose, ): - # pylint: disable=import-outside-toplevel - from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency - inp = MeasureInput.deserialize(inp_serialized) - task_input_names = inp.task.task_input_names tic = time.time() error_no = 0 error_msg = None @@ -862,33 +906,18 @@ def _timed_eval_func( try: random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True) assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" - - tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {} - args = [] - task_inputs_count = 0 - for arg in build_res.args: - if arg in tensor_input_map: - tensor_name = tensor_input_map[arg] - if tensor_name in task_input_names: - args.append( - ndarray.array( - get_task_input_buffer(inp.task.workload_key, tensor_name), dev - ) - ) - task_inputs_count += 1 - else: - raise ValueError( - "%s not found in task_inputs, " % (tensor_name) - + "should provide with `SearchTask(..., task_inputs={...})`" - ) - else: - empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev) + assert len(args) == len(build_res.args) + # pylint: disable=consider-using-enumerate + for idx in range(len(args)): + if args[idx] is None: + build_res_arg = build_res.args[idx] + empty_array = ndarray.empty( + get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev + ) random_fill(empty_array) - args.append(empty_array) - if task_inputs_count != len(task_input_names): - raise RuntimeError( - "task_inputs not fully matched, check if there's any unexpected error" - ) + args[idx] = empty_array + else: + args[idx] = ndarray.array(args[idx], dev) dev.sync() costs = time_f(*args).results # pylint: disable=broad-except @@ -950,7 +979,7 @@ def local_run( i.e., When the run time of one `repeat` falls below this time, the `number` parameter will be automatically increased. cooldown_interval : float = 0.0 - The cool down interval between two measurements. + The cool down interval between two measurements in seconds. enable_cpu_cache_flush: bool = False Whether to flush cache on CPU between repeated measurements. Flushing cache can make the measured latency of one operator closer to @@ -968,6 +997,7 @@ def local_run( measure_results = [] assert len(inputs) == len(build_results), "Measure input size should be equal to build results" + worker = PopenWorker() for inp, build_res in zip(inputs, build_results): if build_res.error_no != 0: res = ( @@ -978,12 +1008,15 @@ def local_run( time.time(), ) else: + args = prepare_runner_args(inp, build_res) res = call_func_with_timeout( + worker, timeout, _timed_eval_func, args=( inp.serialize(), build_res, + args, number, repeat, min_repeat_ms, @@ -991,7 +1024,6 @@ def local_run( enable_cpu_cache_flush, verbose, ), - add_thread_wrapper=True, ) if isinstance(res, TimeoutError): if verbose >= 1: @@ -1022,9 +1054,10 @@ def local_run( return measure_results -def _timed_rpc_run( +def _rpc_run( inp_serialized, build_res, + args, key, host, port, @@ -1037,11 +1070,7 @@ def _timed_rpc_run( enable_cpu_cache_flush, verbose, ): - # pylint: disable=import-outside-toplevel - from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency - inp = MeasureInput.deserialize(inp_serialized) - task_input_names = inp.task.task_input_names tic = time.time() error_no = 0 error_msg = None @@ -1080,32 +1109,18 @@ def _timed_rpc_run( random_fill ), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices" - tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {} - args = [] - task_inputs_count = 0 - for arg in build_res.args: - if arg in tensor_input_map: - tensor_name = tensor_input_map[arg] - if tensor_name in task_input_names: - args.append( - ndarray.array( - get_task_input_buffer(inp.task.workload_key, tensor_name), dev - ) - ) - task_inputs_count += 1 - else: - raise ValueError( - "%s not found in task_inputs, " % (tensor_name) - + "should provide with `SearchTask(..., task_inputs={...})`" - ) - else: - empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev) + assert len(args) == len(build_res.args) + # pylint: disable=consider-using-enumerate + for idx in range(len(args)): + if args[idx] is None: + build_res_arg = build_res.args[idx] + empty_array = ndarray.empty( + get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev + ) random_fill(empty_array) - args.append(empty_array) - if task_inputs_count != len(task_input_names): - logger.warning( - "task_inputs not fully matched, check if there's any unexpected error" - ) + args[idx] = empty_array + else: + args[idx] = ndarray.array(args[idx], dev) dev.sync() # First run for check that the kernel is correct @@ -1152,7 +1167,7 @@ def _rpc_run_worker(args): res : MeasureResult The measure result of this Runner thread. """ - _, build_res, _, _, _, _, timeout, _, _, _, _, _, verbose = args + _, build_res, _, _, _, _, _, timeout, _, _, _, _, _, verbose = args if build_res.error_no != MeasureErrorNo.NO_ERROR: return ( (MAX_FLOAT,), @@ -1162,24 +1177,16 @@ def _rpc_run_worker(args): time.time(), ) - res = call_func_with_timeout(timeout, _timed_rpc_run, args=args) - if isinstance(res, TimeoutError): - if verbose >= 1: - print("*T", end="") # Run timeout - res = ( - (MAX_FLOAT,), - MeasureErrorNo.RUN_TIMEOUT, - None, - build_res.time_cost + timeout, - time.time(), - ) - elif isinstance(res, Exception): + try: + res = _rpc_run(*args) + # pylint: disable=broad-except + except Exception: if verbose >= 1: print("*E", end="") # Run error res = ( (MAX_FLOAT,), MeasureErrorNo.RUNTIME_DEVICE, - str(res), + make_traceback_info(), build_res.time_cost + timeout, time.time(), ) @@ -1242,7 +1249,7 @@ def rpc_runner_run( i.e., When the run time of one `repeat` falls below this time, the `number` parameter will be automatically increased. cooldown_interval : float = 0.0 - The cool down interval between two measurements. + The cool down interval between two measurements in seconds. enable_cpu_cache_flush: bool = False Whether to flush cache on CPU between repeated measurements. Flushing cache can make the measured latency of one operator closer to @@ -1259,13 +1266,14 @@ def rpc_runner_run( """ assert len(inputs) == len(build_results), "Measure input size should be equal to build results" # This pool is not doing computationally intensive work, so we can use threads - pool = multiprocessing.pool.ThreadPool(n_parallel) - tuple_res = pool.map( + executor = PopenPoolExecutor(n_parallel) + tuple_res = executor.map_with_error_catching( _rpc_run_worker, [ ( inp.serialize(), build_res, + prepare_runner_args(inp, build_res), key, host, port, @@ -1281,13 +1289,25 @@ def rpc_runner_run( for inp, build_res in zip(inputs, build_results) ], ) - pool.terminate() - pool.join() - del pool results = [] - for res in tuple_res: - results.append(MeasureResult(*res)) + for i, res in enumerate(tuple_res): + if res.status == StatusKind.COMPLETE: + results.append(MeasureResult(*res.value)) + else: + assert res.status == StatusKind.TIMEOUT + if verbose >= 1: + print("*T", end="") # Run timeout + build_res = build_results[i] + results.append( + MeasureResult( + (MAX_FLOAT,), + MeasureErrorNo.RUN_TIMEOUT, + None, + build_res.time_cost + timeout, + time.time(), + ) + ) if verbose >= 1: print("") diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 4b402a916267..8b68f4e9002a 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -22,6 +22,7 @@ 2. Provide auto-scheduling for all TOPI compute functions """ +import json import logging import threading from copy import deepcopy @@ -30,11 +31,10 @@ from tvm import autotvm, transform from tvm.ir.transform import PassContext from tvm.runtime import convert_to_object - +from tvm.target import Target from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor from tvm.tir import Reduce from tvm.tir import expr as _expr -from tvm.target import Target from . import _ffi_api from .compute_dag import ComputeDAG, LayoutRewriteOption @@ -97,6 +97,7 @@ def extract_tasks( target_host=None, hardware_params=None, include_simple_tasks=False, + dump_workload_to_dag_log=None, opt_level=3, ): """Extract tuning tasks from a relay program. @@ -115,6 +116,8 @@ def extract_tasks( Hardware parameters used for the search tasks include_simple_tasks: bool Whether to extract simple tasks that do not include complicated ops. + dump_workload_to_dag_log: Optional[str] + A file to dump an association between the workload keys and the actual DAG opt_level : Optional[int] The optimization level of the task extractions. @@ -150,7 +153,7 @@ def extract_tasks( # create search tasks tasks = [] weights = [] - for (func_name, wkl_key), weight in env.wkl_key_to_weight.items(): + for wkl_key, (weight, func_names) in env.wkl_key_to_weight.items(): tasks.append( SearchTask( workload_key=wkl_key, @@ -165,11 +168,15 @@ def extract_tasks( else None ), task_inputs_save_to_file=True, - desc=func_name, + desc=",".join(func_names), ) ) weights.append(weight) + if dump_workload_to_dag_log is not None: + with open(dump_workload_to_dag_log, "w") as f: + json.dump({task.workload_key: str(task.compute_dag) for task in tasks}, f) + return tasks, weights @@ -189,6 +196,7 @@ class TracingEnvironment: def __init__(self, tracing_mode): self.tracing_mode = tracing_mode self.relay_disable_build_cache = "false" + self.func_name_to_wkl_key = {} self.wkl_key_to_weight = {} self.wkl_key_to_input_names = {} @@ -210,10 +218,12 @@ def add_workload_key(self, func_name, workload_key): workload_key: str The workload key of a task. """ - key = (func_name, workload_key) - if key not in self.wkl_key_to_weight: - self.wkl_key_to_weight[key] = 0 - self.wkl_key_to_weight[key] += 1 + self.func_name_to_wkl_key[func_name] = workload_key + if workload_key not in self.wkl_key_to_weight: + self.wkl_key_to_weight[workload_key] = (0, set()) + weight, func_names = self.wkl_key_to_weight[workload_key] + func_names.add(func_name) + self.wkl_key_to_weight[workload_key] = (weight + 1, func_names) def add_workload_input_names(self, workload_key, input_names): """Add special task inputs to this workload. @@ -379,11 +389,37 @@ def auto_schedule_topi(func_name, outs): @tvm._ffi.register_func("auto_scheduler.relay_integration.te_compiler_update_weights") def te_compiler_update_weights(function_weights): - """A callback for updating the weights of extracted tasks.""" + """A callback for updating the weights of extracted tasks. When using the TE compiler + that avoids compiling the same function multiple times by caching, all extracted tasks + have weight 1, so the TE compiler invokes this callback at the end. In this case, + we override existing weights with the use_count in TE compiler cache. + + Parameters + ---------- + function_weights: Dict[str, int] + Mapping from function names to their weights. + """ env = TracingEnvironment.current if env is not None: - for key in env.wkl_key_to_weight: - env.wkl_key_to_weight[key] = function_weights[key[0]] + # Override this map with the weights in the TE compiler. + env.wkl_key_to_weight = {} + + for func_name, weight in function_weights.items(): + # If the function name is not in the map, then it means we are not interested in + # this function during task extraction (e.g., a function without reduction). + if func_name not in env.func_name_to_wkl_key: + continue + + workload_key = env.func_name_to_wkl_key[func_name] + if workload_key not in env.wkl_key_to_weight: + env.wkl_key_to_weight[workload_key] = (0, set()) + + # Note that the function appears multiple times in a model will be renamed + # to make sure function names are unique, so we use the workload key generated + # from the function's TE compute to determine their weights. + old_weight, func_names = env.wkl_key_to_weight[workload_key] + func_names.add(func_name) + env.wkl_key_to_weight[workload_key] = (old_weight + weight, func_names) def tensor_no_check_call(self, *indices): diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 023fdc770a30..9b975063105f 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -629,7 +629,7 @@ def __init__(self, log_file): def post_tune(self, task_scheduler, task_id): if all(cost < 1e9 for cost in task_scheduler.best_costs): - total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3) + total_latency_str = "%.3f" % (task_scheduler.cur_score.value * 1e3) else: total_latency_str = "N/A" diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index 1c03491c5614..9919bcb470ee 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -20,9 +20,6 @@ from typing import Hashable import json -import multiprocessing -import multiprocessing.pool -import queue import signal import threading import traceback @@ -289,41 +286,15 @@ def wrapper(): return res[0] -def _func_wrapper(que, func, args, kwargs, add_thread_wrapper): - """Call function and return the result over the queue.""" - try: - if add_thread_wrapper: - # Add a new layer of threadinng to avoid the conflict between - # python's multiprocessing and tvm's thread pool. - res = call_func_with_thread(func, args, kwargs) - else: - res = func(*args, **kwargs) - que.put(res) - except Exception: # pylint: disable=broad-except - que.put(Exception(make_traceback_info())) - - -def call_func_with_timeout(timeout, func, args=(), kwargs=None, add_thread_wrapper=False): +def call_func_with_timeout( + worker, timeout, func, args=(), kwargs=None +): # pylint: disable=unused-argument """Call a function with timeout""" - que = multiprocessing.Queue(2) - process = multiprocessing.Process( - target=_func_wrapper, args=(que, func, args, kwargs or {}, add_thread_wrapper) - ) - process.start() - + worker.send(func, args, kwargs, timeout) try: - res = que.get(timeout=timeout) - except queue.Empty: - res = TimeoutError() - - # clean queue and process - kill_child_processes(process.pid) - process.terminate() - process.join() - que.close() - que.join_thread() - del process - del que + res = worker.recv() + except Exception: # pylint: disable=broad-except + res = Exception(make_traceback_info()) return res diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index cd8f8c9d1a3e..885eb0d1d0f8 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -245,7 +245,9 @@ def deserialize_workload_registry_entry(data): name, value = data if name not in WORKLOAD_FUNC_REGISTRY: # pylint: disable=assignment-from-no-return - WORKLOAD_FUNC_REGISTRY[name] = LoadJSON(value) + if not callable(value): + value = LoadJSON(value) + WORKLOAD_FUNC_REGISTRY[name] = value def save_workload_func_registry(filename): diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index 780e6c9a7477..beb1aa03090d 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -375,6 +375,7 @@ def benchmark_layout_transform( layout_records=None, target_host=None, infer_layout=False, + runner=None, ): """Benchmark all possible layout transformation in the graph, given a set of schedule candidates for each workload of target operator. @@ -438,6 +439,8 @@ def benchmark_layout_transform( of benchmarking on target device. This might bring performance loss comparing to benchmarking layout transformation. + runner : Runner, optional + Accept a user-supplied runner """ self._logger.info("Start to benchmark layout transformation...") self._target, target_host = Target.check_and_update_host_consist(self._target, target_host) @@ -483,7 +486,6 @@ def _callback(_, inputs, results): return _callback builder = autotvm.LocalBuilder(n_parallel=n_parallel, build_func=build_func) - runner = autotvm.LocalRunner(number=min_exec_num, repeat=1, timeout=timeout) if use_rpc: if device_key is None: raise RuntimeError("device_key need to be set to use rpc tracker mode.") @@ -496,6 +498,8 @@ def _callback(_, inputs, results): repeat=1, timeout=timeout, ) + elif not runner: + runner = autotvm.LocalRunner(number=min_exec_num, repeat=1, timeout=timeout) measure_option = autotvm.measure_option(builder=builder, runner=runner) for args in args_list: data, in_layout, out_layout = args diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 3de25cb6100b..eab6822b63b8 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -32,6 +32,7 @@ import typing from collections import namedtuple from random import getrandbits +import warnings import tvm._ffi import tvm.ir.transform @@ -235,6 +236,7 @@ def __init__( self.number = number self.repeat = repeat self.min_repeat_ms = min_repeat_ms + self._ref_input = None self.enable_cpu_cache_flush = enable_cpu_cache_flush self.cooldown_interval = cooldown_interval @@ -242,6 +244,26 @@ def __init__( self.executor = LocalExecutor(timeout=timeout * (self.n_parallel + 1)) + @property + def ref_input(self): + """ + Fixed input for tuning special operators, e.g., sparse operators + requiring indices as input. + """ + return self._ref_input + + @ref_input.setter + def ref_input(self, val): + if val is not None: + warnings.warn( + "You are specifying fixed input for tuning the operator. " + "Be sure your input always fits the operator. Some " + "operators may conduct layout transformation during tuning, " + "thus can lead to unexpected behaviors. ", + RuntimeWarning, + ) + self._ref_input = val + def set_task(self, task): self.task = task @@ -308,6 +330,7 @@ def run(self, measure_inputs, build_results): self.min_repeat_ms, self.cooldown_interval, remote_kwargs, + self.ref_input, self.enable_cpu_cache_flush, module_loader, ) @@ -508,6 +531,7 @@ def run_through_rpc( min_repeat_ms, cooldown_interval, remote_kwargs, + ref_input, enable_cpu_cache_flush=False, module_loader=None, ): @@ -539,6 +563,8 @@ def run_through_rpc( The cool down interval between two measurements remote_kwargs: dict Passed to module_loader(). Ultimately, keyword args to request_remote(). + ref_input: List of np.ndarray + The reference input used for tuning. Empty for randomly filled input. enable_cpu_cache_flush: bool Whether to flush cache on CPU between repeated measurements. Flushing cache can make the measured latency of one operator closer to @@ -573,18 +599,22 @@ def run_through_rpc( f_preproc=f_prepare, ) - try: - random_fill = remote.get_function("tvm.contrib.random.random_fill") - except AttributeError: - raise AttributeError( - "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices" - ) - args = [nd.empty(x[0], x[1], dev) for x in build_result.arg_info] - if "scatter" not in measure_input.task.name: - # the index tensor of scatter op cannot be randomly initialized - for arg in args: - random_fill(arg) - dev.sync() + if ref_input: + args = [nd.array(x, device=dev) for x in ref_input] + else: + try: + random_fill = remote.get_function("tvm.contrib.random.random_fill") + except AttributeError: + raise AttributeError( + "Please make sure USE_RANDOM is ON in the config.cmake " + "on the remote devices" + ) + args = [nd.empty(x[0], x[1], dev) for x in build_result.arg_info] + if "scatter" not in measure_input.task.name: + # the index tensor of scatter op cannot be randomly initialized + for arg in args: + random_fill(arg) + dev.sync() costs = time_f(*args).results diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py index 4f11aea2911f..8145563f5075 100644 --- a/python/tvm/autotvm/record.py +++ b/python/tvm/autotvm/record.py @@ -21,7 +21,6 @@ import argparse import base64 import logging -import multiprocessing import pickle import json import time @@ -32,6 +31,7 @@ from .. import build, lower from ..target import Target +from ..contrib import popen_pool from .. import __version__ from . import task from .task import ConfigEntity, ApplyHistoryBest @@ -230,7 +230,7 @@ def split_workload(in_file, clean=True): lines = list(open(in_file).readlines()) logger.info("start converting...") - pool = multiprocessing.Pool() + pool = popen_pool.PopenPoolExecutor() lines = [rec for rec in pool.map(decode, lines) if rec is not None] logger.info("map done %.2f", time.time() - tic) diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py index afbfb4c03988..8a707b872113 100644 --- a/python/tvm/autotvm/task/space.py +++ b/python/tvm/autotvm/task/space.py @@ -824,7 +824,10 @@ def valid(self): def _add_new_transform(self, space_class, name, axes, policy, **kwargs): """Add a new transform space in template""" - if self._collect: + # if we do not have tuned info (_collect == True) but defined KNOB value + # for "default" scheduling before call of _add_new_transform, in this case + # no need to create new space and override previously pointed KNOB values + if self._collect and not (self.is_fallback and name in self._entity_map): # convert schedule axis to space definition axis axes = [x if isinstance(x, (VirtualAxis, Axis)) else self.axis(x) for x in axes] diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index 5c0a46336532..f438bc197afe 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -178,7 +178,7 @@ def download_package(tophub_location, package_name): download_url = "{0}/{1}".format(tophub_location, package_name) logger.info("Download pre-tuned parameters package from %s", download_url) - download(download_url, Path(rootpath, package_name), True, verbose=0) + download(download_url, Path(rootpath, package_name), overwrite=True) # global cache for load_reference_log diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index 81904354c5fd..99972ee3d74e 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -17,12 +17,13 @@ # pylint: disable=invalid-name """XGBoost as cost model""" -import multiprocessing import logging import time import numpy as np +from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind + from .. import feature from ..utils import get_rank from .metric import max_curve, recall_curve, cover_curve @@ -153,20 +154,14 @@ def _reset_pool(self, space, target, task): self._close_pool() - # Use global variable to pass common arguments. This is only used when - # new processes are started with fork. We have to set the globals - # before we create the pool, so that processes in the pool get the - # correct globals. - global _extract_space, _extract_target, _extract_task - _extract_space = space - _extract_target = target - _extract_task = task - self.pool = multiprocessing.Pool(self.num_threads) + self.pool = PopenPoolExecutor( + max_workers=self.num_threads, + initializer=_extract_popen_initializer, + initargs=(space, target, task), + ) def _close_pool(self): if self.pool: - self.pool.terminate() - self.pool.join() self.pool = None def _get_pool(self): @@ -247,13 +242,16 @@ def fit_log(self, records, plan_size, min_seed_records=500): feature_extract_func = _extract_curve_feature_log else: raise RuntimeError("Invalid feature type: " + self.fea_type) - res = pool.map(feature_extract_func, data) + result = pool.map_with_error_catching(feature_extract_func, data) # filter out feature with different shapes fea_len = len(self._get_feature([0])[0]) xs, ys = [], [] - for x, y in res: + for res in result: + if res.status != StatusKind.COMPLETE: + continue + x, y = res.value if len(x) == fea_len: xs.append(x) ys.append(y) @@ -327,14 +325,9 @@ def _get_feature(self, indexes): if need_extract: pool = self._get_pool() - # If we are forking, we can pass arguments in globals for better performance - if multiprocessing.get_start_method(False) == "fork": - feas = pool.map(self.feature_extract_func, need_extract) - else: - args = [(self.space.get(x), self.target, self.task) for x in need_extract] - feas = pool.map(self.feature_extract_func, args) + feas = pool.map_with_error_catching(self.feature_extract_func, need_extract) for i, fea in zip(need_extract, feas): - fea_cache[i] = fea + fea_cache[i] = fea.value if fea.status == StatusKind.COMPLETE else None feature_len = None for idx in indexes: @@ -358,17 +351,20 @@ def __del__(self): _extract_task = None +def _extract_popen_initializer(space, target, task): + global _extract_space, _extract_target, _extract_task + _extract_space = space + _extract_target = target + _extract_task = task + + def _extract_itervar_feature_index(args): """extract iteration var feature for an index in extract_space""" try: - if multiprocessing.get_start_method(False) == "fork": - config = _extract_space.get(args) - with _extract_target: - sch, fargs = _extract_task.instantiate(config) - else: - config, target, task = args - with target: - sch, fargs = task.instantiate(config) + config = _extract_space.get(args) + with _extract_target: + sch, fargs = _extract_task.instantiate(config) + fea = feature.get_itervar_feature_flatten(sch, fargs, take_log=True) fea = np.concatenate((fea, list(config.get_other_option().values()))) return fea @@ -398,10 +394,9 @@ def _extract_itervar_feature_log(arg): def _extract_knob_feature_index(args): """extract knob feature for an index in extract_space""" try: - if multiprocessing.get_start_method(False) == "fork": - config = _extract_space.get(args) - else: - config = args[0] + + config = _extract_space.get(args) + return config.get_flatten_feature() except Exception: # pylint: disable=broad-except return None @@ -428,14 +423,11 @@ def _extract_knob_feature_log(arg): def _extract_curve_feature_index(args): """extract sampled curve feature for an index in extract_space""" try: - if multiprocessing.get_start_method(False) == "fork": - config = _extract_space.get(args) - with _extract_target: - sch, fargs = _extract_task.instantiate(config) - else: - config, target, task = args - with target: - sch, fargs = task.instantiate(config) + + config = _extract_space.get(args) + with _extract_target: + sch, fargs = _extract_task.instantiate(config) + fea = feature.get_buffer_curve_sample_flatten(sch, fargs, sample_n=20) fea = np.concatenate((fea, list(config.get_other_option().values()))) return np.array(fea) diff --git a/python/tvm/autotvm/utils.py b/python/tvm/autotvm/utils.py index fa1dcfd1241b..ec3f18daa6c9 100644 --- a/python/tvm/autotvm/utils.py +++ b/python/tvm/autotvm/utils.py @@ -17,7 +17,6 @@ # pylint: disable=invalid-name """Utilities""" import logging -import multiprocessing import time from random import randrange @@ -25,6 +24,7 @@ import numpy as np import tvm.arith from tvm.tir import expr +from tvm.contrib.popen_pool import PopenPoolExecutor logger = logging.getLogger("autotvm") @@ -111,7 +111,7 @@ def pool_map(func, args, batch_size, verbose=False, pool=None): ret = None tic = time.time() - local_pool = pool or multiprocessing.Pool() + local_pool = pool or PopenPoolExecutor() if verbose: logger.info("mapping begin") for i in range(0, len(args), batch_size): diff --git a/python/tvm/contrib/debugger/debug_executor.py b/python/tvm/contrib/debugger/debug_executor.py index dc043353c475..fc3b245d88ad 100644 --- a/python/tvm/contrib/debugger/debug_executor.py +++ b/python/tvm/contrib/debugger/debug_executor.py @@ -64,8 +64,7 @@ def create(graph_json_str, libmod, device, dump_root=None): fcreate = tvm._ffi.get_global_func("tvm.graph_executor_debug.create") except ValueError: raise ValueError( - "Please set '(USE_GRAPH_EXECUTOR_DEBUG ON)' in " - "config.cmake and rebuild TVM to enable debug mode" + "Please set '(USE_PROFILER ON)' in " "config.cmake and rebuild TVM to enable debug mode" ) func_obj = fcreate(graph_json_str, libmod, *device_type_id) return GraphModuleDebug(func_obj, dev, graph_json_str, dump_root) @@ -268,23 +267,28 @@ def run_individual(self, number, repeat=1, min_repeat_ms=0): ret = self._run_individual(number, repeat, min_repeat_ms) return ret.strip(",").split(",") if ret else [] - def profile(self, **input_dict): + def profile(self, collectors=None, **input_dict): """Run forward execution of the graph and collect overall and per-op performance metrics. Parameters ---------- + collectors : Optional[Sequence[MetricCollector]] + Extra metrics to collect. + input_dict : dict of str to NDArray List of input values to be feed to + Return ------ timing_results : str Per-operator and whole graph timing results in a table format. """ + collectors = [] if collectors is None else collectors if input_dict: self.set_input(**input_dict) - return self._profile() + return self._profile(collectors) def exit(self): """Exits the dump folder and all its contents""" diff --git a/python/tvm/contrib/download.py b/python/tvm/contrib/download.py index f7c68a99229e..e0c13acc8de9 100644 --- a/python/tvm/contrib/download.py +++ b/python/tvm/contrib/download.py @@ -15,15 +15,18 @@ # specific language governing permissions and limitations # under the License. """Helper utility for downloading""" + +import logging +import os from pathlib import Path -from os import environ -import sys -import time -import uuid import shutil +import tempfile +import time + +LOG = logging.getLogger("download") -def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=3): +def download(url, path, overwrite=False, size_compare=False, retries=3): """Downloads the file from the internet. Set the input options correctly to overwrite or do the size comparison @@ -33,19 +36,18 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1, retries= Download url. path : str - Local file path to save downloaded file + Local file path to save downloaded file. overwrite : bool, optional - Whether to overwrite existing file + Whether to overwrite existing file, defaults to False. size_compare : bool, optional - Whether to do size compare to check downloaded file. - - verbose: int, optional - Verbose level + Whether to do size compare to check downloaded file, defaults + to False retries: int, optional - Number of time to retry download, default at 3. + Number of time to retry download, defaults to 3. + """ # pylint: disable=import-outside-toplevel import urllib.request as urllib2 @@ -62,21 +64,19 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1, retries= res_get = urllib2.urlopen(url) url_file_size = int(res_get.headers["Content-Length"]) if url_file_size != file_size: - print("exist file got corrupted, downloading %s file freshly..." % path) - download(url, path, True, False) + LOG.warning("Existing file %s has incorrect size, downloading fresh copy", path) + download(url, path, overwrite=True, size_compare=False, retries=retries) return - print("File {} exists, skip.".format(path)) + + LOG.info("File %s exists, skipping.", path) return - if verbose >= 1: - print("Downloading from url {} to {}".format(url, path)) + LOG.info("Downloading from url %s to %s", url, path) # Stateful start time start_time = time.time() dirpath = path.parent dirpath.mkdir(parents=True, exist_ok=True) - random_uuid = str(uuid.uuid4()) - tempfile = Path(dirpath, random_uuid) def _download_progress(count, block_size, total_size): # pylint: disable=unused-argument @@ -84,44 +84,53 @@ def _download_progress(count, block_size, total_size): if count == 0: return duration = time.time() - start_time - progress_size = int(count * block_size) - speed = int(progress_size / (1024 * duration)) + progress_bytes = int(count * block_size) + progress_megabytes = progress_bytes / (1024.0 * 1024) + speed_kbps = int(progress_bytes / (1024 * duration)) percent = min(int(count * block_size * 100 / total_size), 100) - sys.stdout.write( - "\r...%d%%, %.2f MB, %d KB/s, %d seconds passed" - % (percent, progress_size / (1024.0 * 1024), speed, duration) + + # Temporarily suppress newlines on the output stream. + prev_terminator = logging.StreamHandler.terminator + logging.StreamHandler.terminator = "" + LOG.debug( + "\r...%d%%, %.2f MB, %d KB/s, %d seconds passed", + percent, + progress_megabytes, + speed_kbps, + duration, ) - sys.stdout.flush() - - while retries >= 0: - # Disable pyling too broad Exception - # pylint: disable=W0703 - try: - if sys.version_info >= (3,): - urllib2.urlretrieve(url, tempfile, reporthook=_download_progress) - print("") - else: - f = urllib2.urlopen(url) - data = f.read() - with open(tempfile, "wb") as code: - code.write(data) - shutil.move(tempfile, path) - break - except Exception as err: - retries -= 1 - if retries == 0: - if tempfile.exists(): - tempfile.unlink() - raise err - print( - "download failed due to {}, retrying, {} attempt{} left".format( - repr(err), retries, "s" if retries > 1 else "" + logging.StreamHandler.terminator = prev_terminator + + with tempfile.TemporaryDirectory() as tempdir: + tempdir = Path(tempdir) + download_loc = tempdir.joinpath(path.name) + + for i_retry in range(retries): + # pylint: disable=broad-except + try: + + urllib2.urlretrieve(url, download_loc, reporthook=_download_progress) + LOG.debug("") + try: + download_loc.rename(path) + except OSError: + # Prefer a move, but if the tempdir and final + # location are in different drives, fall back to a + # copy. + shutil.copy2(download_loc, path) + return + + except Exception as err: + if i_retry == retries - 1: + raise err + + LOG.warning( + "%s\nDownload attempt %d/%d failed, retrying.", repr(err), i_retry, retries ) - ) -if "TEST_DATA_ROOT_PATH" in environ: - TEST_DATA_ROOT_PATH = Path(environ.get("TEST_DATA_ROOT_PATH")) +if "TEST_DATA_ROOT_PATH" in os.environ: + TEST_DATA_ROOT_PATH = Path(os.environ.get("TEST_DATA_ROOT_PATH")) else: TEST_DATA_ROOT_PATH = Path(Path("~").expanduser(), ".tvm_test_data") TEST_DATA_ROOT_PATH.mkdir(parents=True, exist_ok=True) @@ -141,10 +150,16 @@ def download_testdata(url, relpath, module=None, overwrite=False): module : Union[str, list, tuple], optional Subdirectory paths under test data folder. + overwrite : bool, defaults to False + If True, will download a fresh copy of the file regardless of + the cache. If False, will only download the file if a cached + version is missing. + Returns ------- abspath : str Absolute file path of downloaded file + """ global TEST_DATA_ROOT_PATH if module is None: diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index a4bc85905f5e..f064f8dbee69 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -157,6 +157,7 @@ def __init__(self, module): self._get_output = module["get_output"] self._get_input = module["get_input"] self._get_num_outputs = module["get_num_outputs"] + self._get_input_index = module["get_input_index"] self._get_num_inputs = module["get_num_inputs"] self._load_params = module["load_params"] self._share_params = module["share_params"] @@ -242,6 +243,21 @@ def get_input(self, index, out=None): return self._get_input(index) + def get_input_index(self, name): + """Get inputs index via input name. + + Parameters + ---------- + name : str + The input key name + + Returns + ------- + index: int + The input index. -1 will be returned if the given input name is not found. + """ + return self._get_input_index(name) + def get_output(self, index, out=None): """Get index-th output to out @@ -304,3 +320,90 @@ def __getitem__(self, key): The key to the module. """ return self.module[key] + + def benchmark( + self, + device, + func_name="run", + repeat=5, + number=5, + min_repeat_ms=None, + end_to_end=False, + **kwargs, + ): + """Calculate runtime of a function by repeatedly calling it. + + Use this function to get an accurate measurement of the runtime of a function. The function + is run multiple times in order to account for variability in measurements, processor speed + or other external factors. Mean, median, standard deviation, min and max runtime are all + reported. On GPUs, CUDA and ROCm specifically, special on-device timers are used so that + synchonization and data transfer operations are not counted towards the runtime. This allows + for fair comparison of runtimes across different functions and models. The `end_to_end` flag + switches this behavior to include data transfer operations in the runtime. + + The benchmarking loop looks approximately like so: + + .. code-block:: python + + for r in range(repeat): + time_start = now() + for n in range(number): + func_name() + time_end = now() + total_times.append((time_end - time_start)/number) + + + Parameters + ---------- + func_name : str + The function to benchmark. This is ignored if `end_to_end` is true. + + repeat : int + Number of times to run the outer loop of the timing code (see above). The output will + contain `repeat` number of datapoints. + + number : int + Number of times to run the inner loop of the timing code. This inner loop is run in + between the timer starting and stopping. In order to amortize any timing overhead, + `number` should be increased when the runtime of the function is small (less than a 1/10 + of a millisecond). + + min_repeat_ms : Optional[float] + If set, the inner loop will be run until it takes longer than `min_repeat_ms` + milliseconds. This can be used to ensure that the function is run enough to get an + accurate measurement. + + end_to_end : bool + If set, include time to transfer input tensors to the device and time to transfer + returned tensors in the total runtime. This will give accurate timings for end to end + workloads. + + kwargs : Dict[str, Object] + Named arguments to the function. These are cached before running timing code, so that + data transfer costs are not counted in the runtime. + + Returns + ------- + timing_results : BenchmarkResult + Runtimes of the function. Use `.mean` to access the mean runtime, use `.results` to + access the individual runtimes (in seconds). + """ + min_repeat_ms = 0 if min_repeat_ms is None else min_repeat_ms + if end_to_end: + # Have to unpack kwargs into a single list + args = [] + for k, v in kwargs.items(): + args.append(k) + args.append(v) + return self.module.time_evaluator( + "run_from_inputs", + device, + repeat=repeat, + number=number, + min_repeat_ms=min_repeat_ms, + )(device.device_type, device.device_id, *args) + if kwargs: + self.set_input(**kwargs) + return self.module.time_evaluator( + func_name, device, repeat=repeat, number=number, min_repeat_ms=min_repeat_ms + )() diff --git a/python/tvm/contrib/hexagon.py b/python/tvm/contrib/hexagon.py index 34b37537776f..6364ef749dd9 100644 --- a/python/tvm/contrib/hexagon.py +++ b/python/tvm/contrib/hexagon.py @@ -176,23 +176,26 @@ def buf_align(var): def visit(stmt): """Collect information about VTCM buffers and their alignments.""" if isinstance(stmt, tvm.tir.AttrStmt): - if stmt.attr_key == "storage_scope" and stmt.value == "local.vtcm": - vtcm_buffers.append(stmt.node) - elif stmt.attr_key == "storage_alignment": + if stmt.attr_key == "storage_alignment": if not stmt.node in alignments: alignments[stmt.node] = [] alignments[stmt.node].append(stmt.value) + elif isinstance(stmt, tvm.tir.Allocate): + scope = stmt.buffer_var.type_annotation.storage_scope + if scope == "local.vtcm": + vtcm_buffers.append(stmt.buffer_var) def mutate(stmt): """Insert calls to VTCM allocation and deallocation routines.""" if isinstance(stmt, tvm.tir.AttrStmt): - if stmt.attr_key == "storage_scope" and stmt.value == "local.vtcm": - vtcm_buffers.pop() - elif stmt.attr_key == "storage_alignment": + if stmt.attr_key == "storage_alignment": alignments[stmt.node].pop() return stmt if isinstance(stmt, tvm.tir.Allocate): var = stmt.buffer_var + scope = var.type_annotation.storage_scope + if scope == "local.vtcm": + vtcm_buffers.pop() if var in vtcm_buffers: is_null = tvm.tir.call_intrin("bool", tvm.ir.Op.get("tir.isnullptr"), var) throw_error = tvm.tir.call_intrin( diff --git a/python/tvm/contrib/miopen.py b/python/tvm/contrib/miopen.py index 112fc320973b..0e336c1c82b9 100644 --- a/python/tvm/contrib/miopen.py +++ b/python/tvm/contrib/miopen.py @@ -136,3 +136,55 @@ def conv2d_forward( ), name="y", ) + + +def softmax(x, axis=-1): + """Compute softmax with MIOpen + + Parameters + ---------- + x : tvm.te.Tensor + The input tensor + + axis : int + The axis to compute softmax over + + Returns + ------- + ret : tvm.te.Tensor + The result tensor + """ + return te.extern( + x.shape, + [x], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.miopen.softmax.forward", ins[0], outs[0], axis + ), + name="y", + ) + + +def log_softmax(x, axis=-1): + """Compute log softmax with MIOpen + + Parameters + ---------- + x : tvm.te.Tensor + The input tensor + + axis : int + The axis to compute log softmax over + + Returns + ------- + ret : tvm.te.Tensor + The result tensor + """ + return te.extern( + x.shape, + [x], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.miopen.log_softmax.forward", ins[0], outs[0], axis + ), + name="y", + ) diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py index 2f552034e9f8..907231c1a9fa 100644 --- a/python/tvm/contrib/popen_pool.py +++ b/python/tvm/contrib/popen_pool.py @@ -84,10 +84,22 @@ class PopenWorker: PopenWorker provides a low-level API to interact with a separate process via Popen. + + Parameters + ---------- + initializer: callable or None + A callable initializer, or None + + initargs: Tuple[object] + A tuple of args for the initializer """ - def __init__(self): + def __init__(self, initializer=None, initargs=()): self._proc = None + self._initializer = initializer + self._initargs = initargs + if self._initializer is not None and not callable(self._initializer): + raise TypeError("initializer must be callable for PopenWorker") def __del__(self): try: @@ -203,6 +215,10 @@ def send(self, fn, args=(), kwargs=None, timeout=None): if self._proc is None: self._start() + # init + if self._initializer is not None: + self.send(self._initializer, self._initargs) + self.recv() kwargs = {} if not kwargs else kwargs data = cloudpickle.dumps((fn, args, kwargs, timeout), protocol=pickle.HIGHEST_PROTOCOL) try: @@ -269,14 +285,33 @@ class PopenPoolExecutor: timeout : float Timeout value for each function submit. + + initializer: callable or None + A callable initializer, or None + + initargs: Tuple[object] + A tuple of args for the initializer + + Note + ---- + If max_workers is NONE then the number returned by + os.cpu_count() is used. This method aligns with the + behavior of multiprocessing.pool(). """ - def __init__(self, max_workers, timeout=None): + def __init__(self, max_workers=None, timeout=None, initializer=None, initargs=()): + if max_workers is None: + max_workers = os.cpu_count() # Use an internal thread pool to send to popen workers self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) self._timeout = timeout self._worker_map = {} self._lock = threading.Lock() + self._initializer = initializer + self._initargs = initargs + + if self._initializer is not None and not callable(self._initializer): + raise TypeError("initializer must be callable for PopenPoolExecutor") def __del__(self): self._lock.acquire() @@ -293,7 +328,7 @@ def _worker_run(self, fn, args, kwargs): self._lock.acquire() tid = threading.get_ident() if tid not in self._worker_map: - proc = PopenWorker() + proc = PopenWorker(self._initializer, self._initargs) self._worker_map[tid] = proc else: proc = self._worker_map[tid] diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index e442c806a2f3..6f8aab23cde1 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -655,13 +655,103 @@ def convert_attributes(cls, attrs): class Cast(OpConverter): - """ Operator converter for Cast.""" + """Operator converter for Cast.""" @classmethod def convert_attributes(cls, attrs): return {"to": getattr(TensorProto, attrs.dtype.upper())} +class Resize(OpConverter): + """Operator converter for Resize.""" + + @classmethod + def convert_attributes(cls, attrs): + method = attrs.get_str("method") + if method == "nearest_neighbor": + mode = b"nearest" + elif "linear" in method: # linear / bilinear + mode = b"linear" + elif "cubic" in method: # cubic / bicubic + mode = b"cubic" + else: + raise RuntimeError("Unsupported method %s in operator Resize" % method) + + coord_trans = attrs.get_str("coordinate_transformation_mode") + if coord_trans == "half_pixel": + coord_trans = b"half_pixel" + elif coord_trans == "align_corners": + coord_trans = b"align_corners" + elif coord_trans == "asymmetric": + coord_trans = b"asymmetric" + else: + raise RuntimeError( + "Unsupported coordinate transform mode %s in operator Resize" % coord_trans + ) + + rounding_method = attrs.get_str("rounding_method") + if rounding_method == "round": + rounding_method = b"round_prefer_ceil" + elif rounding_method == "floor": + rounding_method = b"floor" + elif rounding_method == "ceil": + rounding_method = b"ceil" + else: + raise RuntimeError( + "Unsupported rounding method %s in operator Resize" % rounding_method + ) + + size = attrs.get_int_tuple("size") + + return { + "mode": mode, + "coord_trans": coord_trans, + "size": size, + "nearest_mode": rounding_method, + } + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + attrs = cls.convert_attributes(node_entry["relay_node"].attrs) + + name = node_entry["name"] + input_node = node_dict[node_entry["inputs"][0]] + assert len(input_node) == 1, "input node can not be a Tuple" + input_node = input_node[0] + input_shape = input_node["types"][0].shape + + # (TBD) needed in opset 11 + roi = [0] * len(input_shape) + [1] * len(input_shape) + roi_array = numpy.asarray(roi).astype(numpy.float64) + roi_node = add_input(roi_array, name, "roi", model_container) + + out_size = attrs["size"] + + # (onnx) rank of scale / size must match rank of X + # relay size node contains only spatial dimensions + # pad with 1s to match rank + match_rank_pad = len(input_shape) - len(out_size) + out_size_full_rank = input_shape[:match_rank_pad] + list(out_size) + out_size_array = numpy.asarray(out_size_full_rank).astype(numpy.int64) + + input_size_array = numpy.asarray(list(input_shape)).astype(numpy.int64) + + scale_array = numpy.divide(out_size_array, input_size_array).astype(numpy.float32) + scale_node = add_input(scale_array, name, "scales", model_container) + + input_names = [node_entry["input_names"][0], roi_node, scale_node] + + resize_node = onnx.helper.make_node( + cls.__name__, + input_names, + node_entry["output_names"], + mode=attrs["mode"], + coordinate_transformation_mode=attrs["coord_trans"], + nearest_mode=attrs["nearest_mode"], + ) + model_container.add_nodes([resize_node]) + + relay_to_onnx_op_mapping = { "reshape": Reshape, "nn.conv2d": Conv, @@ -701,6 +791,7 @@ def convert_attributes(cls, attrs): "copy": rename("Identity"), "round": rename("Round"), "cast": Cast, + "image.resize2d": Resize, } diff --git a/python/tvm/contrib/utils.py b/python/tvm/contrib/utils.py index 68c6b3d5bf6b..e2ca182779c6 100644 --- a/python/tvm/contrib/utils.py +++ b/python/tvm/contrib/utils.py @@ -124,7 +124,7 @@ def remove(self): def path(self): return pathlib.Path(self.temp_dir) - def __div__(self, other): + def __truediv__(self, other): if not isinstance(other, (str, pathlib.Path)): raise TypeError( "TempDirectory / operator: must supply str or pathlib.Path; got %r" % (other,) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 0533898ded35..a7ebc00c315f 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -107,7 +107,7 @@ def lower( It should be None if we want to lower TensorIR. name : str - The name of result function. + The name of the result function. binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]] Dictionary that maps the Tensor to Buffer which specified the data layout @@ -161,7 +161,10 @@ def _build_for_device(input_mod, target, target_host): mod_mixed = input_mod mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) - opt_mixed = [tvm.tir.transform.VerifyMemory()] + opt_mixed = [ + tvm.tir.transform.VerifyMemory(), + tvm.tir.transform.MergeDynamicSharedMemoryAllocations(), + ] if len(mod_mixed.functions) == 1: opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))] diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index d3a62b508135..15c09753d46f 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -396,7 +396,7 @@ def parse_shape_string(inputs_string): """ # Create a regex pattern that extracts each separate input mapping. - pattern = r"\w+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" + pattern = r"(?:\w+\/)?\w+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" input_mappings = re.findall(pattern, inputs_string) if not input_mappings: raise argparse.ArgumentTypeError( diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index 8c8828ddd49b..48bb052124ee 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -46,7 +46,7 @@ import os import tarfile import json -from typing import Optional, Union, List, Dict, Callable, TextIO +from typing import Optional, Union, Dict, Callable, TextIO import numpy as np import tvm @@ -54,6 +54,7 @@ from tvm import relay from tvm.contrib import utils from tvm.relay.backend.executor_factory import GraphExecutorFactoryModule +from tvm.runtime.module import BenchmarkResult try: from tvm.micro import export_model_library_format @@ -336,8 +337,8 @@ def import_package(self, package_path: str): with open(temp.relpath("metadata.json")) as metadata_json: metadata = json.load(metadata_json) - is_graph_runtime = "graph" in metadata["runtimes"] - graph = temp.relpath("runtime-config/graph/graph.json") if is_graph_runtime else None + has_graph_executor = "graph" in metadata["executors"] + graph = temp.relpath("executor-config/graph/graph.json") if has_graph_executor else None params = temp.relpath("parameters/default.params") self.type = "mlf" @@ -371,14 +372,14 @@ def import_package(self, package_path: str): class TVMCResult(object): """A class that stores the results of tvmc.run and provides helper utilities.""" - def __init__(self, outputs: Dict[str, np.ndarray], times: List[str]): + def __init__(self, outputs: Dict[str, np.ndarray], times: BenchmarkResult): """Create a convenience wrapper around the output of tvmc.run Parameters ---------- outputs : dict Outputs dictionary mapping the name of the output to its numpy value. - times : list of float + times : BenchmarkResult The execution times measured by the time evaluator in seconds to produce outputs. """ self.outputs = outputs @@ -390,29 +391,15 @@ def format_times(self): This has the effect of producing a small table that looks like: .. code-block:: Execution time summary: - mean (ms) max (ms) min (ms) std (ms) - 0.14310 0.16161 0.12933 0.01004 + mean (ms) median (ms) max (ms) min (ms) std (ms) + 0.14310 0.14310 0.16161 0.12933 0.01004 Returns ------- str A formatted string containing the statistics. """ - - # timestamps - mean_ts = np.mean(self.times) * 1000 - std_ts = np.std(self.times) * 1000 - max_ts = np.max(self.times) * 1000 - min_ts = np.min(self.times) * 1000 - - header = "Execution time summary:\n{0:^10} {1:^10} {2:^10} {3:^10}".format( - "mean (ms)", "max (ms)", "min (ms)", "std (ms)" - ) - stats = "{0:^10.2f} {1:^10.2f} {2:^10.2f} {3:^10.2f}".format( - mean_ts, max_ts, min_ts, std_ts - ) - - return "%s\n%s\n" % (header, stats) + return str(self.times) def get_output(self, name: str): """A helper function to grab one of the outputs by name. diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 916139874579..489604d79cf4 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -417,14 +417,12 @@ def run_module( # Run must be called explicitly if profiling if profile: logger.info("Running the module with profiling enabled.") - module.run() - - # create the module time evaluator (returns a function) - timer = module.module.time_evaluator("run", dev, number=number, repeat=repeat) - # call the evaluator function to invoke the module and save execution times - prof_result = timer() - # collect a list of execution times from the profiling results - times = prof_result.results + report = module.profile() + # This print is intentional + print(report) + + # call the benchmarking function of the executor + times = module.benchmark(dev, number=number, repeat=repeat) logger.debug("Collecting the output tensors.") num_outputs = module.get_num_outputs() diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index b4cc4421b169..83557a3eae19 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -21,6 +21,7 @@ from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType from .tensor_type import TensorType +from .affine_type import TensorAffineType, TupleAffineType from .type_relation import TypeCall, TypeRelation from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range from .op import Op, register_op_attr, register_intrin_lowering diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py new file mode 100644 index 000000000000..a1ce08017b1b --- /dev/null +++ b/python/tvm/ir/affine_type.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Types for quantized Tensors.""" +import tvm._ffi + +from .base import Node +from . import _ffi_api + + +class AffineType(Node): + """The base class of Affine Types.""" + + def __eq__(self, other): + """Compare two types for structural equivalence.""" + return bool(tvm.ir.structural_equal(self, other)) + + def __ne__(self, other): + return not self.__eq__(other) + + +@tvm._ffi.register_object("TensorAffineType") +class TensorAffineType(AffineType): + """The quantized type of a tensor, with scale, zero point, and datatype + + The real space value is calculated as x = x_q * scale + zero_point + + Parameters + ---------- + scale: Expr + The scale + + zero_point: Expr + The zero_point + + dtype : str + The content data type. + """ + + def __init__(self, scale, zero_point, dtype): + self.__init_handle_by_constructor__(_ffi_api.TensorAffineType, scale, zero_point, dtype) + + +@tvm._ffi.register_object("TupleAffineType") +class TupleAffineType(AffineType): + """Affine types of a node with multiple outputs + + Parameters + ---------- + types : List[TensorAffineType] + The shape of the Tensor + + """ + + def __init__(self, types): + self.__init_handle_by_constructor__(_ffi_api.TupleAffineType, types) diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 93aae45930e3..17995bfa7850 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -199,7 +199,7 @@ class Sequential(Pass): The list of passes that the sequential pass is dependent on. """ - def __init__(self, passes=None, opt_level=2, name="sequential", required=None): + def __init__(self, passes=None, opt_level=0, name="sequential", required=None): passes = passes if passes else [] if not isinstance(passes, (list, tuple)): raise TypeError("passes must be a list of Pass objects.") diff --git a/python/tvm/micro/__init__.py b/python/tvm/micro/__init__.py index a70cb96d9b13..88dcde8ceaf0 100644 --- a/python/tvm/micro/__init__.py +++ b/python/tvm/micro/__init__.py @@ -16,18 +16,13 @@ # under the License. """MicroTVM module for bare-metal backends""" -from .artifact import Artifact -from .build import build_static_runtime, default_options, get_standalone_crt_dir -from .build import get_standalone_crt_lib, Workspace -from .compiler import Compiler, DefaultCompiler, Flasher -from .debugger import GdbRemoteDebugger -from .micro_library import MicroLibrary -from .micro_binary import MicroBinary +from .build import get_standalone_crt_dir from .model_library_format import export_model_library_format, UnsupportedInModelLibraryFormatError +from .project import generate_project, GeneratedProject, TemplateProject from .session import ( create_local_graph_executor, create_local_debug_executor, Session, SessionTerminatedError, ) -from .transport import TransportLogger, DebugWrapperTransport, SubprocessTransport +from .transport import TransportLogger diff --git a/python/tvm/micro/artifact.py b/python/tvm/micro/artifact.py deleted file mode 100644 index c8faccb3f512..000000000000 --- a/python/tvm/micro/artifact.py +++ /dev/null @@ -1,295 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""""Defines abstractions around compiler artifacts produced in compiling micro TVM binaries.""" - -import hashlib -import io -import os -import json -import shutil -import tarfile - - -class ArtifactFileNotFoundError(Exception): - """Raised when an artifact file cannot be found on disk.""" - - -class ArtifactBadSymlinkError(Exception): - """Raised when an artifact symlink points outside the base directory.""" - - -class ArtifactBadArchiveError(Exception): - """Raised when an artifact archive is malformed.""" - - -class ImmobileArtifactError(Exception): - """Raised when an artifact is declared immobile and thus cannot be archived.""" - - -class ArchiveModifiedError(Exception): - """Raised when the underlying files in a metadata-only archive were modified after archiving.""" - - -def sha256_hexdigest(path): - with open(path, "rb") as path_fd: - h = hashlib.sha256() - chunk = path_fd.read(1 * 1024 * 1024) - while chunk: - h.update(chunk) - chunk = path_fd.read(1 * 1024 * 1024) - - return h.hexdigest() - - -def _validate_metadata_only(metadata): - """Validate that the files in a metadata-only archive have not changed.""" - problems = [] - for files in metadata["labelled_files"].values(): - for f in files: - disk_path = os.path.join(metadata["base_dir"], f) - try: - sha = sha256_hexdigest(disk_path) - except FileNotFoundError: - problems.append(f"{f}: original file not found") - continue - - expected_sha = metadata["file_digests"][f] - if sha != expected_sha: - problems.append(f"{f}: sha256 mismatch: expected {expected_sha}, got {sha}") - - if problems: - raise ArchiveModifiedError( - "Files in metadata-only archive have been modified:\n" - + "\n".join([f" * {p}" for p in problems]) - ) - - -class Artifact: - """Describes a compiler artifact and defines common logic to archive it for transport.""" - - # A version number written to the archive. - ENCODING_VERSION = 2 - - # A unique string identifying the type of artifact in an archive. Subclasses must redefine this - # variable. - ARTIFACT_TYPE = None - - @classmethod - def unarchive(cls, archive_path, base_dir): - """Unarchive an artifact into base_dir. - - Parameters - ---------- - archive_path : str - Path to the archive file. - base_dir : str - Path to a non-existent, empty directory under which the artifact will live. If working - with a metadata-only archive, this directory will just hold the metadata.json. - - Returns - ------- - Artifact : - The unarchived artifact. - """ - if os.path.exists(base_dir): - raise ValueError(f"base_dir exists: {base_dir}") - - base_dir_parent, base_dir_name = os.path.split(base_dir) - temp_dir = os.path.join(base_dir_parent, f"__tvm__{base_dir_name}") - os.mkdir(temp_dir) - try: - with tarfile.open(archive_path) as tar_f: - tar_f.extractall(temp_dir) - - temp_dir_contents = os.listdir(temp_dir) - if len(temp_dir_contents) != 1: - raise ArtifactBadArchiveError( - "Expected exactly 1 subdirectory at root of archive, got " - f"{temp_dir_contents!r}" - ) - - metadata_path = os.path.join(temp_dir, temp_dir_contents[0], "metadata.json") - if not metadata_path: - raise ArtifactBadArchiveError("No metadata.json found in archive") - - with open(metadata_path) as metadata_f: - metadata = json.load(metadata_f) - - version = metadata.get("version") - if version != cls.ENCODING_VERSION: - raise ArtifactBadArchiveError( - f"archive version: expect {cls.EXPECTED_VERSION}, found {version}" - ) - - metadata_only = metadata.get("metadata_only") - if metadata_only: - _validate_metadata_only(metadata) - - os.rename(os.path.join(temp_dir, temp_dir_contents[0]), base_dir) - - artifact_cls = cls - for sub_cls in cls.__subclasses__(): - if sub_cls.ARTIFACT_TYPE is not None and sub_cls.ARTIFACT_TYPE == metadata.get( - "artifact_type" - ): - artifact_cls = sub_cls - break - - return artifact_cls.from_unarchived( - base_dir if not metadata_only else metadata["base_dir"], - metadata["labelled_files"], - metadata["metadata"], - immobile=metadata.get("immobile"), - ) - finally: - shutil.rmtree(temp_dir) - - @classmethod - def from_unarchived(cls, base_dir, labelled_files, metadata, immobile): - return cls(base_dir, labelled_files, metadata, immobile) - - def __init__(self, base_dir, labelled_files, metadata, immobile=False): - """Create a new artifact. - - Parameters - ---------- - base_dir : str - The path to a directory on disk which contains all the files in this artifact. - labelled_files : Dict[str, str] - A dict mapping a file label to the relative paths of the files that carry that label. - metadata : Dict - A dict containing artitrary JSON-serializable key-value data describing the artifact. - immobile : bool - True when this artifact can't be used after being moved out of its current location on - disk. This can happen when artifacts contain absolute paths or when it's not feasible to - include enough files in the artifact to reliably re-run commands in arbitrary locations. - Setting this flag will cause archive() to raise ImmboileArtifactError. - """ - self.base_dir = os.path.realpath(base_dir) - self.labelled_files = labelled_files - self.metadata = metadata - self.immobile = immobile - - for label, files in labelled_files.items(): - for f in files: - f_path = os.path.join(self.base_dir, f) - if not os.path.lexists(f_path): - raise ArtifactFileNotFoundError(f"{f} (label {label}): not found at {f_path}") - - if os.path.islink(f_path): - link_path = os.path.readlink(f_path) - if os.path.isabs(link_path): - link_fullpath = link_path - else: - link_fullpath = os.path.join(os.path.dirname(f_path), link_path) - - link_fullpath = os.path.realpath(link_fullpath) - if not link_fullpath.startswith(self.base_dir): - raise ArtifactBadSymlinkError( - f"{f} (label {label}): symlink points outside artifact tree" - ) - - def abspath(self, rel_path): - """Return absolute path to the member with the given relative path.""" - return os.path.join(self.base_dir, rel_path) - - def label(self, label): - """Return a list of relative paths to files with the given label.""" - return self.labelled_files[label] - - def label_abspath(self, label): - return [self.abspath(p) for p in self.labelled_files[label]] - - def archive(self, archive_path, metadata_only=False): - """Create a relocatable tar archive of the artifacts. - - Parameters - ---------- - archive_path : str - Path to the tar file to create. Or, path to a directory, under which a tar file will be - created named {base_dir}.tar. - metadata_only : bool - If true, don't archive artifacts; instead, just archive metadata plus original - base_path. A metadata-only archive can be unarchived and used like a regular archive - provided none of the files have changed in their original locations on-disk. - - Returns - ------- - str : - The value of archive_path, after potentially making the computation describe above. - - Raises - ------ - ImmboileArtifactError : - When immobile=True was passed to the constructor. - """ - if self.immobile and not metadata_only: - raise ImmobileArtifactError("This artifact can't be moved") - - if os.path.isdir(archive_path): - archive_path = os.path.join(archive_path, f"{os.path.basename(self.base_dir)}.tar") - - archive_name = os.path.splitext(os.path.basename(archive_path))[0] - with tarfile.open(archive_path, "w") as tar_f: - - def _add_file(name, data, f_type): - tar_info = tarfile.TarInfo(name=name) - tar_info.type = f_type - data_bytes = bytes(data, "utf-8") - tar_info.size = len(data) - tar_f.addfile(tar_info, io.BytesIO(data_bytes)) - - metadata = { - "version": self.ENCODING_VERSION, - "labelled_files": self.labelled_files, - "metadata": self.metadata, - "metadata_only": False, - } - if metadata_only: - metadata["metadata_only"] = True - metadata["base_dir"] = self.base_dir - metadata["immobile"] = self.immobile - metadata["file_digests"] = {} - for files in self.labelled_files.values(): - for f in files: - metadata["file_digests"][f] = sha256_hexdigest(self.abspath(f)) - - _add_file( - f"{archive_name}/metadata.json", - json.dumps(metadata, indent=2, sort_keys=True), - tarfile.REGTYPE, - ) - for dir_path, _, files in os.walk(self.base_dir): - for f in files: - file_path = os.path.join(dir_path, f) - archive_file_path = os.path.join( - archive_name, os.path.relpath(file_path, self.base_dir) - ) - if not os.path.islink(file_path): - tar_f.add(file_path, archive_file_path, recursive=False) - continue - - link_path = os.readlink(file_path) - if not os.path.isabs(link_path): - tar_f.add(file_path, archive_file_path, recursive=False) - continue - - relpath = os.path.relpath(link_path, os.path.dirname(file_path)) - _add_file(archive_file_path, relpath, tarfile.LNKTYPE) - - return archive_path diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index a83ccaa47cda..16e7ed24cb4f 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -17,42 +17,15 @@ """Defines top-level glue functions for building microTVM artifacts.""" -import copy import logging import os -import re -import typing -from tvm.contrib import utils -from .micro_library import MicroLibrary from .._ffi import libinfo _LOG = logging.getLogger(__name__) -class Workspace: - """Defines helper functions for manipulating temporary compilation workspaces.""" - - def __init__(self, root=None, debug=False): - if debug or root is not None: - with utils.TempDirectory.set_keep_for_debug(): - self.tempdir = utils.tempdir(custom_path=root) - _LOG.info("Created debug mode workspace at: %s", self.tempdir.temp_dir) - else: - self.tempdir = utils.tempdir() - - def relpath(self, path): - return self.tempdir.relpath(path) - - def listdir(self): - return self.tempdir.listdir() - - @property - def path(self): - return self.tempdir.temp_dir - - STANDALONE_CRT_DIR = None @@ -84,186 +57,3 @@ def get_standalone_crt_dir() -> str: raise CrtNotFoundError() return STANDALONE_CRT_DIR - - -def get_standalone_crt_lib(name: str) -> str: - """Find a source library directory in the standalone_crt. - - The standalone C runtime is split into various libraries (one per directory underneath - src/runtime/crt). This convenience function returns the full path to one of those libraries - located in get_standalone_crt_dir(). - - Parameters - ---------- - name : str - Name of the library subdirectory underneath src/runtime/crt. - - Returns - ------- - str : - The full path to the the library. - """ - return os.path.join(get_standalone_crt_dir(), "src", "runtime", "crt", name) - - -def get_runtime_libs(executor: str) -> str: - """Return abspath to all CRT directories in link order which contain - source (i.e. not header) files. - """ - if executor == "host-driven": - crt_runtime_lib_names = ["microtvm_rpc_server", "microtvm_rpc_common", "common"] - elif executor == "aot": - crt_runtime_lib_names = ["aot_executor", "common"] - else: - raise ValueError(f"Incorrect executor: {executor}") - return [get_standalone_crt_lib(n) for n in crt_runtime_lib_names] - - -RUNTIME_SRC_REGEX = re.compile(r"^.*\.cc?$", re.IGNORECASE) - - -_COMMON_CFLAGS = ["-Wall", "-Werror", "-DDMLC_USE_LOGGING_LIBRARY="] - - -def _build_default_compiler_options(standalone_crt_dir: typing.Optional[str] = None) -> str: - """Return a dict containing base compile flags for the CRT under gcc common to . - - Parameters - ---------- - standalone_crt_dir : Optional[str] - If given, the path to the standalone_crt - """ - if standalone_crt_dir is None: - standalone_crt_dir = get_standalone_crt_dir() - return { - "cflags": ["-std=c11"] + _COMMON_CFLAGS, - "ccflags": ["-std=c++11"] + _COMMON_CFLAGS, - "ldflags": ["-std=c++11"], - "include_dirs": [os.path.join(standalone_crt_dir, "include")], - } - - -def default_options(crt_config_include_dir, standalone_crt_dir=None): - """Return default opts passed to Compile commands. - - Parameters - ---------- - crt_config_include_dir : str - Path to a directory containing crt_config.h for the target. This will be appended - to the include path for cflags and ccflags. - standalone_crt_dir : Optional[str] - - Returns - ------- - Dict : - A dictionary containing 3 subkeys, each whose value is _build_default_compiler_options() - plus additional customization. - - - "bin_opts" - passed as "options" to Compiler.binary() when building MicroBinary. - - "lib_opts" - passed as "options" to Compiler.library() when building bundled CRT - libraries (or otherwise, non-generated libraries). - - "generated_lib_opts" - passed as "options" to Compiler.library() when building the - generated library. - """ - bin_opts = _build_default_compiler_options(standalone_crt_dir) - bin_opts["include_dirs"].append(crt_config_include_dir) - - lib_opts = _build_default_compiler_options(standalone_crt_dir) - lib_opts["cflags"] = ["-Wno-error=incompatible-pointer-types"] - lib_opts["include_dirs"].append(crt_config_include_dir) - - generated_lib_opts = copy.copy(lib_opts) - - # Disable due to limitation in the TVM C codegen, which generates lots of local variable - # declarations at the top of generated code without caring whether they're used. - # Example: - # void* arg0 = (((TVMValue*)args)[0].v_handle); - # int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)]; - generated_lib_opts["cflags"].append("-Wno-unused-variable") - generated_lib_opts["ccflags"].append("-Wno-unused-variable") - - # Many TVM-intrinsic operators (i.e. expf, in particular) - generated_lib_opts["cflags"].append("-fno-builtin") - - return {"bin_opts": bin_opts, "lib_opts": lib_opts, "generated_lib_opts": generated_lib_opts} - - -def build_static_runtime( - workspace, - compiler, - module, - compiler_options, - executor=None, - extra_libs=None, -): - """Build the on-device runtime, statically linking the given modules. - - Parameters - ---------- - compiler : tvm.micro.Compiler - Compiler instance used to build the runtime. - - module : IRModule - Module to statically link. - - compiler_options : dict - The return value of tvm.micro.default_options(), with any keys overridden to inject - compiler options specific to this build. If not given, tvm.micro.default_options() is - used. This dict contains the `options` parameter passed to Compiler.library() and - Compiler.binary() at various stages in the compilation process. - - executor : Optional[str] - Executor used for runtime. Based on this we determine the libraries that need to be - linked with runtime. - - extra_libs : Optional[List[MicroLibrary|str]] - If specified, extra libraries to be compiled into the binary. If a MicroLibrary, it is - included into the binary directly. If a string, the path to a directory; all direct children - of this directory matching RUNTIME_SRC_REGEX are built into a library. These libraries are - placed before any common CRT libraries in the link order. - - Returns - ------- - MicroBinary : - The compiled runtime. - """ - mod_build_dir = workspace.relpath(os.path.join("build", "module")) - os.makedirs(mod_build_dir) - mod_src_dir = workspace.relpath(os.path.join("src", "module")) - - if not executor: - executor = "host-driven" - - libs = [] - for mod_or_src_dir in (extra_libs or []) + get_runtime_libs(executor): - if isinstance(mod_or_src_dir, MicroLibrary): - libs.append(mod_or_src_dir) - continue - - lib_src_dir = mod_or_src_dir - lib_name = os.path.basename(lib_src_dir) - lib_build_dir = workspace.relpath(f"build/{lib_name}") - os.makedirs(lib_build_dir) - - lib_srcs = [] - for p in os.listdir(lib_src_dir): - if RUNTIME_SRC_REGEX.match(p): - lib_srcs.append(os.path.join(lib_src_dir, p)) - - libs.append(compiler.library(lib_build_dir, lib_srcs, compiler_options["lib_opts"])) - - mod_src_dir = workspace.relpath(os.path.join("src", "module")) - os.makedirs(mod_src_dir) - libs.append( - module.export_library( - mod_build_dir, - workspace_dir=mod_src_dir, - fcompile=lambda bdir, srcs, **kwargs: compiler.library( - bdir, srcs, compiler_options["generated_lib_opts"] - ), - ) - ) - - runtime_build_dir = workspace.relpath(f"build/runtime") - os.makedirs(runtime_build_dir) - return compiler.binary(runtime_build_dir, libs, compiler_options["bin_opts"]) diff --git a/python/tvm/micro/compiler.py b/python/tvm/micro/compiler.py deleted file mode 100644 index 5bc5aba8a1be..000000000000 --- a/python/tvm/micro/compiler.py +++ /dev/null @@ -1,361 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines interfaces and default implementations for compiling and flashing code.""" - -import abc -import glob -import os -import re -import subprocess - -import tvm.target -from . import class_factory -from . import debugger -from . import transport - - -def run_cmd(cmd): - """Runs `cmd` in a subprocess and awaits its completion. - - Parameters - ---------- - cmd : List[str] - list of command-line arguments - - Returns - ------- - output : str - resulting stdout capture from the subprocess - """ - proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - (output, _) = proc.communicate() - output = output.decode("utf-8") - if proc.returncode != 0: - cmd_str = " ".join(cmd) - msg = f'error while running command "{cmd_str}":\n{output}' - raise RuntimeError(msg) - - -class DetectTargetError(Exception): - """Raised when no target comment was detected in the sources given.""" - - -class NoDefaultToolchainMatchedError(Exception): - """Raised when no default toolchain matches the target string.""" - - -class Compiler(metaclass=abc.ABCMeta): - """The compiler abstraction used with micro TVM.""" - - TVM_TARGET_RE = re.compile(r"^// tvm target: (.*)$") - - @classmethod - def _target_from_sources(cls, sources): - """Determine the target used to generate the given source files. - - Parameters - ---------- - sources : List[str] - The paths to source files to analyze. - - Returns - ------- - tvm.target.Target : - A Target instance reconstructed from the target string listed in the source files. - """ - target_strs = set() - - for obj in sources: - if os.path.splitext(obj)[1] not in (".cc", ".c"): - continue - - with open(obj) as obj_f: - for line in obj_f: - m = cls.TVM_TARGET_RE.match(line) - if m: - target_strs.add(m.group(1)) - - if len(target_strs) != 1: - raise DetectTargetError( - "autodetecting cross-compiler: could not extract TVM target from C source; regex " - f"{cls.TVM_TARGET_RE.pattern} does not match any line in sources: " - f'{", ".join(sources)}' - ) - - target_str = next(iter(target_strs)) - return tvm.target.Target(target_str) - - # Maps regexes identifying CPUs to the default toolchain prefix for that CPU. - TOOLCHAIN_PREFIX_BY_CPU_REGEX = { - r"cortex-[am].*": "arm-none-eabi-", - "x86[_-]64": "", - "native": "", - } - - def _autodetect_toolchain_prefix(self, target): - # Treat absence of -mcpu as if -mcpu=native is specified. The gcc shipped with OS X - # complains if -mcpu=native is given, so this approach allows model targets to avoid - # specifying this flag e.g. for tutorials. - if "mcpu" not in target.attrs: - return self.TOOLCHAIN_PREFIX_BY_CPU_REGEX["native"] - - matches = [] - for regex, prefix in self.TOOLCHAIN_PREFIX_BY_CPU_REGEX.items(): - if re.match(regex, target.attrs["mcpu"]): - matches.append(prefix) - - if matches: - if len(matches) != 1: - raise NoDefaultToolchainMatchedError( - f'{opt} matched more than 1 default toolchain prefix: {", ".join(matches)}. ' - "Specify cc.cross_compiler to create_micro_library()" - ) - - return matches[0] - - raise NoDefaultToolchainMatchedError( - f"target {str(target)} did not match any default toolchains" - ) - - def _defaults_from_target(self, target): - """Determine the default compiler options from the target specified. - - Parameters - ---------- - target : tvm.target.Target - - Returns - ------- - List[str] : - Default options used the configure the compiler for that target. - """ - opts = [] - # TODO use march for arm(https://gcc.gnu.org/onlinedocs/gcc/ARM-Options.html)? - if target.attrs.get("mcpu"): - opts.append(f'-mcpu={target.attrs["mcpu"]}') - if target.attrs.get("mfpu"): - opts.append(f'-mfpu={target.attrs["mfpu"]}') - if target.attrs.get("march"): - opts.append(f'-march={target.attrs["march"]}') - - return opts - - @abc.abstractmethod - def library(self, output, sources, options=None): - """Build a library from the given source files. - - Parameters - ---------- - output : str - The path to the library that should be created. The containing directory - is guaranteed to be empty and should be the base_dir for the returned - Artifact. - sources : List[str] - A list of paths to source files that should be compiled. - options : Optional[List[str]] - If given, additional command-line flags to pass to the compiler. - - Returns - ------- - MicroLibrary : - The compiled library, as a MicroLibrary instance. - """ - raise NotImplementedError() - - @abc.abstractmethod - def binary(self, output, objects, options=None, link_main=True, main_options=None): - """Link a binary from the given object and/or source files. - - Parameters - ---------- - output : str - The path to the binary that should be created. The containing directory - is guaranteed to be empty and should be the base_dir for the returned - Artifact. - objects : List[MicroLibrary] - A list of paths to source files or libraries that should be compiled. The final binary - should be statically-linked. - options: Optional[List[str]] - If given, additional command-line flags to pass to the compiler. - link_main: Optional[bool] - True if the standard main entry point for this Compiler should be included in the - binary. False if a main entry point is provided in one of `objects`. - main_options: Optional[List[str]] - If given, additional command-line flags to pass to the compiler when compiling the - main() library. In some cases, the main() may be compiled directly into the final binary - along with `objects` for logistical reasons. In those cases, specifying main_options is - an error and ValueError will be raised. - - Returns - ------- - MicroBinary : - The compiled binary, as a MicroBinary instance. - """ - raise NotImplementedError() - - @property - def flasher_factory(self): - """Produce a FlasherFactory for a Flasher instance suitable for this Compiler.""" - raise NotImplementedError("The Compiler base class doesn't define a flasher.") - - def flasher(self, **kw): - """Return a Flasher that can be used to program a produced MicroBinary onto the target.""" - return self.flasher_factory.override_kw(**kw).instantiate() - - -class IncompatibleTargetError(Exception): - """Raised when source files specify a target that differs from the compiler target.""" - - -class DefaultCompiler(Compiler): - """A Compiler implementation that attempts to use the system-installed GCC.""" - - def __init__(self, target=None): - super(DefaultCompiler, self).__init__() - self.target = target - if isinstance(target, str): - self.target = tvm.target.create(target) - - def library(self, output, sources, options=None): - options = options if options is not None else {} - try: - target = self._target_from_sources(sources) - except DetectTargetError: - assert self.target is not None, ( - "Must specify target= to constructor when compiling sources which don't specify a " - "target" - ) - - target = self.target - - if self.target is not None and str(self.target) != str(target): - raise IncompatibleTargetError( - f"auto-detected target {target} differs from configured {self.target}" - ) - - prefix = self._autodetect_toolchain_prefix(target) - outputs = [s for s in sources if os.path.splitext(s)[1] == ".o"] - sources = [s for s in sources if s not in outputs] - for src in sources: - src_base, src_ext = os.path.splitext(os.path.basename(src)) - - compiler_name = {".c": "gcc", ".cc": "g++", ".cpp": "g++"}[src_ext] - args = [prefix + compiler_name, "-g"] - args.extend(self._defaults_from_target(target)) - - args.extend(options.get(f"{src_ext[1:]}flags", [])) - - for include_dir in options.get("include_dirs", []): - args.extend(["-I", include_dir]) - - output_filename = f"{src_base}.o" - output_abspath = os.path.join(output, output_filename) - run_cmd(args + ["-c", "-o", output_abspath, src]) - outputs.append(output_abspath) - - output_filename = f"{os.path.basename(output)}.a" - output_abspath = os.path.join(output, output_filename) - run_cmd([prefix + "ar", "-r", output_abspath] + outputs) - run_cmd([prefix + "ranlib", output_abspath]) - - return tvm.micro.MicroLibrary(output, [output_filename]) - - def binary(self, output, objects, options=None, link_main=True, main_options=None): - assert self.target is not None, ( - "must specify target= to constructor, or compile sources which specify the target " - "first" - ) - - args = [self._autodetect_toolchain_prefix(self.target) + "g++"] - args.extend(self._defaults_from_target(self.target)) - if options is not None: - args.extend(options.get("ldflags", [])) - - for include_dir in options.get("include_dirs", []): - args.extend(["-I", include_dir]) - - output_filename = os.path.basename(output) - output_abspath = os.path.join(output, output_filename) - args.extend(["-g", "-o", output_abspath]) - - if link_main: - host_main_srcs = glob.glob( - os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host", "*.cc") - ) - if main_options: - main_lib = self.library(os.path.join(output, "host"), host_main_srcs, main_options) - for lib_name in main_lib.library_files: - args.append(main_lib.abspath(lib_name)) - else: - args.extend(host_main_srcs) - - for obj in objects: - for lib_name in obj.library_files: - args.append(obj.abspath(lib_name)) - - run_cmd(args) - return tvm.micro.MicroBinary(output, output_filename, []) - - @property - def flasher_factory(self): - return FlasherFactory(HostFlasher, [], {}) - - -class Flasher(metaclass=abc.ABCMeta): - """An interface for flashing binaries and returning a transport factory.""" - - @abc.abstractmethod - def flash(self, micro_binary): - """Flash a binary onto the device. - - Parameters - ---------- - micro_binary : MicroBinary - A MicroBinary instance. - - Returns - ------- - transport.TransportContextManager : - A ContextManager that can be used to create and tear down an RPC transport layer between - this TVM instance and the newly-flashed binary. - """ - raise NotImplementedError() - - -class FlasherFactory(class_factory.ClassFactory): - """A ClassFactory for Flasher instances.""" - - SUPERCLASS = Flasher - - -class HostFlasher(Flasher): - """A Flasher implementation that spawns a subprocess on the host.""" - - def __init__(self, debug=False): - self.debug = debug - - def flash(self, micro_binary): - if self.debug: - gdb_wrapper = debugger.GdbTransportDebugger( - [micro_binary.abspath(micro_binary.binary_file)] - ) - return transport.DebugWrapperTransport( - debugger=gdb_wrapper, transport=gdb_wrapper.transport() - ) - - return transport.SubprocessTransport([micro_binary.abspath(micro_binary.binary_file)]) diff --git a/python/tvm/micro/contrib/base.py b/python/tvm/micro/contrib/base.py deleted file mode 100644 index 9c4f4863e3bc..000000000000 --- a/python/tvm/micro/contrib/base.py +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines common helper functions useful for integrating custom compiler toolchains.""" - -import glob -import os -import shutil - - -GLOB_PATTERNS = ["__tvm_*", "libtvm__*"] - - -def populate_tvm_objs(dest_dir, objs): - """Replace tvm-prefixed files in a build worktree. - - This function is intended to be used to place TVM source files and libraries into a - template on-device runtime project. - - Parameters - ---------- - dest_dir : str - Path to the destination directory. - - objs : List[MicroLibrary] - List of MicroLibrary to place in the project directory. - - Returns - ------- - List[str] : - List of paths, each relative to `dest_dir` to the newly-copied MicroLibrary files. - """ - copied = [] - for p in GLOB_PATTERNS: - for f in glob.glob(os.path.join(dest_dir, p)): - if os.path.isdir(f): - shutil.rmtree(f) - else: - os.unlink(f) - - for obj in objs: - for lib_file in obj.library_files: - obj_base = os.path.basename(lib_file) - if obj_base.endswith(".a"): - dest_basename = f"libtvm__{obj_base}" - else: - dest_basename = f"__tvm_{obj_base}" - - copied.append(dest_basename) - dest = os.path.join(dest_dir, dest_basename) - shutil.copy(obj.abspath(lib_file), dest) - - return copied diff --git a/python/tvm/micro/contrib/zephyr.py b/python/tvm/micro/contrib/zephyr.py deleted file mode 100644 index 77cfb8d09bf2..000000000000 --- a/python/tvm/micro/contrib/zephyr.py +++ /dev/null @@ -1,789 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines a compiler integration that uses an externally-supplied Zephyr project.""" - -import collections -import copy -import logging -import multiprocessing -import os -import pathlib -import re -import tempfile -import textwrap -import shlex -import shutil -import subprocess -import sys -import threading -import queue -import enum - -import yaml - -import tvm.micro -from . import base -from .. import compiler -from .. import debugger -from ..transport import debug -from ..transport import file_descriptor - -from ..transport import serial -from ..transport import Transport, TransportClosedError, TransportTimeouts -from ..transport import wakeup - - -_LOG = logging.getLogger(__name__) - - -class SubprocessEnv(object): - def __init__(self, default_overrides): - self.default_overrides = default_overrides - - def run(self, cmd, **kw): - env = dict(os.environ) - for k, v in self.default_overrides.items(): - env[k] = v - - return subprocess.check_output(cmd, env=env, **kw, universal_newlines=True) - - -class ProjectNotFoundError(Exception): - """Raised when the project_dir supplied to ZephyrCompiler does not exist.""" - - -class FlashRunnerNotSupported(Exception): - """Raised when the FLASH_RUNNER for a project isn't supported by this Zephyr adapter.""" - - -class ZephyrCompiler(tvm.micro.Compiler): - """A Compiler instance that builds against a pre-existing zephyr project.""" - - def __init__( - self, - project_dir=None, - board=None, - west_cmd=None, - zephyr_base=None, - zephyr_toolchain_variant=None, - env_vars=None, - ): - """Configure the compiler for use. - - Parameters - ---------- - project_dir : str - Path to the pre-existing Zephyr project. - board : str - Name of the Zephyr board to build for (i.e. passed to `west build -b`) - west_cmd : Optional[list] - If given, argv that invoke the west build tool. Used only for flashing. - zephyr_base : Optional[str] - If given, path to Zephyr, as would normally be present in the ZEPHYR_BASE environment - variable. If not given, consults this environment variable. This value must be set in - one of those two places. - zephyr_toolchain_variant: Optional[str] - If given, overrides the toolchain used by Zephyr. If not given, uses the default - zephyr toolchain. When running on OS X outside of docker, you need to specify this. - env_vars : Optional[Dict[str,str]] - If given, additional environment variables present when invoking west, cmake, or make. - """ - self._project_dir = project_dir - if not os.path.exists(project_dir): - # Raise this error instead of a potentially-more-cryptic compiler error due to a missing - # prj.conf. - raise ProjectNotFoundError( - f"project_dir supplied to ZephyrCompiler does not exist: {project_dir}" - ) - - self._qemu = "qemu" in board - - # For Zephyr boards that run emulated by default but don't have the prefix "qemu_" in their - # board names, a suffix "-qemu" is added by users of microTVM when specifying the board - # name to inform that the QEMU transporter must be used just like for the boards with - # the prefix. Zephyr does not recognize the suffix, so we trim it off before passing it. - if "-qemu" in board: - board = board.replace("-qemu", "") - - self._board = board - - if west_cmd is None: - self._west_cmd = [sys.executable, "-mwest.app.main"] - elif isinstance(west_cmd, str): - self._west_cmd = [west_cmd] - elif isinstance(west_cmd, list): - self._west_cmd = west_cmd - else: - raise TypeError("west_cmd: expected string, list, or None; got %r" % (west_cmd,)) - - env = {} - if zephyr_toolchain_variant is not None: - env["ZEPHYR_TOOLCHAIN_VARIANT"] = zephyr_toolchain_variant - - self._zephyr_base = zephyr_base or os.environ["ZEPHYR_BASE"] - assert ( - self._zephyr_base is not None - ), f"Must specify zephyr_base=, or ZEPHYR_BASE must be in environment variables" - env["ZEPHYR_BASE"] = self._zephyr_base - - if env_vars: - env.update(env_vars) - - self._subprocess_env = SubprocessEnv(env) - - OPT_KEY_TO_CMAKE_DEFINE = { - "cflags": "CFLAGS", - "ccflags": "CXXFLAGS", - "ldflags": "LDFLAGS", - } - - @classmethod - def _options_to_cmake_args(cls, options): - args = [] - for key, define in cls.OPT_KEY_TO_CMAKE_DEFINE.items(): - if key in options: - quoted_opts = [shlex.quote(o).replace(";", "\\;") for o in options[key]] - args.append(f'-DEXTRA_{define}={" ".join(quoted_opts)}') - - if "cmake_args" in options: - args.extend(options["cmake_args"]) - - return args - - def library(self, output, sources, options=None): - project_name = os.path.basename(output) - if project_name.startswith("lib"): - project_name = project_name[3:] - - lib_prj_conf = os.path.join(output, "prj.conf") - if self._project_dir is not None: - project_dir_conf = os.path.join(self._project_dir, "prj.conf") - if os.path.exists(project_dir_conf): - shutil.copy(project_dir_conf, lib_prj_conf) - - # Copy board-specific Zephyr config file from the project_dir to - # the build lib dir so board-specific configs can be found and used by - # Zephyr's build system in conjunction with the generic prj.conf configs. - board_conf = os.path.join("boards", self._board + ".conf") - project_dir_board_conf = os.path.join(self._project_dir, board_conf) - if os.path.exists(project_dir_board_conf): - os.mkdir(os.path.join(output, "boards")) - lib_dir_board_conf = os.path.join(output, board_conf) - shutil.copy(project_dir_board_conf, lib_dir_board_conf) - - else: - with open(lib_prj_conf, "w") as prj_conf_f: - prj_conf_f.write("CONFIG_CPLUSPLUS=y\n") - - cmakelists_path = os.path.join(output, "CMakeLists.txt") - with open(cmakelists_path, "w") as cmake_f: - sources = " ".join(f'"{o}"' for o in sources) - cmake_f.write( - textwrap.dedent( - f"""\ - cmake_minimum_required(VERSION 3.13.1) - - find_package(Zephyr HINTS $ENV{{ZEPHYR_BASE}}) - project({project_name}_prj) - target_sources(app PRIVATE) - zephyr_library_named({project_name}) - target_sources({project_name} PRIVATE {sources}) - target_sources(app PRIVATE main.c) - target_link_libraries(app PUBLIC {project_name}) - """ - ) - ) - if "include_dirs" in options: - cmake_f.write( - f"target_include_directories({project_name} PRIVATE " - f'{" ".join(os.path.abspath(d) for d in options["include_dirs"])})\n' - ) - - with open(os.path.join(output, "main.c"), "w"): - pass - - # expected not to exist after populate_tvm_libs - build_dir = os.path.join(output, "__tvm_build") - os.mkdir(build_dir) - self._subprocess_env.run( - ["cmake", "..", f"-DBOARD={self._board}"] + self._options_to_cmake_args(options), - cwd=build_dir, - ) - num_cpus = multiprocessing.cpu_count() - self._subprocess_env.run( - ["make", f"-j{num_cpus}", "VERBOSE=1", project_name], cwd=build_dir - ) - return tvm.micro.MicroLibrary(build_dir, [f"lib{project_name}.a"]) - - def _print_make_statistics(self, output): - output = output.splitlines() - lines = iter(output) - for line in lines: - if line.startswith("Memory region"): - # print statistics header - _LOG.info(line) - _LOG.info("--------------------- ---------- ------------ ---------") - line = next(lines) - # while there is a region print it - try: - while ":" in line: - _LOG.info(line) - line = next(lines) - else: - break - except StopIteration: - pass - - def binary(self, output, objects, options=None, link_main=True, main_options=None): - assert link_main, "Must pass link_main=True" - assert self._project_dir is not None, "Must supply project_dir= to build binaries" - - copied_libs = base.populate_tvm_objs(self._project_dir, objects) - - # expected not to exist after populate_tvm_objs - cmake_args = [ - "cmake", - os.path.abspath(self._project_dir), - f"-DBOARD={self._board}", - ] + self._options_to_cmake_args(options) - if "include_dirs" in options: - cmake_args.append( - "-DTVM_INCLUDE_DIRS=" - f'{";".join(os.path.abspath(d) for d in options["include_dirs"])}' - ) - cmake_args.append(f'-DTVM_LIBS={";".join(copied_libs)}') - self._subprocess_env.run(cmake_args, cwd=output) - - make_output = self._subprocess_env.run(["make"], cwd=output) - - self._print_make_statistics(make_output) - - return tvm.micro.MicroBinary( - output, - binary_file=os.path.join("zephyr", "zephyr.elf"), - debug_files=[os.path.join("zephyr", "zephyr.elf")], - labelled_files={ - "cmake_cache": ["CMakeCache.txt"], - "device_tree": [os.path.join("zephyr", "zephyr.dts")], - }, - immobile=bool(self._qemu), - ) - - @property - def flasher_factory(self): - return compiler.FlasherFactory( - ZephyrFlasher, - ( - self._board, - self._qemu, - ), - dict( - zephyr_base=self._zephyr_base, - project_dir=self._project_dir, - subprocess_env=self._subprocess_env.default_overrides, - west_cmd=self._west_cmd, - ), - ) - - -CACHE_ENTRY_RE = re.compile(r"(?P[^:]+):(?P[^=]+)=(?P.*)") - - -CMAKE_BOOL_MAP = dict( - [(k, True) for k in ("1", "ON", "YES", "TRUE", "Y")] - + [(k, False) for k in ("0", "OFF", "NO", "FALSE", "N", "IGNORE", "NOTFOUND", "")] -) - - -def read_cmake_cache(file_name): - """Read a CMakeCache.txt-like file and return a dictionary of values.""" - entries = collections.OrderedDict() - with open(file_name, encoding="utf-8") as f: - for line in f: - m = CACHE_ENTRY_RE.match(line.rstrip("\n")) - if not m: - continue - - if m.group("type") == "BOOL": - value = CMAKE_BOOL_MAP[m.group("value").upper()] - else: - value = m.group("value") - - entries[m.group("name")] = value - - return entries - - -class BoardError(Exception): - """Raised when an attached board cannot be opened (i.e. missing /dev nodes, etc).""" - - -class BoardAutodetectFailed(Exception): - """Raised when no attached hardware is found matching the board= given to ZephyrCompiler.""" - - -class ZephyrFlasher(tvm.micro.compiler.Flasher): - """A Flasher implementation that delegates to Zephyr/west.""" - - def __init__( - self, - board, - qemu, - zephyr_base=None, - project_dir=None, - subprocess_env=None, - nrfjprog_snr=None, - openocd_serial=None, - flash_args=None, - debug_rpc_session=None, - serial_timeouts=None, - west_cmd=None, - ): - zephyr_base = zephyr_base or os.environ["ZEPHYR_BASE"] - sys.path.insert(0, os.path.join(zephyr_base, "scripts", "dts")) - try: - import dtlib # pylint: disable=import-outside-toplevel - - self._dtlib = dtlib - finally: - sys.path.pop(0) - - self._board = board - self._qemu = qemu - self._zephyr_base = zephyr_base - self._project_dir = project_dir - self._west_cmd = west_cmd - self._flash_args = flash_args - self._openocd_serial = openocd_serial - self._autodetected_openocd_serial = None - self._subprocess_env = SubprocessEnv(subprocess_env) - self._debug_rpc_session = debug_rpc_session - self._nrfjprog_snr = nrfjprog_snr - self._serial_timeouts = serial_timeouts - - def _get_nrf_device_args(self): - nrfjprog_args = ["nrfjprog", "--ids"] - nrfjprog_ids = subprocess.check_output(nrfjprog_args, encoding="utf-8") - if not nrfjprog_ids.strip("\n"): - raise BoardAutodetectFailed( - f'No attached boards recognized by {" ".join(nrfjprog_args)}' - ) - - boards = nrfjprog_ids.split("\n")[:-1] - if len(boards) > 1: - if self._nrfjprog_snr is None: - raise BoardError( - "Multiple boards connected; specify one with nrfjprog_snr=: " - f'{", ".join(boards)}' - ) - - if str(self._nrfjprog_snr) not in boards: - raise BoardError( - f"nrfjprog_snr ({self._nrfjprog_snr}) not found in {nrfjprog_args}: {boards}" - ) - - return ["--snr", str(self._nrfjprog_snr)] - - if not boards: - return [] - - return ["--snr", boards[0]] - - # kwargs passed to usb.core.find to find attached boards for the openocd flash runner. - BOARD_USB_FIND_KW = { - "nucleo_l4r5zi": {"idVendor": 0x0483, "idProduct": 0x374B}, - "nucleo_f746zg": {"idVendor": 0x0483, "idProduct": 0x374B}, - "stm32f746g_disco": {"idVendor": 0x0483, "idProduct": 0x374B}, - } - - def openocd_serial(self, cmake_entries): - """Find the serial port to use for a board with OpenOCD flash strategy.""" - if self._openocd_serial is not None: - return self._openocd_serial - - if self._autodetected_openocd_serial is None: - import usb # pylint: disable=import-outside-toplevel - - find_kw = self.BOARD_USB_FIND_KW[cmake_entries["BOARD"]] - boards = usb.core.find(find_all=True, **find_kw) - serials = [] - for b in boards: - serials.append(b.serial_number) - - if len(serials) == 0: - raise BoardAutodetectFailed(f"No attached USB devices matching: {find_kw!r}") - serials.sort() - - self._autodetected_openocd_serial = serials[0] - _LOG.debug("zephyr openocd driver: autodetected serial %s", serials[0]) - - return self._autodetected_openocd_serial - - def _get_openocd_device_args(self, cmake_entries): - return ["--serial", self.openocd_serial(cmake_entries)] - - @classmethod - def _get_flash_runner(cls, cmake_entries): - flash_runner = cmake_entries.get("ZEPHYR_BOARD_FLASH_RUNNER") - if flash_runner is not None: - return flash_runner - - with open(cmake_entries["ZEPHYR_RUNNERS_YAML"]) as f: - doc = yaml.load(f, Loader=yaml.FullLoader) - return doc["flash-runner"] - - def _get_device_args(self, cmake_entries): - flash_runner = self._get_flash_runner(cmake_entries) - - if flash_runner == "nrfjprog": - return self._get_nrf_device_args() - if flash_runner == "openocd": - return self._get_openocd_device_args(cmake_entries) - - raise BoardError( - f"Don't know how to find serial terminal for board {cmake_entries['BOARD']} with flash " - f"runner {flash_runner}" - ) - - def _zephyr_transport(self, micro_binary): - qemu_debugger = None - if self._debug_rpc_session: - qemu_debugger = debugger.RpcDebugger( - self._debug_rpc_session, - debugger.DebuggerFactory( - QemuGdbDebugger, - (micro_binary.abspath(micro_binary.debug_files[0]),), - {}, - ), - ) - - return ZephyrQemuTransport( - micro_binary.base_dir, startup_timeout_sec=30.0, qemu_debugger=qemu_debugger - ) - - def flash(self, micro_binary): - if self._qemu: - return self._zephyr_transport(micro_binary) - - cmake_cache_path = micro_binary.abspath(micro_binary.labelled_files["cmake_cache"][0]) - cmake_entries = read_cmake_cache(cmake_cache_path) - - build_dir = os.path.dirname(cmake_cache_path) - - # The nRF5340DK requires an additional `nrfjprog --recover` before each flash cycle. - # This is because readback protection is enabled by default when this device is flashed. - # Otherwise, flashing may fail with an error such as the following: - # ERROR: The operation attempted is unavailable due to readback protection in - # ERROR: your device. Please use --recover to unlock the device. - if ( - self._board.startswith("nrf5340dk") - and self._get_flash_runner(cmake_entries) == "nrfjprog" - ): - recover_args = ["nrfjprog", "--recover"] - recover_args.extend(self._get_nrf_device_args()) - self._subprocess_env.run(recover_args, cwd=build_dir) - - west_args = ( - self._west_cmd - + ["flash", "--build-dir", build_dir, "--skip-rebuild"] - + self._get_device_args(cmake_entries) - ) - if self._flash_args is not None: - west_args.extend(self._flash_args) - self._subprocess_env.run(west_args, cwd=build_dir) - - return self.transport(micro_binary) - - def _find_nrf_serial_port(self, cmake_entries): - com_ports = subprocess.check_output( - ["nrfjprog", "--com"] + self._get_device_args(cmake_entries), encoding="utf-8" - ) - ports_by_vcom = {} - for line in com_ports.split("\n")[:-1]: - parts = line.split() - ports_by_vcom[parts[2]] = parts[1] - - return {"port_path": ports_by_vcom["VCOM2"]} - - def _find_openocd_serial_port(self, cmake_entries): - return {"grep": self.openocd_serial(cmake_entries)} - - def _find_serial_port(self, micro_binary): - cmake_entries = read_cmake_cache( - micro_binary.abspath(micro_binary.labelled_files["cmake_cache"][0]) - ) - flash_runner = self._get_flash_runner(cmake_entries) - - if flash_runner == "nrfjprog": - return self._find_nrf_serial_port(cmake_entries) - - if flash_runner == "openocd": - return self._find_openocd_serial_port(cmake_entries) - - raise FlashRunnerNotSupported( - f"Don't know how to deduce serial port for flash runner {flash_runner}" - ) - - def transport(self, micro_binary): - """Instantiate the transport for use with non-QEMU Zephyr.""" - dt_inst = self._dtlib.DT( - micro_binary.abspath(micro_binary.labelled_files["device_tree"][0]) - ) - uart_baud = ( - dt_inst.get_node("/chosen") - .props["zephyr,console"] - .to_path() - .props["current-speed"] - .to_num() - ) - _LOG.debug("zephyr transport: found UART baudrate from devicetree: %d", uart_baud) - - port_kwargs = self._find_serial_port(micro_binary) - serial_transport = serial.SerialTransport( - timeouts=self._serial_timeouts, baudrate=uart_baud, **port_kwargs - ) - if self._debug_rpc_session is None: - return serial_transport - - return debug.DebugWrapperTransport( - debugger.RpcDebugger( - self._debug_rpc_session, - debugger.DebuggerFactory( - ZephyrDebugger, - ( - " ".join(shlex.quote(x) for x in self._west_cmd), - os.path.dirname(micro_binary.abspath(micro_binary.label("cmake_cache")[0])), - micro_binary.abspath(micro_binary.debug_files[0]), - self._zephyr_base, - ), - {}, - ), - ), - serial_transport, - ) - - -class QemuGdbDebugger(debugger.GdbDebugger): - def __init__(self, elf_file): - super(QemuGdbDebugger, self).__init__() - self._elf_file = elf_file - - def popen_kwargs(self): - # expect self._elf file to follow the form .../zephyr/zephyr.elf - cmake_cache_path = pathlib.Path(self._elf_file).parent.parent / "CMakeCache.txt" - cmake_cache = read_cmake_cache(cmake_cache_path) - return { - "args": [ - cmake_cache["CMAKE_GDB"], - "-ex", - "target remote localhost:1234", - "-ex", - f"file {self._elf_file}", - ], - } - - -class QemuStartupFailureError(Exception): - """Raised when the qemu pipe is not present within startup_timeout_sec.""" - - -class QemuFdTransport(file_descriptor.FdTransport): - """An FdTransport subclass that escapes written data to accommodate the QEMU monitor. - - It's supposedly possible to disable the monitor, but Zephyr controls most of the command-line - arguments for QEMU and there are too many options which implictly enable the monitor, so this - approach seems more robust. - """ - - def write_monitor_quit(self): - file_descriptor.FdTransport.write(self, b"\x01x", 1.0) - - def close(self): - file_descriptor.FdTransport.close(self) - - def timeouts(self): - assert False, "should not get here" - - def write(self, data, timeout_sec): - """Write data, escaping for QEMU monitor.""" - to_write = bytearray() - escape_pos = [] - for i, b in enumerate(data): - if b == 0x01: - to_write.append(b) - escape_pos.append(i) - to_write.append(b) - - num_written = file_descriptor.FdTransport.write(self, to_write, timeout_sec) - num_written -= sum(1 if x < num_written else 0 for x in escape_pos) - return num_written - - -class ZephyrQemuMakeResult(enum.Enum): - QEMU_STARTED = "qemu_started" - MAKE_FAILED = "make_failed" - EOF = "eof" - - -class ZephyrQemuTransport(Transport): - """The user-facing Zephyr QEMU transport class.""" - - def __init__(self, base_dir, startup_timeout_sec=5.0, qemu_debugger=None, **kwargs): - self.base_dir = base_dir - self.startup_timeout_sec = startup_timeout_sec - self.kwargs = kwargs - self.proc = None - self.fd_transport = None - self.pipe_dir = None - self.qemu_debugger = qemu_debugger - self._queue = queue.Queue() - - def timeouts(self): - return TransportTimeouts( - session_start_retry_timeout_sec=2.0, - session_start_timeout_sec=self.startup_timeout_sec, - session_established_timeout_sec=5.0 if self.qemu_debugger is None else 0, - ) - - def open(self): - self.pipe_dir = tempfile.mkdtemp() - self.pipe = os.path.join(self.pipe_dir, "fifo") - self.write_pipe = os.path.join(self.pipe_dir, "fifo.in") - self.read_pipe = os.path.join(self.pipe_dir, "fifo.out") - - os.mkfifo(self.write_pipe) - os.mkfifo(self.read_pipe) - if self.qemu_debugger is not None: - if "env" in self.kwargs: - self.kwargs["env"] = copy.copy(self.kwargs["env"]) - else: - self.kwargs["env"] = os.environ.copy() - - self.kwargs["env"]["TVM_QEMU_DEBUG"] = "1" - - self.proc = subprocess.Popen( - ["make", "run", f"QEMU_PIPE={self.pipe}"], - cwd=self.base_dir, - **self.kwargs, - stdout=subprocess.PIPE, - ) - try: - self._wait_for_qemu() - except Exception as error: - raise error - - if self.qemu_debugger is not None: - self.qemu_debugger.start() - - # NOTE: although each pipe is unidirectional, open both as RDWR to work around a select - # limitation on linux. Without this, non-blocking I/O can't use timeouts because named - # FIFO are always considered ready to read when no one has opened them for writing. - self.fd_transport = wakeup.WakeupTransport( - QemuFdTransport( - os.open(self.read_pipe, os.O_RDWR | os.O_NONBLOCK), - os.open(self.write_pipe, os.O_RDWR | os.O_NONBLOCK), - self.timeouts(), - ), - b"\xfe\xff\xfd\x03\0\0\0\0\0\x02" b"fw", - ) - self.fd_transport.open() - - def close(self): - if self.qemu_debugger is not None: - self.qemu_debugger.stop() - - if self.fd_transport is not None: - self.fd_transport.child_transport.write_monitor_quit() - self.proc.wait() - self.fd_transport.close() - self.fd_transport = None - - if self.proc is not None: - self.proc = None - - if self.pipe_dir is not None: - shutil.rmtree(self.pipe_dir) - self.pipe_dir = None - - def read(self, n, timeout_sec): - if self.fd_transport is None: - raise TransportClosedError() - return self.fd_transport.read(n, timeout_sec) - - def write(self, data, timeout_sec): - if self.fd_transport is None: - raise TransportClosedError() - return self.fd_transport.write(data, timeout_sec) - - def _qemu_check_stdout(self): - for line in self.proc.stdout: - line = str(line) - _LOG.debug(line) - if "[QEMU] CPU" in line: - self._queue.put(ZephyrQemuMakeResult.QEMU_STARTED) - else: - line = re.sub("[^a-zA-Z0-9 \n]", "", line) - pattern = r"recipe for target (\w*) failed" - if re.search(pattern, line, re.IGNORECASE): - self._queue.put(ZephyrQemuMakeResult.MAKE_FAILED) - self._queue.put(ZephyrQemuMakeResult.EOF) - - def _wait_for_qemu(self): - threading.Thread(target=self._qemu_check_stdout, daemon=True).start() - while True: - try: - item = self._queue.get(timeout=120) - except Exception: - raise TimeoutError("QEMU setup timeout.") - - if item == ZephyrQemuMakeResult.QEMU_STARTED: - break - - if item in [ZephyrQemuMakeResult.MAKE_FAILED, ZephyrQemuMakeResult.EOF]: - raise RuntimeError("QEMU setup failed.") - - raise ValueError(f"{item} not expected.") - - -class ZephyrDebugger(debugger.GdbDebugger): - """A Zephyr debugger implementation.""" - - def __init__(self, west_cmd, build_dir, elf_path, zephyr_base): - super(ZephyrDebugger, self).__init__() - self._west_cmd = shlex.split(west_cmd) - self._build_dir = build_dir - self._elf_path = elf_path - self._zephyr_base = zephyr_base - - def popen_kwargs(self): - env = dict(os.environ) - env["ZEPHYR_BASE"] = self._zephyr_base - - args = dict( - args=self._west_cmd - + [ - "debug", - "--skip-rebuild", - "--build-dir", - self._build_dir, - "--elf-file", - self._elf_path, - ], - env=env, - ) - return args diff --git a/python/tvm/micro/interface_api.py b/python/tvm/micro/interface_api.py new file mode 100644 index 000000000000..8086b1ed6554 --- /dev/null +++ b/python/tvm/micro/interface_api.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines functions for generating a C interface header""" + +import os + +from tvm.relay.backend.utils import mangle_module_name + + +def _emit_brief(header_file, module_name, description): + header_file.write("/*!\n") + header_file.write(f' * \\brief {description} for TVM module "{module_name}" \n') + header_file.write(" */\n") + + +def generate_c_interface_header(module_name, inputs, outputs, output_path): + """Generates a C interface header for a given modules inputs and outputs + + Parameters + ---------- + module_name : str + Name of the module to be used in defining structs and naming the header + inputs : list[str] + List of module input names to be placed in generated structs + outputs : list[str] + List of module output names to be placed in generated structs + output_path : str + Path to the output folder to generate the header into + + Returns + ------- + str : + Name of the generated file. + """ + mangled_name = mangle_module_name(module_name) + metadata_header = os.path.join(output_path, f"{mangled_name}.h") + with open(metadata_header, "w") as header_file: + header_file.write( + "#include \n" + f"#ifndef {mangled_name.upper()}_H_\n" + f"#define {mangled_name.upper()}_H_\n" + ) + + _emit_brief(header_file, module_name, "Input tensor pointers") + header_file.write(f"struct {mangled_name}_inputs {{\n") + for input_name in inputs: + header_file.write(f" void* {input_name};\n") + header_file.write("};\n\n") + + _emit_brief(header_file, module_name, "Output tensor pointers") + header_file.write(f"struct {mangled_name}_outputs {{\n") + for output_name in outputs: + header_file.write(f" void* {output_name};\n") + header_file.write("};\n\n") + + header_file.write( + "/*!\n" + f' * \\brief entrypoint function for TVM module "{module_name}"\n' + " * \\param inputs Input tensors for the module \n" + " * \\param outputs Output tensors for the module \n" + " */\n" + f"int32_t {mangled_name}_run(\n" + f" struct {mangled_name}_inputs* inputs,\n" + f" struct {mangled_name}_outputs* outputs\n" + ");\n" + ) + + header_file.write(f"#endif // {mangled_name.upper()}_H_\n") + + return metadata_header diff --git a/python/tvm/micro/micro_binary.py b/python/tvm/micro/micro_binary.py deleted file mode 100644 index 74b760b67650..000000000000 --- a/python/tvm/micro/micro_binary.py +++ /dev/null @@ -1,65 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines an Artifact implementation for representing compiled micro TVM binaries.""" - -from . import artifact - - -class MicroBinary(artifact.Artifact): - """An Artifact that describes a compiled binary.""" - - ARTIFACT_TYPE = "micro_binary" - - @classmethod - def from_unarchived(cls, base_dir, labelled_files, metadata, immobile): - binary_file = labelled_files["binary_file"][0] - del labelled_files["binary_file"] - - debug_files = None - if "debug_files" in labelled_files: - debug_files = labelled_files["debug_files"] - del labelled_files["debug_files"] - - return cls( - base_dir, - binary_file, - debug_files=debug_files, - labelled_files=labelled_files, - metadata=metadata, - immobile=immobile, - ) - - def __init__( - self, - base_dir, - binary_file, - debug_files=None, - labelled_files=None, - metadata=None, - immobile=False, - ): - labelled_files = {} if labelled_files is None else dict(labelled_files) - metadata = {} if metadata is None else dict(metadata) - labelled_files["binary_file"] = [binary_file] - if debug_files is not None: - labelled_files["debug_files"] = debug_files - - super(MicroBinary, self).__init__(base_dir, labelled_files, metadata, immobile=immobile) - - self.binary_file = binary_file - self.debug_files = debug_files diff --git a/python/tvm/micro/micro_library.py b/python/tvm/micro/micro_library.py deleted file mode 100644 index 74687ede1235..000000000000 --- a/python/tvm/micro/micro_library.py +++ /dev/null @@ -1,93 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines an Artifact subclass that describes a compiled static library.""" - -from tvm.contrib import utils -from . import artifact -from . import compiler - - -class MicroLibrary(artifact.Artifact): - """An Artifact that describes a compiled static library.""" - - ARTIFACT_TYPE = "micro_library" - - @classmethod - def from_unarchived(cls, base_dir, labelled_files, metadata, immobile): - library_files = labelled_files["library_files"] - del labelled_files["library_files"] - - debug_files = None - if "debug_files" in labelled_files: - debug_files = labelled_files["debug_files"] - del labelled_files["debug_files"] - - return cls( - base_dir, - library_files, - debug_files=debug_files, - labelled_files=labelled_files, - metadata=metadata, - immobile=immobile, - ) - - def __init__( - self, - base_dir, - library_files, - debug_files=None, - labelled_files=None, - metadata=None, - immobile=False, - ): - labelled_files = {} if labelled_files is None else dict(labelled_files) - metadata = {} if metadata is None else dict(metadata) - labelled_files["library_files"] = library_files - if debug_files is not None: - labelled_files["debug_files"] = debug_files - - super(MicroLibrary, self).__init__(base_dir, labelled_files, metadata, immobile=immobile) - - self.library_files = library_files - self.debug_file = debug_files - - -def create_micro_library(output, objects, options=None): - """Create a MicroLibrary using the default compiler options. - - Parameters - ---------- - output : str - Path to the output file, expected to end in .tar. - objects : List[str] - Paths to the source files to include in the library. - options : Optional[List[str]] - If given, additional command-line flags for the compiler. - """ - temp_dir = utils.tempdir() - comp = compiler.DefaultCompiler() - output = temp_dir.relpath("micro-library.o") - comp.library(output, objects, options=options) - - with open(output, "rb") as output_f: - elf_data = output_f.read() - - # TODO(areusch): Define a mechanism to determine compiler and linker flags for each lib - # enabled by the target str, and embed here. - micro_lib = MicroLibrary("", elf_data, {"target": comp.target.str()}) - micro_lib.save(output) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 87c067051f82..ed44a3336a52 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -25,7 +25,9 @@ import tarfile import typing +from tvm.ir.type import TupleType from .._ffi import get_global_func +from .interface_api import generate_c_interface_header from ..contrib import utils from ..driver import build_module from ..runtime import ndarray as _nd @@ -55,7 +57,6 @@ def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None): """ dso_modules = mod._collect_dso_modules() - dso_module_handles = [m.handle.value for m in dso_modules] non_dso_modules = mod._collect_from_import_tree(lambda m: m not in dso_modules) if non_dso_modules: raise UnsupportedInModelLibraryFormatError( @@ -169,9 +170,14 @@ def _build_function_memory_map(function_metadata): target_local_entries[func_name] = list() for func_name, finfo in function_metadata.items(): - if func_name == MAIN_FUNC_NAME_STR: + # Skip a few unsupported cases: + # 1. The main function metadata is exported elsewhere. + # 2. BYOC operator implementations do not currently export useful FunctionInfo. + if func_name == MAIN_FUNC_NAME_STR or not finfo.tir_primfuncs: continue - assert len(finfo.constant_sizes.items()) == num_targets + assert ( + len(finfo.constant_sizes.items()) == num_targets + ), f"{func_name}: found {finfo.constant_sizes!r} vs {num_targets}" assert len(finfo.io_sizes.items()) == num_targets target = finfo.workspace_sizes.items()[i][0] workspace_size = finfo.workspace_sizes.items()[i][1] @@ -213,6 +219,39 @@ def _build_function_memory_map(function_metadata): return ret +def _get_main_relay_func(mod: executor_factory.ExecutorFactoryModule): + main_func = mod.function_metadata[MAIN_FUNC_NAME_STR] + target = list(main_func.relay_primfuncs.keys())[0] + return main_func.relay_primfuncs[target] + + +def _convert_tuple_to_outputs(ret_type, offset=0): + outputs = [] + added_fields = len(ret_type.fields) + for output_index in range(added_fields): + next_output = offset + len(outputs) + if isinstance(ret_type.fields[output_index], TupleType): + outputs.extend(_convert_tuple_to_outputs(ret_type.fields[output_index], next_output)) + else: + outputs.append(f"output{next_output}") + return outputs + + +def _get_inputs_and_outputs_from_module(mod): + main_func = _get_main_relay_func(mod) + inputs = [argument.name_hint for argument in main_func.params] + + outputs = ["output"] + if isinstance(main_func.ret_type, TupleType): + outputs = _convert_tuple_to_outputs(main_func.ret_type) + + return inputs, outputs + + +def _should_generate_interface_header(mod): + return any(target.attrs.get("interface-api") == "c" for target in mod.target.values()) + + def _make_tar(source_dir, tar_file_path): """Build a tar file from source_dir.""" with tarfile.open(tar_file_path, "w") as tar_f: @@ -225,7 +264,7 @@ def reset(tarinfo): tar_f.add(str(source_dir), arcname=".", filter=reset) -_GENERATED_VERSION = 4 +_GENERATED_VERSION = 5 def _export_graph_model_library_format( @@ -241,7 +280,7 @@ def _export_graph_model_library_format( Temporary directory to populate with Model Library Format contents. """ is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule) - runtime = ["aot"] if is_aot else ["graph"] + executor = ["aot"] if is_aot else ["graph"] metadata = { "version": _GENERATED_VERSION, @@ -249,7 +288,7 @@ def _export_graph_model_library_format( "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), "memory": _build_memory_map(mod), "target": {int(k): str(v) for k, v in mod.target.items()}, - "runtimes": runtime, + "executors": executor, "style": "full-model", } @@ -260,6 +299,12 @@ def _export_graph_model_library_format( codegen_dir.mkdir() _populate_codegen_dir(mod.lib, codegen_dir, mod.libmod_name) + if _should_generate_interface_header(mod): + include_path = codegen_dir / "host" / "include" + include_path.mkdir() + inputs, outputs = _get_inputs_and_outputs_from_module(mod) + generate_c_interface_header(mod.libmod_name, inputs, outputs, include_path) + parameters_dir = tempdir / "parameters" parameters_dir.mkdir() param_filename = parameters_dir / f"{mod.libmod_name}.params" @@ -272,7 +317,7 @@ def _export_graph_model_library_format( f.write(str(mod.ir_mod)) if not is_aot: - graph_config_dir = tempdir / "runtime-config" / "graph" + graph_config_dir = tempdir / "executor-config" / "graph" graph_config_dir.mkdir(parents=True) with open(graph_config_dir / "graph.json", "w") as f: f.write(mod.get_executor_config()) @@ -363,7 +408,7 @@ def _export_operator_model_library_format(mod: build_module.OperatorModule, temp "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), "memory": memory_map, "target": {k: str(v) for k, v in targets.items()}, - "runtimes": [], + "executors": [], "style": "operator", } with open(tempdir / "metadata.json", "w") as metadata_f: diff --git a/python/tvm/micro/project.py b/python/tvm/micro/project.py new file mode 100644 index 000000000000..8d1408c679fb --- /dev/null +++ b/python/tvm/micro/project.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines glue wrappers around the Project API which mate to TVM interfaces.""" + +import pathlib +import typing + +from .. import __version__ +from ..contrib import utils +from .build import get_standalone_crt_dir +from .model_library_format import ExportableModule, export_model_library_format +from .project_api import client +from .transport import Transport, TransportTimeouts + + +class ProjectTransport(Transport): + """A Transport implementation that uses the Project API client.""" + + def __init__(self, api_client, options): + self._api_client = api_client + self._options = options + self._timeouts = None + + def timeouts(self): + assert self._timeouts is not None, "Transport not yet opened" + return self._timeouts + + def open(self): + reply = self._api_client.open_transport(self._options) + self._timeouts = TransportTimeouts(**reply["timeouts"]) + + def close(self): + if not self._api_client.is_shutdown: + self._api_client.close_transport() + self._api_client.shutdown() + + def write(self, data, timeout_sec): + self._api_client.write_transport(data, timeout_sec) + + def read(self, n, timeout_sec): + return self._api_client.read_transport(n, timeout_sec)["data"] + + +class TemplateProjectError(Exception): + """Raised when the Project API server given to GeneratedProject reports is_template=True.""" + + +class GeneratedProject: + """Defines a glue interface to interact with a generated project through the API server.""" + + @classmethod + def from_directory(cls, project_dir: typing.Union[pathlib.Path, str], options: dict): + return cls(client.instantiate_from_dir(project_dir), options) + + def __init__(self, api_client, options): + self._api_client = api_client + self._options = options + self._info = self._api_client.server_info_query(__version__) + if self._info["is_template"]: + raise TemplateProjectError() + + def build(self): + self._api_client.build(self._options) + + def flash(self): + self._api_client.flash(self._options) + + def transport(self): + return ProjectTransport(self._api_client, self._options) + + +class NotATemplateProjectError(Exception): + """Raised when the API server given to TemplateProject reports is_template=false.""" + + +class TemplateProject: + """Defines a glue interface to interact with a template project through the API Server.""" + + @classmethod + def from_directory(cls, template_project_dir, options): + return cls(client.instantiate_from_dir(template_project_dir), options) + + def __init__(self, api_client, options): + self._api_client = api_client + self._options = options + self._info = self._api_client.server_info_query(__version__) + if not self._info["is_template"]: + raise NotATemplateProjectError() + + def generate_project(self, graph_executor_factory, project_dir): + """Generate a project given GraphRuntimeFactory.""" + model_library_dir = utils.tempdir() + model_library_format_path = model_library_dir.relpath("model.tar") + export_model_library_format(graph_executor_factory, model_library_format_path) + + self._api_client.generate_project( + model_library_format_path=model_library_format_path, + standalone_crt_dir=get_standalone_crt_dir(), + project_dir=project_dir, + options=self._options, + ) + + return GeneratedProject.from_directory(project_dir, self._options) + + +def generate_project( + template_project_dir: typing.Union[pathlib.Path, str], + module: ExportableModule, + generated_project_dir: typing.Union[pathlib.Path, str], + options: dict = None, +): + """Generate a project for an embedded platform that contains the given model. + + Parameters + ---------- + template_project_path : pathlib.Path or str + Path to a template project containing a microTVM Project API server. + + generated_project_path : pathlib.Path or str + Path to a directory to be created and filled with the built project. + + module : ExportableModule + A runtime.Module exportable as Model Library Format. The value returned from tvm.relay.build + or tvm.build. + + options : dict + If given, Project API options given to the microTVM API server found in both + template_project_path and generated_project_path. + + Returns + ------- + GeneratedProject : + A class that wraps the generated project and which can be used to further interact with it. + """ + template = TemplateProject.from_directory(str(template_project_dir), options) + return template.generate_project(module, str(generated_project_dir)) diff --git a/python/tvm/micro/project_api/client.py b/python/tvm/micro/project_api/client.py new file mode 100644 index 000000000000..f650ad946d87 --- /dev/null +++ b/python/tvm/micro/project_api/client.py @@ -0,0 +1,235 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import base64 +import io +import json +import logging +import os +import pathlib +import subprocess +import sys +import typing + +from . import server + +_LOG = logging.getLogger(__name__) + + +class ProjectAPIErrorBase(Exception): + """Base class for all Project API errors.""" + + +class ConnectionShutdownError(ProjectAPIErrorBase): + """Raised when a request is made but the connection has been closed.""" + + +class MalformedReplyError(ProjectAPIErrorBase): + """Raised when the server responds with an invalid reply.""" + + +class MismatchedIdError(ProjectAPIErrorBase): + """Raised when the reply ID does not match the request.""" + + +class ProjectAPIServerNotFoundError(ProjectAPIErrorBase): + """Raised when the Project API server can't be found in the repo.""" + + +class UnsupportedProtocolVersionError(ProjectAPIErrorBase): + """Raised when the protocol version returned by the API server is unsupported.""" + + +class RPCError(ProjectAPIErrorBase): + def __init__(self, request, error): + self.request = request + self.error = error + + def __str__(self): + return f"Calling project API method {self.request['method']}:" "\n" f"{self.error}" + + +class ProjectAPIClient: + """A client for the Project API.""" + + def __init__( + self, + read_file: typing.BinaryIO, + write_file: typing.BinaryIO, + testonly_did_write_request: typing.Optional[typing.Callable] = None, + ): + self.read_file = io.TextIOWrapper(read_file, encoding="UTF-8", errors="strict") + self.write_file = io.TextIOWrapper( + write_file, encoding="UTF-8", errors="strict", write_through=True + ) + self.testonly_did_write_request = testonly_did_write_request + self.next_request_id = 1 + + @property + def is_shutdown(self): + return self.read_file is None + + def shutdown(self): + if self.is_shutdown: + return + + self.read_file.close() + self.write_file.close() + + def _request_reply(self, method, params): + if self.is_shutdown: + raise ConnectionShutdownError("connection already closed") + + request = { + "jsonrpc": "2.0", + "method": method, + "params": params, + "id": self.next_request_id, + } + self.next_request_id += 1 + + request_str = json.dumps(request) + self.write_file.write(request_str) + _LOG.debug("send -> %s", request_str) + self.write_file.write("\n") + if self.testonly_did_write_request: + self.testonly_did_write_request() # Allow test to assert on server processing. + reply_line = self.read_file.readline() + _LOG.debug("recv <- %s", reply_line) + if not reply_line: + self.shutdown() + raise ConnectionShutdownError("got EOF reading reply from API server") + + reply = json.loads(reply_line) + + if reply.get("jsonrpc") != "2.0": + raise MalformedReplyError( + f"Server reply should include 'jsonrpc': '2.0'; " + f"saw jsonrpc={reply.get('jsonrpc')!r}" + ) + + if reply["id"] != request["id"]: + raise MismatchedIdError( + f"Reply id ({reply['id']}) does not equal request id ({request['id']}" + ) + + if "error" in reply: + raise server.JSONRPCError.from_json(f"calling method {method}", reply["error"]) + elif "result" not in reply: + raise MalformedReplyError(f"Expected 'result' key in server reply, got {reply!r}") + + return reply["result"] + + def server_info_query(self, tvm_version: str): + reply = self._request_reply("server_info_query", {"tvm_version": tvm_version}) + if reply["protocol_version"] != server.ProjectAPIServer._PROTOCOL_VERSION: + raise UnsupportedProtocolVersionError( + f'microTVM API Server supports protocol version {reply["protocol_version"]}; ' + f"want {server.ProjectAPIServer._PROTOCOL_VERSION}" + ) + + return reply + + def generate_project( + self, + model_library_format_path: str, + standalone_crt_dir: str, + project_dir: str, + options: dict = None, + ): + return self._request_reply( + "generate_project", + { + "model_library_format_path": model_library_format_path, + "standalone_crt_dir": standalone_crt_dir, + "project_dir": project_dir, + "options": (options if options is not None else {}), + }, + ) + + def build(self, options: dict = None): + return self._request_reply("build", {"options": (options if options is not None else {})}) + + def flash(self, options: dict = None): + return self._request_reply("flash", {"options": (options if options is not None else {})}) + + def open_transport(self, options: dict = None): + return self._request_reply( + "open_transport", {"options": (options if options is not None else {})} + ) + + def close_transport(self): + return self._request_reply("close_transport", {}) + + def read_transport(self, n, timeout_sec): + reply = self._request_reply("read_transport", {"n": n, "timeout_sec": timeout_sec}) + reply["data"] = base64.b85decode(reply["data"]) + return reply + + def write_transport(self, data, timeout_sec): + return self._request_reply( + "write_transport", + {"data": str(base64.b85encode(data), "utf-8"), "timeout_sec": timeout_sec}, + ) + + +# NOTE: windows support untested +SERVER_LAUNCH_SCRIPT_FILENAME = ( + f"launch_microtvm_api_server.{'sh' if os.system != 'win32' else '.bat'}" +) + + +SERVER_PYTHON_FILENAME = "microtvm_api_server.py" + + +def instantiate_from_dir(project_dir: typing.Union[pathlib.Path, str], debug: bool = False): + """Launch server located in project_dir, and instantiate a Project API Client connected to it.""" + args = None + + project_dir = pathlib.Path(project_dir) + + python_script = project_dir / SERVER_PYTHON_FILENAME + if python_script.is_file(): + args = [sys.executable, str(python_script)] + + launch_script = project_dir / SERVER_LAUNCH_SCRIPT_FILENAME + if launch_script.is_file(): + args = [str(launch_script)] + + if args is None: + raise ProjectAPIServerNotFoundError( + f"No Project API server found in project directory: {project_dir}" + "\n" + f"Tried: {SERVER_LAUNCH_SCRIPT_FILENAME}, {SERVER_PYTHON_FILENAME}" + ) + + api_server_read_fd, tvm_write_fd = os.pipe() + tvm_read_fd, api_server_write_fd = os.pipe() + + args.extend(["--read-fd", str(api_server_read_fd), "--write-fd", str(api_server_write_fd)]) + if debug: + args.append("--debug") + + api_server_proc = subprocess.Popen( + args, bufsize=0, pass_fds=(api_server_read_fd, api_server_write_fd), cwd=project_dir + ) + os.close(api_server_read_fd) + os.close(api_server_write_fd) + + return ProjectAPIClient( + os.fdopen(tvm_read_fd, "rb", buffering=0), os.fdopen(tvm_write_fd, "wb", buffering=0) + ) diff --git a/python/tvm/micro/project_api/server.py b/python/tvm/micro/project_api/server.py new file mode 100644 index 000000000000..144f0cb6dee1 --- /dev/null +++ b/python/tvm/micro/project_api/server.py @@ -0,0 +1,776 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines a basic Project API server template. + +This file is meant to be imported or copied into Project API servers, so it should not have any +imports or dependencies outside of things strictly required to run the API server. +""" + +import abc +import argparse +import base64 +import collections +import enum +import io +import json +import logging +import os +import pathlib +import re +import select +import sys +import textwrap +import time +import traceback +import typing + + +_LOG = logging.getLogger(__name__) + + +_ProjectOption = collections.namedtuple("ProjectOption", ("name", "choices", "help")) + + +class ProjectOption(_ProjectOption): + def __new__(cls, name, **kw): + """Override __new__ to force all options except name to be specified as kwargs.""" + assert "name" not in kw + kw["name"] = name + kw.setdefault("choices", None) + return super().__new__(cls, **kw) + + +ServerInfo = collections.namedtuple( + "ServerInfo", ("platform_name", "is_template", "model_library_format_path", "project_options") +) + + +# Timeouts supported by the underlying C++ MicroSession. +# +# session_start_retry_timeout_sec : float +# Number of seconds to wait for the device to send a kSessionStartReply after sending the +# initial session start message. After this time elapses another +# kSessionTerminated-kSessionStartInit train is sent. 0 disables this. +# session_start_timeout_sec : float +# Total number of seconds to wait for the session to be established. After this time, the +# client gives up trying to establish a session and raises an exception. +# session_established_timeout_sec : float +# Number of seconds to wait for a reply message after a session has been established. 0 +# disables this. +TransportTimeouts = collections.namedtuple( + "TransportTimeouts", + [ + "session_start_retry_timeout_sec", + "session_start_timeout_sec", + "session_established_timeout_sec", + ], +) + + +class ErrorCode(enum.IntEnum): + """Enumerates error codes which can be returned. Includes JSON-RPC standard and custom codes.""" + + # Custom (in reserved error code space). + SERVER_ERROR = -32000 # A generic error was raised while processing the request. + + # JSON-RPC standard + PARSE_ERROR = -32700 + INVALID_REQUEST = -32600 + METHOD_NOT_FOUND = -32601 + INVALID_PARAMS = -32602 + INTERNAL_ERROR = -32603 + + +class JSONRPCError(Exception): + """An error class with properties that meet the JSON-RPC error spec.""" + + def __init__(self, code, message, data, client_context=None): + self.code = code + self.message = message + self.data = data + self.client_context = client_context + + def to_json(self): + return { + "code": self.code, + "message": self.message, + "data": self.data, + } + + def __str__(self): + data_str = "" + if self.data: + if isinstance(self.data, dict) and self.data.get("traceback"): + data_str = f'\n{self.data["traceback"]}' + else: + data_str = f"\n{self.data!r}" + return f"JSON-RPC error # {self.code}: {self.message}" + data_str + + @classmethod + def from_json(cls, client_context, json_error): + # Subclasses of ServerError capture exceptions that occur in the Handler, and thus return a + # traceback. The encoding in `json_error` is also slightly different to allow the specific subclass + # to be identified. + found_server_error = False + try: + if ErrorCode(json_error["code"]) == ErrorCode.SERVER_ERROR: + found_server_error = True + except ValueError: + ServerError.from_json(client_context, json_error) + + if found_server_error: + return ServerError.from_json(client_context, json_error) + + return cls( + json_error["code"], + json_error["message"], + json_error.get("data", None), + client_context=client_context, + ) + + +class ServerError(JSONRPCError): + @classmethod + def from_exception(cls, exc, **kw): + to_return = cls(**kw) + to_return.set_traceback(traceback.TracebackException.from_exception(exc).format()) + return to_return + + def __init__(self, message=None, data=None, client_context=None): + if self.__class__ == ServerError: + assert message is not None, "Plain ServerError must have message=" + else: + assert ( + message is None + ), f"ServerError subclasses must not supply message=; got {message!r}" + message = self.__class__.__name__ + + super(ServerError, self).__init__(ErrorCode.SERVER_ERROR, message, data) + self.client_context = client_context + + def __str__(self): + context_str = f"{self.client_context}: " if self.client_context is not None else "" + super_str = super(ServerError, self).__str__() + return context_str + super_str + + def set_traceback(self, traceback): + if self.data is None: + self.data = {} + + if "traceback" not in self.data: + # NOTE: TVM's FFI layer reorders Python stack traces several times and strips + # intermediary lines that start with "Traceback". This logic adds a comment to the first + # stack frame to explicitly identify the first stack frame line that occurs on the server. + traceback_list = list(traceback) + + # The traceback list contains one entry per stack frame, and each entry contains 1-2 lines: + # File "path/to/file", line 123, in : + # + # We want to place a comment on the first line of the outermost frame to indicate this is the + # server-side stack frame. + first_frame_list = traceback_list[1].split("\n") + self.data["traceback"] = ( + traceback_list[0] + + f"{first_frame_list[0]} # <--- Outermost server-side stack frame\n" + + "\n".join(first_frame_list[1:]) + + "".join(traceback_list[2:]) + ) + + @classmethod + def from_json(cls, client_context, json_error): + assert json_error["code"] == ErrorCode.SERVER_ERROR + + for sub_cls in cls.__subclasses__(): + if sub_cls.__name__ == json_error["message"]: + return sub_cls( + data=json_error.get("data"), + client_context=client_context, + ) + + return cls( + json_error["message"], data=json_error.get("data"), client_context=client_context + ) + + +class TransportClosedError(ServerError): + """Raised when a transport can no longer be used due to underlying I/O problems.""" + + +class IoTimeoutError(ServerError): + """Raised when the I/O operation could not be completed before the timeout. + + Specifically: + - when no data could be read before the timeout + - when some of the write data could be written before the timeout + + Note the asymmetric behavior of read() vs write(), since in one case the total length of the + data to transfer is known. + """ + + +class UnsupportedTVMVersionError(ServerError): + """Raised when the version of TVM supplied to server_info_query is unsupported.""" + + +class ProjectAPIHandler(metaclass=abc.ABCMeta): + """The interface class for all Project API implementations. + + Extend this class in your microtvm_api_server.py and implement each function defined here. + """ + + @abc.abstractmethod + def server_info_query(self, tvm_version: str) -> ServerInfo: + """Initial request issued by TVM to retrieve metadata about this API server and project. + + Should this API server not + + Parameters + ---------- + tvm_version : str + The value of tvm.__version__. + + Returns + ------- + ServerInfo : + A ServerInfo namedtuple containing the metadata needed by TVM. + + Raises + ------ + UnsupportedTVMVersionError : + When tvm_version indicates a known-unsupported version of TVM. + """ + raise NotImplementedError() + + @abc.abstractmethod + def generate_project( + self, + model_library_format_path: pathlib.Path, + standalone_crt_dir: pathlib.Path, + project_dir: pathlib.Path, + options: dict, + ): + """Generate a project from the given artifacts, copying ourselves to that project. + + Parameters + ---------- + model_library_format_path : pathlib.Path + Path to the Model Library Format tar archive. + standalone_crt_dir : pathlib.Path + Path to the root directory of the "standalone_crt" TVM build artifact. This contains the + TVM C runtime. + project_dir : pathlib.Path + Path to a nonexistent directory which should be created and filled with the generated + project. + options : dict + Dict mapping option name to ProjectOption. + """ + raise NotImplementedError() + + @abc.abstractmethod + def build(self, options: dict): + """Build the project, enabling the flash() call to made. + + Parameters + ---------- + options : Dict[str, ProjectOption] + ProjectOption which may influence the build, keyed by option name. + """ + raise NotImplementedError() + + @abc.abstractmethod + def flash(self, options: dict): + """Program the project onto the device. + + Parameters + ---------- + options : Dict[str, ProjectOption] + ProjectOption which may influence the programming process, keyed by option name. + """ + raise NotImplementedError() + + @abc.abstractmethod + def open_transport(self, options: dict) -> TransportTimeouts: + """Open resources needed for the transport layer. + + This function might e.g. open files or serial ports needed in write_transport or read_transport. + + Calling this function enables the write_transport and read_transport calls. If the + transport is not open, this method is a no-op. + + Parameters + ---------- + options : Dict[str, ProjectOption] + ProjectOption which may influence the programming process, keyed by option name. + """ + raise NotImplementedError() + + @abc.abstractmethod + def close_transport(self): + """Close resources needed to operate the transport layer. + + This function might e.g. close files or serial ports needed in write_transport or read_transport. + + Calling this function disables the write_transport and read_transport calls. If the + transport is not open, this method is a no-op. + """ + raise NotImplementedError() + + @abc.abstractmethod + def read_transport(self, n: int, timeout_sec: typing.Union[float, type(None)]) -> bytes: + """Read data from the transport. + + Parameters + ---------- + n : int + The exact number of bytes to read from the transport. + timeout_sec : Union[float, None] + Number of seconds to wait for at least one byte to be written before timing out. If + timeout_sec is 0, write should attempt to service the request in a non-blocking fashion. + If timeout_sec is None, write should block until all `n` bytes of data can be returned. + + Returns + ------- + bytes : + Data read from the channel. Should be exactly `n` bytes long. + + Raises + ------ + TransportClosedError : + When the transport layer determines that the transport can no longer send or receive + data due to an underlying I/O problem (i.e. file descriptor closed, cable removed, etc). + + IoTimeoutError : + When `timeout_sec` elapses without receiving any data. + """ + raise NotImplementedError() + + @abc.abstractmethod + def write_transport(self, data: bytes, timeout_sec: float): + """Write data to the transport. + + This function should either write all bytes in `data` or raise an exception. + + Parameters + ---------- + data : bytes + The data to write over the channel. + timeout_sec : Union[float, None] + Number of seconds to wait for all bytes to be written before timing out. If timeout_sec + is 0, write should attempt to service the request in a non-blocking fashion. If + timeout_sec is None, write should block until it has written all data. + + Raises + ------ + TransportClosedError : + When the transport layer determines that the transport can no longer send or receive + data due to an underlying I/O problem (i.e. file descriptor closed, cable removed, etc). + + IoTimeoutError : + When `timeout_sec` elapses without receiving any data. + """ + raise NotImplementedError() + + +class ProjectAPIServer: + """Base class for Project API Servers. + + This API server implements communication using JSON-RPC 2.0: https://www.jsonrpc.org/specification + + Suggested use of this class is to import this module or copy this file into Project Generator + implementations, then instantiate it with server.start(). + + This RPC server is single-threaded, blocking, and one-request-at-a-time. Don't get anxious. + """ + + _PROTOCOL_VERSION = 1 + + def __init__( + self, read_file: typing.BinaryIO, write_file: typing.BinaryIO, handler: ProjectAPIHandler + ): + """Initialize a new ProjectAPIServer. + + Parameters + ---------- + read_file : BinaryIO + A file-like object used to read binary data from the client. + write_file : BinaryIO + A file-like object used to write binary data to the client. + handler : ProjectAPIHandler + A class which extends the abstract class ProjectAPIHandler and implements the server RPC + functions. + """ + self._read_file = io.TextIOWrapper(read_file, encoding="UTF-8", errors="strict") + self._write_file = io.TextIOWrapper( + write_file, encoding="UTF-8", errors="strict", write_through=True + ) + self._handler = handler + + def serve_forever(self): + """Serve requests until no more are available.""" + has_more = True + while has_more: + has_more = self.serve_one_request() + + def serve_one_request(self): + """Read, process, and reply to a single request from read_file. + + When errors occur reading the request line or loading the request into JSON, they are + propagated to the caller (the stream is then likely corrupted and no further requests + should be served. When errors occur past this point, they are caught and send back to the + client. + + Return + ---------- + bool : + True when more data could be read from read_file, False otherwise. + """ + try: + line = self._read_file.readline() + _LOG.debug("read request <- %s", line) + if not line: + return False + + request = json.loads(line) + + except EOFError: + _LOG.error("EOF") + return False + + except Exception as exc: + _LOG.error("Caught error reading request", exc_info=1) + return False + + did_validate = False + try: + self._validate_request(request) + did_validate = True + self._dispatch_request(request) + except JSONRPCError as exc: + if isinstance(exc, ServerError): + exc.set_traceback(traceback.TracebackException.from_exception(exc).format()) + request_id = None if not did_validate else request.get("id") + self._reply_error(request_id, exc) + return did_validate + except Exception as exc: + message = "validating request" + if did_validate: + message = f"calling method {request['method']}" + + exc = ServerError.from_exception(exc, message=message) + request_id = None if not isinstance(request, dict) else request.get("id") + self._reply_error(request_id, exc) + return did_validate + + return True + + VALID_METHOD_RE = re.compile("^[a-zA-Z0-9_]+$") + + def _validate_request(self, request): + if type(request) is not dict: + raise JSONRPCError( + ErrorCode.INVALID_REQUEST, f"request: want dict; got {request!r}", None + ) + + jsonrpc = request.get("jsonrpc") + if jsonrpc != "2.0": + raise JSONRPCError( + ErrorCode.INVALID_REQUEST, f'request["jsonrpc"]: want "2.0"; got {jsonrpc!r}', None + ) + + method = request.get("method") + if type(method) != str: + raise JSONRPCError( + ErrorCode.INVALID_REQUEST, f'request["method"]: want str; got {method!r}', None + ) + + if not self.VALID_METHOD_RE.match(method): + raise JSONRPCError( + ErrorCode.INVALID_REQUEST, + f'request["method"]: should match regex {self.VALID_METHOD_RE.pattern}; got {method!r}', + None, + ) + + params = request.get("params") + if type(params) != dict: + raise JSONRPCError( + ErrorCode.INVALID_REQUEST, f'request["params"]: want dict; got {type(params)}', None + ) + + request_id = request.get("id") + if type(request_id) not in (str, int, type(None)): + raise JSONRPCError( + ErrorCode.INVALID_REQUEST, + f'request["id"]: want str, number, null; got {request_id!r}', + None, + ) + + def _dispatch_request(self, request): + method = request["method"] + + interface_method = getattr(ProjectAPIHandler, method, None) + if interface_method is None: + raise JSONRPCError( + ErrorCode.METHOD_NOT_FOUND, f'{request["method"]}: no such method', None + ) + + has_preprocessing = True + dispatch_method = getattr(self, f"_dispatch_{method}", None) + if dispatch_method is None: + dispatch_method = getattr(self._handler, method) + has_preprocessing = False + + request_params = request["params"] + params = {} + + for var_name, var_type in typing.get_type_hints(interface_method).items(): + if var_name == "self" or var_name == "return": + continue + + # NOTE: types can only be JSON-compatible types, so var_type is expected to be of type 'type'. + if var_name not in request_params: + raise JSONRPCError( + ErrorCode.INVALID_PARAMS, + f'method {request["method"]}: parameter {var_name} not given', + None, + ) + + param = request_params[var_name] + if not has_preprocessing and not isinstance(param, var_type): + raise JSONRPCError( + ErrorCode.INVALID_PARAMS, + f'method {request["method"]}: parameter {var_name}: want {var_type!r}, got {type(param)!r}', + None, + ) + + params[var_name] = param + + extra_params = [p for p in request["params"] if p not in params] + if extra_params: + raise JSONRPCError( + ErrorCode.INVALID_PARAMS, + f'{request["method"]}: extra parameters: {", ".join(extra_params)}', + None, + ) + + return_value = dispatch_method(**params) + self._write_reply(request["id"], result=return_value) + + def _write_reply(self, request_id, result=None, error=None): + reply_dict = { + "jsonrpc": "2.0", + "id": request_id, + } + + if error is not None: + assert ( + result is None + ), f"Want either result= or error=, got result={result!r} and error={error!r})" + reply_dict["error"] = error + else: + reply_dict["result"] = result + + reply_str = json.dumps(reply_dict) + _LOG.debug("write reply -> %r", reply_dict) + self._write_file.write(reply_str) + self._write_file.write("\n") + + def _reply_error(self, request_id, exception): + self._write_reply(request_id, error=exception.to_json()) + + def _dispatch_generate_project( + self, model_library_format_path, standalone_crt_dir, project_dir, options + ): + return self._handler.generate_project( + pathlib.Path(model_library_format_path), + pathlib.Path(standalone_crt_dir), + pathlib.Path(project_dir), + options, + ) + + def _dispatch_server_info_query(self, tvm_version): + query_reply = self._handler.server_info_query(tvm_version) + to_return = query_reply._asdict() + if to_return["model_library_format_path"] is not None: + to_return["model_library_format_path"] = str(to_return["model_library_format_path"]) + to_return.setdefault("protocol_version", self._PROTOCOL_VERSION) + to_return["project_options"] = [o._asdict() for o in query_reply.project_options] + return to_return + + def _dispatch_open_transport(self, options): + reply = self._handler.open_transport(options) + return {"timeouts": reply._asdict()} + + def _dispatch_read_transport(self, n, timeout_sec): + reply_data = self._handler.read_transport(n, timeout_sec) + return {"data": str(base64.b85encode(reply_data), "utf-8")} + + def _dispatch_write_transport(self, data, timeout_sec): + self._handler.write_transport(base64.b85decode(data), timeout_sec) + + +def _await_nonblocking_ready(rlist, wlist, timeout_sec=None, end_time=None): + if end_time is None: + return True + + if timeout_sec is None: + timeout_sec = max(0, end_time - time.monotonic()) + rlist, wlist, xlist = select.select(rlist, wlist, rlist + wlist, timeout_sec) + if not rlist and not wlist and not xlist: + raise IoTimeoutError() + + return True + + +def read_with_timeout(fd, n, timeout_sec): + """Read data from a file descriptor, with timeout. + + This function is intended as a helper function for implementations of ProjectAPIHandler + read_transport. Tested on Linux and OS X. Not tested on Windows. + + Parameters + ---------- + fd : int + File descriptor to read from. Must be opened in non-blocking mode (e.g. with O_NONBLOCK) + if timeout_sec is not None. + + n : int + Maximum number of bytes to read. + + timeout_sec : float or None + If not None, maximum number of seconds to wait before raising IoTimeoutError. + + Returns + ------- + bytes : + If at least one byte was received before timeout_sec, returns a bytes object with length + in [1, n]. If timeout_sec is None, returns the equivalent of os.read(fd, n). + + Raises + ------ + IoTimeoutException : + When timeout_sec is not None and that number of seconds elapses before any data is read. + """ + end_time = None if timeout_sec is None else time.monotonic() + timeout_sec + + while True: + _await_nonblocking_ready([fd], [], end_time=end_time) + try: + to_return = os.read(fd, n) + break + except BlockingIOError: + pass + + # When EOF is reached, close the file. + if not to_return: + os.close(fd) + raise TransportClosedError() + + return to_return + + +def write_with_timeout(fd, data, timeout_sec): + """Write data to a file descriptor, with timeout. + + This function is intended as a helper function for implementations of ProjectAPIHandler + write_transport. Tested on Linux and OS X. Not tested on Windows. + + Parameters + ---------- + fd : int + File descriptor to read from. Must be opened in non-blocking mode (e.g. with O_NONBLOCK) + if timeout_sec is not None. + + data : bytes + Data to write. + + timeout_sec : float or None + If not None, maximum number of seconds to wait before raising IoTimeoutError. + + Returns + ------- + int : + The number of bytes written to the file descriptor, if any bytes were written. A value + in [1, len(data)]. If timeout_sec is None, returns the equivalent of os.write(fd, data). + + Raises + ------ + IoTimeoutException : + When timeout_sec is not None and that number of seconds elapses before any data is read. + """ + end_time = None if timeout_sec is None else time.monotonic() + timeout_sec + + num_written = 0 + while data: + try: + _await_nonblocking_ready([], [fd], end_time=end_time) + except IoTimeoutError as exc: + if num_written: + return num_written + + raise exc + + num_written_this_cycle = os.write(fd, data) + + if not num_written_this_cycle: + os.close(fd) + raise base.TransportClosedError() + + data = data[num_written_this_cycle:] + num_written += num_written_this_cycle + + return num_written + + +def main(handler: ProjectAPIHandler, argv: typing.List[str] = None): + """Start a Project API server. + + Parameters + ---------- + argv : list[str] + Command-line parameters to this program. If not given, sys.argv is used. + handler : ProjectAPIHandler + Handler class that implements the API server RPC calls. + """ + if argv is None: + argv = sys.argv[1:] + + parser = argparse.ArgumentParser(description="Generic TVM Project API server entry point") + parser.add_argument( + "--read-fd", + type=int, + required=True, + help="Numeric file descriptor where RPC requests should be read.", + ) + parser.add_argument( + "--write-fd", + type=int, + required=True, + help="Numeric file descriptor where RPC replies should be written.", + ) + parser.add_argument( + "--debug", action="store_true", help="When given, configure logging at DEBUG level." + ) + args = parser.parse_args() + + logging.basicConfig(level="DEBUG" if args.debug else "INFO", stream=sys.stderr) + + read_file = os.fdopen(args.read_fd, "rb", buffering=0) + write_file = os.fdopen(args.write_fd, "wb", buffering=0) + + server = ProjectAPIServer(read_file, write_file, handler) + server.serve_forever() diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py index 78bf03379939..d4ad5b84fb76 100644 --- a/python/tvm/micro/session.py +++ b/python/tvm/micro/session.py @@ -60,8 +60,6 @@ class Session: def __init__( self, - binary=None, - flasher=None, transport_context_manager=None, session_name="micro-rpc", timeout_override=None, @@ -70,12 +68,6 @@ def __init__( Parameters ---------- - binary : MicroBinary - If given, `flasher` must also be given. During session initialization, this binary will - be flashed to the device before the transport is created. - flasher : Flasher - If given, `binary` must also be given. Used to flash `binary` during session - initialization. transport_context_manager : ContextManager[transport.Transport] If given, `flasher` and `binary` should not be given. On entry, this context manager should establish a tarnsport between this TVM instance and the device. @@ -85,8 +77,6 @@ def __init__( If given, TransportTimeouts that govern the way Receive() behaves. If not given, this is determined by calling has_flow_control() on the transport. """ - self.binary = binary - self.flasher = flasher self.transport_context_manager = transport_context_manager self.session_name = session_name self.timeout_override = timeout_override @@ -106,12 +96,11 @@ def _wrap_transport_read(self, n, timeout_microsec): return bytes([]) def _wrap_transport_write(self, data, timeout_microsec): - try: - return self.transport.write( - data, float(timeout_microsec) / 1e6 if timeout_microsec is not None else None - ) - except IoTimeoutError: - return 0 + self.transport.write( + data, float(timeout_microsec) / 1e6 if timeout_microsec is not None else None + ) + + return len(data) # TODO(areusch): delete def __enter__(self): """Initialize this session and establish an RPC session with the on-device RPC server. @@ -121,9 +110,6 @@ def __enter__(self): Session : Returns self. """ - if self.flasher is not None: - self.transport_context_manager = self.flasher.flash(self.binary) - self.transport = TransportLogger( self.session_name, self.transport_context_manager, level=logging.DEBUG ).__enter__() diff --git a/python/tvm/micro/transport/base.py b/python/tvm/micro/transport.py similarity index 84% rename from python/tvm/micro/transport/base.py rename to python/tvm/micro/transport.py index fdc7e9b2afce..8e95ff7ea77a 100644 --- a/python/tvm/micro/transport/base.py +++ b/python/tvm/micro/transport.py @@ -18,50 +18,18 @@ """Defines abstractions and implementations of the RPC transport used with micro TVM.""" import abc -import collections import logging import string import typing -_LOG = logging.getLogger(__name__) - - -class TransportClosedError(Exception): - """Raised when a transport can no longer be used due to underlying I/O problems.""" +from .project_api.server import IoTimeoutError, TransportTimeouts +from .project_api.server import TransportClosedError -class IoTimeoutError(Exception): - """Raised when the I/O operation could not be completed before the timeout. +_ = TransportClosedError # work around pylint unused-import error - Specifically: - - when no data could be read before the timeout - - when some of the write data could be written before the timeout - Note the asymmetric behavior of read() vs write(), since in one case the total length of the - data to transfer is known. - """ - - -# Timeouts supported by the underlying C++ MicroSession. -# -# session_start_retry_timeout_sec : float -# Number of seconds to wait for the device to send a kSessionStartReply after sending the -# initial session start message. After this time elapses another -# kSessionTerminated-kSessionStartInit train is sent. 0 disables this. -# session_start_timeout_sec : float -# Total number of seconds to wait for the session to be established. After this time, the -# client gives up trying to establish a session and raises an exception. -# session_established_timeout_sec : float -# Number of seconds to wait for a reply message after a session has been established. 0 -# disables this. -TransportTimeouts = collections.namedtuple( - "TransportTimeouts", - [ - "session_start_retry_timeout_sec", - "session_start_timeout_sec", - "session_established_timeout_sec", - ], -) +_LOG = logging.getLogger(__name__) def debug_transport_timeouts(session_start_retry_timeout_sec=0): @@ -263,7 +231,7 @@ def read(self, n, timeout_sec): def write(self, data, timeout_sec): timeout_str = f"{timeout_sec:5.2f}s" if timeout_sec is not None else " None " try: - bytes_written = self.child.write(data, timeout_sec) + self.child.write(data, timeout_sec) except IoTimeoutError: self.logger.log( self.level, @@ -286,14 +254,14 @@ def write(self, data, timeout_sec): ) raise err - hex_lines = self._to_hex(data[:bytes_written]) + hex_lines = self._to_hex(data) if len(hex_lines) > 1: self.logger.log( self.level, "%s: write {%s} <- [%3d B]:\n%s", self.name, timeout_str, - bytes_written, + len(data), "\n".join(hex_lines), ) else: @@ -302,11 +270,9 @@ def write(self, data, timeout_sec): "%s: write {%s} <- [%3d B]: %s", self.name, timeout_str, - bytes_written, + len(data), hex_lines[0], ) - return bytes_written - TransportContextManager = typing.ContextManager[Transport] diff --git a/python/tvm/micro/transport/__init__.py b/python/tvm/micro/transport/__init__.py deleted file mode 100644 index dffe9ae32792..000000000000 --- a/python/tvm/micro/transport/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines abstractions and implementations related to the microTVM RPC transport layer.""" - -from .base import IoTimeoutError -from .base import Transport -from .base import TransportClosedError -from .base import TransportLogger -from .base import TransportTimeouts -from .base import debug_transport_timeouts -from .debug import DebugWrapperTransport -from .subprocess import SubprocessTransport diff --git a/python/tvm/micro/transport/debug.py b/python/tvm/micro/transport/debug.py deleted file mode 100644 index 71e12c7ed391..000000000000 --- a/python/tvm/micro/transport/debug.py +++ /dev/null @@ -1,64 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines a wrapper Transport class that launches a debugger before opening.""" - -from .base import Transport, TransportTimeouts - - -class DebugWrapperTransport(Transport): - """A Transport wrapper class that launches a debugger before opening the transport. - - This is primiarly useful when debugging the other end of a SubprocessTransport. It allows you - to pipe data through the GDB process to drive the subprocess with a debugger attached. - """ - - def __init__(self, debugger, transport, disable_session_start_retry=False): - self.debugger = debugger - self.transport = transport - self.disable_session_start_retry = disable_session_start_retry - - def timeouts(self): - child_timeouts = self.transport.timeouts() - return TransportTimeouts( - session_start_retry_timeout_sec=( - 0 - if self.disable_session_start_retry - else child_timeouts.session_start_retry_timeout_sec - ), - session_start_timeout_sec=0, - session_established_timeout_sec=0, - ) - - def open(self): - self.debugger.start() - - try: - self.transport.open() - except Exception: - self.debugger.stop() - raise - - def write(self, data, timeout_sec): - return self.transport.write(data, timeout_sec) - - def read(self, n, timeout_sec): - return self.transport.read(n, timeout_sec) - - def close(self): - self.transport.close() - self.debugger.stop() diff --git a/python/tvm/micro/transport/file_descriptor.py b/python/tvm/micro/transport/file_descriptor.py deleted file mode 100644 index 58c4026f6704..000000000000 --- a/python/tvm/micro/transport/file_descriptor.py +++ /dev/null @@ -1,119 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines an implementation of Transport that uses file descriptors.""" - -import fcntl -import os -import select -import time -from . import base - - -class FdConfigurationError(Exception): - """Raised when specified file descriptors can't be placed in non-blocking mode.""" - - -class FdTransport(base.Transport): - """A Transport implementation that implements timeouts using non-blocking I/O.""" - - @classmethod - def _validate_configure_fd(cls, file_descriptor): - file_descriptor = ( - file_descriptor if isinstance(file_descriptor, int) else file_descriptor.fileno() - ) - flag = fcntl.fcntl(file_descriptor, fcntl.F_GETFL) - if flag & os.O_NONBLOCK != 0: - return file_descriptor - - fcntl.fcntl(file_descriptor, fcntl.F_SETFL, os.O_NONBLOCK | flag) - new_flag = fcntl.fcntl(file_descriptor, fcntl.F_GETFL) - if (new_flag & os.O_NONBLOCK) == 0: - raise FdConfigurationError( - f"Cannot set file descriptor {file_descriptor} to non-blocking" - ) - return file_descriptor - - def __init__(self, read_fd, write_fd, timeouts): - self.read_fd = self._validate_configure_fd(read_fd) - self.write_fd = self._validate_configure_fd(write_fd) - self._timeouts = timeouts - - def timeouts(self): - return self._timeouts - - def open(self): - pass - - def close(self): - if self.read_fd is not None: - os.close(self.read_fd) - self.read_fd = None - - if self.write_fd is not None: - os.close(self.write_fd) - self.write_fd = None - - def _await_ready(self, rlist, wlist, timeout_sec=None, end_time=None): - if end_time is None: - return True - - if timeout_sec is None: - timeout_sec = max(0, end_time - time.monotonic()) - rlist, wlist, xlist = select.select(rlist, wlist, rlist + wlist, timeout_sec) - if not rlist and not wlist and not xlist: - raise base.IoTimeoutError() - - return True - - def read(self, n, timeout_sec): - if self.read_fd is None: - raise base.TransportClosedError() - - end_time = None if timeout_sec is None else time.monotonic() + timeout_sec - - while True: - self._await_ready([self.read_fd], [], end_time=end_time) - try: - to_return = os.read(self.read_fd, n) - break - except BlockingIOError: - pass - - if not to_return: - self.close() - raise base.TransportClosedError() - - return to_return - - def write(self, data, timeout_sec): - if self.write_fd is None: - raise base.TransportClosedError() - - end_time = None if timeout_sec is None else time.monotonic() + timeout_sec - - data_len = len(data) - while data: - self._await_ready(end_time, [], [self.write_fd]) - num_written = os.write(self.write_fd, data) - if not num_written: - self.close() - raise base.TransportClosedError() - - data = data[num_written:] - - return data_len diff --git a/python/tvm/micro/transport/serial.py b/python/tvm/micro/transport/serial.py deleted file mode 100644 index dc107d68abc2..000000000000 --- a/python/tvm/micro/transport/serial.py +++ /dev/null @@ -1,135 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines a Transport implementation using pyserial.""" - -import atexit -import time -import serial -import serial.tools.list_ports -from .base import IoTimeoutError, Transport, TransportTimeouts - - -_DEFAULT_SERIAL_TIMEOUTS = TransportTimeouts( - session_start_retry_timeout_sec=5, - session_start_timeout_sec=10.0, - session_established_timeout_sec=30.0, -) - - -class SerialTransport(Transport): - """A Transport implementation using pySerial.""" - - _OPEN_PORTS = [] - - @classmethod - def close_atexit(cls): - """Close all serial ports before exit. - - Some USB-UART kernel drivers are particularly sensitive to being left open (i.e. require - unplugging and replugging of attached hardware or reboot of machine); try very hard to - close all serial ports at exit. - """ - for port in cls._OPEN_PORTS: - try: - port.close() - except Exception: # pylint: disable=broad-except - _LOG.warn("exception closing port", exc_info=True) - - cls._OPEN_PORTS = [] - - def __init__(self, grep=None, port_path=None, timeouts=None, **kw): - self._port_path = port_path - self._grep = grep - self._timeouts = timeouts if timeouts is not None else _DEFAULT_SERIAL_TIMEOUTS - self._kw = kw - if self._port_path is None and self._grep is None: - raise SerialPortNotFoundError("Must specify one of grep= or port_path=") - - def timeouts(self): - return self._timeouts - - def open(self): - if self._port_path is not None: - port_path = self._port_path - else: - ports = list(serial.tools.list_ports.grep(self._grep)) - if len(ports) != 1: - raise SerialPortNotFoundError( - f"grep expression should find 1 serial port; found {ports!r}" - ) - - port_path = ports[0].device - - self._port = serial.Serial(port_path, timeout=0.1, exclusive=True, **self._kw) - self._port.cancel_read() - self._port.reset_input_buffer() - self._port.reset_output_buffer() - self._OPEN_PORTS.append(self._port) - - def close(self): - if self._port is None: - return - - self._port.close() - self._OPEN_PORTS.remove(self._port) - self._port = None - - def read(self, n, timeout_sec): - if timeout_sec is None: - self._port.timeout = None - in_waiting = self._port.in_waiting - if in_waiting > 0: - return self._port.read(min(n, in_waiting)) - return self._port.read(1) - - end_time = time.monotonic() + timeout_sec - to_return = bytearray() - while True: - timeout_remaining = end_time - time.monotonic() - if timeout_sec != 0 and timeout_remaining < 0: - break - - # Read until *something* can be returned. If nothing is sent within 5 chars' time, stop. - # 5 is an arbitrary number. - self._port.timeout = 1 / self._port.baudrate * 5 - try: - data = self._port.read(n if timeout_sec != 0 else 1) - if not data and to_return: - break - - to_return.extend(data) - except serial.SerialTimeoutException: - if to_return: - break - - if not to_return: - raise IoTimeoutError() - - return to_return - - def write(self, data, timeout_sec): - self._port.write_timeout = timeout_sec - try: - to_return = self._port.write(data) - self._port.flush() - return to_return - except serial.SerialTimeoutException: - raise IoTimeoutError() - - -atexit.register(SerialTransport.close_atexit) diff --git a/python/tvm/micro/transport/subprocess.py b/python/tvm/micro/transport/subprocess.py deleted file mode 100644 index 4de1fa1266d3..000000000000 --- a/python/tvm/micro/transport/subprocess.py +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines an implementation of Transport that uses subprocesses.""" - -import subprocess -from . import base -from . import file_descriptor - - -class SubprocessFdTransport(file_descriptor.FdTransport): - def timeouts(self): - raise NotImplementedError() - - -class SubprocessTransport(base.Transport): - """A Transport implementation that uses a subprocess's stdin/stdout as the channel.""" - - def __init__(self, args, max_startup_latency_sec=5.0, max_latency_sec=5.0, **kwargs): - self.max_startup_latency_sec = max_startup_latency_sec - self.max_latency_sec = max_latency_sec - self.args = args - self.kwargs = kwargs - self.popen = None - self.child_transport = None - - def timeouts(self): - return base.TransportTimeouts( - session_start_retry_timeout_sec=0, - session_start_timeout_sec=self.max_startup_latency_sec, - session_established_timeout_sec=self.max_latency_sec, - ) - - def open(self): - self.kwargs["stdout"] = subprocess.PIPE - self.kwargs["stdin"] = subprocess.PIPE - self.kwargs["bufsize"] = 0 - self.popen = subprocess.Popen(self.args, **self.kwargs) - self.child_transport = SubprocessFdTransport( - self.popen.stdout, self.popen.stdin, self.timeouts() - ) - - def write(self, data, timeout_sec): - return self.child_transport.write(data, timeout_sec) - - def read(self, n, timeout_sec): - return self.child_transport.read(n, timeout_sec) - - def close(self): - if self.child_transport is not None: - self.child_transport.close() - - self.popen.terminate() diff --git a/python/tvm/micro/transport/wakeup.py b/python/tvm/micro/transport/wakeup.py deleted file mode 100644 index 418f8bdbb27a..000000000000 --- a/python/tvm/micro/transport/wakeup.py +++ /dev/null @@ -1,79 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines an implementation of Transport that uses subprocesses.""" - -import logging -import time -from . import base - - -_LOG = logging.getLogger(__name__) - - -class WakeupTransport(base.Transport): - """A Transport implementation that waits for a "wakeup sequence" from the remote end.""" - - def __init__(self, child_transport, wakeup_sequence): - self.child_transport = child_transport - self.wakeup_sequence = bytes(wakeup_sequence) - self.wakeup_sequence_buffer = bytearray() - self.line_start_index = 0 - self.found_wakeup_sequence = False - - def open(self): - return self.child_transport.open() - - def close(self): - return self.child_transport.close() - - def timeouts(self): - return self.child_transport.timeouts() - - def _await_wakeup(self, end_time): - def _time_remaining(): - if end_time is None: - return None - return max(0, end_time - time.monotonic()) - - if not self.found_wakeup_sequence: - while self.wakeup_sequence not in self.wakeup_sequence_buffer: - x = self.child_transport.read(1, _time_remaining()) - self.wakeup_sequence_buffer.extend(x) - if x[0] in (b"\n", b"\xff"): - _LOG.debug("%s", self.wakeup_sequence_buffer[self.line_start_index : -1]) - self.line_start_index = len(self.wakeup_sequence_buffer) - - _LOG.info("remote side woke up!") - self.found_wakeup_sequence = True - time.sleep(0.2) - - return _time_remaining() - - def read(self, n, timeout_sec): - if not self.found_wakeup_sequence: - end_time = None if timeout_sec is None else time.monotonic() + timeout_sec - timeout_sec = self._await_wakeup(end_time) - - return self.child_transport.read(n, timeout_sec) - - def write(self, data, timeout_sec): - if not self.found_wakeup_sequence: - end_time = None if timeout_sec is None else time.monotonic() + timeout_sec - timeout_sec = self._await_wakeup(end_time) - - return self.child_transport.write(data, timeout_sec) diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index a8f1a993552e..c7b6c60849a1 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -433,8 +433,7 @@ def get_calibration_data(mod, data): mod = _ffi_api.get_calibrate_module(mod) mod = transform.Inline()(mod) - ref_ex = build_module.create_executor("graph", mod=mod, device=cpu(0)) - ref_res = ref_ex.evaluate()(**data) + ref_res = build_module.create_executor("graph", mod=mod, device=cpu(0)).evaluate()(**data) calib_data = {} for gvar, indices in output_map.items(): diff --git a/python/tvm/relay/analysis/sparse_conv2d.py b/python/tvm/relay/analysis/sparse_conv2d.py index 11278bddca33..1862ded831f6 100644 --- a/python/tvm/relay/analysis/sparse_conv2d.py +++ b/python/tvm/relay/analysis/sparse_conv2d.py @@ -54,7 +54,9 @@ def _search_conv2d_op_weight(expr): return _ffi_api.search_conv2d_op_weight(expr) -def process_params(expr, params, block_size, sparsity_threshold, layout): +def process_params( + expr, params, block_size, sparsity_threshold, layout, kernel_size, reg_task_input=True +): """Process parameters of conv2d from dense to sparse. Parameters @@ -86,14 +88,18 @@ def process_params(expr, params, block_size, sparsity_threshold, layout): for name in weight_names: name = str(name) w_np = params[name].numpy() - # currently only support conv2d_1*1 - if not ( - (w_np.shape[0] == 1 and w_np.shape[1] == 1) - or (w_np.shape[2] == 1 and w_np.shape[3] == 1) - ): + + if layout == "NHWC": # HWIO + weight_kernel = (w_np.shape[0], w_np.shape[1]) + elif layout == "NCHW": # OIHW + weight_kernel = (w_np.shape[2], w_np.shape[3]) + if weight_kernel[0] != weight_kernel[1]: continue - sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size) - if sparsity >= sparsity_threshold: + + if weight_kernel[0] == kernel_size == 1: + sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size) + if sparsity < sparsity_threshold: + continue if layout == "NHWC": w_np = w_np.squeeze().T elif layout == "NCHW": @@ -108,19 +114,31 @@ def process_params(expr, params, block_size, sparsity_threshold, layout): ) else: sparse_weight_data = sparse_weight.data + elif weight_kernel[0] == kernel_size == 3: + if layout == "NHWC": # HWIO + w_np = w_np.reshape((-1, w_np.shape[-1])).T + elif layout == "NCHW": # OIHW + w_np = w_np.reshape((w_np.shape[0], -1)) + sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size) + if 1 - (sparse_weight.nnz / w_np.size) < sparsity_threshold: + continue + sparse_weight_data = sparse_weight.data + else: + continue - # remove dense weight - del params[name] - memo.weight_name.append(name) - memo.weight_shape.append( - list(sparse_weight_data.shape) - + list(sparse_weight.indices.shape) - + list(sparse_weight.indptr.shape) - ) - params[name + ".data"] = tvm.nd.array(sparse_weight_data) - params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) - params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) - + # remove dense weight + del params[name] + memo.weight_name.append(name) + memo.weight_shape.append( + list(sparse_weight_data.shape) + + list(sparse_weight.indices.shape) + + list(sparse_weight.indptr.shape) + ) + params[name + ".data"] = tvm.nd.array(sparse_weight_data) + params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) + params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) + + if reg_task_input: prefix = "sparse_conv2d_bsr_%d_%d_%d_%d_%d_%d_" % ( w_np.shape[0], w_np.shape[1], diff --git a/python/tvm/relay/backend/graph_executor_codegen.py b/python/tvm/relay/backend/graph_executor_codegen.py index 11274b97197f..58717a0ab482 100644 --- a/python/tvm/relay/backend/graph_executor_codegen.py +++ b/python/tvm/relay/backend/graph_executor_codegen.py @@ -20,14 +20,14 @@ The compiler is built from a few pieces. First we define a compiler from a single Relay expression to the -graph langauge. We require the expression to be a function. +graph language. We require the expression to be a function. The function's parameters correspond to the placeholder/inputs and model parameters found in the computation graph representation. The body of the function represents the computation graph. The compiler's output is a program in the graph language, which is composed of -graph langauge is composed of Node, NodeRef, InputNode, OpNode. -This "little language" represents programs in TVM's graph format. +Node, NodeRef, InputNode, OpNode. This "little language" represents programs in +TVM's graph format. To connect to the graph executor, we use a printer that converts our graph format into TVM's JSON format. The resulting string can be loaded by diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index b62fca86668d..819e5eda41f5 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -22,10 +22,9 @@ import tvm._ffi from tvm.runtime import container, Object -from tvm.ir import IRModule from . import _backend -from .. import _make, analysis, transform +from .. import _make, analysis from ... import nd from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, const from ..function import Function @@ -178,6 +177,7 @@ def evaluate(self, expr=None, binds=None): return self._make_executor(expr) # normal expression evaluated by running a function. + # TODO(mbs): This should really be type rather than syntax driven. func = Function([], expr) return self._make_executor(func)() @@ -196,6 +196,23 @@ class Interpreter(Executor): target : tvm.Target The target option to build the function. + + CAUTION: Despite the API the module is prepared upon each call to evaluate + rather than once in create_executor. + That is: + .. code-block:: python + + executor = relay.create_executor(kind="debug", mod=module) + a = executor.evaluate(expr)(args1) + b = executor.evaluate(expr)(args2) + + will prepare all the bindings in module twice. For efficiency, try to hoist + calls to evaluate as high as possible, preferably immediately after create_executor: + .. code-block:: python + + func = relay.create_executor(kind="debug", mod=module).evaluate(expr) + a = func(args1) + b = func(args2) """ def __init__(self, mod, device, target): @@ -203,57 +220,30 @@ def __init__(self, mod, device, target): self.device = device self.target = target - def optimize(self): - """Optimize functions in a module. - - Returns - ------- - opt_mod : tvm.IRModule - The optimized module. - """ - seq = tvm.transform.Sequential( - [ - # tvm.parser.AnnotateSpans(), - transform.SimplifyInference(), - transform.FuseOps(0), - transform.ToANormalForm(), - transform.InferType(), - ] - ) - mod = seq(self.mod) - return mod - def _make_executor(self, expr=None): if expr is None or isinstance(expr, GlobalVar): assert self.mod is not None - def _interp_wrapper(*args, **kwargs): - if expr is None: - args = self._convert_args(self.mod["main"], args, kwargs) + if expr is None: + # A missing expr denotes 'main' in the given module. + expr = self.mod.get_global_var("main") + + # Evaluate expr to a packed function we can efficiently re-apply + # to Relay arguments. + func = _backend.EvalFunction(self.mod, expr, self.device, self.target) + + def _apply_args(*args, **kwargs): + if isinstance(expr, GlobalVar): + # When expanding args, look inside the actual global definition so kwargs + # can be matched. + args = self._convert_args(self.mod[expr.name_hint], args, kwargs) else: args = self._convert_args(expr, args, kwargs) - + # Reflect python arguments up into Relay. relay_args = [] for arg in args: relay_args.append(_arg_to_ast(self.mod, arg)) + # Apply func to Relay args + return func(relay_args) - # Set the entry function for the module. - if expr is None: - pass - elif isinstance(expr, GlobalVar): - self.mod["main"] = self.mod[expr] - else: - assert isinstance(expr, Function) - func = Function([], Call(expr, relay_args)) - relay_args = [] - if self.mod: - self.mod["main"] = func - else: - self.mod = IRModule.from_expr(func) - - mod = self.optimize() - opt_expr = Call(mod["main"], relay_args) - _intrp = _backend.CreateInterpreter(mod, self.device, self.target) - return _intrp(opt_expr) - - return _interp_wrapper + return _apply_args diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index aa826aee57a1..c67ac1dc423d 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -511,7 +511,10 @@ def _graph_wrapper(*args, **kwargs): return _graph_wrapper -def create_executor(kind="debug", mod=None, device=None, target="llvm"): +# TODO(mbs): Collapse the create_executor/evaluate phases together since a) most callers don't +# reuse the executor for multiple expressions and b) any preparation necessary for the expression +# evaluation needs to (currently) be done along with preparation for the module. +def create_executor(kind="debug", mod=None, device=None, target="llvm", params=None): """Factory function to create an executor. Example @@ -544,6 +547,10 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm"): target : :py:class:`tvm.Target` The corresponding context + params : dict of str to NDArray + Input parameters to the graph that do not change + during inference time. + Returns ------- executor : :py:class:`~tvm.relay.backend.interpreter.Executor` @@ -555,6 +562,9 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm"): else: device = _nd.device(str(target), 0) + if params is not None: + mod = IRModule.from_expr(bind_params_by_name(mod["main"], params)) + if isinstance(target, str): target = Target(target) if kind == "debug": diff --git a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py index 6913a428b2ac..20e01da1493e 100644 --- a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py +++ b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py @@ -23,8 +23,8 @@ from .utils import _run_opt_pass -def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"): - """Convert a dense func and according parameters to block sparse +def convert(func, params, blocksize, sparsity_threshold, layout="NHWC", kernel_size=1): + """Convert a conv2d func and according parameters to block sparse Parameters ---------- @@ -49,10 +49,46 @@ def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"): params: Dict[Srting, tvm.nd.array] New params with BSR matrix for mutated Expr """ - weight_info = process_params(func, params, blocksize, sparsity_threshold, layout) + weight_info = process_params(func, params, blocksize, sparsity_threshold, layout, kernel_size) new_func = _run_opt_pass( func, - relay.transform.Conv2dToSparse(weight_info.weight_name, weight_info.weight_shape, layout), + relay.transform.Conv2dToSparse( + weight_info.weight_name, weight_info.weight_shape, layout, kernel_size + ), ) return new_func, params + + +def convert2(func, params, blocksize, sparsity_threshold, layout, kernel_size): + """Convert a freezed conv2d func to block sparse + + Parameters + ---------- + func : relay.Expr + Expr will be optimized to sparse operation, with params freezed + params : Dict[Srting, tvm.nd.array] + Parameters of the Expr (not used in this pass) + blocksize : Tuple(int, int) + Blocksize for BSR matrix + sparsity_threshold : float + Minimal sparsity requirement for converting. + If weight sparsity is lower than this threshold, + the dense operation will be kept. + layout : str + layout of network + kernel_size : int + kernel size of the conv2d, for filtering + + Returns + ------- + new_func: relay.Expr + Mutated Expr with sparse operations + + params: Dict[Srting, tvm.nd.array] + New params with BSR matrix for mutated Expr (not modified) + """ + new_func = _run_opt_pass( + func, relay.transform.Conv2dToSparse2(layout, kernel_size, blocksize, sparsity_threshold) + ) + return new_func, params diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 320a599d5d91..1f6d8bb9ab0b 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -796,11 +796,14 @@ class DFPatternCallback: ---------- require_type: bool Whether InferType is required to be run before the callback. + rewrite_once: bool + If True, run the callback only once. """ - def __init__(self, require_type=False): + def __init__(self, require_type=False, rewrite_once=False): self.pattern = None self.require_type = require_type + self.rewrite_once = rewrite_once def rewrite(self, expr: Expr) -> Expr: """ @@ -842,8 +845,10 @@ def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Exp class _DFPatternCallback(Object): """C++ implemenation""" - def __init__(self, pattern, callback, require_type): - self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type) + def __init__(self, pattern, callback, require_type, rewrite_once): + self.__init_handle_by_constructor__( + ffi.DFPatternCallback, pattern, callback, require_type, rewrite_once + ) def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr: @@ -870,7 +875,11 @@ def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr: tmp = [] for callback in callbacks: assert callback.pattern is not None - tmp.append(_DFPatternCallback(callback.pattern, callback.callback, callback.require_type)) + tmp.append( + _DFPatternCallback( + callback.pattern, callback.callback, callback.require_type, callback.rewrite_once + ) + ) return ffi.rewrite(tmp, expr, mod) diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 40a116ab0b43..b9ca7d0e11f2 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -152,8 +152,10 @@ def visit_let(self, let): self.visit(let.value) self.visit(let.body) - def visit_function(self, f): - self.visit(f.body) + def visit_function(self, fn): + for x in fn.params: + self.visit(x) + self.visit(fn.body) def visit_if(self, i): self.visit(i.cond) diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index aa8ac4fc7434..aa49b63203f2 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -31,4 +31,5 @@ from .darknet import from_darknet from .pytorch import from_pytorch from .caffe import from_caffe +from .paddlepaddle import from_paddle from .change_datatype import ChangeDatatype diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 713d01ef9fda..ce048105ae8b 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -263,7 +263,7 @@ def get_relay_op(op_name): The Relay operator name. """ if "." in op_name: - # explicit hierachical modules + # explicit hierarchical modules op = _op try: for opn in op_name.split("."): @@ -545,16 +545,17 @@ def infer_value(input_val, params, mod=None): mod["main"] = _function.Function(analysis.free_vars(input_val), input_val) else: mod = IRModule.from_expr(input_val) - exc = tvm.relay.create_executor("debug", mod=mod, device=tvm.cpu(), target="llvm") inputs = [] for param in mod["main"].params: inputs.append(params[param.name_hint]) - result = exc.evaluate()(*inputs) + result = tvm.relay.create_executor( + "debug", mod=mod, device=tvm.cpu(), target="llvm" + ).evaluate()(*inputs) return result def infer_value_simulated(input_val, params): - """Extention to infer_value that can be used when some input + """Extension to infer_value that can be used when some input values are missing. This function creates dummy inputs with the same shape and random values then calls infer_value. This is helpful when implementing certain onnx operators where we need to evaluate the graph @@ -624,3 +625,212 @@ def to_int_list(np_array): cause problems in relay/TOPI. """ return [int(x) for x in np_array] + + +def unbind(data, axis=0): + """ + Unbind was taken from Pytorch frontend. The operation removes a tensor dimension + and returns a tuple of all slices along a given dimension, with specified axis removed. + TODO (vvchernov): It needs such operation on relay side to reduce time consumption + on squeeze operation. + + Parameters + ---------- + data : relay.Expr + Input tensor + axis : int + Axis along which tensor is split. + Returns + ------- + result : List[relay.Expr] + The sequence of computed tensors + """ + shape = infer_shape(data) + if axis >= len(shape): + msg = "Please check input dim, it shouldn't be greater than or equal to rank." + raise AttributeError(msg) + + selections = shape[axis] + res_split = _op.split(data, selections, axis) + ret = [] + for i in range(selections): + ret.append(_op.squeeze(res_split[i], axis=[axis])) + return _expr.TupleWrapper(_expr.Tuple(ret), selections) + + +def gru_cell( + input_seqs, + hidden_state, + w_inp, + w_hid, + b_inp=None, + b_hid=None, + rz_act=_op.sigmoid, + n_act=_op.tanh, + backwards=False, + linear_before_reset=True, +): + """ + Common implementation of GRU cell for all frontends of TVM + TODO(vvchernov): currently it is used by pytorch and ONNX. Extend for other frontends + + Parameters + ---------- + input_seqs : List[relay.Expr] + The sequence of input tensors + Input tensor should be 2d while issue #8412 is not resolved + Shape = (batch, feature_size) + hidden_state : relay.Expr + Hidden state. shape = (batch_size, hidden_size) + w_inp, w_hid : relay.Expr + weight matrices. wi shape = (3 * hidden_size, feature_size) + wh shape = (3 * hidden_size, hidden_size) + NOTE: wi = (w_ir|w_iz|w_in) for reset, update and new gates. + The order is important for correct GRU calculation! + b_inp, b_hid : relay.Expr + bias matrices. The same order of internal parts as for weights. shape = (3 * hidden_size) + r_act : relay.op + activation funtion for reset gate. it is sigmoid by default + z_act : relay.op + activation funtion for update gate. it is sigmoid by default + n_act : relay.op + activation funtion for new gate. it is tanh by default + backwards : bool + Flag for reverse pass of GRU + + Returns + ------- + result : List[relay.Expr], relay.Expr, relay.Expr + The sequence of computed result, final hidden and cell state + """ + + outputs_list = [] + for x_t in input_seqs if not backwards else reversed(input_seqs): + xwt = _op.nn.dense(x_t, w_inp) + if linear_before_reset: + hwt = _op.nn.dense(hidden_state, w_hid) + if b_inp is not None and b_hid is not None: + xwt += b_inp + hwt += b_hid + i_r, i_z, i_n = _op.split(xwt, 3, axis=-1) + h_r, h_z, h_n = _op.split(hwt, 3, axis=-1) + r_gate = rz_act(i_r + h_r) + z_gate = rz_act(i_z + h_z) + n_gate = n_act(i_n + r_gate * h_n) + else: + i_r, i_z, i_n = _op.split(xwt, 3, axis=1) + w_hr, w_hz, w_hn = _op.split(w_hid, 3, axis=0) + r_gate = i_r + _op.nn.dense(hidden_state, w_hr) + z_gate = i_z + _op.nn.dense(hidden_state, w_hz) + if b_inp is not None and b_hid is not None: + b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1) + b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1) + r_gate += b_ir + b_hr + z_gate += b_iz + b_hz + i_n += b_in + h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + b_hn + else: + h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + r_gate = rz_act(r_gate) + z_gate = rz_act(z_gate) + n_gate = n_act(i_n + h_n) + + hidden_state = (hidden_state - n_gate) * z_gate + n_gate + + outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)] + + return outputs_list, hidden_state + + +def lstm_cell( + input_seqs, + hidden_state, + cell_state, + w_inp, + w_hid, + b_inp=None, + b_hid=None, + proj=None, + p_i=None, + p_f=None, + p_o=None, + f_act=_op.sigmoid, + g_act=_op.tanh, + h_act=_op.tanh, + backwards=False, +): + """ + Common implementation of LSTM cell for all frontends of TVM + TODO (vvchernov): currently it is used by onnx and pytorch. Extend for other frontends + + Parameters + ---------- + input_seqs : List[relay.Expr] + The sequence of input tensors + Input tensor should be 2d while issue #8412 is not resolved + Shape = (batch, feature_size) + hidden_state : relay.Expr + Hidden state. shape = (batch, hidden_size) + cell_state : relay.Expr + Cell state. shape = (batch, hidden_size) + w_inp, w_hid : relay.Expr + weight matrices. wi shape = (4 * hidden_size, feature_size) + wh shape = (4 * hidden_size, hidden_size or proj_size) + NOTE: wi = (w_ii|w_if|w_ig|w_io) for input, forget, cell and output gates. + The order is important for correct LSTM calculation! + b_inp, b_hid : relay.Expr + bias matrices. The same order of internal parts as for weights. shape = (4 * hidden_size) + proj : relay.Expr + projection matrix. shape = (proj_size, hidden_size) + p_i, p_f, p_o : relay.Expr + peephole LSTM matrices. shape = (batch, hidden_size) + f_act, g_act, h_act : relay.op + activation funtions + backwards : bool + Flag for reverse pass of LSTM + + Returns + ------- + result : List[relay.Expr], relay.Expr, relay.Expr + The sequence of computed result, final hidden and cell state + """ + + outputs_list = [] + for x_t in input_seqs if not backwards else reversed(input_seqs): + # x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size) + step = _op.concatenate([x_t, hidden_state], axis=1) + cat_w = _op.concatenate([w_inp, w_hid], axis=1) + # Instead of nn.dense(x_t, w_inp) + nn.dense(hidden_state, w_hid) + # nn.dense(step, cat_w) is used + # gates shape = (batch, 4 * hidden_size) + gates = _op.nn.dense(step, cat_w) + # Add biases + if b_inp is not None: + gates += b_inp + if b_hid is not None: + gates += b_hid + # any gate shape = (batch, hidden_size) + inp_gate, fgt_gate, cell_gate, otp_gate = _op.split(gates, 4, axis=-1) + + if p_i is not None and p_f is not None: + inp_gate = f_act(inp_gate + p_i * cell_state) + fgt_gate = f_act(fgt_gate + p_f * cell_state) + else: + inp_gate = f_act(inp_gate) + fgt_gate = f_act(fgt_gate) + + cell_gate = g_act(cell_gate) + cell_state = fgt_gate * cell_state + inp_gate * cell_gate + if p_o is not None: + otp_gate = f_act(otp_gate + p_o * cell_state) + else: + otp_gate = f_act(otp_gate) + + hidden_state = otp_gate * h_act(cell_state) + + if proj is not None: + hidden_state = _op.nn.dense(hidden_state, proj) + + outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)] + + return outputs_list, hidden_state, cell_state diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index f850750fad51..e515843e5fe2 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -562,7 +562,7 @@ def from_coreml(model, shape=None): etab = ExprTable() for i in spec.description.input: - input_shape = shape[i.name] if shape is not None and i.name in shape else None + input_shape = list(shape[i.name]) if shape is not None and i.name in shape else None etab.set_expr(i.name, _expr.var(i.name, shape=input_shape)) for pp in cc.preprocessing: diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f876b1d14fa1..9144d3e145c8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -34,6 +34,7 @@ from .. import qnn as _qnn from .. import ty as _ty from .. import vision as _vision +from .. import random as _random from .common import ( AttrCvt, Renamer, @@ -45,59 +46,33 @@ infer_type, infer_value, new_var, + unbind, + gru_cell, + lstm_cell, ) __all__ = ["from_onnx"] +# The default configurations of Relay ONNX frontend. +ONNX_DEFAULT_CONFIGS = { + # By default, TVM converts qualified onnx `matmul` to `transpose(weight) + nn.batch_matmul_NT`. + # Change this flag to False to directly convert to `nn.batch_matmul`. + # Note that `nn.batch_matmul` with format other than NT is in experimental, it may have some + # performance issues. + "use_nt_batch_matmul": True, +} -class onnx_input: - """Dual purpose list or dictionary access object.""" - def __init__(self): - self.input_keys = [] - self.input_dict = {} +class onnx_input(list): + """A helper extension to list that returns None for out of bound indices.""" def __getitem__(self, item): - if isinstance(item, int): - if item > (len(self.input_keys) - 1): - return None - return self.input_dict[self.input_keys[item]] - if isinstance(item, str): - if item not in self.input_keys: - return None - return self.input_dict[item] if isinstance(item, slice): - keys = self.input_keys[item] - return [self.input_dict[key] for key in keys] - - raise ValueError("Only integer, string, and slice accesses allowed.") - - def __setitem__(self, item, value): + indices = list(range(item.stop)[item]) + return [self[i] for i in indices] if isinstance(item, int): - self.input_dict[self.input_keys[item]] = value - elif isinstance(item, str): - self.input_keys.append(item) - self.input_dict[item] = value - else: - raise ValueError("Only integer and string indexed writes allowed.") - - def keys(self): - return self.input_keys - - def __len__(self): - return len(self.input_keys) - - def __iter__(self): - self.n = 0 - return self - - def __next__(self): - if self.n < len(self.input_keys): - output = self.input_dict[self.input_keys[self.n]] - self.n += 1 - return output - - raise StopIteration + return list(self)[item] if item < len(self) else None + raise TypeError("list indices must be integers or slices, not %s" % type(item).__name__) def get_numpy(tensor_proto): @@ -441,7 +416,10 @@ def autopad( # pad N and C with zeros pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) - return _op.nn.pad(data, fold_constant(pad), _op.const(pad_value), pad_type) + if isinstance(pad_value, (float, int)): + pad_value = _op.const(pad_value) + + return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) class Conv(OnnxOpConverter): @@ -581,6 +559,70 @@ def _impl_v1(cls, inputs, attr, params): out = _op.nn.bias_add(out, inputs[2]) return out + @classmethod + def _impl_v11(cls, inputs, attr, params): + # get number of channels + out_type = infer_type(inputs[1]) + out_shapes = [get_const_tuple(out_type.checked_type.shape)] + channels = out_shapes[0][1] + attr["channels"] = channels + groups = attr.get("group", 1) + + if "kernel_shape" not in attr: + attr["kernel_shape"] = out_shapes[0][2:] + + attr["groups"] = groups + # infer pads for auto_pad + data = inputs[0] + input_shape = infer_shape(data) + ndim = len(input_shape) + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): + # Warning: Convolution does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import + kernel_shape = attr["kernel_shape"] + kndim = len(kernel_shape) + dilations = attr.get("dilations", [1] * kndim) + output_padding = attr.get("output_padding", [0] * kndim) + strides = attr["strides"] + total_pad = [0] * kndim + for i in range(kndim): + total_pad[i] = ( + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - strides[i] + ) + left = [p // 2 for p in total_pad] + right = [total_pad[i] - left[i] for i in range(kndim)] + if "LOWER" in attr["auto_pad"]: + pad = left + right + else: + pad = right + left + attr["pads"] = pad + elif attr["auto_pad"] == "VALID": + attr["pads"] = tuple([0 for i in range(ndim - 2)]) + elif attr["auto_pad"] == "NOTSET": + pass + else: + msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) + attr.pop("auto_pad") + + out = AttrCvt( + op_name=dimension_picker("conv", "_transpose"), + transforms={ + "kernel_shape": "kernel_size", + "dilations": ("dilation", 1), + "pads": ("padding", 0), + "group": ("groups", 1), + }, + disables=["output_shape"], + custom_check=dimension_constraint(), + )([data, inputs[1]], attr, params) + use_bias = len(inputs) == 3 + if use_bias: + out = _op.nn.bias_add(out, inputs[2]) + return out + class GlobalAveragePool(OnnxOpConverter): """Operator converter for GlobalAveragePool""" @@ -678,25 +720,38 @@ def _impl_v1(cls, inputs, attr, params): # When performing a batch matmul, we need to properly handle N-dim shapes. if a_rank > 2 or b_rank > 2: - def flatten_to_3d(x, x_shape): + def flatten_to_nd(x, x_shape, nd=3): ndims = infer_shape(x_shape)[0] + if ndims == nd: + return x newshape = _op.concatenate( [ _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype), - _op.strided_slice(x_shape, [ndims - 2], [ndims]), + _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]), ], 0, ) out = _op.reshape(x, fold_constant(newshape)) return out - # Convert a and b into 3 dimensional tensors. - a = flatten_to_3d(inputs[0], a_shape) - b = flatten_to_3d(inputs[1], b_shape) - # Transpose matrix dimensions of b. - b = _op.transpose(b, [0, 2, 1]) - # Perform a batch matmul. - output = _op.nn.batch_matmul(a, b) + b_type = infer_type(inputs[1]) + # Convert to dense if the second matrix is 2d and non-dynamic + if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type): + a = flatten_to_nd(inputs[0], a_shape, 2) + b = _op.transpose(inputs[1]) + output = _op.nn.dense(a, b) + else: + # Convert a and b into 3 dimensional tensors. + a = flatten_to_nd(inputs[0], a_shape, 3) + b = flatten_to_nd(inputs[1], b_shape, 3) + if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]: + # Transpose matrix dimensions of b. + b = _op.transpose(b, [0, 2, 1]) + # Perform a NT batch matmul. + output = _op.nn.batch_matmul(a, b) + else: + # Perform a NN batch matmul. + output = _op.nn.batch_matmul(a, b, transpose_b=False) # Determine the output batch dimension. if a_rank > b_rank: out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) @@ -931,7 +986,7 @@ def _impl_v11(cls, inputs, attr, params): if len(inputs) == 3: value = fold_constant(_op.take(inputs[2], _op.const(0))) else: - value = 0 + value = 0.0 pad_width_expr = fold_constant(_op.transpose(_op.reshape(pads, (2, -1)))) pad_mode = attr.get("mode", b"constant").decode("utf-8") @@ -1384,13 +1439,13 @@ def has_static_axes(): ) if axes is not None and has_static_axes(): - axes_np = axes.data.asnumpy().astype("int64") - begin_np = starts.data.asnumpy().astype("int64") - end_np = ends.data.asnumpy().astype("int64") + axes_np = axes.data.numpy().astype("int64") + begin_np = starts.data.numpy().astype("int64") + end_np = ends.data.numpy().astype("int64") if steps is None: strides_np = np.ones_like(begin_np).astype("int64") else: - strides_np = steps.data.asnumpy().astype("int64") + strides_np = steps.data.numpy().astype("int64") if all([isinstance(ishape[i], int) for i in axes_np]): return _op.strided_slice( @@ -1768,6 +1823,18 @@ def _impl_v1(cls, inputs, attr, params): e = _op.exp(x - m) return e / _op.sum(e, axes, keepdims=True) + @classmethod + def _impl_v13(cls, inputs, attr, params): + axis = attr.get("axis", -1) + ndim = len(infer_shape(inputs[0])) + if axis < 0: + axis += ndim + axes = [axis] + x = inputs[0] + m = _op.max(x, axes, keepdims=True) + e = _op.exp(x - m) + return e / _op.sum(e, axes, keepdims=True) + class LogSoftmax(OnnxOpConverter): """Operator converter for Softmax.""" @@ -1785,6 +1852,19 @@ def _impl_v1(cls, inputs, attr, params): s = _op.sum(e, axes, keepdims=True) return x - m - _op.log(s) + @classmethod + def _impl_v13(cls, inputs, attr, params): + axis = attr.get("axis", -1) + ndim = len(infer_shape(inputs[0])) + if axis < 0: + axis += ndim + axes = [axis] + x = inputs[0] + m = _op.max(x, axes, keepdims=True) + e = _op.exp(x - m) + s = _op.sum(e, axes, keepdims=True) + return x - m - _op.log(s) + class Hardmax(OnnxOpConverter): """Operator converter for Hardmax.""" @@ -2065,58 +2145,44 @@ class LSTM(RNN): """Operator converter for LSTM""" @classmethod - def generate_lstm( - cls, X_steps, H_t, C_t, W, R, B, p_i, p_f, p_o, f_act, g_act, h_act, backwards=False + def bidir_lstm_cell( + cls, + input_seqs, + weight_dicts, + acts, ): - """Create an unrolled lstm loop. - - See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math. """ - h_list = [] - seq_length = len(X_steps) - for i in range(seq_length): - step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)] - step = _op.squeeze(step, axis=[0]) - gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R) - if B is not None: - WB, RB = _op.split(B, 2) - gates += WB + RB - i, o, f, c = _op.split(gates, 4, axis=-1) - - if p_i != 0: - i = f_act(i + p_i * C_t) - else: - i = f_act(i) - - if p_f != 0: - f = f_act(f + p_f * C_t) - else: - f = f_act(f) - - c = g_act(c) - C = f * C_t + i * c - if p_o != 0: - o = f_act(o + p_o * C) - else: - o = f_act(o) - - H = o * h_act(C) - - H_t = H - C_t = C - h_list.append(_op.expand_dims(H, axis=0)) + Bidirectional LSTM cell + """ + seq_len = len(input_seqs) + forward_outputs, fw_H_t, fw_C_t = lstm_cell( + input_seqs, + **weight_dicts[0], + f_act=acts[0], + g_act=acts[1], + h_act=acts[2], + ) - if backwards: - # Canonical view is hidden states from the first token not last - h_list = h_list[::-1] + reverse_outputs, rev_H_t, rev_C_t = lstm_cell( + input_seqs, + **weight_dicts[1], + f_act=acts[3], + g_act=acts[4], + h_act=acts[5], + backwards=True, + ) - # Concatenate outputs and add back in direction axis. - concatenated = _op.concatenate(h_list, 0) - output = _op.expand_dims(concatenated, axis=1) - H_t = _op.expand_dims(H_t, axis=0) - C_t = _op.expand_dims(C_t, axis=0) + final_outputs = [] + for i in range(seq_len): + final_outputs.append( + _op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0) + ) - return output, H_t, C_t + return ( + _op.stack(final_outputs, axis=0), + _op.stack([fw_H_t, rev_H_t], axis=0), + _op.stack([fw_C_t, rev_C_t], axis=0), + ) @classmethod def _impl_v7(cls, inputs, attr, params): @@ -2147,12 +2213,6 @@ def _impl_v7(cls, inputs, attr, params): Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) if Cp_0 is None: Cp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) - if Bp is None: - Bp = _op.zeros((num_directions, hidden_size * 8), W_dtype) - if Pp is not None: - p_i, p_o, p_f = _op.split(Pp, 3, axis=1) - else: - p_i = p_o = p_f = _op.zeros((num_directions, hidden_size), W_dtype) if "activations" in attr: activations = attr["activations"] @@ -2183,53 +2243,67 @@ def _impl_v7(cls, inputs, attr, params): else: acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions - X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0) - result_output = [] - result_H = [] - result_C = [] + # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved + X_steps = unbind(X, axis=0) H_ts = _op.split(Hp_0, num_directions) C_ts = _op.split(Cp_0, num_directions) Ws = _op.split(Wp, num_directions) Rs = _op.split(Rp, num_directions) - Bs = _op.split(Bp, num_directions) - p_is = _op.split(p_i, num_directions) - p_fs = _op.split(p_f, num_directions) - p_os = _op.split(p_o, num_directions) + + if Bp is not None: + Bs = _op.split(Bp, num_directions) + if Pp is not None: + p_i, p_o, p_f = _op.split(Pp, 3, axis=1) + + p_is = _op.split(p_i, num_directions) + p_fs = _op.split(p_f, num_directions) + p_os = _op.split(p_o, num_directions) + + weights_dicts = [] for i in range(num_directions): - H_t = _op.squeeze(H_ts[i], axis=[0]) - C_t = _op.squeeze(C_ts[i], axis=[0]) - W = _op.squeeze(Ws[i], axis=[0]) - R = _op.squeeze(Rs[i], axis=[0]) - B = _op.squeeze(Bs[i], axis=[0]) - p_i = _op.squeeze(p_is[i], axis=[0]) - p_f = _op.squeeze(p_fs[i], axis=[0]) - p_o = _op.squeeze(p_os[i], axis=[0]) - - f_act, g_act, h_act = acts[i * 3 : (i + 1) * 3] - output, H, C = LSTM.generate_lstm( - X_steps=X_steps, - H_t=H_t, - C_t=C_t, - W=W, - R=R, - B=B, - p_i=p_i, - p_f=p_f, - p_o=p_o, - f_act=f_act, - g_act=g_act, - h_act=h_act, - backwards=i == 1, + weights_dict = {} + + weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0]) + weights_dict["cell_state"] = _op.squeeze(C_ts[i], axis=[0]) + + # Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o + mati, mato, matf, matc = _op.split(_op.squeeze(Ws[i], axis=[0]), 4) + weights_dict["w_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0) + mati, mato, matf, matc = _op.split(_op.squeeze(Rs[i], axis=[0]), 4) + weights_dict["w_hid"] = _op.concatenate([mati, matf, matc, mato], axis=0) + if Bp is not None: + Bi, Bh = _op.split(Bs[i], 2, -1) + mati, mato, matf, matc = _op.split(_op.squeeze(Bi, axis=[0]), 4) + weights_dict["b_inp"] = _op.concatenate([mati, matf, matc, mato], axis=0) + mati, mato, matf, matc = _op.split(_op.squeeze(Bh, axis=[0]), 4) + weights_dict["b_hid"] = _op.concatenate([mati, matf, matc, mato], axis=0) + if Pp is not None: + weights_dict["p_i"] = _op.squeeze(p_is[i], axis=[0]) + weights_dict["p_f"] = _op.squeeze(p_fs[i], axis=[0]) + weights_dict["p_o"] = _op.squeeze(p_os[i], axis=[0]) + weights_dicts.append(weights_dict) + + if num_directions == 2: + output, H, C = LSTM.bidir_lstm_cell( + input_seqs=X_steps, + weight_dicts=weights_dicts, + acts=acts, + ) + else: + # outputs shape = [seqs_num, (batch_size, hidden_size)] + outputs, H, C = lstm_cell( + input_seqs=X_steps, + **weights_dicts[0], + f_act=acts[0], + g_act=acts[1], + h_act=acts[2], ) - result_output.append(output) - result_H.append(H) - result_C.append(C) - - output = _op.concatenate(result_output, axis=1) - H = _op.concatenate(result_H, axis=0) - C = _op.concatenate(result_C, axis=0) + # output shape = (seqs_num, num_directions, batch_size, hidden_size) + output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1) + H = _op.expand_dims(H, axis=0) + C = _op.expand_dims(C, axis=0) return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3) @@ -2238,56 +2312,41 @@ class GRU(RNN): """Operator convert for GRU""" @classmethod - def generate_gru( - cls, X_steps, H_t, W, R, B, linear_before_reset, f_act, g_act, W_dtype, backwards=False + def bidir_gru_cell( + cls, + input_seqs, + weight_dicts, + acts, ): - """Create an unrolled gru loop. - - See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math. """ - h_list = [] - seq_length = len(X_steps) - for i in range(seq_length): - step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)] - step = _op.squeeze(step, axis=[0]) - current = _op.nn.dense(step, W) - cz, cr, ch = _op.split(current, 3, axis=1) - rz, rr, rh = _op.split(R, 3, axis=0) - z = cz + _op.nn.dense(H_t, rz) - r = cr + _op.nn.dense(H_t, rr) - if B is not None: - WB, RB = _op.split(B, 2) - wbz, wbr, wbh = _op.split(WB, 3, axis=-1) - rbz, rbr, rbh = _op.split(RB, 3, axis=-1) - z += wbz + rbz - r += wbr + rbr - if linear_before_reset: - h = ch + (r * (_op.nn.dense(H_t, rh) + rbh)) + wbh - else: - h = ch + _op.nn.dense((r * H_t), rh) + wbh + rbh - else: - if linear_before_reset: - h = ch + (r * (_op.nn.dense(H_t, rh))) - else: - h = ch + _op.nn.dense((r * H_t), rh) - - z = f_act(z) - r = f_act(r) - h = g_act(h) - - H_t = ((_expr.const(1, dtype=W_dtype) - z) * h) + (z * H_t) - h_list.append(_op.expand_dims(H_t, axis=0)) + Bidirectional GRU cell + """ + seq_len = len(input_seqs) + forward_outputs, fw_H_t = gru_cell( + input_seqs, + **weight_dicts[0], + rz_act=acts[0], + n_act=acts[1], + ) - if backwards: - # Canonical view is hidden states from the first token not last - h_list = h_list[::-1] + reverse_outputs, rev_H_t = gru_cell( + input_seqs, + **weight_dicts[1], + rz_act=acts[2], + n_act=acts[3], + backwards=True, + ) - # Concatenate outputs and add back in direction axis. - concatenated = _op.concatenate(h_list, 0) - output = _op.expand_dims(concatenated, axis=1) - H_t = _op.expand_dims(H_t, axis=0) + final_outputs = [] + for i in range(seq_len): + final_outputs.append( + _op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0) + ) - return output, H_t + return ( + _op.stack(final_outputs, axis=0), + _op.stack([fw_H_t, rev_H_t], axis=0), + ) @classmethod def _impl_v7(cls, inputs, attr, params): @@ -2305,20 +2364,14 @@ def _impl_v7(cls, inputs, attr, params): W_dtype = infer_type(Wp).checked_type.dtype if num_directions not in [1, 2]: - raise NotImplementedError( - f"Directions for GRUs should be either 1 or 2 got {num_directions}" - ) + raise ValueError("num_directions must be either 1 or 2!") X_shape = infer_shape(X) hidden_size = infer_shape(Rp)[-1] batch_size = X_shape[1] - # Initialize state if not provided. - # Otherwise remove bidirectional axis. if Hp_0 is None: Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) - if Bp is None: - Bp = _op.zeros((num_directions, hidden_size * 6), W_dtype) if "activations" in attr: activations = attr["activations"] @@ -2349,39 +2402,54 @@ def _impl_v7(cls, inputs, attr, params): else: acts = [_op.sigmoid, _op.tanh] * 2 - result_output = [] - result_H = [] + # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved + X_steps = unbind(X, axis=0) - X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0) H_ts = _op.split(Hp_0, num_directions) Ws = _op.split(Wp, num_directions) Rs = _op.split(Rp, num_directions) - Bs = _op.split(Bp, num_directions) + if Bp is not None: + Bs = _op.split(Bp, num_directions) + + weights_dicts = [] for i in range(num_directions): - H_t = _op.squeeze(H_ts[i], axis=[0]) - W = _op.squeeze(Ws[i], axis=[0]) - R = _op.squeeze(Rs[i], axis=[0]) - B = _op.squeeze(Bs[i], axis=[0]) - f_act, g_act = acts[i * 2 : (i + 1) * 2] - output, H = GRU.generate_gru( - X_steps=X_steps, - H_t=H_t, - W=W, - R=R, - B=B, - linear_before_reset=linear_before_reset, - f_act=f_act, - g_act=g_act, - W_dtype=W_dtype, - backwards=i == 1, + weights_dict = {} + + weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0]) + weights_dict["linear_before_reset"] = linear_before_reset + + # Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o + matz, matr, matn = _op.split(_op.squeeze(Ws[i], axis=[0]), 3) + weights_dict["w_inp"] = _op.concatenate([matr, matz, matn], axis=0) + matz, matr, matn = _op.split(_op.squeeze(Rs[i], axis=[0]), 3) + weights_dict["w_hid"] = _op.concatenate([matr, matz, matn], axis=0) + if Bp is not None: + Bi, Bh = _op.split(Bs[i], 2, -1) + matz, matr, matn = _op.split(_op.squeeze(Bi, axis=[0]), 3) + weights_dict["b_inp"] = _op.concatenate([matr, matz, matn], axis=0) + matz, matr, matn = _op.split(_op.squeeze(Bh, axis=[0]), 3) + weights_dict["b_hid"] = _op.concatenate([matr, matz, matn], axis=0) + weights_dicts.append(weights_dict) + + if num_directions == 2: + output, H = GRU.bidir_gru_cell( + input_seqs=X_steps, + weight_dicts=weights_dicts, + acts=acts, + ) + else: + # outputs shape = [seqs_num, (batch_size, hidden_size)] + outputs, H = gru_cell( + input_seqs=X_steps, + **weights_dicts[0], + rz_act=acts[0], + n_act=acts[1], ) - result_output.append(output) - result_H.append(H) - - output = _op.concatenate(result_output, axis=1) - H = _op.concatenate(result_H, axis=0) + # output shape = (seqs_num, num_directions, batch_size, hidden_size) + output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1) + H = _op.expand_dims(H, axis=0) return _expr.TupleWrapper(_expr.Tuple((output, H)), 2) @@ -2566,6 +2634,19 @@ def _impl_v10(cls, inputs, attr, params): return isinf +class Celu(OnnxOpConverter): + """Operator convereter for celu""" + + @classmethod + def _impl_v12(cls, inputs, attr, params): + x = inputs[0] + dtype = infer_type(x).checked_type.dtype + alpha = _op.const(attr.get("alpha", 1.0), dtype) + zero = _op.const(0, dtype) + one = _op.const(1, dtype) + return _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one)) + + class MaxRoiPool(OnnxOpConverter): """Operator converter for MaxRoiPool.""" @@ -2628,10 +2709,10 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): if len(inputs) == 3 and isinstance(inputs[2], _expr.Constant): - attr["max"] = inputs[2].data.asnumpy().item() + attr["max"] = inputs[2].data.numpy().item() inputs = inputs[0:2] if len(inputs) >= 2 and isinstance(inputs[1], _expr.Constant): - attr["min"] = inputs[1].data.asnumpy().item() + attr["min"] = inputs[1].data.numpy().item() inputs = inputs[0:1] if "min" in attr and "max" in attr: return Clip.convert_attributes(inputs, attr, params) @@ -2729,24 +2810,24 @@ def get_var(name, val, scan=False): loop_var_names = [v.name_hint for v in loop_vars] num_scan_outputs = len(body.output) - (1 + num_deps) - # TODO (jwfromm) Test with strided slice once type unifier for this case is fixed. - if num_scan_outputs != 0 and "Slice" in [n.op_type for n in body.node]: - warnings.warn( - """ - Using scan outputs in a loop with strided slice - currently may cause errors during compilation. - """ - ) - # Construct variables and intial empty tensors for any scan outputs. + # Construct variables and initial empty tensors for any scan outputs. + # To do this, we'll figure out the output shapes of the body subgraph by importing + # it and doing type inference. scan_output_vars = [] scan_output_init = [] + if num_scan_outputs > 0: + with subgraph_scope: + loop_outputs = subgraph_scope.from_onnx( + body, graph_scope.opset, get_output_expr=True + ) + loop_outputs = _expr.TupleWrapper(loop_outputs, len(body.output)) + for i in range(num_scan_outputs): - name, shape, dtype, _ = get_info(body.output[i + 1 + num_deps]) - if dtype is None: - dtype = infer_type(loop_deps[i]).checked_type.dtype - if dtype == "float": - dtype = "float32" + name, _, _, _ = get_info(body.output[i + 1 + num_deps]) + output_node = infer_type(loop_outputs[i + 1 + num_deps]) + shape = get_const_tuple(output_node.checked_type.shape) + dtype = output_node.checked_type.dtype scan_output_vars.append( _expr.var(name, shape=([_ty.Any()] * (len(shape) + 1)), dtype=dtype) ) @@ -3193,6 +3274,113 @@ def get_scalar(x, dtype="float32"): return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype) +class QLinearMul(OnnxOpConverter): + """Operator converter for QLinearMul from Microsoft onnxruntime contrib opset.""" + + @classmethod + def _impl_v10(cls, inputs, attr, params): + def get_scalar(x, dtype="float32"): + if isinstance(x, _expr.Var) and x.name_hint in params: + return _op.const(params[x.name_hint].numpy(), dtype) + rank = len(infer_shape(x)) + assert rank <= 1, "QLinearMul scale and zero_point input must be scalars" + if rank == 1: + x = _op.squeeze(x, [0]) + return _op.cast(x, dtype) + + a = inputs[0] + a_scale = get_scalar(inputs[1]) + a_zero_point = get_scalar(inputs[2], "int32") + b = inputs[3] + b_scale = get_scalar(inputs[4]) + b_zero_point = get_scalar(inputs[5], "int32") + y_scale = fold_constant(get_scalar(inputs[6])) + y_zero_point = get_scalar(inputs[7], "int32") + + dtype = infer_type(a).checked_type.dtype + + ## Onnxruntime doesn't actually do this op in integer, they dequantize to fp32 + ## and then requantize afer + ## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qlmul.cpp + a = _qnn.op.dequantize(inputs[0], a_scale, a_zero_point) + b = _qnn.op.dequantize(inputs[3], b_scale, b_zero_point) + out = _op.multiply(a, b) + return _qnn.op.quantize(out, y_scale, y_zero_point, out_dtype=dtype) + + +class ConvInteger(OnnxOpConverter): + """Operator converter for ConvInteger.""" + + @classmethod + def _impl_v10(cls, inputs, attr, params): + data = inputs[0] + weight = inputs[1] + data_zp = inputs[2] + weight_zp = inputs[3] + if data_zp is None: + data_zp = _expr.const(0, "int32") + if weight_zp is None: + weight_zp = _expr.const(0, "int32") + + input_type = infer_type(data) + input_shape = get_const_tuple(input_type.checked_type.shape) + + ndim = len(input_shape) + kernel_type = infer_type(weight) + kernel_shape = get_const_tuple(kernel_type.checked_type.shape) + if "kernel_shape" not in attr: + attr["kernel_shape"] = kernel_shape[2:] + + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): + # Warning: Convolution does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import + data = autopad( + data, + attr.get("strides", [1] * (ndim - 2)), + attr["kernel_shape"], + attr.get("dilations", [1] * (ndim - 2)), + ndim, + pad_value=data_zp, + mode=attr["auto_pad"], + ) + elif attr["auto_pad"] == "VALID": + attr["pads"] = tuple([0 for i in range(ndim - 2)]) + elif attr["auto_pad"] == "NOTSET": + pass + else: + msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) + attr.pop("auto_pad") + + out_channels = kernel_shape[0] + dilation = attr.get("dilations", [1] * (ndim - 2)) + strides = attr.get("strides", [1] * (ndim - 2)) + padding = attr["pads"] if "pads" in attr else 0 + groups = attr["group"] if "group" in attr else 1 + + if ndim != 4: + raise tvm.error.OpAttributeInvalid( + "Only 2D kernels are supported for operator ConvInteger." + ) + + return _qnn.op.conv2d( + data, + weight, + _op.cast(data_zp, "int32"), + _op.cast(weight_zp, "int32"), + _expr.const(1.0, "float32"), + _expr.const(1.0, "float32"), + kernel_size=attr["kernel_shape"], + channels=out_channels, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + ) + + class BitShift(OnnxOpConverter): """Operator converter for NonZero""" @@ -3244,6 +3432,30 @@ def _impl_v11(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4) +class RandomUniform(OnnxOpConverter): + """Operator converter for random_uniform""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + dtype = get_type(attr.get("dtype", 1)) + high = attr.get("high", 1.0) + low = attr.get("low", 0.0) + seed = attr.get("seed", None) + shape = attr["shape"] + + assert dtype in [ + "float32", + "float64", + ], "Only float random value generation is currently supported." + + if seed is None: + seed = np.random.randint(1e6) + key = _random.threefry_key(seed) + output = _op.random.uniform(key, shape, dtype=dtype, low=low, high=high) + _, vals = _expr.TupleWrapper(output, 2) + return vals + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -3297,6 +3509,7 @@ def _get_convert_map(opset): "IsNaN": Renamer("isnan"), "Sqrt": Renamer("sqrt"), "Relu": Renamer("relu"), + "Celu": Celu.get_converter(opset), "LeakyRelu": Renamer("leaky_relu"), "Selu": Selu.get_converter(opset), "Elu": Elu.get_converter(opset), @@ -3421,6 +3634,10 @@ def _get_convert_map(opset): "ReverseSequence": ReverseSequence.get_converter(opset), "QLinearConv": QLinearConv.get_converter(opset), "QLinearAdd": QLinearAdd.get_converter(opset), + "QLinearMul": QLinearMul.get_converter(opset), + "ConvInteger": ConvInteger.get_converter(opset), + # Random number generation. + "RandomUniform": RandomUniform.get_converter(opset), } @@ -3580,13 +3797,13 @@ def from_onnx(self, graph, opset, get_output_expr=False): for node in graph.node: op_name = node.op_type attr = self._parse_attr(node.attribute) - # Create and populate onnx input object. + # Create and populate input list. inputs = onnx_input() for i in node.input: if i != "": - inputs[i] = self._nodes[self._renames.get(i, i)] + inputs.append(self._nodes[self._renames.get(i, i)]) else: - inputs[i] = None + inputs.append(None) i_name = self._parse_value_proto(node) node_output = self._fix_outputs(op_name, node.output) attr["tvm_custom"] = {} @@ -3739,7 +3956,9 @@ def _fix_outputs(self, op_name, outputs): return outputs -def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=False): +def from_onnx( + model, shape=None, dtype="float32", opset=None, freeze_params=False, convert_config=None +): """Convert a ONNX model into an equivalent Relay Function. ONNX graphs are represented as Python Protobuf objects. @@ -3778,6 +3997,12 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals at compile time and helps in making models static if certain inputs represent attributes relay would traditionally consider compile-time constants. + convert_config : Optional[Dict[str, Any]] + Default config: + use_nt_batch_matmul : bool = True + True to convert qualified onnx `matmul` to `nn.batch_matmul` strict to NT format + (transpose_a=False, transpose_b=True). + Returns ------- mod : tvm.IRModule @@ -3786,6 +4011,10 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals params : dict of str to tvm.nd.NDArray The parameter dict to be used by relay """ + global ONNX_DEFAULT_CONFIGS + if convert_config is not None: + ONNX_DEFAULT_CONFIGS.update(convert_config) + try: import onnx diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py new file mode 100644 index 000000000000..76a12691d2bf --- /dev/null +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -0,0 +1,918 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines +# pylint: disable=import-outside-toplevel +"""Paddle: PArallel Distributed Deep LEarning.""" +import warnings + +import numpy as np + +import tvm +from tvm.ir import IRModule + +from .. import analysis +from .. import expr as _expr +from .. import function as _function +from .. import ty as _ty +from .. import op as _op +from .common import ( + fold_constant, + infer_shape, + infer_type, + infer_value, + new_var, +) + +__all__ = ["from_paddle"] + + +def shape_of(x, dtype="int32"): + """Get shape of a tensor""" + + ttype = infer_type(x).checked_type + if not _ty.is_dynamic(ttype): + shape = list(ttype.shape) + return _expr.const(shape, dtype) + return _op.shape_of(x, dtype) + + +def _get_pad_size(in_size, dilated_kernel_size, stride_size): + """calculate the paddings size""" + + if stride_size == 1 or in_size % stride_size == 0: + pad = max(dilated_kernel_size - stride_size, 0) + else: + pad = max(dilated_kernel_size - (in_size % stride_size), 0) + + pad_before = pad // 2 + pad_after = pad - pad_before + + return [pad_before, pad_after] + + +def convert_arg_max(g, op, block): + """Operator converter for arg_max.""" + + axis = op.attr("axis") + keepdims = op.attr("keepdims") + flatten = op.attr("flatten") + + x = g.get_node(op.input("X")[0]) + if axis is None or flatten: + x = _op.reshape(x, [-1]) + out = _op.argmax(x, axis=None, keepdims=True) + else: + out = _op.argmax(x, axis=axis, keepdims=keepdims) + g.add_node(op.output("Out")[0], out) + + +def convert_assign(g, op, block): + """Operator converter for assign.""" + + out = _op.copy(g.get_node(op.input("X")[0])) + g.add_node(op.output("Out")[0], out) + + +def convert_batch_norm(g, op, block): + """Operator converter for batch_norm.""" + + ipt_name = op.input("X")[0] + scale_name = op.input("Scale")[0] + bias_name = op.input("Bias")[0] + mean_name = op.input("Mean")[0] + variance_name = op.input("Variance")[0] + epsilon = op.attr("epsilon") + out = _op.nn.batch_norm( + g.get_node(ipt_name), + g.get_node(scale_name), + g.get_node(bias_name), + g.get_node(mean_name), + g.get_node(variance_name), + epsilon=epsilon, + ) + g.add_node(op.output("Y")[0], out[0]) + + +def convert_cast(g, op, block): + """Operator converter for cast.""" + + dtype = block.var(op.output("Out")[0]).dtype + dtype = str(dtype).strip().split(".")[1] + x = g.get_node(op.input("X")[0]) + out = _op.cast(x, dtype=dtype) + g.add_node(op.output("Out")[0], out) + + +def convert_concat(g, op, block): + """Operator converter for concat.""" + + inputs = [g.get_node(op.input("X")[i]) for i in range(len(op.input("X")))] + axis = op.attr("axis") + out = _op.concatenate(inputs, axis=axis) + g.add_node(op.output("Out")[0], out) + + +def convert_conv2d(g, op, block): + """Operator converter for conv2d.""" + + dilations = op.attr("dilations") + groups = op.attr("groups") + paddings = op.attr("paddings") + padding_algorithm = op.attr("padding_algorithm") + strides = op.attr("strides") + + kernel = g.get_node(op.input("Filter")[0]) + input_x = g.get_node(op.input("Input")[0]) + out_channels, _, k_h, k_w = infer_shape(kernel) + in_h, in_w = infer_shape(input_x)[2:] + if padding_algorithm == "VALID": + paddings = [0, 0] + elif padding_algorithm == "SAME": + pad_h = _get_pad_size(in_h, (k_h - 1) * dilations[0] + 1, strides[0]) + pad_w = _get_pad_size(in_w, (k_w - 1) * dilations[1] + 1, strides[1]) + paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + elif padding_algorithm == "EXPLICIT": + if len(paddings) == 2: + paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] + if len(paddings) == 4: + paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] + else: + msg = 'Value {} in attribute "padding" of operator Conv is not "valid."' + raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) + + out = _op.nn.conv2d( + input_x, + kernel, + strides=strides, + padding=paddings, + dilation=dilations, + groups=groups, + channels=out_channels, + kernel_size=[k_h, k_w], + ) + g.add_node(op.output("Output")[0], out) + + +def convert_cumsum(g, op, block): + """Operator converter for cumsum.""" + + axis = op.attr("axis") + exclusive = op.attr("exclusive") + flatten = op.attr("flatten") + reverse = op.attr("reverse") + + x = g.get_node(op.input("X")[0]) + if axis is None or flatten: + x = _op.reshape(x, [-1]) + if reverse: + x = _op.reverse(x, axis=axis) + out = _op.cumsum(x, axis=axis, exclusive=exclusive) + out = _op.reverse(out, axis=axis) + else: + out = _op.cumsum(x, axis=axis, exclusive=exclusive) + g.add_node(op.output("Out")[0], out) + + +def convert_dropout(g, op, block): + """Operator converter for dropout.""" + + x = g.get_node(op.input("X")[0]) + out = _op.copy(x) + g.add_node(op.output("Out")[0], out) + + +def convert_elementwise_op(g, op, block): + """Operator converter for all the elementwise operators.""" + + op_map = { + "elementwise_div": lambda x, y: x / y, + "elementwise_add": lambda x, y: x + y, + "elementwise_mul": lambda x, y: x * y, + "elementwise_sub": lambda x, y: x - y, + "elementwise_mod": lambda x, y: x % y, + } + op_func = op_map[op.type] + ipt0 = g.get_node(op.input("X")[0]) + ipt1 = g.get_node(op.input("Y")[0]) + ipt0_shape = block.var(op.input("X")[0]).shape + ipt1_shape = block.var(op.input("Y")[0]).shape + axis = op.attr("axis") + if len(ipt0_shape) != len(ipt1_shape): + if axis < 0: + axis = axis + len(ipt0_shape) + if axis != len(ipt0_shape) - 1: + ipt1 = _op.expand_dims(ipt1, axis=axis, num_newaxis=(len(ipt0_shape) - axis - 1)) + out = op_func(ipt0, ipt1) + g.add_node(op.output("Out")[0], out) + + +def convert_equal(g, op, block): + """Operator converter for equal.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + out = _op.equal(x, y) + g.add_node(op.output("Out")[0], out) + + +def convert_activation(g, op, block): + """Operator converter for all the activation.""" + + op_map = { + "exp": _op.exp, + "relu": _op.nn.relu, + "tanh": _op.tanh, + "sqrt": _op.sqrt, + "erf": _op.erf, + "abs": _op.abs, + } + act_func = op_map[op.type] + out = act_func(g.get_node(op.input("X")[0])) + g.add_node(op.output("Out")[0], out) + + +def convert_feed(g, op, block): + """Converter for model input node.""" + + if block is not None: + ipt_name = op.output("Out")[0] + ipt_shape = block.var(ipt_name).shape + ipt_dtype = block.var(ipt_name).dtype + ipt_dtype = str(ipt_dtype).strip().split(".")[1] + else: + ipt_shape = op.shape + ipt_dtype = str(op.dtype).strip().split(".")[1] + ipt_name = op.name + if g.shape_dict is not None: + ipt_shape = g.shape_dict[ipt_name] + out = new_var(ipt_name, shape=ipt_shape, dtype=ipt_dtype) + g.add_node(ipt_name, out) + + +def convert_fill_any_like(g, op, block): + """Operator converter for fill_any_like.""" + + out_name = op.output("Out")[0] + out_dtype = block.var(out_name).dtype + out_dtype = str(out_dtype).strip().split(".")[1] + x = g.get_node(op.input("X")[0]) + ipt_type = infer_type(x).checked_type + value = op.attr("value") + if not _ty.is_dynamic(ipt_type): + shape = infer_shape(x) + const = np.ones(shape) * value + out = _expr.const(const.astype(out_dtype)) + else: + out = _op.transform.full_like(x, value).astype(out_dtype) + g.add_node(op.output("Out")[0], out) + + +def convert_fill_constant(g, op, block): + """Operator converter for fill_constant.""" + + value = op.attr("value") + shape = block.var(op.output("Out")[0]).shape + dtype = block.var(op.output("Out")[0]).dtype + dtype = str(dtype).strip().split(".")[1] + if op.input("ValueTensor"): + shape = g.get_node(op.input("ValueTensor")[0]) + shape = infer_value(shape, g.get_params()).numpy() + if op.input("ShapeTensor"): + shape = g.get_node(op.input("ShapeTensor")[0]) + shape = infer_value(shape, g.get_params()).numpy() + value = np.full(shape, value, dtype) + out = _expr.const(value.astype(dtype)).astype(dtype) + g.add_node(op.output("Out")[0], out) + + +def convert_gelu(g, op, block): + """Operator converter for gelu.""" + + x = g.get_node(op.input("X")[0]) + out = x * ( + _expr.const(0.5, dtype="float32") + + _op.erf(x * _expr.const(0.5 ** 0.5, dtype="float32")) * _expr.const(0.5, dtype="float32") + ) + g.add_node(op.output("Out")[0], out) + + +def convert_hard_sigmoid(g, op, block): + """Operator converter for hard_sigmoid.""" + + slope = op.attr("slope") + x = g.get_node(op.input("X")[0]) + out = x * _expr.const(slope) + _expr.const(0.5) + out = _op.clip(out, 0, 1) + g.add_node(op.output("Out")[0], out) + + +def convert_hard_swish(g, op, block): + """Operator converter for hard_swish.""" + + offset = op.attr("offset") + scale = op.attr("scale") + threshold = op.attr("threshold") + assert np.isclose(offset, 3.0), "Only support offset==3.0 for PaddlePaddle's hard_swish" + assert np.isclose(scale, 6.0), "Only support scale==6.0 for PaddlePaddle's hard_swish" + assert np.isclose(threshold, 6.0), "Only support threshold==6.0 for PaddlePaddle's hard_swish" + x = g.get_node(op.input("X")[0]) + out = _op.clip(x, -1 * offset, offset) + out = out / _expr.const(threshold) + _expr.const(0.5) + out = x * out + g.add_node(op.output("Out")[0], out) + + +def convert_layer_norm(g, op, block): + """Operator converter for layer_norm.""" + + begin_norm_axis = op.attr("begin_norm_axis") + epsilon = op.attr("epsilon") + x = g.get_node(op.input("X")[0]) + bias_input = op.input("Bias") + scale_input = op.input("Scale") + + x_shape = infer_shape(x) + assert begin_norm_axis in ( + len(x_shape) - 1, + -1, + ), "Support only normalization over last one dimension." + + if bias_input: + bias = g.get_node(bias_input[0]) + else: + bias = _expr.const(np.zeros(x_shape[begin_norm_axis])) + + if scale_input: + scale = g.get_node(scale_input[0]) + else: + scale = _expr.const(np.ones(x_shape[begin_norm_axis])) + + out = _op.nn.layer_norm( + x, gamma=scale, beta=bias, axis=begin_norm_axis, epsilon=epsilon, center=True, scale=True + ) + g.add_node(op.output("Y")[0], out) + + +def convert_leaky_relu(g, op, block): + """Operator converter for leaky_relu.""" + + alpha = op.attr("alpha") + x = g.get_node(op.input("X")[0]) + out = _op.nn.leaky_relu(x, alpha=alpha) + g.add_node(op.output("Out")[0], out) + + +def convert_lookup_table(g, op, block): + """Operator converter for lookup_table_v2.""" + + indices = g.get_node(op.input("Ids")[0]) + padding_idx = op.attr("padding_idx") + if padding_idx != -1: + g.get_params[op.input("W")[0]][padding_idx] = 0.0 + g.add_node(op.input("W")[0], _expr.const(g.params[op.input("W")[0]])) + weights = g.get_node(op.input("W")[0]) + out = _op.take(weights, indices.astype("int32"), axis=0) + g.add_node(op.output("Out")[0], out) + + +def convert_matmul(g, op, block): + """Operator converter for matmul.""" + + inputs = [g.get_node(op.input("X")[0]), g.get_node(op.input("Y")[0])] + a_shape = infer_shape(inputs[0]) + b_shape = infer_shape(inputs[1]) + if op.has_attr("trans_x"): + # for matmul_v2 + trans_x = op.attr("trans_x") + trans_y = op.attr("trans_y") + else: + # for matmul + trans_x = op.attr("transpose_X") + trans_y = op.attr("transpose_Y") + if trans_x: + perm = list(range(len(a_shape))) + perm[-2] = len(a_shape) - 1 + perm[-1] = len(a_shape) - 2 + inputs[0] = _op.transpose(inputs[0], axes=perm) + if trans_y: + perm = list(range(len(b_shape))) + perm[-2] = len(b_shape) - 1 + perm[-1] = len(b_shape) - 2 + inputs[1] = _op.transpose(inputs[1], axes=perm) + + # This implemention almost keeps same with ONNX + # Need to check input shape as batch matmul must be supported. + a_shape = shape_of(inputs[0]) + a_rank = infer_shape(a_shape)[0] + b_shape = shape_of(inputs[1]) + b_rank = infer_shape(b_shape)[0] + # When performing a batch matmul, we need to properly handle N-dim shapes. + if a_rank > 2 or b_rank > 2: + + def flatten_to_nd(x, x_shape, nd=3): + ndims = infer_shape(x_shape)[0] + if ndims == nd: + return x + newshape = _op.concatenate( + [ + _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype), + _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]), + ], + 0, + ) + out = _op.reshape(x, fold_constant(newshape)) + return out + + b_type = infer_type(inputs[1]) + # Convert to dense if the second matrix is 2d and non-dynamic + if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type): + a = flatten_to_nd(inputs[0], a_shape, 2) + b = _op.transpose(inputs[1]) + output = _op.nn.dense(a, b) + else: + # Convert a and b into 3 dimensional tensors. + a = flatten_to_nd(inputs[0], a_shape, 3) + b = flatten_to_nd(inputs[1], b_shape, 3) + # Transpose matrix dimensions of b. + b = _op.transpose(b, [0, 2, 1]) + # Perform a batch matmul. + output = _op.nn.batch_matmul(a, b) + # Determine the output batch dimension. + if a_rank > b_rank: + out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) + elif a_rank < b_rank: + out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2]) + # If its unclear how broadcasting should be applied, the output + # shape is determined by choosing the maximum value from each input. + else: + out_batch = _op.concatenate( + [ + _op.maximum( + _op.strided_slice(a_shape, [i], [i + 1]), + _op.strided_slice(b_shape, [i], [i + 1]), + ) + for i in range(a_rank - 2) + ], + 0, + ) + # Reshape output to original dimensions. + final_shape = _op.concatenate( + [ + out_batch, + _op.strided_slice( + a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1] + ), + _op.strided_slice( + b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]] + ), + ], + 0, + ) + out = _op.reshape(output, fold_constant(final_shape)) + else: + if b_rank == 1: + inputs[1] = _op.expand_dims(inputs[1], 1, 1) + # Otherwise a simple dense op will get the job done. + input_1_t = _op.transpose(inputs[1], axes=(1, 0)) + out = _op.nn.dense(inputs[0], input_1_t) + if b_rank == 1: + out = _op.squeeze(out, axis=[-1]) + if op.has_attr("alpha"): + alpha = op.attr("alpha") + if not np.isclose(alpha, 1.0): + out = out * _expr.const(alpha).astype("float32") + g.add_node(op.output("Out")[0], out) + + +def convert_mul(g, op, block): + """Operator converter for mul.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + x_num_col_dims = op.attr("x_num_col_dims") + y_num_col_dims = op.attr("y_num_col_dims") + x_shape = shape_of(x) + y_shape = shape_of(y) + x_dim = infer_shape(x_shape)[0] + y_dim = infer_shape(y_shape)[0] + if x_num_col_dims < 0: + x_num_col_dims += x_dim + if y_num_col_dims < 0: + y_num_col_dims += y_dim + if x_num_col_dims == 1: + x = _op.nn.batch_flatten(x) + else: + pre_shape = _op.prod(_op.strided_slice(x_shape, [0], [x_num_col_dims], [1]), keepdims=True) + post_shape = _op.prod( + _op.strided_slice(x_shape, [x_num_col_dims], [x_dim], [1]), keepdims=True + ) + new_shape = _op.concatenate([pre_shape, post_shape], axis=0) + new_shape = fold_constant(new_shape) + x = _op.reshape(x, new_shape) + if y_num_col_dims == 1: + y = _op.nn.batch_flatten(y) + else: + pre_shape = _op.prod(_op.strided_slice(y_shape, [0], [y_num_col_dims], [1]), keepdims=True) + post_shape = _op.prod( + _op.strided_slice(y_shape, [y_num_col_dims], [y_dim], [1]), keepdims=True + ) + new_shape = _op.concatenate([pre_shape, post_shape], axis=0) + new_shape = fold_constant(new_shape) + y = _op.reshape(y, new_shape) + y = _op.transpose(y) + out = _op.nn.dense(x, y) + out_pre_shape = _op.strided_slice(x_shape, [0], [x_num_col_dims], [1]) + out_post_shape = _op.strided_slice(y_shape, [y_num_col_dims], [y_dim], [1]) + out_shape = _op.concatenate([out_pre_shape, out_post_shape], axis=0) + out_shape = fold_constant(out_shape) + out = _op.reshape(out, out_shape) + g.add_node(op.output("Out")[0], out) + + +def convert_pool2d(g, op, block): + """Operator converter for pool2d.""" + + adaptive = op.attr("adaptive") + ceil_mode = op.attr("ceil_mode") + global_pooling = op.attr("global_pooling") + ksize = op.attr("ksize") + paddings = op.attr("paddings") + padding_algorithm = op.attr("padding_algorithm") + pooling_type = op.attr("pooling_type") + if global_pooling: + adaptive = True + ksize = [1, 1] + + input_x = g.get_node(op.input("X")[0]) + in_h, in_w = infer_shape(input_x)[2:] + + op_map = { + "avg": "avg_pool2d", + "max": "max_pool2d", + } + strides = op.attr("strides") + if isinstance(strides, int): + strides = [strides, strides] + if isinstance(ksize, int): + ksize = [ksize, ksize] + if isinstance(paddings, int): + paddings = [paddings] * 2 + + if padding_algorithm == "VALID": + paddings = [0, 0] + elif padding_algorithm == "SAME": + pad_h = _get_pad_size(in_h, ksize[0], strides[0]) + pad_w = _get_pad_size(in_w, ksize[1], strides[1]) + paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + elif padding_algorithm == "EXPLICIT": + if len(paddings) == 2: + paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] + if len(paddings) == 4: + paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] + else: + msg = 'Value {} in attribute "padding" of operator Pool2d is not "valid."' + raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) + + if not adaptive: + out = getattr(_op.nn, op_map[pooling_type])( + input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode + ) + else: + out = getattr(_op.nn, "adaptive_" + op_map[pooling_type])(input_x, output_size=ksize) + g.add_node(op.output("Out")[0], out) + + +def convert_reshape(g, op, block): + """Operator converter for reshape.""" + + input_shape = op.input("Shape") + input_shape_tensor = op.input("ShapeTensor") + data = g.get_node(op.input("X")[0]) + if input_shape: + new_shape = g.get_node(input_shape[0]) + elif input_shape_tensor: + tmp_shape = [] + for shape_name in input_shape_tensor: + shape = g.get_node(shape_name) + if len(infer_shape(shape)) == 0: + shape = _op.reshape(shape, [-1]) + if isinstance(shape, _expr.Constant): + tmp_shape.append(shape) + elif isinstance(shape, _expr.Expr): + tmp_shape.append(shape) + else: + tmp_shape.append(_expr.const(np.array(shape).astype("int64"))) + new_shape = _op.concatenate(tmp_shape, axis=0) + else: + new_shape = op.attr("shape") + out = _op.reshape(data, new_shape) + g.add_node(op.output("Out")[0], out) + + +def convert_scale(g, op, block): + """Operator converter for scale.""" + + scale = op.attr("scale") + bias = op.attr("bias") + bias_after_scale = op.attr("bias_after_scale") + x = g.get_node(op.input("X")[0]) + if np.isclose(scale, 1.0) and np.isclose(bias, 0.0): + out = _op.copy(x) + else: + if np.isclose(bias, 0.0): + out = x * _expr.const(np.array(scale).astype("float32")) + elif np.isclose(scale, 1.0): + out = x + _expr.const(np.array(bias).astype("float32")) + else: + if bias_after_scale: + out = x * _expr.const(np.array(scale).astype("float32")) + _expr.const( + np.array(bias).astype("float32") + ) + else: + out = (x + _expr.const(np.array(bias).astype("float32"))) * _expr.const( + np.array(scale).astype("float32") + ) + g.add_node(op.output("Out")[0], out) + + +def convert_shape(g, op, block): + """Operator converter for shape.""" + + x = g.get_node(op.input("Input")[0]) + out = shape_of(x) + g.add_node(op.output("Out")[0], out) + + +def convert_slice(g, op, block): + """Operator converter for slice.""" + + def parameter_process(starts, ends, axes, dshape): + new_axes = [] + new_starts = [] + new_ends = [] + pop_index = 0 + for i in range(max(axes) + 1): + new_axes.append(i) + if i in axes: + new_starts.append(starts[pop_index]) + new_ends.append(ends[pop_index]) + pop_index += 1 + else: + new_starts.append(0) + new_ends.append(dshape[i]) + return new_starts, new_ends, new_axes + + data = g.get_node(op.input("Input")[0]) + dshape = infer_shape(data) + starts = op.attr("starts") + ends = op.attr("ends") + axes = op.attr("axes") + decrease_axis = op.attr("decrease_axis") + if isinstance(starts, int): + starts = [starts] + if isinstance(ends, int): + ends = [ends] + if isinstance(axes, int): + axes = [axes] + if isinstance(decrease_axis, int): + decrease_axis = [decrease_axis] + starts, ends, axes = parameter_process(starts, ends, axes, dshape) + out = _op.strided_slice(data, begin=starts, end=ends) + if decrease_axis: + out = _op.squeeze(out, axis=decrease_axis) + g.add_node(op.output("Out")[0], out) + + +def convert_softmax(g, op, block): + """Operator converter for softmax.""" + + axis = op.attr("axis") + input_shape = block.var(op.input("X")[0]).shape + if axis < 0: + axis = len(input_shape) + axis + x = g.get_node(op.input("X")[0]) + m = _op.max(x, axis, keepdims=True) + e = _op.exp(x - m) + out = e / _op.sum(e, axis, keepdims=True) + g.add_node(op.output("Out")[0], out) + + +def convert_unsqueeze(g, op, block): + """Operator converter for unsqueeze.""" + + x = g.get_node(op.input("X")[0]) + axes = sorted(op.attr("axes")) + for axis in axes: + x = _op.expand_dims(x, axis=axis, num_newaxis=1) + g.add_node(op.output("Out")[0], x) + + +_convert_map = { + "arg_max": convert_arg_max, + "assign": convert_assign, + "batch_norm": convert_batch_norm, + "cast": convert_cast, + "concat": convert_concat, + "conv2d": convert_conv2d, + "cumsum": convert_cumsum, + "depthwise_conv2d": convert_conv2d, + "dropout": convert_dropout, + "elementwise_add": convert_elementwise_op, + "elementwise_div": convert_elementwise_op, + "elementwise_mul": convert_elementwise_op, + "elementwise_sub": convert_elementwise_op, + "equal": convert_equal, + "exp": convert_activation, + "feed": convert_feed, + "fill_any_like": convert_fill_any_like, + "fill_constant": convert_fill_constant, + "gelu": convert_gelu, + "hard_sigmoid": convert_hard_sigmoid, + "hard_swish": convert_hard_swish, + "layer_norm": convert_layer_norm, + "leaky_relu": convert_leaky_relu, + "lookup_table_v2": convert_lookup_table, + "matmul": convert_matmul, + "matmul_v2": convert_matmul, + "mul": convert_mul, + "pool2d": convert_pool2d, + "relu": convert_activation, + "reshape2": convert_reshape, + "scale": convert_scale, + "shape": convert_shape, + "slice": convert_slice, + "softmax": convert_softmax, + "tanh": convert_activation, + "unsqueeze2": convert_unsqueeze, +} + + +class GraphProto: + """A helper class for handling relay functions from PaddlePaddle model.""" + + def __init__(self): + self.nodes = {} + self.params = {} + self.shape_dict = None + + def get_node(self, name): + """get node from graph""" + + assert name in self.nodes + return self.nodes[name] + + def add_node(self, name, node): + """add a node to graph""" + + self.nodes[name] = fold_constant(node) + + def get_params(self, name=None): + """get params from graph""" + + if name is None: + return self.params + assert name in self.params + return self.params[name] + + def extract_parameters(self, program, scope=None): + """Extract all the weights from PaddlePaddle program.""" + + self.params = {} + variables = program.global_block().vars + for name in variables: + var = program.global_block().var(name) + if name.endswith("feed") or name.endswith("fetch"): + continue + if not var.persistable: + continue + if isinstance(scope, dict): + self.params[name] = scope[name] + else: + self.params[name] = np.array(scope.var(name).get_tensor()) + self.nodes[name] = _expr.const(self.params[name]) + + def check_input_shape(self, op, block): + """Check the shape information of model's inputs, fixed shape is recommended.""" + + ipt_name = op.input(op.input_names[0]) + ipt_shape = block.var(ipt_name).shape + for i in ipt_shape: + if i < 0: + warning_msg = "Input {}(shape={}) has unkown dimension shapes. \ + Specifying static values may improve performance".format( + ipt_name, ipt_shape + ) + warnings.warn(warning_msg) + + def check_unsupported_ops(self, program): + """Check whether all the operators are supported.""" + + unsupported_ops = set() + for block in program.blocks: + for op in block.ops: + if op.type == "fetch": + continue + if op.type not in _convert_map: + unsupported_ops.add(op.type) + if len(unsupported_ops) > 0: + msg = "The following operators are not supported for frontend Paddle: " + msg += ", ".join(unsupported_ops) + raise tvm.error.OpNotImplemented(msg) + + def ops_to_relay(self, program, input_specs=None): + """Convert PaddlePaddle operators to TVM relay functions.""" + + if input_specs is not None: + for input_spec in input_specs: + convert_feed(self, input_spec, None) + for block in program.blocks: + for op in block.ops: + if op.type == "fetch": + continue + convert_func = _convert_map[op.type] + convert_func(self, op, block) + + def from_program(self, program, shape_dict, scope): + """Construct the TVM relay expression from PaddlePaddle program.""" + + self.shape_dict = shape_dict + if scope is None: + import paddle + + scope = paddle.fluid.global_scope() + self.check_unsupported_ops(program) + self.extract_parameters(program, scope) + self.ops_to_relay(program) + + output_names = list() + for block in program.blocks: + for op in block.ops: + if op.type == "fetch": + output_names.append(op.input("X")[0]) + + outputs = [self.nodes[name] for name in output_names] + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + + free_vars = analysis.free_vars(outputs) + func = _function.Function(free_vars, outputs) + mod = IRModule.from_expr(func) + return mod, self.params + + def from_translated_layer(self, layer, shape_dict): + """Construct the TVM relay expression from PaddlePaddle TranslatedLayer.""" + + self.shape_dict = shape_dict + program = layer.program() + parameters = dict() + for param in layer.parameters(): + parameters[param.name] = np.array(param.value().get_tensor()) + self.check_unsupported_ops(program) + self.extract_parameters(program, parameters) + + input_specs = layer._input_spec() + self.ops_to_relay(program, input_specs) + + output_names = [x.name for x in layer._output_spec()] + + outputs = [self.nodes[name] for name in output_names] + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + + free_vars = analysis.free_vars(outputs) + func = _function.Function(free_vars, outputs) + mod = IRModule.from_expr(func) + return mod, self.params + + +def from_paddle(program_or_layer, shape_dict=None, scope=None): + """Convert a PaddlePaddle model into an equivalent Relay Function. + + PaddlePaddle Program/TranslatedLayer represent the computation graph of PaddlePaddle model, + and PaddlePaddle scope stores all the weights of PaddlePaddle model. + """ + + import paddle + + g = GraphProto() + if isinstance(program_or_layer, paddle.jit.TranslatedLayer): + # model is loaded by `paddle.jit.load` + mod, params = g.from_translated_layer(program_or_layer, shape_dict) + elif isinstance(program_or_layer, paddle.static.Program): + # model is loaded by `paddle.static.load_inference_model` + mod, params = g.from_program(program_or_layer, shape_dict, scope) + else: + raise Exception("Only PaddlePaddle's Program and TranslatedLayer are supported.") + return mod, params diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 4c874672445b..c13d791cf2e2 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -39,8 +39,9 @@ from ..prelude import Prelude, StaticTensorArrayOps from ..ty import Any, TensorType, TupleType from . import qnn_torch -from .common import AttrCvt, get_relay_op +from .common import AttrCvt, get_relay_op, unbind, lstm_cell, gru_cell from .common import infer_value as _infer_value +from .common import infer_shape as _infer_shape from .common import infer_value_simulated as _infer_value_simulated from .common import try_infer_value from .pytorch_utils import is_version_greater_than @@ -572,6 +573,12 @@ def repeat_interleave(self, inputs, input_types): if isinstance(inputs[1], int): repeats = inputs[1] axis = inputs[2] + elif isinstance(inputs[1], _expr.Expr): + if isinstance(inputs[1], _expr.Constant): + repeats = int(inputs[1].data.numpy()) + else: + repeats, _ = try_infer_value(inputs[1], lambda ret: ret.tolist()) + axis = inputs[2] else: msg = "Only repeat with one value as repeat is currently supported." raise AssertionError(msg) @@ -770,7 +777,7 @@ def leaky_relu(self, inputs, input_types): def elu(self, inputs, input_types): data = inputs[0] dtype = input_types[0] - alpha = _expr.const(float(inputs[1]), dtype=dtype) + alpha = _expr.const(-float(inputs[1]), dtype=dtype) return alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data) def celu(self, inputs, input_types): @@ -803,6 +810,10 @@ def selu(self, inputs, input_types): alpha * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data) ) + def silu(self, inputs, input_types): + data = inputs[0] + return data * _op.tensor.sigmoid(data) + def log_sigmoid(self, inputs, input_types): data = inputs[0] return _op.log(_op.tensor.sigmoid(data)) @@ -1444,7 +1455,16 @@ def linear(self, inputs, input_types): # 0 - input # 1 - weight bias = inputs[2] - mm_out = self.matmul(inputs[:2], input_types[:2]) + a_shape = self.infer_shape_with_prelude(inputs[0]) + b_shape = self.infer_shape_with_prelude(inputs[1]) + if len(a_shape) == 2 and len(b_shape) == 2: + mm_out = _op.nn.dense(inputs[0], inputs[1]) + elif len(b_shape) == 1: + mm_out = self.matmul([inputs[0], inputs[1]], input_types[:2]) + else: + mm_out = self.matmul( + [inputs[0], _op.transpose(inputs[1], axes=(1, 0))], input_types[:2] + ) if isinstance(bias, _expr.Expr): bias_ndims = len(self.infer_shape_with_prelude(bias)) if bias_ndims == 1: @@ -1584,28 +1604,11 @@ def chunk(self, inputs, input_types): else: unif_size = int(dim / num_chunks) - chunks = [] - for i in range(0, dim, unif_size): - begin = [0] * len(shape) - end = shape[:] - begin[axis] = i - end[axis] = i + unif_size - stride = [1] * len(shape) + indeces = [] + for i in range(unif_size, dim, unif_size): + indeces.append(i) - chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride) - chunks.append(chunk_out) - - if dim % num_chunks: - begin = [0] * len(shape) - end = shape[:] - begin[axis] = unif_size * (num_chunks - 1) - end[axis] = dim - stride = [1] * len(shape) - - chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride) - chunks.append(chunk_out) - - return chunks + return _op.split(data, indeces, axis) def matmul(self, inputs, input_types): @@ -1798,7 +1801,7 @@ def get_upsample_out_size(self, inputs, method): else: out_size.append(size) else: - scale_index = 3 if method == "linear" else 2 + scale_index = 3 if method != "nearest_neighbor" else 2 scales = inputs[scale_index] assert scales is not None, "neither out size nor scale provided" assert isinstance(scales, list) @@ -1813,7 +1816,7 @@ def upsample(inputs, input_types): data = inputs[0] out_size = self.get_upsample_out_size(inputs, method) - if len(inputs) > 2 and method == "linear": + if len(inputs) > 2 and method != "nearest_neighbor": align_corners = inputs[2] else: align_corners = False @@ -1826,7 +1829,9 @@ def upsample(inputs, input_types): coord_trans = "half_pixel" def func(x): - return _op.image.resize2d(x, out_size, "NCHW", method, coord_trans) + return _op.image.resize2d( + x, out_size, "NCHW", method, coord_trans, cubic_alpha=-0.75 + ) if self.is_quantized_tensor(data): # input qparams are manually appended by us @@ -2093,21 +2098,8 @@ def deform_conv2d(self, inputs, input_types): def unbind(self, inputs, input_types): data = inputs[0] - dim = int(inputs[1]) - ishapes = self.infer_shape(data) - if dim >= len(ishapes): - msg = "Please check input dim, it shouldn't be greater than or equal to rank." - raise AttributeError(msg) - - selections = ishapes[dim] - res_split = _op.split(data, selections, dim) - # squeeze each split piece to get same shape as aten::unbind - # TODO (yongwww): add new op to avoid the squeeze overhead - ret = [] - for i in range(selections): - ret.append(_op.transform.squeeze(res_split[i], axis=[dim])) - ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) - return ret + axis = int(inputs[1]) + return unbind(data, axis) def shape_as_tensor(self, inputs, input_types): is_symbolic_shape = False @@ -2134,7 +2126,7 @@ def nonzero(self, inputs, input_types, is_numpy_style=False): data = inputs[0] ret = _op.transform.argwhere(data) if is_numpy_style or (len(inputs) > 1 and inputs[1]): - return self.unbind([ret, 1], None) + return unbind(ret, 1) return ret def nonzero_numpy(self, inputs, input_types): @@ -2202,7 +2194,7 @@ def interpolate(self, inputs, input_types): else: coord_trans = "half_pixel" - return _op.image.resize2d(data, out_size, "NCHW", method, coord_trans) + return _op.image.resize2d(data, out_size, "NCHW", method, coord_trans, cubic_alpha=-0.75) def numel(self, inputs, input_types): return _op.ndarray_size(inputs[0]) @@ -2329,6 +2321,448 @@ def flip(self, inputs, input_types): axis = inputs[1] return _op.transform.reverse(data, axis=axis[0]) + def bidir_gru_cell( + self, + input_seqs, + weights_dicts, + ): + """ + Bidirectional GRU cell + """ + seq_len = len(input_seqs) + forward_outputs, fw_H_t = gru_cell( + input_seqs, + **weights_dicts[0], + ) + + reverse_outputs, rev_H_t = gru_cell( + input_seqs, + **weights_dicts[1], + backwards=True, + ) + + final_outputs = [] + for i in range(seq_len): + final_outputs.append( + _op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=-1) + ) + + return final_outputs, _op.stack([fw_H_t, rev_H_t], axis=0) + + def gru_layers(self, input_data, layer_weights_dicts, bidirectional, dropout_p=0.0): + """ + Methods iterates layers for Stacked GRU + """ + layers_num = len(layer_weights_dicts) + # split input sequence to samples set + input_seqs = unbind(input_data, 0) # [seq_num, (batch, feature_size)] + output_hiddens = [] + for i in range(layers_num): + weights_dicts = layer_weights_dicts[i] + # input_seqs shape = [seq_num, (batch, feature_size)] or + # [seq_num, (batch, 2*feature_size)] for bidirectional + if bidirectional: + input_seqs, H_t = self.bidir_gru_cell(input_seqs, weights_dicts) + else: + input_seqs, H_t = gru_cell(input_seqs, **weights_dicts[0]) + + output_hiddens.append(H_t) + + # TODO (vvchernov): in pytorch implementation train is also checked + # see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339 + # /aten/src/ATen/native/RNN.cpp#L1054 + if dropout_p != 0 and i < layers_num - 1: + # for input in input_seqs: + # input = _op.dropout(input, dropout_p) + raise NotImplementedError("Dropout for GRU has not been supported yet!") + + return _op.stack(input_seqs, 0), _op.stack(output_hiddens, 0) + + def gru(self, inputs, input_types): + """ + Description of GRU in pytorch: + https://pytorch.org/docs/stable/generated/torch.nn.GRU.html?highlight=gru#torch.nn.GRU + """ + # TODO (vvchernov): support dropout + assert len(inputs) == 9, "Input of size 9 is expected" + # Unpack inputs, note that if optional and not provided then value will be None. + _X = inputs[0] + # _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size) + + hidden_state = inputs[1] + # Hidden state shape (hidden_layers_num, batch, hidden_size) + + _weights = inputs[2] + # Wi layer[0] shape (3 * hidden_size, feature_size) + # Wh layer[0] shape (3 * hidden_size, hidden_size) + # Bi layer[0] shape (3 * hidden_size) + # Bh layer[0] shape (3 * hidden_size) + + # Wi layer[>0] shape (3 * hidden_size, hidden_size * num_directions) + # Wh layer[>0] shape (3 * hidden_size, hidden_size) + # Bi layer[>0] shape (3 * hidden_size) + # Bh layer[>0] shape (3 * hidden_size) + + # Scalar inputs + has_biases = inputs[3] + num_layers = inputs[4] + dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout + # train = inputs[6] + bidirectional = inputs[7] + batch_first = inputs[8] + + num_directions = 1 + if bidirectional: + num_directions = 2 + + rsd = len(_weights) % num_layers + assert rsd == 0, "The number of weights must be a multiple of the number of layers!" + rsd = (len(_weights) / num_layers) % num_directions + assert ( + rsd == 0 + ), "The number of weights in layer must be a multiple of the number of directions!" + + weights_num = int(len(_weights) / num_layers / num_directions) + if has_biases: + assert weights_num == 4, "The weights number in layer is expected equal to 4" + else: + assert weights_num == 2, "The weights number in layer is expected equal to 2" + + X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X + # TODO (vvchernov): Which data type should be used? from input or weights? + # Instead of it _infer_type(X).checked_type.dtype can be used + X_dtype = input_types[0] + X_shape = _infer_shape(X) # (seq_num, batch, feature_size) + + hidden_size = int(_infer_shape(_weights[0])[0] / 3) + batch_size = X_shape[1] + + # Initialize hidden states if not provided. + layers_h = [] + hidden_layers_num = num_directions * num_layers + if hidden_state is None: + h_0 = _op.zeros((batch_size, hidden_size), X_dtype) + for i in range(hidden_layers_num): + layers_h.append(h_0) + else: + layers_h = unbind(hidden_state, 0) + + layer_weights_dicts = [] + k = 0 # layer counter + if has_biases: + names = ["hidden_state", "w_inp", "w_hid", "b_inp", "b_hid"] + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of GRU weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_tensors = [layers_h[2 * k], *_weights[i : i + 4]] + fw_weights_dict = dict(zip(names, fw_tensors)) + j = i + weights_num + rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 4]] + rev_weights_dict = dict(zip(names, rev_tensors)) + layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) + k += 1 + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights" + for i in range(0, len(_weights), weights_num): + fw_tensors = [layers_h[k], *_weights[i : i + 4]] + fw_weights_dict = dict(zip(names, fw_tensors)) + layer_weights_dicts.append([fw_weights_dict]) + k += 1 + else: + names = ["hidden_state", "w_inp", "w_hid"] + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of GRU weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_tensors = [layers_h[2 * k], *_weights[i : i + 2]] + fw_weights_dict = dict(zip(names, fw_tensors)) + j = i + weights_num + rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 2]] + rev_weights_dict = dict(zip(names, rev_tensors)) + layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) + k += 1 + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights" + for i in range(0, len(_weights), weights_num): + fw_tensors = [layers_h[k], *_weights[i : i + 2]] + fw_weights_dict = dict(zip(names, fw_tensors)) + layer_weights_dicts.append([fw_weights_dict]) + k += 1 + assert ( + len(layer_weights_dicts) == num_layers and k == num_layers + ), "For stacked GRU number of weights sets should be the same as number of layers!" + + output, out_hidden_state = self.gru_layers( + X, + layer_weights_dicts, + bidirectional, + dropout_p=dropout_p, + ) + + # output shape = (seq_num, batch, hidden_size) or + # (seq_num, batch, 2*feature_size) for bidirectional + if batch_first: + output = _op.transpose(output, (1, 0, 2)) + + return (output, out_hidden_state) + + def bidir_lstm_cell( + self, + input_seqs, + weights_dicts, + ): + """ + Bidirectional LSTM cell + """ + seq_len = len(input_seqs) + forward_outputs, fw_H_t, fw_C_t = lstm_cell( + input_seqs, + **weights_dicts[0], + ) + + reverse_outputs, rev_H_t, rev_C_t = lstm_cell( + input_seqs, + **weights_dicts[1], + backwards=True, + ) + + final_outputs = [] + for i in range(seq_len): + final_outputs.append( + _op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=-1) + ) + + return final_outputs, (fw_H_t, fw_C_t), (rev_H_t, rev_C_t) + + def lstm_layers(self, input_data, layer_weights_dicts, bidirectional, dtype, dropout_p=0.0): + """ + Methods iterates layers for Stacked LSTM + """ + layers_num = len(layer_weights_dicts) + # split input sequence to samples set + input_seqs = unbind(input_data, 0) # [seq_num, (batch, feature_size)] + output_hiddens = [] + for i in range(layers_num): + weights_dicts = layer_weights_dicts[i] + # input_seqs shape = [seq_num, (batch, feature_size)] or + # [seq_num, (batch, 2*feature_size)] for bidirectional + if bidirectional: + input_seqs, H_t, C_t = self.bidir_lstm_cell(input_seqs, weights_dicts) + else: + input_seqs, H_t, C_t = lstm_cell(input_seqs, **weights_dicts[0]) + + output_hiddens.append((H_t, C_t)) + + # TODO (vvchernov): in pytorch implementation train is also checked + # see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339 + # /aten/src/ATen/native/RNN.cpp#L1054 + if dropout_p != 0 and i < layers_num - 1: + # for input in input_seqs: + # input = _op.dropout(input, dropout_p) + raise NotImplementedError("Dropout for LSTM has not been supported yet!") + final_hiddens = [] + if bidirectional: + for output_hidden in output_hiddens: + final_hiddens.append(output_hidden[0]) + final_hiddens.append(output_hidden[1]) + else: + final_hiddens = output_hiddens + + return _op.stack(input_seqs, 0), final_hiddens + + def lstm(self, inputs, input_types): + """ + Description of LSTM in pytorch:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html + Native implementation for torch version less than 1.8.0 (projection is unsupported): + https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/ \ + src/ATen/native/RNN.cpp#L1396 + Native implementation for torch version from 1.8.0 and higher (projection is supported): + https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp#L1483 + """ + # TODO (vvchernov): support dropout + assert len(inputs) == 9, "Input of size 9 is expected" + # Unpack inputs, note that if optional and not provided then value will be None. + _X = inputs[0] + # _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size) + + hidden_states = inputs[1] + assert len(hidden_states) == 2, "lstm expects two hidden states" + h_0 = hidden_states[0] + c_0 = hidden_states[1] + # H0 shape (hidden_layers_num, batch, proj_size) if projection + # else (hidden_layers_num, batch, hidden_size) + # C0 shape (hidden_layers_num, batch, hidden_size) + + _weights = inputs[2] + # If no projection + # Wi layer[0] shape (4 * hidden_size, feature_size) + # Wh layer[0] shape (4 * hidden_size, hidden_size) + # Bi layer[0] shape (4 * hidden_size) + # Bh layer[0] shape (4 * hidden_size) + + # Wi layer[>0] shape (4 * hidden_size, hidden_size * num_directions) + # Wh layer[>0] shape (4 * hidden_size, hidden_size) + # Bi layer[>0] shape (4 * hidden_size) + # Bh layer[>0] shape (4 * hidden_size) + + # If projection + # Wi layer[0] shape (4 * hidden_size, feature_size) + # Wh layer[0] shape (4 * hidden_size, proj_size) + # Bi layer[0] shape (4 * hidden_size) + # Bh layer[0] shape (4 * hidden_size) + # P layer[0] shape (proj_size, hidden_size) + + # Wi layer[>0] shape (4 * hidden_size, proj_size * num_directions) + # Wh layer[>0] shape (4 * hidden_size, proj_size) + # Bi layer[>0] shape (4 * hidden_size) + # Bh layer[>0] shape (4 * hidden_size) + # P layer[>0] shape (proj_size, hidden_size) + + # Scalar inputs + has_biases = inputs[3] + num_layers = inputs[4] + dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout + # train = inputs[6] + bidirectional = inputs[7] + batch_first = inputs[8] + + num_directions = 1 + if bidirectional: + num_directions = 2 + + rsd = len(_weights) % num_layers + assert rsd == 0, "The number of weights must be a multiple of the number of layers!" + rsd = (len(_weights) / num_layers) % num_directions + assert ( + rsd == 0 + ), "The number of weights in layer must be a multiple of the number of directions!" + has_proj = False + proj_size = 0 + weights_num = int(len(_weights) / num_layers / num_directions) + if has_biases: + if weights_num == 5: + has_proj = True + proj_size = _infer_shape(_weights[4])[0] + else: + assert weights_num == 4, "The weights number in layer is expected equal to 4" + else: + if weights_num == 3: + has_proj = True + proj_size = _infer_shape(_weights[2])[0] + else: + assert weights_num == 2, "The weights number in layer is expected equal to 2" + + X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X + # TODO (vvchernov): Which data type should be used? from input or weights? + # Instead of it _infer_type(X).checked_type.dtype can be used + X_dtype = input_types[0] + X_shape = _infer_shape(X) # (seq_num, batch, feature_size) + + hidden_size = _infer_shape(_weights[0])[0] / 4 + batch_size = X_shape[1] + + # Initialize hidden states if not provided. + layers_h = [] + layers_c = [] + hidden_layers_num = num_directions * num_layers + if h_0 is None: + if has_proj: + h_0 = _op.zeros((batch_size, proj_size), X_dtype) + else: + h_0 = _op.zeros((batch_size, hidden_size), X_dtype) + for i in range(hidden_layers_num): + layers_h.append(h_0) + else: + layers_h = unbind(h_0, 0) + if c_0 is None: + c_0 = _op.zeros((batch_size, hidden_size), X_dtype) + for i in range(hidden_layers_num): + layers_c.append(c_0) + else: + layers_c = unbind(c_0, 0) + + layer_weights_dicts = [] + k = 0 # layer counter + if has_biases: + names = ["hidden_state", "cell_state", "w_inp", "w_hid", "b_inp", "b_hid"] + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_tensors = [layers_h[2 * k], layers_c[2 * k], *_weights[i : i + 4]] + fw_weights_dict = dict(zip(names, fw_tensors)) + if has_proj: + fw_weights_dict["proj"] = _weights[i + 4] + j = i + weights_num + rev_tensors = [layers_h[2 * k + 1], layers_c[2 * k + 1], *_weights[j : j + 4]] + rev_weights_dict = dict(zip(names, rev_tensors)) + if has_proj: + rev_weights_dict["proj"] = _weights[j + 4] + layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) + k += 1 + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), weights_num): + fw_tensors = [layers_h[k], layers_c[k], *_weights[i : i + 4]] + fw_weights_dict = dict(zip(names, fw_tensors)) + if has_proj: + fw_weights_dict["proj"] = _weights[i + 4] + layer_weights_dicts.append([fw_weights_dict]) + k += 1 + else: + names = ["hidden_state", "cell_state", "w_inp", "w_hid"] + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_tensors = [layers_h[2 * k], layers_c[2 * k], *_weights[i : i + 2]] + fw_weights_dict = dict(zip(names, fw_tensors)) + if has_proj: + fw_weights_dict["proj"] = _weights[i + 2] + j = i + weights_num + rev_tensors = [layers_h[2 * k + 1], layers_c[2 * k + 1], *_weights[j : j + 2]] + rev_weights_dict = dict(zip(names, rev_tensors)) + if has_proj: + rev_weights_dict["proj"] = _weights[j + 2] + layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) + k += 1 + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), weights_num): + fw_tensors = [layers_h[k], layers_c[k], *_weights[i : i + 2]] + fw_weights_dict = dict(zip(names, fw_tensors)) + if has_proj: + fw_weights_dict["proj"] = _weights[i + 2] + layer_weights_dicts.append([fw_weights_dict]) + k += 1 + assert ( + len(layer_weights_dicts) == num_layers and k == num_layers + ), "For stacked LSTM number of weights sets should be the same as number of layers!" + + outputs = self.lstm_layers( + X, + layer_weights_dicts, + bidirectional, + dtype=X_dtype, + dropout_p=dropout_p, + ) + + # output shape = (seq_num, batch, hidden_size) or + # (seq_num, batch, 2*feature_size) for bidirectional + output = outputs[0] + + hy = [] + cy = [] + for hidden in outputs[1]: + hy.append(hidden[0]) + cy.append(hidden[1]) + + if batch_first: + output = _op.transpose(output, (1, 0, 2)) + + return (output, _op.stack(hy, 0), _op.stack(cy, 0)) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2385,6 +2819,7 @@ def create_convert_map(self): "aten::celu": self.celu, "aten::gelu": self.gelu, "aten::selu": self.selu, + "aten::silu": self.silu, "aten::log_sigmoid": self.log_sigmoid, "aten::adaptive_avg_pool2d": self.adaptive_avg_pool_2d, "aten::adaptive_max_pool2d": self.adaptive_max_pool_2d, @@ -2414,6 +2849,7 @@ def create_convert_map(self): "aten::clone": self.clone, "aten::log_softmax": self.log_softmax, "aten::sigmoid": self.sigmoid, + "aten::sigmoid_": self.sigmoid, "aten::softplus": self.softplus, "aten::avg_pool1d": self.make_avg_pool(1), "aten::avg_pool2d": self.make_avg_pool(2), @@ -2425,6 +2861,7 @@ def create_convert_map(self): "aten::alpha_dropout": self.dropout, "aten::mean": self.mean, "aten::chunk": self.chunk, + "aten::unsafe_chunk": self.chunk, "aten::matmul": self.matmul, "aten::bmm": self.matmul, "aten::expand": self.expand, @@ -2455,6 +2892,7 @@ def create_convert_map(self): "aten::sinh": self.make_unary("sinh"), "aten::tan": self.make_unary("tan"), "aten::tanh": self.make_unary("tanh"), + "aten::tanh_": self.make_unary("tanh"), "aten::acos": self.make_unary("acos"), "aten::asin": self.make_unary("asin"), "aten::atan": self.make_unary("atan"), @@ -2478,6 +2916,7 @@ def create_convert_map(self): "aten::clamp_": self.clamp, "aten::detach": self.identity, "aten::upsample_bilinear2d": self.make_upsample("linear"), + "aten::upsample_bicubic2d": self.make_upsample("cubic"), "aten::upsample_nearest2d": self.make_upsample("nearest_neighbor"), "aten::upsample_trilinear3d": self.make_upsample3d("linear"), "aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"), @@ -2545,6 +2984,8 @@ def create_convert_map(self): "aten::nll_loss": self.nll_loss, "aten::nll_loss2d": self.nll_loss, "aten::flip": self.flip, + "aten::gru": self.gru, + "aten::lstm": self.lstm, } def update_convert_map(self, custom_map): diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 02b2484d4fb7..753b1f253a89 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -90,7 +90,7 @@ def batched_nms(boxes, scores, idxs, iou_threshold): """ one = is_constant() - # Equivelent PyTorch code from above snippet + # Equivalent PyTorch code from above snippet # offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) cast = is_op("cast")(idxs) mx = is_op("max")(boxes) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index f614982aac6c..9eafae905baf 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -163,7 +163,7 @@ def _get_quant_param_for_input(input_value): """ We want to know the input scale and zp of this input_value, since input quant params are not explicitly passed around in torch (they - are embeded in a QTensor data structure, not visible statically). + are embedded in a QTensor data structure, not visible statically). We know that it is quantized using output scale and zp of some previous quantized op. The purpose of this function is to find that pair of parameters. diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index e297398ffe5b..d35e0e1c203d 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -52,6 +52,11 @@ # However, please note that `nn.matmul` is in experimental so it may have some performance # issues. "use_dense": True, + # By default, TVM converts `tf.batch_matmul` to `transpose(weight) + nn.batch_matmul_NT`. + # Change this flag to False to directly convert to `nn.batch_matmul`. + # Note that `nn.batch_matmul` with format other than NT is in experimental, it may have some + # performance issues. + "use_nt_batch_matmul": True, } # compatible operators that do NOT require any conversion. @@ -117,7 +122,7 @@ def _in_while_loop(control_flow_node_map, op_name): Parameters ---------- control_flow_node_map : Dict[str, Set[str]] - A dictionay contains the unique control flow execution frame name to + A dictionary contains the unique control flow execution frame name to a set of primitive operators mapping. op_name : str @@ -139,7 +144,7 @@ class RewriteSubgraph(ExprMutator): Parameters ---------- rewrite_map : Dict[expr, expr] - A dictionay contains a set of expr to var mapping. + A dictionary contains a set of expr to var mapping. """ def __init__(self, rewrite_map): @@ -1214,7 +1219,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): return func, self._params -def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op=True): +def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, convert_config=None): """Load tensorflow graph which is a python tensorflow graph object into relay. The companion parameters will be handled automatically. @@ -1232,10 +1237,15 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op outputs : List of output tensor names (Optional) if not specified then the last node is assumed as graph output. - use_dense_op : bool (Optional) = True - Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`. - The `nn.dense` op requires the data tensor to be non-transposed and weight tensor to be - transposed, may insert extra `transpose` to the original graph. + convert_config : Optional[Dict[str, Any]] + Default config: + use_dense : bool = True + Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`. + The `nn.dense` op requires the data tensor to be non-transposed and weight tensor + to be transposed, may insert extra `transpose` to the original graph. + use_nt_batch_matmul : bool = True + True to convert `tf.batch_matmul` to `nn.batch_matmul` strict to NT format + (transpose_a=False, transpose_b=True). Returns ------- @@ -1246,7 +1256,8 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op Dict of converted parameters stored in tvm.nd.NDArray format """ global TF_DEFAULT_CONFIGS - TF_DEFAULT_CONFIGS["use_dense"] = use_dense_op + if convert_config is not None: + TF_DEFAULT_CONFIGS.update(convert_config) g = GraphProto() mod, params = g.from_tensorflow(graph, layout, shape, outputs) diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index e5339b33c4e9..465f530624b9 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except +# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except, too-many-nested-blocks """Tensorflow2.x graph to relay converter. If model is constructed using tf2.x API, then use this converter: @@ -38,12 +38,20 @@ from .common import infer_type as _infer_type from .tensorflow_ops import _convert_map as _convert_map_common -from .tensorflow_ops import _need_prelude_for_shape_inference +from .tensorflow_ops import _get_more_static_shape_rank +from .tensorflow2_ops import _convert_map as _convert_map_tf2 +from .tensorflow2_ops import _need_prelude_for_shape_inference from ..ty import Any __all__ = ["from_tensorflow"] +# A map to record tensor list write ops and input tl/tensor indices +# Value is (index of tensor list, index of written node) +_tensor_list_write_ops = { + "TensorListSetItem": (0, 2), +} + def _infer_type_with_prelude(val, prelude): body = _infer_type(val, prelude.mod) @@ -66,6 +74,11 @@ def set_span(sym, node_name): return sym +def is_tensor_list_constuctor(tf_node): + """Check whether is tensor list constructor node.""" + return tf_node.op == "TensorListReserve" + + def convert_const_node(node, shape): """convert tf const node into relay const or var""" @@ -196,6 +209,10 @@ def __init__(self, module): self._output_shapes = {} self._tf_node_map = {} self._gdef_lib = {} + self._tensor_list_shapes = {} + self._tensor_list_shape_nodes = {} + self._sub_map = {} + self._sub_input_idx_map = {} def from_tensorflow( self, graph, layout="NHWC", shape=None, outputs=None, input_types=None, gdef_lib=None @@ -215,10 +232,134 @@ def from_tensorflow( ) return func, self._params + def _analysis_tensor_list_op( + self, + graph, + node, + tl_write_nodes, + tl_stack_nodes, + tl_construct_nodes, + sub_func_name="", + root_node="", + ): + if sub_func_name and sub_func_name not in self._sub_input_idx_map: + self._sub_input_idx_map[sub_func_name] = {} + + if node.op == "Placeholder": + # record placeholder node in sub functions + self._sub_map[sub_func_name] = node + self._sub_input_idx_map[sub_func_name][node.name] = len( + self._sub_input_idx_map[sub_func_name] + ) + + if node.op.startswith("TensorList"): + if is_tensor_list_constuctor(node): + tl_construct_nodes.append(node) + else: + for tl_write_name, idx in _tensor_list_write_ops.items(): + if node.op.startswith(tl_write_name): + tl_write_nodes.append((node, idx, sub_func_name, root_node)) + if node.op.startswith("TensorListStack"): + tl_stack_nodes.append(node) + elif node.op.startswith("StatelessWhile"): + root_node = node.name + cond_fn_name, body_fn_name = [ + parse_attr(node.attr).get(x).name for x in ["cond", "body"] + ] + for fn_name in [cond_fn_name, body_fn_name]: + subfunction = self._gdef_lib[fn_name] + sub_func_name = fn_name + for sub_node in subfunction.node: + # bypass const node + if sub_node.op == "Const": + continue + self._tf_node_map[sub_node.name] = sub_node + self._analysis_tensor_list_op( + subfunction, + sub_node, + tl_write_nodes, + tl_stack_nodes, + tl_construct_nodes, + sub_func_name=sub_func_name, + root_node=root_node, + ) + + def _infer_static_shape_stack_node(self, tl_stack_nodes): + for stack_node in tl_stack_nodes: + if len(stack_node.input) < 2: + # Stack node does not have shape + continue + input_shape_name = stack_node.input[1].split(":")[0] + input_shape_node = self._tf_node_map[input_shape_name] + stack = [self._tf_node_map[stack_node.input[0].split(":")[0]]] + in_idx = -1 + while stack: + cnode = stack.pop(0) + if not cnode.op.startswith("TensorList"): + if in_idx and cnode.op.startswith("StatelessWhile"): + stack.append(self._tf_node_map[cnode.input[in_idx].split(":")[0]]) + else: + for iname in cnode.input: + if self._tf_node_map[iname.split(":")[0]].op.startswith( + "StatelessWhile" + ): + # identify input index based on output index + if iname.split(":")[1]: + in_idx = int(iname.split(":")[1]) + stack.append(self._tf_node_map[iname.split(":")[0]]) + # identify the corresponding constructor node and add shape to _tensor_list_shapes + elif cnode.name != stack_node.name: + if is_tensor_list_constuctor(cnode): + shape_attr = parse_attr(input_shape_node.attr) + if "value" not in shape_attr: + continue + raw_elem_shape = tensor_util.MakeNdarray(shape_attr["value"]) + elem_shape = [] + for dim in raw_elem_shape: + if dim < 0: + elem_shape.append(Any()) + else: + elem_shape.append(int(dim)) + self._tensor_list_shapes[cnode.name] = elem_shape + break + + def _infer_static_shape_write_node(self, tl_write_nodes): + for item in tl_write_nodes: + wnode = item[0] + ta_idx, inode_idx = item[1] + sub_func_name = item[2] + root_name = item[3] + stack = [self._tf_node_map[wnode.input[ta_idx].split(":")[0]]] + while stack: + cnode = stack.pop(0) + + if not cnode.op.startswith("TensorList"): + if cnode.op == "Placeholder" and sub_func_name: + # need to map subfunction + input_idx = self._sub_input_idx_map[sub_func_name][cnode.name] + stack.append( + self._tf_node_map[ + self._tf_node_map[root_name].input[input_idx].split(":")[0] + ] + ) + else: + for iname in cnode.input: + stack.append(self._tf_node_map[iname.split(":")[0]]) + # identify the corresponding constructor node and add it to _tensor_list_shape_nodes + elif cnode.name != wnode.name: + if is_tensor_list_constuctor(cnode): + inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]] + tn = wnode.input[inode_idx].split(":") + output_index = int(tn[1]) if len(tn) > 1 else 0 + self._tensor_list_shape_nodes[cnode.name] = (inode, wnode.op, output_index) + break + def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_types=None): if input_types is None: input_types = {} - + tl_write_nodes = [] + tl_stack_nodes = [] + tl_construct_nodes = [] self._layout = layout for node in graph.node: name = node.name @@ -235,6 +376,18 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_ self._nodes[node.name] = sym if param: self._params[node.name] = param + # recursivly iterate tensorlist op if seen while loop + else: + self._analysis_tensor_list_op( + graph, node, tl_write_nodes, tl_stack_nodes, tl_construct_nodes + ) + + # Use tensor list stack to infer static tensor list shape + self._infer_static_shape_stack_node(tl_stack_nodes) + + # Fetch node contains static tensor list shape + self._infer_static_shape_write_node(tl_write_nodes) + for node in graph.node: self._backtrack_construct(graph, node.name) @@ -321,16 +474,36 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs): gdef_lib=self._gdef_lib, ) elif op_name in _convert_map_common: + # assert op are exclusive + assert not set(_convert_map_common.keys()) & set(_convert_map_tf2.keys()) if _need_prelude_for_shape_inference(op_name): sym = _convert_map_common[op_name](inputs, attrs, self._params, self._prelude) else: sym = _convert_map_common[op_name](inputs, attrs, self._params, self._module.mod) + elif op_name in _convert_map_tf2: + if _need_prelude_for_shape_inference(op_name): + sym = _convert_map_tf2[op_name](inputs, attrs, self._params, self._prelude) + else: + sym = _convert_map_tf2[op_name](inputs, attrs, self._params, self._module.mod) else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) sym = set_span(sym, node_name) return sym + def _parse_element_shape(self, elem_shape, shape_attr): + if "value" in shape_attr: + raw_elem_shape = tensor_util.MakeNdarray(shape_attr["value"]) + + if raw_elem_shape.size == 1 and raw_elem_shape == -1: + elem_shape.append(Any()) + else: + for dim in raw_elem_shape: + if dim < 0: + elem_shape.append(Any()) + else: + elem_shape.append(dim) + def _backtrack_construct(self, graph, node_name): """Convert a specific tensorflow node to relay expression. @@ -370,8 +543,8 @@ def _backtrack_construct(self, graph, node_name): CallNode(Op(add), [Var(x, ty=TensorType([], float32)), Constant(1.0)], (nullptr), []) """ - input_op_name = node_name.split(":")[0].split("^")[-1] + if input_op_name not in self._nodes: node = self._tf_node_map[input_op_name] attr = parse_attr(node.attr) @@ -386,8 +559,31 @@ def _backtrack_construct(self, graph, node_name): attr["_node_name"] = node.name attr["_target_layout"] = self._layout inputs = [self._backtrack_construct(graph, iname) for iname in node.input] - op = self._convert_operator(graph, node.op, node.name, inputs, attr) + # infer shape for TensorList op + if is_tensor_list_constuctor(node): + input_shape_name = ( + node.input[1] if "TensorListFromTensor" in node.op else node.input[0] + ) + input_shape_name = input_shape_name.split(":")[0] + input_shape_node = self._tf_node_map[input_shape_name] + shape_attr = parse_attr(input_shape_node.attr) + elem_shape = [] + + self._parse_element_shape(elem_shape, shape_attr) + + if elem_shape: + attr["shape"] = elem_shape + if ( + "identical_element_shapes" in attr and attr["identical_element_shapes"] + ) or elem_shape: + shape = elem_shape + if node.name in self._tensor_list_shapes: + preset_shape = self._tensor_list_shapes[node.name] + shape = _get_more_static_shape_rank(shape, preset_shape) + attr["shape"] = shape + + op = self._convert_operator(graph, node.op, node.name, inputs, attr) if isinstance(op, np.ndarray): self._params[node.name] = tvm.nd.array(op) op = [ @@ -512,7 +708,7 @@ def _convert_function( Examples -------- - a tf function "x+1", is implemented as a subgraph in the libary section of the graph. + a tf function "x+1", is implemented as a subgraph in the library section of the graph. this subgraph is converted to a relay function such as fn (%x: float32) { add(%x, 1f) /* Identity */ diff --git a/python/tvm/relay/frontend/tensorflow2_ops.py b/python/tvm/relay/frontend/tensorflow2_ops.py new file mode 100644 index 000000000000..17cd112878a5 --- /dev/null +++ b/python/tvm/relay/frontend/tensorflow2_ops.py @@ -0,0 +1,187 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except +"""Tensorflow2.x to relay converter ops and helper""" +import tvm +from tvm.relay.prelude import StaticTensorArrayOps, get_tensor_array_shape + +from .. import op as _op +from ..ty import Any +from .common import infer_value as _infer_value +from .common import infer_type as _infer_type +from .tensorflow_ops import _get_more_static_shape_rank + + +def _infer_type_with_prelude(val, prelude): + body = _infer_type(val, prelude.mod) + return body.checked_type + + +def _need_prelude_for_shape_inference(op): + return "TensorList" in op or "TensorArray" in op + + +def _tensorlist_reserve(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get("element_dtype").name + elem_shape = _infer_value(inputs[0], params, prelude.mod) + elem_shape = tuple(elem_shape.numpy().astype("int32").flatten()) + + if elem_shape or "shape" in attr: + shape = attr["shape"] if "shape" in attr else elem_shape + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, shape) + static_tensor_array_ops.register() + tensor_array_constructor = static_tensor_array_ops.get_global_var("tensor_array") + tensor_array = tensor_array_constructor(inputs[1]) + else: + tensor_array_constructor = prelude.get_global_var("tensor_array", dtype_str) + tensor_array = tensor_array_constructor(inputs[1]) + return tensor_array + + return _impl + + +def _tensorlist_set_item(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get("element_dtype").name + input_ta = inputs[0] + input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) + input_t_shape = _infer_type_with_prelude(inputs[2], prelude).shape + input_rank = len(input_t_shape) + + if input_ta_shape is None: + tensor_name = "tensor{}".format(input_rank) + tensor_func = prelude.get_tensor_ctor(tensor_name, dtype_str) + v = tensor_func(inputs[2]) + write_func = prelude.get_global_var("tensor_array_write", dtype_str) + out = write_func(input_ta, inputs[1], v) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape) + static_tensor_array_ops.register() + tensor_func = static_tensor_array_ops.get_ctor("tensor_constructor") + v = tensor_func(inputs[2]) + # Write tensor with more static shape + # convert shape with -1 to any() + input_ta_shape_a = [] + for dim in input_ta_shape: + if isinstance(dim, (int, tvm.tir.expr.IntImm)): + if dim < 0: + input_ta_shape_a.append(Any()) + else: + input_ta_shape_a.append(dim) + else: + input_ta_shape_a.append(dim) + actual_shape = _get_more_static_shape_rank(input_t_shape, input_ta_shape_a) + if actual_shape != input_ta_shape_a: + new_shape = [] + num_any_dim = 0 + for dim in actual_shape: + if not isinstance(dim, int): + num_any_dim += 1 + new_shape.append(dim if isinstance(dim, int) else -1) + if num_any_dim <= 1: + v = tensor_func(_op.reshape(inputs[2], new_shape)) + write_func = prelude.get_global_var_static( + "tensor_array_write", dtype_str, input_ta_shape_a + ) + out = write_func(input_ta, inputs[1], v) + return out + + return _impl + + +def _tensorlist_get_item(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr["element_dtype"].name + input_shape = get_tensor_array_shape(inputs[0], dtype_str, prelude) + + if input_shape is None: + read_func = prelude.get_global_var("tensor_array_read", dtype_str) + out = read_func(inputs[0], _op.take(inputs[1], tvm.relay.const(0))) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape) + static_tensor_array_ops.register() + read_func = static_tensor_array_ops.get_global_var("tensor_array_read") + out_tensor = read_func(inputs[0], _op.take(inputs[1], tvm.relay.const(0))) + get_data_func = static_tensor_array_ops.get_global_var("tensor_get_data") + out = get_data_func(out_tensor) + return out + + return _impl + + +def _tensorlist_stack(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr["element_dtype"].name + input_ta_shape = get_tensor_array_shape(inputs[0], dtype_str, prelude) + + if input_ta_shape is None: + stack_func = prelude.get_global_var("tensor_array_stack", dtype_str) + out = stack_func(inputs[0]) + else: + if "num_elements" in attr: + num_elements = attr["num_elements"] + static_tensor_array_ops = StaticTensorArrayOps( + prelude, dtype_str, input_ta_shape, num_elements + ) + static_tensor_array_ops.register() + stack_func = prelude.get_global_var_static( + "tensor_array_stack", dtype_str, input_ta_shape, num_elements + ) + out_tensor = stack_func(inputs[0]) + out_shape = ( + (num_elements,) + input_ta_shape + if num_elements and num_elements == 1 + else (Any(),) + input_ta_shape + ) + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape) + static_tensor_array_ops.register() + get_data_func = prelude.get_global_var_static("tensor_get_data", dtype_str, out_shape) + out = get_data_func(out_tensor) + + return out + + return _impl + + +def _tensorlist_from_tensor(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr["element_dtype"].name + input_ta_shape = _infer_type_with_prelude(inputs[0], prelude).shape + + if input_ta_shape is None: + unstack_func = prelude.get_global_var("tensor_array_unstack", dtype_str) + out = unstack_func(inputs[0]) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape) + static_tensor_array_ops.register() + unstack_func = prelude.get_global_var_static( + "tensor_array_unstack", dtype_str, input_ta_shape + ) + out = unstack_func(inputs[0]) + return out + + return _impl + + +_convert_map = { + "TensorListFromTensor": _tensorlist_from_tensor(), + "TensorListGetItem": _tensorlist_get_item(), + "TensorListReserve": _tensorlist_reserve(), + "TensorListSetItem": _tensorlist_set_item(), + "TensorListStack": _tensorlist_stack(), +} diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 797ff51ace7a..a8213d4b1c49 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -138,6 +138,18 @@ def _get_more_static_shape(shape0, shape1): return shape1 +def _get_more_static_shape_rank(shape0, shape1): + """Compare two shapes with different rank, + and return the one with fewer symbolic dimension. + """ + num_sym_dim0 = sum([not isinstance(dim, (int, tvm.tir.expr.IntImm)) for dim in list(shape0)]) + num_sym_dim1 = sum([not isinstance(dim, (int, tvm.tir.expr.IntImm)) for dim in list(shape1)]) + + if num_sym_dim0 < num_sym_dim1: + return shape0 + return shape1 + + def _rsqrt(): def _impl(inputs, attr, params, mod): inputs.append(tvm.relay.const(-0.5, attr["T"].name)) @@ -1137,6 +1149,8 @@ def _impl(inputs, attr, params, mod): def _batch_matmul(): def _impl(inputs, attr, params, mod): + from .tensorflow import TF_DEFAULT_CONFIGS + input_x = inputs[0] input_y = inputs[1] orig_shape_x = _infer_shape(input_x, mod) @@ -1173,9 +1187,16 @@ def _impl(inputs, attr, params, mod): input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1])) adj_x = attr["adj_x"] adj_y = attr["adj_y"] - input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x - input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y - ret = get_relay_op("batch_matmul")(input_x, input_y) + + if TF_DEFAULT_CONFIGS["use_nt_batch_matmul"]: + # Strictly convert all batch_matmul to NT format + input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x + input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y + ret = get_relay_op("batch_matmul")(input_x, input_y) + else: + ret = get_relay_op("batch_matmul")( + input_x, input_y, transpose_a=adj_x, transpose_b=adj_y + ) # reshape result back to n-dimensional if ndim > 3: @@ -1483,7 +1504,13 @@ def _impl(inputs, attr, params, mod): def _concatV2(): def _impl(inputs, attr, params, mod): pop_node = inputs.pop(len(inputs) - 1) - axis = int(_get_num_param(params, pop_node)) + try: + axis = int(_get_num_param(params, pop_node)) + except (IndexError, KeyError, AttributeError): + try: + axis = int(_infer_value(pop_node, params, mod).numpy()) + except Exception: + axis = int(pop_node) return AttrCvt(op_name="concatenate", ignores=["T", "N", "Tidx"], extras={"axis": axis})( [inputs], attr ) @@ -2244,7 +2271,7 @@ def _transform_mask(stride_dim, ellipsis_mask): if begin[index] < 0 else begin[index] ) - m_end[final_index] = begin[index] + 1 + m_end[final_index] = m_begin[final_index] + 1 m_stride[final_index] = 1 fshape_indices.append(-2) else: diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 42096ad9af2f..4d607e46c97f 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -91,6 +91,7 @@ def __init__(self, model, subgraph, exp_tab): "EQUAL": self.convert_equal, "EXP": self.convert_exp, "EXPAND_DIMS": self.convert_expand_dims, + "FAKE_QUANT": self.convert_fake_quant, "FILL": self.convert_fill, "FLOOR_DIV": self.convert_floor_div, "FLOOR_MOD": self.convert_floor_mod, @@ -255,23 +256,23 @@ def get_op_code_str(self, op): op_c = self.model.OperatorCodes(op_code_list_idx) # In TFlite 2.4.x there was a change where the type of the field that contained # the builtin code changed from int8 to int32 in the flat buffer representation. - # However to retain support for old flat buffers that were created, they retained - # the original 8 bit encoding for the operator but in a new field accessed by the - # DeprecatedBuiltinCode method. - # This means that the API function BuiltinCode() is used on an operator - # which was originally encoded as an 8 bit quantity it would look for the - # code in the new int32 field in the schema and this creates the need - # for the check for the magic number of 127 which is indicated by - # BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES + # However, to retain support for old flat buffers that were created, they retained + # the original 8 bit field, but named it "deprecated_builtin_code" in TFLite 2.4. + # This means that the API function BuiltinCode() which originally returned the value + # of the 8 bit field would now look for the value in the new int32 field in the + # schema and DeprecatedBuiltinCode() will look at the old 8 bit field. + # In TFLite 2.4, if the opcode value is less than 127, it can be in either field + # (however, if it is only in the "builtin_code" field, the model is not backward + # compatible), so similarly to TFLite 2.4 reader, we'll pick the higher value of the + # two fields. # Remember however that this value came into existence only after Tensorflow # lite 2.4.x and hence encase it in a try -except block. # Phew ! try: - if op_c.BuiltinCode() < BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES: - opc = op_c.DeprecatedBuiltinCode() - else: - opc = op_c.BuiltinCode() + opc = max(op_c.DeprecatedBuiltinCode(), op_c.BuiltinCode()) except AttributeError: + # In versions before 2.4 the int8 field that holds the builtin code is accessed + # by BuiltinCode() and DeprecatedBuiltinCode() doesn't exist opc = op_c.BuiltinCode() op_code_id = opc @@ -642,9 +643,12 @@ def _convert_resize(self, method, op): op_options = op.BuiltinOptions() resize_options.Init(op_options.Bytes, op_options.Pos) align_corners = resize_options.AlignCorners() + half_pixel_centers = resize_options.HalfPixelCenters() # Use layout NHWC coord_trans = "align_corners" if align_corners else "asymmetric" + coord_trans = "half_pixel" if half_pixel_centers else coord_trans + if bilinear_method and input_tensor.qnn_params: in_expr = self.dequantize(in_expr, input_tensor) out = _op.image.resize2d( @@ -3333,6 +3337,56 @@ def convert_densify(self, op): self.set_prefetched_node(output_tensor.tensor_idx, dense_weight) + def convert_fake_quant(self, op): + """Convert TFLite FAKE_QUANT""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + from tflite.BuiltinOptions import BuiltinOptions + from tflite.FakeQuantOptions import FakeQuantOptions + + assert op.BuiltinOptionsType() == BuiltinOptions.FakeQuantOptions + + op_options = op.BuiltinOptions() + fake_quant_options = FakeQuantOptions() + fake_quant_options.Init(op_options.Bytes, op_options.Pos) + + opt_min = fake_quant_options.Min() + opt_max = fake_quant_options.Max() + narrow_range = fake_quant_options.NarrowRange() + num_bits = fake_quant_options.NumBits() + + assert 2 <= num_bits <= 16 + + quant_min = 1 if narrow_range else 0 + quant_max = (1 << num_bits) - 1 + scale = (opt_max - opt_min) / (quant_max - quant_min) + + zero_point_from_min = quant_min - opt_min / scale + if zero_point_from_min <= quant_min: + nudged_zero_point = quant_min + elif zero_point_from_min >= quant_max: + nudged_zero_point = quant_max + else: + nudged_zero_point = round(zero_point_from_min) + + nudged_min = (quant_min - nudged_zero_point) * scale + nudged_max = (quant_max - nudged_zero_point) * scale + + nudged_min_expr = _op.const(nudged_min) + clamped = _op.clip(in_expr, nudged_min, nudged_max) + clamped_shifted = _op.subtract(clamped, nudged_min_expr) + + half = _op.const(0.5) + one = _op.const(1.0) + scale_expr = _op.const(scale) + inv_scale = _op.divide(one, scale_expr) + rounded = _op.floor(_op.add(_op.multiply(clamped_shifted, inv_scale), half)) + return _op.add(_op.multiply(rounded, scale_expr), nudged_min_expr) + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) @@ -3471,7 +3525,7 @@ def _get_flattened_index(indices, shape): indices_list = [] # Below function iterates through each applicable indices per dimension - # based on format type specified and finaly produce the dense matrix and the NZ indices. + # based on format type specified and finally produce the dense matrix and the NZ indices. def _def_prepare_dense_matrix_from_sparse(indices, level, prev_idx): if level == len(indices): start_pos = 0 diff --git a/python/tvm/relay/frontend/tflite_flexbuffer.py b/python/tvm/relay/frontend/tflite_flexbuffer.py index 4b5d2b9c605c..4533886d14da 100644 --- a/python/tvm/relay/frontend/tflite_flexbuffer.py +++ b/python/tvm/relay/frontend/tflite_flexbuffer.py @@ -88,7 +88,7 @@ def indirect_jump(self, offset, byte_width): def decode_keys(self, end, size, byte_width): """Decodes the flexbuffer type vector. Map keys are stored in this form""" - # Keys are strings here. The format is all strings seperated by null, followed by back + # Keys are strings here. The format is all strings separated by null, followed by back # offsets for each of the string. For example, (str1)\0(str1)\0(offset1)(offset2) The end # pointer is pointing at the end of all strings keys = list() diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index fa2772c1299f..3793f947c5cc 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -590,11 +590,59 @@ def batch_matmul_grad(orig, grad): GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk """ lhs, rhs = orig.args + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, True): + # ki, jk -> ij + # jk, ij -> ki + # ij, ki -> jk + return [ + collapse_sum_like(_nn.batch_matmul(rhs, grad, transpose_a=True, transpose_b=True), lhs), + collapse_sum_like(_nn.batch_matmul(grad, lhs, transpose_a=True, transpose_b=True), rhs), + ] + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, False): + # ki, kj -> ij + # kj, ij -> ki + # ki, ij -> kj + return [ + collapse_sum_like( + _nn.batch_matmul(rhs, grad, transpose_a=False, transpose_b=True), lhs + ), + collapse_sum_like( + _nn.batch_matmul(lhs, grad, transpose_a=False, transpose_b=False), rhs + ), + ] + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, True): + # ik, jk -> ij + # ij, jk -> ik + # ij, ik -> jk + # Keep using NT format batch_matmul here for not involving extra ops + # TODO(jcf94): Merge all to normal batch_matmul when it is finally ready + return [ + collapse_sum_like( + _nn.batch_matmul( + grad, + transpose(rhs, [0, 2, 1]), + transpose_a=False, + transpose_b=True, + ), + lhs, + ), + collapse_sum_like( + _nn.batch_matmul( + transpose(grad, [0, 2, 1]), + transpose(lhs, [0, 2, 1]), + transpose_a=False, + transpose_b=True, + ), + rhs, + ), + ] + # (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, False) + # ik, kj -> ij + # ij, kj -> ik + # ik, ij -> kj return [ - collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs), - collapse_sum_like( - _nn.batch_matmul(transpose(grad, [0, 2, 1]), transpose(lhs, [0, 2, 1])), rhs - ), + collapse_sum_like(_nn.batch_matmul(grad, rhs, transpose_a=False, transpose_b=True), lhs), + collapse_sum_like(_nn.batch_matmul(lhs, grad, transpose_a=True, transpose_b=False), rhs), ] diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index cbe6a22f4a4d..cec7c4d141cb 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -23,6 +23,7 @@ from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem +from tvm.ir import Op from tvm.relay.expr_functor import ExprMutator, ExprVisitor logger = logging.getLogger("TensorRT") @@ -1035,6 +1036,7 @@ def visit_tuple_getitem(self, op): return visit if ( isinstance(visit.tuple_value, Call) + and isinstance(visit.tuple_value.op, Op) and visit.tuple_value.op.name == "nn.dropout" and visit.index == 0 ): diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 753a17605667..a9e485866381 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -198,7 +198,11 @@ def compute_sparse_transpose(attrs, inputs, out_type): @reg.register_compute("nn.sparse_conv2d") def compute_sparse_conv2d(attrs, inputs, out_type): """Compute definition of sparse_conv2d""" - return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"])] + return [ + topi.nn.sparse_conv2d( + inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"], attrs["kernel_size"] + ) + ] reg.register_strategy("nn.sparse_conv2d", strategy.sparse_conv2d_strategy) @@ -964,7 +968,8 @@ def compute_space_to_depth(attrs, inputs, out_dtype): @script -def _conv_shape_func(dshape, kshape, strides, padding, dilation): +def _conv_shape_func_nchw(dshape, kshape, strides, padding, dilation): + """Shape function for conv*d op with nchw & oihw layout.""" out = output_tensor((dshape.shape[0],), "int64") out[0] = dshape[0] out[1] = kshape[0] @@ -975,23 +980,52 @@ def _conv_shape_func(dshape, kshape, strides, padding, dilation): return out +@script +def _conv_shape_func_nhwc_hwio(dshape, kshape, strides, padding, dilation): + """Shape function for conv*d op with nhwc & hwio layout.""" + out = output_tensor((dshape.shape[0],), "int64") + out[0] = dshape[0] + out[dshape.shape[0] - 1] = kshape[kshape.shape[0] - 1] + + for i in const_range(dshape.shape[0] - 2): + dilated_k = (kshape[i] - 1) * dilation[i] + 1 + out[i + 1] = (dshape[i + 1] + 2 * padding[i] - dilated_k) // strides[i] + 1 + return out + + +@script +def _conv_shape_func_nhwc_hwoi(dshape, kshape, strides, padding, dilation): + """Shape function for conv*d op with nhwc & hwoi layout.""" + out = output_tensor((dshape.shape[0],), "int64") + out[0] = dshape[0] + out[dshape.shape[0] - 1] = kshape[kshape.shape[0] - 2] + + for i in const_range(dshape.shape[0] - 2): + dilated_k = (kshape[i] - 1) * dilation[i] + 1 + out[i + 1] = (dshape[i + 1] + 2 * padding[i] - dilated_k) // strides[i] + 1 + return out + + def conv_shape_func(attrs, inputs, _): - """ - Shape function for contrib_conv2d_NCHWc op. - """ + """Shape function for conv*d op.""" strides = get_const_tuple(attrs.strides) padding = get_const_tuple(attrs.padding) dilation = get_const_tuple(attrs.dilation) - return [ - _conv_shape_func( - inputs[0], - inputs[1], - convert(strides), - convert(padding), - convert(dilation), + shape_func = None + if attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] == "OIHW": + shape_func = _conv_shape_func_nchw + elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO": + shape_func = _conv_shape_func_nhwc_hwio + elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWOI": + shape_func = _conv_shape_func_nhwc_hwoi + else: + raise ValueError( + "Unsupported data/kernel layout: %s, %s" + % (attrs["data_layout"], attrs["kernel_layout"]) ) - ] + + return [shape_func(inputs[0], inputs[1], convert(strides), convert(padding), convert(dilation))] reg.register_shape_func("nn.conv1d", False, conv_shape_func) @@ -1052,27 +1086,22 @@ def conv2d_NCHWc_shape_func(attrs, inputs, _): @script -def _conv2d_transpose_nchw_shape_func(dshape, kshape, strides, padding, dilation, output_padding): +def _conv_transpose_shape_func(dshape, kshape, strides, padding, dilation, output_padding): out = output_tensor((dshape.shape[0],), "int64") - kheight = kshape[2] - kwidth = kshape[3] - dilated_kh = (kheight - 1) * dilation[0] + 1 - dilated_kw = (kwidth - 1) * dilation[1] + 1 - - out_height = strides[0] * (dshape[2] - 1) + dilated_kh - 2 * padding[0] + output_padding[0] - out_width = strides[1] * (dshape[3] - 1) + dilated_kw - 2 * padding[1] + output_padding[1] - out[0] = dshape[0] out[1] = kshape[1] - out[2] = out_height - out[3] = out_width + + for i in const_range(dshape.shape[0] - 2): + dilated_k = (kshape[i + 2] - 1) * dilation[i] + 1 + out[i + 2] = ( + strides[i] * (dshape[i + 2] - 1) + dilated_k - 2 * padding[i] + output_padding[i] + ) return out -@reg.register_shape_func("nn.conv2d_transpose", False) -def conv2d_transpose_nchw_shape_func(attrs, inputs, _): +def conv_transpose_shape_func(attrs, inputs, _): """ - Shape function for conv2d_transpose op. + Shape function for contrib_conv2d_NCHWc op. """ strides = get_const_tuple(attrs.strides) padding = get_const_tuple(attrs.padding) @@ -1080,7 +1109,7 @@ def conv2d_transpose_nchw_shape_func(attrs, inputs, _): output_padding = get_const_tuple(attrs.output_padding) return [ - _conv2d_transpose_nchw_shape_func( + _conv_transpose_shape_func( inputs[0], inputs[1], convert(strides), @@ -1091,6 +1120,10 @@ def conv2d_transpose_nchw_shape_func(attrs, inputs, _): ] +reg.register_shape_func("nn.conv1d_transpose", False, conv_transpose_shape_func) +reg.register_shape_func("nn.conv2d_transpose", False, conv_transpose_shape_func) + + @script def _pool2d_shape_func(data_shape, pool_size, strides, padding, height_axis, width_axis): out = output_tensor((data_shape.shape[0],), "int64") @@ -1230,9 +1263,9 @@ def dense_shape_func(attrs, inputs, _): @script def _dense_pack_shape_func(data_shape, weight_shape): out = output_tensor((data_shape.shape[0],), "int64") - for i in const_range(out.shape[0] - 1): - out[i] = data_shape[i] - out[out.shape[0] - 1] = weight_shape[0] * weight_shape[2] + assert data_shape.shape[0] == 2, "Input data must be 2D" + out[0] = data_shape[0] + out[1] = weight_shape[0] * weight_shape[2] return out @@ -1247,14 +1280,11 @@ def dense_pack_shape_func(attrs, inputs, _): @script -def _batch_matmul_shape_func(data_shape, weight_shape): - out = output_tensor((data_shape.shape[0],), "int64") - for i in const_range(out.shape[0] - 1): - if i == 0: - out[i] = max(data_shape[i], weight_shape[i]) - else: - out[i] = data_shape[i] - out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2] +def _batch_matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a, transpose_b): + out = output_tensor((tensor_a_shape.shape[0],), "int64") + out[0] = max(tensor_a_shape[0], tensor_b_shape[0]) + out[1] = tensor_a_shape[2] if transpose_a else tensor_a_shape[1] + out[2] = tensor_b_shape[1] if transpose_b else tensor_b_shape[2] return out @@ -1262,9 +1292,16 @@ def _batch_matmul_shape_func(data_shape, weight_shape): @reg.register_shape_func("nn.batch_matmul", False) def batch_matmul_shape_func(attrs, inputs, _): """ - Shape function for dense op. + Shape function for batch matmul op. """ - ret = [_batch_matmul_shape_func(inputs[0], inputs[1])] + ret = [ + _batch_matmul_shape_func( + inputs[0], + inputs[1], + expr.IntImm("bool", attrs.transpose_a), + expr.IntImm("bool", attrs.transpose_b), + ) + ] return ret @@ -1307,4 +1344,7 @@ def dilate_shape_func(attrs, inputs, _): reg.register_shape_func("nn.bias_add", False, elemwise_shape_func) reg.register_shape_func("nn.softmax", False, elemwise_shape_func) +reg.register_shape_func("nn.fast_softmax", False, elemwise_shape_func) reg.register_shape_func("nn.relu", False, elemwise_shape_func) +reg.register_shape_func("nn.leaky_relu", False, elemwise_shape_func) +reg.register_shape_func("nn.prelu", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 4c94102275bb..e882bcf7e271 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1548,9 +1548,9 @@ def dense(data, weight, units=None, out_dtype=""): return _make.dense(data, weight, units, out_dtype) -def contrib_dense_pack(data, weight, units=None, out_dtype=""): +def contrib_dense_pack(data, weight, weight_layout="NK", units=None, out_dtype=""): """Dense operator. - Applies a linear transformation + Applies a linear transformation with packed weight .. math:: @@ -1560,25 +1560,27 @@ def contrib_dense_pack(data, weight, units=None, out_dtype=""): ---------- data : tvm.relay.Expr The input data to the operator, - of shape `(d_1, d_2, ..., d_n, units_in)`. + of shape `(batch, units_in)`. weight : tvm.relay.Expr The transformed weight expressions, 3-D matrix, of shape `(units // pack_weight_tile, units_in, pack_weight_tile)`. + weight_layout: str + The layout of weight, such as "NK" or "NK8n". + units : int, optional Number of hidden units of the dense transformation. out_dtype : str, optional - Specifies the output data type for mixed precision dense, - of shape `(d_1, d_2, ..., d_n, units)`. + Specifies the output data type for mixed precision dense. Returns ------- result : tvm.relay.Expr The computed result. """ - return _make.contrib_dense_pack(data, weight, units, out_dtype) + return _make.contrib_dense_pack(data, weight, weight_layout, units, out_dtype) def fifo_buffer(data, buffer, axis): @@ -2137,32 +2139,40 @@ def group_norm(data, gamma, beta, num_groups, axis=1, epsilon=1e-5, center=True, return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale) -def batch_matmul(x, y, out_dtype=""): +def batch_matmul(tensor_a, tensor_b, out_dtype="", transpose_a=False, transpose_b=True): r""" - Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data - in batch. + Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + + Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. .. math:: - \mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T) + \mbox{batch_matmul}(A, B)[i, :, :] = \mbox{matmul}(A[i, :, :], B[i, :, :]) Parameters ---------- - x : tvm.relay.Expr + tensor_a : tvm.relay.Expr The first input. - y : tvm.relay.Expr + tensor_b : tvm.relay.Expr The second input. - out_dtype : str, optional - Specifies the output data type for mixed precision batch matmul + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. Returns ------- result: tvm.relay.Expr The computed result. """ - return _make.batch_matmul(x, y, out_dtype) + return _make.batch_matmul(tensor_a, tensor_b, out_dtype, transpose_a, transpose_b) # pylint: disable=no-else-return,inconsistent-return-statements diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 0d90a5cdeafa..bbaffc469321 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -341,6 +341,23 @@ def register_convert_op_layout(op_name, convert_layout=None, level=10): return tvm.ir.register_op_attr(op_name, "FTVMConvertOpLayout", convert_layout, level) +def register_infer_correct_layout(op_name, infer_layout=None, level=10): + """Register infer op layout function for an op + + Parameters + ---------- + op_name : str + The name of the operator + + infer_layout: function (attrs: Attrs, inputs: List[Layout]) -> InferCorrectLayoutOutput + The function to infer correct layout + + level : int + The priority level + """ + return tvm.ir.register_op_attr(op_name, "FInferCorrectLayout", infer_layout, level) + + def register_legalize(op_name, legal_op=None, level=10): """Register legal transformation function for an op diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 2d185bcee798..507dd9371a97 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -74,6 +74,11 @@ class DenseAttrs(Attrs): """Attributes for nn.dense""" +@tvm._ffi.register_object("relay.attrs.BatchMatmulAttrs") +class BatchMatmulAttrs(Attrs): + """Attributes for nn.batch_matmul""" + + @tvm._ffi.register_object("relay.attrs.SoftmaxAttrs") class SoftmaxAttrs(Attrs): """Attributes for nn.softmax""" diff --git a/python/tvm/relay/op/strategy/bifrost.py b/python/tvm/relay/op/strategy/bifrost.py index 8008391fe86c..ec3edab2c8b1 100644 --- a/python/tvm/relay/op/strategy/bifrost.py +++ b/python/tvm/relay/op/strategy/bifrost.py @@ -83,6 +83,14 @@ def conv2d_strategy_bifrost(attrs, inputs, out_type, target): wrap_topi_schedule(topi.bifrost.schedule_depthwise_conv2d_nchw), name="depthwise_conv2d_nchw.bifrost", ) + elif layout == "NHWC": + assert kernel_layout == "HWOI" + # For now just reuse general Mali strategy. + strategy.add_implementation( + wrap_compute_conv2d(topi.mali.depthwise_conv2d_nhwc), + wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nchw.bifrost", + ) else: raise RuntimeError( "Unsupported depthwise_conv2d layout {} for Mali(Bifrost)".format(layout) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 1f999a810164..ba47ae7bc4f1 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -819,7 +819,13 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): """batch_matmul cuda strategy""" strategy = _op.OpStrategy() x, y = inputs - if x.dtype == "int8" and y.dtype == "int8" and out_type.dtype == "int32": + if ( + x.dtype == "int8" + and y.dtype == "int8" + and out_type.dtype == "int32" + and not attrs["transpose_a"] + and attrs["transpose_b"] + ): strategy.add_implementation( wrap_compute_batch_matmul(topi.cuda.batch_matmul_int8, need_out_dtype=True), wrap_topi_schedule(topi.cuda.schedule_batch_matmul_int8), @@ -840,7 +846,12 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): name="batch_matmul_cublas.cuda", plevel=15, ) - if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target): + if ( + target.kind.name == "cuda" + and nvcc.have_tensorcore(target=target) + and not attrs["transpose_a"] + and attrs["transpose_b"] + ): x, y = inputs _, M, K = get_const_tuple(x.shape) _, N, K = get_const_tuple(y.shape) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 3348d8033904..9c756f201721 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -799,10 +799,11 @@ def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, ne def _compute_batch_matmul(attrs, inputs, out_type): args = [inputs[0], inputs[1], out_type.shape] + args.append(out_type.dtype if need_out_dtype else None) + args.append(attrs.transpose_a) + args.append(attrs.transpose_b) if need_auto_scheduler_layout: args.append(get_auto_scheduler_rewritten_layout(attrs)) - if need_out_dtype: - args.append(out_type.dtype) return [topi_compute(*args)] return _compute_batch_matmul diff --git a/python/tvm/relay/op/strategy/hls.py b/python/tvm/relay/op/strategy/hls.py index b147af06cfc3..1eebbd36b847 100644 --- a/python/tvm/relay/op/strategy/hls.py +++ b/python/tvm/relay/op/strategy/hls.py @@ -80,7 +80,7 @@ def log_softmax_strategy_hls(attrs, inputs, out_type, target): return strategy -@override_native_generic_func("conv2d_strategy") +@conv2d_strategy.register("hls") def conv2d_strategy_hls(attrs, inputs, out_type, target): """conv2d hls strategy""" strategy = _op.OpStrategy() @@ -132,7 +132,7 @@ def conv2d_strategy_hls(attrs, inputs, out_type, target): return strategy -@override_native_generic_func("conv2d_NCHWc_strategy") +@conv2d_NCHWc_strategy.register("hls") def conv2d_NCHWc_strategy_hls(attrs, inputs, out_type, target): """conv2d_NCHWc hls strategy""" strategy = _op.OpStrategy() diff --git a/python/tvm/relay/op/strategy/mali.py b/python/tvm/relay/op/strategy/mali.py index d38fe0d82758..e5f4b4e58562 100644 --- a/python/tvm/relay/op/strategy/mali.py +++ b/python/tvm/relay/op/strategy/mali.py @@ -120,14 +120,17 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target): elif layout == "NHWC": assert kernel_layout == "HWOI" if not is_auto_scheduler_enabled(): - raise RuntimeError( - "depthwise_conv2d NHWC layout is not enabled for mali without auto_scheduler." + strategy.add_implementation( + wrap_compute_conv2d(topi.mali.depthwise_conv2d_nhwc), + wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nhwc.mali", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), + naive_schedule, + name="depthwise_conv2d_nhwc.mali", ) - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), - naive_schedule, - name="depthwise_conv2d_nhwc.mali", - ) else: raise RuntimeError("Unsupported depthwise_conv2d layout {} for mali".format(layout)) else: # group_conv2d diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index f4538071e11e..64373dcdd7bf 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -20,6 +20,7 @@ from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.te import SpecializedCondition from tvm.contrib.thrust import can_use_rocthrust +from tvm.contrib import miopen from .generic import * from .. import op as _op @@ -304,3 +305,41 @@ def topk_strategy_cuda(attrs, inputs, out_type, target): plevel=15, ) return strategy + + +@softmax_strategy.register(["rocm"]) +def softmax_strategy_rocm(attrs, inputs, out_type, target): + """rocm strategy for softmax""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.cuda.schedule_softmax), + name="softmax.rocm", + ) + if "miopen" in target.libs: + strategy.add_implementation( + wrap_compute_softmax(miopen.softmax), + wrap_topi_schedule(topi.generic.schedule_extern), + name="softmax.miopen", + plevel=15, + ) + return strategy + + +@log_softmax_strategy.register(["rocm"]) +def log_softmax_strategy_rocm(attrs, inputs, out_type, target): + """rocm strategy for log softmax""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.log_softmax), + wrap_topi_schedule(topi.cuda.schedule_softmax), + name="log_softmax.rocm", + ) + if "miopen" in target.libs: + strategy.add_implementation( + wrap_compute_softmax(miopen.log_softmax), + wrap_topi_schedule(topi.generic.schedule_extern), + name="log_softmax.miopen", + plevel=15, + ) + return strategy diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 6a4030514580..1c8d1b478cb1 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -521,14 +521,16 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): strategy = _op.OpStrategy() if is_dynamic(out_type) or is_auto_scheduler_enabled(): strategy.add_implementation( - wrap_compute_batch_matmul(topi.nn.batch_matmul, need_auto_scheduler_layout=True), + wrap_compute_batch_matmul( + topi.nn.batch_matmul, need_auto_scheduler_layout=True, need_out_dtype=True + ), wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul), name="batch_matmul.generic", plevel=10, ) else: strategy.add_implementation( - wrap_compute_batch_matmul(topi.x86.batch_matmul), + wrap_compute_batch_matmul(topi.x86.batch_matmul, need_out_dtype=True), wrap_topi_schedule(topi.x86.schedule_batch_matmul), name="batch_matmul.x86", plevel=10, @@ -563,6 +565,31 @@ def sparse_dense_strategy_cpu(attrs, inputs, out_type, target): return strategy +@sparse_conv2d_strategy.register("cpu") +def sparse_conv2d_strategy_cpu(attrs, inputs, out_type, target): + """sparse conv2d x86 strategy""" + strategy = _op.OpStrategy() + if attrs["kernel_size"][0] == 1: + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.nn.sparse_conv2d), + wrap_topi_schedule(topi.generic.schedule_sparse_conv2d), + name="sparse_conv2d.generic", + ) + elif attrs["kernel_size"][0] == 3: + if attrs["layout"] == "NHWC": + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nhwc), + wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nhwc), + name="conv3x3_spNHWC.x86", + ) + elif attrs["layout"] == "NCHW": + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nchw), + wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nchw), + ) + return strategy + + @roi_align_strategy.register("cpu") def roi_align_strategy_cpu(attrs, inputs, out_type, target): """roi_align x86 strategy""" diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 376bf4a4804d..542980561e78 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -73,9 +73,33 @@ def get_tensor_array_shape(expr, dtype, prelude): return None -def _get_name_static(canonical, dtype, shape): - """Get name for static shape tensor array op corresponding - to the canonical name""" +def _get_name_static(canonical, dtype, shape, batch_dim=None): + """Get name for static shape tensor array op + + By design, static ADT tensor in TVM has type name in the format + of static_tensor_dim0_dim1_..._dimN_t + or static_tensor_batch1_dim0_dim1_..._dimN_t if tensorlist stack only have one item. + + Parameters + ---------- + canonical : String + Tensor array op name + + dtype : str + Data type. + + shape : tuple of (int, Any) or None + Tensor array shape + + batch_dim: None or int + 1 if tensorlist stack only have one item. + None by default + + Returns + ------- + name : String + The tensor array op name + """ dim_names = [] for dim in shape: if isinstance(dim, Any): @@ -89,26 +113,31 @@ def _get_name_static(canonical, dtype, shape): shape_str = "scalar" if canonical == "tensor_t": return "static_tensor_{}_{}_t".format(dtype, shape_str) - return "{}_{}_{}".format(canonical, dtype, shape_str) + if batch_dim is None or canonical in ["tensor_constructor", "tensor_nil"]: + return "{}_{}_{}".format(canonical, dtype, shape_str) + if batch_dim != 1: + return "{}_{}_{}".format(canonical, dtype, shape_str) + return "{}_{}_batch{}_{}".format(canonical, dtype, str(batch_dim), shape_str) class StaticTensorArrayOps(object): """Contains tensor array related ops for fixed rank tensor array""" - def __init__(self, prelude, dtype, shape): + def __init__(self, prelude, dtype, shape, batch_dim=None): """Create tensor array ops registry""" self.prelude = prelude self.dtype = dtype self.shape = shape + self.batch_dim = batch_dim self.list, self.cons, self.nil = self.prelude.mod.get_type("List") def get_name(self, canonical): """Get name corresponding to the canonical name""" - return _get_name_static(canonical, self.dtype, self.shape) + return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim) def get_global_var(self, canonical): """Get global corresponding to the canonical name""" - return self.prelude.get_global_var_static(canonical, self.dtype, self.shape) + return self.prelude.get_global_var_static(canonical, self.dtype, self.shape, self.batch_dim) def get_type(self, canonical): """Get type corresponding to the canonical name""" @@ -262,9 +291,10 @@ def define_tensor_expand_dims(self): # Note: we set the added axis to be Any() instead of 1 due to # in stack op, we need to recursively concatenate. + new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape( [ - Any(), + new_axis, ] + list(self.shape) ) @@ -573,20 +603,27 @@ def define_tensor_array_stack(self): expand_dims_var = self.get_global_var("tensor_expand_dims") # Register tensor_concatenate for output_shape + new_axis = Any() if not self.batch_dim or self.batch_dim != 1 else self.batch_dim output_shape = [ - Any(), + new_axis, ] + list(self.shape) - _, _, output_ops = self._get_adt_by_shape(output_shape) output_ops.define_tensor_concatenate() concat_var = output_ops.get_global_var("tensor_concatenate") tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) - tensors = self.prelude.foldl( - concat_var, - self.prelude.hd(tensor_array_expand_dims), - self.prelude.tl(tensor_array_expand_dims), - ) + if self.batch_dim is not None and self.batch_dim == 1: + # only one element + tensors = self.prelude.id( + self.prelude.hd(tensor_array_expand_dims), + ) + else: + tensors = self.prelude.foldl( + concat_var, + self.prelude.hd(tensor_array_expand_dims), + self.prelude.tl(tensor_array_expand_dims), + ) + output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape) self.prelude.mod[stack_var] = Function( [tensor_array], tensors, output_tensor_type_var(), [] @@ -599,8 +636,9 @@ def define_tensor_array_gather(self): helper_name = self.get_name("tensor_array_gather_helper") helper_var = self._create_global_var(helper_name) + new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim output_shape = [ - Any(), + new_axis, ] + list(self.shape) output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape) stack_var = self.get_global_var("tensor_array_stack") @@ -668,7 +706,7 @@ def register(self): def _get_adt_by_shape(self, shape): """Get ADT type and constructor with given shape.""" - adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape) + adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape, self.batch_dim) adt_ops.define_tensor_adt() tensor_type_var = adt_ops.get_type("tensor_t") tensor_constructor = adt_ops.get_ctor("tensor_constructor") @@ -1482,13 +1520,13 @@ def get_tensor_ctor(self, canonical, dtype): ty = self.get_type("tensor_t", dtype) return self.get_ctor(ty.name_hint, canonical, dtype) - def get_name_static(self, canonical, dtype, shape): + def get_name_static(self, canonical, dtype, shape, batch_dim=None): """Get name corresponding to the canonical name""" - return _get_name_static(canonical, dtype, shape) + return _get_name_static(canonical, dtype, shape, batch_dim) - def get_global_var_static(self, canonical, dtype, shape): + def get_global_var_static(self, canonical, dtype, shape, batch_dim=None): """Get var corresponding to the canonical name""" - name = self.get_name_static(canonical, dtype, shape) + name = self.get_name_static(canonical, dtype, shape, batch_dim) return self.mod.get_global_var(name) def get_type_static(self, canonical, dtype, shape): diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 961517f863fb..3226240fbe39 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -94,6 +94,25 @@ def get_scalar_from_constant(expr): return value.item(0) +def _shift(data, zero_point, out_dtype): + """Shifts (add/subtracts) the qnn tensor with +/-128)""" + if out_dtype == "uint8": + shift = 128 + elif out_dtype == "int8": + shift = -128 + else: + raise ValueError("Unsupported out dtype.") + data_modified = relay.cast(data, "int32") + data_modified = relay.add(data_modified, relay.const(shift, "int32")) + data_modified = relay.cast(data_modified, out_dtype) + if isinstance(zero_point, relay.Constant): + zero_point_val = get_scalar_from_constant(zero_point) + zero_point_modified = relay.const(zero_point_val + shift, "int32") + else: + zero_point_modified = zero_point + relay.const(shift, "int32") + return (data_modified, zero_point_modified) + + # Helper function for lowering in the abscence of fast Int8 arithmetic units. def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): """Converts QNN operators into a sequence of Relay operators that are friendly to HW that do @@ -161,22 +180,6 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op): result : tvm.relay.Expr The legalized expr """ - - def _shift(data, zero_point, out_dtype): - """Shifts (add/subtracts) the qnn tensor with +/-128)""" - if out_dtype == "uint8": - shift = 128 - elif out_dtype == "int8": - shift = -128 - else: - raise ValueError("Unsupported out dtype.") - data_modified = relay.cast(data, "int32") - data_modified = relay.add(data_modified, relay.const(shift, "int32")) - data_modified = relay.cast(data_modified, out_dtype) - zero_point_val = get_scalar_from_constant(zero_point) - zero_point_modified = relay.const(zero_point_val + shift, "int32") - return (data_modified, zero_point_modified) - # Collect the dtypes. data_dtype = types[0].dtype kernel_dtype = types[1].dtype @@ -205,6 +208,54 @@ def _shift(data, zero_point, out_dtype): ) +# Helper function to change dtypes to int8 x int8. Cuda dp4a instructions prefer this setting. +def helper_change_dtypes_to_int8(attrs, inputs, types, relay_op): + """Legalizes QNN conv2d/dense op for Nvidia HW. dp4a supports i8 x i8 fast conv/MM. If the + dtypes are already good, we dont transform. Else, we shift the tensor values and zero points + to change the dtype. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + # Collect the dtypes. + data_dtype = types[0].dtype + kernel_dtype = types[1].dtype + + # Collect the input exprs. + data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale = inputs + + # dp4a supports i8 x i8 fast conv/MM. Don't do anything if it is already satisfied. + if data_dtype == "int8" and kernel_dtype == "int8": + return None + + # Shift input if necessary. + if data_dtype == "uint8": + # Compute (QA + 128) and (zp_a + 128) + data, input_zero_point = _shift(data, input_zero_point, "int8") + + # Shift kernel if necessary. + if kernel_dtype == "uint8": + # Compute (QA - 128) and (zp_a - 128) + kernel, kernel_zero_point = _shift(kernel, kernel_zero_point, "int8") + + # Call qnn.conv2d with modified inputs and zero points. + new_attrs = {k: attrs[k] for k in attrs.keys()} + return relay_op( + data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale, **new_attrs + ) + + # Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting. def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op): """Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However, @@ -339,11 +390,11 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types): @qnn_conv2d_legalize.register("cuda") def _qnn_conv2d_legalize_cuda(attrs, inputs, types): - # CUDA prefers the dtypes to be same. - return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) + # CUDA prefers both datatypes to be int8. + return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.conv2d) @qnn_dense_legalize.register("cuda") def _qnn_dense_legalize_cuda(attrs, inputs, types): - # CUDA prefers the dtypes to be same. - return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense) + # CUDA prefers both datatypes to be the int8. + return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.dense) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index f02f8227e14a..e74256ec74c3 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -682,6 +682,44 @@ def subtract( ) +def batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype="int32"): + r""" + Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data + in batch. + + .. math:: + + \mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T) + + Parameters + ---------- + x : tvm.relay.Expr + The first quantized input. + A quantized tensor is represented in following manner + `A = scale_a x (QA - zp_A)` + where QA is quantized tensor, scale_a and zp_A are quantization + params. + y : tvm.relay.Expr + The second quantized input. + x_zero_point: tvm.relay.Expr + The first input zero point. + y_zero_point: tvm.relay.Expr + The second input zero point. + x_scale: tvm.relay.Expr + The scale for the first input tensor. + y_scale: tvm.relay.Expr + The scale for the second input tensor. + out_dtype : str, optional + Specifies the output data type for mixed precision dense can be int32 or int16. + + Returns + ------- + result: tvm.relay.Expr + The computed result. + """ + return _make.batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype) + + # register fuse pattern for qnn ops reg.register_pattern("qnn.quantize", OpPattern.OPAQUE) reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE) diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index bfe797d844a8..8eb07d7b583b 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -134,10 +134,13 @@ def check_grad( test_inputs = inputs for target, dev in enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) + # Eval the backward and forward functions + # TODO(mbs): Evaluate a pair of functions so can share preparation between them. + bwd_func_compiled = relay.create_executor(device=dev, target=target).evaluate(bwd_func) + fwd_func_compiled = relay.create_executor(device=dev, target=target).evaluate(fwd_func) # Get analytic gradients. - _, grads = intrp.evaluate(bwd_func)(*inputs) + _, grads = bwd_func_compiled(*inputs) grads = [grad.numpy().astype("float64") for grad in grads] # Throw out gradients we aren't testing @@ -160,9 +163,9 @@ def check_grad( for i in np.ndindex(*x.shape): x_i = x[i] x[i] = x_i + eps - fwd_plus = intrp.evaluate(fwd_func)(*inputs).numpy().astype("float64") + fwd_plus = fwd_func_compiled(*inputs).numpy().astype("float64") x[i] = x_i - eps - fwd_minus = intrp.evaluate(fwd_func)(*inputs).numpy().astype("float64") + fwd_minus = fwd_func_compiled(*inputs).numpy().astype("float64") x[i] = x_i approx_grad[i] = np.sum((fwd_plus - fwd_minus) / (2 * eps)) approx_grads.append(approx_grad) diff --git a/python/tvm/relay/testing/byoc.py b/python/tvm/relay/testing/byoc.py new file mode 100644 index 000000000000..619c9b99ca1d --- /dev/null +++ b/python/tvm/relay/testing/byoc.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines test utilties useful for testing BYOC flows.""" + +from tvm import relay +from tvm.relay.expr_functor import ExprMutator +from tvm.relay.op.annotation import compiler_begin, compiler_end + + +class CcompilerAnnotator(ExprMutator): + """ + This is used to create external functions for ccompiler. + A simple annotator that creates the following program: + | + -- begin -- + | + add + | + subtract + | + multiply + | + -- end -- + | + """ + + def __init__(self): + super(CcompilerAnnotator, self).__init__() + self.in_compiler = 0 + + def visit_call(self, call): + if call.op.name == "add": # Annotate begin at args + if self.in_compiler == 1: + lhs = compiler_begin(super().visit(call.args[0]), "ccompiler") + rhs = compiler_begin(super().visit(call.args[1]), "ccompiler") + op = relay.add(lhs, rhs) + self.in_compiler = 2 + return op + elif call.op.name == "subtract": + if self.in_compiler == 1: + lhs = super().visit(call.args[0]) + rhs = super().visit(call.args[1]) + if isinstance(lhs, relay.expr.Var): + lhs = compiler_begin(lhs, "ccompiler") + if isinstance(rhs, relay.expr.Var): + rhs = compiler_begin(rhs, "ccompiler") + return relay.subtract(lhs, rhs) + elif call.op.name == "multiply": # Annotate end at output + self.in_compiler = 1 + lhs = super().visit(call.args[0]) + rhs = super().visit(call.args[1]) + if isinstance(lhs, relay.expr.Var): + lhs = compiler_begin(lhs, "ccompiler") + if isinstance(rhs, relay.expr.Var): + rhs = compiler_begin(rhs, "ccompiler") + op = relay.multiply(lhs, rhs) + if self.in_compiler == 2: + op = compiler_end(op, "ccompiler") + self.in_compiler = 0 + return op + return super().visit_call(call) diff --git a/python/tvm/relay/testing/densenet.py b/python/tvm/relay/testing/densenet.py index 1ceb6267d355..6b8d0098a5c6 100644 --- a/python/tvm/relay/testing/densenet.py +++ b/python/tvm/relay/testing/densenet.py @@ -44,9 +44,12 @@ def _make_dense_layer(data, growth_rate, bn_size, index): def _make_dense_block(data, num_layers, bn_size, growth_rate, index): """Makes a block of dense layers of the specified size.""" layer_out = data + blocks = [] for i in range(num_layers): layer_out = _make_dense_layer(layer_out, growth_rate, bn_size, "%s_%s" % (index, i)) - return layer_out + blocks.append(layer_out) + block_out = relay.concatenate(blocks, 1) + return block_out def _make_transition(data, num_output_features, index): @@ -63,7 +66,9 @@ def _make_dense_net( num_init_features, growth_rate, block_config, data_shape, data_dtype, bn_size=4, classes=1000 ): """Builds up a densenet.""" - data = relay.Var("data", relay.TensorType(data_shape, data_dtype)) # (bn_size, 3, 224, 224))) + data = relay.Var( + "data", relay.TensorType(data_shape, data_dtype) + ) # (batch_size, 3, 224, 224))) conv1 = layers.conv2d( data, channels=num_init_features, @@ -79,7 +84,7 @@ def _make_dense_net( num_features = num_init_features layer_out = mp for i, num_layers in enumerate(block_config): - layer_out = _make_dense_block(layer_out, num_layers, growth_rate, bn_size, i) + layer_out = _make_dense_block(layer_out, num_layers, bn_size, growth_rate, i) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: layer_out = _make_transition(layer_out, num_features // 2, i) @@ -131,10 +136,10 @@ def get_workload( 169: (69, 32, [6, 12, 32, 32]), 201: (64, 32, [6, 12, 48, 32]), } - + bn_size = 4 num_init_features, growth_rate, block_config = specs[densenet_size] data_shape = tuple([batch_size] + list(image_shape)) net = _make_dense_net( - num_init_features, growth_rate, block_config, data_shape, dtype, batch_size, classes + num_init_features, growth_rate, block_config, data_shape, dtype, bn_size, classes ) return create_workload(net) diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index 9fb3f1102137..3680c4b9805e 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -91,7 +91,7 @@ def vmobj_to_list(o): """ if isinstance(o, tvm.nd.NDArray): - result = [o.asnumpy()] + result = [o.numpy()] elif isinstance(o, tvm.runtime.container.ADT): result = [] for f in o: @@ -107,7 +107,7 @@ def vmobj_to_list(o): elif "tensor_nil" in o.constructor.name_hint: result = [0] elif "tensor" in o.constructor.name_hint: - result = [o.fields[0].asnumpy()] + result = [o.fields[0].numpy()] else: raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint) else: diff --git a/python/tvm/relay/testing/yolo_detection.py b/python/tvm/relay/testing/yolo_detection.py index a387f3076bf5..949d024bd86f 100644 --- a/python/tvm/relay/testing/yolo_detection.py +++ b/python/tvm/relay/testing/yolo_detection.py @@ -103,8 +103,8 @@ def _get_yolo_detections(l, im_shape, net_shape, thresh, relative, dets): l["biases"], np.asarray(l["mask"])[location[0]], location, - data.shape[2], data.shape[3], + data.shape[2], net_shape[0], net_shape[1], ) @@ -139,10 +139,10 @@ def _get_region_detections(l, im_shape, net_shape, thresh, relative, dets): l["biases"], n, location, - data.shape[2], data.shape[3], data.shape[2], data.shape[3], + data.shape[2], ) objectness = scale if scale > thresh else 0 if objectness: diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index 9ed40f85c3bc..378b0c38ff64 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -19,4 +19,4 @@ # transformation passes from .transform import * from .recast import recast -from . import fake_quantization_to_integer +from . import fake_quantization_to_integer, mixed_precision diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 5f4c53772eec..cf55c67c8083 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -17,13 +17,12 @@ """Relay functions for rewriting fake quantized ops.""" import tvm from tvm import relay +from tvm.ir import TensorAffineType, TupleAffineType from ..op import register_fake_quantization_to_integer def fold_constant(expr): - mod = tvm.IRModule.from_expr(expr) - mod = relay.transform.FoldConstant()(mod) - return mod["main"].body + return relay.transform.FoldConstantExpr(expr, tvm.IRModule()) @register_fake_quantization_to_integer("qnn.dequantize") @@ -31,7 +30,7 @@ def dequantize(expr, type_map): """Remove dequantize op""" out = expr.args[0] t = type_map[expr] - return [out, t.scale, t.zero_point, t.dtype] + return [out, t] @register_fake_quantization_to_integer("qnn.quantize") @@ -54,23 +53,26 @@ def quantize(expr, type_map): expr.args[2], out_dtype=expr.attrs.out_dtype, ) - return [out, expr.args[1], expr.args[2], expr.attrs.out_dtype] + return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype)] -def register_unary_identity(op_name, op): +def register_unary_identity(op_name): def identity(expr, type_map): assert len(expr.args) == 1 arg = expr.args[0] t = type_map[arg] - out = op(arg, **expr.attrs) - return [out, t.scale, t.zero_point, t.dtype] + return [expr, t] return register_fake_quantization_to_integer(op_name, identity) -register_unary_identity("reshape", relay.op.reshape) -register_unary_identity("transpose", relay.op.transpose) -register_unary_identity("nn.max_pool2d", relay.op.nn.max_pool2d) +register_unary_identity("reshape") +register_unary_identity("squeeze") +register_unary_identity("strided_slice") +register_unary_identity("transpose") +register_unary_identity("expand_dims") +register_unary_identity("nn.max_pool2d") +register_unary_identity("nn.batch_flatten") @register_fake_quantization_to_integer("nn.avg_pool2d") @@ -81,7 +83,7 @@ def avgpool2d(expr, type_map): arg = relay.op.cast(arg, "int32") out = relay.op.nn.avg_pool2d(arg, **expr.attrs) out = relay.op.cast(out, t.dtype) - return [out, t.scale, t.zero_point, t.dtype] + return [out, t] @register_fake_quantization_to_integer("nn.bias_add") @@ -99,10 +101,10 @@ def bias_add(expr, type_map): b_t.zero_point, in_scale, in_zero_point, - out_dtype=xt.dtype, + out_dtype=x_t.dtype, ) out = relay.op.nn.bias_add(x, b, **expr.attrs) - return [out, x_t.scale, x_t.zero_point, x_t.dtype] + return [out, x_t] @register_fake_quantization_to_integer("nn.conv2d") @@ -118,7 +120,35 @@ def conv2d(expr, type_map): out = relay.qnn.op.conv2d( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) - return [out, conv_scale, conv_zp, out.attrs.out_dtype] + return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype)] + + +@register_fake_quantization_to_integer("nn.dense") +def dense(expr, type_map): + """Rewrite a dense op""" + attrs = {**expr.attrs} + attrs.pop("out_dtype") + x, weight = expr.args + x_t = type_map[x] + w_t = type_map[weight] + dense_scale = fold_constant(x_t.scale * w_t.scale) + dense_zp = relay.const(0) + out = relay.qnn.op.dense( + x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs + ) + return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype)] + + +@register_fake_quantization_to_integer("nn.batch_matmul") +def batch_matmul(expr, type_map): + """Rewrite a batch_matmul op""" + x, y = expr.args + x_t = type_map[x] + y_t = type_map[y] + matmul_scale = fold_constant(x_t.scale * y_t.scale) + matmul_zp = relay.const(0) + out = relay.qnn.op.batch_matmul(x, y, x_t.zero_point, y_t.zero_point, x_t.scale, y_t.scale) + return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype)] @register_fake_quantization_to_integer("concatenate") @@ -126,8 +156,9 @@ def concat(expr, type_map): """Rewrite a concat op""" scales = [] zps = [] - for arg in expr.args[0].fields: - t = type_map[arg] + + tuple_type = type_map[expr.args[0]] + for t in tuple_type.types: scales.append(t.scale) zps.append(t.zero_point) @@ -141,7 +172,21 @@ def concat(expr, type_map): out_type.zero_point, **expr.attrs, ) - return [out, out_type.scale, out_type.zero_point, out_type.dtype] + return [out, out_type] + + +@register_fake_quantization_to_integer("split") +def split(expr, type_map): + """Rewrite a split op""" + arg = expr.args[0] + t = type_map[arg] + attrs = {**expr.attrs} + if isinstance(attrs["indices_or_sections"], tvm.tir.IntImm): + num_split = attrs["indices_or_sections"].value + attrs["indices_or_sections"] = num_split + else: + num_split = len(attrs["indices_or_sections"]) + 1 + return [expr, TupleAffineType([t] * num_split)] @register_fake_quantization_to_integer("clip") @@ -163,4 +208,133 @@ def clip(expr, type_map): amin = relay.op.round(relay.op.const(amin) / scale + z_p) amax = relay.op.round(relay.op.const(amax) / scale + z_p) out = relay.op.minimum(relay.op.maximum(arg, amin), amax) - return [out, t.scale, t.zero_point, t.dtype] + return [out, t] + + +@register_fake_quantization_to_integer("nn.pad") +def pad(expr, type_map): + """Rewite an nn.pad op""" + arg = expr.args[0] + t = type_map[arg] + pad_value = expr.args[1] + ## TF2ONNX will sometimes implement the pad_value as a constant without a quantize + ## To support that, the pass lets branches that terminate in a constant through + if pad_value in type_map: + ## if the pad value is calcuated from a dequantize op, it should be in the type map + ## and we need to make sure it's affine type matches the arg + pad_t = type_map[pad_value] + if not tvm.ir.structural_equal(t, pad_t): + pad_value = relay.qnn.op.requantize( + pad_value, + pad_t.scale, + pad_t.zero_point, + t.scale, + t.zero_point, + out_dtype=t.dtype, + ) + else: + ## If the pad-value is a constant, we need to quantize it + assert isinstance(pad_value, relay.expr.Constant) + pad_value = relay.qnn.op.quantize(pad_value, t.scale, t.zero_point) + + out = relay.op.nn.pad(arg, pad_value=pad_value, **expr.attrs) + return [out, t] + + +def get_binary_types(expr, type_map): + """Get Affine types of a binary op's inputs and unify them""" + ##Support the case where one input is quantized and the other is a constant float + left = expr.args[0] + right = expr.args[1] + left_t = None + right_t = None + + if left in type_map: + left_t = type_map[left] + if right in type_map: + right_t = type_map[right] + + out_t = type_map[expr] + if left_t is None and right_t is None: + raise TypeError("neither input is quantized!") + if left_t is None: + assert isinstance(left, relay.expr.Constant) + left = relay.qnn.op.quantize( + left, right_t.scale, right_t.zero_point, out_dtype=right_t.dtype + ) + left_t = right_t + out_t = right_t + if right_t is None: + assert isinstance(right, relay.expr.Constant) + right = relay.qnn.op.quantize( + right, left_t.scale, left_t.zero_point, out_dtype=left_t.dtype + ) + right_t = left_t + out_t = left_t + + # Handle the case of mismatched inputs + if not left_t.dtype == out_t.dtype: + out_t = left_t + + return left, right, left_t, right_t, out_t + + +def register_binary_qnn(op_name, op): + """Register a Binary Op that converts to QNN""" + + def binary(expr, type_map): + left, right, left_t, right_t, out_t = get_binary_types(expr, type_map) + out = op( + left, + right, + left_t.scale, + left_t.zero_point, + right_t.scale, + right_t.zero_point, + out_t.scale, + out_t.zero_point, + ) + return [out, out_t] + + return register_fake_quantization_to_integer(op_name, binary) + + +# Use lambdas here to avoid a circular import problem +# pylint: disable=unnecessary-lambda +register_binary_qnn("add", lambda *args: relay.qnn.op.add(*args)) +register_binary_qnn("multiply", lambda *args: relay.qnn.op.mul(*args)) +register_binary_qnn("subtract", lambda *args: relay.qnn.op.subtract(*args)) + + +def register_binary_identity(op_name, op): + """Register a binary op that works directly on int8""" + + def binary(expr, type_map): + left, right, left_t, right_t, out_t = get_binary_types(expr, type_map) + if left_t != out_t: + left = relay.qnn.op.requantize( + left, + left_t.scale, + left_t.zero_point, + out_t.scale, + out_t.zero_point, + out_dtype=out_t.dtype, + ) + + if right_t != out_t: + right = relay.qnn.op.requantize( + right, + right_t.scale, + right_t.zero_point, + out_t.scale, + out_t.zero_point, + out_dtype=out_t.dtype, + ) + out = op(left, right) + return [out, out_t] + + return register_fake_quantization_to_integer(op_name, binary) + + +register_binary_identity("minimum", relay.op.minimum) +register_binary_identity("maximum", relay.op.maximum) diff --git a/apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf b/python/tvm/relay/transform/infer_layout_utils.py old mode 100644 new mode 100755 similarity index 57% rename from apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf rename to python/tvm/relay/transform/infer_layout_utils.py index d298325eb4a4..2dc0d25e2dcd --- a/apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf +++ b/python/tvm/relay/transform/infer_layout_utils.py @@ -14,18 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -# This file is specific to the nRF5340 DK board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y +# pylint: disable=invalid-name, unused-argument, missing-docstring, unused-import +""" +Relay infer correct layout pass. +""" +import tvm +from tvm.runtime import Object +from . import _ffi_api -# For AOT runtime which requires lots of function call. -CONFIG_MAIN_STACK_SIZE=2000 -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y +@tvm._ffi.register_object("relay._transform.InferCorrectLayoutOutput") +class InferCorrectLayoutOutput(Object): + """An output structure to hold results from FInferCorrectLayout calls.""" -# For debugging. -CONFIG_LED=y + def __init__(self, input_layouts, output_layouts, new_attrs): + self.__init_handle_by_constructor__( + _ffi_api.InferCorrectLayoutOutput, input_layouts, output_layouts, new_attrs + ) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 6f8ecb970221..fb4d3fa208a8 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -18,7 +18,6 @@ """Default behavior for ops in mixed_precision pass. Import this file to use.""" from typing import List -from tvm import relay from tvm.relay.op import register_mixed_precision_conversion # MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory @@ -82,8 +81,6 @@ "divide", "nn.bias_add", "nn.batch_norm", - "sum", - "mean", "sqrt", "shape_of", # Simple activations @@ -108,15 +105,9 @@ # "nn.global_max_pool1d", # does not exist yet "nn.global_max_pool2d", # "nn.global_max_pool3d", # does not exist yet - # "nn.global_avg_pool1d", # does not exist yet - "nn.global_avg_pool2d", - # "nn.global_avg_pool3d", # does not exist yet "nn.adaptive_max_pool1d", "nn.adaptive_max_pool2d", "nn.adaptive_max_pool3d", - "nn.adaptive_avg_pool1d", - "nn.adaptive_avg_pool2d", - "nn.adaptive_avg_pool3d", ] DEFAULT_NEVER_LIST = [ # In general if |f(x)| >> |x| for expected inputs then put the op here. @@ -129,6 +120,16 @@ # Error function doesn't seem to be able to be lowered into fp16 version in llvm. # Move to follow list when it does. "erf", + # Do not allow arange arguments (begin/end) to be fp16. "end" can be a big fp32 number + # not representable in fp16. + "arange", + # Ops that could involve a large summation are not allowed in fp16. + "nn.global_avg_pool2d", + "nn.adaptive_avg_pool1d", + "nn.adaptive_avg_pool2d", + "nn.adaptive_avg_pool3d", + "sum", + "mean", ] @@ -141,7 +142,7 @@ def decorator(func): return decorator -def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]: +def get_generic_out_dtypes(call_node: "relay.Call", mixed_precision_type: str) -> List[str]: """A function which returns output dtypes in a way which works for most ops. Parameters @@ -174,15 +175,15 @@ def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> # Take in CallNodes and a DType and returns a conversion type, # an accumulation dtype, and an output_dtype. @register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST) -def generic_always_op(call_node: relay.Call, mixed_precision_type: str) -> List: +def generic_always_op(call_node: "relay.Call", mixed_precision_type: str) -> List: return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, mixed_precision_type) @register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST) -def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List: +def generic_follow_op(call_node: "relay.Call", mixed_precision_type: str) -> List: return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, mixed_precision_type) @register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST) -def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List: +def generic_never_op(call_node: "relay.Call", mixed_precision_type: str) -> List: return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 6294e7acea15..9a7857a01fe6 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1093,7 +1093,7 @@ def DenseToSparse(weight_name, weight_shape): return _ffi_api.DenseToSparse(weight_name, weight_shape) -def Conv2dToSparse(weight_name, weight_shape, layout): +def Conv2dToSparse(weight_name, weight_shape, layout, kernel_size): """ Rewrite qualified ```nn.conv2d operation``` to ```nn.sparse_conv2d``` @@ -1113,7 +1113,27 @@ def Conv2dToSparse(weight_name, weight_shape, layout): ret : tvm.transform.Pass The registered DenseToSparse pass. """ - return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout) + return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout, kernel_size) + + +def Conv2dToSparse2(layout, kernel_size, blocksize, sparsity_threshold): + """ + Rewrite freezed ```nn.conv2d``` operation to ```nn.sparse_conv2d``` + + Parameters + ---------- + layout : str + layout of data + + kernel_size : int + kernel size of conv2d + + Returns + ------- + ret : tvm.transform.Pass + The registered DenseToSparse pass. + """ + return _ffi_api.Conv2dToSparse2(layout, kernel_size, *blocksize, sparsity_threshold) def SimplifyFCTranspose(target_weight_name): diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index 4531ceca2ce9..d8199c4c93a6 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -327,7 +327,7 @@ def text_summary(self): res += "----------------------------\n" for item in data["server_info"]: addr = item["addr"] - res += addr[0] + ":" + str(addr[1]) + "\t" + res += str(addr[0]) + ":" + str(addr[1]) + "\t" res += item["key"] + "\n" key = item["key"].split(":")[1] # 'server:rasp3b` -> 'rasp3b' if key not in total_ct: diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 0b49b675d77d..52a7a898269a 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -365,9 +365,14 @@ def _popen_start_rpc_server( custom_addr=None, silent=False, no_fork=False, + server_init_callback=None, ): if no_fork: multiprocessing.set_start_method("spawn") + + if server_init_callback: + server_init_callback() + # This is a function that will be sent to the # Popen worker to run on a separate process. # Create and start the server in a different thread @@ -420,6 +425,25 @@ class Server(object): no_fork: bool, optional Whether forbid fork in multiprocessing. + + server_init_callback: Callable, optional + Additional initialization function when starting the server. + + Note + ---- + The RPC server only sees functions in the tvm namespace. + To bring additional custom functions to the server env, you can use server_init_callback. + + .. code:: python + + def server_init_callback(): + import tvm + # must import mypackage here + import mypackage + + tvm.register_func("function", mypackage.func) + + server = rpc.Server(host, server_init_callback=server_init_callback) """ def __init__( @@ -434,6 +458,7 @@ def __init__( custom_addr=None, silent=False, no_fork=False, + server_init_callback=None, ): try: if _ffi_api.ServerLoop is None: @@ -455,6 +480,7 @@ def __init__( custom_addr, silent, no_fork, + server_init_callback, ], ) # receive the port diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 8107ab5b87d2..25a57bbb1c36 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -20,7 +20,8 @@ import os import ctypes import struct -from collections import namedtuple +from typing import Sequence +import numpy as np import tvm._ffi from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY @@ -30,8 +31,69 @@ from . import _ffi_api -# profile result of time evaluator -ProfileResult = namedtuple("ProfileResult", ["mean", "results"]) +class BenchmarkResult: + """Runtimes from benchmarking""" + + def __init__(self, results: Sequence[float]): + """Construct a new BenchmarkResult from a sequence of runtimes. + + Parameters + ---------- + results : Sequence[float] + Raw times from benchmarking + + Attributes + ---------- + min : float + Minimum runtime in seconds of all results. + mean : float + Mean runtime in seconds of all results. If py:meth:`Module.time_evaluator` or + `benchmark` is called with `number` > 0, then each result is already the mean of a + `number` of runtimes, so this becomes the mean of means. + median : float + Median runtime in seconds of all results. If py:meth:`Module.time_evaluator` is called + with `number` > 0, then each result is already the mean of a `number` of runtimes, so + this becomes the median of means. + max : float + Maximum runtime in seconds of all results. If py:meth:`Module.time_evaluator` is called + with `number` > 0, then each result is already the mean of a `number` of runtimes, so + this becomes the maximum of those means. + std : float + Standard deviation in seconds of runtimes. If py:meth:`Module.time_evaluator` is called + with `number` > 0, then each result is already the mean of a `number` of runtimes, so + this becomes the standard deviation of means. + results : Sequence[float] + The collected runtimes (in seconds). This may be a series of mean runtimes if + py:meth:`Module.time_evaluator` or `benchmark` was run with `number` > 1. + """ + self.results = results + self.mean = np.mean(self.results) + self.std = np.std(self.results) + self.median = np.median(self.results) + self.min = np.min(self.results) + self.max = np.max(self.results) + + def __repr__(self): + return "BenchmarkResult(min={}, mean={}, median={}, max={}, std={}, results={})".format( + self.min, self.mean, self.median, self.max, self.std, self.results + ) + + def __str__(self): + return """Execution time summary: +{:^12} {:^12} {:^12} {:^12} {:^12} +{:^12.4f} {:^12.4f} {:^12.4f} {:^12.4f} {:^12.4f} + """.format( + "mean (ms)", + "median (ms)", + "max (ms)", + "min (ms)", + "std (ms)", + self.mean * 1000, + self.median * 1000, + self.max * 1000, + self.min * 1000, + self.std * 1000, + ) class Module(object): @@ -209,7 +271,7 @@ def time_evaluator(self, func_name, dev, number=10, repeat=1, min_repeat_ms=0, f Returns ------- ftimer : function - The function that takes same argument as func and returns a ProfileResult. + The function that takes same argument as func and returns a BenchmarkResult. The ProfileResult reports `repeat` time costs in seconds. """ try: @@ -230,12 +292,11 @@ def evaluator(*args): blob = feval(*args) fmt = "@" + ("d" * repeat) results = struct.unpack(fmt, blob) - mean = sum(results) / float(repeat) - return ProfileResult(mean=mean, results=results) + return BenchmarkResult(results) return evaluator except NameError: - raise NameError("time_evaluate is only supported when RPC is enabled") + raise NameError("time_evaluator is only supported when RPC is enabled") def _collect_from_import_tree(self, filter_func): """Helper function to collect modules from the tree matching a filter_func, then return it. diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index 0c2abd296b42..cac7c6d8fdae 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -56,8 +56,10 @@ def __dir__(self): return sorted([fnames(i) for i in range(size)] + class_names) def __getattr__(self, name): - if name in self.__slots__: - raise AttributeError(f"{name} is not set") + # specially check handle since + # this is required for PackedFunc calls + if name == "handle": + raise AttributeError("handle is not set") try: return _ffi_node_api.NodeGetAttr(self, name) diff --git a/python/tvm/runtime/profiler_vm.py b/python/tvm/runtime/profiler_vm.py index e1c3dc66a360..b3043d8b8760 100644 --- a/python/tvm/runtime/profiler_vm.py +++ b/python/tvm/runtime/profiler_vm.py @@ -50,7 +50,7 @@ def get_stat(self, sort_by_time=True): # pylint: disable=unused-argument warnings.warn("get_stat has been removed, use profile instead") return "" - def profile(self, *args, func_name="main", **kwargs): + def profile(self, *args, func_name="main", collectors=None, **kwargs): """Profile a function call. Parameters @@ -58,6 +58,9 @@ def profile(self, *args, func_name="main", **kwargs): func_name : str The name of the function. + collectors : Optional[Sequence[MetricCollector]] + Extra metrics to collect. + args : list[tvm.runtime.NDArray] or list[np.ndarray] The arguments to the function. @@ -69,6 +72,7 @@ def profile(self, *args, func_name="main", **kwargs): timing_results : str Overall and per-op timing results formatted in a table. """ + collectors = [] if collectors is None else collectors if args or kwargs: self.set_input(func_name, *args, **kwargs) - return self._profile(func_name) + return self._profile(func_name, collectors) diff --git a/python/tvm/runtime/profiling.py b/python/tvm/runtime/profiling.py deleted file mode 100644 index 5a1cd6796b64..000000000000 --- a/python/tvm/runtime/profiling.py +++ /dev/null @@ -1,48 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Registration of profiling objects in python.""" - -from .. import _ffi -from . import Object - -_ffi._init_api("runtime.profiling", __name__) - - -@_ffi.register_object("runtime.profiling.Report") -class Report(Object): - """A container for information gathered during a profiling run. - - Attributes - ---------- - calls : Array[Dict[str, Object]] - Per-call profiling metrics (function name, runtime, device, ...). - - device_metrics : Dict[Device, Dict[str, Object]] - Per-device metrics collected over the entire run. - """ - - def csv(self): - """Convert this profiling report into CSV format. - - This only includes calls and not overall metrics. - - Returns - ------- - csv : str - `calls` in CSV format. - """ - return AsCSV(self) diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py new file mode 100644 index 000000000000..881691609398 --- /dev/null +++ b/python/tvm/runtime/profiling/__init__.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Registration of profiling objects in python.""" + +from typing import Dict, Sequence, Optional +from ... import _ffi +from . import _ffi_api +from .. import Object, Device + + +@_ffi.register_object("runtime.profiling.Report") +class Report(Object): + """A container for information gathered during a profiling run. + + Attributes + ---------- + calls : Array[Dict[str, Object]] + Per-call profiling metrics (function name, runtime, device, ...). + + device_metrics : Dict[Device, Dict[str, Object]] + Per-device metrics collected over the entire run. + """ + + def csv(self): + """Convert this profiling report into CSV format. + + This only includes calls and not overall metrics. + + Returns + ------- + csv : str + `calls` in CSV format. + """ + return _ffi_api.AsCSV(self) + + def json(self): + """Convert this profiling report into JSON format. + + Example output: + + .. code-block: + + { + "calls": [ + { + "Duration (us)": { + "microseconds": 12.3 + }, + "Name": "fused_dense", + "Count": { + "count": 1 + }, + "Percent": { + "percent": 10.3 + } + } + ], + "device_metrics": { + "cpu": { + "Duration (us)": { + "microseconds": 334.2 + }, + "Percent": { + "percent": 100 + } + } + } + } + + {"calls": + [ + {"Duration (us)": {"microseconds": 12.3} + ,"Name": "fused_dense" + ,"Count": {"count":1} + ,"Percent": {"percent": 10.3} + } + ], + "device_metrics": + {"cpu": + {"Duration (us)": {"microseconds": 334.2} + ,"Percent": {"percent": 100.0} + } + } + } + + Returns + ------- + json : str + Formatted JSON + """ + return _ffi_api.AsJSON(self) + + +@_ffi.register_object("runtime.profiling.MetricCollector") +class MetricCollector(Object): + """Interface for user defined profiling metric collection.""" + + +@_ffi.register_object("runtime.profiling.DeviceWrapper") +class DeviceWrapper(Object): + """Wraps a tvm.runtime.Device""" + + def __init__(self, dev: Device): + self.__init_handle_by_constructor__(_ffi_api.DeviceWrapper, dev) + + +# We only enable this class when TVM is build with PAPI support +if _ffi.get_global_func("runtime.profiling.PAPIMetricCollector", allow_missing=True) is not None: + + @_ffi.register_object("runtime.profiling.PAPIMetricCollector") + class PAPIMetricCollector(MetricCollector): + """Collects performance counter information using the Performance + Application Programming Interface (PAPI). + """ + + def __init__(self, metric_names: Optional[Dict[Device, Sequence[str]]] = None): + """ + Parameters + ---------- + metric_names : Optional[Dict[Device, Sequence[str]]] + List of per-device metrics to collect. You can find a list of valid + metrics by runing `papi_native_avail` from the command line. + """ + metric_names = {} if metric_names is None else metric_names + wrapped = dict() + for dev, names in metric_names.items(): + wrapped[DeviceWrapper(dev)] = names + self.__init_handle_by_constructor__(_ffi_api.PAPIMetricCollector, wrapped) diff --git a/apps/microtvm/zephyr/aot_demo/boards/mps2_an521.conf b/python/tvm/runtime/profiling/_ffi_api.py similarity index 76% rename from apps/microtvm/zephyr/aot_demo/boards/mps2_an521.conf rename to python/tvm/runtime/profiling/_ffi_api.py index 3916b17c49cf..d26b847a699f 100644 --- a/apps/microtvm/zephyr/aot_demo/boards/mps2_an521.conf +++ b/python/tvm/runtime/profiling/_ffi_api.py @@ -14,15 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -# This file is specific to the MPS2-AN512 board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y +"""FFI for profiling""" +from ... import _ffi -# For debugging. -CONFIG_LED=n +_ffi._init_api("runtime.profiling", __name__) diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 429da5892628..6416ad7814e1 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -45,7 +45,7 @@ def _convert(arg, cargs): _convert(field, field_args) cargs.append(container.tuple_object(field_args)) elif isinstance(arg, (_base.numeric_types, bool)): - dtype = "int32" if isinstance(arg, (int, bool)) else "float32" + dtype = "int32" if isinstance(arg, (_base.integer_types, bool)) else "float32" value = tvm.nd.array(np.array(arg, dtype=dtype), device=tvm.cpu(0)) cargs.append(value) else: @@ -345,6 +345,7 @@ def __init__(self, exe, device, memory_cfg=None): self._invoke_stateful = self.module["invoke_stateful"] self._get_output = self.module["get_output"] self._get_num_outputs = self.module["get_num_outputs"] + self._get_input_index = self.module["get_input_index"] self._set_input = self.module["set_input"] self._setup_device(device, memory_cfg) @@ -490,3 +491,114 @@ def get_outputs(self): outputs : List[NDArray] """ return [self._get_output(i) for i in range(self._get_num_outputs())] + + def get_input_index(self, input_name, func_name="main"): + """Get inputs index via input name. + Parameters + ---------- + name : str + The input key name + func_name : str + The function name + + Returns + ------- + index: int + The input index. -1 will be returned if the given input name is not found. + """ + return self._get_input_index(input_name, func_name) + + def benchmark( + self, + device, + *args, + func_name="main", + repeat=5, + number=5, + min_repeat_ms=None, + end_to_end=False, + **kwargs, + ): + """Calculate runtime of a function by repeatedly calling it. + + Use this function to get an accurate measurement of the runtime of a function. The function + is run multiple times in order to account for variability in measurements, processor speed + or other external factors. Mean, median, standard deviation, min and max runtime are all + reported. On GPUs, CUDA and ROCm specifically, special on-device timers are used so that + synchonization and data transfer operations are not counted towards the runtime. This allows + for fair comparison of runtimes across different functions and models. The `end_to_end` flag + switches this behavior to include data transfer operations in the runtime. + + The benchmarking loop looks approximately like so: + + .. code-block:: python + + for r in range(repeat): + time_start = now() + for n in range(number): + func_name() + time_end = now() + total_times.append((time_end - time_start)/number) + + + Parameters + ---------- + func_name : str + The function to benchmark + + repeat : int + Number of times to run the outer loop of the timing code (see above). The output will + contain `repeat` number of datapoints. + + number : int + Number of times to run the inner loop of the timing code. This inner loop is run in + between the timer starting and stopping. In order to amortize any timing overhead, + `number` should be increased when the runtime of the function is small (less than a 1/10 + of a millisecond). + + min_repeat_ms : Optional[float] + If set, the inner loop will be run until it takes longer than `min_repeat_ms` + milliseconds. This can be used to ensure that the function is run enough to get an + accurate measurement. + + end_to_end : bool + If set, include time to transfer input tensors to the device and time to transfer + returned tensors in the total runtime. This will give accurate timings for end to end + workloads. + + args : Sequence[Object] + Arguments to the function. These are cached before running timing code, so that data + transfer costs are not counted in the runtime. + + kwargs : Dict[str, Object] + Named arguments to the function. These are cached like `args`. + + Returns + ------- + timing_results : BenchmarkResult + Runtimes of the function. Use `.mean` to access the mean runtime, use `.results` to + access the individual runtimes (in seconds). + """ + min_repeat_ms = 0 if min_repeat_ms is None else min_repeat_ms + if end_to_end: + # We need to unpack keyword arguments into positional arguments + packed_args = list(args) + for k, v in kwargs.items(): + i = self.get_input_index(k, func_name) + if i < 0: + raise TypeError(f"{func_name}() got an unexpected keyword argument '{k}'") + while i >= len(packed_args): + packed_args.append(None) + packed_args[i] = v + return self.module.time_evaluator( + "invoke_return_to_device", + device, + repeat=repeat, + number=number, + min_repeat_ms=min_repeat_ms, + )(func_name, device.device_type, device.device_id, *packed_args) + if args or kwargs: + self.set_input(func_name, *args, **kwargs) + return self.module.time_evaluator( + "invoke", device, repeat=repeat, number=number, min_repeat_ms=min_repeat_ms + )(func_name) diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index ae3e9d885f1a..44c92b792f12 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -57,7 +57,7 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # match_buffers of the block, # which bind a sub-region of source buffer into a new buffer - D = tir.match_buffer_region(C[vi, vj]) + D = tir.match_buffer(C[vi, vj], ()) # init part of the block, executed when all reduce axes are the beginning value with tir.init(): @@ -65,13 +65,13 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # block body CC[0, 0] = A[vi, vk] * B[vj, vk] - D[0, 0] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0] + D[()] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0] """ alloc_buffers: List[Buffer] = [] """List[Buffer]: list of tir.alloc_buffer statements in the block signature""" match_buffers: List[MatchBufferRegion] = [] - """List[MatchBufferRegion]: list of tir.match_buffer_region statements in the block signature""" + """List[MatchBufferRegion]: list of tir.match_buffer statements in the block signature""" iter_bindings: Mapping[Var, PrimExpr] = {} """Mapping[Var, PrimExpr]: map of block iter var to its values""" reads: Optional[List[BufferSlice]] = None diff --git a/python/tvm/script/intrin.py b/python/tvm/script/intrin.py index 38ff1b71f07d..e2d44440e2ac 100644 --- a/python/tvm/script/intrin.py +++ b/python/tvm/script/intrin.py @@ -37,67 +37,67 @@ def handle(self, arg_list: List[Any], span: tvm.ir.Span): @register def bool(imm, span): - return tvm.tir.Cast("bool", imm, span) + return imm.astype("bool", span) @register def int8(imm, span): - return tvm.tir.Cast("int8", imm, span) + return imm.astype("int8", span) @register def int16(imm, span): - return tvm.tir.Cast("int16", imm, span) + return imm.astype("int16", span) @register def int32(imm, span): - return tvm.tir.Cast("int32", imm, span) + return imm.astype("int32", span) @register def int64(imm, span): - return tvm.tir.Cast("int64", imm, span) + return imm.astype("int64", span) @register def uint8(imm, span): - return tvm.tir.Cast("uint8", imm, span) + return imm.astype("uint8", span) @register def uint16(imm, span): - return tvm.tir.Cast("uint16", imm, span) + return imm.astype("uint16", span) @register def uint32(imm, span): - return tvm.tir.Cast("uint32", imm, span) + return imm.astype("uint32", span) @register def uint64(imm, span): - return tvm.tir.Cast("uint64", imm, span) + return imm.astype("uint64", span) @register def float8(imm, span): - return tvm.tir.Cast("float8", imm, span) + return imm.astype("float8", span) @register def float16(imm, span): - return tvm.tir.Cast("float16", imm, span) + return imm.astype("float16", span) @register def float32(imm, span): - return tvm.tir.Cast("float32", imm, span) + return imm.astype("float32", span) @register def float64(imm, span): - return tvm.tir.Cast("float64", imm, span) + return imm.astype("float64", span) @register @@ -120,6 +120,11 @@ def floormod(x, y, span): return tvm.tir.floormod(x, y, span) +@register +def abs(x, span): + return tvm.tir.abs(x, span) + + @register def load(dtype, var, index, predicate=None, span=None): return tvm.tir.Load(dtype, var, index, predicate, span) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 49f71041590b..9acf21b6ba3a 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -784,10 +784,11 @@ def transform_Slice(self, node): def transform_Subscript(self, node): """Array access visitor. - By now only 2 types of Subscript are supported: + By now only 3 types of Subscript are supported: 1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore) Var[index] Buffer element access() 2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...])) + 3. Array[index], Buffer element access """ symbol = self.transform(node.params[0]) @@ -812,6 +813,25 @@ def transform_Subscript(self, node): return BufferSlice( symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span) ) + elif isinstance(symbol, tvm.container.Array): + if len(indexes) > 1: + self.report_error( + "Array access should be one-dimension access, but the indices are " + + str(indexes), + node.span, + ) + index = indexes[0] + if not isinstance(index, (int, tvm.tir.expr.IntImm)): + self.report_error( + "Array access index expected int or IntImm, but got " + type(index), + node.span, + ) + if int(index) >= len(symbol): + self.report_error( + f"Array access out of bound, size: {len(symbol)}, got index {index}.", + node.span, + ) + return symbol[int(index)] else: self.report_error( f"Cannot subscript from a {type(symbol).__name__}. Only variables and " diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index a23401d926e9..bb408f6cdc8f 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -110,10 +110,9 @@ def __init__(self): def allocate(extents, dtype, scope, condition=True, span=None): condition = tvm.runtime.convert(condition) scope = tvm.runtime.convert(scope) - body = tvm.tir.Allocate( + return tvm.tir.Allocate( self.buffer_var, dtype, extents, condition, self.body, span=span ) - return tvm.tir.AttrStmt(self.buffer_var, "storage_scope", scope, body, span=span) super().__init__(allocate, concise_scope=True, def_symbol=True) self.buffer_var = None @@ -140,7 +139,7 @@ def enter_scope( def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None): """Setup buffer var for a given type.""" - buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) + buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope) self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index 7eb938c58f96..25af7635742b 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -96,12 +96,24 @@ def handle( @register class MatchBuffer(SpecialStmt): - """Special Stmt match_buffer(var, shape, dtype, data, strides, elem_offset, scope, align, + """Special Stmt match_buffer(param, shape, dtype, data, strides, elem_offset, scope, align, offset_factor, buffer_type) + + Note + ---- + This Special Stmt will perform different behavior depends on the type of param. + If the param is a var in function parameter, it will create a buffer from DLTensor. + Else if the param is a subregion of other buffers, then create a subregion match inside a block. + Example ------- + Match buffer from function parameter .. code-block:: python A = tir.match_buffer(a, (128, 128), dtype="float32") + + Match buffer from Buffer subregion + .. code-block:: python + A = tir.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32") """ def __init__(self): @@ -123,10 +135,6 @@ def match_buffer( "match_buffer must be assigned to a buffer, e.g. A = match_buffer(...)", self.node.span, ) - if param not in self.context.func_params: - self.context.report_error( - "Can not bind non-input param to buffer", self.node.rhs.params[0].span - ) if strides is None: strides = [] align = convert_to_int(align, "align", self.context.report_error, self.node.span) @@ -146,7 +154,23 @@ def match_buffer( buffer_type, span=span, ) - self.context.func_buffer_map[param] = buffer + if isinstance(param, tvm.tir.Var): + if param not in self.context.func_params: + self.context.report_error( + "Can not bind non-input param to buffer", self.node.rhs.params[0].span + ) + self.context.func_buffer_map[param] = buffer + elif isinstance(param, BufferSlice): + buffer_region = buffer_slice_to_region(param) + self.context.current_block_scope().match_buffers.append( + tvm.tir.MatchBufferRegion(buffer, buffer_region) + ) + else: + self.context.report_error( + "The source of match_buffer expected Var or BufferSlice, but got " + + str(type(param)), + self.node.rhs.params[0].span, + ) self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) super().__init__(match_buffer, def_symbol=True) @@ -225,7 +249,7 @@ def alloc_buffer( data=None, strides=None, elem_offset=None, - scope="", + scope="global", align=-1, offset_factor=0, buffer_type="default", @@ -414,68 +438,6 @@ def where(predicate, span=None): super().__init__(where, def_symbol=False) -@register -class BlockMatchBufferRegion(SpecialStmt): - """Special function match_buffer_region(source, strides, elem_offset, align, offset_factor) - - Example - ------- - .. code-block:: python - - B = tir.match_buffer_region(A[0: 4]) - """ - - def __init__(self): - def match_buffer_region( - source, - strides=None, - elem_offset=None, - align=-1, - offset_factor=0, - span=None, - ): - assert self.context, "call 'exit_scope' before 'enter_scope'" - if not isinstance(self.node, ast.Assign): - self.context.report_error( - "match_buffer_region must be assigned to a buffer, " - + "e.g. A = match_buffer_region(...)", - self.node.span, - ) - - if strides is None: - strides = [] - align = convert_to_int(align, "align", self.context.report_error, self.node.span) - offset_factor = convert_to_int( - offset_factor, "offset_factor", self.context.report_error, self.node.span - ) - - if not isinstance(source, BufferSlice): - self.context.report_error( - "match_buffer_region needs a buffer region as source", - span=span, - ) - buffer_region = buffer_slice_to_region(source) - shape = [r.extent for r in buffer_region.region] - buffer = tvm.tir.decl_buffer( - shape, - buffer_region.buffer.dtype, - self.node.lhs.id.name, - data=None, - strides=strides, - elem_offset=elem_offset, - scope=buffer_region.buffer.scope, - data_alignment=align, - offset_factor=offset_factor, - span=span, - ) - self.context.current_block_scope().match_buffers.append( - tvm.tir.MatchBufferRegion(buffer, buffer_region) - ) - self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) - - super().__init__(match_buffer_region, def_symbol=True) - - @register class VarDef(SpecialStmt): """Special function for defining a Var""" @@ -491,6 +453,22 @@ def var(dtype, span): super().__init__(var, def_symbol=True) +@register +class BufferVarDef(SpecialStmt): + """Special function for defining a variable of pointer type""" + + def __init__(self): + def buffer_var(dtype, storage_scope, span): + assert isinstance( + self.node, ast.Assign + ), f"BufferVarDef expected ast.Assign but got {type(self.node)}" + ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) + v = te.var(self.node.lhs.id.name, ptr_type, span=span) + self.context.update_symbol(v.name, v, self.node) + + super().__init__(buffer_var, def_symbol=True) + + @register class EnvThread(SpecialStmt): """Bind a var to thread env""" diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 0d3feaaadbc2..aa9226101b52 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -281,10 +281,17 @@ def intel_graphics(model="unknown", options=None): MICRO_SUPPORTED_MODELS = { "host": [], + "atsamd51": ["-mcpu=cortex-m4"], + "cxd5602gg": ["-mcpu=cortex-m4"], + "esp32": [], + "imxrt1060": ["-mcpu=cortex-m7"], "mps2_an521": ["-mcpu=cortex-m33"], + "nrf52840": ["-mcpu=cortex-m4"], "nrf5340dk": ["-mcpu=cortex-m33"], + "sam3x8e": ["-mcpu=cortex-m3"], "stm32f746xx": ["-mcpu=cortex-m7", "-march=armv7e-m"], "stm32l4r5zi": ["-mcpu=cortex-m4"], + "zynq_mp_r5": ["-mcpu=cortex-r5"], } @@ -306,7 +313,7 @@ def micro(model="unknown", options=None): options, ) - if (not options) or (options and "--executor=aot" not in options): + if (not options) or (options and not any("-executor=aot" in o for o in options)): opts = _merge_opts(opts, "--system-lib") # NOTE: in the future, the default micro target will be LLVM except when @@ -406,79 +413,118 @@ def bifrost(model="unknown", options=None): return Target(" ".join(["opencl"] + opts)) -def hexagon(cpu_ver="v66", sim_args=None, llvm_args=None, hvx=128): +def hexagon(cpu_ver="v66", **kwargs): """Returns a Hexagon target. Parameters ---------- - cpu_ver : str + cpu_ver : str (default: "v66") CPU version used for code generation. Not all allowed cpu str will be valid, LLVM will throw an error. - sim_args : str or list of str + + Recognized keyword parameters + ----------------------------- + hvx : int (default: 128) + Size of HVX vector in bytes. Value of 0 disables HVX codegen. + sim_options : str or list of str (default: None) User defined sim arguments. CPU version defaults to cpu_ver. Otherwise, separate versions are used for codegen and sim. Not all allowed cpu strings will be valid, simulator will throw an error if invalid. Does not affect codegen. - llvm_args : str or list of str + llvm_options : str or list of str (default: None) User defined compiler arguments. - hvx : int - Size of hvx register. Value of 0 indicates disabled hvx. """ + + # Some of the target parameters correspond to target kind attributes + # listed in src/target/target_kind.cc. For those parameters, their + # names follow the attribute names with the exception of '_' being used + # in place of '-'. + # Example compiler arguments # llvm -mtriple=hexagon -mcpu=hexagonv66 -mattr=+hvxv66,+hvx-length128b # Check for valid codegen cpu - valid_hex = ["v60", "v62", "v65", "v66", "v67", "v67t"] + valid_hex = ["v60", "v62", "v65", "v66", "v67", "v67t", "v68"] try: cpu_ver = cpu_ver[cpu_ver.index("v") :].lower() - assert 3 <= len(cpu_ver) <= 4 + assert cpu_ver in valid_hex except: msg = "{} is not a valid Hexagon version\nvalid versions include {}" raise ValueError(msg.format(cpu_ver, valid_hex)) from None - assert hvx in [0, 64, 128] + # Target configuration: + config = { + "hvx": 128, + "sim_options": None, + "llvm_options": None, + } + config.update(kwargs) + + # Warn about obsolete parameter names. + if config.get("sim_args"): + msg = "The keyword parameter 'sim_args' is deprecated, use 'sim_options' instead" + warnings.warn(msg, stacklevel=2) + config.update({"sim_options": config["sim_args"]}) + if config.get("llvm_args"): + msg = "The keyword parameter 'llvm_args' is deprecated, use 'llvm_options' instead" + warnings.warn(msg, stacklevel=2) + config.update({"llvm_options": config["llvm_args"]}) + + # LLVM target string + def create_llvm_target(cpu_ver, config): + """ Create LLVM target string. """ - # Target string - def create_target(cpu_ver): target = " -mtriple=hexagon" mcpu = " -mcpu=hexagon" + cpu_ver - mattr = "" - # HVX enable - if hvx: - mattr = " -mattr=+hvx" + cpu_ver + ",+hvx-length" + str(hvx) + "b" - return target + mcpu + mattr - - # Simulator string - def create_sim(cpu_ver, sim_args): - def validate_hvx_length(codegen_hvx, sim_args): - if sim_args and "--hvx_length" in sim_args: + + # Process the options that affect target features and return the + # target feature string. + def create_target_features(config): + tfs = [] + if config["hvx"] > 0: + valid_hvx = [0, 64, 128] + if not config["hvx"] in valid_hvx: + raise ValueError("Invalid hvx value, should be one of " + str(valid_hvx)) + tfs += ["+hvx" + cpu_ver, "+hvx-length" + str(config["hvx"]) + "b"] + else: + tfs += ["-hvx"] + return "-mattr=" + ",".join(tfs) if tfs else "" + + return target + mcpu + " " + create_target_features(config) + + # Simulator options string + def create_sim_options(cpu_ver, config): + """ Create simulator option string. """ + + def validate_hvx_length(codegen_hvx, sim_options): + if sim_options and "--hvx_length" in sim_options: # If --hvx_length was specified, check HVX length of sim # vs codegen - i = sim_args.index("hvx_length") + len("hvx_length") + 1 - sim_hvx = sim_args[i : i + 3] + i = sim_options.index("hvx_length") + len("hvx_length") + 1 + sim_hvx = sim_options[i : i + 3] if sim_hvx != str(codegen_hvx): - print( - "WARNING: sim hvx {} and codegen hvx {} mismatch!".format( - sim_hvx, codegen_hvx - ) - ) + msg = "sim hvx {} and codegen hvx {} mismatch!".format(sim_hvx, codegen_hvx) + # Set the stacklevel to the tvm.target.hexagon() call. + warnings.warn(msg, stacklevel=4) elif codegen_hvx != 0: # If --hvx_length was not given, add it if HVX is enabled - sim_args = sim_args + " " if isinstance(sim_args, str) else "" - sim_args += "--hvx_length " + str(codegen_hvx) - return sim_args or "" + sim_options = sim_options + " " if isinstance(sim_options, str) else "" + sim_options += "--hvx_length " + str(codegen_hvx) + return sim_options or "" - if not sim_args: - return cpu_ver + " " + validate_hvx_length(hvx, sim_args) + hvx = config["hvx"] + sim_options = config["sim_options"] + if not sim_options: + return cpu_ver + " " + validate_hvx_length(hvx, sim_options) sim_cpu = cpu_ver + " " # Add user defined args - if isinstance(sim_args, list): - sim_args = " ".join(sim_args) + if isinstance(sim_options, list): + sim_options = " ".join(sim_options) # Check for supplied sim cpu version - if "v6" in sim_args: + if "v6" in sim_options: sim_cpu = "" # Regex match for allowed cpus @@ -487,13 +533,13 @@ def validate_hvx_length(codegen_hvx, sim_args): + r"(?Pv6[25678])(?P[a-z])?" + r"(?P_[0-9]+)?(?P_rev[0-9])?\s?(?P--.*)?" ) - m = re.match(valid_cpu_str_regex, sim_args.lower()) + m = re.match(valid_cpu_str_regex, sim_options.lower()) if not m: - raise ValueError('Invalid simulator argument string "{}"'.format(sim_args)) + raise ValueError('Invalid simulator argument string "{}"'.format(sim_options)) # Parse options into correct order cpu_attr = {x: str(m.groupdict()[x] or "") for x in m.groupdict()} - sim_args = ( + sim_options = ( cpu_attr["base_version"] + cpu_attr["sub_version"] + cpu_attr["l2_size"] @@ -503,23 +549,27 @@ def validate_hvx_length(codegen_hvx, sim_args): + cpu_attr["post"] ) - return sim_cpu + " " + validate_hvx_length(hvx, sim_args) + return sim_cpu + " " + validate_hvx_length(hvx, sim_options) + + # LLVM options string + def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument + """ Create LLVM options string. """ + + llvm_options = config["llvm_options"] - # LLVM string - def create_llvm(llvm_args): # TVM's option parser doesn't allow '=' in values, but '=' can # appear in LLVM flags. Replace it with '@', since it's unlikely # that '@' will be used in another context. - if llvm_args is None or len(llvm_args.replace(" ", "")) == 0: + if llvm_options is None or len(llvm_options.strip()) == 0: return "" - args = [s.replace("=", "@") for s in llvm_args.split()] + args = [s.replace("=", "@") for s in llvm_options.split()] return "--llvm-options=" + ",".join(args) # Sim args - os.environ["HEXAGON_SIM_ARGS"] = create_sim(cpu_ver, sim_args) + os.environ["HEXAGON_SIM_ARGS"] = create_sim_options(cpu_ver, config) - target_str = create_target(cpu_ver) - llvm_str = create_llvm(llvm_args) + target_str = create_llvm_target(cpu_ver, config) + llvm_str = create_llvm_options(cpu_ver, config) args_list = target_str.split() + llvm_str.split() return Target(" ".join(["hexagon"] + args_list)) diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 7bb85e3da83c..442aeb6f1027 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -207,8 +207,7 @@ def wrap_up_realize(self, node, body): _domain = [Range.from_min_extent(0, i) for i in _buf.shape] _dtype = _buf.dtype _true = tvm.runtime.convert(True) - body = tvm.tir.ProducerRealize(_buf, _domain, _true, body) - body = tvm.tir.AttrStmt(_buf.op, "realize_scope", tvm.runtime.convert(_scope), body) + body = tvm.tir.ProducerRealize(_buf, _domain, _true, body, tvm.runtime.convert(_scope)) for elem in to_pop: self.symbols.pop(elem) diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py new file mode 100644 index 000000000000..f610c6ecc0db --- /dev/null +++ b/python/tvm/testing/__init__.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=redefined-builtin, wildcard-import +"""Utility Python functions for TVM testing""" +from .utils import assert_allclose, assert_prim_expr_equal, check_bool_expr_is_true +from .utils import check_int_constraints_trans_consistency, check_numerical_grads +from .utils import device_enabled, enabled_targets, exclude_targets +from .utils import fixture, parameter, parameters, parametrize_targets, uses_gpu +from .utils import known_failing_targets, requires_cuda, requires_cudagraph +from .utils import requires_gpu, requires_llvm, requires_rocm, requires_rpc +from .utils import requires_tensorcore, requires_metal, requires_micro, requires_opencl +from .utils import identity_after, terminate_self + +from ._ffi_api import nop, echo, device_test, run_check_signal, object_use_count +from ._ffi_api import test_wrap_callback, test_raise_error_callback, test_check_eq_callback +from ._ffi_api import ErrorTest, FrontendTestModule, identity_cpp + +from .popen_pool import initializer, after_initializer, register_ffi, call_cpp_ffi +from .popen_pool import call_py_ffi, call_cpp_py_ffi + +from . import auto_scheduler diff --git a/apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf b/python/tvm/testing/_ffi_api.py similarity index 76% rename from apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf rename to python/tvm/testing/_ffi_api.py index 3916b17c49cf..56a77223b767 100644 --- a/apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf +++ b/python/tvm/testing/_ffi_api.py @@ -14,15 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -# This file is specific to the MPS2-AN512 board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y +"""FFI APIs for tvm.testing""" +import tvm._ffi -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y -# For debugging. -CONFIG_LED=n +tvm._ffi._init_api("testing", __name__) diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/python/tvm/testing/auto_scheduler.py similarity index 99% rename from tests/python/unittest/test_auto_scheduler_common.py rename to python/tvm/testing/auto_scheduler.py index 4890268c907b..bc335c82d324 100644 --- a/tests/python/unittest/test_auto_scheduler_common.py +++ b/python/tvm/testing/auto_scheduler.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +# pylint: disable=invalid-name, missing-function-docstring """Common functions for auto_scheduler test cases""" import tvm from tvm import auto_scheduler, te, topi diff --git a/python/tvm/testing/plugin.py b/python/tvm/testing/plugin.py new file mode 100644 index 000000000000..06b4fa4f65eb --- /dev/null +++ b/python/tvm/testing/plugin.py @@ -0,0 +1,294 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pytest plugin for using tvm testing extensions. + +TVM provides utilities for testing across all supported targets, and +to more easily parametrize across many inputs. For more information +on usage of these features, see documentation in the tvm.testing +module. + +These are enabled by default in all pytests provided by tvm, but may +be useful externally for one-off testing. To enable, add the +following line to the test script, or to the conftest.py in the same +directory as the test scripts. + + pytest_plugins = ['tvm.testing.plugin'] + +""" + +import collections + +import pytest +import _pytest + +import tvm +from tvm.testing import utils + + +MARKERS = { + "gpu": "mark a test as requiring a gpu", + "tensorcore": "mark a test as requiring a tensorcore", + "cuda": "mark a test as requiring cuda", + "opencl": "mark a test as requiring opencl", + "rocm": "mark a test as requiring rocm", + "vulkan": "mark a test as requiring vulkan", + "metal": "mark a test as requiring metal", + "llvm": "mark a test as requiring llvm", +} + + +def pytest_configure(config): + """Runs at pytest configure time, defines marks to be used later.""" + + for markername, desc in MARKERS.items(): + config.addinivalue_line("markers", "{}: {}".format(markername, desc)) + + print("enabled targets:", "; ".join(map(lambda x: x[0], utils.enabled_targets()))) + print("pytest marker:", config.option.markexpr) + + +def pytest_generate_tests(metafunc): + """Called once per unit test, modifies/parametrizes it as needed.""" + _parametrize_correlated_parameters(metafunc) + _auto_parametrize_target(metafunc) + + +def pytest_collection_modifyitems(config, items): + """Called after all tests are chosen, currently used for bookkeeping.""" + # pylint: disable=unused-argument + _count_num_fixture_uses(items) + _remove_global_fixture_definitions(items) + + +@pytest.fixture +def dev(target): + """Give access to the device to tests that need it.""" + return tvm.device(target) + + +def pytest_sessionfinish(session, exitstatus): + # Don't exit with an error if we select a subset of tests that doesn't + # include anything + if session.config.option.markexpr != "": + if exitstatus == pytest.ExitCode.NO_TESTS_COLLECTED: + session.exitstatus = pytest.ExitCode.OK + + +def _auto_parametrize_target(metafunc): + """Automatically applies parametrize_targets + + Used if a test function uses the "target" fixture, but isn't + already marked with @tvm.testing.parametrize_targets. Intended + for use in the pytest_generate_tests() handler of a conftest.py + file. + + """ + + def update_parametrize_target_arg( + argnames, + argvalues, + *args, + **kwargs, + ): + args = [arg.strip() for arg in argnames.split(",") if arg.strip()] + if "target" in args: + target_i = args.index("target") + + new_argvalues = [] + for argvalue in argvalues: + + if isinstance(argvalue, _pytest.mark.structures.ParameterSet): + # The parametrized value is already a + # pytest.param, so track any marks already + # defined. + param_set = argvalue.values + target = param_set[target_i] + additional_marks = argvalue.marks + elif len(args) == 1: + # Single value parametrization, argvalue is a list of values. + target = argvalue + param_set = (target,) + additional_marks = [] + else: + # Multiple correlated parameters, argvalue is a list of tuple of values. + param_set = argvalue + target = param_set[target_i] + additional_marks = [] + + new_argvalues.append( + pytest.param( + *param_set, marks=_target_to_requirement(target) + additional_marks + ) + ) + + try: + argvalues[:] = new_argvalues + except TypeError as err: + pyfunc = metafunc.definition.function + filename = pyfunc.__code__.co_filename + line_number = pyfunc.__code__.co_firstlineno + msg = ( + f"Unit test {metafunc.function.__name__} ({filename}:{line_number}) " + "is parametrized using a tuple of parameters instead of a list " + "of parameters." + ) + raise TypeError(msg) from err + + if "target" in metafunc.fixturenames: + # Update any explicit use of @pytest.mark.parmaetrize to + # parametrize over targets. This adds the appropriate + # @tvm.testing.requires_* markers for each target. + for mark in metafunc.definition.iter_markers("parametrize"): + update_parametrize_target_arg(*mark.args, **mark.kwargs) + + # Check if any explicit parametrizations exist, and apply one + # if they do not. If the function is marked with either + # excluded or known failing targets, use these to determine + # the targets to be used. + parametrized_args = [ + arg.strip() + for mark in metafunc.definition.iter_markers("parametrize") + for arg in mark.args[0].split(",") + ] + if "target" not in parametrized_args: + excluded_targets = getattr(metafunc.function, "tvm_excluded_targets", []) + xfail_targets = getattr(metafunc.function, "tvm_known_failing_targets", []) + metafunc.parametrize( + "target", + _pytest_target_params(None, excluded_targets, xfail_targets), + scope="session", + ) + + +def _count_num_fixture_uses(items): + # Helper function, counts the number of tests that use each cached + # fixture. Should be called from pytest_collection_modifyitems(). + for item in items: + is_skipped = item.get_closest_marker("skip") or any( + mark.args[0] for mark in item.iter_markers("skipif") + ) + if is_skipped: + continue + + for fixturedefs in item._fixtureinfo.name2fixturedefs.values(): + # Only increment the active fixturedef, in a name has been overridden. + fixturedef = fixturedefs[-1] + if hasattr(fixturedef.func, "num_tests_use_this_fixture"): + fixturedef.func.num_tests_use_this_fixture[0] += 1 + + +def _remove_global_fixture_definitions(items): + # Helper function, removes fixture definitions from the global + # variables of the modules they were defined in. This is intended + # to improve readability of error messages by giving a NameError + # if a test function accesses a pytest fixture but doesn't include + # it as an argument. Should be called from + # pytest_collection_modifyitems(). + + modules = set(item.module for item in items) + + for module in modules: + for name in dir(module): + obj = getattr(module, name) + if hasattr(obj, "_pytestfixturefunction") and isinstance( + obj._pytestfixturefunction, _pytest.fixtures.FixtureFunctionMarker + ): + delattr(module, name) + + +def _pytest_target_params(targets, excluded_targets=None, xfail_targets=None): + # Include unrunnable targets here. They get skipped by the + # pytest.mark.skipif in _target_to_requirement(), showing up as + # skipped tests instead of being hidden entirely. + if targets is None: + if excluded_targets is None: + excluded_targets = set() + + if xfail_targets is None: + xfail_targets = set() + + target_marks = [] + for t in utils._get_targets(): + # Excluded targets aren't included in the params at all. + if t["target_kind"] not in excluded_targets: + + # Known failing targets are included, but are marked + # as expected to fail. + extra_marks = [] + if t["target_kind"] in xfail_targets: + extra_marks.append( + pytest.mark.xfail( + reason='Known failing test for target "{}"'.format(t["target_kind"]) + ) + ) + + target_marks.append((t["target"], extra_marks)) + + else: + target_marks = [(target, []) for target in targets] + + return [ + pytest.param(target, marks=_target_to_requirement(target) + extra_marks) + for target, extra_marks in target_marks + ] + + +def _target_to_requirement(target): + if isinstance(target, str): + target = tvm.target.Target(target) + + # mapping from target to decorator + if target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []): + return utils.requires_cudnn() + if target.kind.name == "cuda": + return utils.requires_cuda() + if target.kind.name == "rocm": + return utils.requires_rocm() + if target.kind.name == "vulkan": + return utils.requires_vulkan() + if target.kind.name == "nvptx": + return utils.requires_nvptx() + if target.kind.name == "metal": + return utils.requires_metal() + if target.kind.name == "opencl": + return utils.requires_opencl() + if target.kind.name == "llvm": + return utils.requires_llvm() + return [] + + +def _parametrize_correlated_parameters(metafunc): + parametrize_needed = collections.defaultdict(list) + + for name, fixturedefs in metafunc.definition._fixtureinfo.name2fixturedefs.items(): + fixturedef = fixturedefs[-1] + if hasattr(fixturedef.func, "parametrize_group") and hasattr( + fixturedef.func, "parametrize_values" + ): + group = fixturedef.func.parametrize_group + values = fixturedef.func.parametrize_values + parametrize_needed[group].append((name, values)) + + for parametrize_group in parametrize_needed.values(): + if len(parametrize_group) == 1: + name, values = parametrize_group[0] + metafunc.parametrize(name, values, indirect=True) + else: + names = ",".join(name for name, values in parametrize_group) + value_sets = zip(*[values for name, values in parametrize_group]) + metafunc.parametrize(names, value_sets, indirect=True) diff --git a/python/tvm/testing/popen_pool.py b/python/tvm/testing/popen_pool.py new file mode 100644 index 000000000000..20345a2218fe --- /dev/null +++ b/python/tvm/testing/popen_pool.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, missing-function-docstring +"""Common functions for popen_pool test cases""" +import tvm + +TEST_GLOBAL_STATE_1 = 0 +TEST_GLOBAL_STATE_2 = 0 +TEST_GLOBAL_STATE_3 = 0 + + +def initializer(test_global_state_1, test_global_state_2, test_global_state_3): + global TEST_GLOBAL_STATE_1, TEST_GLOBAL_STATE_2, TEST_GLOBAL_STATE_3 + TEST_GLOBAL_STATE_1 = test_global_state_1 + TEST_GLOBAL_STATE_2 = test_global_state_2 + TEST_GLOBAL_STATE_3 = test_global_state_3 + + +def after_initializer(): + global TEST_GLOBAL_STATE_1, TEST_GLOBAL_STATE_2, TEST_GLOBAL_STATE_3 + return TEST_GLOBAL_STATE_1, TEST_GLOBAL_STATE_2, TEST_GLOBAL_STATE_3 + + +@tvm._ffi.register_func("testing.identity_py") +def identity_py(arg): + return arg + + +def register_ffi(): + @tvm._ffi.register_func("testing.nested_identity_py") + def _identity_py(arg): # pylint: disable=unused-variable + return arg + + +def call_py_ffi(arg): + _identity_py = tvm._ffi.get_global_func("testing.nested_identity_py") + return _identity_py(arg) + + +def call_cpp_ffi(arg): + return tvm.testing.echo(arg) + + +def call_cpp_py_ffi(arg): + return tvm.testing.identity_cpp(arg) diff --git a/python/tvm/testing.py b/python/tvm/testing/utils.py similarity index 83% rename from python/tvm/testing.py rename to python/tvm/testing/utils.py index 4721c0050656..6f115f8da58c 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing/utils.py @@ -16,7 +16,14 @@ # under the License. # pylint: disable=invalid-name,unnecessary-comprehension -""" TVM testing utilities +"""TVM testing utilities + +Organization +************ + +This file contains functions expected to be called directly by a user +while writing unit tests. Integrations with the pytest framework +are in plugin.py. Testing Markers *************** @@ -53,9 +60,11 @@ def test_something(): fpgas), we need to add a new marker in `tests/python/pytest.ini` and a new function in this module. Then targets using this node should be added to the `TVM_TEST_TARGETS` environment variable in the CI. + """ -import collections import copy +import copyreg +import ctypes import functools import logging import os @@ -63,14 +72,14 @@ def test_something(): import time import pickle import pytest -import _pytest import numpy as np import tvm import tvm.arith import tvm.tir import tvm.te import tvm._ffi -from tvm.contrib import nvcc + +from tvm.contrib import nvcc, cudnn from tvm.error import TVMError @@ -78,7 +87,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): """Version of np.testing.assert_allclose with `atol` and `rtol` fields set in reasonable defaults. - Arguments `actual` and `desired` are not interchangable, since the function + Arguments `actual` and `desired` are not interchangeable, since the function compares the `abs(actual-desired)` with `atol+rtol*abs(desired)`. Since we often allow `desired` to be close to zero, we generally want non-zero `atol`. """ @@ -375,17 +384,24 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): def _get_targets(target_str=None): if target_str is None: target_str = os.environ.get("TVM_TEST_TARGETS", "") + # Use dict instead of set for de-duplication so that the + # targets stay in the order specified. + target_names = list({t.strip(): None for t in target_str.split(";") if t.strip()}) - if len(target_str) == 0: - target_str = DEFAULT_TEST_TARGETS - - target_names = set(t.strip() for t in target_str.split(";") if t.strip()) + if not target_names: + target_names = DEFAULT_TEST_TARGETS targets = [] for target in target_names: target_kind = target.split()[0] - is_enabled = tvm.runtime.enabled(target_kind) - is_runnable = is_enabled and tvm.device(target_kind).exist + + if target_kind == "cuda" and "cudnn" in tvm.target.Target(target).attrs.get("libs", []): + is_enabled = tvm.support.libinfo()["USE_CUDNN"].lower() in ["on", "true", "1"] + is_runnable = is_enabled and cudnn.exists() + else: + is_enabled = tvm.runtime.enabled(target_kind) + is_runnable = is_enabled and tvm.device(target_kind).exist + targets.append( { "target": target, @@ -413,10 +429,19 @@ def _get_targets(target_str=None): return targets -DEFAULT_TEST_TARGETS = ( - "llvm;cuda;opencl;metal;rocm;vulkan -from_device=0;nvptx;" - "llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu" -) +DEFAULT_TEST_TARGETS = [ + "llvm", + "llvm -device=arm_cpu", + "cuda", + "cuda -model=unknown -libs=cudnn", + "nvptx", + "vulkan -from_device=0", + "opencl", + "opencl -device=mali,aocl_sw_emu", + "opencl -device=intel_graphics", + "metal", + "rocm", +] def device_enabled(target): @@ -548,6 +573,26 @@ def requires_cuda(*args): return _compose(args, _requires_cuda) +def requires_cudnn(*args): + """Mark a test as requiring the cuDNN library. + + This also marks the test as requiring a cuda gpu. + + Parameters + ---------- + f : function + Function to mark + """ + + requirements = [ + pytest.mark.skipif( + not cudnn.exists(), reason="cuDNN library not enabled, or not installed" + ), + *requires_cuda(), + ] + return _compose(args, requirements) + + def requires_nvptx(*args): """Mark a test as requiring the NVPTX compilation on the CUDA runtime @@ -729,90 +774,6 @@ def requires_rpc(*args): return _compose(args, _requires_rpc) -def _target_to_requirement(target): - # mapping from target to decorator - if target.startswith("cuda"): - return requires_cuda() - if target.startswith("rocm"): - return requires_rocm() - if target.startswith("vulkan"): - return requires_vulkan() - if target.startswith("nvptx"): - return requires_nvptx() - if target.startswith("metal"): - return requires_metal() - if target.startswith("opencl"): - return requires_opencl() - if target.startswith("llvm"): - return requires_llvm() - return [] - - -def _pytest_target_params(targets, excluded_targets=None, xfail_targets=None): - # Include unrunnable targets here. They get skipped by the - # pytest.mark.skipif in _target_to_requirement(), showing up as - # skipped tests instead of being hidden entirely. - if targets is None: - if excluded_targets is None: - excluded_targets = set() - - if xfail_targets is None: - xfail_targets = set() - - target_marks = [] - for t in _get_targets(): - # Excluded targets aren't included in the params at all. - if t["target_kind"] not in excluded_targets: - - # Known failing targets are included, but are marked - # as expected to fail. - extra_marks = [] - if t["target_kind"] in xfail_targets: - extra_marks.append( - pytest.mark.xfail( - reason='Known failing test for target "{}"'.format(t["target_kind"]) - ) - ) - - target_marks.append((t["target"], extra_marks)) - - else: - target_marks = [(target, []) for target in targets] - - return [ - pytest.param(target, marks=_target_to_requirement(target) + extra_marks) - for target, extra_marks in target_marks - ] - - -def _auto_parametrize_target(metafunc): - """Automatically applies parametrize_targets - - Used if a test function uses the "target" fixture, but isn't - already marked with @tvm.testing.parametrize_targets. Intended - for use in the pytest_generate_tests() handler of a conftest.py - file. - - """ - if "target" in metafunc.fixturenames: - parametrized_args = [ - arg.strip() - for mark in metafunc.definition.iter_markers("parametrize") - for arg in mark.args[0].split(",") - ] - - if "target" not in parametrized_args: - # Check if the function is marked with either excluded or - # known failing targets. - excluded_targets = getattr(metafunc.function, "tvm_excluded_targets", []) - xfail_targets = getattr(metafunc.function, "tvm_known_failing_targets", []) - metafunc.parametrize( - "target", - _pytest_target_params(None, excluded_targets, xfail_targets), - scope="session", - ) - - def parametrize_targets(*args): """Parametrize a test over a specific set of targets. @@ -849,17 +810,14 @@ def parametrize_targets(*args): >>> ... # do something """ - def wrap(targets): - def func(f): - return pytest.mark.parametrize( - "target", _pytest_target_params(targets), scope="session" - )(f) - - return func - + # Backwards compatibility, when used as a decorator with no + # arguments implicitly parametrizes over "target". The + # parametrization is now handled by _auto_parametrize_target, so + # this use case can just return the decorated function. if len(args) == 1 and callable(args[0]): - return wrap(None)(args[0]) - return wrap(args) + return args[0] + + return pytest.mark.parametrize("target", list(args), scope="session") def exclude_targets(*args): @@ -1065,28 +1023,6 @@ def fixture_func(*_cls, request): return outputs -def _parametrize_correlated_parameters(metafunc): - parametrize_needed = collections.defaultdict(list) - - for name, fixturedefs in metafunc.definition._fixtureinfo.name2fixturedefs.items(): - fixturedef = fixturedefs[-1] - if hasattr(fixturedef.func, "parametrize_group") and hasattr( - fixturedef.func, "parametrize_values" - ): - group = fixturedef.func.parametrize_group - values = fixturedef.func.parametrize_values - parametrize_needed[group].append((name, values)) - - for parametrize_group in parametrize_needed.values(): - if len(parametrize_group) == 1: - name, values = parametrize_group[0] - metafunc.parametrize(name, values, indirect=True) - else: - names = ",".join(name for name, values in parametrize_group) - value_sets = zip(*[values for name, values in parametrize_group]) - metafunc.parametrize(names, value_sets, indirect=True) - - def fixture(func=None, *, cache_return_value=False): """Convenience function to define pytest fixtures. @@ -1160,13 +1096,69 @@ def wraps(func): return wraps(func) +class _DeepCopyAllowedClasses(dict): + def __init__(self, allowed_class_list): + self.allowed_class_list = allowed_class_list + super().__init__() + + def get(self, key, *args, **kwargs): + """Overrides behavior of copy.deepcopy to avoid implicit copy. + + By default, copy.deepcopy uses a dict of id->object to track + all objects that it has seen, which is passed as the second + argument to all recursive calls. This class is intended to be + passed in instead, and inspects the type of all objects being + copied. + + Where copy.deepcopy does a best-effort attempt at copying an + object, for unit tests we would rather have all objects either + be copied correctly, or to throw an error. Classes that + define an explicit method to perform a copy are allowed, as + are any explicitly listed classes. Classes that would fall + back to using object.__reduce__, and are not explicitly listed + as safe, will throw an exception. + + """ + obj = ctypes.cast(key, ctypes.py_object).value + cls = type(obj) + if ( + cls in copy._deepcopy_dispatch + or issubclass(cls, type) + or getattr(obj, "__deepcopy__", None) + or copyreg.dispatch_table.get(cls) + or cls.__reduce__ is not object.__reduce__ + or cls.__reduce_ex__ is not object.__reduce_ex__ + or cls in self.allowed_class_list + ): + return super().get(key, *args, **kwargs) + + rfc_url = ( + "https://github.com/apache/tvm-rfcs/blob/main/rfcs/0007-parametrized-unit-tests.md" + ) + raise TypeError( + ( + f"Cannot copy fixture of type {cls.__name__}. TVM fixture caching " + "is limited to objects that explicitly provide the ability " + "to be copied (e.g. through __deepcopy__, __getstate__, or __setstate__)," + "and forbids the use of the default `object.__reduce__` and " + "`object.__reduce_ex__`. For third-party classes that are " + "safe to use with copy.deepcopy, please add the class to " + "the arguments of _DeepCopyAllowedClasses in tvm.testing._fixture_cache.\n" + "\n" + f"For discussion on this restriction, please see {rfc_url}." + ) + ) + + def _fixture_cache(func): cache = {} # Can't use += on a bound method's property. Therefore, this is a # list rather than a variable so that it can be accessed from the # pytest_collection_modifyitems(). - num_uses_remaining = [0] + num_tests_use_this_fixture = [0] + + num_times_fixture_used = 0 # Using functools.lru_cache would require the function arguments # to be hashable, which wouldn't allow caching fixtures that @@ -1191,6 +1183,14 @@ def get_cache_key(*args, **kwargs): @functools.wraps(func) def wrapper(*args, **kwargs): + if num_tests_use_this_fixture[0] == 0: + raise RuntimeError( + "Fixture use count is 0. " + "This can occur if tvm.testing.plugin isn't registered. " + "If using outside of the TVM test directory, " + "please add `pytest_plugins = ['tvm.testing.plugin']` to your conftest.py" + ) + try: cache_key = get_cache_key(*args, **kwargs) @@ -1199,68 +1199,29 @@ def wrapper(*args, **kwargs): except KeyError: cached_value = cache[cache_key] = func(*args, **kwargs) - try: - yield copy.deepcopy(cached_value) - except TypeError as e: - rfc_url = ( - "https://github.com/apache/tvm-rfcs/blob/main/rfcs/" - "0007-parametrized-unit-tests.md#unresolved-questions" - ) - message = ( - "TVM caching of fixtures can only be used on serializable data types, not {}.\n" - "Please see {} for details/discussion." - ).format(type(cached_value), rfc_url) - raise TypeError(message) from e + yield copy.deepcopy( + cached_value, + # allowed_class_list should be a list of classes that + # are safe to copy using copy.deepcopy, but do not + # implement __deepcopy__, __reduce__, or + # __reduce_ex__. + _DeepCopyAllowedClasses(allowed_class_list=[]), + ) finally: # Clear the cache once all tests that use a particular fixture # have completed. - num_uses_remaining[0] -= 1 - if not num_uses_remaining[0]: + nonlocal num_times_fixture_used + num_times_fixture_used += 1 + if num_times_fixture_used >= num_tests_use_this_fixture[0]: cache.clear() - # Set in the pytest_collection_modifyitems() - wrapper.num_uses_remaining = num_uses_remaining + # Set in the pytest_collection_modifyitems(), by _count_num_fixture_uses + wrapper.num_tests_use_this_fixture = num_tests_use_this_fixture return wrapper -def _count_num_fixture_uses(items): - # Helper function, counts the number of tests that use each cached - # fixture. Should be called from pytest_collection_modifyitems(). - for item in items: - is_skipped = item.get_closest_marker("skip") or any( - mark.args[0] for mark in item.iter_markers("skipif") - ) - if is_skipped: - continue - - for fixturedefs in item._fixtureinfo.name2fixturedefs.values(): - # Only increment the active fixturedef, in a name has been overridden. - fixturedef = fixturedefs[-1] - if hasattr(fixturedef.func, "num_uses_remaining"): - fixturedef.func.num_uses_remaining[0] += 1 - - -def _remove_global_fixture_definitions(items): - # Helper function, removes fixture definitions from the global - # variables of the modules they were defined in. This is intended - # to improve readability of error messages by giving a NameError - # if a test function accesses a pytest fixture but doesn't include - # it as an argument. Should be called from - # pytest_collection_modifyitems(). - - modules = set(item.module for item in items) - - for module in modules: - for name in dir(module): - obj = getattr(module, name) - if hasattr(obj, "_pytestfixturefunction") and isinstance( - obj._pytestfixturefunction, _pytest.fixtures.FixtureFunctionMarker - ): - delattr(module, name) - - def identity_after(x, sleep): """Testing function to return identity after sleep @@ -1285,6 +1246,3 @@ def identity_after(x, sleep): def terminate_self(): """Testing function to terminate the process.""" sys.exit(-1) - - -tvm._ffi._init_api("testing", __name__) diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 030918a5e18e..500195ac9a13 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -16,13 +16,17 @@ # under the License. """Wrapping existing analysis utils.""" # pylint: disable=invalid-name -from typing import Dict +from typing import Dict, List + +from tvm.tir.stmt import Block, BufferRegion +from tvm.tir.stmt import PrimExpr +from tvm.tir.expr import Var from . import _ffi_api from ..function import PrimFunc from .. import Buffer, Stmt -def expr_deep_equal(lhs, rhs): +def expr_deep_equal(lhs: PrimExpr, rhs: PrimExpr) -> bool: """Deeply compare two nested expressions. Parameters @@ -56,10 +60,10 @@ def expr_deep_equal(lhs, rhs): -------- tvm.ir.structural_equal """ - return _ffi_api.expr_deep_equal(lhs, rhs) + return _ffi_api.expr_deep_equal(lhs, rhs) # type: ignore -def verify_ssa(func): +def verify_ssa(func: PrimFunc) -> bool: """Verify if the func is in SSA form. Parameters @@ -72,10 +76,10 @@ def verify_ssa(func): result : bool The result of verification. """ - return _ffi_api.verify_ssa(func) + return _ffi_api.verify_ssa(func) # type: ignore -def verify_memory(func): +def verify_memory(func: PrimFunc) -> bool: """Verify if func contains illegal host side direct memory access. Parameters @@ -88,10 +92,10 @@ def verify_memory(func): result : bool The result of verification. """ - return _ffi_api.verify_memory(func) + return _ffi_api.verify_memory(func) # type: ignore -def verify_gpu_code(func, constraints): +def verify_gpu_code(func: PrimFunc, constraints: Dict[str, int]) -> None: """Verify if module contains illegal host side direct memory access. Parameters @@ -107,10 +111,12 @@ def verify_gpu_code(func, constraints): result : bool The result of verification. """ - return _ffi_api.verify_gpu_code(func, constraints) + return _ffi_api.verify_gpu_code(func, constraints) # type: ignore -def get_block_access_region(block, buffer_var_map): +def get_block_access_region( + block: Block, buffer_var_map: Dict[Var, Buffer] +) -> List[List[BufferRegion]]: """Detect which regions of tensors in this block are read or written to. Regions are sorted by order of appearance in the AST. @@ -130,10 +136,10 @@ def get_block_access_region(block, buffer_var_map): - second: write regions - third: opaque regions """ - return _ffi_api.get_block_access_region(block, buffer_var_map) + return _ffi_api.get_block_access_region(block, buffer_var_map) # type: ignore -def calculate_workspace_bytes(func: PrimFunc, workspace_byte_alignment: int): +def calculate_workspace_bytes(func: PrimFunc, workspace_byte_alignment: int) -> int: """Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc. @@ -149,7 +155,7 @@ def calculate_workspace_bytes(func: PrimFunc, workspace_byte_alignment: int): result : int Workspace size in bytes. """ - return _ffi_api.calculate_workspace_bytes(func, workspace_byte_alignment) + return _ffi_api.calculate_workspace_bytes(func, workspace_byte_alignment) # type: ignore def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: @@ -167,4 +173,4 @@ def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: result : Dict[Buffer, Stmt] Map from buffer to the LCA of all access to it. """ - return _ffi_api.detect_buffer_access_lca(func) # pylint: disable=no-member + return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index d905a53b3303..6dddd7b119a0 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -90,7 +90,9 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0): raise ValueError("Unknown access_mask %s" % access_mask) access_mask = mask offset = convert(offset) - return _ffi_api.BufferAccessPtr(self, access_mask, ptr_type, content_lanes, offset) + return _ffi_api.BufferAccessPtr( + self, access_mask, ptr_type, content_lanes, offset # type: ignore + ) def vload(self, begin, dtype=None): """Generate an Expr that loads dtype from begin index. @@ -111,7 +113,7 @@ def vload(self, begin, dtype=None): """ begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin dtype = dtype if dtype else self.dtype - return _ffi_api.BufferVLoad(self, begin, dtype) + return _ffi_api.BufferVLoad(self, begin, dtype) # type: ignore def vstore(self, begin, value): """Generate a Stmt that store value into begin index. @@ -130,7 +132,16 @@ def vstore(self, begin, value): The corresponding store stmt. """ begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin - return _ffi_api.BufferVStore(self, begin, value) + return _ffi_api.BufferVStore(self, begin, value) # type: ignore + + def scope(self): + """Return the storage scope associated with this buffer. + Returns + ------- + scope : str + The storage scope associated with this buffer. + """ + return _ffi_api.BufferStorageScope(self) # type: ignore def decl_buffer( @@ -244,21 +255,20 @@ def decl_buffer( dtype = "float32" if dtype is None else dtype strides = () if strides is None else strides if offset_factor != 0 and elem_offset is None: - shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" + shape_dtype = shape[0].dtype if shape and hasattr(shape[0], "dtype") else "int32" elem_offset = Var("%s_elem_offset" % name, shape_dtype) if data is None: # Bool is represented as uint1 in the IR, but stored as int8 storage_type = PrimType(dtype) storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type - data = Var(name, PointerType(storage_type), span) - return _ffi_api.Buffer( + data = Var(name, PointerType(storage_type, scope), span) + return _ffi_api.Buffer( # type: ignore data, dtype, shape, strides, elem_offset, name, - scope, data_alignment, offset_factor, buffer_type, diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py index 40805d9beb8e..f46a154612e1 100644 --- a/python/tvm/tir/data_layout.py +++ b/python/tvm/tir/data_layout.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Data layout.""" +from typing import Union + import tvm._ffi from tvm.runtime import Object @@ -36,7 +38,7 @@ class Layout(Object): """ def __len__(self): - return _ffi_api.LayoutNdim(self) + return _ffi_api.LayoutNdim(self) # type: ignore def __contains__(self, axis): return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name @@ -44,7 +46,7 @@ def __contains__(self, axis): def __getitem__(self, index): if index >= len(self): raise IndexError("Layout index out of range") - return _ffi_api.LayoutGetItem(self, index) + return _ffi_api.LayoutGetItem(self, index) # type: ignore def index_of(self, axis): """Get the index of an axis @@ -59,7 +61,7 @@ def index_of(self, axis): index : int The index of the axis, -1 if not found. """ - return _ffi_api.LayoutIndexOf(self, axis) + return _ffi_api.LayoutIndexOf(self, axis) # type: ignore def factor_of(self, axis): """Get the factor size of the subordinate axis. @@ -76,7 +78,7 @@ def factor_of(self, axis): or the size of axis itself (if axis is a subordinate-axis). Return -1 if axis is not in the layout. """ - return _ffi_api.LayoutFactorOf(self, axis) + return _ffi_api.LayoutFactorOf(self, axis) # type: ignore @tvm._ffi.register_object("tir.BijectiveLayout") @@ -113,7 +115,7 @@ def forward_index(self, index): dst_index: Array of Expr The inferred indices in dst-layout. """ - return _ffi_api.BijectiveLayoutForwardIndex(self, index) + return _ffi_api.BijectiveLayoutForwardIndex(self, index) # type: ignore def backward_index(self, index): """Given the indices of the dst-layout, infer the src index. @@ -128,7 +130,7 @@ def backward_index(self, index): src_index: Array of Expr The inferred indices in src-layout. """ - return _ffi_api.BijectiveLayoutBackwardIndex(self, index) + return _ffi_api.BijectiveLayoutBackwardIndex(self, index) # type: ignore def forward_shape(self, shape): """Given the shape of the src-layout, infer the dst shape. @@ -143,7 +145,7 @@ def forward_shape(self, shape): dst_shape: Array of Expr The inferred shape in dst-layout. """ - return _ffi_api.BijectiveLayoutForwardShape(self, shape) + return _ffi_api.BijectiveLayoutForwardShape(self, shape) # type: ignore def backward_shape(self, shape): """Given the shape of the dst-layout, infer the src shape. @@ -158,10 +160,10 @@ def backward_shape(self, shape): src_shape: Array of Expr The inferred shape in src-layout. """ - return _ffi_api.BijectiveLayoutBackwardShape(self, shape) + return _ffi_api.BijectiveLayoutBackwardShape(self, shape) # type: ignore -def layout(layout_str): +def layout(layout_str: str) -> Layout: """Create a layout node from a string. Parameters @@ -180,10 +182,12 @@ def layout(layout_str): layout : Layout The created layout """ - return _ffi_api.Layout(layout_str) + return _ffi_api.Layout(layout_str) # type: ignore -def bijective_layout(src_layout, dst_layout): +def bijective_layout( + src_layout: Union[str, Layout], dst_layout: Union[str, Layout] +) -> BijectiveLayout: """Create a bijective layout mapping. Parameters @@ -203,4 +207,4 @@ def bijective_layout(src_layout, dst_layout): src_layout = layout(src_layout) if isinstance(dst_layout, str): dst_layout = layout(dst_layout) - return _ffi_api.BijectiveLayout(src_layout, dst_layout) + return _ffi_api.BijectiveLayout(src_layout, dst_layout) # type: ignore diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 286e4051da51..4ba8c5471b5d 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -27,7 +27,10 @@ assert(isinstance(y, tvm.tir.Add)) assert(y.a == x) """ +from typing import Optional, Union +from tvm import ir import tvm._ffi +from tvm.ir.base import Span from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const from tvm.ir import PrimExpr, Op @@ -47,13 +50,17 @@ def div_ambiguity_error(): def _dtype_is_int(value): if isinstance(value, int): return True - return isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.INT + return ( + isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.INT + ) # type: ignore def _dtype_is_float(value): if isinstance(value, float): return True - return isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.FLOAT + return ( + isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.FLOAT + ) # type: ignore class ExprOp(object): @@ -106,55 +113,55 @@ def __rfloordiv__(self, other): return _generic.floordiv(other, self, None) def __mod__(self, other): - return _ffi_api._OpFloorMod(self, other, None) + return _ffi_api._OpFloorMod(self, other, None) # type: ignore def __rmod__(self, other): - return _ffi_api._OpFloorMod(other, self, None) + return _ffi_api._OpFloorMod(other, self, None) # type: ignore def __neg__(self): - neg_one = const(-1, self.dtype) + neg_one = const(-1, self.dtype) # type: ignore return self.__mul__(neg_one) def __lshift__(self, other): - return _ffi_api.left_shift(self, other, None) + return _ffi_api.left_shift(self, other, None) # type: ignore def __rlshift__(self, other): - return _ffi_api.left_shift(other, self, None) + return _ffi_api.left_shift(other, self, None) # type: ignore def __rshift__(self, other): - return _ffi_api.right_shift(self, other, None) + return _ffi_api.right_shift(self, other, None) # type: ignore def __rrshift__(self, other): - return _ffi_api.right_shift(other, self, None) + return _ffi_api.right_shift(other, self, None) # type: ignore def __and__(self, other): - return _ffi_api.bitwise_and(self, other, None) + return _ffi_api.bitwise_and(self, other, None) # type: ignore def __rand__(self, other): - return _ffi_api.bitwise_and(other, self, None) + return _ffi_api.bitwise_and(other, self, None) # type: ignore def __or__(self, other): - return _ffi_api.bitwise_or(self, other, None) + return _ffi_api.bitwise_or(self, other, None) # type: ignore def __ror__(self, other): - return _ffi_api.bitwise_or(other, self, None) + return _ffi_api.bitwise_or(other, self, None) # type: ignore def __xor__(self, other): - return _ffi_api.bitwise_xor(self, other, None) + return _ffi_api.bitwise_xor(self, other, None) # type: ignore def __rxor__(self, other): - return _ffi_api.bitwise_xor(other, self, None) + return _ffi_api.bitwise_xor(other, self, None) # type: ignore def __invert__(self): if _dtype_is_float(self): raise RuntimeError("Cannot use ~ operator on float type Expr.") - return _ffi_api.bitwise_not(self, None) + return _ffi_api.bitwise_not(self, None) # type: ignore def __lt__(self, other): - return _ffi_api._OpLT(self, other, None) + return _ffi_api._OpLT(self, other, None) # type: ignore def __le__(self, other): - return _ffi_api._OpLE(self, other, None) + return _ffi_api._OpLE(self, other, None) # type: ignore def __eq__(self, other): return EqualOp(self, other) @@ -163,10 +170,10 @@ def __ne__(self, other): return NotEqualOp(self, other) def __gt__(self, other): - return _ffi_api._OpGT(self, other, None) + return _ffi_api._OpGT(self, other, None) # type: ignore def __ge__(self, other): - return _ffi_api._OpGE(self, other, None) + return _ffi_api._OpGE(self, other, None) # type: ignore def __nonzero__(self): raise ValueError( @@ -193,9 +200,9 @@ def equal(self, other, span=None): ret : PrimExpr The equality expression. """ - return _ffi_api._OpEQ(self, other, span) + return _ffi_api._OpEQ(self, other, span) # type: ignore - def astype(self, dtype, span=None): + def astype(self, dtype: str, span: Optional[Span] = None): """Cast the expression to other type. Parameters @@ -248,7 +255,7 @@ def __bool__(self): def asobject(self): """Convert object.""" - return _ffi_api._OpEQ(self.a, self.b, self.span) + return _ffi_api._OpEQ(self.a, self.b, self.span) # type: ignore class NotEqualOp(ObjectGeneric, ExprOp): @@ -285,7 +292,7 @@ def __bool__(self): def asobject(self): """Convert object.""" - return _ffi_api._OpNE(self.a, self.b, self.span) + return _ffi_api._OpNE(self.a, self.b, self.span) # type: ignore class IntImmEnum(ObjectGeneric): @@ -307,7 +314,7 @@ def __init__(self, value, span=None): def asobject(self): """Convert object.""" - return IntImm("int32", self.value, self.span) + return IntImm("int32", self.value, self.span) # type: ignore class PrimExprWithOp(ExprOp, PrimExpr): @@ -350,8 +357,8 @@ class Var(PrimExprWithOp): The location of this itervar in the source code. """ - def __init__(self, name, dtype, span=None): - self.__init_handle_by_constructor__(_ffi_api.Var, name, dtype, span) + def __init__(self, name: str, dtype: Union[str, ir.Type], span: Optional[Span] = None): + self.__init_handle_by_constructor__(_ffi_api.Var, name, dtype, span) # type: ignore @tvm._ffi.register_object("tir.SizeVar") @@ -373,7 +380,7 @@ class SizeVar(Var): # pylint: disable=super-init-not-called def __init__(self, name, dtype, span=None): - self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype, span) + self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype, span) # type: ignore @tvm._ffi.register_object("tir.IterVar") @@ -428,7 +435,9 @@ def __init__(self, dom, var, iter_type, thread_tag="", span=None): name = var if var is not None else "iter" dtype = "int32" if dom is None else dom.extent.dtype var = Var(name, dtype=dtype, span=span) if not isinstance(var, Var) else var - self.__init_handle_by_constructor__(_ffi_api.IterVar, dom, var, iter_type, thread_tag, span) + self.__init_handle_by_constructor__( + _ffi_api.IterVar, dom, var, iter_type, thread_tag, span # type: ignore + ) @tvm._ffi.register_object("tir.CommReducer") @@ -455,7 +464,7 @@ class CommReducer(Object): def __init__(self, lhs, rhs, result, identity_element, span=None): self.__init_handle_by_constructor__( - _ffi_api.CommReducer, lhs, rhs, result, identity_element, span + _ffi_api.CommReducer, lhs, rhs, result, identity_element, span # type: ignore ) @@ -489,7 +498,7 @@ class Reduce(PrimExprWithOp): def __init__(self, combiner, src, rdom, condition, value_index, init=None, span=None): self.__init_handle_by_constructor__( - _ffi_api.Reduce, combiner, src, rdom, condition, value_index, init, span + _ffi_api.Reduce, combiner, src, rdom, condition, value_index, init, span # type: ignore ) @@ -510,7 +519,9 @@ class FloatImm(ConstExpr): """ def __init__(self, dtype, value, span=None): - self.__init_handle_by_constructor__(tvm.ir._ffi_api.FloatImm, dtype, value, span) + self.__init_handle_by_constructor__( + tvm.ir._ffi_api.FloatImm, dtype, value, span # type: ignore + ) @tvm._ffi.register_object @@ -530,7 +541,9 @@ class IntImm(ConstExpr): """ def __init__(self, dtype, value, span=None): - self.__init_handle_by_constructor__(tvm.ir._ffi_api.IntImm, dtype, value, span) + self.__init_handle_by_constructor__( + tvm.ir._ffi_api.IntImm, dtype, value, span # type: ignore + ) def __hash__(self): return self.value @@ -542,16 +555,16 @@ def __nonzero__(self): return self.value != 0 def __eq__(self, other): - return _ffi_api._OpEQ(self, other, None) + return _ffi_api._OpEQ(self, other, None) # type: ignore def __ne__(self, other): - return _ffi_api._OpNE(self, other, None) + return _ffi_api._OpNE(self, other, None) # type: ignore def __bool__(self): return self.__nonzero__() -@tvm._ffi.register_object("tir.StringImm") +@tvm._ffi.register_object("tir.StringImm") # type: ignore class StringImm(ConstExpr): """String constant. @@ -565,7 +578,7 @@ class StringImm(ConstExpr): """ def __init__(self, value, span=None): - self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span) + self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span) # type: ignore def __eq__(self, other): if isinstance(other, ConstExpr): @@ -577,6 +590,9 @@ def __ne__(self, other): return self.value != other.value return self.value != other + def __hash__(self): + return PrimExpr.__hash__(self) + @tvm._ffi.register_object("tir.Cast") class Cast(PrimExprWithOp): @@ -595,7 +611,7 @@ class Cast(PrimExprWithOp): """ def __init__(self, dtype, value, span=None): - self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value, span) + self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value, span) # type: ignore @tvm._ffi.register_object("tir.Add") @@ -615,7 +631,7 @@ class Add(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Add, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Add, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Sub") @@ -635,7 +651,7 @@ class Sub(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Sub, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Sub, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Mul") @@ -655,7 +671,7 @@ class Mul(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Mul, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Mul, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Div") @@ -675,7 +691,7 @@ class Div(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Div, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Div, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Mod") @@ -695,7 +711,7 @@ class Mod(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Mod, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Mod, a, b, span) # type: ignore @tvm._ffi.register_object("tir.FloorDiv") @@ -715,7 +731,7 @@ class FloorDiv(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) # type: ignore @tvm._ffi.register_object("tir.FloorMod") @@ -735,7 +751,7 @@ class FloorMod(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Min") @@ -755,7 +771,7 @@ class Min(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Min, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Min, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Max") @@ -775,7 +791,7 @@ class Max(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Max, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Max, a, b, span) # type: ignore @tvm._ffi.register_object("tir.EQ") @@ -795,7 +811,7 @@ class EQ(CmpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.EQ, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.EQ, a, b, span) # type: ignore @tvm._ffi.register_object("tir.NE") @@ -815,7 +831,7 @@ class NE(CmpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.NE, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.NE, a, b, span) # type: ignore @tvm._ffi.register_object("tir.LT") @@ -835,7 +851,7 @@ class LT(CmpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.LT, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.LT, a, b, span) # type: ignore @tvm._ffi.register_object("tir.LE") @@ -855,7 +871,7 @@ class LE(CmpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span) # type: ignore @tvm._ffi.register_object("tir.GT") @@ -875,7 +891,7 @@ class GT(CmpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.GT, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.GT, a, b, span) # type: ignore @tvm._ffi.register_object("tir.GE") @@ -895,7 +911,7 @@ class GE(CmpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.GE, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.GE, a, b, span) # type: ignore @tvm._ffi.register_object("tir.And") @@ -915,7 +931,7 @@ class And(LogicalExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.And, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.And, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Or") @@ -935,7 +951,7 @@ class Or(LogicalExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Or, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Or, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Not") @@ -952,7 +968,7 @@ class Not(LogicalExpr): """ def __init__(self, a, span=None): - self.__init_handle_by_constructor__(_ffi_api.Not, a, span) + self.__init_handle_by_constructor__(_ffi_api.Not, a, span) # type: ignore @tvm._ffi.register_object("tir.Select") @@ -983,7 +999,7 @@ class Select(PrimExprWithOp): def __init__(self, condition, true_value, false_value, span=None): self.__init_handle_by_constructor__( - _ffi_api.Select, condition, true_value, false_value, span + _ffi_api.Select, condition, true_value, false_value, span # type: ignore ) @@ -1011,9 +1027,9 @@ class Load(PrimExprWithOp): def __init__(self, dtype, buffer_var, index, predicate=None, span=None): if predicate is None: - predicate = _ffi_api.const_true(dtype, span) + predicate = _ffi_api.const_true(dtype, span) # type: ignore self.__init_handle_by_constructor__( - _ffi_api.Load, dtype, buffer_var, index, predicate, span + _ffi_api.Load, dtype, buffer_var, index, predicate, span # type: ignore ) @@ -1034,7 +1050,9 @@ class BufferLoad(PrimExprWithOp): """ def __init__(self, buffer, indices, span=None): - self.__init_handle_by_constructor__(_ffi_api.BufferLoad, buffer, indices, span) + self.__init_handle_by_constructor__( + _ffi_api.BufferLoad, buffer, indices, span # type: ignore + ) @tvm._ffi.register_object("tir.ProducerLoad") @@ -1054,7 +1072,9 @@ class ProducerLoad(PrimExprWithOp): """ def __init__(self, producer, indices, span=None): - self.__init_handle_by_constructor__(_ffi_api.ProducerLoad, producer, indices, span) + self.__init_handle_by_constructor__( + _ffi_api.ProducerLoad, producer, indices, span # type: ignore + ) @tvm._ffi.register_object("tir.Ramp") @@ -1077,7 +1097,9 @@ class Ramp(PrimExprWithOp): """ def __init__(self, base, stride, lanes, span=None): - self.__init_handle_by_constructor__(_ffi_api.Ramp, base, stride, lanes, span) + self.__init_handle_by_constructor__( + _ffi_api.Ramp, base, stride, lanes, span # type: ignore + ) @tvm._ffi.register_object("tir.Broadcast") @@ -1097,7 +1119,7 @@ class Broadcast(PrimExprWithOp): """ def __init__(self, value, lanes, span=None): - self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes, span) + self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes, span) # type: ignore @tvm._ffi.register_object("tir.Shuffle") @@ -1117,7 +1139,9 @@ class Shuffle(PrimExprWithOp): """ def __init__(self, vectors, indices, span=None): - self.__init_handle_by_constructor__(_ffi_api.Shuffle, vectors, indices, span) + self.__init_handle_by_constructor__( + _ffi_api.Shuffle, vectors, indices, span # type: ignore + ) class CallEffectKind: @@ -1163,7 +1187,7 @@ def __init__(self, dtype, op, args, span=None): % op ) op = Op.get(op) - self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args, span) + self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args, span) # type: ignore @tvm._ffi.register_object("tir.Let") @@ -1186,11 +1210,11 @@ class Let(PrimExprWithOp): """ def __init__(self, var, value, body, span=None): - self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body, span) + self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body, span) # type: ignore @tvm._ffi.register_object("tir.Any") -class Any(PrimExpr): +class Any(PrimExprWithOp): """Any node. span : Optional[Span] @@ -1198,4 +1222,4 @@ class Any(PrimExpr): """ def __init__(self, span=None): - self.__init_handle_by_constructor__(_ffi_api.Any, span) + self.__init_handle_by_constructor__(_ffi_api.Any, span) # type: ignore diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index b1081d436150..68d967aa497d 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -67,7 +67,7 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa raise TypeError("params can only contain Var or Buffer") self.__init_handle_by_constructor__( - _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span + _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span # type: ignore ) def with_body(self, new_body, span=None): @@ -137,4 +137,4 @@ def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: func : PrimFunc The new function with parameter specialized """ - return _ffi_api.Specialize(self, param_map) + return _ffi_api.Specialize(self, param_map) # type: ignore diff --git a/python/tvm/tir/generic.py b/python/tvm/tir/generic.py index 1194eaa6b462..58efc0985970 100644 --- a/python/tvm/tir/generic.py +++ b/python/tvm/tir/generic.py @@ -43,7 +43,7 @@ def add(lhs, rhs, span=None): op : tvm.Expr The result Expr of add operaton. """ - return _ffi_api._OpAdd(lhs, rhs, span) + return _ffi_api._OpAdd(lhs, rhs, span) # type: ignore def subtract(lhs, rhs, span=None): @@ -63,7 +63,7 @@ def subtract(lhs, rhs, span=None): op : tvm.Expr The result Expr of subtract operaton. """ - return _ffi_api._OpSub(lhs, rhs, span) + return _ffi_api._OpSub(lhs, rhs, span) # type: ignore def multiply(lhs, rhs, span=None): @@ -83,7 +83,7 @@ def multiply(lhs, rhs, span=None): op : tvm.Expr The result Expr of multiply operaton. """ - return _ffi_api._OpMul(lhs, rhs, span) + return _ffi_api._OpMul(lhs, rhs, span) # type: ignore def divide(lhs, rhs, span=None): @@ -103,7 +103,7 @@ def divide(lhs, rhs, span=None): op : tvm.Expr The result Expr of divide operaton. """ - return _ffi_api._OpDiv(lhs, rhs, span) + return _ffi_api._OpDiv(lhs, rhs, span) # type: ignore def floordiv(lhs, rhs, span=None): @@ -123,7 +123,7 @@ def floordiv(lhs, rhs, span=None): op : tvm.Expr The result Expr of divide operaton. """ - return _ffi_api._OpFloorDiv(lhs, rhs, span) + return _ffi_api._OpFloorDiv(lhs, rhs, span) # type: ignore def cast(src, dtype, span=None): @@ -141,4 +141,4 @@ def cast(src, dtype, span=None): op : tvm.Expr The result Expr of divide operaton. """ - return _ffi_api._cast(dtype, src, span) + return _ffi_api._cast(dtype, src, span) # type: ignore diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 4934bf04727f..978c630b17ad 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -141,7 +141,7 @@ class IRBuilder(object): """ def __init__(self): - self._seq_stack = [[]] + self._seq_stack = [[]] # type: ignore self.nidx = 0 def _pop_seq(self): @@ -394,7 +394,7 @@ def let(self, var_name, value): self.emit(lambda x: _stmt.LetStmt(var, value, x)) return var - def allocate(self, dtype, shape, name="buf", scope=None): + def allocate(self, dtype, shape, name="buf", scope=""): """Create a allocate statement. Parameters @@ -416,15 +416,13 @@ def allocate(self, dtype, shape, name="buf", scope=None): buffer : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, PointerType(PrimType(dtype))) + buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] - if scope: - self.scope_attr(buffer_var, "storage_scope", scope) self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) return BufferVar(self, buffer_var, shape, dtype) - def pointer(self, content_type, name="ptr"): + def pointer(self, content_type, name="ptr", scope=""): """Create pointer variable with content type. Parameters @@ -435,12 +433,15 @@ def pointer(self, content_type, name="ptr"): name : str, optional The name of the pointer. + scope : str, optional + The scope of the pointer. + Returns ------- ptr : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, dtype="handle") + buffer_var = _expr.Var(name, PointerType(PrimType(content_type), scope)) return BufferVar(self, buffer_var, None, content_type) def buffer_ptr(self, buf, shape=None): diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 874724b767cb..de3ca5fa8d5b 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -16,12 +16,14 @@ # under the License. # pylint: disable=redefined-builtin, invalid-name """Operators used in TIR expression.""" +from typing import Any, Optional import tvm._ffi +from tvm.ir.base import Span from tvm.runtime import convert, const from tvm.ir import Array, Op from .buffer import Buffer -from .expr import Call, StringImm, Var, CommReducer +from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer from . import _ffi_api @@ -257,9 +259,9 @@ def any(*args, span=None): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - val = _ffi_api._OpOr(args[0], args[1], span) + val = _ffi_api._OpOr(args[0], args[1], span) # type: ignore for i in range(2, len(args)): - val = _ffi_api._OpOr(val, args[i], span) + val = _ffi_api._OpOr(val, args[i], span) # type: ignore return val @@ -284,9 +286,9 @@ def all(*args, span=None): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - val = _ffi_api._OpAnd(args[0], args[1], span) + val = _ffi_api._OpAnd(args[0], args[1], span) # type: ignore for i in range(2, len(args)): - val = _ffi_api._OpAnd(val, args[i], span) + val = _ffi_api._OpAnd(val, args[i], span) # type: ignore return val @@ -343,10 +345,10 @@ def min_value(dtype, span=None): value : tvm.Expr The minimum value of dtype. """ - return _ffi_api.min_value(dtype, span) + return _ffi_api.min_value(dtype, span) # type: ignore -def max_value(dtype, span=None): +def max_value(dtype: str, span: Optional[Span] = None) -> Any: """maximum value of dtype Parameters @@ -362,11 +364,11 @@ def max_value(dtype, span=None): value : tvm.Expr The maximum value of dtype. """ - return _ffi_api.max_value(dtype, span) + return _ffi_api.max_value(dtype, span) # type: ignore def exp(x): - """Take exponetial of input x. + """Take exponential of input x. Parameters ---------- @@ -769,7 +771,7 @@ def clz(x): return call_intrin("int32", "tir.clz", x) -def floor(x, span=None): +def floor(x: PrimExprWithOp, span=None): """Take floor of float input x. Parameters @@ -785,7 +787,7 @@ def floor(x, span=None): y : PrimExpr The result. """ - return _ffi_api.floor(x, span) + return _ffi_api.floor(x, span) # type: ignore def ceil(x, span=None): @@ -804,7 +806,7 @@ def ceil(x, span=None): y : PrimExpr The result. """ - return _ffi_api.ceil(x, span) + return _ffi_api.ceil(x, span) # type: ignore def trunc(x, span=None): @@ -826,7 +828,7 @@ def trunc(x, span=None): y : PrimExpr The result. """ - return _ffi_api.trunc(x, span) + return _ffi_api.trunc(x, span) # type: ignore def abs(x, span=None): @@ -845,7 +847,7 @@ def abs(x, span=None): y : PrimExpr The result. """ - return _ffi_api.abs(x, span) + return _ffi_api.abs(x, span) # type: ignore def round(x, span=None): @@ -864,7 +866,7 @@ def round(x, span=None): y : PrimExpr The result. """ - return _ffi_api.round(x, span) + return _ffi_api.round(x, span) # type: ignore def nearbyint(x, span=None): @@ -890,7 +892,7 @@ def nearbyint(x, span=None): y : PrimExpr The result. """ - return _ffi_api.nearbyint(x, span) + return _ffi_api.nearbyint(x, span) # type: ignore def nextafter(x1, x2): @@ -909,7 +911,7 @@ def nextafter(x1, x2): y : PrimExpr The result. """ - return call_intrin(x1.dtype, "tir.nextafter", x1, x2) + return call_intrin(x1.dtype, "tir.nextafter", x1, x2) # type: ignore def hypot(x1, x2): @@ -928,7 +930,7 @@ def hypot(x1, x2): y : PrimExpr The result. """ - return call_intrin(x1.dtype, "tir.hypot", x1, x2) + return call_intrin(x1.dtype, "tir.hypot", x1, x2) # type: ignore def copysign(x1, x2): @@ -947,7 +949,7 @@ def copysign(x1, x2): y : PrimExpr The result. """ - return call_intrin(x1.dtype, "tir.copysign", x1, x2) + return call_intrin(x1.dtype, "tir.copysign", x1, x2) # type: ignore def ldexp(x1, x2): @@ -966,7 +968,7 @@ def ldexp(x1, x2): y : PrimExpr The result. """ - return call_intrin(x1.dtype, "tir.ldexp", x1, x2) + return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore def isnan(x, span=None): @@ -985,7 +987,7 @@ def isnan(x, span=None): y : PrimExpr The result. """ - return _ffi_api.isnan(x, span) + return _ffi_api.isnan(x, span) # type: ignore def isfinite(x, span=None): @@ -1004,7 +1006,7 @@ def isfinite(x, span=None): y : PrimExpr The result. """ - return _ffi_api.isfinite(x, span) + return _ffi_api.isfinite(x, span) # type: ignore def isinf(x, span=None): @@ -1023,7 +1025,7 @@ def isinf(x, span=None): y : PrimExpr The result. """ - return _ffi_api.isinf(x, span) + return _ffi_api.isinf(x, span) # type: ignore def power(x, y, span=None): @@ -1045,7 +1047,7 @@ def power(x, y, span=None): z : PrimExpr The result. """ - return _ffi_api._OpPow(convert(x), convert(y), span) + return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore def popcount(x): @@ -1141,7 +1143,7 @@ def if_then_else(cond, t, f, span=None): Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions. """ - return _ffi_api._OpIfThenElse(convert(cond), convert(t), convert(f), span) + return _ffi_api._OpIfThenElse(convert(cond), convert(t), convert(f), span) # type: ignore def div(a, b, span=None): @@ -1166,7 +1168,7 @@ def div(a, b, span=None): ---- When operands are integers, returns truncdiv(a, b, span). """ - return _ffi_api._OpDiv(a, b, span) + return _ffi_api._OpDiv(a, b, span) # type: ignore def indexdiv(a, b, span=None): @@ -1194,7 +1196,7 @@ def indexdiv(a, b, span=None): This function may take advantage of operands' non-negativeness. """ - return _ffi_api._OpIndexDiv(a, b, span) + return _ffi_api._OpIndexDiv(a, b, span) # type: ignore def indexmod(a, b, span=None): @@ -1222,7 +1224,7 @@ def indexmod(a, b, span=None): This function may take advantage of operands' non-negativeness. """ - return _ffi_api._OpIndexMod(a, b, span) + return _ffi_api._OpIndexMod(a, b, span) # type: ignore def truncdiv(a, b, span=None): @@ -1248,7 +1250,7 @@ def truncdiv(a, b, span=None): ---- This is the default integer division behavior in C. """ - return _ffi_api._OpTruncDiv(a, b, span) + return _ffi_api._OpTruncDiv(a, b, span) # type: ignore def truncmod(a, b, span=None): @@ -1274,7 +1276,7 @@ def truncmod(a, b, span=None): ---- This is the default integer division behavior in C. """ - return _ffi_api._OpTruncMod(a, b, span) + return _ffi_api._OpTruncMod(a, b, span) # type: ignore def floordiv(a, b, span=None): @@ -1296,7 +1298,7 @@ def floordiv(a, b, span=None): res : PrimExpr The result expression. """ - return _ffi_api._OpFloorDiv(a, b, span) + return _ffi_api._OpFloorDiv(a, b, span) # type: ignore def floormod(a, b, span=None): @@ -1318,7 +1320,7 @@ def floormod(a, b, span=None): res : PrimExpr The result expression. """ - return _ffi_api._OpFloorMod(a, b, span) + return _ffi_api._OpFloorMod(a, b, span) # type: ignore def comm_reducer(fcombine, fidentity, name="reduce"): @@ -1476,5 +1478,5 @@ def reducer(expr, axis, where=None, init=None, *args): # pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") -min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") -max = comm_reducer(lambda x, y: _ffi_api._OpMax(x, y, None), min_value, name="max") +min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore +max = comm_reducer(lambda x, y: _ffi_api._OpMax(x, y, None), min_value, name="max") # type: ignore diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index ef1cab1fb663..5f0e169c43e3 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -18,5 +18,7 @@ """Namespace for the TensorIR schedule API.""" from .block_scope import BlockScope, Dependency, DepKind, StmtSRef +from .instruction import Instruction, InstructionKind +from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError from .state import ScheduleDebugMask, ScheduleState -from .schedule import LoopRV, BlockRV, ExprRV, RAND_VAR_TYPE, Schedule, ScheduleError +from .trace import Trace diff --git a/python/tvm/tir/schedule/_ffi_api_schedule.py b/python/tvm/tir/schedule/_ffi_api.py similarity index 100% rename from python/tvm/tir/schedule/_ffi_api_schedule.py rename to python/tvm/tir/schedule/_ffi_api.py diff --git a/python/tvm/tir/schedule/block_scope.py b/python/tvm/tir/schedule/block_scope.py index 061a472ad9a9..30e047b4f78a 100644 --- a/python/tvm/tir/schedule/block_scope.py +++ b/python/tvm/tir/schedule/block_scope.py @@ -22,7 +22,7 @@ from tvm.runtime import Object from tvm.tir import Block, For -from . import _ffi_api_schedule +from . import _ffi_api @register_object("tir.StmtSRef") @@ -45,24 +45,24 @@ class StmtSRef(Object): @property def stmt(self) -> Optional[Union[Block, For]]: """The block/for stmt the object refers to""" - return _ffi_api_schedule.StmtSRefStmt(self) # type: ignore # pylint: disable=no-member + return _ffi_api.StmtSRefStmt(self) # type: ignore # pylint: disable=no-member @property def parent(self) -> Optional["StmtSRef"]: """The parent sref""" - return _ffi_api_schedule.StmtSRefParent(self) # type: ignore # pylint: disable=no-member + return _ffi_api.StmtSRefParent(self) # type: ignore # pylint: disable=no-member @staticmethod def inline_mark() -> "StmtSRef": """A special StmtSRef, which doesn't point to any stmt in the AST, only serving as a "mark" to hint compute-at to do the work of compute-inline""" - return _ffi_api_schedule.StmtSRefInlineMark() # type: ignore # pylint: disable=no-member + return _ffi_api.StmtSRefInlineMark() # type: ignore # pylint: disable=no-member @staticmethod def root_mark() -> "StmtSRef": """A special StmtSRef, which doesn't point to any stmt in the AST, only serving as a "mark" to hint compute-at to do nothing""" - return _ffi_api_schedule.StmtSRefRootMark() # type: ignore # pylint: disable=no-member + return _ffi_api.StmtSRefRootMark() # type: ignore # pylint: disable=no-member class DepKind(IntEnum): @@ -137,7 +137,7 @@ def get_deps_by_src(self, block: StmtSRef) -> List[Dependency]: blocks: List[Dependency] The dependencies """ - return _ffi_api_schedule.BlockScopeGetDepsBySrc(self, block) # type: ignore # pylint: disable=no-member + return _ffi_api.BlockScopeGetDepsBySrc(self, block) # type: ignore # pylint: disable=no-member def get_deps_by_dst(self, block: StmtSRef) -> List[Dependency]: """Get all dependencies whose `dst` is the target `block`. @@ -152,4 +152,4 @@ def get_deps_by_dst(self, block: StmtSRef) -> List[Dependency]: blocks: List[Dependency] The dependencies """ - return _ffi_api_schedule.BlockScopeGetDepsByDst(self, block) # type: ignore # pylint: disable=no-member + return _ffi_api.BlockScopeGetDepsByDst(self, block) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/schedule/instruction.py b/python/tvm/tir/schedule/instruction.py new file mode 100644 index 000000000000..09b2d70dc321 --- /dev/null +++ b/python/tvm/tir/schedule/instruction.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Schedule instructions each corresponds to a schedule primitive""" +from typing import TYPE_CHECKING, Any, List, Union + +from tvm._ffi import register_object as _register_object +from tvm.runtime import Object + +from . import _ffi_api + +if TYPE_CHECKING: + from .schedule import RAND_VAR_TYPE + + INPUT_RV_TYPE = Union[RAND_VAR_TYPE, float, int, str, None] # pylint: disable=invalid-name + OUTPUT_RV_TYPE = Union[RAND_VAR_TYPE] # pylint: disable=invalid-name + ATTR_TYPE = Any +else: + INPUT_RV_TYPE = OUTPUT_RV_TYPE = ATTR_TYPE = Any + + +@_register_object("tir.InstructionKind") +class InstructionKind(Object): + """Kind of an instruction, e.g. Split, Reorder, etc. + Besides the name, every kind of instruction has its own properties, including: + 1) A boolean indicating if the instruction is pure, i.e. change nothing in the schedule state + 2) A functor that applies the instruction to a TensorIR schedule + 3) A functor that converts the instruction to a statement in python syntax + 4) A functor that serialize its attributes to JSON + 5) A functor that deserialize its attributes from JSON + + Unlike `tvm.ir.op`, `InstructionKind` doesn't support unstructured properties, + mainly because there is no such usecase yet to add any other property. + + Attributes + ---------- + name : str + The name of a kind of instructions + + Note + ---- + The functor properties are not exposed on python side at the moment + """ + + name: str + + @property + def is_pure(self) -> bool: + """Indicates if the instruction is pure, i.e. removing it alone doesn't mutate the schedule + state. For example, the instruction `GetBlock` is pure because it changes + nothing, while `ComputeInline` is not because removing it leads to a different resulting + schedule. + + Returns + ------- + pure : bool + The boolean flag indicating if the instruction is pure + """ + return bool(self._is_pure) + + @staticmethod + def get(name: str) -> "InstructionKind": + """Retrieve an InstructionKind using its name + + Parameters + ---------- + name : str + The registered name of the InstructionKind + + Returns + ------- + kind : InstructionKind + The InstructionKind retrieved + """ + return _ffi_api.InstructionKindGet(name) # type: ignore # pylint: disable=no-member + + +@_register_object("tir.Instruction") +class Instruction(Object): + """Schedule instructions each corresponds to a schedule primitive + + Attributes + ---------- + kind : InstructionKind + The kind of the instruction + inputs : List[INPUT_RV_TYPE] + The input random variables of the instruction, + and the type of each element can be one of the following: + - BlockRV + - LoopRV + - ExprRV + - float + - int + - str + - None + attrs : List[ATTR_TYPE] + The attributes of the instruction. Similar to attributes of an operator, + attributes of an instruction are arbitrary constant metadata required by the instructions. + For example, the name of the block to be retrieved in `GetBlock`. + outputs : List[OUTPUT_RV_TYPE] + The output random variables of the instruction, + and the type of each element can be one of the following: + - BlockRV + - LoopRV + - ExprRV, atomic variables only, won't be constants or composite PrimExpr + """ + + kind: InstructionKind + inputs: List[INPUT_RV_TYPE] + attrs: List[ATTR_TYPE] + outputs: List[OUTPUT_RV_TYPE] + + def __init__( + self, + kind: InstructionKind, + inputs: List[INPUT_RV_TYPE], + attrs: List[ATTR_TYPE], + outputs: List[OUTPUT_RV_TYPE], + ) -> None: + """Constructor + + Parameters + ---------- + kind : InstructionKind + The kind of the instruction + inputs : List[INPUT_RV_TYPE] + The input random variables of the instruction, + and the type of each element can be one of the following: + - BlockRV + - LoopRV + - ExprRV + - float + - int + - str + - None + attrs : List[ATTR_TYPE] + The attributes of the instruction. Similar to attributes of an operator, + attributes of an instruction are arbitrary constant metadata required by the + instructions. For example, the name of the block to be retrieved in `GetBlock`. + outputs : List[OUTPUT_RV_TYPE] + The output random variables of the instruction, + and the type of each element can be one of the following: + - BlockRV + - LoopRV + - ExprRV, atomic variables only, won't be constants or composite PrimExpr + """ + self.__init_handle_by_constructor__( + _ffi_api.Instruction, # type: ignore # pylint: disable=no-member + kind, + inputs, + attrs, + outputs, + ) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 2091f4d80ab3..9433d019f9a5 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -14,18 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-import """The TensorIR schedule class""" -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object -from tvm.tir import Block, For, IntImm, PrimFunc, Var +from tvm.tir import Block, For, IntImm, PrimFunc -from . import _ffi_api_schedule -from .state import ScheduleState, StmtSRef +from . import _ffi_api +from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod +from .trace import Trace @register_error @@ -37,15 +37,56 @@ class ScheduleError(TVMError): class LoopRV(Object): """A random variable that refers to a loop""" + def __init__(self) -> None: + """Construct a new LoopRV.""" + self.__init_handle_by_constructor__( + _ffi_api.LoopRV # type: ignore # pylint: disable=no-member + ) + @_register_object("tir.BlockRV") class BlockRV(Object): """A random variable that refers to a block""" + def __init__(self) -> None: + """Construct a new BlockRV.""" + self.__init_handle_by_constructor__( + _ffi_api.BlockRV # type: ignore # pylint: disable=no-member + ) + + +# It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370 +# This feature is not supported until python 3.10: +# https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias +ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer + +RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name + +# Update to `Literal["detail", "fast", "none"]` once upgraded to python3.8 +_ERROR_RENDER_LEVEL: Dict[str, int] = { + "detail": 0, + "fast": 1, + "none": 2, +} + + +def _parse_error_render_level(error_render_level: str) -> int: + if error_render_level not in _ERROR_RENDER_LEVEL: + raise ValueError( + 'error_render_level can be "detail", "fast", or "none", but got: ' + + f"{error_render_level}" + ) + return _ERROR_RENDER_LEVEL.get(error_render_level) -ExprRV = PrimExpr # A random variable that evaluates to an integer -RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # type: ignore # pylint: disable=invalid-name +def _parse_seed(seed: Optional[int]) -> int: + if seed is None: + return -1 + if not isinstance(seed, int): + raise TypeError(f"Expected `seed` to be int or None, but gets: {seed}") + if seed < 1 or seed > 2147483647: + raise ValueError(f"seed must be in the range [1, 2147483647], but gets: {seed}") + return seed @_register_object("tir.Schedule") @@ -63,54 +104,66 @@ class Schedule(Object): Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html """ - ERROR_RENDER_LEVEL = {"detail": 0, "fast": 1, "none": 2} - def __init__( self, - func_or_mod: Union[PrimFunc, IRModule], + mod: Union[PrimFunc, IRModule], *, - debug_mode: Union[bool, int] = False, + seed: Optional[int] = None, + debug_mask: Union[str, int] = "none", error_render_level: str = "detail", - ): - """Construct a concrete TensorIR schedule from an IRModule or a PrimFunc + ) -> None: + """Construct a TensorIR schedule class from an IRModule Parameters ---------- - func_or_mod : Union[PrimFunc, IRModule] + mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to be scheduled - debug_mode : Union[bool, int] + seed: Optional[int] + The seed value for schedule's random state + Note that None and -1 means use device random, otherwise only integer between 1 and + 2147483647 is allowed. + debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time - scheduling primitive + after calling the Replace method. + Possible choices of `debug_mask`: + 1) "all" - Turn on all the checks + 2) "none" - Turn off all the checks + 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask error_render_level : str = "detail" The level of error rendering. Choices: "detail", "fast", "none". - "detail": Render a detailed error message, with the TIR and error locations printed - "fast: Show a simple error message without rendering or string manipulation - "none": Do not show any error message. + - "detail": Render a detailed error message, with the TIR and error locations printed + - "fast: Show a simple error message without rendering or string manipulation + - "none": Do not show any error message. Note - ---------- + ---- The checks performed includes: 1) VerifySRefTree 2) VerifyCachedFlags """ - if isinstance(debug_mode, bool): - if debug_mode: - debug_mode = -1 - else: - debug_mode = 0 - if not isinstance(debug_mode, int): - raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}") - if error_render_level not in Schedule.ERROR_RENDER_LEVEL: - raise ValueError( - 'error_render_level can be "detail", "fast", or "none", but got: ' - + f"{error_render_level}" - ) - error_render_level = Schedule.ERROR_RENDER_LEVEL.get(error_render_level) # type: ignore + # call the constructor self.__init_handle_by_constructor__( - _ffi_api_schedule.ConcreteSchedule, # type: ignore # pylint: disable=no-member - func_or_mod, - debug_mode, - error_render_level, + _ffi_api.TracedSchedule, # type: ignore # pylint: disable=no-member + _parse_mod(mod), + _parse_seed(seed), + _parse_debug_mask(debug_mask), + _parse_error_render_level(error_render_level), + ) + + @staticmethod + def _create_non_traced( + mod: Union[PrimFunc, IRModule], + *, + seed: Optional[int] = None, + debug_mask: Union[str, int] = "none", + error_render_level: str = "detail", + ) -> "Schedule": + """Construct a non-traced TensorIR schedule class from an IRModule.""" + return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member + _parse_mod(mod), + _parse_seed(seed), + _parse_debug_mask(debug_mask), + _parse_error_render_level(error_render_level), ) ########## Utilities ########## @@ -118,44 +171,63 @@ def __init__( @property def mod(self) -> IRModule: """Returns the AST of the module being scheduled""" - return _ffi_api_schedule.ScheduleModule(self) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleGetMod(self) # type: ignore # pylint: disable=no-member @property def state(self) -> ScheduleState: """Returns the ScheduleState in the current schedule class""" - return _ffi_api_schedule.ScheduleGetState(self) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleGetState(self) # type: ignore # pylint: disable=no-member + + @property + def trace(self) -> Optional[Trace]: + """Returns the internally maintained trace of scheduling program execution""" + return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint: disable=no-member def copy(self) -> "Schedule": """Returns a copy of the schedule, including both the state and the symbol table, * guaranteeing that * 1) SRef tree is completely reconstructed; * 2) The IRModule being scheduled is untouched; - * 3) All the random variables are valid in the copy, pointing to the correpsonding sref + * 3) All the random variables are valid in the copy, pointing to the corresponding sref * reconstructed + Returns ------- copy : Schedule A new copy of the schedule """ - return _ffi_api_schedule.ScheduleCopy(self) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleCopy(self) # type: ignore # pylint: disable=no-member def seed(self, seed: int) -> None: """Seed the randomness + Parameters ---------- seed : int The new random seed, -1 if use device random, otherwise non-negative """ - return _ffi_api_schedule.ScheduleSeed(self, seed) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleSeed(self, seed) # type: ignore # pylint: disable=no-member + + def fork_seed(self) -> int: + """Returns a forked random state as seed for new schedules + + Returns + ------- + seed : int + The forked random state, not the same as the current random state + """ + return _ffi_api.ScheduleForkSeed(self) # type: ignore # pylint: disable=no-member def show(self, rand_var: RAND_VAR_TYPE) -> str: """Returns a string representation of the value that the random variable evaluates to + Parameters ---------- rand_var : Union[ExprRV, BlockRV, LoopRV] The random variable to be evaluated + Returns - ---------- + ------- str_repr : str The string representation """ @@ -173,71 +245,108 @@ def get( - the corresponding integer that a ExprRV evaluates to; - the corresponding Block that a block sref points to; - the corresponding For that a loop sref points to; + Parameters ---------- rand_var_or_sref : Union[ExprRV, BlockRV, LoopRV, StmtSRef] The random variable / sref to be evaluated + Returns - ---------- + ------- result : Optional[Union[int, Block, For]] - The correpsonding result + The corresponding result """ if isinstance(rand_var_or_sref, StmtSRef): return rand_var_or_sref.stmt - result = _ffi_api_schedule.ScheduleGet(self, rand_var_or_sref) # type: ignore # pylint: disable=no-member + result = _ffi_api.ScheduleGet(self, rand_var_or_sref) # type: ignore # pylint: disable=no-member if isinstance(result, IntImm): result = result.value return result def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Optional[StmtSRef]: - """Returns the correpsonding sref to the given + """Returns the corresponding sref to the given 1) LoopRV 2) BlockRV 3) Block 4) For + Parameters ---------- rand_var_or_stmt : Union[BlockRV, LoopRV, Block, For] The random variable / sref to be evaluated + Returns - ---------- + ------- result : Optional[StmtSRef] - The correpsonding result + The corresponding result """ - return _ffi_api_schedule.ScheduleGetSRef( # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleGetSRef( # type: ignore # pylint: disable=no-member self, rand_var_or_stmt ) def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None: """Remove a random variable from the symbol table + Parameters ---------- rand_var : Union[BlockRV, LoopRV, ExprRV] The random variable to be removed """ - return _ffi_api_schedule.ScheduleRemoveRV(self, rand_var) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleRemoveRV(self, rand_var) # type: ignore # pylint: disable=no-member + + ########## Schedule: Sampling ########## + + def sample_categorical( + self, + candidates: List[int], + probs: List[float], + decision: Optional[int] = None, + ) -> ExprRV: + """Sample an integer given the probability distribution - ########## Block/Loop relation ########## + Parameters + ---------- + candidates : List[int] + The candidates to be sampled from + probs : List[float] + The probability of each candidate + decision : Optional[int] + The sampling decision, if any + Returns + ------- + result : ExprRV + The random variable sampled from candidates + """ + return _ffi_api.ScheduleSampleCategorical( # type: ignore # pylint: disable=no-member + self, + candidates, + probs, + decision, + ) + + ########## Schedule: Get blocks & loops ########## def get_block( self, name: str, func_name: str = "main", ) -> BlockRV: """Retrieve a block in a specific function with its name + Parameters ---------- name : str The name of the block func_name : str = "main" The name of the function + Returns - ---------- + ------- block : BlockRV The block retrieved IndexError is raised if 0 or multiple blocks exist with the specific name. """ - return _ffi_api_schedule.ScheduleGetBlock( # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleGetBlock( # type: ignore # pylint: disable=no-member self, name, func_name, @@ -245,19 +354,444 @@ def get_block( def get_loops(self, block: BlockRV) -> List[LoopRV]: """Get the parent loops of the block in its scope, from outer to inner + Parameters ---------- block : BlockRV The query block + Returns - ---------- + ------- loops : List[LoopRV] A list of loops above the given block in its scope, from outer to inner """ - return _ffi_api_schedule.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member + + ########## Schedule: Transform loops ########## + def fuse(self, *loops: List[LoopRV]) -> LoopRV: + """Fuse a list of consecutive loops into one. It requires: + 1) The loops can't have annotations or thread bindings. + 2) The (i+1)-th loop must be the only child of the i-th loop. + 3) All loops must start with 0. + + Parameters + ---------- + *loops : List[LoopRV] + The loops to be fused + + Returns + ------- + fused_loop : LoopRV + The new loop after fusion + + Examples + -------- + + Before applying fuse, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_fuse(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do fuse: + + .. code-block:: python + + sch = tir.Schedule(before_fuse) + i, j = sch.get_loops(sch.get_block("B")) + sch.fuse(i, j) + print(tvm.script.asscript(sch.mod["main"])) + + After applying fuse, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_fuse(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + # the 2 loops are fused into 1 + for i_j_fused in tir.serial(0, 16384): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, tir.floordiv(i_j_fused, 128)) + tir.bind(vj, tir.floormod(i_j_fused, 128)) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + return _ffi_api.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member + + def split( + self, + loop: LoopRV, + factors: List[Union[ExprRV, None]], + ) -> List[LoopRV]: + """Split a loop into a list of consecutive loops. It requires: + 1) The loop can't have annotation or thread binding. + 2) The loop must start with 0. + Predicates may be added to ensure the total loop numbers keeps unchanged. + In `factors`, at most one of the factors can be None, + which will be automatically inferred. + + Parameters + ---------- + loop : LoopRV + The loop to be split + + factors: List[Union[ExprRV, None]] + The splitting factors + Potential inputs are: + - None + - ExprRV + - Nonnegative constant integers + + Returns + ------- + split_loops : List[LoopRV] + The new loops after split + + Examples + -------- + + Before split, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do split: + + .. code-block:: python + + sch = tir.Schedule(before_split) + i, j = sch.get_loops(sch.get_block("B")) + sch.split(i, factors=[2, 64]) + print(tvm.script.asscript(sch.mod["main"])) + + After applying split, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + # the original loop is split into 2 loops + for i0, i1, j in tir.grid(2, 64, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, ((i0*64) + i1)) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + # it will be checked later in C++ implementation + # that there is at most one None in `factors` + return _ffi_api.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member + + def reorder(self, *ordered_loops: List[LoopRV]) -> None: + """ + Reorder a list of loops. It doesn't require the loops to be consecutive. + It requires: + 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , + l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between + l_1 and l_n (which also indicates they are under the same scope). + 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. + 3) For every block under the loop nests, its block binding must be affine, and the block + variables must be either data parallel or reduction. + 4) No duplicated loops are allowed in the arguments. + + Parameters + ---------- + *ordered_loops : List[LoopRV] + The loops in the new order + + Examples + -------- + + Before reorder, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_reorder(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do reorder: + + .. code-block:: python + + sch = tir.Schedule(before_reorder) + i, j = sch.get_loops(sch.get_block("B")) + sch.reorder(j, i) + print(tvm.script.asscript(sch.mod["main"])) + + After applying reorder, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_reorder(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + # Here j and i are reordered + for j, i in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleReorder(self, ordered_loops) # type: ignore # pylint: disable=no-member + + ########## Schedule: Manipulate ForKind ########## + + def parallel(self, loop: LoopRV) -> None: + """Parallelize the input loop. It requires: + 1) The scope block that the loop is in should have stage-pipeline property + 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + bindings + 3) For each block under the loop, the loop can only be contained in data-parallel block + iters' bindings + + Parameters + ---------- + loop : LoopRV + The loop to be parallelized + + Examples + -------- + + Before parallel, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_parallel(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do parallel: + + .. code-block:: python + + sch = tir.Schedule(before_parallel) + i, j = sch.get_loops(sch.get_block("B")) + sch.parallel(i) + + After applying parallel, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_parallel(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i in tir.parallel(0, 128): + for j in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleParallel(self, loop) # type: ignore # pylint: disable=no-member + + def vectorize(self, loop: LoopRV) -> None: + """Vectorize the input loop. It requires: + 1) The scope block that the loop is in should have stage-pipeline property + 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + bindings + 3) For each block under the loop, the loop can only be contained in data-parallel block + iters' bindings + + Parameters + ---------- + loop : LoopRV + The loop to be vectorized + + Examples + -------- + + Before vectorize, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_vectorize(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do vectorize: + + .. code-block:: python + + sch = tir.Schedule(before_vectorize) + i, j = sch.get_loops(sch.get_block("B")) + sch.vectorize(j) + + After applying vectorize, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_vectorize(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i in tir.serial(0, 128): + for j in tir.vectorized(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleVectorize(self, loop) # type: ignore # pylint: disable=no-member + + def bind(self, loop: LoopRV, thread_axis: str) -> None: + """Bind the input loop to the given thread axis. It requires: + 1) The scope block that the loop is in should have stage-pipeline property + 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + bindings + 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can + only be contained in data-parallel block iter and reduction block iters' bindings. Otherwise + the loop can only be contained in data-parallel block iters' bindings + + Parameters + ---------- + loop : LoopRV + The loop to be bound to the thread axis + thread_axis : str + The thread axis to be bound to the loop. Possible candidates: + - blockIdx.x/y/z + - threadIdx.x/y/z + - vthread.x/y/z + - vthread (It is a legacy behavior that will be deprecated. Please use `vthread.x/y/z` + instead.) + + Examples + -------- + + Before bind, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_bind(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do bind: + + .. code-block:: python + + sch = tir.Schedule(before_bind) + i, j = sch.get_loops(sch.get_block("B")) + sch.bind(i, "blockIdx.x") + sch.bind(j, "threadIdx.x") + + After applying bind, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_bind(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i in tir.thread_binding(0, 128, thread = "blockIdx.x"): + for j in tir.thread_binding(0, 128, thread = "threadIdx.x"): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleBind(self, loop, thread_axis) # type: ignore # pylint: disable=no-member + + def unroll(self, loop: LoopRV) -> None: + """Unroll the input loop. It requires nothing + + Parameters + ---------- + loop : LoopRV + The loop to be unrolled + + Examples + -------- + + Before unroll, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_unroll(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do unroll: + + .. code-block:: python + + sch = tir.Schedule(before_unroll) + i, j = sch.get_loops(sch.get_block("B")) + sch.unroll(i) + + After applying unroll, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_unroll(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i in tir.unroll(0, 128): + for j in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleUnroll(self, loop) # type: ignore # pylint: disable=no-member + + ########## Schedule: Insert cache stages ########## + + ########## Schedule: Compute location ########## - ########## Schedule: loops manipulation ########## - ########## Schedule: compute location ########## def compute_inline(self, block: BlockRV) -> None: """Inline a block into its consumer(s). It requires: @@ -297,7 +831,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None: .. code-block:: python - sch = tir.Schedule(before_inline, debug_mode=True) + sch = tir.Schedule(before_inline) sch.compute_inline(sch.get_block("B")) print(tvm.script.asscript(sch.mod["main"])) @@ -313,7 +847,7 @@ def after_inline(a: ty.handle, c: ty.handle) -> None: C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ - _ffi_api_schedule.ScheduleComputeInline(self, block) # type: ignore # pylint: disable=no-member + _ffi_api.ScheduleComputeInline(self, block) # type: ignore # pylint: disable=no-member def reverse_compute_inline(self, block: BlockRV) -> None: """Inline a block into its only producer. It requires: @@ -357,7 +891,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None: .. code-block:: python - sch = tir.Schedule(before_inline, debug_mode=True) + sch = tir.Schedule(before_inline) sch.reverse_compute_inline(sch.get_block("C")) print(tvm.script.asscript(sch.mod["main"])) @@ -373,14 +907,232 @@ def after_inline(a: ty.handle, c: ty.handle) -> None: C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ - _ffi_api_schedule.ScheduleReverseComputeInline(self, block) # type: ignore # pylint: disable=no-member + _ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore # pylint: disable=no-member + + ########## Schedule: Reduction ########## + + def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV: + """Factorize an associative reduction block by the specified loop. + + An associative reduction cannot be parallelized directly, + because it leads to potential race condition during accumulation. + Alternatively, the reduction could be factorized on a loop with the following steps: + - Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent + - Step 2: compute the chunks separately and write the result into `n` intermediate buffers; + - Step 3: accumulate the `n` separate buffer into the result buffer. + Note that the Step 2 above introduces opportunities for parallelization. + + RFactor is a schedule primitive that implements the transformation described above: + Given a block that writes to buffer `B`, it factorizes a loop of extent `n`. + + For example, the pseudocode below accumulates `B[i] = sum(A[i, : , : ])`: + + .. code-block:: python + + for i in range(128): # loop i is a data parallel loop + for j in range(128): # loop j is a reduction loop + for k in range(128): # loop k is a reduction loop + B[i] = B[i] + A[i, j, k] + + Suppose RFactor is applied on the innermost loop `k` and `factor_axis = 1`. + RFactor then creates an intermediate buffer and two blocks. + + 1. The intermediate buffer, or "rf-buffer" is a buffer of rank `ndim(B) + 1` and + size `size(B) * n`, whose shape expands from `shape(B)` by adding an axis of `n` + at the position specified by `factor_axis`. For example, + + * shape(B) = [1, 2, 3], factor_axis = 0 => shape(B_rf) = [n, 1, 2, 3] + * shape(B) = [1, 2, 3], factor_axis = 1 => shape(B_rf) = [1, n, 2, 3] + * shape(B) = [1, 2, 3], factor_axis = 2 => shape(B_rf) = [1, 2, n, 3] + * shape(B) = [1, 2, 3], factor_axis = 3 => shape(B_rf) = [1, 2, 3, n] + + 2. The rfactor block, or "rf-block", is a block that writes to the `rf-buffer` without + accumulating over the loop `k`, i.e. the loop `k` is converted from a reduction loop + to a data parallel loop. In our example, the rf-block is: + + .. code-block:: python + + B_rf = np.zeros((128, 128)) # the rf-buffer + for k in range(128): # loop k is converted to a data parallel loop + for i in range(128): # loop i is a data parallel loop (unchanged) + for j in range(128): # loop j is a reduction loop (unchanged) + B_rf[i, k] = B_rf[i, k] + A[i, j, k] + + + 3. The write-back block, or `wb-block`, is a block that accumulates the rf-buffer into + the result buffer. All the reduction loops are removed except the loop `k` for accumulation. + In our example, the wb-block is: + + .. code-block:: python + + for i in range(128): # loop i is a data parallel loop (unchanged) + # loop j is removed because it is a reduction loop + for k in range(128): # loop k is a reduction loop (unchanged) + B[i] = B[i] + B_rf[i, k] + + + Parameters + ---------- + loop : LoopRV + The loop outside block for which we want to do rfactor + factor_axis : int + The position where the new dimension is placed in the new introduced rfactor buffer + + Returns + ------- + rf_block : BlockRV + The block which computes partial results over each slices (i.e., the first block + as described in the above illustration) + + Examples + -------- + + Before rfactor, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_rfactor(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128,)) + with tir.block([128, tir.reduce_axis(0, 128), + tir.reduce_axis(0, 128)], "B") as [vii, vi, vj]: + with tir.init(): + B[vii] = 0.0 + B[vii] = B[vii] + A[vii, vi, vj] + + Create the schedule and do rfactor: + + .. code-block:: python + + sch = tir.Schedule(before_rfactor) + _, _, k = sch.get_loops(sch.get_block("B")) + sch.rfactor(k, 0) + print(tvm.script.asscript(sch.mod["main"])) + + After applying rfactor, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_rfactor(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128, 128]) + B = tir.match_buffer(b, [128]) + B_rf = tir.alloc_buffer([128, 128]) + with tir.block([128, 128, tir.reduce_axis(0, 128)], "B_rf") as [vi2, vii, vi]: + with tir.init(): + B_rf[vi2, vii] = 0.0 + B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2]) + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vii_1, vi2_1]: + with tir.init(): + B[vii_1] = 0.0 + B[vii_1] = (B[vii_1] + B_rf[vi2_1, vii_1]) + + + Note + ---- + + Rfactor requires: + 1) `loop` has only one child block, and it is a reduction block; + 2) `loop` is a reduction loop, i.e. the loop variable is bound to only reduction variables + in the block binding; + 3) `loop` is not parallelized, vectorized, unrolled or bound to any thread axis; + 4) The block scope that `loop` is in is a staged-pipeline; + 5) The outermost loop outside the reduction block should has the reduction block as its + first child block; + 6) The outermost reduction loop should have only one child block; + 7) An unary extent loop that is not bound to any reduction or data parallel variables in + the block binding should not appear under some reduction loop; + 8) The reduction block should write to only one buffer, and its init and body are both + simple `BufferStore`s, and the pattern is registered as an associative reducer. + The pre-defined patterns include: plus, multiplication, min and max; + 9) Each of the loops on top of the block cannot be bound to a data parallel and a + reduction block binding at the same time; + 10) `factor_axis` should be in range `[-ndim(B) - 1, ndim(B)]`, + where `B` is the buffer that the reduction block writes to. + Negative indexing is normalized according to numpy convention. + """ + return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member + + ######## Schedule: Block annotatoin ######## + + def storage_align( # pylint: disable=too-many-arguments + self, block: BlockRV, buffer_index: int, axis: int, factor: int, offset: int + ) -> None: + """Set alignment requirement for specific dimension such that + stride[axis] == k * factor + offset for some k. This is useful to set memory layout for more + friendly memory access pattern. For example, we can set alignment to be factor=2, offset=1 + to avoid bank conflict for thread access on higher dimension in GPU shared memory. + + Parameters + ---------- + block : BlockRV + The producer block of the buffer. + buffer_index : int + The index of the buffer in block's write region. + axis : int + The dimension to be specified for alignment. + factor : int + The factor multiple of alignment. + offset : int + The required offset factor. + + Examples + -------- + + Before storage_align, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_storage_align(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do storage_align: + + .. code-block:: python + + sch = tir.Schedule(before_storage_align) + sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, factor=128, offset=1) + print(tvm.script.asscript(sch.mod["main"])) + + After applying rfactor, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_storage_align(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + tir.block_attr({"buffer_dim_align": [[[0, 128, 1]]]}) + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + After lowering passes, buffer B will have strides as [129, 1]. + + Note + ---- + Storage_align requires the buffer to be an intermediate buffer defined via `alloc_buffer`. + """ + _ffi_api.ScheduleStorageAlign( # type: ignore # pylint: disable=no-member + self, block, buffer_index, axis, factor, offset + ) + + ########## Schedule: Blockize & Tensorize ########## - ########## Schedule: loop binding/annotation ########## - ########## Schedule: cache read/write ########## - ########## Schedule: reduction ########## - ########## Schedule: blockize & tensorize ########## + ########## Schedule: Annotation ########## + ########## Schedule: Misc ########## -@_register_object("tir.ConcreteSchedule") -class ConcreteSchedule(Schedule): - """A concrete schedule class of TensorIR. Do not use directly, use tvm.tir.Schedule instead.""" + def enter_postproc(self) -> None: + """A no-op that marks the start of postprocessing phase of scheduling""" + _ffi_api.ScheduleEnterPostproc(self) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py index 845e1db5cb83..a371897abd13 100644 --- a/python/tvm/tir/schedule/state.py +++ b/python/tvm/tir/schedule/state.py @@ -24,17 +24,17 @@ from tvm.runtime import Object from tvm.tir import Block, BlockRealize, For, PrimFunc -from . import _ffi_api_schedule +from . import _ffi_api from .block_scope import BlockScope, StmtSRef CachedFlags = namedtuple("CachedFlags", ["affine_binding", "region_cover", "stage_pipeline"]) class ScheduleDebugMask(IntEnum): - """The bitmask of the `debug_mode` flag in the ScheduleState class. + """The bitmask of the `debug_mask` flag in the ScheduleState class. - If the `debug_mode` flag has a certain bit on, then the correpsonding - verification pass will be conducted. For example, if `(debug_mode & VERIFY_SREF_TREE) != 0`, + If the `debug_mask` flag has a certain bit on, then the correpsonding + verification pass will be conducted. For example, if `(debug_mask & VERIFY_SREF_TREE) != 0`, then the correctness of the sref tree will be verified after each schedule instruction. Attributes @@ -49,6 +49,27 @@ class ScheduleDebugMask(IntEnum): VERIFY_CACHED_FLAGS = 2 +def _parse_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: + if isinstance(mod, PrimFunc): + mod = IRModule({"main": mod}) + if not isinstance(mod, IRModule): + raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") + return mod + + +def _parse_debug_mask(debug_mask: Union[str, int]) -> int: + if isinstance(debug_mask, str): + if debug_mask == "all": + debug_mask = ScheduleDebugMask.VERIFY_SREF_TREE | ScheduleDebugMask.VERIFY_CACHED_FLAGS + elif debug_mask == "none": + debug_mask = 0 + else: + raise ValueError(f"Unrecognizable `debug_mask`: {debug_mask}") + if isinstance(debug_mask, bool) or not isinstance(debug_mask, int): + raise TypeError(f"`debug_mask` should be integer or boolean, but gets: {debug_mask}") + return debug_mask + + @register_object("tir.ScheduleState") class ScheduleState(Object): """The state of scheduling, which exposes a `Replace` method as @@ -59,50 +80,44 @@ class ScheduleState(Object): 2) The sref tree of schedulable statements (indicated by the srefs) 3) The dependency information of each block scope (block_info) 4) A reverse mapping from the AST nodes to that in the sref tree (get_sref) - 5) A debug flag, if set, extra checking is enabled (debug_mode) + 5) A debug flag, if set, extra checking is enabled (debug_mask) Parameters ---------- mod : IRModule The AST of the module being scheduled - debug_mode : int + debug_mask : int Do extra correctness checking after the object construction and each time after calling the Replace method. """ mod: IRModule - debug_mode: int + debug_mask: int def __init__( self, - func_or_mod: Union[PrimFunc, IRModule], - debug_mode: Union[bool, int] = False, - ): + mod: Union[PrimFunc, IRModule], + *, + debug_mask: Union[str, int] = "none", + ) -> None: """Construct a schedule state from an IRModule or a PrimFunc Parameters ---------- - func_or_mod : Union[PrimFunc, IRModule] + mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to be scheduled - debug_mode : Union[bool, int] + debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time after calling the Replace method. - Possible choices of `debug_mode`: - 1) True - Turn on all the checks - 2) False - Turn off all the checks + Possible choices of `debug_mask`: + 1) "all" - Turn on all the checks + 2) "none" - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask """ - if isinstance(debug_mode, bool): - if debug_mode: - debug_mode = -1 - else: - debug_mode = 0 - if not isinstance(debug_mode, int): - raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}") self.__init_handle_by_constructor__( - _ffi_api_schedule.ScheduleState, # type: ignore # pylint: disable=no-member - func_or_mod, - debug_mode, + _ffi_api.ScheduleState, # type: ignore # pylint: disable=no-member + _parse_mod(mod), + _parse_debug_mask(debug_mask), ) def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]: @@ -118,7 +133,7 @@ def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]: sref : StmtSRef The corresponding sref """ - return _ffi_api_schedule.ScheduleStateGetSRef(self, stmt) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleStateGetSRef(self, stmt) # type: ignore # pylint: disable=no-member def get_block_scope(self, block_sref: StmtSRef) -> BlockScope: """Get the BlockScope correpsonding to the block sref @@ -133,7 +148,7 @@ def get_block_scope(self, block_sref: StmtSRef) -> BlockScope: sref : StmtSRef The corresponding sref """ - return _ffi_api_schedule.ScheduleStateGetBlockScope( # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleStateGetBlockScope( # type: ignore # pylint: disable=no-member self, block_sref ) @@ -151,14 +166,14 @@ def _get_cached_flags(self, block_sref: StmtSRef) -> CachedFlags: Three flags: affine_binding, region_cover, stage_pipeline Note - ------- + ---- It is an API intended for internal testing use. """ ( affine_binding, region_cover, stage_pipeline, - ) = _ffi_api_schedule.ScheduleStateGetCachedFlags( # type: ignore # pylint: disable=no-member + ) = _ffi_api.ScheduleStateGetCachedFlags( # type: ignore # pylint: disable=no-member self, block_sref ) return CachedFlags( @@ -199,12 +214,12 @@ def replace( the sref that points to the old block will point to the new one Note - ---------- + ---- The reuse of loop srefs are detected automatically according to the reuse of loop vars. """ if block_sref_reuse is None: block_sref_reuse = {} - _ffi_api_schedule.ScheduleStateReplace( # type: ignore # pylint: disable=no-member + _ffi_api.ScheduleStateReplace( # type: ignore # pylint: disable=no-member self, src_sref, tgt_stmt, diff --git a/python/tvm/tir/schedule/testing.py b/python/tvm/tir/schedule/testing.py new file mode 100644 index 000000000000..66ede31f4103 --- /dev/null +++ b/python/tvm/tir/schedule/testing.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities for the TensorIR schedule API""" +from typing import Union + +from tvm import tir +from tvm.ir import IRModule, structural_equal +from tvm.tir import PrimFunc +from tvm.tir.schedule import Trace + + +def verify_trace_roundtrip( + sch: tir.Schedule, + mod: Union[PrimFunc, IRModule], + *, + debug_mask: Union[str, int] = "all", +) -> tir.Schedule: + """Serialize a traced schedule to JSON, then replay the JSON trace by applying to + a fresh new schedule, verifying the reproducibility of scheduling. + + Parameters + ---------- + sch : tir.Schedule + The traced TensorIR schedule to be verified + mod : Union[PrimFunc, IRModule] + The IRModule or PrimFunc to construct the fresh new schedule + debug_mask : Union[str, int] + Do extra correctness checking after the class creation and each time + after calling the Replace method. + Possible choices of `debug_mask`: + 1) "all" - Turn on all the checks + 2) "none" - Turn off all the checks + 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask + """ + # Step 1. Serialize the trace to JSON + trace = sch.trace + assert trace is not None + json_obj = trace.as_json() + # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling + new_sch = tir.Schedule(mod=mod, debug_mask=debug_mask) + Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) + assert structural_equal(new_sch.mod, sch.mod) + # Step 3. Check the consistency of the text format between the old and new traces + py_repr = "\n".join(trace.as_python()) + new_py_repr = "\n".join(new_sch.trace.as_python()) + assert py_repr == new_py_repr + # Step 4. Return the new schedule in case it could be useful + return new_sch diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py new file mode 100644 index 000000000000..18bcca373dbb --- /dev/null +++ b/python/tvm/tir/schedule/trace.py @@ -0,0 +1,260 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""An execution trace of a scheduling program""" +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional + +from tvm._ffi import register_object as _register_object +from tvm.runtime import Object + +from ...ir import Array, Map +from ...runtime import String +from ..expr import FloatImm, IntImm +from . import _ffi_api +from .instruction import ATTR_TYPE, INPUT_RV_TYPE, Instruction + +if TYPE_CHECKING: + from .schedule import Schedule + + +DECISION_TYPE = Any +JSON_TYPE = Any + + +def _json_from_tvm(obj): + if obj is None: + return None + if isinstance(obj, Array): + return [_json_from_tvm(i) for i in obj] + if isinstance(obj, Map): + return {_json_from_tvm(k): _json_from_tvm(v) for k, v in obj.items()} + if isinstance(obj, String): + return str(obj) + if isinstance(obj, (IntImm, FloatImm)): + return obj.value + raise TypeError("Not supported type: " + str(type(obj))) + + +@_register_object("tir.Trace") +class Trace(Object): + """An execution trace of a scheduling program. + + A trace has two parts: + 1) The instructions invoked so far + 2) The random decisions made upon those instructions, if any + + A trace can be serialized to: + 1) Roundtrippable JSON format: can be saved to file and loaded back + 2) Python syntax: allows users to copy-paste the trace to reproduce the scheduling process + + A trace can be applied to a TensorIR schedule by re-applying all its instructions possibly with + their decisions accordingly. Re-sampling is invoked if a sampling instruction doesn't have its + corresponding decision; Otherwise the existing decision will be reused accordingly. + + Attributes + ---------- + insts : List[Instruction] + The instructions invoked so far in the program execution + decisions : Dict[Instruction, DECISION_TYPE] + The random decisions made upon those instructions + """ + + insts: List[Instruction] + decisions: Dict[Instruction, DECISION_TYPE] + + def __init__( + self, + insts: List[Instruction], + decisions: Dict[Instruction, DECISION_TYPE], + ) -> None: + """Constructor + + Parameters + ---------- + insts : List[Instruction] + The instructions invoked so far in the program execution + decisions : Dict[Instruction, DECISION_TYPE] + The random decisions made upon those instructions + """ + self.__init_handle_by_constructor__( + _ffi_api.Trace, # type: ignore # pylint: disable=no-member + insts, + decisions, + ) + + def get_decision(self, inst: Instruction) -> Optional[DECISION_TYPE]: + """Retrieve the decision made on a specific instruction + + Parameters + ---------- + insts : Instruction + The instruction whose decision is to be retrieved + + Returns + ------- + decision : Optional[DECISION_TYPE] + The corresponding decision; None if there is no decision made on the instruction + """ + return _ffi_api.TraceGetDecision(self, inst) # type: ignore # pylint: disable=no-member + + def append( + self, + inst: Instruction, + decision: Optional[DECISION_TYPE] = None, + ) -> None: + """Append a new instruction to the trace + + Parameters + ---------- + insts : Instruction + The new instruction to be appended + decision : Optional[DECISION_TYPE] = None + The random decision made on this instruction + """ + _ffi_api.TraceAppend(self, inst, decision) # type: ignore # pylint: disable=no-member + + def pop(self) -> Optional[Instruction]: + """Remove the last instruction, along with the decision made on that instruction, if any + + Returns + ------- + popped_inst : Instruction + Returns the instruction removed; NullOpt if the trace is empty + """ + return _ffi_api.TracePop(self) # type: ignore # pylint: disable=no-member + + def apply_to_schedule( + self, + sch: "Schedule", + remove_postproc: bool, + decision_provider: Optional[ + Callable[ + [Instruction, List[INPUT_RV_TYPE], List[ATTR_TYPE], DECISION_TYPE], DECISION_TYPE + ] + ] = None, + ) -> None: + """Apply the trace to a TensorIR schedule + + Parameters + ---------- + sch : Schedule + The schedule to be applied onto + remove_postproc : bool + If postprocessing instructions are removed + decision_provider: Optional[Callable] = None + A callback that allows users to mutate decisions on the fly when applying instructions. + The signature of the callback is: + - The 1st argument: The instruction + - The 2nd argument: The input random variables + - The 3rd argument: The attributes + - The 4th argument: The decision + - Return: A new decision + """ + _ffi_api.TraceApplyToSchedule( # type: ignore # pylint: disable=no-member + self, + sch, + remove_postproc, + decision_provider, + ) + + def as_json(self, remove_postproc: bool = False) -> JSON_TYPE: + """Serialize the trace as a JSON-style object + + Parameters + ---------- + remove_postproc : bool = False + If postprocessing instructions are removed + + Returns + ------- + json: JSON_TYPE + The JSON-style object + """ + obj = _ffi_api.TraceAsJSON(self, remove_postproc) # type: ignore # pylint: disable=no-member + return _json_from_tvm(obj) + + def as_python(self, remove_postproc: bool = False) -> List[str]: + """Serialize the trace as a sequence of python statements + + Parameters + ---------- + remove_postproc : bool = False + If postprocessing instructions are removed + + Returns + ------- + py_stmts: List[str] + A sequence of python statements + """ + return _ffi_api.TraceAsPython(self, remove_postproc) # type: ignore # pylint: disable=no-member + + def with_decision( + self, + inst: Instruction, + decision: DECISION_TYPE, + remove_postproc: bool, + ) -> "Trace": + """Create a new trace with an instruction whose decision is changed, + assuming this instruction exists in the resulting trace + + Parameters + ---------- + inst : Instruction + The instruction whose decision is to be changed + decision : DECISION_TYPE + The decision to be changed to + remove_postproc : bool + If postprocessing instructions are removed + + Returns + ------- + trace: Trace + The new trace with the decision changed + """ + return _ffi_api.TraceWithDecision( # type: ignore # pylint: disable=no-member + self, + inst, + decision, + remove_postproc, + ) + + def simplified(self, remove_postproc: bool) -> "Trace": + """Simplify the trace with dead-code elimination + + Parameters + ---------- + remove_postproc : bool + If postprocessing instructions are removed + + Returns + ------- + trace: Trace + A simplified trace + """ + return _ffi_api.TraceSimplified(self, remove_postproc) # type: ignore # pylint: disable=no-member + + @staticmethod + def apply_json_to_schedule(json_obj: JSON_TYPE, sch: "Schedule") -> None: + """Apply a JSON-serialized trace to a TensorIR schedule + + Parameters + ---------- + json_obj : JSON_TYPE + The JSON-serialized trace + sch : Schedule + The TensorIR schedule + """ + _ffi_api.TraceApplyJSONToSchedule(json_obj, sch) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index dd7665a56692..d57077f08b52 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -62,7 +62,9 @@ class LetStmt(Stmt): """ def __init__(self, var, value, body, span=None): - self.__init_handle_by_constructor__(_ffi_api.LetStmt, var, value, body, span) + self.__init_handle_by_constructor__( + _ffi_api.LetStmt, var, value, body, span # type: ignore + ) @tvm._ffi.register_object("tir.AssertStmt") @@ -85,7 +87,9 @@ class AssertStmt(Stmt): """ def __init__(self, condition, message, body, span=None): - self.__init_handle_by_constructor__(_ffi_api.AssertStmt, condition, message, body, span) + self.__init_handle_by_constructor__( + _ffi_api.AssertStmt, condition, message, body, span # type: ignore + ) class ForKind(IntEnum): @@ -148,7 +152,7 @@ def __init__( span=None, ): self.__init_handle_by_constructor__( - _ffi_api.For, + _ffi_api.For, # type: ignore loop_var, min_val, extent, @@ -178,7 +182,7 @@ class While(Stmt): def __init__(self, condition, body, span=None): self.__init_handle_by_constructor__( - _ffi_api.While, + _ffi_api.While, # type: ignore condition, body, span, @@ -209,9 +213,9 @@ class Store(Stmt): def __init__(self, buffer_var, value, index, predicate=None, span=None): if predicate is None: - predicate = _ffi_api.const_true(value.dtype, span) + predicate = _ffi_api.const_true(value.dtype, span) # type: ignore self.__init_handle_by_constructor__( - _ffi_api.Store, buffer_var, value, index, predicate, span + _ffi_api.Store, buffer_var, value, index, predicate, span # type: ignore ) @@ -235,7 +239,9 @@ class BufferStore(Stmt): """ def __init__(self, buffer, value, indices, span=None): - self.__init_handle_by_constructor__(_ffi_api.BufferStore, buffer, value, indices, span) + self.__init_handle_by_constructor__( + _ffi_api.BufferStore, buffer, value, indices, span # type: ignore + ) @tvm._ffi.register_object("tir.BufferRealize") @@ -262,7 +268,7 @@ class BufferRealize(Stmt): def __init__(self, buffer, bounds, condition, body, span=None): self.__init_handle_by_constructor__( - _ffi_api.BufferRealize, buffer, bounds, condition, body, span + _ffi_api.BufferRealize, buffer, bounds, condition, body, span # type: ignore ) @@ -286,7 +292,9 @@ class ProducerStore(Stmt): """ def __init__(self, producer, value, indices, span=None): - self.__init_handle_by_constructor__(_ffi_api.ProducerStore, producer, value, indices, span) + self.__init_handle_by_constructor__( + _ffi_api.ProducerStore, producer, value, indices, span # type: ignore + ) @tvm._ffi.register_object("tir.Allocate") @@ -316,7 +324,7 @@ class Allocate(Stmt): def __init__(self, buffer_var, dtype, extents, condition, body, span=None): self.__init_handle_by_constructor__( - _ffi_api.Allocate, buffer_var, dtype, extents, condition, body, span + _ffi_api.Allocate, buffer_var, dtype, extents, condition, body, span # type: ignore ) @@ -343,7 +351,9 @@ class AttrStmt(Stmt): """ def __init__(self, node, attr_key, value, body, span=None): - self.__init_handle_by_constructor__(_ffi_api.AttrStmt, node, attr_key, value, body, span) + self.__init_handle_by_constructor__( + _ffi_api.AttrStmt, node, attr_key, value, body, span # type: ignore + ) @tvm._ffi.register_object("tir.ProducerRealize") @@ -364,13 +374,22 @@ class ProducerRealize(Stmt): body : Stmt The realize body + storage_scope : str + The storage scope associated with this realization + span : Optional[Span] The location of this itervar in the source code. """ - def __init__(self, producer, bounds, condition, body, span=None): + def __init__(self, producer, bounds, condition, body, storage_scope="", span=None): self.__init_handle_by_constructor__( - _ffi_api.ProducerRealize, producer, bounds, condition, body, span + _ffi_api.ProducerRealize, + producer, + bounds, + condition, + body, + storage_scope, + span, # type: ignore ) @@ -388,7 +407,7 @@ class SeqStmt(Stmt): """ def __init__(self, seq, span=None): - self.__init_handle_by_constructor__(_ffi_api.SeqStmt, seq, span) + self.__init_handle_by_constructor__(_ffi_api.SeqStmt, seq, span) # type: ignore def __getitem__(self, i): return self.seq[i] @@ -418,7 +437,7 @@ class IfThenElse(Stmt): def __init__(self, condition, then_case, else_case, span=None): self.__init_handle_by_constructor__( - _ffi_api.IfThenElse, condition, then_case, else_case, span + _ffi_api.IfThenElse, condition, then_case, else_case, span # type: ignore ) @@ -436,7 +455,7 @@ class Evaluate(Stmt): """ def __init__(self, value, span=None): - self.__init_handle_by_constructor__(_ffi_api.Evaluate, value, span) + self.__init_handle_by_constructor__(_ffi_api.Evaluate, value, span) # type: ignore @tvm._ffi.register_object("tir.Prefetch") @@ -456,7 +475,7 @@ class Prefetch(Stmt): """ def __init__(self, buffer, bounds, span=None): - self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds, span) + self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds, span) # type: ignore @tvm._ffi.register_object("tir.BufferRegion") @@ -476,7 +495,7 @@ class BufferRegion(Object): region: List[Range] def __init__(self, buffer: Buffer, region: List[Range]): - self.__init_handle_by_constructor__(_ffi_api.BufferRegion, buffer, region) + self.__init_handle_by_constructor__(_ffi_api.BufferRegion, buffer, region) # type: ignore @tvm._ffi.register_object("tir.MatchBufferRegion") @@ -496,7 +515,9 @@ class MatchBufferRegion(Object): source: BufferRegion def __init__(self, buffer: Buffer, source: BufferRegion): - self.__init_handle_by_constructor__(_ffi_api.MatchBufferRegion, buffer, source) + self.__init_handle_by_constructor__( + _ffi_api.MatchBufferRegion, buffer, source # type: ignore + ) @tvm._ffi.register_object("tir.Block") @@ -567,7 +588,7 @@ def __init__( if annotations is None: annotations = {} self.__init_handle_by_constructor__( - _ffi_api.Block, + _ffi_api.Block, # type: ignore iter_vars, reads, writes, @@ -578,7 +599,7 @@ def __init__( match_buffers, annotations, span, - ) + ) # type: ignore @tvm._ffi.register_object("tir.BlockRealize") @@ -615,12 +636,12 @@ def __init__( if isinstance(predicate, bool): predicate = const(predicate, "bool") self.__init_handle_by_constructor__( - _ffi_api.BlockRealize, + _ffi_api.BlockRealize, # type: ignore iter_values, predicate, block, span, - ) + ) # type: ignore def stmt_seq(*args): diff --git a/python/tvm/tir/stmt_functor.py b/python/tvm/tir/stmt_functor.py index 4ec755cdf922..56dc1c20c2b3 100644 --- a/python/tvm/tir/stmt_functor.py +++ b/python/tvm/tir/stmt_functor.py @@ -43,7 +43,7 @@ def ir_transform(stmt, preorder, postorder, only_enable=None): result : tvm.tir.Stmt The result. """ - return _ffi_api.IRTransform(stmt, preorder, postorder, only_enable) + return _ffi_api.IRTransform(stmt, preorder, postorder, only_enable) # type: ignore def post_order_visit(stmt, fvisit): @@ -55,7 +55,7 @@ def post_order_visit(stmt, fvisit): fvisit: function The visitor function. """ - return _ffi_api.PostOrderVisit(stmt, fvisit) + return _ffi_api.PostOrderVisit(stmt, fvisit) # type: ignore def substitute(node, vmap): @@ -74,4 +74,4 @@ def substitute(node, vmap): result : tvm.tir.Stmt The result. """ - return _ffi_api.Substitute(node, vmap) + return _ffi_api.Substitute(node, vmap) # type: ignore diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py index 374e731725be..9450ade34e67 100644 --- a/python/tvm/tir/transform/function_pass.py +++ b/python/tvm/tir/transform/function_pass.py @@ -18,6 +18,7 @@ import inspect import types import functools +from typing import Callable, List, Optional, Union import tvm._ffi from tvm.ir.transform import Pass, PassInfo @@ -47,7 +48,10 @@ def __init__(self, *args, **kwargs): def _pass_func(func, mod, ctx): return inst.transform_function(func, mod, ctx) - self.__init_handle_by_constructor__(_ffi_api.CreatePrimFuncPass, _pass_func, pass_info) + self.__init_handle_by_constructor__( + _ffi_api.CreatePrimFuncPass, _pass_func, pass_info # type: ignore + ) + self._inst = inst def __getattr__(self, name): @@ -61,7 +65,12 @@ def __getattr__(self, name): return PyFunctionPass -def prim_func_pass(pass_func=None, opt_level=None, name=None, required=None): +def prim_func_pass( + pass_func=None, + opt_level: int = None, + name: Optional[str] = None, + required: Optional[List[str]] = None, +) -> Union[Callable, PrimFuncPass]: """Decorate a function pass. This function returns a callback when pass_func @@ -123,7 +132,7 @@ def transform(func, mod, ctx): assert isinstance(function_pass, transform.FunctionPass) assert function_pass.info.opt_level == 2 - # Given a module m, the optimization could be invoked as the follwoing: + # Given a module m, the optimization could be invoked as the following: updated_mod = function_pass(m) # Now constant folding should have been applied to every function in # the provided module m. And the updated module will be returned. @@ -144,7 +153,7 @@ def create_function_pass(pass_arg): return _wrap_class_function_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): raise TypeError("pass_func must be a callable for Module pass") - return _ffi_api.CreatePrimFuncPass(pass_arg, info) + return _ffi_api.CreatePrimFuncPass(pass_arg, info) # type: ignore if pass_func: return create_function_pass(pass_func) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 51330f80afc6..2183319a006f 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -16,6 +16,7 @@ # under the License. """Wrapping existing transformations.""" # pylint: disable=invalid-name +from typing import Optional from . import _ffi_api from . import function_pass as _fpass @@ -39,7 +40,7 @@ def Apply(ftransform): def _transform(func, mod, ctx): return ftransform(func) - return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") + return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") # type: ignore def Filter(fcond): @@ -59,7 +60,7 @@ def Filter(fcond): def _transform(func, mod, ctx): return func if fcond(func) else None - return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") + return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") # type: ignore def InjectPrefetch(): @@ -70,10 +71,10 @@ def InjectPrefetch(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.InjectPrefetch() + return _ffi_api.InjectPrefetch() # type: ignore -def StorageFlatten(cache_line_size, create_bound_attribute=False): +def StorageFlatten(cache_line_size, create_bound_attribute: bool = False): """Flatten the multi-dimensional read/write to 1D. @@ -91,10 +92,25 @@ def StorageFlatten(cache_line_size, create_bound_attribute=False): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.StorageFlatten(cache_line_size, create_bound_attribute) + return _ffi_api.StorageFlatten(cache_line_size, create_bound_attribute) # type: ignore -def InjectCopyIntrin(pragma_key, fintrin): +def TextureFlatten(): + """Flatten the multi-dimensional read/write to 2D. + + + Parameters + ---------- + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.TextureFlatten() # type: ignore + + +def InjectCopyIntrin(pragma_key: str, fintrin): """Inject virtual thread loops. Parameters @@ -110,7 +126,7 @@ def InjectCopyIntrin(pragma_key, fintrin): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.InjectCopyIntrin(pragma_key, fintrin) + return _ffi_api.InjectCopyIntrin(pragma_key, fintrin) # type: ignore def CoProcSync(): @@ -121,10 +137,10 @@ def CoProcSync(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.CoProcSync() + return _ffi_api.CoProcSync() # type: ignore -def LiftAttrScope(attr_key): +def LiftAttrScope(attr_key: str): """Lift common attrs with attr_key to outer scope. Parameters @@ -137,7 +153,7 @@ def LiftAttrScope(attr_key): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LiftAttrScope(attr_key) + return _ffi_api.LiftAttrScope(attr_key) # type: ignore def LoopPartition(): @@ -148,10 +164,10 @@ def LoopPartition(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LoopPartition() + return _ffi_api.LoopPartition() # type: ignore -def VectorizeLoop(enable_vectorize=True): +def VectorizeLoop(enable_vectorize: bool = True): """Lower vectorization loops. Parameters @@ -165,7 +181,7 @@ def VectorizeLoop(enable_vectorize=True): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.VectorizeLoop(enable_vectorize) + return _ffi_api.VectorizeLoop(enable_vectorize) # type: ignore def InjectVirtualThread(): @@ -176,7 +192,7 @@ def InjectVirtualThread(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.InjectVirtualThread() + return _ffi_api.InjectVirtualThread() # type: ignore def InjectDoubleBuffer(): @@ -187,7 +203,7 @@ def InjectDoubleBuffer(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.InjectDoubleBuffer() + return _ffi_api.InjectDoubleBuffer() # type: ignore def StorageRewrite(): @@ -202,7 +218,7 @@ def StorageRewrite(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.StorageRewrite() + return _ffi_api.StorageRewrite() # type: ignore def UnrollLoop(): @@ -215,7 +231,7 @@ def UnrollLoop(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.UnrollLoop() + return _ffi_api.UnrollLoop() # type: ignore def RemoveNoOp(): @@ -226,7 +242,7 @@ def RemoveNoOp(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.RemoveNoOp() + return _ffi_api.RemoveNoOp() # type: ignore def BF16Legalize(): @@ -238,7 +254,7 @@ def BF16Legalize(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.BF16Legalize() + return _ffi_api.BF16Legalize() # type: ignore def BF16Promote(): @@ -250,7 +266,7 @@ def BF16Promote(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.BF16Promote() + return _ffi_api.BF16Promote() # type: ignore def BF16CastElimination(): @@ -269,7 +285,7 @@ def BF16CastElimination(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.BF16CastElimination() + return _ffi_api.BF16CastElimination() # type: ignore def BF16TypeLowering(): @@ -281,7 +297,7 @@ def BF16TypeLowering(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.BF16TypeLowering() + return _ffi_api.BF16TypeLowering() # type: ignore def RewriteUnsafeSelect(): @@ -292,7 +308,7 @@ def RewriteUnsafeSelect(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.RewriteUnsafeSelect() + return _ffi_api.RewriteUnsafeSelect() # type: ignore def Simplify(): @@ -303,7 +319,7 @@ def Simplify(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.Simplify() + return _ffi_api.Simplify() # type: ignore def InstrumentBoundCheckers(): @@ -314,7 +330,7 @@ def InstrumentBoundCheckers(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.InstrumentBoundCheckers() + return _ffi_api.InstrumentBoundCheckers() # type: ignore def LowerCustomDatatypes(): @@ -327,24 +343,25 @@ def LowerCustomDatatypes(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerCustomDatatypes() + return _ffi_api.LowerCustomDatatypes() # type: ignore -def MakePackedAPI(num_unpacked_params=0): +def MakePackedAPI(num_unpacked_params: int = -1): """Transform the PrimFuncs in the module to a packed func API. Parameters ---------- num_unpacked_params : int Number of parameters that we hope to directly pass via normal arguments - following the PackedFunc input signature. + following the PackedFunc input signature. If it is specified as -1 or it + is less than the number of arguments, the pass will packed arguments still. Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.MakePackedAPI(num_unpacked_params) + return _ffi_api.MakePackedAPI(num_unpacked_params) # type: ignore def MakeUnpackedAPI(): @@ -355,7 +372,7 @@ def MakeUnpackedAPI(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.MakeUnpackedAPI() + return _ffi_api.MakeUnpackedAPI() # type: ignore def SplitHostDevice(): @@ -366,7 +383,7 @@ def SplitHostDevice(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.SplitHostDevice() + return _ffi_api.SplitHostDevice() # type: ignore def DecorateDeviceScope(): @@ -377,7 +394,7 @@ def DecorateDeviceScope(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.DecorateDeviceScope() + return _ffi_api.DecorateDeviceScope() # type: ignore def SkipAssert(): @@ -388,10 +405,10 @@ def SkipAssert(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.SkipAssert() + return _ffi_api.SkipAssert() # type: ignore -def ThreadSync(storage_scope): +def ThreadSync(storage_scope: str): """Insert sync between parallel read/write of shared buffers. Parameters @@ -404,7 +421,7 @@ def ThreadSync(storage_scope): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.ThreadSync(storage_scope) + return _ffi_api.ThreadSync(storage_scope) # type: ignore def LowerThreadAllreduce(): @@ -415,7 +432,7 @@ def LowerThreadAllreduce(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerThreadAllreduce() + return _ffi_api.LowerThreadAllreduce() # type: ignore def InferFragment(): @@ -426,7 +443,7 @@ def InferFragment(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.InferFragment() + return _ffi_api.InferFragment() # type: ignore def LowerWarpMemory(): @@ -437,7 +454,7 @@ def LowerWarpMemory(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerWarpMemory() + return _ffi_api.LowerWarpMemory() # type: ignore def LowerTVMBuiltin(): @@ -448,7 +465,7 @@ def LowerTVMBuiltin(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerTVMBuiltin() + return _ffi_api.LowerTVMBuiltin() # type: ignore def LegalizePackedCalls(): @@ -459,7 +476,7 @@ def LegalizePackedCalls(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LegalizePackedCalls() + return _ffi_api.LegalizePackedCalls() # type: ignore def LowerIntrin(): @@ -470,7 +487,7 @@ def LowerIntrin(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerIntrin() + return _ffi_api.LowerIntrin() # type: ignore def LowerDeviceStorageAccessInfo(): @@ -485,7 +502,7 @@ def LowerDeviceStorageAccessInfo(): ---- Run this pass after all storage access analysis finish. """ - return _ffi_api.LowerDeviceStorageAccessInfo() + return _ffi_api.LowerDeviceStorageAccessInfo() # type: ignore def CombineContextCall(): @@ -496,10 +513,10 @@ def CombineContextCall(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.CombineContextCall() + return _ffi_api.CombineContextCall() # type: ignore -def NarrowDataType(target_bits): +def NarrowDataType(target_bits: int): """Narrow down PrimExpr datatype in stmt to target_bits. Parameters @@ -516,7 +533,7 @@ def NarrowDataType(target_bits): ---- Run this pass after StorageFlatten. """ - return _ffi_api.NarrowDataType(target_bits) + return _ffi_api.NarrowDataType(target_bits) # type: ignore def VerifyMemory(): @@ -527,12 +544,12 @@ def VerifyMemory(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.VerifyMemory() + return _ffi_api.VerifyMemory() # type: ignore # pylint: disable=no-else-return,inconsistent-return-statements -def HoistIfThenElse(variant=None): - """Hoist loop-invariant IfThenElse nodes to outside the elligible loops. +def HoistIfThenElse(variant: Optional[str] = None): + """Hoist loop-invariant IfThenElse nodes to outside the eligible loops. Parameters ---------- @@ -540,7 +557,7 @@ def HoistIfThenElse(variant=None): The variant of the pass. variant can have any one of following values ["basic", None(Default)]. - The basic variant supports basic hoisting scenarios where it exepects + The basic variant supports basic hoisting scenarios where it expects the For & If Nodes are in place consecutively and does not involve global scope variables or more advanced scenarios. @@ -555,20 +572,20 @@ def HoistIfThenElse(variant=None): The result pass """ if variant == "basic": - return _ffi_api.HoistIfThenElseBasic() + return _ffi_api.HoistIfThenElseBasic() # type: ignore elif variant is None: - return _ffi_api.HoistIfThenElse() + return _ffi_api.HoistIfThenElse() # type: ignore def LowerInitBlock(): - """Lower block init stmt into IfThenElse stmts + """Lower block init stmt into IfThenElse statements. Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerInitBlock() + return _ffi_api.LowerInitBlock() # type: ignore def PlanAndUpdateBufferAllocationLocation(): @@ -581,7 +598,7 @@ def PlanAndUpdateBufferAllocationLocation(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.PlanAndUpdateBufferAllocationLocation() + return _ffi_api.PlanAndUpdateBufferAllocationLocation() # type: ignore def ConvertBlocksToOpaque(): @@ -594,7 +611,7 @@ def ConvertBlocksToOpaque(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.ConvertBlocksToOpaque() + return _ffi_api.ConvertBlocksToOpaque() # type: ignore def CompactBufferAllocation(): @@ -639,7 +656,18 @@ def CompactBufferAllocation(): The result pass """ - return _ffi_api.CompactBufferAllocation() + return _ffi_api.CompactBufferAllocation() # type: ignore + + +def LowerMatchBuffer(): + """Remove match buffers inside the block. Also, it will validate the binding. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerMatchBuffer() # type: ignore def FlattenBuffer(): @@ -652,4 +680,38 @@ def FlattenBuffer(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.FlattenBuffer() + return _ffi_api.FlattenBuffer() # type: ignore + + +def UnifyThreadBinding(): + """Unify all the thread bindings for "blockIdx.x/y/z", + "threadIdx.x/y/z", and "vthread.x/y/z". Before the unification, + two vars that are bound to a thread axis (e.g., "threadIdx.x") + use different IterVars and variables in their AttrStmts. After + the unification, we use a consolidated IterVar and a variable + for them. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + + Note + ---- + `vthread` is a legacy behavior that will be deprecated, though + thread bindings of `vthread` are still also unified in this + pass. Please use `vthread.x`, `vthread.y` and `vthread.z` instead. + """ + return _ffi_api.UnifyThreadBinding() # type: ignore + + +def MergeDynamicSharedMemoryAllocations(): + """This pass merges multiple TIR-level dynamic shared memory allocations + into one allocation. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore diff --git a/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py b/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py index 7651305ab2dd..4eed56a22572 100644 --- a/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py +++ b/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py @@ -273,7 +273,9 @@ def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_ data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0]) # ==================== define configuration space ==================== - n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW) + # If it has dynamic shape in batch, we fix the split factor to 1 + n = cfg.axis(N) if isinstance(N, int) else cfg.axis(1) + oc, oh, ow = cfg.axis(OC), cfg.axis(OH), cfg.axis(OW) ic, kh, kw = cfg.reduce_axis(IC), cfg.reduce_axis(KH), cfg.reduce_axis(KW) if num_tile == 2: # for arm cpu diff --git a/python/tvm/topi/arm_cpu/group_conv2d.py b/python/tvm/topi/arm_cpu/group_conv2d.py index d852b9acef66..81b2c7260f05 100644 --- a/python/tvm/topi/arm_cpu/group_conv2d.py +++ b/python/tvm/topi/arm_cpu/group_conv2d.py @@ -42,7 +42,9 @@ def schedule_group_conv2d_nchw(outs): return schedule_group_conv2d_nchwc(outs) -def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, layout="NCHW"): +def _get_default_config( + cfg, data, kernel, strides, padding, dilation, groups, out_dtype, layout="NCHW" +): """ Get default schedule config for the workload """ @@ -54,7 +56,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, static_data_shape.append(dim) data = te.placeholder(static_data_shape, dtype=data.dtype) - wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) + wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype, layout) _fallback_schedule(cfg, wkl) @@ -158,6 +160,7 @@ def group_conv2d_nchw_spatial_pack( ), strides, padding, + dilation, groups, out_dtype, ) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index fb91912f29a0..bd556d2976da 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -27,9 +27,49 @@ @autotvm.register_topi_compute("batch_matmul.cuda") -def batch_matmul(cfg, x, y, out_shape=None): - """Compute conv2d with NCHW layout""" - return nn.batch_matmul(x, y) +def batch_matmul(cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True): + """Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + + Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. + + Parameters + ---------- + cfg : ConfigSpace + Autotvm tuning space config file. + + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. + + Returns + ------- + output : tvm.te.Tensor + 3-D with shape [batch, M, N] + """ + return nn.batch_matmul( + x, + y, + oshape=out_shape, + out_dtype=out_dtype, + transpose_a=transpose_a, + transpose_b=transpose_b, + ) @autotvm.register_topi_schedule("batch_matmul.cuda") @@ -140,31 +180,54 @@ def _callback(op): @autotvm.register_topi_compute("batch_matmul_cublas.cuda") -def batch_matmul_cublas(cfg, x, y, out_shape=None): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. +def batch_matmul_cublas( + cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): + """Compute batch matrix multiplication of `x` and `y`. + + Both `x` and `y` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. Parameters ---------- + cfg : ConfigSpace + Autotvm tuning space config file. + x : tvm.te.Tensor - 3-D with shape [batch, M, K] + 3-D with shape [batch, M, K] or [batch, K, M]. y : tvm.te.Tensor - 3-D with shape [batch, N, K] + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. - out_shape : None - The output shape + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. Returns ------- output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - b, m, k = get_const_tuple(x.shape) - b, n, k = get_const_tuple(y.shape) + if transpose_a: + b, k, m = get_const_tuple(x.shape) + else: + b, m, k = get_const_tuple(x.shape) + if transpose_b: + b, n, k = get_const_tuple(y.shape) + else: + b, k, n = get_const_tuple(y.shape) if all([isinstance(s, int) for s in [b, m, n, k]]): cfg.add_flop(b * m * k * n * 2) - return cublas.batch_matmul(x, y, False, True) + return cublas.batch_matmul(x, y, transa=transpose_a, transb=transpose_b) @autotvm.register_topi_schedule("batch_matmul_cublas.cuda") @@ -174,8 +237,43 @@ def schedule_batch_matmul_cublas(_, outs): @autotvm.register_topi_compute("batch_matmul_int8.cuda") -def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None): - """Batch Matmul operator for int8 on CUDA""" +def batch_matmul_int8( + cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): + """Batch Matmul operator for int8 on CUDA. + + Parameters + ---------- + cfg : ConfigSpace + Autotvm tuning space config file. + + x : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + y : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. + + Returns + ------- + output : tvm.te.Tensor + 3-D with shape [batch, M, N] + """ + del out_shape + # TODO(jcf94): Deal with different transpose combinations + assert not transpose_a and transpose_b if out_dtype is None: out_dtype = x.dtype diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index a56d3c36ba33..5324302051ba 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -29,9 +29,14 @@ @autotvm.register_topi_compute("batch_matmul_tensorcore.cuda") -def batch_matmul_tensorcore(cfg, x, y, out_shape=None, out_dtype=None): +def batch_matmul_tensorcore( + cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """batch matmul tensorcore operator on cuda""" - # todo: deal with out_shape for broadcast, liuxin.ai + # TODO(jcf94): Deal with different transpose combinations + assert not transpose_a and transpose_b + # TODO(liuxin.ai): Deal with out_shape for broadcast + del out_shape return batch_matmul_tensorcore_cuda(x, y, out_dtype) diff --git a/python/tvm/topi/cuda/conv1d_transpose_ncw.py b/python/tvm/topi/cuda/conv1d_transpose_ncw.py index 58f53eab20ac..2098aa9089c6 100644 --- a/python/tvm/topi/cuda/conv1d_transpose_ncw.py +++ b/python/tvm/topi/cuda/conv1d_transpose_ncw.py @@ -142,7 +142,7 @@ def _callback(op): ##### space definition begin ##### n, f, x = s[conv].op.axis rc = s[conv].op.reduce_axis[0] - cfg.define_split("tile_n", cfg.axis(n), num_outputs=4) + cfg.define_split("tile_n", cfg.axis(n if isinstance(n, int) else 1), num_outputs=4) cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index a199534ccb51..8338208dd968 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -86,14 +86,13 @@ def conv2d_cudnn( # handle dilation stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation + KH_dilated = (KH - 1) * dilation_h + 1 + KW_dilated = (KW - 1) * dilation_h + 1 - if ( - isinstance(padding, (list, tuple)) - and len(padding) == 4 - and (padding[0] != padding[2] or padding[1] != padding[3]) - ): + pt, pl, pb, pr = get_pad_tuple(padding, (KH_dilated, KW_dilated)) + if (pt != pb) or (pl != pr): raise ValueError("Cudnn doesn't support asymmetric padding.") - pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) + OH = (H + pt + pb - KH) // stride_h + 1 OW = (W + pl + pr - KW) // stride_w + 1 diff --git a/python/tvm/topi/cuda/conv2d_int8.py b/python/tvm/topi/cuda/conv2d_int8.py index 001411d6e4c9..02470bab5228 100644 --- a/python/tvm/topi/cuda/conv2d_int8.py +++ b/python/tvm/topi/cuda/conv2d_int8.py @@ -98,9 +98,9 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_ ) out_channels, in_channels, kernel_h, kernel_w = get_const_tuple(kernel.shape) - assert out_channels % 4 == 0, "Number of output channels should be multiple of {}".format( - oc_block_factor - ) + assert ( + out_channels % oc_block_factor == 0 + ), "Number of output channels should be multiple of {}".format(oc_block_factor) packed_kernel = te.compute( ( out_channels // oc_block_factor, diff --git a/python/tvm/topi/cuda/conv2d_nhwc.py b/python/tvm/topi/cuda/conv2d_nhwc.py index 991585587bbf..e4361e30b5c3 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc.py +++ b/python/tvm/topi/cuda/conv2d_nhwc.py @@ -43,12 +43,15 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): AL = s.cache_read(AA, "local", [OL]) WL = s.cache_read(WW, "local", [OL]) + # Currently Conv2d NHWC only support dynamic shpe in batch + dynamic_batch = isinstance(s[output].op.axis[0].dom.extent, tvm.tir.expr.Var) + # Schedule for autotvm - cfg.define_knob("tile_n", [2, 4, 8]) + cfg.define_knob("tile_n", [1] if dynamic_batch else [2, 4, 8]) cfg.define_knob("tile_c", [2, 4, 8]) - cfg.define_knob("num_thread_n", [4, 8, 16]) + cfg.define_knob("num_thread_n", [1] if dynamic_batch else [4, 8, 16]) cfg.define_knob("num_thread_c", [4, 8, 16]) - cfg.define_knob("vthread_n", [1, 2]) + cfg.define_knob("vthread_n", [1] if dynamic_batch else [1, 2]) cfg.define_knob("vthread_c", [1, 2]) cfg.define_knob("step", [16, 3, 32, 64]) diff --git a/python/tvm/topi/cuda/injective.py b/python/tvm/topi/cuda/injective.py index cce56b796cea..0faddc31c25a 100644 --- a/python/tvm/topi/cuda/injective.py +++ b/python/tvm/topi/cuda/injective.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name, unused-variable, """Schedule for composition of injective operator""" +import numpy as np + import tvm from tvm import te from .. import utils @@ -36,13 +38,21 @@ def schedule_injective_from_existing(sch, out): sch: Schedule The updated schedule. """ + + def find_nearest_small_factor(num, target): + """Find the nearest factor of the given number that is smaller than the target.""" + for i in range(target, 0, -1): + if num % i == 0: + return i + # Unreachable because i=1 must hold. + return -1 + fused = sch[out].fuse(*sch[out].op.axis) num_thread = tvm.target.Target.current(allow_none=False).max_num_threads max_block = 256 - # vectorize on fp16 data type. This allows to better utilize the memory - # bandwidth. - vector_width = 4 if out.dtype == "float16" else 1 + # Vectorize on fp16 data type to enable half2 for better memory bandwidth utilization. + vector_width = 2 if out.dtype == "float16" else 1 is_dynamic_output = False for dim in out.shape: @@ -54,6 +64,26 @@ def schedule_injective_from_existing(sch, out): try: const_size = utils.get_const_int(out_len) + + # Adjust block and thread to make sure they are dividable so that vectorize can be + # correctly applied. + if vector_width > 1 and const_size % vector_width == 0: + remain_total_size = const_size // vector_width + cand_sizes = [] + for max_size in [num_thread, max_block]: + cand_sizes.append( + max_size + if remain_total_size % max_size == 0 + else find_nearest_small_factor(remain_total_size, max_size) + ) + remain_total_size //= cand_sizes[-1] + + # If the product of candidate dividable (block * thread) is too small, + # then the performance may be worse even half2 is enabled. Note that 0.7 + # is just a heuristic ratio and may not be optimal for all workloads. + if np.prod(cand_sizes) / (max_block * num_thread) >= 0.7: + num_thread, max_block = cand_sizes + need_block_split = const_size > max_block * num_thread * vector_width except ValueError: need_block_split = False diff --git a/python/tvm/topi/cuda/reduction.py b/python/tvm/topi/cuda/reduction.py index b9d02d9c81d8..fb6a3bfd5174 100644 --- a/python/tvm/topi/cuda/reduction.py +++ b/python/tvm/topi/cuda/reduction.py @@ -17,6 +17,8 @@ # pylint: disable=invalid-name,unused-variable,too-many-locals,len-as-condition """Schedule for reduce operators""" from __future__ import absolute_import as _abs +from operator import mul +from functools import reduce import tvm from tvm import te from .. import tag @@ -80,16 +82,40 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): if is_idx_reduce: sch[temp_idx_input].compute_at(sch[real_output], outer_in) sch[temp_val_input].compute_at(sch[real_output], outer_in) + sch[real_output].set_store_predicate( + tvm.tir.all( + thread_x.equal(0), block_x * num_thread + thread_y < reduce(mul, real_output.shape) + ) + ) else: if is_idx_reduce: spatial_axis = sch[real_output].fuse(*(sch[real_output].op.axis)) sch[real_output].bind(spatial_axis, te.thread_axis("blockIdx.x")) sch[temp_idx_input].compute_at(sch[real_output], spatial_axis) sch[temp_val_input].compute_at(sch[real_output], spatial_axis) - sch[real_output].set_store_predicate(thread_x.equal(0)) + sch[real_output].set_store_predicate(thread_x.equal(0)) return sch +def _enable_auto_inline(sch): + def is_scheduled(stage): + # auto inline requires the attach type is AttachType.kGroupRoot + conds = [ + len(stage.relations) == 0, + stage.attach_type == 1, + stage.all_iter_vars == stage.leaf_iter_vars, + ] + if not all(conds): + return True + return False + + for s in sch.stages: + if not s.is_output and isinstance(s.op, tvm.te.ComputeOp): + if is_scheduled(s) or len(s.op.reduce_axis) != 0: + return False + return True + + def schedule_reduce(outs): """Schedule for inject->reduce->bcast ops. @@ -107,6 +133,7 @@ def schedule_reduce(outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs sch = te.create_schedule([x.op for x in outs]) scheduled_ops = [] + enable_auto_inline = _enable_auto_inline(sch) def traverse_before_reduce(operator): """Internal traverse function""" @@ -128,7 +155,11 @@ def traverse_after_reduce(operator): if operator not in scheduled_ops: schedule_injective_from_existing(sch, operator.output(0)) for tensor in operator.input_tensors: - traverse_after_reduce(tensor.op) + if tensor.op not in scheduled_ops: + if enable_auto_inline: + traverse_before_reduce(tensor.op) + else: + traverse_after_reduce(tensor.op) elif operator.tag == "comm_reduce": if operator not in scheduled_ops: _schedule_reduce(operator, sch, is_idx_reduce=False) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index cee13d7e01a2..fa7545cd323a 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -772,9 +772,10 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) - # We combine all the indices dimensions but the first one into a single - # dimension so we can iterate it in single loop instead of an arbitrary - # number of loops. We do the same thing for all the update dimensions. + atomic_add_return = ib.allocate( + updates.dtype, (1,), name="atomic_add_return", scope="local" + ) + fused_indices_dimension = 1 for i in indices_ptr.shape[1:]: fused_indices_dimension *= i @@ -787,42 +788,91 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): for i in data_ptr.shape: fused_shape *= i - # For now we avoid parallizing over dimensions indexed by `indices` as - # there may be repeated indices and hadling parallel accumulation can - # be hard. So we parallelize over X_M .. X_{N-1} instead. This will - # work well when these dimensions are large enough to saturate memory - # bandwidth, but performance will be bad when these dimensions are - # small. - bx = te.thread_axis("blockIdx.x") - tx = te.thread_axis("threadIdx.x") max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) tdim = min(max_threads, fused_updates_dimension) - ib.scope_attr(tx, "thread_extent", tdim) - bdim = ceil_div(fused_updates_dimension, tdim) - ib.scope_attr(bx, "thread_extent", bdim) - with ib.for_range(0, ceil_div(fused_shape, bdim)) as i: - index = i * fused_updates_dimension + bx * tdim + tx + with ib.new_scope(): + bdim = ceil_div(fused_shape, tdim) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim) + ib.scope_attr(tx, "thread_extent", tdim) + + index = bx * tdim + tx with ib.if_scope(index < fused_shape): out[index] = data[index] - with ib.for_range(0, fused_indices_dimension) as i: - j = bx * tdim + tx - with ib.if_scope(j < fused_updates_dimension): - offset = fused_updates_dimension - index = j # This is x_M, .. x_{N-1} part of the index into out. - # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part - # of the index into out. - for l in reversed(range(indices_ptr.shape[0].value)): - # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] - index += offset * indices[i + l * fused_indices_dimension] - offset *= data_ptr.shape[l] - if mode == "update": - out[index] = updates[i * fused_updates_dimension + j] - elif mode == "add": - out[index] += updates[i * fused_updates_dimension + j] - else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) + # For better performance, we introduce blockIdx.y to implement for-loops + # within one thread. + # The code is parallel over the scattered indices, so we use atomic_add + # to guarantee correctness when mode=="add" + + # For now, atomic is not supported by target "vulkan", "metal", or "cuda" with "int64" + # So we fallback to normal algorithm, using "+=" rather than atomic_add + + # TODO (CaptainDuke): + # Since multiple threads compete for the same write index, which leads to + # non-determinstic output for update mode. We could add a new attribute, + # "allow_non_deterministic", which can be conditionally set to True by + # each frontend when non-determinsm is allowed. + cur_target_kind = str(tvm.target.Target.current(allow_none=False).kind) + with ib.new_scope(): + if ( + mode == "add" + and cur_target_kind not in ["vulkan", "metal"] + and updates.dtype in ["int32", "float32"] + ): + bdim_x = fused_indices_dimension + bdim_y = ceil_div(fused_updates_dimension, tdim) + # In case of large input sizes, fused_indices_dimension might be too large. + # So we use blockIdx.x because holds larger scales. + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim_x) + ib.scope_attr(by, "thread_extent", bdim_y) + ib.scope_attr(tx, "thread_extent", tdim) + + j = by * tdim + tx + with ib.if_scope(j < fused_updates_dimension): + offset = fused_updates_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] + # part of the index into out. + up_index = bx * fused_updates_dimension + j + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[bx * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[bx + l * fused_indices_dimension] + offset *= data_ptr.shape[l] + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", out[index]), + updates[up_index], + ) + else: + bdim_x = ceil_div(fused_updates_dimension, tdim) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim_x) + ib.scope_attr(tx, "thread_extent", tdim) + with ib.for_range(0, fused_indices_dimension) as i: + j = bx * tdim + tx + with ib.if_scope(j < fused_updates_dimension): + offset = fused_updates_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the + # indices[0, y_0, .. y_{K-1}], ... indices[M-1, y_0, .. y_{K-1}] + # part of the index into out. + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[i * l * fused_indices_dimension] = indices[l, y_0, + # ... y_{k-1}] + index += offset * indices[i + l * fused_indices_dimension] + offset *= data_ptr.shape[l] + if mode == "update": + out[index] = updates[i * fused_updates_dimension + j] + elif mode == "add": + out[index] += updates[i * fused_updates_dimension + j] + else: + raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) return ib.get() diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index fffb0d6d48fc..50bcafd9f9a7 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -19,7 +19,7 @@ import logging import math -from tvm import relay +from tvm import relay, tir from .. import nn @@ -56,6 +56,15 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): B, M, K = x_tensor.shape B, N, K = y_tensor.shape + if ( + isinstance(B, tir.expr.Any) + or isinstance(M, tir.expr.Any) + or isinstance(K, tir.expr.Any) + or isinstance(N, tir.expr.Any) + ): + # Dynamic shape do not support alter op layout now + return None + M = M.value K = K.value N = N.value diff --git a/python/tvm/topi/mali/depthwise_conv2d.py b/python/tvm/topi/mali/depthwise_conv2d.py index b292f694b995..98109ab4535f 100644 --- a/python/tvm/topi/mali/depthwise_conv2d.py +++ b/python/tvm/topi/mali/depthwise_conv2d.py @@ -30,7 +30,7 @@ def depthwise_conv2d_nchw(cfg, data, kernel, strides, padding, dilation, out_dty return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype) -# register customized schedule for arm cpu. +# register customized schedule for Mali. @autotvm.register_topi_schedule("depthwise_conv2d_nchw.mali") def schedule_depthwise_conv2d_nchw(cfg, outs): """Schedule depthwise conv2d @@ -51,86 +51,158 @@ def schedule_depthwise_conv2d_nchw(cfg, outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) - def _schedule(pad_data, kernel, conv): - """schedule depthwise_conv2d""" - max_unroll = 16 - vec_size = [1, 2, 4, 8, 16] + def _callback(op): + """traverse to find op to schedule""" + # schedule depthwise_conv2d + if op.tag == "depthwise_conv2d_nchw": + pad_data = op.input_tensors[0] + kernel = op.input_tensors[1] + conv = op.output(0) + _schedule(cfg, s, pad_data, kernel, conv, "NCHW") - ##### space definition begin ##### - n, c, y, x = s[conv].op.axis - bc, tc, ci = cfg.define_split("tile_c", c, num_outputs=3) - by, ty, yi = cfg.define_split("tile_y", y, num_outputs=3) - bx, tx, xi = cfg.define_split("tile_x", x, num_outputs=3) - cfg.define_annotate("ann_spatial", [ci, yi, xi], policy="try_unroll_vec") + traverse_inline(s, outs[0].op, _callback) + return s - # fallback support - if cfg.is_fallback: - ref_log = autotvm.tophub.load_reference_log( - "mali", "rk3399", "depthwise_conv2d_nchw.mali" - ) - cfg.fallback_with_reference_log(ref_log) - ###### space definition end ###### - # schedule padding - n, c, y, x = s[pad_data].op.axis - tile_and_bind3d(s, pad_data, c, y, x, cfg["tile_c"].size[1], 1, 1) +# register original implementation of depthwise_conv2d_nhwc since we don't need to change this part +@autotvm.register_topi_compute("depthwise_conv2d_nhwc.mali") +def depthwise_conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): + return nn.depthwise_conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) - # schedule dilation - if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: - s[kernel].compute_inline() - # schedule conv - if conv.op not in s.outputs: - s[conv].set_scope("local") - OL = conv - output = s.outputs[0].output(0) - else: - OL = s.cache_write(conv, "local") - output = conv - - n, c, y, x = s[output].op.axis - bc, tc, ci = cfg["tile_c"].apply(s, output, c) - by, ty, yi = cfg["tile_y"].apply(s, output, y) - bx, tx, xi = cfg["tile_x"].apply(s, output, x) - - bc = s[output].fuse(n, bc) - s[output].bind(bc, te.thread_axis("blockIdx.z")) - s[output].bind(tc, te.thread_axis("threadIdx.z")) - s[output].bind(by, te.thread_axis("blockIdx.y")) - s[output].bind(ty, te.thread_axis("threadIdx.y")) - s[output].bind(bx, te.thread_axis("blockIdx.x")) - s[output].bind(tx, te.thread_axis("threadIdx.x")) - - di, dj = s[OL].op.reduce_axis - s[OL].unroll(di) - s[OL].unroll(dj) - - s[OL].compute_at(s[output], tx) - n, ci, yi, xi = s[OL].op.axis - - cfg["ann_spatial"].apply( - s, - OL, - [ci, yi, xi], - axis_lens=[cfg["tile_c"].size[2], cfg["tile_y"].size[2], cfg["tile_x"].size[2]], - max_unroll=max_unroll, - vec_size=vec_size, - cfg=cfg, - ) +# register customized schedule for Mali. +@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.mali") +def schedule_depthwise_conv2d_nhwc(cfg, outs): + """Schedule depthwise conv2d + + Parameters + ---------- + cfg: ConfigEntity + The configuration of this template + outs: Array of Tensor + The computation graph description of depthwise convolution2d + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for depthwise_conv2d nchw. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) def _callback(op): """traverse to find op to schedule""" # schedule depthwise_conv2d - if op.tag == "depthwise_conv2d_nchw": + if op.tag == "depthwise_conv2d_nhwc": pad_data = op.input_tensors[0] kernel = op.input_tensors[1] conv = op.output(0) - _schedule(pad_data, kernel, conv) + _schedule(cfg, s, pad_data, kernel, conv, "NHWC") traverse_inline(s, outs[0].op, _callback) return s +def _schedule(cfg, s, pad_data, kernel, conv, layout): + """schedule depthwise_conv2d""" + assert layout in ("NCHW", "NHWC") + + max_unroll = 16 + vec_size = [1, 2, 4, 8, 16] + + ##### space definition begin ##### + if layout == "NCHW": + n, c, h, w = s[conv].op.axis + else: + n, h, w, c = s[conv].op.axis + + bc, tc, ci = cfg.define_split("tile_c", c, num_outputs=3) + bh, th, hi = cfg.define_split("tile_y", h, num_outputs=3) + bw, tw, wi = cfg.define_split("tile_x", w, num_outputs=3) + cfg.define_annotate("ann_spatial", [ci, hi, wi], policy="try_unroll_vec") + + # fallback support + if cfg.is_fallback: + if layout == "NCHW": + ref_log = autotvm.tophub.load_reference_log( + "mali", "rk3399", "depthwise_conv2d_nchw.mali" + ) + cfg.fallback_with_reference_log(ref_log) + else: + cfg.fallback_split("tile_c", [-1, 4, 2]) + cfg.fallback_split("tile_y", [-1, 4, 2]) + cfg.fallback_split("tile_x", [-1, 4, 2]) + ###### space definition end ###### + + # schedule padding + if layout == "NCHW": + n, c, h, w = s[pad_data].op.axis + z, y, x = c, h, w + z_factor, y_factor, x_factor = cfg["tile_c"].size[1], 1, 1 + else: + n, h, w, c = s[pad_data].op.axis + z, y, x = h, w, c + z_factor, y_factor, x_factor = 1, 1, cfg["tile_c"].size[1] + tile_and_bind3d(s, pad_data, z, y, x, z_factor, y_factor, x_factor) + + # schedule dilation + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + + # schedule conv + if conv.op not in s.outputs: + s[conv].set_scope("local") + OL = conv + output = s.outputs[0].output(0) + else: + OL = s.cache_write(conv, "local") + output = conv + + if layout == "NCHW": + n, c, h, w = s[output].op.axis + else: + n, h, w, c = s[output].op.axis + + bc, tc, ci = cfg["tile_c"].apply(s, output, c) + bh, th, hi = cfg["tile_y"].apply(s, output, h) + bw, tw, wi = cfg["tile_x"].apply(s, output, w) + + if layout == "NCHW": + bz, tz, by, ty, bx, tx = bc, tc, bh, th, bw, tw + else: + bz, tz, by, ty, bx, tx = bh, th, bw, tw, bc, tc + + bz = s[output].fuse(n, bz) + s[output].bind(bz, te.thread_axis("blockIdx.z")) + s[output].bind(tz, te.thread_axis("threadIdx.z")) + s[output].bind(by, te.thread_axis("blockIdx.y")) + s[output].bind(ty, te.thread_axis("threadIdx.y")) + s[output].bind(bx, te.thread_axis("blockIdx.x")) + s[output].bind(tx, te.thread_axis("threadIdx.x")) + + di, dj = s[OL].op.reduce_axis + s[OL].unroll(di) + s[OL].unroll(dj) + + s[OL].compute_at(s[output], tx) + + if layout == "NCHW": + n, ci, hi, wi = s[OL].op.axis + else: + n, hi, wi, ci = s[OL].op.axis + + cfg["ann_spatial"].apply( + s, + OL, + [ci, hi, wi], + axis_lens=[cfg["tile_c"].size[2], cfg["tile_y"].size[2], cfg["tile_x"].size[2]], + max_unroll=max_unroll, + vec_size=vec_size, + cfg=cfg, + ) + + def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None): """tile and bind 3d""" y_factor = y_factor or z_factor diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index b6ed5a373e81..26d45feb0387 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -16,28 +16,50 @@ # under the License. """Batch matrix multiplication""" # pylint: disable=invalid-name +import logging import tvm from tvm import te, auto_scheduler from ..utils import get_const_tuple +logger = logging.getLogger("topi") -def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. Supports broadcasting for batch dimension. + +def batch_matmul( + tensor_a, + tensor_b, + oshape=None, + out_dtype=None, + transpose_a=False, + transpose_b=True, + auto_scheduler_rewritten_layout="", +): + """Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + + Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. Parameters ---------- - x : tvm.te.Tensor - 3-D with shape [batch, M, K] + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. - y : tvm.te.Tensor - 3-D with shape [batch, N, K] + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. oshape : List[Optional] Explicit intended output shape of the computation. Can be useful in cases with dynamic input shapes. - auto_scheduler_rewritten_layout: str = "" + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. + + auto_scheduler_rewritten_layout: Optional[str] = "" The layout after auto-scheduler's layout rewrite pass. Returns @@ -45,35 +67,79 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - x_shape = get_const_tuple(x.shape) + assert len(tensor_a.shape) == 3, "tensor_a only support 3-dim" + if transpose_a: + XB, XK, XI = get_const_tuple(tensor_a.shape) + else: + XB, XI, XK = get_const_tuple(tensor_a.shape) if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout - y_shape = auto_scheduler.get_shape_from_rewritten_layout( - auto_scheduler_rewritten_layout, ["b", "j", "k"] + YB, YK, YJ = auto_scheduler.get_shape_from_rewritten_layout( + auto_scheduler_rewritten_layout, ["b", "k", "j"] ) - auto_scheduler.remove_index_check(y) + auto_scheduler.remove_index_check(tensor_b) else: - y_shape = get_const_tuple(y.shape) - assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim batch_matmul" + assert len(tensor_b.shape) == 3, "tensor_b only support 3-dim" + if transpose_b: + YB, YJ, YK = get_const_tuple(tensor_b.shape) + else: + YB, YK, YJ = get_const_tuple(tensor_b.shape) - XB = x_shape[0] - YB = y_shape[0] - _, M, K = x.shape - k = te.reduce_axis((0, K), name="k") + assert XK == YK or isinstance(YK, tvm.tir.expr.Var), "shapes of x and y are inconsistent" + k = te.reduce_axis((0, XK), name="k") if oshape is None: assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" - assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent" - batch = te.max(XB, YB) - N = y.shape[1] - oshape = (batch, M, N) + batch = ( + tvm.tir.expr.SizeVar("batch", "int32") + if isinstance(XB, tvm.tir.expr.Var) or isinstance(YB, tvm.tir.expr.Var) + else te.max(XB, YB) + ) + oshape = (batch, XI, YJ) + if out_dtype is None: + out_dtype = tensor_a.dtype + if tensor_a.dtype != tensor_b.dtype: + logger.warning( + "tensor_a has different data type with tensor_b: %s, %s", + tensor_a.dtype, + tensor_b.dtype, + ) + + if (transpose_a, transpose_b) == (True, True): + compute_lambda = lambda b, i, j: te.sum( + tensor_a[b if XB != 1 else 0, k, i].astype(out_dtype) + * tensor_b[b if YB != 1 else 0, j, k].astype(out_dtype), + axis=k, + ) + compute_name = "T_batch_matmul_TT" + elif (transpose_a, transpose_b) == (True, False): + compute_lambda = lambda b, i, j: te.sum( + tensor_a[b if XB != 1 else 0, k, i].astype(out_dtype) + * tensor_b[b if YB != 1 else 0, k, j].astype(out_dtype), + axis=k, + ) + compute_name = "T_batch_matmul_TN" + elif (transpose_a, transpose_b) == (False, True): + compute_lambda = lambda b, i, j: te.sum( + tensor_a[b if XB != 1 else 0, i, k].astype(out_dtype) + * tensor_b[b if YB != 1 else 0, j, k].astype(out_dtype), + axis=k, + ) + compute_name = "T_batch_matmul_NT" + else: # (transpose_a, transpose_b) == (False, False): + compute_lambda = lambda b, i, j: te.sum( + tensor_a[b if XB != 1 else 0, i, k].astype(out_dtype) + * tensor_b[b if YB != 1 else 0, k, j].astype(out_dtype), + axis=k, + ) + compute_name = "T_batch_matmul_NN" output = te.compute( oshape, - lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), + compute_lambda, + name=compute_name, tag="batch_matmul", - attrs={"layout_free_placeholders": [y]}, + attrs={"layout_free_placeholders": [tensor_b]}, ) - if auto_scheduler_rewritten_layout: output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 3f72bdc4b667..7cb4b09b8805 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -176,10 +176,13 @@ def _get_workload(data, kernel, stride, padding, dilation, out_dtype, data_layou else: KH, KW, CIG, CO = get_const_tuple(kernel.shape) - pt, pl, pb, pr = get_pad_tuple(padding, (get_const_int(KH), get_const_int(KW))) dilation_h, dilation_w = ( dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) ) + pt, pl, pb, pr = get_pad_tuple( + padding, + (get_const_int((KH - 1) * dilation_h + 1), get_const_int((KW - 1) * dilation_w + 1)), + ) GRPS = CI // CIG if isinstance(stride, (tuple, list)): HSTR, WSTR = stride diff --git a/python/tvm/topi/nn/depthwise_conv2d.py b/python/tvm/topi/nn/depthwise_conv2d.py index a3639b57e7e0..48ffb8c6d9ff 100644 --- a/python/tvm/topi/nn/depthwise_conv2d.py +++ b/python/tvm/topi/nn/depthwise_conv2d.py @@ -24,7 +24,7 @@ from .dilate import dilate from .pad import pad from .utils import get_pad_tuple -from ..utils import simplify +from ..utils import simplify, get_const_tuple # workload description of depthwise-conv2d Workload = namedtuple( @@ -50,11 +50,47 @@ ) -def _get_workload(data, kernel, stride, padding, dilation, out_dtype): - """Get the workload structure.""" - _, in_channel, height, width = [x.value for x in data.shape] - channel, channel_multiplier, kh, kw = [x.value for x in kernel.shape] - out_channel = channel * channel_multiplier +def _get_workload(data, kernel, stride, padding, dilation, out_dtype, data_layout="NCHW"): + """Get the workload structure for a depthwise conv2d. + + Input data and filter should use NCHW layout. + """ + if data_layout == "NCHW": + _, in_channel, height, width = get_const_tuple(data.shape) + filter_channel, channel_multiplier, kh, kw = get_const_tuple(kernel.shape) + elif data_layout == "NHWC": + _, height, width, in_channel = get_const_tuple(data.shape) + kh, kw, filter_channel, channel_multiplier = get_const_tuple(kernel.shape) + elif data_layout == "NCHWc": + _, in_channel_chunk, height, width, in_channel_block = get_const_tuple(data.shape) + in_channel = in_channel_chunk * in_channel_block + ( + filter_channel_chunk, + cm_chunk, + kh, + kw, + cm_block, + filter_channel_block, + ) = get_const_tuple(kernel.shape) + filter_channel = filter_channel_chunk * filter_channel_block + channel_multiplier = cm_chunk * cm_block + + assert ( + in_channel_block == filter_channel_block + ), "Incorrect dimensions, data has block size {}, but filter has block size {}".format( + in_channel_block, filter_channel_block + ) + + else: + raise ValueError("Data layout {} not supported".format(data_layout)) + + assert ( + in_channel == filter_channel + ), "Incorrect dimensions, data has {} channels but filter expects {} channels".format( + in_channel, filter_channel + ) + + out_channel = filter_channel * channel_multiplier dilation_h, dilation_w = ( dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) ) @@ -102,8 +138,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No Filter : tvm.te.Tensor 4-D with shape [in_channel, channel_multiplier, filter_height, filter_width] - stride : tuple of two ints - The spatial stride along height and width + stride : int or a list/tuple of two ints + The spatial stride, or (stride_height, stride_width). padding : int or str Padding size, or ['VALID', 'SAME'] diff --git a/python/tvm/topi/nn/mapping.py b/python/tvm/topi/nn/mapping.py index c048fc86d4d5..0e0b1825df30 100644 --- a/python/tvm/topi/nn/mapping.py +++ b/python/tvm/topi/nn/mapping.py @@ -29,7 +29,7 @@ def scale_shift_nchw(Input, Scale, Shift): Parameters ---------- Input : tvm.te.Tensor - Input tensor, layout is NCHW + 4-D input tensor, NCHW layout [batch, channel, height, width] Scale : tvm.te.Tensor Scale tensor, 1-D of size channel number @@ -54,7 +54,7 @@ def scale_shift_nhwc(Input, Scale, Shift): Parameters ---------- Input : tvm.te.Tensor - Input tensor, layout is NHWC + 4-D input tensor, NHWC layout [batch, height, width, channel] Scale : tvm.te.Tensor Scale tensor, 1-D of size channel number @@ -70,3 +70,30 @@ def scale_shift_nhwc(Input, Scale, Shift): return te.compute( Input.shape, lambda b, i, j, c: Input[b, i, j, c] * Scale[c] + Shift[c], name="ScaleShift" ) + + +@tvm.te.tag_scope(tag=tag.BROADCAST) +def scale_shift_nchwc(Input, Scale, Shift): + """Batch normalization operator in inference. + + Parameters + ---------- + Input : tvm.te.Tensor + 5-D input tensor, NCHWc layout [batch, channel_chunk, height, width, channel_block] + + Scale : tvm.te.Tensor + Scale tensor, 2-D of size [channel_chunk, channel_block] + + Shift : tvm.te.Tensor + Shift tensor, 2-D of size [channel_chunk, channel_block] + + Returns + ------- + Output : tvm.te.Tensor + Output tensor, layout is NHWC + """ + return te.compute( + Input.shape, + lambda b, cc, i, j, cb: Input[b, cc, i, j, cb] * Scale[cc, cb] + Shift[cc, cb], + name="ScaleShift", + ) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 73998db6f162..e577104c3ddc 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -31,7 +31,7 @@ def sparse_dense_sp_rhs(data, weight_data, weight_indices, weight_indptr): Parameters ---------- data : tvm.te.Tensor - 2-D with shape [M, K], float32 + 2-D with shape [M, K] weight_data : tvm.te.Tensor 1-D with shape [nnz] (CSR) or @@ -78,7 +78,7 @@ def sparse_dense_sp_lhs(data_data, data_indices, data_indptr, weight): 1-D with shape [(M + 1) // bs_r] (BSR) weight: - 2-D with shape [N, K], float32 + 2-D with shape [N, K] Returns ------- @@ -105,7 +105,7 @@ def sparse_dense(dense_data, sparse_data, sparse_indices, sparse_indptr, sparse_ Parameters ---------- dense_data : tvm.te.Tensor - 2-D with shape [M, K], float32 + 2-D with shape [M, K] sparse_data : tvm.te.Tensor 1-D with shape [nnz] (CSR) or @@ -239,7 +239,7 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr): Parameters ---------- sparse_data : tvm.te.Tensor - 1-D with shape [nonzeros], dtype of 'float32' + 1-D with shape [nonzeros] sparse_indices : tvm.te.Tensor 1-D with shape [nonzeros], dtype of 'int32' @@ -250,7 +250,7 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr): Returns ------- out_data : tvm.te.Tensor - 1-D with shape [nonzeros], dtype of 'float32' + 1-D with shape [nonzeros] out_indices : tvm.te.Tensor 1-D with shape [nonzeros], dtype of 'int32' @@ -275,7 +275,7 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr): ins[0], ins[1], ins[2], outs[0], outs[1], outs[2] ), tag="sparse_transpose_csr", - dtype=["float32", "int32", "int32"], + dtype=[sparse_data.dtype, "int32", "int32"], name="out", ) @@ -566,7 +566,9 @@ def _compute_block(i, nb_j, j, h, w): # pylint: disable=C0103 ) -def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC"): +def sparse_conv2d( + dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC", kernel_size=1 +): """ Computes sparse-conv2d(1*1) of ``data`` and ``(weight_data, weight_indices, weight_indptr)`` @@ -598,14 +600,15 @@ def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout 4-D with shape [M, H, W, N] (layout=NHWC) 4-D with shape [M, N, H ,W] (layout=NCHW) """ - if layout == "NHWC": - return _sparse_conv2d_bsr_compute_nhwc( - dense_data, sparse_data, sparse_indices, sparse_indptr - ) - elif layout == "NCHW": - return _sparse_conv2d_bsr_compute_nchw( - dense_data, sparse_data, sparse_indices, sparse_indptr - ) + if kernel_size == 1: + if layout == "NHWC": + return _sparse_conv2d_bsr_compute_nhwc( + dense_data, sparse_data, sparse_indices, sparse_indptr + ) + elif layout == "NCHW": + return _sparse_conv2d_bsr_compute_nchw( + dense_data, sparse_data, sparse_indices, sparse_indptr + ) else: raise ValueError("Unsupport Layout %s" % layout) diff --git a/python/tvm/topi/rocm/batch_matmul.py b/python/tvm/topi/rocm/batch_matmul.py index 7f35f4b55620..53b51eedf6d9 100644 --- a/python/tvm/topi/rocm/batch_matmul.py +++ b/python/tvm/topi/rocm/batch_matmul.py @@ -23,7 +23,9 @@ @autotvm.register_topi_compute("batch_matmul_rocblas.rocm") -def batch_matmul_rocblas(cfg, x, y, out_shape=None): +def batch_matmul_rocblas( + cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """Computes matrix multiplication of `x` and `y` via rocblas when `x` and `y` are batched matrices. @@ -40,12 +42,13 @@ def batch_matmul_rocblas(cfg, x, y, out_shape=None): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ + del out_dtype batch, M, K = get_const_tuple(x.shape) _, N, _ = get_const_tuple(y.shape) if out_shape is not None: assert out_shape[0] == batch, "Input and output batch sizes must match" assert out_shape[1] == M and out_shape[2] == N, "Invalid output shape" - result = rocblas.batch_matmul(x, y, False, True) + result = rocblas.batch_matmul(x, y, transpose_a, transpose_b) cfg.add_flop(batch * M * N * K * 2) return result diff --git a/python/tvm/topi/sparse/csrmm.py b/python/tvm/topi/sparse/csrmm.py index 39ba3332fc72..4d659c801103 100644 --- a/python/tvm/topi/sparse/csrmm.py +++ b/python/tvm/topi/sparse/csrmm.py @@ -20,6 +20,7 @@ from tvm import te from .. import tag from ..utils import simplify +from ...tir.generic import cast def csrmm_default(data, indices, indptr, weight, bias=None): @@ -57,6 +58,12 @@ def csrmm_default(data, indices, indptr, weight, bias=None): assert isinstance( weight, te.tensor.Tensor ), "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight)) + assert ( + data.dtype == weight.dtype + ), "Data and weight must have the same dtype, but they have %s and %s" % ( + data.dtype, + weight.dtype, + ) if bias is not None: assert len(bias.shape) == 1 M = simplify(indptr.shape[0] - 1) @@ -74,9 +81,9 @@ def csrmm_default_ir(data, indices, indptr, weight, out): _, N = weight.shape with irb.for_range(0, N, kind="vectorize", name="n") as n: with irb.for_range(0, M, kind="parallel", name="row") as row: - dot = irb.allocate("float32", (1,), name="dot", scope="local") - out_ptr[row * N + n] = 0.0 - dot[0] = 0.0 + dot = irb.allocate(data.dtype, (1,), name="dot", scope="local") + out_ptr[row * N + n] = cast(0, data.dtype) + dot[0] = cast(0, data.dtype) row_start = indptr_ptr[row] row_end = indptr_ptr[row + 1] row_elems = row_end - row_start @@ -92,7 +99,7 @@ def csrmm_default_ir(data, indices, indptr, weight, out): [data, indices, indptr, weight], lambda ins, outs: csrmm_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]), tag="csrmm", - dtype="float32", + dtype=data.dtype, name="out", ) if bias is not None: diff --git a/python/tvm/topi/sparse/csrmv.py b/python/tvm/topi/sparse/csrmv.py index a2d22afe01e0..3c2016c6513a 100644 --- a/python/tvm/topi/sparse/csrmv.py +++ b/python/tvm/topi/sparse/csrmv.py @@ -19,6 +19,7 @@ import tvm from tvm import te from .. import tag +from ...tir.generic import cast def csrmv_default(data, indices, indptr, weight, bias=None): @@ -50,6 +51,12 @@ def csrmv_default(data, indices, indptr, weight, bias=None): assert isinstance( weight, te.tensor.Tensor ), "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight)) + assert ( + data.dtype == weight.dtype + ), "Data and weight must have the same dtype, but they have %s and %s" % ( + data.dtype, + weight.dtype, + ) if bias is not None: assert len(bias.shape) == 1 batch = indptr.shape[0] - 1 @@ -64,9 +71,9 @@ def csrmv_default_ir(data, indices, indptr, weight, out): out_ptr = irb.buffer_ptr(out) num_rows = indptr.shape[0] - 1 with irb.for_range(0, num_rows, kind="parallel", name="row") as row: - dot = irb.allocate("float32", (1,), name="dot", scope="local") - out_ptr[row] = 0.0 - dot[0] = 0.0 + dot = irb.allocate(data.dtype, (1,), name="dot", scope="local") + out_ptr[row] = cast(0, data.dtype) + dot[0] = cast(0, data.dtype) row_start = indptr_ptr[row] row_end = indptr_ptr[row + 1] row_elems = row_end - row_start @@ -82,7 +89,7 @@ def csrmv_default_ir(data, indices, indptr, weight, out): [data, indices, indptr, weight], lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]), tag="csrmv", - dtype="float32", + dtype=data.dtype, name="csrmv", ) if bias is not None: diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 610c51668835..d10c49f5c084 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -32,7 +32,11 @@ from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python from .correlation_nchw_python import correlation_nchw_python from .deformable_conv2d_python import deformable_conv2d_nchw_python, deformable_conv2d_nhwc_python -from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc +from .depthwise_conv2d_python import ( + depthwise_conv2d_python_nchw, + depthwise_conv2d_python_nhwc, + depthwise_conv2d_python_nchwc, +) from .dilate_python import dilate_python from .softmax_python import softmax_python, log_softmax_python from .resize_python import resize1d_python, resize2d_python, resize3d_python diff --git a/python/tvm/topi/testing/batch_matmul.py b/python/tvm/topi/testing/batch_matmul.py index 96d1fcbb5bc3..18fa7e8c4b33 100644 --- a/python/tvm/topi/testing/batch_matmul.py +++ b/python/tvm/topi/testing/batch_matmul.py @@ -19,7 +19,7 @@ import numpy as np -def batch_matmul(x, y, out_dtype=None): +def batch_matmul(x, y, out_dtype=None, trans_x=False, trans_y=True): """batch_matmul operator implemented in numpy. Parameters @@ -38,13 +38,22 @@ def batch_matmul(x, y, out_dtype=None): out : numpy.ndarray 3-D with shape [batch, M, N] """ - XB, M, _ = x.shape - YB, N, _ = y.shape + if trans_x: + XB, _, M = x.shape + else: + XB, M, _ = x.shape + if trans_y: + YB, N, _ = y.shape + else: + YB, _, N = y.shape batch = max(XB, YB) dtype = x.dtype if out_dtype is None else out_dtype out = np.zeros((batch, M, N)).astype(dtype) for i in range(batch): + xx = x[i if XB != 1 else 0].astype(dtype) + yy = y[i if YB != 1 else 0].astype(dtype) out[i] = np.dot( - x[i if XB != 1 else 0].astype(dtype), y[i if YB != 1 else 0].T.astype(dtype) + xx.T if trans_x else xx, + yy.T if trans_y else yy, ) return out diff --git a/python/tvm/topi/testing/common.py b/python/tvm/topi/testing/common.py index 785a6d11d8a7..d040310ccc8f 100644 --- a/python/tvm/topi/testing/common.py +++ b/python/tvm/topi/testing/common.py @@ -18,6 +18,8 @@ """Common utility for topi test""" import numpy as np +import scipy.signal + import tvm from tvm import topi from tvm.testing import assert_allclose @@ -108,3 +110,52 @@ def compare_numpy_tvm(inputs, output, target, device, compute, schedule): arys = [tvm.nd.array(x, device=device) for x in inputs] func(*(arys + [te_out])) assert_allclose(te_out.numpy(), output, atol=1e-4, rtol=1e-4) + + +def _convolve2d(data, weights): + """2d convolution operator in HW layout. + + This is intended to be used as a replacement for + scipy.signals.convolve2d, with wider support for different dtypes. + scipy.signal.convolve2d does not support all TVM-supported + dtypes (e.g. float16). Where possible, this function uses + scipy.signal.convolve2d to take advantage of compiled scipy + routines, falling back to an explicit loop only where needed. + + Parameters + ---------- + data : numpy.ndarray + 2-D with shape [in_height, in_width] + + weights : numpy.ndarray + 2-D with shape [filter_height, filter_width]. + + Returns + ------- + b_np : np.ndarray + 2-D with shape [out_height, out_width] + + Return value and layout conventions are matched to + ``scipy.signal.convolve2d(data, weights, mode="valid")`` + """ + + try: + return scipy.signal.convolve2d(data, weights, mode="valid") + except ValueError: + pass + + weights = np.rot90(weights, k=2) + + assert len(data.shape) == len(weights.shape) == 2 + + dtype = data.dtype + kernel_h, kernel_w = weights.shape + + output_shape = [a_dim - w_dim + 1 for a_dim, w_dim in zip(data.shape, weights.shape)] + output = np.zeros(output_shape, dtype=dtype) + + for y in range(output_shape[0]): + for x in range(output_shape[1]): + output[y][x] = np.sum(data[y : y + kernel_h, x : x + kernel_w] * weights) + + return output diff --git a/python/tvm/topi/testing/conv2d_nchw_python.py b/python/tvm/topi/testing/conv2d_nchw_python.py index ce5d981cc651..4214ee4a2459 100644 --- a/python/tvm/topi/testing/conv2d_nchw_python.py +++ b/python/tvm/topi/testing/conv2d_nchw_python.py @@ -17,7 +17,8 @@ # pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-branches """Convolution in python""" import numpy as np -import scipy.signal +import scipy + from tvm.topi.nn.utils import get_pad_tuple @@ -58,21 +59,67 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding): out_channel = num_filter out_height = (in_height - kernel_h + pad_h) // stride_h + 1 out_width = (in_width - kernel_w + pad_w) // stride_w + 1 - b_np = np.zeros((batch, out_channel, out_height, out_width)) + b_np = np.zeros((batch, out_channel, out_height, out_width), dtype=a_np.dtype) # computation for n in range(batch): for f in range(out_channel): for c in range(in_channel): if pad_h > 0 or pad_w > 0: - apad = np.zeros((in_height + pad_h, in_width + pad_w)) + apad = np.zeros((in_height + pad_h, in_width + pad_w), dtype=a_np.dtype) apad[pad_top : pad_top + in_height, pad_left : pad_left + in_width] = a_np[n, c] else: apad = a_np[n, c] - out = scipy.signal.convolve2d(apad, np.rot90(np.rot90(w_np[f, c])), mode="valid") + + out = _conv2d_hw(apad, w_np[f, c]) b_np[n, f] += out[::stride_h, ::stride_w] return b_np +def _conv2d_hw(apad, w_np_fc): + """2d convolution operator in HW layout. + + This is intended to be used as a subroutine from + _conv2d_nchw_python. Using scipy.signal.convolve2d directly does + not work for all dtypes (e.g. float16). Where possible, this + function uses scipy.signal.convolve2d to take advantage of + compiled scipy routines, falling back to an explicit loop only + where needed + + Parameters + ---------- + a_np : numpy.ndarray + 2-D with shape [in_height, in_width] + + w_np : numpy.ndarray + 2-D with shape [filter_height, filter_width]. + + Returns + ------- + b_np : np.ndarray + 2-D with shape [out_height, out_width] + """ + + try: + return scipy.signal.convolve2d(apad, np.rot90(np.rot90(w_np_fc)), mode="valid") + except ValueError: + pass + + assert len(apad.shape) == len(w_np_fc.shape) == 2 + + dtype = apad.dtype + in_height, in_width = apad.shape + kernel_h, kernel_w = w_np_fc.shape + + output_shape = [a_dim - w_dim + 1 for a_dim, w_dim in zip(apad.shape, w_np_fc.shape)] + output = np.zeros(output_shape, dtype=apad.dtype) + + for y in range(output_shape[0]): + for x in range(output_shape[1]): + output[y][x] = np.sum(apad[y : y + kernel_h, x : x + kernel_w] * w_np_fc) + + return output + + def conv2d_nchw_python(a_np, w_np, stride, padding, groups=1): """Convolution operator in NCHW layout. diff --git a/python/tvm/topi/testing/depthwise_conv2d_python.py b/python/tvm/topi/testing/depthwise_conv2d_python.py index 02964ecfae3b..a6247e9f92cc 100644 --- a/python/tvm/topi/testing/depthwise_conv2d_python.py +++ b/python/tvm/topi/testing/depthwise_conv2d_python.py @@ -17,7 +17,9 @@ # pylint: disable=invalid-name, unused-variable, line-too-long """Depthwise convolution in python""" import numpy as np -from scipy import signal + +from tvm.topi.nn.utils import get_pad_tuple +from .common import _convolve2d def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding): @@ -49,48 +51,94 @@ def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding): else: stride_h, stride_w = stride - # calculate output shape - if padding == "VALID": - out_channel = in_channel * channel_multiplier - out_height = (in_height - filter_height) // stride_h + 1 - out_width = (in_width - filter_width) // stride_w + 1 - output_np = np.zeros((batch, out_channel, out_height, out_width)) - for i in range(batch): - for j in range(out_channel): - output_np[i, j, :, :] = signal.convolve2d( - input_np[i, j // channel_multiplier, :, :], - np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], 2), - mode="valid", - )[ - 0 : (in_height - filter_height + 1) : stride_h, - 0 : (in_width - filter_width + 1) : stride_w, - ] - elif padding == "SAME": - out_channel = in_channel * channel_multiplier - out_height = int(np.ceil(float(in_height) / float(stride_h))) - out_width = int(np.ceil(float(in_width) / float(stride_w))) - output_np = np.zeros((batch, out_channel, out_height, out_width)) - pad_along_height = int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) - pad_along_width = int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) - pad_top_tvm = int(np.ceil(float(pad_along_height) / 2)) - pad_left_tvm = int(np.ceil(float(pad_along_width) / 2)) - pad_top_scipy = int(np.ceil(float(filter_height - 1) / 2)) - pad_left_scipy = int(np.ceil(float(filter_width - 1) / 2)) - index_h = pad_top_scipy - pad_top_tvm - index_w = pad_left_scipy - pad_left_tvm - for i in range(batch): - for j in range(out_channel): - output_np[i, j, :, :] = signal.convolve2d( - input_np[i, j // channel_multiplier, :, :], - np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], 2), - mode="same", - )[index_h:in_height:stride_h, index_w:in_width:stride_w] + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (filter_height, filter_width)) + pad_h = pad_top + pad_bottom + pad_w = pad_left + pad_right + + out_channel = in_channel * channel_multiplier + out_height = (in_height - filter_height + pad_h) // stride_h + 1 + out_width = (in_width - filter_width + pad_w) // stride_w + 1 + output_np = np.zeros((batch, out_channel, out_height, out_width)) + + for i in range(batch): + for j in range(out_channel): + apad = input_np[i, j // channel_multiplier, :, :] + if pad_h or pad_w: + apad = np.pad(apad, [(pad_top, pad_bottom), (pad_left, pad_right)]) + + conv = _convolve2d( + apad, + np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], k=2), + ) + output_np[i, j, :, :] = conv[ + ::stride_h, + ::stride_w, + ] return output_np +def depthwise_conv2d_python_nchwc(input_np, filter_np, stride, padding): + """Depthwise convolution operator in NCHWc layout. + + Parameters + ---------- + input_np : numpy.ndarray + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + + filter_np : numpy.ndarray + 6-D with shape [out_channel_chunk, channel_multiplier_chunk, + filter_height, filter_width, + channel_multiplier_block, out_channel_block] + + stride : list / tuple of 2 ints + [stride_height, stride_width] + + padding : str + 'VALID' or 'SAME' + + Returns + ------- + output_np : np.ndarray + 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] + """ + # Transform to NCHW + batch_size, in_channel_chunk, in_height, in_width, in_channel_block = input_np.shape + input_nchw = input_np.transpose(0, 1, 4, 2, 3).reshape( + (batch_size, in_channel_chunk * in_channel_block, in_height, in_width) + ) + + ( + out_channel_chunk, + channel_multiplier_chunk, + filter_height, + filter_width, + channel_multiplier_block, + out_channel_block, + ) = filter_np.shape + filter_nchw = filter_np.transpose(0, 5, 1, 4, 2, 3).reshape( + ( + out_channel_chunk * out_channel_block, + channel_multiplier_chunk * channel_multiplier_block, + filter_height, + filter_width, + ) + ) + + # Perform conv2d + output_np = depthwise_conv2d_python_nchw(input_nchw, filter_nchw, stride, padding) + + # Transform back to NCHWc + + # pylint: disable=unpacking-non-sequence + batch_size, out_channel, out_height, out_width = output_np.shape + return output_np.reshape( + (batch_size, out_channel_chunk, out_channel_block, out_height, out_width) + ).transpose(0, 1, 3, 4, 2) + + def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding): - """Depthwise convolution operator in nchw layout. + """Depthwise convolution operator in nhwc layout. Parameters ---------- @@ -111,48 +159,7 @@ def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding): output_np : np.ndarray 4-D with shape [batch, out_height, out_width, out_channel] """ - batch, in_height, in_width, in_channel = input_np.shape - filter_height, filter_width, _, channel_multiplier = filter_np.shape - if isinstance(stride, int): - stride_h = stride_w = stride - else: - stride_h, stride_w = stride - - # calculate output shape - if padding == "VALID": - out_channel = in_channel * channel_multiplier - out_height = (in_height - filter_height) // stride_h + 1 - out_width = (in_width - filter_width) // stride_w + 1 - output_np = np.zeros((batch, out_height, out_width, out_channel)) - for i in range(batch): - for j in range(out_channel): - output_np[i, :, :, j] = signal.convolve2d( - input_np[i, :, :, j // channel_multiplier], - np.rot90(filter_np[:, :, j // channel_multiplier, j % channel_multiplier], 2), - mode="valid", - )[ - 0 : (in_height - filter_height + 1) : stride_h, - 0 : (in_width - filter_width + 1) : stride_w, - ] - if padding == "SAME": - out_channel = in_channel * channel_multiplier - out_height = int(np.ceil(float(in_height) / float(stride_h))) - out_width = int(np.ceil(float(in_width) / float(stride_w))) - output_np = np.zeros((batch, out_height, out_width, out_channel)) - pad_along_height = int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) - pad_along_width = int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) - pad_top_tvm = int(np.ceil(float(pad_along_height) / 2)) - pad_left_tvm = int(np.ceil(float(pad_along_width) / 2)) - pad_top_scipy = int(np.ceil(float(filter_height - 1) / 2)) - pad_left_scipy = int(np.ceil(float(filter_width - 1) / 2)) - index_h = pad_top_scipy - pad_top_tvm - index_w = pad_left_scipy - pad_left_tvm - for i in range(batch): - for j in range(out_channel): - output_np[i, :, :, j] = signal.convolve2d( - input_np[i, :, :, j // channel_multiplier], - np.rot90(filter_np[:, :, j // channel_multiplier, j % channel_multiplier], 2), - mode="same", - )[index_h:in_height:stride_h, index_w:in_width:stride_w] - - return output_np + input_nchw = input_np.transpose(0, 3, 1, 2) + filter_nchw = filter_np.transpose(2, 3, 0, 1) + output_nchw = depthwise_conv2d_python_nchw(input_nchw, filter_nchw, stride, padding) + return output_nchw.transpose(0, 2, 3, 1) diff --git a/python/tvm/topi/testing/dilate_python.py b/python/tvm/topi/testing/dilate_python.py index 0ae611559729..43559e3cee12 100644 --- a/python/tvm/topi/testing/dilate_python.py +++ b/python/tvm/topi/testing/dilate_python.py @@ -19,7 +19,7 @@ import numpy as np -def dilate_python(input_np, strides, dilation_value=0.0): +def dilate_python(input_np, strides, dilation_value=0.0, out_dtype=None): """Dilate operation. Parameters @@ -33,23 +33,34 @@ def dilate_python(input_np, strides, dilation_value=0.0): dilation_value : int/float, optional Value used to dilate the input. + out_dtype : Option[str] + The datatype of the dilated array. If unspecified, will use + the same dtype as the input array. + Returns ------- output_np : numpy.ndarray n-D, the same layout as Input. + """ - n = len(input_np.shape) - assert len(strides) == n, "Input dimension and strides size dismatch : %d vs %d" % ( - n, + assert len(input_np.shape) == len( + strides + ), "Input dimension and strides size dismatch : %d vs %d" % ( + len(input_np.shape), len(strides), ) - output_size = () - no_zero = () - for i in range(n): - output_size += ((input_np.shape[i] - 1) * strides[i] + 1,) - no_zero += ((range(0, output_size[i], strides[i])),) - output_np = np.ones(shape=output_size) - output_np = dilation_value * output_np - output_np[np.ix_(*no_zero)] = input_np + + if out_dtype is None: + out_dtype = input_np.dtype + + output_size = [ + (input_dim - 1) * stride + 1 for input_dim, stride in zip(input_np.shape, strides) + ] + non_zero_elements = np.ix_( + *[range(0, output_dim, stride) for output_dim, stride in zip(output_size, strides)] + ) + + output_np = np.full(shape=output_size, fill_value=dilation_value, dtype=out_dtype) + output_np[non_zero_elements] = input_np return output_np diff --git a/python/tvm/topi/testing/poolnd_python.py b/python/tvm/topi/testing/poolnd_python.py index 43440d32f44e..28bf5fc26497 100644 --- a/python/tvm/topi/testing/poolnd_python.py +++ b/python/tvm/topi/testing/poolnd_python.py @@ -18,12 +18,60 @@ """Ground truth max and average pooling operators in python.""" import itertools import math -from typing import List, Tuple +from typing import List, Tuple, Optional import numpy as np import tvm +def _get_supported_layout(dims: int): + """ + Returns layout that is supported by poolnd_python based on number of + dimensions of input tensor + """ + assert dims in [3, 4, 5], f"{dims}-dimensional tensor is not supported" + if dims == 3: + return "NCW" + if dims == 4: + return "NCHW" + # dims == 5 + return "NCDHW" + + +def _convert_to_layout( + input_tensor: np.ndarray, + layout: str, +) -> np.ndarray: + """ + Converts back to original layout after the algorithm is finished + """ + supported_layout = _get_supported_layout(input_tensor.ndim) + if layout is not None and supported_layout != layout: + # Generate transpose list + transpose_list = [] + for d in layout: + transpose_list.append(supported_layout.index(d)) + return input_tensor.transpose(transpose_list) + return input_tensor + + +def _convert_from_layout( + input_tensor: np.ndarray, + layout: str, +) -> np.ndarray: + """ + Converts tensor to one of suppored layouts + """ + supported_layout = _get_supported_layout(input_tensor.ndim) + if layout is not None and supported_layout != layout: + # Generate transpose list + transpose_list = [] + for d in supported_layout: + transpose_list.append(layout.index(d)) + return input_tensor.transpose(transpose_list) + return input_tensor + + def get_slice( spatial_dimensions: int, pad_np: np.array, @@ -90,8 +138,12 @@ def poolnd_python( count_include_pad: bool = True, ceil_mode: bool = False, dtype: str = "float32", + layout: Optional[str] = None, ) -> np.array: """Ground truth pooling operator impelmented in numpy.""" + + np_data = _convert_from_layout(np_data, layout) + out_shape = [np_data.shape[0], np_data.shape[1]] for dim in range(2, len(np_data.shape)): i = dim - 2 @@ -158,4 +210,4 @@ def poolnd_python( else: raise ValueError("Pool type {} is not supported".format(pool_type)) - return ret_np + return _convert_to_layout(ret_np, layout) diff --git a/python/tvm/topi/testing/resize_python.py b/python/tvm/topi/testing/resize_python.py index e8d5c0599887..13b460f07e1d 100644 --- a/python/tvm/topi/testing/resize_python.py +++ b/python/tvm/topi/testing/resize_python.py @@ -66,51 +66,52 @@ def resize3d_nearest(arr, scale, coordinate_transformation_mode): def resize3d_linear(data_in, scale, coordinate_transformation_mode): """Trilinear 3d scaling using python""" + dtype = data_in.dtype d, h, w = data_in.shape new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)] data_out = np.ones((new_d, new_h, new_w)) - def _lerp(A, B, t): - return A * (1.0 - t) + B * t + indexes = np.mgrid[0:2, 0:2, 0:2] - def _in_coord(new_coord, in_shape, out_shape): - in_coord = get_inx(new_coord, in_shape, out_shape, coordinate_transformation_mode) - coord0 = int(math.floor(in_coord)) - coord1 = max(min(coord0 + 1, in_shape - 1), 0) - coord0 = max(coord0, 0) - coord_lerp = in_coord - math.floor(in_coord) - return coord0, coord1, coord_lerp + def _get_patch(zint, yint, xint): + # Get the surrounding values + indices = indexes.copy() + indices[0] = np.maximum(np.minimum(indexes[0] + zint, d - 1), 0) + indices[1] = np.maximum(np.minimum(indexes[1] + yint, h - 1), 0) + indices[2] = np.maximum(np.minimum(indexes[2] + xint, w - 1), 0) + p = data_in[indices[0], indices[1], indices[2]] + return p for m in range(new_d): for j in range(new_h): for k in range(new_w): - z0, z1, z_lerp = _in_coord(m, d, new_d) - y0, y1, y_lerp = _in_coord(j, h, new_h) - x0, x1, x_lerp = _in_coord(k, w, new_w) - - A0 = data_in[z0][y0][x0] - B0 = data_in[z0][y0][x1] - C0 = data_in[z0][y1][x0] - D0 = data_in[z0][y1][x1] - A1 = data_in[z1][y0][x0] - B1 = data_in[z1][y0][x1] - C1 = data_in[z1][y1][x0] - D1 = data_in[z1][y1][x1] - - A = _lerp(A0, A1, z_lerp) - B = _lerp(B0, B1, z_lerp) - C = _lerp(C0, C1, z_lerp) - D = _lerp(D0, D1, z_lerp) - top = _lerp(A, B, x_lerp) - bottom = _lerp(C, D, x_lerp) - - data_out[m][j][k] = np.float32(_lerp(top, bottom, y_lerp)) + in_z = get_inx(m, d, new_d, coordinate_transformation_mode) + in_y = get_inx(j, h, new_h, coordinate_transformation_mode) + in_x = get_inx(k, w, new_w, coordinate_transformation_mode) + zint = math.floor(in_z) + zfract = in_z - math.floor(in_z) + + yint = math.floor(in_y) + yfract = in_y - math.floor(in_y) + + xint = math.floor(in_x) + xfract = in_x - math.floor(in_x) + + wz = np.array([1.0 - zfract, zfract], dtype=dtype) + wy = np.array([1.0 - yfract, yfract], dtype=dtype) + wx = np.array([1.0 - xfract, xfract], dtype=dtype) + + p = _get_patch(zint, yint, xint) + l = np.sum(p * wx, axis=-1) + col = np.sum(l * wy, axis=-1) + data_out[m, j, k] = np.sum(col * wz) return data_out def resize3d_cubic(data_in, scale, coordinate_transformation_mode): """Tricubic 3d scaling using python""" + dtype = data_in.dtype d, h, w = data_in.shape new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)] data_out = np.ones((new_d, new_h, new_w)) @@ -123,29 +124,17 @@ def _cubic_spline_weights(t, alpha=-0.5): w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1 w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t w4 = -alpha * t3 + alpha * t2 - return [w1, w2, w3, w4] + return np.array([w1, w2, w3, w4]) - def _cubic_kernel(inputs, w): - """perform cubic interpolation in 1D""" - return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) - - def _get_input_value(z, y, x): - z = max(min(z, d - 1), 0) - y = max(min(y, h - 1), 0) - x = max(min(x, w - 1), 0) - return data_in[z][y][x] + indexes = np.mgrid[-1:3, -1:3, -1:3] def _get_patch(zint, yint, xint): # Get the surrounding values - p = [[[0 for i in range(4)] for j in range(4)] for k in range(4)] - for kk in range(4): - for jj in range(4): - for ii in range(4): - p[kk][jj][ii] = _get_input_value( - zint + kk - 1, - yint + jj - 1, - xint + ii - 1, - ) + indices = indexes.copy() + indices[0] = np.maximum(np.minimum(indexes[0] + zint, d - 1), 0) + indices[1] = np.maximum(np.minimum(indexes[1] + yint, h - 1), 0) + indices[2] = np.maximum(np.minimum(indexes[2] + xint, w - 1), 0) + p = data_in[indices[0], indices[1], indices[2]] return p for m in range(new_d): @@ -169,16 +158,9 @@ def _get_patch(zint, yint, xint): p = _get_patch(zint, yint, xint) - l = [[0 for i in range(4)] for j in range(4)] - for jj in range(4): - for ii in range(4): - l[jj][ii] = _cubic_kernel(p[jj][ii], wx) - - col0 = _cubic_kernel(l[0], wy) - col1 = _cubic_kernel(l[1], wy) - col2 = _cubic_kernel(l[2], wy) - col3 = _cubic_kernel(l[3], wy) - data_out[m][j][k] = _cubic_kernel([col0, col1, col2, col3], wz) + l = np.sum(p * wx, axis=-1) + col = np.sum(l * wy, axis=-1) + data_out[m, j, k] = np.sum(col * wz) return data_out diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 37bdd09d6ca6..13ca851f0e38 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -20,52 +20,66 @@ from tvm import autotvm from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, mkl -from .. import generic +from .. import generic, nn from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor @autotvm.register_topi_compute("batch_matmul.x86") -def batch_matmul(cfg, x, y, out_shape=None): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. Supports broadcasting in batch dimension. +def batch_matmul( + cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): + """Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + + Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. Parameters ---------- cfg : ConfigSpace - Autotvm tuning space config file - x : tvm.te.Tensor - 3-D with shape [batch, M, K] - y : tvm.te.Tensor - 3-D with shape [batch, N, K] - out_shape : tuple or None - Shape of the outputs + Autotvm tuning space config file. + + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. Returns ------- output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" - XB, M, XK = get_const_tuple(x.shape) - YB, N, YK = get_const_tuple(y.shape) - assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match" - assert XK == YK, "shapes of x and y is inconsistent" - B = te.max(XB, YB) - K = XK - if out_shape is not None: - assert out_shape[0] == B, "got invalid output shape" - assert out_shape[1] == M, "got invalid output shape" - assert out_shape[2] == N, "got invalid output shape" if cfg.is_fallback: + if transpose_a: + _, K, M = get_const_tuple(tensor_a.shape) + else: + _, M, K = get_const_tuple(tensor_a.shape) + if transpose_b: + _, N, _ = get_const_tuple(tensor_b.shape) + else: + _, _, N = get_const_tuple(tensor_b.shape) _default_batch_matmul_config(cfg, M, N, K) - - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (B, M, N), - lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), - tag="batch_matmul", + return nn.batch_matmul( + tensor_a, + tensor_b, + out_shape, + out_dtype, + transpose_a, + transpose_b, ) - return C @autotvm.register_topi_schedule("batch_matmul.x86") @@ -137,20 +151,32 @@ def _default_batch_matmul_config(cfg, M, N, K): cfg["tile_y"] = SplitEntity([M // y_bn, y_bn]) -def batch_matmul_blas_common(cfg, x, y, out_shape, lib): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch, using one of BLAS libraries. Supports broadcasting in batch dimension. +def batch_matmul_blas_common(cfg, tensor_a, tensor_b, out_shape, trans_a, trans_b, lib): + """Computes batch matrix multiplication of `tensor_a` and `tensor_b` when `tensor_a` and + `tensor_b` are data in batch, using one of BLAS libraries. Supports broadcasting in batch + dimension. Parameters ---------- cfg : ConfigSpace Autotvm tuning space config file - x : tvm.te.Tensor - 3-D with shape [batch, M, K] - y : tvm.te.Tensor - 3-D with shape [batch, N, K] - out_shape : tuple or None - Shape of the output + + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + trans_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + trans_b : Optional[bool] = True + Whether the second tensor is in transposed format. + lib : A contrib module which implements batch_matmul function cblas and mkl are supported @@ -159,9 +185,15 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" - XB, M, XK = get_const_tuple(x.shape) - YB, N, YK = get_const_tuple(y.shape) + assert len(tensor_a.shape) == 3 and len(tensor_b.shape) == 3, "only support 3-dim batch_matmul" + if trans_a: + XB, XK, M = get_const_tuple(tensor_a.shape) + else: + XB, M, XK = get_const_tuple(tensor_a.shape) + if trans_b: + YB, N, YK = get_const_tuple(tensor_b.shape) + else: + YB, YK, N = get_const_tuple(tensor_a.shape) assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match" assert XK == YK, "shapes of x and y is inconsistent" if out_shape is not None: @@ -169,13 +201,18 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib): assert out_shape[1] == M, "got invalid output shape" assert out_shape[2] == N, "got invalid output shape" cfg.add_flop(XB * M * N * XK * 2) - return lib.batch_matmul(x, y, False, True) + return lib.batch_matmul(tensor_a, tensor_b, trans_a, trans_b) @autotvm.register_topi_compute("batch_matmul_cblas.x86") -def batch_matmul_cblas(cfg, x, y, out_shape=None): +def batch_matmul_cblas( + cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """Compute batch_matmul using cblas""" - return batch_matmul_blas_common(cfg, x, y, out_shape, cblas) + del out_dtype # Unused argument + return batch_matmul_blas_common( + cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, cblas + ) @autotvm.register_topi_schedule("batch_matmul_cblas.x86") @@ -185,9 +222,14 @@ def schedule_batch_matmul_cblas(_, outs): @autotvm.register_topi_compute("batch_matmul_mkl.x86") -def batch_matmul_mkl(cfg, x, y, out_shape=None): +def batch_matmul_mkl( + cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """Compute batch_matmul using mkl""" - return batch_matmul_blas_common(cfg, x, y, out_shape, mkl) + del out_dtype # Unused argument + return batch_matmul_blas_common( + cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, mkl + ) @autotvm.register_topi_schedule("batch_matmul_mkl.x86") diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 6f2c202e3f61..29c378dda30f 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -154,7 +154,6 @@ def _default_dense_nopack_config(cfg, M, N, K): cfg["tile_k"] = SplitEntity([K // tilek_bn, tilek_bn]) cfg["tile_x"] = SplitEntity([N, 1]) cfg["tile_y"] = SplitEntity([1, M]) - return M, N, K @autotvm.register_topi_compute("dense_nopack.x86") @@ -175,7 +174,7 @@ def dense_nopack(cfg, data, weight, bias=None, out_dtype=None): "tile_k", 32 if isinstance(K, (tvm.tir.Var, tvm.tir.Any)) else K, num_outputs=2 ) if cfg.is_fallback: - M, N, K = _default_dense_nopack_config(cfg, M, N, K) + _default_dense_nopack_config(cfg, M, N, K) vec = cfg["tile_k"].size[-1] k = te.reduce_axis((0, K // vec), "k") diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 5e15c8bf5368..cb2f1929d395 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -39,6 +39,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): relay.op.get("nn.dense"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) + if workload: cfg = dispatch_ctx.query(target, workload) topi_impl = workload[0] @@ -62,7 +63,6 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): topi_impl, ) dispatch_ctx.update(target, new_workload, cfg) - weight_transform = relay.layout_transform(inputs[1], "NK", weight_layout) - return relay.nn.contrib_dense_pack(inputs[0], weight_transform, None, out_dtype) + return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype) return None diff --git a/python/tvm/topi/x86/group_conv2d.py b/python/tvm/topi/x86/group_conv2d.py index 0501c5534cf2..0e10052e2428 100644 --- a/python/tvm/topi/x86/group_conv2d.py +++ b/python/tvm/topi/x86/group_conv2d.py @@ -43,7 +43,9 @@ def schedule_group_conv2d_nchw(outs): return schedule_group_conv2d_nchwc(outs) -def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, layout="NCHW"): +def _get_default_config( + cfg, data, kernel, strides, padding, dilation, groups, out_dtype, layout="NCHW" +): """ Get default schedule config for the workload """ @@ -55,7 +57,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, static_data_shape.append(dim) data = te.placeholder(static_data_shape, dtype=data.dtype) - wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) + wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype, layout) _fallback_schedule(cfg, wkl) @@ -159,6 +161,7 @@ def group_conv2d_nchw_spatial_pack( ), strides, padding, + dilation, groups, out_dtype, ) diff --git a/python/tvm/topi/x86/pooling.py b/python/tvm/topi/x86/pooling.py index 91108ac7485d..db0f9faf1970 100644 --- a/python/tvm/topi/x86/pooling.py +++ b/python/tvm/topi/x86/pooling.py @@ -26,8 +26,8 @@ def vectorize(fused_axis, num_parallel_axis, vectorize_limit=64): reorder_axis = [fused_axis] for i in range(num_parallel_axis, len(sch.op.axis) - 1): reorder_axis.append(sch.op.axis[i]) - kw, kh = sch.op.reduce_axis - fuse_k = sch.fuse(kw, kh) + k = sch.op.reduce_axis + fuse_k = sch.fuse(*k) c = sch.op.axis[len(sch.op.axis) - 1] reorder_axis += [fuse_k, c] sch.reorder(*reorder_axis) @@ -83,7 +83,7 @@ def schedule_pool(outs, layout): def _schedule(PaddedInput, Pool): if isinstance(PaddedInput.op, te.tensor.ComputeOp): s[PaddedInput].compute_inline() - do_vectorize = layout[-1] not in "HWhw" + do_vectorize = layout[-1] not in "DHWdhw" _parallel_sch(s[Pool], outs[0].shape, do_vectorize) def traverse(OP): diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index c6300f6701e0..48ec233fa4bb 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -16,8 +16,10 @@ # under the License. """sparse_dense schedule on x86""" -from tvm import te +from functools import partial, reduce +from tvm import te, tir, autotvm +from ..transform import reshape from ..utils import traverse_inline, get_const_int from .utils import get_fp32_len @@ -60,3 +62,161 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + + +@autotvm.register_topi_compute("conv3x3_spNHWC.x86") +def spconv2d_3x3_nhwc(cfg, data, wdat, wind, wptr, layout="NHWC"): + """Sparse Conv2d 3x3 compute (NHWC).""" + assert layout == "NHWC" + nsamples, imh, imw, chanin = [i.value for i in data.shape] + nelems, bsrr, bsrc = [i.value for i in wdat.shape] + chanout = (wptr.shape[0].value - 1) * bsrr + + imglen, chanlen = nsamples * imh * imw, 9 * chanin + cfg.define_split("tile_y", imglen, num_outputs=3) + cfg.define_split("tile_x", chanout // bsrr, num_outputs=2) + cfg.add_flop(imglen * (nelems * bsrc * bsrr * 2 - chanout)) + if cfg.is_fallback: + cfg["tile_y"] = autotvm.task.space.SplitEntity([-1, 160, 8]) + cfg["tile_x"] = autotvm.task.space.SplitEntity([-1, 4]) + + idxsplit = lambda x, y: reduce(lambda a, b: a[:-1] + [a[-1] % b, a[-1] // b], y, [x]) + + @partial(te.compute, (imglen, chanlen), name="Im2Col") + def im2col(row, col): + j_w, j_h, j_n = idxsplit(row, [imw, imh]) + j_c, k_w, k_h = idxsplit(col, [chanin, 3]) + i_h, i_w = j_h + k_h - 1, j_w + k_w - 1 + return tir.if_then_else( + tir.all(i_h >= 0, i_h < imh, i_w >= 0, i_w < imw), data[j_n, i_h, i_w, j_c], 0 + ) + + @partial(te.compute, (imglen, chanout // bsrr, bsrr, bsrc), name="CC") + def matmul(drow, wrow, brow, bcol): + row_start, row_end = wptr[wrow], wptr[wrow + 1] + elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx") + elem = row_start + elem_idx + return te.sum( + im2col[drow, wind[elem] * bsrc + bcol] * wdat[elem, brow, bcol], axis=elem_idx + ) + + sum_bsrc = te.reduce_axis((0, bsrc), name="k") + ret = te.compute( + (imglen, chanout), + lambda y, x: te.sum(matmul[y, x // bsrr, x % bsrr, sum_bsrc], axis=sum_bsrc), + name="C", + tag="conv3x3_spNHWC", + ) + return reshape(ret, (nsamples, imh, imw, chanout)) + + +@autotvm.register_topi_schedule("conv3x3_spNHWC.x86") +def schedule_spconv2d_3x3_nhwc(cfg, outs): + """Sparse Conv2d 3x3 schedule (NHWC).""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "conv3x3_spNHWC": + (matmul,) = op.input_tensors + # wptr, wind, im2col, wdat + _, _, im2col, _ = matmul.op.input_tensors + (data,) = im2col.op.input_tensors + bsrr = matmul.shape[-2].value + chanin = data.shape[-1].value + + mm_y, mm_x = s[op].op.axis + y_t, y_o, y_i = cfg["tile_y"].apply(s, op, mm_y) + x_o, x_i = s[op].split(mm_x, factor=bsrr) + x_t, x_o = cfg["tile_x"].apply(s, op, x_o) + (sum_ax,) = s[op].op.reduce_axis + s[op].reorder(y_t, x_t, y_o, x_o, y_i, x_i, sum_ax) + s[op].unroll(sum_ax) + s[op].vectorize(x_i) + s[op].unroll(y_i) + + s[matmul].compute_at(s[op], x_o) + y_i, x_i, bsrr, bsrc = s[matmul].op.axis + (sum_ax,) = s[matmul].op.reduce_axis + s[matmul].reorder(x_i, sum_ax, y_i, bsrr, bsrc) + s[matmul].unroll(bsrc) + s[matmul].vectorize(bsrr) + s[matmul].unroll(y_i) + + s[im2col].compute_at(s[op], y_o) + y_i, sum_ax = s[im2col].op.axis + _, k_i = s[im2col].split(sum_ax, factor=chanin) + s[im2col].vectorize(k_i) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv3x3_spNCHW.x86") +def spconv2d_3x3_nchw(cfg, data, wdat, wind, wptr, layout="NCHW"): + """Sparse Conv2d 3x3 compute (NCHW).""" + nsamples, chanin, imgh, imgw = [i.value for i in data.shape] + nelems, veclen, bsrc = [i.value for i in wdat.shape] + chanout = (wptr.shape[0].value - 1) * veclen + assert bsrc == 1 and layout == "NCHW" + + cfg.add_flop(nsamples * imgh * imgw * (nelems * veclen * bsrc * 2 - chanout)) + cfg.define_split("tile_hw", imgh * imgw, num_outputs=3) + cfg.define_split("tile_ckk", chanin * 9, num_outputs=3) + + @partial(te.compute, (nsamples, chanin * 3 * 3, imgh * imgw), name="im2col") + def im2col(nsamples, ckk, imglen): + j_h, j_w = imglen // imgw, imglen % imgw + i_c, k_h, k_w = ckk // 9, ckk // 3 % 3, ckk % 3 + i_h, i_w = j_h + k_h - 1, j_w + k_w - 1 + return tir.if_then_else( + tir.all(i_h >= 0, i_h < imgh, i_w >= 0, i_w < imgw), data[nsamples, i_c, i_h, i_w], 0 + ) + + @partial( + te.compute, + (nsamples, chanout // veclen, veclen, bsrc, imgh * imgw), + name="CC", + tag="conv3x3_spNCHW", + ) + def matmul(nsamples, f_o, f_i, bsrk, imglen): + row_start, row_end = wptr[f_o], wptr[f_o + 1] + elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx") + elem = row_start + elem_idx + return te.sum( + im2col[nsamples, wind[elem] * bsrc + bsrk, imglen] * wdat[elem, f_i, bsrk], + axis=elem_idx, + ) + + return reshape(matmul, [nsamples, chanout, imgh, imgw]) + + +@autotvm.register_topi_schedule("conv3x3_spNCHW.x86") +def schedule_spconv2d_3x3_nchw(cfg, outs): + """Sparse Conv2d 3x3 schedule (NCHW).""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "conv3x3_spNCHW": + # wptr, wind, im2col, wdat + _, _, im2col, _ = op.input_tensors + + n_samples, f_o, f_i, b_c, imglen = s[op].op.axis + (sum_ax,) = s[op].op.reduce_axis + hw1, hw2, hw3 = cfg["tile_hw"].apply(s, op, imglen) + s[op].reorder(n_samples, hw1, f_o, hw2, sum_ax, f_i, b_c, hw3) + s[op].unroll(f_i) + s[op].unroll(b_c) + s[op].vectorize(hw3) + + s[im2col].compute_at(s[op], hw1) + n_samples, ckk, imglen = s[im2col].op.axis + ckk1, ckk2, ckk3 = cfg["tile_ckk"].apply(s, im2col, ckk) + hw2, hw3 = s[im2col].split(imglen, factor=cfg["tile_hw"].size[-1]) + s[im2col].reorder(n_samples, ckk1, ckk2, hw2, ckk3, hw3) + s[im2col].unroll(ckk3) + s[im2col].vectorize(hw3) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/rust/tvm-graph-rt/src/allocator.rs b/rust/tvm-graph-rt/src/allocator.rs index 81499af5f8b8..fe741aa69c23 100644 --- a/rust/tvm-graph-rt/src/allocator.rs +++ b/rust/tvm-graph-rt/src/allocator.rs @@ -17,7 +17,7 @@ * under the License. */ -use std::alloc::{self, Layout, LayoutErr}; +use std::alloc::{self, Layout, LayoutError}; const DEFAULT_ALIGN_BYTES: usize = 4; @@ -29,7 +29,7 @@ pub struct Allocation { impl Allocation { /// Allocates a chunk of memory of `size` bytes with optional alignment. - pub fn new(size: usize, align: Option) -> Result { + pub fn new(size: usize, align: Option) -> Result { let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); let layout = Layout::from_size_align(size, alignment)?; let ptr = unsafe { alloc::alloc(layout) }; diff --git a/rust/tvm-graph-rt/src/array.rs b/rust/tvm-graph-rt/src/array.rs index 8ae716a3266f..1a8ff81f56c4 100644 --- a/rust/tvm-graph-rt/src/array.rs +++ b/rust/tvm-graph-rt/src/array.rs @@ -24,7 +24,7 @@ use tvm_sys::{ffi::DLTensor, DataType, Device}; use crate::allocator::Allocation; use crate::errors::ArrayError; -use std::alloc::LayoutErr; +use std::alloc::LayoutError; /// A `Storage` is a container which holds `Tensor` data. #[derive(PartialEq)] @@ -37,7 +37,7 @@ pub enum Storage<'a> { } impl<'a> Storage<'a> { - pub fn new(size: usize, align: Option) -> Result, LayoutErr> { + pub fn new(size: usize, align: Option) -> Result, LayoutError> { Ok(Storage::Owned(Allocation::new(size, align)?)) } diff --git a/rust/tvm-graph-rt/src/graph.rs b/rust/tvm-graph-rt/src/graph.rs index de2e7dddff5c..058e55b0261c 100644 --- a/rust/tvm-graph-rt/src/graph.rs +++ b/rust/tvm-graph-rt/src/graph.rs @@ -233,7 +233,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { let mut storages: Vec = storage_num_bytes .into_iter() .map(|nbytes| Storage::new(nbytes, align)) - .collect::, std::alloc::LayoutErr>>()?; + .collect::, std::alloc::LayoutError>>()?; let tensors = izip!(storage_ids, shapes, dtypes) .map(|(storage_id, shape, dtype)| { diff --git a/rust/tvm-graph-rt/src/module/mod.rs b/rust/tvm-graph-rt/src/module/mod.rs index 511ba4b37132..a345758deca1 100644 --- a/rust/tvm-graph-rt/src/module/mod.rs +++ b/rust/tvm-graph-rt/src/module/mod.rs @@ -52,6 +52,7 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box< values.len() as i32, &mut ret_val, &mut ret_type_code, + std::ptr::null_mut(), ); if exit_code == 0 { Ok(RetValue::from_tvm_value(ret_val, ret_type_code)) diff --git a/rust/tvm-graph-rt/src/workspace.rs b/rust/tvm-graph-rt/src/workspace.rs index cf264974bc03..82bbfddcf261 100644 --- a/rust/tvm-graph-rt/src/workspace.rs +++ b/rust/tvm-graph-rt/src/workspace.rs @@ -26,7 +26,7 @@ use std::{ use crate::allocator::Allocation; use crate::errors::InvalidPointer; -use std::alloc::LayoutErr; +use std::alloc::LayoutError; const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h` @@ -50,13 +50,13 @@ impl WorkspacePool { } } - fn alloc_new(&mut self, size: usize) -> Result<*mut u8, LayoutErr> { + fn alloc_new(&mut self, size: usize) -> Result<*mut u8, LayoutError> { self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?); self.in_use.push(self.workspaces.len() - 1); Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) } - fn alloc(&mut self, size: usize) -> Result<*mut u8, LayoutErr> { + fn alloc(&mut self, size: usize) -> Result<*mut u8, LayoutError> { if self.free.is_empty() { return self.alloc_new(size); } diff --git a/rust/tvm-graph-rt/tests/test_graph_serde.rs b/rust/tvm-graph-rt/tests/test_graph_serde.rs index 7d8e867a151f..aaa33ef6dd4f 100644 --- a/rust/tvm-graph-rt/tests/test_graph_serde.rs +++ b/rust/tvm-graph-rt/tests/test_graph_serde.rs @@ -64,7 +64,7 @@ fn test_load_graph() { .unwrap() .get("func_name") .unwrap(), - "fused_nn_dense_nn_bias_add" + "tvmgen_default_fused_nn_dense_nn_bias_add" ); assert_eq!(graph.nodes[3].inputs[0].index, 0); assert_eq!(graph.nodes[4].inputs[0].index, 0); diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index c84d0aab612f..4134da5fe6d9 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -147,8 +147,8 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } } - impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> { - fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> { + impl<'a> From<&'a #ref_id> for #tvm_rt_crate::ArgValue<'a> { + fn from(object_ref: &'a #ref_id) -> #tvm_rt_crate::ArgValue<'a> { use std::ffi::c_void; let object_ptr = &object_ref.0; match object_ptr { @@ -156,18 +156,11 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { #tvm_rt_crate::ArgValue:: ObjectHandle(std::ptr::null::() as *mut c_void) } - Some(value) => value.clone().into() + Some(value) => value.into() } } } - impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> { - fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> { - let oref: #ref_id = object_ref.clone(); - #tvm_rt_crate::ArgValue::<'a>::from(oref) - } - } - impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id { type Error = #error; diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml index eb49558ec6ce..24d9061a213f 100644 --- a/rust/tvm-rt/Cargo.toml +++ b/rust/tvm-rt/Cargo.toml @@ -32,7 +32,56 @@ edition = "2018" default = ["dynamic-linking"] dynamic-linking = ["tvm-sys/dynamic-linking"] static-linking = ["tvm-sys/static-linking"] +standalone = ["tvm-sys/runtime-only"] +runtime-only = ["tvm-sys/runtime-only"] blas = ["ndarray/blas"] +# Enabling any of the following features is like setting the value to "ON" in config.cmake. +use-cuda = ["tvm-sys/use-cuda"] +use-opencl = ["tvm-sys/use-opencl"] +use-vulkan = ["tvm-sys/use-vulkan"] +use-metal = ["tvm-sys/use-metal"] +use-rocm = ["tvm-sys/use-rocm"] +use-hexagon-device = ["tvm-sys/use-hexagon-device"] +use-rpc = ["tvm-sys/use-rpc"] +use-threads = ["tvm-sys/use-threads"] +use-llvm = ["tvm-sys/use-llvm"] +use-stackvm-runtime = ["tvm-sys/use-stackvm-runtime"] +use-graph-runtime = ["tvm-sys/use-graph-runtime"] +use-graph-runtime-debug = ["tvm-sys/use-graph-runtime-debug"] +use-openmp = ["tvm-sys/use-openmp"] +use-relay-debug = ["tvm-sys/use-relay-debug"] +use-rtti = ["tvm-sys/use-rtti"] +use-mscv-mt = ["tvm-sys/use-mscv-mt"] +use-micro = ["tvm-sys/use-micro"] +use-install-dev = ["tvm-sys/use-install-dev"] +hide-private-symbols = ["tvm-sys/hide-private-symbols"] +use-fallback-stl-map = ["tvm-sys/use-fallback-stl-map"] +use-ethosn = ["tvm-sys/use-ethosn"] +use-index-default-i64 = ["tvm-sys/use-index-default-i64"] +use-tf-tvmdsoop = ["tvm-sys/use-tf-tvmdsoop"] +use-byodt-posit = ["tvm-sys/use-byodt-posit"] +use-mkl = ["tvm-sys/use-mkl"] +use-mkldnn = ["tvm-sys/use-mkldnn"] +use-dnnl-codegen = ["tvm-sys/use-dnnl-codegen"] +use-cudnn = ["tvm-sys/use-cudnn"] +use-cublas = ["tvm-sys/use-cublas"] +use-thrust = ["tvm-sys/use-thrust"] +use-miopen = ["tvm-sys/use-miopen"] +use-rocblas = ["tvm-sys/use-rocblas"] +use-sort = ["tvm-sys/use-sort"] +use-nnpack = ["tvm-sys/use-nnpack"] +use-random = ["tvm-sys/use-random"] +use-micro-standalone-runtime = ["tvm-sys/use-micro-standalone-runtime"] +use-cpp-rpc = ["tvm-sys/use-cpp-rpc"] +use-tflite = ["tvm-sys/use-tflite"] +use-coreml = ["tvm-sys/use-coreml"] +use-target-onnx = ["tvm-sys/use-target-onnx"] +use-arm-compute-lib = ["tvm-sys/use-arm-compute-lib"] +use-arm-compute-lib-graph-runtime = ["tvm-sys/use-arm-compute-lib-graph-runtime"] +use-tensorrt-codegen = ["tvm-sys/use-tensorrt-codegen"] +use-tensorrt-runtime = ["tvm-sys/use-tensorrt-runtime"] +use-vitis-ai = ["tvm-sys/use-vitis-ai"] +build-static-runtime = ["tvm-sys/build-static-runtime"] [dependencies] thiserror = "^1.0" diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index e8902b54f6ef..02c34a1d133f 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -45,19 +45,22 @@ external! { fn array_size(array: ObjectRef) -> i64; } -impl IsObjectRef for Array { +impl IsObjectRef for Array { type Object = Object; fn as_ptr(&self) -> Option<&ObjectPtr> { self.object.as_ptr() } + fn into_ptr(self) -> Option> { self.object.into_ptr() } + fn from_ptr(object_ptr: Option>) -> Self { let object_ref = match object_ptr { Some(o) => o.into(), _ => panic!(), }; + Array { object: object_ref, _data: PhantomData, @@ -67,7 +70,7 @@ impl IsObjectRef for Array { impl Array { pub fn from_vec(data: Vec) -> Result> { - let iter = data.into_iter().map(T::into_arg_value).collect(); + let iter = data.iter().map(T::into_arg_value).collect(); let func = Function::get("runtime.Array").expect( "runtime.Array function is not registered, this is most likely a build or linking error", @@ -151,9 +154,9 @@ impl FromIterator for Array { } } -impl<'a, T: IsObjectRef> From> for ArgValue<'a> { - fn from(array: Array) -> ArgValue<'a> { - array.object.into() +impl<'a, T: IsObjectRef> From<&'a Array> for ArgValue<'a> { + fn from(array: &'a Array) -> ArgValue<'a> { + (&array.object).into() } } diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index aec4a8ad44de..62474e6650d4 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -26,6 +26,7 @@ //! See the tests and examples repository for more examples. use std::convert::{TryFrom, TryInto}; +use std::sync::Arc; use std::{ ffi::CString, os::raw::{c_char, c_int}, @@ -34,41 +35,49 @@ use std::{ use crate::errors::Error; -pub use super::to_function::{ToFunction, Typed}; +pub use super::to_function::{RawArgs, ToFunction, Typed}; +use crate::object::AsArgValue; pub use tvm_sys::{ffi, ArgValue, RetValue}; pub type Result = std::result::Result; -/// Wrapper around TVM function handle which includes `is_global` -/// indicating whether the function is global or not, and `is_cloned` showing -/// not to drop a cloned function from Rust side. -/// The value of these fields can be accessed through their respective methods. #[derive(Debug, Hash)] -pub struct Function { - pub(crate) handle: ffi::TVMFunctionHandle, - // whether the registered function is global or not. - is_global: bool, - from_rust: bool, +struct FunctionPtr { + handle: ffi::TVMFunctionHandle, +} + +// NB(@jroesch): I think this is ok, need to double check, +// if not we should mutex the pointer or move to Rc. +unsafe impl Send for FunctionPtr {} +unsafe impl Sync for FunctionPtr {} + +impl FunctionPtr { + fn from_raw(handle: ffi::TVMFunctionHandle) -> Self { + FunctionPtr { handle } + } } -unsafe impl Send for Function {} -unsafe impl Sync for Function {} +impl Drop for FunctionPtr { + fn drop(&mut self) { + check_call!(ffi::TVMFuncFree(self.handle)); + } +} + +/// An owned thread-safe version of `tvm::PackedFunc` for consumption in Rust. +#[derive(Debug, Hash)] +pub struct Function { + inner: Arc, +} impl Function { - pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self { + pub(crate) fn from_raw(handle: ffi::TVMFunctionHandle) -> Self { Function { - handle, - is_global: false, - from_rust: false, + inner: Arc::new(FunctionPtr::from_raw(handle)), } } pub unsafe fn null() -> Self { - Function { - handle: std::ptr::null_mut(), - is_global: false, - from_rust: false, - } + Function::from_raw(std::ptr::null_mut()) } /// For a given function, it returns a function by name. @@ -84,11 +93,7 @@ impl Function { if handle.is_null() { None } else { - Some(Function { - handle, - is_global: true, - from_rust: false, - }) + Some(Function::from_raw(handle)) } } @@ -103,12 +108,7 @@ impl Function { /// Returns the underlying TVM function handle. pub fn handle(&self) -> ffi::TVMFunctionHandle { - self.handle - } - - /// Returns `true` if the underlying TVM function is global and `false` otherwise. - pub fn is_global(&self) -> bool { - self.is_global + self.inner.handle } /// Calls the function that created from `Builder`. @@ -122,7 +122,7 @@ impl Function { let ret_code = unsafe { ffi::TVMFuncCall( - self.handle, + self.handle(), values.as_mut_ptr() as *mut ffi::TVMValue, type_codes.as_mut_ptr() as *mut c_int, num_args as c_int, @@ -154,12 +154,12 @@ macro_rules! impl_to_fn { where Error: From, Out: TryFrom, - $($t: Into>),* + $($t: for<'a> AsArgValue<'a>),* { fn from(func: Function) -> Self { #[allow(non_snake_case)] Box::new(move |$($t : $t),*| { - let args = vec![ $($t.into()),* ]; + let args = vec![ $((&$t).as_arg_value()),* ]; Ok(func.invoke(args)?.try_into()?) }) } @@ -171,25 +171,15 @@ impl_to_fn!(T1, T2, T3, T4, T5, T6,); impl Clone for Function { fn clone(&self) -> Function { - Self { - handle: self.handle, - is_global: self.is_global, - from_rust: true, + Function { + inner: self.inner.clone(), } } } -// impl Drop for Function { -// fn drop(&mut self) { -// if !self.is_global && !self.is_cloned { -// check_call!(ffi::TVMFuncFree(self.handle)); -// } -// } -// } - impl From for RetValue { fn from(func: Function) -> RetValue { - RetValue::FuncHandle(func.handle) + RetValue::FuncHandle(func.handle()) } } @@ -198,7 +188,7 @@ impl TryFrom for Function { fn try_from(ret_value: RetValue) -> Result { match ret_value { - RetValue::FuncHandle(handle) => Ok(Function::new(handle)), + RetValue::FuncHandle(handle) => Ok(Function::from_raw(handle)), _ => Err(Error::downcast( format!("{:?}", ret_value), "FunctionHandle", @@ -207,12 +197,12 @@ impl TryFrom for Function { } } -impl<'a> From for ArgValue<'a> { - fn from(func: Function) -> ArgValue<'a> { - if func.handle.is_null() { +impl<'a> From<&'a Function> for ArgValue<'a> { + fn from(func: &'a Function) -> ArgValue<'a> { + if func.handle().is_null() { ArgValue::Null } else { - ArgValue::FuncHandle(func.handle) + ArgValue::FuncHandle(func.handle()) } } } @@ -222,7 +212,7 @@ impl<'a> TryFrom> for Function { fn try_from(arg_value: ArgValue<'a>) -> Result { match arg_value { - ArgValue::FuncHandle(handle) => Ok(Function::new(handle)), + ArgValue::FuncHandle(handle) => Ok(Function::from_raw(handle)), _ => Err(Error::downcast( format!("{:?}", arg_value), "FunctionHandle", @@ -236,7 +226,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function { fn try_from(arg_value: &ArgValue<'a>) -> Result { match arg_value { - ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)), + ArgValue::FuncHandle(handle) => Ok(Function::from_raw(*handle)), _ => Err(Error::downcast( format!("{:?}", arg_value), "FunctionHandle", @@ -302,12 +292,12 @@ where } pub fn register_untyped>( - f: fn(Vec>) -> Result, + f: for<'a> fn(Vec>) -> Result, name: S, override_: bool, ) -> Result<()> { - // TODO(@jroesch): can we unify all the code. - let func = f.to_function(); + //TODO(@jroesch): can we unify the untpyed and typed registration functions. + let func = ToFunction::::to_function(f); let name = name.into(); // Not sure about this code let handle = func.handle(); diff --git a/rust/tvm/src/runtime/graph_rt.rs b/rust/tvm-rt/src/graph_rt.rs similarity index 93% rename from rust/tvm/src/runtime/graph_rt.rs rename to rust/tvm-rt/src/graph_rt.rs index 421a00386cf5..53f3210aa742 100644 --- a/rust/tvm/src/runtime/graph_rt.rs +++ b/rust/tvm-rt/src/graph_rt.rs @@ -19,8 +19,8 @@ use std::convert::TryInto; -use crate::runtime::Function; -use crate::{runtime::function::Result, runtime::ByteArray, Device, Module, NDArray}; +use crate::Function; +use crate::{function::Result, ByteArray, Device, Module, NDArray}; /// An instance of the C++ graph executor. /// @@ -50,11 +50,12 @@ impl GraphRt { let runtime_create_fn_ret = runtime_create_fn.invoke(vec![ graph.into(), - lib.into(), + (&lib).into(), (&dev.device_type).into(), // NOTE you must pass the device id in as i32 because that's what TVM expects (dev.device_id as i32).into(), ]); + let graph_executor_module: Module = runtime_create_fn_ret?.try_into()?; Ok(Self { module: graph_executor_module, @@ -79,7 +80,7 @@ impl GraphRt { pub fn set_input(&mut self, name: &str, input: NDArray) -> Result<()> { let ref set_input_fn = self.module.get_function("set_input", false)?; - set_input_fn.invoke(vec![name.into(), input.into()])?; + set_input_fn.invoke(vec![name.into(), (&input).into()])?; Ok(()) } @@ -101,7 +102,7 @@ impl GraphRt { /// Extract the ith output from the graph executor and write the results into output. pub fn get_output_into(&mut self, i: i64, output: NDArray) -> Result<()> { let get_output_fn = self.module.get_function("get_output", false)?; - get_output_fn.invoke(vec![i.into(), output.into()])?; + get_output_fn.invoke(vec![i.into(), (&output).into()])?; Ok(()) } } diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index ce2d709c2a6c..3b7d066e7b78 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -26,8 +26,40 @@ //! The TVM object system enables cross-language interoperability including that of closures for all //! supported languages including C++, and Python. +// Macro to check the return call to TVM runtime shared library. + +#[macro_export] +macro_rules! tvm_call { + ($e:expr) => {{ + if unsafe { $e } != 0 { + Err($crate::get_last_error().into()) + } else { + Ok(()) + } + }}; +} + +#[macro_export] +macro_rules! check_call { + ($e:expr) => {{ + if unsafe { $e } != 0 { + panic!("{}", $crate::get_last_error()); + } + }}; +} + +// Define all sumodules. +pub mod array; +pub mod device; +pub mod errors; +pub mod function; +pub mod graph_rt; +pub mod map; +pub mod module; +pub mod ndarray; pub mod object; pub mod string; +mod to_function; pub use object::*; pub use string::*; @@ -52,28 +84,6 @@ use tvm_sys::ffi; pub use tvm_macros::external; -// Macro to check the return call to TVM runtime shared library. - -#[macro_export] -macro_rules! tvm_call { - ($e:expr) => {{ - if unsafe { $e } != 0 { - Err($crate::get_last_error().into()) - } else { - Ok(()) - } - }}; -} - -#[macro_export] -macro_rules! check_call { - ($e:expr) => {{ - if unsafe { $e } != 0 { - panic!("{}", $crate::get_last_error()); - } - }}; -} - /// Gets the last error message. pub fn get_last_error() -> &'static str { unsafe { @@ -91,15 +101,6 @@ pub(crate) fn set_last_error(err: &E) { } } -pub mod array; -pub mod device; -pub mod errors; -pub mod function; -pub mod map; -pub mod module; -pub mod ndarray; -mod to_function; - /// Outputs the current TVM version. pub fn version() -> &'static str { match str::from_utf8(ffi::TVM_VERSION) { @@ -129,16 +130,17 @@ mod tests { ); } - #[test] - fn bytearray() { - let w = vec![1u8, 2, 3, 4, 5]; - let v = ByteArray::from(w.as_slice()); - let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); - assert_eq!( - tvm.data(), - w.iter().copied().collect::>().as_slice() - ); - } + // todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. + // #[test] + // fn bytearray() { + // let w = vec![1u8, 2, 3, 4, 5]; + // let v = ByteArray::from(w.as_slice()); + // let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); + // assert_eq!( + // tvm.data(), + // w.iter().copied().collect::>().as_slice() + // ); + // } #[test] fn ty() { diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index d6dfaf3641b8..5594a91dc0f0 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -58,18 +58,18 @@ external! { fn map_items(map: ObjectRef) -> Array; } -impl FromIterator<(K, V)> for Map +impl<'a, K: 'a, V: 'a> FromIterator<(&'a K, &'a V)> for Map where K: IsObjectRef, V: IsObjectRef, { - fn from_iter>(iter: T) -> Self { + fn from_iter>(iter: T) -> Self { let iter = iter.into_iter(); let (lower_bound, upper_bound) = iter.size_hint(); let mut buffer: Vec = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2); for (k, v) in iter { - buffer.push(k.into()); - buffer.push(v.into()) + buffer.push(k.into_arg_value()); + buffer.push(v.into_arg_value()); } Self::from_data(buffer).expect("failed to convert from data") } @@ -202,13 +202,13 @@ where } } -impl<'a, K, V> From> for ArgValue<'a> +impl<'a, K, V> From<&'a Map> for ArgValue<'a> where K: IsObjectRef, V: IsObjectRef, { - fn from(map: Map) -> ArgValue<'a> { - map.object.into() + fn from(map: &'a Map) -> ArgValue<'a> { + (&map.object).into() } } @@ -268,7 +268,7 @@ mod test { let mut std_map: HashMap = HashMap::new(); std_map.insert("key1".into(), "value1".into()); std_map.insert("key2".into(), "value2".into()); - let tvm_map = Map::from_iter(std_map.clone().into_iter()); + let tvm_map = Map::from_iter(std_map.iter()); let back_map = tvm_map.into(); assert_eq!(std_map, back_map); } diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs index 343f0dce8f98..8d59c2a035a9 100644 --- a/rust/tvm-rt/src/module.rs +++ b/rust/tvm-rt/src/module.rs @@ -82,7 +82,7 @@ impl Module { return Err(errors::Error::NullHandle(name.into_string()?.to_string())); } - Ok(Function::new(fhandle)) + Ok(Function::from_raw(fhandle)) } /// Imports a dependent module such as `.ptx` for cuda gpu. diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 0e2d2830615f..80f8f184140c 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -61,7 +61,7 @@ use num_traits::Num; use crate::errors::NDArrayError; -use crate::object::{Object, ObjectPtr}; +use crate::object::{Object, ObjectPtr, ObjectRef}; /// See the [`module-level documentation`](../ndarray/index.html) for more details. #[repr(C)] @@ -73,7 +73,7 @@ pub struct NDArrayContainer { // Container Base dl_tensor: DLTensor, manager_ctx: *mut c_void, - // TOOD: shape? + shape: ObjectRef, } impl NDArrayContainer { @@ -101,6 +101,21 @@ impl NDArrayContainer { .cast::() } } + + pub fn as_mut_ptr<'a>(object_ptr: &ObjectPtr) -> *mut NDArrayContainer + where + NDArrayContainer: 'a, + { + let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; + unsafe { + object_ptr + .ptr + .as_ptr() + .cast::() + .offset(base_offset) + .cast::() + } + } } fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> { diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 8c07ed9f0853..f5832fcb3ab8 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -29,6 +29,19 @@ mod object_ptr; pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef}; +pub trait AsArgValue<'a> { + fn as_arg_value(&'a self) -> ArgValue<'a>; +} + +impl<'a, T: 'static> AsArgValue<'a> for T +where + &'a T: Into>, +{ + fn as_arg_value(&'a self) -> ArgValue<'a> { + self.into() + } +} + // TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we // can't because of coherence rules. Instead, we generate them in the macro, and // add what we can (including Into instead of From) as subtraits. @@ -37,8 +50,8 @@ pub trait IsObjectRef: Sized + Clone + Into + + for<'a> AsArgValue<'a> + TryFrom - + for<'a> Into> + for<'a> TryFrom, Error = Error> + std::fmt::Debug { @@ -51,8 +64,8 @@ pub trait IsObjectRef: Self::from_ptr(None) } - fn into_arg_value<'a>(self) -> ArgValue<'a> { - self.into() + fn into_arg_value<'a>(&'a self) -> ArgValue<'a> { + self.as_arg_value() } fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result { diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 64fd6a2218aa..09d6068f1a88 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -20,11 +20,14 @@ use std::convert::TryFrom; use std::ffi::CString; use std::fmt; +use std::os::raw::c_char; use std::ptr::NonNull; use std::sync::atomic::AtomicI32; use tvm_macros::Object; -use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index}; +use tvm_sys::ffi::{ + self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeIndex2Key, TVMObjectTypeKey2Index, +}; use tvm_sys::{ArgValue, RetValue}; use crate::errors::Error; @@ -62,10 +65,12 @@ pub struct Object { /// "subtype". /// /// This function just converts the pointer to the correct type -/// and invokes the underlying typed delete function. +/// and reconstructs a Box which then is dropped to deallocate +/// the underlying allocation. unsafe extern "C" fn delete(object: *mut Object) { let typed_object: *mut T = object as *mut T; - T::typed_delete(typed_object); + let boxed: Box = Box::from_raw(typed_object); + drop(boxed); } fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool { @@ -98,6 +103,18 @@ impl Object { } } + fn get_type_key(&self) -> String { + let mut cstring: *mut c_char = std::ptr::null_mut(); + unsafe { + if TVMObjectTypeIndex2Key(self.type_index, &mut cstring as *mut _) != 0 { + panic!("{}", crate::get_last_error()); + } + return CString::from_raw(cstring) + .into_string() + .expect("type keys should be valid utf-8"); + } + } + fn get_type_index() -> u32 { let type_key = T::TYPE_KEY; let cstring = CString::new(type_key).expect("type key must not contain null characters"); @@ -148,18 +165,6 @@ impl Object { } } -// impl fmt::Debug for Object { -// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { -// let index = -// format!("{} // key: {}", self.type_index, "the_key"); - -// f.debug_struct("Object") -// .field("type_index", &index) -// // TODO(@jroesch: do we expose other fields?) -// .finish() -// } -// } - /// An unsafe trait which should be implemented for an object /// subtype. /// @@ -169,11 +174,6 @@ impl Object { /// to the subtype. pub unsafe trait IsObject: AsRef + std::fmt::Debug { const TYPE_KEY: &'static str; - - unsafe extern "C" fn typed_delete(object: *mut Self) { - let object = Box::from_raw(object); - drop(object) - } } /// A smart pointer for types which implement IsObject. @@ -264,13 +264,18 @@ impl ObjectPtr { if is_derived { Ok(unsafe { self.cast() }) } else { - Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY)) + let type_key = self.as_ref().get_type_key(); + Err(Error::downcast(type_key.into(), U::TYPE_KEY)) } } pub unsafe fn into_raw(self) -> *mut T { self.ptr.as_ptr() } + + pub unsafe fn as_ptr(&self) -> *mut T { + self.ptr.as_ptr() + } } impl std::ops::Deref for ObjectPtr { @@ -320,26 +325,25 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { } } -impl<'a, T: IsObject> From> for ArgValue<'a> { - fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { +impl<'a, T: IsObject> From<&'a ObjectPtr> for ArgValue<'a> { + fn from(object_ptr: &'a ObjectPtr) -> ArgValue<'a> { debug_assert!(object_ptr.count() >= 1); - let object_ptr = object_ptr.upcast::(); + let object_ptr = object_ptr.clone().upcast::(); match T::TYPE_KEY { "runtime.NDArray" => { use crate::ndarray::NDArrayContainer; - // TODO(this is probably not optimal) - let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap()) - as *mut NDArrayContainer as *mut std::ffi::c_void; + let dcast_ptr = object_ptr.downcast().unwrap(); + let raw_ptr = NDArrayContainer::as_mut_ptr(&dcast_ptr) as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::NDArrayHandle(raw_ptr) } "runtime.Module" => { - let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::ModuleHandle(raw_ptr) } _ => { - let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::ObjectHandle(raw_ptr) } @@ -357,14 +361,22 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { match arg_value { ArgValue::ObjectHandle(handle) | ArgValue::ModuleHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; - debug_assert!(optr.count() >= 1); + optr.inc_ref(); + // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must + // bump the reference count by one. + assert!(optr.count() >= 1); optr.downcast() } ArgValue::NDArrayHandle(handle) => { let optr = NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; - debug_assert!(optr.count() >= 1); - optr.upcast::().downcast() + // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must + // bump the reference count by one. + assert!(optr.count() >= 1); + // TODO(@jroesch): figure out if there is a more optimal way to do this + let object = optr.upcast::(); + object.inc_ref(); + object.downcast() } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), } @@ -452,11 +464,12 @@ mod tests { assert_eq!(ptr.count(), 1); let ptr_clone = ptr.clone(); assert_eq!(ptr.count(), 2); - let arg_value: ArgValue = ptr_clone.into(); + let arg_value: ArgValue = (&ptr_clone).into(); assert_eq!(ptr.count(), 2); let ptr2: ObjectPtr = arg_value.try_into()?; - assert_eq!(ptr2.count(), 2); + assert_eq!(ptr2.count(), 3); assert_eq!(ptr.count(), ptr2.count()); + drop(ptr_clone); assert_eq!(ptr.count(), 2); ensure!( ptr.type_index == ptr2.type_index, @@ -472,26 +485,71 @@ mod tests { Ok(()) } - fn test_fn(o: ObjectPtr) -> ObjectPtr { - // The call machinery adds at least 1 extra count while inside the call. + fn test_fn_raw<'a>( + mut args: crate::to_function::ArgList<'a>, + ) -> crate::function::Result { + let v: ArgValue = args.remove(0); + let v2: ArgValue = args.remove(0); + // assert_eq!(o.count(), 2); + let o: ObjectPtr = v.try_into().unwrap(); + assert_eq!(o.count(), 2); + let o2: ObjectPtr = v2.try_into().unwrap(); + assert_eq!(o2.count(), 3); + drop(o2); + assert_eq!(o.count(), 2); + Ok(o.into()) + } + + #[test] + fn test_ref_count_raw_fn() { + use super::*; + use crate::function::{register_untyped, Function}; + let ptr = ObjectPtr::new(Object::base::()); + // Call the function without the wrapping for TVM. + assert_eq!(ptr.count(), 1); + let same = test_fn_raw(vec![(&ptr).into(), (&ptr).into()]).unwrap(); + let output: ObjectPtr = same.try_into().unwrap(); + assert_eq!(output.count(), 2); + drop(output); + assert_eq!(ptr.count(), 1); + + register_untyped(test_fn_raw, "test_fn_raw", true).unwrap(); + let raw_func = Function::get("test_fn_raw").unwrap(); + let output = raw_func.invoke(vec![(&ptr).into(), (&ptr).into()]).unwrap(); + let output: ObjectPtr = output.try_into().unwrap(); + assert_eq!(output.count(), 2); + drop(output); + assert_eq!(ptr.count(), 1); + } + + fn test_fn_typed(o: ObjectPtr, o2: ObjectPtr) -> ObjectPtr { assert_eq!(o.count(), 3); + assert_eq!(o2.count(), 3); + drop(o2); + assert_eq!(o.count(), 2); return o; } #[test] - fn test_ref_count_boundary3() { + fn test_ref_count_typed() { use super::*; use crate::function::{register, Function}; let ptr = ObjectPtr::new(Object::base::()); + // Call the function without the wrapping for TVM. + assert_eq!(ptr.count(), 1); + let output = test_fn_typed(ptr.clone(), ptr.clone()); + assert_eq!(output.count(), 2); + drop(output); + assert_eq!(ptr.count(), 1); + + register(test_fn_typed, "test_fn_typed").unwrap(); + let typed_func = Function::get("test_fn_typed").unwrap(); + let output = typed_func + .invoke(vec![(&ptr).into(), (&ptr).into()]) + .unwrap(); + let output: ObjectPtr = output.try_into().unwrap(); + assert_eq!(output.count(), 2); + drop(output); assert_eq!(ptr.count(), 1); - let stay = ptr.clone(); - assert_eq!(ptr.count(), 2); - register(test_fn, "my_func2").unwrap(); - let func = Function::get("my_func2").unwrap(); - let same = func.invoke(vec![ptr.into()]).unwrap(); - let same: ObjectPtr = same.try_into().unwrap(); - // TODO(@jroesch): normalize RetValue ownership assert_eq!(same.count(), 2); - drop(same); - assert_eq!(stay.count(), 3); } } diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index c5ede7d224ce..67fbfc996af0 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -44,8 +44,16 @@ pub use tvm_sys::{ffi, ArgValue, RetValue}; /// conversion of inputs and outputs to this trait. /// /// And the implementation of it to `ToFunction`. + +pub type ArgList<'a> = Vec>; + +pub enum Args<'a, I> { + Typed(I), + Raw(ArgList<'a>), +} + pub trait Typed { - fn args(i: Vec>) -> Result; + fn args<'arg>(i: Vec>) -> Result>; fn ret(o: O) -> Result; } @@ -54,7 +62,7 @@ pub trait ToFunction: Sized { fn into_raw(self) -> *mut Self::Handle; - fn call(handle: *mut Self::Handle, args: Vec>) -> Result + fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result where Self: Typed; @@ -70,11 +78,11 @@ pub trait ToFunction: Sized { check_call!(ffi::TVMFuncCreateFromCFunc( Some(Self::tvm_callback), resource_handle as *mut _, - None, // Some(Self::tvm_finalizer), + Some(Self::tvm_finalizer), &mut fhandle as *mut ffi::TVMFunctionHandle, )); - Function::new(fhandle) + Function::from_raw(fhandle) } /// The callback function which is wrapped converted by TVM @@ -102,22 +110,28 @@ pub trait ToFunction: Sized { for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; - if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int - || tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int - || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int - || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int - || tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int - { - check_call!(ffi::TVMCbArgToReturn( - &mut value as *mut _, - &mut tcode as *mut _ - )); - } + // TODO(@jroesch): I believe it is sound to disable this specialized move rule. + // + // This is used in C++ to deal with moving an RValue or reference to a return value + // directly so you can skip copying. + // + // I believe this is not needed as the move directly occurs into the Rust function. + + // if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int + // { + // check_call!(ffi::TVMCbArgToReturn( + // &mut value as *mut _, + // &mut tcode as *mut _ + // )); + // } let arg_value = ArgValue::from_tvm_value(value, tcode as u32); local_args.push(arg_value); } - // Ref-count be 2. let rv = match Self::call(resource_handle, local_args) { Ok(v) => v, Err(msg) => { @@ -125,6 +139,12 @@ pub trait ToFunction: Sized { } }; + // TODO(@jroesch): clean up the handling of the is dec_ref + match rv.clone().try_into() as Result> { + Err(_) => {} + Ok(v) => drop(v), + }; + let (mut ret_val, ret_tcode) = rv.to_tvm_value(); let mut ret_type_code = ret_tcode as c_int; @@ -165,9 +185,11 @@ pub trait ToFunction: Sized { } } -impl Typed>, RetValue> for fn(Vec>) -> Result { - fn args(args: Vec>) -> Result>> { - Ok(args) +pub struct RawArgs; + +impl Typed for for<'a> fn(Vec>) -> Result { + fn args<'arg>(args: Vec>) -> Result> { + Ok(Args::Raw(args)) } fn ret(o: RetValue) -> Result { @@ -175,43 +197,59 @@ impl Typed>, RetValue> for fn(Vec>) -> R } } -impl ToFunction>, RetValue> - for fn(Vec>) -> Result -{ - type Handle = fn(Vec>) -> Result; +impl ToFunction for for<'arg> fn(Vec>) -> Result { + type Handle = for<'arg> fn(Vec>) -> Result; fn into_raw(self) -> *mut Self::Handle { let ptr: Box = Box::new(self); Box::into_raw(ptr) } - fn call(handle: *mut Self::Handle, args: Vec>) -> Result { - unsafe { (*handle)(args) } + fn call<'arg>(handle: *mut Self::Handle, args: Vec>) -> Result { + unsafe { + let func = *handle; + func(args) + } } fn drop(_: *mut Self::Handle) {} } +/// A helper trait which correctly captures the complex conversion and lifetime semantics needed +/// to coerce an ordinary Rust value into `ArgValue`. +pub trait TryFromArgValue: TryFrom { + fn from_arg_value(f: F) -> std::result::Result; +} + +impl<'a, T> TryFromArgValue> for T +where + Self: TryFrom>, + Error: From<>>::Error>, +{ + fn from_arg_value(f: ArgValue<'a>) -> std::result::Result { + Ok(TryFrom::try_from(f)?) + } +} + macro_rules! impl_typed_and_to_function { ($len:literal; $($t:ident),*) => { - impl Typed<($($t,)*), Out> for F + impl Typed<($($t,)*), Out> for Fun where - F: Fn($($t),*) -> Out, + Fun: Fn($($t),*) -> Out, Out: TryInto, Error: From, - $( $t: TryFrom>, - Error: From<$t::Error>, )* + $( for<'a> $t: TryFromArgValue>, )* { #[allow(non_snake_case, unused_variables, unused_mut)] - fn args(args: Vec>) -> Result<($($t,)*)> { + fn args<'arg>(args: Vec>) -> Result> { if args.len() != $len { return Err(Error::CallFailed(format!("{} expected {} arguments, got {}.\n", std::any::type_name::(), $len, args.len()))) } let mut args = args.into_iter(); - $(let $t = args.next().unwrap().try_into()?;)* - Ok(($($t,)*)) + $(let $t = TryFromArgValue::from_arg_value(args.next().unwrap())?;)* + Ok(Args::Typed(($($t,)*))) } fn ret(out: Out) -> Result { @@ -220,9 +258,9 @@ macro_rules! impl_typed_and_to_function { } - impl ToFunction<($($t,)*), Out> for F + impl ToFunction<($($t,)*), Out> for Fun where - F: Fn($($t,)*) -> Out + 'static + Fun: Fn($($t,)*) -> Out + 'static { type Handle = Box Out + 'static>; @@ -232,13 +270,18 @@ macro_rules! impl_typed_and_to_function { } #[allow(non_snake_case)] - fn call(handle: *mut Self::Handle, args: Vec>) -> Result + fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result where - F: Typed<($($t,)*), Out> + Fun: Typed<($($t,)*), Out> { - let ($($t,)*) = F::args(args)?; - let out = unsafe { (*handle)($($t),*) }; - F::ret(out) + let ($($t,)*) = match Fun::args(args)? { + Args::Raw(_) => panic!("impossible case"), + Args::Typed(typed) => typed, + }; + + let fn_ptr = unsafe { &*handle }; + let out = fn_ptr($($t),*); + Fun::ret(out) } fn drop(ptr: *mut Self::Handle) { @@ -255,13 +298,15 @@ impl_typed_and_to_function!(2; A, B); impl_typed_and_to_function!(3; A, B, C); impl_typed_and_to_function!(4; A, B, C, D); impl_typed_and_to_function!(5; A, B, C, D, E); -impl_typed_and_to_function!(6; A, B, C, D, E, G); +impl_typed_and_to_function!(6; A, B, C, D, E, F); +impl_typed_and_to_function!(7; A, B, C, D, E, F, G); +impl_typed_and_to_function!(8; A, B, C, D, E, F, G, H); #[cfg(test)] mod tests { use super::*; - fn call(f: F, args: Vec>) -> Result + fn call<'a, F, I, O>(f: F, args: Vec>) -> Result where F: ToFunction, F: Typed, diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml index c7ee98fc455a..ccc104ba9223 100644 --- a/rust/tvm-sys/Cargo.toml +++ b/rust/tvm-sys/Cargo.toml @@ -27,6 +27,54 @@ description = "Low level bindings to TVM's cross language API." default = ["dynamic-linking"] static-linking = [] dynamic-linking = [] +runtime-only = [] +# Enabling any of the following features is like setting the value to "ON" in config.cmake. +use-cuda = [] +use-opencl = [] +use-vulkan = [] +use-metal = [] +use-rocm = [] +use-hexagon-device = [] +use-rpc = [] +use-threads = [] +use-llvm = [] +use-stackvm-runtime = [] +use-graph-runtime = [] +use-graph-runtime-debug = [] +use-openmp = [] +use-relay-debug = [] +use-rtti = [] +use-mscv-mt = [] +use-micro = [] +use-install-dev = [] +hide-private-symbols = [] +use-fallback-stl-map = [] +use-ethosn = [] +use-index-default-i64 = [] +use-tf-tvmdsoop = [] +use-byodt-posit = [] +use-mkl = [] +use-mkldnn = [] +use-dnnl-codegen = [] +use-cudnn = [] +use-cublas = [] +use-thrust = [] +use-miopen = [] +use-rocblas = [] +use-sort = [] +use-nnpack = [] +use-random = [] +use-micro-standalone-runtime = [] +use-cpp-rpc = [] +use-tflite = [] +use-coreml = [] +use-target-onnx = [] +use-arm-compute-lib = [] +use-arm-compute-lib-graph-runtime = [] +use-tensorrt-codegen = [] +use-tensorrt-runtime = [] +use-vitis-ai = [] +build-static-runtime = [] [dependencies] thiserror = "^1.0" @@ -37,4 +85,4 @@ enumn = "^0.1" [build-dependencies] bindgen = { version="0.57", default-features = false, features = ["runtime"] } anyhow = "^1.0" -tvm-build = "0.1" +tvm-build = "0.2.1" diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs index d80bd9598246..7793f9f6962e 100644 --- a/rust/tvm-sys/build.rs +++ b/rust/tvm-sys/build.rs @@ -19,10 +19,13 @@ extern crate bindgen; -use std::path::{Path, PathBuf}; +use std::{ + path::{Path, PathBuf}, + str::FromStr, +}; use anyhow::{Context, Result}; -use tvm_build::BuildConfig; +use tvm_build::{BuildConfig, CMakeSetting}; /// The necessary information for detecting a TVM installation. struct TVMInstall { @@ -59,6 +62,149 @@ fn find_using_tvm_build() -> Result { let mut build_config = BuildConfig::default(); build_config.repository = Some("https://github.com/apache/tvm".to_string()); build_config.branch = Some(option_env!("TVM_BRANCH").unwrap_or("main").into()); + + if cfg!(feature = "use-cuda") { + build_config.settings.use_cuda = CMakeSetting::from_str("on").ok(); + } + if cfg!(feature = "use-opencl") { + build_config.settings.use_opencl = CMakeSetting::from_str("on").ok(); + } + if cfg!(feature = "use-vulkan") { + build_config.settings.use_vulkan = CMakeSetting::from_str("on").ok(); + } + if cfg!(feature = "use-rocm") { + build_config.settings.use_rocm = CMakeSetting::from_str("on").ok(); + } + if cfg!(feature = "use-metal") { + build_config.settings.use_metal = CMakeSetting::from_str("on").ok(); + } + if cfg!(feature = "use-hexagon-device") { + build_config.settings.use_hexagon_device = Some(true); + } + if cfg!(feature = "use-rpc") { + build_config.settings.use_rpc = Some(true); + } + if cfg!(feature = "use-threads") { + build_config.settings.use_threads = Some(true); + } + if cfg!(feature = "use-llvm") { + build_config.settings.use_llvm = CMakeSetting::from_str("on").ok(); + } + if cfg!(feature = "use-stackvm-runtime") { + build_config.settings.use_stackvm_runtime = Some(true); + } + if cfg!(feature = "use-graph-runtime") { + build_config.settings.use_graph_runtime = Some(true); + } + if cfg!(feature = "use-graph-runtime-debug") { + build_config.settings.use_graph_runtime_debug = Some(true); + } + if cfg!(feature = "use-openmp") { + build_config.settings.use_openmp = Some(true); + } + if cfg!(feature = "use-relay-debug") { + build_config.settings.use_relay_debug = Some(true); + } + if cfg!(feature = "use-rtti") { + build_config.settings.use_rtti = Some(true); + } + if cfg!(feature = "use-mscv-mt") { + build_config.settings.use_mscv_mt = Some(true); + } + if cfg!(feature = "use-micro") { + build_config.settings.use_micro = Some(true); + } + if cfg!(feature = "use-install-dev") { + build_config.settings.use_install_dev = Some(true); + } + if cfg!(feature = "hide_private-symbols") { + build_config.settings.hide_private_symbols = Some(true); + } + if cfg!(feature = "use-fallback-stl-map") { + build_config.settings.use_fallback_stl_map = Some(true); + } + if cfg!(feature = "use-ethosn") { + build_config.settings.use_ethosn = Some(true); + } + if cfg!(feature = "use-index_default-i64") { + build_config.settings.use_index_default_i64 = Some(true); + } + if cfg!(feature = "use-tf-tvmdsoop") { + build_config.settings.use_tf_tvmdsoop = Some(true); + } + if cfg!(feature = "use-byodt-posit") { + build_config.settings.use_byodt_posit = Some(true); + } + if cfg!(feature = "use-mkl") { + build_config.settings.use_mkl = CMakeSetting::from_str("on").ok(); + } + if cfg!(feature = "use-mkldnn") { + build_config.settings.use_mkldnn = CMakeSetting::from_str("on").ok(); + } + if cfg!(feature = "use-dnnl-codegen") { + build_config.settings.use_dnnl_codegen = Some(true); + } + if cfg!(feature = "use-cudnn") { + build_config.settings.use_cudnn = Some(true); + } + if cfg!(feature = "use-cublas") { + build_config.settings.use_cublas = Some(true); + } + if cfg!(feature = "use-thrust") { + build_config.settings.use_thrust = Some(true); + } + if cfg!(feature = "use-miopen") { + build_config.settings.use_miopen = Some(true); + } + if cfg!(feature = "use-rocblas") { + build_config.settings.use_rocblas = Some(true); + } + if cfg!(feature = "use-sort") { + build_config.settings.use_sort = Some(true); + } + if cfg!(feature = "use-nnpack") { + build_config.settings.use_nnpack = Some(true); + } + if cfg!(feature = "use-random") { + build_config.settings.use_random = Some(true); + } + if cfg!(feature = "use-micro-standalone-runtime") { + build_config.settings.use_micro_standalone_runtime = Some(true); + } + if cfg!(feature = "use-cpp-rpc") { + build_config.settings.use_cpp_rpc = Some(true); + } + if cfg!(feature = "use-tflite") { + build_config.settings.use_tflite = Some(true); + } + if cfg!(feature = "use-coreml") { + build_config.settings.use_coreml = Some(true); + } + if cfg!(feature = "use-target-onnx") { + build_config.settings.use_target_onnx = Some(true); + } + if cfg!(feature = "use-arm-compute-lib") { + build_config.settings.use_arm_compute_lib = Some(true); + } + if cfg!(feature = "use-arm-compute-lib-graph-runtime") { + build_config.settings.use_arm_compute_lib_graph_runtime = CMakeSetting::from_str("on").ok(); + } + if cfg!(feature = "use-tensorrt-codegen") { + build_config.settings.use_tensorrt_codegen = Some(true); + } + if cfg!(feature = "use-tensorrt-runtime") { + build_config.settings.use_tensorrt_runtime = CMakeSetting::from_str("on").ok(); + } + if cfg!(feature = "use-vitis-ai") { + build_config.settings.use_vitis_ai = Some(true); + } + if cfg!(any( + feature = "static-linking", + feature = "build-static-runtime" + )) { + build_config.settings.build_static_runtime = Some(true); + } + let build_result = tvm_build::build(build_config)?; let source_path = build_result.revision.source_path(); let build_path = build_result.revision.build_path(); @@ -84,22 +230,35 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed={}", build_path.display()); println!("cargo:rerun-if-changed={}/include", source_path.display()); - if cfg!(feature = "static-linking") { - println!("cargo:rustc-link-lib=static=tvm"); - // TODO(@jroesch): move this to tvm-build as library_path? - println!( - "cargo:rustc-link-search=native={}/build", - build_path.display() - ); - } + let library_name = if cfg!(feature = "runtime-only") { + "tvm_runtime" + } else { + "tvm" + }; - if cfg!(feature = "dynamic-linking") { - println!("cargo:rustc-link-lib=dylib=tvm"); - println!( - "cargo:rustc-link-search=native={}/build", - build_path.display() - ); - } + match &std::env::var("CARGO_CFG_TARGET_ARCH") + .expect("CARGO_CFG_TARGET_ARCH must be set by CARGO")[..] + { + "wasm32" => {} + _ => { + if cfg!(feature = "static-linking") { + println!("cargo:rustc-link-lib=static={}", library_name); + // TODO(@jroesch): move this to tvm-build as library_path? + println!( + "cargo:rustc-link-search=native={}/build", + build_path.display() + ); + } + + if cfg!(feature = "dynamic-linking") { + println!("cargo:rustc-link-lib=dylib={}", library_name); + println!( + "cargo:rustc-link-search=native={}/build", + build_path.display() + ); + } + } + }; let runtime_api = source_path.join("include/tvm/runtime/c_runtime_api.h"); let backend_api = source_path.join("include/tvm/runtime/c_backend_api.h"); diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs index 4b005abee7ef..2903a81d9c36 100644 --- a/rust/tvm-sys/src/byte_array.rs +++ b/rust/tvm-sys/src/byte_array.rs @@ -17,10 +17,9 @@ * under the License. */ use std::convert::TryFrom; -use std::os::raw::c_char; use crate::errors::ValueDowncastError; -use crate::ffi::TVMByteArray; +use crate::ffi::{TVMByteArray, TVMByteArrayFree}; use crate::{ArgValue, RetValue}; /// A newtype wrapping a raw TVM byte-array. @@ -33,20 +32,45 @@ use crate::{ArgValue, RetValue}; /// assert_eq!(barr.len(), v.len()); /// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); /// ``` -pub struct ByteArray { - /// The raw FFI ByteArray. - array: TVMByteArray, +pub enum ByteArray { + Rust(TVMByteArray), + External(TVMByteArray), +} + +impl Drop for ByteArray { + fn drop(&mut self) { + match self { + ByteArray::Rust(bytes) => { + let ptr = bytes.data; + let len = bytes.size as _; + let cap = bytes.size as _; + let data: Vec = unsafe { Vec::from_raw_parts(ptr as _, len, cap) }; + drop(data); + } + ByteArray::External(byte_array) => unsafe { + if TVMByteArrayFree(byte_array as _) != 0 { + panic!("error"); + } + }, + } + } } impl ByteArray { /// Gets the underlying byte-array - pub fn data(&self) -> &'static [u8] { - unsafe { std::slice::from_raw_parts(self.array.data as *const u8, self.array.size as _) } + pub fn data(&self) -> &[u8] { + match self { + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => unsafe { + std::slice::from_raw_parts(byte_array.data as *const u8, byte_array.size as _) + }, + } } /// Gets the length of the underlying byte-array pub fn len(&self) -> usize { - self.array.size as _ + match self { + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => byte_array.size as _, + } } /// Converts the underlying byte-array to `Vec` @@ -59,50 +83,49 @@ impl ByteArray { } } -// Needs AsRef for Vec -impl> From for ByteArray { +impl>> From for ByteArray { fn from(arg: T) -> Self { - let arg = arg.as_ref(); - ByteArray { - array: TVMByteArray { - data: arg.as_ptr() as *const c_char, - size: arg.len() as _, - }, - } + let mut incoming_bytes: Vec = arg.into(); + let mut bytes = Vec::with_capacity(incoming_bytes.len()); + bytes.append(&mut incoming_bytes); + + let mut bytes = std::mem::ManuallyDrop::new(bytes); + let ptr = bytes.as_mut_ptr(); + assert_eq!(bytes.len(), bytes.capacity()); + ByteArray::Rust(TVMByteArray { + data: ptr as _, + size: bytes.len() as _, + }) } } impl<'a> From<&'a ByteArray> for ArgValue<'a> { fn from(val: &'a ByteArray) -> ArgValue<'a> { - ArgValue::Bytes(&val.array) - } -} - -impl TryFrom> for ByteArray { - type Error = ValueDowncastError; - - fn try_from(val: ArgValue<'static>) -> Result { match val { - ArgValue::Bytes(array) => Ok(ByteArray { array: *array }), - _ => Err(ValueDowncastError { - expected_type: "ByteArray", - actual_type: format!("{:?}", val), - }), + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { + ArgValue::Bytes(byte_array) + } } } } -impl From for RetValue { - fn from(val: ByteArray) -> RetValue { - RetValue::Bytes(val.array) - } -} +// todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. +// impl From for RetValue { +// fn from(val: ByteArray) -> RetValue { +// match val { +// ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { +// // TODO(@jroesch): This requires a little more work, going to land narratives +// RetValue::Bytes(byte_array) +// } +// } +// } +// } impl TryFrom for ByteArray { type Error = ValueDowncastError; fn try_from(val: RetValue) -> Result { match val { - RetValue::Bytes(array) => Ok(ByteArray { array }), + RetValue::Bytes(array) => Ok(ByteArray::External(array)), _ => Err(ValueDowncastError { expected_type: "ByteArray", actual_type: format!("{:?}", val), @@ -118,11 +141,11 @@ mod tests { #[test] fn convert() { let v = vec![1u8, 2, 3]; - let barr = ByteArray::from(&v); + let barr = ByteArray::from(v.to_vec()); assert_eq!(barr.len(), v.len()); assert_eq!(barr.to_vec(), vec![1u8, 2, 3]); let v = b"hello"; - let barr = ByteArray::from(&v); + let barr = ByteArray::from(v.to_vec()); assert_eq!(barr.len(), v.len()); assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); } diff --git a/rust/tvm-sys/src/device.rs b/rust/tvm-sys/src/device.rs index 1da64fd60483..1ebac09bf611 100644 --- a/rust/tvm-sys/src/device.rs +++ b/rust/tvm-sys/src/device.rs @@ -65,14 +65,14 @@ use thiserror::Error; #[repr(i64)] pub enum DeviceType { CPU = 1, - CUDA, - CUDAHost, - OpenCL, - Vulkan, - Metal, - VPI, - ROCM, - ExtDev, + CUDA = 2, + CUDAHost = 3, + OpenCL = 4, + Vulkan = 7, + Metal = 8, + VPI = 9, + ROCM = 10, + ExtDev = 12, } impl Default for DeviceType { diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs index f874e672bb66..f9ac3b461c69 100644 --- a/rust/tvm-sys/src/lib.rs +++ b/rust/tvm-sys/src/lib.rs @@ -40,6 +40,7 @@ pub mod ffi { num_args: c_int, out_ret_value: *mut TVMValue, out_ret_tcode: *mut u32, + resource_handle: *mut c_void, ) -> c_int; } diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index 6f43b786780a..a74cbe318e2d 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -224,7 +224,7 @@ macro_rules! impl_pod_value { } } - impl<'a, 'v> From<&'a $type> for ArgValue<'v> { + impl<'a> From<&'a $type> for ArgValue<'a> { fn from(val: &'a $type) -> Self { Self::$variant(*val as $inner_ty) } @@ -284,9 +284,9 @@ impl<'a> From<&'a CStr> for ArgValue<'a> { } } -impl<'a> From for ArgValue<'a> { - fn from(s: CString) -> Self { - Self::String(s.into_raw()) +impl<'a> From<&'a CString> for ArgValue<'a> { + fn from(s: &'a CString) -> Self { + Self::String(s.as_ptr() as _) } } @@ -311,14 +311,14 @@ impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for &'v str { } /// Converts an unspecialized handle to a ArgValue. -impl From<*const T> for ArgValue<'static> { +impl<'a, T> From<*const T> for ArgValue<'a> { fn from(ptr: *const T) -> Self { Self::Handle(ptr as *mut c_void) } } /// Converts an unspecialized mutable handle to a ArgValue. -impl From<*mut T> for ArgValue<'static> { +impl<'a, T> From<*mut T> for ArgValue<'a> { fn from(ptr: *mut T) -> Self { Self::Handle(ptr as *mut c_void) } @@ -382,9 +382,9 @@ impl TryFrom for std::ffi::CString { // Implementations for bool. -impl<'a> From for ArgValue<'a> { - fn from(s: bool) -> Self { - (s as i64).into() +impl<'a> From<&bool> for ArgValue<'a> { + fn from(s: &bool) -> Self { + (*s as i64).into() } } diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml index ca32226e0ac5..8d9b23f7616b 100644 --- a/rust/tvm/Cargo.toml +++ b/rust/tvm/Cargo.toml @@ -34,6 +34,52 @@ dynamic-linking = ["tvm-rt/dynamic-linking"] static-linking = ["tvm-rt/static-linking"] blas = ["ndarray/blas"] python = ["pyo3"] +# Enabling any of the following features is like setting the value to "ON" in config.cmake. +use-cuda = ["tvm-rt/use-cuda"] +use-opencl = ["tvm-rt/use-opencl"] +use-vulkan = ["tvm-rt/use-vulkan"] +use-metal = ["tvm-rt/use-metal"] +use-rocm = ["tvm-rt/use-rocm"] +use-hexagon-device = ["tvm-rt/use-hexagon-device"] +use-rpc = ["tvm-rt/use-rpc"] +use-threads = ["tvm-rt/use-threads"] +use-llvm = ["tvm-rt/use-llvm"] +use-stackvm-runtime = ["tvm-rt/use-stackvm-runtime"] +use-graph-runtime = ["tvm-rt/use-graph-runtime"] +use-graph-runtime-debug = ["tvm-rt/use-graph-runtime-debug"] +use-openmp = ["tvm-rt/use-openmp"] +use-relay-debug = ["tvm-rt/use-relay-debug"] +use-rtti = ["tvm-rt/use-rtti"] +use-mscv-mt = ["tvm-rt/use-mscv-mt"] +use-micro = ["tvm-rt/use-micro"] +use-install-dev = ["tvm-rt/use-install-dev"] +hide-private-symbols = ["tvm-rt/hide-private-symbols"] +use-fallback-stl-map = ["tvm-rt/use-fallback-stl-map"] +use-ethosn = ["tvm-rt/use-ethosn"] +use-index-default-i64 = ["tvm-rt/use-index-default-i64"] +use-tf-tvmdsoop = ["tvm-rt/use-tf-tvmdsoop"] +use-byodt-posit = ["tvm-rt/use-byodt-posit"] +use-mkl = ["tvm-rt/use-mkl"] +use-mkldnn = ["tvm-rt/use-mkldnn"] +use-dnnl-codegen = ["tvm-rt/use-dnnl-codegen"] +use-cudnn = ["tvm-rt/use-cudnn"] +use-cublas = ["tvm-rt/use-cublas"] +use-thrust = ["tvm-rt/use-thrust"] +use-miopen = ["tvm-rt/use-miopen"] +use-rocblas = ["tvm-rt/use-rocblas"] +use-sort = ["tvm-rt/use-sort"] +use-nnpack = ["tvm-rt/use-nnpack"] +use-random = ["tvm-rt/use-random"] +use-micro-standalone-runtime = ["tvm-rt/use-micro-standalone-runtime"] +use-cpp-rpc = ["tvm-rt/use-cpp-rpc"] +use-tflite = ["tvm-rt/use-tflite"] +use-coreml = ["tvm-rt/use-coreml"] +use-target-onnx = ["tvm-rt/use-target-onnx"] +use-arm-compute-lib = ["tvm-rt/use-arm-compute-lib"] +use-arm-compute-lib-graph-runtime = ["tvm-rt/use-arm-compute-lib-graph-runtime"] +use-tensorrt-codegen = ["tvm-rt/use-tensorrt-codegen"] +use-tensorrt-runtime = ["tvm-rt/use-tensorrt-runtime"] +use-vitis-ai = ["tvm-rt/use-vitis-ai"] [dependencies.tvm-rt] version = "0.1.0-alpha" diff --git a/rust/tvm/examples/resnet/Cargo.toml b/rust/tvm/examples/resnet/Cargo.toml index 646385a6373e..1e45739dd93d 100644 --- a/rust/tvm/examples/resnet/Cargo.toml +++ b/rust/tvm/examples/resnet/Cargo.toml @@ -25,7 +25,7 @@ edition = "2018" [dependencies] ndarray = "0.12" -tvm = { path = "../../" } +tvm-rt = { path = "../../../tvm-rt", features = ["standalone"] } image = "0.20" csv = "1.1" anyhow = "^1.0" diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs index 9bf7d867e50f..9e3a76433ffc 100644 --- a/rust/tvm/examples/resnet/build.rs +++ b/rust/tvm/examples/resnet/build.rs @@ -22,17 +22,25 @@ use std::{io::Write, path::Path, process::Command}; fn main() -> Result<()> { let out_dir = std::env::var("CARGO_MANIFEST_DIR")?; + let python_script = concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"); + let synset_txt = concat!(env!("CARGO_MANIFEST_DIR"), "/synset.txt"); + + println!("cargo:rerun-if-changed={}", python_script); + println!("cargo:rerun-if-changed={}", synset_txt); + let output = Command::new("python3") - .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + .arg(python_script) .arg(&format!("--build-dir={}", out_dir)) .output() .with_context(|| anyhow::anyhow!("failed to run python3"))?; + if !output.status.success() { std::io::stdout() .write_all(&output.stderr) .context("Failed to write error")?; panic!("Failed to execute build script"); } + assert!( Path::new(&format!("{}/deploy_lib.o", out_dir)).exists(), "Could not prepare demo: {}", diff --git a/rust/tvm/examples/resnet/src/build_resnet.py b/rust/tvm/examples/resnet/src/build_resnet.py index 277555eeb409..df02dd78f57c 100644 --- a/rust/tvm/examples/resnet/src/build_resnet.py +++ b/rust/tvm/examples/resnet/src/build_resnet.py @@ -115,6 +115,9 @@ def download_img_labels(): f.write(synset[key]) f.write("\n") + print(synset_path) + print(synset_name) + return synset diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index 7f5fcd458c26..c22d55f2e4da 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -27,8 +27,8 @@ use ::ndarray::{Array, ArrayD, Axis}; use image::{FilterType, GenericImageView}; use anyhow::Context as _; -use tvm::runtime::graph_rt::GraphRt; -use tvm::*; +use tvm_rt::graph_rt::GraphRt; +use tvm_rt::*; fn main() -> anyhow::Result<()> { let dev = Device::cpu(0); @@ -78,24 +78,40 @@ fn main() -> anyhow::Result<()> { "/deploy_lib.so" )))?; - let mut graph_rt = GraphRt::create_from_parts(&graph, lib, dev)?; - // parse parameters and convert to TVMByteArray let params: Vec = fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params"))?; - println!("param bytes: {}", params.len()); - graph_rt.load_params(¶ms)?; + // If you want an easy way to test a memory leak simply replace the program below with: + // let mut output: Vec; + + // loop { + // let mut graph_rt = GraphRt::create_from_parts(&graph, lib.clone(), dev)?; + // graph_rt.load_params(params.clone())?; + // graph_rt.set_input("data", input.clone())?; + // graph_rt.run()?; + + // // prepare to get the output + // let output_shape = &[1, 1000]; + // let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); + // graph_rt.get_output_into(0, output_nd.clone())?; + + // // flatten the output as Vec + // output = output_nd.to_vec::()?; + // } + + let mut graph_rt = GraphRt::create_from_parts(&graph, lib, dev)?; + graph_rt.load_params(params)?; graph_rt.set_input("data", input)?; graph_rt.run()?; // prepare to get the output let output_shape = &[1, 1000]; - let output = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); - graph_rt.get_output_into(0, output.clone())?; + let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); + graph_rt.get_output_into(0, output_nd.clone())?; // flatten the output as Vec - let output = output.to_vec::()?; + let output: Vec = output_nd.to_vec::()?; // find the maximum entry in the output and its index let (argmax, max_prob) = output @@ -107,7 +123,7 @@ fn main() -> anyhow::Result<()> { // create a hash map of (class id, class name) let file = File::open("synset.txt").context("failed to open synset")?; - let synset: Vec = BufReader::new(file) + let synset: Vec = BufReader::new(file) .lines() .into_iter() .map(|x| x.expect("readline failed")) diff --git a/rust/tvm/src/compiler/graph_rt.rs b/rust/tvm/src/compiler/graph_rt.rs index 6b5873398cab..8313e47bea20 100644 --- a/rust/tvm/src/compiler/graph_rt.rs +++ b/rust/tvm/src/compiler/graph_rt.rs @@ -51,11 +51,11 @@ fn _compile_module( ) -> Result { // The RAW API is Fn(IRModule, String, String, Map, String); let module = TVM_BUILD.invoke(vec![ - module.into(), - target.into(), - target_host.into(), - params.into(), - module_name.into(), + (&module).into(), + (&target).into(), + (&target_host).into(), + (¶ms).into(), + (&module_name).into(), ])?; let module: RtModule = module.try_into().unwrap(); Ok(module) diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 513a906f6db4..ea257af1ebc0 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -99,10 +99,10 @@ external! { // Note: we don't expose update here as update is going to be removed. impl IRModule { - pub fn new(funcs: F, types: T) -> Result + pub fn new<'a, F, T>(funcs: F, types: T) -> Result where - F: IntoIterator, - T: IntoIterator, + F: IntoIterator, + T: IntoIterator, { module_new(Map::from_iter(funcs), Map::from_iter(types)) } @@ -110,7 +110,7 @@ impl IRModule { pub fn empty() -> Result { let funcs = HashMap::::new(); let types = HashMap::::new(); - IRModule::new(funcs, types) + IRModule::new(funcs.iter(), types.iter()) } pub fn parse(file_name: N, source: S) -> Result @@ -206,10 +206,10 @@ impl IRModule { Self::from_expr_with_items(expr, HashMap::new(), HashMap::new()) } - pub fn from_expr_with_items(expr: E, funcs: F, types: T) -> Result + pub fn from_expr_with_items<'a, E, F, T>(expr: E, funcs: F, types: T) -> Result where - F: IntoIterator, - T: IntoIterator, + F: IntoIterator, + T: IntoIterator, E: IsObjectRef, E::Object: AsRef<::Object>, { diff --git a/rust/tvm/src/runtime/mod.rs b/rust/tvm/src/runtime/mod.rs index 84da186557f7..69fbb371824a 100644 --- a/rust/tvm/src/runtime/mod.rs +++ b/rust/tvm/src/runtime/mod.rs @@ -18,5 +18,3 @@ */ pub use tvm_rt::*; - -pub mod graph_rt; diff --git a/rust/tvm/tests/basics/src/main.rs b/rust/tvm/tests/basics/src/main.rs index 2e0f5b5255a1..b7c30364f294 100644 --- a/rust/tvm/tests/basics/src/main.rs +++ b/rust/tvm/tests/basics/src/main.rs @@ -35,7 +35,7 @@ fn main() { let mut arr = NDArray::empty(shape, dev, dtype); arr.copy_from_buffer(data.as_mut_slice()); let ret = NDArray::empty(shape, dev, dtype); - let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap(); + let fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap(); if !fadd.enabled(dev_name) { return; } diff --git a/rust/tvm/tests/callback/src/bin/array.rs b/rust/tvm/tests/callback/src/bin/array.rs index 81ee426d3967..8deae30c076d 100644 --- a/rust/tvm/tests/callback/src/bin/array.rs +++ b/rust/tvm/tests/callback/src/bin/array.rs @@ -35,7 +35,7 @@ use tvm::{ }; fn main() { - fn sum(args: Vec>) -> Result { + fn sum<'a>(args: Vec>) -> Result { let mut ret = 0.0; for arg in args { let arg: NDArray = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/error.rs b/rust/tvm/tests/callback/src/bin/error.rs index 37027af0ca37..f8886a55c3a2 100644 --- a/rust/tvm/tests/callback/src/bin/error.rs +++ b/rust/tvm/tests/callback/src/bin/error.rs @@ -26,7 +26,7 @@ use tvm::{ }; fn main() { - fn error(_args: Vec>) -> Result { + fn error<'a>(_args: Vec>) -> Result { Err(errors::NDArrayError::DataTypeMismatch { expected: DataType::int(64, 1), actual: DataType::float(64, 1), diff --git a/rust/tvm/tests/callback/src/bin/float.rs b/rust/tvm/tests/callback/src/bin/float.rs index 6fd4f868dc79..d575f47c87cd 100644 --- a/rust/tvm/tests/callback/src/bin/float.rs +++ b/rust/tvm/tests/callback/src/bin/float.rs @@ -27,7 +27,7 @@ use tvm::{ }; fn main() { - fn sum(args: Vec>) -> Result { + fn sum<'a>(args: Vec>) -> Result { let mut ret = 0.0; for arg in args.into_iter() { let val: f64 = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/int.rs b/rust/tvm/tests/callback/src/bin/int.rs index cdea2e1044c4..fc2e40d8de4d 100644 --- a/rust/tvm/tests/callback/src/bin/int.rs +++ b/rust/tvm/tests/callback/src/bin/int.rs @@ -25,7 +25,7 @@ use tvm::{ }; fn main() { - fn sum(args: Vec>) -> Result { + fn sum<'a>(args: Vec>) -> Result { let mut ret = 0i64; for arg in args.iter() { let val: i64 = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/string.rs b/rust/tvm/tests/callback/src/bin/string.rs index dbe65ba4c631..4f3d67e95d64 100644 --- a/rust/tvm/tests/callback/src/bin/string.rs +++ b/rust/tvm/tests/callback/src/bin/string.rs @@ -26,7 +26,7 @@ use tvm::{ // FIXME fn main() { - fn concat_str(args: Vec>) -> Result { + fn concat_str<'a>(args: Vec>) -> Result { let mut ret = "".to_string(); for arg in args.iter() { let val: &str = arg.try_into()?; diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index ba549959ac98..94db659e25c9 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1137,8 +1137,10 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) // and recursively mark the corresponding components for (size_t i = 0; i < simplified_result.size(); ++i) if (!used[i]) { - if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) || - ExprUseVar(simplified_result[idx], op->combiner->rhs[i])) + if (UsesVar(simplified_result[idx], + [v = op->combiner->lhs[i].get()](const VarNode* var) { return var == v; }) || + UsesVar(simplified_result[idx], + [v = op->combiner->rhs[i].get()](const VarNode* var) { return var == v; })) mark_used(i); } }; diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index f0634feac083..d81159bf05c9 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -108,7 +108,7 @@ class LinearEqDetector : public ExprFunctor DetectLinearEquation(const PrimExpr& e, const Array& vars) for (size_t i = vars.size(); i > 1; --i) { vset.insert(vars[i - 1].get()); // The previous coeff contains the variable - if (ExprUseVar(coeff[i - 2], vset_contains)) { + if (UsesVar(coeff[i - 2], vset_contains)) { return Array(); } } diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 96cb92850d5a..ac78c55ed610 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -515,7 +515,6 @@ class IterMapRewriter : public ExprMutator { */ Optional TryFuseIters(IterSumExpr expr) { if (!is_zero(expr->base)) return NullOpt; - if (expr->args.size() == 1) return expr->args[0]; // select the iterators in order std::vector visited(expr->args.size(), false); std::vector flattened_iters, grouped_iters; @@ -1086,6 +1085,21 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter return NormalizeIterMapToExpr(expr); }); +Array IterMapSimplify(const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, bool require_bijective) { + Analyzer analyzer; + Array rewrite = + DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer); + if (rewrite.empty()) { + return indices; + } + Array res; + res.reserve(rewrite.size()); + IterMapToExprNormalizer converter(&analyzer); + for (const auto& expr : rewrite) res.push_back(converter.Convert(expr)); + return res; +} + /*! * \brief Divider to divide the bindings into two sets of bindings(outer and inner) * such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X. diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index a58e4433dadd..1d3475b13dad 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -799,6 +799,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1, x * c2), floordiv(c1, c2), c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); @@ -856,14 +858,18 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { ModularSet bmod = analyzer_->modular_set(b1.Eval()); int64_t ramp_min = floordiv(bmod->base, c2val); int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val); - if (bmod->coeff % c2val == 0) { - if (ramp_min == ramp_max) { + if (ramp_min == ramp_max) { + // If b1 can devide c2 + if (bmod->coeff % c2val == 0) { return ramp(floormod(bmod->base, c2), c1, lanes).Eval(); - } else { - return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } - } else if (c2val % bmod->coeff == 0 && ramp_min == ramp_max) { - return ramp(floormod(b1, c2), c1, lanes).Eval(); + // If all indices can be guaranteed to settle inside a coeff range + if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) { + return ramp(floormod(b1, c2), c1, lanes).Eval(); + } + } + if (bmod->coeff % c2val == 0) { + return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } } } @@ -882,6 +888,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); + TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x)); TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y)); diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index abbcba234848..e82830fa4d06 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -611,10 +611,14 @@ class FlopEstimator : public ExprFunctor { std::max(VisitExpr(op->true_value), VisitExpr(op->false_value)); } -#define VisitBinary(Node) \ - double VisitExpr_(const Node* op) final { \ - double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \ - return base + VisitExpr(op->a) + VisitExpr(op->b); \ +// Index calculations (e.g., the "i + j" expression in A[i + j]) are not counted in FLOPS. +#define VisitBinary(Node) \ + double VisitExpr_(const Node* op) final { \ + double base = 1.0; \ + if ((op->a->dtype.code() != cur_type_code_) && (op->b->dtype.code() != cur_type_code_)) { \ + base = 0.0; \ + } \ + return base + VisitExpr(op->a) + VisitExpr(op->b); \ } #define VisitUnary(Node) \ diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 7522f20523c8..54edbaee35cd 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -315,10 +315,6 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { indent_ += tab_; PrintStmt(op->body); indent_ -= tab_; - } else if (op->attr_key == tir::attr::realize_scope) { - auto v = Downcast(op->node); - alloc_storage_scope_[v] = op->value.as()->value; - PrintStmt(op->body); } else { // For now we ignore the unsupported AttrStmt PrintStmt(op->body); @@ -327,8 +323,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) { auto tensor = Downcast(op->producer); - ICHECK(alloc_storage_scope_.count(tensor->op)); - if (!alloc_storage_scope_[tensor->op].empty()) { + if (!op->storage_scope.empty()) { PrintIndent(); stream << GetTensorID(tensor) << " = allocate(("; for (size_t i = 0; i < op->bounds.size(); ++i) { @@ -339,7 +334,7 @@ void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) { stream << "), '"; PrintType(tensor->dtype, stream); stream << "', '"; - stream << alloc_storage_scope_[tensor->op] << "')\n"; + stream << op->storage_scope << "')\n"; } PrintStmt(op->body); } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index b01ca2763e28..47c13f73022f 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -168,8 +168,6 @@ class CodeGenHybrid : public ExprFunctor, * \param tensor The tensor to allocate a name. */ std::string GetTensorID(const Tensor& tensor); - /*! \brief the storage scope of allocation */ - std::map alloc_storage_scope_; }; } // namespace contrib diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 50f00140df9b..bfea3e7b67c0 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -88,7 +88,7 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std elem_offset = PrimExpr(); } - return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, "", data_alignment, + return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, data_alignment, offset_factor, buffer_type); } @@ -167,7 +167,7 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } -Array CreatePassList(bool disable_loop_partition, bool for_te_schedule) { +Array CreatePassList(bool disable_loop_partition) { transform::PassContext pass_ctx = transform::PassContext::Current(); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); @@ -214,16 +214,16 @@ Array CreatePassList(bool disable_loop_partition, bool for Array pass_list = user_lower_phase0; // PHASE 1 - if (for_te_schedule) { - pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); - } else { - pass_list.push_back(tir::transform::LowerInitBlock()); - pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); - pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); - pass_list.push_back(tir::transform::CompactBufferAllocation()); - pass_list.push_back(tir::transform::FlattenBuffer()); - } + pass_list.push_back(tir::transform::InjectPrefetch()); + pass_list.push_back(tir::transform::TextureFlatten()); + pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + pass_list.push_back(tir::transform::LowerInitBlock()); + pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::LowerMatchBuffer()); + pass_list.push_back(tir::transform::FlattenBuffer()); + pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -287,6 +287,10 @@ IRModule ScheduleToModule(te::Schedule sch, const Array& args, const tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + // Mark this schedule as being converted from an TE schedule. Makes sure that + // the correct TE passes are run. + f = WithAttr(std::move(f), "from_legacy_te_schedule", Bool(true)); + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); if (noalias) { @@ -310,7 +314,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") }); IRModule LowerModule(IRModule mod, bool simple_mode) { - Array pass_list = CreatePassList(simple_mode, false); + Array pass_list = CreatePassList(simple_mode); return LowerWithPassList(std::move(mod), pass_list); } @@ -330,7 +334,7 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_ IRModule mod = IRModule(Map({{GlobalVar(name), f}})); // Get the pass list - Array pass_list = CreatePassList(simple_mode, false); + Array pass_list = CreatePassList(simple_mode); return LowerWithPassList(std::move(mod), pass_list); } @@ -352,7 +356,7 @@ IRModule LowerSchedule(te::Schedule sch, const Array& args, const std const std::unordered_map& binds, bool simple_mode) { IRModule mod = ScheduleToModule(std::move(sch), args, name, binds); // Get the legacy TE pass list - Array pass_list = CreatePassList(simple_mode, true); + Array pass_list = CreatePassList(simple_mode); return LowerWithPassList(mod, pass_list); } @@ -377,6 +381,7 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target Array mixed_pass_list = {BindTarget(target), tir::transform::VerifyMemory()}; + mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); if (pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value()) { mixed_pass_list.push_back(tir::transform::ThreadSync("global")); } @@ -388,7 +393,7 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target if (target->GetAttr("unpacked-api").value_or(Bool(false))) { mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); } else { - mixed_pass_list.push_back(tir::transform::MakePackedAPI(0)); + mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1)); } mixed_pass_list.push_back(tir::transform::SplitHostDevice()); diff --git a/src/ir/affine_type.cc b/src/ir/affine_type.cc new file mode 100644 index 000000000000..3454b6011c9b --- /dev/null +++ b/src/ir/affine_type.cc @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/affine_type.cc + * \brief The Type information for quantized nodes. + */ +#include +#include +#include + +namespace tvm { + +using tvm::ReprPrinter; +using namespace tvm::runtime; + +TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype) { + ObjectPtr n = make_object(); + n->scale = std::move(scale); + n->zero_point = std::move(zero_point); + n->dtype = std::move(dtype); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TensorAffineTypeNode); + +TVM_REGISTER_GLOBAL("ir.TensorAffineType") + .set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype) { + return TensorAffineType(scale, zero_point, dtype); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TensorAffineType(" << node->scale << ", " << node->zero_point << ", " + << node->dtype << ")"; + }); + +TupleAffineType::TupleAffineType(Array types) { + ObjectPtr n = make_object(); + n->types = std::move(types); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TupleAffineTypeNode); + +TVM_REGISTER_GLOBAL("ir.TupleAffineType").set_body_typed([](Array types) { + return TupleAffineType(types); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleAffineType(["; + for (size_t i = 0; i < node->types.size(); ++i) { + p->stream << node->types[i]; + if (i < node->types.size() - 1) { + p->stream << ", "; + } + } + p->stream << "])"; + }); + +} // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index 7990b281fb04..d4129c84ccf5 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -284,6 +284,20 @@ Constructor IRModuleNode::LookupTag(const int32_t tag) { return (*it).second; } +String IRModuleNode::GetUniqueName(const String& name) { + String result = name; + int suffix = 0; + while (true) { + auto it = global_var_map_.find(result); + if (it == global_var_map_.end()) { + return result; + } + std::ostringstream os; + os << name << "_" << ++suffix; + result = os.str(); + } +} + struct Renamer : relay::ExprMutator, TypeMutator { Map defs; Map types; @@ -347,25 +361,38 @@ void IRModuleNode::Update(const IRModule& mod) { } } -IRModule IRModule::FromExpr(const RelayExpr& expr, - const tvm::Map& global_funcs, - const tvm::Map& type_definitions) { - auto mod = IRModule(global_funcs, type_definitions); - BaseFunc func; - std::string gv_name = "main"; +std::pair IRModule::FromExprInContext( + const RelayExpr& expr, const tvm::Map& global_funcs, + const tvm::Map& type_definitions, + std::unordered_set import_set) { + auto mod = IRModule(global_funcs, type_definitions, std::move(import_set)); + String gv_name; + // All global definitions must be functions. + BaseFunc func; if (auto* func_node = expr.as()) { func = GetRef(func_node); if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + // Function literal has been annotated with it's required global symbol. gv_name = opt.value(); } - } else { func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); } - auto main_gv = GlobalVar(gv_name); + + if (gv_name.empty()) { + // Bind function to 'main' (though rename if would clash with existing 'main'). + gv_name = mod->GetUniqueName("main"); + } + + GlobalVar main_gv(gv_name); mod->Add(main_gv, func); - return mod; + return {mod, main_gv}; +} + +IRModule IRModule::FromExpr(const RelayExpr& expr, const Map& global_funcs, + const Map& type_definitions) { + return FromExprInContext(expr, global_funcs, type_definitions).first; } void IRModuleNode::Import(const String& path) { @@ -465,11 +492,7 @@ TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32 return mod->LookupTag(tag); }); -TVM_REGISTER_GLOBAL("ir.Module_FromExpr") - .set_body_typed([](RelayExpr e, tvm::Map funcs, - tvm::Map type_defs) { - return IRModule::FromExpr(e, funcs, type_defs); - }); +TVM_REGISTER_GLOBAL("ir.Module_FromExpr").set_body_typed(&IRModule::FromExpr); TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { mod->Update(from); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 8120ca798ab2..4c37f0f1a6e9 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -435,7 +435,7 @@ Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { Sequential::Sequential(tvm::Array passes, String name) { auto n = make_object(); n->passes = std::move(passes); - PassInfo pass_info = PassInfo(2, std::move(name), {}); + PassInfo pass_info = PassInfo(0, std::move(name), {}); n->pass_info = std::move(pass_info); data_ = std::move(n); } @@ -466,7 +466,7 @@ Pass GetPass(const String& pass_name) { return (*f)(); } -// TODO(zhiics): we currenlty only sequentially execute each pass in +// TODO(zhiics): we currently only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase // ordering problem needs to be handled in the future. IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const { diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index 7340f6977943..4e79d0e74c59 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -40,7 +40,6 @@ Source::Source(SourceName src_name, std::string source) { // NB(@jroesch): std::string source_str = n->source; for (auto c : source_str) { - DLOG(INFO) << "char=" << c; if (c == '\n') { // Record the length of the line. n->line_map.back().second = length; diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 0fefb0515e49..f232994480f8 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -35,6 +35,7 @@ #include #include +#include "../tir/transforms/ir_utils.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" @@ -204,8 +205,8 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { if (!is_zero(buf->elem_offset)) { doc << ", elem_offset=" << Print(buf->elem_offset); } - if (buf->scope != "global") { - doc << ", scope=" << Doc::StrLiteral(buf->scope); + if (GetRef(buf).scope() != "global") { + doc << ", scope=" << Doc::StrLiteral(GetRef(buf).scope()); } if (buf->data_alignment != 128) { doc << ", align=" << buf->data_alignment; @@ -447,8 +448,9 @@ Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; + auto scope = GetPtrStorageScope(op->buffer_var); doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " - << Print(op->extents) << ")"; + << Print(op->extents) << "), storage_scope = " << scope; if (!is_one(op->condition)) { doc << " if " << Print(op->condition); } @@ -485,23 +487,6 @@ Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) { return doc; } -inline const char* ForKind2String(ForKind t) { - switch (t) { - case ForKind::kSerial: - return "serial"; - case ForKind::kParallel: - return "parallel"; - case ForKind::kVectorized: - return "vectorized"; - case ForKind::kUnrolled: - return "unroll"; - case ForKind::kThreadBinding: - return "thread_binding"; - } - LOG(FATAL) << "Unknown ForKind"; - return "Unknown"; -} - Doc TIRTextPrinter::VisitStmt_(const ForNode* op) { Doc doc; doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", " @@ -598,8 +583,8 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) { << Print(alloc_buf->shape) << ")" << Doc::NewLine(); } for (const auto& match_buf : block_op->match_buffers) { - body << AllocBuf(match_buf->buffer) << " = match_buffer_region(" << Print(match_buf->source) - << ")" << Doc::NewLine(); + body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")" + << Doc::NewLine(); } if (block_op->init.defined()) { Doc init_block; diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 4bbe17064c87..cc7536b48cfd 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -36,6 +37,7 @@ #include #include +#include "../tir/transforms/ir_utils.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" @@ -301,8 +303,8 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) { } else { doc << ", elem_offset=" << Print(buf->elem_offset); } - if (buf->scope != "global") { - doc << ", scope=" << Doc::StrLiteral(buf->scope); + if (buf.scope() != "global") { + doc << ", scope=" << Doc::StrLiteral(buf.scope()); } if (buf->data_alignment != -1) { doc << ", align=" << buf->data_alignment; @@ -335,29 +337,8 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) { const Buffer& buf = op->buffer; buf_not_in_headers.insert(buf.get()); - Doc doc = Print(op->buffer) << " = tir.match_buffer_region(" << Print(op->source); - if (!buf->strides.empty()) { - doc << ", strides=" << Print(buf->strides); - } - if (buf->offset_factor != 0 && buf->elem_offset->IsInstance()) { - Var elem_offset = Downcast(buf->elem_offset); - if (memo_var_.find(elem_offset) != memo_var_.end()) { - doc << ", elem_offset=" << Print(buf->elem_offset); - } else { - // implicitly define elem_offset - memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() + ".elem_offset"); - var_not_in_headers.insert(elem_offset.get()); - } - } else { - doc << ", elem_offset=" << Print(buf->elem_offset); - } - if (buf->data_alignment != -1) { - doc << ", align=" << buf->data_alignment; - } - if (buf->offset_factor != 0) { - doc << ", offset_factor=" << buf->offset_factor; - } - doc << ")"; + Doc doc = Print(op->buffer) << " = tir.match_buffer(" << Print(op->source) << ", " + << memo_buf_decl_[op->buffer] << ")"; return doc; } @@ -578,31 +559,6 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) { Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) { Doc doc; - // merge attr with allocate when possible - if (op->node->IsInstance() && op->attr_key == "storage_scope" && - op->body->IsInstance()) { - const auto* alloc = Downcast(op->body).get(); - if (alloc->buffer_var.same_as(op->node)) { - var_not_in_headers.insert(alloc->buffer_var.get()); - if (current_num_ != num_child_ - 1) { - doc << "with tir.allocate(" << Print(alloc->extents) << ", " << PrintDType(alloc->dtype) - << ", " << Print(op->value); - if (!is_one(alloc->condition)) { - doc << ", " << Print(alloc->condition); - } - doc << ") as " << Print(op->node) << ":"; - doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body)); - } else { - doc << Print(op->node) << " = tir.allocate(" << Print(alloc->extents) << ", " - << PrintDType(alloc->dtype) << ", " << Print(op->value); - if (!is_one(alloc->condition)) { - doc << ", " << Print(alloc->condition); - } - doc << ")" << Doc::NewLine() << PrintBody(alloc->body); - } - return doc; - } - } // merge attr with realize when possible if (op->node->IsInstance() && op->attr_key == "realize_scope" && op->body->IsInstance()) { @@ -680,8 +636,26 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { } Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { - LOG(FATAL) << "TVM Script Printer Internal Error: All the Allocate should be folded with Attr"; - return Doc(); + var_not_in_headers.insert(op->buffer_var.get()); + Doc doc; + auto storage_scope = GetPtrStorageScope(op->buffer_var); + if (current_num_ != num_child_ - 1) { + doc << "with tir.allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype) << ", " + << Print(storage_scope); + if (!is_one(op->condition)) { + doc << ", " << Print(op->condition); + } + doc << ") as " << Print(op->buffer_var) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + } else { + doc << Print(op->buffer_var) << " = tir.allocate(" << Print(op->extents) << ", " + << PrintDType(op->dtype) << ", " << Print(storage_scope); + if (!is_one(op->condition)) { + doc << ", " << Print(op->condition); + } + doc << ")" << Doc::NewLine() << PrintBody(op->body); + } + return doc; } Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) { @@ -709,23 +683,6 @@ Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) { return doc; } -inline const char* ForKind2String(ForKind t) { - switch (t) { - case ForKind::kSerial: - return "serial"; - case ForKind::kParallel: - return "parallel"; - case ForKind::kVectorized: - return "vectorized"; - case ForKind::kUnrolled: - return "unroll"; - case ForKind::kThreadBinding: - return "thread_binding"; - } - LOG(FATAL) << "Unknown ForKind"; - return "Unknown"; -} - Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { Doc doc; var_not_in_headers.insert(op->loop_var.get()); @@ -1013,8 +970,17 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { return memo_var_[GetRef(a)].str() < memo_var_[GetRef(b)].str(); }); for (const auto& var : vars) { - header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.var("; - header_var << PrintDType(var->dtype) << ")"; + auto type = GetRef(var)->type_annotation; + if (auto* ptr_type = type.as()) { + auto* prim_type = ptr_type->element_type.as(); + ICHECK(prim_type); + header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.buffer_var("; + header_var << PrintDType(prim_type->dtype) << ", " + << Doc::StrLiteral(ptr_type->storage_scope) << ")"; + } else { + header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.var("; + header_var << PrintDType(var->dtype) << ")"; + } } } doc << Doc::Indent(4, header_attr << header_var << header_buf << body); diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 840878390018..53c680b722cd 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -76,17 +76,20 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) } } -AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) { +AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& func_name, + const std::string& target) { auto ret = regions_.emplace(AnnotatedRegion()); (*ret.first)->id_ = region_id_++; (*ret.first)->target_ = target; + (*ret.first)->func_name_ = func_name; return *ret.first; } class AnnotatedRegionSet::Creator : protected MixedModeVisitor { public: - Creator(const Op& region_begin_op, const Op& region_end_op) - : begin_op_(region_begin_op), end_op_(region_end_op) {} + Creator(const Op& region_begin_op, const Op& region_end_op, + const std::string& func_name = "default") + : begin_op_(region_begin_op), end_op_(region_end_op), func_name_(func_name) {} AnnotatedRegionSet Create(const Expr& expr) { VisitExpr(expr); @@ -144,7 +147,7 @@ class AnnotatedRegionSet::Creator : protected MixedModeVisitor { ICHECK(!region.defined()); // Create a new region. - region = region_set_->MakeRegion(target); + region = region_set_->MakeRegion(func_name_, target); region->nodes_.insert(GetRef(call)); region->ins_.push_back(GetRef(call)); } else { @@ -213,10 +216,13 @@ class AnnotatedRegionSet::Creator : protected MixedModeVisitor { const Op begin_op_; /*! \brief Region 'end' annotation operator. */ const Op end_op_; + /*! \brief The unique function name that is used to be the name of this region set. */ + const std::string func_name_; }; -AnnotatedRegionSet AnnotatedRegionSet::Create(const Expr& expr, const Op& begin, const Op& end) { - return Creator(begin, end).Create(expr); +AnnotatedRegionSet AnnotatedRegionSet::Create(const Expr& expr, const Op& begin, const Op& end, + const std::string& func_name) { + return Creator(begin, end, func_name).Create(expr); } TVM_REGISTER_NODE_TYPE(AnnotatedRegionNode); diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index 2e4eec23f733..aca42397916c 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -62,6 +62,9 @@ class AnnotatedRegionNode : public Object { /*! \brief Get the region ID. */ int GetID() const { return id_; } + /*! \brief Get the region name. */ + std::string GetName() const { return func_name_; } + /*! \brief Get the region target. */ std::string GetTarget() const { return target_; } @@ -80,6 +83,8 @@ class AnnotatedRegionNode : public Object { protected: /*! \brief The region ID. */ int id_{-1}; + /*! \brief The func name. */ + std::string func_name_ = "default"; /*! \brief The target for this region. */ std::string target_ = "default"; /*! \brief The inputs to this region. */ @@ -177,7 +182,7 @@ class AnnotatedRegionSetNode : public Object { * * \return The new region. */ - AnnotatedRegion MakeRegion(const std::string& target); + AnnotatedRegion MakeRegion(const std::string& func_name, const std::string& target); std::unordered_set regions_; /*! \brief The next region ID to assign. */ @@ -256,10 +261,12 @@ class AnnotatedRegionSet : public ObjectRef { * \param expr The relay expr from which to construct the set. * \param begin Region begin annotation operator. * \param end Region end annotation operator. + * \param func_name function name * * \return The created RegionSet for the expression. */ - static AnnotatedRegionSet Create(const Expr& expr, const Op& begin, const Op& end); + static AnnotatedRegionSet Create(const Expr& expr, const Op& begin, const Op& end, + const std::string& func_name = "default"); private: /*! \brief Helper class to construct a RegionSet from an expr.*/ diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 84c17b53c83e..2fb35f3a2e27 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -38,7 +38,7 @@ #include #include -#include "compile_engine.h" +#include "te_compiler.h" #include "utils.h" namespace tvm { @@ -46,7 +46,6 @@ namespace relay { namespace backend { using IntegerArray = Array; -using TargetsMap = std::unordered_map; using StorageMap = std::unordered_map; @@ -54,7 +53,7 @@ using StorageMap = * This is an on demand allocator for AOT. A new temporary * (storage allocator identifier) is allocated for each operation. */ -class AOTOnDemandAllocator : public ExprVisitor { +class AOTOnDemandAllocator : public MixedModeVisitor { public: // run the visitor on a function. void Run(const Function& func) { @@ -85,10 +84,7 @@ class AOTOnDemandAllocator : public ExprVisitor { AssignReturnSid(GetRef(op)); } - void VisitExpr_(const VarNode* op) final { - ExprVisitor::VisitExpr_(op); - AssignReturnSid(GetRef(op)); - } + void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef(op)); } void VisitExpr_(const FunctionNode* op) final { // do not recurse into sub function. @@ -219,7 +215,7 @@ class AOTOnDemandAllocator : public ExprVisitor { }; /*! \brief Code generator for AOT executor */ -class AOTExecutorCodegen : public ExprVisitor { +class AOTExecutorCodegen : public MixedModeVisitor { protected: /*! * \brief Utility function to allocate a DLTensor or TVMValue @@ -287,7 +283,6 @@ class AOTExecutorCodegen : public ExprVisitor { void CreateFuncCall(Call call, std::string func_name) { tvm::Array args{tvm::tir::StringImm(func_name)}; std::vector create_func_call_stmts; - // Pack the inputs for (Expr arg : call->args) { if (params_by_expr_.find(arg) != params_by_expr_.end()) { @@ -365,155 +360,21 @@ class AOTExecutorCodegen : public ExprVisitor { return ss.str(); } - /*! - * \brief Update the "main" control function's metadata - * - * \param func The main function that contains calls to operator tir primitive functions - */ - void UpdateMainWorkspaceSize(const tir::PrimFunc& primfunc, const relay::Function& func) { - auto workspace_byte_alignment = target_host_->GetAttr("workspace-byte-alignment") - .value_or(tvm::runtime::kDefaultWorkspaceAlignment); - Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment); - // Populate FunctionInfo - auto fi_node = make_object(); - // Initialize all target workspaces to zero - for (const auto& kv : targets_) { - auto tgt = kv.second; - fi_node->workspace_sizes.Set(tgt, 0); - } - fi_node->workspace_sizes.Set(target_host_, workspace_size); - fi_node->relay_primfuncs.Set(target_host_, func); - - int64_t io_size = 0; - for (const auto& input : input_vars_) { - io_size += CalculateRelayExprSizeBytes(input->checked_type()); - } - io_size += CalculateRelayExprSizeBytes(func->body->checked_type()); - fi_node->io_sizes.Set(target_host_, io_size); - - int64_t const_size = 0; - for (const auto& kv : params_by_expr_) { - const_size += CalculateRelayExprSizeBytes(kv.first->checked_type()); - } - fi_node->constant_sizes.Set(target_host_, const_size); - function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node)); - } - - /*! - * \brief Update the function metadata for a given cached function and its relay - * primitive function. - * - * \param cfunc The cached function as provided the by the compile engine - * \param relay_func The source relay primitive function - * \param relay_target The target associated with relay primitive function - */ - void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& relay_func, - const Target& relay_target) { - auto fi_node = make_object(); - for (const auto& kv : cfunc->funcs->functions) { - auto primfunc = Downcast(kv.second); - auto workspace_byte_alignment = - target_host_->GetAttr("workspace-byte-alignment").value_or(16); - Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment); - Target primfunc_target = relay_target; - if (primfunc->attrs->dict.count("target")) { - primfunc_target = Downcast(primfunc->attrs->dict["target"]); - } - fi_node->workspace_sizes.Set(primfunc_target, workspace_size); - // Calculating size for I/O - for (auto const& param : primfunc->params) { - auto p_shape = primfunc->buffer_map[param]->shape; - int num_of_elements = 1; - for (const auto& dim_index_expr : p_shape) { - if (dim_index_expr->IsInstance()) { - num_of_elements *= dim_index_expr.as()->value; - } else { - // If shape is dynamic, we cannot calculate workspace in compile time. - num_of_elements = 0; - } - } - int element_size = primfunc->buffer_map[param]->dtype.bytes(); - fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements); - } - fi_node->constant_sizes.Set(primfunc_target, 0); - fi_node->tir_primfuncs.Set(primfunc_target, primfunc); - fi_node->relay_primfuncs.Set(primfunc_target, relay_func); - } - function_metadata_.Set(cfunc->prim_fn_var->name_hint, FunctionInfo(fi_node)); - } - void VisitExpr_(const CallNode* op) override { // Descend the call tree for (auto arg : op->args) { VisitExpr(arg); } - Expr expr = GetRef(op); - Function func; if (op->op.as()) { LOG(FATAL) << "Operators should be transformed away; try applying" << "the fuse_ops transformation to the expression."; } else if (op->op.as()) { - LOG(FATAL) << "Not implemented"; - } else if (op->op.as()) { - func = GetRef(op->op.as()); + GlobalVar node = GetRef(op->op.as()); + CreateFuncCall(GetRef(op), node->name_hint); } else { LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); } - if (!func->HasNonzeroAttr(attr::kPrimitive)) { - LOG(FATAL) << "TVM only support calls to primitive functions " - << "(i.e functions composed of fusable operator invocations)"; - } - - Target target; - - // Handle external function - if (func->GetAttr(attr::kCompiler).defined()) { - target = Target("ext_dev"); - CCacheKey key = CCacheKey(func, target); - CachedFunc ext_func = compile_engine_->Lower(key, mod_name_); - ICHECK(ext_func.defined()) << "External function is not defined."; - UpdateConstants(func, ¶ms_); - - // Generate the TIR function call - CreateFuncCall(GetRef(op), ext_func->prim_fn_var->name_hint); - return; - } - - ICHECK_GE(storage_device_map_.count(expr), 0); - StorageInfo& sinfo = storage_device_map_[expr]; - auto call_dev_type = sinfo->device_types[0]; - // Normal Relay Function - if (targets_.size() == 1) { - // homogeneous execution. - const auto& it = targets_.begin(); - target = (*it).second; - } else { - // heterogeneous execution. - std::string call_dev_name; - if (call_dev_type == 0) { - call_dev_name = "llvm"; - } else { - call_dev_name = runtime::DeviceName(call_dev_type); - } - if (targets_.count(call_dev_type) == 0) { - LOG(FATAL) << "No target is provided for device " << call_dev_name; - } - target = targets_[call_dev_type]; - } - - CCacheKey key = CCacheKey(func, target); - CachedFunc lowered_func = compile_engine_->Lower(key, mod_name_); - - if (!lowered_funcs_.count(target->str())) { - lowered_funcs_[target->str()] = IRModule(Map({})); - } - lowered_funcs_[target->str()]->Update(lowered_func->funcs); - // Update function metadata via looking at all primfuncs - UpdateFunctionMetadata(lowered_func, func, target); - - // Generate the TIR function call - CreateFuncCall(GetRef(op), lowered_func->prim_fn_var->name_hint); } void VisitExpr_(const VarNode* op) override { @@ -573,7 +434,6 @@ class AOTExecutorCodegen : public ExprVisitor { void VisitExpr_(const OpNode* op) override { throw std::runtime_error("can not compile op in non-eta expanded form"); } - void VisitExpr_(const GlobalVarNode* op) override { throw std::runtime_error(""); } void VisitExpr_(const IfNode* op) override { throw std::invalid_argument("if not supported"); } void VisitExpr_(const FunctionNode* op) override { ICHECK(op->GetAttr(attr::kCompiler).defined()) @@ -598,7 +458,7 @@ class AOTExecutorCodegen : public ExprVisitor { // Create the main PrimFunc to execute the graph. Please note that // the packed function calls don't pack their arguments. The AOT // runner function needs to be legalized by the LegalizePackedCalls pass. - tir::PrimFunc CreateMainFunc(unsigned int relay_params) { + tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) { tir::Stmt body = tir::SeqStmt(stmts_); // Allocate the sids @@ -625,8 +485,6 @@ class AOTExecutorCodegen : public ExprVisitor { // so we don't pay the price of allocation for every inference if (!allocated[sid]) { body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body); - body = tir::AttrStmt(sids_table_[sid], tir::attr::storage_scope, tir::StringImm("global"), - body); } allocated[sid] = true; } @@ -639,7 +497,7 @@ class AOTExecutorCodegen : public ExprVisitor { // Define the PrimFunc attributes Map dict_attrs; String run_func_name = - runtime::get_name_mangled(mod_name_, runtime::symbol::tvm_run_func_suffix); + runtime::get_name_mangled(mod_name, runtime::symbol::tvm_run_func_suffix); dict_attrs.Set("global_symbol", run_func_name); dict_attrs.Set("runner_function", Bool(true)); @@ -652,11 +510,11 @@ class AOTExecutorCodegen : public ExprVisitor { /*! \brief mod */ runtime::Module* mod_; /*! \brief list of input expressions (i.e., variable passed by the user) */ - std::vector input_vars_; + std::vector input_vars_; /*! \brief input and output variables belonging to the main function signature */ Array main_signature_; /*! \brief target device */ - TargetsMap targets_; + tec::TargetMap targets_; /*! \brief target host */ Target target_host_; /*! @@ -686,35 +544,72 @@ class AOTExecutorCodegen : public ExprVisitor { /*! \brief mapping sid -> tir::Var */ std::unordered_map sids_table_; /*! \brief lowered funcs */ - std::unordered_map lowered_funcs_; - /*! \brief lowered funcs */ Map function_metadata_; - /*! \brief compile engine */ - CompileEngine compile_engine_; /*! \brief the set of statements that make the program */ std::vector stmts_; /*! \brief the list of return sids (note that the function might return more then one output */ std::vector return_sid_; - /*! \brief the module name we use to mangle the function names */ - String mod_name_; public: - AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host) + AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host) : mod_(mod), targets_(targets), target_host_(target_host), - use_unpacked_api_(target_host->GetAttr("unpacked-api").value_or(Bool(false))), - compile_engine_(CompileEngine::Global()) {} + use_unpacked_api_(target_host->GetAttr("unpacked-api").value_or(Bool(false))) {} LoweredOutput Codegen(relay::Function func, String mod_name) { auto aot_allocator = AOTOnDemandAllocator(); aot_allocator.Run(func); - // Retrieve the storage map - storage_device_map_ = aot_allocator.GetStorageMap(); - mod_name_ = mod_name; + // Pre-lowering storage map and memory plan + StorageMap initial_storage_map = aot_allocator.GetStorageMap(); + StaticMemoryPlan memory_plan(initial_storage_map); + + // Build a map from each operation to device. + tec::DeviceMap device_context_map; + for (const auto& it : memory_plan->expr_to_storage_info) { + auto expr = it.first; + auto storage_info = it.second; + auto device_types = storage_info->device_types; + // CHECK_EQ(device_types.size(), 1); + tvm::Device dev; + dev.device_id = 0; + dev.device_type = device_types[0]; + device_context_map.insert({expr, dev}); + } + + // This first phase moves from implicit use of compile engine, + // to instead explicitly lowering the incoming IRModule, and then + // performing the preexisting AOT executor code generation phase. + IRModule mod = IRModule::FromExpr(func); + + IRModule new_mod = + LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) { + // We need to maintain the constant map for external + // functions so we pass this processing function which + // allows us to process each function as we lower it. + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, ¶ms_); + } + + // TODO(@areusch, @jroesch): We should refactor this to + // execute as a further pass, instead writing data to the + // lowering process directly. + tec::UpdateFunctionMetadata(func, this->function_metadata_); + })(mod); + + tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod); + function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); + auto lowered_main = lowered_module.main_module->Lookup("main"); + auto lowered_main_func = GetRef(lowered_main.as()); - for (auto input : func->params) { + // Post-lowering storage map for writing main func - this should be the same map as previously + // created, just referencing the new expressions created from lowering + auto new_allocator = AOTOnDemandAllocator(); + new_allocator.Run(lowered_main_func); + storage_device_map_ = new_allocator.GetStorageMap(); + + for (auto input : lowered_main_func->params) { input_vars_.push_back(input); main_signature_.push_back(tir::Var("input", DataType::Handle())); } @@ -722,7 +617,8 @@ class AOTExecutorCodegen : public ExprVisitor { // Define the storage allocator ids for (auto kv : storage_device_map_) { for (auto sid : kv.second->storage_ids) { - te::Var buffer_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8)))); + te::Var buffer_var(MakeString("sid_", sid), + PointerType(PrimType(DataType::Int(8)), "global")); sids_table_[sid] = buffer_var; } } @@ -733,13 +629,12 @@ class AOTExecutorCodegen : public ExprVisitor { main_signature_.push_back(tir::Var("output", DataType::Handle())); } - VisitExpr(func->body); + VisitExpr(lowered_main_func->body); // Create the runner function. Please note that the function is not legal yet // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need // to run the LegalizePackedCalls pass. - auto prim_func = CreateMainFunc(func->params.size()); - UpdateMainWorkspaceSize(prim_func, func); + auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size()); LoweredOutput ret; ret.params = std::unordered_map>(); @@ -749,17 +644,7 @@ class AOTExecutorCodegen : public ExprVisitor { std::make_pair(static_cast(param_storage_ids_[param.first]), param.second))); } - for (auto& kv : lowered_funcs_) { - if (ret.lowered_funcs.count(kv.first) == 0) { - ret.lowered_funcs.Set(kv.first, IRModule(Map({}))); - } - auto& mod = ret.lowered_funcs[kv.first]; - mod->Update(kv.second); - ret.lowered_funcs.Set(kv.first, mod); - } - ret.external_mods = compile_engine_->LowerExternalFunctions(); - - // Build the TIR IRModule + // Build the TIR IRModule for the AOT function Map symbol_map; symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func); IRModule mod_run(symbol_map); @@ -775,16 +660,23 @@ class AOTExecutorCodegen : public ExprVisitor { mod_run = pack_calls(mod_run); } - // Update the lowered functions + ret.function_metadata = std::move(function_metadata_); + + ret.lowered_funcs = lowered_module.per_target_module; + ret.external_mods = lowered_module.external_mods; + auto target_host_str = target_host_->str(); if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { ret.lowered_funcs[target_host_str]->Update(mod_run); } else { ret.lowered_funcs.Set(target_host_str, mod_run); } - ret.function_metadata = std::move(function_metadata_); - ret.metadata = runtime::Metadata(input_vars_.size(), return_sid_.size(), - runtime::kTvmExecutorAot, mod_name); + + std::vector input_var_names(input_vars_.size()); + std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(), + [](Var input_var) -> String { return input_var->name_hint(); }); + ret.metadata = + runtime::Metadata(input_var_names, return_sid_.size(), runtime::kTvmExecutorAot, mod_name); return ret; } }; @@ -842,7 +734,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { private: void init(void* mod, Map tmp) { - TargetsMap targets; + tec::TargetMap targets; Target target_host; for (const auto& it : tmp) { auto dev_type = it.first.as(); @@ -850,7 +742,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { target_host = it.second; } ICHECK(dev_type); - targets[dev_type->value] = it.second; + targets[static_cast(dev_type->value)] = it.second; } codegen_ = std::make_shared(reinterpret_cast(mod), targets, target_host); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ea53c34c793b..b2b73e9bad02 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -313,57 +313,7 @@ class RelayBuildModule : public runtime::ModuleNode { relay_module_ptr->Update(main_glb_var, new_main); } - Array pass_seqs; - Array entry_functions{"main"}; - pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); - pass_seqs.push_back(transform::ToBasicBlockNormalForm()); - - // Run all dialect legalization passes. - pass_seqs.push_back(relay::qnn::transform::Legalize()); - - // Legalize pass is restricted to homogeneous execution for now. - if (targets.size() == 1) { - pass_seqs.push_back(transform::Legalize()); - } - - pass_seqs.push_back(transform::SimplifyInference()); - - // Convert Dynamic ops to static versions - pass_seqs.push_back(transform::DynamicToStatic()); - - PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - *rv = false; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == DataType::Int(32)) { - *rv = true; - } - } - } - }); - pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); - pass_seqs.push_back(transform::SimplifyExpr()); - pass_seqs.push_back(transform::CombineParallelConv2D(3)); - pass_seqs.push_back(transform::CombineParallelDense(3)); - pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); - pass_seqs.push_back(transform::FoldConstant()); - pass_seqs.push_back(transform::FoldScaleAxis()); - pass_seqs.push_back(transform::CanonicalizeCast()); - pass_seqs.push_back(transform::CanonicalizeOps()); - - // Alter layout transformation is only applied to homogeneous execution yet. - if (targets.size() == 1) { - pass_seqs.push_back(transform::InferType()); - pass_seqs.push_back(transform::AlterOpLayout()); - } - - // Fast math optimizations. - pass_seqs.push_back(transform::FastMath()); - pass_seqs.push_back(transform::FoldConstant()); + Array pass_seqs = GetPassPrefix(targets, false); if (targets.size() == 1) { const auto& target = (*targets.begin()).second; @@ -540,6 +490,11 @@ class RelayBuildModule : public runtime::ModuleNode { auto lowered_funcs = executor_codegen_->GetIRModule(); + // No need to build for external functions. + if (lowered_funcs.find("ext_dev") != lowered_funcs.end()) { + lowered_funcs.Set("ext_dev", IRModule()); + } + // Generate a placeholder function that attaches linked params as its arguments. if (target_host->GetAttr("link-params").value_or(Bool(false))) { CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen."; diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 9d59e8e5f3a8..486a6dcd7d87 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -36,7 +36,6 @@ #include #include -#include "compile_engine.h" #include "te_compiler.h" #include "utils.h" @@ -184,7 +183,7 @@ class GraphOpNode : public GraphNode { */ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> { public: - GraphExecutorCodegen(runtime::Module* mod, const TargetMap& targets) : mod_(mod) { + GraphExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets) : mod_(mod) { targets_ = targets; } @@ -222,21 +221,22 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorGetAttr(attr::kCompiler).defined()) { - UpdateConstants(func, ¶ms_); - } - - // TODO(@areusch, @jroesch): We should refactor this to - // execute as a further pass, instead writing data to the - // lowering process directly. - UpdateFunctionMetadata(func, this->function_metadata_); - }); + IRModule new_mod = + LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) { + // We need to maintain the constant map for external + // functions so we pass this processing function which + // allows us to process each function as we lower it. + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, ¶ms_); + } + + // TODO(@areusch, @jroesch): We should refactor this to + // execute as a further pass, instead writing data to the + // lowering process directly. + tec::UpdateFunctionMetadata(func, this->function_metadata_); + })(mod); + tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod); function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); auto main_module = lowered_module.main_module; main_module = relay::transform::InferType()(main_module); @@ -580,7 +580,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> var_map_; /*! \brief target device */ - TargetMap targets_; + tec::TargetMap targets_; /*! * \brief parameters (i.e. ConstantNodes found in the graph). * These are take as inputs to the GraphExecutor. @@ -593,9 +593,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator lowered_funcs_; - /*! \brief lowered funcs */ + /*! \brief function metadata */ Map function_metadata_; /*! \brief name map */ std::unordered_map name_map_; @@ -611,7 +609,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { << "runtime::Module mod and Map targets"; void* mod = args[0]; Map tmp = args[1]; - TargetMap targets; + tec::TargetMap targets; for (const auto& it : tmp) { auto dev_type = it.first.as(); ICHECK(dev_type); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 53985c78a33c..af2cbae1f72d 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -21,26 +21,102 @@ * \file src/relay/interpreter.cc * \brief An interpreter for the Relay IR. */ + #include #include +#include #include #include #include #include #include #include +#include #include #include #include "../transforms/pass_utils.h" #include "compile_engine.h" +#include "te_compiler.h" namespace tvm { namespace relay { -using namespace runtime; +using runtime::ADT; +using runtime::ADTObj; +using runtime::NDArray; +using runtime::TVMArgsSetter; +using runtime::operator<<; + +namespace { +// TODO(mbs): Centralize. +struct PairHash { + template + std::size_t operator()(const std::pair& k) const { + return std::hash()(k.first) ^ std::hash()(k.second); + } +}; + +// Analogue of FlattenTupleType for runtime ADT vs NDArray values. +// TODO(mbs): Hoist somewhere sensible, maybe op/memory.h? +void FlattenADTAux(const ObjectRef& object_ref, std::vector* out) { + if (const NDArray::ContainerType* ndarray = object_ref.as()) { + out->push_back(GetRef(ndarray)); + } else if (const ADTObj* adt = object_ref.as()) { + for (size_t i = 0; i < adt->size; ++i) { + FlattenADTAux((*adt)[i], out); + } + } else { + LOG(FATAL) << "unsupported " << object_ref; + } +} + +std::vector FlattenADT(const ObjectRef& object_ref) { + std::vector out; + FlattenADTAux(object_ref, &out); + return out; +} -InterpreterClosure::InterpreterClosure(tvm::Map env, Function func) { +std::vector FlattenADTs(const std::vector& object_refs) { + std::vector out; + for (const auto& object_ref : object_refs) { + FlattenADTAux(object_ref, &out); + } + return out; +} + +// Analogue of ToTupleType for runtime ADT vs NDArray values. +// TODO(mbs): Hoist somewhere sensible, maybe op/memory.h? +void ToADTOrNDArrayAux(const Type& type, const std::vector& nd_arrays, int* index, + std::vector* out) { + if (type.as()) { + out->push_back(nd_arrays[*index]); + *index += 1; + } else if (const TupleTypeNode* ttn = type.as()) { + std::vector tuple_out; + for (size_t i = 0; i < ttn->fields.size(); i++) { + ToADTOrNDArrayAux(ttn->fields[i], nd_arrays, index, &tuple_out); + } + out->push_back(ADT::Tuple(tuple_out)); + } else { + LOG(FATAL) << "unsupported " << type; + } +} + +ObjectRef ToADTOrNDArray(const Type& type, const std::vector& nd_arrays) { + if (type.as() && nd_arrays.size() == 1) { + return nd_arrays[0]; + } else { + std::vector out; + int index = 0; + ToADTOrNDArrayAux(type, nd_arrays, &index, &out); + return out[0]; + } +} + +} // namespace + +InterpreterClosure::InterpreterClosure(Map env, Function func) { ObjectPtr n = make_object(); n->env = std::move(env); n->func = std::move(func); @@ -54,7 +130,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); inline const PackedFunc& GetPackedFunc(const std::string& name) { - const PackedFunc* pf = tvm::runtime::Registry::Get(name); + const PackedFunc* pf = runtime::Registry::Get(name); ICHECK(pf != nullptr) << "Cannot find function " << name << " in registry"; return *pf; } @@ -92,8 +168,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "RefValueObj(" << node->value << ")"; }); -ConstructorValue::ConstructorValue(int32_t tag, tvm::Array fields, - Constructor constructor) { +ConstructorValue::ConstructorValue(int32_t tag, Array fields, Constructor constructor) { ObjectPtr n = make_object(); n->tag = tag; n->fields = fields; @@ -102,7 +177,7 @@ ConstructorValue::ConstructorValue(int32_t tag, tvm::Array fields, } TVM_REGISTER_GLOBAL("relay._make.ConstructorValue") - .set_body_typed([](int32_t tag, tvm::Array fields, Constructor constructor) { + .set_body_typed([](int32_t tag, Array fields, Constructor constructor) { return ConstructorValue(tag, fields, constructor); }); @@ -121,9 +196,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) */ struct Frame { /*! \brief The set of local variables and arguments for the frame. */ - tvm::Map locals; + Map locals; - explicit Frame(tvm::Map locals) : locals(locals) {} + explicit Frame(Map locals) : locals(locals) {} }; /*! @@ -168,8 +243,8 @@ class InterpreterState; /*! \brief A container capturing the state of the interpreter. */ class InterpreterStateObj : public Object { public: - using Frame = tvm::Map; - using Stack = tvm::Array; + using Frame = Map; + using Stack = Array; /*! \brief The current expression under evaluation. */ Expr current_expr; @@ -177,7 +252,7 @@ class InterpreterStateObj : public Object { /*! \brief The call stack of the interpreter. */ Stack stack; - void VisitAttrs(tvm::AttrVisitor* v) { + void VisitAttrs(AttrVisitor* v) { v->Visit("current_expr", ¤t_expr); v->Visit("stack", &stack); } @@ -188,8 +263,8 @@ class InterpreterStateObj : public Object { class InterpreterState : public ObjectRef { public: - using Frame = tvm::Map; - using Stack = tvm::Array; + using Frame = Map; + using Stack = Array; InterpreterState(Expr current_expr, Stack stack); @@ -213,10 +288,13 @@ InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack st class Interpreter : public ExprFunctor, PatternFunctor { public: - Interpreter(IRModule mod, Device device, Target target) - : mod_(mod), device_(device), target_(target), debug_op_(Op::Get("debug")) { - engine_ = CompileEngine::Global(); - } + // TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule. + Interpreter(IRModule mod, Map per_target_module, Device device, Target target) + : mod_(mod), + per_target_module_(per_target_module), + device_(device), + target_(target), + debug_op_(Op::Get("debug")) {} template T WithFrame(const Frame& fr, const std::function& f) { @@ -239,8 +317,7 @@ class Interpreter : public ExprFunctor, ObjectRef VisitExpr_(const OpNode* id) override { // TODO(@jroesch): Eta-expand and return in this case. LOG(FATAL) << "internal error, need to wrap intrinsic into call synthetic call node " - << "in " - << "this case, eta expand"; + << "in this case, eta expand"; return ObjectRef(); } @@ -258,7 +335,7 @@ class Interpreter : public ExprFunctor, } ObjectRef MakeClosure(const Function& func, Var letrec_name = Var()) { - tvm::Map captured_mod; + Map captured_mod; Array free_vars = FreeVars(func); for (const auto& var : free_vars) { @@ -284,251 +361,301 @@ class Interpreter : public ExprFunctor, return MakeClosure(func); } - Array ComputeDynamicShape(const Function& func, const Array& args) { - CCacheKey key(func, Target("llvm")); - auto cfunc = engine_->LowerShapeFunc(key); - size_t arity = cfunc->inputs.size() + cfunc->outputs.size(); + /*! + * \brief Returns the packed function implementing the TIR function bound to \p tir_fn_var. + * + * \param tir_fn_var Global var for the already lowered TIR function. + * \param all_tir_fn_vars Global vars for all lowered TIR functions the above + * may reference, plus \p tir_fn_var itself. + * \param target Target for which the TIR function should be compiled. For primitives this + * will be the interpreter's target_. However for shape functions this will be the generic + * 'cpu' target, since shape functions are always executed on the host cpu. + */ + PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array& all_tir_fn_vars, + Target target) { + std::pair packed_func_key(target->str(), tir_fn_var->name_hint); + auto packed_itr = compiled_packed_funcs_.find(packed_func_key); + if (packed_itr != compiled_packed_funcs_.end()) { + // Already compiled. + return packed_itr->second; + } + + // Project out just the function(s) we need. + IRModule lowered_projected_mod; + auto mod_itr = per_target_module_.find(target->str()); + ICHECK(mod_itr != per_target_module_.end()) + << "No target module for target '" << target->str() << "'"; + const IRModule& target_module = (*mod_itr).second; + for (const auto& var : all_tir_fn_vars) { + ICHECK(target_module->ContainGlobalVar(var->name_hint)) + << "No global var for '" << var->name_hint << "' in module for target '" << target->str() + << "'"; + lowered_projected_mod->Add(var, target_module->Lookup(var->name_hint)); + } + + // Compile (aka 'build') the projected module into a runtime module of packed functions. + runtime::Module runtime_module; + if (const auto* f = runtime::Registry::Get("relay.backend.build")) { + // TODO(mbs): Cleanup hooks. + runtime_module = (*f)(lowered_projected_mod, target); + } else { + runtime_module = build(lowered_projected_mod, target, /*target_host=*/Target(nullptr)); + } + + // Extract all the packed functions. + for (const auto& var : all_tir_fn_vars) { + PackedFunc packed_func = runtime_module.GetFunction(var->name_hint); + ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint + << "' in compiled module for target '" << target->str() << "'"; + compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func); + } + // Return just what we need for this call. + packed_itr = compiled_packed_funcs_.find(packed_func_key); + ICHECK(packed_itr != compiled_packed_funcs_.end()) << " " << tir_fn_var->name_hint; + ICHECK_NOTNULL(packed_itr->second); + return packed_itr->second; + } + + /*! + * \brief Call the dynamic shape function bound to \p prim_shape_fn_var passing the + * shapes of args, and return the resulting shapes. + * + * \param prim_shape_fn_var Global var bound to lowered shape function. + * \param all_prim_shape_fn_vars All the global vars needed to build the above, including + * the shape function itself. + * \param prim_shape_fn_states For each primitive arg, indicate whether the primitive shape + * function requires the shape of the argument and/or the actual argument tensor. + * \param num_shape_inputs The number of inputs, after accounting for both shapes vs data + * inputs and unfolding of tuple types. + * \param num_shape_outputs The number of outputs, after accounting for flattening of + * tuple types. + * \param args Arguments to the primitive this shape function is for. + * \return Expected shapes of the underlying primitive's flattened outputs. + */ + Array ComputeDynamicShape(const GlobalVar& prim_shape_fn_var, + const Array& all_prim_shape_fn_vars, + const Array& prim_shape_fn_states, + size_t num_shape_inputs, size_t num_shape_outputs, + const std::vector& args) { + ICHECK(prim_shape_fn_var.defined()); + ICHECK(prim_shape_fn_states.defined()); + ICHECK(prim_shape_fn_var->checked_type().defined()); + // The function type is that of the original primitive rather than the shape function + // itself. We currently can't express shape function types in Relay. + const FuncTypeNode* ftn = prim_shape_fn_var->checked_type().as(); + ICHECK(ftn); + // The primitive shape function states are w.r.t. the primitive's arguments in + // non-flattened form. + // TODO(mbs): Clean this up so we don't mix flattened vs original conventions. + ICHECK_EQ(prim_shape_fn_states.size(), ftn->arg_types.size()); + ICHECK_EQ(args.size(), ftn->arg_types.size()); + // num_shape_inputs will account for which primitive function arguments are dynamic, + // whether the shape and or data needs to be passed, and flattening of tuples. + // Similarly, num_shape_outputs will account for flattening of tuples. + + // Shape functions always run on the cpu + Device shape_device; + shape_device.device_type = kDLCPU; + shape_device.device_id = 0; + Target shape_target("llvm"); + + // 'Compile' the TIR shape function to appropriate callable form. + PackedFunc packed_shape_func = + TIRToPackedFunc(prim_shape_fn_var, all_prim_shape_fn_vars, shape_target); + + size_t arity = num_shape_inputs + num_shape_outputs; std::vector values(arity); std::vector codes(arity); TVMArgsSetter setter(values.data(), codes.data()); - std::vector inputs(cfunc->inputs.size()); - std::vector outputs(cfunc->outputs.size()); - - Device cpu_dev; - cpu_dev.device_type = kDLCPU; - cpu_dev.device_id = 0; - - auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) { - auto nd_array = Downcast(val); - if (need_shape) { - int64_t ndim = nd_array.Shape().size(); - NDArray shape_arr; - if (ndim == 0) { - shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_dev); - } else { - shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_dev); - int64_t* data = reinterpret_cast(shape_arr->data); - for (auto j = 0; j < ndim; ++j) { - data[j] = nd_array.Shape()[j]; - } - } - inputs[i] = shape_arr; - setter(i, shape_arr); - } else { - auto arr = nd_array.CopyTo(cpu_dev); - inputs[i] = arr; - setter(i, arr); - } - }; + std::vector inputs(num_shape_inputs); + std::vector outputs(num_shape_outputs); + // Collect the shapes and/or data needed by the shape function from + // the primitive's arguments. size_t arg_counter = 0; for (size_t i = 0; i < args.size(); ++i) { - auto arg = args[i]; - auto param = func->params[i]; - int state = cfunc->shape_func_param_states[i]->value; - if (arg->IsInstance()) { - if (state & kNeedInputData) { - fset_input(arg_counter++, arg, false); - } - if (state & kNeedInputShape) { - fset_input(arg_counter++, arg, true); - } - } else { - const ADT adt = Downcast(arg); + // TODO(mbs): The same need data/need shape arg state applies to everything in the + // flattened form of this arg. Does that match what lowering actually does? + int64_t state = prim_shape_fn_states[i]->value; + for (const auto& nd_array : FlattenADT(args[i])) { if (state & kNeedInputData) { - for (size_t i = 0; i < adt.size(); ++i) { - fset_input(arg_counter++, adt[i], false); - } + auto arr = nd_array.CopyTo(shape_device); + inputs[arg_counter] = arr; + setter(arg_counter, arr); + ++arg_counter; } if (state & kNeedInputShape) { - for (size_t i = 0; i < adt.size(); ++i) { - fset_input(arg_counter++, adt[i], true); + int64_t ndim = nd_array.Shape().size(); + NDArray shape_arr; + if (ndim == 0) { + shape_arr = NDArray::Empty({}, DataType::Int(64), shape_device); + } else { + shape_arr = NDArray::Empty({ndim}, DataType::Int(64), shape_device); + int64_t* data = reinterpret_cast(shape_arr->data); + for (auto j = 0; j < ndim; ++j) { + data[j] = nd_array.Shape()[j]; + } } + inputs[arg_counter] = shape_arr; + setter(arg_counter, shape_arr); + ++arg_counter; } } } - ICHECK_EQ(arg_counter, cfunc->inputs.size()) << "Shape function input sizes mismatch"; - - auto fset_shape_output = [&](size_t i, Type val_type) { - // TODO(@icemelon): allow recursive tuple - const TensorTypeNode* rtype = val_type.as(); - ICHECK(rtype != nullptr); - int64_t ndim = rtype->shape.size(); - auto arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_dev); - outputs[i] = arr; - setter(arg_counter + i, arr); - }; + ICHECK_EQ(arg_counter, num_shape_inputs) << "Shape function input sizes mismatch"; - auto ret_type = func->body->checked_type(); + // Prepare NDArrays to hold the output shapes. size_t out_cnt = 0; - if (auto rtype = ret_type.as()) { - out_cnt = rtype->fields.size(); - for (size_t i = 0; i < out_cnt; ++i) { - fset_shape_output(i, rtype->fields[i]); - } - } else { - out_cnt = 1; - auto tt = Downcast(ret_type); - fset_shape_output(0, tt); + for (const auto& ttype : FlattenTupleType(ftn->ret_type)) { + ICHECK(out_cnt < num_shape_outputs); + int64_t ndim = ttype->shape.size(); + auto arr = NDArray::Empty({ndim}, DataType::Int(64), shape_device); + outputs[out_cnt] = arr; + setter(arg_counter + out_cnt, arr); + ++out_cnt; } - ICHECK_EQ(cfunc->outputs.size(), out_cnt) << "Shape function output sizes mismatch"; + ICHECK_EQ(out_cnt, num_shape_outputs) << "Shape function output sizes mismatch"; - PackedFunc shape_func; - Module m; - TVMRetValue rv; - if (const auto* f = runtime::Registry::Get("relay.backend.build")) { - m = (*f)(cfunc->funcs, cfunc->target); - } else { - m = build(cfunc->funcs, cfunc->target, Target(nullptr)); - } - shape_func = m.GetFunction(cfunc->prim_fn_var->name_hint); - shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); + // Call the dynamic shape function. + TVMRetValue rv; // ignored + packed_shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); - // Get output shapes + // Convert result tensors back to shapes. Array out_shapes; for (auto out_tensor : outputs) { int64_t* shape_data = reinterpret_cast(out_tensor->data); Shape out_shape; for (int i = 0; i < out_tensor->shape[0]; ++i) { - out_shape.push_back(tvm::Integer(shape_data[i])); + out_shape.push_back(Integer(shape_data[i])); } out_shapes.push_back(out_shape); } return out_shapes; } - ObjectRef InvokePrimitiveOp(const Function& func, const Array& args) { - const auto* call_node = func->body.as(); - - if (call_node && call_node->op == debug_op_) { - auto dattrs = call_node->attrs.as(); - auto interp_state = this->get_state(call_node->args[0]); - - if (dattrs->debug_func.defined()) { - dattrs->debug_func(interp_state); - } else { - RELAY_DEBUG_INTERP(interp_state); - } - - return args[0]; - } + /*! + * \brief Call primitive op bound to \p prim_fn_var with \p args. If necessary, evaluate dynamic + * shape function bound to \p prim_shape_fn_var to calculate shapes of result tensors. + * + * @param prim_fn_var Global bound to lowered primitive. + * @param all_prim_fn_vars All globals references by lowered primitive, plus prim_fn_var itself. + * @param prim_shape_fn_var Global bound to lowered shape function for primitive, if neeeded. + * @param all_prim_shape_fn_vars All globals references by lowered shape function, plus + * prim_shape_fn_var itself. + * @param prim_shape_fn_states Records whether shape and/or data is needed by the dynamic + * shape function (if any) for each (flattened) argument. + * @param num_shape_inputs Number of arguments to the dynamic shape function (if any). + * @param num_shape_outputs Number of outputs from the dynamic shape function (if any). + * @param args Already evaluated arguments to primitive. + * @return Result of primitive. + */ + ObjectRef InvokePrimitiveOp(const GlobalVar& prim_fn_var, const Array all_prim_fn_vars, + const GlobalVar& prim_shape_fn_var, + const Array& all_prim_shape_fn_vars, + const Array& prim_shape_fn_states, size_t num_shape_inputs, + size_t num_shape_outputs, const std::vector& args) { + ICHECK(prim_fn_var->checked_type().defined()); + const FuncTypeNode* ftn = prim_fn_var->checked_type().as(); + ICHECK(ftn); + + // 'Compile' the TIR primitive to appropriate callable form (on the desired target). + PackedFunc packed_func = TIRToPackedFunc(prim_fn_var, all_prim_fn_vars, target_); + + // Argument tuples are flattened. + std::vector arg_nd_arrays = FlattenADTs(args); + const size_t num_inputs = arg_nd_arrays.size(); + // num_inputs should equal size(concat(map(FlattenTupleType, function arg types))) + + // TVM's primitive calling convention is for the final arguments to be for output + // buffers. We must allocate space for those buffers based on the return type. + std::vector result_tensor_types = FlattenTupleType(ftn->ret_type); + const size_t arg_len = num_inputs + result_tensor_types.size(); - // Marshal the arguments. - // Handle adt input/output by flattening them. - size_t arg_len = 0; - for (size_t i = 0; i < args.size(); ++i) { - if (args[i]->IsInstance()) { - ++arg_len; - } else { - auto adt = Downcast(args[i]); - arg_len += adt.size(); - } - } - size_t num_inputs = arg_len; - if (const auto* tuple_type = func->body->checked_type().as()) { - arg_len += tuple_type->fields.size(); - } else { - ICHECK(func->body->checked_type().as()) << func->body->checked_type(); - arg_len += 1; - } std::vector values(arg_len); std::vector codes(arg_len); TVMArgsSetter setter(values.data(), codes.data()); - auto fset_input = [&](size_t i, ObjectRef val) { - const auto nd_array = Downcast(val); - setter(i, nd_array); + // Marshall the call's arguments in flattened form. + int arg_counter = 0; + for (const auto& nd_array : arg_nd_arrays) { + setter(arg_counter++, nd_array); Device arg_dev = nd_array->device; ICHECK(arg_dev.device_type == device_.device_type && arg_dev.device_id == device_.device_id) - << "Interpreter expect device to be " << device_ << ", but get " << arg_dev; - }; + << "Interpreter expect device to be " << device_ << ", but got " << arg_dev; + } - int arg_counter = 0; - for (ObjectRef arg : args) { - if (arg->IsInstance()) { - fset_input(arg_counter++, arg); - } else { - auto adt = Downcast(arg); - for (size_t i = 0; i < adt.size(); ++i) { - fset_input(arg_counter++, adt[i]); - } - } + // If necessary, retrieve concrete shapes for outputs from shape function rather + // than relying on TensorType shapes. + Array runtime_shapes; + bool is_dyn = IsDynamic(ftn->ret_type); + if (is_dyn) { + ICHECK(prim_shape_fn_var.defined()); + ICHECK(prim_shape_fn_states.defined()); + runtime_shapes = + ComputeDynamicShape(prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_fn_states, + num_shape_inputs, num_shape_outputs, args); + ICHECK_EQ(runtime_shapes.size(), result_tensor_types.size()); } - // TVM's calling convention is that the final argument is the output - // buffer. To preserve the illusion of being a functional language - // we need to allocate space for the output buffer based on the - // return type. - auto fset_output = [&](size_t i, Type val_type) { - const TensorTypeNode* rtype = val_type.as(); - ICHECK(rtype != nullptr); - // Allocate output tensor. - std::vector shape; - for (auto dim : rtype->shape) { + // Prepare the result tensors for the call. + TVMRetValue rv; // ignored + std::vector result_nd_arrays; + for (size_t i = 0; i < result_tensor_types.size(); ++i) { + const auto& ttype = result_tensor_types[i]; + const Shape& shape = is_dyn ? runtime_shapes[i] : ttype->shape; + // Allocate output tensor of appropriate shape. + std::vector concrete_shape; + for (const auto& dim : shape) { const auto* ivalue = tir::as_const_int(dim); ICHECK(ivalue) << "expected concrete dimensions"; - shape.push_back(ivalue[0]); + concrete_shape.push_back(ivalue[0]); } - DLDataType dtype = rtype->dtype; - NDArray nd_array = NDArray::Empty(shape, dtype, device_); + NDArray nd_array = NDArray::Empty(concrete_shape, ttype->dtype, device_); setter(num_inputs + i, nd_array); - return nd_array; - }; + result_nd_arrays.emplace_back(nd_array); + } - Array out_shapes; - auto ret_type = func->body->checked_type(); - bool is_dyn = IsDynamic(ret_type); + // Call the primitive. + packed_func.CallPacked(TVMArgs(values.data(), codes.data(), static_cast(arg_len)), &rv); - if (is_dyn) { - ICHECK(func->HasNonzeroAttr(attr::kPrimitive)); - out_shapes = ComputeDynamicShape(func, args); - } - - PackedFunc packed_func = engine_->JIT(CCacheKey(func, target_)); - TVMRetValue rv; - if (const TupleTypeNode* rtype = func->body->checked_type().as()) { - ICHECK(!is_dyn || out_shapes.size() == rtype->fields.size()); - std::vector fields; - for (size_t i = 0; i < rtype->fields.size(); ++i) { - if (is_dyn) { - auto sh = out_shapes[i]; - auto tt = Downcast(rtype->fields[i]); - fields.push_back(fset_output(i, TensorType(sh, tt->dtype))); - } else { - fields.push_back(fset_output(i, rtype->fields[i])); - } - } - packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv); - return ADT::Tuple(fields); - } else { - ObjectRef out_tensor; - if (is_dyn) { - ICHECK_EQ(out_shapes.size(), 1); - auto sh = out_shapes[0]; - auto tt = Downcast(ret_type); - out_tensor = fset_output(0, TensorType(sh, tt->dtype)); - } else { - out_tensor = fset_output(0, ret_type); - } - packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv); - return out_tensor; - } + // Unflatten the results. + return ToADTOrNDArray(ftn->ret_type, result_nd_arrays); } - // Invoke the closure - ObjectRef Invoke(const InterpreterClosure& closure, const tvm::Array& args, + /*! + * \brief Invoke \p closure with \p args. If \p bind is defined then this is a recursive + * closure and \p bind should refer to itself. + */ + ObjectRef Invoke(const InterpreterClosure& closure, const Array& args, const Var& bind = Var()) { // Get a reference to the function inside the closure. - if (closure->func->HasNonzeroAttr(attr::kPrimitive)) { - return InvokePrimitiveOp(closure->func, args); + Function func = closure->func; + ICHECK_EQ(func->params.size(), args.size()); + + if (func->HasNonzeroAttr(attr::kPrimitive)) { + if (const CallNode* call_node = closure->func->body.as()) { + if (call_node->op == debug_op_) { + // Special case: Calling the debug tracing function. + auto dattrs = call_node->attrs.as(); + auto interp_state = get_state(call_node->args[0]); + + if (dattrs->debug_func.defined()) { + dattrs->debug_func(interp_state); + } else { + RELAY_DEBUG_INTERP(interp_state); + } + + return args[0]; + } + } } - auto func = closure->func; - // Allocate a frame with the parameters and free variables. - tvm::Map locals; - ICHECK_EQ(func->params.size(), args.size()); + ICHECK(!func->HasNonzeroAttr(attr::kPrimitive)) + << "Calls to primitive functions should have been removed by lowering"; + // Allocate a frame with the parameters and free variables. + Map locals; for (size_t i = 0; i < func->params.size(); i++) { ICHECK_EQ(locals.count(func->params[i]), 0); locals.Set(func->params[i], args[i]); @@ -548,30 +675,70 @@ class Interpreter : public ExprFunctor, } ObjectRef VisitExpr_(const CallNode* call) final { - tvm::Array args; + std::vector args; for (auto arg : call->args) { args.push_back(Eval(arg)); } - // We should not find operators after running fusion, - // and operator lowering. - // - // We have some functions cotaining chunks of operators - // which will be loaded into operator map. - if (const auto* op_node = call->op.as()) { + + // We should not find calls to operators after running fusion and lowering. + if (const OpNode* op_node = call->op.as()) { LOG(FATAL) << "found " << op_node->name - << "; operators should be removed by future passes; try " + << "; operators should have been removed by previous passes; try " "fusing and lowering"; } - if (auto con = call->op.as()) { + + if (const ConstructorNode* con = call->op.as()) { + // Special case: ADT constructor return ConstructorValue(con->tag, args, GetRef(con)); } + + if (const GlobalVarNode* gvn = call->op.as()) { + if (const TIRCallAttrs* attrs = call->attrs.as()) { + // Special case: Call a lowered TIR function. + // TODO(mbs): Make calling convention first-class in Relay. + Array all_prim_fn_vars; + if (attrs->metadata.count("all_prim_fn_vars")) { + all_prim_fn_vars = Downcast>(attrs->metadata.at("all_prim_fn_vars")); + } + GlobalVar prim_shape_fn_var; + if (attrs->metadata.count("prim_shape_fn_var")) { + prim_shape_fn_var = Downcast(attrs->metadata.at("prim_shape_fn_var")); + } + Array all_prim_shape_fn_vars; + if (attrs->metadata.count("all_prim_shape_fn_vars")) { + all_prim_shape_fn_vars = + Downcast>(attrs->metadata.at("all_prim_shape_fn_vars")); + } + Array prim_shape_fn_states; + if (attrs->metadata.count("prim_shape_fn_states")) { + prim_shape_fn_states = + Downcast>(attrs->metadata.at("prim_shape_fn_states")); + } + size_t num_shape_inputs = 0; + if (attrs->metadata.count("prim_shape_fn_num_inputs")) { + num_shape_inputs = static_cast( + Downcast(attrs->metadata.at("prim_shape_fn_num_inputs"))->value); + } + size_t num_shape_outputs = 0; + if (attrs->metadata.count("prim_shape_fn_num_outputs")) { + num_shape_outputs = static_cast( + Downcast(attrs->metadata.at("prim_shape_fn_num_outputs"))->value); + } + + // Special case: Call TIR primitive. + return InvokePrimitiveOp(GetRef(gvn), all_prim_fn_vars, prim_shape_fn_var, + all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs, + num_shape_outputs, args); + } + } + // Now we just evaluate and expect to find a closure. ObjectRef fn_val = Eval(call->op); if (const InterpreterClosureObj* closure_node = fn_val.as()) { auto closure = GetRef(closure_node); - return this->Invoke(closure, args); + return Invoke(closure, args); } else if (const RecClosureObj* closure_node = fn_val.as()) { - return this->Invoke(closure_node->clos, args, closure_node->bind); + return Invoke(closure_node->clos, args, closure_node->bind); } else { LOG(FATAL) << "internal error: type error, expected function value in the call " << "position"; @@ -701,43 +868,211 @@ class Interpreter : public ExprFunctor, } private: - // Module + // Main module. All expressions are eval'ed w.r.t. the definitions in this module. This module + // may contain calls to TIR functions bound in a per_target_module_ below. IRModule mod_; - // For simplicity we only run the interpreter on a single context. - // Context to run the interpreter on. + // Map from target key to lowered TIR functions derived from mod_. + // Note that primitives are implicitly executed on target_, while shape functions are implicitly + // executed on the default 'cpu' host. Thus this map has at most two entries. + Map per_target_module_; + // Cached packed functions for the primitives and shape functions, keyed by target and + // global var name. + std::unordered_map, PackedFunc, PairHash> + compiled_packed_funcs_; + // Unique device on which primitives (but not shape functions) will be executed. + // (For simplicity we only run the interpreter on a single device.) Device device_; - // Target parameter being used by the interpreter. + // Unique target describing how to compile for primitives (but not shape functions). Target target_; - // Object stack. + // Call stack. Stack stack_; - // Backend compile engine. - CompileEngine engine_; - // Cache ops that need to be frequently used later to reduce lookup overhead. + // The distinguished 'debug' operator, which is handled specially. const Op& debug_op_; }; -TypedPackedFunc CreateInterpreter(IRModule mod, Device device, Target target) { - if (mod.defined()) { - // eta expand to support constructors in argument position - transform::Sequential seq({transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false), - transform::InferType()}); +/*! + * Lowers all calls to primitives in \p mod appropriate for device and target. Returns the + * rewritten \p mod and target-specific modules containing bindings for all TIR primitive + * functions needed by the rewritten module. + */ +std::pair> Prepare(IRModule mod, Device device, Target target) { + // Run minimal transforms on module to establish invariants needed by interpreter. + transform::Sequential seq({transform::SimplifyInference(), + // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' + // attribute. + transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(), + // eta expand to support constructors in argument position + transform::EtaExpand( + /*expand_constructor=*/true, /*expand_global_var=*/false), + transform::InferType()}); + + transform::PassContext pass_ctx = transform::PassContext::Current(); + With ctx(pass_ctx); + mod = seq(mod); + + // We only have one device-specific target. + tec::TargetMap targets = {{device.device_type, target}}; + + // All calls to primitives will use the unique target. + tec::DeviceMap device_map; + + // No need for a memory plan. + backend::StaticMemoryPlan memory_plan; /*=nullptr*/ + + // Lower all primitive functions reachable from expr. + // TODO(mbs): This should be just another pass in seq above, which requires LoweredModule to + // be merged into IRModule. + LoweredModule lowered_module = + tec::LowerTE(mod, targets, device_map, memory_plan, /*module_name=*/"intrp", + [](Function func) { /* no-op */ }); + return {lowered_module.main_module, lowered_module.per_target_module}; +} + +/*! \brief Check if an expression could be changed by \p Prepare. + * + * If not we can evaluate it directly and don't need to bind it into a fresh module. + */ +class NeedsPreparationVisitor : public ExprVisitor { + public: + bool needs_preparation = false; + + private: + void VisitExpr_(const VarNode* vn) override { + // Could be prim. + needs_preparation = true; + } + // ConstantNode ok + // GlobalVarNode ok + void VisitExpr_(const OpNode* op) override { + // Could be prim. + needs_preparation = true; + } + // TupleNode recurse + void VisitExpr_(const FunctionNode* op) override { + // Could be prim. + needs_preparation = true; + } + // CallNode recurse + void VisitExpr_(const LetNode* ln) override { + // May bind prim. + needs_preparation = true; + } + // IfNode recurse + // TupleGetItemNode recurse + // RefCreateNode recurse + // RefReadNode recurse + // RefWriteNode recurse + // ConstructorNode ok + void VisitExpr_(const MatchNode* op) override { + // Needs eta-expansion. + needs_preparation = true; + } +}; - transform::PassContext pass_ctx = transform::PassContext::Current(); - tvm::With ctx(pass_ctx); - mod = seq(mod); +TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, Device device, + Target target) { + // + // Step 1: Prepare mod. + // + + // If expr is simple enough we can avoid binding it into the module and + // just eval it directly. + NeedsPreparationVisitor visitor; + visitor.VisitExpr(expr); + + Expr expr_to_eval; + IRModule mod_with_expr; // default empty + if (visitor.needs_preparation) { + GlobalVar main; + // Bind expr to a new zero-argument function so it can be prepared along with the module + // (if any). + std::pair mod_and_global; + if (mod.defined()) { + // TODO(mbs): Type inference currently assumes all global functions in modules have + // known result types, and so each global function has it's body types inferred independently + // and in arbitrary order. However, the interpreter may be called with an expression relative + // to a 'main' which has no result type annotation, and that expressions will be bound into a + // fresh global below. Type inference then fails since 'main' has unknown type. We should + // allow inference on mutually recursive global functions. To workaround, infer the type + // of mod now. Obviously that won't work if 'main' itself calls other global functions of + // partial type, but it at least maintains legacy behavior. + transform::PassContext pass_ctx = transform::PassContext::Current(); + With ctx(pass_ctx); + mod = transform::InferType()(mod); + mod_and_global = + IRModule::FromExprInContext(expr, mod->functions, mod->type_definitions, mod->Imports()); + } else { + mod_and_global = IRModule::FromExprInContext(expr); + } + mod_with_expr = mod_and_global.first; + expr_to_eval = mod_and_global.second; + } else { + if (mod.defined()) { + mod_with_expr = mod; + } + // Prepare won't change expr, so we don't need to worry about binding it into a module + // and can just eval it directly. + expr_to_eval = expr; + } + std::pair> main_and_lowered = + Prepare(mod_with_expr, device, target); + std::shared_ptr intrp = std::make_shared( + /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, + target); + + // + // Step 2: Evaluate target function to a closure. + // + ObjectRef object_ref = intrp->Eval(expr_to_eval); + if (const InterpreterClosureObj* closure_obj = object_ref.as()) { + InterpreterClosure closure = GetRef(closure_obj); + ICHECK(closure.defined()); + ICHECK(closure->func.defined()); + + return TypedPackedFunc)>([intrp, closure](Array args) { + // + // Step 3: Apply closure to arguments. + // + ICHECK_NOTNULL(intrp); + ICHECK(closure.defined()); + ICHECK(closure->func.defined()); + Array evaled_args; + for (auto arg : args) { + NeedsPreparationVisitor visitor; + visitor.VisitExpr(arg); + ICHECK(!visitor.needs_preparation) + << "attempting to apply closure to expression which needs preparation: " + << PrettyPrint(arg); + evaled_args.push_back(intrp->Eval(arg)); + } + return intrp->Invoke(closure, evaled_args); + }); + } else { + LOG(FATAL) << "expecting expression to have function type and evaluate to a closure"; + return nullptr; } +} - auto intrp = std::make_shared(mod, device, target); - auto packed = [intrp](Expr expr) { - auto f = DetectFeature(expr); - ICHECK(f.is_subset_of(FeatureSet::All() - fGraph)); - return intrp->Eval(expr); - }; - return TypedPackedFunc(packed); +ObjectRef Eval(Expr expr, Map type_definitions, + std::unordered_set import_set, Device device, Target target) { + std::pair mod_and_global = + IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); + std::pair> main_and_lowered = + Prepare(mod_and_global.first, device, target); + Interpreter intrp( + /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, + target); + Expr expr_to_eval = main_and_lowered.first->GetGlobalVar(mod_and_global.second->name_hint); + if (expr.as() == nullptr) { + // TODO(mbs): IRModule::FromExpr will implicitly close over the free vars of expr + // unless it is a function, so we must reverse that in the expression to eval. + // This should done more systematically. + expr_to_eval = Call(expr_to_eval, {}); + } + return intrp.Eval(expr_to_eval); } -TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter").set_body_typed(CreateInterpreter); +TVM_REGISTER_GLOBAL("relay.backend.EvalFunction").set_body_typed(EvalFunction); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 93b9c6fc1827..71ac752ec680 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -20,17 +20,16 @@ #include "te_compiler.h" #include -#include +#include +#include #include #include #include #include #include #include -#include #include #include -#include #include #include #include @@ -43,8 +42,6 @@ #include #include -#include "../transforms/pass_utils.h" -#include "te_compiler.h" #include "te_compiler_cache.h" #include "utils.h" @@ -101,6 +98,18 @@ class TECompilerImpl : public TECompilerNode { lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); } + + for (const auto& it : shape_func_cache_) { + auto source_func = it.first; + auto lowered_func = it.second; + auto target = source_func->target; + + if (!lowered_functions.count(target->str())) { + lowered_functions.Set(target->str(), IRModule(Map({}))); + } + + lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + } return lowered_functions; } @@ -195,6 +204,7 @@ class TECompilerImpl : public TECompilerNode { auto target = Target("ext_dev"); auto global_var = GlobalVar(func_name); global_var->checked_type_ = key->source_func->checked_type(); + ir_module->Add(global_var, key->source_func); value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); return value; } @@ -303,130 +313,141 @@ std::tuple IsDeviceCopy(const Function& func) { return std::tuple(false, -1, -1); } -class LowerTensorExpr : public ExprMutator { +/*! + * \brief Rewrites call expressions to Relay functions marked as 'primitive' + * to calls to the corresponding TIR primitive for the appropriate target. + * + * \code + * let %p = fn(...) { prim_op(...) } + * ... %p(...) ... + * ==> + * (in target-specific module) def @p' = fn (...) { } + * let %p = fn(...) { prim_op(...) } + * ... @p'(...) ... + * \endcode + * + * Requires FuseOps, ToANormalForm, EtaExpand and InferType to have run. + * + * FuseOps is needed to identify and lift all prim op calls: + * \code + * ... prim_op(...) ... + * ==> + * let %p = fn(...) { prim_op(...) } + * ... %p(...) ... + * \endcode + * + * ToANormalForm is needed so we only need to consider vars as the call target. + * (However we'll also allow function literals.) + * + * EtaExpand is needed to ensures all calls to primitives are direct: + * \code + * let %p1 = fn(...) { prim_op1(...) } + * let %p2 = fn(...) { prim_op2(...) } + * let %p = if (...) { %p1 } else { %p2 } + * ... %p(...) ... + * ==> + * let %p1 = fn(...) { prim_op1(...) } + * let %p2 = fn(...) { prim_op2(...) } + * let %p = fn(...) { if (...) { %p1(...) } else { %p2(...) } } + * ... %p(...) ... + * \endcode + */ +class LowerTensorExprMutator : public ExprMutator { public: - LowerTensorExpr(const IRModule& module, const TargetMap& targets, const DeviceMap& device_ctx_map, - ProcessFn process_fn, const String& module_name, TECompiler compiler) + LowerTensorExprMutator(const IRModule& module, const TargetMap& targets, + const DeviceMap& device_ctx_map, ProcessFn process_fn, + const String& module_name, TECompiler compiler) : module_(module), targets_(targets), device_context_map_(device_ctx_map), - process_fn(process_fn), + process_fn_(process_fn), module_name_(module_name), - compiler_(compiler) {} - - Expr VisitExpr_(const CallNode* call) override { - Call expr = GetRef(call); - Function func; - - if (call->op.as()) { - func = GetRef(call->op.as()); - } else { - return ExprMutator::VisitExpr_(call); - } - - if (!func->HasNonzeroAttr(attr::kPrimitive)) { - // Provide a callback hook which allows one-level up code generators to - // act when we process a function. - this->process_fn(func); - return ExprMutator::VisitExpr_(call); - } + compiler_(compiler), + debug_op_(Op::Get("debug")) {} - // Process inputs. - Array args; - for (size_t i = 0; i < expr->args.size(); i++) { - args.push_back(VisitExpr(expr->args[i])); + /*! + * \brief Returns the primitive function associated with \p expr, or + * nullptr if none. + */ + Function ResolveToPrimitive(Expr expr) { + if (const GlobalVarNode* gvn = expr.as()) { + BaseFunc base_func = module_->Lookup(GetRef(gvn)); + return ResolveToPrimitive(base_func); + } else if (const VarNode* vn = expr.as()) { + auto itr = primitive_functions_.find(GetRef(vn)); + return itr == primitive_functions_.end() ? Function() : itr->second; + } else if (const FunctionNode* fn = expr.as()) { + if (!fn->HasNonzeroAttr(attr::kPrimitive)) { + // Not marked as primitive by FuseOps. + return Function(); + } + if (const CallNode* cn = fn->body.as()) { + if (cn->op == debug_op_) { + // Debug 'primitives' are not lowered. + return Function(); + } + } + return GetRef(fn); } + return Function(); + } - Target target; - + /*! + * \brief Lowers the primitive function \p func to TIR for ultimate execution + * on a device with configuration \p target. Returns the global var bound + * to the TIR implementation, and attributes to attach to the call to identify it as + * a TIR call. + */ + std::pair LowerFunction(Function func, Target target) { if (func->GetAttr(attr::kCompiler).defined()) { - target = Target("ext_dev"); + // BYOC flow. CCacheKey key = CCacheKey(func, target); CachedFunc ext_func = compiler_->Lower(key, module_name_); ICHECK(ext_func.defined()) << "Lowering returned undefined function for " << ext_func->prim_fn_var->name_hint; Map prim_fns; - - for (auto prim_fn : ext_func->funcs->functions) { - CHECK(prim_fn.second.as()) << "must be a prim fn"; - prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); - } - relay::Function func_with_metadata = func; func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var); func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); - func_with_metadata = WithAttr(func_with_metadata, "target", ext_func->target); + func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, ext_func->target); // Provide a callback hook which allows one-level up code generators to // act when we process a function. - this->process_fn(func_with_metadata); + this->process_fn_(func_with_metadata); - auto ret_call = Call(ext_func->prim_fn_var, args, {}); - return std::move(ret_call); + // TODO(mbs): Need TIRCallAttrs or equiv so targets know this is an extern. + // TODO(mbs): Dynamic shapes? + return {ext_func->prim_fn_var, Attrs()}; } - ICHECK_GE(device_context_map_.count(expr), 0) - << "Could not find an entry in the device context map for " << PrettyPrint(expr) - << "The memory planning was either not performed for this precise node, or there is bug " - "in the memory planner."; - - auto& device_context = this->device_context_map_[expr]; - auto call_dev_type = device_context.device_type; - // Non-External Relay Function - if (targets_.size() == 1) { - // The homogeneous execution case, we should only have one target - // so we just grab it. - const auto& it = targets_.begin(); - target = (*it).second; - } else { - // The heterogeneous execution case we have multiple targets - // in this case. - // - // We need to identify the target and translate. - std::string call_dev_name; - if (call_dev_type == 0) { - call_dev_name = "llvm"; - call_dev_type = kDLCPU; - } else { - call_dev_name = ::tvm::runtime::DeviceName(call_dev_type); - } - - if (targets_.count(call_dev_type) == 0) { - std::stringstream msg; - msg << "No target is specified for provided device name: `" << call_dev_name << "`\n\n"; - msg << call_dev_name << " mapped to device type (" << call_dev_type - << ") which was not found in the target map.\n"; - msg << "Availible targets: \n"; - for (auto target : targets_) { - msg << " " << target.first << "-> " << target.second << "\n"; - } - LOG(FATAL) << msg.str(); - } - - target = targets_[call_dev_type]; - } - + DLOG(INFO) << "lowering to target '" << target->str() << "' for primitive:\n" + << PrettyPrint(func); CCacheKey key = CCacheKey(func, target); CachedFunc lowered_func = compiler_->Lower(key, module_name_); + DLOG(INFO) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; + // Collect all the lowered functions produced for this primitive function. Map prim_fns; - + Array all_prim_fn_vars; for (auto prim_fn : lowered_func->funcs->functions) { CHECK(prim_fn.second.as()) << "must be a prim fn"; prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); + all_prim_fn_vars.push_back(prim_fn.first); + DLOG(INFO) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) + << "'"; } // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT relay::Function func_with_metadata = func; func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var); func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); - func_with_metadata = WithAttr(func_with_metadata, "target", lowered_func->target); + func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, lowered_func->target); // Provide a callback hook which allows one-level up code generators to // act when we process a function. - this->process_fn(func_with_metadata); + this->process_fn_(func_with_metadata); auto tir_call_attrs = make_object(); if (func->HasNonzeroAttr(attr::kReshapeOnly)) { @@ -442,42 +463,162 @@ class LowerTensorExpr : public ExprMutator { } tir_call_attrs->metadata.Set("relay_attrs", func->attrs); + tir_call_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars); + + if (IsDynamic(func->ret_type)) { + // Also lower the dynamic shape function. + // Shape function keys use the underlying primitive function as their 'function', + // but the generic 'cpu' target as the target since all shape functions run + // on the host cpu irrespective of where the primitive runs. + // TODO(mbs): Cleanup target handling. + Target shape_target("llvm"); + DLOG(INFO) << "lowering to target '" << shape_target->str() + << "' for dynamic shape function for primitive"; + CCacheKey shape_key(func, shape_target); + CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); + // Capture the shape function's global var and parameters 'states' in call + // annotations so calling convention can be recovered. + // TODO(mbs): Capture all this as part of a 'call into TIR' construct once available. + // The way the shape function calling convention is derived and passed to call sites + // via the 'parameter states' could be improved. + tir_call_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); + tir_call_attrs->metadata.Set("prim_shape_fn_states", + lowered_shape_func->shape_func_param_states); + tir_call_attrs->metadata.Set("prim_shape_fn_num_inputs", + Integer(static_cast(lowered_shape_func->inputs.size()))); + tir_call_attrs->metadata.Set("prim_shape_fn_num_outputs", + Integer(static_cast(lowered_shape_func->outputs.size()))); + Array all_prim_shape_fn_vars; + for (auto prim_shape_fn : lowered_shape_func->funcs->functions) { + CHECK(prim_shape_fn.second.as()) << "must be a prim fn"; + all_prim_shape_fn_vars.push_back(prim_shape_fn.first); + } + tir_call_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); + } - Expr ret_call = Call(lowered_func->prim_fn_var, args, Attrs(tir_call_attrs)); - return ret_call; + return {lowered_func->prim_fn_var, Attrs(tir_call_attrs)}; + } + + Expr VisitExpr_(const LetNode* let) override { + Var var = Downcast(Mutate(let->var)); + Expr value = Mutate(let->value); + Function prim_func = ResolveToPrimitive(value); + if (prim_func.defined()) { + // Remember let var is bound to (possibly indirectly) to a primitive. + primitive_functions_.emplace(let->var, prim_func); + } + Expr body = Mutate(let->body); + if (prim_func.defined()) { + // Leaving let var scope. + primitive_functions_.erase(let->var); + } + if (var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body)) { + return GetRef(let); + } else { + return Let(var, value, body, let->span); + } + } + + Expr VisitExpr_(const CallNode* call) override { + Call expr = GetRef(call); + + // Look for (indirect) calls to primitives. + Function prim_func = ResolveToPrimitive(call->op); + if (!prim_func.defined()) { + // Not a call to a primitive function. + if (const FunctionNode* fn = call->op.as()) { + this->process_fn_(GetRef(fn)); + } + return ExprMutator::VisitExpr_(call); + } + + // Find the desired target device. + Target target; + if (prim_func->GetAttr(attr::kCompiler).defined()) { + // The generic 'external device' target. + target = Target("ext_dev"); + } else if (device_context_map_.empty() && targets_.size() == 1) { + // The unique target. + target = GetTargetFromInteger(kDLCPU, targets_); + } else { + // The target corresponding to the call expression's annotation. + auto itr = device_context_map_.find(expr); + ICHECK(itr != device_context_map_.end()) + << "Could not find an entry in the device context map for " << expr + << "The memory planning was either not performed for this precise node, or there is " + "bug in the memory planner."; + target = GetTargetFromInteger(itr->second.device_type, targets_); + } + + // Lower the primitive function for that target. + std::pair pair = LowerFunction(prim_func, target); + + // Similarly transform arguments. + Array args; + for (const auto& arg : call->args) { + args.push_back(VisitExpr(arg)); + } + + // Replace with direct call to lowered primitive, and attach annotations to record calling + // convention. + return Call(pair.first, args, pair.second); } IRModule module_; TargetMap targets_; DeviceMap device_context_map_; - ProcessFn process_fn; + ProcessFn process_fn_; + // Map from in-scope let-bound variables to Relay functions known to be + // primitive. We'll rewrite these to the fresh global vars bound to the lowered + // primitive function as we go. Those vars will be bound in the + // target device-type specific module we'll ultimately emit for each required + // device-type. Note that a primitive may be lowered for multiple device + // types, each which will be assigned a fresh var. + std::unordered_map + primitive_functions_; String module_name_; TECompiler compiler_; + // Cache ops that need to be frequently used later to reduce lookup overhead. + const Op& debug_op_; }; -/*! - * \brief Obtain the Target from the device type. - * If homogenous compilation, this will return the only target. - * If heteregenous compilation, this will select associated using the targets_ Map. - * - * \param dev_type - * \return Target - */ +Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + TECompiler compiler, std::function process_fn) { + runtime::TypedPackedFunc pass_func = + [=](Function func, IRModule module, PassContext ctx) { + LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn, + module_name, compiler); + return Downcast(lower_te.Mutate(func)); + }; + return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); +} + Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { if (targets.size() == 1) { - // homogeneous execution. + // The homogeneous execution case, return the only target. const auto& it = targets.begin(); return (*it).second; } else { - // heterogeneous execution. - std::string call_dev_name; - if (dev_type == 0) { - call_dev_name = "llvm"; - } else { - call_dev_name = runtime::DeviceName(dev_type); + // The heterogeneous execution case, return the target associated with the + // given device type. + // If "dev_type" equals to 0, the device name only can be got from + // "targets", and it may not be "llvm", so here just set it to "unknown". + std::string dev_name = "unknown"; + if (dev_type != 0) { + dev_name = runtime::DeviceName(dev_type); } + if (targets.count(dev_type) == 0) { - LOG(FATAL) << "No target is provided for device " << call_dev_name; + std::stringstream msg; + msg << "No target is specified for provided device name: `" << dev_name << "`\n\n" + << dev_name << " mapped to device type (" << dev_type + << ") which was not found in the target map.\n" + << "Availible targets: \n"; + for (auto target : targets) { + msg << " " << target.first << "-> " << target.second << "\n"; + } + LOG(FATAL) << msg.str(); } return targets[dev_type]; } @@ -609,8 +750,6 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar relay_primfuncs); } -// TODO(@electriclilies): Is the function passed in here relay_func?? -// Also should this be inlined? /*! * \brief A function to create the function metadata for an input function (ie calculate buffer * input/output sizes) @@ -639,7 +778,7 @@ void UpdateFunctionMetadata(Function relay_func, Optional prim_fn_var = relay_func->GetAttr("prim_fn_var"); CHECK(prim_fn_var) << "prim_fn_var must be set on Relay functions by TECompiler."; - Optional relay_target = relay_func->GetAttr("target"); + Optional relay_target = relay_func->GetAttr(tvm::attr::kTarget); CHECK(relay_target) << "target must be set on Relay functions by the TECompiler."; for (const auto& kv : prim_fns.value()) { @@ -653,8 +792,8 @@ void UpdateFunctionMetadata(Function relay_func, // Workspace sizes Target prim_fn_target; - if (prim_fn->attrs->dict.count("target")) { - prim_fn_target = Downcast(prim_fn->attrs->dict["target"]); + if (prim_fn->attrs->dict.count(tvm::attr::kTarget)) { + prim_fn_target = Downcast(prim_fn->attrs->dict[tvm::attr::kTarget]); } else { prim_fn_target = relay_target.value(); } @@ -693,24 +832,18 @@ void UpdateFunctionMetadata(Function relay_func, LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map, backend::StaticMemoryPlan memory_plan, const String& module_name, std::function process_fn) { - TECompiler compiler; + DLOG(INFO) << "lowering module:\n" << PrettyPrint(module); - CHECK_EQ(module->functions.size(), 1) - << "There should only be one function in the module passed to LowerTE"; - - auto pass = CreateFunctionPass( - [=](Function func, IRModule module, PassContext ctx) { - LowerTensorExpr lower_te(module, targets, device_context_map, process_fn, module_name, - compiler); - return Downcast(lower_te.VisitExpr(func)); - }, - 0, "LowerTensorExpr", {}); + TECompiler compiler; - // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize - backend::FunctionInfo func_info = - UpdateMainWorkspaceSize(module, targets, memory_plan->expr_to_storage_info); + backend::FunctionInfo func_info; + if (memory_plan.defined()) { + // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize + func_info = UpdateMainWorkspaceSize(module, targets, memory_plan->expr_to_storage_info); + } - auto updated_module = pass(module); + auto updated_module = LowerTensorExpr(targets, device_context_map, memory_plan, module_name, + compiler, process_fn)(module); // A temporary solution until we can rewrite the auto-scheduler task extraction code to work // in a more reasonable way. @@ -738,6 +871,117 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic return lowered_module; } +IRModule LoweredModuleToIRModule(LoweredModule mod) { + IRModule unified_module; + + // Copy the main module and its typedefs + for (const auto& kv : mod.main_module->functions) { + unified_module->Add(kv.first, kv.second); + } + for (const auto& kv : mod.main_module->type_definitions) { + unified_module->AddTypeDef(kv.first, kv.second); + } + + // Annotate the per-target functions with their target and add them to the unified module + for (const auto& kv : mod.per_target_module) { + const String target = kv.first; + const IRModule target_module = kv.second; + + // Right now, per-target functions are TIR functions, which don't have type definitions, so + // there should be no type defs in the per_target_modules + size_t ty_def_size = target_module->type_definitions.size(); + ICHECK(ty_def_size == 0) + << "Expected there to be no type definitions in the per_target_modules, but found " + << ty_def_size; + + for (const auto& kv : target_module->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + if (func->IsInstance()) { + tir::PrimFunc primFunc = + WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, target); + unified_module->Add(var, primFunc); + } else if (func->IsInstance()) { + relay::Function relayFunc = + WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, target); + unified_module->Add(var, relayFunc); + } else { + LOG(FATAL) + << "We expected to only have PrimFuncs or RelayFuncs in the target modules, but found " + << func->GetTypeKey(); + } + } + } + + IRModule ret_mod = WithAttr(unified_module, "external_mods", mod.external_mods); + ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info); + return ret_mod; +} + +LoweredModule IRModuleToLoweredModule(IRModule mod) { + IRModule main_mod; + // Copy just the TypeDefs from the IRModule to the LoweredModule's main module + // This is the only time we need to do this since there are no TypeDefs in TIR + for (const auto& kv : mod->type_definitions) { + main_mod->AddTypeDef(kv.first, kv.second); + } + + Map per_target_modules; + for (const auto& kv : mod->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + if (func->IsInstance()) { + main_mod->Add(var, func); + } else if (func->IsInstance()) { + // Extract target + Optional target = func->GetAttr(tvm::attr::kTarget); + ICHECK(target) << "Target should be set at this point"; + + // Put the function in per_target_modules + if (!per_target_modules.count(target.value())) { + // Initialize the IRModule for this target and add the function + IRModule target_module; + target_module->Add(var, func); + per_target_modules.Set(target.value(), target_module); + } else { + // The IRModule for this target is initialized, so just add the function. + IRModule target_module = per_target_modules.at(target.value()); + target_module->Add(var, func); + } + } else { + LOG(FATAL) + << "The function types in the IRModule should be RelayFunction or PrimFunc, but got " + << func->GetTypeKey(); + } + } + + // Put the LoweredModule together + LoweredModule lowered_module; + lowered_module.main_module = main_mod; + lowered_module.per_target_module = per_target_modules; + + // Extract external modules and main func info, add to lowered module if they exist + auto external_mods = mod->GetAttr>("external_mods"); + if (external_mods) { + lowered_module.external_mods = external_mods.value(); + } + auto main_func_info = mod->GetAttr("main_func_info"); + if (main_func_info) { + lowered_module.main_func_info = main_func_info.value(); + } + return lowered_module; +} + +Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + std::function process_fn) { + runtime::TypedPackedFunc pass_func = [=](IRModule module, + PassContext ctx) { + return LoweredModuleToIRModule( + LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn)); + }; + return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {}); +} } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index a32eefb5127f..e9cfb0d62e66 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -76,7 +76,7 @@ using ProcessFn = std::function; /*! * \brief A compiler which lowers primitive Relay functions to tensor expressions - * and schdules them into TIR functions. + * and schedules them into TIR functions. */ class TECompilerNode : public Object { public: @@ -166,29 +166,73 @@ void UpdateFunctionMetadata(Function relay_func, /*! * \brief Obtain the Target from the device type. * If homogenous compilation, this will return the only target. - * If heteregenous compilation, this will select associated using the targets_ Map. + * If heterogeneous compilation, this will select the associated target using the + * targets_ Map. * * \param dev_type * \return Target */ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); +/*! \brief Utility to convert a LoweredModule to an IRModule. + * + * This function takes all the target specific modules in LoweredModule and + * annotates their functions with the correct target, and puts all those functions + * in one IRModule. + * The purpose of this utility is to allow us to slowly remove LoweredModule from the codebase. + * + * \param mod The LoweredModule to convert. + * \return The IRModule form of the input LoweredModule. + */ +IRModule LoweredModuleToIRModule(LoweredModule mod); + +/*! \brief Utility to convert an IRModule to a LoweredModule. + * + * This function takes all the functions in the IRModule and moves them into target-specific + * IRModules stored inside a LoweredModule. + * The purpose of this utility is to allow us to slowly remove LoweredModule from the codebase. + * \param mod The IRModule to convert. + * \return The LoweredModule form of the input IRModule. + */ +LoweredModule IRModuleToLoweredModule(IRModule mod); + /*! \brief Lower an IRModule's primitive functions to TIR. * * This is the "back half" of the Relay compiler which lowers "primitive functions" * to TE expressions, schedules them, and then to TIR. * - * /param module The IRModule. - * /param targets The mapping for devices to targets. - * /param device_map An analysis result mapping each sub-expression to a device. - * /return The lowered module, see above. + * \param module The IRModule. + * \param targets The mapping for devices to targets. + * \param device_map An analysis result mapping each sub-expression to a device. + * \param memory_plan The memory plan used during lowering + * \param module_name The name of this module + * \param process_fn Callback allowing one-level up code generators to process + * each function that we lower + * \return The lowered module, see above. */ -// TODO(@electriclilies): Not sure if this default initialization is correct... LoweredModule LowerTE( const IRModule& module, TargetMap targets, DeviceMap device_map, backend::StaticMemoryPlan memory_plan, const String& module_name, ProcessFn process_fn = [](Function f) {}); +/*! \brief Pass to lower an IRModule's primitive functions to TIR. + * + * This is the "back half" of the Relay compiler which lowers "primitive functions" + * to TE expressions, schedules them, and then to TIR. This Pass calls LowerTE, and + * uses LoweredModuleToIRModule utility to convert the output LowerTE's output + * LoweredModule into an IRModule before returning it. + * + * \param targets The mapping for devices to targets. + * \param device_context_map An analysis result mapping each sub-expression to a device. + * \param memory_plan The memory plan used during lowering + * \param module_name The name of this module + * \param process_fn Callback allowing one-level up code generators to process + * each function that we lower + * \returns The pass which lowers primative functions to TIR + */ +transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + std::function process_fn); } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index bbe38f0426b4..d0e83765928a 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -41,6 +41,7 @@ #include #include +#include "../op/memory/memory.h" #include "../transforms/pass_utils.h" #include "utils.h" @@ -98,7 +99,8 @@ Array GetShape(const Array& shape) { res.push_back(val); #endif // TVM_INDEX_DEFAULT_I64 } else if (val->IsInstance()) { - res.push_back(val.as()->ToVar()); + // currently all 'any' we meet in shape function are non-negative. + res.push_back(val.as()->ToSizeVar()); } else { res.push_back(val); } @@ -119,21 +121,10 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator Array fn_inputs; for (Var param : prim_func->params) { Array inputs; - if (const auto* ttype = param->checked_type().as()) { + for (const auto& ttype : FlattenTupleType(param->checked_type())) { tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); fn_inputs.push_back(tensor); inputs.push_back(tensor); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - // TODO(@icemelon): Allow recursive tuple - ICHECK(ttype != nullptr); - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - fn_inputs.push_back(tensor); - inputs.push_back(tensor); - } } memo_[param] = inputs; } @@ -313,6 +304,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator Array VisitExpr_(const TupleNode* op) final { Array fields; for (Expr field : op->fields) { + // TODO(mbs): Generalize to be equivalent to FlattenTupleType. ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; Array res = VisitExpr(field); ICHECK_EQ(res.size(), 1); @@ -371,7 +363,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> Array data_inputs; Array shape_inputs; - auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) { + for (const auto& ttype : FlattenTupleType(param->checked_type())) { // Add data placeholder Shape shape = GetShape(ttype->shape); tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype); @@ -384,20 +376,6 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> } tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64)); shape_inputs.push_back(shape_tensor); - }; - - if (const auto* ttype = param->checked_type().as()) { - add_placeholder(ttype); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - // TODO(@icemelon): Support recursive tuple - ICHECK(tuple_type); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - ICHECK(ttype); - add_placeholder(ttype); - } } param_data_[param] = data_inputs; param_shapes_[param] = shape_inputs; diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 1c7511ffd7d2..47ba96b2c77e 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -213,6 +213,7 @@ CachedFunc PrimFuncFor(const Function& source_func, const Target& target, CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, std::function renamer); +// TODO(mbs): Bring name uniqification under control -- this is replicated in quite a few places. std::string GetUniqueName(std::string name, std::unordered_map* name_map); // implementations diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index f0c543f1244b..4b4844599e29 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -24,6 +24,8 @@ #include "utils.h" +#include + namespace tvm { namespace relay { namespace backend { @@ -120,6 +122,71 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ",\n relay_primfuncs=" << node->relay_primfuncs << ")"; }); +Array GetPassPrefix(const Map& targets, bool is_vm) { + Array pass_seqs; + Array entry_functions{"main"}; + pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); + pass_seqs.push_back(transform::ToBasicBlockNormalForm()); + // Run all dialect legalization passes. + pass_seqs.push_back(relay::qnn::transform::Legalize()); + + // Legalize pass is restricted to homogeneous execution for now. + if (targets.size() == 1) { + pass_seqs.push_back(transform::Legalize()); + } + + pass_seqs.push_back(transform::SimplifyInference()); + + if (is_vm) { + // eta expand to support constructors in argument position + pass_seqs.push_back(transform::EtaExpand( + /* expand_constructor */ true, /* expand_global_var */ false)); + } else { + // Convert Dynamic ops to static versions + pass_seqs.push_back(transform::DynamicToStatic()); + } + + PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == DataType::Int(32)) { + *rv = true; + } + } + } + *rv = false; + }); + pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::SimplifyExpr()); + if (is_vm) { + pass_seqs.push_back(transform::InlinePrimitives()); + } + pass_seqs.push_back(transform::CombineParallelConv2D(3)); + pass_seqs.push_back(transform::CombineParallelDense(3)); + pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); + pass_seqs.push_back(transform::FoldConstant()); + pass_seqs.push_back(transform::FoldScaleAxis()); + pass_seqs.push_back(transform::CanonicalizeCast()); + pass_seqs.push_back(transform::CanonicalizeOps()); + + // Alter layout transformation is only applied to homogeneous execution yet. + if (targets.size() == 1) { + if (!is_vm) { + pass_seqs.push_back(transform::InferType()); + } + pass_seqs.push_back(transform::AlterOpLayout()); + } + + // Fast math optimizations. + pass_seqs.push_back(transform::FastMath()); + pass_seqs.push_back(transform::FoldConstant()); + return pass_seqs; +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index d2a173a43f46..a0c7a5aad26d 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -44,7 +44,12 @@ namespace tvm { namespace relay { +namespace transform { +Pass InlinePrimitives(); +} + namespace backend { +using Pass = tvm::transform::Pass; /*! * \brief The static storage information produced by memory planning. @@ -410,6 +415,18 @@ inline bool IsCompileEngineCacheDisabled() { .value(); } +/*! + * \brief Get the sequence of Relay optimization passes based on backend type. + * The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight + * difference. This function unifies the shared optimization pass prefix between vm and graph + * runtime, and returns the pass prefix given the backend type. + * + * \param targets The device type to `Target` mapping. + * \param is_vm A boolean indicating if the passes are used for vm or graph runtime. + * \return An array of passes. + */ +Array GetPassPrefix(const Map& targets, bool is_vm); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 96aa77f286a9..b3eab91d202c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -45,7 +45,6 @@ #include #include "../../../target/source/codegen_source_base.h" -#include "../../backend/compile_engine.h" #include "../../op/op_common.h" #include "../../transforms/pass_utils.h" #include "../utils.h" @@ -79,6 +78,7 @@ namespace vm { using namespace tvm::runtime; using namespace tvm::runtime::vm; using namespace relay::transform; +using namespace tec; // (@jroesch): VM passes, eventually declare as passes. bool IsClosure(const Function& func); @@ -253,7 +253,6 @@ class VMFunctionCompiler : ExprFunctor { ExprDeviceMap expr_device_map) : last_register_(0), registers_num_(0), - engine_(CompileEngine::Global()), context_(context), target_host_(target_host), expr_device_map_(std::move(expr_device_map)) { @@ -465,7 +464,7 @@ class VMFunctionCompiler : ExprFunctor { void EmitShapeFunc(Function func, Array inputs, Array outputs) { // Lower shape function CCacheKey key(func, target_host_); - auto cfunc = engine_->LowerShapeFunc(key); + auto cfunc = context_->compiler->LowerShapeFunc(key); int op_index = -1; // pick the only function inside the context ICHECK_EQ(cfunc->funcs->functions.size(), 1); @@ -551,7 +550,7 @@ class VMFunctionCompiler : ExprFunctor { CCacheKey key(func, target); auto mangle_fn = [](String name) { return name; }; - auto cfunc = engine_->Lower(key, mangle_fn); + auto cfunc = context_->compiler->Lower(key, mangle_fn); auto op_index = -1; if (func->GetAttr(attr::kCompiler).defined()) { @@ -857,8 +856,6 @@ class VMFunctionCompiler : ExprFunctor { size_t last_register_; /*! \brief Total number of virtual registers allocated. */ size_t registers_num_; - /*! \brief Compiler engine to lower primitive functions. */ - CompileEngine engine_; /*! \brief Global shared meta data */ VMCompilerContext* context_; /*! \brief Target devices. */ @@ -1042,57 +1039,7 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, mod->Add(gvar, f); } - Array pass_seqs; - Array entry_functions{"main"}; - pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); - pass_seqs.push_back(transform::ToBasicBlockNormalForm()); - // Run all dialect legalization passes. - pass_seqs.push_back(relay::qnn::transform::Legalize()); - - // Legalize pass is restricted to homogeneous execution for now. - if (targets.size() == 1) { - pass_seqs.push_back(transform::Legalize()); - } - - // eta expand to support constructors in argument position - pass_seqs.push_back(transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false)); - - pass_seqs.push_back(transform::SimplifyInference()); - PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == DataType::Int(32)) { - *rv = true; - } - } - } - *rv = false; - }); - pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); - pass_seqs.push_back(transform::SimplifyExpr()); - pass_seqs.push_back(transform::InlinePrimitives()); - - pass_seqs.push_back(transform::CombineParallelConv2D(3)); - pass_seqs.push_back(transform::CombineParallelDense(3)); - pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); - pass_seqs.push_back(transform::FoldConstant()); - pass_seqs.push_back(transform::FoldScaleAxis()); - pass_seqs.push_back(transform::CanonicalizeCast()); - pass_seqs.push_back(transform::CanonicalizeOps()); - - // Alter layout transformation is only applied to homogeneous execution yet. - if (targets.size() == 1) { - pass_seqs.push_back(transform::AlterOpLayout()); - } - - // Fast math optimizations. - pass_seqs.push_back(transform::FastMath()); - pass_seqs.push_back(transform::FoldConstant()); + Array pass_seqs = relay::backend::GetPassPrefix(targets, true); if (targets_.size() > 1) { // Handle heterogeneous compilation. @@ -1184,8 +1131,8 @@ void VMCompiler::Codegen() { } } - auto compile_engine = CompileEngine::Global(); - auto ext_mods = compile_engine->LowerExternalFunctions(); + auto ext_mods = context_.compiler->LowerExternalFunctions(); + runtime::Module lib; if (funcs.size() > 0) { lib = tvm::build(funcs, target_host_); @@ -1196,7 +1143,6 @@ void VMCompiler::Codegen() { } lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_, runtime::Metadata()); exec_->SetLib(lib); - CompileEngine::Global()->Clear(); } ExprDeviceMap VMCompiler::AnalyzeContext() const { diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 3a3796373a61..a05c52ced07f 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -43,8 +43,9 @@ #include "../../../runtime/vm/naive_allocator.h" #include "../../../runtime/vm/profiler/vm.h" -#include "../../backend/compile_engine.h" #include "../../transforms/pass_utils.h" +#include "../te_compiler.h" +#include "../te_compiler_cache.h" namespace tvm { namespace relay { @@ -75,12 +76,14 @@ struct VMCompilerContext { TagMap tag_map; // Map from global var to a unique integer GlobalMap global_map; + // TEcompiler for lowering + tec::TECompiler compiler; // List of constants std::vector constants; // Device type for constants std::vector const_device_type; // List of cached functions - std::vector cached_funcs; + std::vector cached_funcs; // The functions that have been lowered. std::unordered_map seen_funcs; }; diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 5ce06d9fefaa..851a498377b2 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -29,50 +29,12 @@ #include -#include "indexed_graph.h" +#include "dataflow_matcher_impl.h" namespace tvm { namespace relay { // Pattern Matcher - -class DominatorMatcher; - -class DFPatternMatcher : public DFPatternFunctor { - public: - explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} - bool Match(const DFPattern& pattern, const Expr& expr); - Map> GetMemo() { return Map>(memo_); } - const IndexedGraph expr_graph_; - - protected: - bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; - bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override; - - void ClearMap(size_t watermark); - bool MatchesPath(const DominatorPatternNode* op, const Expr& expr); - bool DominatesParent(const DominatorPatternNode* op, const Expr& expr); - - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; - std::vector matched_nodes_; - bool memoize_ = true; -}; - bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { memo_.clear(); matched_nodes_.clear(); @@ -542,388 +504,320 @@ bool MatchPattern(DFPattern pattern, Expr expr) { TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern); -/*! - * \brief PatternGrouper does pre-rewriting pattern matching and analysis - * - * This class creates a number of groups of matched expressions, ensures they don't overlap, and - * returns them to the caller for post-analysis rewriting. - * - * This is primarily needed to support the post-dominator analysis required for dominator pattern - * matching. - */ -class PatternGrouper { +/*! \brief Creates a new set of nodes based on Group inputs, used to create functions and perform + * group overlap analysis */ +class MatchExtractor : public ExprMutator { public: - /*! \brief Internal Group class for storing analysis */ - struct Group { - Expr root_node; - int gid; - Map> matched_nodes; - std::string name; - Function function; - Array args; - }; - - /*! \brief Return the group assignments of expressions */ - const std::unordered_map& GetGIDAssignments() { - return gid_assignments_; + explicit MatchExtractor( + const std::unordered_map& inputs) + : inputs_(inputs) {} + const std::unordered_map& GetMemo() { + return this->memo_; } - /*! \brief Group expressions that match the pattern */ - const std::unordered_map& GroupMatches(const DFPattern& pattern, const Expr& pre) { - groups_.clear(); - gid_assignments_.clear(); + const std::string& GetName() { return name_; } - pattern_ = pattern; - pattern_graph_ = CreateIndexedGraph(pattern_); - auto matcher = DFPatternMatcher(pre); - matcher_ = &matcher; - this->VisitExprs(); - return this->groups_; + protected: + Expr VisitExpr(const Expr& pre) override { + if (inputs_.count(pre)) { + return inputs_.at(pre); + } + return ExprMutator::VisitExpr(pre); } + Expr VisitExpr_(const TupleNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Tuple_"; + return out; + }; + Expr VisitExpr_(const FunctionNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Function"; + return out; + }; + Expr VisitExpr_(const CallNode* call_node) override { + auto out = ExprMutator::VisitExpr_(call_node); + if (auto operation = call_node->op.as()) { + name_ += operation->name + "_"; + } else { + name_ += "Call_"; + } + return out; + }; + Expr VisitExpr_(const LetNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Let_"; + return out; + }; + Expr VisitExpr_(const IfNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "If_"; + return out; + }; + Expr VisitExpr_(const TupleGetItemNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "TupleGetItem" + std::to_string(op->index) + "_"; + return out; + }; + Expr VisitExpr_(const MatchNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Match_"; + return out; + }; + std::string name_; + const std::unordered_map inputs_; +}; - protected: - /*! \brief Iteratively traverse the Expression in pre-order to find subgraphs - * - * If we traverse the graph in post-order, we can run into situtations where a small subgraph will - * match the pattern. Due to options like AltPattern, a larger subgraph with more nodes later in - * the graph may also match the pattern. With post-order traversal, we mark the smaller subgraph - * as matched and fail to catch the larger subgraph. This problem is fixed by using pre-order - * traversal. - */ - void VisitExprs() { - std::unordered_set pre_partitioned; - for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) { - size_t index = i - 1; - Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_; - if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped - if (auto op = current.as()) { - if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { - pre_partitioned.insert(current); - PostOrderVisit(op->body, - [&pre_partitioned](const Expr& expr) { pre_partitioned.insert(expr); }); - } - } - if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, current)) { - CreateGroup(current); +/*! \brief Group expressions that match the pattern */ +const std::unordered_map& PatternGrouper::GroupMatches( + const DFPattern& pattern, const Expr& pre) { + groups_.clear(); + gid_assignments_.clear(); + + pattern_ = pattern; + pattern_graph_ = CreateIndexedGraph(pattern_); + auto matcher = DFPatternMatcher(pre); + matcher_ = &matcher; + this->VisitExprs(); + return this->groups_; +} + +void PatternGrouper::VisitExprs() { + std::unordered_set pre_partitioned; + for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) { + size_t index = i - 1; + Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_; + if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped + if (auto op = current.as()) { + if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { + pre_partitioned.insert(current); + PostOrderVisit(op->body, + [&pre_partitioned](const Expr& expr) { pre_partitioned.insert(expr); }); } } + if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, current)) { + CreateGroup(current); + } } } - /*! \brief Creates a new set of nodes based on Group inputs, used to create functions and perform - * group overlap analysis */ - class MatchExtractor : public ExprMutator { - public: - explicit MatchExtractor( - const std::unordered_map& inputs) - : inputs_(inputs) {} - const std::unordered_map& GetMemo() { - return this->memo_; - } - const std::string& GetName() { return name_; } +} - protected: - Expr VisitExpr(const Expr& pre) override { - if (inputs_.count(pre)) { - return inputs_.at(pre); +void PatternGrouper::CreateGroup(const Expr& expr) { + int var_number = 0; + + auto node_map = matcher_->GetMemo(); + // Get fuzzy patterns + std::unordered_set fuzzy_matches; + for (auto node : pattern_graph_.topological_order_) { + // Don't treat fuzzy Dominator patterns input variables for partition + if (auto op = node->ref_.as()) { + for (auto fuzzy_op : {op->parent, op->path}) { + for (auto match : node_map[fuzzy_op]) { + fuzzy_matches.insert(match); + } } - return ExprMutator::VisitExpr(pre); } - Expr VisitExpr_(const TupleNode* op) override { - auto out = ExprMutator::VisitExpr_(op); - name_ += "Tuple_"; - return out; - }; - Expr VisitExpr_(const FunctionNode* op) override { - auto out = ExprMutator::VisitExpr_(op); - name_ += "Function"; - return out; - }; - Expr VisitExpr_(const CallNode* call_node) override { - auto out = ExprMutator::VisitExpr_(call_node); - if (auto operation = call_node->op.as()) { - name_ += operation->name + "_"; - } else { - name_ += "Call_"; + // Don't treat Function params or body as input variables for partition + if (node->ref_.as()) { + auto matches = node_map[node->ref_]; + for (auto match : matches) { + auto graph = CreateIndexedGraph(match.as()->body); + for (auto node : graph.topological_order_) { + fuzzy_matches.insert(node->ref_); + } } - return out; - }; - Expr VisitExpr_(const LetNode* op) override { - auto out = ExprMutator::VisitExpr_(op); - name_ += "Let_"; - return out; - }; - Expr VisitExpr_(const IfNode* op) override { - auto out = ExprMutator::VisitExpr_(op); - name_ += "If_"; - return out; - }; - Expr VisitExpr_(const TupleGetItemNode* op) override { - auto out = ExprMutator::VisitExpr_(op); - name_ += "TupleGetItem" + std::to_string(op->index) + "_"; - return out; - }; - Expr VisitExpr_(const MatchNode* op) override { - auto out = ExprMutator::VisitExpr_(op); - name_ += "Match_"; - return out; - }; - std::string name_; - const std::unordered_map inputs_; - }; + } + } - /*! \brief Create a group based on a matched expression */ - void CreateGroup(const Expr& expr) { - int var_number = 0; - - auto node_map = matcher_->GetMemo(); - // Get fuzzy patterns - std::unordered_set fuzzy_matches; - for (auto node : pattern_graph_.topological_order_) { - // Don't treat fuzzy Dominator patterns input variables for partition - if (auto op = node->ref_.as()) { - for (auto fuzzy_op : {op->parent, op->path}) { - for (auto match : node_map[fuzzy_op]) { - fuzzy_matches.insert(match); - } - } + // Create input variables + Group group; + group.root_node = expr; + group.matched_nodes = node_map; + + std::unordered_map inputs; + Array params; + + for (auto node : pattern_graph_.topological_order_) { + auto make_input = [&](const Expr& input) { + if (fuzzy_matches.count(input) == 0 && input.as() == nullptr && + input.as() == nullptr && !EmbedConst(input, node->ref_)) { + inputs[input] = + Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), + NullValue()); + group.args.push_back(input); + params.push_back(inputs[input]); + var_number++; } - // Don't treat Function params or body as input variables for partition - if (node->ref_.as()) { + }; + auto tuple = node->ref_.as(); + auto call = node->ref_.as(); + if (tuple && !tuple->fields.defined()) { + if (node_map.count(node->ref_)) { auto matches = node_map[node->ref_]; for (auto match : matches) { - auto graph = CreateIndexedGraph(match.as()->body); - for (auto node : graph.topological_order_) { - fuzzy_matches.insert(node->ref_); + for (auto input : match.as()->fields) { + make_input(input); } } } - } - - // Create input variables - Group group; - group.root_node = expr; - group.matched_nodes = node_map; - - std::unordered_map inputs; - Array params; - - for (auto node : pattern_graph_.topological_order_) { - auto make_input = [&](const Expr& input) { - if (fuzzy_matches.count(input) == 0 && input.as() == nullptr && - input.as() == nullptr && !EmbedConst(input, node->ref_)) { - inputs[input] = - Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), - NullValue()); - group.args.push_back(input); - params.push_back(inputs[input]); - var_number++; - } - }; - auto tuple = node->ref_.as(); - auto call = node->ref_.as(); - if (tuple && !tuple->fields.defined()) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; - for (auto match : matches) { - for (auto input : match.as()->fields) { - make_input(input); - } - } - } - } else if (call && !call->args.defined()) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; - for (auto match : matches) { - for (auto input : match.as()->args) { - make_input(input); - } + } else if (call && !call->args.defined()) { + if (node_map.count(node->ref_)) { + auto matches = node_map[node->ref_]; + for (auto match : matches) { + for (auto input : match.as()->args) { + make_input(input); } } - } else if (node->inputs_.size() == 0) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; - for (auto match : matches) { - make_input(match); - } + } + } else if (node->inputs_.size() == 0) { + if (node_map.count(node->ref_)) { + auto matches = node_map[node->ref_]; + for (auto match : matches) { + make_input(match); } } } + } - graph_number_++; - - // Extract a Function. Used in Partition directly, - // used to determine Group overlap in other passes - auto extractor = MatchExtractor(inputs); - auto body = extractor.Mutate(expr); - - group.function = Function(params, body, NullValue(), Array()); - group.name = extractor.GetName(); - // Check to make sure we aren't overlapping with another group or creating an invalid fusion - // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the - // pattern with the input FunctionVar* Variables. The resulting memoization map will only - // contain nodes in the expression that matched the pattern. If a non-input node of the pattern - // (i.e., some piece of computation) overlaps with the nodes in a previous group, we'll have a - // situation where we try to rewrite the same node twice in the second rewriting or parition - // pass. This isn't valid, so we check for it here. We ignore Ops, functions, and constants - // because they exist more globally outside of the fusion. - // Similiarly, if interior nodes in a group are used outside of the group fusing to a single - // output would create an invalid graph tranformation, so we block the creation of such groups. - auto memo = extractor.GetMemo(); - for (auto kv : memo) { - // Check to ensure that this node isn't an input or a global - if (inputs.count(kv.first) == 0 && kv.first.as() == nullptr && - kv.first.as() == nullptr && kv.first.as() == nullptr) { - if (gid_assignments_.count(kv.first) != 0) { - // check to see if the node is use in other groups - // Exit due to overlapping partitions - return; - } else if (kv.second != body) { - // if the node isn't the output of the group - auto node = matcher_->expr_graph_.node_map_.at(kv.first); - for (auto* output : node->outputs_) { - // and the node is used by nodes outside of the group - if (memo.count(output->ref_) == 0 && - !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) { - // Exit because nodes in this pattern's body are used outside the pattern - // fusing it would be invalid - return; - } + graph_number_++; + + // Extract a Function. Used in Partition directly, + // used to determine Group overlap in other passes + auto extractor = MatchExtractor(inputs); + auto body = extractor.Mutate(expr); + + group.function = Function(params, body, NullValue(), Array()); + group.name = extractor.GetName(); + // Check to make sure we aren't overlapping with another group or creating an invalid fusion + // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the + // pattern with the input FunctionVar* Variables. The resulting memoization map will only + // contain nodes in the expression that matched the pattern. If a non-input node of the pattern + // (i.e., some piece of computation) overlaps with the nodes in a previous group, we'll have a + // situation where we try to rewrite the same node twice in the second rewriting or parition + // pass. This isn't valid, so we check for it here. We ignore Ops, functions, and constants + // because they exist more globally outside of the fusion. + // Similiarly, if interior nodes in a group are used outside of the group fusing to a single + // output would create an invalid graph tranformation, so we block the creation of such groups. + auto memo = extractor.GetMemo(); + for (auto kv : memo) { + // Check to ensure that this node isn't an input or a global + if (inputs.count(kv.first) == 0 && kv.first.as() == nullptr && + kv.first.as() == nullptr && kv.first.as() == nullptr) { + if (gid_assignments_.count(kv.first) != 0) { + // check to see if the node is use in other groups + // Exit due to overlapping partitions + return; + } else if (kv.second != body) { + // if the node isn't the output of the group + auto node = matcher_->expr_graph_.node_map_.at(kv.first); + for (auto* output : node->outputs_) { + // and the node is used by nodes outside of the group + if (memo.count(output->ref_) == 0 && + !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) { + // Exit because nodes in this pattern's body are used outside the pattern + // fusing it would be invalid + return; } } } } - // Assign Group Ids - group.gid = ++gid_; - for (auto kv : extractor.GetMemo()) { - gid_assignments_[kv.first] = gid_; - } + } + // Assign Group Ids + group.gid = ++gid_; + for (auto kv : extractor.GetMemo()) { + gid_assignments_[kv.first] = gid_; + } + + // Save Group + groups_[group.gid] = std::move(group); +} - // Save Group - groups_[group.gid] = std::move(group); - } - - /*! \brief EmbedConst implements rules for embedding constants into partitioned functions or - * lifting them into the function arguments. - * - * The rules depend on what pattern the ConstantNode matched. - * - * The basic rules are: - * If the constant matches ExprPattern(relay.const(*)) or a ConstantPattern(), embed the constant - * in the partitioned function. If the constant matched an AltPattern, recursively check the - * matched side of the pattern. For any other matching pattern (i.e, wildcard, VarPattern, etc), - * lift the constant into the arguments of the partitioned function. - */ - bool EmbedConst(const Expr& expr, const DFPattern pattern) { - bool embed = false; - if (expr.as()) { - if (pattern.as() != nullptr) { +bool PatternGrouper::EmbedConst(const Expr& expr, const DFPattern pattern) { + bool embed = false; + if (expr.as()) { + if (pattern.as() != nullptr) { + embed = true; + } else if (auto expr_pat = pattern.as()) { + if (expr_pat->expr.as()) { embed = true; - } else if (auto expr_pat = pattern.as()) { - if (expr_pat->expr.as()) { - embed = true; - } - } else if (auto alt_pat = pattern.as()) { - if (matcher_->Match(alt_pat->left, expr)) { - embed = EmbedConst(expr, alt_pat->left); - } else { - embed = EmbedConst(expr, alt_pat->right); - } + } + } else if (auto alt_pat = pattern.as()) { + if (matcher_->Match(alt_pat->left, expr)) { + embed = EmbedConst(expr, alt_pat->left); + } else { + embed = EmbedConst(expr, alt_pat->right); } } - return embed; } - // Internal State - DFPattern pattern_; - std::unordered_map groups_; - std::unordered_map gid_assignments_; - DFPatternMatcher* matcher_ = nullptr; - IndexedGraph pattern_graph_; - int gid_ = 0; - int graph_number_ = 0; -}; + return embed; +} // Rewrite -DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type) { +DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type, + bool rewrite_once) { ObjectPtr n = make_object(); n->pattern = std::move(pattern); n->function = std::move(function); n->require_type = require_type; + n->rewrite_once = rewrite_once; data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback") - .set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type) { - return DFPatternCallback(pattern, function, require_type); + .set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type, + bool rewrite_once) { + return DFPatternCallback(pattern, function, require_type, rewrite_once); }); -/*! - * \brief PatternRewriter rewrites the expression by finding matches and allowing user callback - * function to rewrite those matches - * - * The class uses PatternGrouper to support the dominator pattern. - */ -class PatternRewriter : protected MixedModeMutator { - public: - PatternRewriter(IRModule mod) : mod_(mod) {} - /*! \brief Rewrite can take a number of callbacks and will repeatedly rewrite the graph with the - * callbacks until it stops changing */ - Expr Rewrite(const Array& callbacks, const Expr& pre) { - auto post = pre; - auto last = post; - // rewrite the graph until it stops changing to make sure all rewrites are complete - int count = 0; - bool equal = true; - static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); - ICHECK(structural_equal) << "node.StructuralEqual is not registered."; - do { - last = post; - for (auto callback : callbacks) { - callback_ = callback; - if (callback_->require_type) { - post = InferTypeWithModule(post, mod_); - } - auto grouper = PatternGrouper(); - groups_ = grouper.GroupMatches(callback_->pattern, post); - gid_assignments_ = grouper.GetGIDAssignments(); - memo_.clear(); - post = this->VisitExpr(post); - count++; - } - equal = (*structural_equal)(last, post, false, true); - } while (!equal && count < 100); - if (count >= 100) { - LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?"; +Expr PatternRewriter::Rewrite(const Array& callbacks, const Expr& pre) { + auto post = pre; + auto last = post; + // rewrite the graph until it stops changing to make sure all rewrites are complete + int count = 0; + bool equal = true; + static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); + ICHECK(structural_equal) << "node.StructuralEqual is not registered."; + do { + last = post; + for (auto callback : callbacks) { + callback_ = callback; + if (callback_->require_type) { + post = InferTypeWithModule(post, mod_); + } + auto grouper = PatternGrouper(); + groups_ = grouper.GroupMatches(callback_->pattern, post); + gid_assignments_ = grouper.GetGIDAssignments(); + memo_.clear(); + post = this->VisitExpr(post); + count++; } - return post; + equal = (*structural_equal)(last, post, false, true); + } while (!equal && count < 100 && !callback_->rewrite_once); + if (count >= 100) { + LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?"; } + return post; +} - protected: - Expr DispatchVisitExpr(const Expr& pre) override { - auto post = MixedModeMutator::DispatchVisitExpr(pre); - if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { - // Convert the pre-rewrite node map to a post-rewrite node map - auto group = groups_[gid_assignments_[pre]]; - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_map; - for (auto kv : group.matched_nodes) { - Array tmp; - for (size_t i = 0; i < kv.second.size(); ++i) { - tmp.push_back(this->memo_[kv.second[i]]); - } - node_map.insert({kv.first, tmp}); +Expr PatternRewriter::DispatchVisitExpr(const Expr& pre) { + auto post = MixedModeMutator::DispatchVisitExpr(pre); + if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { + // Convert the pre-rewrite node map to a post-rewrite node map + auto group = groups_[gid_assignments_[pre]]; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_map; + for (auto kv : group.matched_nodes) { + Array tmp; + for (size_t i = 0; i < kv.second.size(); ++i) { + tmp.push_back(this->memo_[kv.second[i]]); } - // run the user callback function - return callback_->function(pre, post, Map>(node_map)); + node_map.insert({kv.first, tmp}); } - return post; + // run the user callback function + return callback_->function(pre, post, Map>(node_map)); } - - IRModule mod_; - DFPatternCallback callback_; - std::unordered_map groups_; - std::unordered_map gid_assignments_; -}; + return post; +} Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod) { return PatternRewriter(mod).Rewrite(callbacks, expr); diff --git a/src/relay/ir/dataflow_matcher_impl.h b/src/relay/ir/dataflow_matcher_impl.h new file mode 100644 index 000000000000..d993d4720e4e --- /dev/null +++ b/src/relay/ir/dataflow_matcher_impl.h @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_matcher_impl.h + * \brief The auxiliary data structure for dataflow matcher. + */ +#ifndef TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_ +#define TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_ + +#include +#include +#include + +#include +#include +#include + +#include "indexed_graph.h" + +namespace tvm { +namespace relay { + +class DFPatternMatcher : public DFPatternFunctor { + public: + explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} + bool Match(const DFPattern& pattern, const Expr& expr); + Map> GetMemo() { return Map>(memo_); } + const IndexedGraph expr_graph_; + + protected: + bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; + bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override; + + void ClearMap(size_t watermark); + bool MatchesPath(const DominatorPatternNode* op, const Expr& expr); + bool DominatesParent(const DominatorPatternNode* op, const Expr& expr); + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; + std::vector matched_nodes_; + bool memoize_ = true; +}; + +/*! + * \brief PatternGrouper does pre-rewriting pattern matching and analysis + * + * This class creates a number of groups of matched expressions, ensures they don't overlap, and + * returns them to the caller for post-analysis rewriting. + * + * This is primarily needed to support the post-dominator analysis required for dominator pattern + * matching. + */ +class PatternGrouper { + public: + /*! \brief Internal Group class for storing analysis */ + struct Group { + Expr root_node; + int gid; + Map> matched_nodes; + std::string name; + Function function; + Array args; + }; + + /*! \brief Return the group assignments of expressions */ + inline const std::unordered_map& GetGIDAssignments() { + return gid_assignments_; + } + /*! \brief Group expressions that match the pattern */ + const std::unordered_map& GroupMatches(const DFPattern& pattern, const Expr& pre); + + protected: + /*! \brief Iteratively traverse the Expression in pre-order to find subgraphs + * + * If we traverse the graph in post-order, we can run into situtations where a small subgraph will + * match the pattern. Due to options like AltPattern, a larger subgraph with more nodes later in + * the graph may also match the pattern. With post-order traversal, we mark the smaller subgraph + * as matched and fail to catch the larger subgraph. This problem is fixed by using pre-order + * traversal. + */ + void VisitExprs(); + + /*! \brief Create a group based on a matched expression */ + void CreateGroup(const Expr& expr); + + /*! \brief EmbedConst implements rules for embedding constants into partitioned functions or + * lifting them into the function arguments. + * + * The rules depend on what pattern the ConstantNode matched. + * + * The basic rules are: + * If the constant matches ExprPattern(relay.const(*)) or a ConstantPattern(), embed the constant + * in the partitioned function. If the constant matched an AltPattern, recursively check the + * matched side of the pattern. For any other matching pattern (i.e, wildcard, VarPattern, etc), + * lift the constant into the arguments of the partitioned function. + */ + bool EmbedConst(const Expr& expr, const DFPattern pattern); + // Internal State + DFPattern pattern_; + std::unordered_map groups_; + std::unordered_map gid_assignments_; + DFPatternMatcher* matcher_ = nullptr; + IndexedGraph pattern_graph_; + int gid_ = 0; + int graph_number_ = 0; +}; + +/*! + * \brief PatternRewriter rewrites the expression by finding matches and allowing user callback + * function to rewrite those matches + * + * The class uses PatternGrouper to support the dominator pattern. + */ +class PatternRewriter : protected MixedModeMutator { + public: + explicit PatternRewriter(IRModule mod) : mod_(mod) {} + /*! \brief Rewrite can take a number of callbacks and will repeatedly rewrite the graph with the + * callbacks until it stops changing */ + virtual Expr Rewrite(const Array& callbacks, const Expr& pre); + + protected: + virtual Expr DispatchVisitExpr(const Expr& pre); + + IRModule mod_; + DFPatternCallback callback_; + std::unordered_map groups_; + std::unordered_map gid_assignments_; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_ diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 1a47193bb91a..43ce6656cdc0 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -49,7 +49,7 @@ Expr MakeMatmul(Expr tensor_a, Expr tensor_b, IndexExpr units, DataType out_dtyp Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype); -Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype); +Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype, bool transpose_a, bool transpose_b); Expr MakeExpandDims(Expr data, int axis, int num_newaxis); diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index d09a8495b549..a05e460dc680 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -206,6 +206,13 @@ Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { return Call(op, {data, weight}, Attrs(attrs), {}); } +InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + return InferCorrectLayoutOutput({"NC", "NK"}, {"NC"}, attrs); +} + TVM_REGISTER_GLOBAL("relay.op.nn._make.dense").set_body_typed(MakeDense); RELAY_REGISTER_OP("nn.dense") @@ -221,35 +228,75 @@ RELAY_REGISTER_OP("nn.dense") .add_argument("data", "nD Tensor", "Input data.") .add_argument("weight", "2D Tensor", "Weight matrix.") .set_support_level(1) + .set_attr("FInferCorrectLayout", DenseInferCorrectLayout) .add_type_rel("Dense", MatmulRel); // ------------------- relay.nn.dense // ------------------- relay.nn.contrib_dense_pack +TVM_REGISTER_NODE_TYPE(DensePackAttrs); + // Positional relay function to create dense_pack operator used by frontend FFI. -Expr MakeDensePack(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { - auto attrs = make_object(); +Expr MakeDensePack(Expr data, Expr weight, tvm::String weight_layout, IndexExpr units, + DataType out_dtype) { + auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; + attrs->weight_layout = std::move(weight_layout); static const Op& op = Op::Get("nn.contrib_dense_pack"); return Call(op, {data, weight}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_dense_pack").set_body_typed(MakeDensePack); +bool DensePackRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr || weight == nullptr) return false; + + const DensePackAttrs* param = attrs.as(); + ICHECK(param != nullptr); + + ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported"; + ICHECK_EQ(weight->shape.size(), 3) << "Weight is not packed"; + + Array oshape = data->shape; + oshape.Set(1, weight->shape[0] * weight->shape[2]); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + // assign output type + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + +InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + auto params = attrs.as(); + ICHECK(params); + return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs); +} + RELAY_REGISTER_OP("nn.contrib_dense_pack") .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. -- **data**: `(x1, x2, ..., xn, input_dim)` +- **data**: `(batch, input_dim)` - **weight**: `(units // pack_weight_tile, input_dim, pack_weight_tile)` -- **out**: `(x1, x2, ..., xn, units)`. +- **out**: `(batch, units)`. )code" TVM_ADD_FILELINE) .set_attrs_type() .set_num_inputs(2) - .add_argument("data", "nD Tensor", "Input data.") + .add_argument("data", "2D Tensor", "Input data.") .add_argument("weight", "3D Tensor", "Packed weight matrix.") .set_support_level(10) - .add_type_rel("DensePack", DensePackRel); + .set_attr("FInferCorrectLayout", DensePackInferCorrectLayout) + .add_type_rel("DensePack", DensePackRel); // ------------------- relay.nn.contrib_dense_pack // relay.leaky_relu @@ -307,7 +354,6 @@ bool PReluRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -template InferCorrectLayoutOutput PReluInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, @@ -343,7 +389,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. .add_argument("alpha", "Tensor", "Input channelwise alpha.") .set_support_level(3) .add_type_rel("PRelu", PReluRel) - .set_attr("FInferCorrectLayout", PReluInferCorrectLayout) + .set_attr("FInferCorrectLayout", PReluInferCorrectLayout) .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); @@ -932,88 +978,43 @@ If the input has size k on axis 1, then both gamma and beta have shape (k,). .set_support_level(1) .add_type_rel("GroupNorm", GroupNormRel); -// relay.nn.batch_matmul +// ------------------- relay.nn.batch_matmul TVM_REGISTER_NODE_TYPE(BatchMatmulAttrs); -bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 3); - const auto* x = types[0].as(); - const auto* y = types[1].as(); - if (x == nullptr || y == nullptr) return false; - - const auto* param = attrs.as(); - Array y_shape; - if (param->auto_scheduler_rewritten_layout.size() == 0) { - y_shape = y->shape; - } else { - y_shape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, - {"b", "j", "k"}); - } - - ICHECK(x->shape.size() == 3 && y_shape.size() == 3); - bool is_dyn = false; - Array oshape; - for (size_t i = 0; i < 3; ++i) { - if (x->shape[i].as() != nullptr || y_shape[i].as() != nullptr) { - is_dyn = true; - oshape.push_back(Any()); - } else { - if (i == 0) { - oshape.push_back(max(x->shape[i], y_shape[i])); - } else { - oshape.push_back(x->shape[i]); - } - } - } - if (!is_dyn) { - ICHECK(reporter->AssertEQ(x->shape[0], y_shape[0]) || reporter->AssertEQ(x->shape[0], 1) || - reporter->AssertEQ(y_shape[0], 1)) - << "BatchDot: batch dimensions don't match, " - << " x shape=" << x->shape << ", y shape=" << y_shape; - ICHECK(reporter->AssertEQ(x->shape[2], y_shape[2])) - << "BatchDot: shapes of x and y is inconsistent, " - << " x shape=" << x->shape << ", y shape=" << y_shape; - } - oshape.Set(2, y_shape[1]); - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = x->dtype; - } - // assign output type - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - // Positional relay function to create batch_matmul operator used by frontend FFI. -Expr MakeBatchMatmul(Expr x, Expr y, DataType out_dtype) { +Expr MakeBatchMatmul(Expr tensor_a, Expr tensor_b, DataType out_dtype, bool transpose_a, + bool transpose_b) { auto attrs = make_object(); attrs->out_dtype = out_dtype; + attrs->transpose_a = transpose_a; + attrs->transpose_b = transpose_b; static const Op& op = Op::Get("nn.batch_matmul"); - return Call(op, {x, y}, Attrs(attrs), {}); + return Call(op, {tensor_a, tensor_b}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul").set_body_typed(MakeBatchMatmul); RELAY_REGISTER_OP("nn.batch_matmul") - .describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y` -are data in batch. + .describe(R"code(Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + +Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format +(transpose_a=False, transpose_b=True) by default. .. math:: - batch\_matmul(x, y)[i, :, :] = matmul(x[i, :, :], y[i, :, :]^T) + batch\_matmul(A, B)[i, :, :] = matmul(A[i, :, :], B[i, :, :]^T) -- **x**: `(b, m, k)` -- **y**: `(b, n, k)` +- **tensor_a**: `(b, m, k)` or `(b, k, m)` +- **tensor_b**: `(b, k, n)` or `(b, n, k)` - **out**: `(b, m, n)`. )code" TVM_ADD_FILELINE) .set_num_inputs(2) - .add_argument("x", "3D Tensor", "First input.") - .add_argument("y", "3D Tensor", "Second input.") + .add_argument("tensor_a", "3D Tensor", "The first input.") + .add_argument("tensor_b", "3D Tensor", "The second input.") .set_support_level(10) - .add_type_rel("BatchMatmul", BatchMatmulRel); + .add_type_rel("BatchMatmul", BatchMatmulRel); +// ------------------- relay.nn.batch_matmul // relay.nn.cross_entropy bool CrossEntropyRel(const Array& types, int num_inputs, const Attrs& attrs, diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 29f200c67c59..6bc21473af18 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -24,10 +24,12 @@ #ifndef TVM_RELAY_OP_NN_NN_H_ #define TVM_RELAY_OP_NN_NN_H_ +#include #include #include #include +#include #include #include "../op_common.h" @@ -115,25 +117,55 @@ bool MatmulRel(const Array& types, int num_inputs, const Attrs& attrs, } template -bool DensePackRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { ICHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr || weight == nullptr) return false; + const auto* x = types[0].as(); + const auto* y = types[1].as(); + if (x == nullptr || y == nullptr) return false; const AttrType* param = attrs.as(); ICHECK(param != nullptr); + bool transpose_a = param->transpose_a; + bool transpose_b = param->transpose_b; + const Array& y_shape = + param->auto_scheduler_rewritten_layout.size() == 0 + ? y->shape + : auto_scheduler::GetShapeFromRewrittenLayout( + param->auto_scheduler_rewritten_layout, + transpose_b ? tvm::runtime::Array({"b", "j", "k"}) + : tvm::runtime::Array({"b", "k", "j"})); + ICHECK(x->shape.size() == 3 && y_shape.size() == 3); + const PrimExpr& xb = x->shape[0]; + const PrimExpr& xi = x->shape[transpose_a ? 2 : 1]; + const PrimExpr& xk = x->shape[transpose_a ? 1 : 2]; + const PrimExpr& yb = y_shape[0]; + const PrimExpr& yk = y_shape[transpose_b ? 2 : 1]; + const PrimExpr& yj = y_shape[transpose_b ? 1 : 2]; - Array oshape = data->shape; - oshape.Set((oshape.size() - 1), weight->shape[0] * weight->shape[2]); + bool is_dyn = false; + for (size_t i = 0; i < 3; ++i) { + if (x->shape[i].as() != nullptr || y_shape[i].as() != nullptr) { + is_dyn = true; + break; + } + } + if (!is_dyn) { + ICHECK(reporter->AssertEQ(xb, yb) || reporter->AssertEQ(xb, 1) || reporter->AssertEQ(yb, 1)) + << "BatchDot: batch dimensions don't match, " + << " x shape=" << x->shape << ", y shape=" << y_shape; + ICHECK(reporter->AssertEQ(xk, yk)) << "BatchDot: shapes of x and y is inconsistent, " + << " x shape=" << x->shape << ", y shape=" << y_shape; + } DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { - out_dtype = data->dtype; + out_dtype = x->dtype; } // assign output type - reporter->Assign(types[2], TensorType(oshape, out_dtype)); + const auto& out_b = + xb->IsInstance() || yb->IsInstance() ? tir::Any() : max(xb, yb); + reporter->Assign(types[2], TensorType(Array({out_b, xi, yj}), out_dtype)); return true; } diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index 32b0811b48ac..7d21005cb4db 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -274,10 +274,11 @@ bool SparseConv2dRel(const Array& types, int num_inputs, const Attrs& attr } Expr MakeSparseConv2d(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr, - std::string layout) { + std::string layout, Array kernel_size) { static const Op& op = Op::Get("nn.sparse_conv2d"); auto attrs = make_object(); attrs->layout = std::move(layout); + attrs->kernel_size = std::move(kernel_size); return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 2d14ee37fde0..9f9ed1c075cd 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2353,7 +2353,15 @@ bool BroadCastToRel(const Array& types, int num_inputs, const Attrs& attrs const InitOpAttrs* param = attrs.as(); ICHECK(param); - DataType out_dtype = types[0].as()->dtype; + DataType out_dtype; + if (auto ttype = types[0].as()) { + out_dtype = ttype->dtype; + } else { + ICHECK(types[0].as()) + << "Broadcast: expect to be TensorType but get " << types[0]; + return false; + } + std::vector oshape; const Array& cshape_array = param->shape.value(); diff --git a/src/relay/qnn/op/batch_matmul.cc b/src/relay/qnn/op/batch_matmul.cc new file mode 100644 index 000000000000..4b0bcacacaa1 --- /dev/null +++ b/src/relay/qnn/op/batch_matmul.cc @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/batch_matmul.cc + * \brief Property def of qnn batch_matmul operator. + */ + +#include +#include +#include +#include + +#include "../../op/nn/nn.h" +#include "../../transforms/pattern_utils.h" +#include "../utils.h" + +namespace tvm { +namespace relay { +namespace qnn { + +// relay.op.qnn.batch_matmul + +bool QnnBatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // Expected Types: x, y, x_zero_point, y_zero_point, x_scale, y_scale, + // out_type + ICHECK_EQ(types.size(), 7); + const auto* x = types[0].as(); + const auto* y = types[1].as(); + if (x == nullptr || y == nullptr) return false; + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "BatchMatmulAttrs cannot be nullptr."; + ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8)) + << "Expected quantized batch_matmul type(int8, uint8) for input but was " << x->dtype; + ICHECK(y->dtype == DataType::Int(8) || y->dtype == DataType::UInt(8)) + << "Expected quantized batch_matmul type(int8, uint8) for weight but was " << y->dtype; + ICHECK(param->out_dtype == DataType::Int(32)) + << "Expected quantized batch_matmul type(int32) for output but was " << param->out_dtype; + + // Check the types of scale and zero points. + for (size_t i = 2; i < 5; ++i) { + if (types[i].as()) { + return false; + } + } + ICHECK(IsScalarType(types[2], DataType::Int(32))); // x_zero_point + ICHECK(IsScalarType(types[3], DataType::Int(32))); // y_zero_point + ICHECK(IsScalarType(types[4], DataType::Float(32))); // x_scale + ICHECK(IsScalarType(types[5], DataType::Float(32))); // y_scale + + ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; + + // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay + // BatchMatmul infer type function. + Array tensor_types = {types[0], types[1], types[6]}; + return BatchMatmulRel(tensor_types, 3, attrs, reporter); +} + +// Positional relay function to create quantized batch_matmul operator used by frontend FFI. +Expr MakeQuantizedBatchMatmul(Expr x, Expr y, Expr x_zero_point, Expr y_zero_point, Expr x_scale, + Expr y_scale, DataType out_dtype) { + auto attrs = make_object(); + attrs->out_dtype = out_dtype; + // For legacy reason, currently `qnn.batch_matmul` only supports + // (transpose_a=false, transpose_b=true) + // TODO(jcf94): extent to support all tensor format + attrs->transpose_a = false; + attrs->transpose_b = true; + static const Op& op = Op::Get("qnn.batch_matmul"); + return Call(op, {x, y, x_zero_point, y_zero_point, x_scale, y_scale}, Attrs(attrs), {}); +} + +Expr BatchMatmulFirstTerm(const Expr& quantized_x, const Expr& quantized_y, + const BatchMatmulAttrs* attrs) { + ICHECK(attrs->transpose_a == false && attrs->transpose_b == true) + << "Currently qnn.batch_matmul only supports (transpose_a=false, transpose_b=true)."; + return MakeBatchMatmul(quantized_x, quantized_y, attrs->out_dtype, attrs->transpose_a, + attrs->transpose_b); +} + +Expr BatchMatmulSecondTerm(const Expr& x_quantized_data, const Expr& y_zero_point) { + Array axes = {2}; + return Multiply(y_zero_point, Sum(Cast(x_quantized_data, DataType::Int(32)), axes, true, false)); +} + +Expr BatchMatmulThirdTerm(const Expr& y_quantized_data, const Expr& x_zero_point, + int broadcast_dim_size) { + Array axes = {2}; + auto reducemult = + Multiply(x_zero_point, Sum(Cast(y_quantized_data, DataType::Int(32)), axes, true, false)); + Array newshape; + newshape = {1, 1, broadcast_dim_size}; + return Reshape(reducemult, newshape); +} + +Expr BatchMatmulFourthTerm(int x_zero_point_int, int y_zero_point_int, int reduction_dim_size) { + int32_t scalar_term = x_zero_point_int * y_zero_point_int * reduction_dim_size; + return MakeConstantScalar(DataType::Int(32), scalar_term); +} + +Expr BatchMatmulCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, + const Expr& term4) { + auto data1_term = Subtract(term1, term2); + auto data2_term = Subtract(term4, term3); + return Add(data1_term, data2_term); +} + +/* + * \brief Forward rewrite the qnn batch_matmul op. + * \param attrs The QNN batch_matmul attrs. + * \param new_args The new mutated args to the call node. + * \param arg_types The types of input and output. + * \return The sequence of Relay ops for qnn batch_matmul op. + * \note Lowering of the qnn.batch_matmul operator + * A quantized tensor is represented in following manner + * A = scale_a x (QA - zp_A) + * where QA is quantized tensor, scale_a and zp_A are quantization + * params. + * + * Quantized batch_matmul multiplies two quantized tensors and returns a + * quantized tensor of default dtype of int32, with scale equaling to the + * product of scales of input tensors, and a zero point of zero. + * + * The lowering for asymmetric quantized batch_matmul looks similar to + * quantized conv2d and dense and originally was discussed here: + * https://discuss.tvm.apache.org/t/tf-lite-quantized-conv2d-operator-conversion/2651/7 + * + * The computation gets unrolled into following 4 terms + * C(m, n) = Sigma(k) (X(m, k) * Y(n, k)) + * + * RHS becomes + * Sigma(k) ([QX(m, k) - zp_x] * [QY(n, k) - zp_y]) + * + * Unrolling leads to following sequence + * Sigma(k) QX(m, k) * QX(n, k) // Term1 + * - Sigma(k) zp_y * QX(m, k) // Term2 + * - Sigma(k) zp_x * QY(n, k) // Term3 + * - Sigma(k) * zp_x * zp_y // Term4 + * + * Term4 can be computed at compile time, everything else depending on the + * input type. + */ +Expr QnnBatchMatmulCanonicalize(const Attrs& attrs, const Array& new_args, + const Array& arg_types) { + ICHECK_EQ(new_args.size(), 6); + Expr quantized_x = new_args[0]; + Expr quantized_y = new_args[1]; + Expr x_zero_point = new_args[2]; + Expr y_zero_point = new_args[3]; + + const auto in_shape = get_shape(arg_types[0]); + const int reduction_dim_size = get_const_int(in_shape[2]); + + const auto y_shape = get_shape(arg_types[1]); + const int broadcast_dim_size = get_const_int(y_shape[1]); + + const auto* qnn_batch_matmul_attrs = attrs.as(); + + // Extract the integer zero points. + auto y_zero_point_int = GetScalarFromConstant(y_zero_point); + auto x_zero_point_int = GetScalarFromConstant(x_zero_point); + + // Get all the terms as described in the comments. + auto term1 = BatchMatmulFirstTerm(quantized_x, quantized_y, qnn_batch_matmul_attrs); + auto term2 = BatchMatmulSecondTerm(quantized_x, y_zero_point); + auto term3 = BatchMatmulThirdTerm(quantized_y, x_zero_point, broadcast_dim_size); + auto term4 = BatchMatmulFourthTerm(x_zero_point_int, y_zero_point_int, reduction_dim_size); + + // Combine those 4 terms depending on the zero points to get the best lowering. + if (x_zero_point_int == 0 && y_zero_point_int == 0) { + // term 2, 3 and 4 become zero. + return term1; + } else if (x_zero_point_int == 0 && y_zero_point_int != 0) { + // term 3 and term 4 become zero. + return Subtract(term1, term2); + } else if (x_zero_point_int != 0 && y_zero_point_int == 0) { + // term 2 and term 4 become zero. + return Subtract(term1, term3); + } else { + return BatchMatmulCombineTerms(term1, term2, term3, term4); + } +} + +RELAY_REGISTER_OP("qnn.batch_matmul") + .describe(R"code(Applies a linear transformation: :math:`Z = XY`. +- **data**: quantized(int8, unit8) `(x1, x2, ..., xn, input_dim)` +- **weight**: quantized(int8, unit8) `(units, input_dim)` +- **out**: quantized(int32) `(x1, x2, ..., xn, units)`. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(6) + .add_argument("x", "quantized 2D Tensor", "First input data.") + .add_argument("y", "quantized 2D Tensor", "Second input data.") + .add_argument("x_scale", "Tensor", "The quantization scale of the x input tensor.") + .add_argument("x_zero_point", "Tensor", "The quantization zero_point of the x input tensor.") + .add_argument("y_scale", "Tensor", "The quantization scale of the y input tensor.") + .add_argument("y_zero_point", "Tensor", "The quantization zero_point of the y input tensor.") + .set_support_level(11) + .add_type_rel("QBatchMatmul", QnnBatchMatmulRel) + .set_attr("TNonComputational", true) + .set_attr("FTVMQnnCanonicalize", QnnBatchMatmulCanonicalize); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.batch_matmul").set_body_typed(MakeQuantizedBatchMatmul); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index a5161358865a..cf5266485f2e 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -65,7 +65,6 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, } } ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point - ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale // Kernel scale can be a vector of length output_channels or a scalar. if (param->groups == 1) { @@ -293,7 +292,11 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_ auto multiplied_t2 = reduced_t2; auto one_scalar = MakeConstantScalar(DataType::Int(32), 1); if (!IsEqualScalar(kernel_zero_point, one_scalar)) { - multiplied_t2 = Multiply(kernel_zero_point, reduced_t2); + if (!IsConstScalar(kernel_zero_point)) { + multiplied_t2 = Multiply(MakeRepeat(kernel_zero_point, channel_multiplier, 0), reduced_t2); + } else { + multiplied_t2 = Multiply(kernel_zero_point, reduced_t2); + } } // Reduce the C dimension. Find the dimension. @@ -378,6 +381,25 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i return MakeConstantScalar(DataType::Int(32), scalar_term4); } +/* + * \brief Calculates the fourth term in the qnn.conv2d depthwise lowering sequence + for non-constant zero_points. + * \param input_zero_point The Expr for the input zero point. + * \param kernel_zero_point The Expr for the kernel zero point. + * \param kernel_h The height of kernel. + * \param kernel_w The width of kernel. + * \return The sequence of Relay operators for term4. + * \note The term4 looks like this + * + * Sigma(r, s) zp_a * zp_w + */ +Expr DepthwiseConv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, + int kernel_h, int kernel_w) { + Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w); + Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point); + return Multiply(scalar_term4, variable_term4); +} + /* * \brief Calculates the first term in the qnn.conv2d lowering sequence. * \param data The input expr. @@ -457,6 +479,11 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, auto multiplied_t2 = reduced_t2; auto one_scalar = MakeConstantScalar(DataType::Int(32), 1); if (!IsEqualScalar(kernel_zero_point, one_scalar)) { + if (!IsConstScalar(kernel_zero_point)) { + Layout layout(param->data_layout); + int channel_axis = layout.IndexOf(LayoutAxis::Get('C')); + reduced_t2 = MakeRepeat(reduced_t2, out_channels, channel_axis); + } multiplied_t2 = Multiply(kernel_zero_point, reduced_t2); } return multiplied_t2; @@ -531,6 +558,27 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i return MakeConstantScalar(DataType::Int(32), scalar_term4); } +/* + * \brief Calculates the fourth term in the qnn.conv2d lowering sequence + for non-constant zero_points. + * \param input_zero_point The Expr for the input zero point. + * \param kernel_zero_point The Expr for the kernel zero point. + * \param in_channels The number of input channels. + * \param kernel_h The height of kernel. + * \param kernel_w The width of kernel. + * \return The sequence of Relay operators for term4. + * \note The term4 looks like this + * + * Sigma(c,r,s) zp_a * zp_w + * + */ +Expr Conv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, int in_channels, + int kernel_h, int kernel_w) { + Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), in_channels * kernel_h * kernel_w); + Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point); + return Multiply(scalar_term4, variable_term4); +} + /* * \brief Combines different terms of qnn conv2d lowering. * \param term1 The term1 of qnn conv2d lowering. @@ -656,9 +704,24 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) = GetWorkload(arg_types, param); - // Extract the integer zero points. - auto input_zero_point_int = GetScalarFromConstant(input_zero_point); - auto kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); + // zero points are allowed to be non-scalar. Let's check if that's the case. + bool dynamic_zp = false; + // Use -1 zero point as a default for dynamic. + int input_zero_point_int = -1; + int kernel_zero_point_int = -1; + + // Input zero point can either be a constant or a scalar expression. + if (IsConstScalar(input_zero_point) && (IsConstScalar(kernel_zero_point))) { + // Extract the integer zero points. + input_zero_point_int = GetScalarFromConstant(input_zero_point); + kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); + } else { + // Make kernel_zero_point expression a 1-D tensor for consistent shape. + kernel_zero_point = Reshape(kernel_zero_point, { + -1, + }); + dynamic_zp = true; + } // Fallback to int32 conv if there is dilation with non-zero kernel point or grouped conv2d // For dilated conv, if the kernel zero point is non-zero, the pooling operator also has to @@ -668,8 +731,26 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, ICHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation"; auto dilation_h = get_const_int(param->dilation[0]); auto dilation_w = get_const_int(param->dilation[1]); - if ((kernel_zero_point_int != 0 && (dilation_h != 1 || dilation_w != 1)) || - (param->groups != 1 && !is_depthwise(param))) { + // Check if qnn supports the conv2d parameters. If not, fallback to regular conv2d. + bool supported_dilation = (kernel_zero_point_int == 0) || (dilation_h == 1 && dilation_w == 1); + bool supported_groups = (param->groups == 1 || is_depthwise(param)); + bool conv2d_params_supported = supported_dilation && supported_groups; + + // If we need to fall back to default conv2d, kernel zp may need to be broadcast to kernel_layout. + // Otherwise, we broadcast it to data_layout for qnn lowering. + if (dynamic_zp) { + if (!conv2d_params_supported) { + Layout kernel_layout(param->kernel_layout); + int kernel_axis = kernel_layout.IndexOf(LayoutAxis::Get("O")); + kernel_zero_point = ExpandBiasToMatchAxis(kernel_zero_point, 4, {kernel_axis}); + } else { + Layout data_layout(param->data_layout); + int channel_axis = data_layout.IndexOf(LayoutAxis::Get("C")); + kernel_zero_point = ExpandBiasToMatchAxis(kernel_zero_point, 4, {channel_axis}); + } + } + + if (!conv2d_params_supported) { return Conv2DFallBack(data, weight, input_zero_point, kernel_zero_point, param); } else if (is_depthwise(param)) { ICHECK_NE(channel_multiplier, -1); @@ -679,8 +760,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, kernel_w, channel_multiplier); auto term3 = DepthwiseConv2DThirdTerm(weight, input_zero_point, param, out_channels, channel_multiplier); - auto term4 = - DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, kernel_h, kernel_w); + Expr term4; + if (dynamic_zp) { + term4 = DepthwiseConv2DFourthTerm(input_zero_point, kernel_zero_point, kernel_h, kernel_w); + } else { + term4 = DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, kernel_h, + kernel_w); + } return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int, kernel_zero_point_int); } @@ -690,8 +776,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, auto term2 = Conv2DSecondTerm(padded_data, kernel_zero_point, param, kernel_h, kernel_w, out_channels); auto term3 = Conv2DThirdTerm(weight, input_zero_point, param, out_channels); - auto term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h, - kernel_w); + Expr term4; + if (dynamic_zp) { + term4 = Conv2DFourthTerm(input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w); + } else { + term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h, + kernel_w); + } return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int, kernel_zero_point_int); } diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 751abfc5ca81..2f1d7d8da16c 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -51,14 +51,20 @@ bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* quantize_attrs = attrs.as(); int axis = quantize_attrs->axis; - axis = (axis < 0) ? data->shape.size() + axis : axis; - ICHECK_LT(axis, static_cast(data->shape.size())) - << "axis " << quantize_attrs->axis << " is out of range"; + auto rank = static_cast(data->shape.size()); + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; + ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << quantize_attrs->axis << " is out of range"; ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; + PrimExpr axis_shape; + if (rank > 0) { + axis_shape = data->shape[axis]; + } else { + axis_shape = Integer(1); + } // Check and assign types for scale and zero points. - AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale - AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point + AssignType(types[1], DataType::Float(32), axis_shape, reporter); // scale + AssignType(types[2], DataType::Int(32), axis_shape, reporter); // zero point const Array oshape = data->shape; const DataType out_dtype = quantize_attrs->out_dtype; diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 769f37205790..46de3522061b 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -279,14 +279,20 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const RequantizeAttrs* requantize_attrs = attrs.as(); int axis = requantize_attrs->axis; - axis = (axis == -1) ? data->shape.size() - 1 : axis; - ICHECK_LT(axis, static_cast(data->shape.size())) - << "axis " << requantize_attrs->axis << " is out of range"; + auto rank = static_cast(data->shape.size()); + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; + ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << requantize_attrs->axis << " is out of range"; ICHECK_GE(axis, 0) << "axis " << requantize_attrs->axis << " is out of range"; + PrimExpr axis_shape; + if (rank > 0) { + axis_shape = data->shape[axis]; + } else { + axis_shape = Integer(1); + } // Check and assign types for scale and zero points. - AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // input_scale - AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // input_zero_pt + AssignType(types[1], DataType::Float(32), axis_shape, reporter); // input_scale + AssignType(types[2], DataType::Int(32), axis_shape, reporter); // input_zero_pt // For now, requantize output tensor is limited to full tensor uniform quantization. ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point diff --git a/src/relay/transforms/combine_parallel_batch_matmul.cc b/src/relay/transforms/combine_parallel_batch_matmul.cc index f8c46d93c675..ddab87a4893e 100644 --- a/src/relay/transforms/combine_parallel_batch_matmul.cc +++ b/src/relay/transforms/combine_parallel_batch_matmul.cc @@ -68,6 +68,16 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner { // shape[2] is the contraction axis and automatically consistent // if it were valid batch_matmul ops + // TODO(jcf94): Add full support of layout format + if (!(attrs_a->transpose_a == false && attrs_a->transpose_b == true && + attrs_b->transpose_a == false && attrs_b->transpose_b == true)) { + LOG(WARNING) << "For legacy reason, this pass only supports" + << " (transpose_a=false, transpose_b=true) now, skip combining these two with:" + << " batch_matmul_a: " << attrs_a->transpose_a << ", " << attrs_a->transpose_b + << " batch_matmul_b: " << attrs_b->transpose_a << ", " << attrs_b->transpose_b; + return false; + } + auto res = eq(rhs_a->dtype, rhs_b->dtype) && eq(restype_a->dtype, restype_b->dtype) && (rhs_a->shape.size() == 3) && (rhs_b->shape.size() == 3) && eq(rhs_a->shape[0], rhs_b->shape[0]) && eq(attrs_a->out_dtype, attrs_b->out_dtype); @@ -86,7 +96,8 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner { const auto* origin_attrs = branches[0][0]->attrs.as(); ICHECK(origin_attrs); - return Downcast(MakeBatchMatmul(data, new_weight, origin_attrs->out_dtype)); + return Downcast(MakeBatchMatmul(data, new_weight, origin_attrs->out_dtype, + origin_attrs->transpose_a, origin_attrs->transpose_b)); } bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { return true; } diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 3cd9cca4fec4..d5404ba30f90 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -72,7 +72,8 @@ class ParallelDenseToBatchCombiner : public ParallelOpBatchCombiner { CHECK_EQ(num_args, 2); const auto* origin_attrs = branches[0][0]->attrs.as(); ICHECK(origin_attrs); - return Downcast(MakeBatchMatmul(new_args[0], new_args[1], origin_attrs->out_dtype)); + return Downcast( + MakeBatchMatmul(new_args[0], new_args[1], origin_attrs->out_dtype, false, true)); } virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index a29fdeb37832..e74ea0115857 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -155,6 +155,13 @@ Pass ConvertLayout(const Map>& desired_layouts) { TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout); +TVM_REGISTER_GLOBAL("relay._transform.InferCorrectLayoutOutput") + .set_body_typed([](Array input_layouts, Array output_layouts, Attrs new_attrs) { + return InferCorrectLayoutOutput(input_layouts, output_layouts, new_attrs); + }); + +TVM_REGISTER_NODE_TYPE(InferCorrectLayoutOutputNode); + } // namespace transform } // namespace relay diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index 6e4c03b0fcbc..3f2c25e988f9 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -73,10 +73,12 @@ TVM_REGISTER_GLOBAL("relay.analysis.search_conv2d_op_weight").set_body_typed(Sea class Conv2dToSparseConv2dMutator : public ExprRewriter { public: Conv2dToSparseConv2dMutator(const Array& weight_name, - const Array>& weight_shape, const String& layout) + const Array>& weight_shape, const String& layout, + int kernel_size) : conv2d_op_(Op::Get("nn.conv2d")), sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")) { ICHECK_EQ(weight_name.size(), weight_shape.size()); layout_ = layout; + kernel_size_ = kernel_size; for (size_t i = 0; i < weight_name.size(); ++i) { ICHECK(weight_name[i]->IsInstance()); std::string k = weight_name[i].as()->data; @@ -112,6 +114,7 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter { Var weight_indptr(prefix + ".indptr", ws_indptr_type); auto attrs = make_object(); attrs->layout = std::move(layout_); + attrs->kernel_size = Array{kernel_size_, kernel_size_}; return Call(sparse_conv2d_op_, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs)); } @@ -126,22 +129,168 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter { const Op& sparse_conv2d_op_; std::unordered_map> target_weights_; String layout_; + int kernel_size_; }; // class Conv2dToSparseConv2dAlter Expr Conv2dToSparse(const Expr& e, const Array& weight_name, - const Array>& weight_shape, const String& layout) { - auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout); + const Array>& weight_shape, const String& layout, + int kernel_size) { + auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout, kernel_size); + return PostOrderRewrite(e, &rewriter); +} + +template +auto unpack_to_tuple_internal(elemTy* arr, std::index_sequence) { + return std::make_tuple(arr[Is]...); +} + +template +auto unpack_to_tuple(elemTy* arr) { + return unpack_to_tuple_internal(arr, std::make_index_sequence{}); +} + +struct Range { + size_t dim; + explicit Range(size_t d) : dim(d) {} + + struct iterpoint { + size_t val, lim; + iterpoint(size_t v1, size_t v2) : val(v1), lim(v2) {} + + size_t operator*() const { return val; } + + iterpoint operator/(const iterpoint& rhs) const { + return iterpoint(val * rhs.lim + rhs.val, lim * rhs.lim); + } + }; + + struct iterator { + size_t val, lim; + iterator(size_t v1, size_t v2) : val(v1), lim(v2) {} + + bool operator!=(const iterator& rhs) const { return val != rhs.val; } + + void operator++() { ++val; } + + iterpoint operator*() const { return iterpoint(val, lim); } + }; + + iterator begin() { return iterator(0, dim); } + + iterator end() { return iterator(dim, dim); } +}; + +// Mutate ```nn.conv2d``` to ```nn.sparse_conv2d``` +class Conv2dToSparseConv2dMutator2 : public ExprRewriter { + public: + Conv2dToSparseConv2dMutator2(const String& layout, int kernel_size, int blockH, int blockW, + double sparse_thresh) + : sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")), + dev_cpu0_{DLDeviceType::kDLCPU, 0}, + layout_(layout), + kernel_size_(kernel_size), + blockH_(blockH), + blockW_(blockW), + sparse_thresh_(sparse_thresh) {} + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + // check op type & attrs + const auto pre_attrs = pre->attrs.as(); + if (!pre_attrs || pre_attrs->data_layout != layout_ || + pre_attrs->strides[0].as()->value != 1 || + pre_attrs->kernel_size[0].as()->value != kernel_size_) + return post; + // check constant weight + const auto pre_weight_node = pre->args[1].as(); + if (!pre_weight_node) return post; + + // check weight dtype & shape + auto&& pre_weight = pre_weight_node->data; + auto dtype = pre_weight.DataType(), itype = runtime::DataType::Int(32); + ICHECK(dtype.code() == DataType::kFloat && dtype.bits() == 32); // float32 only + auto pre_weight_shape = unpack_to_tuple<4>(pre_weight.Shape().data()); + int O, I, H, W; + if (layout_ == "NCHW") { + std::tie(O, I, H, W) = pre_weight_shape; + } else { // NHWC + std::tie(H, W, I, O) = pre_weight_shape; + } + int CO = O, CI = H * W * I; + + // copy to vector + std::vector pre_weight_data(CO * CI); + pre_weight.CopyToBytes(pre_weight_data.data(), pre_weight_data.size() * sizeof(float)); + if (layout_ == "NHWC") { + std::vector tmp(pre_weight_data.size()); + for (auto i : Range(CO)) + for (auto j : Range(CI)) tmp[*(i / j)] = pre_weight_data[*(j / i)]; + std::swap(tmp, pre_weight_data); + } + // convert to BSR + std::vector wdata, block(blockH_ * blockW_); + std::vector windices, windptr; + for (auto bh : Range(CO / blockH_)) { + windptr.push_back(windices.size()); + for (auto bw : Range(CI / blockW_)) { + int cntnnz = 0; + for (auto i : Range(blockH_)) + for (auto j : Range(blockW_)) { + auto tmp = pre_weight_data[*(bh / i / bw / j)]; + if (tmp) cntnnz++; + block[*(i / j)] = tmp; + } + if (cntnnz) { + wdata.insert(wdata.end(), block.begin(), block.end()); + windices.push_back(*bw); + } + } + } + windptr.push_back(windices.size()); + double sprate = 1 - 1.0 * wdata.size() / pre_weight_data.size(); + if (sprate < sparse_thresh_) return post; + + // constrct return data + int nnz = windices.size(); + auto weight_data = runtime::NDArray::Empty({nnz, blockH_, blockW_}, dtype, dev_cpu0_); + auto weight_indices = runtime::NDArray::Empty({nnz}, itype, dev_cpu0_); + auto weight_indptr = runtime::NDArray::Empty({CO / blockH_ + 1}, itype, dev_cpu0_); + weight_data.CopyFromBytes(wdata.data(), wdata.size() * sizeof(float)); + weight_indices.CopyFromBytes(windices.data(), windices.size() * sizeof(int32_t)); + weight_indptr.CopyFromBytes(windptr.data(), windptr.size() * sizeof(int32_t)); + + // construct return call + auto args = runtime::Array{post.as()->args[0], Constant(weight_data), + Constant(weight_indices), Constant(weight_indptr)}; + auto attrs = make_object(); + attrs->layout = layout_; + attrs->kernel_size = Array{kernel_size_, kernel_size_}; + return Call(sparse_conv2d_op_, args, Attrs(attrs)); + } + + private: + const Op& sparse_conv2d_op_; + DLDevice dev_cpu0_; + String layout_; + int kernel_size_, blockH_, blockW_; + double sparse_thresh_; +}; // class Conv2dToSparseConv2dMutator2 + +Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, int blockH, int blockW, + double sparse_thresh) { + auto rewriter = Conv2dToSparseConv2dMutator2(layout, kernel_size, blockH, blockW, sparse_thresh); return PostOrderRewrite(e, &rewriter); } namespace transform { +// Convert a model with seperate weight info (already sparsified). Pass Conv2dToSparse(const Array& weight_name, const Array>& weight_shape, - const String& layout) { + const String& layout, int kernel_size) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { // Remove FreeVar warnings - auto f0 = Downcast(Conv2dToSparse(f, weight_name, weight_shape, layout)); + auto f0 = + Downcast(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); Array sparse_params = FreeVars(f0); auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); Array params = FreeVars(f1); @@ -155,6 +304,20 @@ Pass Conv2dToSparse(const Array& weight_name, const Array pass_func = + [=](Function f, IRModule m, PassContext pc) { + auto f0 = Downcast( + Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh)); + return f0; + }; + return CreateFunctionPass(pass_func, 5, "Conv2dToSparse2", {"DeadCodeElimination"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse2").set_body_typed(Conv2dToSparse2); + } // namespace transform } // namespace relay diff --git a/src/relay/transforms/defuse_ops.cc b/src/relay/transforms/defuse_ops.cc index 6abf4c31d359..d7a9bfde57c3 100644 --- a/src/relay/transforms/defuse_ops.cc +++ b/src/relay/transforms/defuse_ops.cc @@ -41,19 +41,13 @@ class DefuseOpsMutator : public ExprMutator { public: class FuncBodyMutator : public ExprMutator { public: - explicit FuncBodyMutator(const Array& args) : ExprMutator() { args_ = args; } - - Expr VisitExpr_(const VarNode* n) { - const std::string& name = n->name_hint(); - ICHECK(!name.empty() && (name[0] == 'p')); - std::string id_str = name.substr(1); - int id = std::stoi(id_str); - ICHECK(id >= 0 && size_t(id) < args_.size()); - return args_[id]; - } + explicit FuncBodyMutator(std::unordered_map args) + : ExprMutator(), name_to_args_(std::move(args)) {} + + Expr VisitExpr_(const VarNode* n) { return name_to_args_[n->name_hint()]; } private: - Array args_; + std::unordered_map name_to_args_; }; Expr VisitExpr_(const CallNode* n) { @@ -62,7 +56,15 @@ class DefuseOpsMutator : public ExprMutator { if (const auto* call = new_n.as()) { if (const auto* func = call->op.as()) { if (func->body->IsInstance()) { - return FuncBodyMutator(call->args).Mutate(func->body); + std::unordered_map name_to_args; + for (size_t i = 0; i < func->params.size(); ++i) { + const std::string& pname = func->params[i]->name_hint(); + ICHECK(name_to_args.cend() == name_to_args.find(pname)) + << "Found multiple parameters share the same variable name `" << pname + << "` which introduces uncertainty in DefuseOps pass"; + name_to_args[pname] = call->args[i]; + } + return FuncBodyMutator(std::move(name_to_args)).Mutate(func->body); } } } diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index f883b4113656..b5f434e74c43 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -23,10 +23,14 @@ * to actual integer operations. */ +#include #include #include #include +namespace tvm { +namespace relay { + /* Description of FakeQuantizationToInteger * * The purpose of this pass is to find regions of the graph that follow @@ -63,65 +67,6 @@ * rewritten subgraph and the processing continues */ -namespace tvm { -namespace relay { - -/*! - * \brief AffineType representation - * \sa AffineType - */ -class AffineTypeNode : public Object { - public: - /*! \brief The scale of this type */ - Expr scale; - /*! \brief The zero point of this type */ - Expr zero_point; - /*! \brief The data type of this type */ - DataType dtype; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("scale", &scale); - v->Visit("zero_point", &zero_point); - v->Visit("dtype", &dtype); - } - - bool SEqualReduce(const AffineTypeNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); - return equal(scale, other->scale) && equal(zero_point, other->zero_point) && - equal(dtype, other->dtype); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce->MarkGraphNode(); - hash_reduce(scale); - hash_reduce(zero_point); - hash_reduce(dtype); - } - - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; - static constexpr const char* _type_key = "AffineTypeNode"; - TVM_DECLARE_BASE_OBJECT_INFO(AffineTypeNode, Object); -}; - -/*! - * \brief Managed reference to AffineTypes. - * \sa AffineTypeNode - */ -class AffineType : public ObjectRef { - public: - TVM_DLL AffineType(Expr scale, Expr zero_point, DataType dtype) { - ObjectPtr n = make_object(); - n->scale = std::move(scale); - n->zero_point = std::move(zero_point); - n->dtype = std::move(dtype); - data_ = std::move(n); - } - TVM_DEFINE_OBJECT_REF_METHODS(AffineType, ObjectRef, AffineTypeNode); -}; - -TVM_REGISTER_NODE_TYPE(AffineTypeNode); - using ExprSet = std::unordered_set; using ExprMap = std::unordered_map; using AffineTypeMap = Map; @@ -147,8 +92,14 @@ class SubgraphExtractor : public ExprVisitor { } const AffineTypeMap GetAffineTypes() { return affine_types_; } void VisitExpr(const Expr& expr) override { + // When looking for fake quantized subgraphs, we only support data-flow regions of the graph, + // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we + // abort the rewrite. if (expr.as() == nullptr && expr.as() == nullptr && - expr.as() == nullptr) { + expr.as() == nullptr && expr.as() == nullptr && + expr.as() == nullptr) { + LOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside" + << " a fake quantize region, aborting this rewrite"; is_fake_quantized_ = false; } else { ExprVisitor::VisitExpr(expr); @@ -162,13 +113,14 @@ class SubgraphExtractor : public ExprVisitor { VisitExpr(call_node->args[0]); // Collect type of quantize ops affine_types_.Set(GetRef(call_node), - AffineType(call_node->args[1], call_node->args[2], - call_node->checked_type().as()->dtype)); + TensorAffineType(call_node->args[1], call_node->args[2], + call_node->checked_type().as()->dtype)); } else if (call_node->op == dequantize_op_) { // Collect type of dequantize ops - affine_types_.Set(GetRef(call_node), - AffineType(call_node->args[1], call_node->args[2], - call_node->args[0]->checked_type().as()->dtype)); + affine_types_.Set( + GetRef(call_node), + TensorAffineType(call_node->args[1], call_node->args[2], + call_node->args[0]->checked_type().as()->dtype)); } else { // run normally on everything else. ExprVisitor::VisitExpr_(call_node); @@ -225,19 +177,38 @@ class SubgraphMutator : public ExprMutator { } // Call the rewrite Array vals = fqfq[op](expr, affine_types_); - // Save teh outputs of the rewrite - ICHECK(vals.size() == 4) + // Save the outputs of the rewrite + ICHECK(vals.size() == 2) << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for " << AsText(op, false); out = Downcast(vals[0]); - affine_types_.Set(out, AffineType(Downcast(vals[1]), Downcast(vals[2]), - DataType(String2DLDataType(Downcast(vals[3]))))); + affine_types_.Set(out, Downcast(vals[1])); } else { ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node " << AsText(GetRef(call_node), false); } return out; } + + Expr VisitExpr_(const TupleNode* node) { + Expr expr = ExprMutator::VisitExpr_(node); + auto new_node = expr.as(); + Array types; + for (Expr field : new_node->fields) { + ICHECK(affine_types_[field].as()); + types.push_back(Downcast(affine_types_[field])); + } + affine_types_.Set(expr, TupleAffineType(types)); + return expr; + } + + Expr VisitExpr_(const TupleGetItemNode* node) { + Expr expr = ExprMutator::VisitExpr_(node); + auto tuple_type = affine_types_[expr.as()->tuple].as(); + affine_types_.Set(expr, tuple_type->types[node->index]); + return expr; + } + ExprSet subgraph_; AffineTypeMap affine_types_; AffineType out_type_; diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 57603035b848..d545518c1c3c 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -229,35 +229,16 @@ class ConstantFolder : public MixedModeMutator { } // Constant evaluate an expression. Expr ConstEvaluate(Expr expr) { - std::vector passes = {transform::FuseOps(0), transform::ToANormalForm(), - transform::InferType()}; - Function func; - if (expr.as()) { - func = Downcast(expr); - } else { - // TODO(@jroesch): fix this - func = Function(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {}); - } - auto mod = IRModule({}, module_->type_definitions, module_->Imports()); - auto global = GlobalVar("main"); - mod->Add(global, func); - auto seq = transform::Sequential(passes); - mod = seq(mod); - auto entry_func = Downcast(mod->Lookup("main")); - expr = expr.as() == nullptr ? entry_func->body : entry_func; - - using tvm::transform::PassContext; Device dev; dev.device_type = kDLCPU; dev.device_id = 0; Target target = Target("llvm"); - // use a fresh build context - // in case we are already in a build context. + + // use a fresh build context in case we are already in a build context. // needed for both execution and creation(due to JIT) - With fresh_build_ctx(PassContext::Create()); + With fresh_build_ctx(transform::PassContext::Create()); - FInterpreter executor = CreateInterpreter(mod, dev, target); - return ObjectToExpr(executor(expr)); + return ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), dev, target)); } // Evaluate a call to the shape_of operator for tensors with constant diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index a93532895b5a..7056dfe79fee 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -202,7 +202,7 @@ using FForwardRewrite = TypedPackedFunc Prepare(const Expr& body) { this->Update(body, NullValue()); @@ -585,15 +585,22 @@ RELAY_REGISTER_OP("nn.conv2d") Expr ForwardFoldScaleAxis(const Expr& data) { auto message = ForwardPrep().Prepare(data); - auto fcontext = [&](const Call& call) -> ObjectRef { - auto it = message.find(call.get()); - if (it != message.end()) { - return it->second; - } else { - return ObjectRef(nullptr); + for (const auto& m : message) { + if (m.second.defined()) { + // run optimization + auto fcontext = [&](const Call& call) -> ObjectRef { + auto it = message.find(call.get()); + if (it != message.end()) { + return it->second; + } else { + return ObjectRef(nullptr); + } + }; + return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext); } - }; - return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext); + } + // no messages - no optimization + return data; } //---------------------------------------- @@ -618,7 +625,7 @@ using FBackwardTransform = // Generic Visitors for FScaleAxisBackward //---------------------------------------------- -class BackwardPrep : private ExprVisitor { +class BackwardPrep : private MixedModeVisitor { public: // The message on each node. std::unordered_map Prepare(const Expr& body) { @@ -643,6 +650,14 @@ class BackwardPrep : private ExprVisitor { // We only allow propagation of scale backward // if the expression is only referred by a single parent. if (rit->second != 1) return; + Array in_messages = GetInMessages(call); + Message out_message = f(GetRef(call), in_messages); + if (out_message.defined()) { + message_[call] = out_message; + } + } + + Array GetInMessages(const CallNode* call) { Array in_messages; for (Expr arg : call->args) { auto it = message_.find(arg.get()); @@ -652,52 +667,34 @@ class BackwardPrep : private ExprVisitor { in_messages.push_back(NullValue()); } } - Message out_message = f(GetRef(call), in_messages); - if (out_message.defined()) { - message_[call] = out_message; - } + return in_messages; } }; -class BackwardTransformerNode : public Object, private ExprMutator { +/* + * Hybrid apporach is used with the transformation + * itself is recursive but the traversal is non-recursive + */ +class BackwardTransformerNode : public Object, private MixedModeMutator { public: + using MixedModeMutator::Mutate; // Run forward transform. Expr Fold(Expr expr) { message_ = BackwardPrep().Prepare(expr); - return this->Mutate(expr); - } - /*! - * \brief Transform the expr to consider the scaling. - * - * \param expr The input expression. - * \param axes The axes to scale. - * \param scale The scale applied to the axes. - * \return The result of transformation. - */ - Expr Transform(const Expr& expr, Message message, Expr scale) { - // NOTE: the result of Transform is memoized. - if (const CallNode* call_node = expr.as()) { - return Transform(call_node, message, scale); - } else { - ICHECK(!message.defined()) << "outstanding scale"; - return ExprMutator::VisitExpr(expr); + for (const auto& m : message_) { + if (m.second.defined()) { + // run optimization + return this->Mutate(expr); + } } + // no messages - no optimization + return expr; } + /*! - * \brief Normal way of mutating call node. - * \param call_node The call node to be mutated. - * \return the result of the call Mutation. + * \brief Transform the expr to consider the scaling. */ - Expr NormalCallTransform(const CallNode* call_node) { - const Call call = GetRef(call_node); - const auto it = memo_.find(call); - if (it != memo_.end()) { - return it->second; - } - Expr new_expr = ExprMutator::VisitExpr_(call_node); - memo_[call] = new_expr; - return new_expr; - } + Expr Transform(const Expr& expr, Message message, Expr scale); /*! * \brief Get the message propogated to the expr. * \param expr The expresison. @@ -719,11 +716,12 @@ class BackwardTransformerNode : public Object, private ExprMutator { // Valid axes on each node. std::unordered_map message_; // Override mutation of call. - Expr VisitExpr_(const CallNode* call_node) final { - return Transform(call_node, NullValue(), NullValue()); + Expr Rewrite_(const CallNode* call_node, const Expr& post) final { + return Transform(GetRef(call_node), NullValue(), NullValue()); } - // Transform of CallNode. - Expr Transform(const CallNode* call_node, Message message, Expr scale); + + public: + Expr NormalCallTransform(const CallNode* call_node) { return ExprMutator::VisitExpr_(call_node); } }; class BackwardTransformer : public ObjectRef { @@ -736,21 +734,39 @@ class BackwardTransformer : public ObjectRef { using ContainerType = BackwardTransformerNode; }; -Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message message, Expr scale) { - static const auto& ftransform = Op::GetAttrMap("FScaleAxisBackwardTransform"); - auto f = ftransform.get(call_node->op, nullptr); - if (f != nullptr) { +/*! + * \brief Transform the expr to consider the scaling. + * + * \param expr The input expression. + * \param message The axes to scale. + * \param scale The scale applied to the axes. + * \return The result of transformation. + */ +Expr BackwardTransformerNode::Transform(const Expr& expr, Message message, Expr scale) { + if (const CallNode* call_node = expr.as()) { + static const auto& ftransform = + Op::GetAttrMap("FScaleAxisBackwardTransform"); + auto f = ftransform.get(call_node->op, nullptr); const Call call = GetRef(call_node); - const auto it = memo_.find(call); - if (it != memo_.end()) { - return it->second; + // ignore if there is a message + if (!message.defined()) { + const auto it = memo_.find(call); + if (it != memo_.end()) { + return it->second; + } + } + Expr new_expr = NullValue(); + if (f != nullptr) { + new_expr = f(call, message, scale, GetRef(this)); + } else { + ICHECK(!message.defined()) << "outstanding scale"; + new_expr = NormalCallTransform(call.operator->()); } - Expr new_expr = f(GetRef(call_node), message, scale, GetRef(this)); memo_[call] = new_expr; return new_expr; } else { ICHECK(!message.defined()) << "outstanding scale"; - return NormalCallTransform(call_node); + return this->Mutate(expr); } } @@ -813,6 +829,7 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } + Message lhs_message = transformer->GetMessage(call->args[0]); Message rhs_message = transformer->GetMessage(call->args[1]); StructuralEqual equal; @@ -959,7 +976,9 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp } else { wscale = ReshapeToMatchAxis(scale, weight->type_as()->shape, {big_ko_axis, small_ko_axis}); - if (!wscale.defined()) return transformer->NormalCallTransform(call.operator->()); + if (!wscale.defined()) { + return transformer->NormalCallTransform(call.operator->()); + } } weight = Multiply(weight, wscale); return Call(call->op, {data, weight}, call->attrs, call->type_args); diff --git a/src/relay/transforms/infer_layout_utils.h b/src/relay/transforms/infer_layout_utils.h index 5aedb9ff75d4..76d6aa646f4c 100644 --- a/src/relay/transforms/infer_layout_utils.h +++ b/src/relay/transforms/infer_layout_utils.h @@ -97,7 +97,16 @@ class InferCorrectLayoutOutputNode : public Object { Array input_layouts; Array output_layouts; Attrs new_attrs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("input_layouts", &input_layouts); + v->Visit("output_layouts", &output_layouts); + v->Visit("new_attrs", &new_attrs); + } + TVM_DECLARE_BASE_OBJECT_INFO(InferCorrectLayoutOutputNode, Object); + + static constexpr const char* _type_key = "relay._transform.InferCorrectLayoutOutput"; }; class InferCorrectLayoutOutput : public ObjectRef { diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index b61567d0bae0..657e2c392455 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -41,7 +41,8 @@ #include #include -#include "../backend/compile_engine.h" +#include "../backend/te_compiler.h" +#include "../backend/te_compiler_cache.h" #include "../op/memory/memory.h" #include "../op/vm/vm.h" #include "./pass_utils.h" @@ -49,6 +50,7 @@ #include "pattern_utils.h" using namespace tvm::runtime; +using namespace tvm::relay::tec; namespace tvm { namespace relay { @@ -271,9 +273,11 @@ class DialectRewriter : public ExprMutator { Array EmitShapeFunc(LetList* scope, const Function& func, const std::vector& new_args) { Array shape_func_ins; - auto engine = CompileEngine::Global(); + + TECompiler compiler; + CCacheKey key(func, target_host_); - auto cfunc = engine->LowerShapeFunc(key); + auto cfunc = compiler->LowerShapeFunc(key); auto input_states = cfunc->shape_func_param_states; Array is_inputs; diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 9572faf08714..ccdd9c92cc27 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -526,6 +526,8 @@ bool StatefulOp(const Expr& e) { using FInterpreter = runtime::TypedPackedFunc; +Target CPUTarget() { return Target("llvm"); } + Device CPUDevice() { Device dev; dev.device_type = kDLCPU; @@ -533,17 +535,6 @@ Device CPUDevice() { return dev; } -FInterpreter CPUInterpreter() { - using tvm::transform::PassContext; - - Target target = Target("llvm"); - // use a fresh build context - // in case we are already in a build context. - With fresh_build_ctx(PassContext::Create()); - - return CreateInterpreter(IRModule(nullptr), CPUDevice(), target); -} - using FuncId = int; /*! @@ -904,13 +895,9 @@ class PartialEvaluator : public ExprFunctor // Constant evaluate an expression. PStatic ConstEvaluate(const Expr& expr, LetList* ll) { - std::vector passes = {transform::FuseOps(0), transform::InferType()}; - auto mod = IRModule::FromExpr(expr); - auto seq = transform::Sequential(passes); - mod = seq(mod); - auto entry_func = Downcast(mod->Lookup("main")); - auto fused_infered = expr.as() == nullptr ? entry_func->body : entry_func; - return Reify(executor_(fused_infered), ll); + // use a fresh build context in case we are already in a build context. + With fresh_build_ctx(transform::PassContext::Create()); + return Reify(Eval(expr, mod_->type_definitions, mod_->Imports(), CPUDevice(), CPUTarget()), ll); } Func ConstEvaluateFunc(const Expr& expr) { @@ -1137,7 +1124,6 @@ class PartialEvaluator : public ExprFunctor std::unordered_map fuel_map_; Store store_; Device device_ = CPUDevice(); - FInterpreter executor_ = CPUInterpreter(); }; /*! \brief Remap multiple Var sharing the same Id into the same Var. */ diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 68f31a17ab1b..b48fbe44bd11 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -113,12 +113,19 @@ struct RegionFuncMetadata { class Partitioner : public MixedModeMutator { public: explicit Partitioner(const IRModule& module) : module_(module) { + std::set func_names; for (auto f : module->functions) { GlobalVar f_var = f.first; BaseFunc f_func = f.second; + std::string f_name = f_var.as()->name_hint; + while (func_names.find(f_name) != func_names.end()) { + f_name += "_a"; + } + func_names.insert(f_name); // Creating regionset per function in the module. - auto region_set = AnnotatedRegionSet::Create(f_func, CompilerBeginOp(), CompilerEndOp()); + auto region_set = + AnnotatedRegionSet::Create(f_func, CompilerBeginOp(), CompilerEndOp(), f_name); regions_sets_[region_set] = f_func; } } @@ -301,7 +308,7 @@ class Partitioner : public MixedModeMutator { } std::string target = end_node->attrs.as()->compiler; - std::string name = target + "_" + std::to_string(region->GetID()); + std::string name = target + "_" + region->GetName() + "_" + std::to_string(region->GetID()); // Constant propagation if (!params_bind.empty()) { @@ -502,7 +509,7 @@ class NameMangleExtFuncs : public MixedModeMutator { // Walk the tree and mangle the functions. Then replace compiler functions // with mangled functions in the module - IRModule new_module; + IRModule new_module = IRModule({}, module_->type_definitions, module_->Imports()); for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); diff --git a/src/relay/transforms/to_basic_block_normal_form.cc b/src/relay/transforms/to_basic_block_normal_form.cc index 79157bba1918..d03fc1488aea 100644 --- a/src/relay/transforms/to_basic_block_normal_form.cc +++ b/src/relay/transforms/to_basic_block_normal_form.cc @@ -51,8 +51,11 @@ Expr ToBasicBlockNormalFormAux(const Expr& e) { IRModule ToBasicBlockNormalForm(const IRModule& mod) { DLOG(INFO) << "ToBBlock:" << std::endl << mod; + // Create a new module by shallow copy. + auto mod_ = IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map); + tvm::Map updates; - auto funcs = mod->functions; + auto funcs = mod_->functions; for (const auto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables"; if (const auto* n = it.second.as()) { @@ -63,12 +66,12 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) { } for (auto pair : updates) { - mod->Add(pair.first, pair.second, true); + mod_->Add(pair.first, pair.second, true); } - DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod; + DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod_; - return mod; + return mod_; } bool BasicBlockNormalFormCheck(const Expr& e) { diff --git a/src/runtime/contrib/miopen/miopen_utils.cc b/src/runtime/contrib/miopen/miopen_utils.cc index 426d2f24ddf5..ff748b8826de 100644 --- a/src/runtime/contrib/miopen/miopen_utils.cc +++ b/src/runtime/contrib/miopen/miopen_utils.cc @@ -89,6 +89,10 @@ void ConvEntry::CleanWorkspace() { workspace_size = 0; } +SoftmaxEntry::SoftmaxEntry() { MIOPEN_CALL(miopenCreateTensorDescriptor(&shape_desc)); } + +SoftmaxEntry::~SoftmaxEntry() { MIOPEN_CALL(miopenDestroyTensorDescriptor(shape_desc)); } + } // namespace miopen } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/miopen/miopen_utils.h b/src/runtime/contrib/miopen/miopen_utils.h index d3a8c7b9ad64..76913696b0b9 100644 --- a/src/runtime/contrib/miopen/miopen_utils.h +++ b/src/runtime/contrib/miopen/miopen_utils.h @@ -62,11 +62,18 @@ struct ConvEntry { void CleanWorkspace(); }; // ConvThreadEntry +struct SoftmaxEntry { + miopenTensorDescriptor_t shape_desc; + SoftmaxEntry(); + ~SoftmaxEntry(); +}; // SoftmaxEntry + struct MIOpenThreadEntry { MIOpenThreadEntry(); ~MIOpenThreadEntry(); miopenHandle_t handle{nullptr}; ConvEntry conv_entry; + SoftmaxEntry softmax_entry; runtime::DeviceAPI* rocm_api{nullptr}; static MIOpenThreadEntry* ThreadLocal(); }; // MIOpenThreadEntry diff --git a/src/runtime/contrib/miopen/softmax.cc b/src/runtime/contrib/miopen/softmax.cc new file mode 100644 index 000000000000..5a0f24ed7a84 --- /dev/null +++ b/src/runtime/contrib/miopen/softmax.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/miopen/softmax.cc + * \brief Use external miopen softmax function + */ +#include +#include + +#include "miopen_utils.h" + +namespace tvm { +namespace contrib { +namespace miopen { + +using namespace runtime; + +void softmax_impl(TVMArgs args, TVMRetValue* ret, miopenSoftmaxAlgorithm_t alg) { + DLTensor* x = args[0]; + DLTensor* y = args[1]; + int axis = args[2]; + int ndim = x->ndim; + int64_t* shape = x->shape; + if (axis < 0) axis += ndim; + ICHECK(axis >= 0 && axis < ndim); + // just fp32 for now + ICHECK(TypeMatch(x->dtype, kDLFloat, 32)); + ICHECK(TypeMatch(y->dtype, kDLFloat, 32)); + + MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); + + miopenSoftmaxMode_t mode; + if (axis == ndim - 1) { + int64_t N = 1; + for (int i = 0; i < ndim - 1; ++i) { + N *= shape[i]; + } + mode = MIOPEN_SOFTMAX_MODE_INSTANCE; + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->softmax_entry.shape_desc, miopenFloat, + static_cast(N), static_cast(shape[ndim - 1]), + 1, 1)); + } else { + int64_t pre_axis_dim = 1; + int64_t post_axis_dim = 1; + for (int i = 0; i < ndim; ++i) { + if (i < axis) { + pre_axis_dim *= shape[i]; + } else if (i > axis) { + post_axis_dim *= shape[i]; + } + } + mode = MIOPEN_SOFTMAX_MODE_CHANNEL; + MIOPEN_CALL(miopenSet4dTensorDescriptor( + entry_ptr->softmax_entry.shape_desc, miopenFloat, static_cast(pre_axis_dim), + static_cast(shape[axis]), static_cast(post_axis_dim), 1)); + } + + const float alpha = 1.f; + const float beta = 0.f; + MIOPEN_CALL(miopenSoftmaxForward_V2(entry_ptr->handle, &alpha, + entry_ptr->softmax_entry.shape_desc, x->data, &beta, + entry_ptr->softmax_entry.shape_desc, y->data, alg, mode)); +} + +TVM_REGISTER_GLOBAL("tvm.contrib.miopen.softmax.forward") + .set_body([](TVMArgs args, TVMRetValue* ret) { + softmax_impl(args, ret, MIOPEN_SOFTMAX_ACCURATE); + }); + +TVM_REGISTER_GLOBAL("tvm.contrib.miopen.log_softmax.forward") + .set_body([](TVMArgs args, TVMRetValue* ret) { softmax_impl(args, ret, MIOPEN_SOFTMAX_LOG); }); + +} // namespace miopen +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/papi/papi.cc b/src/runtime/contrib/papi/papi.cc new file mode 100644 index 000000000000..b9ba8f9984e9 --- /dev/null +++ b/src/runtime/contrib/papi/papi.cc @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { +namespace profiling { + +#define PAPI_CALL(func) \ + { \ + int e = (func); \ + if (e != PAPI_OK) { \ + LOG(FATAL) << "PAPIError: in function " #func " " << e << " " \ + << std::string(PAPI_strerror(e)); \ + } \ + } + +static const std::unordered_map> default_metric_names = { + {kDLCPU, + {"perf::CYCLES", "perf::STALLED-CYCLES-FRONTEND", "perf::STALLED-CYCLES-BACKEND", + "perf::INSTRUCTIONS", "perf::CACHE-MISSES"}}, + {kDLCUDA, {"cuda:::event:elapsed_cycles_sm:device=0"}}}; + +/*! \brief Object that holds the values of counters at the start of a function call. */ +struct PAPIEventSetNode : public Object { + /*! \brief The starting values of counters for all metrics of a specific device. */ + std::vector start_values; + /*! \brief The device these counters are for. */ + Device dev; + + explicit PAPIEventSetNode(std::vector start_values, Device dev) + : start_values(start_values), dev(dev) {} + + static constexpr const char* _type_key = "PAPIEventSetNode"; + TVM_DECLARE_FINAL_OBJECT_INFO(PAPIEventSetNode, Object); +}; + +/* Get the PAPI component id for the given device. + * \param dev The device to get the component for. + * \returns PAPI component id for the device. Returns -1 if the device is not + * supported by PAPI. + */ +int component_for_device(Device dev) { + std::string component_name; + switch (dev.device_type) { + case kDLCPU: + component_name = "perf_event"; + break; + case kDLCUDA: + component_name = "cuda"; + break; + case kDLROCM: + component_name = "rocm"; + break; + default: + LOG(WARNING) << "PAPI does not support device " << DeviceName(dev.device_type); + return -1; + } + int cidx = PAPI_get_component_index(component_name.c_str()); + if (cidx < 0) { + LOG(FATAL) << "Cannot find PAPI component \"" << component_name + << "\". Maybe you need to build PAPI with support for this component (use " + "`./configure --with-components=" + << component_name << "`)."; + } + return cidx; +} + +/*! \brief MetricCollectorNode for PAPI metrics. + * + * PAPI (Performance Application Programming Interface) collects metrics on a + * variety of platforms including cpu, cuda and rocm. + * + * PAPI is avaliable at https://bitbucket.org/icl/papi/src/master/. + */ +struct PAPIMetricCollectorNode final : public MetricCollectorNode { + /*! \brief Construct a metric collector that collects a specific set of metrics. + * + * \param metrics A mapping from a device type to the metrics that should be + * collected on that device. You can find the names of available metrics by + * running `papi_native_avail`. + */ + explicit PAPIMetricCollectorNode(Map> metrics) { + for (auto& p : metrics) { + papi_metric_names[p.first->device] = {}; + for (auto& metric : p.second) { + papi_metric_names[p.first->device].push_back(metric); + } + } + } + explicit PAPIMetricCollectorNode() {} + + /*! \brief Initialization call. + * \param devices The devices this collector will be running on + */ + void Init(Array devices) { + if (!PAPI_is_initialized()) { + if (sizeof(long_long) > sizeof(int64_t)) { + LOG(WARNING) << "PAPI's long_long is larger than int64_t. Overflow may occur when " + "reporting metrics."; + } + CHECK_EQ(PAPI_library_init(PAPI_VER_CURRENT), PAPI_VER_CURRENT) + << "Error while initializing PAPI"; + } + + // If no metrics were provided we use the default set. The names were not + // initialized in the constructor because we did not know which devices we + // were running on. + if (papi_metric_names.size() == 0) { + for (auto wrapped_device : devices) { + Device device = wrapped_device->device; + auto it = default_metric_names.find(device.device_type); + if (it != default_metric_names.end()) { + papi_metric_names[device] = it->second; + } + } + } + + // create event sets for each device + for (auto wrapped_device : devices) { + Device device = wrapped_device->device; + int cidx = component_for_device(device); + // unknown device, skipping + if (cidx < 0) { + continue; + } + + auto it = papi_metric_names.find(device); + // skip devices with no metrics defined + if (it == papi_metric_names.end() || it->second.size() == 0) { + continue; + } + auto& metric_names = it->second; + + const PAPI_component_info_t* component = PAPI_get_component_info(cidx); + if (component->disabled) { + std::string help_message = ""; + switch (device.device_type) { + case kDLCPU: + help_message = + "Try setting `sudo sh -c 'echo 1 >/proc/sys/kernel/perf_event_paranoid'`"; + break; + case kDLCUDA: + help_message = + "Try enabling gpu profiling with `modprobe nvidia " + "NVreg_RestrictProfilingToAdminUsers=0`. If that does not work, try adding " + "`options nvidia \"NVreg_RestrictProfilingToAdminUsers=0\"` to " + "`/etc/modprobe.d/nvidia-kernel-common.conf`."; + break; + default: + break; + } + LOG(WARNING) << "PAPI could not initialize counters for " << DeviceName(device.device_type) + << ": " << component->disabled_reason << "\n" + << help_message; + continue; + } + + int event_set = PAPI_NULL; + PAPI_CALL(PAPI_create_eventset(&event_set)); + PAPI_CALL(PAPI_assign_eventset_component(event_set, cidx)); + if (device.device_type == kDLCPU) { + // we set PAPI_INHERIT to make it so threads created after this inherit the event_set. + PAPI_option_t opt; + memset(&opt, 0x0, sizeof(PAPI_option_t)); + opt.inherit.inherit = PAPI_INHERIT_ALL; + opt.inherit.eventset = event_set; + PAPI_CALL(PAPI_set_opt(PAPI_INHERIT, &opt)); + } + + if (static_cast(metric_names.size()) > PAPI_num_cmp_hwctrs(cidx)) { + PAPI_CALL(PAPI_set_multiplex(event_set)); + } + + // add all the metrics + for (auto metric : metric_names) { + int e = PAPI_add_named_event(event_set, metric.c_str()); + if (e != PAPI_OK) { + LOG(FATAL) << "PAPIError: " << e << " " << std::string(PAPI_strerror(e)) << ": " << metric + << "."; + } + } + // Because we may have multiple calls in flight at the same time, we + // start all the timers when we initialize. Then we calculate the metrics + // counts for a call by comparing counter values at the start vs end of + // the call. + PAPI_CALL(PAPI_start(event_set)); + event_sets[device] = event_set; + } + } + /*! \brief Called right before a function call. Reads starting values of the + * measured metrics. + * + * \param dev The device the function will be run on. + * \returns A `PAPIEventSetNode` containing values for the counters at the + * start of the call. Passed to a corresponding `Stop` call. + */ + ObjectRef Start(Device dev) final { + // Record counter values at the start of the call, so we can calculate the + // metrics for the call by comparing the values at the end of the call. + auto it = event_sets.find(dev); + if (it != event_sets.end()) { + int event_set = it->second; + std::vector values(papi_metric_names[dev].size()); + PAPI_CALL(PAPI_read(event_set, values.data())); + return ObjectRef(make_object(values, dev)); + } else { + return ObjectRef(nullptr); + } + } + /*! \brief Called right after a function call. Reads ending values of the + * measured metrics. Computes the change in each metric from the + * corresponding `Start` call. + * + * \param obj `PAPIEventSetNode` created by a call to `Start`. + * \returns A mapping from metric name to value. + */ + Map Stop(ObjectRef obj) final { + const PAPIEventSetNode* event_set_node = obj.as(); + std::vector end_values(papi_metric_names[event_set_node->dev].size()); + PAPI_CALL(PAPI_read(event_sets[event_set_node->dev], end_values.data())); + std::unordered_map reported_metrics; + for (size_t i = 0; i < end_values.size(); i++) { + if (end_values[i] < event_set_node->start_values[i]) { + LOG(WARNING) << "Detected overflow when reading performance counter, setting value to -1."; + reported_metrics[papi_metric_names[event_set_node->dev][i]] = + ObjectRef(make_object(-1)); + } else { + reported_metrics[papi_metric_names[event_set_node->dev][i]] = + ObjectRef(make_object(end_values[i] - event_set_node->start_values[i])); + } + } + return reported_metrics; + } + + ~PAPIMetricCollectorNode() final { + for (auto p : event_sets) { + PAPI_CALL(PAPI_stop(p.second, NULL)); + PAPI_CALL(PAPI_cleanup_eventset(p.second)); + PAPI_CALL(PAPI_destroy_eventset(&p.second)); + } + } + + /*! \brief Device-specific event sets. Contains the running counters (the int values) for that + * device. */ + std::unordered_map event_sets; + /*! \brief Device-specific metric names. Order of names matches the order in the corresponding + * `event_set`. */ + std::unordered_map> papi_metric_names; + + static constexpr const char* _type_key = "runtime.profiling.PAPIMetricCollector"; + TVM_DECLARE_FINAL_OBJECT_INFO(PAPIMetricCollectorNode, MetricCollectorNode); +}; + +/*! \brief Wrapper for `PAPIMetricCollectorNode`. */ +class PAPIMetricCollector : public MetricCollector { + public: + explicit PAPIMetricCollector(Map> metrics) { + data_ = make_object(metrics); + } + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PAPIMetricCollector, MetricCollector, + PAPIMetricCollectorNode); +}; + +MetricCollector CreatePAPIMetricCollector(Map> metrics) { + return PAPIMetricCollector(metrics); +} + +TVM_REGISTER_OBJECT_TYPE(PAPIEventSetNode); +TVM_REGISTER_OBJECT_TYPE(PAPIMetricCollectorNode); + +TVM_REGISTER_GLOBAL("runtime.profiling.PAPIMetricCollector") + .set_body_typed([](Map> metrics) { + return PAPIMetricCollector(metrics); + }); + +} // namespace profiling +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 66f36ffa50d6..cee723c510f4 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -27,6 +27,8 @@ #include #include +#include "../../../../3rdparty/compiler-rt/builtin_fp16.h" + namespace tvm { namespace contrib { @@ -42,6 +44,24 @@ bool CompareDescend(const std::pair& lhs, const std::pair rhs.second; } +struct float16 { + uint16_t bits; + float to_float() const { + return __extendXfYf2__(bits); + } +}; + +template <> +bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { + return lhs.second.to_float() < rhs.second.to_float(); +} + +template <> +bool CompareDescend(const std::pair& lhs, + const std::pair& rhs) { + return lhs.second.to_float() > rhs.second.to_float(); +} + // Argsort implemented C library sort for nms. // Return indices of sorted tensor. // By default, the last axis will be used to sort. @@ -125,7 +145,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms").set_body([](TVMArgs args, TV }); template -void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, bool is_argsort) { +void sort_impl( + DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, + std::function&)> epilogue) { auto data_ptr = static_cast(input->data); auto out_ptr = static_cast(output->data); std::vector> sorter; @@ -153,14 +175,8 @@ void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, } else { std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); } - if (is_argsort) { - for (int64_t k = 0; k < input->shape[axis]; ++k) { - out_ptr[base_idx + k * axis_mul_after] = static_cast(sorter[k].first); - } - } else { - for (int64_t k = 0; k < input->shape[axis]; ++k) { - out_ptr[base_idx + k * axis_mul_after] = static_cast(sorter[k].second); - } + for (int64_t k = 0; k < input->shape[axis]; ++k) { + epilogue(out_ptr, base_idx + k * axis_mul_after, sorter[k]); } } } @@ -168,12 +184,20 @@ void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, template void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { - return sort_impl(input, output, axis, is_ascend, true); + return sort_impl( + input, output, axis, is_ascend, + [](OutType* out_ptr, size_t index, const std::pair& sort_pair) { + out_ptr[index] = static_cast(sort_pair.first); + }); } template void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { - return sort_impl(input, output, axis, is_ascend, false); + return sort_impl( + input, output, axis, is_ascend, + [](DataType* out_ptr, size_t index, const std::pair& sort_pair) { + out_ptr[index] = sort_pair.second; + }); } // Argsort implemented C library sort. @@ -254,6 +278,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body([](TVMArgs args, TVMRet } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } + } else if (data_dtype == "float16") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } } else { LOG(FATAL) << "Unsupported input dtype: " << data_dtype; } @@ -295,6 +331,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.sort").set_body([](TVMArgs args, TVMRetVal sort(input, output, axis, is_ascend); } else if (data_dtype == "int64") { sort(input, output, axis, is_ascend); + } else if (data_dtype == "float16") { + sort(input, output, axis, is_ascend); } else { LOG(FATAL) << "Unsupported input dtype: " << data_dtype; } @@ -432,6 +470,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body([](TVMArgs args, TVMRetVal } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } + } else if (data_dtype == "float16") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } } else { LOG(FATAL) << "Unsupported input dtype: " << data_dtype; } diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index d8182b0e8378..08ac2ae0ec45 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -163,10 +163,19 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { auto profile = builder_->createOptimizationProfile(); for (int i = 0; i < network_->getNbInputs(); ++i) { auto name = network_->getInput(i)->getName(); - auto dims = network_->getInput(i)->getDimensions(); - profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, dims); + const uint32_t entry_id = entry_id_map_[name]; + std::vector shape(data_entry_[entry_id]->shape, + data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); + auto dims = VectorToTrtDims(shape); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, dims); profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, dims); + // Set minimum batch size to 1 when dynamic batching is used. + if (network_->getInput(i)->getDimensions().nbDims >= 1 && + network_->getInput(i)->getDimensions().d[0] == -1) { + dims.d[0] = 1; + } + profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, dims); } config_->addOptimizationProfile(profile); } diff --git a/src/runtime/contrib/tensorrt/tensorrt_logger.h b/src/runtime/contrib/tensorrt/tensorrt_logger.h index eb0164210dbb..5406f4c57d66 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_logger.h +++ b/src/runtime/contrib/tensorrt/tensorrt_logger.h @@ -39,7 +39,7 @@ class TensorRTLogger : public nvinfer1::ILogger { public: TensorRTLogger() : TensorRTLogger(Severity::kWARNING) {} explicit TensorRTLogger(Severity severity) : reportable_severity(severity) {} - void log(Severity severity, const char* msg) override { + void log(Severity severity, const char* msg) noexcept override { // suppress messages with severity enum value greater than the reportable if (severity > reportable_severity) return; diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 7197172d73db..94bbae1559d9 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -835,7 +835,7 @@ class SplitOpConverter : public TensorRTOpConverter { std::vector start(input_dims.size(), 0); std::vector size(input_dims.begin(), input_dims.end()); std::vector strides(input_dims.size(), 1); - for (int i = 0; i < split_sizes.size(); ++i) { + for (size_t i = 0; i < split_sizes.size(); ++i) { start[axis] = split_starts[i]; size[axis] = split_sizes[i]; auto slice_layer = params->network->addSlice(*input, VectorToTrtDims(start), @@ -1174,9 +1174,14 @@ class BatchMatmulOpConverter : public TensorRTOpConverter { BatchMatmulOpConverter() : TensorRTOpConverter({kTensor, kTensor}) {} void Convert(TensorRTOpConverterParams* params) const { + auto transa = std::stoi(params->node.GetAttr>("transpose_a")[0]); + auto transb = std::stoi(params->node.GetAttr>("transpose_b")[0]); + nvinfer1::MatrixOperation trt_transa = + transa ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE; + nvinfer1::MatrixOperation trt_transb = + transb ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE; nvinfer1::IMatrixMultiplyLayer* matmul_layer = params->network->addMatrixMultiply( - *params->inputs.at(0).tensor, nvinfer1::MatrixOperation::kNONE, - *params->inputs.at(1).tensor, nvinfer1::MatrixOperation::kTRANSPOSE); + *params->inputs.at(0).tensor, trt_transa, *params->inputs.at(1).tensor, trt_transb); ICHECK(matmul_layer != nullptr); params->outputs.push_back(matmul_layer->getOutput(0)); } diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index 6358e59ce3bc..5562f853383c 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -140,6 +140,12 @@ class TensorRTRuntime : public JSONRuntimeBase { const std::string name = nodes_[nid].GetOpName() + "_" + std::to_string(j); int binding_index = engine->getBindingIndex(name.c_str()); ICHECK_NE(binding_index, -1); + if (!use_implicit_batch_) { + std::vector shape(data_entry_[eid]->shape, + data_entry_[eid]->shape + data_entry_[eid]->ndim); + auto dims = VectorToTrtDims(shape); + ICHECK(context->setBindingDimensions(binding_index, dims)); + } if (data_entry_[eid]->device.device_type == kDLCUDA) { bindings[binding_index] = data_entry_[eid]->data; } else { @@ -300,7 +306,7 @@ class TensorRTRuntime : public JSONRuntimeBase { helper.DeclareField("inputs", &engine_and_context.inputs); helper.DeclareField("outputs", &engine_and_context.outputs); helper.ReadAllFields(&reader); - const int batch_size = 1; + const int batch_size = GetBatchSize(); trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] = engine_and_context; return true; } diff --git a/src/runtime/crt/Makefile b/src/runtime/crt/Makefile index f458a2f08002..99efdda62ee9 100644 --- a/src/runtime/crt/Makefile +++ b/src/runtime/crt/Makefile @@ -68,7 +68,6 @@ endef LIBS = \ src/runtime/crt/common \ src/runtime/crt/graph_executor \ - src/runtime/crt/aot_executor \ src/runtime/crt/graph_executor_module \ src/runtime/crt/memory \ src/runtime/crt/microtvm_rpc_common \ diff --git a/src/runtime/crt/aot_executor/aot_executor.c b/src/runtime/crt/aot_executor/aot_executor.c deleted file mode 100644 index d34639bc30de..000000000000 --- a/src/runtime/crt/aot_executor/aot_executor.c +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Main entry point for - * \param model Model descriptor structure to reference for runtime information - * \param inputs Pointer to input pointer(s) - * \param outputs Pointer to output pointer(s) - * \param context Context information to be passed through to operators - * \return tvm_status_t containing success or errors from the model run - */ -#include -#include - -tvm_crt_error_t tvm_runtime_run(const tvm_model_t* model, void** inputs, void** outputs) { - static DLDevice fake_device = {kDLCPU, 0}; - static int64_t fake_dims = 0; - static int64_t fake_shape = {0}; - - DLTensor tensors[model->num_input_tensors + model->num_output_tensors]; // NOLINT - TVMValue tvm_values[model->num_input_tensors + model->num_output_tensors]; // NOLINT - int32_t tvm_typeids[model->num_input_tensors + model->num_output_tensors]; // NOLINT - - for (size_t i = 0; i < model->num_input_tensors; i++) { - tensors[i].device = fake_device; - tensors[i].data = inputs[i]; - tensors[i].shape = &fake_shape; - tensors[i].ndim = fake_dims; - tensors[i].byte_offset = 0; - tensors[i].strides = NULL; - tvm_values[i].v_handle = &tensors[i]; - } - - for (size_t i = 0; i < model->num_output_tensors; i++) { - size_t j = model->num_input_tensors + i; - tensors[j].device = fake_device; - tensors[j].data = outputs[i]; - tensors[j].shape = &fake_shape; - tensors[j].ndim = fake_dims; - tensors[j].byte_offset = 0; - tensors[j].strides = NULL; - tvm_values[j].v_handle = &tensors[j]; - } - - return (tvm_crt_error_t)model->run_func(tvm_values, tvm_typeids, 0, NULL, 0, NULL); -} diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index f34bbd4fec95..04721ee6d705 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -398,20 +398,13 @@ int RPCGetCRTMaxPacketSize(TVMValue* args, int* type_codes, int num_args, TVMVal tvm_crt_error_t TVMInitializeRuntime() { int idx = 0; tvm_crt_error_t error = kTvmErrorNoError; - void* func_registry_memory = NULL; DLDevice dev = {kDLCPU, 0}; - error = TVMPlatformMemoryAllocate(TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES, dev, - &func_registry_memory); - if (error != kTvmErrorNoError) { - return error; - } void* registry_backing_memory; error = TVMPlatformMemoryAllocate(TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES, dev, ®istry_backing_memory); if (error != kTvmErrorNoError) { - TVMPlatformMemoryFree(func_registry_memory, dev); return error; } @@ -441,7 +434,6 @@ tvm_crt_error_t TVMInitializeRuntime() { if (error != kTvmErrorNoError) { TVMPlatformMemoryFree(registry_backing_memory, dev); - TVMPlatformMemoryFree(func_registry_memory, dev); } return error; diff --git a/src/runtime/crt/common/func_registry.c b/src/runtime/crt/common/func_registry.c index cc1ba56c8cdc..116a5c496f1b 100644 --- a/src/runtime/crt/common/func_registry.c +++ b/src/runtime/crt/common/func_registry.c @@ -99,7 +99,6 @@ tvm_crt_error_t TVMMutableFuncRegistry_Create(TVMMutableFuncRegistry* reg, uint8 return kTvmErrorBufferTooSmall; } - memset(reg, 0, sizeof(*reg)); reg->registry.names = (const char*)buffer; buffer[0] = 0; // number of functions present in buffer. buffer[1] = 0; // end of names list marker. diff --git a/src/runtime/crt/crt_config-template.h b/src/runtime/crt/crt_config-template.h index 907559421e5d..7949aea6f171 100644 --- a/src/runtime/crt/crt_config-template.h +++ b/src/runtime/crt/crt_config-template.h @@ -24,6 +24,12 @@ #ifndef TVM_RUNTIME_CRT_CRT_CONFIG_TEMPLATE_H_ #define TVM_RUNTIME_CRT_CRT_CONFIG_TEMPLATE_H_ +/*! Log level of the CRT runtime */ +#define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG + +/*! Support low-level debugging in MISRA-C runtime */ +#define TVM_CRT_DEBUG 0 + /*! Maximum supported dimension in NDArray */ #define TVM_CRT_MAX_NDIM 6 @@ -31,7 +37,7 @@ #define TVM_CRT_MAX_ARGS 10 /*! Size of the global function registry, in bytes. */ -#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200 +#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 250 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 @@ -48,9 +54,6 @@ /*! \brief Maximum length of a PackedFunc function name. */ #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 -/*! \brief DLDataType for the return value from strlen */ -#define TVM_CRT_STRLEN_DLTYPE 10 - /*! \brief Enable checks to enforce the stack allocator with a FIFO ordering. Off by default */ // #define TVM_CRT_STACK_ALLOCATOR_ENABLE_FIFO_CHECK diff --git a/src/runtime/crt/graph_executor/graph_executor.c b/src/runtime/crt/graph_executor/graph_executor.c index bf64096441be..c2e465361651 100644 --- a/src/runtime/crt/graph_executor/graph_executor.c +++ b/src/runtime/crt/graph_executor/graph_executor.c @@ -100,8 +100,10 @@ void TVMGraphExecutorNode_LoadAttrs(TVMGraphExecutorNode* node, JSONReader* read } else if (!strcmp(key, "flatten_data")) { param->flatten_data = strtoul(value, 0, 10); bitmask |= 8; +#if TVM_CRT_DEBUG } else { - fprintf(stderr, "do not support key %s", key); + printf("do not support key %s", key); +#endif // TVM_CRT_DEBUG } } if (bitmask != (1 | 2 | 4 | 8)) { @@ -130,7 +132,7 @@ int TVMGraphExecutorNode_Load(TVMGraphExecutorNode* node, JSONReader* reader) { } bitmask |= 2; } else if (!strcmp(key, "inputs")) { - size_t count = node->inputs_count; + size_t count = 0; reader->BeginArray(reader); size_t num_inputs = 0; if (reader->ArrayLength(reader, &num_inputs) != 0) { @@ -265,8 +267,8 @@ int TVMGraphExecutorGraphAttr_Load(TVMGraphExecutorGraphAttr* attr, JSONReader* break; } DLDevice dev = {kDLCPU, 0}; - tvm_crt_error_t err = - TVMPlatformMemoryAllocate(TVM_CRT_STRLEN_DLTYPE * num_items, dev, (void**)&attr->dltype); + tvm_crt_error_t err = TVMPlatformMemoryAllocate(TVM_CRT_MAX_STRLEN_DLTYPE * num_items, dev, + (void**)&attr->dltype); if (err != kTvmErrorNoError) { fprintf(stderr, "memory allocate error: %08x", err); return -1; @@ -278,8 +280,8 @@ int TVMGraphExecutorGraphAttr_Load(TVMGraphExecutorGraphAttr* attr, JSONReader* status = -1; return status; } - status = reader->ReadString(reader, attr->dltype + dltype_count * TVM_CRT_STRLEN_DLTYPE, - TVM_CRT_STRLEN_DLTYPE); + status = reader->ReadString(reader, attr->dltype + dltype_count * TVM_CRT_MAX_STRLEN_DLTYPE, + TVM_CRT_MAX_STRLEN_DLTYPE); if (status != 0) { fprintf(stderr, "error reading dltype array item"); break; @@ -792,14 +794,14 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl // read names char* names = NULL; DLDevice dev = {kDLCPU, 0}; - tvm_crt_error_t err = - TVMPlatformMemoryAllocate(TVM_CRT_STRLEN_NAME * executor->nodes_count, dev, (void**)&names); + tvm_crt_error_t err = TVMPlatformMemoryAllocate( + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * executor->nodes_count, dev, (void**)&names); if (err != kTvmErrorNoError) { fprintf(stderr, "memory allocate error: %08x", err); status = -1; return status; } - memset(names, 0, TVM_CRT_STRLEN_NAME * executor->nodes_count); + memset(names, 0, TVM_CRT_MAX_STRLEN_FUNCTION_NAME * executor->nodes_count); uint64_t names_count; int idx; memcpy(&names_count, bptr, sizeof(names_count)); @@ -808,11 +810,11 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl uint64_t name_length; memcpy(&name_length, bptr, sizeof(name_length)); bptr += sizeof(name_length); - if (name_length >= TVM_CRT_STRLEN_NAME) { + if (name_length >= TVM_CRT_MAX_STRLEN_FUNCTION_NAME) { fprintf(stderr, "Error: function name longer than expected.\n"); status = -1; } - memcpy(names + TVM_CRT_STRLEN_NAME * idx, bptr, name_length); + memcpy(names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx, bptr, name_length); bptr += name_length; } @@ -827,9 +829,10 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl } for (idx = 0; idx < size; idx++) { - int32_t in_idx = TVMGraphExecutor_GetInputIndex(executor, names + TVM_CRT_STRLEN_NAME * idx); + int32_t in_idx = + TVMGraphExecutor_GetInputIndex(executor, names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx); CHECK_GT(in_idx, 0, "Found param for non-existent input: %s\n", - names + TVM_CRT_STRLEN_NAME * idx); + names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx); uint32_t eid = TVMGraphExecutor_GetEntryId(executor, executor->input_nodes[in_idx], 0); if (!(eid < executor->data_entry_count)) { fprintf(stderr, "`entry_id`=%d is greater than expected(%d).\n", eid, @@ -855,7 +858,7 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl #if TVM_CRT_DEBUG TVMNDArray* entry = &(executor->data_entry[eid]); printf("loading: param %s loaded, in_idx=%d, eid=%d, ndim=%d, data[0]=%f\n", - names + TVM_CRT_STRLEN_NAME * idx, in_idx, eid, entry->dl_tensor.ndim, + names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx, in_idx, eid, entry->dl_tensor.ndim, ((float*)entry->dl_tensor.data)[0]); // NOLINT(*) #endif // TVM_CRT_DEBUG } @@ -937,7 +940,7 @@ int TVMGraphExecutor_SetupStorage(TVMGraphExecutor* executor) { return -1; } for (idx = 0; idx < attrs->dltype_count; idx++) { - vtype[idx] = String2DLDataType(attrs->dltype + idx * TVM_CRT_STRLEN_DLTYPE); + vtype[idx] = String2DLDataType(attrs->dltype + idx * TVM_CRT_MAX_STRLEN_DLTYPE); } // Size and device type of each storage pool entry. @@ -1088,9 +1091,10 @@ int TVMGraphExecutor_SetupOpExecs(TVMGraphExecutor* executor) { printf("tvm_op: creating %s with node_id=%d\n", inode->param.func_name, nid); #endif // TVM_CRT_DEBUG TVMPackedFunc pf; - TVMGraphExecutor_CreateTVMOp(executor, &(inode->param), args, args_count, inode->inputs_count, - &pf); + TVMGraphExecutor_CreateTVMOp(executor, &(inode->param), args, args_count, &pf); executor->op_execs[nid] = pf; + } else { + memset(&executor->op_execs[nid], 0, sizeof(TVMPackedFunc)); } } return status; @@ -1109,7 +1113,7 @@ typedef struct TVMOpArgs { int32_t TVMGraphExecutor_CreateTVMOp(TVMGraphExecutor* executor, const TVMOpParam* param, DLTensorPtr* args, const uint32_t args_count, - uint32_t num_inputs, TVMPackedFunc* pf) { + TVMPackedFunc* pf) { int status = 0; uint32_t idx; TVMOpArgs arg_ptr; diff --git a/src/runtime/crt/host/Makefile b/src/runtime/crt/host/Makefile new file mode 100644 index 000000000000..efed3c438699 --- /dev/null +++ b/src/runtime/crt/host/Makefile @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +INCLUDES ?= -isystem crt/include -Icrt_config +CFLAGS ?= -Werror -Wall +CXXFLAGS ?= -Werror -Wall -std=c++11 +LDFLAGS ?= -Werror -Wall + +# Codegen produces spurious lines like: int32_t arg2_code = ((int32_t*)arg_type_ids)[(2)]; +MODEL_CFLAGS ?= -Wno-error=unused-variable + +AR ?= ${PREFIX}ar +CC ?= ${PREFIX}gcc +CXX ?= ${PREFIX}g++ +RANLIB ?= ${PREFIX}ranlib + +QUIET ?= @ + +PWD = $(shell pwd) +BUILD_DIR = build +CRT_LIB_NAMES = microtvm_rpc_server microtvm_rpc_common graph_executor graph_executor_module common memory +CRT_LIBS = $(patsubst %, $(BUILD_DIR)/crt/lib%.a, $(CRT_LIB_NAMES)) + +CRT_INCLUDES = $(glob crt/include/**) + +$(BUILD_DIR)/crt/lib%.a: $(glob crt/src/runtime/%/*.c) + ${QUIET}cd crt && $(MAKE) \ + BUILD_DIR=../$(BUILD_DIR)/crt \ + CRT_CONFIG=$(PWD)/crt_config/crt_config.h \ + EXTRA_CFLAGS="$(CFLAGS)" \ + EXTRA_CXXFLAGS="$(CXXFLAGS)" \ + EXTRA_LDFLAGS="$(EXTRA_LDFLAGS)" \ + $(patsubst $(BUILD_DIR)/crt/lib%.a,%,$@) + +crt: $(CRT_LIBS) +.PHONY: crt + +# Compile codegen files +$(BUILD_DIR)/model/codegen/host/%.o: model/codegen/host/%.c + ${QUIET}mkdir -p $(dir $@) + ${QUIET}$(CC) $(INCLUDES) $(CFLAGS) $(MODEL_CFLAGS) -c -o "$@" "$<" + +MODEL_LIBS = \ + $(patsubst model/codegen/host/src/%.c, $(BUILD_DIR)/model/codegen/host/src/%.o, $(wildcard model/codegen/host/src/*.c)) \ + $(wildcard model/codegen/host/lib/*.o) + +# Compile src/ files +build/%.o: src/%.cc + ${QUIET}mkdir -p $(dir $@) + ${QUIET}$(CXX) $(INCLUDES) $(CXXFLAGS) -c -o "$@" "$<" + +SRCS = $(wildcard src/*.cc) +OBJS = $(patsubst src/%.cc,build/%.o,$(SRCS)) + +build/main: ${OBJS} ${MODEL_LIBS} ${CRT_LIBS} + ${QUIET}mkdir -p $(dir $@) + ${QUIET}$(CXX) $(LDFLAGS) -o "$@" $^ + +all: build/main +.PHONY = all + +.DEFAULT_GOAL = all diff --git a/src/runtime/crt/host/microtvm_api_server.py b/src/runtime/crt/host/microtvm_api_server.py new file mode 100644 index 000000000000..5f9019817e82 --- /dev/null +++ b/src/runtime/crt/host/microtvm_api_server.py @@ -0,0 +1,200 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import fcntl +import os +import os.path +import pathlib +import select +import shutil +import subprocess +import tarfile +import time +from tvm.micro.project_api import server + + +PROJECT_DIR = pathlib.Path(os.path.dirname(__file__) or os.path.getcwd()) + + +MODEL_LIBRARY_FORMAT_RELPATH = "model.tar" + + +IS_TEMPLATE = not os.path.exists(os.path.join(PROJECT_DIR, MODEL_LIBRARY_FORMAT_RELPATH)) + + +class Handler(server.ProjectAPIHandler): + + BUILD_TARGET = "build/main" + + def __init__(self): + super(Handler, self).__init__() + self._proc = None + + def server_info_query(self, tvm_version): + return server.ServerInfo( + platform_name="host", + is_template=IS_TEMPLATE, + model_library_format_path="" + if IS_TEMPLATE + else PROJECT_DIR / MODEL_LIBRARY_FORMAT_RELPATH, + project_options=[server.ProjectOption("verbose", help="Run make with verbose output")], + ) + + # These files and directories will be recursively copied into generated projects from the CRT. + CRT_COPY_ITEMS = ("include", "Makefile", "src") + + # The build target given to make + BUILD_TARGET = "build/main" + + def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): + # Make project directory. + project_dir.mkdir(parents=True) + + # Copy ourselves to the generated project. TVM may perform further build steps on the generated project + # by launching the copy. + shutil.copy2(__file__, project_dir / os.path.basename(__file__)) + + # Place Model Library Format tarball in the special location, which this script uses to decide + # whether it's being invoked in a template or generated project. + project_model_library_format_path = project_dir / MODEL_LIBRARY_FORMAT_RELPATH + shutil.copy2(model_library_format_path, project_model_library_format_path) + + # Extract Model Library Format tarball.into /model. + extract_path = project_dir / project_model_library_format_path.stem + with tarfile.TarFile(project_model_library_format_path) as tf: + os.makedirs(extract_path) + tf.extractall(path=extract_path) + + # Populate CRT. + crt_path = project_dir / "crt" + os.mkdir(crt_path) + for item in self.CRT_COPY_ITEMS: + src_path = standalone_crt_dir / item + dst_path = crt_path / item + if os.path.isdir(src_path): + shutil.copytree(src_path, dst_path) + else: + shutil.copy2(src_path, dst_path) + + # Populate Makefile. + shutil.copy2(pathlib.Path(__file__).parent / "Makefile", project_dir / "Makefile") + + # Populate crt-config.h + crt_config_dir = project_dir / "crt_config" + crt_config_dir.mkdir() + shutil.copy2( + os.path.join(os.path.dirname(__file__), "..", "crt_config-template.h"), + os.path.join(crt_config_dir, "crt_config.h"), + ) + + # Populate src/ + src_dir = os.path.join(project_dir, "src") + os.mkdir(src_dir) + shutil.copy2( + os.path.join(os.path.dirname(__file__), "main.cc"), os.path.join(src_dir, "main.cc") + ) + + def build(self, options): + args = ["make"] + if options.get("verbose"): + args.append("QUIET=") + + args.append(self.BUILD_TARGET) + + subprocess.check_call(args, cwd=PROJECT_DIR) + + def flash(self, options): + pass # Flashing does nothing on host. + + def _set_nonblock(self, fd): + flag = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flag | os.O_NONBLOCK) + new_flag = fcntl.fcntl(fd, fcntl.F_GETFL) + assert (new_flag & os.O_NONBLOCK) != 0, "Cannot set file descriptor {fd} to non-blocking" + + def open_transport(self, options): + self._proc = subprocess.Popen( + [self.BUILD_TARGET], stdin=subprocess.PIPE, stdout=subprocess.PIPE, bufsize=0 + ) + self._set_nonblock(self._proc.stdin.fileno()) + self._set_nonblock(self._proc.stdout.fileno()) + return server.TransportTimeouts( + session_start_retry_timeout_sec=0, + session_start_timeout_sec=0, + session_established_timeout_sec=0, + ) + + def close_transport(self): + if self._proc is not None: + proc = self._proc + self._proc = None + proc.terminate() + proc.wait() + + def _await_ready(self, rlist, wlist, timeout_sec=None, end_time=None): + if timeout_sec is None and end_time is not None: + timeout_sec = max(0, end_time - time.monotonic()) + + rlist, wlist, xlist = select.select(rlist, wlist, rlist + wlist, timeout_sec) + if not rlist and not wlist and not xlist: + raise server.IoTimeoutError() + + return True + + def read_transport(self, n, timeout_sec): + if self._proc is None: + raise server.TransportClosedError() + + fd = self._proc.stdout.fileno() + end_time = None if timeout_sec is None else time.monotonic() + timeout_sec + + try: + self._await_ready([fd], [], end_time=end_time) + to_return = os.read(fd, n) + except BrokenPipeError: + to_return = 0 + + if not to_return: + self.disconnect_transport() + raise server.TransportClosedError() + + return to_return + + def write_transport(self, data, timeout_sec): + if self._proc is None: + raise server.TransportClosedError() + + fd = self._proc.stdin.fileno() + end_time = None if timeout_sec is None else time.monotonic() + timeout_sec + + data_len = len(data) + while data: + self._await_ready([], [fd], end_time=end_time) + try: + num_written = os.write(fd, data) + except BrokenPipeError: + num_written = 0 + + if not num_written: + self.disconnect_transport() + raise server.TransportClosedError() + + data = data[num_written:] + + +if __name__ == "__main__": + server.main(Handler()) diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/aot_executor/aot_executor.h b/src/runtime/crt/include/tvm/runtime/crt/internal/aot_executor/aot_executor.h deleted file mode 100644 index edfd0ecd4f54..000000000000 --- a/src/runtime/crt/include/tvm/runtime/crt/internal/aot_executor/aot_executor.h +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief TVM Executor for the Ahead-of-Time Runtime - * - * AOT models are described by the TVM model descriptor format - * which can be passed to tvm_runtime_run. These descriptors will be - * generated by the AOT compilation process. This can optionally be - * augmented with platform specific context to be passed to the TVM - * operators. - * - * Example: - * extern tvm_model_t my_network; - * int main() { - * void* data = get_data(); - * void* output[4] = {0, 0, 0, 0}; - * void* inputs = {data}; - * void* outputs = {output}; - * tvm_context_t my_context = { - * .driver = ...; - * }; - * tvm_runtime_run( - * &my_network, - * inputs, - * outputs - * &my_context - * ); - * return 0; - * } - */ - -#ifndef TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_AOT_EXECUTOR_AOT_EXECUTOR_H_ -#define TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_AOT_EXECUTOR_AOT_EXECUTOR_H_ - -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -/*! - * \brief TVM Model descriptor to describe the - * model to the runtime. - */ -typedef struct { - size_t num_input_tensors; /** Number of expected input tensors */ - size_t num_output_tensors; /** Number of expected output tensors */ - TVMBackendPackedCFunc run_func; /** Generated model function, called through tvm_runtime_run */ -} tvm_model_t; - -/*! - * \brief Main entry point to execute the AOT runner function - * \param model Model descriptor structure to reference for runtime information - * \param inputs Pointer to input pointer(s) - * \param outputs Pointer to output pointer(s) - * \return tvm_status_t containing success or errors from the model run - */ -tvm_crt_error_t tvm_runtime_run(const tvm_model_t* model, void** inputs, void** outputs); - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_AOT_EXECUTOR_AOT_EXECUTOR_H_ diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h index 47ef474778e0..c9b3ebe5c643 100644 --- a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h +++ b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h @@ -116,6 +116,6 @@ int TVMGraphExecutor_GetOutput(TVMGraphExecutor* executor, const int32_t idx, DL int32_t TVMGraphExecutor_CreateTVMOp(TVMGraphExecutor* executor, const TVMOpParam* param, DLTensorPtr* args, const uint32_t args_count, - uint32_t num_inputs, TVMPackedFunc* pf); + TVMPackedFunc* pf); #endif // TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_ diff --git a/src/runtime/crt/memory/stack_allocator.c b/src/runtime/crt/memory/stack_allocator.c index 7a41ca4241ab..ba205f8f209b 100644 --- a/src/runtime/crt/memory/stack_allocator.c +++ b/src/runtime/crt/memory/stack_allocator.c @@ -79,8 +79,15 @@ tvm_crt_error_t StackMemoryManager_Free(tvm_workspace_t* tvm_runtime_workspace, tvm_crt_error_t StackMemoryManager_Init(tvm_workspace_t* tvm_runtime_workspace, uint8_t* g_aot_memory, size_t workspace_size) { - tvm_runtime_workspace->next_alloc = g_aot_memory; - tvm_runtime_workspace->workspace = g_aot_memory; - tvm_runtime_workspace->workspace_size = workspace_size; + // We need to round up g_aot_memory in case it is not aligned to + // TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES. + uintptr_t unaligned_mask = TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES - 1; + uint8_t* memory_aligned = + (uint8_t*)(((uintptr_t)g_aot_memory + unaligned_mask) & ~unaligned_mask); + uint32_t offset = (uintptr_t)(memory_aligned - g_aot_memory); + + tvm_runtime_workspace->next_alloc = memory_aligned; + tvm_runtime_workspace->workspace = memory_aligned; + tvm_runtime_workspace->workspace_size = workspace_size - offset; return kTvmErrorNoError; } diff --git a/src/runtime/crt/microtvm_rpc_common/framing.cc b/src/runtime/crt/microtvm_rpc_common/framing.cc index f89c6e5688c0..47e4a33a718c 100644 --- a/src/runtime/crt/microtvm_rpc_common/framing.cc +++ b/src/runtime/crt/microtvm_rpc_common/framing.cc @@ -66,6 +66,26 @@ void Unframer::Reset() { num_buffer_bytes_valid_ = 0; } +size_t Unframer::BytesNeeded() { + size_t bytes_needed = 0; + switch (state_) { + case State::kFindPacketStart: + return 1; + case State::kFindPacketLength: + bytes_needed = PacketFieldSizeBytes::kPayloadLength; + break; + case State::kFindPacketCrc: + return num_payload_bytes_remaining_; + case State::kFindCrcEnd: + bytes_needed = PacketFieldSizeBytes::kCrc; + break; + default: + CHECK(false); + } + + return bytes_needed > num_buffer_bytes_valid_ ? bytes_needed - num_buffer_bytes_valid_ : 0; +} + tvm_crt_error_t Unframer::Write(const uint8_t* data, size_t data_size_bytes, size_t* bytes_consumed) { tvm_crt_error_t return_code = kTvmErrorNoError; diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 47f038b5c612..33a87c9a2be2 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -112,9 +112,14 @@ class CUDADeviceAPI final : public DeviceAPI { ICHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes"; void* ret; if (dev.device_type == kDLCUDAHost) { + DLOG(INFO) << "allocating " << nbytes << "bytes on host"; CUDA_CALL(cudaMallocHost(&ret, nbytes)); } else { CUDA_CALL(cudaSetDevice(dev.device_id)); + size_t free_mem, total_mem; + CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem)); + DLOG(INFO) << "allocating " << nbytes << " bytes on device, with " << free_mem + << " bytes currently free out of " << total_mem << " bytes available"; CUDA_CALL(cudaMalloc(&ret, nbytes)); } return ret; @@ -122,9 +127,11 @@ class CUDADeviceAPI final : public DeviceAPI { void FreeDataSpace(Device dev, void* ptr) final { if (dev.device_type == kDLCUDAHost) { + DLOG(INFO) << "freeing host memory"; CUDA_CALL(cudaFreeHost(ptr)); } else { CUDA_CALL(cudaSetDevice(dev.device_id)); + DLOG(INFO) << "freeing device memory"; CUDA_CALL(cudaFree(ptr)); } } @@ -280,5 +287,16 @@ TVM_REGISTER_GLOBAL("profiling.timer.gpu").set_body_typed([](Device dev) { return Timer(make_object()); }); +TVM_DLL String GetCudaFreeMemory() { + size_t free_mem, total_mem; + CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem)); + std::stringstream ss; + ss << "Current CUDA memory is " << free_mem << " bytes free out of " << total_mem + << " bytes on device"; + return ss.str(); +} + +TVM_REGISTER_GLOBAL("runtime.GetCudaFreeMemory").set_body_typed(GetCudaFreeMemory); + } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index a877bc634300..7d6879a62aba 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -153,12 +153,12 @@ class CUDAWrappedFunc { public: // initialize the CUDA function. void Init(CUDAModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& thread_axis_tags) { + size_t num_void_args, const std::vector& launch_param_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - thread_axis_cfg_.Init(num_void_args, thread_axis_tags); + launch_param_config_.Init(num_void_args, launch_param_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { @@ -168,10 +168,10 @@ class CUDAWrappedFunc { fcache_[device_id] = m_->GetFunc(device_id, func_name_); } CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), - wl.block_dim(2), 0, strm, void_args, nullptr); + wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { const char* msg; cuGetErrorName(result, &msg); @@ -201,8 +201,8 @@ class CUDAWrappedFunc { // Device function cache per device. // mark as mutable, to enable lazy initialization mutable std::array fcache_; - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; + // launch parameters configuration + LaunchParamConfig launch_param_config_; }; class CUDAPrepGlobalBarrier { @@ -241,7 +241,7 @@ PackedFunc CUDAModuleNode::GetFunction(const std::string& name, if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; CUDAWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncVoidAddr(f, info.arg_types); } diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 32dd1d8020c9..35832e83f59c 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -43,7 +43,7 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("arg_types", sarg_types); - writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags); + writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags); writer->EndObject(); } @@ -52,7 +52,9 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { std::vector sarg_types; helper.DeclareField("name", &name); helper.DeclareField("arg_types", &sarg_types); - helper.DeclareField("thread_axis_tags", &thread_axis_tags); + helper.DeclareOptionalField("launch_param_tags", &launch_param_tags); + helper.DeclareOptionalField("thread_axis_tags", + &launch_param_tags); // for backward compatibility helper.ReadAllFields(reader); arg_types.resize(sarg_types.size()); for (size_t i = 0; i < arg_types.size(); ++i) { @@ -63,13 +65,13 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { void FunctionInfo::Save(dmlc::Stream* writer) const { writer->Write(name); writer->Write(arg_types); - writer->Write(thread_axis_tags); + writer->Write(launch_param_tags); } bool FunctionInfo::Load(dmlc::Stream* reader) { if (!reader->Read(&name)) return false; if (!reader->Read(&arg_types)) return false; - if (!reader->Read(&thread_axis_tags)) return false; + if (!reader->Read(&launch_param_tags)) return false; return true; } diff --git a/src/runtime/graph_executor/debug/graph_executor_debug.cc b/src/runtime/graph_executor/debug/graph_executor_debug.cc index 1ea01b19e8aa..2fa73971d000 100644 --- a/src/runtime/graph_executor/debug/graph_executor_debug.cc +++ b/src/runtime/graph_executor/debug/graph_executor_debug.cc @@ -276,16 +276,20 @@ class GraphExecutorDebug : public GraphExecutor { * the module compared to GraphRuntimeDebug::RunIndividual as it runs the * entire graph in order. * + * \param collectors Optional user defined `MetricCollector`s to use with this profiling run. + * * \returns A table of per-op runtimes and total times. */ - profiling::Report Profile() { + profiling::Report Profile(Array collectors) { + std::vector cs(collectors.begin(), collectors.end()); + profiling::Profiler prof(devices_, cs); + // warm up. 1 iteration does not seem enough. for (int i = 0; i < 3; i++) { GraphExecutor::Run(); } - profiling::Profiler prof; - prof.Start(devices_); + prof.Start(); for (size_t i = 0; i < op_execs_.size(); ++i) { if (op_execs_[i]) { // get argument shapes @@ -359,7 +363,10 @@ PackedFunc GraphExecutorDebug::GetFunction(const std::string& name, *rv = this->RunIndividual(number, repeat, min_repeat_ms); }); } else if (name == "profile") { - return TypedPackedFunc([sptr_to_self, this]() { return this->Profile(); }); + return TypedPackedFunc)>( + [sptr_to_self, this](Array collectors) { + return this->Profile(collectors); + }); } else { return GraphExecutor::GetFunction(name, sptr_to_self); } diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index 1084b4ee3ec4..6fe640e87404 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -91,6 +91,11 @@ void GraphExecutor::Init(const std::string& graph_json, tvm::runtime::Module mod std::string& name = nodes_[nid].name; input_map_[name] = i; } + for (size_t i = 0; i < outputs_.size(); i++) { + const uint32_t nid = outputs_[i].node_id; + std::string& name = nodes_[nid].name; + output_map_[name] = i; + } } /*! * \brief Get the input index given the name of input. @@ -104,6 +109,18 @@ int GraphExecutor::GetInputIndex(const std::string& name) { } return -1; } +/*! + * \brief Get the output index given the name of output. + * \param name The name of the output. + * \return The index of output. + */ +int GraphExecutor::GetOutputIndex(const std::string& name) { + auto it = output_map_.find(name); + if (it != output_map_.end()) { + return it->second; + } + return -1; +} /*! * \brief set index-th input to the graph. * \param index The input index. @@ -114,6 +131,23 @@ void GraphExecutor::SetInput(int index, DLTensor* data_in) { uint32_t eid = this->entry_id(input_nodes_[index], 0); data_entry_[eid].CopyFrom(data_in); } +/*! + * \brief Check the legality of external DLTensor*. + * \param external The external DLTensor*. + * \param eid The data_enrty_ index. + */ +void GraphExecutor::CheckExternalDLTensor(const DLTensor* external, uint32_t eid) const { + const DLTensor* internal = data_entry_[eid].operator->(); + + ICHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*external)); + ICHECK_EQ(reinterpret_cast(external->data) % kAllocAlignment, 0); + ICHECK_EQ(internal->ndim, static_cast(external->ndim)); + ICHECK_EQ(internal->device.device_type, external->device.device_type); + ICHECK_EQ(internal->device.device_id, external->device.device_id); + for (auto i = 0; i < external->ndim; ++i) { + ICHECK_EQ(internal->shape[i], external->shape[i]); + } +} /*! * \brief set index-th input to the graph without copying the data. * \param index The input index. @@ -122,23 +156,37 @@ void GraphExecutor::SetInput(int index, DLTensor* data_in) { void GraphExecutor::SetInputZeroCopy(int index, DLTensor* data_ref) { ICHECK_LT(static_cast(index), input_nodes_.size()); uint32_t eid = this->entry_id(input_nodes_[index], 0); - const DLTensor* old_t = data_entry_[eid].operator->(); - // check the consistency of input - ICHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*data_ref)); - ICHECK_EQ(reinterpret_cast(data_ref->data) % kAllocAlignment, 0); - ICHECK_EQ(old_t->ndim, static_cast(data_ref->ndim)); - ICHECK_EQ(old_t->device.device_type, data_ref->device.device_type); - ICHECK_EQ(old_t->device.device_id, data_ref->device.device_id); - for (auto i = 0; i < data_ref->ndim; ++i) { - ICHECK_EQ(old_t->shape[i], data_ref->shape[i]); - } - + CheckExternalDLTensor(data_ref, eid); // Update the data pointer for each argument of each op for (DLTensor* t : input_dltensors_[eid]) { t->data = data_ref->data; } } +/*! + * \brief set index-th output to the graph without copying the data. + * \param index The output index. + * \param data_ref The output data that is referred. + */ +void GraphExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) { + ICHECK_LT(static_cast(index), outputs_.size()); + ICHECK_LT(static_cast(index), output_dltensors_.size()); + const NodeEntry& output_node = outputs_[index]; + uint32_t output_node_eid = this->entry_id(output_node); + + // check the consistency of output + CheckExternalDLTensor(data_ref, output_node_eid); + + // Update the data pointer for output op + for (DLTensor* t : output_dltensors_[output_node_eid]) { + t->data = data_ref->data; + } + + // Update the input of the op connected to the output + for (DLTensor* t : both_output_opinput_dltensors_[output_node_eid]) { + t->data = data_ref->data; + } +} /*! * \brief Get the number of outputs * @@ -358,11 +406,17 @@ void GraphExecutor::SetupStorage() { void GraphExecutor::SetupOpExecs() { op_execs_.resize(this->GetNumOfNodes()); input_dltensors_.resize(num_node_entries()); + output_dltensors_.resize(num_node_entries()); + both_output_opinput_dltensors_.resize(num_node_entries()); std::unordered_set input_node_eids; for (size_t i = 0; i < input_nodes_.size(); i++) { uint32_t nid = input_nodes_[i]; input_node_eids.insert(entry_id(nid, 0)); } + std::unordered_set output_node_eids; + for (size_t i = 0; i < outputs_.size(); i++) { + output_node_eids.insert(entry_id(outputs_[i])); + } // setup the array and requirements. for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) { @@ -380,21 +434,35 @@ void GraphExecutor::SetupOpExecs() { ICHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op"; std::shared_ptr op_args = nullptr; - std::tie(op_execs_[nid], op_args) = CreateTVMOp(inode.param, args, inode.inputs.size()); + std::tie(op_execs_[nid], op_args) = CreateTVMOp(inode.param, args); for (size_t i = 0; i < inode.inputs.size(); i++) { - uint32_t eid = this->entry_id(inode.inputs[i]); + uint32_t input_eid = this->entry_id(inode.inputs[i]); // check if op input is model input - if (input_node_eids.count(eid) > 0) { - input_dltensors_[eid].push_back(static_cast(op_args->arg_values[i].v_handle)); + if (input_node_eids.count(input_eid) > 0) { + input_dltensors_[input_eid].push_back( + static_cast(op_args->arg_values[i].v_handle)); + } + // check if any model output is the input of the op + if (output_node_eids.count(input_eid) > 0) { + both_output_opinput_dltensors_[input_eid].push_back( + static_cast(op_args->arg_values[i].v_handle)); + } + } + + for (uint32_t i = inode.inputs.size(); i < inode.inputs.size() + inode.param.num_outputs; ++i) { + uint32_t output_eid = this->entry_id(nid, i - inode.inputs.size()); + // check if op output is model output + if (output_node_eids.count(output_eid) > 0) { + output_dltensors_[output_eid].push_back( + static_cast(op_args->arg_values[i].v_handle)); } } } } std::pair, std::shared_ptr > -GraphExecutor::CreateTVMOp(const TVMOpParam& param, const std::vector& args, - size_t num_inputs) { +GraphExecutor::CreateTVMOp(const TVMOpParam& param, const std::vector& args) { std::shared_ptr arg_ptr = std::make_shared(); // setup address. arg_ptr->args = args; @@ -463,6 +531,15 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name, this->SetInputZeroCopy(args[0], args[1]); } }); + } else if (name == "set_output_zero_copy") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + int out_idx = this->GetOutputIndex(args[0].operator String()); + if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]); + } else { + this->SetOutputZeroCopy(args[0], args[1]); + } + }); } else if (name == "get_output") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { if (args.num_args == 2) { @@ -491,6 +568,34 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name, [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); }); } else if (name == "run") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); }); + } else if (name == "run_from_inputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(args.size() % 2 == 0) + << "Number of arguments to run_from_inputs must be an even number of key-value pairs"; + Device host{static_cast(args[0].operator int()), args[1].operator int()}; + for (int i = 2; i < args.size(); i += 2) { + if (String::CanConvertFrom(args[i])) { + int in_idx = this->GetInputIndex(args[i].operator String()); + if (in_idx >= 0) { + this->SetInput(in_idx, args[i + 1]); + } else { + LOG(FATAL) << args[i].operator String() << " is not a valid input name"; + } + } else { + this->SetInput(args[i], args[i + 1]); + } + } + this->Run(); + Array outputs; + for (int i = 0; i < this->NumOutputs(); i++) { + NDArray out = this->GetOutput(i); + NDArray a = NDArray::Empty(out.Shape(), out.DataType(), host); + a.CopyFrom(out); + outputs.push_back(a); + } + *rv = outputs; + }); } else if (name == "load_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->LoadParams(args[0].operator std::string()); @@ -503,6 +608,11 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name, dmlc::MemoryStringStream strm(const_cast(¶m_blob)); this->ShareParams(dynamic_cast(*module.operator->()), &strm); }); + } else if (name == "get_input_index") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string"; + *rv = this->GetInputIndex(args[0].operator String()); + }); } else { return PackedFunc(); } diff --git a/src/runtime/graph_executor/graph_executor.h b/src/runtime/graph_executor/graph_executor.h index 631605f630da..87e8aa3cee34 100644 --- a/src/runtime/graph_executor/graph_executor.h +++ b/src/runtime/graph_executor/graph_executor.h @@ -107,6 +107,13 @@ class TVM_DLL GraphExecutor : public ModuleNode { */ int GetInputIndex(const std::string& name); + /*! + * \brief Get the output index given the name of output. + * \param name The name of the output. + * \return The index of output. + */ + int GetOutputIndex(const std::string& name); + /*! * \brief set index-th input to the graph. * \param index The input index. @@ -119,6 +126,12 @@ class TVM_DLL GraphExecutor : public ModuleNode { * \param data_ref The input data that is referred. */ void SetInputZeroCopy(int index, DLTensor* data_ref); + /*! + * \brief set index-th output to the graph without copying the data. + * \param index The output index. + * \param data_ref The output data that is referred. + */ + void SetOutputZeroCopy(int index, DLTensor* data_ref); /*! * \brief Get the number of outputs * @@ -193,6 +206,9 @@ class TVM_DLL GraphExecutor : public ModuleNode { uint32_t node_id; uint32_t index; uint32_t version; + inline bool operator==(const NodeEntry& other) const { + return node_id == other.node_id && index == other.index && version == other.version; + } // JSON Loader void Load(dmlc::JSONReader* reader) { reader->BeginArray(); @@ -377,15 +393,20 @@ class TVM_DLL GraphExecutor : public ModuleNode { void SetupStorage(); /*! \brief Setup the executors. */ void SetupOpExecs(); + /*! + * \brief Check the legality of external DLTensor*. + * \param external The external DLTensor*. + * \param eid The data_enrty_ index. + */ + void CheckExternalDLTensor(const DLTensor* external, uint32_t eid) const; /*! * \brief Create an execution function given input. * \param attrs The node attributes. * \param args The arguments to the functor, including inputs and outputs. - * \param num_inputs Number of inputs. * \return The created executor. */ std::pair, std::shared_ptr> CreateTVMOp( - const TVMOpParam& attrs, const std::vector& args, size_t num_inputs); + const TVMOpParam& attrs, const std::vector& args); // Get node entry index. uint32_t entry_id(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; } // Get node entry index. @@ -398,8 +419,14 @@ class TVM_DLL GraphExecutor : public ModuleNode { std::vector input_nodes_; /*! \brief Map of input names to input indices. */ std::unordered_map input_map_; + /*! \brief Map of output names to output indices. */ + std::unordered_map output_map_; /*! \brief Used for quick node input DLTensor* lookup given an input eid. */ std::vector> input_dltensors_; + /*! \brief Used for quick node output DLTensor* lookup given an output eid. */ + std::vector> output_dltensors_; + /*! \brief Used for quick node(both model output and op input) DLTensor* lookup given an eid. */ + std::vector> both_output_opinput_dltensors_; /*! \brief Used for quick entry indexing. */ std::vector node_row_ptr_; /*! \brief Output entries. */ diff --git a/src/runtime/hexagon/sim/driver/CMakeLists.txt b/src/runtime/hexagon/sim/driver/CMakeLists.txt index bed23c1f94b2..dbac99534383 100644 --- a/src/runtime/hexagon/sim/driver/CMakeLists.txt +++ b/src/runtime/hexagon/sim/driver/CMakeLists.txt @@ -24,11 +24,17 @@ if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake) include(${CMAKE_CURRENT_BINARY_DIR}/config.cmake) endif() +if("${HEXAGON_ARCH}" STREQUAL "") + set(DEFAULT_HEXAGON_ARCH "v66") + message(STATUS "HEXAGON_ARCH not defined, defaulting to ${DEFAULT_HEXAGON_ARCH}") + set(HEXAGON_ARCH "${DEFAULT_HEXAGON_ARCH}") +endif() + set(EXTRA_CXX_FLAGS "-O2" "-Wno-format" "-mhvx -mhvx-length=128b" - "-mv65" + "-m${HEXAGON_ARCH}" "-stdlib=libc++" ) diff --git a/src/runtime/hexagon/sim/hexagon_device_sim.cc b/src/runtime/hexagon/sim/hexagon_device_sim.cc index 1d3f0fd1006f..14ab4c30e2f2 100644 --- a/src/runtime/hexagon/sim/hexagon_device_sim.cc +++ b/src/runtime/hexagon/sim/hexagon_device_sim.cc @@ -17,12 +17,10 @@ * under the License. */ -#include -#include -#include -#include -#include +#include +#include #include +#include #include #include @@ -31,6 +29,7 @@ #include #include #include +#include #include #include "../hexagon_module.h" @@ -84,6 +83,18 @@ std::unique_ptr make_unique(size_t size) { return std::unique_ptr(new U[size]()); } +// An "Optional" class, originally a replacement for llvm::Optional, then an +// extension of dmlc::optional to make it compatible with C++17's std::optional. +template +struct Optional : public dmlc::optional { + using dmlc::optional::optional; + using dmlc::optional::operator=; + Optional(const T& val) : dmlc::optional(val) {} // NOLINT(*) + + T* operator->() { return &this->operator*(); } + const T* operator->() const { return &this->operator*(); } +}; + // Converter class to translate vector to char**. This relieves the // user from memory reallocation and copying. struct non_const_str { @@ -117,7 +128,7 @@ struct non_const_str { std::vector> storage_; }; -using MaybeString = llvm::Optional; +using MaybeString = Optional; MaybeString front(const string_list& deq) { return !deq.empty() ? MaybeString(deq.front()) : MaybeString(); @@ -130,47 +141,47 @@ MaybeString pop_front(string_list& deq) { // NOLINT(*) return MaybeString(f); } -llvm::Optional to_int(const MaybeString& str) { - auto none = llvm::Optional(); - if (str.hasValue()) { +Optional to_int(const MaybeString& str) { + auto none = Optional(); + if (str.has_value()) { try { size_t pos; int64_t val = std::stoll(*str, &pos, 0); - return pos == str->size() ? llvm::Optional(val) : none; + return pos == str->size() ? Optional(val) : none; } catch (std::invalid_argument) { } } return none; } -llvm::Optional to_uint(const MaybeString& str) { - auto none = llvm::Optional(); - if (str.hasValue()) { +Optional to_uint(const MaybeString& str) { + auto none = Optional(); + if (str.has_value()) { try { size_t pos; uint64_t val = std::stoull(*str, &pos, 0); - return pos == str->size() ? llvm::Optional(val) : none; + return pos == str->size() ? Optional(val) : none; } catch (std::invalid_argument) { } } return none; } -llvm::Optional to_float(const MaybeString& str) { - auto none = llvm::Optional(); - if (str.hasValue()) { +Optional to_float(const MaybeString& str) { + auto none = Optional(); + if (str.has_value()) { try { size_t pos; float val = std::stof(*str, &pos); - return pos == str->size() ? llvm::Optional(val) : none; + return pos == str->size() ? Optional(val) : none; } catch (std::invalid_argument) { } } return none; } -llvm::Optional to_bool(const MaybeString& str) { - auto none = llvm::Optional(); +Optional to_bool(const MaybeString& str) { + auto none = Optional(); if (auto num = to_int(str)) { if (*num == 0) return false; if (*num == 1) return true; @@ -184,9 +195,9 @@ llvm::Optional to_bool(const MaybeString& str) { } template -using MaybeRange = llvm::Optional>; +using MaybeRange = Optional>; -template Parse(const MaybeString&)> +template Parse(const MaybeString&)> MaybeRange to_range(const MaybeString& str) { auto none = MaybeRange(); if (str && !str->empty()) { @@ -202,6 +213,72 @@ MaybeRange to_range(const MaybeString& str) { return none; } +// Replacement for llvm::StringSwitch. +template +class StringSwitch { + public: + explicit StringSwitch(const std::string& key) : key(key) {} + operator T() const { + auto f = map.find(key); + if (f != map.end()) { + return f->second; + } + ICHECK(static_cast(def_val)) << "default value not set"; + return *def_val; + } + StringSwitch& Case(const std::string& key, T val) { + map.insert(std::make_pair(key, val)); + return *this; + } + StringSwitch& Default(T val) { + ICHECK(!static_cast(def_val)) << "default value already set"; + def_val = val; + return *this; + } + + private: + const std::string key; + std::map map; + Optional def_val; +}; + +// Replacement for llvm::sys::fs::access with AccessMode = Execute. +bool FileExists(const std::string& file) { return access(file.c_str(), X_OK) == 0; } + +// Replacement for llvm::sys::Process::FindInEnvPath. +MaybeString FindInEnvPath(const std::string& env_var, const std::string& file) { + auto none = MaybeString(); + if (file.empty() || file[0] == '/') { + return none; + } + + const char* e = getenv(env_var.c_str()); + std::string env_val = e != nullptr ? std::string(e) : std::string(); + + std::vector paths; + // Split the environment variable into individual paths. + size_t first = 0, env_size = env_val.size(); + for (size_t last = 0; last != env_size; ++last) { + if (env_val[last] == ':') { + if (last > first) { + paths.emplace_back(env_val, first, last - first); + } + first = last + 1; + } + } + if (first < env_size) { + paths.emplace_back(env_val, first, env_size - first); + } + + // Search for the file. + for (const std::string& dir : paths) { + std::string full = dir + '/' + file; + if (FileExists(full)) { + return full; + } + } + return none; +} } // namespace detail class HexagonSimulator final : public tvm::runtime::hexagon::Device { @@ -304,17 +381,17 @@ class HexagonSimulator final : public tvm::runtime::hexagon::Device { bool HandleV2PTranslation(string_list& rest); // NOLINT(*) bool HandleVerbose(string_list& rest); // NOLINT(*) - using MaybeUInt64 = llvm::Optional; + using MaybeUInt64 = detail::Optional; using MaybeUIntRange = std::pair; bool should_parse_next(const string_list& rest); - llvm::Optional to_interval(const detail::MaybeString& str); - llvm::Optional to_timingmode(const detail::MaybeString& str); - llvm::Optional to_verbosemode(const detail::MaybeString& str); - llvm::Optional to_nullptr(const detail::MaybeString& str); + detail::Optional to_interval(const detail::MaybeString& str); + detail::Optional to_timingmode(const detail::MaybeString& str); + detail::Optional to_verbosemode(const detail::MaybeString& str); + detail::Optional to_nullptr(const detail::MaybeString& str); MaybeUIntRange ahb_, axi2_; - llvm::Optional debug_port_; + detail::Optional debug_port_; detail::non_const_str sim_dev_args_; using OptionHandler = bool (HexagonSimulator::*)(string_list&); @@ -556,13 +633,13 @@ HexagonSimulator::HexagonSimulator(bool enable_queuing) { LOG(INFO) << "HexagonSimulator: Core version: " << arch_; // Locate the sim_dev binary in PATH, or in the current working directory. - llvm::StringRef sim_dev = "sim_dev"; - detail::MaybeString path_sim_dev = llvm::sys::Process::FindInEnvPath("PATH", sim_dev); + std::string sim_dev = "sim_dev"; + detail::MaybeString path_sim_dev = detail::FindInEnvPath("PATH", sim_dev); if (!path_sim_dev) { - if (!llvm::sys::fs::exists(sim_dev)) { + if (!detail::FileExists(sim_dev)) { LOG(FATAL) << "Cannot find sim_dev in PATH."; } - path_sim_dev = sim_dev.str(); + path_sim_dev = sim_dev; } CHECKED_CALL(ConfigureExecutableBinary, path_sim_dev->c_str()); @@ -767,19 +844,19 @@ bool HexagonSimulator::Configure(string_list& opts) { } // Check AHB. - if (ahb_.first.hasValue() && ahb_.second.hasValue()) { + if (ahb_.first.has_value() && ahb_.second.has_value()) { CHECKED_CALL(ConfigureAHB, *ahb_.first, *ahb_.second); } else { - ICHECK(!ahb_.first.hasValue() && !ahb_.second.hasValue()) + ICHECK(!ahb_.first.has_value() && !ahb_.second.has_value()) << "HexagonSimulator: please specify both low and high addresses " "for AHB"; } // Check AXI2. - if (axi2_.first.hasValue() && axi2_.second.hasValue()) { + if (axi2_.first.has_value() && axi2_.second.has_value()) { CHECKED_CALL(ConfigureAXI2, *axi2_.first, *axi2_.second); } else { - ICHECK(!axi2_.first.hasValue() && !axi2_.second.hasValue()) + ICHECK(!axi2_.first.has_value() && !axi2_.second.has_value()) << "HexagonSimulator: please specify both low and high addresses " "for AXI2"; } @@ -1260,8 +1337,8 @@ bool HexagonSimulator::should_parse_next(const string_list& rest) { return false; } -llvm::Optional HexagonSimulator::to_interval(const detail::MaybeString& str) { - auto none = llvm::Optional(); +detail::Optional HexagonSimulator::to_interval(const detail::MaybeString& str) { + auto none = detail::Optional(); if (!str) return none; if (auto val = detail::to_int(*str)) { @@ -1275,7 +1352,7 @@ llvm::Optional HexagonSimulator::to_interval(const detail::Mayb } } - return llvm::StringSwitch>(*str) + return detail::StringSwitch>(*str) .Case("MILLISEC", HEX_MILLISEC) .Case("MICROSEC", HEX_MICROSEC) .Case("NANOSEC", HEX_NANOSEC) @@ -1284,8 +1361,9 @@ llvm::Optional HexagonSimulator::to_interval(const detail::Mayb .Default(none); } -llvm::Optional HexagonSimulator::to_timingmode(const detail::MaybeString& str) { - auto none = llvm::Optional(); +detail::Optional HexagonSimulator::to_timingmode( + const detail::MaybeString& str) { + auto none = detail::Optional(); if (!str) return none; if (auto val = detail::to_int(*str)) { @@ -1298,7 +1376,7 @@ llvm::Optional HexagonSimulator::to_timingmode(const detail:: } } - return llvm::StringSwitch>(*str) + return detail::StringSwitch>(*str) .Case("NOTIMING", HEX_NOTIMING) .Case("TIMING_NODBC", HEX_TIMING_NODBC) .Case("TIMING", HEX_TIMING) @@ -1306,9 +1384,9 @@ llvm::Optional HexagonSimulator::to_timingmode(const detail:: .Default(none); } -llvm::Optional HexagonSimulator::to_verbosemode( +detail::Optional HexagonSimulator::to_verbosemode( const detail::MaybeString& str) { - auto none = llvm::Optional(); + auto none = detail::Optional(); if (!str) return none; if (auto val = detail::to_int(*str)) { @@ -1322,7 +1400,7 @@ llvm::Optional HexagonSimulator::to_verbosemode( } } - return llvm::StringSwitch>(*str) + return detail::StringSwitch>(*str) .Case("SILENT", HEX_SILENT) .Case("QUIET", HEX_QUIET) .Case("NORMAL", HEX_NORMAL) @@ -1331,8 +1409,8 @@ llvm::Optional HexagonSimulator::to_verbosemode( .Default(none); } -llvm::Optional HexagonSimulator::to_nullptr(const detail::MaybeString& str) { - auto none = llvm::Optional(); +detail::Optional HexagonSimulator::to_nullptr(const detail::MaybeString& str) { + auto none = detail::Optional(); if (!str) return none; if (auto val = detail::to_int(*str)) { @@ -1345,7 +1423,7 @@ llvm::Optional HexagonSimulator::to_nullptr(const detail::MaybeS } } - return llvm::StringSwitch>(*str) + return detail::StringSwitch>(*str) .Case("IGNORE", HEX_NULLPTR_IGNORE) .Case("WARN", HEX_NULLPTR_WARN) .Case("FATAL", HEX_NULLPTR_FATAL) diff --git a/src/runtime/hexagon/target/fastrpc/CMakeLists.txt b/src/runtime/hexagon/target/fastrpc/CMakeLists.txt index 0d790d76f7f7..5aac04a0ea56 100644 --- a/src/runtime/hexagon/target/fastrpc/CMakeLists.txt +++ b/src/runtime/hexagon/target/fastrpc/CMakeLists.txt @@ -23,22 +23,19 @@ if(NOT "${FASTRPC_LIBS}" STREQUAL "SKEL" AND message(SEND_ERROR "Please set FASTRPC_LIBS to either SKEL or STUB") endif() +include(../../../../../cmake/modules/HexagonSDK.cmake) -set(FASTRPC_SRC "${CMAKE_CURRENT_SOURCE_DIR}") +find_hexagon_sdk_root("${HEXAGON_SDK_ROOT}" "${HEXAGON_ARCH}") include_directories(include) -include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs) -include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef) -include_directories( - SYSTEM ${HEXAGON_SDK_ROOT}/libs/common/remote/ship/android_Release_aarch64) - -set(QAIC_EXE "${HEXAGON_SDK_ROOT}/tools/qaic/Ubuntu16/qaic") -set(QAIC_FLAGS - "-I${HEXAGON_SDK_ROOT}/incs/stddef" - "-I${HEXAGON_SDK_ROOT}/libs/common/remote/ship/android_Release_aarch64" - "-I${HEXAGON_SDK_ROOT}/libs/common/rpcmem/inc" -) +include_directories(SYSTEM ${HEXAGON_SDK_INCLUDES} ${HEXAGON_REMOTE_ROOT}) +set(QAIC_EXE "${HEXAGON_QAIC_EXE}") +foreach(INCDIR IN LISTS HEXAGON_SDK_INCLUDES HEXAGON_REMOTE_ROOT) + list(APPEND QAIC_FLAGS "-I${INCDIR}") +endforeach() + +set(FASTRPC_SRC "${CMAKE_CURRENT_SOURCE_DIR}") set(CMAKE_SKIP_RPATH TRUE) # Qaic for the non-domain header. @@ -51,13 +48,13 @@ set(TVM_REMOTE_ND_SKEL_C "tvm_remote_nd_skel.c") set(TVM_REMOTE_ND_STUB_C "tvm_remote_nd_stub.c") add_custom_command( - OUTPUT ${TVM_REMOTE_ND_SKEL_C} ${TVM_REMOTE_ND_STUB_C} - "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" - COMMAND ${QAIC_EXE} ${QAIC_FLAGS} - "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_IDL}" - COMMAND ${CMAKE_COMMAND} -E rename "${TVM_REMOTE_ND_H}" - "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" - MAIN_DEPENDENCY "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_IDL}" + OUTPUT ${TVM_REMOTE_ND_SKEL_C} ${TVM_REMOTE_ND_STUB_C} + "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" + COMMAND ${QAIC_EXE} ${QAIC_FLAGS} + "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_IDL}" + COMMAND ${CMAKE_COMMAND} -E rename "${TVM_REMOTE_ND_H}" + "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" + MAIN_DEPENDENCY "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_IDL}" ) # Qaic for the domain header. @@ -70,35 +67,20 @@ set(TVM_REMOTE_D_SKEL_C "tvm_remote_skel.c") set(TVM_REMOTE_D_STUB_C "tvm_remote_stub.c") add_custom_command( - OUTPUT ${TVM_REMOTE_D_SKEL_C} ${TVM_REMOTE_D_STUB_C} - "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" - COMMAND ${QAIC_EXE} ${QAIC_FLAGS} - "${FASTRPC_SRC}/include/${TVM_REMOTE_D_IDL}" - COMMAND ${CMAKE_COMMAND} -E rename "${TVM_REMOTE_D_H}" - "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" - MAIN_DEPENDENCY "${FASTRPC_SRC}/include/${TVM_REMOTE_D_IDL}" + OUTPUT ${TVM_REMOTE_D_SKEL_C} ${TVM_REMOTE_D_STUB_C} + "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" + COMMAND ${QAIC_EXE} ${QAIC_FLAGS} + "${FASTRPC_SRC}/include/${TVM_REMOTE_D_IDL}" + COMMAND ${CMAKE_COMMAND} -E rename "${TVM_REMOTE_D_H}" + "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" + MAIN_DEPENDENCY "${FASTRPC_SRC}/include/${TVM_REMOTE_D_IDL}" ) if("${FASTRPC_LIBS}" STREQUAL "SKEL") # Skel libraries. # - set(HEXARCH_DIR_v60 "ADSPv60MP") - set(HEXARCH_DIR_v62 "ADSPv62MP") - set(HEXARCH_DIR_v65 "computev65") - set(HEXARCH_DIR_v66 "computev66") - set(HEXARCH_DIR_STR "HEXARCH_DIR_${HEXAGON_ARCH}") - set(HEXARCH_DIR ${${HEXARCH_DIR_STR}}) - - if(NOT HEXARCH_DIR) - message(SEND_ERROR - "Please set HEXAGON_ARCH to one of v60, v62, v65, v66") - endif() - - include_directories( - SYSTEM ${HEXAGON_SDK_ROOT}/libs/common/qurt/${HEXARCH_DIR}/include/qurt) - include_directories( - SYSTEM ${HEXAGON_SDK_ROOT}/libs/common/qurt/${HEXARCH_DIR}/include/posix) + include_directories(SYSTEM ${HEXAGON_QURT_INCLUDES}) # Extra compile flags (both C and C++). set(EXTRA_COMP_FLAGS @@ -106,45 +88,40 @@ if("${FASTRPC_LIBS}" STREQUAL "SKEL") "-m${HEXAGON_ARCH}" ) string(REGEX REPLACE ";" " " EXTRA_COMP_FLAGS_STR "${EXTRA_COMP_FLAGS}") - message(STATUS "EXTRA_COMP_FLAGS_STR: ${EXTRA_COMP_FLAGS_STR}") set(CMAKE_C_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_CXX_FLAGS}") set(EXTRA_LINK_FLAGS - "-Wl,--no-threads" - "-Wl,--wrap=malloc" - "-Wl,--wrap=calloc" - "-Wl,--wrap=free" - "-Wl,--wrap=realloc" - "-Wl,--wrap=memalign" - "-Wl,--wrap=posix_memalign" - "-Wl,--wrap=__stack_chk_fail" + "-Wl,--no-threads" + "-Wl,--wrap=malloc" + "-Wl,--wrap=calloc" + "-Wl,--wrap=free" + "-Wl,--wrap=realloc" + "-Wl,--wrap=memalign" + "-Wl,--wrap=posix_memalign" + "-Wl,--wrap=__stack_chk_fail" ) string(REGEX REPLACE ";" " " EXTRA_LINK_FLAGS_STR "${EXTRA_LINK_FLAGS}") - # Extra linker flags for linking shared libraries. - set(CMAKE_SHARED_LINKER_FLAGS - "${EXTRA_LINK_FLAGS_STR} ${CMAKE_SHARED_LINKER_FLAGS}") - set(SKEL_ND_SRCS - "src/tvm_hvx.cc" - "src/tvm_remote_nd_imp.cc" + "src/tvm_hvx.cc" + "src/tvm_remote_nd_imp.cc" ) add_library(tvm_remote_nd_skel SHARED - "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" - ${TVM_REMOTE_ND_SKEL_C} - ${SKEL_ND_SRCS} + "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" + ${TVM_REMOTE_ND_SKEL_C} + ${SKEL_ND_SRCS} ) set(SKEL_D_SRCS - # Also includes src/tvm_remote_nd_imp.cc - ${SKEL_ND_SRCS} - "src/tvm_remote_imp.cc" + # Also includes src/tvm_remote_nd_imp.cc + ${SKEL_ND_SRCS} + "src/tvm_remote_imp.cc" ) add_library(tvm_remote_skel SHARED - "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" - ${TVM_REMOTE_D_SKEL_C} - ${SKEL_D_SRCS} + "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" + ${TVM_REMOTE_D_SKEL_C} + ${SKEL_D_SRCS} ) # Separate shared library with __wrap_pthread_create. @@ -155,24 +132,28 @@ if("${FASTRPC_LIBS}" STREQUAL "SKEL") set(WRAP_PTHREAD_SRCS "src/tvm_wrap_pthread.cc") add_library(tvm_wrap_pthread SHARED ${WRAP_PTHREAD_SRCS}) + # Extra linker flags for linking shared libraries. + set_target_properties(tvm_remote_nd_skel PROPERTIES LINK_FLAGS ${EXTRA_LINK_FLAGS_STR}) + set_target_properties(tvm_remote_skel PROPERTIES LINK_FLAGS ${EXTRA_LINK_FLAGS_STR}) + set_target_properties(tvm_wrap_pthread PROPERTIES LINK_FLAGS ${EXTRA_LINK_FLAGS_STR}) else() # Stub libraries. # - include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/a1std) - include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/qlist) - include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/libs/common/rpcmem/inc) - link_directories( - SYSTEM ${HEXAGON_SDK_ROOT}/libs/common/remote/ship/android_Release_aarch64) + include_directories(SYSTEM + ${HEXAGON_SDK_INCLUDES} + "${HEXAGON_RPCMEM_ROOT}/inc" + ) + link_directories(${HEXAGON_REMOTE_ROOT}) add_library(tvm_remote_nd_stub SHARED - "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" - "${HEXAGON_SDK_ROOT}/libs/common/rpcmem/src/rpcmem_android.c" - "${TVM_REMOTE_ND_STUB_C}" + "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" + "${HEXAGON_RPCMEM_ROOT}/src/rpcmem_android.c" + "${TVM_REMOTE_ND_STUB_C}" ) add_library(tvm_remote_stub SHARED - "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" - "${HEXAGON_SDK_ROOT}/libs/common/rpcmem/src/rpcmem_android.c" - "${TVM_REMOTE_D_STUB_C}" + "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" + "${HEXAGON_RPCMEM_ROOT}/src/rpcmem_android.c" + "${TVM_REMOTE_D_STUB_C}" ) target_link_libraries(tvm_remote_nd_stub adsprpc) target_link_libraries(tvm_remote_stub adsprpc) diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index e3ec155dc291..66d9a44099da 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -54,8 +54,8 @@ inline String get_name_mangled(const String& module_name, const String& name) { */ class MetadataNode : public Object { public: - /*! \brief number of inputs of the main function */ - int num_inputs = 1; + /*! \brief input information for the main function */ + Array inputs; /*! \brief number of outputs of the main function */ int num_outputs = 1; /*! \brief the executor to be used to run the model */ @@ -73,9 +73,9 @@ class MetadataNode : public Object { */ class Metadata : public ObjectRef { public: - TVM_DLL Metadata(int num_inputs, int num_outputs, String executor, String mod_name) { + TVM_DLL Metadata(Array inputs, int num_outputs, String executor, String mod_name) { auto n = make_object(); - n->num_inputs = num_inputs; + n->inputs = inputs; n->num_outputs = num_outputs; n->executor = executor; n->mod_name = mod_name; @@ -99,11 +99,14 @@ Module MetadataModuleCreate( const std::unordered_map& metadata, const std::unordered_map>& sym_vars); +/*! \brief A tag to specify whether or not dynamic shared memory is used */ +constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; + /*! \brief function information needed by device */ struct FunctionInfo { std::string name; std::vector arg_types; - std::vector thread_axis_tags; + std::vector launch_param_tags; void Save(dmlc::JSONWriter* writer) const; void Load(dmlc::JSONReader* reader); diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 88501880557e..1e81ac1bbb34 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -178,7 +178,7 @@ void SaveToBinary(dmlc::Stream* stream) final { // initialize the METAL function. void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags) { + const std::vector& launch_param_tags) { w_ = metal::MetalWorkspace::Global(); m_ = m; sptr_ = sptr; @@ -186,7 +186,7 @@ void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_na num_buffer_args_ = num_buffer_args; num_pack_args_ = num_pack_args; std::fill(scache_.begin(), scache_.end(), (id)nil); - thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); + launch_param_config_.Init(num_buffer_args + num_pack_args, launch_param_tags); metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int dev_id = t->device.device_id; scache_[dev_id] = m->GetPipelineState(dev_id, func_name); @@ -201,7 +201,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons if (scache_[device_id] == nil) { scache_[device_id] = m_->GetPipelineState(device_id, func_name_); } - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2); auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup; CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup); @@ -242,8 +242,8 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons // Device state cache per device. // mark as mutable, to enable lazy initialization mutable std::array, kMetalMaxNumDevice> scache_; - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; + // launch parameters configuration + LaunchParamConfig launch_param_config_; }; PackedFunc MetalModuleNode::GetFunction(const std::string& name, @@ -261,7 +261,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons MetalWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, - info.thread_axis_tags); + info.launch_param_tags); pf = PackFuncNonBufferArg(f, info.arg_types); }; return pf; diff --git a/src/runtime/crt/host/crt_config.h b/src/runtime/micro/crt_config.h similarity index 90% rename from src/runtime/crt/host/crt_config.h rename to src/runtime/micro/crt_config.h index b81a74eb4ae6..c3e8fea1ba08 100644 --- a/src/runtime/crt/host/crt_config.h +++ b/src/runtime/micro/crt_config.h @@ -21,8 +21,8 @@ * \file tvm/runtime/crt/host/crt_config.h * \brief CRT configuration for the host-linked CRT. */ -#ifndef TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_ -#define TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_ +#ifndef TVM_RUNTIME_MICRO_CRT_CONFIG_H_ +#define TVM_RUNTIME_MICRO_CRT_CONFIG_H_ /*! Log level of the CRT runtime */ #define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG @@ -35,9 +35,9 @@ /*! Maximum supported arguments in generated functions */ #define TVM_CRT_MAX_ARGS 10 /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ -#define TVM_CRT_STRLEN_DLTYPE 10 +#define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_STRLEN_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 @@ -53,4 +53,4 @@ // #define TVM_CRT_FRAMER_ENABLE_LOGS -#endif // TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_ +#endif // TVM_RUNTIME_MICRO_CRT_CONFIG_H_ diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc index cd916d46971d..2dcd928b24f8 100644 --- a/src/runtime/micro/micro_session.cc +++ b/src/runtime/micro/micro_session.cc @@ -37,10 +37,10 @@ #include #include "../../support/str_escape.h" -#include "../crt/host/crt_config.h" #include "../rpc/rpc_channel.h" #include "../rpc/rpc_endpoint.h" #include "../rpc/rpc_session.h" +#include "crt_config.h" namespace tvm { namespace runtime { @@ -56,10 +56,12 @@ class CallbackWriteStream : public WriteStream { bytes.data = (const char*)data; bytes.size = data_size_bytes; if (write_timeout_ == ::std::chrono::microseconds::zero()) { - return static_cast(fsend_(bytes, nullptr)); + fsend_(bytes, nullptr); } else { - return static_cast(fsend_(bytes, write_timeout_.count())); + fsend_(bytes, write_timeout_.count()); } + + return static_cast(data_size_bytes); } void PacketDone(bool is_valid) override {} @@ -143,15 +145,16 @@ class MicroTransportChannel : public RPCChannel { } ::std::string chunk; + size_t bytes_needed = unframer_.BytesNeeded(); + CHECK_GT(bytes_needed, 0) << "unframer unexpectedly needs no data"; if (timeout != nullptr) { ::std::chrono::microseconds iter_timeout{ ::std::max(::std::chrono::microseconds{0}, ::std::chrono::duration_cast<::std::chrono::microseconds>( end_time - ::std::chrono::steady_clock::now()))}; - chunk = - frecv_(size_t(kReceiveBufferSizeBytes), iter_timeout.count()).operator std::string(); + chunk = frecv_(bytes_needed, iter_timeout.count()).operator std::string(); } else { - chunk = frecv_(size_t(kReceiveBufferSizeBytes), nullptr).operator std::string(); + chunk = frecv_(bytes_needed, nullptr).operator std::string(); } pending_chunk_ = chunk; if (pending_chunk_.size() == 0) { diff --git a/src/runtime/micro/standalone/microtvm_graph_executor.cc b/src/runtime/micro/standalone/microtvm_graph_executor.cc index d91d0ea74ba4..d09efc4be50e 100644 --- a/src/runtime/micro/standalone/microtvm_graph_executor.cc +++ b/src/runtime/micro/standalone/microtvm_graph_executor.cc @@ -321,7 +321,7 @@ void MicroGraphExecutor::SetupStorage() { } std::function CreateTVMOp(const DSOModule& module, const TVMOpParam& param, - const DynArray& args, size_t num_inputs) { + const DynArray& args) { typedef union { void* v_handle; } TVMValue; @@ -389,7 +389,7 @@ void MicroGraphExecutor::SetupOpExecs() { args[index + inode.inputs.size()] = data_entry_[eid].ToDLTensor(); } assert(inode.op_type == "tvm_op"); - op_execs_[nid] = CreateTVMOp(*module_, inode.param, args, inode.inputs.size()); + op_execs_[nid] = CreateTVMOp(*module_, inode.param, args); } } diff --git a/src/runtime/module.cc b/src/runtime/module.cc index acc7fc7286d1..f9c281ab9d02 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -84,7 +84,10 @@ Module Module::LoadFromFile(const std::string& file_name, const std::string& for } std::string load_f_name = "runtime.module.loadfile_" + fmt; const PackedFunc* f = Registry::Get(load_f_name); - ICHECK(f != nullptr) << "Loader of " << format << "(" << load_f_name << ") is not presented."; + ICHECK(f != nullptr) << "Loader for `." << format << "` files is not registered," + << " resolved to (" << load_f_name << ") in the global registry." + << "Ensure that you have loaded the correct runtime code, and" + << "that you are on the correct hardware architecture."; Module m = (*f)(file_name, format); return m; } @@ -113,7 +116,10 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { if (pf == nullptr) { const PackedFunc* f = Registry::Get(name); ICHECK(f != nullptr) << "Cannot find function " << name - << " in the imported modules or global registry"; + << " in the imported modules or global registry." + << " If this involves ops from a contrib library like" + << " cuDNN, ensure TVM was built with the relevant" + << " library."; return f; } else { import_cache_.insert(std::make_pair(name, std::make_shared(pf))); diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 23594b5d7d8a..3cd5df613f4a 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -138,8 +138,12 @@ class TypeContext { std::string TypeIndex2Key(uint32_t tindex) { std::lock_guard lock(mutex_); - ICHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) - << "Unknown type index " << tindex; + if (tindex != 0) { + // always return the right type key for root + // for non-root type nodes, allocated slots should not equal 0 + ICHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) + << "Unknown type index " << tindex; + } return type_table_[tindex].name; } @@ -258,3 +262,11 @@ int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key); API_END(); } + +int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key) { + API_BEGIN(); + auto key = tvm::runtime::Object::TypeIndex2Key(tindex); + *out_type_key = static_cast(malloc(key.size() + 1)); + strncpy(*out_type_key, key.c_str(), key.size()); + API_END(); +} diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 4040d82b33e7..f6c7f6232819 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -40,14 +40,14 @@ class OpenCLWrappedFunc { // initialize the OpenCL function. void Init(OpenCLModuleNode* m, ObjectPtr sptr, OpenCLModuleNode::KTRefEntry entry, std::string func_name, std::vector arg_size, - const std::vector& thread_axis_tags) { + const std::vector& launch_param_tags) { w_ = m->GetGlobalWorkspace(); m_ = m; sptr_ = sptr; entry_ = entry; func_name_ = func_name; arg_size_ = arg_size; - thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags); + launch_param_config_.Init(arg_size.size(), launch_param_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { @@ -73,8 +73,8 @@ class OpenCLWrappedFunc { OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg)); } cl_command_queue queue = w_->GetQueue(t->device); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - cl_uint work_dim = static_cast(thread_axis_cfg_.work_dim()); + ThreadWorkLoad wl = launch_param_config_.Extract(args); + cl_uint work_dim = static_cast(launch_param_config_.work_dim()); for (cl_uint i = 0; i < work_dim; ++i) { wl.work_size[i] *= wl.work_size[i + 3]; } @@ -96,8 +96,8 @@ class OpenCLWrappedFunc { std::string func_name_; // convert code for void argument std::vector arg_size_; - // thread axis config - ThreadAxisConfig thread_axis_cfg_; + // launch parameters config + LaunchParamConfig launch_param_config_; }; OpenCLModuleNode::~OpenCLModuleNode() { @@ -148,7 +148,7 @@ PackedFunc OpenCLModuleNode::GetFunction(const std::string& name, } } // initialize the wrapped func. - f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.thread_axis_tags); + f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.launch_param_tags); return PackFuncVoidAddr(f, info.arg_types); } diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index ab9d674fad50..596b6ace8831 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -23,8 +23,10 @@ */ #include +#include #include #include +#include #include #include @@ -100,16 +102,37 @@ TVM_REGISTER_GLOBAL("profiling.start_timer").set_body_typed(Timer::Start); namespace profiling { -void Profiler::Start(const std::vector& devs) { - CHECK(global_timers_.empty()) << "You can only call Start once per Profiler."; +Profiler::Profiler(std::vector devs, std::vector metric_collectors) + : devs_(devs), collectors_(metric_collectors) { + is_running_ = false; + std::vector wrapped_devs; for (auto dev : devs) { - global_timers_.emplace_back(dev, Timer::Start(dev)); + wrapped_devs.push_back(DeviceWrapper(make_object(dev))); + } + for (auto& x : collectors_) { + x->Init(wrapped_devs); + } + // reset the thread pool so that PAPI eventset hooks are set in all threads. + threading::ResetThreadPool(); +} + +void Profiler::Start() { + is_running_ = true; + for (auto dev : devs_) { + StartCall("Total", dev, {}); } } void Profiler::StartCall(String name, Device dev, std::unordered_map extra_metrics) { - in_flight_.push(CallFrame{dev, name, Timer::Start(dev), extra_metrics}); + std::vector> objs; + for (auto& collector : collectors_) { + ObjectRef obj = collector->Start(dev); + if (obj.defined()) { + objs.emplace_back(collector, obj); + } + } + in_flight_.push(CallFrame{dev, name, Timer::Start(dev), extra_metrics, objs}); } void Profiler::StopCall(std::unordered_map extra_metrics) { @@ -118,14 +141,21 @@ void Profiler::StopCall(std::unordered_map extra_metrics for (auto& p : extra_metrics) { cf.extra_metrics[p.first] = p.second; } + // collect the extra metrics from user defined collectors + for (const auto& obj : cf.extra_collectors) { + auto collector_metrics = obj.first->Stop(obj.second); + for (auto& p : collector_metrics) { + cf.extra_metrics[p.first] = p.second; + } + } in_flight_.pop(); calls_.push_back(cf); } void Profiler::Stop() { - // Stop all global timers. We wait to synchronize until we are making the report. - for (auto p : global_timers_) { - p.second->Stop(); + is_running_ = false; + for (size_t i = 0; i < devs_.size(); i++) { + StopCall(); } } @@ -198,6 +228,71 @@ String ReportNode::AsCSV() const { return s.str(); } +namespace { +void print_metric(std::ostream& os, ObjectRef o) { + if (o.as()) { + os << "\"" << Downcast(o) << "\""; + } else if (const CountNode* n = o.as()) { + os << "{\"count\":" << std::to_string(n->value) << "}"; + } else if (const DurationNode* n = o.as()) { + os << "{\"microseconds\":" << std::to_string(n->microseconds) << "}"; + } else if (const PercentNode* n = o.as()) { + os << "{\"percent\":" << std::to_string(n->percent) << "}"; + } else { + LOG(FATAL) << "Unprintable type " << o->GetTypeKey(); + } +} +} // namespace + +String ReportNode::AsJSON() const { + std::ostringstream s; + // DMLC's JSONWriter does not allow us to write a key value pair without + // implementing Write for the value. We want a specific write for the value, + // so we would have to implement a custom data structure for each type of + // value we want to print. Instead we construct the json by hand because it + // is easier. + s << "{"; + s << "\"calls\":["; + for (size_t i = 0; i < calls.size(); i++) { + size_t j = 0; + s << "{"; + for (const auto& kv : calls[i]) { + s << "\"" << kv.first << "\":"; + print_metric(s, kv.second); + if (j < calls[i].size() - 1) { + s << ","; + } + j++; + } + s << "}"; + if (i < calls.size() - 1) { + s << ","; + } + } + s << "],"; + s << "\"device_metrics\":{"; + size_t i = 0; + for (const auto& dev_kv : device_metrics) { + size_t j = 0; + s << "\"" << dev_kv.first << "\":{"; + for (const auto& metric_kv : dev_kv.second) { + s << "\"" << metric_kv.first << "\":"; + print_metric(s, metric_kv.second); + if (j < dev_kv.second.size() - 1) { + s << ","; + } + j++; + } + s << "}"; + if (i < device_metrics.size() - 1) { + s << ","; + } + i++; + } + s << "}}"; + return s.str(); +} + String ReportNode::AsTable(bool sort, bool aggregate) const { // aggregate calls by op hash (or op name if hash is not set) + argument shapes std::vector> aggregated_calls; @@ -396,31 +491,11 @@ std::string DeviceString(Device dev) { } Report Profiler::Report(bool aggregate, bool sort) { - std::vector> global_times; - for (auto p : global_timers_) { - global_times.emplace_back(p.first, p.second->SyncAndGetElapsedNanos() / 1e3); - } - - double overall_time = 0; - for (auto p : global_times) { - overall_time = std::max(overall_time, p.second); - } - - std::unordered_map> device_metrics; - for (auto p : global_times) { - std::unordered_map row; - row["Name"] = String("Total"); - row["Duration (us)"] = ObjectRef(make_object(p.second)); - row["Percent"] = ObjectRef(make_object(p.second / overall_time * 100)); - row["Device"] = String(DeviceString(p.first)); - device_metrics[DeviceString(p.first)] = row; - } - - std::vector> rows; + // sync all timers and normalize rows + std::vector> rows; for (auto& cf : calls_) { std::unordered_map row; double us = cf.timer->SyncAndGetElapsedNanos() / 1e3; - row["Percent"] = ObjectRef(make_object(us / overall_time * 100)); row["Duration (us)"] = ObjectRef(make_object(us)); row["Count"] = ObjectRef(make_object(1)); row["Name"] = cf.name; @@ -431,7 +506,30 @@ Report Profiler::Report(bool aggregate, bool sort) { rows.push_back(row); } - return profiling::Report(rows, device_metrics); + // the last couple of call frames are the overall times + double overall_time_us = 0; + std::unordered_map> device_metrics; + for (size_t i = 0; i < devs_.size(); i++) { + auto row = rows[rows.size() - 1]; + rows.pop_back(); + device_metrics[Downcast(row["Device"])] = row; + overall_time_us = + std::max(overall_time_us, row["Duration (us)"].as()->microseconds); + } + + // Calculate percentages + for (auto& row : rows) { + row["Percent"] = ObjectRef(make_object( + row["Duration (us)"].as()->microseconds / overall_time_us * 100)); + } + + // convert to map + std::vector> converted_rows; + for (const auto& row : rows) { + converted_rows.push_back(row); + } + + return profiling::Report(converted_rows, device_metrics); } Report::Report(Array> calls, @@ -446,8 +544,16 @@ TVM_REGISTER_OBJECT_TYPE(DurationNode); TVM_REGISTER_OBJECT_TYPE(PercentNode); TVM_REGISTER_OBJECT_TYPE(CountNode); TVM_REGISTER_OBJECT_TYPE(ReportNode); +TVM_REGISTER_OBJECT_TYPE(DeviceWrapperNode); +TVM_REGISTER_OBJECT_TYPE(MetricCollectorNode); TVM_REGISTER_GLOBAL("runtime.profiling.AsCSV").set_body_typed([](Report n) { return n->AsCSV(); }); +TVM_REGISTER_GLOBAL("runtime.profiling.AsJSON").set_body_typed([](Report n) { + return n->AsJSON(); +}); +TVM_REGISTER_GLOBAL("runtime.profiling.DeviceWrapper").set_body_typed([](Device dev) { + return DeviceWrapper(dev); +}); } // namespace profiling } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 567557c56794..487ad23e16b9 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -147,12 +147,12 @@ class ROCMWrappedFunc { public: // initialize the ROCM function. void Init(ROCMModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& thread_axis_tags) { + size_t num_void_args, const std::vector& launch_param_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - thread_axis_cfg_.Init(num_void_args, thread_axis_tags); + launch_param_config_.Init(num_void_args, launch_param_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const { @@ -164,13 +164,14 @@ class ROCMWrappedFunc { hipStream_t strm = static_cast(ROCMThreadEntry::ThreadLocal()->stream); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes, HIP_LAUNCH_PARAM_END}; // HIP supports only extra_args. - ROCM_DRIVER_CALL(hipModuleLaunchKernel( - fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), - wl.block_dim(1), wl.block_dim(2), 0, strm, nullptr, reinterpret_cast(&config))); + ROCM_DRIVER_CALL(hipModuleLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), + wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), + wl.block_dim(2), wl.dyn_shmem_size, strm, nullptr, + reinterpret_cast(&config))); } private: @@ -183,8 +184,8 @@ class ROCMWrappedFunc { // Device function cache per device. // mark as mutable, to enable lazy initialization mutable std::array fcache_; - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; + // launch parameters configuration + LaunchParamConfig launch_param_config_; }; PackedFunc ROCMModuleNode::GetFunction(const std::string& name, @@ -195,7 +196,7 @@ PackedFunc ROCMModuleNode::GetFunction(const std::string& name, if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; ROCMWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncPackedArg(f, info.arg_types); } diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 7272269680c5..b9ed54e73508 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -417,8 +417,9 @@ TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") << "Cannot find " << f_preproc_name << " in the global function"; f_preproc = *pf_preproc; } - return WrapTimeEvaluator(m.GetFunction(name, false), dev, number, repeat, min_repeat_ms, - f_preproc); + PackedFunc pf = m.GetFunction(name, false); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global registry"; + return WrapTimeEvaluator(pf, dev, number, repeat, min_repeat_ms, f_preproc); } } else { auto* pf = runtime::Registry::Get(name); diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index cab04ec0db4a..c8d1845266e8 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -258,26 +258,30 @@ class SpscTaskQueue { class ThreadPool { public: ThreadPool() : num_workers_(tvm::runtime::threading::MaxConcurrency()) { - for (int i = 0; i < num_workers_; ++i) { - // The SpscTaskQueue only hosts ONE item at a time - queues_.emplace_back(std::unique_ptr(new SpscTaskQueue())); - } const char* exclude_worker0 = getenv("TVM_EXCLUDE_WORKER0"); if (exclude_worker0 && atoi(exclude_worker0) == 0) { exclude_worker0_ = false; } - threads_ = std::unique_ptr( - new tvm::runtime::threading::ThreadGroup( - num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, - exclude_worker0_ /* include_main_thread */)); - num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_); + Init(); } + ~ThreadPool() { for (std::unique_ptr& q : queues_) { q->SignalForKill(); } threads_.reset(); } + + void Reset() { + for (std::unique_ptr& q : queues_) { + q->SignalForKill(); + } + // Destroy threads before we destory the shared queue, otherwise we segfault on MacOS + threads_.reset(); + queues_.clear(); + Init(); + } + int Launch(FTVMParallelLambda flambda, void* cdata, int num_task, int need_sync) { ParallelLauncher* launcher = ParallelLauncher::ThreadLocal(); ICHECK(!launcher->is_worker) @@ -323,6 +327,19 @@ class ThreadPool { } private: + // Shared initialization code + void Init() { + for (int i = 0; i < num_workers_; ++i) { + // The SpscTaskQueue only hosts ONE item at a time + queues_.emplace_back(std::unique_ptr(new SpscTaskQueue())); + } + threads_ = std::unique_ptr( + new tvm::runtime::threading::ThreadGroup( + num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, + exclude_worker0_ /* include_main_thread */)); + num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_); + } + // Internal worker function. void RunWorker(int worker_id) { SpscTaskQueue* queue = queues_[worker_id].get(); @@ -359,6 +376,10 @@ TVM_REGISTER_GLOBAL("runtime.config_threadpool").set_body([](TVMArgs args, TVMRe ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads); }); +namespace threading { +void ResetThreadPool() { tvm::runtime::ThreadPool::ThreadLocal()->Reset(); } +} // namespace threading + } // namespace runtime } // namespace tvm diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index c0393600b60c..d577770db1a9 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -19,7 +19,7 @@ /*! * \file thread_storage_scope.h - * \brief Extract thread axis configuration from TVMArgs. + * \brief Extract launch parameters configuration from TVMArgs. */ #ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ @@ -29,6 +29,8 @@ #include #include +#include "meta_data.h" + namespace tvm { namespace runtime { @@ -118,7 +120,9 @@ struct StorageScope { */ static StorageScope Create(const std::string& s) { StorageScope r; - if (s.compare(0, 6, "global") == 0) { + if (s.empty()) { + r.rank = StorageRank::kGlobal; + } else if (s.compare(0, 6, "global") == 0) { r.rank = StorageRank::kGlobal; r.tag = s.substr(6, std::string::npos); } else if (s.compare(0, 6, "shared") == 0) { @@ -159,7 +163,7 @@ struct ThreadScope { */ static ThreadScope Create(const std::string& s) { ThreadScope r; - if (s == "vthread" || s == "cthread") { + if (s.compare(0, 7, "vthread") == 0 || s == "cthread") { // virtual thread at the same level as local r.rank = 1; r.dim_index = -1; @@ -180,6 +184,8 @@ struct ThreadScope { struct ThreadWorkLoad { // array, first three are thread configuration. size_t work_size[6]; + // Dynamic shared memory allocation size in bytes. + size_t dyn_shmem_size{0}; /*! * \param i The block dimension. * \return i-th block dim @@ -191,17 +197,23 @@ struct ThreadWorkLoad { */ inline size_t grid_dim(size_t i) const { return work_size[i]; } }; -/*! \brief Thread axis configuration */ -class ThreadAxisConfig { +/*! \brief Launch parameters configuration */ +class LaunchParamConfig { public: - void Init(size_t base, const std::vector& thread_axis_tags) { + void Init(size_t base, const std::vector& launch_param_tags) { base_ = base; std::vector filled(6, false); - for (size_t i = 0; i < thread_axis_tags.size(); ++i) { - const std::string& tag = thread_axis_tags[i]; - ThreadScope ts = ThreadScope::Create(tag); - arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); - filled[ts.rank * 3 + ts.dim_index] = true; + for (size_t i = 0; i < launch_param_tags.size(); ++i) { + const std::string& tag = launch_param_tags[i]; + if (tag == kUseDynamicSharedMemoryTag) { + ICHECK_EQ(i, launch_param_tags.size() - 1) + << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags."; + use_dyn_shared_memory_ = true; + } else { + ThreadScope ts = ThreadScope::Create(tag); + arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); + filled[ts.rank * 3 + ts.dim_index] = true; + } } work_dim_ = 1; for (int i = 0; i < 3; ++i) { @@ -221,6 +233,9 @@ class ThreadAxisConfig { w.work_size[arg_index_map_[i]] = size; } } + if (use_dyn_shared_memory_) { + w.dyn_shmem_size = static_cast(x.values[base_ + arg_index_map_.size()].v_int64); + } return w; } // return the work dim @@ -233,6 +248,8 @@ class ThreadAxisConfig { size_t work_dim_; /*! \brief The index mapping. */ std::vector arg_index_map_; + /*! \brief Whether or not use dynamic shared memory. */ + bool use_dyn_shared_memory_{false}; }; } // namespace runtime diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index e8b948d3d2ae..c2dc0307e166 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -234,8 +234,8 @@ TVMByteArray Executable::Save() { } void Executable::SaveGlobalSection(dmlc::Stream* strm) { - std::vector > globals(this->global_map.begin(), - this->global_map.end()); + std::vector> globals(this->global_map.begin(), + this->global_map.end()); auto comp = [](const std::pair& a, const std::pair& b) { return a.second < b.second; }; @@ -273,6 +273,20 @@ void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { primitive_names[packed_index] = it.first; } strm->Write(primitive_names); + std::map> primitive_attrs; + for (const auto& it : this->op_attrs) { + auto packed_index = static_cast(it.first); + std::map attrs; + for (const auto& elem : it.second) { + // TODO(tkonolige): cannot serialize ObjectRefs with dmlc's serializer, so we just serialize + // strings for now + if (elem.second.as()) { + attrs[elem.first] = Downcast(elem.second); + } + } + primitive_attrs[packed_index] = attrs; + } + strm->Write(primitive_attrs); } // Serialize a virtual machine instruction. It creates a list that contains the @@ -569,6 +583,16 @@ void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { for (size_t i = 0; i < primitive_names.size(); i++) { this->primitive_map.insert({primitive_names[i], i}); } + + std::map> primitive_attrs; + STREAM_CHECK(strm->Read(&primitive_attrs), "primitive attrs"); + for (const auto& fn : primitive_attrs) { + std::vector> attrs; + for (const auto& elem : fn.second) { + attrs.push_back({elem.first, String(elem.second)}); + } + this->op_attrs[fn.first] = Map(attrs.begin(), attrs.end()); + } } // Extract the `cnt` number of fields started at `start` from the list @@ -851,8 +875,8 @@ TVM_REGISTER_GLOBAL("runtime.GetGlobalFields").set_body([](TVMArgs args, TVMRetV const auto* exec = dynamic_cast(mod.operator->()); ICHECK(exec); int idx = args[1]; - std::vector > globals(exec->global_map.begin(), - exec->global_map.end()); + std::vector> globals(exec->global_map.begin(), + exec->global_map.end()); auto comp = [](const std::pair& a, const std::pair& b) { return a.second < b.second; }; diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index a7d65944d581..6d893114d623 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -43,26 +43,31 @@ namespace vm { PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "profile") { - return TypedPackedFunc([sptr_to_self, this](String arg_name) { - std::vector devices; - for (auto dev : devices_) { - if (dev.device_type > 0) { - devices.push_back(dev); - } - } - - auto invoke = VirtualMachine::GetFunction("invoke", sptr_to_self); - // warmup - for (int i = 0; i < 3; i++) { - invoke(arg_name); - } - - prof_ = profiling::Profiler(); // reset profiler - prof_.Start(devices); - invoke(arg_name); - prof_.Stop(); - return prof_.Report(); - }); + return TypedPackedFunc)>( + [sptr_to_self, this](String arg_name, Array collectors) { + std::vector devices; + for (auto dev : devices_) { + if (dev.device_type > 0) { + devices.push_back(dev); + } + } + + std::vector cs(collectors.begin(), collectors.end()); + prof_ = profiling::Profiler(devices, cs); + + auto invoke = VirtualMachine::GetFunction("invoke", sptr_to_self); + // warmup + for (int i = 0; i < 3; i++) { + invoke(arg_name); + } + + prof_.operator*().Start(); + invoke(arg_name); + prof_.operator*().Stop(); + auto report = prof_.operator*().Report(); + prof_ = dmlc::optional(); // releases hardware counters + return report; + }); } else { return VirtualMachine::GetFunction(name, sptr_to_self); } @@ -80,7 +85,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& fun Index output_size, const std::vector& args) { ICHECK(exec_); ICHECK(!devices_.empty()) << "Device has not been initialized yet."; - if (prof_.IsRunning()) { + if (prof_ && prof_.operator*().IsRunning()) { // The device of any input of the operator is used for synchronization. ICHECK_GT(arg_count, 0U); ObjectRef arg = args[0]; @@ -122,11 +127,11 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& fun } metrics["Argument Shapes"] = profiling::ShapeString(shapes); - prof_.StartCall(packed_index_map_[packed_index], dev, metrics); + prof_.operator*().StartCall(packed_index_map_[packed_index], dev, metrics); } VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args); - if (prof_.IsRunning()) { - prof_.StopCall(); + if (prof_ && prof_.operator*().IsRunning()) { + prof_.operator*().StopCall(); } } diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index 521a9bd454e7..1efefda52b97 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -25,6 +25,7 @@ #ifndef TVM_RUNTIME_VM_PROFILER_VM_H_ #define TVM_RUNTIME_VM_PROFILER_VM_H_ +#include #include #include @@ -39,7 +40,7 @@ namespace vm { class VirtualMachineDebug : public VirtualMachine { public: - VirtualMachineDebug() : VirtualMachine() {} + VirtualMachineDebug() : VirtualMachine(), prof_({}) {} PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; @@ -52,7 +53,7 @@ class VirtualMachineDebug : public VirtualMachine { const std::vector& args) final; std::unordered_map packed_index_map_; - profiling::Profiler prof_; + dmlc::optional prof_; }; } // namespace vm diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index c96364108a2a..4df013baa2fb 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -118,6 +118,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, if (name == "invoke") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK(exec_) << "The executable is not created yet."; + std::string func_name = args[0]; auto git = exec_->global_map.find(func_name); ICHECK(git != exec_->global_map.end()) @@ -140,6 +141,26 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, TVMRetValue rv_; invoke.CallPacked(args, &rv_); }); + } else if (name == "invoke_return_to_device") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + Device host{static_cast(args[1].operator int()), args[2].operator int()}; + + SetInput(args[0].operator std::string(), args, 3); + PackedFunc invoke = GetFunction("invoke", sptr_to_self); + TVMRetValue rv_; + invoke.CallPacked(args, &rv_); // Invoke only uses the first arg, so the rest of the args + // should not cause an issue + if (rv_.type_code() == kTVMObjectHandle) { + ADT adt = Downcast(rv_.operator ObjectRef()); + std::vector transfered; + for (size_t i = 0; i < adt.size(); i++) { + transfered.push_back(CopyTo(adt[i], host)); + } + *rv = ADT(adt.tag(), transfered); + } else { + *rv = CopyTo(rv_, host); + } + }); } else if (name == "get_output") { return TypedPackedFunc([this](int64_t index) { if (this->return_register_.as()) { @@ -159,6 +180,21 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, return 1; } }); + } else if (name == "get_input_index") { + return TypedPackedFunc( + [this](std::string input_name, std::string func_name) { + auto gvit = exec_->global_map.find(func_name); + ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name; + auto func_index = gvit->second; + const auto& vm_func = exec_->functions[func_index]; + const auto& param_names = vm_func.params; + for (uint64_t i = 0; i < param_names.size(); i++) { + if (input_name == param_names[i]) { + return static_cast(i); + } + } + return static_cast(-1); + }); } else if (name == "init") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.size() % 3, 0); @@ -176,47 +212,49 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, this->Init(devices, alloc_types); }); } else if (name == "set_input") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - ICHECK(exec_) << "The executable is not created yet."; - std::string func_name = args[0]; - auto gvit = exec_->global_map.find(func_name); - ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name; - auto func_index = gvit->second; - const auto& vm_func = exec_->functions[func_index]; - const auto& param_names = vm_func.params; - ICHECK_EQ(args.size() - 1, param_names.size()) - << "The number of provided parameters doesn't match the number of arguments"; - ICHECK_EQ(param_names.size(), vm_func.params_device_type.size()) - << "The number of provided parameters doesn't match the number of assigned devices"; - std::vector func_args(param_names.size()); - for (int i = 1; i < args.size(); ++i) { - Index device_type = vm_func.params_device_type[i - 1]; - Device dev = GetDevice(device_type); - - if (args[i].type_code() == kTVMDLTensorHandle) { - // Automatically convert input DLTensors to NDArray - DLTensor* tensor = args[i]; - std::vector shape; - for (int64_t i = 0; i < tensor->ndim; i++) { - shape.push_back(tensor->shape[i]); - } - NDArray ary = NDArray::Empty(shape, tensor->dtype, dev); - ary.CopyFrom(tensor); - func_args[i - 1] = ary; - } else { - ObjectRef obj = CopyTo(args[i], dev); - func_args[i - 1] = obj; - } - } - inputs_.erase(func_name); - inputs_.emplace(func_name, func_args); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInput(args[0], args, 1); }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); } } +void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) { + ICHECK(exec_) << "The executable is not created yet."; + auto gvit = exec_->global_map.find(func_name); + ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name; + auto func_index = gvit->second; + const auto& vm_func = exec_->functions[func_index]; + const auto& param_names = vm_func.params; + ICHECK_EQ(args.size() - offset, param_names.size()) + << "The number of provided parameters doesn't match the number of arguments"; + ICHECK_EQ(param_names.size(), vm_func.params_device_type.size()) + << "The number of provided parameters doesn't match the number of assigned devices"; + std::vector func_args(param_names.size()); + for (int i = offset; i < args.size(); ++i) { + Index device_type = vm_func.params_device_type[i - offset]; + Device dev = GetDevice(device_type); + + if (args[i].type_code() == kTVMDLTensorHandle) { + // Automatically convert input DLTensors to NDArray + DLTensor* tensor = args[i]; + std::vector shape; + for (int64_t i = 0; i < tensor->ndim; i++) { + shape.push_back(tensor->shape[i]); + } + NDArray ary = NDArray::Empty(shape, tensor->dtype, dev); + ary.CopyFrom(tensor); + func_args[i - offset] = ary; + } else { + ObjectRef obj = CopyTo(args[i], dev); + func_args[i - offset] = obj; + } + } + inputs_.erase(func_name); + inputs_.emplace(func_name, func_args); +} + inline Device VirtualMachine::GetDevice(Index device_type) const { ICHECK_GE(devices_.size(), device_type) << "devices_ doesn't contain device:" << device_type; diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index a03801cf511f..f1e0ef587ecc 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index 16987210b232..156f86dbb03e 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -156,6 +156,27 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, device_name = properties.properties.deviceName; driver_version = properties.properties.driverVersion; + switch (properties.properties.deviceType) { + case VK_PHYSICAL_DEVICE_TYPE_OTHER: + device_type = "other"; + break; + case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU: + device_type = "integrated"; + break; + case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU: + device_type = "discrete"; + break; + case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU: + device_type = "virtual"; + break; + case VK_PHYSICAL_DEVICE_TYPE_CPU: + device_type = "cpu"; + break; + default: + LOG(FATAL) << "Unknown vulkan device type: " << properties.properties.deviceType; + break; + } + // By default, use the maximum API version that the driver allows, // so that any supported features can be used by TVM shaders. // However, if we can query the conformance version, then limit to diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index 045628bc9092..412542029209 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -92,7 +92,8 @@ struct VulkanDeviceProperties { uint32_t max_storage_buffer_range{1 << 27}; uint32_t max_per_stage_descriptor_storage_buffer{4}; uint32_t max_shared_memory_per_block{16384}; - std::string device_name{"unknown device name"}; + std::string device_type{"unknown_device_type"}; + std::string device_name{"unknown_device_name"}; uint32_t driver_version{0}; uint32_t vulkan_api_version{VK_API_VERSION_1_0}; uint32_t max_spirv_version{0x10000}; diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 1fede98f7211..3d27e1651852 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -50,6 +50,28 @@ VulkanDeviceAPI::VulkanDeviceAPI() { devices_.push_back(std::move(device)); } } + + // Move discrete GPUs to the start of the list, so the default + // device_id=0 preferentially uses a discrete GPU. + auto preference = [](const VulkanDevice& device) { + const std::string& type = device.device_properties.device_type; + if (type == "discrete") { + return 0; + } else if (type == "integrated") { + return 1; + } else if (type == "virtual") { + return 2; + } else if (type == "cpu") { + return 3; + } else { + return 4; + } + }; + + std::stable_sort(devices_.begin(), devices_.end(), + [&preference](const VulkanDevice& a, const VulkanDevice& b) { + return preference(a) < preference(b); + }); } VulkanDeviceAPI::~VulkanDeviceAPI() {} @@ -100,7 +122,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) break; } case kDeviceName: - *rv = prop.device_name; + *rv = std::string(prop.device_name); break; case kMaxClockRate: @@ -214,9 +236,12 @@ void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, if (property == "max_shared_memory_per_block") { *rv = int64_t(prop.max_shared_memory_per_block); } - if (property == ":string device_name") { + if (property == "device_name") { *rv = prop.device_name; } + if (property == "device_type") { + *rv = prop.device_type; + } if (property == "driver_version") { *rv = int64_t(prop.driver_version); } diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/runtime/vulkan/vulkan_device_api.h index b8be3eb43c79..851fede3067f 100644 --- a/src/runtime/vulkan/vulkan_device_api.h +++ b/src/runtime/vulkan/vulkan_device_api.h @@ -106,7 +106,7 @@ class VulkanDeviceAPI final : public DeviceAPI { * Returns the results of feature/property queries done during the * device initialization. */ - void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv); + void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) final; private: std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 103b2aa7692c..0712f723bb64 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -33,13 +33,13 @@ namespace vulkan { void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags) { + const std::vector& launch_param_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; num_buffer_args_ = num_buffer_args; num_pack_args_ = num_pack_args; - thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); + launch_param_config_.Init(num_buffer_args + num_pack_args, launch_param_tags); } void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, @@ -50,7 +50,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_); } const auto& pipeline = scache_[device_id]; - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); std::vector descriptor_buffers; descriptor_buffers.resize(num_buffer_args_); for (size_t i = 0; i < num_buffer_args_; ++i) { @@ -197,7 +197,7 @@ PackedFunc VulkanModuleNode::GetFunction(const std::string& name, VulkanWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, - info.thread_axis_tags); + info.launch_param_tags); return PackFuncNonBufferArg(std::move(f), info.arg_types); } diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index a174f22eba59..cd4774bf0f5a 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -58,7 +58,7 @@ class VulkanWrappedFunc { public: void Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags); + const std::vector& launch_param_tags); void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const; @@ -73,11 +73,10 @@ class VulkanWrappedFunc { size_t num_buffer_args_; // number of packed arguments. size_t num_pack_args_; + // launch parameters configuration + LaunchParamConfig launch_param_config_; // Device state cache per device. // mark as mutable, to enable lazy initialization - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; - mutable std::array, kVulkanMaxNumDevice> scache_; }; diff --git a/src/support/array.h b/src/support/array.h index 2cf416c471ec..89e17433344b 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -18,6 +18,7 @@ */ #ifndef TVM_SUPPORT_ARRAY_H_ #define TVM_SUPPORT_ARRAY_H_ +#include #include #include @@ -67,6 +68,73 @@ inline bool ArrayWithSameContent(const std::vector& a, const std::vector return true; } +/*! + * \brief Convert a tvm::runtime::Array to std::vector + * \tparam TSrc The type of elements in the source Array + * \tparam TDst The type of elements in the result vector + * \return The result vector + */ +template +std::vector AsVector(const Array& vec); + +/********** Implementation details of AsVector **********/ +namespace details { + +template +struct AsVectorImpl {}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& vec) const { + return std::vector(vec.begin(), vec.end()); + } +}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& vec) const { + std::vector results; + for (const TSrcObjectRef& x : vec) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); + } + return results; + } +}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& vec) const { + std::vector results; + for (const TSrcObjectRef& x : vec) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); + } + return results; + } +}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& array) const { + std::vector results; + for (const TSrcObjectRef& x : array) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); + } + return results; + } +}; +} // namespace details + +template +inline std::vector AsVector(const Array& vec) { + return details::AsVectorImpl()(vec); +} + } // namespace support } // namespace tvm #endif // TVM_SUPPORT_ARRAY_H_ diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index a7a91c1bfcdb..c7cec21508ef 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -91,6 +91,14 @@ TVM_REGISTER_GLOBAL("testing.run_check_signal").set_body_typed([](int nsec) { LOG(INFO) << "Function finished without catching signal"; }); +TVM_REGISTER_GLOBAL("testing.identity_cpp").set_body([](TVMArgs args, TVMRetValue* ret) { + const auto* identity_func = tvm::runtime::Registry::Get("testing.identity_py"); + ICHECK(identity_func != nullptr) + << "AttributeError: \"testing.identity_py\" is not registered. Please check " + "if the python module is properly loaded"; + *ret = (*identity_func)(args[0]); +}); + // in src/api_test.cc void ErrorTest(int x, int y) { // raise ValueError diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 6ff40aa20beb..7317cab665cf 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -87,8 +87,8 @@ #define TVM_INFO_USE_GRAPH_EXECUTOR "NOT-FOUND" #endif -#ifndef TVM_INFO_USE_GRAPH_EXECUTOR_DEBUG -#define TVM_INFO_USE_GRAPH_EXECUTOR_DEBUG "NOT-FOUND" +#ifndef TVM_INFO_USE_PROFILER +#define TVM_INFO_USE_PROFILER "NOT-FOUND" #endif #ifndef TVM_INFO_USE_OPENMP @@ -244,7 +244,7 @@ TVM_DLL Map GetLibInfo() { {"CUDA_VERSION", TVM_INFO_CUDA_VERSION}, {"USE_STACKVM_RUNTIME", TVM_INFO_USE_STACKVM_RUNTIME}, {"USE_GRAPH_EXECUTOR", TVM_INFO_USE_GRAPH_EXECUTOR}, - {"USE_GRAPH_EXECUTOR_DEBUG", TVM_INFO_USE_GRAPH_EXECUTOR_DEBUG}, + {"USE_PROFILER", TVM_INFO_USE_PROFILER}, {"USE_OPENMP", TVM_INFO_USE_OPENMP}, {"USE_RELAY_DEBUG", TVM_INFO_USE_RELAY_DEBUG}, {"USE_RTTI", TVM_INFO_USE_RTTI}, diff --git a/src/target/build_common.h b/src/target/build_common.h index d2fe6468eef8..c66c2b52822e 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -53,7 +53,12 @@ inline std::unordered_map ExtractFuncInfo(co if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { auto thread_axis = opt.value(); for (size_t i = 0; i < thread_axis.size(); ++i) { - info.thread_axis_tags.push_back(thread_axis[i]->thread_tag); + info.launch_param_tags.push_back(thread_axis[i]->thread_tag); + } + } + if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { + if (opt.value()) { + info.launch_param_tags.push_back(runtime::kUseDynamicSharedMemoryTag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 78f8a50e4e1b..7770e42086de 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -72,50 +72,45 @@ class CodeGenAMDGPU : public CodeGenLLVM { void VisitStmt_(const AllocateNode* op) final { ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - int32_t constant_size = op->constant_allocation_size(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { + LOG(WARNING) << "Dynamic shared memory support for rocm is experimental."; + buf = AllocateSharedMemory(op->dtype, 0, 3, std::min(info.alignment, 16), + llvm::GlobalValue::ExternalLinkage); + } else { + int32_t constant_size = op->constant_allocation_size(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); - } - // maximum necessary alignment in the AMD devices - if (info.alignment > 16) { - info.alignment = 16; - } - if (info.scope.rank == runtime::StorageRank::kLocal) { - // const int local_address_space = 5; - // TODO(tqchen): for higher version of LLVM, local address space can be set. - llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); - if (alloca->getAlignment() < static_cast(info.alignment)) { -#if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); -#else - alloca->setAlignment(info.alignment); -#endif + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } - buf = alloca; - } else { - ICHECK(info.scope.rank == runtime::StorageRank::kShared) - << "Can only allocate shared or local memory inside kernel"; - // Shared memory: address space == 3 - const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); - // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, - llvm::GlobalValue::NotThreadLocal, shared_address_space); - if (global->getAlignment() < static_cast(info.alignment)) { + // maximum necessary alignment in the AMD devices + if (info.alignment > 16) { + info.alignment = 16; + } + if (storage_scope.rank == runtime::StorageRank::kLocal) { + // const int local_address_space = 5; + // TODO(tqchen): for higher version of LLVM, local address space can be set. + llvm::AllocaInst* alloca = WithFunctionEntry([&]() { + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(info.alignment); + alloca->setAlignment(info.alignment); #endif + } + buf = alloca; + } else { + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) + << "Can only allocate shared or local memory inside kernel"; + // Shared memory: address space == 3 + buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, + llvm::GlobalValue::PrivateLinkage); } - buf = global; } buf = builder_->CreatePointerCast( diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 9d324d56887f..26356a547990 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -704,8 +704,8 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { (void)CallOnce; std::unique_ptr tm = GetLLVMTargetMachine(target); - std::unique_ptr cg(new CodeGenHexagon()); std::unique_ptr ctx(new llvm::LLVMContext()); + std::unique_ptr cg(new CodeGenHexagon()); cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false); for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 48ccefafe3c4..6aabdc1bd804 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -501,7 +501,8 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp auto it = alloc_storage_info_.find(buf_var); if (it != alloc_storage_info_.end()) { const StorageInfo& info = it->second; - *p_native_bits = NativeVectorBits(info.scope); + *p_native_bits = + NativeVectorBits(runtime::StorageScope::Create(GetPtrStorageScope(GetRef(buf_var)))); max_align_bits = info.alignment * 8; } else { *p_native_bits = native_vector_bits_; @@ -523,6 +524,22 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp *p_alignment = align_bits / 8; } +llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t size, + unsigned int shared_address_space, + int alignment, + llvm::GlobalValue::LinkageTypes linkage) { + llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size); + llvm::GlobalVariable* global = + new llvm::GlobalVariable(*module_, type, false, linkage, nullptr, "shmem", nullptr, + llvm::GlobalValue::NotThreadLocal, shared_address_space); +#if TVM_LLVM_VERSION >= 100 + global->setAlignment(llvm::Align(alignment)); +#else + global->setAlignment(alignment); +#endif + return global; +} + std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { #if TVM_LLVM_VERSION >= 100 auto debug_info = std::make_unique(); @@ -844,7 +861,11 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { : llvm::Type::getVoidTy(*ctx_); llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " +#if TVM_LLVM_VERSION <= 130 << llvm::Intrinsic::getName(id, {}); +#else + << llvm::Intrinsic::getName(id, return_type, {}); +#endif return builder_->CreateCall(f, arg_value); } else if (op->op.same_as(builtin::bitwise_and())) { return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1])); @@ -1390,11 +1411,6 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value)); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - alloc_storage_info_[v].scope = - runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index d5fcfab6d889..52c5b98a0025 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -163,8 +163,6 @@ class CodeGenLLVM : public ExprFunctor, protected: /*! \brief The storage information */ struct StorageInfo { - /*! \brief The storage scope */ - runtime::StorageScope scope; /*! \brief The alignment of allocation */ int alignment{0}; }; @@ -294,6 +292,11 @@ class CodeGenLLVM : public ExprFunctor, const Var& loop_var, const Stmt& body); // add alias information. void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index); + + llvm::GlobalVariable* AllocateSharedMemory(DataType dtype, size_t size, + unsigned int shared_address_space, int alignment, + llvm::GlobalValue::LinkageTypes linkage); + // The IRBuilder. using IRBuilder = llvm::IRBuilder; // The current function diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 9e56529ec9ef..15543eda423f 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -48,48 +48,44 @@ class CodeGenNVPTX : public CodeGenLLVM { void VisitStmt_(const AllocateNode* op) final { ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - - int32_t constant_size = op->constant_allocation_size(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); - } // maximum necessary alignment in the NV devices if (info.alignment > 16) { info.alignment = 16; } - if (info.scope.rank == runtime::StorageRank::kLocal) { - // const int local_address_space = 5; - // TODO(tqchen): for higher version of LLVM, local address space can be set. - llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); - if (alloca->getAlignment() < static_cast(info.alignment)) { -#if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); -#else - alloca->setAlignment(info.alignment); -#endif - } - buf = alloca; - } else { - ICHECK(info.scope.rank == runtime::StorageRank::kShared) - << "Can only allocate shared or local memory inside kernel"; + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { // Shared memory: address space == 3 - const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); - // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, - llvm::GlobalValue::NotThreadLocal, shared_address_space); + buf = + AllocateSharedMemory(op->dtype, 0, 3, info.alignment, llvm::GlobalValue::ExternalLinkage); + } else { + int32_t constant_size = op->constant_allocation_size(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + } + if (storage_scope.rank == runtime::StorageRank::kLocal) { + // const int local_address_space = 5; + // TODO(tqchen): for higher version of LLVM, local address space can be set. + llvm::AllocaInst* alloca = WithFunctionEntry([&]() { + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(info.alignment); + alloca->setAlignment(info.alignment); #endif - buf = global; + } + buf = alloca; + } else { + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) + << "Can only allocate shared or local memory inside kernel"; + buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, + llvm::GlobalValue::PrivateLinkage); + } } buf = builder_->CreatePointerCast( diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 15a1493b8585..12c7a3132947 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -98,7 +98,11 @@ class LLVMModuleNode final : public runtime::ModuleNode { void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = runtime::GetFileFormat(file_name, format); std::error_code ecode; +#if TVM_LLVM_VERSION <= 70 llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None); +#else + llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::OF_None); +#endif ICHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); if (fmt == "o" || fmt == "obj") { #if TVM_LLVM_VERSION <= 60 diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 99c9452975d4..a311111532c8 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -83,6 +83,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); this->PrintFuncPrefix(); + this->PrintExtraAttrs(f); this->stream << " " << static_cast(global_symbol.value()) << "("; for (size_t i = 0; i < f->params.size(); ++i) { @@ -105,8 +106,8 @@ void CodeGenC::AddFunction(const PrimFunc& f) { } } - if (no_alias && restrict_keyword_.length() != 0) { - stream << ' ' << restrict_keyword_; + if (no_alias) { + PrintRestrict(v, stream); } } else { PrintType(GetType(v), stream); @@ -125,6 +126,8 @@ void CodeGenC::AddFunction(const PrimFunc& f) { void CodeGenC::PrintFuncPrefix() { stream << "void"; } +void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {} + void CodeGenC::PrintFinalReturn() {} std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); } @@ -861,12 +864,11 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - const VarNode* buffer = op->buffer_var.as(); - auto it = alloc_storage_scope_.find(buffer); - if (it != alloc_storage_scope_.end()) { - std::string scope = alloc_storage_scope_.at(buffer); - PrintStorageScope(scope, stream); - } + + auto scope = GetPtrStorageScope(op->buffer_var); + alloc_storage_scope_[op->buffer_var.get()] = scope; + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); stream << ' ' << vid << '[' << constant_size << "];\n"; @@ -882,10 +884,6 @@ void CodeGenC::VisitStmt_(const AttrStmtNode* op) { BindThreadIndex(iv); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - alloc_storage_scope_[v] = op->value.as()->value; } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); ICHECK(v); @@ -1020,6 +1018,12 @@ void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, return; } +void CodeGenC::PrintRestrict(const Var& v, std::ostream& os) { + if (restrict_keyword_.length() != 0) { + os << ' ' << restrict_keyword_; + } +} + static bool CheckOutermostBracketMatch(const std::string& s) { if (!s.empty() && s.front() == '(' && s.back() == ')') { size_t len = s.size(); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index ae451f39f89b..299f7e0a9cef 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -39,6 +39,7 @@ #include #include +#include "../../tir/transforms/ir_utils.h" #include "codegen_source_base.h" namespace tvm { @@ -102,6 +103,12 @@ class CodeGenC : public ExprFunctor, * Example: stream << "void"; */ virtual void PrintFuncPrefix(); // NOLINT(*) + /*! + * \brief Print extra function attributes + * + * Example: __launch_bounds__(256) for CUDA functions + */ + virtual void PrintExtraAttrs(const PrimFunc& f); /*! * \brief Print the final return at the end the function. */ @@ -193,6 +200,8 @@ class CodeGenC : public ExprFunctor, virtual std::string CastFromTo(std::string value, DataType from, DataType target); // Get load of single element with expression virtual void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os); + // Print restrict keyword for a given Var if applicable + virtual void PrintRestrict(const Var& v, std::ostream& os); protected: // Print reference to struct location diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 6e76c3538e71..0aad18ffb6f9 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -23,7 +23,9 @@ #include "codegen_cuda.h" +#include #include +#include #include #include @@ -46,6 +48,45 @@ void CodeGenCUDA::Init(bool output_ssa) { void CodeGenCUDA::PrintFuncPrefix() { stream << "extern \"C\" __global__ void"; } +class ThreadIdxExtractor : public tir::StmtVisitor { + private: + void VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x") { + threadIdx_x_ext = op->value; + } + if (iv->var->name_hint == "threadIdx.y" || iv->thread_tag == "threadIdx.y") { + threadIdx_y_ext = op->value; + } + if (iv->var->name_hint == "threadIdx.z" || iv->thread_tag == "threadIdx.z") { + threadIdx_z_ext = op->value; + } + } + StmtVisitor::VisitStmt_(op); + } + + public: + PrimExpr threadIdx_x_ext = Integer(1); + PrimExpr threadIdx_y_ext = Integer(1); + PrimExpr threadIdx_z_ext = Integer(1); +}; + +void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) { + ThreadIdxExtractor extractor; + extractor(f->body); + arith::Analyzer analyzer; + PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext * + extractor.threadIdx_z_ext); + if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as()) { + if (threadIdx_ext_int->value == 1) { + // unable to extract the number of threads per block, hence directly return + return; + } + stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; + } +} + std::string CodeGenCUDA::Finish() { if (enable_fp16_) { decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n"; @@ -525,6 +566,8 @@ void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) "all global arrays as input instead"; if (scope == "shared") { os << "__shared__ "; + } else if (scope == "shared.dyn") { + os << "extern __shared__ "; } } @@ -703,14 +746,8 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { std::string vid = AllocVarID(op->buffer_var.get()); this->PrintIndent(); - int32_t constant_size = op->constant_allocation_size(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + std::string scope = GetPtrStorageScope(op->buffer_var); const VarNode* buffer = op->buffer_var.as(); - auto it = alloc_storage_scope_.find(buffer); - ICHECK(it != alloc_storage_scope_.end()) - << "Buffer " << op->buffer_var << " is missing an AttrStmt with a \"storage_scope\" key"; - - std::string scope = it->second; if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || @@ -724,18 +761,28 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { op->dtype == DataType::Int(32)) << "Accumulator only support half, float and int type for now"; } - constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); PrintWmmaScope(scope, op->dtype, buffer, stream); } else { PrintStorageScope(scope, stream); PrintType(op->dtype, stream); } - if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || - op->dtype == DataType::Int(1)) && - scope == "shared") { - constant_size = constant_size / (32 / op->dtype.bits()); + + if (scope == "shared.dyn") { + stream << ' ' << vid << "[];\n"; + } else { + int32_t constant_size = op->constant_allocation_size(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + + if (scope.find("wmma.") == 0) { + constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); + } + if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || + op->dtype == DataType::Int(1)) && + scope == "shared") { + constant_size = constant_size / (32 / op->dtype.bits()); + } + stream << ' ' << vid << '[' << constant_size << "];\n"; } - stream << ' ' << vid << '[' << constant_size << "];\n"; RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 2098b8ac8344..385b7343c8fd 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -46,6 +46,7 @@ class CodeGenCUDA final : public CodeGenC { } // override behavior void PrintFuncPrefix() final; + void PrintExtraAttrs(const PrimFunc& f) final; void VisitStmt_(const ForNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index edb614d9c122..7abff36a3ddb 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -27,18 +27,63 @@ #include #include "../../runtime/opencl/opencl_module.h" +#include "../../runtime/texture.h" #include "../../runtime/thread_storage_scope.h" #include "../build_common.h" namespace tvm { namespace codegen { -CodeGenOpenCL::CodeGenOpenCL() { restrict_keyword_ = "restrict"; } +class InferTextureAccess : public StmtExprVisitor { + public: + static constexpr const uint8_t kReadAccess = 1; + static constexpr const uint8_t kWriteAccess = 2; + + InferTextureAccess() {} + std::unordered_map Infer(const Stmt& n) { + StmtExprVisitor::VisitStmt(n); + std::unordered_map storage_scope_qualifiers; + for (auto& texture : var_access_map_) { + if (texture.second == kReadAccess) { + storage_scope_qualifiers.insert({texture.first, "texture_read"}); + } else if (texture.second == kWriteAccess) { + storage_scope_qualifiers.insert({texture.first, "texture_write"}); + } else if (texture.second == (kReadAccess | kWriteAccess)) { + storage_scope_qualifiers.insert({texture.first, ""}); + } + } + return storage_scope_qualifiers; + } + void VisitExpr_(const CallNode* op) { + if (op->op.same_as(builtin::texture2d_load())) { + var_access_map_[op->args[0].as()] |= kReadAccess; + } else if (op->op.same_as(builtin::texture2d_store())) { + var_access_map_[op->args[0].as()] |= kWriteAccess; + } else { + StmtExprVisitor::VisitExpr_(op); + } + StmtExprVisitor::VisitExpr_(op); + } + + private: + std::unordered_map var_access_map_; +}; + +CodeGenOpenCL::CodeGenOpenCL() { + // Set OpenCL specific restrict keyword + restrict_keyword_ = "restrict"; +} void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); + this->SetTextureScope(InferTextureAccess().Infer(f->body)); for (Var arg : f->params) { - if (arg.dtype().is_handle()) { + auto ptr_type = arg->type_annotation.as(); + if (ptr_type && runtime::IsTextureStorage(std::string(ptr_type->storage_scope))) { + // Storage scope qualifiers for textures are inferred + // and set prior to function codegen. + continue; + } else if (arg.dtype().is_handle()) { alloc_storage_scope_[arg.get()] = "global"; } } @@ -75,6 +120,40 @@ std::string CodeGenOpenCL::Finish() { decl_stream << "#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable\n" "#pragma OPENCL EXTENSION cl_khr_global_int32_extended_atomics : enable\n\n"; } + + // Enable OpenCL 1.2 sampler-less texture reads, but utilize + // provided sampler in OpenCL 2.0. + if (enable_compliant_texture_reads_) { + // TODO(csullivan, lunderberg): Extend device attribute querying to support remote devices + // generically through the device API such that a target can be created from a specific device's + // attributes and utilized during codegen. Potential generlization of #8127 (c02cafb) for remote + // devices. + // + // E.g. Only provide an image sampler when the local or remote device supports OpenCL 2.0, + // see below for context. + // + // For backwards compatibility with OpenCL 1.2, sampler-less read_image calls are used. + // By default in sampler-less read_image calls OpenCL defaults to + // sampler_ = "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST"; + // See section 6.12.14.3 Built-in Image Sampler-less Read Functions in the OpenCL 1.2 + // specification. For OpenCL 2.0 it can be preferable to use, + // sampler_ = "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST"; + // For now we rely on OpenCL preprocessor directives to utilize the correct behavior + // depending on the OpenCL version detected at OpenCL compile time. + decl_stream << "#ifdef __OPENCL_VERSION__\n" + << "#if __OPENCL_VERSION__ == CL_VERSION_2_0\n" + << "#define READ_IMAGEH(image, sampler, coord) " + << "read_imageh(image, sampler, coord)\n" + << "#define READ_IMAGEF(image, sampler, coord) " + << "read_imagef(image, sampler, coord)\n" + << "#else\n" + << "#define READ_IMAGEH(image, sampler, coord) " + << "read_imageh(image, coord)\n" + << "#define READ_IMAGEF(image, sampler, coord) " + << "read_imagef(image, coord)\n" + << "#endif\n" + << "#endif\n\n"; + } return CodeGenC::Finish(); } @@ -162,6 +241,23 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type"; } +void CodeGenOpenCL::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) + if (auto* ptr = type.as()) { + return PrintType(ptr->dtype, os); + } else if (auto* ptr = type.as()) { + if (runtime::IsTextureStorage(std::string(ptr->storage_scope))) { + os << "image2d_t"; + } else { + PrintType(ptr->element_type, os); + os << '*'; + } + } else if (IsVoidType(type)) { + os << "void"; + } else { + LOG(FATAL) << "Type " << type << " does not have a corresponding C Type"; + } +} + void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, std::ostream& os) { // NOLINT(*) if (!HandleTypeMatch(buffer, t.element_of())) { @@ -210,6 +306,19 @@ void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os os << "__global "; } else if (scope == "shared") { os << "__local "; + } else if (scope == "texture_read") { + os << "__read_only "; + } else if (scope == "texture_write") { + os << "__write_only "; + } +} + +void CodeGenOpenCL::PrintRestrict(const Var& v, std::ostream& os) { + // Apply restrict qualifer for non-texture types only + if (auto* ptr = v->type_annotation.as()) { + if (!runtime::IsTextureStorage(std::string(ptr->storage_scope))) { + os << ' ' << restrict_keyword_; + } } } @@ -229,6 +338,39 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType return os.str(); } +void CodeGenOpenCL::VisitStmt_(const StoreNode* op) { + if (auto call = op->value.as()) { + if (call->op.same_as(builtin::texture2d_load())) { + need_texture_ssa_ = false; + // If storing a texture load into a buffer, don't use an + // intermediate local unless the buffer allocation is a + // single element selected from the texture read. + auto it = allocation_size_.find(op->buffer_var.get()); + if (it != allocation_size_.end() && it->second == 1) { + need_texture_ssa_ = true; + } + } + } + CodeGenC::VisitStmt_(op); + need_texture_ssa_ = true; +} + +void CodeGenOpenCL::VisitExpr_(const CastNode* op, std::ostream& os) { + if (auto call = op->value.as()) { + if (call->op.same_as(builtin::texture2d_load())) { + need_texture_ssa_ = false; + } + } + CodeGenC::VisitExpr_(op, os); + need_texture_ssa_ = true; +} + +void CodeGenOpenCL::VisitStmt_(const AllocateNode* op) { + allocation_size_.insert( + {op->buffer_var.get(), op->constant_allocation_size() * op->dtype.lanes()}); + CodeGenC::VisitStmt_(op); +} + void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { if (op->op.same_as(builtin::address_of())) { // Overload tvm_address_of to add storage scope (e.g. __global). @@ -243,6 +385,64 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { os << " *)" << this->GetVarID(load->buffer_var.get()) << " + "; this->PrintExpr(load->index, os); os << ')'; + } else if (op->op.same_as(builtin::texture2d_store())) { + auto* ptr_type = op->args[0].as()->type_annotation.as(); + ICHECK(ptr_type != nullptr) << "Texture Var's must be of PointerType"; + ICHECK(runtime::IsTextureStorage(std::string(ptr_type->storage_scope))) + << "builtin::texture2d_store() only supports storing to texture buffers"; + DataType buffer_type = ptr_type->element_type.as()->dtype; + if (buffer_type.is_float16()) { + os << "write_imageh("; + } else if (buffer_type.is_float()) { + os << "write_imagef("; + } else { + LOG(FATAL) << "Unsupported type: " << buffer_type + << ", currently only float and half are supported for image2d OpenCL codegen."; + } + this->PrintExpr(op->args[0], os); + os << ", "; + os << "(int2)("; + this->PrintExpr(op->args[1], os); + os << ", "; + this->PrintExpr(op->args[2], os); + os << "), "; + this->PrintExpr(op->args[3], os); + os << ")"; + } else if (op->op.same_as(builtin::texture2d_load())) { + enable_compliant_texture_reads_ = true; + std::stringstream ss; + if (op->dtype.is_float16()) { + ss << "READ_IMAGEH("; + } else if (op->dtype.is_float()) { + ss << "READ_IMAGEF("; + } else { + LOG(FATAL) << "Unsupported type: " << op->dtype + << ", currently only float and half are supported for image2d OpenCL codegen."; + } + this->PrintExpr(op->args[0], ss); + ss << ", "; + ss << "CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST, "; + ss << "((int2)("; + this->PrintExpr(op->args[1], ss); + ss << ", "; + this->PrintExpr(op->args[2], ss); + ss << ")))"; + + // Only use local SSA if texture is not already being stored + if (need_texture_ssa_) { + std::string rhs = SSAGetID(ss.str(), op->dtype.with_lanes(4)); + if (op->args.back().as()) { + os << rhs; + } else { + os << "(("; + this->PrintType(op->dtype.with_lanes(1), os); + os << "*)&" << rhs << ")["; + this->PrintExpr(op->args.back(), os); + os << "]"; + } + } else { + os << ss.str(); + } } else if (op->op.same_as(builtin_call_extern_)) { auto func = Downcast(op->args[0]); // Enable atomics extension if used. @@ -280,6 +480,13 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N } } +void CodeGenOpenCL::SetTextureScope( + const std::unordered_map& scope) { // NOLINT(*) + for (auto& texture : scope) { + alloc_storage_scope_.insert(texture); + } +} + runtime::Module BuildOpenCL(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 32102fec22b9..a8c293c03056 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -27,6 +27,7 @@ #include #include +#include #include "codegen_c.h" @@ -45,18 +46,24 @@ class CodeGenOpenCL final : public CodeGenC { void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintType(const Type& type, std::ostream& os) final; // NOLINT(*) std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) final; void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value) final; // NOLINT(*) // the address of load/store void PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, - std::ostream& os); // NOLINT(*) - std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) + std::ostream& os); // NOLINT(*) + void PrintRestrict(const Var& v, std::ostream& os) final; // NOLINT(*) + std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) + void SetTextureScope(const std::unordered_map&); // NOLINT(*) // overload visitor - void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) + void VisitStmt_(const AllocateNode* op) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) + void VisitStmt_(const StoreNode* op) final; // NOLINT(*) private: // whether enable fp16 and fp64 extension @@ -64,6 +71,15 @@ class CodeGenOpenCL final : public CodeGenC { bool enable_fp64_{false}; // Whether to enable atomics extension. bool enable_atomics_{false}; + // Whether to enable sampler or sampler-less texture reads, + // where the choice depends on the OpenCL version used. + bool enable_compliant_texture_reads_{false}; + // Key to disable use of texture SSA in certain scenarios. For example, + // when loaded value is stored directly to a user declared l-value buffer + bool need_texture_ssa_{true}; + // Mapping from buffer to allocation size. + // Useful to track when a scalar store of a vectorized texture load is required. + std::unordered_map allocation_size_; }; } // namespace codegen diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 288bb2cfc069..64a50c3c84b1 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -49,6 +49,9 @@ TVM_REGISTER_OP("tir.round") TVM_REGISTER_OP("tir.exp").set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); +TVM_REGISTER_OP("tir.erf").set_attr("opencl.FLowerIntrinsic", + DispatchPureExtern); + TVM_REGISTER_OP("tir.exp2") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index ac4d7e3666ea..7728773b13d7 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -192,25 +192,26 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { << "}\n"; } - void GenerateEntrypointForUnpackedAPI(const std::string& run_func) { + void GenerateEntrypointForUnpackedAPI(const std::string& entrypoint_name, + const std::string& run_func) { code_ << "TVM_DLL int32_t " << run_func << "("; - int total_args = (metadata_->num_inputs + metadata_->num_outputs); - for (int i = 0; i < total_args; ++i) { - code_ << "arg" << i; + unsigned int total_args = (metadata_->inputs.size() + metadata_->num_outputs); + for (unsigned int i = 0; i < total_args; ++i) { + code_ << "void* arg" << i; if (i + 1 != total_args) { code_ << ","; } } code_ << ");\n"; - code_ << "static int32_t " << ::tvm::runtime::symbol::tvm_module_main; + code_ << "int32_t " << entrypoint_name; code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " "out_type_code, void* resource_handle) {\n"; code_ << "return " << run_func << "("; - for (int i = 0; i < metadata_->num_inputs; ++i) { + for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) { code_ << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,"; } for (int i = 0; i < metadata_->num_outputs; ++i) { - int j = metadata_->num_inputs + i; + int j = metadata_->inputs.size() + i; code_ << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data"; if (i + 1 != metadata_->num_outputs) { code_ << ","; @@ -220,11 +221,12 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "}\n"; } - void GenerateEntrypointForPackedAPI(const std::string& run_func) { + void GenerateEntrypointForPackedAPI(const std::string& entrypoint_name, + const std::string& run_func) { code_ << "TVM_DLL int32_t " << run_func; code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " "out_type_code, void* resource_handle);\n"; - code_ << "static int32_t " << ::tvm::runtime::symbol::tvm_module_main; + code_ << "int32_t " << entrypoint_name; code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " "out_type_code, void* resource_handle) {\n"; code_ << "return " << run_func; @@ -232,25 +234,70 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "}\n"; } + void GenerateCInterfaceEntrypoint(const std::string& entrypoint_name, const std::string& run_func, + const std::string& mod_name) { + code_ << "#include <" << mod_name << ".h>\n"; + code_ << "TVM_DLL int32_t " << run_func << "("; + unsigned int total_args = (metadata_->inputs.size() + metadata_->num_outputs); + for (unsigned int i = 0; i < total_args; ++i) { + code_ << "void* arg" << i; + if (i + 1 != total_args) { + code_ << ","; + } + } + code_ << ");\n"; + code_ << "int32_t " << entrypoint_name << "("; + code_ << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "* inputs," + << "struct " << runtime::get_name_mangled(mod_name, "outputs") << "* outputs" + << ") {"; + code_ << "return " << run_func << "("; + for (const auto& input : metadata_->inputs) { + code_ << "inputs->" << input << ","; + } + if (metadata_->num_outputs == 1) { + code_ << "outputs->output"; + } else { + for (int i = 0; i < metadata_->num_outputs; ++i) { + code_ << "outputs->output" << i; + if (i + 1 != metadata_->num_outputs) { + code_ << ","; + } + } + } + code_ << ");\n"; + code_ << "}\n"; + } + void GenerateAOTDescriptor() { - const std::string run_func = ::tvm::runtime::symbol::tvm_run_func_suffix; - const std::string run_func_mangled = runtime::get_name_mangled(metadata_->mod_name, run_func); + const std::string run_func_suffix = ::tvm::runtime::symbol::tvm_run_func_suffix; + const std::string tvm_entrypoint_suffix = ::tvm::runtime::symbol::tvm_entrypoint_suffix; + const std::string run_func_mangled = + runtime::get_name_mangled(metadata_->mod_name, run_func_suffix); + const std::string entrypoint_mangled = + runtime::get_name_mangled(metadata_->mod_name, tvm_entrypoint_suffix); const std::string network_mangled = runtime::get_name_mangled(metadata_->mod_name, "network"); - code_ << "#include \"tvm/runtime/crt/internal/aot_executor/aot_executor.h\"\n"; + auto unpacked_api = target_->GetAttr("unpacked-api").value_or(Bool(false)); + auto interface_api = target_->GetAttr("interface-api").value_or(String("packed")); + code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n"; code_ << "#ifdef __cplusplus\n"; - code_ << "extern \"C\"\n"; + code_ << "extern \"C\" {\n"; code_ << "#endif\n"; - if (target_->GetAttr("unpacked-api").value_or(Bool(false))) { - GenerateEntrypointForUnpackedAPI(run_func_mangled); + + if (unpacked_api) { + if (interface_api == "c") { + GenerateCInterfaceEntrypoint(entrypoint_mangled, run_func_mangled, metadata_->mod_name); + } else { + GenerateEntrypointForUnpackedAPI(entrypoint_mangled, run_func_mangled); + } } else { - GenerateEntrypointForPackedAPI(run_func_mangled); + ICHECK_EQ(interface_api, "packed") << "Packed interface required for packed operators"; + GenerateEntrypointForPackedAPI(entrypoint_mangled, run_func_mangled); } - code_ << "const tvm_model_t " << network_mangled << " = {\n" - << " .run_func = &" << ::tvm::runtime::symbol::tvm_module_main << ",\n" - << " .num_input_tensors = " << metadata_->num_inputs << ",\n" - << " .num_output_tensors = " << metadata_->num_outputs << ", \n" - << "};\n"; + + code_ << "#ifdef __cplusplus\n"; + code_ << "}\n"; + code_ << "#endif\n"; } void CreateSource() { diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 5d52bee44e98..66952dae269e 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -32,6 +32,7 @@ #include "../../runtime/pack_args.h" #include "../../runtime/vulkan/vulkan_common.h" #include "../../runtime/vulkan/vulkan_shader.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { namespace codegen { @@ -42,7 +43,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: this->InitFuncState(); ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; - uint32_t num_buffer = 0; + uint32_t i_buffer = 0; // Currently, all storage and uniform buffer arguments are passed as // a single descriptor set at index 0. If ever non-zero, must @@ -52,24 +53,25 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: for (Var arg : f->params) { DataType t = arg.dtype(); if (t.is_handle()) { - if (auto* ptr = arg->type_annotation.as()) { - auto* prim = ptr->element_type.as(); - ICHECK(prim); - DataType value_storage_type = prim->dtype; - if (value_storage_type == DataType::UInt(1)) { - // We need a physically addressable buffer type to support boolean tensors. - // The loaded byte is cast to bool inside the LoadNode visitor below. - value_storage_type = DataType::UInt(8); - } - spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type), - descriptor_set, num_buffer); - builder_->SetName(arg_value, arg->name_hint); - storage_info_[arg.get()].UpdateContentType(value_storage_type); - var_map_[arg.get()] = arg_value; - } else { - LOG(FATAL) << "require all handles to be typed"; + auto* ptr = arg->type_annotation.as(); + ICHECK(ptr) << "All handles passed to the Vulkan codegen must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; + auto* prim = ptr->element_type.as(); + ICHECK(prim) << "All handles passed to the Vulkan codegen must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; + DataType value_storage_type = prim->dtype; + if (value_storage_type == DataType::Bool()) { + // We need a physically addressable buffer type to support boolean tensors. + // The loaded byte is cast to bool inside the LoadNode visitor below. + value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes()); } - ++num_buffer; + spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type), + descriptor_set, i_buffer++); + builder_->SetName(arg_value, arg->name_hint); + storage_info_[arg.get()].SetContentType(value_storage_type, arg->name_hint); + var_map_[arg.get()] = arg_value; } else { pod_args.push_back(arg); } @@ -94,7 +96,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: } else { shader.flag |= 1 << runtime::vulkan::ShaderMetaDataFlagMask::kUseUBO; // If we need to pass more arguments than push constants could handle, we use UBO. - spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, num_buffer); + spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, i_buffer++); for (size_t i = 0; i < pod_args.size(); ++i) { spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast(i)); var_map_[pod_args[i].get()] = value; @@ -108,6 +110,14 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: builder_->CommitKernelFunction(func_ptr, name); + ICHECK_LE(shared_memory_bytes_used_, spirv_support_.max_shared_memory_per_block) + << "Vulkan shader " << name << " uses " << shared_memory_bytes_used_ + << " bytes of shared memory, " + << "but target supports only " << spirv_support_.max_shared_memory_per_block << " bytes. " + << "If the device supports this allocation, " + << "please add -max_shared_memory_per_block=NBYTES to the target, " + << "or query all device parameters by adding -from_device=0."; + shader.data = builder_->Finalize(); return shader; } @@ -119,6 +129,7 @@ void CodeGenSPIRV::InitFuncState() { analyzer_.reset(new arith::Analyzer()); builder_.reset(new spirv::IRBuilder(spirv_support_)); builder_->InitHeader(); + shared_memory_bytes_used_ = 0; } spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) { @@ -403,14 +414,19 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { ICHECK(is_one(op->predicate)); - auto it = storage_info_.find(op->buffer_var.get()); + + DataType desired_read_type = op->dtype; + if (desired_read_type == DataType::Bool()) { + desired_read_type = boolean_storage_type_.with_lanes(desired_read_type.lanes()); + } + + const VarNode* buffer_var = op->buffer_var.get(); + auto it = storage_info_.find(buffer_var); ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; - if (!info.content_fixed) { - info.UpdateContentType(op->dtype); - } + info.CheckContentType(desired_read_type, op->index.dtype().lanes()); - spirv::SType content_type = builder_->GetSType(info.content_type); + spirv::SType content_type = builder_->GetSType(info.element_type); spirv::Value buffer = MakeValue(op->buffer_var); spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); @@ -418,47 +434,38 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { if (info.is_volatile) { mask |= spv::MemoryAccessVolatileMask; } - if (op->dtype.lanes() == 1) { + + if (desired_read_type == info.element_type) { + // Requested a single value from an array. This may be a scalar load + // or a vectorized load, based on the array element type. spirv::Value index = MakeValue(op->index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); - if (op->dtype == DataType::UInt(1)) { - // A bool tensor is backed by a byte buffer, we cast to bool here. - auto bool_ty = builder_->GetSType(DataType::UInt(1)); - return builder_->Cast(bool_ty, loaded); - } else { - ICHECK_EQ(info.content_type, op->dtype) - << "Vulkan only allow one type access to the same buffer"; - return loaded; + // OpTypeBool have no physical address/storage. Here, cast from + // the storage type to an OpTypeBool. + if (op->dtype == DataType::Bool()) { + auto spirv_bool = builder_->GetSType(DataType::Bool()); + loaded = builder_->Cast(spirv_bool, loaded); } + return loaded; + + } else if (desired_read_type.element_of() == info.element_type) { + // Requested several elements returned as an array. Read out each + // element and concatenate into the result. + std::vector values; + auto f = [&](int i, spirv::Value index) { + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); + values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); + }; + this->Scalarize(op->index, f); + return builder_->Concat(values); + } else { - if (op->dtype.element_of() == info.content_type) { - // because content type is element type, we can only do scalarize load. - std::vector values; - auto f = [&](int i, spirv::Value index) { - spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); - values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); - }; - this->Scalarize(op->index, f); - return builder_->Concat(values); - } else { - if (const RampNode* ramp = op->index.as()) { - if (is_one(ramp->stride)) { - ICHECK_EQ(ramp->lanes, op->dtype.lanes()); - arith::ModularSet me = analyzer_->modular_set(ramp->base); - ICHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) - << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = - analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); - spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index)); - return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); - } - } - } - LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV"; + LOG(FATAL) << "Cannot perform buffer access of buffer variable '" << buffer_var->name_hint + << "' with element type " << info.element_type << " using index of type " + << op->index->dtype << " to produce output of type " << op->dtype; + return spirv::Value(); } - LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV"; - return spirv::Value(); } void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function f) { @@ -481,12 +488,9 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { auto it = storage_info_.find(op->buffer_var.get()); ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; + info.CheckContentType(op->value.dtype(), op->index.dtype().lanes()); - if (!info.content_fixed) { - info.UpdateContentType(op->value.dtype()); - } - - spirv::SType content_type = builder_->GetSType(info.content_type); + spirv::SType content_type = builder_->GetSType(info.element_type); spirv::Value buffer = MakeValue(op->buffer_var); spirv::Value value = MakeValue(op->value); spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); @@ -496,37 +500,29 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { mask |= spv::MemoryAccessVolatileMask; } - if (op->value.dtype().lanes() == 1) { - ICHECK_EQ(info.content_type, op->value.dtype()) + if (op->value.dtype() == info.element_type) { + // Requested store of a single value. This may be a scalar store + // or a vectorized store, based on the array element type. + ICHECK_EQ(info.element_type, op->value.dtype()) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(op->index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, value, mask); + + } else if (op->value.dtype().element_of() == info.element_type) { + // Requested store of several arbitrarily located values. Extract + // each value from the composite, then assign to the buffer. + auto f = [&](int i, spirv::Value index) { + spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); + builder_->MakeInst(spv::OpStore, ptr, elem, mask); + }; + this->Scalarize(op->index, f); + } else { - if (op->value.dtype().element_of() == info.content_type) { - // because content type is element type, we can only do scalarize load. - auto f = [&](int i, spirv::Value index) { - spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i); - spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); - builder_->MakeInst(spv::OpStore, ptr, elem, mask); - }; - this->Scalarize(op->index, f); - } else { - if (const RampNode* ramp = op->index.as()) { - if (is_one(ramp->stride)) { - ICHECK_EQ(ramp->lanes, op->value.dtype().lanes()); - arith::ModularSet me = analyzer_->modular_set(ramp->base); - ICHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) - << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = - analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); - spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index)); - builder_->MakeInst(spv::OpStore, ptr, value, mask); - return; - } - } - LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV"; - } + LOG(FATAL) << "Cannot store value of type " << op->value.dtype() << " into buffer variable '" + << op->buffer_var->name_hint << "' with element type " << info.element_type + << " using index of type " << op->index->dtype; } } @@ -644,24 +640,30 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { ICHECK(!op->dtype.is_handle()); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + spirv::Value buf; - StorageInfo& info = storage_info_[op->buffer_var.get()]; + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); spirv::SType etype = builder_->GetSType(op->dtype); - if (info.scope.rank == runtime::StorageRank::kLocal) { + if (storage_scope.rank == runtime::StorageRank::kLocal) { buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassFunction); - } else if (info.scope.rank == runtime::StorageRank::kShared) { + } else if (storage_scope.rank == runtime::StorageRank::kShared) { // Shared memory buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassWorkgroup); + + size_t num_bytes = op->dtype.bytes() * op->dtype.lanes() * static_cast(constant_size); + shared_memory_bytes_used_ += num_bytes; } else { LOG(FATAL) << "Can only allocate shared or local memory inside kernel"; } builder_->SetName(buf, op->buffer_var->name_hint); - ICHECK(!info.content_fixed); - info.UpdateContentType(op->dtype); + StorageInfo& info = storage_info_[op->buffer_var.get()]; + ICHECK(!info.element_type_known); + info.SetContentType(op->dtype, op->buffer_var->name_hint); + ICHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -677,10 +679,6 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { var_map_[iv->var.get()] = GetThreadIndex(iv, op->value); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - storage_info_[v].scope = runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 3868322a74e0..74b62e7613d1 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -114,51 +114,110 @@ class CodeGenSPIRV : public ExprFunctor, void VisitStmt_(const EvaluateNode* op) override; protected: - /*! \brief The storage information */ + /*! \brief Storage information for a buffer */ struct StorageInfo { - /*! \brief The storage scope */ - runtime::StorageScope scope; + /*! \brief The name of the tir::Var for the buffer + * + * Used for error messages. + */ + std::string name_hint; + /*! \brief Whether it is volatile */ bool is_volatile{false}; - /*! \brief Whether it is volatile */ - bool content_fixed{false}; - /*! \brief Current content type */ - DataType content_type{DataType::Handle()}; - - // Update content type if it hasn't beenupdated. - void UpdateContentType(DataType type) { - if (content_fixed) { - ICHECK_EQ(type, content_type) << "Cannot use two different content type in GLSL model"; - } else { - this->content_type = type; - content_fixed = true; - } + + /*! \brief Whether the element type of the buffer is known. + * + * This value is determined based on the type_annotation of the + * buffer variable (AllocateNode) or of the parameter (shader + * arguments). + */ + bool element_type_known{false}; + + /*! \brief The known element type of the buffer. + * + * This value is determined based on the type_annotation of the + * buffer variable (AllocateNode) or of the parameter (shader + * arguments). + */ + DataType element_type{DataType()}; + + /* \brief Check that the access type matches the known type + * + * Asserts that the type given is the same as the type previously + * stored in this array. + * + * @param type The data type being stored/loaded in the buffer + * + * @param index_lanes The number of lanes of the index. The + * number of lanes in the value being stored/loaded should be the + * product of the number of lanes of the buffer element type and + * the number of lanes of the index. + */ + void CheckContentType(DataType type, int index_lanes = 1) { + ICHECK(element_type_known) << "Cannot check element type of buffer " << name_hint + << " no previous element type defined"; + DataType expected_type = element_type.with_lanes(index_lanes * element_type.lanes()); + ICHECK_EQ(type, expected_type) << "Attempted to access buffer " << name_hint + << " as element type " << type << " using an index of size " + << index_lanes << " when the element type is " << element_type; + } + + // Update content type if it hasn't been updated. + void SetContentType(DataType type, std::string name_hint) { + ICHECK(!element_type_known) << "Cannot set element type of buffer " << name_hint + << " a second time."; + this->element_type = type; + this->name_hint = name_hint; + element_type_known = true; } }; // Reset the state so it works for a new function. void InitFuncState(); // Get the thread index spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent); + spirv::Value CreateStorageSync(const CallNode* op); void Scalarize(const PrimExpr& e, std::function f); + // SPIRV-related capabilities of the target SPIRVSupport spirv_support_; + // The builder std::unique_ptr builder_; + // Work group size of three uint32_t workgroup_size_[3]; + // Likely branch uint32_t weight_likely_branch_{128}; + + /* The data type used for the backing array for booleans. + * + * Currently matched to the data type used in Buffer::vstore and + * Buffer::vload. In the future, this should be the smallest + * integer type supported by the device, as not all Vulkan + * implementations support int8. + */ + DataType boolean_storage_type_{DataType::Int(8)}; + // the storage scope of allocation std::unordered_map storage_info_; + // The definition of local variable. std::unordered_map var_map_; + // The analyzer. std::unique_ptr analyzer_; + // deep comparison of PrimExpr ExprDeepEqual deep_equal_; + // binding of let variables. Enables duplicate var defs that map to same value std::unordered_map let_binding_; + + // Running total of the number of bytes of shared memory used. + // Checked against the max_shared_memory_per_group + size_t shared_memory_bytes_used_{0}; }; } // namespace codegen diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index 4a294d56bd9c..0f1207f3e9a8 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -52,6 +52,9 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) { if (target->GetAttr("max_storage_buffer_range")) { max_storage_buffer_range = target->GetAttr("max_storage_buffer_range").value(); } + if (target->GetAttr("max_shared_memory_per_block")) { + max_shared_memory_per_block = target->GetAttr("max_shared_memory_per_block").value(); + } if (target->GetAttr("max_per_stage_descriptor_storage_buffer")) { max_per_stage_descriptor_storage_buffers = target->GetAttr("max_per_stage_descriptor_storage_buffer").value(); diff --git a/src/target/spirv/spirv_support.h b/src/target/spirv/spirv_support.h index 1497c7c6333a..04d13cca5031 100644 --- a/src/target/spirv/spirv_support.h +++ b/src/target/spirv/spirv_support.h @@ -101,6 +101,22 @@ struct SPIRVSupport { */ uint32_t max_storage_buffer_range{1 << 27}; + /*! + * \brief The maximum amount of shared memory usable by a shader + * + * Vulkan extension: N/A + * Vulkan struct: VkPhysicalDeviceLimits + * Device Property: maxComputeSharedMemorySize + * SPV Extension name: N/A + * SPV Capability: N/A + * + * The maximum amount of shared memory (Workgroup scope) that may be + * allocated by a shader. Default value is from Vulkan spec, + * "Required Limits" table. Implementations may have a larger + * limit. + */ + uint32_t max_shared_memory_per_block{16384}; + /*! * \brief The maximum number of storage buffers accessible by a single shader. * diff --git a/src/target/target.cc b/src/target/target.cc index df810185784e..e0b9539380d7 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -21,6 +21,7 @@ * \file src/target/target.cc */ #include +#include #include #include #include @@ -29,6 +30,7 @@ #include #include +#include #include #include "../runtime/object_internal.h" @@ -57,6 +59,9 @@ class TargetInternal { n->host = target_host; return (Target)n; } + + private: + static std::unordered_map QueryDevice(int device_id, const TargetNode* target); }; /********** Helper functions **********/ @@ -146,17 +151,83 @@ static int FindFirstSubstr(const std::string& str, const std::string& substr) { } static Optional JoinString(const std::vector& array, char separator) { + char escape = '\\'; + char quote = '\''; + if (array.empty()) { return NullOpt; } + std::ostringstream os; - os << array[0]; - for (size_t i = 1; i < array.size(); ++i) { - os << separator << array[i]; + + for (size_t i = 0; i < array.size(); ++i) { + if (i > 0) { + os << separator; + } + + std::string str = array[i]; + + if ((str.find(separator) == std::string::npos) && (str.find(quote) == std::string::npos)) { + os << str; + } else { + os << quote; + for (char c : str) { + if (c == quote) { + os << escape; + } + os << c; + } + os << quote; + } } return String(os.str()); } +static std::vector SplitString(const std::string& str, char separator) { + char escape = '\\'; + char quote = '\''; + + std::vector output; + + const char* start = str.data(); + const char* end = start + str.size(); + const char* pos = start; + + std::stringstream current_word; + + auto finish_word = [&]() { + std::string word = current_word.str(); + if (word.size()) { + output.push_back(word); + current_word.str(""); + } + }; + + bool pos_quoted = false; + + while (pos < end) { + if ((*pos == separator) && !pos_quoted) { + finish_word(); + pos++; + } else if ((*pos == escape) && (pos + 1 < end) && (pos[1] == quote)) { + current_word << quote; + pos += 2; + } else if (*pos == quote) { + pos_quoted = !pos_quoted; + pos++; + } else { + current_word << *pos; + pos++; + } + } + + ICHECK(!pos_quoted) << "Mismatched quotes '' in string"; + + finish_word(); + + return output; +} + static int ParseKVPair(const std::string& s, const std::string& s_next, std::string* key, std::string* value) { int pos; @@ -206,9 +277,9 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi ObjectRef TargetInternal::ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info) { - std::istringstream is(str); if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing integer + std::istringstream is(str); int v; if (!(is >> v)) { std::string lower(str.size(), '\x0'); @@ -225,19 +296,18 @@ ObjectRef TargetInternal::ParseType(const std::string& str, } return Integer(v); } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - // Parsing string - std::string v; - if (!(is >> v)) { - throw Error(": Cannot parse into type \"String\" from string: " + str); - } - return String(v); + // Parsing string, strip leading/trailing spaces + auto start = str.find_first_not_of(' '); + auto end = str.find_last_not_of(' '); + return String(str.substr(start, (end - start + 1))); + } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing target return Target(TargetInternal::FromString(str)); } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) { // Parsing array std::vector result; - for (std::string substr; std::getline(is, substr, ',');) { + for (const std::string& substr : SplitString(str, ',')) { try { ObjectRef parsed = TargetInternal::ParseType(substr, *info.key); result.push_back(parsed); @@ -549,24 +619,14 @@ ObjectPtr TargetInternal::FromConfigString(const String& config_str) { } ObjectPtr TargetInternal::FromRawString(const String& target_str) { + ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string"; // Split the string by empty spaces - std::string name; - std::vector options; - std::string str; - for (std::istringstream is(target_str); is >> str;) { - if (name.empty()) { - name = str; - } else { - options.push_back(str); - } - } - if (name.empty()) { - throw Error(": Cannot parse empty target string"); - } + std::vector options = SplitString(std::string(target_str), ' '); + std::string name = options[0]; // Create the target config std::unordered_map config = {{"kind", String(name)}}; TargetKind kind = GetTargetKind(name); - for (size_t iter = 0, end = options.size(); iter < end;) { + for (size_t iter = 1, end = options.size(); iter < end;) { std::string key, value; try { // Parse key-value pair @@ -673,6 +733,21 @@ ObjectPtr TargetInternal::FromConfig(std::unordered_map(attrs.at("from_device")); + attrs.erase("from_device"); + auto device_params = QueryDevice(device_id, target.get()); + + for (const auto& kv : device_params) { + if (attrs.count(kv.first) == 0) { + attrs[kv.first] = kv.second; + } + } + } + // set default attribute values if they do not exist for (const auto& kv : target->kind->key2default_) { if (!attrs.count(kv.first)) { @@ -688,6 +763,69 @@ ObjectPtr TargetInternal::FromConfig(std::unordered_map TargetInternal::QueryDevice(int device_id, + const TargetNode* target) { + std::unordered_map output; + + Device device{static_cast(target->kind->device_type), device_id}; + + auto api = runtime::DeviceAPI::Get(device, true); + if (!api) { + LOG(INFO) << "Requested reading the parameters for " << target->kind->name << " from device_id " + << device_id << ", but support for this runtime wasn't enabled at compile-time. " + << "Using default target parameters."; + return output; + } + + TVMRetValue ret; + api->GetAttr(device, runtime::kExist, &ret); + if (!ret) { + ICHECK(ret) << "Requested reading the parameters for " << target->kind->name + << " from device_id " << device_id << ", but device_id " << device_id + << " doesn't exist. Using default target parameters."; + return output; + } + + for (const auto& kv : target->kind->key2vtype_) { + const String& key = kv.first; + const TargetKindNode::ValueTypeInfo& type_info = kv.second; + + TVMRetValue ret; + api->GetTargetProperty(device, key, &ret); + + switch (ret.type_code()) { + case kTVMNullptr: + // Nothing returned for this parameter, move on to the next one. + continue; + + case kTVMArgInt: + if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + output[key] = Integer(static_cast(ret)); + } else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + output[key] = Bool(static_cast(ret)); + } else { + LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key + << "', but received integer from device api"; + } + break; + + case kTVMStr: + ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex()) + << "Expected " << type_info.type_key << " parameter for attribute '" << key + << "', but received string from device api"; + output[key] = String(ret.operator std::string()); + break; + + default: + LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key + << "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api"; + break; + } + } + + return output; +} + /********** Registry **********/ TVM_REGISTER_GLOBAL("target.Target").set_body(TargetInternal::ConstructorDispatcher); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index d037b9dfdbdb..97317b5c4800 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -209,85 +209,6 @@ Map UpdateROCmAttrs(Map attrs) { return attrs; } -/*! - * \brief Update the attributes in the Vulkan target. - * \param attrs The original attributes - * \return The updated attributes - */ -Map UpdateVulkanAttrs(Map attrs) { - if (attrs.count("from_device")) { - int device_id = Downcast(attrs.at("from_device")); - Device device{kDLVulkan, device_id}; - const PackedFunc* get_target_property = - runtime::Registry::Get("device_api.vulkan.get_target_property"); - ICHECK(get_target_property) - << "Requested to read Vulkan parameters from device, but no Vulkan runtime available"; - - // Current vulkan implementation is partially a proof-of-concept, - // with long-term goal to move the -from_device functionality to - // TargetInternal::FromConfig, and to be usable by all targets. - // The duplicate list of parameters is needed until then, since - // TargetKind::Get("vulkan")->key2vtype_ is private. - std::vector bool_opts = { - "supports_float16", "supports_float32", - "supports_float64", "supports_int8", - "supports_int16", "supports_int32", - "supports_int64", "supports_8bit_buffer", - "supports_16bit_buffer", "supports_storage_buffer_storage_class", - "supports_push_descriptor", "supports_dedicated_allocation"}; - std::vector int_opts = {"supported_subgroup_operations", - "max_num_threads", - "thread_warp_size", - "max_block_size_x", - "max_block_size_y", - "max_block_size_z", - "max_push_constants_size", - "max_uniform_buffer_range", - "max_storage_buffer_range", - "max_per_stage_descriptor_storage_buffer", - "max_shared_memory_per_block", - "driver_version", - "vulkan_api_version", - "max_spirv_version"}; - std::vector str_opts = {"device_name"}; - - for (auto& key : bool_opts) { - if (!attrs.count(key)) { - attrs.Set(key, Bool(static_cast((*get_target_property)(device, key)))); - } - } - for (auto& key : int_opts) { - if (!attrs.count(key)) { - attrs.Set(key, Integer(static_cast((*get_target_property)(device, key)))); - } - } - for (auto& key : str_opts) { - if (!attrs.count(key)) { - attrs.Set(key, (*get_target_property)(device, key)); - } - } - - attrs.erase("from_device"); - } - - // Set defaults here, rather than in the .add_attr_option() calls. - // The priority should be user-specified > device-query > default, - // but defaults defined in .add_attr_option() are already applied by - // this point. Longer-term, would be good to add a - // `DeviceAPI::GetTargetProperty` function and extend "from_device" - // to work for all runtimes. - std::unordered_map defaults = {{"supports_float32", Bool(true)}, - {"supports_int32", Bool(true)}, - {"max_num_threads", Integer(256)}, - {"thread_warp_size", Integer(1)}}; - for (const auto& kv : defaults) { - if (!attrs.count(kv.first)) { - attrs.Set(kv.first, kv.second); - } - } - return attrs; -} - /********** Register Target kinds and attributes **********/ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) @@ -299,6 +220,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("runtime") .add_attr_option("link-params", Bool(false)) .add_attr_option("unpacked-api") + .add_attr_option("interface-api") .set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("c", kDLCPU) @@ -310,6 +232,7 @@ TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("executor") .add_attr_option("workspace-byte-alignment") .add_attr_option("unpacked-api") + .add_attr_option("interface-api") .set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) @@ -360,14 +283,13 @@ TVM_REGISTER_TARGET_KIND("metal", kDLMetal) TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("system-lib") - .add_attr_option("from_device") // Feature support .add_attr_option("supports_float16") - .add_attr_option("supports_float32") + .add_attr_option("supports_float32", Bool(true)) .add_attr_option("supports_float64") .add_attr_option("supports_int8") .add_attr_option("supports_int16") - .add_attr_option("supports_int32") + .add_attr_option("supports_int32", Bool(true)) .add_attr_option("supports_int64") .add_attr_option("supports_8bit_buffer") .add_attr_option("supports_16bit_buffer") @@ -376,8 +298,8 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("supports_dedicated_allocation") .add_attr_option("supported_subgroup_operations") // Physical device limits - .add_attr_option("max_num_threads") - .add_attr_option("thread_warp_size") + .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("thread_warp_size", Integer(1)) .add_attr_option("max_block_size_x") .add_attr_option("max_block_size_y") .add_attr_option("max_block_size_z") @@ -387,13 +309,13 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("max_per_stage_descriptor_storage_buffer") .add_attr_option("max_shared_memory_per_block") // Other device properties + .add_attr_option("device_type") .add_attr_option("device_name") .add_attr_option("driver_version") .add_attr_option("vulkan_api_version") .add_attr_option("max_spirv_version") // Tags - .set_default_keys({"vulkan", "gpu"}) - .set_attrs_preprocessor(UpdateVulkanAttrs); + .set_default_keys({"vulkan", "gpu"}); TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) .add_attr_option("system-lib") diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc index 76fed053fda1..240adf14b31d 100644 --- a/src/te/autodiff/ad_simplify.cc +++ b/src/te/autodiff/ad_simplify.cc @@ -834,7 +834,7 @@ std::pair ImplicationNotContainingVars( return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) && (pair_b.first || pair_a.second) && (pair_a.second || pair_b.second)}; - } else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) { + } else if (!tir::UsesVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) { return {cond, const_true()}; } else { return {const_true(), cond}; @@ -1014,7 +1014,7 @@ PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond, // Keep only those variables of the new vars which are used in the new_expr Array used_res_variables; for (const Var& var : res->dst->variables) { - if (ExprUseVar(new_expr, var)) { + if (tir::UsesVar(new_expr, [&var](const VarNode* var_) { return var_ == var.get(); })) { ICHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred."; used_res_variables.push_back(var); } diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 9a4eadb35619..c73a6e0ce120 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -260,7 +260,7 @@ void BaseComputeOpNode::GatherBound(const Operation& self, Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { ICHECK_EQ(stage->op.get(), this); Region bounds; for (IterVar iv : this->axis) { @@ -269,7 +269,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, Stmt realize = body; for (int i = this->num_outputs(); i > 0; --i) { Tensor t = stage->op.output(i - 1); - realize = tir::ProducerRealize(t, bounds, const_true(), realize); + realize = tir::ProducerRealize(t, bounds, const_true(), realize, storage_scope); // alignment requirement, only useful for compute for (size_t i = 0; i < num_schedulable_dims(); ++i) { auto it = stage->iter_var_attrs.find(this->axis[i]); @@ -591,7 +591,7 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_mapshape, tensor->dtype, tensor->GetNameHint()); + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); info->tensor2buffers[tensor] = buffer; // Step 3. Add Buffer to root_alloc @@ -270,7 +270,8 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { const te::Tensor& tensor = op.output(0); // Check op is in op list ICHECK(info.IsArg(tensor)); - const Buffer& buffer = decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name); + const Buffer& buffer = + decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name, "global"); info.tensor2buffers[tensor] = buffer; } else if (const auto* compute_op = op.as()) { // Case 2. ComputeOp (te.compute) diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index da20dd875ba5..2ed5fd4029a2 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -146,7 +146,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, for (size_t i = 0; i < size; ++i) { DataType t = reduces[i]->dtype; normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), - PointerType(PrimType(t))); + PointerType(PrimType(t), "local")); lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); } Array init_value = combiner->identity_element; @@ -177,7 +177,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, std::vector res_handles(size); for (size_t idx = 0; idx < size; ++idx) { DataType dtype = reduces[idx]->dtype; - res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype))); + res_handles[idx] = + Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local")); freduce_args.push_back(res_handles[idx]); } @@ -224,12 +225,9 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = AttrStmt(res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); if (!normal_red.empty()) { body = Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = - AttrStmt(normal_res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); } } body = Substitute(body, value_map); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 1c9a3cb336ae..b602efcfc28b 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -124,7 +124,7 @@ void ExternOpNode::GatherBound(const Operation& self, Stmt ExternOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { ICHECK_EQ(stage->op.get(), this); Stmt realize_body = body; for (int k = 0; k < num_outputs(); ++k) { @@ -133,7 +133,7 @@ Stmt ExternOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope); } return realize_body; } diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 65b8660ca1fb..5d2412abb3d2 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -144,7 +144,7 @@ void HybridOpNode::GatherBound(const Operation& self, Stmt HybridOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { // TODO(@were): Add attribute inject here and remove it from hybrid parser. ICHECK_EQ(stage->op.get(), this); Stmt realize_body = body; @@ -154,7 +154,7 @@ Stmt HybridOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope); } return realize_body; } diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc index b3897e142545..ddc78866ae02 100644 --- a/src/te/operation/op_utils.cc +++ b/src/te/operation/op_utils.cc @@ -156,10 +156,12 @@ std::vector > MakeLoopNest(const Stage& stage, nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op)); if (!debug_keep_trivial_loop && is_one(dom->extent)) { value_map[iv] = dom->min; + } else if (stage->scope == "") { + value_map[iv] = var; } else { runtime::ThreadScope ts = runtime::ThreadScope::Create(bind_iv->thread_tag); - if (stage->scope == "" || - static_cast(runtime::StorageScope::Create(stage->scope).rank) <= ts.rank) { + runtime::StorageScope ss = runtime::StorageScope::Create(stage->scope); + if (static_cast(ss.rank) <= ts.rank) { value_map[iv] = var; } else if (stage->scope == "warp" && ts.rank == 1) { // To determine whether a thread index is inside or outside a warp, we need diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index c51e53e16cd1..4f5df7ad3024 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -85,7 +85,7 @@ void PlaceholderOpNode::GatherBound(const Operation& self, Stmt PlaceholderOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { return body; } diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index a555e86097b7..39689bd9654a 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -234,7 +234,7 @@ void ScanOpNode::GatherBound(const Operation& self, } Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map& dom_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { arith::Analyzer analyzer; ICHECK_EQ(stage->op.get(), this); Range sdom = dom_map.at(this->scan_axis); @@ -250,7 +250,7 @@ Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_mapspatial_axis_[sp_idx]; bounds.push_back(dom_map.at(sp_ax)); } - ret = tir::ProducerRealize(t, bounds, const_true(), ret); + ret = tir::ProducerRealize(t, bounds, const_true(), ret, storage_scope); } return ret; } diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 0aa279fb9246..447fc501d03b 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -140,13 +140,13 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage, auto fbanned = [&](const VarNode* node) { return banned.count(node); }; for (const PrimExpr& pred : n.main_predicates) { - if (tir::ExprUseVar(pred, fbanned)) { + if (tir::UsesVar(pred, fbanned)) { LOG(FATAL) << "Tensorize failed, split condition " << pred << " relies on var defined inside tensorize scope"; } } for (const PrimExpr& pred : n.init_predicates) { - if (tir::ExprUseVar(pred, fbanned)) { + if (tir::UsesVar(pred, fbanned)) { LOG(FATAL) << "Tensorize failed, split condition " << pred << " relies on var defined inside tensorize scope"; } diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 355e3c39494b..825092d20ac0 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -51,11 +51,8 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_ if (consumer.defined() && !is_no_op(consumer)) { pipeline = SeqStmt({producer, consumer}); } - pipeline = s->op->BuildRealize(s, dom_map, pipeline); - // use attribute to mark scope of the operation. - pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline); - return pipeline; + return s->op->BuildRealize(s, dom_map, pipeline, s->scope); } // inject the operator's realization on the stmt. @@ -175,8 +172,7 @@ class SchedulePostProc : public StmtExprMutator { thread_extent_scope_.erase(op->node.get()); return ret; } - } else if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + } else if (op->attr_key == tir::attr::double_buffer_scope) { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { @@ -218,7 +214,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_realize_.find(key); if (it != replace_realize_.end()) { if (it->second.defined()) { - Stmt ret = ProducerRealize(it->second, op->bounds, op->condition, op->body); + Stmt ret = + ProducerRealize(it->second, op->bounds, op->condition, op->body, op->storage_scope); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 5c59961fe011..439d0ff17255 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -49,12 +49,12 @@ namespace tvm { namespace te { // create a buffer for tensor. -Buffer CreateBufferFor(const Tensor& tensor) { +Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") { std::string name = tensor->op->name; if (tensor->op->num_outputs() != 1) { name += ".v" + std::to_string(tensor->value_index); } - Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name); + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name, storage_scope); return buffer; } @@ -67,10 +67,7 @@ class TensorToBufferMapper : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - // TODO(tvm-team): remove realize_scope, turn the info into - // Buffer's scope field in this pass. - if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + if (op->attr_key == tir::attr::double_buffer_scope) { Stmt body = op->body; Operation operation = Downcast(op->node); for (int i = operation->num_outputs(); i != 0; --i) { @@ -95,7 +92,7 @@ class TensorToBufferMapper : public StmtExprMutator { Stmt VisitStmt_(const ProducerRealizeNode* op) final { Tensor tensor = Downcast(op->producer); - Buffer buffer = GetOrAllocBuffer(tensor); + Buffer buffer = GetOrAllocBuffer(tensor, op->storage_scope); auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); @@ -122,14 +119,16 @@ class TensorToBufferMapper : public StmtExprMutator { } private: - Buffer GetOrAllocBuffer(const Tensor& tensor) { return GetBuffer(tensor, true); } + Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") { + return GetBuffer(tensor, storage_scope, true); + } - Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) { + Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) { auto it = buffer_map_.find(tensor); if (it != buffer_map_.end()) return it->second; ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; - auto buffer = CreateBufferFor(tensor); + auto buffer = CreateBufferFor(tensor, storage_scope); buffer_map_[tensor] = buffer; return buffer; } @@ -171,7 +170,9 @@ PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, } body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body)); - return tir::PrimFunc(params, body, VoidType(), buffer_map); + // We mark this PrimFunc as coming from a TE schedule + return WithAttr(tir::PrimFunc(params, body, VoidType(), buffer_map), "from_legacy_te_schedule", + Bool(true)); } TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc") diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index b1da536f1dad..dd01aed61c52 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -26,6 +26,7 @@ #include #include +#include "../transforms/ir_utils.h" namespace tvm { namespace tir { @@ -65,8 +66,12 @@ class BlockReadWriteDetector : public StmtExprVisitor { std::vector> read_regions_; /*! \brief The write regions of the current block */ std::vector> write_regions_; + /*! \brief The opaque regions of the current block */ + std::vector> opaque_regions_; /*! \brief The outside buffer data mapping to its buffer */ Map buffer_var_map_; + /*! \brief The target buffer var mapping to its matching */ + std::unordered_map match_buffers_; /*! \brief The analyzer for simplifying*/ arith::Analyzer analyzer_; @@ -78,14 +83,18 @@ class BlockReadWriteDetector : public StmtExprVisitor { * \param region The provided region */ void Update(std::vector* buffers, std::vector>* regions, - const Buffer& buffer, const std::vector& region); + Buffer buffer, std::vector region); /*! \brief Helper function to collect access regions. */ Array CollectRegions(const std::vector& buffers, const std::vector>& regions); - /*! \brief Helper function to add a opaque buffer. */ - void AddOpaque(const Var& buffer_var); + /*! \brief Helper function to convert matched access region to source region. */ + std::vector ConvertMatchedRegion(const MatchBufferRegion& match_buffer, + const std::vector& int_sets) const; + + /*! \brief Helper function to update a opaque access. */ + void UpdateOpaque(const Var& buffer_var); void VisitStmt_(const ForNode* op) override; void VisitStmt_(const BlockRealizeNode* op) override; @@ -97,8 +106,16 @@ class BlockReadWriteDetector : public StmtExprVisitor { }; void BlockReadWriteDetector::operator()(const Stmt& stmt) { - ICHECK(stmt.as() != nullptr) - << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); + const auto* block = stmt.as(); + ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); + for (const MatchBufferRegion& match_buffer : block->match_buffers) { + const Var& target_var = match_buffer->buffer->data; + const Var& source_var = match_buffer->source->buffer->data; + if (buffer_var_map_.find(source_var) != buffer_var_map_.end()) { + match_buffers_[target_var.get()] = match_buffer; + buffer_var_map_.Set(target_var, match_buffer->buffer); + } + } StmtExprVisitor::operator()(stmt); } @@ -111,18 +128,13 @@ Array BlockReadWriteDetector::CollectWrites() { } Array BlockReadWriteDetector::CollectOpaques() { - Array res; - res.reserve(opaque_buffers_.size()); - for (const Buffer& buffer : opaque_buffers_) { - res.push_back(BufferRegion::FullRegion(buffer)); - } - return res; + return CollectRegions(opaque_buffers_, opaque_regions_); } -void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { AddOpaque(GetRef(op)); } +void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef(op)); } void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) { - AddOpaque(op->buffer_var); + UpdateOpaque(op->buffer_var); ExprVisitor::VisitExpr_(op); } @@ -143,7 +155,7 @@ void BlockReadWriteDetector::VisitStmt_(const ForNode* op) { } void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { - AddOpaque(op->buffer_var); + UpdateOpaque(op->buffer_var); StmtVisitor::VisitStmt_(op); } @@ -184,11 +196,39 @@ void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) { } } +std::vector BlockReadWriteDetector::ConvertMatchedRegion( + const MatchBufferRegion& match_buffer, const std::vector& int_sets) const { + const Buffer& buffer = match_buffer->buffer; + + Region region; + region.reserve(int_sets.size()); + ICHECK_EQ(buffer->shape.size(), int_sets.size()); + for (size_t i = 0; i < int_sets.size(); ++i) { + const tvm::arith::IntSet& int_set = int_sets[i]; + region.push_back(int_set.CoverRange(Range::FromMinExtent(0, buffer->shape[i]))); + } + + region = ConvertRegion(match_buffer, region); + + std::vector result; + result.reserve(region.size()); + for (const Range& range : region) { + result.push_back(arith::EvalSet(range, dom_map_)); + } + return result; +} + void BlockReadWriteDetector::Update(std::vector* buffers, - std::vector>* regions, - const Buffer& buffer, - const std::vector& region) { + std::vector>* regions, Buffer buffer, + std::vector region) { if (buffer_var_map_.find(buffer->data) == buffer_var_map_.end()) return; + // Handle match_buffer remap + auto it = match_buffers_.find(buffer->data.get()); + if (it != match_buffers_.end()) { + const MatchBufferRegion& match_buffer = it->second; + buffer = match_buffer->source->buffer; + region = ConvertMatchedRegion(match_buffer, std::move(region)); + } ICHECK_EQ(buffers->size(), regions->size()) << " Expected the buffer and regions to have the same size "; for (size_t i = 0; i < regions->size(); ++i) { @@ -200,8 +240,8 @@ void BlockReadWriteDetector::Update(std::vector* buffers, return; } } - buffers->push_back(buffer); - regions->push_back(region); + buffers->push_back(std::move(buffer)); + regions->push_back(std::move(region)); } Array BlockReadWriteDetector::CollectRegions( @@ -213,8 +253,9 @@ Array BlockReadWriteDetector::CollectRegions( for (size_t i = 0; i < regions.size(); ++i) { Array region; region.reserve(regions[i].size()); + ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { - tvm::arith::IntSet range = regions[i][j]; + const tvm::arith::IntSet& range = regions[i][j]; region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); } res.push_back(BufferRegion(buffers[i], region)); @@ -222,14 +263,18 @@ Array BlockReadWriteDetector::CollectRegions( return res; } -void BlockReadWriteDetector::AddOpaque(const Var& buffer_var) { +void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) { auto it = buffer_var_map_.find(buffer_var); if (it != buffer_var_map_.end()) { const Buffer& buffer = (*it).second; - for (const Buffer& opaque_buffer : opaque_buffers_) { - if (buffer.same_as(opaque_buffer)) return; + const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); + const Region& region = buffer_region->region; + std::vector int_set; + int_set.reserve(region.size()); + for (const Range& range : region) { + int_set.push_back(arith::EvalSet(range, dom_map_)); } - opaque_buffers_.push_back(buffer); + Update(&opaque_buffers_, &opaque_regions_, buffer, int_set); } } diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 6f2622f3a61e..e680d689735d 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -85,9 +85,17 @@ class LCADetector : public StmtExprVisitor { for (const Buffer& buf : op->alloc_buffers) { buffer_var_map_.emplace(buf->data.get(), buf.get()); } + const ScopeInfo* parent_scope = ancestor_scopes_.back(); auto* current_scope = arena_.make(parent_scope, op, n); + ancestor_scopes_.push_back(current_scope); + // Update match_buffers + for (const MatchBufferRegion& match_buffer : op->match_buffers) { + UpdateBufferLCA(match_buffer->source->buffer.get()); + match_buffers_.insert(match_buffer->buffer.get()); + } + StmtExprVisitor::VisitStmt_(op); ancestor_scopes_.pop_back(); } @@ -129,8 +137,11 @@ class LCADetector : public StmtExprVisitor { } void UpdateBufferLCA(const BufferNode* buffer) { - const ScopeInfo*& lca = buffer_lca_[buffer]; - lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); + if (match_buffers_.find(buffer) == match_buffers_.end()) { + // Ingore buffer created by block match_buffer + const ScopeInfo*& lca = buffer_lca_[buffer]; + lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); + } } static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) { @@ -164,6 +175,8 @@ class LCADetector : public StmtExprVisitor { std::unordered_map buffer_lca_ = {}; /*! \brief The map from Buffer data to the Buffer. */ std::unordered_map buffer_var_map_ = {}; + /*! \brief The match buffers inside blocks. */ + std::unordered_set match_buffers_ = {}; /*! \brief Internal arena. */ support::Arena arena_; }; diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 7eb8013f2a85..7f48cc439234 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -59,6 +59,9 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { auto* prhs = rhs.as(); return plhs->dtype == prhs->dtype && plhs->value == prhs->value; } + if (lhs.as()) { + return false; + } return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false); } diff --git a/src/tir/analysis/var_touch.cc b/src/tir/analysis/var_touch.cc index 40a2cce70ae9..c4acd2b74aad 100644 --- a/src/tir/analysis/var_touch.cc +++ b/src/tir/analysis/var_touch.cc @@ -22,23 +22,33 @@ * \brief Implementation of simple passes */ #include -#include #include namespace tvm { namespace tir { -class VarTouchVisitor : public ExprVisitor { +class VarTouchVisitor : public StmtExprVisitor { public: - explicit VarTouchVisitor(std::function var_set) : var_set_(var_set) {} + explicit VarTouchVisitor(std::function var_set) + : var_set_(std::move(var_set)) {} + + void VisitStmt(const Stmt& stmt) final { + if (use_var_) return; + StmtExprVisitor::VisitStmt(stmt); + } void VisitExpr(const PrimExpr& e) final { if (use_var_) return; - ExprVisitor::VisitExpr(e); + StmtExprVisitor::VisitExpr(e); } void VisitExpr_(const VarNode* op) final { Handle(op); } + void VisitStmt_(const StoreNode* op) final { + Handle(op->buffer_var.get()); + StmtVisitor::VisitStmt_(op); + } + void VisitExpr_(const LoadNode* op) final { Handle(op->buffer_var.get()); ExprVisitor::VisitExpr_(op); @@ -54,9 +64,15 @@ class VarTouchVisitor : public ExprVisitor { std::function var_set_; }; -bool ExprUseVar(const PrimExpr& e, std::function var_set) { - VarTouchVisitor visitor(var_set); - visitor(e); +bool UsesVar(const Stmt& stmt, std::function var_set) { + VarTouchVisitor visitor(std::move(var_set)); + visitor(stmt); + return visitor.use_var_; +} + +bool UsesVar(const PrimExpr& expr, std::function var_set) { + VarTouchVisitor visitor(std::move(var_set)); + visitor(expr); return visitor.use_var_; } diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index afd3c7add605..efffa9031ac0 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -30,6 +30,8 @@ #include #include +#include "../transforms/ir_utils.h" + namespace tvm { namespace tir { @@ -38,7 +40,7 @@ class GPUCodeVerifier : public StmtExprVisitor { std::vector Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block, int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y, int64_t max_thread_z, - int64_t max_vthread, int64_t max_vector_bytes) { + int64_t max_vthread, int64_t max_vector_bytes, int64_t max_kernels) { max_local_memory_per_block_ = static_cast(max_local_memory_per_block); max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); max_threads_per_block_ = static_cast(max_threads_per_block); @@ -47,7 +49,7 @@ class GPUCodeVerifier : public StmtExprVisitor { max_thread_z_ = static_cast(max_thread_z); max_vthread_ = static_cast(max_vthread); max_vector_bytes_ = static_cast(max_vector_bytes); - + max_kernels_ = static_cast(max_kernels); Reset_(); // TODO(jcf94): Add support of detecting CUDA Misaligned Address error @@ -58,11 +60,12 @@ class GPUCodeVerifier : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { StmtVisitor::VisitStmt_(op); + auto scope = GetPtrStorageScope(op->buffer_var); // visit an allocation of a buffer in shared memory, record its size - if (visited_local_buffers_.count(op->buffer_var.get()) != 0) { + if (scope == "local") { size_t size = static_cast(op->constant_allocation_size()); local_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); - } else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) { + } else if (scope == "shared") { size_t size = static_cast(op->constant_allocation_size()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } @@ -78,18 +81,11 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - std::string op_value = op->value.as()->value; - if (op_value == "local") { - visited_local_buffers_.insert(op->node.as()); - } else if (op_value == "shared") { - visited_shared_buffers_.insert(op->node.as()); - } - StmtVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { if (nest_level_ == 0) { // enter a new kernel, reset statistics Reset_(); + kernels_launched_++; } Var var = op->node.as()->var; @@ -158,6 +154,13 @@ class GPUCodeVerifier : public StmtExprVisitor { err("threads per block", thread_per_block_, max_threads_per_block_); err("local memory per block", local_memory_per_block_, max_local_memory_per_block_); err("shared memory per block", shared_memory_per_block_, max_shared_memory_per_block_); + + if (kernels_launched_ > max_kernels_) { + std::stringstream s; + s << "Number of launched kernels (" << kernels_launched_ + << ") is greater than the allowed maximum (" << max_kernels_ << ")"; + errors_.push_back(s.str()); + } } } else { StmtVisitor::VisitStmt_(op); @@ -211,8 +214,6 @@ class GPUCodeVerifier : public StmtExprVisitor { private: int nest_level_{0}; - std::unordered_set visited_local_buffers_; - std::unordered_set visited_shared_buffers_; std::unordered_set visited_threads_; size_t thread_x_extent_, thread_y_extent_, thread_z_extent_; @@ -220,18 +221,18 @@ class GPUCodeVerifier : public StmtExprVisitor { size_t local_memory_per_block_; size_t shared_memory_per_block_; size_t thread_per_block_; + size_t kernels_launched_{0}; size_t max_local_memory_per_block_; size_t max_shared_memory_per_block_; size_t max_threads_per_block_; size_t max_thread_x_, max_thread_y_, max_thread_z_, max_vthread_; size_t max_vector_bytes_; + size_t max_kernels_; std::vector errors_; void Reset_() { - visited_local_buffers_.clear(); - visited_shared_buffers_.clear(); local_memory_per_block_ = 0; shared_memory_per_block_ = 0; @@ -251,6 +252,7 @@ std::vector VerifyGPUCode_(const PrimFunc& func, Map c int64_t max_thread_z = INT64_MAX; int64_t max_vthread = INT64_MAX; int64_t max_vector_bytes = INT64_MAX; + int64_t max_kernels = INT64_MAX; for (auto iter : constraints) { const IntImmNode* val = iter.second.as(); @@ -270,6 +272,8 @@ std::vector VerifyGPUCode_(const PrimFunc& func, Map c max_vthread = val->value; } else if (iter.first == "max_vector_bytes") { max_vector_bytes = val->value; + } else if (iter.first == "max_kernels") { + max_kernels = val->value; } else { LOG(FATAL) << "Invalid check item: " << iter.first; } @@ -277,7 +281,7 @@ std::vector VerifyGPUCode_(const PrimFunc& func, Map c return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block, max_threads_per_block, max_thread_x, max_thread_y, max_thread_z, - max_vthread, max_vector_bytes); + max_vthread, max_vector_bytes, max_kernels); } bool VerifyGPUCode(const PrimFunc& func, Map constraints) { diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 2089ead98168..0382b8071de7 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -172,6 +172,9 @@ std::vector VerifyMemory_(const PrimFunc& func) { auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; + DLOG(INFO) << "verifying memory for target '" << target.value()->str() << "' for primitive\n" + << PrettyPrint(func); + if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { MemoryAccessVerifier v(func, target.value()->kind->device_type); diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 1667eb7d1fbd..335ff19dd775 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -32,6 +32,8 @@ #include #include +#include "../../arith/pattern_match.h" + namespace tvm { namespace tir { @@ -45,10 +47,11 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { return array; } -Buffer decl_buffer(Array shape, DataType dtype, String name, Span span) { +Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, + Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); - return Buffer(Var(name, PointerType(PrimType(storage_dtype)), span), dtype, shape, - Array(), PrimExpr(), name, "", 0, 0, kDefault, span); + return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, + Array(), PrimExpr(), name, 0, 0, kDefault, span); } // Split the given expression w.r.t the add operator @@ -180,7 +183,12 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { // a list that contain all the elements that match Mod. // The elements in the Mod will be used to match against the elements in Mul. // The result will then be split and pushed back to these two lists. - PrimExpr simplified_base = analyzer->Simplify(base); + PrimExpr simplified_base = base; + arith::PVar x, y; + if ((floordiv(x, y) * y + floormod(x, y)).Match(simplified_base)) { + simplified_base = x.Eval(); + } + simplified_base = analyzer->Simplify(simplified_base); std::vector eles = ExprSplitAddition(simplified_base); std::list mult_exprs; std::list > mod_exprs; @@ -311,6 +319,15 @@ Stmt Buffer::vstore(Array begin, PrimExpr value) const { } } +String Buffer::scope() const { + const auto* ptr_type = (*this)->data->type_annotation.as(); + ICHECK(ptr_type) << "Buffer variable is not of pointer type"; + if (ptr_type->storage_scope.empty()) { + return "global"; + } + return ptr_type->storage_scope; +} + Buffer Buffer::MakeStrideView() const { if ((*this)->strides.size() != 0) return *this; if ((*this)->shape.size() == 0) return *this; @@ -350,7 +367,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const return MakeStrideView().MakeSlice(begins, extents); } } - return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->scope, + return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->data_alignment, 0, n->buffer_type); } @@ -383,8 +400,8 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane } Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, String name, String scope, int data_alignment, - int offset_factor, BufferType buffer_type, Span span) { + PrimExpr elem_offset, String name, int data_alignment, int offset_factor, + BufferType buffer_type, Span span) { DataType storage_dtype = dtype; // specially handle bool if (storage_dtype == DataType::Bool()) { @@ -401,10 +418,6 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array n->shape = std::move(shape); n->strides = std::move(strides); n->name = std::move(name); - if (scope.length() == 0) { - scope = "global"; - } - n->scope = std::move(scope); if (!elem_offset.defined()) { elem_offset = make_const(n->DefaultIndexType(), 0); } @@ -436,11 +449,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(BufferNode); TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args.size(), 11); - auto buffer_type = args[9].operator String(); + ICHECK_EQ(args.size(), 10); + auto buffer_type = args[8].operator String(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; - *ret = Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], - type, args[10]); + *ret = + Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], type, args[9]); }); TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); @@ -449,5 +462,7 @@ TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); +TVM_REGISTER_GLOBAL("tir.BufferStorageScope").set_body_method(&Buffer::scope); + } // namespace tir } // namespace tvm diff --git a/src/tir/ir/buffer_common.h b/src/tir/ir/buffer_common.h new file mode 100644 index 000000000000..8dac41a02e57 --- /dev/null +++ b/src/tir/ir/buffer_common.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tir/ir/buffer_common.h + * \brief Common utils for buffer access + */ +#ifndef TVM_TIR_IR_BUFFER_COMMON_H_ +#define TVM_TIR_IR_BUFFER_COMMON_H_ + +#include +#include + +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Returns the type of object pointed to. + * + * \param type The type to be checked. + * + * \return A (bool, DataType) pair. If the type is a pointer to a + * primitive, the boolean is true and the DataType is the pointed-to + * type. Otherwise, the boolean is false and the DataType is + * default-constructed. This can be replaced with std::optional with + * C++17 if/when C++17 is required. + */ +inline std::pair GetPointerType(const Type& type) { + if (type.defined()) { + if (auto* ptr_type = type.as()) { + if (auto* prim_type = ptr_type->element_type.as()) { + return {true, prim_type->dtype}; + } + } + } + + return {false, DataType()}; +} + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_IR_BUFFER_COMMON_H_ diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 352f75abdf5e..afc5c36ebb92 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -25,10 +25,8 @@ #include #include -#include -#include - #include "../../support/str_escape.h" +#include "buffer_common.h" namespace tvm { namespace tir { @@ -618,8 +616,42 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, S ICHECK(buffer_var.defined()); ICHECK(predicate.defined()); ICHECK(index.defined()); - ICHECK_EQ(dtype.lanes(), index.dtype().lanes()); - ICHECK_EQ(dtype.lanes(), predicate.dtype().lanes()); + + // Assume that the array elements have 1 lane, unless a type + // annotation tells us otherwise. + int element_lanes = 1; + auto pointer_type = tir::GetPointerType(buffer_var->type_annotation); + if (pointer_type.first) { + // Cannot check element type of array, as it may be different than + // the loaded type in some cases. + // + // 1. Booleans use DataType::Int(8) while stored, and the codegens + // handle cast to boolean. + // + // 2. The StorageRewrite pass can merge multiple allocations at + // the same scope, regardless of element type. The codegen is + // then responsible for casting to the output type. + + // TODO(Lunderberg): Uncomment this check once it can be applied. + // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 + // for discussion. + + // ICHECK(dtype.element_of() == pointer_type.second.element_of()) + // << "Type mismatch, cannot load type " << dtype << " from buffer " << + // buffer_var->name_hint + // << " of type " << pointer_type.second; + element_lanes = pointer_type.second.lanes(); + } + + // The C-based codegens assume that all loads occur on a array with + // non-vectorized elements, and cast between + // vectorized/non-vectorized arrays as needed. Ideally, these + // should be changed to explicit casts in the TIR graph, rather than + // being handled at the code-gen level. + ICHECK((dtype.lanes() == element_lanes * index.dtype().lanes()) || + (dtype.lanes() == index.dtype().lanes())); + ICHECK((dtype.lanes() == element_lanes * predicate.dtype().lanes()) || + (dtype.lanes() == index.dtype().lanes())); ObjectPtr node = make_object(); node->dtype = dtype; diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index c15b3bb47bf4..f265a8ae2b1b 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -36,15 +36,12 @@ namespace tir { /*! \brief Generate surrounding loops automatically */ class ScriptCompleter : public StmtMutator { public: - explicit ScriptCompleter(Map* buffer_var_map, bool contain_root) - : buffer_var_map_(buffer_var_map), contain_root_(contain_root) {} + explicit ScriptCompleter(Map* buffer_var_map) : buffer_var_map_(buffer_var_map) {} /*! \brief Whether the stmt contains at least one block. */ bool contains_block = false; private: Map* buffer_var_map_; - bool contain_root_; - bool visited_root_ = false; Stmt VisitStmt_(const BlockRealizeNode* op) override { contains_block = true; Stmt body = StmtMutator::VisitStmt_(op); @@ -65,17 +62,23 @@ class ScriptCompleter : public StmtMutator { } Stmt VisitStmt_(const BlockNode* op) override { - bool is_root_block = contain_root_ && !visited_root_; - visited_root_ = true; // Buffers allocated in the block can be accessed by its body. for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->Set(alloc_buffer->data, alloc_buffer); } + for (const auto& match_buffer : op->match_buffers) { + const Buffer& target_buffer = match_buffer->buffer; + buffer_var_map_->Set(target_buffer->data, target_buffer); + } Block block = Downcast(StmtMutator::VisitStmt_(op)); // Remove buffers allocated inside block to detect its access region for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->erase(alloc_buffer->data); } + for (const auto& match_buffer : op->match_buffers) { + const Buffer& target_buffer = match_buffer->buffer; + buffer_var_map_->erase(target_buffer->data); + } // Get access detection mask // 0 for provided region, 1 and 3 for need detect read, 2 and 3 for need detect write int mask = 0; @@ -85,13 +88,6 @@ class ScriptCompleter : public StmtMutator { } // ignore root block or blocks which already has reads/writes regions if (mask != 0) { - if (op->iter_vars.empty()) { - // non-root opaque block is not allowed - CHECK(is_root_block) - << "ValueError: Can not auto detect buffer access region for an opaque block. Please " - "annotate the access region manually."; - return std::move(block); - } auto access_region = GetBlockAccessRegion(block, *buffer_var_map_); const Array& reads = access_region[0]; const Array& writes = access_region[1]; @@ -122,7 +118,7 @@ PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { } bool contain_root = root_allocates.empty() && func->body->IsInstance() && Downcast(func->body)->block->iter_vars.empty(); - ScriptCompleter script_completer(&buffer_var_map, contain_root); + ScriptCompleter script_completer(&buffer_var_map); // generate surrounding loops automatically Stmt res = script_completer(func->body); // generate root block automatically diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index aa5f271c20c2..768787735a1f 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -45,6 +46,27 @@ inline bool IsParam(const PrimFunc& func, const Var& param) { /**************** Specializer ****************/ +// Try fold constants if op's child get specialized to constant. +#define DEFINE_SPECIALIZER_BINARY_OP_MUTATE(BinaryNode, BinaryFunc) \ + PrimExpr VisitExpr_(const BinaryNode* op) final { \ + PrimExpr a = VisitExpr(op->a); \ + PrimExpr b = VisitExpr(op->b); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return BinaryFunc(a, b); \ + } \ + } +#define DEFINE_SPECIALIZER_UNARY_OP_MUTATE(UnaryNode, UnaryFunc) \ + PrimExpr VisitExpr_(const UnaryNode* op) final { \ + PrimExpr a = VisitExpr(op->a); \ + if (a.same_as(op->a)) { \ + return GetRef(op); \ + } else { \ + return UnaryFunc(a); \ + } \ + } + /*! \brief Mutator to specialize function and remove const parameters */ class PrimFuncSpecializer : public StmtExprMutator { public: @@ -157,14 +179,33 @@ class PrimFuncSpecializer : public StmtExprMutator { } } + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(AddNode, add); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(SubNode, sub); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MulNode, mul); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(DivNode, div); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(ModNode, truncmod); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(FloorDivNode, floordiv); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(FloorModNode, floormod); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MaxNode, max); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MinNode, min); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(EQNode, equal); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(NENode, not_equal); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(LTNode, less); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(LENode, less_equal); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(GTNode, greater); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(GENode, greater_equal); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(AndNode, logical_and); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(OrNode, logical_or); + DEFINE_SPECIALIZER_UNARY_OP_MUTATE(NotNode, logical_not); + private: - Buffer MutateBuffer(const Buffer& buffer) const { + Buffer MutateBuffer(const Buffer& buffer) { Array shape = - MutateArray(buffer->shape, [this](const PrimExpr& e) { return Substitute(e, var_map_); }); + MutateArray(buffer->shape, [this](const PrimExpr& e) { return VisitExpr(e); }); Array strides = - MutateArray(buffer->strides, [this](const PrimExpr& e) { return Substitute(e, var_map_); }); + MutateArray(buffer->strides, [this](const PrimExpr& e) { return VisitExpr(e); }); - PrimExpr elem_offset = Substitute(buffer->elem_offset, var_map_); + PrimExpr elem_offset = VisitExpr(buffer->elem_offset); if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) && buffer->strides.same_as(strides)) { diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index b2016eb74c91..d59c94dc5753 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -20,11 +20,14 @@ /*! * \file tvm/tir/stmt.cc */ +#include #include #include #include #include +#include "buffer_common.h" + namespace tvm { namespace tir { @@ -234,8 +237,29 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, ICHECK(value.defined()); ICHECK(index.defined()); ICHECK(predicate.defined()); - ICHECK_EQ(value.dtype().lanes(), index.dtype().lanes()); - ICHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes()); + + // Assume that the array elements have 1 lane, unless a type + // annotation tells us otherwise. + int element_lanes = 1; + auto pointer_type = tir::GetPointerType(buffer_var->type_annotation); + if (pointer_type.first) { + // Currently cannot check element type of array, see Load::Load + // for details. + + // TODO(Lunderberg): Uncomment this check once it can be applied. + // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 + // for discussion. + + // ICHECK_EQ(value.dtype().element_of(), pointer_type.second.element_of()) + // << "Type mismatch, cannot store type " << value.dtype() << " into buffer " + // << buffer_var->name_hint << " of type " << pointer_type.second; + element_lanes = pointer_type.second.lanes(); + } + + ICHECK((value.dtype().lanes() == element_lanes * index.dtype().lanes()) || + (value.dtype().lanes() == index.dtype().lanes())); + ICHECK((value.dtype().lanes() == element_lanes * predicate.dtype().lanes()) || + (value.dtype().lanes() == index.dtype().lanes())); ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); @@ -360,13 +384,15 @@ TVM_REGISTER_NODE_TYPE(AllocateNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); + const auto* ptr_type = op->buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; p->PrintIndent(); p->stream << "allocate " << op->buffer_var << "[" << op->dtype; for (size_t i = 0; i < op->extents.size(); ++i) { p->stream << " * "; p->Print(op->extents[i]); } - p->stream << "]"; + p->stream << "], storage_scope = " << ptr_type->storage_scope; if (!is_one(op->condition)) { p->stream << " if "; p->Print(op->condition); @@ -377,7 +403,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // ProducerRealize ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, - Stmt body, Span span) { + Stmt body, String storage_scope, Span span) { for (size_t i = 0; i < bounds.size(); ++i) { ICHECK(bounds[i]->min.defined()); ICHECK(bounds[i]->extent.defined()); @@ -394,13 +420,14 @@ ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr node->condition = std::move(condition); node->body = std::move(body); node->span = std::move(span); + node->storage_scope = std::move(storage_scope); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.ProducerRealize") .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - Span span) { - return ProducerRealize(producer, bounds, condition, body, span); + String storage_scope, Span span) { + return ProducerRealize(producer, bounds, condition, body, storage_scope, span); }); TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); @@ -632,6 +659,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // BufferRegion BufferRegion::BufferRegion(Buffer buffer, Array region) { + CHECK_EQ(buffer->shape.size(), region.size()) + << "The dimension between " << buffer << " and region " << region + << " mismatched, the buffer is " << buffer; ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->region = std::move(region); @@ -679,6 +709,49 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // MatchBufferRegion MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { + const Buffer& source_buffer = source->buffer; + arith::Analyzer analyzer; + // Check scope and dtype + CHECK_EQ(buffer.scope(), source_buffer.scope()) + << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << " vs. " + << source_buffer.scope(); + CHECK_EQ(buffer->dtype, source_buffer->dtype) + << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << " vs. " + << source_buffer->dtype; + + // Check data_alignment + CHECK(source_buffer->data_alignment % buffer->data_alignment == 0) + << "Trying to match buffer to another one with lower alignment requirement " + << " required_alignment=" << buffer->data_alignment + << ", provided_alignment=" << source_buffer->data_alignment; + + // Check BufferType. AutoBroadcast is not allowed for now. + CHECK(buffer->buffer_type == BufferType::kDefault && + source_buffer->buffer_type == BufferType::kDefault) + << "AutoBroadcast is not allowed in MatchBuffer"; + + // Validate shape + CHECK(source->region.size() >= buffer->shape.size()) + << "Dimension of source Region expected to be larger or equal than target buffer shape, but " + "got " + << source->region.size() << " vs. " << buffer->shape.size(); + size_t offset = source->region.size() - buffer->shape.size(); + for (size_t i = 0; i < offset; ++i) { + CHECK(analyzer.CanProve(source->region[i]->extent == 1)) + << "The higher dimension should be 1, but got " << source->region[i]->extent << "."; + } + for (size_t i = 0; i < buffer->shape.size(); ++i) { + const Range& source_range = source->region[i + offset]; + const PrimExpr& buffer_shape = buffer->shape[i]; + if (!buffer_shape->IsInstance()) { + CHECK(analyzer.CanProve(source_range->extent == buffer_shape)) + << "The dimension mismatched between source region and target buffer shape, got " + << source_range->extent << " vs. " << buffer_shape << "."; + } + } + // Note that we do not check elem_offset and strides in this function + + // Construction ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->source = std::move(source); @@ -695,7 +768,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); - p->stream << op->buffer->name << " = match_buffer_region("; + p->stream << op->buffer->name << " = match_buffer("; p->Print(op->source); p->stream << ")\n"; }); diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index f0ca04cbd5fd..c593cbf7290c 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -246,6 +246,17 @@ TIR_DEFINE_BUILTIN_FUNC(vectorcombine) TIR_DEFINE_BUILTIN_FUNC(atomic_add) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(texture2d_alloca) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(texture2d_store) + .set_attr("TVectorizable", true) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(texture2d_load) + .set_attr("TVectorizable", true) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + } // namespace builtin } // namespace tir } // namespace tvm diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index af78804837ba..5db131c44f2a 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -79,12 +79,6 @@ Type GetType(const PrimExpr& expr) { return PrimType(dtype); } -// simple cast that only checks if type matches and cast -inline PrimExpr SimpleCast(const DataType& t, PrimExpr value, Span span) { - if (value.dtype() == t) return value; - return tir::Cast(t, value, span); -} - // LargeUIntImm PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high, Span span) { return tir::Call( @@ -112,34 +106,49 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; } if (lhs.dtype() == rhs.dtype()) return; - // Only do very simple type coversion - // int->float, DataType::Int(32)->int(64) - // require the types to be relatively consistent - // This will the reduce amount code generated by operators - // and also help user to find potential type conversion problems. - if (!lhs.dtype().is_float() && - (rhs.dtype().is_float() || - datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) { - // int->float - lhs = cast(rhs.dtype(), lhs); - } else if ((lhs.dtype().is_float() || - datatype::Registry::Global()->GetTypeRegistered(lhs.dtype().code())) && - !rhs.dtype().is_float()) { - // int->float - rhs = cast(lhs.dtype(), rhs); - } else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) || - (lhs.dtype().is_uint() && rhs.dtype().is_uint())) { - // promote int to higher bits - if (lhs.dtype().bits() < rhs.dtype().bits()) { - lhs = cast(rhs.dtype(), lhs); + + ltype = lhs.dtype(); + rtype = rhs.dtype(); + // We keep dtypes conversion to be relatively consistent to reduce the amount code generated by + // operators. This can be helpful for users to find potential type conversion problems. The + // following are exceptions: + if (ltype.is_float() && rtype.is_float()) { + // Given two dissimilar floats, cast the lower bit version to the higher bit version. + // E.g. fp16 + fp32 --> fp32 + fp32 + if (ltype.bits() < rtype.bits()) { + lhs = cast(rtype, lhs); } else { - rhs = cast(lhs.dtype(), rhs); + rhs = cast(ltype, rhs); + } + } else if (!ltype.is_float() && + (rtype.is_float() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) { + // Cast int->float when the other operand is a float + lhs = cast(rtype, lhs); + } else if ((ltype.is_float() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) && + !rtype.is_float()) { + // Cast int->float when the other operand is a float + rhs = cast(ltype, rhs); + } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) { + // Promote int to higher bits e.g. int8 + int16 --> int16 + int16 + if (ltype.bits() < rtype.bits()) { + lhs = cast(rtype, lhs); + } else { + rhs = cast(ltype, rhs); + } + } else if ((ltype.is_int() && rtype.is_uint()) || (ltype.is_uint() && rtype.is_int())) { + // Handle mixing signed and unsigned integers + if (ltype.bits() < rtype.bits()) { + lhs = cast(rtype, lhs); + } else if (ltype.bits() > rtype.bits()) { + rhs = cast(ltype, rhs); + } else { + // The width of signed and unsigned integers is same. + if (ltype.is_uint()) { + rhs = cast(ltype, rhs); + } else { + lhs = cast(rtype, lhs); + } } - } else if ((lhs.dtype().is_int() && rhs.dtype().is_uint()) || - (lhs.dtype().is_uint() && rhs.dtype().is_int())) { - int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits()); - lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs, span); - rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs, span); } else { LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype; } diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index dd7fee37e2d1..3fa0c63b2e2f 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -21,6 +21,9 @@ #include +#include +#include + namespace tvm { namespace tir { @@ -41,30 +44,35 @@ void VerifySRefTree(const ScheduleState& self); */ void VerifyCachedFlags(const ScheduleState& self); -/******** Scope ********/ +/******** IR Module ********/ /*! - * \brief Gets the sref to the scope root block, exclusive - * \param sref The block or loop sref to be retrieved - * \return The sref to the scope root block. NullOpt if `sref` is the root block of the IR + * \brief Get PrimFunc and GlobalVar that the root block belongs to + * \param mod The IRModule + * \param root_block The root block of the PrimFunc + * \param result_g_var The result GlobalVar + * \return The result PrimFunc where the root block belongs to + * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write */ -Optional GetScopeRoot(const StmtSRef& sref); +const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, + GlobalVar* result_g_var); +/******** Scope ********/ /*! * \brief Checks if scope the specified sref is in is a stage-pipeline and return it - * \param prim The name of the schedule primitive * \param self The schedule state * \param sref The sref whose scope is to be checked + * \param require_stage_pipeline A boolean indicating whether to check stage pipeline * \throw ScheduleError if the sref has been the root of the AST (so it has no scope root), or its * scope root is not a stage pipeline * \return The block sref to the scope root */ -StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const StmtSRef& sref); +StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline); /*! * \brief Checks whether the block is a complete block under the scope * \param self The schedule state * \param block_sref The block to be checked - * \param scope_root The sref to the root block of the scope that `block_sref` is in + * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in * \return A boolean indicating if the block is a complete block * \note Definition of a complete block: * 1) All block vars are data parallel @@ -73,10 +81,10 @@ StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const Stmt * 3) No overlap between the buffers the block reads and writes */ bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, - const StmtSRef& scope_root); + const StmtSRef& scope_root_sref); /*! - * \brief Checks if the block is a complete block + * \brief Check if the block is a complete block under the scope * \param self The schedule state * \param block_sref The sref to the block whose completeness is to be checked * \param scope_root_sref The scope root of the block @@ -85,6 +93,47 @@ bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref); +/*! + * \brief Check whether the block is a reduction block under the scope + * \param self The schedule state + * \param block_sref The block to be checked + * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in + * \return A boolean indicating if the block is a reduction block + * \note Definition of a reduction block: + * 1) The block has the `init` statement + * 2) All the block bindings are quasi-affine expressions + * 3) All block vars are either data parallel block vars or reduction block vars + * 4) Dominant: the block is the only writer of its output, dominating the reader of its output + * buffers + * 5) The reduction block vars are not used to index the output buffers + */ +bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref); + +/*! + * \brief Check if the block is a reduction block under the scope + * \param self The schedule state + * \param block_sref The sref of the block to be checked + * \param scope_root_sref The scope root of the block + * \throw ScheduleError If the block is not a reduction block + */ +void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref); + +/*! + * \brief Check whether a subtree on SRef tree has compact data flow, and throw an exception if the + * subtree does not have compact data flow + * \details For a given StmtSRef, We say the subtree rooted from the StmtSRef has "compact data + * flow" property if: + * - the scope root of the input subtree root has stage-pipeline property, and + * - all its child blocks on SRef tree are complete blocks or reduction blocks. + * \param self The schedule state + * \param subtree_root_sref The root of the subtree to be checked in the SRef tree + * \throw ScheduleError If the subtree does not have compact data flow + * \sa IsCompleteBlock, IsReductionBlock + */ +void CheckSRefSubtreeCompactDataFlow(const ScheduleState& self, const StmtSRef& subtree_root_sref); + /******** Binding ********/ /*! * \brief Verifies if the block binding in a specific BlockRealize is an affine binding. @@ -97,6 +146,15 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, arith::Analyzer* analyzer); +/*! + * \brief Check whether a block has an affine binding using the cached flag, and throw an exception + * if the block does not have an affine binding. + * \param self The schedule state + * \param block The block to be checked + * \throw ScheduleError If the input block does not have an affine binding + */ +void CheckAffineBinding(const ScheduleState& self, Block block); + /*! * \brief Extracts the ranges of loop variables in a path of the sref tree * \param low_inclusive The lowest node in the path @@ -119,29 +177,88 @@ Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, */ Map GetBindings(const BlockRealize& realize); +/*! + * \brief Get the vars involved in the bindings of data parallel block vars and reduction block + * vars, respectively + * \param block_realize The BlockRealize to be analyzed + * \param data_par_vars The vars that appear in the binding of any data parallel block iter + * \param reduce_vars The vars that appear in the binding of any reduction block iter + * \return A boolean indicating whether the block has block iters that is neither a data parallel + * block iter nor a reduction block iter + */ +bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, + std::unordered_set* data_par_vars, + std::unordered_set* reduce_vars); + /******** Block-loop relation ********/ + /*! - * \brief Retrieves blocks in a specific function with its name + * \brief Gets StmtSRefs of leaf blocks of a scope where a specific block/loop is in * \param self The schedule state - * \param name The name of the blocks to be retrieved - * \param func_name The name of the function - * \return A list of blocks with the specific name + * \param parent_sref The StmtSRef that points to the parent block/loop + * \return A list of StmtSRefs of leaf block */ -Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name); +Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, const StmtSRef& parent_sref); + /*! - * \brief Gets the parent loops of the block in its scope, from outer to inner - * \param self The schedule state - * \param block_sref The query block - * \return A list of loops above the given block in its scope, from outer to inner + * \brief Gets the BlockRealize of the leaf blocks of a scope where a specific block/loop is in + * \param parent_sref The StmtSRef that points to the parent block/loop + * \return A list of leaf BlockRealize */ -Array GetLoops(const StmtSRef& block_sref); +Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref); + /*! - * \brief Gets the leaf blocks of a scope where a specific block/loop is in + * \brief Get the BlockRealize of the single child block of the block or loop specified by + * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks * \param self The schedule state * \param parent_sref The StmtSRef that points to the parent block/loop - * \return A list of leaf blocks + * \return The BlockRealize of the single child block + * \throw ScheduleError If there is 0 or multiple child blocks + */ +BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref); +/*! + * \brief Get the BlockRealize of the input block + * \param self The schedule state + * \param block_sref The StmtSRef of the queried block + * \return The BlockRealize of the input block + */ +BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); + +/******** Block-buffer relation ********/ + +/*! + * \brief Get the BlockRealize of the single child block of the block or loop specified by + * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks + * \param self The schedule state + * \param block The queried block + * \param n The index of the queried buffer + * \return The buffer of the n-th write region of the block. + * \throw ScheduleError If the buffer index is out of bound. + */ +Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n); + +/******** Commutative Reducer ********/ + +/*! + * \brief Get the list of the registered reducer-getter functions + * \return The list of the registered reducer-getter functions + * \sa ReducerRegistry + */ +std::vector> GetReducerGetters(); + +/*! + * \brief Given the input identity and the combiner BufferStore of a reduction, extract the + * corresponding commutative reducer and its lhs, rhs if possible. + * \param identity The identity of the reduction + * \param combiner The combiner of the reduction + * \param result_reducer The extracted CommReducer + * \param lhs The extracted lhs of the reducer + * \param rhs The extracted rhs of the reducer + * \return A boolean indicating whether a corresponding commutative reducer is found */ -Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); +bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, + CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index d58dece3c644..c9f8ff4c7e75 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -21,8 +21,37 @@ namespace tvm { namespace tir { +/******** IR Module ********/ + +const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, + GlobalVar* result_g_var) { + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + if (const auto* realize = func->body.as()) { + if (realize->block.get() == root_block) { + if (result_g_var != nullptr) { + *result_g_var = g_var; + } + return func; + } + } + } + } + LOG(FATAL) << "IndexError: Could not get the corresponding function in the schedule state of the " + "statement:\n" + << GetRef(root_block); + throw; +} + /******** Scope ********/ +/*! + * \brief Gets the sref to the scope root block, exclusive + * \param sref The block or loop sref to be retrieved + * \return The sref to the scope root block. NullOpt if `sref` is the root block of the IR + */ Optional GetScopeRoot(const StmtSRef& sref) { for (const StmtSRefNode* p = sref->parent; p != nullptr; p = p->parent) { if (p->stmt->IsInstance()) { @@ -32,7 +61,8 @@ Optional GetScopeRoot(const StmtSRef& sref) { return NullOpt; } -StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const StmtSRef& sref) { +StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, + bool require_stage_pipeline) { class RootBlockError : public ScheduleError { public: explicit RootBlockError(IRModule mod) : mod_(mod) {} @@ -75,7 +105,7 @@ Definition of a scope that is a stage pipeline: throw RootBlockError(self->mod); } bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline; - if (stage_pipeline == false) { + if (require_stage_pipeline && stage_pipeline == false) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); throw NotStagePipelineError(self->mod, GetRef(block)); } @@ -106,20 +136,29 @@ bool IsDominantBlock(const BlockScope& self, const StmtSRef& block_sref) { return true; } -bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, - const StmtSRef& scope_root) { - BlockScope scope = self->GetBlockScope(scope_root); +/*! + * \brief A helper function that checks whether a given block is a complete block under the scope, + * or return the condition it violates if it is not a complete block + * \param self The schedule state + * \param block_sref The block to be checked + * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in + * \return 0 if the block is a complete block, or a positive integer indicating which condition is + * first violated + */ +int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + BlockScope scope = self->GetBlockScope(scope_root_sref); // Cond 1. All block vars are data parallel - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type != kDataPar) { - return false; + return 1; } } // Cond 2. Dominant: the block is the only writer of its output, // dominating the reader of its output buffers if (!IsDominantBlock(scope, block_sref)) { - return false; + return 2; } // Cond 3. No overlap between the buffers the block reads and writes std::unordered_set written_buffers; @@ -129,35 +168,187 @@ bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, } for (const BufferRegion& read : block->reads) { if (written_buffers.count(read->buffer.get())) { - return false; + return 3; } } - return true; + return 0; +} + +bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + return CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref) == 0; } void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { class IncompleteBlockError : public ScheduleError { public: - explicit IncompleteBlockError(IRModule mod, Block block) : mod_(mod), block_(block) {} + explicit IncompleteBlockError(IRModule mod, Block block, int violated_cond) + : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} String FastErrorString() const final { return "ScheduleError: Incomplete block"; } String DetailRenderTemplate() const final { - return R"(The block {0} is not a complete block. -Definition of a complete block: + std::ostringstream os; + os << "The block {0} is not a complete block - it violates condition #" << violated_cond_ + << ".\n" + << R"(Definition of a complete block: 1) All block vars are data parallel 2) Dominant: the block is the only writer of its output, dominating the reader of its output buffers 3) No overlap between the buffers the block reads and writes)"; + return os.str(); } IRModule mod() const final { return mod_; } Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; + int violated_cond_; }; - bool result = IsCompleteBlock(self, block_sref, scope_root_sref); - if (result == false) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); - throw IncompleteBlockError(self->mod, GetRef(block)); + int error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref); + if (error_code != 0) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + throw IncompleteBlockError(self->mod, GetRef(block), error_code); + } +} + +/*! + * \brief A helper function that checks whether a given block is a reduction block under the scope, + * or return the condition it violates if it is not a reduction block + * \param self The schedule state + * \param block_sref The block to be checked + * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in + * \return 0 if the block is a reduction block, or a positive integer indicating which condition is + * first violated + */ +int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + BlockScope scope = self->GetBlockScope(scope_root_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + // Cond 1. The block has the `init` statement. + if (!block->init.defined()) { + return 1; + } + // Cond 2. All the block bindings are quasi-affine expressions. + if (!self->IsAffineBlockBinding(block_sref)) { + return 2; + } + // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile, + // we collect all the reduction block vars. + std::unordered_set reduction_block_vars; + reduction_block_vars.reserve(block->iter_vars.size()); + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { + return 3; + } else if (iter_var->iter_type == kCommReduce) { + reduction_block_vars.insert(iter_var->var.get()); + } + } + // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its + // output buffers. + if (!IsDominantBlock(scope, block_sref)) { + return 4; + } + // Cond 5. The reduction block vars are not used to index the output buffers. + std::unordered_set buffer_written; + buffer_written.reserve(block->writes.size()); + for (const BufferRegion& write_region : block->writes) { + buffer_written.insert(write_region->buffer.get()); + } + bool affected = false; + PreOrderVisit(block->body, [&](const ObjectRef& obj) { + if (affected) { + return false; + } + if (const auto* store = obj.as()) { + ICHECK(buffer_written.count(store->buffer.get())) + << "ValueError: The buffer \"" << store->buffer + << "\" is written in the block but is not in the block's signature"; + for (const PrimExpr& index : store->indices) { + if (UsesVar(index, [&reduction_block_vars](const VarNode* var) { + return reduction_block_vars.count(var); + })) { + affected = true; + return false; + } + } + return false; + } + return true; + }); + return !affected ? 0 : 5; +} + +bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + return CheckReductionBlockErrorCode(self, block_sref, scope_root_sref) == 0; +} + +void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + class NotReductionBlockError : public ScheduleError { + public: + explicit NotReductionBlockError(IRModule mod, Block block, int violated_cond) + : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} + String FastErrorString() const final { return "ScheduleError: Not a reduction block"; } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The block {0} is not a reduction block - it violates condition #" << violated_cond_ + << ".\n" + << R"(Definition of a reduction block: +1) The block has the `init` statement +2) All the block bindings are quasi-affine expressions +3) All block vars are either data parallel block vars or reduction block vars +4) Dominant: the block is the only writer of its output, dominating the reader of its output buffers +5) The reduction block vars are not used to index the output buffers)"; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + Block block_; + int violated_cond_; + }; + + int error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref); + if (error_code != 0) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + throw NotReductionBlockError(self->mod, GetRef(block), error_code); + } +} + +void CheckSRefSubtreeCompactDataFlow(const ScheduleState& self, const StmtSRef& subtree_root_sref) { + class NotCompactDataFlowError : public ScheduleError { + public: + explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block) + : mod_(std::move(mod)), + subtree_root_(std::move(subtree_root)), + violate_block_(std::move(violate_block)) { + ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); + } + String FastErrorString() const final { + return "ScheduleError: The queried subtree root in SRef tree does not have compact data " + "flow, because some of its child block on SRef tree is neither a complete block nor a " + "reduction block"; + } + String DetailRenderTemplate() const final { + return "The queried subtree root {0} in SRef tree does not have compact data flow, because " + "its child block {1} on SRef tree is neither a complete block nor a reduction block"; + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {subtree_root_, violate_block_}; } + + IRModule mod_; + Stmt subtree_root_; + Block violate_block_; + }; + + StmtSRef scope_root = GetScopeRoot(self, subtree_root_sref, /*require_stage_pipeline=*/true); + Array child_blocks = GetChildBlockSRefOnSRefTree(self, scope_root); + for (const StmtSRef& block : child_blocks) { + if (!IsCompleteBlock(self, block, scope_root) && !IsReductionBlock(self, block, scope_root)) { + const BlockNode* violate_block = TVM_SREF_TO_BLOCK(violate_block, block); + throw NotCompactDataFlowError(self->mod, GetRef(subtree_root_sref->stmt), + GetRef(violate_block)); + } } } @@ -186,6 +377,28 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va return true; } +void CheckAffineBinding(const ScheduleState& self, Block block) { + class NotAffineBindingError : public ScheduleError { + public: + explicit NotAffineBindingError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + String FastErrorString() const final { + return "ScheduleError: The block is required to have an affine binding"; + } + String DetailRenderTemplate() const final { + return "The block {0} is required to have an affine binding"; + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + Block block_; + }; + + if (!self->IsAffineBlockBinding(self->stmt2ref.at(block.get()))) { + throw NotAffineBindingError(self->mod, std::move(block)); + } +} + Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, const Optional& high_exclusive, const runtime::StorageScope& extra_relax_scope) { @@ -229,74 +442,504 @@ Map GetBindings(const BlockRealize& realize) { return result; } -/******** Block-loop relation ********/ - -Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name) { - struct Finder : public StmtVisitor { - explicit Finder(const ScheduleState& self, const String& name) : self_(self), name_(name) {} +bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, + std::unordered_set* data_par_vars, + std::unordered_set* reduce_vars) { + Block block = block_realize->block; + ICHECK(block_realize->block.same_as(block)) + << "ValueError: The input `block_realize` is required to be the exact BlockRealize of the " + "input block"; - void VisitStmt_(const BlockNode* block) override { - if (block->name_hint == name_) { - auto it = self_->stmt2ref.find(block); - ICHECK(it != self_->stmt2ref.end()); - results_.push_back(it->second); - } - StmtVisitor::VisitStmt_(block); + bool has_block_vars_of_other_types = false; + ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); + int n = static_cast(block->iter_vars.size()); + for (int i = 0; i < n; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& iter_value = block_realize->iter_values[i]; + std::unordered_set* set = nullptr; + if (iter_var->iter_type == IterVarType::kDataPar) { + set = data_par_vars; + } else if (iter_var->iter_type == IterVarType::kCommReduce) { + set = reduce_vars; + } else { + has_block_vars_of_other_types = true; } - const ScheduleState& self_; - const String& name_; - Array results_; - }; + Array vars_in_binding = UndefinedVars(iter_value); + for (const Var& var : vars_in_binding) { + set->insert(var.get()); + } + } - BaseFunc func = self->mod->Lookup(func_name); - const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode); - Finder finder(self, name); - finder(prim_func->body); - return std::move(finder.results_); + return has_block_vars_of_other_types; } -Array GetLoops(const StmtSRef& block_sref) { - std::vector result; - for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance(); - parent = parent->parent) { - result.push_back(GetRef(parent)); +/******** Block-loop relation ********/ + +Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref) { + Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); + Array child_block_srefs; + child_block_srefs.reserve(child_block_realize.size()); + + for (BlockRealize realize : child_block_realize) { + child_block_srefs.push_back(self->stmt2ref.at(realize->block.get())); } - return {result.rbegin(), result.rend()}; + return child_block_srefs; } -Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) { +Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref) { struct Collector : public StmtVisitor { - public: - static Array Collect(const ScheduleState& self, const Stmt& stmt) { - Collector collector(self); + static Array Collect(const Stmt& stmt) { + Collector collector; collector(stmt); return std::move(collector.result_); } - private: - explicit Collector(const ScheduleState& self) : self_(self) {} - - void VisitStmt_(const BlockNode* block) final { - auto it = self_->stmt2ref.find(block); - ICHECK(it != self_->stmt2ref.end()); - result_.push_back(it->second); + void VisitStmt_(const BlockRealizeNode* block_realize) final { + result_.push_back(GetRef(block_realize)); } - const ScheduleState& self_; - Array result_; + Array result_; }; if (parent_sref->stmt->IsInstance()) { const auto* loop = static_cast(parent_sref->stmt); - return Collector::Collect(self, loop->body); + return Collector::Collect(loop->body); } else if (parent_sref->stmt->IsInstance()) { const auto* block = static_cast(parent_sref->stmt); - return Collector::Collect(self, block->body); + return Collector::Collect(block->body); } ICHECK(false) << "Unreachable"; throw; } +BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref) { + class NonSingleChildBlockError : public ScheduleError { + public: + explicit NonSingleChildBlockError(IRModule mod, const StmtSRef& sref) + : mod_(std::move(mod)), stmt_(GetRef(sref->stmt)) { + sref_type_ = stmt_.as() != nullptr ? "block" : "loop"; + } + + String FastErrorString() const final { + std::ostringstream os; + os << "ScheduleError: The " << sref_type_ << " is required to have only one child block"; + return os.str(); + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The " << sref_type_ << " {0} is required to have only one child block"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {stmt_}; } + + IRModule mod_; + Stmt stmt_; + String sref_type_; + }; + + Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); + if (child_block_realize.size() != 1) { + throw NonSingleChildBlockError(self->mod, parent_sref); + } + return child_block_realize[0]; +} + +BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref) { + struct BlockRealizeFinder : public StmtVisitor { + explicit BlockRealizeFinder(const BlockNode* target_block) + : target_block(target_block), result(nullptr) {} + + void VisitStmt(const Stmt& stmt) final { + if (result != nullptr) { + return; + } + StmtVisitor::VisitStmt(stmt); + } + + void VisitStmt_(const BlockRealizeNode* block_realize) final { + if (block_realize->block.get() == target_block) { + result = block_realize; + } + // No need to visit recursively, since the deeper BlockRealizes must not be the result. + } + + const BlockNode* target_block; + const BlockRealizeNode* result; + }; + + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block_sref->parent == nullptr) { + const PrimFuncNode* func = GetRootPrimFunc(self->mod, block, nullptr); + return Downcast(func->body); + } else { + BlockRealizeFinder finder(block); + finder(GetRef(block_sref->parent->stmt)); + ICHECK(finder.result != nullptr) + << "InternalError: Cannot find the BlockRealize of block " << GetRef(block); + return GetRef(finder.result); + } +} + +/******** Block-buffer relation ********/ + +Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) { + class WriteBufferIndexOutOfRangeError : public ScheduleError { + public: + explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index) + : mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {} + + String FastErrorString() const final { + return "ScheduleError: The input `buffer_index` is out of range. It is required to be in " + "range [0, num_write_regions) where `num_write_regions` is the number of buffer " + "regions written by the block."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + size_t num_writes = block_->writes.size(); + os << "The block {0} has " << num_writes + << " write regions, so `buffer_index` is required to be in [0, " << num_writes + << "). However, the input `buffer_index` is " << buffer_index_ + << ", which is out of the expected range"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + int buffer_index_; + }; + + if (n < 0 || static_cast(n) >= block->writes.size()) { + throw WriteBufferIndexOutOfRangeError(self->mod, block, n); + } + return block->writes[n]->buffer; +} + +/******** Pattern Matcher ********/ + +/*! + * \brief PrimExpr pattern matcher. + * + * It is different from the pattern matcher in arith/pattern_match.h, which is dedicated + * for compile-time constant patterns. This pattern matcher can work on dynamic user-specific + * patterns. + * + * The code below shows how to use the pattern matcher. + * + * \code + * + * Var x("x"), y("y"); + * // use PrimExpr to declare patterns, x, y are holes that can be filled with + * PatternMatcher pattern_matcher(x + y); + * // expr = C[i, j] + A[i, k] * B[k, j], which is the expr we want to match + * pattern_matcher.Match(expr); + * + * if (pattern_matcher.Success()) { + * pattern_matcher.Eval(x) // C[i, j] + * pattern_matcher.Eval(y) // A[i, k] * B[k, j] + * } + * + * \endcode + */ +class PatternMatcher : public ExprVisitor { + public: + explicit PatternMatcher(PrimExpr pattern) : pattern_(std::move(pattern)) {} + + void VisitExpr_(const VarNode* op) final { + auto it = filled_map_.find(op); + if (it == filled_map_.end()) { + filled_map_[op] = expr_to_match_; + } else { + ExprDeepEqual equal; + if (it->second.same_as(expr_to_match_) || equal(it->second, expr_to_match_)) return; + match_success_ = false; + } + } + + void VisitExpr_(const LoadNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!op->buffer_var.same_as(ptr->buffer_var)) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->predicate; + VisitExpr(op->predicate); + expr_to_match_ = ptr->index; + VisitExpr(op->index); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const LetNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->var; + VisitExpr(op->var); + expr_to_match_ = ptr->value; + VisitExpr(op->value); + expr_to_match_ = ptr->body; + VisitExpr(op->body); + std::swap(expr_to_match_, tmp); + } + } + + void VisitExpr_(const CallNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!op->op.same_as(ptr->op)) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + for (size_t i = 0; i < op->args.size(); ++i) { + expr_to_match_ = ptr->args[i]; + VisitExpr(op->args[i]); + } + std::swap(expr_to_match_, tmp); + } + } + } + +#define TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OpName) \ + void VisitExpr_(const OpName* op) { \ + const auto* ptr = expr_to_match_.as(); \ + if (ptr == nullptr) { \ + match_success_ = false; \ + } else { \ + PrimExpr current = expr_to_match_; \ + expr_to_match_ = ptr->a; \ + VisitExpr(op->a); \ + expr_to_match_ = ptr->b; \ + VisitExpr(op->b); \ + std::swap(expr_to_match_, current); \ + } \ + } + + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AddNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(SubNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MulNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(DivNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(ModNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorDivNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorModNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MinNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MaxNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(EQNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(NENode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LTNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LENode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GTNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GENode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AndNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OrNode); + + void VisitExpr_(const CastNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!runtime::TypeEqual(op->dtype, ptr->dtype)) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->value; + VisitExpr(op->value); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const NotNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->a; + VisitExpr(op->a); + std::swap(expr_to_match_, tmp); + } + } + + void VisitExpr_(const SelectNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->condition; + VisitExpr(op->condition); + expr_to_match_ = ptr->true_value; + VisitExpr(op->true_value); + expr_to_match_ = ptr->false_value; + VisitExpr(op->false_value); + std::swap(expr_to_match_, tmp); + } + } + + void VisitExpr_(const RampNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (op->lanes != ptr->lanes) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->base; + VisitExpr(op->base); + expr_to_match_ = ptr->stride; + VisitExpr(op->stride); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const BroadcastNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (op->lanes != ptr->lanes) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->value; + VisitExpr(op->value); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const ShuffleNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (op->vectors.size() != ptr->vectors.size() || op->indices.size() != ptr->indices.size()) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + for (size_t i = 0; i < op->indices.size(); ++i) { + expr_to_match_ = ptr->indices[i]; + VisitExpr(op->indices[i]); + } + for (size_t i = 0; i < op->vectors.size(); ++i) { + expr_to_match_ = ptr->vectors[i]; + VisitExpr(op->vectors[i]); + } + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const IntImmNode* op) final { + const auto* ptr = expr_to_match_.as(); + match_success_ = ptr != nullptr && op->value == ptr->value; + } + + void VisitExpr_(const FloatImmNode* op) final { + const auto* ptr = expr_to_match_.as(); + match_success_ = ptr != nullptr && op->value == ptr->value; + } + + void VisitExpr_(const StringImmNode* op) final { + const auto* ptr = expr_to_match_.as(); + match_success_ = ptr != nullptr && op->value == ptr->value; + } + + void VisitExpr_(const BufferLoadNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!op->buffer.same_as(ptr->buffer) || op->indices.size() != ptr->indices.size()) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + for (size_t i = 0; i < op->indices.size(); ++i) { + expr_to_match_ = ptr->indices[i]; + VisitExpr(op->indices[i]); + } + std::swap(expr_to_match_, tmp); + } + } + } + + void Match(const PrimExpr& expr_to_match) { + this->match_success_ = true; + this->filled_map_.clear(); + this->expr_to_match_ = expr_to_match; + this->operator()(pattern_); + } + + PrimExpr Eval(const Var& var) { + auto it = filled_map_.find(var.operator->()); + ICHECK(it != filled_map_.end()) << "Unknown pattern variable"; + ICHECK(match_success_) << "Match failed"; + return it->second; + } + + bool Success() const { return match_success_; } + + private: + bool match_success_{true}; + PrimExpr pattern_, expr_to_match_; + std::unordered_map filled_map_; +}; + +/******** Commutative Reducer ********/ + +bool MatchReducer(const CommReducer& reducer, const PrimExpr& identity, const PrimExpr& combiner, + const BufferLoad& load, PrimExpr* lhs, PrimExpr* rhs) { + if (!ExprDeepEqual()(reducer->identity_element[0], identity)) { + return false; + } + PatternMatcher pattern_matcher(reducer->result[0]); + pattern_matcher.Match(combiner); + if (pattern_matcher.Success()) { + PrimExpr lhs_tmp = pattern_matcher.Eval(reducer->lhs[0]); + PrimExpr rhs_tmp = pattern_matcher.Eval(reducer->rhs[0]); + if (ExprDeepEqual()(load, lhs_tmp)) { + *lhs = std::move(lhs_tmp); + *rhs = std::move(rhs_tmp); + } + return true; + } + return false; +} + +bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, + CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs) { + BufferLoad load(combiner->buffer, combiner->indices); + // Check reduction patterns. + for (const TypedPackedFunc& reducer_getter : GetReducerGetters()) { + CommReducer reducer = reducer_getter(identity.dtype()); + if (MatchReducer(reducer, identity, combiner->value, load, lhs, rhs)) { + *result_reducer = std::move(reducer); + return true; + } + } + return false; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 0563d39427b1..cd9aad8ae512 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -18,16 +18,19 @@ */ #include "./concrete_schedule.h" +#include + namespace tvm { namespace tir { -Schedule Schedule::Concrete(IRModule mod, int debug_mode, - ScheduleErrorRenderLevel error_render_level) { +Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); - n->state_ = ScheduleState(mod, debug_mode); + n->state_ = ScheduleState(mod, debug_mask); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); + support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); return Schedule(std::move(n)); } @@ -50,7 +53,7 @@ class ScheduleCopier { n->mod = src_state->mod; n->block_info = copier.Copy(src_state->block_info); n->stmt2ref = copier.Copy(src_state->stmt2ref); - n->debug_mode = src_state->debug_mode; + n->debug_mask = src_state->debug_mask; *new_state = ScheduleState(std::move(n)); *new_symbol_table = copier.Copy(self->symbol_table_); } @@ -182,8 +185,8 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb Schedule ConcreteScheduleNode::Copy() const { ObjectPtr n = make_object(); n->error_render_level_ = this->error_render_level_; - this->Copy(&n->state_, &n->symbol_table_); - n->analyzer_ = std::make_unique(); + ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); + n->analyzer_ = std::make_unique(); // new analyzer needed because it is stateful return Schedule(std::move(n)); } @@ -207,7 +210,31 @@ Schedule ConcreteScheduleNode::Copy() const { } \ } -/******** Block/Loop relation ********/ +/******** Schedule: Schedule: Sampling ********/ + +void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) { + if (seed == -1) { + seed = std::random_device()(); + } + support::LinearCongruentialEngine(&rand_state_).Seed(seed); +} + +support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { + // In order for reproducibility, we computer the new seed using RNG's random state and a different + // set of parameters. Note that both 32767 and 1999999973 are prime numbers. + return (support::LinearCongruentialEngine(&rand_state_)() * 32767) % 1999999973; +} + +ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); + TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); + throw; +} + +/******** Schedule: Get blocks & loops ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { class NotSingleResult : public ScheduleError { @@ -257,8 +284,139 @@ Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { return CreateRV(tir::GetLoops(this->GetSRef(block_rv))); } -/******** Schedule: loops manipulation ********/ -/******** Schedule: compute location ********/ +/******** Schedule: Transform loops ********/ + +LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { + CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)"; + Array loop_srefs = this->GetSRefs(loop_rvs); + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::Fuse(state_, loop_srefs); + TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, + const Array>& factor_rvs) { + class NotSingleInferFactorError : public ScheduleError { + public: + explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} + + String FastErrorString() const final { + return "ScheduleError: only one factor can be specified as -1 or none"; + } + + String DetailRenderTemplate() const final { + return "Only one factor can be specified as -1 or none"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + }; + + class WrongFactorProductError : public ScheduleError { + public: + explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The product of factors is not larger than or equal to the extent of " + "loop"; + } + + String DetailRenderTemplate() const final { + return "The product of factors is not larger than or equal to the extent of loop {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; + }; + // Prepare for the splitting + StmtSRef loop_sref = this->GetSRef(loop_rv); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + Array factors; + factors.reserve(factor_rvs.size()); + int infer_index = -1; + PrimExpr tot_length = 1; + Array results; + TVM_TIR_SCHEDULE_BEGIN(); + // infer factor if needed and check validity of factors + for (size_t i = 0; i < factor_rvs.size(); i++) { + if (!factor_rvs[i].defined()) { + factors.push_back(Integer(-1)); + if (infer_index == -1) { + infer_index = i; + } else { + throw NotSingleInferFactorError(state_->mod); + } + } else { + PrimExpr factor = this->Get(factor_rvs[i].value()); + factors.push_back(factor); + tot_length *= factor; + } + } + if (infer_index != -1) { + factors.Set(infer_index, + this->analyzer_->Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); + } else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) { + throw WrongFactorProductError(state_->mod, GetRef(loop)); + } + results = tir::Split(state_, loop_sref, factors); + TVM_TIR_SCHEDULE_END("split", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(results); +} + +void ConcreteScheduleNode::Reorder(const Array& ordered_loop_rvs) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Reorder(state_, GetSRefs(ordered_loop_rvs)); + TVM_TIR_SCHEDULE_END("reorder", this->error_render_level_); + this->state_->DebugVerify(); +} + +/******** Schedule: Manipulate ForKind ********/ + +void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Parallel(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("parallel", this->error_render_level_); +} + +void ConcreteScheduleNode::Vectorize(const LoopRV& loop_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Vectorize(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("vectorize", this->error_render_level_); +} + +void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) { + if (thread_axis == "vthread") { + LOG(WARNING) << "`vthread` is legacy behavior and is going to be deprecated. Please use " + "`vthread.x`, `vthread.y` and `vthread.z` instead"; + } + TVM_TIR_SCHEDULE_BEGIN(); + tir::Bind(state_, this->GetSRef(loop_rv), + IterVar(/*dom=*/Range(nullptr), /*var=*/Var(thread_axis), /*iter_type=*/kThreadIndex, + /*thread_tag=*/thread_axis)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("bind", this->error_render_level_); +} + +void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Unroll(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("unroll", this->error_render_level_); +} + +/******** Schedule: Insert cache stages ********/ +/******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); @@ -274,14 +432,30 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { this->state_->DebugVerify(); } -/******** Schedule: loop binding/annotation ********/ -/******** Schedule: cache read/write ********/ -/******** Schedule: reduction ********/ -/******** Schedule: blockize & tensorize ********/ +/******** Schedule: Block Annotation ********/ + +void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, + int factor, int offset) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::StorageAlign(state_, this->GetSRef(block_rv), buffer_index, axis, factor, offset); + TVM_TIR_SCHEDULE_END("storage-align", this->error_render_level_); + this->state_->DebugVerify(); +} + +/******** Schedule: Reduction ********/ -/******** FFI ********/ +BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::RFactor(state_, this->GetSRef(loop_rv), factor_axis); + TVM_TIR_SCHEDULE_END("rfactor", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} -TVM_REGISTER_NODE_TYPE(ConcreteScheduleNode); +/******** Schedule: Blockize & Tensorize ********/ +/******** Schedule: Annotation ********/ +/******** Schedule: Misc ********/ } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 8945fb9ee0dc..0bd902d183bf 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -43,23 +43,26 @@ class ConcreteScheduleNode : public ScheduleNode { TSymbolTable symbol_table_; /*! \brief A persistent stateless arithmetic analyzer. */ std::unique_ptr analyzer_; + /*! \brief The value of random state for sampling. */ + support::LinearCongruentialEngine::TRandState rand_state_; public: void VisitAttrs(tvm::AttrVisitor* v) { - // `error_render_level_` is not visited // `state_` is not visited + // `error_render_level_` is not visited // `symbol_table_` is not visited - // `analyzer_` is not visitied + // `analyzer_` is not visited + // `rand_state_` is not visited } virtual ~ConcreteScheduleNode() = default; - static constexpr const char* _type_key = "tir.ConcreteSchedule"; - TVM_DECLARE_BASE_OBJECT_INFO(ConcreteScheduleNode, ScheduleNode); - public: ScheduleState state() const final { return state_; } + Optional trace() const override { return NullOpt; } Schedule Copy() const override; + void Seed(support::LinearCongruentialEngine::TRandState seed = -1) final; + support::LinearCongruentialEngine::TRandState ForkSeed() final; public: /******** Lookup random variables ********/ @@ -68,26 +71,53 @@ class ConcreteScheduleNode : public ScheduleNode { inline PrimExpr Get(const ExprRV& expr_rv) const final; inline StmtSRef GetSRef(const BlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; + inline Array GetSRefs(const Array& rvs) const; + inline Array GetSRefs(const Array& rvs) const; void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); } using ScheduleNode::GetSRef; public: - /******** Block/Loop relation ********/ + /******** Schedule: Sampling ********/ + /*! + * \brief Sample an integer given the probability distribution + * \param candidates The candidates + * \param probs The probability distribution of the candidates + * \param decision The sampling decision, if it's given we would validate the decision, otherwise + * we would sample a decision from the distribution and set the decision accordingly. + * \return The random variable sampled from candidates + */ + ExprRV SampleCategorical(const Array& candidates, const Array& probs, + Optional decision = NullOpt) override; + /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; - /******** Schedule: loops manipulation ********/ - /******** Schedule: compute location ********/ + /******** Schedule: Transform loops ********/ + LoopRV Fuse(const Array& loop_rvs) override; + Array Split(const LoopRV& loop_rv, const Array>& factors) override; + void Reorder(const Array& ordered_loop_rvs) override; + /******** Schedule: Manipulate ForKind ********/ + void Parallel(const LoopRV& loop_rv) override; + void Vectorize(const LoopRV& loop_rv) override; + void Bind(const LoopRV& loop_rv, const String& thread_axis) override; + void Unroll(const LoopRV& loop_rv) override; + /******** Schedule: Insert cache stages ********/ + /******** Schedule: Compute location ********/ void ComputeInline(const BlockRV& block) override; void ReverseComputeInline(const BlockRV& block) override; - /******** Schedule: loop binding/annotation ********/ - /******** Schedule: cache read/write ********/ - /******** Schedule: reduction ********/ - /******** Schedule: blockize & tensorize ********/ + /******** Schedule: Reduction ********/ + BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; + /******** Schedule: Block annotation ********/ + void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, + int offset) override; + /******** Schedule: Blockize & Tensorize ********/ + /******** Schedule: Annotation ********/ + /******** Schedule: Misc ********/ + void EnterPostproc() override {} - /******** Utility functions ********/ protected: + /******** Utility functions ********/ /*! * \brief Copy the schedule state, as well as the symbol table * \param new_state The ScheduleState copied @@ -111,17 +141,11 @@ class ConcreteScheduleNode : public ScheduleNode { template inline T CreateRV(const StmtSRef& sref); /*! - * \brief Add an expr as a random variable into the symbol table - * \param expr The expr to be added to the symbol table + * \brief Add an integer as a random variable into the symbol table + * \param value The integer to be added to the symbol table * \return The new random variable created */ - inline ExprRV CreateRV(const PrimExpr& expr); - /*! - * \brief Add expr as random variables into the symbol table - * \param exprs The expr to be added to the symbol table - * \return The new random variables created - */ - inline Array CreateRV(const Array& exprs); + inline ExprRV CreateRV(int64_t value); /*! \brief Remove a random variable from the symbol table */ inline void RemoveFromSymbolTable(const ObjectRef& rv); }; @@ -132,28 +156,27 @@ class ConcreteScheduleNode : public ScheduleNode { inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const { StmtSRef sref = this->GetSRef(block_rv); - const auto* block = TVM_SREF_TO_BLOCK(block, sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, sref); return GetRef(block); } inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { StmtSRef sref = this->GetSRef(loop_rv); - const auto* loop = TVM_SREF_TO_FOR(loop, sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); return GetRef(loop); } inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { - auto it = this->symbol_table_.find(expr_rv); - if (it == this->symbol_table_.end()) { - LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << expr_rv; - } - const ObjectRef& obj = (*it).second; - const auto* expr_node = obj.as(); - if (expr_node == nullptr) { - LOG(FATAL) << "ValueError: ExprRV's corresponding type is invalid: " - << (obj.defined() ? obj->GetTypeKey() : "None"); - } - return GetRef(expr_node); + PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> Optional { + auto it = this->symbol_table_.find(var); + if (it == this->symbol_table_.end()) { + LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var; + } + const ObjectRef& obj = (*it).second; + const auto* int_imm = TVM_TYPE_AS(int_imm, obj, IntImmNode); + return Integer(int_imm->value); + }); + return this->analyzer_->Simplify(transformed); } inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { @@ -198,6 +221,24 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { return GetRef(sref); } +template +inline Array GetSRefsHelper(const ConcreteScheduleNode* sch, const Array& rvs) { + Array result; + result.reserve(rvs.size()); + for (const T& rv : rvs) { + result.push_back(sch->GetSRef(rv)); + } + return result; +} + +inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { + return GetSRefsHelper(this, rvs); +} + +inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { + return GetSRefsHelper(this, rvs); +} + /******** Adding/Removing elements in the symbol table ********/ template @@ -219,23 +260,12 @@ inline T ConcreteScheduleNode::CreateRV(const StmtSRef& sref) { return std::move(rv); } -inline ExprRV ConcreteScheduleNode::CreateRV(const PrimExpr& expr) { - ExprRV rv; - this->symbol_table_.Set(rv, expr); +inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { + Var rv("v" + std::to_string(this->symbol_table_.size() + 1), DataType::Int(32)); + this->symbol_table_.Set(rv, Integer(static_cast(value))); return std::move(rv); } -inline Array ConcreteScheduleNode::CreateRV(const Array& exprs) { - Array result; - result.reserve(exprs.size()); - for (const PrimExpr& expr : exprs) { - ExprRV rv; - this->symbol_table_.Set(rv, expr); - result.push_back(rv); - } - return result; -} - inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) { auto it = this->symbol_table_.find(obj); if (it != this->symbol_table_.end()) { diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc new file mode 100644 index 000000000000..af721767c32f --- /dev/null +++ b/src/tir/schedule/instruction.cc @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace tir { + +Instruction::Instruction(InstructionKind kind, Array inputs, Array attrs, + Array outputs) { + ObjectPtr n = make_object(); + n->kind = std::move(kind); + n->inputs = std::move(inputs); + n->attrs = std::move(attrs); + n->outputs = std::move(outputs); + this->data_ = std::move(n); +} + +using InstructionKindRegistry = AttrRegistry; + +InstructionKind InstructionKind::Get(const String& name) { + const InstructionKindRegEntry* reg = InstructionKindRegistry::Global()->Get(name); + ICHECK(reg != nullptr) << "AttributeError: Instruction kind " << name << " is not registered"; + return reg->inst_kind_; +} + +InstructionKindRegEntry::InstructionKindRegEntry(uint32_t reg_index) { + this->inst_kind_ = InstructionKind(make_object()); +} + +InstructionKindRegEntry& InstructionKindRegEntry::RegisterOrGet(const String& name) { + return InstructionKindRegistry::Global()->RegisterOrGet(name); +} + +/**************** Repr ****************/ + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { + const auto* self = obj.as(); + ICHECK_NOTNULL(self); + Array inputs; + inputs.reserve(self->inputs.size()); + for (const ObjectRef& obj : self->inputs) { + if (!obj.defined()) { + inputs.push_back(String("None")); + } else if (obj->IsInstance() || obj->IsInstance()) { + inputs.push_back(String("_")); + } else if (const auto* str_obj = obj.as()) { + inputs.push_back(String('"' + std::string(str_obj->data) + '"')); + } else if (obj->IsInstance() || obj->IsInstance()) { + inputs.push_back(obj); + } else if (const auto* expr = obj.as()) { + PrimExpr new_expr = + Substitute(GetRef(expr), [](const Var& var) -> Optional { + ObjectPtr new_var = make_object(*var.get()); + new_var->name_hint = "_"; + return Var(new_var); + }); + std::ostringstream os; + os << new_expr; + inputs.push_back(String(os.str())); + } else { + LOG(FATAL) << "TypeError: Stringifying is not supported for type: " << obj->GetTypeKey(); + throw; + } + } + p->stream << self->kind->f_as_python( + /*inputs=*/inputs, + /*attrs=*/self->attrs, + /*decision=*/NullOpt, + /*outputs=*/Array(self->outputs.size(), String("_"))); + }); + +/**************** FFI ****************/ + +TVM_REGISTER_NODE_TYPE(InstructionNode); +TVM_REGISTER_NODE_TYPE(InstructionKindNode); + +TVM_REGISTER_GLOBAL("tir.schedule.InstructionKindGet").set_body_typed(InstructionKind::Get); +TVM_REGISTER_GLOBAL("tir.schedule.Instruction") + .set_body_typed([](InstructionKind kind, Array inputs, Array attrs, + Array outputs) -> Instruction { + return Instruction(kind, inputs, attrs, outputs); + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h new file mode 100644 index 000000000000..95d636467aa0 --- /dev/null +++ b/src/tir/schedule/instruction_traits.h @@ -0,0 +1,536 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ +#define TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ + +#include +#include + +#include +#include +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Register an InstructionKind using a trait class + * \param InstructionKindTraits A traits class of an InstructionKind + * + * Example: + * + * \code + * + * struct SomeInstructionKindTraits { + * static constexpr const char* kName = "name-of-the-instruction"; + * static constexpr bool kIsPure = false; + * + * // Convertible to `InstructionKindNode::FInstructionApply` + * static Array ApplyToSchedule( + * const tir::Schedule& sch, + * const Array& inputs, + * const Array& attrs, + * const Optional& decision); + * + * // Convertible to `InstructionKindNode::FInstructionAsPython` + * static String AsPython( + * const Array& inputs, + * const Array& attrs, + * const Optional& decision, + * const Array& outputs); + * + * // Convertible to `InstructionKindNode::FInstructionAttrsAsJSON` + * static Array AttrsAsJSON( + * const Array& attrs); + * + * // Convertible to `InstructionKindNode::FInstructionAttrsFromJSON` + * static Array AttrsFromJSON( + * const Array& attrs_record); + * }; + * + * TVM_REGISTER_INST_KIND_TRAITS(SomeInstructionKindTraits); + * + * \endcode + */ +#define TVM_REGISTER_INST_KIND_TRAITS(InstructionKindTraits) \ + TVM_REGISTER_INST_KIND(InstructionKindTraits::kName) \ + .set_is_pure(InstructionKindTraits::kIsPure) \ + .set_apply_to_schedule(InstructionKindTraits::ApplyToSchedule) \ + .set_attrs_as_json(InstructionKindTraits::AttrsAsJSON) \ + .set_attrs_from_json(InstructionKindTraits::AttrsFromJSON) \ + .set_as_python(InstructionKindTraits::AsPython) + +/*! + * \brief A helper to conveniently register an InstructionKind. When inherited in curiously + * recursive template pattern, the derived class `TTraits` only needs to define two functions on the + * unpacked inputs, and the helper handles unpacking and downcasting. See the example for more + * details. + * + * \tparam TTraits The derived class + * + * Example: + * + * \code + * + * struct SamplePerfectTileTraits : public UnpackedInstTraits { + * // The name of this kind of instruction + * static constexpr const char* kName = "SamplePerfectTile"; + * // A boolean indicating if the instruction is pure, i.e. change nothing in the schedule state + * static constexpr bool kIsPure = true; + * // The number of inputs in this kind of instruction + * static constexpr size_t kNumInputs = 1; + * // The number of attributes in this kind of instruction + * static constexpr size_t kNumAttrs = 2; + * // The number of decisions in this kind of instruction (only 0 or 1 is allowed) + * static constexpr size_t kNumDecisions = 1; + * + * // Calling convention: + * // - All the arguments must be ObjectRef + * // - The 1st argument is Schedule + * // - The next `kNumInputs` arguments are input random variables + * // - The next `kNumAttrs` arguments are attributes + * // - The next argument is decision, if `kNumDecisions == 1` + * static Array UnpackedApplyToSchedule( + * Schedule sch, + * LoopRV loop_rv, + * Integer n, + * Integer max_innermost_factor, + * Optional> decision) { + * return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); + * } + * + * // Calling convention: + * // - All the arguments must be ObjectRef + * // - The 1st argument is an array containing names of output random variables + * // - The next `kNumInputs` arguments are names of input random variables + * // - The next `kNumAttrs` arguments are attributes + * // - The next argument is decision, if `kNumDecisions == 1` + * static String UnpackedAsPython( + * Array outputs, + * String loop_rv, + * Integer n, + * Integer max_innermost_factor, + * Optional> decision) { + * PythonAPICall py("sample_perfect_tile"); + * py.Input("loop", loop_rv); + * py.Input("n", n->value); + * py.Input("max_innermost_factor", max_innermost_factor->value); + * py.Decision(decision); + * py.OutputList(outputs); + * return py.Str(); + * } + * + * template + * friend struct UnpackedInstTraits; + * }; + * + * TVM_REGISTER_INST_KIND(SamplePerfectTileTraits); + * \endcode + */ +template +struct UnpackedInstTraits { + /*! + * \brief Unpack the arguments in the calling convention, and feed them into + * `TTraits::UnpackedApplyToSchedule` + * \sa InstructionKindNode::f_apply_to_schedule + */ + static Array ApplyToSchedule(const Schedule& sch, const Array& inputs, + const Array& attrs, + const Optional& decision); + + /*! + * \brief Unpack the arguments in the calling convention, and feed them into + * `TTraits::UnpackedAsPython` + * \sa InstructionKindNode::f_as_python + */ + static String AsPython(const Array& inputs, const Array& attrs, + const Optional& decision, const Array& outputs); + + /*! \brief No customized serializer by default */ + static constexpr std::nullptr_t AttrsAsJSON = nullptr; + + /*! \brief No customized deserializer by default */ + static constexpr std::nullptr_t AttrsFromJSON = nullptr; + + protected: + template + static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, + const Array& inputs); + template + static TVM_ALWAYS_INLINE void _SetAttrs(const runtime::TVMArgsSetter& setter, + const Array& attrs); + template + static TVM_ALWAYS_INLINE void _SetDecision(const runtime::TVMArgsSetter& setter, + const Optional& decision); + static TVM_ALWAYS_INLINE Array _ConvertOutputs(const TVMRetValue& rv); +}; + +/*! + * \brief A helper class that constructs schedule API call in python syntax, + * which helps convert an Inst to a python statement. + * \sa InstructionKindNode::f_as_python + */ +class PythonAPICall { + public: + /*! + * \brief Constructor + * \param method_name The name of the schedule API to be called + */ + explicit PythonAPICall(String method_name) : method_name_(method_name), output_(NullOpt) {} + /*! \brief Add an intger input */ + inline void Input(String arg_name, int arg); + /*! \brief Add an intger input */ + inline void Input(String arg_name, int64_t arg); + /*! \brief Add a double input */ + inline void Input(String arg_name, double arg); + /*! \brief Add an input random variable */ + inline void Input(String arg_name, String arg); + /*! \brief Add an input, dispatched to different implementations according to the object's type */ + inline void Input(String arg_name, ObjectRef arg); + /*! \brief Add the decision */ + inline void Decision(ObjectRef decision); + /*! + * \brief Add a single output random variable + * \param unit_array An array containing only one element + */ + inline void SingleOutput(Array unit_array); + /*! \brief Add a list of output random variables */ + inline void OutputList(Array outputs); + /*! \returns The schedule API call in python syntax */ + inline String Str() const; + + private: + /*! \brief Converts a TVM object to python string and print to the output stream */ + inline void AsPythonString(const ObjectRef& obj, std::ostream& os); + + private: + /*! \brief The name of the API to call */ + String method_name_; + /*! \brief The output of the instruction */ + Optional output_; + /*! \brief The names of input arguments */ + std::vector arg_names_; + /*! \brief The values of input arguments */ + std::vector args_; +}; + +/********** implementation details **********/ + +// forward declaration +namespace details { + +template +struct _ArgsPacker; + +template <> +struct _ArgsPacker<> { + static constexpr bool checked = true; +}; + +template +struct _ArgsPacker { + static constexpr bool checked = + std::is_base_of::value && _ArgsPacker::checked; +}; + +template +struct _MethodType {}; + +template +struct _MethodType { + using return_type = TReturn; + using argument_type = _ArgsPacker; +}; + +template +struct _NumArgs {}; + +template +struct _NumArgs { + static constexpr size_t value = sizeof...(Args); +}; + +template +struct _IsTVMArray : std::false_type {}; + +template +struct _IsTVMArray> : std::true_type {}; + +template +struct _IsSingleObject + : std::integral_constant::value && !_IsTVMArray::value> { +}; + +template +using ReturnType = typename _MethodType>::return_type; + +template +static constexpr bool ArgumentAreAllObjects = + _MethodType>::argument_type::checked; + +template +static constexpr size_t NumArgs = _NumArgs>::value; + +template +static constexpr int IsTVMArray = _IsTVMArray>::value; + +template +static constexpr int IsSingleObject = _IsSingleObject>::value; + +}; // namespace details + +template +Array UnpackedInstTraits::ApplyToSchedule(const Schedule& sch, + const Array& inputs, + const Array& attrs, + const Optional& decision) { + using method_type = decltype(TTraits::UnpackedApplyToSchedule); + using return_type = details::ReturnType; + static_assert(details::ArgumentAreAllObjects, + "All arguments to `UnpackedApplyToSchedule` must be subclasses of ObjectRef"); + constexpr size_t kNumArgs = details::NumArgs; + constexpr size_t kNumInputs = TTraits::kNumInputs; + constexpr size_t kNumAttrs = TTraits::kNumAttrs; + constexpr size_t kNumDecisions = TTraits::kNumDecisions; + static_assert(kNumArgs == 1 + kNumInputs + kNumAttrs + kNumDecisions, + "length of argument list mismatch"); + TVMValue tvm_values[kNumArgs]; + int tvm_type_codes[kNumArgs]; + runtime::TVMArgsSetter setter(tvm_values, tvm_type_codes); + setter(0, sch); + TTraits::template _SetInputs<1>(setter, inputs); + TTraits::template _SetAttrs<1 + kNumInputs>(setter, attrs); + TTraits::template _SetDecision<1 + kNumInputs + kNumAttrs>(setter, decision); + PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void { + using runtime::detail::unpack_call; + constexpr size_t kNumArgs = details::NumArgs; + ICHECK_EQ(args.size(), kNumArgs); + unpack_call(nullptr, TTraits::UnpackedApplyToSchedule, args, rv); + }); + TVMRetValue rv; + pf.CallPacked(TVMArgs(tvm_values, tvm_type_codes, kNumArgs), &rv); + return TTraits::_ConvertOutputs(rv); +} + +template +String UnpackedInstTraits::AsPython(const Array& inputs, + const Array& attrs, + const Optional& decision, + const Array& outputs) { + using method_type = decltype(TTraits::UnpackedAsPython); + using return_type = details::ReturnType; + static_assert(details::ArgumentAreAllObjects, + "All arguments to `UnpackedAsPython` must be subclasses of ObjectRef"); + constexpr size_t kNumArgs = details::NumArgs; + constexpr size_t kNumInputs = TTraits::kNumInputs; + constexpr size_t kNumAttrs = TTraits::kNumAttrs; + constexpr size_t kNumDecisions = TTraits::kNumDecisions; + static_assert(kNumArgs == 1 + kNumInputs + kNumAttrs + kNumDecisions, + "length of argument list mismatch"); + TVMValue tvm_values[kNumArgs]; + int tvm_type_codes[kNumArgs]; + runtime::TVMArgsSetter setter(tvm_values, tvm_type_codes); + setter(0, outputs); + TTraits::template _SetInputs<1>(setter, inputs); + TTraits::template _SetAttrs<1 + kNumInputs>(setter, attrs); + TTraits::template _SetDecision<1 + kNumInputs + kNumAttrs>(setter, decision); + PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void { + using runtime::detail::unpack_call; + constexpr size_t kNumArgs = details::NumArgs; + ICHECK_EQ(args.size(), kNumArgs); + unpack_call(nullptr, TTraits::UnpackedAsPython, args, rv); + }); + TVMRetValue rv; + pf.CallPacked(TVMArgs(tvm_values, tvm_type_codes, kNumArgs), &rv); + String result = rv; + return result; +} + +template +template +TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetInputs(const runtime::TVMArgsSetter& setter, + const Array& inputs) { + constexpr size_t kNumInputs = TTraits::kNumInputs; + ICHECK_EQ(kNumInputs, inputs.size()) + << "ValueError: Incorrect kNumInputs for instruction: " << TTraits::kName; + const ObjectRef* ptr = inputs.template as()->begin(); + for (size_t i = 0; i < kNumInputs; ++i) { + setter(i + index_offset, *(ptr + i)); + } +} + +template +template +TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetAttrs(const runtime::TVMArgsSetter& setter, + const Array& attrs) { + constexpr size_t kNumAttrs = TTraits::kNumAttrs; + ICHECK_EQ(kNumAttrs, attrs.size()) + << "ValueError: Incorrect kNumAttrs for instruction: " << TTraits::kName; + const ObjectRef* ptr = attrs.as()->begin(); + for (size_t i = 0; i < kNumAttrs; ++i) { + setter(i + index_offset, *(ptr + i)); + } +} + +template +template +TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetDecision( + const runtime::TVMArgsSetter& setter, const Optional& decision) { + constexpr size_t kNumDecisions = TTraits::kNumDecisions; + static_assert(kNumDecisions <= 1, "an instruction is supposed to have at most 1 decision"); + if (kNumDecisions == 1) { + setter(index_offset, decision); + } else { + ICHECK(!decision.defined()); + } +} + +template +TVM_ALWAYS_INLINE Array UnpackedInstTraits::_ConvertOutputs( + const TVMRetValue& rv) { + using method_type = decltype(TTraits::UnpackedApplyToSchedule); + using return_type = details::ReturnType; + constexpr int is_array = details::IsTVMArray; + constexpr int is_single_obj = details::IsSingleObject; + constexpr int is_void = std::is_void::value; + static_assert(is_array || is_single_obj || is_void, "return type not supported"); + static_assert(is_array + is_single_obj + is_void == 1, "internal template error"); + if (is_void) { + return {}; + } else if (is_single_obj) { + ObjectRef obj = rv; + return {obj}; + } else if (is_array) { + ObjectRef obj = rv; + const ArrayNode* array = obj.as(); + return GetRef>(array); + } +} + +/********** PythonAPICall **********/ + +inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os) { + if (const auto* str = obj.as()) { + os << str->data; + } else if (const auto* int_imm = obj.as()) { + os << int_imm->value; + } else if (const auto* float_imm = obj.as()) { + os.precision(17); + os << float_imm->value; + } else if (const auto* array = obj.as()) { + os << '['; + bool is_first = true; + for (const ObjectRef& e : *array) { + if (is_first) { + is_first = false; + } else { + os << ", "; + } + AsPythonString(e, os); + } + os << ']'; + } else { + LOG(FATAL) << "ValueError: Cannot translate type '" << obj->GetTypeKey() + << "' to python. Its value is: " << obj; + throw; + } +} + +void PythonAPICall::Input(String arg_name, int arg) { + arg_names_.emplace_back(std::move(arg_name)); + args_.push_back(std::to_string(arg)); +} + +void PythonAPICall::Input(String arg_name, int64_t arg) { + arg_names_.emplace_back(std::move(arg_name)); + args_.push_back(std::to_string(arg)); +} + +void PythonAPICall::Input(String arg_name, double arg) { + arg_names_.emplace_back(std::move(arg_name)); + std::ostringstream os; + os.precision(17); + os << arg; + args_.push_back(os.str()); +} + +void PythonAPICall::Input(String arg_name, String arg) { + arg_names_.emplace_back(std::move(arg_name)); + args_.emplace_back(std::move(arg)); +} + +void PythonAPICall::Input(String arg_name, ObjectRef arg) { + arg_names_.emplace_back(std::move(arg_name)); + std::ostringstream os; + AsPythonString(arg, os); + args_.push_back(os.str()); +} + +void PythonAPICall::Decision(ObjectRef decision) { + if (decision.defined()) { + this->Input("decision", decision); + } +} + +void PythonAPICall::SingleOutput(Array unit_array) { + ICHECK_EQ(unit_array.size(), 1); + this->output_ = unit_array[0]; +} + +void PythonAPICall::OutputList(Array outputs) { + if (outputs.empty()) { + return; + } + if (outputs.size() == 1) { + this->output_ = outputs[0] + ","; + return; + } + std::ostringstream os; + os << outputs[0]; + for (int i = 1, n = outputs.size(); i < n; ++i) { + os << ", " << outputs[i]; + } + this->output_ = os.str(); +} + +String PythonAPICall::Str() const { + std::ostringstream os; + if (output_.defined()) { + os << output_.value() << " = "; + } + os << "sch." << method_name_ << '('; + int n = args_.size(); + for (int i = 0; i < n; ++i) { + if (i > 0) { + os << ", "; + } + if (arg_names_[i].empty()) { + os << args_[i]; + } else { + os << arg_names_[i] << '=' << args_[i]; + } + } + os << ')'; + return os.str(); +} + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index ab8299e38169..be33c2acca10 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -19,14 +19,123 @@ #ifndef TVM_TIR_SCHEDULE_PRIMITIVE_H_ #define TVM_TIR_SCHEDULE_PRIMITIVE_H_ +#include #include namespace tvm { namespace tir { -/******** Schedule: loops manipulation ********/ +/******** Schedule: Sampling ********/ +/*! + * \brief Sample once category from candidates according to the probability weights. + * \param self The schedule to update + * \param rand_state The pointer to schedule's random state + * \param candidates The candidates + * \param probs The probability distribution of the candidates + * \param decision The sampling decision, if any + * \return The random variable sampled from candidates + */ +TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, + const Array& candidates, const Array& probs, + Optional* decision); + +/******** Schedule: Get blocks & loops ********/ +/*! + * \brief Retrieves blocks in a specific function with its name + * \param self The schedule state + * \param name The name of the blocks to be retrieved + * \param func_name The name of the function + * \return A list of blocks with the specific name + */ +Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name); +/*! + * \brief Gets the parent loops of the block in its scope, from outer to inner + * \param self The schedule state + * \param block_sref The query block + * \return A list of loops above the given block in its scope, from outer to inner + */ +Array GetLoops(const StmtSRef& block_sref); +/******** Schedule: Transform loops ********/ +/*! + * Split a loop into a list of consecutive loops. It requires: + * 1) The loop can't have annotation or thread binding. + * 2) The loop must start with 0. + * \param self The state of the schedule + * \param loop_sref The sref to the loop being split + * \param factors The splitting factors + * \return An array of srefs to the loops after splitting + */ +TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, + const Array& factors); +/*! + * \brief Fuse a list of consecutive loops into one. It requires: + * 1) The loops can't have annotations or thread bindings. + * 2) The inner loop must be the only child of the outer loop. + * 3) All loops must start with 0. + * \param self The state of the schedule + * \param loop_srefs An array of srefs to the loops to be fused + * \return The sref to the fused loop + */ +TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); +/*! + * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. + * It requires: + * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , + * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between + * l_1 and l_n (which also indicates they are under the same scope). + * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. + * 3) For every block under the loop nests, its block binding must be affine, and the block + * variables must be either data parallel or reduction. + * 4) No duplicated loops are allowed in the arguments. + * \param self The state of the schedule + * \param ordered_loop_srefs An array of srefs which indicates the new order of loops + */ +TVM_DLL void Reorder(ScheduleState self, const Array& ordered_loop_srefs); -/******** Schedule: compute location ********/ +/******** Schedule: Manipulate ForKind ********/ +/*! + * \brief Parallelize the input loop. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' + * bindings + * \param self The state of the schedule + * \param loop_sref The sref of the loop to be parallelized + */ +TVM_DLL void Parallel(ScheduleState self, const StmtSRef& loop_sref); +/*! + * \brief Vectorize the input loop. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' + * bindings + * \param self The state of the schedule + * \param loop_sref The sref of the loop to be vectorized + */ +TVM_DLL void Vectorize(ScheduleState self, const StmtSRef& loop_sref); +/*! + * \brief Bind the input loop to the given thread axis. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can only + * be contained in data-parallel block iter and reduction block iters' bindings. Otherwise the + * loop can only be contained in data-parallel block iters' bindings + * \param self The state of the schedule + * \param loop_sref The sref of the loop to be bound to the thread axis + * \param thread_axis The thread axis to be bound to the loop + */ +TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis); +/*! + * \brief Unroll the input loop. It requires nothing + * \param self The state of the schedule + * \param loop_sref The loop to be unrolled + */ +TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref); +/******** Schedule: Insert cache stages ********/ +/******** Schedule: Compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: * 1) The block is a complete non-root block, which only produces one buffer @@ -52,14 +161,42 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref); * \param block_sref The sref to the block to be inlined to its producer */ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref); +/******** Schedule: Reduction ********/ +/*! + * \brief Factor a reduction block by the specified loop + * \details See python/tvm/tir/schedule/schedule.py + * \param self The state of the schedule + * \param loop_sref The loop outside block for which we want to do rfactor + * \param factor_axis The position where the new dimension is placed in the new introduced rfactor + * buffer. Suppose the original reduction block writes to buffer `B` with + * ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1, + * ndim(B)]`, and the negative index will be normalized to a non-negative one + * \return The sref of the rfactor block + */ +TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis); +/******** Schedule: Block annotation ********/ +/*! + * \brief Set alignment requirement for specific dimension such that + * stride[axis] == k * factor + offset for some k. This is useful to set memory layout for + * more friendly memory access pattern. For example, we can set alignment to be factor=2, + * offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared + * memory. + * \param block_sref The producer block of the buffer + * \param buffer_index The index of the buffer in block's write region + * \param axis The dimension to be specified for alignment + * \param factor The factor multiple of alignment + * \param offset The required offset factor + */ +TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + int axis, int factor, int offset); -/******** Schedule: loop binding/annotation ********/ - -/******** Schedule: cache read/write ********/ - -/******** Schedule: reduction ********/ +/******** Annotation types for StorageAlign ********/ +using StorageAlignTuple = Array; // (buffer_idx, axis, factor, offset) +using StorageAlignAnnotation = Array; // unordered array of StorageAlignTuple -/******** Schedule: blockize & tensorize ********/ +/******** Schedule: Blockize & Tensorize ********/ +/******** Schedule: Annotation ********/ +/******** Schedule: Misc ********/ } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc new file mode 100644 index 000000000000..937bc7c3802f --- /dev/null +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../transform.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +class StorageAlignAxisOutOfRangeError : public ScheduleError { + public: + explicit StorageAlignAxisOutOfRangeError(IRModule mod, Buffer buffer, int axis) + : mod_(std::move(mod)), buffer_(std::move(buffer)), axis_(axis) {} + + String FastErrorString() const final { + return "ScheduleError: The input `axis` is out of range. It is required to be in range " + "[-ndim, ndim) where `ndim` is the number of dimensions of the buffer to set " + "storage alignment."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + int ndim = static_cast(buffer_->shape.size()); + os << "The buffer to set storage alignment of, " << buffer_->name << ", has " << ndim + << " dimension(s), so `axis` is required to be in [" << -(ndim) << ", " << ndim + << ") for storage_align. However, the input `axis` is " << axis_ + << ", which is out of the expected range."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int axis) { + int ndim = static_cast(buffer->shape.size()); + if (axis < -ndim || axis >= ndim) { + throw StorageAlignAxisOutOfRangeError(mod, buffer, axis); + } + // If axis is negative, convert it to a non-negative one. + if (axis < 0) { + axis += ndim; + } + return axis; + } + + private: + IRModule mod_; + Buffer buffer_; + int axis_; +}; + +/*! + * \brief Find the defining site of the buffer in the given block and its ancestors + * \param block_sref The block sref + * \param buffer The buffer + * \return The defining site of the buffer and whether the buffer is allocated (otherwise the + * buffer is from match_buffer). + */ +std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, + const Buffer& buffer) { + // Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or + // match_buffers. + const StmtSRefNode* defining_site_sref = block_sref.get(); + while (defining_site_sref != nullptr) { + const auto* block = defining_site_sref->StmtAs(); + // If this sref is not a block sref, skip it. + if (block == nullptr) { + defining_site_sref = defining_site_sref->parent; + continue; + } + // Try to find the buffer in `allloc_buffers` + for (const Buffer& alloc_buffer : block->alloc_buffers) { + if (buffer.same_as(alloc_buffer)) { + return {GetRef(defining_site_sref), true}; + } + } + // We do not allow the buffer being defined in `match_buffer`. + for (const MatchBufferRegion match_buffer : block->match_buffers) { + if (buffer.same_as(match_buffer)) { + return {GetRef(defining_site_sref), false}; + } + } + defining_site_sref = defining_site_sref->parent; + } + // If we cannot find the defining site block, it means that the buffer must be in the function's + // buffer_map, which isn't an intermediate buffer. + return {NullOpt, false}; +} + +class NonAllocatedBufferError : public ScheduleError { + public: + explicit NonAllocatedBufferError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {} + + String FastErrorString() const final { + return "ScheduleError: The input buffer is not allocated by a block. This means the buffer is " + " either a function parameter or defined in `match_buffer` of a block."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The input buffer " << buffer_->name + << " is not allocated by a block. This means the buffer is either a function parameter or " + "defined in `match_buffer` of a block."; + return os.str(); + } + + static void CheckBufferAllocated(const IRModule& mod, const StmtSRef& block_sref, + const Buffer& buffer) { + Optional defining_site_sref; + bool is_alloc; + std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, buffer); + if (!defining_site_sref || !is_alloc) { + throw NonAllocatedBufferError(mod, buffer); + } + } + + Array LocationsOfInterest() const final { return {}; } + IRModule mod() const final { return mod_; } + + private: + IRModule mod_; + Buffer buffer_; +}; + +class StorageAlignInvalidFactorError : public ScheduleError { + public: + explicit StorageAlignInvalidFactorError(IRModule mod, int factor) + : mod_(std::move(mod)), factor_(factor) {} + + String FastErrorString() const final { + return "ScheduleError: The input `factor` of storage_align is expected to be a positive " + "number."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The input `factor` of storage_align is expected to be a positive number. However, the " + "input `factor` is " + << factor_ << ", which is out of the expected range."; + return os.str(); + } + + static void Check(const IRModule& mod, int factor) { + if (factor <= 0) { + throw StorageAlignInvalidFactorError(mod, factor); + } + } + + Array LocationsOfInterest() const final { return {}; } + IRModule mod() const final { return mod_; } + + private: + IRModule mod_; + int factor_; +}; + +class StorageAlignInvalidAnnotationError : public ScheduleError { + public: + explicit StorageAlignInvalidAnnotationError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The block annotation for storage align is expected to be an array of " + "4-integer-tuples (buffer_index, axis, factor, offset)."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The block annotation for storage align is expected to be an array of 4-integer-tuples " + "(buffer_index, axis, factor, offset). However, the block annotation with key " + << attr::buffer_dim_align << " of the block {0} is " + << block_->annotations.at(attr::buffer_dim_align) << ", which is unexpected."; + return os.str(); + } + + static StorageAlignAnnotation CheckAndGetAnnotation(const IRModule& mod, const Block& block) { + // Get existing annotation value. + auto it = block->annotations.find(attr::buffer_dim_align); + if (it != block->annotations.end()) { + if (!IsValidAnnotation(block, (*it).second)) { + throw StorageAlignInvalidAnnotationError(mod, block); + } + return Downcast((*it).second); + } + + // Create new annotation value + StorageAlignAnnotation storage_align_annotation; + return storage_align_annotation; + } + + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod() const final { return mod_; } + + private: + static bool IsValidAnnotation(const Block& block, const ObjectRef& anno_value) { + if (!anno_value->IsInstance()) { + return false; + } + auto storage_align_annotations = Downcast>(anno_value); + for (const ObjectRef& storage_align_annotation : storage_align_annotations) { + if (!storage_align_annotation->IsInstance()) { + return false; + } + auto storage_align_tuple = Downcast>(storage_align_annotation); + // Check if the annotation is a 4-tuple. + if (storage_align_tuple.size() != 4) { + return false; + } + for (const ObjectRef& tuple_element : storage_align_tuple) { + if (!tuple_element->IsInstance()) { + return false; + } + } + } + return true; + } + + IRModule mod_; + Block block_; +}; + +void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, + int factor, int offset) { + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + Buffer buffer = GetNthWriteBuffer(self, GetRef(block_ptr), buffer_index); + StorageAlignInvalidFactorError::Check(self->mod, factor); + axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis); + NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer); + + // Step 1: Get existing or create new annotation value. + StorageAlignAnnotation storage_align_annotation = + StorageAlignInvalidAnnotationError::CheckAndGetAnnotation(self->mod, + GetRef(block_ptr)); + + // Step 2: Update the annotation value + // Array> buffer_storage_align = storage_align_annotation[buffer_index]; + bool found = false; + StorageAlignTuple new_storage_align_tuple{Integer(buffer_index), Integer(axis), Integer(factor), + Integer(offset)}; + for (size_t j = 0; j < storage_align_annotation.size(); ++j) { + const auto& storage_align_tuple = storage_align_annotation[j]; + ICHECK(storage_align_tuple.size() == 4); + if (storage_align_tuple[0] == buffer_index && storage_align_tuple[1] == axis) { + storage_align_annotation.Set(j, std::move(new_storage_align_tuple)); + found = true; + break; + } + } + if (!found) { + storage_align_annotation.push_back(std::move(new_storage_align_tuple)); + } + + // Step 3: Replace the block with the new annotation + Block new_block = WithAnnotation(block_ptr, attr::buffer_dim_align, storage_align_annotation); + self->Replace(block_sref, new_block, {{GetRef(block_ptr), new_block}}); +} + +/******** Instruction Registration ********/ + +struct StorageAlignTraits : public UnpackedInstTraits { + static constexpr const char* kName = "StorageAlign"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 4; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, + Integer axis, Integer factor, Integer offset) { + return sch->StorageAlign(block_rv, buffer_index->value, axis->value, factor->value, + offset->value); + } + + static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, + Integer axis, Integer factor, Integer offset) { + PythonAPICall py("storage_align"); + py.Input("block", block_rv); + py.Input("buffer_index", buffer_index); + py.Input("axis", axis); + py.Input("factor", factor); + py.Input("offset", offset); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 6bd6388fafff..2583b21227e4 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -622,7 +622,8 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { Block producer_block = GetRef(_producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); // Step 1. Get the scope block - StmtSRef scope_root_sref = GetScopeRootAndCheckStagePipeline(self, producer_block_sref); + StmtSRef scope_root_sref = + GetScopeRoot(self, producer_block_sref, /*require_stage_pipeline=*/true); // Step 2. Check completeness CheckCompleteBlock(self, producer_block_sref, scope_root_sref); // Step 3. Analyze the block body @@ -649,7 +650,8 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre Block consumer_block = GetRef(_consumer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block); // Step 1. Get the scope block - StmtSRef scope_root_sref = GetScopeRootAndCheckStagePipeline(self, consumer_block_sref); + StmtSRef scope_root_sref = + GetScopeRoot(self, consumer_block_sref, /*require_stage_pipeline=*/true); // Step 2. Check completeness CheckCompleteBlock(self, consumer_block_sref, scope_root_sref); // Step 3. Check if the consumer has a single complete producer @@ -673,5 +675,56 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); } +/******** Instruction Registration ********/ + +struct ComputeInlineTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ComputeInline"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + return sch->ComputeInline(block_rv); + } + + static String UnpackedAsPython(Array outputs, String block_rv) { + PythonAPICall py("compute_inline"); + py.Input("block", block_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct ReverseComputeInlineTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReverseComputeInline"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + return sch->ReverseComputeInline(block_rv); + } + + static String UnpackedAsPython(Array outputs, String block_rv) { + PythonAPICall py("reverse_compute_inline"); + py.Input("block", block_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ComputeInlineTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReverseComputeInlineTraits); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc new file mode 100644 index 000000000000..a6056d607042 --- /dev/null +++ b/src/tir/schedule/primitive/for_kind.cc @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +class WrongBlockIterTypeError : public ScheduleError { + public: + explicit WrongBlockIterTypeError(IRModule mod, ForKind for_kind, Var loop_var, Block block) + : mod_(std::move(mod)), loop_var_(std::move(loop_var)), block_(std::move(block)) { + op_str_ = for_kind == ForKind::kParallel + ? "parallel" + : for_kind == ForKind::kVectorized ? "vectorize" : "bind"; + } + String FastErrorString() const final { + std::ostringstream os; + os << "ScheduleError: The \"" << op_str_ + << "\" cannot be fulfilled with regard to some of its underlying block"; + return os.str(); + } + String DetailRenderTemplate() const final { + std::ostringstream os; + if (op_str_ != "bind") { + os << "The \"" << op_str_ + << "\" cannot be fulfilled with regard to block {0} because some block iter whose block " + "binding contains the loop var is not a data parallel block iter"; + } else { + os << "The \"bind\" cannot be fulfilled with regard to block {0}. This is because some of its" + " block iter whose block binding contains " + << loop_var_ + << " does not meet any of the conditions:\n1) the block iter is data parallel;\n2) the " + "block iter is a reduction block iter, and the thread axis to be bound is " + "\"threadIdx.x/y/z\""; + } + return os.str(); + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + std::string op_str_; + Var loop_var_; + Block block_; +}; + +/*! + * \brief Check if a loop can be parallelized/vectorized/bound with regard to a specific block + * \details There are two conditions: + * 1) The block is required to have affine bindings, and + * 2) For each block iter whose binding contains the input loop variable, either + * - the block iter is data parallel, or + * - the block iter is a reduction block iter, and the input `thread_tag` starts with "threadIdx" + * in case of cross-thread reduction. + * \param self The schedule state + * \param for_kind The desired ForKind (only `kParallel`, `kVectorized` and `kThreadBinding` are + * allowed) + * \param loop_var The loop variable of the loop to be checked + * \param block_realize The block-realize of the block to be checked + * \param thread_scope The thread scope of the thread axis to be bound, which is an invalid value if + * the operation is not "bind" + * \throws ScheduleError If the input loop cannot be parallelized/vectorized/bound with regard to + * the input block + */ +void CheckLoopParallelizableInBlock(const ScheduleState& self, ForKind for_kind, + const Var& loop_var, const BlockRealize& block_realize, + runtime::ThreadScope thread_scope) { + const Block& block = block_realize->block; + + // Cond 1. The block is required to have affine bindings. + CheckAffineBinding(self, block); + + // Cond 2. For each block iter whose binding contains `loop_var`, only two cases are allowed. + ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); + int n_iters = static_cast(block->iter_vars.size()); + for (int i = 0; i < n_iters; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& binding = block_realize->iter_values[i]; + + if (!UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) { + continue; + } + // Only two cases are allowed: + // - The block iter is data parallel, or + // - The block iter is a reduction block iter, and the `thread_scope` is "threadIdx.x/y/z" + // in case of cross-thread reduction. + IterVarType iter_type = iter_var->iter_type; + if (!(iter_type == kDataPar || + (iter_type == kCommReduce && thread_scope.rank == 1 && thread_scope.dim_index != -1))) { + throw WrongBlockIterTypeError(self->mod, for_kind, loop_var, block); + } + } +} + +/*! + * \brief For each block (recursive) under the given loop, check whether the input loop can be + * parallelized/vectorized/bound with regard to the block + * \param self The schedule state + * \param loop The loop to be parallelized/vectorized/bound + * \param for_kind The desired ForKind (only `kParallel`, `kVectorized` and `kThreadBinding` are + * allowed) + * \param thread_scope The thread scope of the thread axis to be bound, which is an invalid value if + * the operation is not "bind" + */ +void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind for_kind, + runtime::ThreadScope thread_scope) { + PreOrderVisit(loop, [&](const ObjectRef& node) { + if (const auto* realize = node.as()) { + CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, GetRef(realize), + thread_scope); + } + return true; + }); +} + +/*! + * \brief The implementation of parallelizing/vectorizing/binding a given loop + * \param self The schedule state + * \param loop_sref The sref of the loop to be parallelized/vectorized/bound + * \param for_kind The type of the operation (only `kParallel`, `kVectorized` and `kThreadBinding` + * are allowed) + * \param thread_axis The thread axis that the input loop is bound to, which is defined only when + * `for_kind` is `kThreadBinding` + */ +void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref, ForKind for_kind, + Optional thread_axis) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + + /* + * Check: + * - 1. the subtree rooted from the input loop in sref tree has compact data flow + * - 2. all the blocks under the given loop have affine block bindings + * - 3. the input loop can be only bound to data parallel block iters, or the loop can be bound to + * reduction block iter if `thread` is `threadIdx.x/y/z` in case of cross-thread reduction + * When the above conditions are all satisfied, this input loop can be + * parallelized/vectorized/bound. + */ + // Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow. + CheckSRefSubtreeCompactDataFlow(self, loop_sref); + + // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each + // underlying block. + CheckParallelizability(self, GetRef(loop), for_kind, + thread_axis.defined() + ? runtime::ThreadScope::Create(thread_axis.value()->thread_tag) + : runtime::ThreadScope{-1, -1}); + + // Step 3. Loop update and IR replacement + ObjectPtr new_loop = make_object(*loop); + new_loop->kind = for_kind; + new_loop->thread_binding = std::move(thread_axis); + self->Replace(loop_sref, For(new_loop), {}); +} + +void Parallel(ScheduleState self, const StmtSRef& loop_sref) { + ParallelizeComputation(self, loop_sref, ForKind::kParallel, NullOpt); +} + +void Vectorize(ScheduleState self, const StmtSRef& loop_sref) { + ParallelizeComputation(self, loop_sref, ForKind::kVectorized, NullOpt); +} + +void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis) { + ParallelizeComputation(self, loop_sref, ForKind::kThreadBinding, thread_axis); +} + +void Unroll(ScheduleState self, const StmtSRef& loop_sref) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + ObjectPtr new_loop = make_object(*loop); + new_loop->kind = ForKind::kUnrolled; + new_loop->thread_binding = NullOpt; + self->Replace(loop_sref, For(new_loop), {}); +} + +/******** Instruction Registration ********/ + +struct ParallelTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Parallel"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { + return sch->Parallel(loop_rv); + } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("parallel"); + py.Input("loop", loop_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct VectorizeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Vectorize"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { + return sch->Vectorize(loop_rv); + } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("vectorize"); + py.Input("loop", loop_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct BindTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Bind"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, String thread) { + return sch->Bind(loop_rv, thread); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, String thread) { + PythonAPICall py("bind"); + py.Input("loop", loop_rv); + py.Input("thread", thread); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct UnrollTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Unroll"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { return sch->Unroll(loop_rv); } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("unroll"); + py.Input("loop", loop_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ParallelTraits); +TVM_REGISTER_INST_KIND_TRAITS(VectorizeTraits); +TVM_REGISTER_INST_KIND_TRAITS(BindTraits); +TVM_REGISTER_INST_KIND_TRAITS(UnrollTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc new file mode 100644 index 000000000000..a8d9c5a69dc9 --- /dev/null +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name) { + struct Finder : public StmtVisitor { + explicit Finder(const ScheduleState& self, const String& name) : self_(self), name_(name) {} + + void VisitStmt_(const BlockNode* block) override { + if (block->name_hint == name_) { + auto it = self_->stmt2ref.find(block); + ICHECK(it != self_->stmt2ref.end()); + results_.push_back(it->second); + } + StmtVisitor::VisitStmt_(block); + } + + const ScheduleState& self_; + const String& name_; + Array results_; + }; + + BaseFunc func = self->mod->Lookup(func_name); + const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode); + Finder finder(self, name); + finder(prim_func->body); + return std::move(finder.results_); +} + +Array GetLoops(const StmtSRef& block_sref) { + std::vector result; + for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance(); + parent = parent->parent) { + result.push_back(GetRef(parent)); + } + return {result.rbegin(), result.rend()}; +} + +/******** Instruction Registration ********/ + +struct GetBlockTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetBlock"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 0; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, String name, String func_name) { + return sch->GetBlock(name, func_name); + } + + static String UnpackedAsPython(Array outputs, String name, String func_name) { + PythonAPICall py("get_block"); + py.Input("name", name); + py.Input("func_name", func_name); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct GetLoopsTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetLoops"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + return sch->GetLoops(block_rv); + } + + static String UnpackedAsPython(Array outputs, String block_rv) { + PythonAPICall py("get_loops"); + py.Input("block", block_rv); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits); +TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc new file mode 100644 index 000000000000..7c2b61344427 --- /dev/null +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -0,0 +1,797 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief Append a new predicate to the each child of type BlockRealize (not recursively) */ +class BlockPredicateAppender : public StmtMutator { + public: + /*! + * \brief Constructor + * \param to_append The predicate to be appended to BlockRealizeNode + */ + explicit BlockPredicateAppender(const PrimExpr& to_append) : to_append_(to_append) {} + + private: + // For each direct child of type BlockRealizeNode, append the predicate + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + // We do not recursively do this + ObjectPtr n = CopyOnWrite(realize); + n->predicate = n->predicate && to_append_; + return BlockRealize(n); + } + + /*! \brief The predicate to be appended */ + const PrimExpr& to_append_; +}; + +/*! \brief Substitute vars and collect the reuse mapping of opaque blocks */ +class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { + public: + explicit SubstituteVarAndCollectOpaqueBlock(std::function(const Var&)> vmap, + Map* opaque_blocks) + : vmap_(vmap), opaque_blocks_(opaque_blocks) {} + + private: + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + if (Optional ret = vmap_(var)) { + return ret.value(); + } else { + return std::move(var); + } + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + BlockRealize realize = Downcast(StmtMutator::VisitStmt_(op)); + if (realize->block->iter_vars.empty()) { + opaque_blocks_->Set(op->block, realize->block); + } + return std::move(realize); + } + + /*! \brief The substitute function */ + std::function(const Var&)> vmap_; + /*! \brief The reuse mapping of opaque blocks */ + Map* opaque_blocks_; +}; + +/*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */ +class IterMapSimplifyBlockBinding : public StmtExprMutator { + public: + explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map loop_var2extent) + : opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent) {} + + static For SimplifyBindings(Stmt stmt, const Array& loop_srefs, + MapNode* opaque_blocks) { + Map loop_var2extent; + for (const StmtSRef& sref : loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); + loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + return Downcast( + IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent))(std::move(stmt))); + } + + private: + Stmt VisitStmt_(const ForNode* op) final { + loop_var2extent_.Set(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + Stmt res = StmtMutator::VisitStmt_(op); + loop_var2extent_.erase(op->loop_var); + return res; + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + // skip opaque block and update mapping + if (op->iter_values.empty()) { + Block block = op->block; + BlockRealize realize = Downcast(StmtMutator::VisitStmt_(op)); + for (const std::pair& entry : *opaque_blocks_) { + if (entry.second.same_as(block)) { + opaque_blocks_->at(entry.first) = realize->block; + break; + } + } + return std::move(realize); + } + Array v = arith::IterMapSimplify(/*indices=*/op->iter_values, + /*input_iters=*/loop_var2extent_, + /*input_pred=*/op->predicate, + /*require_bijective=*/false); + if (v.same_as(op->iter_values)) { + return GetRef(op); + } else { + ObjectPtr n = CopyOnWrite(op); + n->iter_values = std::move(v); + return Stmt(n); + } + } + + /*! \brief The reuse mapping */ + MapNode* opaque_blocks_; + /*! \brief The range of loops */ + Map loop_var2extent_; +}; + +class BlockPropertyError : public ScheduleError { + public: + /*! + * \brief Check that all the blocks under the specific stmt have affine bindings and only have + * data-parallel or reduction block iters + * \param self The state of the schedule + * \param sref The sref to the specific stmt + */ + static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, + const StmtSRefNode* sref) { + class BlockIterTypeAndAffineBindingChecker : public StmtVisitor { + public: + explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state) : state_(state) {} + + private: + void VisitStmt_(const BlockNode* op) final { + for (const IterVar& iter_var : op->iter_vars) { + if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { + throw BlockPropertyError(state_->mod, GetRef(op)); + } + CheckAffineBinding(state_, GetRef(op)); + } + } + const ScheduleState& state_; + }; + + BlockIterTypeAndAffineBindingChecker checker(self); + checker(GetRef(sref->stmt)); + } + + explicit BlockPropertyError(IRModule mod, Block block) : mod_(mod), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The block under the loops to be reordered have block iter type other " + "than data-parallel or reduction"; + } + + String DetailRenderTemplate() const final { + return "The block {0} under the loops to be reordered have block iter type other than " + "data-parallel or reduction"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; +}; + +class HasAnnotationOrThreadBindingError : public ScheduleError { + public: + explicit HasAnnotationOrThreadBindingError(IRModule mod, For loop) + : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The primitive can't be applied because the loop has annotation or " + "thread binding"; + } + + String DetailRenderTemplate() const final { + return "The primitive can't be applied because the loop {0} has annotation or thread binding"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class OuterNotInnerParent : public ScheduleError { + public: + explicit OuterNotInnerParent(IRModule mod, For outer, For inner) + : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} + + String FastErrorString() const final { + return "ScheduleError: The outer loop is not the parent of the inner loop"; + } + + String DetailRenderTemplate() const final { + return "The loops can't be fused because the outer loop {0} is not the parent of the inner " + "loop {1}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {outer_, inner_}; } + + IRModule mod_; + For outer_; + For inner_; +}; + +class NotOnlyChildError : public ScheduleError { + public: + explicit NotOnlyChildError(IRModule mod, For outer, For inner) + : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} + + String FastErrorString() const final { + return "ScheduleError: The inner loop is not the only child of outer loop"; + } + + String DetailRenderTemplate() const final { + return "The loops can't be fused because the inner loop {1} is not the only child of outer " + "loop {0}."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {outer_, inner_}; } + + IRModule mod_; + For outer_; + For inner_; +}; + +class LoopNotStartWithZeroError : public ScheduleError { + public: + explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The primitive only supports loop starting with 0"; + } + + String DetailRenderTemplate() const final { + return "The loop {0} does not start with 0, which is not supported"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class NotSingleInferFactorError : public ScheduleError { + public: + explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} + + String FastErrorString() const final { + return "ScheduleError: only one factor can be specified as -1 or none"; + } + + String DetailRenderTemplate() const final { + return "Only one factor can be specified as -1 or none"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; +}; + +class WrongFactorProductError : public ScheduleError { + public: + explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The product of factors is not larger than or equal to the extent of " + "loop"; + } + + String DetailRenderTemplate() const final { + return "The product of factors is not larger than or equal to the extent of loop {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class LoopMultiAppearanceError : public ScheduleError { + public: + explicit LoopMultiAppearanceError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: Some loop appears in the input array for multiple times."; + } + + String DetailRenderTemplate() const final { + return "Loop {0} appears in the input array for multiple times."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class LoopsNotAChainError : public ScheduleError { + public: + enum class ProblemKind { kNotUnderAScope, kHaveNonSingleBranchStmt }; + + explicit LoopsNotAChainError(IRModule mod, Optional problematic_loop, ProblemKind kind) + : mod_(mod), problematic_loop_(std::move(problematic_loop)), kind_(kind) {} + + String FastErrorString() const final { return "ScheduleError: the loops are not in a chain"; } + + String DetailRenderTemplate() const final { + std::stringstream ss; + ss << "The loops are not in a chain because"; + if (kind_ == ProblemKind::kNotUnderAScope) { + ss << " they are not under the same scope."; + } else { + ss << " there is a non-single-branch stmt in between. Problematic stmt: {0}"; + } + return ss.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { + if (kind_ == ProblemKind::kNotUnderAScope) { + return {}; + } else { + ICHECK(problematic_loop_.defined()); + return {problematic_loop_.value()}; + } + } + + IRModule mod_; + Optional problematic_loop_; + ProblemKind kind_; +}; + +class DependentLoopError : public ScheduleError { + public: + explicit DependentLoopError(IRModule mod, For loop, String inner_var) + : mod_(mod), loop_(std::move(loop)), inner_var_(std::move(inner_var)) {} + + String FastErrorString() const final { + return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop " + "in the new order"; + } + + String DetailRenderTemplate() const final { + return "Outer Loop {0}'s `min` or `extent` is dependent on an inner loop " + inner_var_ + + " in the new order"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; + String inner_var_; +}; + +Array Split(ScheduleState self, const StmtSRef& loop_sref, + const Array& factors) { + // Invariance + // - The total repeat number has not changed for each direct child block with updating predicate. + // - The execution order has not changed. (The block executes with the same args and the same + // order with before. + // Step 1. Check correctness + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (!loop->annotations.empty() || loop->thread_binding.defined()) { + throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + } + // Currently, loops not starting with 0 are not supported + arith::Analyzer analyzer; + if (!analyzer.CanProve(loop->min == 0)) { + throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); + } + // Step 2. Replace all occurrences of the original loop var with new variables + int n = factors.size(); + PrimExpr substitute_value = 0; + std::vector new_loop_vars; + new_loop_vars.reserve(n); + for (int i = 0; i < n; i++) { + const PrimExpr& factor = factors[i]; + Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); + substitute_value = substitute_value * factor + var; + analyzer.Bind(var, Range::FromMinExtent(0, factor)); + new_loop_vars.emplace_back(std::move(var)); + } + Map opaque_block_reuse; + Stmt new_stmt = loop->body; + new_stmt = SubstituteVarAndCollectOpaqueBlock( + [&](const Var& v) -> Optional { + if (v.same_as(loop->loop_var)) { + return substitute_value; + } else { + return NullOpt; + } + }, + &opaque_block_reuse)(std::move(new_stmt)); + // Step 3. Update predicate to guard the loop + PrimExpr predicate = substitute_value < loop->extent; + if (!analyzer.CanProve(predicate)) { + new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt)); + } + // Step 4. Generate nested loops to replace the original loop and simplify the binding + for (int i = n - 1; i >= 0; i--) { + new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt); + } + new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref), + opaque_block_reuse.CopyOnWrite()); + self->Replace(loop_sref, new_stmt, opaque_block_reuse); + Array result_srefs; + result_srefs.reserve(n); + for (int i = 0; i < n; i++) { + result_srefs.push_back(self->stmt2ref.at(new_stmt.get())); + const ForNode* outer_loop = TVM_TYPE_AS(outer_loop, new_stmt, ForNode); + new_stmt = outer_loop->body; + } + return result_srefs; +} + +StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { + // Invariance + // - The total repeat number has not changed for each direct child block. + // - The execution order has not changed. (The block executes with the same + // args and the same order with before.) + std::vector loops; + loops.reserve(loop_srefs.size()); + StmtSRef outer_loop_sref{nullptr}; + const ForNode* outer_loop = nullptr; + arith::Analyzer analyzer; + // Step 1. check correctness + for (const StmtSRef& sref : loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); + if (!loop->annotations.empty() || loop->thread_binding.defined()) { + throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + } + if (outer_loop_sref.defined()) { + if (sref->parent != outer_loop_sref.get()) { + throw OuterNotInnerParent(self->mod, GetRef(outer_loop), GetRef(loop)); + } + if (!outer_loop->body.same_as(GetRef(loop))) { + throw NotOnlyChildError(self->mod, GetRef(outer_loop), GetRef(loop)); + } + } + outer_loop_sref = sref; + outer_loop = loop; + if (!analyzer.CanProve(loop->min == 0)) { + throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); + } + loops.push_back(loop); + } + // Step 2. Create fused loop var and replace the original loop vars + std::string suffix; + int n = loops.size(); + for (int i = 1; i < n; i++) { + suffix += "_" + loops[i]->loop_var->name_hint; + } + suffix += "_fused"; + Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); + Array substitute_value; + substitute_value.resize(loops.size()); + PrimExpr tot = fused_var; + for (int i = static_cast(loops.size()) - 1; i >= 0; i--) { + substitute_value.Set(i, floormod(tot, loops[i]->extent)); + tot = floordiv(tot, loops[i]->extent); + } + Stmt new_stmt = loops.back()->body; + Map opaque_block_reuse; + auto f_substitute = [&](const Var& v) -> Optional { + for (int i = 0; i < n; i++) { + if (v.same_as(loops[i]->loop_var)) { + return substitute_value[i]; + } + } + return NullOpt; + }; + new_stmt = + SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt)); + // Step 3. Generate a loop to replace the original loops + PrimExpr fused_extent = 1; + for (int i = 0; i < n; i++) { + fused_extent *= loops[i]->extent; + } + fused_extent = analyzer.Simplify(fused_extent); + new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt); + new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings( + std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite()); + self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse); + return self->stmt2ref.at(new_stmt.get()); +} +/*! + * \brief Collect an array of loop srefs into a set + * \param self The schedule state + * \param ordered_loop_srefs The array of loop srefs + * \return A set containing all loops in the array + * \throws ScheduleError If there are duplicate loops in the array + */ +std::unordered_set CollectLoopsIntoSet( + const ScheduleState& self, const Array& ordered_loop_srefs) { + std::unordered_set loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + for (const StmtSRef& loop_sref : ordered_loop_srefs) { + auto inserted = loop_srefs.insert(loop_sref.get()); + if (!inserted.second) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + throw LoopMultiAppearanceError(self->mod, GetRef(loop)); + } + } + return loop_srefs; +} + +/*! + * \brief Get the top and bottom boundary of reorder range (which should be a chain) + * \param self The schedule state + * \param loop_srefs The set containing the srefs to the loops to be reordered + * \return A pair containing the top and bottom boundary of the reorder range + * \throws ScheduleError If the loops to be reordered is not in a chain + */ +std::pair GetBoundaryOfReorderRange( + const ScheduleState& self, const std::unordered_set& loop_srefs) { + const StmtSRefNode* top = nullptr; + const StmtSRefNode* bottom = *loop_srefs.begin(); + std::unordered_set visited; + bool scope_block_visited = false; + bool first_traversal = true; + for (const StmtSRefNode* loop_sref : loop_srefs) { + if (visited.count(loop_sref)) { + continue; + } + for (const StmtSRefNode* v = loop_sref;; v = v->parent) { + // Case 1. If `v` corresponds to a block, stop traversal. + if (v->stmt->IsInstance()) { + if (scope_block_visited) { + throw LoopsNotAChainError(self->mod, NullOpt, + LoopsNotAChainError::ProblemKind::kNotUnderAScope); + } + scope_block_visited = true; + break; + } + // Case 2. If `v` corresponds to a previously-visited loop, stop traversal and update + // `bottom`. + if (visited.count(v)) { + if (v != bottom) { + throw LoopsNotAChainError(self->mod, GetRef(v->stmt), + LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); + } + bottom = loop_sref; + break; + } + // Case 3. Add `v` into `visited` + visited.insert(v); + // If it's the first traversal and the loop corresponding to `v` is in the input array, + // update `top`. + if (first_traversal && loop_srefs.count(v)) { + top = v; + } + } + first_traversal = false; + } + return std::make_pair(top, bottom); +} + +/*! + * \brief Get all the loops in the reorder range + * \param self The schedule state + * \param top The top boundary of the reorder range + * \param bottom The bottom boundary of the reorder range + * \return An array containing all the loops in the reorder range + * \throws ScheduleError If some loop in the reorder range is not single-branch + */ +std::vector GetLoopsInReorderRange(const ScheduleState& self, + const StmtSRefNode* top, + const StmtSRefNode* bottom) { + std::vector chain; + for (const StmtSRefNode* loop_sref = bottom; loop_sref != top;) { + const StmtSRefNode* parent_loop_sref = loop_sref->parent; + const ForNode* outer = parent_loop_sref->StmtAs(); + const ForNode* inner = loop_sref->StmtAs(); + ICHECK(outer != nullptr && inner != nullptr); + if (outer->body.get() != inner) { + throw LoopsNotAChainError(self->mod, GetRef(outer), + LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); + } + chain.push_back(loop_sref); + loop_sref = parent_loop_sref; + } + chain.push_back(top); + return chain; +} + +/*! + * \brief Construct a loop chain in the new order + * \param self The schedule state + * \param chain The loops in the reorder range + * \param ordered_loop_srefs The loop srefs to be reordered + * \param loop_srefs The set containing loop srefs to be reordered + * \return The new loop chain + * \throws ScheduleError If the domain of an outer loop depends on any of the inner loops after + * reordering + */ +For ConstructNewLoopChain(const ScheduleState& self, std::vector chain, + const Array& ordered_loop_srefs, + const std::unordered_set& loop_srefs) { + std::unordered_set inner_vars; + inner_vars.reserve(chain.size()); + For new_loop{nullptr}; + int index = static_cast(ordered_loop_srefs.size()) - 1; + for (const StmtSRefNode* loop_sref : chain) { + const ForNode* copy = nullptr; + if (loop_srefs.count(loop_sref)) { + copy = ordered_loop_srefs[index]->StmtAs(); + --index; + } else { + copy = loop_sref->StmtAs(); + } + ICHECK(copy != nullptr); + ObjectPtr n = make_object(*copy); + if (new_loop.defined()) { + n->body = new_loop; + } else { + n->body = loop_sref->StmtAs()->body; + } + const VarNode* used_var = nullptr; + auto f_contain = [&inner_vars, &used_var](const VarNode* var) { + if (inner_vars.count(var)) { + used_var = var; + return true; + } + return false; + }; + if (UsesVar(copy->min, f_contain) || UsesVar(copy->extent, f_contain)) { + throw DependentLoopError(self->mod, GetRef(copy), used_var->name_hint); + } + inner_vars.insert(copy->loop_var.get()); + new_loop = For(std::move(n)); + } + return new_loop; +} + +void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { + if (ordered_loop_srefs.size() <= 1) { + return; + } + // Step 1. Check uniqueness and collect the input loop srefs into a set + std::unordered_set loop_srefs = + CollectLoopsIntoSet(self, ordered_loop_srefs); + // Step 2. Gather loops to be reordered + // For each loop sref in the input sref array, traverse upwards along its parent pointer in the + // sref tree, and stop on either a block, or a previously-visited loop + // - the top of the reorder range is the last loop visited in the first traversal which exists in + // the input array + // - the bottom of the reorder range is the last loop in the input array which is not visited in + // the previous traversals + const StmtSRefNode* top = nullptr; + const StmtSRefNode* bottom = nullptr; + std::tie(top, bottom) = GetBoundaryOfReorderRange(self, loop_srefs); + // Step 3. Collect all loops in the chain and check the loops are single-branch + std::vector chain = GetLoopsInReorderRange(self, top, bottom); + // Step 4. Check the block below has all its block_var to be data-parallel or reduction, + // and the block has an affine binding. + BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom); + // Step 5. Replace the original loops with the reordered loops and check that outer loop is + // not dependent on inner loop + For new_loop = ConstructNewLoopChain(self, std::move(chain), ordered_loop_srefs, loop_srefs); + self->Replace(GetRef(top), new_loop, {}); +} + +/******** Instruction Registration ********/ + +struct SplitTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Split"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + template + static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, + const Array& inputs) { + thread_local ObjectRef loop_rv{nullptr}; + thread_local Array factors{nullptr}; + loop_rv = inputs[0]; + factors = Array{inputs.begin() + 1, inputs.end()}; + setter(delta, loop_rv); + setter(delta + 1, factors); + } + + static Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, + Array> factors) { + return sch->Split(loop_rv, factors); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, Array factors) { + PythonAPICall py("split"); + py.Input("loop", loop_rv); + py.Input("factors", factors); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct FuseTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Fuse"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + template + static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, + const Array& inputs) { + setter(delta, inputs); + } + + static LoopRV UnpackedApplyToSchedule(Schedule sch, Array loop_rvs) { + return sch->Fuse(loop_rvs); + } + + static String UnpackedAsPython(Array outputs, Array loop_rvs) { + PythonAPICall py("fuse"); + for (const String& loop_rv : loop_rvs) { + py.Input("", loop_rv); + } + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct ReorderTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Reorder"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + template + static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, + const Array& inputs) { + setter(delta, inputs); + } + + static void UnpackedApplyToSchedule(Schedule sch, Array loop_rvs) { + return sch->Reorder(loop_rvs); + } + + static String UnpackedAsPython(Array outputs, Array loop_rvs) { + PythonAPICall py("reorder"); + for (const String& loop_rv : loop_rvs) { + py.Input("", loop_rv); + } + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(SplitTraits); +TVM_REGISTER_INST_KIND_TRAITS(FuseTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc new file mode 100644 index 000000000000..af77e51e4d83 --- /dev/null +++ b/src/tir/schedule/primitive/reduction.cc @@ -0,0 +1,994 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/******** Commutative Reducer ********/ + +/*! + * \brief A structure used for registering new commutative reducers, and store all the registered + * reducers. The reducers are preserved in a list, in the form of "reducer-getter function". When + * invoking a reducer-getter function with a specific datatype, the reducer-getter will return the + * CommReducer of the corresponding reduction pattern and the specific datatype + */ +struct ReducerRegistry { + ReducerRegistry() + : reducer_getters{CreateReducerGetter([](const Var& x, const Var& y) { return x + y; }, + [](DataType dtype) { return make_const(dtype, 0); }), + CreateReducerGetter([](const Var& x, const Var& y) { return x * y; }, + [](DataType dtype) { return make_const(dtype, 1); }), + CreateReducerGetter([](const Var& x, const Var& y) { return min(x, y); }, + [](DataType dtype) { return max_value(dtype); }), + CreateReducerGetter([](const Var& x, const Var& y) { return max(x, y); }, + [](DataType dtype) { return min_value(dtype); })} {} + + static void RegisterReducer(TypedPackedFunc combiner_getter, + TypedPackedFunc identity_getter) { + ReducerRegistry::Global()->reducer_getters.push_back(ReducerRegistry::CreateReducerGetter( + std::move(combiner_getter), std::move(identity_getter))); + } + + static TypedPackedFunc CreateReducerGetter( + TypedPackedFunc combiner_getter, + TypedPackedFunc identity_getter) { + return [combiner_getter = std::move(combiner_getter), + identity_getter = std::move(identity_getter)](DataType dtype) -> CommReducer { + Var lhs("x", dtype); + Var rhs("y", dtype); + return CommReducer({lhs}, {rhs}, {combiner_getter(lhs, rhs)}, {identity_getter(dtype)}); + }; + } + + static ReducerRegistry* Global() { + static ReducerRegistry instance; + return &instance; + } + + std::vector> reducer_getters; +}; + +std::vector> GetReducerGetters() { + return ReducerRegistry::Global()->reducer_getters; +} + +class NotSerialLoopKindError : public ScheduleError { + public: + explicit NotSerialLoopKindError(IRModule mod, For loop) + : mod_(std::move(mod)), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The input loop of rfactor is required to be `kSerial`"; + } + + String DetailRenderTemplate() const final { + String str_kind = ForKind2String(loop_->kind); + std::ostringstream os; + os << "ScheduleError: The input loop {0} of rfactor is required to be `Serial`. However, the " + "kind of {0} is `" + << str_kind << "`"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class InitBodyNotBufferStoreError : public ScheduleError { + public: + explicit InitBodyNotBufferStoreError(IRModule mod, Block block, bool init_is_bufferstore, + bool body_is_bufferstore) + : mod_(std::move(mod)), + block_(std::move(block)), + init_is_bufferstore_(init_is_bufferstore), + body_is_bufferstore_(body_is_bufferstore) {} + + String FastErrorString() const final { + return "ScheduleError: The `init` and `body` of reduction block are required to be both " + "BufferStore"; + } + + String DetailRenderTemplate() const final { + if (!init_is_bufferstore_ && !body_is_bufferstore_) { + return "The `init` and `body` of block {0} are required to be BufferStore so that rfactor " + "can be applied"; + } else if (!init_is_bufferstore_) { + return "The `init` of block {0} is required to be BufferStore so that rfactor can be applied"; + } else { + ICHECK(!body_is_bufferstore_); + return "The `body` of block {0} is required to be BufferStore so that rfactor can be applied"; + } + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; + bool init_is_bufferstore_; + bool body_is_bufferstore_; +}; + +class InitBodyNotSameBufferAccessError : public ScheduleError { + public: + explicit InitBodyNotSameBufferAccessError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The `init` and `body` of the reduction block are required to have the " + "same buffer access pattern"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + const auto* init = block_->init.as(); + const auto* update = block_->body.as(); + os << "The `init` and `body` of the block {0} is required to have the same buffer access " + "pattern. However, in block {0} the `init` writes to " + << init->buffer->name << init->indices << ", and the `body` writes to " + << update->buffer->name << update->indices; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; +}; + +class FactorAxisOutOfRangeError : public ScheduleError { + public: + explicit FactorAxisOutOfRangeError(IRModule mod, Buffer buffer, int factor_axis) + : mod_(std::move(mod)), buffer_(std::move(buffer)), factor_axis_(factor_axis) {} + + String FastErrorString() const final { + return "ScheduleError: The input `factor_axis` is out of range. It is required to be in range " + "[-(ndim + 1), ndim] where `ndim` is the number of dimensions of the write buffer"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + int ndim = static_cast(buffer_->shape.size()); + os << "The write buffer " << buffer_->name << " has " << ndim + << " dimension(s), so `factor_axis` is required to be in [" << -(ndim + 1) << ", " << ndim + << "] for rfactor. However, the input `factor_axis` is " << factor_axis_ + << ", which is out of the expected range"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int factor_axis) { + int ndim = static_cast(buffer->shape.size()); + if (factor_axis < -(ndim + 1) || factor_axis > ndim) { + throw FactorAxisOutOfRangeError(mod, buffer, factor_axis); + } + // If factor_axis is negative, convert it to a non-negative one. + if (factor_axis < 0) { + factor_axis += ndim + 1; + } + return factor_axis; + } + + IRModule mod_; + Buffer buffer_; + int factor_axis_; +}; + +class NoMatchedReducerError : public ScheduleError { + public: + explicit NoMatchedReducerError(IRModule mod, PrimExpr identity, BufferStore combiner) + : mod_(std::move(mod)), identity_(std::move(identity)), combiner_(std::move(combiner)) {} + + String FastErrorString() const final { + return "ScheduleError: No matched reducer for the identity and the combiner of this reduction " + "block. So rfactor cannot be applied."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "No matched reducer for identity " << identity_ << " and combiner " << combiner_ + << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for " + "default reducers or registering new reducers."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + PrimExpr identity_; + BufferStore combiner_; +}; + +class LoopPropertyError : public ScheduleError { + public: + enum ErrorType { + kDataParIterTouchRFactorLoop = 0, + kLoopTouchedByBothKindsOfBlockIters = 1, + kNotFirstChildBlockOfOutermostLoop = 2, + kUnboundLoopUnderReductionLoop = 3 + }; + + explicit LoopPropertyError(IRModule mod, For loop, ErrorType error_type) + : mod_(std::move(mod)), loop_(std::move(loop)), error_type_(error_type) {} + + String FastErrorString() const final { + switch (error_type_) { + case kDataParIterTouchRFactorLoop: + return "ScheduleError: The loop to be applied rfactor is required not to be touched by any " + "data parallel block iter of the block"; + case kLoopTouchedByBothKindsOfBlockIters: + return "ScheduleError: The loops outside of the reduction block are required not to be " + "touched by both data parallel block iters and reduction block iters"; + case kNotFirstChildBlockOfOutermostLoop: + return "ScheduleError: The reduction block should be the first child block of the " + "outermost loop outside of it"; + case kUnboundLoopUnderReductionLoop: + return "ScheduleError: A loop who has extent greater than one and is not bound to any " + "block iter should not appear under a reduction loop"; + } + ICHECK(false) << "Unreachable"; + throw; + } + + String DetailRenderTemplate() const final { + switch (error_type_) { + case kDataParIterTouchRFactorLoop: + return "The loop to be applied rfactor is {0}, which is required not to be touched by any " + "data parallel block iter of the block below. However, some of the block's data " + "parallel block iters touch this loop"; + case kLoopTouchedByBothKindsOfBlockIters: + return "It is not allowed that the loop {0} is touched by both some data parallel block " + "iters and some reduction block iters"; + case kNotFirstChildBlockOfOutermostLoop: + return "The first child block of the outermost loop {0} is not the reduction block."; + case kUnboundLoopUnderReductionLoop: + return "The loop {0} has extent greater than one, and is not bound to any block iter. " + "Therefore it shouldn't appear under a reduction loop"; + } + ICHECK(false) << "Unreachable"; + throw; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + static void CheckLoopProperty(const ScheduleState& self, const Array& loops, + const ForNode* rf_loop, const Block& block, + const std::unordered_set& data_par_loop_vars, + const std::unordered_set& reduce_loop_vars) { + Array children_of_outermost_loop = + GetChildBlockRealizeOnSRefTree(self->stmt2ref.at(loops[0].get())); + if (!children_of_outermost_loop[0]->block.same_as(block)) { + throw LoopPropertyError(self->mod, loops[0], kNotFirstChildBlockOfOutermostLoop); + } + + bool meet_reduction_loop = false; + for (const For& loop : loops) { + bool data_par_touched = data_par_loop_vars.count(loop->loop_var.get()); + bool reduction_touched = reduce_loop_vars.count(loop->loop_var.get()); + + if (data_par_touched && reduction_touched) { + throw LoopPropertyError(self->mod, loop, kLoopTouchedByBothKindsOfBlockIters); + } else if (data_par_touched) { + if (loop.get() == rf_loop) { + throw LoopPropertyError(self->mod, loop, kDataParIterTouchRFactorLoop); + } + continue; + } else if (reduction_touched) { + if (!meet_reduction_loop) { + CheckGetSingleChildBlockRealizeOnSRefTree(self, self->stmt2ref.at(loop.get())); + meet_reduction_loop = true; + } + continue; + } else if (meet_reduction_loop && !is_one(loop->extent)) { + throw LoopPropertyError(self->mod, loop, kUnboundLoopUnderReductionLoop); + } + } + } + + IRModule mod_; + For loop_; + ErrorType error_type_; +}; + +/*! + * \brief Convert the `init` and `body` of the input block to BufferStores + * \param self The schedule state + * \param block The block to be analyzed + * \return The BufferStores of the `init` and `body` of the input block + * \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same + * buffer + */ +std::pair GetBufferStoreNodes(const ScheduleState& self, + const Block& block) { + const auto* init = block->init.as(); + const auto* body = block->body.as(); + if (!(init && body)) { + throw InitBodyNotBufferStoreError(self->mod, block, init != nullptr, body != nullptr); + } + if (!init->buffer.same_as(body->buffer)) { + throw InitBodyNotSameBufferAccessError(self->mod, block); + } + int ndim = static_cast(init->buffer->shape.size()); + for (int i = 0; i < ndim; ++i) { + if (!ExprDeepEqual()(init->indices[i], body->indices[i])) { + throw InitBodyNotSameBufferAccessError(self->mod, block); + } + } + return std::make_pair(GetRef(init), GetRef(body)); +} + +/*! + * \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative + * reducer, and extract the combiner lhs and combiner rhs + * \param self The schedule state + * \param identity The reduction identity to be analyzed + * \param combiner The reduction combiner to be analyzed + * \return The corresponding CommReducer, the combiner lhs and the combiner rhs + * \throw ScheduleError If no corresponding commutative reducer can be matched + */ +std::tuple GetReducerAndCombinerLhsRhs( + const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner) { + CommReducer reducer{nullptr}; + PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr}; + bool matched = FromIdentityCombiner(identity, combiner, &reducer, &combiner_lhs, &combiner_rhs); + if (!matched) { + throw NoMatchedReducerError(self->mod, identity, combiner); + } + return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs)); +} + +/*! + * \brief For each loop in the given array of loop, associate its loop var with the loop itself + * using a mapping + * \param loops The loops to be analyzed + * \return A mapping from loops to their corresponding loop vars + */ +std::unordered_map GetLoopVar2LoopMap(const Array& loops) { + std::unordered_map loop_vars2loop; + loop_vars2loop.reserve(loops.size()); + for (const For& loop : loops) { + loop_vars2loop[loop->loop_var.get()] = loop; + } + return loop_vars2loop; +} + +/*! + * \brief Create the intermediate rfactor buffer, which the rfactor block writes to and the + * write-back block reads from + * \param buffer The buffer written by the reduction block + * \param factor_axis The `factor_axis` parameter of rfactor + * \param rf_loop The rfactor loop + * \return The new created intermediate rfactor buffer + */ +Buffer CreateRFactorBuffer(const Buffer& buffer, int factor_axis, const ForNode* rf_loop) { + Array rf_shape = buffer->shape; + rf_shape.insert(rf_shape.begin() + factor_axis, rf_loop->extent); + + ObjectPtr n = make_object(*buffer.get()); + n->shape = rf_shape; + n->name = buffer->name + ".rf"; + n->data = buffer->data.copy_with_suffix(".rf"); + return Buffer(n); +} + +/*! + * \brief The base class of the rfactor/write-back block creator, which creates the blocks in four + * steps: + * 1) Create the new block iters and the their iter bindings + * 2) Create the reduction update of the new block + * 3) Create the read/write regions of the new block + * 4) Create the new block and the new block-realize + */ +class BaseBlockCreator { + public: + explicit BaseBlockCreator(BlockRealize old_block_realize, For rf_loop, + BufferStore old_reduction_update, CommReducer reducer, Buffer rf_buffer, + bool is_rf_block) + : old_block_realize_(std::move(old_block_realize)), + rf_loop_(std::move(rf_loop)), + old_reduction_update_(std::move(old_reduction_update)), + reducer_(std::move(reducer)), + rf_buffer_(std::move(rf_buffer)), + is_rf_block_(is_rf_block) { + n_block_iters_ = static_cast(old_block_realize_->iter_values.size()); + } + + void CreateBlock() { + CreateAdditionalIter(); + for (int i = 0; i < n_block_iters_; ++i) { + CreateNormalIters(i); + } + CreateReductionUpdate(); + CreateReadWriteRegions(); + + String new_block_name = old_block_realize_->block->name_hint; + PrimExpr predicate = Bool(true); + if (is_rf_block_) { + new_block_name = new_block_name + "_rf"; + predicate = old_block_realize_->predicate; + } + new_block_ = Block( + /*iter_vars=*/iter_vars_, + /*reads=*/read_regions_, + /*writes=*/write_regions_, + /*name_hint=*/new_block_name, + /*body=*/new_reduction_update_, + /*init=*/ + BufferStore(new_reduction_update_->buffer, reducer_->identity_element[0], + new_reduction_update_->indices)); + new_block_realize_ = BlockRealize(iter_values_, predicate, new_block_); + } + + private: + virtual void CreateAdditionalIter() = 0; + virtual void CreateNormalIters(int idx) = 0; + virtual void CreateReductionUpdate() = 0; + virtual void CreateReadWriteRegions() = 0; + + public: + /*! \brief The new created block */ + Block new_block_; + /*! \brief The new created block-realize */ + BlockRealize new_block_realize_; + /*! \brief The indices used to access the intermediate rfactor buffer */ + Array rf_buf_access_indices_; + + protected: + /*! \brief The old block-realize */ + BlockRealize old_block_realize_; + /*! \brief The number of block iters in the old block */ + int n_block_iters_; + /*! \brief The rfactor loop */ + For rf_loop_; + /*! \brief The update BufferStore of the old block */ + BufferStore old_reduction_update_; + /*! \brief The matched commutative reducer */ + CommReducer reducer_; + /*! \brief The intermediate rfactor buffer */ + Buffer rf_buffer_; + + /*! \brief Whether we are creating the rfactor block or the write-back block */ + bool is_rf_block_; + /*! \brief The new block iters of the new created block */ + std::vector iter_vars_; + /*! \brief The new block iter bindings of the new created block-realize */ + std::vector iter_values_; + /*! + * \brief A mapping which maps old block iters to new expressions. The old iters will be replaced + * by the expressions in future substitution for the two blocks + */ + Map var_map_; + /*! \brief The update BufferStore of the new created block */ + BufferStore new_reduction_update_; + /*! \brief The read regions of the new created block */ + Array read_regions_; + /*! \brief The write regions of the new created block */ + Array write_regions_; +}; + +/*! + * \brief The derived class of the rfactor block creator, which implements all virtual methods in + * the base creator + * \details Start constructing the rfactor block. The main difficulty to construct the rfactor block + * is to create its block iters. So here we introduce the algorithm to create the block iters. + * 1. Create a block iter for the rfactor loop. The block binding of this iter is the loop var, and + * the block iter is data parallel. + * 2. For all the old block's block iters, there are two cases: + * (a) If it is data parallel block iter, or a reduction block iter which doesn't touch the + * rfactor loop, we keep it and its block binding in the rfactor block. + * (b) Otherwise it is a reduction block iter which touches the rfactor loop. In this case, we + * "split" the block iter into one or more new block iters and do not keep the old block + * var. More specifically, we create a new reduction block iter for each loop var that + * appears in the reduction block iter's binding (except for the rfactor loop), and the + * binding of the new block iter is exactly the loop var. (Note that for each loop var, we + * create at most one block iter, even if there are multiple old block iters which touch + * both this loop and the rfactor loop). + * Then we substitute the appearances of the old block iter with the new created block + * iters by recording two mappings: one maps loops vars to new created block iters which + * is used for binding substitution, and another maps old block iters to new expressions + * which is used for substitutions of the old block iters. + */ +class RFactorBlockCreator : public BaseBlockCreator { + public: + explicit RFactorBlockCreator(BlockRealize old_block_realize, For rf_loop, + BufferStore old_reduction_update, CommReducer reducer, + Buffer rf_buffer, + std::unordered_map loop_vars2loop, + int factor_axis, PrimExpr combiner_rhs) + : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), + std::move(old_reduction_update), std::move(reducer), std::move(rf_buffer), + true), + loop_vars2loop_(std::move(loop_vars2loop)), + factor_axis_(factor_axis), + combiner_rhs_(std::move(combiner_rhs)) {} + + private: + void CreateAdditionalIter() final { + // Create a new data parallel block iter for the rfactor loop. + additional_iter_ = + IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, IterVarType::kDataPar); + loop_var2block_binding_[rf_loop_->loop_var.get()] = additional_iter_->var; + iter_vars_.push_back(additional_iter_); + iter_values_.push_back(rf_loop_->loop_var); + } + + void CreateNormalIters(int idx) final { + IterVar old_iter = old_block_realize_->block->iter_vars[idx]; + PrimExpr old_binding = old_block_realize_->iter_values[idx]; + if (old_iter->iter_type == IterVarType::kDataPar || + !UsesVar(old_binding, + [v = rf_loop_->loop_var.get()](const VarNode* var) { return var == v; })) { + // The old block iter is either a data parallel block iter, or a reduction block iter that + // doesn't touch the rfactor loop. In this case reuse the old reduction block iter and its + // corresponding binding. + iter_vars_.push_back(old_iter); + iter_values_.push_back(old_binding); + return; + } + ICHECK(old_iter->iter_type == kCommReduce); + // This block iter is a reduction block iter that touches the rfactor loop. So next we try to + // create a new block iter for all loop vars that appear in the old binding. + Array vars_in_old_binding = UndefinedVars(old_binding); + for (const Var& var : vars_in_old_binding) { + auto it = loop_vars2loop_.find(var.get()); + if (it == loop_vars2loop_.end()) { + // `var` is not a loop var. So skip. + continue; + } + const For& loop = it->second; + if (loop_var2block_binding_.find(var.get()) == loop_var2block_binding_.end()) { + // We haven't created the new block iter for `var`. So here we create it, append it + // and its binding to `rf_block_iter_vars` and `rf_block_iter_values` respectively. + IterVar new_iter_var = + IterVarFromLoop(loop, "v" + loop->loop_var->name_hint, IterVarType::kCommReduce); + loop_var2block_binding_[var.get()] = new_iter_var->var; + iter_vars_.push_back(new_iter_var); + iter_values_.push_back(var); + } + } + // Substitute the original binding with new block iters. Store the result expression + // in `rf_var_map` for future substitution. + var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_)); + } + + void CreateReductionUpdate() final { + rf_buf_access_indices_ = old_reduction_update_->indices; + rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_, + additional_iter_->var); + new_reduction_update_ = BufferStore( + rf_buffer_, + (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0], + rf_buf_access_indices_); + new_reduction_update_ = Downcast(Substitute(new_reduction_update_, var_map_)); + } + + void CreateReadWriteRegions() final { + const Block& old_block = old_block_realize_->block; + read_regions_ = CreateRegions(old_block->reads); + write_regions_ = CreateRegions(old_block->writes); + } + + Array CreateRegions(const Array& old_regions) { + Array new_regions; + new_regions.reserve(old_regions.size()); + for (const BufferRegion& buffer_region : old_regions) { + if (buffer_region->buffer.same_as(old_reduction_update_->buffer)) { + Array region = buffer_region->region; + region.insert(region.begin() + factor_axis_, + Range::FromMinExtent(additional_iter_->var, 1)); + new_regions.push_back(BufferRegion(rf_buffer_, Substitute(region, var_map_))); + } else { + new_regions.push_back( + BufferRegion(buffer_region->buffer, Substitute(buffer_region->region, var_map_))); + } + } + return new_regions; + } + + public: + /*! \brief The generated additional block iter in rfactor block for the rfactor loop */ + IterVar additional_iter_; + + private: + /*! + * \brief A mapping which maps a loop var to its corresponding For loop for all the reduction + * block's outer loops + */ + std::unordered_map loop_vars2loop_; + /*! \brief The factor_axis specified for rfactor */ + int factor_axis_; + /*! \brief The rhs of the combiner in the reduction update of the old block */ + PrimExpr combiner_rhs_; + /*! + * \brief A mapping which maps loop vars to new created block iters. This map is used to + * substitute the loop vars which appear in the bindings of some old block iters with the new + * created block iters + */ + std::unordered_map loop_var2block_binding_; +}; + +/*! + * \brief The derived class of the write-back block creator, which implements all virtual methods in + * the base creator + */ +class WriteBackBlockCreator : public BaseBlockCreator { + public: + explicit WriteBackBlockCreator(BlockRealize old_block_realize, For rf_loop, + BufferStore old_reduction_update, CommReducer reducer, + Buffer rf_buffer, IterVar rf_additional_iter, + PrimExpr combiner_lhs, Array rf_buf_access_indices) + : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), + std::move(old_reduction_update), std::move(reducer), std::move(rf_buffer), + false), + rf_additional_iter_(std::move(rf_additional_iter)), + combiner_lhs_(std::move(combiner_lhs)) { + iter_vars_.reserve(n_block_iters_); + iter_values_.reserve(n_block_iters_); + rf_buf_access_indices_ = std::move(rf_buf_access_indices); + } + + private: + void CreateAdditionalIter() final { + // Create a new reduction block iter for the rfactor loop. + IterVar wb_new_block_iter = + IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kCommReduce); + iter_vars_.push_back(wb_new_block_iter); + iter_values_.push_back(rf_loop_->loop_var); + var_map_.Set(rf_additional_iter_->var, wb_new_block_iter->var); + } + + void CreateNormalIters(int idx) final { + IterVar old_block_iter = old_block_realize_->block->iter_vars[idx]; + if (old_block_iter->iter_type == IterVarType::kDataPar) { + iter_vars_.emplace_back(old_block_iter->dom, old_block_iter->var.copy_with_suffix(""), + kDataPar); + iter_values_.push_back(old_block_realize_->iter_values[idx]); + var_map_.Set(old_block_iter->var, iter_vars_.back()); + } + } + + void CreateReductionUpdate() final { + wb_lhs_ = Downcast(Substitute(combiner_lhs_, var_map_)); + wb_rhs_ = + Downcast(Substitute(BufferLoad(rf_buffer_, rf_buf_access_indices_), var_map_)); + new_reduction_update_ = + BufferStore(old_reduction_update_->buffer, (*reducer_.get())({wb_lhs_}, {wb_rhs_})[0], + old_reduction_update_->indices); + new_reduction_update_ = Downcast(Substitute(new_reduction_update_, var_map_)); + } + + void CreateReadWriteRegions() final { + read_regions_.push_back(CreateRegion(wb_lhs_)); + read_regions_.push_back(CreateRegion(wb_rhs_)); + write_regions_.push_back(read_regions_[0]); + } + + static BufferRegion CreateRegion(const BufferLoad& load) { + Array region; + region.reserve(load->indices.size()); + for (const PrimExpr& index : load->indices) { + region.push_back(Range::FromMinExtent(index, 1)); + } + return BufferRegion(load->buffer, std::move(region)); + } + + private: + /*! \brief The new created additional block iter of the rfactor block */ + IterVar rf_additional_iter_; + /*! \brief The lhs of the combiner in the reduction update of the old block */ + PrimExpr combiner_lhs_; + /*! \brief The lhs of the combiner of the write-back block */ + BufferLoad wb_lhs_; + /*! \brief The rhs of the combiner of the write-back block */ + BufferLoad wb_rhs_; +}; + +/*! + * \brief Create new outer loops for the rfactor block, meanwhile update the rfactor block's iter + * bindings to use the new created loop vars + * \param rf_block_realize The BlockRealize of the rfactor block + * \param loops The loops to be wrapped over the rfactor block + * \return A Stmt which is the wrapping result + */ +Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const Array& loops) { + int n_loops = static_cast(loops.size()); + + // Step 1. Create new loop vars. + Array new_loops; + std::unordered_map new_loop_var_map; + new_loops.reserve(n_loops); + new_loop_var_map.reserve(n_loops); + for (const For& old_loop : loops) { + Var new_loop_var = old_loop->loop_var.copy_with_suffix(""); + new_loop_var_map[old_loop->loop_var.get()] = new_loop_var; + } + + // Step 2. Update the iter bindings of the rfactor block. + Array new_bindings; + new_bindings.reserve(rf_block_realize->iter_values.size()); + for (const PrimExpr& old_binding : rf_block_realize->iter_values) { + new_bindings.push_back(Substitute(old_binding, new_loop_var_map)); + } + rf_block_realize.CopyOnWrite()->iter_values = new_bindings; + + // Step 3. Wrap `rf_block_realize` with outer loops. + Stmt rf_body = rf_block_realize; + for (int i = n_loops - 1; i >= 0; --i) { + ObjectPtr p_loop = make_object(*loops[i].get()); + p_loop->loop_var = Downcast(new_loop_var_map[loops[i]->loop_var.get()]); + p_loop->body = rf_body; + rf_body = For(std::move(p_loop)); + } + + return rf_body; +} + +class BlockReplacer : public StmtMutator { + public: + /*! + * \brief The replace takes the old scope root block as input, and does four things: + * 1) replace the reduction block with the write-back block, + * 2) remove loops outside the write-back block that are touched by reduction block iters, except + * for the rfactor loop + * 3) combine the rfactor block (wrapped with outer loops) and the transformed outermost loop + * into a SeqStmt, and + * 4) insert the rfactor buffer into the scope root block's `alloc_buffers` + * After transformation, the function returns the new scope root block + * \param scope_root_block The old scope root block + * \param rf_body The rfactor block, which is already wrapped with outer loops + * \param outermost_loop The loop that is outermost among all loops outside the reduction block + * \param wb_block_realize The new created BlockRealize of the write-back block + * \param old_block_realize The BlockRealize of the reduction block + * \param rf_loop The rfactor loop, which should be kept outside the write-back block + * \param reduce_loop_vars The loops that are touched by reduction block iters, used to remove + * loops outside the write-back block + * \param loop_vars2loop The mapping from loop vars to loops that are outside the reduction block, + * which is used to reduce redundant recursive visits + * \param rf_buffer The rfactor buffer to be added into the scope root's `alloc_buffers` + * \return The transformed new scope root block + */ + static Block Replace(Block scope_root_block, Stmt rf_body, For outermost_loop, + BlockRealize wb_block_realize, BlockRealize old_block_realize, For rf_loop, + std::unordered_set reduce_loop_vars, + std::unordered_map loop_vars2loop, + const Buffer& rf_buffer) { + BlockReplacer replacer(std::move(rf_body), std::move(outermost_loop), + std::move(wb_block_realize), std::move(old_block_realize), + std::move(rf_loop), std::move(reduce_loop_vars), + std::move(loop_vars2loop)); + Block new_scope_root = Downcast(replacer(std::move(scope_root_block))); + BlockNode* p = new_scope_root.CopyOnWrite(); + p->alloc_buffers.push_back(rf_buffer); + return new_scope_root; + } + + private: + explicit BlockReplacer(Stmt rf_body, For outermost_loop, BlockRealize wb_block_realize, + BlockRealize old_block_realize, For rf_loop, + std::unordered_set reduce_loop_vars, + std::unordered_map loop_vars2loop) + : rf_body_(std::move(rf_body)), + outermost_loop_(std::move(outermost_loop)), + wb_block_realize_(std::move(wb_block_realize)), + old_block_realize_(std::move(old_block_realize)), + rf_loop_(std::move(rf_loop)), + reduce_loop_vars_(std::move(reduce_loop_vars)), + loop_vars2loop_(std::move(loop_vars2loop)) {} + + Stmt VisitStmt_(const ForNode* loop) final { + // Step 1. Check whether this loop is outside the reduction block. Given that we've made sure + // that the scope root block has stage-pipeline property, if this loop is not outside the + // reduction block, there's no need to recursively mutate. + if (!loop_vars2loop_.count(loop->loop_var.get())) { + return GetRef(loop); + } + + // Step 2. Recursively mutate. + Stmt body = StmtMutator::VisitStmt(loop->body); + + // Step 3. If this loop is the rfactor loop and isn't touched by any reduction block iter, it + // should be kept outside the write-back block. Otherwise it shouldn't. + if (loop == rf_loop_.get() || !reduce_loop_vars_.count(loop->loop_var.get())) { + ObjectPtr p_loop = CopyOnWrite(loop); + p_loop->body = body; + body = Stmt(p_loop); + } + + // Step 4. If this loop is the outermost loop of the reduction block, return the combination of + // `rf_body_` and the mutation result `body`. Otherwise return the mutation result. + return loop == outermost_loop_.get() ? SeqStmt({rf_body_, body}) : body; + } + + Stmt VisitStmt_(const BlockRealizeNode* block_realize) final { + // Due to the visitor's behavior on ForNode, this block-realize must be the reduction block's + // block-realize. And we directly return the new `wb_block_realize`. + ICHECK_EQ(block_realize, old_block_realize_.get()); + return wb_block_realize_; + } + + Stmt VisitStmt_(const SeqStmtNode* seq) final { + Array new_stmts; + new_stmts.reserve(static_cast(seq->seq.size())); + + for (const Stmt old_stmt : seq->seq) { + new_stmts.push_back(VisitStmt(old_stmt)); + } + return SeqStmt::Flatten(new_stmts); + } + + private: + Stmt rf_body_; + For outermost_loop_; + BlockRealize wb_block_realize_; + BlockRealize old_block_realize_; + For rf_loop_; + std::unordered_set reduce_loop_vars_; + std::unordered_map loop_vars2loop_; +}; + +StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_axis) { + // ***************************************************** + // * Condition Checks and Information Collection * + // ***************************************************** + + // Step 1. Check some basic conditions for rfactor. Get the block and block-realize. + BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, rf_loop_sref); + const StmtSRef& block_sref = self->stmt2ref.at(block_realize->block.get()); + const Block& block = block_realize->block; + StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + CheckReductionBlock(self, block_sref, scope_root); + const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop, rf_loop_sref); + if (rf_loop->kind != ForKind::kSerial) { + throw NotSerialLoopKindError(self->mod, GetRef(rf_loop)); + } + + // Step 2. Collect loop vars that are touched by data parallel block iters and reduction block + // iters, respectively. + std::unordered_set data_par_loop_vars; + std::unordered_set reduce_loop_vars; + GetVarsTouchedByBlockIters(block_realize, &data_par_loop_vars, &reduce_loop_vars); + + // Step 3. Collect the loops of the reduction block. Construct a mapping from loops to + // corresponding loop vars. + Array loops = LoopSRefs2Loops(GetLoops(block_sref)); + std::unordered_map loop_vars2loop = GetLoopVar2LoopMap(loops); + + // Step 4. Check four properties that the loops should have: + // - the rfactor loop cannot be touched by any data parallel block iter; + // - all the loops cannot be touched by both data parallel block iters and reduction block iters; + // - the outermost loop should have the reduction block as its first child block; + // - the outermost loop that is touched by some reduction block iters can only have one child + // block. + LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, data_par_loop_vars, + reduce_loop_vars); + + // Step 5. Get the `init` identity and the `update` combiner of the reduction. Extract the + // commutative reducer, combiner lhs and combiner rhs from the reduction identity and the + // reduction combiner. The lhs will be used when constructing the write-back block, and the rhs + // will be used when constructing the rfactor block. + BufferStore init; + BufferStore update; + CommReducer reducer; + PrimExpr combiner_lhs, combiner_rhs; + std::tie(init, update) = GetBufferStoreNodes(self, block); + std::tie(reducer, combiner_lhs, combiner_rhs) = + GetReducerAndCombinerLhsRhs(self, init->value, update); + + // Step 6. Check whether `factor_axis` is in a correct range, and convert it to non-negative if it + // is negative. + factor_axis = FactorAxisOutOfRangeError::CheckAndUpdate(self->mod, update->buffer, factor_axis); + + // ***************************************************** + // * IR Manipulation * + // ***************************************************** + // Since rfactor splits the reduction block into two, we call the first one "rfactor block", and + // the latter one "write-back block", and the intermediate buffer is called "rfactor buffer". + + // Step 1. Create the intermediate buffer (a.k.a. rfactor buffer), which has an additional + // dimension that specified by `factor_axis` and `rf_loop`. + Buffer rf_buffer = CreateRFactorBuffer(update->buffer, factor_axis, rf_loop); + + // Step 2. Create the rfactor block. + RFactorBlockCreator rf_block_creator(block_realize, GetRef(rf_loop), update, reducer, + rf_buffer, loop_vars2loop, factor_axis, + std::move(combiner_rhs)); + rf_block_creator.CreateBlock(); + + // Step 3. Create the write-back block. + WriteBackBlockCreator wb_block_creator(block_realize, GetRef(rf_loop), update, reducer, + rf_buffer, std::move(rf_block_creator.additional_iter_), + std::move(combiner_lhs), + std::move(rf_block_creator.rf_buf_access_indices_)); + wb_block_creator.CreateBlock(); + + // Step 4. Wrap the rfactor block with loops. + Stmt rf_body = CreateLoopOutsideRfactorBlock(rf_block_creator.new_block_realize_, loops); + + // ***************************************************** + // * Schedule Replacement & Update * + // ***************************************************** + + // Step 1. Substitute the old scope root block with the new scope root block. + Block old_scope_root_block = GetRef(scope_root->StmtAs()); + Block new_scope_root_block = BlockReplacer::Replace( + old_scope_root_block, rf_body, loops[0], wb_block_creator.new_block_realize_, block_realize, + GetRef(rf_loop), reduce_loop_vars, loop_vars2loop, rf_buffer); + self->Replace( + scope_root, new_scope_root_block, + {{old_scope_root_block, new_scope_root_block}, {block, wb_block_creator.new_block_}}); + + // Step 2. Update scope information. + std::vector new_block_srefs{self->stmt2ref.at(rf_block_creator.new_block_.get()), + self->stmt2ref.at(wb_block_creator.new_block_.get())}; + for (const StmtSRef& new_block_sref : new_block_srefs) { + BlockInfo& info = self->block_info[new_block_sref]; + info.affine_binding = true; + info.region_cover = true; + info.scope->stage_pipeline = true; + } + return new_block_srefs[0]; +} + +/******** Instruction Registration ********/ + +struct RFactorTraits : public UnpackedInstTraits { + static constexpr const char* kName = "RFactor"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer factor_axis) { + return sch->RFactor(loop_rv, factor_axis->value); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, Integer factor_axis) { + PythonAPICall py("rfactor"); + py.Input("loop", loop_rv); + py.Input("factor_axis", factor_axis->value); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(RFactorTraits); + +/******** FFI ********/ + +TVM_REGISTER_GLOBAL("tir.schedule.RegisterReducer") + .set_body_typed([](PackedFunc combiner_getter, PackedFunc identity_getter) { + ReducerRegistry::RegisterReducer(std::move(combiner_getter), std::move(identity_getter)); + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc new file mode 100644 index 000000000000..ac40d27c4bf3 --- /dev/null +++ b/src/tir/schedule/primitive/sampling.cc @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "../primitive.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, + const Array& candidates, const Array& probs, + Optional* decision) { + CHECK(candidates.size() == probs.size()) + << "ValueError: number of candidates does not match number of probabilities."; + int i = -1; + int n = candidates.size(); + + if (decision->defined()) { + const auto* int_imm = decision->as(); + i = int_imm->value; + CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n + << ", but decision is: " << i; + } else { + std::vector weights = support::AsVector(probs); + std::discrete_distribution dist(weights.begin(), weights.end()); + support::LinearCongruentialEngine rand_(rand_state); + i = dist(rand_); + ICHECK(0 <= i && i < n) << "ValueError: Unexpected decision generated, where n = " << n + << ", but decision is: " << i; + } + + *decision = Integer(i); // decision is guaranteed not to be nullptr. + return candidates[i]; +} + +struct SampleCategoricalTraits : public UnpackedInstTraits { + static constexpr const char* kName = "SampleCategorical"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 0; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 1; + + static ExprRV UnpackedApplyToSchedule(Schedule sch, // + Array candidates, // + Array probs, // + Optional decision) { + return sch->SampleCategorical(candidates, probs, decision); + } + + static String UnpackedAsPython(Array outputs, // + Array candidates, // + Array probs, // + Optional decision) { + PythonAPICall py("sample_categorical"); + py.Input("candidates", candidates); + py.Input("probs", probs); + py.Decision(decision); + py.SingleOutput(outputs); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 115f7936f64e..d24cdc625912 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -#include +#include "./utils.h" namespace tvm { namespace tir { @@ -44,31 +44,35 @@ TVM_REGISTER_NODE_TYPE(BlockRVNode); TVM_REGISTER_NODE_TYPE(LoopRVNode); TVM_REGISTER_OBJECT_TYPE(ScheduleNode); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleModule") // +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetMod") // .set_body_method(&ScheduleNode::mod); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // .set_body_method(&ScheduleNode::state); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // - .set_body_method(&ScheduleNode::Seed); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // + .set_body_method(&ScheduleNode::trace); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // .set_body_method(&ScheduleNode::Copy); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // + .set_body_method(&ScheduleNode::Seed); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") // + .set_body_method(&ScheduleNode::ForkSeed); /**************** (FFI) Constructor ****************/ +TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV(); }); +TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); }); TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") - .set_body_typed([](ObjectRef obj, int debug_mode, int error_render_level) -> Schedule { - IRModule mod{nullptr}; - if (const auto* func = obj.as()) { - mod = IRModule({{GlobalVar("main"), GetRef(func)}}); - } else if (const auto* p_mod = obj.as()) { - mod = GetRef(p_mod); - } else { - LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " - << obj->GetTypeKey(); - } - return Schedule::Concrete(mod, debug_mode, + .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, int error_render_level) -> Schedule { + return Schedule::Concrete(mod, debug_mask, seed, static_cast(error_render_level)); }); +TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") + .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, int error_render_level) -> Schedule { + return Schedule::Traced(mod, seed, debug_mask, + static_cast(error_render_level)); + }); /******** (FFI) Lookup random variables ********/ @@ -116,22 +120,43 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") throw; }); -/***** (FFI) Block/Loop relation *****/ - +/******** (FFI) Sampling ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") + .set_body_method(&ScheduleNode::SampleCategorical); +/******** (FFI) Get blocks & loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") .set_body_method(&ScheduleNode::GetLoops); -/******** (FFI) loops manipulation ********/ -/******** (FFI) compute location ********/ +/******** (FFI) Transform loops ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder") + .set_body_method(&ScheduleNode::Reorder); +/******** (FFI) Manipulate ForKind ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel") + .set_body_method(&ScheduleNode::Parallel); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize") + .set_body_method(&ScheduleNode::Vectorize); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method(&ScheduleNode::Bind); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method(&ScheduleNode::Unroll); +/******** (FFI) Insert cache stages ********/ +/******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") .set_body_method(&ScheduleNode::ComputeInline); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") .set_body_method(&ScheduleNode::ReverseComputeInline); -/******** (FFI) loop binding/annotation ********/ -/******** (FFI) cache read/write ********/ -/******** (FFI) reduction ********/ -/******** (FFI) blockize & tensorize ********/ +/******** (FFI) Reduction ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor") + .set_body_method(&ScheduleNode::RFactor); +/******** (FFI) Block annotation ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") + .set_body_method(&ScheduleNode::StorageAlign); +/******** (FFI) Blockize & Tensorize ********/ +/******** (FFI) Annotation ********/ +/******** (FFI) Misc ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") + .set_body_method(&ScheduleNode::EnterPostproc); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index ca61dfea2768..9a9b97497e04 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -43,7 +43,7 @@ Array AnalyzeRegionUpperBound(const BufferRegion& region, AsIntSet(LoopDomainOfSRefTreePath( /*low_inclusive=*/dom_low_inclusive, /*high_exclusive=*/dom_high_exclusive, - /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer->scope)))); + /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())))); } /*! @@ -67,7 +67,7 @@ Array AnalyzeRegionLowerBound(const BlockRealize& realize, LoopDomainOfSRefTreePath( /*low_inclusive=*/dom_low_inclusive, /*high_exclusive=*/dom_high_exclusive, - /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer->scope)), + /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())), /*predicate=*/realize->predicate, /*analyzer=*/analyzer)) { return result.value(); } @@ -112,7 +112,7 @@ bool ProducerCoversConsumer(const Array& buffer_shape, * \param self The schedule class * \param stmt The statement, or the realize node of the statement whose sref to be set * \param seq_index The seq_index to be set - * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block + * \note The method is NOP for statements that are not schedulable, i.e. not For or Block */ void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) { if (const auto* realize = stmt.as()) { @@ -161,34 +161,6 @@ void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new sref->stmt = new_stmt; } -/*! - * \brief Get PrimFunc and GlobalVar that the root block belongs to - * \param mod The IRModule - * \param root_block The root block of the PrimFunc - * \param result_g_var The result GlobalVar - * \return The result PrimFunc where the root block belongs to - * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write - */ -const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, - GlobalVar* result_g_var) { - for (const auto& kv : mod->functions) { - const GlobalVar& g_var = kv.first; - const BaseFunc& base_func = kv.second; - if (const auto* func = base_func.as()) { - if (const auto* realize = func->body.as()) { - if (realize->block.get() == root_block) { - *result_g_var = g_var; - return func; - } - } - } - } - LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the " - "statement:\n" - << GetRef(root_block); - throw; -} - /**************** Creation ****************/ /*! \brief A helper class to create a new ScheduleStateNode from an IRModule */ @@ -198,13 +170,13 @@ class StateCreator : private StmtVisitor { * \brief The entry function * \param self The schedule state to be completed */ - static ObjectPtr Create(IRModule mod, int debug_mode) { + static ObjectPtr Create(IRModule mod, int debug_mask) { ObjectPtr n = make_object(); ScheduleStateNode* self = n.get(); // Set `n->mod` n->mod = std::move(mod); - // Set `n->debug_mode` - n->debug_mode = debug_mode; + // Set `n->debug_mask` + n->debug_mask = debug_mask; // Set `n->stmt2ref` and `n->block_info` StateCreator creator(self); for (const auto& kv : n->mod->functions) { @@ -433,20 +405,17 @@ class StateCreator : private StmtVisitor { std::unordered_map block2realize_; /*! \brief The stack frames of blocks in the DFS visit. */ std::vector> block_frames_; - /*! \brief The auxilary analyzer */ + /*! \brief The auxiliary analyzer */ arith::Analyzer analyzer_; }; /**************** Constructor ****************/ -ScheduleState::ScheduleState(IRModule mod, int debug_mode) { - CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported"; - data_ = StateCreator::Create(mod, debug_mode); +ScheduleState::ScheduleState(IRModule mod, int debug_mask) { + CHECK_GE(debug_mask, -1) << "ValueError: negative `debug_mask` other than -1 is not supported"; + data_ = StateCreator::Create(mod, debug_mask); } -ScheduleState::ScheduleState(PrimFunc func, int debug_mode) - : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {} - /**************** Replace ****************/ /* @@ -596,7 +565,7 @@ class SRefTreePruner : public StmtVisitor { } auto it = self_->stmt2ref.find(op); ICHECK(it != self_->stmt2ref.end()) - << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n" + << "IndexError: Cannot find corresponding StmtSRef for the loop:\n" << GetRef(op); StmtSRef& sref = it->second; // Detect reuse @@ -619,7 +588,7 @@ class SRefTreePruner : public StmtVisitor { } auto it = self_->stmt2ref.find(op); ICHECK(it != self_->stmt2ref.end()) - << "IndexError: Cannot find correpsonding StmtSRef for the block:\n" + << "IndexError: Cannot find corresponding StmtSRef for the block:\n" << GetRef(op); StmtSRef& sref = it->second; // Detect reuse @@ -737,7 +706,7 @@ class SRefUpdater : public StmtVisitor { void UpdateBlockInfo(const StmtSRef& block_sref) { using TIter = std::unordered_map::iterator; // The caller is responsible for correcting the flags - BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref))); + BlockInfo new_info((BlockScope(GetChildBlockSRefOnSRefTree(self_, block_sref)))); std::pair insert_result = self_->block_info.emplace(block_sref, new_info); bool inserted = insert_result.second; BlockInfo& info = insert_result.first->second; @@ -867,7 +836,7 @@ class ChildReplacer : private StmtMutator { void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt, const Map& _block_sref_reuse) { - if (this->debug_mode != 0) { + if (this->debug_mask != 0) { const StmtNode* src_stmt = _src_sref->stmt; bool input_correct = (src_stmt->IsInstance() && tgt_stmt->IsInstance()) || @@ -1021,8 +990,8 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ new_map->at(g_var) = std::move(ref_new_func); this->mod = GetRef(new_mod); } - uint32_t flag = (debug_mode != -1) // - ? static_cast(debug_mode) // + uint32_t flag = (debug_mask != -1) // + ? static_cast(debug_mask) // : std::numeric_limits::max(); if (flag & ScheduleDebugMask::kVerifySRefTree) { VerifySRefTree(GetRef(this)); @@ -1030,9 +999,9 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ } void ScheduleStateNode::DebugVerify() const { - ICHECK_GE(debug_mode, -1); - uint32_t flag = (debug_mode != -1) // - ? static_cast(debug_mode) // + ICHECK_GE(debug_mask, -1); + uint32_t flag = (debug_mask != -1) // + ? static_cast(debug_mask) // : std::numeric_limits::max(); if (flag & ScheduleDebugMask::kVerifySRefTree) { VerifySRefTree(GetRef(this)); @@ -1045,7 +1014,7 @@ void ScheduleStateNode::DebugVerify() const { /**************** BlockInfo-related ****************/ BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const { - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); auto it = this->block_info.find(block_sref); CHECK(it != this->block_info.end()) << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" @@ -1063,16 +1032,10 @@ TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& bl /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(ScheduleStateNode); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState").set_body_typed([](ObjectRef obj, int debug_mode) { - if (const auto* func = obj.as()) { - return ScheduleState(GetRef(func), debug_mode); - } - if (const auto* mod = obj.as()) { - return ScheduleState(GetRef(mod), debug_mode); - } - LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " << obj->GetTypeKey(); - throw; -}); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState") + .set_body_typed([](IRModule mod, int debug_mask) -> ScheduleState { + return ScheduleState(mod, debug_mask); + }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope") .set_body_method(&ScheduleStateNode::GetBlockScope); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateReplace") diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc new file mode 100644 index 000000000000..d8c18f0de0d6 --- /dev/null +++ b/src/tir/schedule/trace.cc @@ -0,0 +1,533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace tir { + +/**************** Constructors ****************/ + +Trace::Trace() { data_ = make_object(); } + +Trace::Trace(Array insts, Map decisions) { + ObjectPtr n = make_object(); + n->insts = std::move(insts); + n->decisions = std::move(decisions); + data_ = std::move(n); +} + +/**************** Utilities ****************/ + +bool IsPostproc(const InstructionKind& inst_kind) { + static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); + return inst_kind.same_as(inst_enter_postproc); +} + +int GetNumValidInstructions(const Array& insts, bool remove_postproc) { + if (!remove_postproc) { + return insts.size(); + } + int n_insts = 0; + for (const Instruction& inst : insts) { + if (!IsPostproc(inst->kind)) { + ++n_insts; + } else { + break; + } + } + return n_insts; +} + +/**************** TranslateInputRVs ****************/ + +Array TranslateInputRVs(const Array& inputs, + const std::unordered_map& rv_map) { + Array result; + result.reserve(inputs.size()); + for (const ObjectRef& input : inputs) { + if (!input.defined() || // constant: nullptr + input->IsInstance() || // constant: string + input->IsInstance() || // constant: integer + input->IsInstance()) { // constant: float + result.push_back(input); + } else if (input->IsInstance() || // RV: block + input->IsInstance() || // RV: loop + input->IsInstance()) { // RV: var + auto it = rv_map.find(input.get()); + ICHECK(it != rv_map.end()) << "IndexError: Random variable doesn't exist: " << input; + result.push_back(GetRef(it->second)); + } else if (const auto* expr = input.as()) { // RV: Expr + result.push_back( + Substitute(GetRef(expr), [&rv_map](const Var& var) -> Optional { + auto it = rv_map.find(var.get()); + if (it == rv_map.end()) { + return NullOpt; + } + const Object* dst = it->second; + ICHECK(dst->IsInstance()) + << "TypeError: Expect 'tir.Var', but gets: " << dst->GetTypeKey(); + return GetRef(static_cast(dst)); + })); + } else { + ICHECK(false) << "TypeError: Cannot recognize the type of an input random variable: " + << input->GetTypeKey(); + throw; + } + } + return result; +} + +Array TranslateInputRVs( + const Array& inputs, + const std::unordered_map& rv_names) { + Array results; + results.reserve(inputs.size()); + for (const ObjectRef& input : inputs) { + if (!input.defined()) { + // Case 0. nullptr => None + results.push_back(String("None")); + continue; + } + auto it = rv_names.find(input); + if (it != rv_names.end()) { + // Case 1. BlockRV, LoopRV, VarRV + results.push_back(it->second); + } else if (const auto* str_obj = input.as()) { + // Case 2. string => "content" + results.push_back(String('"' + std::string(str_obj->data) + '"')); + } else if (input->IsInstance() || input->IsInstance()) { + // Case 3. integer or floating-point number + results.push_back(input); + } else if (input->IsInstance() || inputs->IsInstance() || + inputs->IsInstance()) { + LOG(FATAL) << "IndexError: Random variable is not defined " << input; + throw; + } else { + LOG(FATAL) << "TypeError: Stringifying is not supported for type: " << input->GetTypeKey(); + throw; + } + } + return results; +} + +Array TranslateInputRVs(const Array& inputs, + const std::unordered_map& named_rvs) { + Array results; + results.reserve(inputs.size()); + for (const ObjectRef& input : inputs) { + // Case 3. integer or floating-point number + if (input->IsInstance() || input->IsInstance()) { + results.push_back(input); + continue; + } + const auto* str = input.as(); + CHECK(str) << "TypeError: Expect String, but gets: " << input->GetTypeKey(); + CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in input names"; + const char* name = str->data; + int64_t size = str->size; + // Case 2. string + if (size > 2 && name[0] == '"' && name[size - 1] == '"') { + results.push_back(String(std::string(name + 1, size - 2))); + continue; + } + // Case 0 & 1. None, BlockRV, LoopRV, VarRV + auto it = named_rvs.find(name); + CHECK(it != named_rvs.end()) << "ValueError: The random variable is not defined: " << name; + results.push_back(it->second); + } + return results; +} + +/**************** TranslateAddOutputRVs ****************/ + +void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, + std::unordered_map* rv_map) { + ICHECK_EQ(old_outputs.size(), new_outputs.size()); + int n = old_outputs.size(); + const ObjectRef* p_old = old_outputs.GetArrayNode()->begin(); + const ObjectRef* p_new = new_outputs.GetArrayNode()->begin(); + for (int i = 0; i < n; ++i) { + (*rv_map)[p_old[i].get()] = p_new[i].get(); + } +} + +Array TranslateAddOutputRVs( + const Array& outputs, + std::unordered_map* rv_names) { + Array results; + results.reserve(outputs.size()); + for (const ObjectRef& output : outputs) { + int i = rv_names->size(); + ICHECK(!rv_names->count(output)) + << "ValueError: The random variable has been produced once: " << rv_names->at(output); + String result{ObjectPtr{nullptr}}; + if (output->IsInstance()) { + result = "b" + std::to_string(i); + } else if (output->IsInstance()) { + result = "l" + std::to_string(i); + } else if (output->IsInstance()) { + result = "v" + std::to_string(i); + } else { + LOG(FATAL) << "TypeError: Cannot recognize the type of the random variable: " + << output->GetTypeKey(); + throw; + } + results.push_back(result); + rv_names->emplace(output, std::move(result)); + } + return results; +} + +void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, + std::unordered_map* named_rvs) { + ICHECK_EQ(old_outputs.size(), new_outputs.size()); + int n = old_outputs.size(); + const ObjectRef* p_old = old_outputs.GetArrayNode()->begin(); + const ObjectRef* p_new = new_outputs.GetArrayNode()->begin(); + for (int i = 0; i < n; ++i) { + const auto* name = static_cast(p_old[i].get()); + named_rvs->emplace(std::string(name->data, name->size), p_new[i]); + } +} + +/**************** Add/Remove/Get ****************/ + +Optional TraceNode::GetDecision(const Instruction& inst) const { + auto it = this->decisions.find(inst); + return it == this->decisions.end() ? Optional(NullOpt) : (*it).second; +} + +void TraceNode::Append(Instruction inst) { insts.push_back(std::move(inst)); } + +void TraceNode::Append(Instruction inst, ObjectRef decision) { + decisions.Set(inst, std::move(decision)); + insts.push_back(std::move(inst)); +} + +Optional TraceNode::Pop() { + if (insts.empty()) { + return NullOpt; + } + Instruction inst = insts.back(); + insts.pop_back(); + if (decisions.count(inst)) { + decisions.erase(inst); + } + return inst; +} + +/**************** Interfacing with InstructionKind ****************/ + +void TraceNode::ApplyToSchedule( + Schedule sch, bool remove_postproc, + runtime::TypedPackedFunc& inputs, // + const Array& attrs, // + const Optional& decision)> + decision_provider) const { + std::unordered_map rv_map; + for (const Instruction& inst : this->insts) { + if (remove_postproc && IsPostproc(inst->kind)) { + break; + } + Array inputs = TranslateInputRVs(inst->inputs, rv_map); + Array attrs = inst->attrs; + Optional decision = this->GetDecision(inst); + if (decision_provider != nullptr) { + decision = decision_provider(inst, inputs, attrs, decision); + } + Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, attrs, decision); + TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); + } +} + +ObjectRef TraceNode::AsJSON(bool remove_postproc) const { + std::unordered_map rv_names; + Array json_insts; + Array json_decisions; + json_insts.reserve(this->insts.size()); + json_decisions.reserve(this->insts.size()); + + int i = 0; + for (const Instruction& inst : this->insts) { + const InstructionKind& kind = inst->kind; + if (remove_postproc && IsPostproc(kind)) { + break; + } + json_insts.push_back(Array{ + /* 0: inst name */ kind->name, + /* 1: inputs */ TranslateInputRVs(inst->inputs, rv_names), + /* 2: attrs */ kind->f_attrs_as_json != nullptr ? kind->f_attrs_as_json(inst->attrs) + : ObjectRef(inst->attrs), + /* 3: outputs */ TranslateAddOutputRVs(inst->outputs, &rv_names), + }); + if (Optional decision = this->GetDecision(inst)) { + json_decisions.push_back(Array{ + /* 0: index */ Integer(i), + /* 1: decision */ decision.value(), + }); + } + ++i; + } + return Array{ + /* 0: trace */ std::move(json_insts), + /* 1: decision */ std::move(json_decisions), + }; +} + +Array TraceNode::AsPython(bool remove_postproc) const { + std::unordered_map rv_names; + Array py_trace; + py_trace.reserve(this->insts.size()); + for (const Instruction& inst : this->insts) { + if (remove_postproc && IsPostproc(inst->kind)) { + break; + } + Array attrs; + attrs.reserve(inst->attrs.size()); + for (const ObjectRef& obj : inst->attrs) { + if (const auto* str = obj.as()) { + attrs.push_back(String('"' + std::string(str->data) + '"')); + } else { + attrs.push_back(obj); + } + } + py_trace.push_back( + inst->kind->f_as_python(/*inputs=*/TranslateInputRVs(inst->inputs, rv_names), + /*attrs=*/attrs, + /*decision=*/this->GetDecision(inst), + /*outputs=*/TranslateAddOutputRVs(inst->outputs, &rv_names))); + } + return py_trace; +} + +void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { + Array json_insts{nullptr}; + Array json_decisions{nullptr}; + // Parse `json` into `json_insts` and `json_decisions` + try { + const ArrayNode* arr = json.as(); + ICHECK(arr && arr->size() == 2); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + ICHECK(arr0 && arr1); + json_insts = GetRef>(arr0); + json_decisions = GetRef>(arr1); + } catch (const tvm::Error& e) { + LOG(FATAL) << "ValueError: The json entry of a trace should contain two arrays, an array of " + "instructions and an array of decisions, but gets: " + << json; + throw; + } + // Parse `json_decisions` + std::vector> decisions(json_insts.size(), NullOpt); + for (const ObjectRef& decision_entry : json_decisions) { + int index = -1; + ObjectRef decision{nullptr}; + try { + const ArrayNode* arr = decision_entry.as(); + ICHECK(arr && arr->size() == 2); + const IntImmNode* arr0 = arr->at(0).as(); + ICHECK(arr0); + index = arr0->value; + decision = arr->at(1); + } catch (const tvm::Error& e) { + LOG(FATAL) << "ValueError: Each entry of a json decision should be a tuple [index, " + "decision], but gets: " + << decision_entry; + throw; + } + decisions[index] = std::move(decision); + } + // Parse `json_insts` + std::unordered_map named_rvs{{"None", ObjectRef{nullptr}}}; + int i = 0; + for (const ObjectRef& inst_entry : json_insts) { + InstructionKind kind{nullptr}; + Array inputs{nullptr}; + Array attrs{nullptr}; + Array outputs{ObjectPtr{nullptr}}; + // Parse the entry + try { + const auto* arr = inst_entry.as(); + ICHECK(arr && arr->size() == 4); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + const auto* arr2 = arr->at(2).as(); + const auto* arr3 = arr->at(3).as(); + ICHECK(arr0 && arr1 && arr2 && arr3); + for (const ObjectRef& str : *arr3) { + ICHECK(str->IsInstance()); + } + kind = InstructionKind::Get(arr0->data); + inputs = GetRef>(arr1); + attrs = GetRef>(arr2); + outputs = GetRef>(arr3); + } catch (const tvm::Error& e) { + LOG(FATAL) << "ValueError: Each entry of a json instruction should be a tuple [inst_name, " + "inputs, attrs, outputs], but gets: " + << inst_entry; + throw; + } + // Parse inputs + inputs = TranslateInputRVs(inputs, named_rvs); + // Parse attrs + if (kind->f_attrs_from_json != nullptr) { + attrs = kind->f_attrs_from_json(attrs); + } + // Apply to the schedule + Array new_outputs = kind->f_apply_to_schedule(sch, inputs, attrs, decisions[i]); + // Parse outputs + TranslateAddOutputRVs(outputs, new_outputs, &named_rvs); + ++i; + } +} + +/**************** Creation ****************/ + +Trace TraceNode::WithDecision(Instruction inst, ObjectRef decision, bool remove_postproc) const { + int n_insts = GetNumValidInstructions(this->insts, remove_postproc); + Array new_insts = + Array{this->insts.begin(), this->insts.begin() + n_insts}; + Map new_decisions{this->decisions.begin(), this->decisions.end()}; + new_decisions.Set(std::move(inst), std::move(decision)); + return Trace(new_insts, new_decisions); +} + +Trace TraceNode::Simplified(bool remove_postproc) const { + int n_insts = GetNumValidInstructions(this->insts, remove_postproc); + std::unordered_set used_rvs; + std::vector new_insts; + std::unordered_map new_decisions; + new_insts.reserve(n_insts); + new_decisions.reserve(this->decisions.size()); + for (int inst_idx = n_insts - 1; inst_idx >= 0; --inst_idx) { + const Instruction& inst = this->insts[inst_idx]; + // Check if all the variables the instruction defined are dead + // If so, and the instruction is pure, we can safely remove this instruction + bool all_defs_dead = inst->kind->is_pure; + if (all_defs_dead) { + for (const ObjectRef& obj : inst->outputs) { + if (used_rvs.count(obj.get())) { + all_defs_dead = false; + break; + } + } + } + // Remove this instruction + if (all_defs_dead) { + continue; + } + // Otherwise this instruction is not dead + new_insts.push_back(inst); + if (Optional decision = this->GetDecision(inst)) { + new_decisions.emplace(inst, std::move(decision)); + } + // Add its inputs as "used" ones + for (const ObjectRef& obj : inst->inputs) { + if (obj->IsInstance() || obj->IsInstance() || + obj->IsInstance()) { + used_rvs.insert(obj.get()); + continue; + } else if (obj->IsInstance()) { + PostOrderVisit(obj, [&used_rvs](const ObjectRef& obj) -> void { + if (obj->IsInstance()) { + used_rvs.insert(obj.get()); + } + }); + } + } + } + return Trace(Array(new_insts.rbegin(), new_insts.rend()), + Map(new_decisions)); +} + +/**************** Repr ****************/ + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { + const auto* self = obj.as(); + ICHECK_NOTNULL(self); + Array repr = self->AsPython(/*remove_postproc=*/false); + bool is_first = true; + for (const String& line : repr) { + if (is_first) { + is_first = false; + } else { + p->stream << std::endl; + } + p->stream << line; + } + }); + +/**************** Instruction Registration ****************/ + +struct EnterPostprocTraits : public UnpackedInstTraits { + static constexpr const char* kName = "EnterPostproc"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 0; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch) { return sch->EnterPostproc(); } + + static String UnpackedAsPython(Array outputs) { + PythonAPICall py("enter_postproc"); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(EnterPostprocTraits); + +/**************** FFI ****************/ + +TVM_REGISTER_NODE_TYPE(TraceNode); +TVM_REGISTER_GLOBAL("tir.schedule.Trace") + .set_body_typed([](Optional> insts, + Optional> decisions) { + return Trace(insts.value_or(Array()), + decisions.value_or(Map())); + }); +TVM_REGISTER_GLOBAL("tir.schedule.TraceGetDecision") + .set_body_method(&TraceNode::GetDecision); +TVM_REGISTER_GLOBAL("tir.schedule.TraceAppend") + .set_body_typed([](Trace self, Instruction inst, Optional decision) { + if (decision.defined()) { + return self->Append(inst, decision.value()); + } else { + return self->Append(inst); + } + }); +TVM_REGISTER_GLOBAL("tir.schedule.TracePop").set_body_method(&TraceNode::Pop); +TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyToSchedule") + .set_body_method(&TraceNode::ApplyToSchedule); +TVM_REGISTER_GLOBAL("tir.schedule.TraceAsJSON").set_body_method(&TraceNode::AsJSON); +TVM_REGISTER_GLOBAL("tir.schedule.TraceAsPython").set_body_method(&TraceNode::AsPython); +TVM_REGISTER_GLOBAL("tir.schedule.TraceWithDecision") + .set_body_method(&TraceNode::WithDecision); +TVM_REGISTER_GLOBAL("tir.schedule.TraceSimplified").set_body_method(&TraceNode::Simplified); +TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyJSONToSchedule") + .set_body_typed(Trace::ApplyJSONToSchedule); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc new file mode 100644 index 000000000000..af4a6588f064 --- /dev/null +++ b/src/tir/schedule/traced_schedule.cc @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./traced_schedule.h" + +namespace tvm { +namespace tir { + +Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, ScheduleErrorRenderLevel error_render_level) { + ObjectPtr n = make_object(); + n->state_ = ScheduleState(mod, debug_mask); + n->error_render_level_ = error_render_level; + n->symbol_table_ = {}; + n->analyzer_ = std::make_unique(); + n->trace_ = Trace(); + support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); + return Schedule(std::move(n)); +} + +Schedule TracedScheduleNode::Copy() const { + ObjectPtr n = make_object(); + n->error_render_level_ = this->error_render_level_; + ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); + n->analyzer_ = std::make_unique(); // new analyzer needed because it is stateful + n->trace_ = Trace(this->trace_->insts, this->trace_->decisions); + return Schedule(std::move(n)); +} + +/******** Schedule: Sampling ********/ +ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { + ExprRV result = + CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); + static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{}, + /*attrs=*/{candidates, probs}, + /*outputs=*/{result}), + /*decision=*/decision); + return result; +} + +/******** Schedule: Get blocks & loops ********/ + +BlockRV TracedScheduleNode::GetBlock(const String& name, const String& func_name) { + BlockRV result = ConcreteScheduleNode::GetBlock(name, func_name); + + static const InstructionKind& kind = InstructionKind::Get("GetBlock"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{}, + /*attrs=*/{name, func_name}, + /*outputs=*/{result})); + return result; +} + +Array TracedScheduleNode::GetLoops(const BlockRV& block_rv) { + Array results = ConcreteScheduleNode::GetLoops(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("GetLoops"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + +/******** Schedule: Transform loops ********/ + +LoopRV TracedScheduleNode::Fuse(const Array& loop_rvs) { + LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs); + + static const InstructionKind& kind = InstructionKind::Get("Fuse"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rvs.begin(), loop_rvs.end()}, + /*attrs=*/{}, + /*outputs=*/{result})); + return result; +} + +Array TracedScheduleNode::Split(const LoopRV& loop_rv, + const Array>& factor_rvs) { + Array results = ConcreteScheduleNode::Split(loop_rv, factor_rvs); + + std::vector inputs; + inputs.reserve(1 + factor_rvs.size()); + inputs.push_back(loop_rv); + for (const ObjectRef& obj : factor_rvs) { + inputs.push_back(obj); + } + + static const InstructionKind& kind = InstructionKind::Get("Split"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/inputs, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + +void TracedScheduleNode::Reorder(const Array& ordered_loop_rvs) { + ConcreteScheduleNode::Reorder(ordered_loop_rvs); + + static const InstructionKind& kind = InstructionKind::Get("Reorder"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{ordered_loop_rvs.begin(), ordered_loop_rvs.end()}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +/******** Schedule: Manipulate ForKind ********/ + +void TracedScheduleNode::Parallel(const LoopRV& loop_rv) { + ConcreteScheduleNode::Parallel(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("Parallel"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Vectorize(const LoopRV& loop_rv) { + ConcreteScheduleNode::Vectorize(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("Vectorize"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) { + ConcreteScheduleNode::Bind(loop_rv, thread_axis); + + static const InstructionKind& kind = InstructionKind::Get("Bind"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{thread_axis}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Unroll(const LoopRV& loop_rv) { + ConcreteScheduleNode::Unroll(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("Unroll"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +/******** Schedule: Insert cache stages ********/ + +/******** Schedule: Compute location ********/ + +void TracedScheduleNode::ComputeInline(const BlockRV& block_rv) { + ConcreteScheduleNode::ComputeInline(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("ComputeInline"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +void TracedScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { + ConcreteScheduleNode::ReverseComputeInline(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("ReverseComputeInline"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +/******** Schedule: Reduction ********/ + +BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { + BlockRV result = ConcreteScheduleNode::RFactor(loop_rv, factor_axis); + static const InstructionKind& kind = InstructionKind::Get("RFactor"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{Integer(factor_axis)}, + /*outputs=*/{result})); + return result; +} + +/******** Schedule: Block annotation ********/ + +void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, + int factor, int offset) { + ConcreteScheduleNode::StorageAlign(block_rv, buffer_index, axis, factor, offset); + static const InstructionKind& kind = InstructionKind::Get("StorageAlign"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index), Integer(axis), Integer(factor), Integer(offset)}, + /*outputs=*/{})); +} + +/******** Schedule: Blockize & Tensorize ********/ + +/******** Schedule: Annotation ********/ + +/******** Schedule: Misc ********/ + +void TracedScheduleNode::EnterPostproc() { + ConcreteScheduleNode::EnterPostproc(); + static const InstructionKind& kind = InstructionKind::Get("EnterPostproc"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h new file mode 100644 index 000000000000..48dadbc03b3b --- /dev/null +++ b/src/tir/schedule/traced_schedule.h @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ +#define TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ + +#include "./concrete_schedule.h" + +namespace tvm { +namespace tir { + +class TracedScheduleNode : public ConcreteScheduleNode { + friend class Schedule; + + protected: + Trace trace_; + + public: + void VisitAttrs(tvm::AttrVisitor* v) { + // `state_` is not visited + // `error_render_level_` is not visited + // `symbol_table_` is not visited + // `analyzer_` is not visitied + // `trace_` is not visited + } + + ~TracedScheduleNode() = default; + + public: + Optional trace() const final { return trace_; } + Schedule Copy() const final; + + public: + /******** Schedule: Sampling ********/ + /*! + * \brief Sample an integer given the probability distribution + * \param candidates The candidates + * \param probs The probability distribution of the candidates + * \param decision The sampling decision, if it's given we would validate the decision, otherwise + * we would sample a decision from the distribution and set the decision accordingly. + * \return The random variable sampled from candidates + */ + ExprRV SampleCategorical(const Array& candidates, const Array& probs, + Optional decision = NullOpt) final; + + /******** Schedule: Get blocks & loops ********/ + BlockRV GetBlock(const String& name, const String& func_name = "main") final; + Array GetLoops(const BlockRV& block_rv) final; + /******** Schedule: Transform loops ********/ + LoopRV Fuse(const Array& loop_rvs) final; + Array Split(const LoopRV& loop_rv, const Array>& factor_rvs) final; + void Reorder(const Array& ordered_loop_rvs) final; + /******** Schedule: Manipulate ForKind ********/ + void Parallel(const LoopRV& loop_rv) final; + void Vectorize(const LoopRV& loop_rv) final; + void Bind(const LoopRV& loop_rv, const String& thread_axis) final; + void Unroll(const LoopRV& loop_rv) final; + /******** Schedule: Insert cache stages ********/ + /******** Schedule: Compute location ********/ + void ComputeInline(const BlockRV& block_rv) final; + void ReverseComputeInline(const BlockRV& block_rv) final; + /******** Schedule: Reduction ********/ + BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final; + /******** Schedule: Block annotation ********/ + void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, + int offset) final; + /******** Schedule: Blockize & Tensorize ********/ + /******** Schedule: Annotation ********/ + /******** Schedule: Misc ********/ + void EnterPostproc() final; +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc new file mode 100644 index 000000000000..f27e0f6d62eb --- /dev/null +++ b/src/tir/schedule/transform.cc @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./transform.h" + +namespace tvm { +namespace tir { + +/******** Annotation ********/ +Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) { + Map annotations = block->annotations; + annotations.Set(attr_key, attr_value); + ObjectPtr new_block = make_object(*block); + new_block->annotations = std::move(annotations); + return Block(new_block); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h new file mode 100644 index 000000000000..53483829a303 --- /dev/null +++ b/src/tir/schedule/transform.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_ +#define TVM_TIR_SCHEDULE_TRANSFORM_H_ + +#include + +namespace tvm { +namespace tir { + +/******** Annotation ********/ + +/*! + * \brief Create a new block with the given annotation added + * \param block The block with original annotation + * \param attr_key The annotation key to be added + * \param attr_value The annotation value to be added + * \return A new block with the given annotation as its last annotation + */ +Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value); + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_TRANSFORM_H_ diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 19ed995ac8cc..8ccf8da731b5 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -25,18 +25,22 @@ #include #include #include +#include #include #include +#include #include #include #include +#include "../../node/attr_registry.h" #include "../../printer/text_printer.h" #include "../../runtime/thread_storage_scope.h" #include "../../support/array.h" #include "./analysis.h" #include "./error.h" +#include "./instruction_traits.h" #include "./primitive.h" namespace tvm { @@ -98,6 +102,21 @@ namespace tir { << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \ << "`, but gets: " << (From.defined() ? From->GetTypeKey() : "None") +/*! + * \brief Convert an array of loop StmtSRefs to an array of loops + * \param loop_srefs The loop StmtSRefs to be converted + * \return The conversion result loops + */ +inline Array LoopSRefs2Loops(const Array& loop_srefs) { + Array loops; + loops.reserve(loop_srefs.size()); + for (StmtSRef loop_sref : loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + loops.push_back(GetRef(loop)); + } + return loops; +} + /******** Storage scope ********/ /*! @@ -143,6 +162,18 @@ inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) { return SeqStmt::Flatten(new_stmts); } +/*! + * \brief Create a new IterVar for the input For loop, with specified name and type + * \param loop The loop to be created from + * \param name The name of the new IterVar + * \param iter_var_type The type of the new IterVar + * \return The newly created IterVar + */ +inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_var_type) { + return IterVar(Range::FromMinExtent(loop->min, loop->extent), + Var(std::move(name), loop->loop_var.dtype()), iter_var_type); +} + /******** Integer set ********/ /*! diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 9cd29357f8c7..293c990d2745 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -88,7 +88,7 @@ void ArgBinder::BindArray(const Array& arg, const Array& val void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name, bool fuzzy_match) { - ICHECK_EQ(arg->scope, value->scope) << "Argument " << arg_name << " Buffer bind scope mismatch"; + ICHECK_EQ(arg.scope(), value.scope()) << "Argument " << arg_name << " Buffer bind scope mismatch"; ICHECK_EQ(arg->dtype, value->dtype) << "Argument " << arg_name << " Buffer bind data type mismatch"; if (value->data_alignment % arg->data_alignment != 0) { diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 7a8789457923..76845cbebd2a 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -323,8 +323,8 @@ class BF16LowerRewriter : public StmtExprMutator { DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype))); auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset, - oldbuf->name, oldbuf->scope, oldbuf->data_alignment, - oldbuf->offset_factor, oldbuf->buffer_type); + oldbuf->name, oldbuf->data_alignment, oldbuf->offset_factor, + oldbuf->buffer_type); buffer_remap_[oldbuf] = newbuf; var_remap_[oldbuf->data] = buffer_var; changes.emplace_back(itr.first, newbuf); diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index edbafe27cf13..961ea1721fa1 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -32,6 +32,7 @@ #include "../../support/arena.h" #include "../../support/utils.h" #include "../schedule/utils.h" +#include "ir_utils.h" namespace tvm { namespace tir { @@ -203,7 +204,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { std::unordered_map dom_map; for (const ForNode* loop : ancestor_loops_) { const VarNode* loop_var = loop->loop_var.get(); - if (NeedRelaxThread(GetRef(loop), runtime::StorageScope::Create(buffer->scope))) { + if (NeedRelaxThread(GetRef(loop), runtime::StorageScope::Create(buffer.scope()))) { dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent); } } @@ -302,18 +303,61 @@ class BufferAccessRegionCollector : public StmtExprVisitor { support::Arena arena_; }; +/*! \brief Collect storage alignment information from block annotations. */ +class StorageAlignCollector : public StmtVisitor { + public: + static std::unordered_map Collect( + const PrimFunc& f) { + StorageAlignCollector collector; + collector(f->body); + return std::move(collector.storage_align_); + } + + private: + void VisitStmt_(const BlockNode* op) final { + auto it = op->annotations.find(attr::buffer_dim_align); + if (it != op->annotations.end()) { + auto storage_align_annotation = Downcast((*it).second); + for (const auto& storage_align_tuple : storage_align_annotation) { + int buffer_index = storage_align_tuple[0]->value; + const Buffer& buffer = op->writes[buffer_index]->buffer; + storage_align_[buffer].push_back(storage_align_tuple); + } + } + StmtVisitor::VisitStmt_(op); + } + + /*! \brief The map from Buffer to its storage alignment information. */ + std::unordered_map storage_align_; +}; + /*! \brief Reallocate the buffers with minimal region. */ class BufferCompactor : public StmtExprMutator { public: static Stmt Compact( const PrimFunc& f, - const std::unordered_map& regions) { + const std::unordered_map& regions, + const std::unordered_map& + storage_align) { std::unordered_map buffer_info; for (const auto& kv : regions) { const Buffer& buffer = kv.first; Region region = kv.second; - buffer_info.emplace(buffer, BufferAllocInfo(std::move(region))); + BufferAllocInfo buffer_alloc_info(std::move(region)); + auto it = storage_align.find(buffer); + if (it != storage_align.end()) { + std::vector dim_aligns(buffer->shape.size()); + for (const StorageAlignTuple& dim_align : (*it).second) { + ICHECK(dim_align.size() == 4); + int dim = dim_align[1]->value; + int factor = dim_align[2]->value; + int offset = dim_align[3]->value; + dim_aligns.at(dim) = {factor, offset}; + } + buffer_alloc_info.dim_aligns = std::move(dim_aligns); + } + buffer_info.emplace(buffer, std::move(buffer_alloc_info)); } BufferCompactor compactor(std::move(buffer_info)); Stmt stmt = compactor(f->body); @@ -321,9 +365,19 @@ class BufferCompactor : public StmtExprMutator { } private: + /*! \brief The storage alignment for a dimension */ + struct DimAlignInfo { + /*! \brief The factor of the alignment */ + int align_factor{0}; + /*! \brief The offset of the alignment */ + int align_offset{0}; + }; + struct BufferAllocInfo { /*! \brief The buffer access region. */ Region region; + /*! \brief The storage alignment information. */ + std::vector dim_aligns; /*! * \brief The reallocated buffer with minimal size. * \note The value if NullOpt if the buffer do not need reallocate (e.g parameter buffer). @@ -362,6 +416,7 @@ class BufferCompactor : public StmtExprMutator { BlockNode* n = block.CopyOnWrite(); RewriteBufferRegions(&n->reads); RewriteBufferRegions(&n->writes); + RewriteMatchBuffers(&n->match_buffers); n->alloc_buffers = std::move(alloc_buffers); return std::move(block); } @@ -378,8 +433,25 @@ class BufferCompactor : public StmtExprMutator { for (const Range& range : info.region) { shape.push_back(range->extent); } + Array strides; + if (info.dim_aligns.size()) { + ICHECK(info.dim_aligns.size() == shape.size()); + strides.resize(shape.size()); + PrimExpr stride = make_const(shape[0].dtype(), 1); + for (size_t i = shape.size(); i != 0; --i) { + size_t dim = i - 1; + if (info.dim_aligns[dim].align_factor != 0) { + PrimExpr factor = make_const(stride.dtype(), info.dim_aligns[dim].align_factor); + PrimExpr offset = make_const(stride.dtype(), info.dim_aligns[dim].align_offset); + stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); + } + strides.Set(dim, stride); + stride = stride * shape[dim]; + } + } ObjectPtr n = make_object(*buffer.get()); n->shape = std::move(shape); + n->strides = std::move(strides); info.new_buffer = Buffer(std::move(n)); result.push_back(info.new_buffer); } @@ -434,16 +506,35 @@ class BufferCompactor : public StmtExprMutator { *regions = std::move(new_regions); } + void RewriteMatchBuffers(Array* match_buffers) const { + Array result; + result.reserve(match_buffers->size()); + for (const auto& match_buffer : *match_buffers) { + const BufferRegion& buffer_region = match_buffer->source; + auto p = make_object(*buffer_region.get()); + RewriteBufferRegion(&p->buffer, &p->region); + result.push_back(MatchBufferRegion(match_buffer->buffer, BufferRegion(p))); + } + *match_buffers = std::move(result); + } + /*! \brief The allocation information about each buffer. */ std::unordered_map buffer_info_; }; PrimFunc CompactBufferAllocation(PrimFunc f) { - PrimFuncNode* fptr = f.CopyOnWrite(); - std::unordered_map region = - BufferAccessRegionCollector::Collect(f); - fptr->body = BufferCompactor::Compact(f, region); - return f; + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + std::unordered_map region = + BufferAccessRegionCollector::Collect(f); + std::unordered_map + storage_align = StorageAlignCollector::Collect(f); + fptr->body = BufferCompactor::Compact(f, region, storage_align); + return f; + } else { + return f; + } } namespace transform { diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc index 4c5e1dd5125b..f7629d100645 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -25,6 +25,8 @@ #include #include +#include "ir_utils.h" + namespace tvm { namespace tir { @@ -83,9 +85,14 @@ class OpaqueBlockConverter : public StmtExprMutator { }; PrimFunc ConvertBlocksToOpaque(PrimFunc f) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = OpaqueBlockConverter::Substitute(f); - return f; + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = OpaqueBlockConverter::Substitute(f); + return f; + } else { + return f; + } } namespace transform { diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 07f7b42fe2eb..5eb6d5b03921 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -28,6 +28,7 @@ #include #include "../../support/utils.h" +#include "ir_utils.h" namespace tvm { namespace tir { @@ -127,13 +128,9 @@ class BufferFlattener : public StmtExprMutator { } static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body) { - String storage_scope = buffer->scope; - if (storage_scope.empty()) { - storage_scope = "global"; - } + String storage_scope = buffer.scope(); PrimExpr area = BufferArea(buffer); body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body)); - body = AttrStmt(buffer->data, attr::storage_scope, StringImm(storage_scope), std::move(body)); return body; } @@ -143,7 +140,10 @@ class BufferFlattener : public StmtExprMutator { /*var=*/std::move(var), /*iter_type=*/IterVarType::kThreadIndex, /*thread_tag=*/thread_tag); - String attr_key = thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent; + String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" || + thread_tag == "vthread.y" || thread_tag == "vthread.z") + ? attr::virtual_thread + : attr::thread_extent; return AttrStmt(/*node=*/std::move(iter_var), /*attr_key=*/std::move(attr_key), /*value=*/std::move(extent), @@ -155,9 +155,14 @@ class BufferFlattener : public StmtExprMutator { }; PrimFunc FlattenBuffer(PrimFunc f) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = BufferFlattener::Flatten(f); - return f; + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = BufferFlattener::Flatten(f); + return f; + } else { + return f; + } } namespace transform { diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index f7443c74c0f7..f99cbd5b5a05 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -29,6 +29,7 @@ #include #include "../../arith/pattern_match.h" +#include "ir_utils.h" namespace tvm { namespace tir { @@ -42,10 +43,7 @@ class CopyIntrinInjector : public StmtMutator { flower_copy_fromto_(flower_copy_fromto) {} Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = op->value.as()->value; - } else if (op->attr_key == pragma_key_) { + if (op->attr_key == pragma_key_) { Stmt ret; ICHECK(MatchCopyPattern(op->body, &ret)) << "Cannot match copy pattern of " << op->body; return ret; @@ -148,30 +146,18 @@ class CopyIntrinInjector : public StmtMutator { dst_strides.push_back(make_const(DataType::Int(32), 1)); } Buffer dst = Buffer(store->buffer_var, store->value.dtype(), dst_shape, dst_strides, - store_strides[loop_var_size], store->buffer_var->name_hint, - GetStorageScope(store->buffer_var.get()), 0, 0, kDefault); + store_strides[loop_var_size], store->buffer_var->name_hint, 0, 0, kDefault); Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset, - load->buffer_var->name_hint, GetStorageScope(load->buffer_var.get()), 0, 0, - kDefault); + load->buffer_var->name_hint, 0, 0, kDefault); *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); ICHECK(out->defined()) << "flower function did not return correct stmt"; return true; } - // Get storage scope - std::string GetStorageScope(const VarNode* var) const { - auto it = storage_scope_.find(var); - if (it != storage_scope_.end()) { - return it->second; - } else { - return ""; - } - } + // pragma key std::string pragma_key_; // function to lower copy intrinsics. const PackedFunc& flower_copy_fromto_; - // Storage scope - std::unordered_map storage_scope_; // arith analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 7a16c06d8058..0b45bde28dfe 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -95,16 +95,7 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - auto it = dbuffer_info_.find(buf); - if (it != dbuffer_info_.end()) { - it->second.scope = op->value.as()->value; - return this->VisitStmt(op->body); - } else { - return StmtExprMutator::VisitStmt_(op); - } - } else if (op->attr_key == attr::double_buffer_scope) { + if (op->attr_key == attr::double_buffer_scope) { return MakeProducer(op); } else { return StmtExprMutator::VisitStmt_(op); @@ -112,8 +103,10 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AllocateNode* op) final { - auto it = dbuffer_info_.find(op->buffer_var.get()); + const VarNode* buf = op->buffer_var.as(); + auto it = dbuffer_info_.find(buf); if (it != dbuffer_info_.end()) { + it->second.scope = GetPtrStorageScope(op->buffer_var); it->second.stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes(); @@ -125,8 +118,6 @@ class DoubleBufferInjector : public StmtExprMutator { } ICHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; - alloc_nest.emplace_back( - AttrStmt(op->buffer_var, attr::storage_scope, StringImm(it->second.scope), Evaluate(0))); alloc_nest.emplace_back( Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0))); return op->body; diff --git a/src/tir/transforms/inject_prefetch.cc b/src/tir/transforms/inject_prefetch.cc index 4ce9c7639b77..f20577e3a01b 100644 --- a/src/tir/transforms/inject_prefetch.cc +++ b/src/tir/transforms/inject_prefetch.cc @@ -31,6 +31,8 @@ #include +#include "ir_utils.h" + namespace tvm { namespace tir { @@ -96,9 +98,14 @@ namespace transform { Pass InjectPrefetch() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - n->body = PrefetchInjector()(std::move(n->body)); - return f; + // Only apply this pass to TIR from TE schedules + if (IsFromLegacyTESchedule(f)) { + auto* n = f.CopyOnWrite(); + n->body = PrefetchInjector()(std::move(n->body)); + return f; + } else { + return f; + } }; return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {}); } diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 4ef10f326bb0..4964bec0334e 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -459,12 +459,12 @@ class VirtualThreadInjector : public StmtMutator { op = stmt.as(); if (op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); - bool allow_share = iv->thread_tag == "vthread"; + bool allow_share = std::string(iv->thread_tag).substr(0, 7) == "vthread"; int nthread = static_cast(op->value.as()->value); VarTouchedAnalysis vs; auto touched = vs.TouchedVar(op->body, iv->var.get()); - VTInjector injecter(iv->var, nthread, touched, allow_share); - return injecter(op->body); + VTInjector injector(iv->var, nthread, touched, allow_share); + return injector(op->body); } else { return stmt; } @@ -476,11 +476,6 @@ class VirtualThreadInjector : public StmtMutator { } }; -Stmt InjectVirtualThread(Stmt stmt) { - stmt = VirtualThreadInjector()(std::move(stmt)); - return ConvertSSA(std::move(stmt)); -} - namespace transform { Pass InjectVirtualThread() { diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index cbae3f95ec68..a41905c148bf 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -23,6 +23,7 @@ */ #include "ir_utils.h" +#include #include #include @@ -172,16 +173,6 @@ class IRConvertSSA final : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { if (const VarNode* v = op->node.as()) { - if (op->attr_key == attr::storage_scope) { - const AllocateNode* alloc = op->body.as(); - if (alloc && op->node.same_as(alloc->buffer_var)) { - Stmt new_alloc = this->VisitStmt(op->body); - if (new_alloc.same_as(op->body)) return GetRef(op); - alloc = new_alloc.as(); - ICHECK(alloc); - return AttrStmt(alloc->buffer_var, op->attr_key, op->value, new_alloc); - } - } Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(v) && scope_[v].size() != 0) { @@ -201,5 +192,62 @@ class IRConvertSSA final : public StmtExprMutator { Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } +String GetPtrStorageScope(Var buffer_var) { + const auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return ptr_type->storage_scope; +} + +Array ConvertIndices(const MatchBufferRegion& match_buffer, + const Array& indices) { + const Buffer& target = match_buffer->buffer; + const BufferRegion& source = match_buffer->source; + ICHECK_EQ(indices.size(), target->shape.size()); + + arith::Analyzer analyzer; + Array result; + result.reserve(source->region.size()); + size_t offset = source->region.size() - indices.size(); + for (size_t i = 0; i < offset; ++i) { + const Range& range = source->region[i]; + ICHECK(analyzer.CanProve(range->extent == 1)); + result.push_back(range->min); + } + for (size_t i = 0; i < indices.size(); ++i) { + const Range& range = source->region[i + offset]; + const PrimExpr& index = indices[i]; + result.push_back(range->min + index); + } + return result; +} + +Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region) { + const Buffer& target = match_buffer->buffer; + const BufferRegion& source = match_buffer->source; + ICHECK_EQ(region.size(), target->shape.size()); + + arith::Analyzer analyzer; + Region result; + result.reserve(source->region.size()); + size_t offset = source->region.size() - region.size(); + for (size_t i = 0; i < offset; ++i) { + const Range& source_range = source->region[i]; + ICHECK(analyzer.CanProve(source_range->extent == 1)); + result.push_back(Range::FromMinExtent(source_range->min, 1)); + } + for (size_t i = 0; i < region.size(); ++i) { + const Range& source_range = source->region[i + offset]; + const Range& target_range = region[i]; + result.push_back( + Range::FromMinExtent(source_range->min + target_range->min, target_range->extent)); + } + return result; +} + +Bool IsFromLegacyTESchedule(PrimFunc f) { + Optional from_legacy_te_schedule = f->GetAttr("from_legacy_te_schedule", Bool(false)); + return from_legacy_te_schedule.value(); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 906ff8a38b6c..9be18b790b41 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -191,6 +192,38 @@ inline PrimExpr StackAlloca(std::string type, size_t num) { */ Stmt ConvertSSA(Stmt stmt); +/*! + * \brief Return the storage scope associated with a buffer variable. + * \param buffer_var The input buffer variable. + * \return A string representing the storage scope of this buffer variable. + */ +String GetPtrStorageScope(Var buffer_var); + +/*! + * \brief Convert match buffer target buffer access indices to original one. + * \param indices The indices of the target buffer + * \return The indices of source buffer. + */ +Array ConvertIndices(const MatchBufferRegion& match_buffer, + const Array& indices); + +/*! + * \brief Convert match buffer target buffer region to original one. + * \param region The sub-region of the target buffer + * \return The region of source buffer. + */ +Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region); + +/*! + * \brief Check if a given PrimFunc originated from a TE schedule. + * + * Internally this checks for the `from_legacy_te_schedule` attr of the PrimFunc. + * + * \param f PrimFunc to check + * \return Whether or not the PrimFunc was created from a te schedule + */ +Bool IsFromLegacyTESchedule(PrimFunc f); + } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index f1d816f0baef..97f5b6f90a70 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -84,19 +85,6 @@ using Partition = std::unordered_map; -bool ExprUseVars(PrimExpr expr, const std::unordered_set& vars) { - bool success = false; - PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) { - if (const VarNode* v = node.as()) { - if (vars.count(v)) { - success = true; - return; - } - } - }); - return success; -} - // Select potential candidate IRs that can be partitioned. // Rule: // - the range should not be const @@ -200,7 +188,8 @@ class PartitionFinder : public StmtExprVisitor { } void VisitStmt_(const ForNode* op) final { - if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; + auto f_vset_contains = [this](const VarNode* var) { return out_vars_.count(var); }; + if (UsesVar(op->min, f_vset_contains) || UsesVar(op->extent, f_vset_contains)) return; const VarNode* var = op->loop_var.get(); hint_map_.insert({var, IntSet::Interval(op->min, op->min + op->extent - 1)}); @@ -230,7 +219,7 @@ class PartitionFinder : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::likely())) { PrimExpr cond = op->args[0]; - if (ExprUseVars(cond, std::unordered_set({current_var_.get()}))) { + if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) { // For cond, find out the interval, if exists, in which we can prove that cond is // true. Also find the interval, if exists, in which we can prove that cond is // false. diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 829b7d822d11..b4ec91ba5012 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -41,39 +41,22 @@ using runtime::StorageScope; class StorageAccessInfoLower : public StmtExprMutator { public: Stmt VisitStmt_(const AllocateNode* op) final { - // Lower allocate to device allocate when needed. - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - // For special memory, remove allocate, or use head expr - auto it = storage_info_.find(op->buffer_var.get()); - if (it != storage_info_.end() && it->second.info.defined()) { - const MemoryInfo& info = it->second.info; - ++it->second.alloc_count; - ICHECK_LE(it->second.alloc_count, 1) - << "Double allocation of " << it->second.scope.to_string(); + auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (scope.tag.length() != 0 && scope.tag != ".dyn") { + auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); + ICHECK(info.defined()) << "Cannot find memory info of " << scope.to_string(); + ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end()) + << "Double allocation of " << scope.to_string(); + storage_info_[op->buffer_var.get()] = info; + // Lower allocate to device allocate when needed. + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); if (info->head_address.defined()) { return LetStmt(op->buffer_var, info->head_address, op->body); } else { return op->body; } - } else { - return stmt; - } - } - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::Create(op->value.as()->value); - StorageEntry e; - e.scope = scope; - if (scope.tag.length() != 0) { - e.info = GetMemoryInfo(op->value.as()->value); - ICHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); - } - storage_info_[buf] = e; - return StmtExprMutator::VisitStmt_(op); - } else { return StmtExprMutator::VisitStmt_(op); } @@ -99,8 +82,8 @@ class StorageAccessInfoLower : public StmtExprMutator { Var buffer_var = Downcast(op->args[1]); PrimExpr offset = op->args[2]; auto it = storage_info_.find(buffer); - if (it != storage_info_.end() && it->second.info.defined()) { - return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second.info); + if (it != storage_info_.end() && it->second.defined()) { + return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second); } ICHECK(op->dtype.is_handle()); // Change to address_of @@ -118,17 +101,8 @@ class StorageAccessInfoLower : public StmtExprMutator { return cast(ptr_type, analyzer_.Simplify( offset / make_const(offset.dtype(), info->unit_bits / dtype_bits))); } - // The storage entry. - struct StorageEntry { - // Whether it is tagged memory. - StorageScope scope; - // The memory info if any. - MemoryInfo info; - // Allocation counter - int alloc_count{0}; - }; // The storage scope of each buffer - std::unordered_map storage_info_; + std::unordered_map storage_info_; // analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc index c8aca5195085..d8621ac3b3e6 100644 --- a/src/tir/transforms/lower_init_block.cc +++ b/src/tir/transforms/lower_init_block.cc @@ -25,6 +25,8 @@ #include #include +#include "ir_utils.h" + namespace tvm { namespace tir { @@ -63,9 +65,14 @@ class InitBlockLower : public StmtMutator { }; PrimFunc LowerInitBlock(PrimFunc func) { - auto fptr = func.CopyOnWrite(); - fptr->body = InitBlockLower()(std::move(fptr->body)); - return func; + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(func)) { + auto fptr = func.CopyOnWrite(); + fptr->body = InitBlockLower()(std::move(fptr->body)); + return func; + } else { + return func; + } } namespace transform { diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc new file mode 100644 index 000000000000..2f8fbe0ea6e7 --- /dev/null +++ b/src/tir/transforms/lower_match_buffer.cc @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_match_buffer.cc + * \brief The pass for lowering match_buffer. + */ + +#include +#include +#include +#include +#include + +#include "../ir/functor_common.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { +class MatchBufferLower : public StmtExprMutator { + public: + explicit MatchBufferLower(const PrimFunc& func) { + for (const Var& param : func->params) { + // Mark input var as const variable. + if (!param.dtype().is_handle()) var_map_.Set(param, param); + } + } + + private: + Stmt VisitStmt_(const BlockNode* op) final { + for (const MatchBufferRegion& match_buffer : op->match_buffers) { + CheckAndUpdateVarMap(match_buffer); + } + + Stmt stmt = StmtExprMutator ::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + Array reads = MutateArray( + op->reads, std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); + Array writes = MutateArray( + op->writes, std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); + + if (reads.same_as(op->reads) && writes.same_as(op->writes) && op->match_buffers.empty()) { + return stmt; + } else { + auto n = CopyOnWrite(op); + n->match_buffers = {}; + n->reads = std::move(reads); + n->writes = std::move(writes); + return Stmt(n); + } + } + + Stmt VisitStmt_(const ForNode* op) final { + analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + return StmtExprMutator::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + Var v = GetRef(op); + auto it = var_map_.find(v); + if (it != var_map_.end()) { + return (*it).second; + } else { + return std::move(v); + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + + auto it = match_buffers_.find(op->buffer); + if (it == match_buffers_.end()) { + return stmt; + } else { + const Buffer& buffer = (*it).first; + const BufferRegion& source = (*it).second; + + auto n = CopyOnWrite(op); + n->indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); + n->buffer = source->buffer; + return Stmt(n); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK(op != nullptr); + + auto it = match_buffers_.find(op->buffer); + if (it == match_buffers_.end()) { + return expr; + } else { + const Buffer& buffer = (*it).first; + const BufferRegion& source = (*it).second; + Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); + return BufferLoad(source->buffer, indices); + } + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + CHECK(var_map_.find(op->buffer_var) == var_map_.end()) + << "Load from buffer created by match_buffer is not allowed, but got: " << expr; + return expr; + } + + Stmt VisitStmt_(const StoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + CHECK(var_map_.find(op->buffer_var) == var_map_.end()) + << "Store from buffer created by match_buffer is not allowed, but got: " << stmt; + return stmt; + } + + BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + auto it = match_buffers_.find(buffer); + if (it == match_buffers_.end()) { + return buffer_region; + } else { + const BufferRegion& source = (*it).second; + Region region = ConvertRegion(MatchBufferRegion(buffer, source), buffer_region->region); + return BufferRegion(source->buffer, std::move(region)); + } + } + + private: + void CheckAndUpdateVarMap(const MatchBufferRegion& match_buffer) { + // Step.1. Check + const Buffer& buffer = match_buffer->buffer; + const BufferRegion& source = VisitBufferRegion(match_buffer->source); + const Buffer& source_buffer = source->buffer; + + // Step.1.1. Check scope & dtype + ICHECK_EQ(buffer.scope(), source_buffer.scope()) + << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << "vs." + << source_buffer.scope(); + ICHECK_EQ(buffer->dtype, source_buffer->dtype) + << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << "vs." + << source_buffer->dtype; + + // Step.1.2. Check data alignment + if (source_buffer->data_alignment % buffer->data_alignment != 0) { + LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " + << " required_alignment=" << buffer->data_alignment + << ", provided_alignment=" << source_buffer->data_alignment; + } + if (is_zero(buffer->elem_offset)) { + ICHECK(is_zero(source_buffer->elem_offset)) + << "Trying to bind a Buffer with offset into one without offset " + << " required elem_offset=" << buffer->elem_offset + << ", provided elem_offset=" << source_buffer->elem_offset; + } + + // Step.2. Update + match_buffers_.Set(buffer, source); + // Step.2.1. Update buffer data + Bind(buffer->data, source_buffer->data, buffer->name + ".data"); + + // Step.2.2. Update element offset + // Note we create Load via vload and try to reuse index calculate. + { + Array indices; + indices.reserve(source->region.size()); + for (const Range& range : source->region) { + indices.push_back(range->min); + } + + Load load = Downcast(source_buffer.vload(indices, source_buffer->dtype)); + Bind(buffer->elem_offset, load->index, buffer->name + ".elem_offset"); + CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) + << "The source elem_offset " << buffer->elem_offset + << " does not satisfy the offset_factor " << buffer->offset_factor << "."; + } + + // Step 2.3. Check and update strides + // Check if target buffer strides are defined + if (!buffer->strides.empty()) { + ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); + PrimExpr stride = make_const(DataType::Int(32), 1); + for (size_t i = buffer->shape.size(); i > 0; --i) { + const PrimExpr& shape = source_buffer->shape[i - 1]; + Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1)); + stride *= shape; + } + } + + // Step 2.4. Check and update shape + ICHECK(source->region.size() >= buffer->shape.size()); + size_t offset = source->region.size() - buffer->shape.size(); + for (size_t i = 0; i < buffer->shape.size(); ++i) { + const Range& range = source->region[i + offset]; + Bind(buffer->shape[i], range->extent, buffer->name + ".shape_" + std::to_string(i)); + } + } + + void Bind(const PrimExpr& arg, PrimExpr value, const std::string& arg_name = "argument") { + CHECK_EQ(arg.dtype(), value.dtype()) + << "The data type mismatched: " << arg->dtype << " vs. " << value->dtype; + // Handle recursive case + value = Substitute(std::move(value), var_map_); + if (arg->IsInstance()) { + Var v = Downcast(arg); + auto it = var_map_.find(v); + if (it == var_map_.end()) { + var_map_.Set(v, value); + analyzer_.Bind(v, value); + } else { + AssertBinding((*it).second, value, arg_name); + } + } else { + AssertBinding(arg, value, arg_name); + } + } + + void AssertBinding(const PrimExpr& lhs, const PrimExpr& rhs, + const std::string& arg_name = "argument") { + CHECK(analyzer_.CanProve(lhs == rhs)) << "The buffer match constraint for " << arg_name + << " unmet: " << lhs << "==" << rhs << "."; + } + + private: + /*! \brief Buffer region mapping. */ + Map match_buffers_; + /*! \brief Var mapping for buffer signature (data, strides, element_offset, etc.) */ + Map var_map_; + /*! \brief The analyzer */ + arith::Analyzer analyzer_; +}; + +PrimFunc LowerMatchBuffer(PrimFunc func) { + auto fptr = func.CopyOnWrite(); + fptr->body = MatchBufferLower(func)(std::move(fptr->body)); + return func; +} + +namespace transform { + +Pass LowerMatchBuffer() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + return LowerMatchBuffer(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerMatchBuffer", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchBuffer); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 9e536814fa12..481b1bfd4b19 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -33,10 +33,32 @@ #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" +#include "update_pointer_storage_scope.h" namespace tvm { namespace tir { +class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScope { + public: + explicit UpdatePointerStorageScopeAllReduce( + const std::unordered_map& new_storage_scopes) + : UpdatePointerStorageScope(new_storage_scopes) {} + + Stmt VisitStmt_(const AllocateNode* op) final { + auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); + auto new_scope = GetPtrStorageScope(remapped); + if (new_scope != GetPtrStorageScope(op->buffer_var)) { + Stmt body = StmtExprMutator::VisitStmt(op->body); + if (new_scope == "shared") { + // use volatile access to shared buffer. + body = AttrStmt(remapped, attr::volatile_scope, 1, body); + } + return Allocate(remapped, op->dtype, op->extents, op->condition, body); + } + return StmtExprMutator::VisitStmt_(op); + } +}; + class ThreadAllreduceBuilder final : public StmtExprMutator { public: explicit ThreadAllreduceBuilder(const TargetNode* target) @@ -48,15 +70,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); thread_extents_.pop_back(); return ret; - } else if (op->attr_key == attr::storage_scope) { - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); - const VarNode* v = op->node.as(); - if (alloc_remap_.count(v)) { - return op->body; - } else { - return ret; - } } else if (op->attr_key == attr::reduce_scope) { const CommReducerNode* combiner = op->node.as(); ICHECK(combiner); @@ -86,12 +99,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt); + new_storage_scopes_[repl->buffer_var.get()] = "local"; } else { - // use volatile access to shared buffer. - stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); - stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); + new_storage_scopes_[repl->buffer_var.get()] = "shared"; } return stmt; } else { @@ -108,6 +119,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } } + std::unordered_map new_storage_scopes_; + private: // Thread entry struct ThreadEntry { @@ -366,7 +379,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = var.as(); if (repl) { body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - body = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), body); + new_storage_scopes_[repl->buffer_var.get()] = "local"; } } @@ -590,7 +603,10 @@ Pass LowerThreadAllreduce() { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; const TargetNode* target_node = target.as(); - n->body = ThreadAllreduceBuilder(target_node)(n->body); + ThreadAllreduceBuilder thread_all_reduce(target_node); + auto reduce_body = thread_all_reduce(n->body); + n->body = + UpdatePointerStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 8b70817398e4..f5a553aa0598 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -98,6 +98,15 @@ class BuiltinLower : public StmtExprMutator { } } + Stmt VisitStmt_(const LetStmtNode* op) final { + if (const CallNode* call = op->value.as()) { + if (call->op.same_as(builtin::texture2d_alloca())) { + return StmtExprMutator::VisitStmt(MakeTextureAlloc(op, call)); + } + } + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const AllocateNode* op) { // Lower allocate to device allocate when needed. Stmt stmt = StmtExprMutator::VisitStmt_(op); @@ -341,6 +350,38 @@ class BuiltinLower : public StmtExprMutator { return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args); } + Stmt MakeTextureAlloc(const LetStmtNode* let, const CallNode* call) { + ICHECK(device_type_.defined()) << "Unknown device type in current IR"; + ICHECK(device_id_.defined()) << "Unknown device id in current IR"; + Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); + + Stmt body = SeqStmt( + {IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error), + let->body}); + DataType dtype = + let->var->type_annotation.as()->element_type.as()->dtype; + + std::string fdevapi_prefix = "device_api."; + fdevapi_prefix += runtime::DeviceName(device_type_.as()->value); + Call call_packed = + Call(let->var.dtype(), builtin::tvm_call_packed(), + {StringImm(fdevapi_prefix + ".AllocTexture"), cast(DataType::Int(32), device_type_), + cast(DataType::Int(32), device_id_), cast(DataType::UInt(64), call->args[0]), + cast(DataType::UInt(64), call->args[1]), IntImm(DataType::Int(32), dtype.code()), + IntImm(DataType::Int(32), dtype.bits())}); + + Stmt alloca = LetStmt(let->var, call_packed, body); + + Call free_op = + Call(DataType::Int(32), builtin::tvm_call_packed(), + {StringImm(fdevapi_prefix + ".FreeTexture"), cast(DataType::Int(32), device_type_), + cast(DataType::Int(32), device_id_), let->var}); + + Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); + body = SeqStmt({alloca, free_stmt}); + return body; + } + private: bool IsArrayHandle(const PrimExpr& arg) { // specially set array handle. diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index b95681a936ca..30ec148c37dd 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -40,6 +40,8 @@ #include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" +#include "update_pointer_storage_scope.h" namespace tvm { namespace tir { @@ -239,7 +241,8 @@ class WarpAccessRewriter : protected StmtExprMutator { if (op->buffer_var.get() == buffer_) { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); - return Store(op->buffer_var, op->value, local_index, op->predicate); + PrimExpr new_value = VisitExpr(op->value); + return Store(op->buffer_var, new_value, local_index, op->predicate); } else { return StmtExprMutator::VisitStmt_(op); } @@ -250,10 +253,13 @@ class WarpAccessRewriter : protected StmtExprMutator { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); // invariance: local index must do not contain warp id - ICHECK(!ExprUseVar(local_index, warp_index_)) + ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index << " local_index=" << local_index; PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate); + if (analyzer_->CanProveEqual(group, warp_index_)) { + return load_value; + } PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}); return Call(load_value.dtype(), builtin::tvm_warp_shuffle(), {mask, load_value, group, width_, warp_size_}); @@ -356,34 +362,21 @@ class WarpMemoryRewriter : private StmtMutator { return stmt; } + std::unordered_map new_storage_scopes_; + private: Stmt VisitStmt_(const AllocateNode* op) { auto ret = StmtMutator::VisitStmt_(op); op = ret.as(); - if (warp_buffer_.count(op->buffer_var.get())) { + if (GetPtrStorageScope(op->buffer_var) == "warp") { + new_storage_scopes_[op->buffer_var.get()] = "local"; WarpAccessRewriter rewriter(warp_size_, &analyzer_); ret = rewriter.Rewrite(op); } return ret; } - Stmt VisitStmt_(const AttrStmtNode* op) { - using runtime::StorageScope; - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::Create(op->value.as()->value); - if (scope.rank == runtime::StorageRank::kWarp) { - warp_buffer_.insert(buf); - Stmt ret = StmtMutator::VisitStmt_(op); - op = ret.as(); - return AttrStmt(op->node, op->attr_key, StringImm("local"), op->body); - } - } - return StmtMutator::VisitStmt_(op); - } - int warp_size_{0}; - std::unordered_set warp_buffer_; arith::Analyzer analyzer_; // variable domain std::unordered_map var_dom_; @@ -397,7 +390,9 @@ Pass LowerWarpMemory() { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); - n->body = WarpMemoryRewriter(warp_size).Rewrite(std::move(n->body)); + WarpMemoryRewriter warp_memory_rewriter(warp_size); + auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); + n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index ee52a6fc0988..393ce6c286b4 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -119,7 +119,12 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); ICHECK_LE(num_unpacked_args, num_args); - + bool pack_args = (num_unpacked_args == -1) || (num_args > num_unpacked_args); + if (num_unpacked_args == -1) { + // reset to zero + num_unpacked_args = 0; + } + ICHECK_GE(num_unpacked_args, 0); int num_packed_args = num_args - num_unpacked_args; // Data field definitions // The packed fields @@ -154,11 +159,10 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { } return res; }; - // --------------------------- // start of logics // add signiture for packed arguments. - if (num_packed_args != 0) { + if (pack_args) { args.push_back(v_packed_args); args.push_back(v_packed_arg_type_ids); args.push_back(v_num_packed_args); @@ -214,13 +218,13 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { } // allow return value if the function is packed. - if (num_packed_args != 0) { + if (pack_args) { args.push_back(v_out_ret_value); args.push_back(v_out_ret_tcode); args.push_back(v_resource_handle); } - size_t expected_nargs = num_unpacked_args + (num_packed_args != 0 ? 6 : 0); + size_t expected_nargs = num_unpacked_args + (pack_args ? 6 : 0); ICHECK_EQ(args.size(), expected_nargs); // Arg definitions are defined before buffer binding to avoid the use before @@ -282,6 +286,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { namespace transform { Pass MakePackedAPI(int num_unpacked_args) { + // packed arguments anyway while `num_unpacked_args` is -1 auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) { IRModuleNode* mptr = m.CopyOnWrite(); std::vector > updates; diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc new file mode 100644 index 000000000000..e8865b260dc1 --- /dev/null +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file merge_dynamic_shared_memory_allocations.cc + * \brief Each GPU kernel is allowed to have only one dynamic shared memory allocation. + * This pass merges multiple TIR-level dynamic shared memory allocations into one allocation. + */ +#include +#include +#include +#include + +#include +#include + +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +bool IsDynamicSharedMemory(Var buffer_var) { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); + return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn"; +} + +class AllocateCollector : public StmtExprVisitor { + public: + void VisitStmt_(const AllocateNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + dyn_shmem_allocs_.insert(op); + } + StmtExprVisitor::VisitStmt_(op); + } + + std::unordered_set dyn_shmem_allocs_; +}; + +class DynamicSharedMemoryRewriter : public StmtExprMutator { + public: + explicit DynamicSharedMemoryRewriter( + const std::unordered_set& dyn_shmem_allocs) + : dyn_shmem_allocs_{dyn_shmem_allocs} {} + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::thread_extent && !allocated) { + // Allocate one dynamic shared memory allocation at the beginning of thread scope + int align = 1; + for (const auto& alloc : dyn_shmem_allocs_) { + ICHECK_EQ(alloc->dtype.lanes(), 1) << "vector dtype allocation not supported."; + align = std::max(align, alloc->dtype.bytes()); + } + for (const auto& alloc : dyn_shmem_allocs_) { + ICHECK_EQ(alloc->extents.size(), 1); + buffer_byte_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; + merged_alloc_size_ += alloc->extents[0] * align; + } + + allocated = true; + auto new_body = Allocate(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, + const_true(), StmtExprMutator::VisitStmt(op->body)); + return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span); + } + return StmtMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const AllocateNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + return StmtExprMutator::VisitStmt(op->body); + } + return StmtExprMutator::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + auto offset = GetBufferOffset(op->buffer_var, op->dtype); + auto index = StmtExprMutator::VisitExpr(op->index); + return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span); + } + return StmtExprMutator::VisitExpr_(op); + } + + Stmt VisitStmt_(const StoreNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + auto offset = GetBufferOffset(op->buffer_var, op->value->dtype); + auto index = StmtExprMutator::VisitExpr(op->index); + auto value = StmtExprMutator::VisitExpr(op->value); + return Store(merged_buf_var_, value, offset + index, op->predicate, op->span); + } + return StmtExprMutator::VisitStmt_(op); + } + + private: + PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { + auto it = buffer_byte_offsets_.find(buffer_var.get()); + ICHECK(it != buffer_byte_offsets_.end()); + return indexdiv(it->second, dtype.bytes()); + } + + Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; + std::unordered_set dyn_shmem_allocs_; + PrimExpr merged_alloc_size_{0}; + std::unordered_map buffer_byte_offsets_; + bool allocated{false}; +}; + +Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) { + AllocateCollector collector; + collector(stmt); + if (collector.dyn_shmem_allocs_.size() > 1) { + return DynamicSharedMemoryRewriter(collector.dyn_shmem_allocs_)(std::move(stmt)); + } + return stmt; +} + +namespace transform { + +Pass MergeDynamicSharedMemoryAllocations() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.MergeDynamicSharedMemoryAllocations", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.MergeDynamicSharedMemoryAllocations") + .set_body_typed(MergeDynamicSharedMemoryAllocations); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 949c955b2dfe..59f9170786b6 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -26,6 +26,8 @@ #include #include +#include "ir_utils.h" + namespace tvm { namespace tir { @@ -73,8 +75,6 @@ class BufferAllocationLocator : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { ICHECK(!op->init.defined()); - bool is_root = is_root_; - is_root_ = false; Array alloc_buffers; auto it = alloc_buffers_.find(op); if (it != alloc_buffers_.end()) { @@ -83,11 +83,23 @@ class BufferAllocationLocator : public StmtExprMutator { buffer_data_to_buffer_.Set(buf->data, buf); } } + for (const MatchBufferRegion match_buffer : op->match_buffers) { + const Var& target_var = match_buffer->buffer->data; + const Var& source_var = match_buffer->source->buffer->data; + ICHECK(buffer_data_to_buffer_.count(source_var)); + buffer_data_to_buffer_.Set(target_var, match_buffer->buffer); + } Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); ICHECK(op != nullptr); - // Ignore buffer allocated inside the block when getting access region. + // No longer consider buffers created by match_buffer inside the block when updating access + // region. + for (const MatchBufferRegion match_buffer : op->match_buffers) { + const Var& target_var = match_buffer->buffer->data; + buffer_data_to_buffer_.erase(target_var); + } + // No longer consider buffers allocated inside the block when updating access region. if (it != alloc_buffers_.end()) { for (const Buffer& buf : it->second) { buffer_data_to_buffer_.erase(buf->data); @@ -96,12 +108,9 @@ class BufferAllocationLocator : public StmtExprMutator { ObjectPtr n = CopyOnWrite(op); n->alloc_buffers = std::move(alloc_buffers); - // The read/write regions of root block are always empty. - if (!is_root) { - // Recalculate block access region - CollectReadWrite(GetRef(op), &n->reads, &n->writes); - } - + // Erase buffer allocated inside the block from access region. + n->reads = RemoveRedundantBufferRegion(n->reads); + n->writes = RemoveRedundantBufferRegion(n->writes); return Stmt(n); } @@ -125,8 +134,18 @@ class BufferAllocationLocator : public StmtExprMutator { return std::move(realize); } + Array RemoveRedundantBufferRegion(const Array& region) const { + Array result; + for (const BufferRegion& buffer_region : region) { + if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) { + result.push_back(buffer_region); + } + } + return result; + } + void CollectReadWrite(const Block& block, Array* reads, - Array* writes) { + Array* writes) const { Array> access = GetBlockAccessRegion(block, buffer_data_to_buffer_); *reads = access[0]; *writes = access[1]; @@ -140,15 +159,18 @@ class BufferAllocationLocator : public StmtExprMutator { std::unordered_map> alloc_buffers_; /*! \brief The buffer already allocated during recursive visiting. */ Map buffer_data_to_buffer_; - /*! \brief indicate the whether the block is root. */ - bool is_root_{true}; }; PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) { - auto fptr = func.CopyOnWrite(); - BufferAllocationLocator locator(func); - fptr->body = locator(fptr->body); - return func; + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(func)) { + auto fptr = func.CopyOnWrite(); + BufferAllocationLocator locator(func); + fptr->body = locator(fptr->body); + return func; + } else { + return func; + } } namespace transform { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index f01d98707586..795ae9d6a73a 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -33,6 +33,9 @@ #include +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + namespace tvm { namespace tir { @@ -89,6 +92,17 @@ class VarUseDefAnalysis : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { this->HandleDef(op->buffer_var.get()); + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { + ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; + ICHECK_GT(op->extents.size(), 0); + dyn_shmem_size_ = op->extents[0]; + for (size_t i = 1; i < op->extents.size(); ++i) { + dyn_shmem_size_ *= op->extents[i]; + } + dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); + use_dyn_shmem_ = true; + } return StmtExprMutator::VisitStmt_(op); } @@ -175,6 +189,8 @@ class VarUseDefAnalysis : public StmtExprMutator { Array undefined_; Array thread_axis_; Array thread_extent_; + PrimExpr dyn_shmem_size_{0}; + bool use_dyn_shmem_{false}; std::unordered_map use_count_; std::unordered_map def_count_; @@ -262,6 +278,10 @@ class HostDeviceSplitter : public StmtMutator { WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, runtime::String(kernel_symbol)); device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); + if (m.use_dyn_shmem_) { + device_func = + WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); + } (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func); // generate calls to the device function @@ -273,6 +293,9 @@ class HostDeviceSplitter : public StmtMutator { for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } + if (m.use_dyn_shmem_) { + call_args.push_back(m.dyn_shmem_size_); + } return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args)); } diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 00002d3587db..0567c8613fcd 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -35,7 +35,7 @@ namespace tir { void StorageAccessVisitor::VisitExpr_(const LoadNode* op) { const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(buf); + StorageScope scope = GetScope(op->buffer_var); if (Enabled(buf, scope)) { ICHECK(allow_append_) << op << " " << scope.to_string(); AccessEntry e; @@ -56,7 +56,7 @@ void StorageAccessVisitor::VisitStmt_(const StoreNode* op) { ICHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(buf); + StorageScope scope = GetScope(op->buffer_var); if (Enabled(buf, scope)) { AccessEntry e; e.threads = env_threads(); @@ -90,11 +90,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { } void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::Create(op->value.as()->value); - StmtExprVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::double_buffer_write) { + if (op->attr_key == attr::double_buffer_write) { ICHECK(double_buffer_write_ == nullptr); double_buffer_write_ = op->node.as(); scope_.push_back(std::vector()); @@ -176,6 +172,7 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { scope_.pop_back(); if (op->else_case.defined()) { scope_.push_back(std::vector()); + this->VisitStmt(op->else_case); auto v = Summarize(std::move(scope_.back()), nullptr); scope_.pop_back(); s.access.insert(s.access.end(), v.begin(), v.end()); @@ -208,7 +205,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { PrimExpr offset = op->args[2]; PrimExpr extent = op->args[3]; const IntImmNode* flag = op->args[4].as(); - StorageScope scope = GetScope(buffer); + StorageScope scope = GetScope(GetRef(buffer)); // The buffer scope. if (Enabled(buffer, scope)) { ICHECK(allow_append_); @@ -244,12 +241,11 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { } } -StorageScope StorageAccessVisitor::GetScope(const VarNode* buf) const { - auto it = storage_scope_.find(buf); - StorageScope s; - s.rank = StorageRank::kGlobal; - if (it == storage_scope_.end()) return s; - return it->second; +StorageScope StorageAccessVisitor::GetScope(Var buffer_var) const { + if (buffer_var->type_annotation.as()) { + return StorageScope::Create(GetPtrStorageScope(buffer_var)); + } + return StorageScope(); // global by default } } // namespace tir diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index 663c570fd15c..9dc4c923b054 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -118,7 +118,7 @@ class StorageAccessVisitor : public StmtExprVisitor { * \brief Get the scope of the buffer array. * \return The scope of the final buffer array. */ - StorageScope GetScope(const VarNode* buf) const; + StorageScope GetScope(Var buffer_var) const; // access scope std::vector > scope_; @@ -135,8 +135,6 @@ class StorageAccessVisitor : public StmtExprVisitor { StmtEntry curr_stmt_; // The involving threads Array env_threads_; - // The storage scope of each buffer - std::unordered_map storage_scope_; }; } // namespace tir diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 43fc1f1ec53f..2c32cc7f0883 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -78,11 +78,7 @@ class StorageFlattener : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::double_buffer_scope && - op->node->IsInstance()) { + if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); auto it = buf_map_.find(buffer); @@ -156,10 +152,8 @@ class StorageFlattener : public StmtExprMutator { shape.push_back(r->extent); } // deduce current storage scope. - auto it = storage_scope_.find(op->buffer.get()); - ICHECK(it != storage_scope_.end()) << "Cannot find storage scope of " << op->buffer; StorageScope skey; - const std::string& strkey = it->second; + std::string strkey = GetPtrStorageScope(op->buffer->data); if (strkey.length() == 0) { if (curr_thread_scope_.size() != 0) { skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); @@ -167,7 +161,6 @@ class StorageFlattener : public StmtExprMutator { } else { skey = StorageScope::Create(strkey); } - // use small alignment for small arrays auto dtype = op->buffer->dtype; int32_t const_size = AllocateNode::constant_allocation_size(shape); @@ -200,9 +193,12 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } - e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation), - op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, - skey.to_string(), align, 0, kDefault); + auto* ptr_type = op->buffer->data->type_annotation.as(); + ICHECK(ptr_type); + auto new_var = + Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type, skey.to_string())); + e.buffer = Buffer(new_var, op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, + align, 0, kDefault); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); @@ -228,7 +224,6 @@ class StorageFlattener : public StmtExprMutator { ret = Allocate(e.buffer->data, storage_type, shape, make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } - ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, @@ -491,8 +486,6 @@ class StorageFlattener : public StmtExprMutator { std::unordered_map buf_map_; // Dimension alignment std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dim_align_; - // Storage scope - std::unordered_map storage_scope_; // The current thread scope. std::vector curr_thread_scope_; // Collects shapes. @@ -507,13 +500,19 @@ class StorageFlattener : public StmtExprMutator { }; PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { - auto fptr = func.CopyOnWrite(); - - IRVisitorWithAnalyzer bound_analyzer; - bound_analyzer(fptr->body); - fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, - &bound_analyzer)(std::move(fptr->body)); - return func; + // Only apply this pass to TIR from TE schedules + Optional from_legacy_te_schedule = func->GetAttr("from_legacy_te_schedule", Bool(false)); + if (from_legacy_te_schedule.value()) { + auto fptr = func.CopyOnWrite(); + + IRVisitorWithAnalyzer bound_analyzer; + bound_analyzer(fptr->body); + fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, + &bound_analyzer)(std::move(fptr->body)); + return func; + } else { + return func; + } } namespace transform { diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index c755576e2b88..592a6a33375e 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -37,6 +37,7 @@ #include #include "../../runtime/thread_storage_scope.h" +#include "../ir/buffer_common.h" #include "ir_utils.h" namespace tvm { @@ -75,8 +76,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { }; // The scope of each allocation struct AllocEntry { - // Scope used for allocation. - StorageScope storage_scope; // scope level size_t level{0}; // allocation stmt @@ -86,13 +85,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { size_t level = scope_.size(); const VarNode* buf = op->buffer_var.get(); - auto it = alloc_info_.find(buf); - ICHECK(it != alloc_info_.end()) << "Could not find buffer `" << buf->name_hint - << "` in the list of allocated buffers. Perhaps you are " - "missing a storage_scope attr for this buffer."; - ICHECK(it->second.alloc == nullptr); - it->second.alloc = op; - it->second.level = level; + alloc_info_[buf].alloc = op; + alloc_info_[buf].level = level; StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const StoreNode* op) final { @@ -180,10 +174,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == attr::virtual_thread) { VisitNewScope(op); - } else if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - alloc_info_[buf].storage_scope = StorageScope::Create(op->value.as()->value); - StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); } @@ -337,7 +327,14 @@ class InplaceOpVerifier : public StmtExprVisitor { const StoreNode* store_{nullptr}; }; -// Planner to plan and rewrite memory allocation. +/* \brief Rewrite and merge memory allocation. + * + * Using LinearAccessPatternFinder, determines which buffers could share an + * allocation. This includes both sequential usage of the same buffer and + * merging small allocations at the same scope into a single larger allocation. + * The merging of small allocations requires the codegen to cast the resulting + * value from the storage type to the output type after access. + */ class StoragePlanRewriter : public StmtExprMutator { public: using StmtEntry = LinearAccessPatternFinder::StmtEntry; @@ -409,10 +406,8 @@ class StoragePlanRewriter : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || - attr::IsPragmaKey(op->attr_key)) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || + attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; @@ -496,8 +491,6 @@ class StoragePlanRewriter : public StmtExprMutator { std::vector nest; for (StorageEntry* e : svec) { if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope, - StringImm(e->scope.to_string()), Evaluate(0))); nest.push_back(e->new_alloc); } } @@ -523,7 +516,7 @@ class StoragePlanRewriter : public StmtExprMutator { // try to find merge, for tagged memory for (size_t i = 0; i < vec.size(); ++i) { StorageEntry* e = vec[i]; - if (e->scope.tag.length() != 0) { + if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { ICHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size"; for (size_t j = 0; j < i; ++j) { if (e->scope == vec[j]->scope) { @@ -557,7 +550,7 @@ class StoragePlanRewriter : public StmtExprMutator { make_const(DataType::Int(32), 1), e->allocs[0]->extents); e->new_alloc = Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0)); - if (e->scope.tag.length() != 0) { + if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) @@ -598,7 +591,7 @@ class StoragePlanRewriter : public StmtExprMutator { combo_size = analyzer_.Simplify(combo_size); e->new_alloc = Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0)); - if (e->scope.tag.length() != 0) { + if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) @@ -716,7 +709,8 @@ class StoragePlanRewriter : public StmtExprMutator { for (const VarNode* var : it->second.gen) { ICHECK(alloc_info.count(var)); - const AllocEntry& ae = alloc_info.at(var); + const AllocateNode* alloc = alloc_info.at(var).alloc; + auto storage_scope = StorageScope::Create(GetPtrStorageScope(GetRef(var))); StorageEntry* dst_entry = nullptr; // inplace detection if (detect_inplace) { @@ -726,13 +720,12 @@ class StoragePlanRewriter : public StmtExprMutator { if (!inplace_flag.count(src) && alloc_map_.count(src)) { InplaceOpVerifier visitor; StorageEntry* src_entry = alloc_map_.at(src); - if (src_entry->scope == ae.storage_scope && + if (src_entry->scope == storage_scope && src_entry->attach_scope_ == thread_scope_ && - src_entry->elem_type == ae.alloc->dtype.element_of() && + src_entry->elem_type == alloc->dtype.element_of() && visitor.Check(s.stmt, var, src)) { - uint64_t const_nbits = - static_cast(ae.alloc->constant_allocation_size()) * - ae.alloc->dtype.bits() * ae.alloc->dtype.lanes(); + uint64_t const_nbits = static_cast(alloc->constant_allocation_size()) * + alloc->dtype.bits() * alloc->dtype.lanes(); if (src_entry->const_nbits == const_nbits && !inplace_found) { // successfully inplace dst_entry = src_entry; @@ -744,9 +737,9 @@ class StoragePlanRewriter : public StmtExprMutator { } } if (dst_entry == nullptr) { - dst_entry = FindAlloc(ae.alloc, thread_scope_, ae.storage_scope); + dst_entry = FindAlloc(alloc, thread_scope_, storage_scope); } - dst_entry->allocs.emplace_back(ae.alloc); + dst_entry->allocs.emplace_back(alloc); alloc_map_[var] = dst_entry; } } @@ -896,107 +889,547 @@ class StoragePlanRewriter : public StmtExprMutator { arith::Analyzer analyzer_; }; -// Turn alloc into vector alloc -// if all its access is the same vector type. -class VectorAllocRewriter : public StmtExprMutator { +/* Helper struct containing information on how a buffer is declared and used + * + */ +struct BufferVarInfo { + enum DeclarationLocation { + kPrimFuncParam = (1 << 0), + kPrimFuncBufferMap = (1 << 1), + kAllocateNode = (1 << 2), + kLetNode = (1 << 3), + }; + + // The tir::Var that represents this buffer. + Var var; + + // The data type of an element of the buffer. + DataType element_dtype; + + /* The extent of the buffer. + * + * If multidimensional, the extent of the last dimension of the buffer. If the + * size is unknown (e.g. pointer arguments to PrimFunc with no corresponding + * entry in buffer_map), then extent is zero. + */ + PrimExpr extent; + + // Where the buffer was declared + DeclarationLocation declaration_location; + + // When accessed, which element type is it accessed as. This may + // differ both in base type (e.g. int32* cast to float32* after + // packing in StorageRewrite) or in number of lanes (e.g. float16* + // cast to float16x4*). + std::unordered_set access_dtype; + + DataType get_preferred_dtype() const { + std::unordered_set base_access_dtype; + for (auto dtype : access_dtype) { + base_access_dtype.insert(dtype.element_of()); + } + // If the array is accessed as multiple base types within a + // function, no point in changing the declared type. CodeGenC can + // handle this with a type-cast prior to indexing. Vulkan will + // raise an error at code-gen time, if a later pass doesn't split + // it out. + if (base_access_dtype.size() != 1) { + return element_dtype; + } + + DataType preferred_base_type = *base_access_dtype.begin(); + + // If there is only one vectorizable size used to access the + // buffer, and if that access size is compatible with the array + // size, then the buffer is vectorizable. In the future, this + // could be improved to allow vectorized buffer access of size + // GCD(*lanes_used), if necessary. + int preferred_lanes = element_dtype.lanes(); + if ((element_dtype.lanes() == 1) && (access_dtype.size() == 1)) { + arith::Analyzer analyzer_; + arith::ModularSet me = analyzer_.modular_set(extent); + + int lanes = access_dtype.begin()->lanes(); + if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) { + preferred_lanes = lanes; + } + } + + return preferred_base_type.with_lanes(preferred_lanes); + } +}; + +/* Checks whether buffers are accessed as scalar or vector parameters in a + * function. + * + */ +class VectorTypeAccessChecker : public StmtExprVisitor { public: - PrimExpr VisitExpr_(const LoadNode* op) final { - UpdateTypeMap(op->buffer_var.get(), op->dtype); - return StmtExprMutator::VisitExpr_(op); + /* Constructor + * + * @param params The parameters passed to a PrimFunc + * + * @param buffer_map The buffer_map associated with a PrimFunc + * + * @param allow_untyped_handles If a buffer or pointer variable is + * missing a type annotation, assume that it has the same underlying + * type as it is later accessed, with scalar element types. + */ + VectorTypeAccessChecker(const Array& params, const Map& buffer_map, + bool allow_untyped_pointers = false) + : allow_untyped_pointers_(allow_untyped_pointers) { + // If a parameter is in the buffer map, we want to track the + // version in the map. + for (auto it : buffer_map) { + Buffer& buffer = it.second; + Var buffer_var = buffer->data; + DataType dtype = buffer->dtype; + PrimExpr extent = buffer->shape.size() ? buffer->shape[buffer->shape.size() - 1] : 0; + OnArrayDeclaration(buffer_var, dtype, extent, BufferVarInfo::kPrimFuncParam); + } + + // If a pointer parameter isn't in the buffer map, then we want to + // track the parameter itself. + for (Var buffer_var : params) { + auto pointer_type = GetPointerType(buffer_var->type_annotation); + if (pointer_type.first && (buffer_map.count(buffer_var) == 0)) { + DataType dtype = pointer_type.second; + PrimExpr extent = 0; + OnArrayDeclaration(buffer_var, dtype, extent, BufferVarInfo::kPrimFuncBufferMap); + } + } } - Stmt VisitStmt_(const StoreNode* op) final { - UpdateTypeMap(op->buffer_var.get(), op->value.dtype()); - return StmtExprMutator::VisitStmt_(op); + void VisitExpr_(const LoadNode* op) final { + OnArrayAccess(op->dtype, op->buffer_var.get(), op->index, op->predicate); + StmtExprVisitor::VisitExpr_(op); } - PrimExpr VisitExpr_(const CallNode* op) final { + + void VisitStmt_(const StoreNode* op) final { + OnArrayAccess(op->value.dtype(), op->buffer_var.get(), op->index, op->predicate); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); - UpdateTypeMap(buffer, dtype); + PrimExpr index = op->args[2]; + OnArrayAccess(dtype, buffer, index, const_true(dtype.lanes())); } - return StmtExprMutator::VisitExpr_(op); + StmtExprVisitor::VisitExpr_(op); } - Stmt VisitStmt_(const AllocateNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - const auto& tvec = acc_map_[op->buffer_var.get()]; - - if (tvec.size() == 1 && tvec[0].element_of() == op->dtype.element_of() && - tvec[0].lanes() % op->dtype.lanes() == 0 && tvec[0].lanes() != op->dtype.lanes()) { - int factor = tvec[0].lanes() / op->dtype.lanes(); - Array extents = op->extents; - arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]); - if (me->base % factor == 0 && me->coeff % factor == 0) { - extents.Set(extents.size() - 1, - extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); - // create a new buffer var - DataType new_dtype = tvec[0]; - Var new_buffer_var(op->buffer_var->name_hint, PointerType(PrimType(new_dtype))); - // update the remap req. - var_remap_.Set(op->buffer_var, new_buffer_var); - return Allocate(new_buffer_var, new_dtype, extents, op->condition, op->body); + void VisitStmt_(const AllocateNode* op) final { + const Array& extents = op->extents; + PrimExpr extent = extents[extents.size() - 1]; + OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateNode); + + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const LetNode* op) final { + HandleLetNode(op->var); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const LetStmtNode* op) final { + HandleLetNode(op->var); + StmtExprVisitor::VisitStmt_(op); + } + + void HandleLetNode(Var let_var) { + if (let_var->dtype.is_handle()) { + auto pointer_type = GetPointerType(let_var->type_annotation); + if (pointer_type.first) { + OnArrayDeclaration(let_var, pointer_type.second, 0, BufferVarInfo::kLetNode); + } else if (allow_untyped_pointers_) { + OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode); + } else { + LOG(FATAL) << "Let statement of variable " << let_var->name_hint + << " is missing a type annotation, " + << "or type annotation is not a pointer to primitive"; } } - return stmt; } - void UpdateTypeMap(const VarNode* buffer, DataType t) { - auto& tvec = acc_map_[buffer]; - if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) { - tvec.push_back(t); + /* Update the type map for a buffer based on its declaration + * + * @param buffer The VarNode representing the buffer. + * + * @param element_dtype The dtype of a single element of the buffer. + * If unknown, when used with the allow_untyped_handles option, + * should be a handle dtype. + * + * @param extent The extent of the buffer. Zero if size is unknown. + * + * @param declaration_location How the buffer was allocated, so that + * some locations can be rewritten without others. + */ + void OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent, + BufferVarInfo::DeclarationLocation declaration_location) { + ICHECK(info_map_.find(buffer.get()) == info_map_.end()) + << "Array declaration of " << buffer->name_hint << " occurred multiple times."; + + if (element_dtype == DataType::Bool()) { + element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes()); + } + + info_map_[buffer.get()] = {buffer, element_dtype, extent, declaration_location}; + } + + /* Update the type map for a buffer based on its usage + * + * @param value_dtype The dtype of the value being stored to or + * loaded from the buffer. + * + * @param buffer The VarNode representing the buffer. + * + * @param index The index at which the value is being stored/loaded. + * + * @param predicate The predicate used for the store/load. + */ + void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const PrimExpr& index, + const PrimExpr& predicate) { + auto it = info_map_.find(buffer); + ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer + << ") occurred before its declaration."; + BufferVarInfo& var_info = it->second; + + if (value_dtype.element_of() == DataType::Bool()) { + value_dtype = DataType::Int(8).with_lanes(value_dtype.lanes()); + } + + if (var_info.element_dtype.is_handle()) { + ICHECK(allow_untyped_pointers_) << "Variable " << buffer->name_hint + << " was missing a type annotation in its declaration"; + var_info.element_dtype = value_dtype.element_of(); + } + + DataType access_dtype = value_dtype; + + int lanes_used = var_info.element_dtype.lanes(); + + // This can happen due to a previous pass that had rewrite_store_load = + // false. This occurs from the StorageRewrite in tvm::lower, followed by the + // PointerValueTypeRewrite in BuildSPIRV. The rewrite_store_load = false is + // necessary because the C-based codegens do not yet support vectorized + // pointer types (e.g. float16x4*). Once they do, this if statement should + // instead be replaced by the below ICHECK_EQ. + if (index.dtype().lanes() * var_info.element_dtype.lanes() != value_dtype.lanes()) { + ICHECK_EQ(index.dtype().lanes(), value_dtype.lanes()); + lanes_used = 1; + var_info.element_dtype = var_info.element_dtype.with_lanes(1); + } + + // TODO(Lunderberg): Uncomment this check once it can be applied. + // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 + // for discussion. + + // ICHECK_EQ(index.dtype().lanes() * var_info.element_dtype.lanes(), value_dtype.lanes()) + // << "Attempting to retrieve " << value_dtype.lanes() << " lanes of data with " + // << index.dtype().lanes() << " indices into an array whose elements have " + // << var_info.element_dtype.lanes() << " lanes. " + // << "Expected output with " << index.dtype().lanes() * var_info.element_dtype.lanes() + // << " lanes."; + + // If the index is a RampNode with stride of 1 and offset + // divisible by the number of number of lanes, and the predicate + // does not apply any masking, then this array access could be + // vectorized. + const RampNode* ramp_index = index.as(); + if (ramp_index && is_one(ramp_index->stride) && is_one(predicate)) { + arith::ModularSet me = analyzer_.modular_set(ramp_index->base); + if ((me->coeff % ramp_index->lanes == 0) && (me->base % ramp_index->lanes == 0)) { + lanes_used = ramp_index->lanes; + } } + + var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used)); } - // Internal access map - std::unordered_map > acc_map_; - // Variables to remap - Map var_remap_; + // Map of buffer variable information determined + std::unordered_map info_map_; + + // + bool allow_untyped_pointers_{false}; + // internal analyzer arith::Analyzer analyzer_; }; -PrimFunc PointerValueTypeRewrite(PrimFunc f) { - auto* n = f.CopyOnWrite(); - VectorAllocRewriter rewriter; - n->body = rewriter(std::move(n->body)); +/* \brief Rewrites buffer/pointer variables from scalar types to vectorized + * types. + * + * Some runtimes do not allow casting between composite types and the underlying + * base type (e.g. Vulkan, casting from 1-lane float16* to 4-lane float16x4*). + * In these cases, in order to have vectorized load/store on an array, the + * element type of that array must be vectorized. This is in contrast to C-style + * runtimes, in which `float16x4* vec = *(float16x4*)(float_arr + offset)` is + * valid. + * + * By default, VectorTypeRewriter will attempt to rewrite all buffer variables to + * vectorized access, if the load/store occurring in the PrimFunc are all + * vectorized. This includes adjusting the indices being used to access the + * array. (e.g. If `float16* scalar_arr` is being converted to `float16x4* + * vec_arr`, then `scalar_arr[Ramp(offset, 1, 4)]` will be converted to + * `vec_arr[offset/4]`.) + * + * Currently, several of the C-style runtimes do not support buffers whose + * elements are vectorized types, or rely on the presence of the Ramp nodes to + * identify vectorized loads. The boolean parameters in the constructor are to + * mimic the previous behavior of VectorTypeRewriter, to avoid breaking these + * runtimes. Once all runtimes support vectorized buffer elements, these + * parameters can be removed. + */ +class VectorTypeRewriter : public StmtExprMutator { + public: + /* Constructor + * + * @param checker The VectorTypeAccessChecker that has previously read out + * information from the PrimFunc + * + * @param rewrite_params Whether pointer-type parameters passed into the + * function should be rewritten from scalar types to vectorized types. + * + * @param rewrite_buffer_map Whether buffers present in the buffer_map should + * have their data variable be rewritten from scalar types to vectorized types. + * + * @param rewrite_allocate_node Whether the buffer variable associated with + * AllocateNodes should be rewritten from scalar types to vectorized types. + * + * @param rewrite_indices Whether the indices to the Load and Store nodes + * should be rewritten to correspond to the new buffer_var type. + * + * @param rewrite_let_node Whether pointer declarations in let nodes + * should be re-written. + */ + VectorTypeRewriter(const std::unordered_map& info_map, + bool rewrite_params = true, bool rewrite_buffer_map = true, + bool rewrite_allocate_node = true, bool rewrite_indices = true, + bool rewrite_let_node = true) + : rewrite_indices_(rewrite_indices) { + int rewrite_mask = 0; + if (rewrite_params) { + rewrite_mask |= BufferVarInfo::kPrimFuncParam; + } + if (rewrite_buffer_map) { + rewrite_mask |= BufferVarInfo::kPrimFuncBufferMap; + } + if (rewrite_allocate_node) { + rewrite_mask |= BufferVarInfo::kAllocateNode; + } + if (rewrite_let_node) { + rewrite_mask |= BufferVarInfo::kLetNode; + } + + // Rewrite any buffer variables whose preferred type isn't their current type. + for (const auto& pair : info_map) { + const auto& var_info = pair.second; + DataType preferred = var_info.get_preferred_dtype(); + if (preferred != var_info.element_dtype && (rewrite_mask & var_info.declaration_location)) { + Var old_buffer_var = var_info.var; + Var new_buffer_var(old_buffer_var->name_hint, + PointerType(PrimType(preferred), GetPtrStorageScope(old_buffer_var)), + old_buffer_var->span); + + rewrite_map_[var_info.var.get()] = {var_info.var, new_buffer_var, var_info.element_dtype, + preferred}; + } + } + } - Map var_remap = std::move(rewriter.var_remap_); - Array args; + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); - // rewrite paramters if needed. - for (Var var : f->params) { - if (var.dtype().is_handle()) { - const auto& tvec = rewriter.acc_map_[var.get()]; + if (!rewrite_indices_) { + return expr; + } - if (tvec.size() == 1) { - tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0]))); - args.push_back(new_var); - var_remap.Set(var, new_var); - } else { - // always set data type to be non vectorized so - // load/store can still work via scalarization - if (tvec.size() != 0 && !var->type_annotation.defined()) { - tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0].with_lanes(1)))); - args.push_back(new_var); - var_remap.Set(var, new_var); - } else { - args.push_back(var); - } + auto it = rewrite_map_.find(op->buffer_var.get()); + if (it == rewrite_map_.end()) { + return expr; + } + const auto& info = it->second; + + DataType out_dtype_base = info.new_element_dtype.element_of(); + + const RampNode* ramp_index = op->index.as(); + if (ramp_index && is_one(ramp_index->stride)) { + PrimExpr new_index = + ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); + return Load(out_dtype_base.with_lanes(op->dtype.lanes()), info.new_buffer_var, new_index, + const_true(new_index.dtype().lanes()), op->span); + } else { + return Load(out_dtype_base, info.new_buffer_var, op->index, op->predicate); + } + } + + Stmt VisitStmt_(const StoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + + if (!rewrite_indices_) { + return stmt; + } + + auto it = rewrite_map_.find(op->buffer_var.get()); + if (it == rewrite_map_.end()) { + return stmt; + } + const auto& info = it->second; + + const RampNode* ramp_index = op->index.as(); + if (ramp_index && is_one(ramp_index->stride)) { + PrimExpr new_index = + ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); + return Store(info.new_buffer_var, op->value, new_index, const_true(new_index.dtype().lanes()), + op->span); + } else { + return Store(info.new_buffer_var, op->value, op->index, op->predicate, op->span); + } + } + + PrimExpr VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + + if (!rewrite_indices_) { + return expr; + } + + const VarNode* buffer_var = op->args[1].as(); + auto it = rewrite_map_.find(buffer_var); + if (it == rewrite_map_.end()) { + return expr; } + const auto& info = it->second; + + PrimExpr index = op->args[2]; + PrimExpr extent = op->args[3]; + PrimExpr flag = op->args[4]; + + PrimExpr e_dtype = tir::TypeAnnotation(info.new_element_dtype); + PrimExpr factor = make_const(extent.dtype(), info.new_element_dtype.lanes()); + extent = extent / factor; + index = index / factor; + Array acc_args{e_dtype, info.new_buffer_var, index, extent, flag}; + return Call(info.new_element_dtype, builtin::tvm_access_ptr(), acc_args); + } else { - args.push_back(var); + return StmtExprMutator::VisitExpr_(op); + } + } + + Stmt VisitStmt_(const AllocateNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + + auto it = rewrite_map_.find(op->buffer_var.get()); + if (it == rewrite_map_.end()) { + return stmt; + } + + const auto& info = it->second; + + Var new_buffer_var = info.new_buffer_var; + + int factor = info.new_element_dtype.lanes() / op->dtype.lanes(); + + Array extents = op->extents; + extents.Set(extents.size() - 1, + extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); + return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body); + } + + /* Update the parameters and all remaining variable references + * + * Should be called after calling operator() on the body of the + * function. + * + * @param func A pointer to the PrimFunc being modified. + */ + void Finalize(PrimFunc* func_ptr) const { + ICHECK(func_ptr) << "Finalize expects a non-null pointer"; + auto& func = *func_ptr; + auto* n = func.CopyOnWrite(); + + // Remap any remaining references to the old buffer variables + Map var_remap; + for (const auto& pair : rewrite_map_) { + const auto& info = pair.second; + var_remap.Set(info.old_buffer_var, info.new_buffer_var); + } + n->body = Substitute(n->body, var_remap); + + // Remap the argument list to use the new buffer variables. + Array new_params; + for (const auto& old_param : n->params) { + auto it = rewrite_map_.find(old_param.get()); + if (it == rewrite_map_.end()) { + new_params.push_back(old_param); + } else { + const auto& info = it->second; + new_params.push_back(info.new_buffer_var); + } + } + n->params = new_params; + + // Remap the Buffer objects in so that the buffers use the new buffer variables + Map new_buffer_map; + for (const auto& pair : n->buffer_map) { + Var key = pair.first; + Buffer old_buffer = pair.second; + Var old_var = old_buffer->data; + + auto it = rewrite_map_.find(old_var.get()); + if (it == rewrite_map_.end()) { + new_buffer_map.Set(key, old_buffer); + } else { + auto& info = it->second; + int factor = info.new_element_dtype.lanes() / info.old_element_dtype.lanes(); + ICHECK_EQ(factor * info.new_element_dtype.lanes(), info.old_element_dtype.lanes()); + + auto* buffer_cow = old_buffer.CopyOnWrite(); + buffer_cow->data = info.new_buffer_var; + buffer_cow->dtype = info.new_element_dtype; + size_t ndim = buffer_cow->shape.size(); + const auto& last_dim = buffer_cow->shape[ndim - 1]; + buffer_cow->shape.Set(ndim - 1, last_dim / make_const(last_dim.dtype(), factor)); + new_buffer_map.Set(key, old_buffer); + } } + n->buffer_map = new_buffer_map; } - // no variable remap is needed. - if (var_remap.size() == 0) return f; + private: + struct RewriteInfo { + Var old_buffer_var; + Var new_buffer_var; + DataType old_element_dtype; + DataType new_element_dtype; + }; + + bool rewrite_indices_{true}; + std::unordered_map rewrite_map_; +}; + +// Rewrite allocates, pointer parameters, and buffer map into vectorized versions +// if each access into a buffer is the same vector type. +PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false, + bool rewrite_params = true, bool rewrite_buffer_map = true, + bool rewrite_allocate_node = true, bool rewrite_indices = true, + bool rewrite_let_node = true) { + VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers); + checker(f->body); + + VectorTypeRewriter rewriter(checker.info_map_, rewrite_params, rewrite_buffer_map, + rewrite_allocate_node, rewrite_indices, rewrite_let_node); + PrimFuncNode* n = f.CopyOnWrite(); + n->body = rewriter(std::move(n->body)); + rewriter.Finalize(&f); - // remap the variables. - ICHECK_EQ(args.size(), n->params.size()); - n->params = args; - n->body = Substitute(n->body, var_remap); return f; } @@ -1006,7 +1439,7 @@ Pass StorageRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true); - return PointerValueTypeRewrite(std::move(f)); + return PointerValueTypeRewrite(std::move(f), true, false, false, true, false, true); }; return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index d0f58074ada0..1836b8ecec0d 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -69,7 +69,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(k); ICHECK(layout); - std::string scope = scopes[buffer_var]; + std::string scope = GetPtrStorageScope(GetRef(buffer_var)); if (fragments.count(buffer_var)) { // check if the fragment has met before FragmentInfo info = fragments[buffer_var]; @@ -102,7 +102,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(n); ICHECK(k); - std::string scope = scopes[buffer_var]; + std::string scope = GetPtrStorageScope(GetRef(buffer_var)); // Only wmma.accumulator can use tvm_fill_fragment ICHECK_EQ(scope, "wmma.accumulator"); if (fragments.count(buffer_var)) { @@ -119,16 +119,9 @@ class FragmentGetter : public StmtExprVisitor { // Get memory scope void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buffer = op->node.as(); - ICHECK(buffer); - scopes[buffer] = op->value.as()->value; - } StmtExprVisitor::VisitStmt_(op); } - // Memory scope for allocations - std::unordered_map scopes; // Fragment metadata for all fragments std::unordered_map fragments; }; diff --git a/src/tir/transforms/texture_flatten.cc b/src/tir/transforms/texture_flatten.cc new file mode 100644 index 000000000000..7dc800737944 --- /dev/null +++ b/src/tir/transforms/texture_flatten.cc @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file texture_flatten.cc + * \brief Flattens texture storage from multi-dimensional array + * to 2D (width, height) buffer access + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "../../arith/ir_visitor_with_analyzer.h" +#include "../../runtime/texture.h" +#include "../../runtime/thread_storage_scope.h" + +namespace tvm { +namespace tir { +using runtime::ApplyTexture2DFlattening; +using runtime::DefaultTextureLayoutSeparator; +using runtime::IsTextureStorage; + +class TextureLoweringBase : public StmtExprMutator { + public: + explicit TextureLoweringBase(const Map& extern_buffer_map, + IRVisitorWithAnalyzer* bound_analyzer) + : bound_analyzer_{bound_analyzer} { + for (auto kv : extern_buffer_map) { + extern_buf_.insert(kv.second); + } + } + + inline PrimExpr SimplifyOffset(const Array& shape, const Array& index) const { + PrimExpr base = make_const(DataType::Int(32), 0); + ICHECK_EQ(shape.size(), index.size()); + if (index.size() > 0) { + PrimExpr offset = index[0]; + for (size_t i = 1; i < index.size(); ++i) { + offset = bound_analyzer_->Simplify(offset * shape[i] + index[i]); + } + base = base + offset; + } + return base; + } + + protected: + std::string GetStorageScope(const Buffer& buffer) { + auto* ptr = buffer->data->type_annotation.as(); + ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; + return ptr->storage_scope; + } + + // Set of all external input and output buffers + std::unordered_set extern_buf_; + // Bound analzer + IRVisitorWithAnalyzer* bound_analyzer_; +}; + +// Lower Nd storage access to 2d texture access using lowering convention +// specified by the buffers storage scope. +class TextureFlattener : public TextureLoweringBase { + public: + using StmtExprMutator::VisitStmt_; + explicit TextureFlattener(const Map& extern_buffer_map, + IRVisitorWithAnalyzer* bound_analyzer) + : TextureLoweringBase(extern_buffer_map, bound_analyzer) {} + + Stmt VisitStmt_(const BufferRealizeNode* op) final { + if (extern_buf_.count(op->buffer)) { + return this->VisitStmt(op->body); + } + + std::string storage_scope = GetStorageScope(op->buffer); + Var buffer_var(op->buffer->data->name_hint, + PointerType(PrimType(op->buffer->dtype), String(storage_scope))); + let_binding_.insert({op->buffer->data, buffer_var}); + + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + + // Rewrite any buffer realizations with storage scope to 2d texture allocations + if (IsTextureStorage(storage_scope)) { + Stmt body = this->VisitStmt(op->body); + ICHECK(op->bounds.size() >= 3) << "Only 2d RGBA texture is currently supported"; + int vec_length = static_cast(op->bounds.back()->extent.as()->value); + ICHECK(vec_length == 4 || vec_length == 1) + << "Inner dimension of texture must be vector of length 1 or 4 (RGBA)"; + + struct ShapeFromRange { + const Array& bounds; + PrimExpr operator[](size_t i) const { return bounds[i]->extent; } + }; + size_t axis = DefaultTextureLayoutSeparator(op->bounds.size(), storage_scope); + auto texture = + ApplyTexture2DFlattening(ShapeFromRange{op->bounds}, op->bounds.size(), axis); + Array args = {texture.width, texture.height}; + stmt = LetStmt(buffer_var, Call(buffer_var.dtype(), builtin::texture2d_alloca(), args), body); + } + + return stmt; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + std::string storage_scope = GetStorageScope(op->buffer); + // Lower to two dimensional access + if (IsTextureStorage(storage_scope)) { + Array args = GetTextureAccessArgs(op, op->buffer); + args.push_back(op->value); + stmt = Evaluate(Call(args[0]->dtype, builtin::texture2d_store(), args)); + } + + return stmt; + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + // Lower to two dimensional access + std::string storage_scope = GetStorageScope(op->buffer); + if (IsTextureStorage(storage_scope)) { + Array args = GetTextureAccessArgs(op, op->buffer); + args.push_back(op->indices.back()); + expr = Call(op->buffer->dtype, builtin::texture2d_load(), args); + } + + return expr; + } + + protected: + template + Array GetTextureAccessArgs(const T* op, const Buffer& buffer) { + Array args; + if (let_binding_.count(op->buffer->data)) { + args.push_back(let_binding_[op->buffer->data]); + } else { + args.push_back(buffer->data); + } + Array row_dims, row_indices, col_dims, col_indices; + for (size_t i = 0; i < op->buffer->shape.size() - 1; i++) { + if (i < DefaultTextureLayoutSeparator(op->buffer->shape.size(), GetStorageScope(buffer))) { + col_dims.push_back(op->buffer->shape[i]); + col_indices.push_back(op->indices[i]); + } else { + row_dims.push_back(op->buffer->shape[i]); + row_indices.push_back(op->indices[i]); + } + } + PrimExpr row_offset = SimplifyOffset(row_dims, row_indices); + PrimExpr col_offset = SimplifyOffset(col_dims, col_indices); + args.push_back(row_offset); + args.push_back(col_offset); + return args; + } + + // Bindings to new texture vars with texture pointer scope + std::unordered_map let_binding_; +}; + +PrimFunc TextureFlatten(PrimFunc func) { + auto fptr = func.CopyOnWrite(); + IRVisitorWithAnalyzer bound_analyzer; + bound_analyzer(fptr->body); + fptr->body = TextureFlattener(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); + return func; +} + +namespace transform { + +Pass TextureFlatten() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return TextureFlatten(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.TextureFlatten", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.TextureFlatten").set_body_typed(TextureFlatten); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 8f757171afbd..35e4563b8f58 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -223,14 +223,14 @@ class ThreadSyncInserter : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { + GetScope(op->buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].read_count; } return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const StoreNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { + GetScope(op->buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].write_count; } return StmtExprMutator::VisitStmt_(op); @@ -250,10 +250,6 @@ class ThreadSyncInserter : public StmtExprMutator { is_lead_ = PrimExpr(); } return ret; - } else if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::Create(op->value.as()->value); - return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); } @@ -264,16 +260,15 @@ class ThreadSyncInserter : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); ICHECK_EQ(op->args.size(), 5U); - const VarNode* buffer_var = op->args[1].as(); - Var var(GetRef(buffer_var)); + Var buffer_var(GetRef(op->args[1].as())); const IntImmNode* flag = op->args[4].as(); if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[var].read_count; + ++rw_stats_[buffer_var].read_count; } if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[var].write_count; + ++rw_stats_[buffer_var].write_count; } return expr; } else { @@ -287,14 +282,12 @@ class ThreadSyncInserter : public StmtExprMutator { int read_count{0}; int write_count{0}; }; + // Get current storage scope. - StorageScope GetScope(const VarNode* buf) const { - auto it = storage_scope_.find(buf); - StorageScope s; - s.rank = StorageRank::kGlobal; - if (it == storage_scope_.end()) return s; - return it->second; + StorageScope GetScope(Var buffer_var) const { + return StorageScope::Create(GetPtrStorageScope(buffer_var)); } + // private functions. Stmt InitGlobalBarrier(const AttrStmtNode* op) { ICHECK(op != nullptr); @@ -337,8 +330,6 @@ class ThreadSyncInserter : public StmtExprMutator { // data structure. StorageScope sync_scope_; const std::unordered_set& syncs_; - // The storage scope of each buffer - std::unordered_map storage_scope_; // The read write statistics of storage std::unordered_map rw_stats_; // The statistics for global barrier diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc new file mode 100644 index 000000000000..6a26103e6079 --- /dev/null +++ b/src/tir/transforms/unify_thread_binding.cc @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unify_thread_binding.cc + */ + +#include +#include +#include +#include + +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief A mutator which searches AttrStmts of thread bindings and changes the `node` field IterVar + * of the AttrStmts, so that for one kind of thread binding, all such thread bindings use the same + * IterVar + */ +class ThreadBindingUnifier : public StmtExprMutator { + public: + static Stmt Unify(Stmt stmt) { return ThreadBindingUnifier()(std::move(stmt)); } + + private: + Stmt VisitStmt_(const AttrStmtNode* attr) final { + // If this AttrStmt is not thread binding attribute, return as usual. + if (attr->attr_key != attr::thread_extent && attr->attr_key != attr::virtual_thread) { + return StmtMutator::VisitStmt_(attr); + } + + // Step 1. Fetch the old IterVar and the thread tag. + IterVar old_iter_var = Downcast(attr->node); + IterVar new_iter_var{nullptr}; + const String& thread_tag = old_iter_var->thread_tag; + + // Step 2: Increase `thread_block_depth_` if the thread tag starts with "blockIdx". If the + // thread block depth is 0 before the increasement, it means we are entering a new kernel, and + // therefore we need to make `thread_tag2iter_var_map_` empty, as different kernels can have + // thread axes with different extents. + if (std::string(thread_tag).substr(0, 9) == "blockIdx.") { + if (!thread_block_depth_) { + thread_tag2iter_var_map_.clear(); + } + ++thread_block_depth_; + } + + // Step 3. See if an IterVar for this kind of thread binding was created before. If so, we use + // the created IterVar. Otherwise, we create a new IterVar for this thread binding and store the + // IterVar in mapping `thread_tag2iter_var_map_`. + Map::iterator it = thread_tag2iter_var_map_.find(thread_tag); + if (it != thread_tag2iter_var_map_.end()) { + new_iter_var = (*it).second; + CHECK(ana.CanProveEqual(old_iter_var->dom->extent, (*it).second->dom->extent)) + << "ValueError: All loops that are bound to `" << thread_tag + << "` should have the same extent. However, there are two loops with extent " + << (*it).second->dom->extent << " and " << old_iter_var->dom->extent + << ", which are not equal"; + } else { + ObjectPtr p_new_iter_var = make_object(*old_iter_var.get()); + p_new_iter_var->var = Var(thread_tag); + new_iter_var = IterVar(p_new_iter_var); + thread_tag2iter_var_map_.Set(thread_tag, new_iter_var); + } + + // Step 4. We will substitute the occurrences of the old variable in the old IterVar with the + // new variable in further mutation. Thus, we store the mapping entry. + var_substitution_map_.Set(old_iter_var->var, new_iter_var->var); + + // Step 5. Mutate recursively, update the AttrStmt with the new IterVar, and decrease the depth + // counter if the thread tag starts with "blockIdx". + AttrStmt new_attr = Downcast(StmtMutator::VisitStmt_(attr)); + ObjectPtr p_new_attr = CopyOnWrite(new_attr.get()); + p_new_attr->node = new_iter_var; + if (std::string(thread_tag).substr(0, 9) == "blockIdx.") { + --thread_block_depth_; + } + return Stmt(p_new_attr); + } + + PrimExpr VisitExpr_(const VarNode* var) final { + // If this variable appears as a key in `var_substitution_map_`, we substitute it with its + // corresponding value in the mapping. + Map::iterator it = var_substitution_map_.find(GetRef(var)); + return it != var_substitution_map_.end() ? (*it).second : GetRef(var); + } + + /*! + * \brief A mapping from a thread tag to its corresponding IterVar that is shared by all + * occurrences of the thread tag + * */ + Map thread_tag2iter_var_map_; + /*! \brief A mapping from old variables to new variables, which is used for substitution */ + Map var_substitution_map_; + /*! \brief A integer counter storing the depth of thread bindings of "blockIdx.x/y/z" */ + int thread_block_depth_ = 0; + /*! \brief An analyzer used for equality proof */ + arith::Analyzer ana; +}; + +PrimFunc UnifyThreadBinding(PrimFunc f) { + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = ThreadBindingUnifier::Unify(std::move(f->body)); + return f; + } else { + return f; + } +} + +namespace transform { + +Pass UnifyThreadBinding() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return UnifyThreadBinding(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.UnifyThreadBinding", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.UnifyThreadBinding").set_body_typed(UnifyThreadBinding); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc new file mode 100644 index 000000000000..4143577a0b17 --- /dev/null +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file update_pointer_storage_scope.cc + * \brief A pass to update storage scopes for buffer variables. + */ +#include "update_pointer_storage_scope.h" + +#include +#include +#include +#include + +#include + +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +Var WithStorageScope(const VarNode* buffer_var, String storage_scope) { + auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), + buffer_var->span); +} + +UpdatePointerStorageScope::UpdatePointerStorageScope( + const std::unordered_map& new_storage_scopes) { + for (auto& kv : new_storage_scopes) { + new_var_remap_[kv.first] = WithStorageScope(kv.first, kv.second); + } +} + +PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) { + auto it = new_var_remap_.find(op); + if (it == new_var_remap_.end()) { + return GetRef(op); + } + return it->second; +} + +PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Load(op->dtype, Downcast(remapped), StmtExprMutator::VisitExpr(op->index), + StmtExprMutator::VisitExpr(op->predicate)); +} + +Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { + auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); + return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), + StmtExprMutator::VisitStmt(op->body)); +} + +Stmt UpdatePointerStorageScope::VisitStmt_(const StoreNode* op) { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Store(Downcast(remapped), StmtExprMutator::VisitExpr(op->value), + StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/update_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h new file mode 100644 index 000000000000..f310194a4a51 --- /dev/null +++ b/src/tir/transforms/update_pointer_storage_scope.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file update_pointer_storage_scope.h + * \brief A pass to update storage scopes for buffer variables. + */ +#ifndef TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ +#define TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace tir { + +class UpdatePointerStorageScope : public StmtExprMutator { + public: + explicit UpdatePointerStorageScope( + const std::unordered_map& new_storage_scopes); + + virtual PrimExpr VisitExpr_(const VarNode*); + virtual PrimExpr VisitExpr_(const LoadNode*); + virtual Stmt VisitStmt_(const AllocateNode*); + virtual Stmt VisitStmt_(const StoreNode*); + + private: + std::unordered_map new_var_remap_; +}; + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 64956bc8ee54..cd2d230f5775 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -265,6 +265,20 @@ class Vectorizer : public StmtMutator, public ExprFunctorop.same_as(builtin::if_then_else())) { return MutateIfThenElseExpr_(op); + } else if (op->op.same_as(builtin::texture2d_load())) { + int lane = 0; + Array fcd = MutateArray({op->args.back()}, &lane); + auto new_args = op->args; + new_args.pop_back(); + new_args.push_back(fcd[0]); + return Call(op->dtype.with_lanes(4), op->op, new_args); + } else if (op->op.same_as(builtin::texture2d_store())) { + int lane = 0; + // Vectorize the value to store + Array value{op->args.back()}; + Array mutated_value = MutateArray(value, &lane); + Array new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]}; + return Call(op->dtype.with_lanes(lane), op->op, new_args); } auto* op_ptr = op->op.as(); bool vectorizable = op_ptr && op_vectorizable_.get(GetRef(op_ptr), false); diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc index 829e2689887e..a3e9bdfa56bd 100644 --- a/tests/cpp/arith_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -53,8 +53,3 @@ TEST(Simplify, Mod) { auto es = ana.canonical_simplify(mod - x); ICHECK(tvm::tir::is_zero(es)); } -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/attrs_test.cc b/tests/cpp/attrs_test.cc index 4d6de60b9706..d836639043d1 100644 --- a/tests/cpp/attrs_test.cc +++ b/tests/cpp/attrs_test.cc @@ -81,9 +81,3 @@ TEST(Attrs, Basic) { LOG(INFO) << "docstring\n" << os.str(); ICHECK(os.str().find("expr : PrimExpr, default=1") != std::string::npos); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/auto_scheduler_test.cc b/tests/cpp/auto_scheduler_test.cc index 16dfd56a69ea..0a753bc9a740 100644 --- a/tests/cpp/auto_scheduler_test.cc +++ b/tests/cpp/auto_scheduler_test.cc @@ -170,9 +170,3 @@ TEST(ComputeDAG, AccessAnalyzer) { } } } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 204a824f9248..2295c3dafe46 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -152,9 +152,9 @@ TEST(BuildModule, Heterogeneous) { auto b_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto c_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto pa = (float*)(a_val->data); - auto pb = (float*)(b_val->data); - auto pc = (float*)(c_val->data); + auto pa = static_cast(a_val->data); + auto pb = static_cast(b_val->data); + auto pc = static_cast(c_val->data); // Assign values. for (int i = 0; i < n; i++) { @@ -192,16 +192,10 @@ TEST(BuildModule, Heterogeneous) { run(); tvm::runtime::NDArray out = get_output(0); - float* p_out = (float*)out->data; + float* p_out = static_cast(out->data); // Check correctness. for (int i = 0; i < n; ++i) { ICHECK_LT(std::fabs(p_out[i] - (i + (i + 1.0) - (i - 1.0))), 1e-5); } } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 7d1fa790146e..019fde069878 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -42,7 +42,7 @@ class TestErrorSwitch { const_cast(other).should_fail = false; } - TestErrorSwitch(bool fail_flag) : should_fail{fail_flag} {} + explicit TestErrorSwitch(bool fail_flag) : should_fail{fail_flag} {} bool should_fail{false}; ~TestErrorSwitch() { @@ -695,9 +695,3 @@ TEST(Optional, PackedCall) { test_ffi(s, static_cast(kTVMObjectHandle)); test_ffi(String(s), static_cast(kTVMObjectRValueRefArg)); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/contrib/bnns.cc b/tests/cpp/contrib/bnns.cc deleted file mode 100644 index f7d40f176fb6..000000000000 --- a/tests/cpp/contrib/bnns.cc +++ /dev/null @@ -1,307 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -TEST(PackedFunc, Basic) { - using namespace tvm; - using namespace tvm::tir; - using namespace tvm::runtime; - int x = 0; - void* handle = &x; - DLTensor a; - - Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - ICHECK(args.num_args == 3); - ICHECK(args.values[0].v_float64 == 1.0); - ICHECK(args.type_codes[0] == kDLFloat); - ICHECK(args.values[1].v_handle == &a); - ICHECK(args.type_codes[1] == kTVMDLTensorHandle); - ICHECK(args.values[2].v_handle == &x); - ICHECK(args.type_codes[2] == kTVMOpaqueHandle); - *rv = Var("a"); - })(1.0, &a, handle); - ICHECK(v->name_hint == "a"); -} - -TEST(PackedFunc, Node) { - using namespace tvm; - using namespace tvm::tir; - using namespace tvm::runtime; - Var x; - Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - ICHECK(args.num_args == 1); - ICHECK(args[0].IsObjectRef()); - Var b = args[0]; - ICHECK(x.same_as(b)); - *rv = b; - })(x); - ICHECK(t.same_as(x)); -} - -TEST(PackedFunc, NDArray) { - using namespace tvm; - using namespace tvm::runtime; - auto x = NDArray::Empty({}, String2DLDataType("float32"), Device{kDLCPU, 0}); - reinterpret_cast(x->data)[0] = 10.0f; - ICHECK(x.use_count() == 1); - - PackedFunc forward([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); - - NDArray ret = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - NDArray y = args[0]; - DLTensor* ptr = args[0]; - ICHECK(ptr == x.operator->()); - ICHECK(x.same_as(y)); - ICHECK(x.use_count() == 2); - *rv = forward(y); - })(x); - ICHECK(ret.use_count() == 2); - ICHECK(ret.same_as(x)); -} - -TEST(PackedFunc, str) { - using namespace tvm; - using namespace tvm::runtime; - PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - ICHECK(args.num_args == 1); - std::string x = args[0]; - ICHECK(x == "hello"); - String y = args[0]; - ICHECK(y == "hello"); - *rv = x; - })("hello"); - - PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - ICHECK(args.num_args == 1); - runtime::String s = args[0]; - ICHECK(s == "hello"); - })(runtime::String("hello")); -} - -TEST(PackedFunc, func) { - using namespace tvm; - using namespace tvm::runtime; - PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0].operator int() + 1; }); - // function as arguments - int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - PackedFunc f = args[0]; - // TVMArgValue -> Arguments as function - *rv = f(args[1]).operator int(); - })(addone, 1); - ICHECK_EQ(r0, 2); - - int r1 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - // TVMArgValue -> TVMRetValue - *rv = args[1]; - })(2, 100); - ICHECK_EQ(r1, 100); - - int r2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - // re-assignment - *rv = args[0]; - // TVMRetValue -> Function argument - *rv = addone(args[0].operator PackedFunc()(args[1], 1)); - })(addone, 100); - ICHECK_EQ(r2, 102); -} - -TEST(PackedFunc, Expr) { - using namespace tvm; - using namespace tvm::runtime; - // automatic conversion of int to expr - PackedFunc addone([](TVMArgs args, TVMRetValue* rv) { - PrimExpr x = args[0]; - *rv = x.as()->value + 1; - }); - int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - PackedFunc f = args[0]; - // TVMArgValue -> Arguments as function - *rv = f(args[1]).operator int(); - })(addone, 1); - ICHECK_EQ(r0, 2); -} - -TEST(PackedFunc, Type) { - using namespace tvm; - using namespace tvm::runtime; - auto get_type = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - DataType x = args[0]; - *rv = x; - }); - auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); - ICHECK(get_type("int32").operator DataType() == DataType::Int(32)); - ICHECK(get_type("float").operator DataType() == DataType::Float(32)); - ICHECK(get_type2("float32x2").operator DataType() == DataType::Float(32, 2)); -} - -TEST(TypedPackedFunc, HighOrder) { - using namespace tvm; - using namespace tvm::runtime; - using Int1Func = TypedPackedFunc; - using Int2Func = TypedPackedFunc; - using BindFunc = TypedPackedFunc; - BindFunc ftyped; - ftyped = [](Int2Func f1, int value) -> Int1Func { - auto binded = [f1, value](int x) { return f1(value, x); }; - Int1Func x(binded); - return x; - }; - auto add = [](int x, int y) { return x + y; }; - ICHECK_EQ(ftyped(Int2Func(add), 1)(2), 3); - PackedFunc f = ftyped(Int2Func(add), 1); - ICHECK_EQ(f(3).operator int(), 4); - // call the type erased version. - Int1Func f1 = ftyped.packed()(Int2Func(add), 1); - ICHECK_EQ(f1(3), 4); -} - -TEST(TypedPackedFunc, Deduce) { - using namespace tvm::runtime; - using tvm::runtime::detail::function_signature; - - TypedPackedFunc x; - auto f = [](int x) -> int { return x + 1; }; - std::function y; - - static_assert(std::is_same::FType, int(float)>::value, - "invariant1"); - static_assert(std::is_same::FType, int(int)>::value, - "invariant2"); - static_assert(std::is_same::FType, void(float)>::value, - "invariant3"); -} - -TEST(PackedFunc, ObjectConversion) { - using namespace tvm; - using namespace tvm::tir; - using namespace tvm::runtime; - TVMRetValue rv; - auto x = NDArray::Empty({}, String2DLDataType("float32"), Device{kDLCPU, 0}); - // assign null - rv = ObjectRef(); - ICHECK_EQ(rv.type_code(), kTVMNullptr); - - // Can assign NDArray to ret type - rv = x; - ICHECK_EQ(rv.type_code(), kTVMNDArrayHandle); - // Even if we assign base type it still shows as NDArray - rv = ObjectRef(x); - ICHECK_EQ(rv.type_code(), kTVMNDArrayHandle); - // Check convert back - ICHECK(rv.operator NDArray().same_as(x)); - ICHECK(rv.operator ObjectRef().same_as(x)); - ICHECK(!rv.IsObjectRef()); - - auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args[0].type_code(), kTVMNDArrayHandle); - ICHECK(args[0].operator NDArray().same_as(x)); - ICHECK(args[0].operator ObjectRef().same_as(x)); - ICHECK(args[1].operator ObjectRef().get() == nullptr); - ICHECK(args[1].operator NDArray().get() == nullptr); - ICHECK(args[1].operator Module().get() == nullptr); - ICHECK(args[1].operator Array().get() == nullptr); - ICHECK(!args[0].IsObjectRef()); - }); - pf1(x, ObjectRef()); - pf1(ObjectRef(x), NDArray()); - - // testcases for modules - auto* pf = tvm::runtime::Registry::Get("runtime.SourceModuleCreate"); - ICHECK(pf != nullptr); - Module m = (*pf)("", "xyz"); - rv = m; - ICHECK_EQ(rv.type_code(), kTVMModuleHandle); - // Even if we assign base type it still shows as NDArray - rv = ObjectRef(m); - ICHECK_EQ(rv.type_code(), kTVMModuleHandle); - // Check convert back - ICHECK(rv.operator Module().same_as(m)); - ICHECK(rv.operator ObjectRef().same_as(m)); - ICHECK(!rv.IsObjectRef()); - - auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args[0].type_code(), kTVMModuleHandle); - ICHECK(args[0].operator Module().same_as(m)); - ICHECK(args[0].operator ObjectRef().same_as(m)); - ICHECK(args[1].operator ObjectRef().get() == nullptr); - ICHECK(args[1].operator NDArray().get() == nullptr); - ICHECK(args[1].operator Module().get() == nullptr); - ICHECK(!args[0].IsObjectRef()); - }); - pf2(m, ObjectRef()); - pf2(ObjectRef(m), Module()); -} - -TEST(TypedPackedFunc, RValue) { - using namespace tvm; - using namespace tvm::runtime; - { - auto inspect = [](TVMArgs args, TVMRetValue* rv) { - for (int i = 0; i < args.size(); ++i) { - ICHECK_EQ(args[0].type_code(), kTVMObjectRValueRefArg); - } - }; - PackedFunc finspect(inspect); - finspect(tir::Var("x")); - } - { - auto f = [](tir::Var x, bool move) { - if (move) { - ICHECK(x.unique()); - } else { - ICHECK(!x.unique()); - } - ICHECK(x->name_hint == "x"); - return x; - }; - TypedPackedFunc tf(f); - - tir::Var var("x"); - ICHECK(var.unique()); - tf(var, false); - // move the result to the function. - tir::Var ret = tf(std::move(var), true); - ICHECK(!var.defined()); - } - - { - // pass child class. - auto f = [](PrimExpr x, bool move) { - if (move) { - ICHECK(x.unique()); - } else { - ICHECK(!x.unique()); - } - return x; - }; - TypedPackedFunc tf(f); - - tir::Var var("x"); - ICHECK(var.unique()); - tf(var, false); - tf(std::move(var), true); - // auto conversion. - tf(1, true); - } -} - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/dataflow_pattern_test.cc b/tests/cpp/dataflow_pattern_test.cc index bdccaaa2e6ba..0545c19d2e3a 100644 --- a/tests/cpp/dataflow_pattern_test.cc +++ b/tests/cpp/dataflow_pattern_test.cc @@ -192,9 +192,3 @@ TEST(DFPattern, HasShape) { ICHECK(node->pattern == a); ICHECK(node->shape == shape); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index 99ff26dc0b58..9c9ea756bbb9 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -42,9 +42,3 @@ TEST(ExprNodeRef, Basic) { const tir::MaxNode* op = z.as(); ICHECK(GetRef(op).same_as(z)); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 9e8595d6809c..97809b0e1398 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -324,9 +324,3 @@ TEST(IRF, StmtMutator) { ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x)); } } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/microtvm_runtime_standalone_test.cc b/tests/cpp/microtvm_runtime_standalone_test.cc index ee324f89b48f..8a9ec1d4f85b 100644 --- a/tests/cpp/microtvm_runtime_standalone_test.cc +++ b/tests/cpp/microtvm_runtime_standalone_test.cc @@ -63,9 +63,9 @@ TEST(MicroStandaloneRuntime, BuildModule) { auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto pA = (float*)A->data; - auto pB = (float*)B->data; - auto pC = (float*)C->data; + auto pA = static_cast(A->data); + auto pB = static_cast(B->data); + auto pC = static_cast(C->data); for (int i = 0; i < 6; ++i) { pA[i] = i; @@ -118,7 +118,7 @@ TEST(MicroStandaloneRuntime, BuildModule) { MicroTVMRuntimeRun(handle); auto Y = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); MicroTVMRuntimeGetOutput(handle, 0, const_cast(Y.operator->())); - auto* pY = (float*)Y->data; + auto* pY = static_cast(Y->data); for (int i = 0; i < 6; ++i) { CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4); } @@ -128,9 +128,3 @@ TEST(MicroStandaloneRuntime, BuildModule) { #endif #endif - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/object_protocol_test.cc b/tests/cpp/object_protocol_test.cc index aaf9ee4af271..42928b484da9 100644 --- a/tests/cpp/object_protocol_test.cc +++ b/tests/cpp/object_protocol_test.cc @@ -95,9 +95,3 @@ TEST(ObjectHierachy, Basic) { ICHECK(refB.as() == nullptr); ICHECK(refB.as() != nullptr); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index f993f9605c91..ef72d03cf9ce 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -185,9 +185,10 @@ TEST(TypedPackedFunc, Deduce) { auto f = [](int x) -> int { return x + 1; }; std::function y; - static_assert(std::is_same::FType, int(float)>::value, - "invariant1"); - static_assert(std::is_same::FType, int(int)>::value, + static_assert( + std::is_same::FType, int(float)>::value, // NOLINT(*) + "invariant1"); + static_assert(std::is_same::FType, int(int)>::value, // NOLINT(*) "invariant2"); static_assert(std::is_same::FType, void(float)>::value, "invariant3"); @@ -306,9 +307,3 @@ TEST(TypedPackedFunc, RValue) { tf(1, true); } } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/parallel_for_test.cc b/tests/cpp/parallel_for_test.cc index a4549344bd11..c1e568e4cede 100644 --- a/tests/cpp/parallel_for_test.cc +++ b/tests/cpp/parallel_for_test.cc @@ -118,9 +118,3 @@ TEST(ParallelFor, Exception) { } ICHECK(exception); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index dfe09406ba52..4194c760628a 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -138,9 +138,3 @@ TEST(Pattern, IntImm) { // cannot match tx + 1 to v ICHECK(!(v * c).Match((tx + 1) * 3)); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/profiling_test.cc b/tests/cpp/profiling_test.cc index f770bfda8e5b..d2fc0e95db2c 100644 --- a/tests/cpp/profiling_test.cc +++ b/tests/cpp/profiling_test.cc @@ -39,9 +39,3 @@ TEST(DefaultTimer, Basic) { } } // namespace runtime } // namespace tvm - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/random_engine_test.cc b/tests/cpp/random_engine_test.cc new file mode 100644 index 000000000000..bc835dede4ee --- /dev/null +++ b/tests/cpp/random_engine_test.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +TEST(RandomEngine, Randomness) { + int64_t rand_state = 0; + + tvm::support::LinearCongruentialEngine rng(&rand_state); + rng.Seed(0x114514); + + bool covered[100]; + memset(covered, 0, sizeof(covered)); + for (int i = 0; i < 100000; i++) { + covered[rng() % 100] = true; + } + for (int i = 0; i < 100; i++) { + ICHECK(covered[i]); + } +} + +TEST(RandomEngine, Reproducibility) { + int64_t rand_state_a = 0, rand_state_b = 0; + tvm::support::LinearCongruentialEngine rng_a(&rand_state_a), rng_b(&rand_state_b); + + rng_a.Seed(0x23456789); + rng_b.Seed(0x23456789); + + for (int i = 0; i < 100000; i++) { + ICHECK_EQ(rng_a(), rng_b()); + } +} + +TEST(RandomEngine, Serialization) { + int64_t rand_state_a = 0, rand_state_b = 0; + tvm::support::LinearCongruentialEngine rng_a(&rand_state_a), rng_b(&rand_state_b); + + rng_a.Seed(0x56728); + + rand_state_b = rand_state_a; + for (int i = 0; i < 100000; i++) ICHECK_EQ(rng_a(), rng_b()); + + for (int i = 0; i < 123456; i++) rng_a(); + + rand_state_b = rand_state_a; + for (int i = 0; i < 100000; i++) ICHECK_EQ(rng_a(), rng_b()); +} diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 37e9e6f9c42c..ebb2867e7b69 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -86,9 +86,9 @@ TEST(Relay, BuildModule) { auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto pA = (float*)A->data; - auto pB = (float*)B->data; - auto pC = (float*)C->data; + auto pA = static_cast(A->data); + auto pB = static_cast(B->data); + auto pC = static_cast(C->data); for (int i = 0; i < 6; ++i) { pA[i] = i; @@ -127,7 +127,8 @@ TEST(Relay, BuildModule) { auto dev = A->device; auto pfr = tvm::runtime::Registry::Get("tvm.graph_executor.create"); ICHECK(mod.defined()) << "Module must be defined"; - tvm::runtime::Module run_mod = (*pfr)(json, mod, (int)dev.device_type, (int)dev.device_id); + tvm::runtime::Module run_mod = + (*pfr)(json, mod, static_cast(dev.device_type), dev.device_id); auto set_input_f = run_mod.GetFunction("set_input_zero_copy", false); auto run_f = run_mod.GetFunction("run", false); auto get_output_f = run_mod.GetFunction("get_output", false); @@ -136,7 +137,7 @@ TEST(Relay, BuildModule) { set_input_f("c", const_cast(C.operator->())); run_f(); tvm::runtime::NDArray Y = get_output_f(0); - auto pY = (float*)Y->data; + auto pY = static_cast(Y->data); for (int i = 0; i < 6; ++i) { ICHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4); } @@ -146,20 +147,20 @@ TEST(Relay, BuildModule) { } run_f(); tvm::runtime::NDArray Y2 = get_output_f(0); - auto pY2 = (float*)Y2->data; + auto pY2 = static_cast(Y2->data); for (int i = 0; i < 6; ++i) { ICHECK_LT(fabs(pY2[i] - (i + (i + 3) + (i + 2))), 1e-4); } // attach a different input and run it again auto C2 = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto pC2 = (float*)C2->data; + auto pC2 = static_cast(C2->data); for (int i = 0; i < 6; ++i) { pC2[i] = i + 4; } set_input_f("c", const_cast(C2.operator->())); run_f(); tvm::runtime::NDArray Y3 = get_output_f(0); - auto pY3 = (float*)Y3->data; + auto pY3 = static_cast(Y3->data); for (int i = 0; i < 6; ++i) { ICHECK_LT(fabs(pY3[i] - (i + (i + 3) + (i + 4))), 1e-4); } @@ -181,9 +182,3 @@ TEST(Relay, GetExprRefCount) { ICHECK(ref_count[y.get()] == 1); ICHECK(ref_count[z.get()] == 1); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/relay_dismantler_test.cc b/tests/cpp/relay_dismantler_test.cc index 8c74d4151818..37b44524e770 100644 --- a/tests/cpp/relay_dismantler_test.cc +++ b/tests/cpp/relay_dismantler_test.cc @@ -143,9 +143,3 @@ TEST(Relay, TupleiGetItemSharedTuple) { .as() ->args.size()); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 38ac906c6dac..6db595281813 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -41,9 +41,3 @@ TEST(Relay, SelfReference) { auto expected = relay::FuncType(tvm::Array{tensor_type}, tensor_type, {}, {}); ICHECK(tvm::StructuralEqual()(type_fx->checked_type(), expected)); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/relay_text_printer_test.cc b/tests/cpp/relay_text_printer_test.cc index ed8029064720..58fa228f8a46 100644 --- a/tests/cpp/relay_text_printer_test.cc +++ b/tests/cpp/relay_text_printer_test.cc @@ -56,9 +56,3 @@ TEST(Relay, LargeGraphPrint) { }; ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*"); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/relay_transform_sequential_test.cc b/tests/cpp/relay_transform_sequential_test.cc index 6d38e1017042..ae1a546e5ddf 100644 --- a/tests/cpp/relay_transform_sequential_test.cc +++ b/tests/cpp/relay_transform_sequential_test.cc @@ -128,9 +128,3 @@ TEST(PassContextListConfigs, Basic) { auto config = configs["relay.backend.use_auto_scheduler"]; ICHECK_EQ(config["type"], "IntImm"); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/runtime_test.cc b/tests/cpp/runtime_test.cc new file mode 100644 index 000000000000..6dbcd61b8c37 --- /dev/null +++ b/tests/cpp/runtime_test.cc @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace tvm; +using namespace tvm::relay; + +TVM_REGISTER_GLOBAL("runtime_test.strategy") + .set_body_typed([](const Attrs& attrs, const Array& inputs, const Type& out_type, + const Target& target) { + FTVMCompute fcompute = [](const Attrs& attrs, const Array& inputs, + const Type& out_type) -> Array { + ICHECK_EQ(inputs.size(), 2U); + return {topi::add(inputs[0], inputs[1])}; + }; + FTVMSchedule fschedule = [](const Attrs& attrs, const Array& outs, + const Target& target) { + With target_scope(target); + return topi::generic::schedule_injective(target, outs); + }; + + auto n = make_object(); + auto strategy = tvm::relay::OpStrategy(std::move(n)); + strategy.AddImplementation(fcompute, fschedule, "runtime_test.strategy", 10); + return strategy; + }); + +TEST(Runtime, ZeroCopy) { + auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32)); + auto a = relay::Var("a", tensor_type); + auto b = relay::Var("b", tensor_type); + auto add_op = relay::Op::Get("add"); + auto x = relay::Call(add_op, {a, b}, tvm::Attrs(), {}); + auto c = relay::Var("c", tensor_type); + auto y = relay::Call(add_op, {x, c}, tvm::Attrs(), {}); + auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {}); + auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto Y = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + + auto pA = static_cast(A->data); + auto pB = static_cast(B->data); + auto pC = static_cast(C->data); + auto pY = static_cast(Y->data); + + for (int i = 0; i < 6; ++i) { + pA[i] = i; + pB[i] = i + 1; + pC[i] = i + 2; + } + // get schedule + auto reg = tvm::runtime::Registry::Get("ir.RegisterOpAttr"); + if (!reg) { + LOG(FATAL) << "no _Register"; + } + auto fs = tvm::runtime::Registry::Get("runtime_test.strategy"); + if (!fs) { + LOG(FATAL) << "No test_strategy registered."; + } + auto fgeneric = GenericFunc::Get("runtime_test.strategy_generic").set_default(*fs); + (*reg)("add", "FTVMStrategy", fgeneric, 10); + Array dep; + dep.push_back(0); + (*reg)("add", "TShapeDataDependent", dep, 10); + // build + auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); + tvm::runtime::Module build_mod = (*pfb)(); + auto build_f = build_mod.GetFunction("build", false); + auto json_f = build_mod.GetFunction("get_graph_json", false); + auto mod_f = build_mod.GetFunction("get_module", false); + Map targets; + Target llvm_tgt = Target("llvm"); + targets.Set(0, llvm_tgt); + auto relay_mod = tvm::IRModule::FromExpr(func); + ICHECK(relay_mod.defined()) << "Module must be defined"; + build_f(relay_mod, targets, llvm_tgt, runtime::kTvmExecutorGraph, ""); + // create graph executor + std::string json = json_f(); + tvm::runtime::Module mod = mod_f(); + auto dev = A->device; + auto pfr = tvm::runtime::Registry::Get("tvm.graph_executor.create"); + ICHECK(mod.defined()) << "Module must be defined"; + tvm::runtime::Module run_mod = + (*pfr)(json, mod, static_cast(dev.device_type), dev.device_id); + // get function + auto set_input_f = run_mod.GetFunction("set_input_zero_copy", false); + auto set_output_f = run_mod.GetFunction("set_output_zero_copy", false); + auto run_f = run_mod.GetFunction("run", false); + // set input zero copy + set_input_f("a", const_cast(A.operator->())); + set_input_f("b", const_cast(B.operator->())); + set_input_f("c", const_cast(C.operator->())); + // set output zero copy + set_output_f(0, const_cast(Y.operator->())); + run_f(); + // check correctness + for (int i = 0; i < 6; ++i) { + ICHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4); + } + // mutate the input a bit and run it again + for (int i = 0; i < 6; ++i) { + pB[i] = i + 3; + } + run_f(); + // check correctness + for (int i = 0; i < 6; ++i) { + ICHECK_LT(fabs(pY[i] - (i + (i + 3) + (i + 2))), 1e-4); + } + // attach a different input and run it again + auto C2 = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto pC2 = static_cast(C2->data); + for (int i = 0; i < 6; ++i) { + pC2[i] = i + 4; + } + set_input_f("c", const_cast(C2.operator->())); + run_f(); + // check correctness + for (int i = 0; i < 6; ++i) { + ICHECK_LT(fabs(pY[i] - (i + (i + 3) + (i + 4))), 1e-4); + } +} diff --git a/tests/cpp/support_test.cc b/tests/cpp/support_test.cc index 7d523fe8d537..df9271f4b49c 100644 --- a/tests/cpp/support_test.cc +++ b/tests/cpp/support_test.cc @@ -58,9 +58,3 @@ TEST(HashTests, HashStability) { } // namespace test } // namespace tvm - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 8dba462132ac..2e8ba11c0262 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -157,9 +157,3 @@ TEST(TargetKindRegistryListTargetKinds, Basic) { ICHECK_EQ(names.empty(), false); ICHECK_EQ(std::count(std::begin(names), std::end(names), "llvm"), 1); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/tensor_test.cc b/tests/cpp/tensor_test.cc index ea02ca656dce..a50af838f735 100644 --- a/tests/cpp/tensor_test.cc +++ b/tests/cpp/tensor_test.cc @@ -49,9 +49,3 @@ TEST(Tensor, Reduce) { {m, n}, [&](Var i, Var j) { return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); }, "C"); LOG(INFO) << C->op.as()->body; } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/texture_copy_test.cc b/tests/cpp/texture_copy_test.cc index 688bcab758ca..92c12bafdd9a 100644 --- a/tests/cpp/texture_copy_test.cc +++ b/tests/cpp/texture_copy_test.cc @@ -134,9 +134,3 @@ TEST(TextureCopy, OverwritePoolSubview) { } } } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/threading_backend_test.cc b/tests/cpp/threading_backend_test.cc index cf7434b4b036..948838971796 100644 --- a/tests/cpp/threading_backend_test.cc +++ b/tests/cpp/threading_backend_test.cc @@ -63,9 +63,3 @@ TEST(ThreadingBackend, TVMBackendParallelLaunchMultipleThreads) { } } } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/tir_analysis_side_effect.cc b/tests/cpp/tir_analysis_side_effect.cc index 022f2cffeda8..a59e4a7f8c05 100644 --- a/tests/cpp/tir_analysis_side_effect.cc +++ b/tests/cpp/tir_analysis_side_effect.cc @@ -33,9 +33,3 @@ TEST(SimplePasses, SideEffect) { ICHECK(tir::SideEffect(tir::Call(DataType::Handle(), tir::builtin::tvm_storage_sync(), {})) == tir::CallEffectKind::kUpdateState); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/topi_ewise_test.cc b/tests/cpp/topi_ewise_test.cc index 22ef8c7dffaa..9f4457de5192 100644 --- a/tests/cpp/topi_ewise_test.cc +++ b/tests/cpp/topi_ewise_test.cc @@ -25,15 +25,9 @@ namespace tvm { namespace topi { TEST(Tensor, Basic) { using namespace tvm; - Var m("m"), n("n"), l("l"); + Var m("m"), l("l"); Tensor A = placeholder({m, l}, DataType::Float(32), "A"); auto C = topi::exp(A); } } // namespace topi } // namespace tvm - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/crt/aot_executor_test.cc b/tests/crt/aot_executor_test.cc deleted file mode 100644 index ded6729d138b..000000000000 --- a/tests/crt/aot_executor_test.cc +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include - -int test_run_func(TVMValue* args, int* arg_type_ids, int num_args, TVMValue* out_ret_value, - int* out_ret_tcode, void* resource_handle) { - return kTvmErrorNoError; -} - -TEST(AOTRuntime, NoOp) { - const tvm_model_t test_model = { - .num_input_tensors = 0, - .num_output_tensors = 0, - .run_func = &test_run_func, - }; - - ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&test_model, NULL, NULL)); -} - -int32_t error_run_func(TVMValue* args, int* arg_type_ids, int32_t num_args, TVMValue* out_ret_value, - int* out_ret_tcode, void* resource_handle) { - return kTvmErrorPlatformNoMemory; -} - -TEST(AOTRuntime, Error) { - const tvm_model_t error_model = { - .num_input_tensors = 0, - .num_output_tensors = 0, - .run_func = &error_run_func, - }; - - ASSERT_EQ(kTvmErrorPlatformNoMemory, tvm_runtime_run(&error_model, NULL, NULL)); -} - -int32_t identity_run_func(TVMValue* args, int* arg_type_ids, int32_t num_args, - TVMValue* out_ret_value, int* out_ret_tcode, void* resource_handle) { - void* arg0 = (((TVMValue*)args)[0].v_handle); - void* arg1 = (((TVMValue*)args)[1].v_handle); - void* placeholder = (((DLTensor*)arg0)[0].data); - void* T_id = (((DLTensor*)arg1)[0].data); - ((uint32_t*)T_id)[(0)] = ((uint32_t*)placeholder)[(0)]; - return kTvmErrorNoError; -} - -TEST(AOTRuntime, Identity) { - const tvm_model_t identity_model = { - .num_input_tensors = 1, - .num_output_tensors = 1, - .run_func = &identity_run_func, - }; - - uint32_t inputs1[1] = {404}; - void* inputs[] = {inputs1}; - uint32_t outputs1[1]; - void* outputs[] = {outputs1}; - - ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&identity_model, inputs, outputs)); - ASSERT_EQ(outputs1[0], 404); -} - -int32_t add_run_func(TVMValue* args, int* arg_type_ids, int32_t num_args, TVMValue* out_ret_value, - int* out_ret_tcode, void* resource_handle) { - void* arg0 = (((TVMValue*)args)[0].v_handle); - void* arg1 = (((TVMValue*)args)[1].v_handle); - void* placeholder = (((DLTensor*)arg0)[0].data); - void* T_add = (((DLTensor*)arg1)[0].data); - ((uint32_t*)T_add)[(0)] = ((uint32_t*)placeholder)[(0)] + ((uint32_t*)placeholder)[(1)]; - return kTvmErrorNoError; - - return kTvmErrorNoError; -} - -TEST(AOTRuntime, Add) { - const tvm_model_t add_model = { - .num_input_tensors = 1, - .num_output_tensors = 1, - .run_func = &add_run_func, - }; - - uint32_t inputs1[2] = {404, 500}; - void* inputs[] = {inputs1}; - uint32_t outputs1[1]; - void* outputs[] = {outputs1}; - - ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&add_model, inputs, outputs)); - ASSERT_EQ(outputs1[0], 904); -} - -int32_t multiple_inputs_run_func(TVMValue* args, int* arg_type_ids, int32_t num_args, - TVMValue* out_ret_value, int* out_ret_tcode, - void* resource_handle) { - void* arg0 = (((TVMValue*)args)[0].v_handle); - void* arg1 = (((TVMValue*)args)[1].v_handle); - void* arg2 = (((TVMValue*)args)[2].v_handle); - void* placeholder = (((DLTensor*)arg0)[0].data); - void* placeholder1 = (((DLTensor*)arg1)[0].data); - void* T_add = (((DLTensor*)arg2)[0].data); - ((uint32_t*)T_add)[(0)] = ((uint32_t*)placeholder)[(0)] + ((uint32_t*)placeholder)[(1)] + - ((uint32_t*)placeholder1)[(0)] + ((uint32_t*)placeholder1)[(1)]; - return kTvmErrorNoError; -} - -TEST(AOTRuntime, MultipleInputs) { - const tvm_model_t multiple_inputs_model = { - .num_input_tensors = 2, - .num_output_tensors = 1, - .run_func = &multiple_inputs_run_func, - }; - - uint32_t inputs1[2] = {404, 500}; - uint32_t inputs2[2] = {200, 202}; - void* inputs[] = {inputs1, inputs2}; - - uint32_t outputs1[1]; - void* outputs[] = {outputs1}; - - ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&multiple_inputs_model, inputs, outputs)); - ASSERT_EQ(outputs1[0], 1306); -} - -int32_t multiple_outputs_run_func(TVMValue* args, int* arg_type_ids, int32_t num_args, - TVMValue* out_ret_value, int* out_ret_tcode, - void* resource_handle) { - void* arg0 = (((TVMValue*)args)[0].v_handle); - void* arg1 = (((TVMValue*)args)[1].v_handle); - void* arg2 = (((TVMValue*)args)[2].v_handle); - void* placeholder = (((DLTensor*)arg0)[0].data); - void* T_split1 = (((DLTensor*)arg1)[0].data); - void* T_split2 = (((DLTensor*)arg2)[0].data); - ((uint32_t*)T_split1)[(0)] = ((uint32_t*)placeholder)[(0)]; - ((uint32_t*)T_split2)[(0)] = ((uint32_t*)placeholder)[(1)]; - return kTvmErrorNoError; -} - -TEST(AOTRuntime, MultipleOutputs) { - const tvm_model_t multiple_outputs_model = { - .num_input_tensors = 1, - .num_output_tensors = 2, - .run_func = &multiple_outputs_run_func, - }; - - uint32_t inputs1[2] = {404, 500}; - void* inputs[] = {inputs1}; - - uint32_t outputs1[1]; - uint32_t outputs2[1]; - void* outputs[] = {outputs1, outputs2}; - - ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&multiple_outputs_model, inputs, outputs)); - ASSERT_EQ(outputs1[0], 404); - ASSERT_EQ(outputs2[0], 500); -} - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/crt/buffer_write_stream.h b/tests/crt/buffer_write_stream.h index 66ef044e6ba1..48a30ac4b273 100644 --- a/tests/crt/buffer_write_stream.h +++ b/tests/crt/buffer_write_stream.h @@ -24,6 +24,8 @@ #include #include +#include + using ::tvm::runtime::micro_rpc::FrameBuffer; using ::tvm::runtime::micro_rpc::WriteStream; @@ -51,7 +53,7 @@ class BufferWriteStream : public WriteStream { std::string BufferContents() { return std::string((const char*)buffer_data_, buffer_.Size()); } - static constexpr unsigned int capacity() { return N; }; + static constexpr unsigned int capacity() { return N; } private: bool packet_done_{false}; diff --git a/tests/crt/framing_test.cc b/tests/crt/framing_test.cc index 5ee226dc5ee7..e257dfc641ab 100644 --- a/tests/crt/framing_test.cc +++ b/tests/crt/framing_test.cc @@ -27,7 +27,6 @@ #include "buffer_write_stream.h" #include "crt_config.h" -#include "platform.cc" using ::tvm::runtime::micro_rpc::Escape; using ::tvm::runtime::micro_rpc::FrameBuffer; @@ -62,16 +61,6 @@ class TestPacket { std::string wire; }; -void PrintTo(const TestPacket* p, std::ostream* os) { - *os << "TestPacket(\"" << p->name << "\", ...)"; -} - -void PrintTo(tvm_crt_error_t p, std::ostream* os) { - std::ios_base::fmtflags f(os->flags()); - *os << "tvm_crt_error_t(0x" << std::hex << std::setw(8) << std::setfill('0') << p << ")"; - os->flags(f); -} - std::vector TestPacket::instances; #define TEST_PACKET(name, payload, wire) \ @@ -161,23 +150,25 @@ TEST_F(UnframerTest, PacketTooLong) { unframer_.Write(packet_length_bytes, sizeof(packet_length), &bytes_consumed)); EXPECT_EQ(sizeof(packet_length), bytes_consumed); - uint8_t long_payload[decltype(write_stream_)::capacity() + 1]; - for (size_t i = 0; i < sizeof(long_payload); i++) { + unsigned int long_payload_len = decltype(write_stream_)::capacity() + 1; + auto long_payload = std::make_unique(long_payload_len); + for (size_t i = 0; i < long_payload_len; i++) { long_payload[i] = i & 0xff; if (long_payload[i] == uint8_t(Escape::kEscapeStart)) { long_payload[i] = 0; } } - crc = tvm::runtime::micro_rpc::crc16_compute(long_payload, sizeof(long_payload), &crc); + crc = tvm::runtime::micro_rpc::crc16_compute(long_payload.get(), long_payload_len, &crc); EXPECT_EQ(kTvmErrorWriteStreamShortWrite, - unframer_.Write(long_payload, sizeof(long_payload), &bytes_consumed)); + unframer_.Write(long_payload.get(), long_payload_len, &bytes_consumed)); EXPECT_EQ(write_stream_.capacity(), bytes_consumed); - EXPECT_EQ(kTvmErrorNoError, unframer_.Write((uint8_t*)&crc, sizeof(crc), &bytes_consumed)); - EXPECT_EQ(2, bytes_consumed); // 2, because framer is now in kFindPacketStart. + EXPECT_EQ(kTvmErrorNoError, + unframer_.Write(reinterpret_cast(&crc), sizeof(crc), &bytes_consumed)); + EXPECT_EQ(2UL, bytes_consumed); // 2, because framer is now in kFindPacketStart. EXPECT_FALSE(write_stream_.packet_done()); EXPECT_FALSE(write_stream_.is_valid()); - EXPECT_EQ(std::string((char*)long_payload, write_stream_.capacity()), + EXPECT_EQ(std::string(reinterpret_cast(long_payload.get()), write_stream_.capacity()), write_stream_.BufferContents()); // Writing a smaller packet directly afterward should work. @@ -188,7 +179,7 @@ TEST_F(UnframerTest, PacketTooLong) { EXPECT_TRUE(write_stream_.packet_done()); EXPECT_TRUE(write_stream_.is_valid()); EXPECT_EQ(kPacket1.payload, write_stream_.BufferContents()); -}; +} class UnframerTestParameterized : public UnframerTest, public ::testing::WithParamInterface {}; @@ -210,7 +201,7 @@ TEST_P(UnframerTestParameterized, TestByteAtATime) { EXPECT_EQ(kTvmErrorNoError, unframer_.Write(reinterpret_cast(&GetParam()->wire[i]), 1, &bytes_consumed)); - EXPECT_EQ(1, bytes_consumed); + EXPECT_EQ(1UL, bytes_consumed); EXPECT_EQ(i == wire_size - 1, write_stream_.packet_done()); } EXPECT_TRUE(write_stream_.is_valid()); @@ -247,7 +238,7 @@ TEST_P(UnframerTestParameterized, TestArbitraryPacketReset) { unframer_.Reset(); write_stream_.Reset(); EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), 1, &bytes_consumed)); - EXPECT_EQ(1, bytes_consumed); + EXPECT_EQ(1UL, bytes_consumed); EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), wire_size, &bytes_consumed)); EXPECT_EQ(wire_size, bytes_consumed); EXPECT_TRUE(write_stream_.packet_done()); @@ -265,13 +256,13 @@ TEST_P(UnframerTestParameterized, TestArbitraryPacketReset) { // Interrupt the packet transmission. The first byte will return no error as it is the escape // byte. EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), 1, &bytes_consumed)); - EXPECT_EQ(1, bytes_consumed); + EXPECT_EQ(1UL, bytes_consumed); EXPECT_FALSE(write_stream_.packet_done()); // Secondt byte will return a short packet error. EXPECT_EQ(kTvmErrorFramingShortPacket, unframer_.Write(&GetParam()->wire_data()[1], 1, &bytes_consumed)); - EXPECT_EQ(0, bytes_consumed); + EXPECT_EQ(0UL, bytes_consumed); EXPECT_FALSE(write_stream_.packet_done()); EXPECT_EQ(kTvmErrorNoError, @@ -291,7 +282,7 @@ TEST_P(UnframerTestParameterized, TestArbitraryPacketReset) { // the internal state. EXPECT_EQ(kTvmErrorFramingShortPacket, unframer_.Write(GetParam()->wire_data(), wire_size, &bytes_consumed)); - EXPECT_EQ(1, bytes_consumed); + EXPECT_EQ(1UL, bytes_consumed); EXPECT_FALSE(write_stream_.packet_done()); EXPECT_EQ(kTvmErrorNoError, unframer_.Write(&GetParam()->wire_data()[1], wire_size - 1, &bytes_consumed)); @@ -309,9 +300,3 @@ TEST_P(UnframerTestParameterized, TestArbitraryPacketReset) { INSTANTIATE_TEST_CASE_P(UnframerTests, UnframerTestParameterized, ::testing::ValuesIn(TestPacket::instances)); #pragma GCC diagnostic pop - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/crt/func_registry_test.cc b/tests/crt/func_registry_test.cc index 2889f7b899a7..9f0e7f8d1a5a 100644 --- a/tests/crt/func_registry_test.cc +++ b/tests/crt/func_registry_test.cc @@ -22,8 +22,6 @@ #include #include -#include "platform.cc" - typedef struct { const char* a; const char* b; @@ -180,22 +178,23 @@ TEST(MutableFuncRegistry, Create) { for (unsigned int rem = 0; rem < kTvmAverageFuncEntrySizeBytes; rem++) { // test_function name will be used to test overfilling. - char test_function_name[kTvmAverageFunctionNameStrlenBytes + 2 + rem]; + auto test_function_name = + std::make_unique(kTvmAverageFunctionNameStrlenBytes + 2 + rem); TVMMutableFuncRegistry reg; memset(mem_buffer, 0, sizeof(mem_buffer)); EXPECT_EQ(kTvmErrorNoError, TVMMutableFuncRegistry_Create( ®, mem_buffer, kTvmAverageFuncEntrySizeBytes * 2 + rem)); - snprintf_truncate(test_function_name, kTvmAverageFunctionNameStrlenBytes + 1, + snprintf_truncate(test_function_name.get(), kTvmAverageFunctionNameStrlenBytes + 1, function_name_chars); // Add function #1, and verify it can be retrieved. - EXPECT_EQ(kTvmErrorNoError, - TVMMutableFuncRegistry_Set(®, test_function_name, TestFunctionHandle(0x01), 0)); + EXPECT_EQ(kTvmErrorNoError, TVMMutableFuncRegistry_Set(®, test_function_name.get(), + TestFunctionHandle(0x01), 0)); tvm_function_index_t func_index = 100; EXPECT_EQ(kTvmErrorNoError, - TVMFuncRegistry_Lookup(®.registry, test_function_name, &func_index)); + TVMFuncRegistry_Lookup(®.registry, test_function_name.get(), &func_index)); EXPECT_EQ(func_index, 0); TVMBackendPackedCFunc func = NULL; @@ -203,22 +202,23 @@ TEST(MutableFuncRegistry, Create) { EXPECT_EQ(func, TestFunctionHandle(0x01)); // Ensure that overfilling `names` by 1 char is not allowed. - snprintf_truncate(test_function_name, kTvmAverageFunctionNameStrlenBytes + rem + 2, + snprintf_truncate(test_function_name.get(), kTvmAverageFunctionNameStrlenBytes + rem + 2, function_name_chars + 1); - EXPECT_EQ(kTvmErrorFunctionRegistryFull, - TVMMutableFuncRegistry_Set(®, test_function_name, TestFunctionHandle(0x02), 0)); + EXPECT_EQ( + kTvmErrorFunctionRegistryFull, + TVMMutableFuncRegistry_Set(®, test_function_name.get(), TestFunctionHandle(0x02), 0)); EXPECT_EQ(kTvmErrorFunctionNameNotFound, - TVMFuncRegistry_Lookup(®.registry, test_function_name, &func_index)); + TVMFuncRegistry_Lookup(®.registry, test_function_name.get(), &func_index)); // Add function #2, with intentionally short (by 2 char) name. Verify it can be retrieved. - snprintf_truncate(test_function_name, kTvmAverageFunctionNameStrlenBytes - 2 + 1, + snprintf_truncate(test_function_name.get(), kTvmAverageFunctionNameStrlenBytes - 2 + 1, function_name_chars + 1); - EXPECT_EQ(kTvmErrorNoError, - TVMMutableFuncRegistry_Set(®, test_function_name, TestFunctionHandle(0x02), 0)); + EXPECT_EQ(kTvmErrorNoError, TVMMutableFuncRegistry_Set(®, test_function_name.get(), + TestFunctionHandle(0x02), 0)); EXPECT_EQ(kTvmErrorNoError, - TVMFuncRegistry_Lookup(®.registry, test_function_name, &func_index)); + TVMFuncRegistry_Lookup(®.registry, test_function_name.get(), &func_index)); EXPECT_EQ(func_index, 1); func = NULL; @@ -228,13 +228,8 @@ TEST(MutableFuncRegistry, Create) { // Try adding another function, which should fail due to lack of function pointers. test_function_name[0] = 'a'; test_function_name[1] = 0; - EXPECT_EQ(kTvmErrorFunctionRegistryFull, - TVMMutableFuncRegistry_Set(®, test_function_name, TestFunctionHandle(0x03), 0)); + EXPECT_EQ( + kTvmErrorFunctionRegistryFull, + TVMMutableFuncRegistry_Set(®, test_function_name.get(), TestFunctionHandle(0x03), 0)); } } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/crt/memory_test.cc b/tests/crt/page_allocator_test.cc similarity index 86% rename from tests/crt/memory_test.cc rename to tests/crt/page_allocator_test.cc index b531383058e6..924bf295ffd2 100644 --- a/tests/crt/memory_test.cc +++ b/tests/crt/page_allocator_test.cc @@ -32,13 +32,14 @@ static constexpr const unsigned int kNumUsablePages = static constexpr const unsigned int kPageSizeBytesLog = 8; // 256 byte pages. static constexpr const unsigned int kMemoryPoolSizeBytes = kTotalPages * (1 << kPageSizeBytesLog); -class MemoryManagerTest : public ::testing::Test { +class PageAllocatorTest : public ::testing::Test { protected: void SetUp() override { memset(raw_memory_pool, 0, sizeof(raw_memory_pool)); - memory_pool = (uint8_t*)(ROUND_UP(((uintptr_t)raw_memory_pool), (1 << kPageSizeBytesLog))); + memory_pool = reinterpret_cast( + ROUND_UP(((uintptr_t)raw_memory_pool), (1 << kPageSizeBytesLog))); PageMemoryManagerCreate(&interface, memory_pool, kMemoryPoolSizeBytes, kPageSizeBytesLog); - mgr = (MemoryManager*)interface; + mgr = reinterpret_cast(interface); ASSERT_EQ(kNumUsablePages, mgr->ptable.max_pages); dev_ = {kDLCPU, 0}; } @@ -57,7 +58,7 @@ class MemoryManagerTest : public ::testing::Test { #define EXPECT_PAGE(expected, actual) EXPECT_EQ(expected, AddressToPageNumber(actual)) -TEST_F(MemoryManagerTest, AllocFreeFifo) { +TEST_F(PageAllocatorTest, AllocFreeFifo) { EXPECT_EQ(interface->vleak_size, 0); for (int i = 0; i < 2; i++) { @@ -70,7 +71,7 @@ TEST_F(MemoryManagerTest, AllocFreeFifo) { } else { EXPECT_PAGE(kNumUsablePages - 1 - idx, a); } - EXPECT_EQ(interface->vleak_size, idx + 1); + EXPECT_EQ(static_cast(interface->vleak_size), idx + 1); ptrs[idx] = a; } @@ -80,9 +81,3 @@ TEST_F(MemoryManagerTest, AllocFreeFifo) { } } } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/crt/session_test.cc b/tests/crt/session_test.cc index 9840f55dc685..b6b58e819700 100644 --- a/tests/crt/session_test.cc +++ b/tests/crt/session_test.cc @@ -27,7 +27,6 @@ #include "buffer_write_stream.h" #include "crt_config.h" -#include "platform.cc" using ::tvm::runtime::micro_rpc::Framer; using ::tvm::runtime::micro_rpc::MessageType; @@ -52,7 +51,7 @@ class ReceivedMessage { class TestSession { public: - TestSession(uint8_t initial_nonce) + explicit TestSession(uint8_t initial_nonce) : framer{&framer_write_stream}, receive_buffer{receive_buffer_array, sizeof(receive_buffer_array)}, sess{&framer, &receive_buffer, TestSessionMessageReceivedThunk, this}, @@ -107,16 +106,6 @@ void TestSessionMessageReceivedThunk(void* context, MessageType message_type, Fr } } -void PrintTo(tvm_crt_error_t p, std::ostream* os) { - std::ios_base::fmtflags f(os->flags()); - *os << "tvm_crt_error_t(0x" << std::hex << std::setw(8) << std::setfill('0') << p << ")"; - os->flags(f); -} - -void PrintTo(ReceivedMessage msg, std::ostream* os) { - *os << "ReceivedMessage(" << int(msg.type) << ", \"" << msg.message << "\")"; -} - class SessionTest : public ::testing::Test { public: static constexpr const uint8_t kAliceNonce = 0x3c; @@ -158,7 +147,7 @@ TEST_F(SessionTest, NormalExchange) { bob_.WriteTo(&alice_); EXPECT_TRUE(alice_.sess.IsEstablished()); - ASSERT_EQ(alice_.messages_received.size(), 1); + ASSERT_EQ(alice_.messages_received.size(), 1UL); EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kStartSessionReply, "")); alice_.ClearBuffers(); @@ -167,7 +156,7 @@ TEST_F(SessionTest, NormalExchange) { "\xFF\xFD\b\0\0\0\x82" "f\x10hello\x90("); alice_.WriteTo(&bob_); - ASSERT_EQ(bob_.messages_received.size(), 2); + ASSERT_EQ(bob_.messages_received.size(), 2UL); EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kStartSessionReply, "")); EXPECT_EQ(bob_.messages_received[1], ReceivedMessage(MessageType::kNormal, "hello")); @@ -177,7 +166,7 @@ TEST_F(SessionTest, NormalExchange) { "\xff\xfd\b\0\0\0\x82" "f\x10ollehLv"); bob_.WriteTo(&alice_); - ASSERT_EQ(alice_.messages_received.size(), 1); + ASSERT_EQ(alice_.messages_received.size(), 1UL); EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kNormal, "olleh")); alice_.ClearBuffers(); @@ -186,13 +175,13 @@ TEST_F(SessionTest, NormalExchange) { alice_.sess.SendMessage(MessageType::kLog, reinterpret_cast("log1"), 4); EXPECT_FRAMED_PACKET(alice_, "\xff\xfd\a\0\0\0\0\0\x03log1\xf0\xd4"); alice_.WriteTo(&bob_); - ASSERT_EQ(bob_.messages_received.size(), 1); + ASSERT_EQ(bob_.messages_received.size(), 1UL); EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kLog, "log1")); bob_.sess.SendMessage(MessageType::kLog, reinterpret_cast("zero"), 4); EXPECT_FRAMED_PACKET(bob_, "\xff\xfd\a\0\0\0\0\0\x03zero\xb2h"); bob_.WriteTo(&alice_); - ASSERT_EQ(alice_.messages_received.size(), 1); + ASSERT_EQ(alice_.messages_received.size(), 1UL); EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kLog, "zero")); } @@ -200,13 +189,13 @@ TEST_F(SessionTest, LogBeforeSessionStart) { alice_.sess.SendMessage(MessageType::kLog, reinterpret_cast("log1"), 4); EXPECT_FRAMED_PACKET(alice_, "\xfe\xff\xfd\a\0\0\0\0\0\x03log1\xf0\xd4"); alice_.WriteTo(&bob_); - ASSERT_EQ(bob_.messages_received.size(), 1); + ASSERT_EQ(bob_.messages_received.size(), 1UL); EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kLog, "log1")); bob_.sess.SendMessage(MessageType::kLog, reinterpret_cast("zero"), 4); EXPECT_FRAMED_PACKET(bob_, "\xfe\xff\xfd\a\0\0\0\0\0\x03zero\xb2h"); bob_.WriteTo(&alice_); - ASSERT_EQ(alice_.messages_received.size(), 1); + ASSERT_EQ(alice_.messages_received.size(), 1UL); EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kLog, "zero")); } @@ -259,9 +248,3 @@ TEST_F(SessionTest, DoubleStart) { alice_.WriteTo(&bob_); EXPECT_TRUE(bob_.sess.IsEstablished()); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/crt/aot_memory_test.cc b/tests/crt/stack_allocator_test.cc similarity index 50% rename from tests/crt/aot_memory_test.cc rename to tests/crt/stack_allocator_test.cc index abda7bebf766..cd0c4a8b65e2 100644 --- a/tests/crt/aot_memory_test.cc +++ b/tests/crt/stack_allocator_test.cc @@ -16,91 +16,134 @@ * specific language governing permissions and limitations * under the License. */ -#include -#include #include "../../src/runtime/crt/memory/stack_allocator.c" -#include "platform.cc" + +#include +#include // Check with LIFO checks enabled for stack allocator #define TVM_CRT_STACK_ALLOCATOR_ENABLE_LIFO_CHECK + +// Number of memory misalignment in bytes +#define NUM_MEMORY_MISALIGNMENT_BYTES 1 + +/*! + * Align memory pointer. + * This function modifies memory_ptr to adjust alignment. + * \return Number of memory offset. + */ +static uint32_t align_pointer(uint8_t** memory_ptr) { + uint32_t extra = (uintptr_t)(*memory_ptr) % TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES; + uint32_t offset = + (TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES - extra) & (TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES - 1); + *memory_ptr += offset; + return offset; +} + +/*! + * Add misalignment to memory pointer. + * This function modifies memory_ptr. + * \return Number of memory offset. + */ +static uint32_t misalign_pointer(uint8_t** memory_ptr) { + uint32_t extra = (uintptr_t)(*memory_ptr) % TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES; + if (extra == 0) { + *memory_ptr += NUM_MEMORY_MISALIGNMENT_BYTES; + return 1; + } + return 0; +} + /* - * Tests allocations are properly aligned when allocated + * Tests allocations are properly aligned when allocated. */ -TEST(AOTMemory, Allocate) { - static uint8_t model_memory[96]; +TEST(StackAllocatorTest, Allocate) { + static uint8_t model_memory[128]; tvm_workspace_t tvm_runtime_workspace; + uint8_t* model_memory_ptr = model_memory; + uint32_t offset = align_pointer(&model_memory_ptr); + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory_ptr, + sizeof(model_memory) - offset), + kTvmErrorNoError); - ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory, 96), kTvmErrorNoError); void* block_one = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_one, 1), kTvmErrorNoError); - ASSERT_EQ(block_one, &model_memory[0]); + ASSERT_EQ(block_one, &model_memory_ptr[0]); void* block_two = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 2, &block_two, 1), kTvmErrorNoError); - ASSERT_EQ(block_two, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(block_two, &model_memory_ptr[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); void* two_blocks = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 24, &two_blocks, 1), kTvmErrorNoError); - ASSERT_EQ(two_blocks, &model_memory[32 + 2 * STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(two_blocks, &model_memory_ptr[32 + 2 * STACK_ALLOCATOR_TAG_SIZE_BYTES]); void* block_three = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_three, 1), kTvmErrorNoError); - ASSERT_EQ(block_three, &model_memory[64 + 3 * STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(block_three, &model_memory_ptr[64 + 3 * STACK_ALLOCATOR_TAG_SIZE_BYTES]); } /* - * Tests resetting the stack after dealloc + * Tests resetting the stack after dealloc. */ -TEST(AOTMemory, Free) { +TEST(StackAllocatorTest, Free) { static uint8_t model_memory[80]; tvm_workspace_t tvm_runtime_workspace; - ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory, 80), kTvmErrorNoError); + uint8_t* model_memory_ptr = model_memory; + uint32_t offset = align_pointer(&model_memory_ptr); + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory_ptr, + sizeof(model_memory) - offset), + kTvmErrorNoError); void* block_one = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_one, 1), kTvmErrorNoError); - ASSERT_EQ(block_one, &model_memory[0]); + ASSERT_EQ(block_one, &model_memory_ptr[0]); void* block_two = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_two, 1), kTvmErrorNoError); - ASSERT_EQ(block_two, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(block_two, &model_memory_ptr[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); ASSERT_EQ(kTvmErrorNoError, StackMemoryManager_Free_Body(&tvm_runtime_workspace, block_two, 1)); void* two_blocks = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 2, &two_blocks, 1), kTvmErrorNoError); - ASSERT_EQ(two_blocks, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(two_blocks, &model_memory_ptr[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); ASSERT_EQ(kTvmErrorNoError, StackMemoryManager_Free_Body(&tvm_runtime_workspace, two_blocks, 1)); void* block_three = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_three, 1), kTvmErrorNoError); - ASSERT_EQ(block_three, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(block_three, &model_memory_ptr[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); } /* - * Tests we return NULL if we over allocate + * Tests we return NULL if we over allocate. */ -TEST(AOTMemory, OverAllocate) { - static uint8_t model_memory[72]; +TEST(StackAllocatorTest, OverAllocate) { + static uint8_t model_memory[80]; tvm_workspace_t tvm_runtime_workspace; - ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory, 80), kTvmErrorNoError); + uint8_t* model_memory_ptr = model_memory; + uint32_t offset = align_pointer(&model_memory_ptr); + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory_ptr, + sizeof(model_memory) - offset), + kTvmErrorNoError); void* block_one = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_one, 1), kTvmErrorNoError); - ASSERT_EQ(block_one, &model_memory[0]); + ASSERT_EQ(block_one, &model_memory_ptr[0]); void* block_two = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_two, 1), kTvmErrorNoError); - ASSERT_EQ(block_two, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(block_two, &model_memory_ptr[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); void* two_blocks = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 64, &two_blocks, 1), @@ -109,29 +152,50 @@ TEST(AOTMemory, OverAllocate) { } /* - * Test for out-of-order memory deallocation + * Test for out-of-order memory deallocation. */ -TEST(AOTMemory, FreeOutOfOrder) { +TEST(StackAllocatorTest, FreeOutOfOrder) { static uint8_t model_memory[80]; tvm_workspace_t tvm_runtime_workspace; - ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory, 80), kTvmErrorNoError); + uint8_t* model_memory_ptr = model_memory; + uint32_t offset = align_pointer(&model_memory_ptr); + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory_ptr, + sizeof(model_memory) - offset), + kTvmErrorNoError); void* block_one = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_one, 1), kTvmErrorNoError); - ASSERT_EQ(block_one, &model_memory[0]); + ASSERT_EQ(block_one, &model_memory_ptr[0]); void* block_two = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_two, 1), kTvmErrorNoError); - ASSERT_EQ(block_two, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(block_two, &model_memory_ptr[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); ASSERT_EQ(StackMemoryManager_Free_Body(&tvm_runtime_workspace, block_one, 1), kTvmErrorPlatformStackAllocBadFree); } -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); +/* + * Test for initial memory misalignment. + */ +TEST(StackAllocatorTest, InitialMemoryMisAlignment) { + static uint8_t model_memory[80]; + tvm_workspace_t tvm_runtime_workspace; + uint8_t* model_memory_ptr = model_memory; + + // Add misaslignment to memory pointer + uint32_t offset = misalign_pointer(&model_memory_ptr); + + // Calculate expected offset + uint8_t* misaligned_ptr = model_memory_ptr; + uint32_t alignment_offset = align_pointer(&misaligned_ptr); + + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory_ptr, + sizeof(model_memory) - offset), + kTvmErrorNoError); + + ASSERT_EQ(tvm_runtime_workspace.next_alloc, &model_memory_ptr[alignment_offset]); + ASSERT_EQ(tvm_runtime_workspace.workspace_size, sizeof(model_memory) - offset - alignment_offset); } diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index 7250a3ccd7d0..ed7288ef00d4 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -80,6 +80,12 @@ "idl", # opencl file "cl", + # zephyr config file + "conf", + # arduino sketch file + "ino", + # linker scripts + "ld", } # List of file names allowed @@ -127,35 +133,17 @@ # pytest config "pytest.ini", # microTVM tests - "tests/micro/zephyr/testdata/digit-2.jpg", - "tests/micro/zephyr/testdata/digit-9.jpg", - "tests/micro/zephyr/testdata/mnist-8.onnx", - "tests/micro/zephyr/testdata/ic_sample_fp32_8.npy", + "tests/micro/testdata/mnist/digit-2.jpg", + "tests/micro/testdata/mnist/digit-9.jpg", + "tests/micro/testdata/mnist/mnist-8.onnx", + "tests/micro/testdata/kws/yes_no.tflite", # microTVM Zephyr runtime - "apps/microtvm/zephyr/qemu-hack/qemu-system-i386", - "apps/microtvm/zephyr/qemu-hack/qemu-system-arm", - "apps/microtvm/zephyr/qemu-hack/qemu-system-riscv32", - "apps/microtvm/zephyr/qemu-hack/qemu-system-riscv64", - "apps/microtvm/zephyr/host_driven/prj.conf", - "apps/microtvm/zephyr/host_driven/boards/qemu_x86.conf", - "apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf", - "apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf", - "apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf", - "apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf", - "apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf", - "apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf", - "apps/microtvm/zephyr/host_driven/boards/nucleo_l4r5zi.conf", - "apps/microtvm/zephyr/host_driven/qemu-hack", - "apps/microtvm/zephyr/aot_demo/prj.conf", - "apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf", - "apps/microtvm/zephyr/aot_demo/boards/qemu_riscv32.conf", - "apps/microtvm/zephyr/aot_demo/boards/qemu_riscv64.conf", - "apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf", - "apps/microtvm/zephyr/aot_demo/boards/nucleo_f746zg.conf", - "apps/microtvm/zephyr/aot_demo/boards/stm32f746g_disco.conf", - "apps/microtvm/zephyr/aot_demo/boards/mps2_an521.conf", - "apps/microtvm/zephyr/aot_demo/boards/nucleo_l4r5zi.conf", - "apps/microtvm/zephyr/aot_demo/qemu-hack", + "apps/microtvm/zephyr/template_project/CMakeLists.txt.template", + "apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm", + "apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64", + "apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-i386", + "apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32", + "apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64", # microTVM Virtual Machines "apps/microtvm/reference-vm/zephyr/Vagrantfile", "apps/microtvm/reference-vm/zephyr/base-box/Vagrantfile.packer-template", diff --git a/tests/lint/cpplint.sh b/tests/lint/cpplint.sh index 8836ee4321d9..31eb1d94a347 100755 --- a/tests/lint/cpplint.sh +++ b/tests/lint/cpplint.sh @@ -19,5 +19,6 @@ python3 3rdparty/dmlc-core/scripts/lint.py vta cpp vta/include vta/src python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp \ - include src \ - examples/extension/src examples/graph_executor/src + include src \ + examples/extension/src examples/graph_executor/src \ + tests/cpp tests/crt diff --git a/apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf b/tests/lint/flake8.sh old mode 100644 new mode 100755 similarity index 75% rename from apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf rename to tests/lint/flake8.sh index 505f1babc3f4..43ade61c7893 --- a/apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf +++ b/tests/lint/flake8.sh @@ -1,3 +1,4 @@ +#!/bin/bash -e # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,15 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -# This file is specific to the STM32F746G Discovery board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y -# For debugging. -CONFIG_LED=n +# Disabled until docker images are rebuilt +# python3 -m flake8 . --count --select=E9,F63,F7 --show-source --statistics diff --git a/tests/lint/rat-excludes b/tests/lint/rat-excludes index 5f0445134dea..3dff79c565ce 100644 --- a/tests/lint/rat-excludes +++ b/tests/lint/rat-excludes @@ -20,6 +20,9 @@ .*\.interp .*\.tokens +# microTVM test data files +testdata + # Generated modules .*\.egg-info .*gen_modules diff --git a/tests/lint/rust_format.sh b/tests/lint/rust_format.sh new file mode 100755 index 000000000000..10c8feec1fcf --- /dev/null +++ b/tests/lint/rust_format.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +TVM_HOME="$(git rev-parse --show-toplevel)" +RUST_DIR="$TVM_HOME/rust" + +if [[ "$1" == "-i" ]]; then + INPLACE_FORMAT=1 + shift 1 +else + INPLACE_FORMAT=0 +fi + +cd $RUST_DIR + +if [[ ${INPLACE_FORMAT} -eq 1 ]]; then + cargo fmt +else + cargo fmt -- --check +fi diff --git a/tests/micro/arduino/.gitignore b/tests/micro/arduino/.gitignore new file mode 100644 index 000000000000..30bf0f9bc376 --- /dev/null +++ b/tests/micro/arduino/.gitignore @@ -0,0 +1 @@ +workspace* diff --git a/apps/microtvm/zephyr/aot_demo/README.md b/tests/micro/arduino/README.md similarity index 59% rename from apps/microtvm/zephyr/aot_demo/README.md rename to tests/micro/arduino/README.md index a718da65e2fa..78e63cabb7e2 100644 --- a/apps/microtvm/zephyr/aot_demo/README.md +++ b/tests/micro/arduino/README.md @@ -15,6 +15,21 @@ -This directory contains a Zephyr-based ahead of time (AOT) "demo" runtime environment that -pulls together the microTVM runtime dependencies into a single application -that can run TVM on a microTVM device without the need to a host. +This directory contains tests for MicroTVM's integration with Arduino. + +To run the test, you first need to be running in a Python environment with +all of the appropriate TVM dependencies installed. You can run the test with: + +``` +$ cd tvm/tests/micro/arduino +$ pytest --microtvm-platforms spresense +``` + +Most of these tests require a supported Arduino board to be connected. +If you don't want to run these tests, you can pass the flag +`--test-build-only` to only test project generation and compilation. + +To see the list of supported values for `----microtvm-platforms`, run: +``` +$ pytest --help +``` diff --git a/tests/micro/arduino/conftest.py b/tests/micro/arduino/conftest.py new file mode 100644 index 000000000000..bcb2bddf2cab --- /dev/null +++ b/tests/micro/arduino/conftest.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import pathlib + +import pytest +import tvm.target.target + +# The models that should pass this configuration. Maps a short, identifying platform string to +# (model, zephyr_board). +PLATFORMS = { + "due": ("sam3x8e", "due"), + "feathers2": ("esp32", "feathers2"), + "nano33ble": ("nrf52840", "nano33ble"), + "pybadge": ("atsamd51", "pybadge"), + "spresense": ("cxd5602gg", "spresense"), + "teensy40": ("imxrt1060", "teensy40"), + "teensy41": ("imxrt1060", "teensy41"), + "wioterminal": ("atsamd51", "wioterminal"), +} + +TEMPLATE_PROJECT_DIR = ( + pathlib.Path(__file__).parent + / ".." + / ".." + / ".." + / "apps" + / "microtvm" + / "arduino" + / "template_project" +).resolve() + + +def pytest_addoption(parser): + parser.addoption( + "--microtvm-platforms", + nargs="+", + required=True, + choices=PLATFORMS.keys(), + help="Target platforms for microTVM tests.", + ) + parser.addoption( + "--arduino-cli-cmd", + default="arduino-cli", + help="Path to `arduino-cli` command for flashing device.", + ) + parser.addoption( + "--test-build-only", + action="store_true", + help="Only run tests that don't require physical hardware.", + ) + parser.addoption( + "--tvm-debug", + action="store_true", + default=False, + help="If given, enable a debug session while the test is running. Before running the test, in a separate shell, you should run: ", + ) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "requires_hardware: mark test to run only when an Arduino board is connected" + ) + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--test-build-only"): + skip_hardware_tests = pytest.mark.skip(reason="--test-build-only was passed") + for item in items: + if "requires_hardware" in item.keywords: + item.add_marker(skip_hardware_tests) + + +# We might do project generation differently for different boards in the future +# (to take advantage of multiple cores / external memory / etc.), so all tests +# are parameterized by board +def pytest_generate_tests(metafunc): + platforms = metafunc.config.getoption("microtvm_platforms") + metafunc.parametrize("platform", platforms, scope="session") + + +@pytest.fixture(scope="session") +def arduino_cli_cmd(request): + return request.config.getoption("--arduino-cli-cmd") + + +@pytest.fixture(scope="session") +def tvm_debug(request): + return request.config.getoption("--tvm-debug") + + +def make_workspace_dir(test_name, platform): + _, arduino_board = PLATFORMS[platform] + filepath = pathlib.Path(__file__) + board_workspace = ( + filepath.parent + / f"workspace_{test_name}_{arduino_board}" + / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + ) + + number = 0 + while board_workspace.exists(): + number += 1 + board_workspace = pathlib.Path(str(board_workspace) + f"-{number}") + board_workspace.parent.mkdir(exist_ok=True, parents=True) + t = tvm.contrib.utils.tempdir(board_workspace) + # time.sleep(200) + return t diff --git a/tests/micro/arduino/test_arduino_rpc_server.py b/tests/micro/arduino/test_arduino_rpc_server.py new file mode 100644 index 000000000000..1b165a02e9d1 --- /dev/null +++ b/tests/micro/arduino/test_arduino_rpc_server.py @@ -0,0 +1,368 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This unit test simulates an autotuning workflow, where we: +1. Instantiate the Arduino RPC server project +2. Build and flash that project onto our target board + +""" + +import datetime +import pathlib +import sys + +import numpy as np +import onnx +import pytest +import tvm +from PIL import Image +from tvm import micro, relay +from tvm.relay.testing import byoc + +import conftest + + +# We'll make a new workspace for each test +@pytest.fixture(scope="function") +def workspace_dir(platform): + return conftest.make_workspace_dir("arduino_rpc_server", platform) + + +def _make_session(model, arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config): + project = tvm.micro.generate_project( + str(conftest.TEMPLATE_PROJECT_DIR), + mod, + workspace_dir / "project", + { + "arduino_board": arduino_board, + "arduino_cli_cmd": arduino_cli_cmd, + "project_type": "host_driven", + "verbose": bool(build_config.get("debug")), + }, + ) + project.build() + project.flash() + return tvm.micro.Session(project.transport()) + + +def _make_sess_from_op( + model, arduino_board, arduino_cli_cmd, workspace_dir, op_name, sched, arg_bufs, build_config +): + target = tvm.target.target.micro(model) + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.build(sched, arg_bufs, target=target, name=op_name) + + return _make_session(model, arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config) + + +def _make_add_sess(model, arduino_board, arduino_cli_cmd, workspace_dir, build_config): + A = tvm.te.placeholder((2,), dtype="int8") + B = tvm.te.placeholder((1,), dtype="int8") + C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name="C") + sched = tvm.te.create_schedule(C.op) + return _make_sess_from_op( + model, arduino_board, arduino_cli_cmd, workspace_dir, "add", sched, [A, B, C], build_config + ) + + +# The same test code can be executed on both the QEMU simulation and on real hardware. +@tvm.testing.requires_micro +@pytest.mark.requires_hardware +def test_compile_runtime(platform, arduino_cli_cmd, tvm_debug, workspace_dir): + """Test compiling the on-device runtime.""" + + model, arduino_board = conftest.PLATFORMS[platform] + build_config = {"debug": tvm_debug} + + # NOTE: run test in a nested function so cPython will delete arrays before closing the session. + def test_basic_add(sess): + A_data = tvm.nd.array(np.array([2, 3], dtype="int8"), device=sess.device) + assert (A_data.numpy() == np.array([2, 3])).all() + B_data = tvm.nd.array(np.array([4], dtype="int8"), device=sess.device) + assert (B_data.numpy() == np.array([4])).all() + C_data = tvm.nd.array(np.array([0, 0], dtype="int8"), device=sess.device) + assert (C_data.numpy() == np.array([0, 0])).all() + + system_lib = sess.get_system_lib() + system_lib.get_function("add")(A_data, B_data, C_data) + assert (C_data.numpy() == np.array([6, 7])).all() + + with _make_add_sess(model, arduino_board, arduino_cli_cmd, workspace_dir, build_config) as sess: + test_basic_add(sess) + + +@tvm.testing.requires_micro +@pytest.mark.requires_hardware +def test_platform_timer(platform, arduino_cli_cmd, tvm_debug, workspace_dir): + """Test compiling the on-device runtime.""" + + model, arduino_board = conftest.PLATFORMS[platform] + build_config = {"debug": tvm_debug} + + # NOTE: run test in a nested function so cPython will delete arrays before closing the session. + def test_basic_add(sess): + A_data = tvm.nd.array(np.array([2, 3], dtype="int8"), device=sess.device) + assert (A_data.numpy() == np.array([2, 3])).all() + B_data = tvm.nd.array(np.array([4], dtype="int8"), device=sess.device) + assert (B_data.numpy() == np.array([4])).all() + C_data = tvm.nd.array(np.array([0, 0], dtype="int8"), device=sess.device) + assert (C_data.numpy() == np.array([0, 0])).all() + + system_lib = sess.get_system_lib() + time_eval_f = system_lib.time_evaluator( + "add", sess.device, number=20, repeat=3, min_repeat_ms=40 + ) + result = time_eval_f(A_data, B_data, C_data) + assert (C_data.numpy() == np.array([6, 7])).all() + assert result.mean > 0 + assert len(result.results) == 3 + + with _make_add_sess(model, arduino_board, arduino_cli_cmd, workspace_dir, build_config) as sess: + test_basic_add(sess) + + +@tvm.testing.requires_micro +@pytest.mark.requires_hardware +def test_relay(platform, arduino_cli_cmd, tvm_debug, workspace_dir): + """Testing a simple relay graph""" + model, arduino_board = conftest.PLATFORMS[platform] + build_config = {"debug": tvm_debug} + + shape = (10,) + dtype = "int8" + + # Construct Relay program. + x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) + xx = relay.multiply(x, x) + z = relay.add(xx, relay.const(np.ones(shape=shape, dtype=dtype))) + func = relay.Function([x], z) + + target = tvm.target.target.micro(model) + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build(func, target=target) + + with _make_session( + model, arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config + ) as session: + graph_mod = tvm.micro.create_local_graph_executor( + mod.get_graph_json(), session.get_system_lib(), session.device + ) + graph_mod.set_input(**mod.get_params()) + x_in = np.random.randint(10, size=shape[0], dtype=dtype) + graph_mod.run(x=x_in) + result = graph_mod.get_output(0).numpy() + tvm.testing.assert_allclose(graph_mod.get_input(0).numpy(), x_in) + tvm.testing.assert_allclose(result, x_in * x_in + 1) + + +@tvm.testing.requires_micro +@pytest.mark.requires_hardware +def test_onnx(platform, arduino_cli_cmd, tvm_debug, workspace_dir): + """Testing a simple ONNX model.""" + model, arduino_board = conftest.PLATFORMS[platform] + build_config = {"debug": tvm_debug} + + # Load test images. + this_dir = pathlib.Path(__file__).parent + mnist_testdata = this_dir.parent / "testdata" / "mnist" + digit_2 = Image.open(mnist_testdata / "digit-2.jpg").resize((28, 28)) + digit_2 = np.asarray(digit_2).astype("float32") + digit_2 = np.expand_dims(digit_2, axis=0) + + digit_9 = Image.open(mnist_testdata / "digit-9.jpg").resize((28, 28)) + digit_9 = np.asarray(digit_9).astype("float32") + digit_9 = np.expand_dims(digit_9, axis=0) + + # Load ONNX model and convert to Relay. + onnx_model = onnx.load(mnist_testdata / "mnist-8.onnx") + shape = {"Input3": (1, 1, 28, 28)} + relay_mod, params = relay.frontend.from_onnx(onnx_model, shape=shape, freeze_params=True) + relay_mod = relay.transform.DynamicToStatic()(relay_mod) + + target = tvm.target.target.micro(model, options=["-link-params=1"]) + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + lowered = relay.build(relay_mod, target, params=params) + graph = lowered.get_graph_json() + + with _make_session( + model, arduino_board, arduino_cli_cmd, workspace_dir, lowered, build_config + ) as session: + graph_mod = tvm.micro.create_local_graph_executor( + graph, session.get_system_lib(), session.device + ) + + # Send the digit-2 image and confirm that the correct result is returned. + graph_mod.set_input("Input3", tvm.nd.array(digit_2)) + graph_mod.run() + result = graph_mod.get_output(0).numpy() + print(result) + assert np.argmax(result) == 2 + + # Send the digit-9 image and confirm that the correct result is returned. + graph_mod.set_input("Input3", tvm.nd.array(digit_9)) + graph_mod.run() + result = graph_mod.get_output(0).numpy() + assert np.argmax(result) == 9 + + +def check_result( + relay_mod, + model, + arduino_board, + arduino_cli_cmd, + workspace_dir, + map_inputs, + out_shape, + result, + build_config, +): + """Helper function to verify results""" + TOL = 1e-5 + target = tvm.target.target.micro(model) + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build(relay_mod, target=target) + + with _make_session( + model, arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config + ) as session: + rt_mod = tvm.micro.create_local_graph_executor( + mod.get_graph_json(), session.get_system_lib(), session.device + ) + rt_mod.set_input(**mod.get_params()) + for name, data in map_inputs.items(): + rt_mod.set_input(name, data) + rt_mod.set_input(**mod.get_params()) + rt_mod.run() + + out_shapes = out_shape if isinstance(out_shape, list) else [out_shape] + results = result if isinstance(result, list) else [result] + + for idx, shape in enumerate(out_shapes): + out = tvm.nd.empty(shape, device=session.device) + out = rt_mod.get_output(idx, out) + tvm.testing.assert_allclose(out.numpy(), results[idx], rtol=TOL, atol=TOL) + + +@tvm.testing.requires_micro +@pytest.mark.requires_hardware +def test_byoc_microtvm(platform, arduino_cli_cmd, tvm_debug, workspace_dir): + """This is a simple test case to check BYOC capabilities of microTVM""" + model, arduino_board = conftest.PLATFORMS[platform] + build_config = {"debug": tvm_debug} + + x = relay.var("x", shape=(10, 10)) + w0 = relay.var("w0", shape=(10, 10)) + w1 = relay.var("w1", shape=(10, 10)) + w2 = relay.var("w2", shape=(10, 10)) + w3 = relay.var("w3", shape=(10, 10)) + w4 = relay.var("w4", shape=(10, 10)) + w5 = relay.var("w5", shape=(10, 10)) + w6 = relay.var("w6", shape=(10, 10)) + w7 = relay.var("w7", shape=(10, 10)) + + # C compiler + z0 = relay.add(x, w0) + p0 = relay.subtract(z0, w1) + q0 = relay.multiply(p0, w2) + + z1 = relay.add(x, w3) + p1 = relay.subtract(z1, w4) + q1 = relay.multiply(p1, w5) + + # Other parts on TVM + z2 = relay.add(x, w6) + q2 = relay.subtract(z2, w7) + + r = relay.concatenate((q0, q1, q2), axis=0) + f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) + mod = tvm.IRModule() + ann = byoc.CcompilerAnnotator() + mod["main"] = ann.visit(f) + mod = tvm.relay.transform.PartitionGraph()(mod) + mod = tvm.relay.transform.InferType()(mod) + + x_data = np.random.rand(10, 10).astype("float32") + w_data = [] + for _ in range(8): + w_data.append(np.random.rand(10, 10).astype("float32")) + + map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} + map_inputs["x"] = x_data + check_result( + relay_mod=mod, + map_inputs=map_inputs, + out_shape=(30, 10), + result=np.concatenate( + ( + ((x_data + w_data[0]) - w_data[1]) * w_data[2], + ((x_data + w_data[3]) - w_data[4]) * w_data[5], + x_data + w_data[6] - w_data[7], + ), + axis=0, + ), + model=model, + build_config=build_config, + arduino_board=arduino_board, + arduino_cli_cmd=arduino_cli_cmd, + workspace_dir=workspace_dir, + ) + + +def _make_add_sess_with_shape( + model, arduino_board, arduino_cli_cmd, workspace_dir, shape, build_config +): + A = tvm.te.placeholder(shape, dtype="int8") + C = tvm.te.compute(A.shape, lambda i: A[i] + A[i], name="C") + sched = tvm.te.create_schedule(C.op) + return _make_sess_from_op( + model, arduino_board, arduino_cli_cmd, workspace_dir, "add", sched, [A, C], build_config + ) + + +@pytest.mark.parametrize( + "shape,", + [ + pytest.param((1 * 1024,), id="(1*1024)"), + pytest.param((4 * 1024,), id="(4*1024)"), + pytest.param((16 * 1024,), id="(16*1024)"), + ], +) +@tvm.testing.requires_micro +@pytest.mark.requires_hardware +def test_rpc_large_array(platform, arduino_cli_cmd, tvm_debug, workspace_dir, shape): + """Test large RPC array transfer.""" + model, arduino_board = conftest.PLATFORMS[platform] + build_config = {"debug": tvm_debug} + + # NOTE: run test in a nested function so cPython will delete arrays before closing the session. + def test_tensors(sess): + a_np = np.random.randint(low=-128, high=127, size=shape, dtype="int8") + + A_data = tvm.nd.array(a_np, device=sess.device) + assert (A_data.numpy() == a_np).all() + C_data = tvm.nd.array(np.zeros(shape, dtype="int8"), device=sess.device) + assert (C_data.numpy() == np.zeros(shape)).all() + + with _make_add_sess_with_shape( + model, arduino_board, arduino_cli_cmd, workspace_dir, shape, build_config + ) as sess: + test_tensors(sess) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/micro/arduino/test_arduino_workflow.py b/tests/micro/arduino/test_arduino_workflow.py new file mode 100644 index 000000000000..101d36f9bd2d --- /dev/null +++ b/tests/micro/arduino/test_arduino_workflow.py @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import pathlib +import shutil +import sys + +import pytest +import tvm +from tvm import micro, relay + +import conftest + +""" +This unit test simulates a simple user workflow, where we: +1. Generate a base sketch using a simple audio model +2. Modify the .ino file, much like a user would +3. Compile the sketch for the target platform +-- If physical hardware is present -- +4. Upload the sketch to a connected board +5. Open a serial connection to the board +6. Use serial connection to ensure model behaves correctly +""" + + +# Since these tests are sequential, we'll use the same project for all tests +@pytest.fixture(scope="module") +def workspace_dir(request, platform): + return conftest.make_workspace_dir("arduino_workflow", platform) + + +@pytest.fixture(scope="module") +def project_dir(workspace_dir): + return workspace_dir / "project" + + +def _generate_project(arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config): + return tvm.micro.generate_project( + str(conftest.TEMPLATE_PROJECT_DIR), + mod, + workspace_dir / "project", + { + "arduino_board": arduino_board, + "arduino_cli_cmd": arduino_cli_cmd, + "project_type": "example_project", + "verbose": bool(build_config.get("debug")), + }, + ) + + +# We MUST pass workspace_dir, not project_dir, or the workspace will be dereferenced too soon +@pytest.fixture(scope="module") +def project(platform, arduino_cli_cmd, tvm_debug, workspace_dir): + this_dir = pathlib.Path(__file__).parent + model, arduino_board = conftest.PLATFORMS[platform] + build_config = {"debug": tvm_debug} + + with open(this_dir.parent / "testdata" / "kws" / "yes_no.tflite", "rb") as f: + tflite_model_buf = f.read() + + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 + try: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + except AttributeError: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) + + mod, params = relay.frontend.from_tflite(tflite_model) + target = tvm.target.target.micro( + model, options=["--link-params=1", "--unpacked-api=1", "--executor=aot"] + ) + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = relay.build(mod, target, params=params) + + return _generate_project(arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config) + + +def _get_directory_elements(directory): + return set(f.name for f in directory.iterdir()) + + +def test_project_folder_structure(project_dir, project): + assert set(["microtvm_api_server.py", "project.ino", "src"]).issubset( + _get_directory_elements(project_dir) + ) + + source_dir = project_dir / "src" + assert _get_directory_elements(source_dir) == set( + ["model", "standalone_crt", "model.c", "model.h"] + ) + + +def test_project_model_integrity(project_dir, project): + model_dir = project_dir / "src" / "model" + assert _get_directory_elements(model_dir) == set( + ["default_lib0.c", "default_lib1.c", "model.tar"] + ) + + +def test_model_header_templating(project_dir, project): + # Ensure model.h was templated with correct WORKSPACE_SIZE + with (project_dir / "src" / "model.h").open() as f: + model_h = f.read() + assert "#define WORKSPACE_SIZE 21312" in model_h + + +def test_import_rerouting(project_dir, project): + # Check one file to ensure imports were rerouted + runtime_path = project_dir / "src" / "standalone_crt" / "src" / "runtime" + c_backend_api_path = runtime_path / "crt" / "common" / "crt_backend_api.c" + assert c_backend_api_path.exists() + + with c_backend_api_path.open() as f: + c_backend_api_c = f.read() + assert '#include "inttypes.h"' in c_backend_api_c + assert "include/tvm/runtime/crt/platform.h" in c_backend_api_c + + +# Build on top of the generated project by replacing the +# top-level .ino fileand adding data input files, much +# like a user would +@pytest.fixture(scope="module") +def modified_project(project_dir, project): + this_dir = pathlib.Path(__file__).parent + kws_testdata_dir = this_dir.parent / "testdata" / "kws" + arduino_testdata_dir = this_dir / "testdata" + + shutil.copy2(arduino_testdata_dir / "project.ino", project_dir / "project.ino") + + project_data_dir = project_dir / "src" / "data" + project_data_dir.mkdir() + for sample in ["yes.c", "no.c", "silence.c", "unknown.c"]: + shutil.copy2(kws_testdata_dir / sample, project_data_dir / sample) + + return project + + +@pytest.fixture(scope="module") +def compiled_project(modified_project): + modified_project.build() + return modified_project + + +def test_compile_yes_no_project(project_dir, project, compiled_project): + build_dir = project_dir / "build" + assert build_dir.exists() + first_build_file = next(build_dir.iterdir(), None) + assert first_build_file is not None + + +"""------------------------------------------------------------ +If we're not running on real hardware, no further tests are run +------------------------------------------------------------""" + + +@pytest.fixture(scope="module") +def uploaded_project(compiled_project): + compiled_project.flash() + return compiled_project + + +""" Sample serial output: + +category,runtime,yes,no,silence,unknown +yes,56762,115,-123,-125,-123, +no,56762,-128,4,-123,-9, +silence,56792,-128,-118,107,-117, +unknown,56792,-128,-125,-128,125, +""" +SERIAL_OUTPUT_HEADERS = "category,runtime,yes,no,silence,unknown" + + +@pytest.fixture(scope="module") +def serial_output(uploaded_project): + transport = uploaded_project.transport() + transport.open() + out = transport.read(2048, -1) + out_str = out.decode("utf-8") + out_lines = out_str.split("\r\n") + + assert SERIAL_OUTPUT_HEADERS in out_lines + headers_index = out_lines.index(SERIAL_OUTPUT_HEADERS) + data_lines = out_lines[headers_index + 1 : headers_index + 5] + split_lines = [line.split(",") for line in data_lines] + + return [[line[0]] + list(map(int, line[1:6])) for line in split_lines] + + +TENSORFLOW_EVALUATIONS = { + "yes": [115, -123, -125, -123], + "no": [-128, 4, -123, -9], + "silence": [-128, -118, 107, -117], + "unknown": [-128, -125, -128, 125], +} +MAX_PREDICTION_DIFFERENCE = 2 + + +@pytest.mark.requires_hardware +def test_project_inference_correctness(serial_output): + predictions = {line[0]: line[2:] for line in serial_output} + + for sample, prediction in predictions.items(): + # Due to rounding issues, we don't get the *exact* same + # values as Tensorflow gives, but they're pretty close + + reference_prediction = TENSORFLOW_EVALUATIONS[sample] + deltas = [prediction[i] - reference_prediction[i] for i in range(4)] + assert max(deltas) < MAX_PREDICTION_DIFFERENCE + + +MAX_INFERENCE_TIME_US = 200 * 1000 +MAX_INFERENCE_TIME_RANGE_US = 1000 + + +@pytest.mark.requires_hardware +def test_project_inference_runtime(serial_output): + runtimes_us = [line[1] for line in serial_output] + + # Inference time will vary based on architecture + # and clock speed. However, anything more than 200 ms + # is way too long. Each inference takes ~60 ms on the + # Sony spresense, running at 156 MHz + assert max(runtimes_us) < MAX_INFERENCE_TIME_US + + # Clock speeds should be consistent for each input. On + # the Sony spresense, they vary by <100 us. Note that + # running with other attached hardware (like the + # Spresense extension board) may cause this check to fail + range_runtimes_us = max(runtimes_us) - min(runtimes_us) + assert range_runtimes_us < MAX_INFERENCE_TIME_RANGE_US + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/micro/arduino/testdata/project.ino b/tests/micro/arduino/testdata/project.ino new file mode 100644 index 000000000000..ebd1c5e0e650 --- /dev/null +++ b/tests/micro/arduino/testdata/project.ino @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "src/model.h" +#include "src/data/yes.c" +#include "src/data/no.c" +#include "src/data/unknown.c" +#include "src/data/silence.c" + +void performInference(int8_t input_data[1960], char *data_name) { + int8_t output_data[4]; + unsigned long start_time = micros(); + TVMExecute(input_data, output_data); + unsigned long end_time = micros(); + + Serial.print(data_name); + Serial.print(","); + Serial.print(end_time - start_time); + Serial.print(","); + for (int i = 0; i < 4; i++) { + Serial.print(output_data[i]); + Serial.print(","); + } + Serial.println(); +} + +void setup() { + TVMInitialize(); + Serial.begin(115200); +} + +void loop() { + Serial.println(); + Serial.println("category,runtime,yes,no,silence,unknown"); + performInference((int8_t*) input_yes, "yes"); + performInference((int8_t*) input_no, "no"); + performInference((int8_t*) input_silence, "silence"); + performInference((int8_t*) input_unknown, "unknown"); +} diff --git a/tests/micro/testdata/kws/no.c b/tests/micro/testdata/kws/no.c new file mode 100644 index 000000000000..a3bd78a5328d --- /dev/null +++ b/tests/micro/testdata/kws/no.c @@ -0,0 +1,128 @@ +/* + * This work is a derivative of "Speech Commands V2" by Google, used under CC BY 4.0. + */ + +static const char input_no[1960] = { + 0x80, 0x80, 0x80, 0xc5, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0xc5, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0xc5, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0xb4, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0xcf, 0xe4, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0xdb, 0xe4, 0xc5, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x2f, 0x1e, 0x7, 0xe4, 0xc5, 0xb4, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x52, 0x41, 0x4b, 0x3a, 0x20, 0xf6, 0xcf, 0xb4, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0xb4, 0x80, 0x80, 0xb4, 0x80, 0x80, 0x80, 0xc5, 0xb4, 0x80, 0x80, 0x80, 0xb4, 0x80, 0x80, + 0x62, 0x53, 0x5d, 0x51, 0x4a, 0xf9, 0xe4, 0xb4, 0xc5, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0xb4, 0x80, 0x80, 0x80, 0xc5, 0x80, 0x80, 0x80, 0xc5, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0xb4, 0xc5, 0x80, 0xcf, 0x80, 0x41, 0x49, 0x6a, 0x5d, 0x75, 0x62, 0x75, 0x63, + 0x7a, 0x65, 0x7b, 0x64, 0x78, 0x62, 0x75, 0x5d, 0x71, 0x5b, 0x37, 0xd, 0x3, 0xf6, 0xec, 0xd6, + 0x32, 0x2a, 0x1a, 0xf6, 0x42, 0x4b, 0x3f, 0xe0, 0xe4, 0xcf, 0xf3, 0xef, 0xf3, 0xfb, 0x3, 0x0, + 0x6d, 0x56, 0x6e, 0x57, 0x69, 0x55, 0x72, 0x5d, 0x66, 0x52, 0x6e, 0x5d, 0x6f, 0x46, 0x64, 0x52, + 0x62, 0x42, 0x4e, 0x29, 0x32, 0xe, 0x25, 0x35, 0x56, 0x49, 0x4d, 0x42, 0x5d, 0x57, 0x61, 0x34, + 0x1c, 0x5, 0x20, 0x17, 0x17, 0x17, 0x24, 0x20, 0x76, 0x65, 0x7a, 0x63, 0x7b, 0x65, 0x7b, 0x5d, + 0x70, 0x53, 0x73, 0x61, 0x70, 0x53, 0x66, 0x57, 0x63, 0x52, 0x5c, 0x3a, 0x54, 0x4d, 0x6b, 0x5f, + 0x78, 0x66, 0x7a, 0x64, 0x7b, 0x64, 0x75, 0x56, 0x5a, 0x46, 0x4b, 0x3d, 0x46, 0x3e, 0x4e, 0x3f, + 0x68, 0x58, 0x6e, 0x57, 0x6d, 0x5f, 0x76, 0x5a, 0x6e, 0x57, 0x75, 0x5d, 0x67, 0x53, 0x68, 0x50, + 0x67, 0x53, 0x6c, 0x59, 0x68, 0x5a, 0x6a, 0x53, 0x65, 0x5a, 0x74, 0x56, 0x6d, 0x5c, 0x6b, 0x4a, + 0x50, 0x46, 0x58, 0x48, 0x66, 0x56, 0x59, 0x46, 0x5e, 0x43, 0x61, 0x44, 0x61, 0x50, 0x6e, 0x55, + 0x67, 0x5a, 0x63, 0x4e, 0x5f, 0x3b, 0x63, 0x52, 0x5e, 0x4e, 0x67, 0x4d, 0x62, 0x51, 0x6a, 0x4e, + 0x62, 0x48, 0x69, 0x55, 0x66, 0x50, 0x62, 0x50, 0x59, 0x40, 0x4c, 0x41, 0x6c, 0x55, 0x5a, 0x3f, + 0x58, 0x3c, 0x5b, 0x28, 0x50, 0x3d, 0x62, 0x4b, 0x5b, 0x55, 0x62, 0x43, 0x5d, 0x3c, 0x50, 0x37, + 0x55, 0x2d, 0x55, 0x49, 0x59, 0x48, 0x53, 0x3e, 0x53, 0x46, 0x64, 0x53, 0x61, 0x3f, 0x5e, 0x2e, + 0x4d, 0x39, 0x4e, 0x41, 0x61, 0x4a, 0x53, 0x36, 0x52, 0x35, 0x55, 0x2a, 0x4f, 0x3a, 0x5a, 0x3e, + 0x55, 0x4f, 0x5e, 0x37, 0x4d, 0x34, 0x4c, 0x37, 0x4e, 0x28, 0x50, 0x36, 0x53, 0x39, 0x49, 0x2b, + 0x4f, 0x39, 0x5c, 0x47, 0x51, 0x35, 0x5d, 0x1b, 0x3f, 0x2b, 0x46, 0x3b, 0x5d, 0x44, 0x5a, 0x35, + 0x4d, 0x35, 0x4e, 0x30, 0x4b, 0x3f, 0x57, 0x35, 0x59, 0x3f, 0x45, 0xd, 0x2b, 0x4, 0x45, 0x26, + 0x48, 0x36, 0x47, 0x26, 0x44, 0x39, 0x50, 0x2e, 0x46, 0x2f, 0x55, 0x43, 0x4c, 0x23, 0x52, 0x2f, + 0x3f, 0x25, 0x43, 0x2d, 0x3b, 0xf9, 0x4d, 0x29, 0x44, 0x1b, 0x35, 0x38, 0x48, 0x3a, 0x46, 0x3c, + 0x5d, 0x29, 0x43, 0x5, 0x4a, 0xd, 0x26, 0xb4, 0x28, 0xcf, 0x3c, 0x13, 0x25, 0x2, 0x32, 0xf9, + 0x2f, 0x1e, 0x4d, 0x19, 0x3a, 0x2, 0x3c, 0x7, 0x3c, 0x12, 0x3c, 0x10, 0xdb, 0x80, 0x37, 0x24, + 0x42, 0x21, 0x3a, 0x30, 0x4a, 0x28, 0x32, 0x31, 0x48, 0xe7, 0x2d, 0x80, 0x19, 0xf9, 0x2d, 0xf3, + 0x32, 0x2, 0x24, 0xb4, 0x14, 0x80, 0x22, 0xb4, 0x35, 0x3, 0x40, 0xf, 0x30, 0x80, 0x26, 0x80, + 0x26, 0xcf, 0x21, 0x80, 0x80, 0x80, 0xf5, 0xef, 0x28, 0x80, 0x4b, 0x34, 0x3c, 0xdb, 0x34, 0x12, + 0x44, 0xe0, 0x26, 0x80, 0x1d, 0x80, 0xd6, 0x80, 0x21, 0xe4, 0x80, 0x80, 0xb4, 0x80, 0xf6, 0x11, + 0x2b, 0xff, 0x3e, 0x16, 0x1f, 0x80, 0x21, 0xf6, 0x14, 0xd6, 0x27, 0xcf, 0x80, 0x80, 0x0, 0xec, + 0x48, 0xd6, 0x3b, 0x0, 0x36, 0x1d, 0x28, 0xcf, 0x2d, 0xef, 0x25, 0x80, 0xcf, 0x80, 0xf5, 0x80, + 0xa, 0x80, 0x11, 0x80, 0x80, 0x80, 0xf8, 0xe4, 0x10, 0xea, 0x2a, 0xf1, 0x21, 0x80, 0xcf, 0x80, + 0x3, 0xe7, 0x1a, 0xb4, 0x80, 0x80, 0xe0, 0xdb, 0x31, 0xe0, 0x32, 0xc, 0x30, 0x80, 0x0, 0xc5, + 0x34, 0x80, 0x2, 0x80, 0xf1, 0x80, 0xcf, 0x80, 0xb4, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x2, 0x80, 0x14, 0x80, 0xd6, 0x80, 0x80, 0x80, 0xfb, 0xdb, 0x8, 0x80, 0x80, 0x80, 0xe4, 0xe7, + 0x28, 0xc5, 0x1e, 0xdb, 0x2a, 0xb4, 0x80, 0x80, 0x30, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0xf8, 0xb4, 0x17, 0x80, 0xcf, 0x80, 0x80, 0x80, + 0x0, 0xcf, 0x12, 0x80, 0x80, 0x80, 0xdb, 0xb4, 0xe4, 0x80, 0x21, 0xb4, 0x2a, 0x80, 0x80, 0x80, + 0x13, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0xf3, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0xfd, 0x80, 0x80, 0x80, 0xe0, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0xe4, 0x80, 0xb4, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}; diff --git a/tests/micro/testdata/kws/silence.c b/tests/micro/testdata/kws/silence.c new file mode 100644 index 000000000000..bc26efa70e4f --- /dev/null +++ b/tests/micro/testdata/kws/silence.c @@ -0,0 +1,128 @@ +/* + * This work is a derivative of "Speech Commands V2" by Google, used under CC BY 4.0. + */ + +static const char input_silence[1960] = { + 0x23, 0x17, 0xe0, 0x3, 0x9, 0xe7, 0xe7, 0xdb, 0xcf, 0xc5, 0xe0, 0xdb, 0xc5, 0xcf, 0xef, 0xcf, + 0xcf, 0xdb, 0xef, 0xdb, 0xe7, 0xc5, 0x5, 0x3, 0xfc, 0xe7, 0xf6, 0xdb, 0xcf, 0xe7, 0x9, 0xef, + 0xef, 0xdb, 0xcf, 0xe7, 0xe0, 0xe7, 0xe0, 0xc5, 0xff, 0xe0, 0x4, 0xcf, 0xdb, 0xb4, 0x80, 0xdb, + 0xef, 0x80, 0xc5, 0xe4, 0x9, 0xe4, 0xcf, 0xc5, 0xdb, 0xcf, 0xdb, 0xcf, 0xf5, 0xdb, 0xe7, 0xcf, + 0xef, 0xe4, 0xe7, 0xe4, 0xe7, 0xdb, 0xdb, 0xcf, 0xc5, 0xdb, 0xcf, 0xcf, 0xcf, 0xb4, 0xcf, 0xcf, + 0x13, 0xef, 0xf5, 0x80, 0x80, 0x80, 0xc5, 0xcf, 0xcf, 0x80, 0x80, 0xcf, 0xf5, 0xcf, 0x80, 0x80, + 0x80, 0x80, 0x80, 0xcf, 0xf9, 0xdb, 0xcf, 0x80, 0x80, 0xcf, 0xe7, 0xdb, 0xfb, 0xe4, 0xdb, 0xcf, + 0xe7, 0xcf, 0xe7, 0xb4, 0xdb, 0xe4, 0xcf, 0xb4, 0xfb, 0x0, 0x6, 0xd6, 0xec, 0xb4, 0x80, 0xb4, + 0x80, 0x80, 0x80, 0x80, 0xf3, 0xb4, 0xdb, 0xdb, 0xc5, 0xb4, 0xc5, 0x80, 0xcf, 0xb4, 0xdb, 0xb4, + 0xb4, 0x80, 0xcf, 0x80, 0xdb, 0xb4, 0xb4, 0x80, 0xc5, 0x80, 0xdb, 0xcf, 0xdb, 0xcf, 0xcf, 0xb4, + 0xff, 0xcf, 0xdb, 0x80, 0xb4, 0x80, 0x80, 0xd6, 0xcf, 0xcf, 0x80, 0xcf, 0xcf, 0xcf, 0xe4, 0xcf, + 0xc5, 0x80, 0x80, 0x80, 0xdb, 0x80, 0xb4, 0x80, 0xdb, 0x80, 0xb4, 0x80, 0xb4, 0xb4, 0xdb, 0xcf, + 0xec, 0xe0, 0xcf, 0xe0, 0xe4, 0xd6, 0xdb, 0x80, 0xef, 0xf6, 0xea, 0xd6, 0xb4, 0xd6, 0xec, 0xc5, + 0xec, 0xcf, 0xc5, 0x80, 0xdb, 0x80, 0x80, 0x80, 0x80, 0xb4, 0xdb, 0xcf, 0xdb, 0xd6, 0xe4, 0xc5, + 0xdb, 0xb4, 0xcf, 0xc5, 0xcf, 0xd6, 0xe4, 0xc5, 0xf3, 0xe0, 0xec, 0xe0, 0xfd, 0xe7, 0xcf, 0xb4, + 0x24, 0x1a, 0x0, 0xf1, 0x19, 0xe0, 0xec, 0xe0, 0xb4, 0xcf, 0xdb, 0xd6, 0xb4, 0xb4, 0xb4, 0x80, + 0xdb, 0x80, 0xdb, 0xc5, 0xf1, 0xe7, 0xea, 0xf8, 0xec, 0xc5, 0xe4, 0xe0, 0xec, 0xc5, 0xcf, 0xb4, + 0xe4, 0xd6, 0xe4, 0xdb, 0xf1, 0xdb, 0xdb, 0xc5, 0x22, 0xea, 0xe7, 0x80, 0xea, 0xf3, 0xec, 0xfb, + 0xec, 0xe0, 0xdb, 0xb4, 0xe4, 0xe0, 0xec, 0xd6, 0xf3, 0xb4, 0xb4, 0x80, 0xd6, 0xd6, 0xe4, 0xdb, + 0xcf, 0xb4, 0xdb, 0xdb, 0xf1, 0xe4, 0xcf, 0xb4, 0xe4, 0xcf, 0xe4, 0xea, 0xea, 0xe4, 0xe4, 0xd6, + 0xef, 0xb4, 0xc5, 0xc5, 0xd6, 0xc5, 0xe4, 0x80, 0x80, 0x80, 0xb4, 0x80, 0xcf, 0xc5, 0x0, 0xdb, + 0xb4, 0xb4, 0xdb, 0x80, 0xb4, 0x80, 0x80, 0x80, 0xb4, 0x80, 0x80, 0x80, 0xb4, 0xc5, 0xcf, 0xb4, + 0xcf, 0xcf, 0xe0, 0xcf, 0xcf, 0x80, 0xb4, 0x80, 0xec, 0xd6, 0xe0, 0xc5, 0xb4, 0xb4, 0xcf, 0x80, + 0xcf, 0xb4, 0xcf, 0x80, 0xd6, 0xc5, 0x80, 0x80, 0xdb, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0xcf, 0x80, 0x80, 0x80, 0xcf, 0xb4, 0xd6, 0xb4, 0xd6, 0xb4, 0xf1, 0xc5, 0xc5, 0x80, 0xb4, 0x80, + 0x11, 0xc5, 0xb4, 0x80, 0x80, 0x80, 0xb4, 0x80, 0xb4, 0x80, 0x80, 0x80, 0xc5, 0xcf, 0xb4, 0x80, + 0xe4, 0xb4, 0x80, 0xb4, 0x80, 0x80, 0x80, 0x80, 0xcf, 0x80, 0xb4, 0x80, 0x80, 0x80, 0xb4, 0xb4, + 0xd6, 0xc5, 0xb4, 0x80, 0xc5, 0x80, 0xb4, 0x80, 0xcf, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0xb4, 0xc5, 0xe4, 0xc5, 0xb4, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0xef, 0x80, 0xc5, 0xb4, 0xc5, 0xc5, 0xc5, 0xcf, 0xd6, 0xc5, 0xf5, 0xb4, 0xcf, 0x80, + 0xe4, 0xc5, 0xb4, 0xe0, 0xd6, 0xb4, 0xcf, 0x80, 0xb4, 0xc5, 0xcf, 0x80, 0xe0, 0xc5, 0xd6, 0x80, + 0x80, 0x80, 0xb4, 0x80, 0x80, 0x80, 0xb4, 0xb4, 0xc5, 0x80, 0xd6, 0xb4, 0xe0, 0xb4, 0xb4, 0xc5, + 0xc5, 0xb4, 0xc5, 0x80, 0xc5, 0xc5, 0xd6, 0x80, 0x80, 0x80, 0xf8, 0x80, 0x80, 0xb4, 0xd6, 0x80, + 0xd6, 0xb4, 0xb4, 0x80, 0xb4, 0x80, 0x80, 0x80, 0x80, 0xb4, 0xcf, 0xcf, 0xe7, 0x80, 0xb4, 0x80, + 0xc5, 0x80, 0xc5, 0x80, 0xb4, 0x80, 0xb4, 0xb4, 0xc5, 0x80, 0xb4, 0x80, 0xc5, 0x80, 0xe0, 0x80, + 0xef, 0x80, 0xcf, 0x80, 0xb4, 0x80, 0x80, 0x80, 0x80, 0x80, 0xb4, 0xb4, 0xfd, 0xb4, 0x80, 0xb4, + 0xe0, 0x80, 0xcf, 0xb4, 0xb4, 0x80, 0xe7, 0xb4, 0xe7, 0xb4, 0xb4, 0xd6, 0xb4, 0x80, 0xe0, 0xc5, + 0x80, 0x80, 0xc5, 0xc5, 0xd6, 0x80, 0xc5, 0x80, 0xdb, 0xc5, 0xea, 0x80, 0x80, 0x80, 0xb4, 0x80, + 0xb4, 0x80, 0xe0, 0x80, 0x80, 0x80, 0xc5, 0xb4, 0x80, 0x80, 0xd6, 0x80, 0xb4, 0x80, 0xb4, 0x80, + 0x80, 0xb4, 0xb4, 0x80, 0x80, 0x80, 0x80, 0x80, 0xb4, 0x80, 0xe7, 0xb4, 0xc5, 0x80, 0xd6, 0x80, + 0xe7, 0xc5, 0xdb, 0x80, 0xdb, 0xcf, 0xe0, 0x80, 0x80, 0x80, 0xc5, 0xb4, 0xdb, 0x80, 0xef, 0xc5, + 0x80, 0x80, 0x80, 0x80, 0xc5, 0xb4, 0x80, 0x80, 0xb4, 0x80, 0x80, 0x80, 0xb4, 0x80, 0xd6, 0x80, + 0xc5, 0xb4, 0xdb, 0x80, 0xb4, 0x80, 0x80, 0x80, 0xe0, 0x80, 0x80, 0xb4, 0xf6, 0xdb, 0xc5, 0x80, + 0x80, 0x80, 0xc5, 0x80, 0x80, 0x80, 0xb4, 0x80, 0xc5, 0x80, 0xb4, 0xb4, 0xd6, 0xb4, 0xd6, 0x80, + 0x80, 0xb4, 0xd6, 0xb4, 0x80, 0x80, 0xdb, 0xb4, 0xf3, 0xb4, 0xdb, 0x80, 0x80, 0x80, 0xc5, 0x80, + 0x1d, 0xcf, 0x16, 0x12, 0x17, 0xc, 0x23, 0x2, 0x1, 0xc5, 0xc5, 0xb4, 0x80, 0x80, 0x80, 0x80, + 0x80, 0xc5, 0xd6, 0xc5, 0xb4, 0xc5, 0xdb, 0x80, 0x80, 0x80, 0x80, 0x80, 0xb4, 0xb4, 0xdb, 0xc5, + 0xe4, 0x80, 0xdb, 0x80, 0xc5, 0xb4, 0x80, 0x80, 0x78, 0x64, 0x7a, 0x64, 0x76, 0x60, 0x67, 0x55, + 0x5a, 0x3a, 0x37, 0x24, 0xf6, 0xc5, 0x14, 0x17, 0x1e, 0x18, 0x31, 0x39, 0x44, 0x43, 0x49, 0x3e, + 0x39, 0x23, 0x18, 0x17, 0x42, 0x41, 0x40, 0x34, 0x39, 0x34, 0x37, 0x30, 0x38, 0x23, 0x22, 0x9, + 0x75, 0x63, 0x73, 0x63, 0x77, 0x58, 0x73, 0x5f, 0x64, 0x4d, 0x57, 0x41, 0x58, 0x46, 0x36, 0x32, + 0x45, 0x51, 0x64, 0x56, 0x72, 0x61, 0x67, 0x57, 0x60, 0x52, 0x49, 0x4e, 0x61, 0x53, 0x62, 0x57, + 0x67, 0x50, 0x66, 0x56, 0x63, 0x52, 0x5e, 0x3d, 0x6b, 0x5a, 0x70, 0x5d, 0x72, 0x50, 0x6c, 0x56, + 0x67, 0x5a, 0x69, 0x49, 0x5a, 0x4f, 0x56, 0x50, 0x61, 0x50, 0x6c, 0x5d, 0x71, 0x5d, 0x6e, 0x56, + 0x6c, 0x58, 0x69, 0x55, 0x6c, 0x57, 0x65, 0x57, 0x6c, 0x56, 0x68, 0x4c, 0x61, 0x58, 0x66, 0x44, + 0x68, 0x52, 0x6b, 0x56, 0x6c, 0x60, 0x6e, 0x52, 0x72, 0x4e, 0x5b, 0x4d, 0x56, 0x4e, 0x68, 0x51, + 0x69, 0x5a, 0x6a, 0x5a, 0x72, 0x54, 0x6f, 0x5d, 0x75, 0x5f, 0x67, 0x57, 0x65, 0x48, 0x5c, 0x4c, + 0x66, 0x52, 0x68, 0x52, 0x63, 0x53, 0x64, 0x44, 0x5f, 0x44, 0x60, 0x49, 0x69, 0x60, 0x71, 0x51, + 0x6c, 0x59, 0x6c, 0x53, 0x62, 0x4b, 0x5c, 0x4e, 0x61, 0x4c, 0x6a, 0x5c, 0x69, 0x4b, 0x6b, 0x56, + 0x6b, 0x40, 0x5d, 0x43, 0x6c, 0x55, 0x60, 0x3f, 0x5f, 0x4d, 0x69, 0x52, 0x64, 0x4d, 0x64, 0x41, + 0x59, 0x3b, 0x55, 0x35, 0x67, 0x55, 0x71, 0x5a, 0x69, 0x58, 0x65, 0x48, 0x5e, 0x4e, 0x6a, 0x55, + 0x69, 0x55, 0x73, 0x5c, 0x68, 0x35, 0x64, 0x57, 0x6a, 0x43, 0x57, 0x42, 0x63, 0x4c, 0x71, 0x57, + 0x60, 0x43, 0x5a, 0x44, 0x5c, 0x3e, 0x5d, 0x3e, 0x57, 0x31, 0x46, 0x7, 0x56, 0x4b, 0x73, 0x52, + 0x64, 0x4b, 0x5b, 0x4a, 0x66, 0x4f, 0x69, 0x4d, 0x69, 0x56, 0x6e, 0x3e, 0x4b, 0x37, 0x5c, 0x44, + 0x56, 0x24, 0x4f, 0x2a, 0x46, 0x3b, 0x61, 0x4e, 0x61, 0x43, 0x5d, 0x45, 0x5e, 0x44, 0x50, 0x3c, + 0x56, 0x2d, 0x45, 0x4, 0x50, 0x40, 0x64, 0x57, 0x69, 0x4d, 0x64, 0x50, 0x62, 0x4e, 0x67, 0x4e, + 0x62, 0x56, 0x67, 0x3c, 0x48, 0x23, 0x58, 0x43, 0x53, 0x28, 0x3b, 0xcf, 0x48, 0x48, 0x5c, 0x40, + 0x4d, 0x37, 0x4e, 0x3c, 0x56, 0x20, 0x3d, 0x11, 0x37, 0xc5, 0x4a, 0xd6, 0x2d, 0x2b, 0x57, 0x4e, + 0x5a, 0x44, 0x60, 0x43, 0x5a, 0x3f, 0x5c, 0x41, 0x67, 0x50, 0x60, 0x2f, 0x36, 0x1c, 0x54, 0x3e, + 0x4f, 0xc, 0x2d, 0x80, 0x36, 0x22, 0x50, 0x41, 0x5f, 0x3e, 0x50, 0x3f, 0x5f, 0x3d, 0x46, 0x19, + 0x41, 0xfd, 0x33, 0xd6, 0x25, 0x2, 0x40, 0x2f, 0x59, 0x3a, 0x4f, 0x3d, 0x47, 0x23, 0x52, 0x32, + 0x5c, 0x3e, 0x45, 0xcf, 0xd, 0xdb, 0x42, 0x2a, 0x3f, 0x80, 0x15, 0x80, 0xe4, 0xb4, 0x36, 0x28, + 0x49, 0x39, 0x52, 0x3a, 0x5a, 0x39, 0x52, 0xb, 0x26, 0x80, 0x27, 0xc5, 0x2f, 0xf6, 0x45, 0x24, + 0x40, 0x29, 0x52, 0x33, 0x43, 0xfc, 0x33, 0x1d, 0x44, 0x17, 0x2e, 0x80, 0x80, 0x80, 0xb4, 0x80, + 0x80, 0x80, 0x24, 0x80, 0xb4, 0x80, 0x34, 0x32, 0x4c, 0x32, 0x4b, 0x30, 0x54, 0x3f, 0x51, 0x30, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0xe4, 0x80, 0x1, 0x80, 0x26, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0xfd, 0x80, 0x80, 0x80, 0xb4, 0x80, + 0x29, 0xe0, 0xe0, 0xc5, 0x27, 0x80, 0x1b, 0x7, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x23, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0xf9, 0x80, 0x80, 0x80, 0x80, 0x80, 0xd6, 0x80, 0x80, 0x80, 0xb4, 0x80, 0xf5, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0xe0, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x1d, 0xe4, 0x11, 0xb4, 0x32, 0xa, + 0x6, 0x80, 0x80, 0x80, 0xd6, 0x80, 0x1c, 0xd, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x15, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0xf8, 0xcf, 0x10, 0x80, 0x17, 0x80, 0x1e, 0x80, 0xff, 0xec, 0x25, 0x80, 0x1c, 0x23, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x11, 0xb4, 0x2, 0x80, 0x30, 0x8, + 0x15, 0x80, 0x6, 0x20, 0x36, 0xf8, 0x2e, 0x18, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0xf3, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0xd, 0x4, 0xa, 0xea, 0x37, 0x24, 0x2a, 0xc, 0x39, 0x26, 0x43, 0x5, 0x2d, 0x1f, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x14, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x7, 0xcf, 0xf, 0xef, 0x32, 0xd, + 0x2a, 0x14, 0x37, 0x1, 0x32, 0x0, 0x38, 0x10, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x1c, 0x80, 0x80, 0x80, 0x28, 0xdb, 0xe4, 0xe0, 0xb4, 0x80, 0x16, 0xcf, 0x1b, 0xb4, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0xb4, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}; diff --git a/tests/micro/testdata/kws/unknown.c b/tests/micro/testdata/kws/unknown.c new file mode 100644 index 000000000000..6e4df3d20b49 --- /dev/null +++ b/tests/micro/testdata/kws/unknown.c @@ -0,0 +1,128 @@ +/* + * This work is a derivative of "Speech Commands V2" by Google, used under CC BY 4.0. + */ + +static const char input_unknown[1960] = { + 0x78, 0x66, 0x7a, 0x63, 0x78, 0x62, 0x6d, 0x52, 0x58, 0x19, 0x0, 0xcf, 0x80, 0x80, 0x80, 0x80, + 0xcf, 0xc5, 0xc5, 0xc5, 0x80, 0x80, 0x80, 0xc5, 0xc5, 0xe7, 0xe0, 0x80, 0x80, 0xc5, 0x80, 0xcf, + 0xc5, 0xc5, 0x80, 0xc5, 0xcf, 0xe7, 0xe0, 0xdb, 0x72, 0x4b, 0x65, 0x60, 0x70, 0x50, 0x73, 0x59, + 0x60, 0x4f, 0x4d, 0x3c, 0x11, 0xff, 0xc5, 0xc5, 0xdb, 0xdb, 0xcf, 0xec, 0xe7, 0xcf, 0xcf, 0x2, + 0x31, 0x4d, 0x4c, 0xe7, 0xdb, 0xc5, 0x80, 0xcf, 0xef, 0xe4, 0x4, 0xff, 0xf5, 0xec, 0xef, 0x5, + 0x6c, 0x4b, 0x56, 0x54, 0x6a, 0x47, 0x6f, 0x5b, 0x63, 0x55, 0x4c, 0x41, 0x2d, 0x22, 0x20, 0x3a, + 0x4e, 0xf1, 0xcf, 0xfc, 0x19, 0xf3, 0xe7, 0x2d, 0x48, 0x4e, 0x5b, 0x80, 0xcf, 0xcf, 0x80, 0x80, + 0x80, 0xdb, 0x3, 0xfb, 0xf5, 0xea, 0x0, 0xf5, 0x62, 0x40, 0x46, 0x47, 0x62, 0x41, 0x68, 0x53, + 0x5f, 0x51, 0x57, 0x4e, 0x5b, 0x51, 0x58, 0x4b, 0x62, 0x2b, 0xef, 0x44, 0x5d, 0x41, 0x49, 0x5c, + 0x62, 0x56, 0x58, 0x2f, 0xc5, 0xb4, 0xcf, 0xcf, 0xc5, 0xe0, 0xf9, 0xe7, 0x7, 0xf5, 0xa, 0xfc, + 0x5b, 0x39, 0x35, 0x3d, 0x5c, 0x37, 0x5d, 0x49, 0x57, 0x49, 0x63, 0x57, 0x61, 0x55, 0x5e, 0x4d, + 0x64, 0x4b, 0x63, 0x58, 0x5c, 0x49, 0x5f, 0x57, 0x6a, 0x56, 0x68, 0x41, 0x15, 0xf1, 0x7, 0xf1, + 0xf9, 0xef, 0xfd, 0xfb, 0xc, 0xf6, 0x5, 0xef, 0x5a, 0x40, 0x4a, 0x44, 0x69, 0x57, 0x55, 0x50, + 0x63, 0x49, 0x67, 0x5a, 0x72, 0x60, 0x70, 0x5a, 0x71, 0x61, 0x77, 0x63, 0x75, 0x5e, 0x71, 0x52, + 0x6f, 0x5f, 0x78, 0x64, 0x78, 0x5d, 0x56, 0x57, 0x56, 0x28, 0x39, 0x3b, 0x58, 0x49, 0x3d, 0x33, + 0x58, 0x3f, 0x2a, 0x50, 0x6c, 0x53, 0x6a, 0x5b, 0x69, 0x57, 0x6e, 0x5e, 0x73, 0x60, 0x74, 0x5a, + 0x75, 0x61, 0x76, 0x60, 0x75, 0x59, 0x6e, 0x4c, 0x6b, 0x4c, 0x6b, 0x58, 0x74, 0x61, 0x6e, 0x36, + 0x49, 0x41, 0x5b, 0x5d, 0x6e, 0x57, 0x5e, 0x44, 0x50, 0x30, 0x3a, 0x46, 0x5f, 0x3c, 0x64, 0x4e, + 0x5d, 0x53, 0x69, 0x55, 0x6a, 0x57, 0x69, 0x52, 0x71, 0x5a, 0x6b, 0x47, 0x5f, 0x4d, 0x61, 0x43, + 0x5b, 0x37, 0x59, 0x3e, 0x57, 0x3f, 0x53, 0xe, 0x44, 0x47, 0x5c, 0x43, 0x62, 0x51, 0x5d, 0x3f, + 0x4a, 0x2a, 0x39, 0x3f, 0x59, 0x37, 0x5c, 0x40, 0x58, 0x50, 0x65, 0x4e, 0x65, 0x52, 0x67, 0x54, + 0x6f, 0x52, 0x59, 0x3b, 0x57, 0x48, 0x61, 0x49, 0x54, 0xf8, 0x3e, 0x2d, 0x4e, 0x3e, 0x50, 0xc, + 0x3e, 0x53, 0x67, 0x2d, 0x4c, 0x3b, 0x4f, 0x2a, 0x43, 0x14, 0x46, 0x37, 0x50, 0x23, 0x58, 0x36, + 0x57, 0x48, 0x63, 0x46, 0x67, 0x4e, 0x65, 0x55, 0x6d, 0x4c, 0x55, 0x35, 0x41, 0x3b, 0x58, 0x3f, + 0x53, 0x2f, 0x44, 0x25, 0x48, 0x37, 0x58, 0xe4, 0x4d, 0x48, 0x53, 0x2b, 0x41, 0x28, 0x4a, 0x2d, + 0x3d, 0x5, 0x44, 0x29, 0x44, 0x1c, 0x5c, 0x3b, 0x53, 0x35, 0x5a, 0x3b, 0x60, 0x45, 0x61, 0x50, + 0x64, 0x3a, 0x43, 0x1f, 0x35, 0x23, 0x4d, 0x4a, 0x5e, 0x3c, 0x4d, 0x30, 0x51, 0x2e, 0x51, 0xf3, + 0x4d, 0x3e, 0x50, 0x1a, 0x34, 0xfc, 0x44, 0x27, 0x37, 0xf8, 0x3a, 0x9, 0x32, 0x33, 0x5d, 0x37, + 0x57, 0x35, 0x5d, 0x3b, 0x58, 0x31, 0x60, 0x45, 0x50, 0xff, 0x3a, 0xe0, 0x24, 0x3, 0x24, 0x3a, + 0x4f, 0xe, 0x32, 0x1d, 0x46, 0x2d, 0x45, 0x4, 0x56, 0x3d, 0x50, 0x7, 0xa, 0x80, 0x3a, 0x1f, + 0x31, 0xe0, 0x43, 0x3, 0x26, 0x3a, 0x5b, 0x34, 0x56, 0x30, 0x58, 0x2e, 0x53, 0x1f, 0x61, 0x3f, + 0x3f, 0x80, 0x2f, 0xe4, 0x2f, 0x14, 0x30, 0x1e, 0x50, 0xe0, 0x22, 0x0, 0x4b, 0x2d, 0x39, 0xdb, + 0x56, 0x3e, 0x46, 0x34, 0x2d, 0x80, 0x29, 0x5, 0x2f, 0xc5, 0x46, 0xfb, 0x1c, 0x3a, 0x56, 0x26, + 0x53, 0x2b, 0x4e, 0x8, 0x53, 0x25, 0x65, 0x3a, 0xf, 0x80, 0xf5, 0x80, 0xb, 0xd6, 0x1e, 0x7, + 0x55, 0xd6, 0x6, 0x80, 0x2c, 0x0, 0x11, 0xe4, 0x3e, 0x26, 0x41, 0x25, 0x2c, 0x80, 0x1d, 0x2, + 0x2a, 0xd6, 0x45, 0xec, 0x4, 0x3c, 0x54, 0x20, 0x4d, 0x12, 0x49, 0xf6, 0x57, 0x32, 0x61, 0x23, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0xb, 0xe7, 0x3b, 0x80, 0xc5, 0x80, 0xc5, 0x80, 0xcf, 0xdb, + 0x14, 0x1d, 0x3d, 0x36, 0x3f, 0x80, 0x19, 0xfc, 0x1f, 0x80, 0x40, 0xea, 0x8, 0x3c, 0x52, 0x22, + 0x3a, 0xf8, 0x49, 0x3, 0x58, 0x21, 0x3c, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0xc5, 0x80, + 0xf6, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x37, 0x2d, 0x3b, 0x1b, 0x31, 0x80, 0x16, 0xf5, + 0xf3, 0x80, 0x3e, 0xcf, 0xec, 0x3b, 0x4e, 0x12, 0x4, 0x80, 0x4f, 0x26, 0x5a, 0x1a, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0xfc, 0xb4, + 0x2c, 0x0, 0x1b, 0x2a, 0x2f, 0x80, 0xc, 0xdb, 0xd6, 0x80, 0x44, 0xfd, 0x11, 0x33, 0x44, 0xd6, + 0x8, 0x80, 0x4e, 0xe, 0x26, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0xb4, 0x80, 0x80, 0x80, 0x80, 0x80, 0x1, 0x80, 0xe7, 0x80, 0x80, 0x80, 0x80, 0x80, 0x14, 0xdb, + 0xf8, 0x80, 0x48, 0x0, 0x7, 0xe7, 0x18, 0x80, 0xef, 0x80, 0x36, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0xdb, 0x80, 0x80, 0x80, 0x80, 0x80, 0x17, 0x80, 0x80, 0x80, 0x48, 0x6, 0x10, 0x80, 0xf1, 0x80, + 0x24, 0x80, 0x7, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x3c, 0xf1, 0x7, 0x80, 0xc5, 0x80, 0x33, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0xe0, 0x80, 0x26, 0x80, 0xcf, 0x80, 0x80, 0x80, + 0x80, 0x80, 0xb4, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0xf6, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}; diff --git a/tests/micro/testdata/kws/yes.c b/tests/micro/testdata/kws/yes.c new file mode 100644 index 000000000000..ec18f20e46cf --- /dev/null +++ b/tests/micro/testdata/kws/yes.c @@ -0,0 +1,128 @@ +/* + * This work is a derivative of "Speech Commands V2" by Google, used under CC BY 4.0. + */ + +static const char input_yes[1960] = { + 0x7c, 0x66, 0x79, 0x65, 0x7d, 0x67, 0x7c, 0x67, 0x7c, 0x66, 0x7c, 0x67, 0x7c, 0x67, 0x7d, 0x66, + 0x7c, 0x67, 0x7d, 0x66, 0x7c, 0x67, 0x7d, 0x66, 0x7c, 0x67, 0x7d, 0x67, 0x7d, 0x67, 0x7d, 0x67, + 0x7d, 0x67, 0x7d, 0x67, 0x7d, 0x67, 0x7d, 0x67, 0x52, 0x57, 0x78, 0x5a, 0x67, 0x53, 0x6f, 0x4b, + 0x6d, 0x5c, 0x71, 0x52, 0x66, 0x4d, 0x6e, 0x56, 0x73, 0x50, 0x5f, 0x54, 0x6d, 0x55, 0x6a, 0x5b, + 0x6f, 0x57, 0x68, 0x50, 0x71, 0x58, 0x6d, 0x57, 0x69, 0x55, 0x6a, 0x55, 0x6c, 0x59, 0x6c, 0x5a, + 0x5b, 0x3c, 0x54, 0x44, 0x58, 0x4f, 0x66, 0x30, 0x58, 0x50, 0x61, 0x3d, 0x67, 0x36, 0x5b, 0x4d, + 0x64, 0x51, 0x6a, 0x4d, 0x60, 0x4b, 0x61, 0x53, 0x69, 0x54, 0x60, 0x47, 0x5c, 0x4d, 0x63, 0x45, + 0x64, 0x4d, 0x63, 0x4b, 0x67, 0x50, 0x68, 0x4d, 0x64, 0x4b, 0x64, 0x4e, 0x5f, 0x3d, 0x53, 0x42, + 0x59, 0x39, 0x57, 0x43, 0x5e, 0x3a, 0x44, 0x3b, 0x56, 0x3c, 0x5c, 0x46, 0x66, 0x4c, 0x61, 0x3e, + 0x5d, 0x49, 0x55, 0x48, 0x5d, 0x45, 0x5a, 0x48, 0x5f, 0x41, 0x59, 0x49, 0x5a, 0x46, 0x5d, 0x3b, + 0x51, 0x3d, 0x4c, 0x44, 0x57, 0x37, 0x54, 0x43, 0x4f, 0xa, 0x32, 0x28, 0x5b, 0x3a, 0x5e, 0x47, + 0x4d, 0x2b, 0x57, 0x4a, 0x5d, 0x34, 0x52, 0x3e, 0x50, 0x38, 0x54, 0x30, 0x53, 0x41, 0x57, 0x39, + 0x5c, 0x3c, 0x53, 0x41, 0x5a, 0x1e, 0x4e, 0x41, 0x4d, 0x2c, 0x3e, 0x18, 0x4c, 0x1c, 0x36, 0x11, + 0x4b, 0x32, 0x52, 0x2f, 0x50, 0x2d, 0x4e, 0x20, 0x50, 0x3c, 0x4a, 0x16, 0x44, 0x22, 0x48, 0x29, + 0x4d, 0x34, 0x4e, 0x2c, 0x52, 0x2e, 0x46, 0x35, 0x4b, 0x14, 0x50, 0x33, 0x53, 0x3e, 0x50, 0x2d, + 0x4a, 0x0, 0x4b, 0x3a, 0x47, 0x16, 0x45, 0x32, 0x45, 0x10, 0x42, 0x23, 0x49, 0x39, 0x41, 0x10, + 0x48, 0x32, 0x4e, 0x30, 0x40, 0x34, 0x46, 0x39, 0x54, 0xf5, 0x49, 0x38, 0x53, 0x2c, 0x4a, 0x37, + 0x51, 0x2c, 0x46, 0x2f, 0x4c, 0x2a, 0x4d, 0x2b, 0x3d, 0x2f, 0x4e, 0x20, 0x1e, 0x7, 0x41, 0x8, + 0x39, 0xd, 0x46, 0x20, 0x3b, 0x2a, 0x3f, 0x20, 0x40, 0xe, 0x4e, 0x2e, 0x3e, 0x21, 0x4f, 0x16, + 0x2e, 0x35, 0x54, 0x32, 0x41, 0x1c, 0x48, 0x2a, 0x44, 0xc, 0x48, 0x21, 0x41, 0x19, 0x48, 0x2a, + 0x3d, 0x21, 0x44, 0xb4, 0x41, 0x14, 0x3e, 0x2b, 0x45, 0x23, 0x50, 0x28, 0x3e, 0x1f, 0x43, 0x26, + 0x46, 0x1b, 0x48, 0x12, 0x44, 0x2d, 0x47, 0x22, 0x3c, 0x32, 0x48, 0x26, 0x2f, 0x21, 0x45, 0x17, + 0x43, 0x22, 0x43, 0x1d, 0x44, 0x28, 0x4d, 0x14, 0x56, 0x23, 0x40, 0x2c, 0x34, 0x80, 0x44, 0xf, + 0x37, 0x16, 0x49, 0x21, 0x34, 0x1e, 0x3f, 0x22, 0x2b, 0x16, 0x34, 0x28, 0x43, 0x2d, 0x43, 0x11, + 0x49, 0x1a, 0x46, 0x20, 0x46, 0x21, 0x3d, 0x17, 0x3d, 0x28, 0x3e, 0xf5, 0x33, 0x15, 0x39, 0x20, + 0x4d, 0x2d, 0x36, 0x80, 0x1a, 0xdb, 0x3e, 0x17, 0x3b, 0x1f, 0x40, 0x17, 0x2b, 0xcf, 0x39, 0x2d, + 0x4d, 0x2b, 0x35, 0xf6, 0x44, 0x29, 0x3d, 0x24, 0x30, 0x17, 0x3b, 0x28, 0x44, 0xd, 0x38, 0x20, + 0x3b, 0xf3, 0x45, 0x19, 0x4c, 0x24, 0x37, 0x15, 0xf3, 0xb4, 0x3c, 0x28, 0x36, 0xf3, 0x44, 0x1b, + 0x48, 0x25, 0x1d, 0xd6, 0x25, 0xcf, 0x3a, 0x9, 0x3f, 0xfc, 0x31, 0xf1, 0x41, 0x24, 0x44, 0x17, + 0x45, 0x20, 0x42, 0x2, 0x33, 0xb4, 0x31, 0x1b, 0x43, 0x18, 0x2c, 0x14, 0x44, 0xa, 0x43, 0x7, + 0x4, 0x80, 0x2b, 0xf3, 0x49, 0x2a, 0x47, 0xea, 0x3b, 0xec, 0x30, 0xfb, 0x3c, 0x18, 0x35, 0xff, + 0x14, 0x18, 0x39, 0x7, 0x3c, 0x5, 0xa, 0xf, 0x35, 0x12, 0x3a, 0x0, 0x2d, 0xc, 0x46, 0x13, + 0x3e, 0x23, 0x3f, 0x18, 0x3a, 0x16, 0x35, 0xf5, 0x3a, 0x1b, 0x4e, 0x2d, 0x3c, 0xef, 0x3c, 0xfc, + 0x2e, 0xa, 0x32, 0xb4, 0x23, 0xfb, 0x3e, 0x16, 0x40, 0xe, 0x24, 0x3, 0x44, 0x24, 0x3b, 0xa, + 0x19, 0x80, 0x28, 0x1a, 0x3b, 0xfb, 0x2a, 0xf, 0x31, 0x4, 0x3a, 0x4, 0x2d, 0xec, 0x29, 0xa, + 0x25, 0xb4, 0x20, 0xb4, 0x35, 0x1b, 0x31, 0xb4, 0x7, 0xc, 0x4b, 0x1b, 0x1c, 0x80, 0x28, 0xd6, + 0x23, 0x16, 0x2d, 0xf8, 0x35, 0xf6, 0x45, 0x11, 0x1d, 0xc5, 0x2a, 0xf6, 0x37, 0xea, 0x36, 0x11, + 0x3f, 0x7, 0x36, 0x11, 0x2e, 0xf1, 0x3b, 0x11, 0x16, 0x2a, 0x3a, 0x6, 0x37, 0xcf, 0x18, 0x80, + 0x30, 0xd6, 0x14, 0xf1, 0x16, 0xfc, 0x28, 0xe4, 0x3d, 0xe0, 0x2d, 0x80, 0x26, 0xec, 0x3d, 0xf8, + 0x36, 0xcf, 0x11, 0xef, 0x2c, 0x16, 0x2d, 0xff, 0x35, 0x12, 0x3e, 0xa, 0x35, 0xd, 0x2f, 0xf9, + 0x3f, 0x2d, 0x40, 0x80, 0xe7, 0x6, 0x2a, 0x80, 0x34, 0x4, 0x5, 0x1d, 0x3d, 0x12, 0x1e, 0xa, + 0x3f, 0x26, 0x2b, 0xfb, 0x2b, 0x80, 0x26, 0x80, 0x1e, 0x15, 0x24, 0xdb, 0x2a, 0xd6, 0x2b, 0x80, + 0x6, 0xdb, 0x26, 0xfd, 0x37, 0xec, 0x2a, 0xec, 0x2, 0x1c, 0x3c, 0xe7, 0x11, 0x80, 0xf3, 0xfd, + 0x3a, 0x1, 0x28, 0x17, 0x3a, 0xdb, 0xf6, 0x80, 0x2, 0xd6, 0x21, 0xcf, 0x2a, 0xdb, 0xf, 0x80, + 0x2b, 0x17, 0x24, 0xcf, 0x2e, 0xcf, 0x30, 0xf8, 0xa, 0xf1, 0x26, 0xe7, 0x2d, 0xf5, 0x31, 0xef, + 0x25, 0x80, 0x1, 0xfb, 0xd6, 0x80, 0x19, 0x1c, 0x37, 0xfb, 0x39, 0x11, 0x2c, 0x80, 0x23, 0x18, + 0x33, 0xf8, 0x2e, 0xd, 0x34, 0xcf, 0x2b, 0xf1, 0x21, 0x80, 0x29, 0x80, 0x1f, 0xe4, 0xe, 0xb, + 0x25, 0xc5, 0x1f, 0xc5, 0x21, 0x0, 0x19, 0x80, 0xef, 0x80, 0xb, 0xe4, 0x1c, 0xcf, 0x33, 0x16, + 0x3e, 0x7, 0x21, 0xf5, 0x2f, 0x0, 0x2e, 0xef, 0x23, 0x6, 0x3d, 0xe7, 0x23, 0xe7, 0x26, 0xd6, + 0x40, 0xfd, 0x30, 0x80, 0xa, 0xf5, 0x35, 0x0, 0x32, 0xf8, 0x20, 0xcf, 0x2d, 0xef, 0x32, 0x13, + 0x3c, 0x1c, 0x0, 0xfc, 0x26, 0xe0, 0x26, 0xd6, 0xec, 0x80, 0x16, 0xf3, 0xb4, 0xf1, 0x31, 0xcf, + 0x1f, 0x80, 0x7, 0xf6, 0x19, 0xfd, 0xe7, 0x80, 0x1, 0x80, 0x1c, 0x2, 0x2f, 0x80, 0x2f, 0x80, + 0x26, 0x4, 0x1c, 0xb4, 0x4, 0xdb, 0x1e, 0xcf, 0x2a, 0x80, 0xdb, 0x80, 0x1a, 0xea, 0x31, 0xa, + 0x18, 0x23, 0x39, 0xf8, 0x36, 0x22, 0x25, 0xc5, 0x1f, 0x80, 0x26, 0xef, 0x34, 0x80, 0x19, 0xe7, + 0x2d, 0xe0, 0x17, 0xe4, 0x2f, 0x17, 0x34, 0x7, 0x31, 0xef, 0x25, 0xe0, 0x1e, 0xf8, 0x1d, 0xdb, + 0xfd, 0xb, 0x11, 0x80, 0x11, 0x80, 0xe7, 0xcf, 0x32, 0x80, 0xc, 0xdb, 0xa, 0x80, 0xf9, 0x80, + 0x14, 0x14, 0x35, 0x80, 0x2c, 0xf9, 0x1f, 0xdb, 0x1b, 0xea, 0x11, 0x80, 0x26, 0xc5, 0xb, 0xb4, + 0xb, 0x80, 0x7, 0xef, 0x22, 0x6, 0x20, 0xe0, 0x0, 0x80, 0x1a, 0x1c, 0x25, 0xfb, 0x2f, 0x80, + 0x80, 0xea, 0x31, 0x19, 0x3c, 0xf, 0x23, 0x80, 0x16, 0x0, 0x38, 0xf1, 0x21, 0xea, 0x2c, 0x80, + 0x1e, 0xec, 0x2a, 0xe4, 0x7, 0x80, 0xf8, 0x80, 0x9, 0xd6, 0x20, 0xc5, 0x18, 0x80, 0x0, 0x14, + 0x2a, 0xcf, 0x1d, 0x80, 0xc, 0xe4, 0x1c, 0xa, 0x3a, 0x24, 0x1b, 0x80, 0xf8, 0x80, 0x8, 0x80, + 0x9, 0x80, 0x20, 0xdb, 0x20, 0xd6, 0x2d, 0x19, 0x1a, 0xd6, 0x25, 0x80, 0xb4, 0x80, 0x38, 0x12, + 0x17, 0xec, 0x14, 0x80, 0x20, 0xb4, 0x13, 0xdb, 0xb, 0x80, 0xfc, 0x15, 0x2f, 0x0, 0xdb, 0x80, + 0xf5, 0x0, 0x8, 0xcf, 0xf8, 0xe4, 0xc, 0x13, 0x34, 0x80, 0x17, 0x80, 0xe7, 0x80, 0x11, 0xcf, + 0x2f, 0xf6, 0x5, 0xdb, 0x27, 0x6, 0xf1, 0x80, 0x11, 0xc5, 0x24, 0x80, 0x11, 0xea, 0xa, 0x80, + 0x23, 0x1, 0x16, 0xf3, 0xfb, 0x80, 0x15, 0x13, 0x33, 0x6, 0xfc, 0x80, 0xd6, 0x80, 0x10, 0x80, + 0x1a, 0xf5, 0x11, 0x80, 0x9, 0xc5, 0xf, 0xcf, 0xef, 0xc5, 0x1b, 0xf9, 0x8, 0x80, 0x20, 0xc5, + 0x1c, 0xdb, 0x1f, 0x80, 0x1e, 0xf3, 0x12, 0xea, 0x26, 0xcf, 0x16, 0xcf, 0x2, 0xd6, 0x7, 0x80, + 0x24, 0x80, 0xf9, 0xcf, 0x1a, 0xb4, 0x26, 0xc5, 0xfb, 0x80, 0xfc, 0xc5, 0xef, 0xcf, 0x28, 0x80, + 0x19, 0xcf, 0x28, 0xea, 0x2c, 0xc5, 0x2f, 0xc, 0x1, 0xec, 0x2d, 0xb4, 0x14, 0x80, 0xc, 0xec, + 0xf5, 0xdb, 0x0, 0xc5, 0x20, 0x80, 0x21, 0x1, 0x0, 0x80, 0xa, 0x80, 0x29, 0x80, 0xdb, 0x7, + 0xf, 0xb4, 0x23, 0xfb, 0x27, 0xdb, 0x22, 0xec, 0x21, 0x80, 0xd6, 0xb4, 0x15, 0xd6, 0x11, 0x80, + 0x1f, 0xc5, 0x1a, 0xb4, 0x7, 0xe0, 0x21, 0xcf, 0x14, 0x16, 0x2a, 0x80, 0x80, 0x80, 0xa, 0xe7, + 0x6, 0x80, 0xb4, 0x80, 0xf, 0x80, 0xfc, 0xe4, 0x13, 0x80, 0x19, 0xb4, 0xd, 0xb4, 0xdb, 0xc5, + 0x18, 0x80, 0x21, 0xb4, 0x2d, 0xc5, 0xf1, 0xdb, 0xf, 0x80, 0x23, 0xd6, 0x28, 0x80, 0xea, 0xd6, + 0xe7, 0xcf, 0x11, 0xe4, 0xec, 0x2, 0x20, 0xb4, 0x29, 0xdb, 0x6, 0x80, 0xef, 0x80, 0xe0, 0x80, + 0x4, 0xc5, 0x32, 0xb4, 0x2f, 0x80, 0x7, 0xb4, 0xe0, 0x80, 0xf5, 0x80, 0x5, 0xb4, 0x8, 0xcf, + 0x1f, 0xf6, 0x28, 0xdb, 0x1b, 0xff, 0x12, 0x80, 0x2a, 0xff, 0x2f, 0xfc, 0xcf, 0x80, 0xc, 0xf1, + 0x21, 0x80, 0x2, 0x1, 0x2d, 0xf8, 0xf9, 0xf3, 0x25, 0x80, 0xdb, 0x80, 0xd6, 0x80, 0xc, 0xe4, + 0x1b, 0xc5, 0xe0, 0xec, 0xec, 0x80, 0x6, 0xb4, 0xf5, 0xcf, 0xc, 0x80, 0x1, 0xf6, 0x1d, 0x80, + 0xe7, 0x80, 0xf3, 0x80, 0xc5, 0x80, 0xf6, 0x80, 0x1b, 0xcf, 0x11, 0x80, 0xd6, 0x80, 0x80, 0x80, + 0xdb, 0x80, 0xec, 0x80, 0x19, 0xe0, 0x2, 0x80, 0x19, 0xef, 0x16, 0x80, 0xd6, 0x80, 0xe7, 0x80, + 0x11, 0xd6, 0xfc, 0x80, 0xa, 0xd6, 0x17, 0xe7, 0xe4, 0x80, 0xb4, 0xb4, 0x1d, 0xb4, 0xf, 0x80, + 0x32, 0xfb, 0x1b, 0xdb, 0x25, 0xec, 0xf5, 0x80, 0xd6, 0xef, 0x23, 0xec, 0x14, 0x80, 0xe0, 0xdb, + 0xf9, 0x80, 0xcf, 0x80, 0xff, 0xb4, 0xd, 0x80, 0xe4, 0x80, 0x0, 0xc5, 0x1f, 0xdb, 0x23, 0xe0, + 0x1, 0x80, 0x80, 0x80, 0xcf, 0x80, 0xb4, 0x80, 0xe0, 0xf6, 0x1d, 0xcf, 0xdb, 0x80, 0xdb, 0x80, + 0x80, 0xb4, 0xb, 0x80, 0x80, 0x80, 0x1d, 0x80, 0x4, 0xe4, 0xf5, 0x80, 0x80, 0x80, 0x4, 0x80, + 0xe4, 0x80, 0xfc, 0x80, 0xd6, 0x80, 0xf9, 0x80, 0x80, 0xb4, 0xc, 0x80, 0x26, 0xf9, 0x80, 0x80, + 0xb4, 0x80, 0xf1, 0x80, 0x80, 0x80, 0xf3, 0xb4, 0x0, 0x80, 0x2, 0xcf, 0xb4, 0xea, 0x14, 0x80, + 0x18, 0x80, 0xcf, 0x80, 0xd, 0x80, 0xe0, 0x80, 0x16, 0x80, 0xf8, 0xc5, 0x11, 0xb4, 0xf8, 0x80, + 0x80, 0x80, 0x80, 0x80, 0xe4, 0xe, 0x1c, 0x80, 0xfc, 0xb4, 0x2a, 0x6, 0x31, 0x10, 0x1c, 0x80, + 0xfd, 0xfc, 0xc, 0xe7, 0xea, 0x80, 0xe7, 0xd6, 0xd, 0xb4, 0x22, 0xf1, 0x7, 0xb4, 0x1d, 0xf6, + 0x11, 0xd6, 0x28, 0x80, 0xc5, 0xb4, 0x1f, 0xe0, 0x80, 0x80, 0x80, 0x80, 0xfb, 0xe7, 0xc, 0x80, + 0xdb, 0x80, 0xcf, 0x80, 0x80, 0x80, 0xd6, 0xc5, 0xf, 0x80, 0x80, 0xb4, 0x1b, 0x80, 0x0, 0xdb, + 0xf5, 0x80, 0x80, 0x80, 0x15, 0xec, 0xf, 0x80, 0xd6, 0x80, 0x80, 0xb4, 0xc, 0xd6, 0xd6, 0x80, + 0xd6, 0xd6, 0x9, 0x80, 0x80, 0x80, 0x3, 0xc5, 0x9, 0x80, 0x80, 0x80, 0xe4, 0x80, 0xf3, 0x80, + 0x10, 0xea, 0xb4, 0x80, 0xdb, 0xf3, 0xa, 0x80, 0xc5, 0x80, 0xef, 0x80, 0xc5, 0x80, 0xec, 0x80, + 0xff, 0x80, 0xa, 0xc5, 0xf1, 0x80, 0xb4, 0x80, 0xe0, 0x80, 0xfb, 0x80, 0xf8, 0x80, 0x3, 0x80, + 0xc, 0xcf, 0x80, 0xd6, 0xe0, 0x80, 0x80, 0xb4, 0xcf, 0xc5, 0x28, 0xd6, 0x17, 0x80, 0x80, 0x80, + 0xc5, 0xec, 0x14, 0x80, 0xf3, 0x80, 0xf8, 0x80, 0xf3, 0x80, 0xcf, 0x80, 0xf8, 0xe0, 0xea, 0x80, + 0xc5, 0x0, 0x35, 0xea, 0x3, 0x80, 0x80, 0x80, 0x17, 0xf, 0x16, 0x80, 0x19, 0xd6, 0x80, 0x80, + 0x80, 0x80, 0xe0, 0x80, 0xfd, 0x80, 0x4, 0xfc, 0x1e, 0x80, 0xef, 0x80, 0xef, 0xf1, 0x1f, 0x80, + 0xfc, 0x80, 0xe7, 0x80, 0xff, 0x80, 0xf8, 0x80, 0x80, 0x80, 0x17, 0x80, 0xcf, 0xfb, 0x1c, 0x0, + 0x26, 0x11, 0x16, 0x80, 0x80, 0xb4, 0x80, 0x80, 0x80, 0xcf, 0xf3, 0x80, 0x14, 0xb4, 0xdb, 0x5, + 0x19, 0x80, 0xd6, 0x80, 0xf5, 0x80, 0x17, 0xc5, 0x0, 0xc5, 0xcf, 0xc5, 0x4, 0x80, 0x5, 0x80, + 0xa, 0x80, 0x19, 0xd6, 0x28, 0x5, 0xea, 0x80, 0x80, 0x80, 0x80, 0xec, 0xd, 0x80, 0x80, 0x80, + 0x2, 0x80, 0xf1, 0x80, 0x80, 0x80, 0xd6, 0x80, 0xd6, 0x80, 0xdb, 0x80, 0xf3, 0x80, 0xff, 0x80, + 0x80, 0xc5, 0x20, 0x80, 0xea, 0x80, 0xb4, 0x80, 0x22, 0x80, 0x80, 0x80, 0x80, 0x80, 0xc5, 0x80, + 0x15, 0x80, 0x24, 0xc5, 0xfc, 0x80, 0xb, 0xe4, 0xcf, 0x80, 0x80, 0x80, 0xe7, 0xa, 0x1, 0xdb, + 0x12, 0x80, 0xf5, 0x80, 0x80, 0x80, 0xa, 0xd6, 0xfd, 0xf5, 0xfc, 0xcf, 0xe, 0x80, 0xd6, 0x80, + 0x80, 0x80, 0xef, 0x80, 0xfd, 0xc5, 0x12, 0xea, 0x20, 0x80, 0xe0, 0xdb, 0xc5, 0xd6, 0x1a, 0x80, + 0x80, 0xd6, 0x14, 0xc5, 0x80, 0x80, 0x80, 0xb4, 0x80, 0x80, 0xc5, 0xb4, 0xe4, 0xb4, 0xf6, 0x3, + 0xfc, 0x80, 0x80, 0x80, 0xfb, 0x80, 0x0, 0xe4, 0x80, 0x80, 0xb4, 0x80, 0x5, 0xb4, 0x80, 0x80, + 0x19, 0xd6, 0xe0, 0x80, 0x80, 0x80, 0xb4, 0xc5, 0xb4, 0x80, 0xfb, 0x4, 0x13, 0x80, 0xf, 0xc5, + 0x2, 0xec, 0xb4, 0xb4, 0xef, 0x80, 0xe0, 0x80, 0xcf, 0xf5, 0x1, 0x80, 0xe4, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x1, 0x1c, 0x80, 0x80, 0x80, 0xc5, 0xcf, 0xc, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0xfb, 0xc5, 0xcf, 0xdb, 0xcf, 0x80, 0xd6, 0x80, 0xea, 0x80, 0x80, 0x80, 0xd6, 0x80, + 0x80, 0xcf, 0xf9, 0xdb, 0xf8, 0x80, 0xdb, 0xb4, 0xff, 0xe0, 0xb4, 0x80, 0x80, 0x80, 0x80, 0x80, + 0xea, 0x80, 0xc5, 0x80, 0x80, 0x80, 0xe0, 0xec, 0xf5, 0x80, 0x80, 0x80, 0x17, 0x80, 0xcf, 0x80, + 0xf8, 0xf6, 0xe7, 0x80, 0xd6, 0x80, 0xcf, 0x80, 0x80, 0x80, 0xb4, 0x80, 0xe4, 0x80, 0xf8, 0x80, + 0x80, 0x80, 0xdb, 0x80, 0xfb, 0x80, 0x80, 0x80, 0xf3, 0x80, 0x11, 0xc5, 0x80, 0x80, 0xb4, 0x80, + 0x80, 0x80, 0xd6, 0x80, 0xec, 0xb4, 0x14, 0xb4, 0xf3, 0x80, 0xf9, 0x80, 0x8, 0x80, 0x80, 0x80, + 0xe7, 0x80, 0x80, 0xc5, 0xf1, 0x80, 0xf3, 0x80}; diff --git a/tests/micro/testdata/kws/yes_no.tflite b/tests/micro/testdata/kws/yes_no.tflite new file mode 100644 index 000000000000..4f533dac8405 Binary files /dev/null and b/tests/micro/testdata/kws/yes_no.tflite differ diff --git a/tests/micro/zephyr/testdata/digit-2.jpg b/tests/micro/testdata/mnist/digit-2.jpg similarity index 100% rename from tests/micro/zephyr/testdata/digit-2.jpg rename to tests/micro/testdata/mnist/digit-2.jpg diff --git a/tests/micro/zephyr/testdata/digit-9.jpg b/tests/micro/testdata/mnist/digit-9.jpg similarity index 100% rename from tests/micro/zephyr/testdata/digit-9.jpg rename to tests/micro/testdata/mnist/digit-9.jpg diff --git a/tests/micro/zephyr/testdata/mnist-8.onnx b/tests/micro/testdata/mnist/mnist-8.onnx similarity index 100% rename from tests/micro/zephyr/testdata/mnist-8.onnx rename to tests/micro/testdata/mnist/mnist-8.onnx diff --git a/tests/micro/zephyr/conftest.py b/tests/micro/zephyr/conftest.py index b1677e6c10f2..cfdb208c92b8 100644 --- a/tests/micro/zephyr/conftest.py +++ b/tests/micro/zephyr/conftest.py @@ -14,28 +14,34 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import datetime +import os +import pathlib + import pytest +import tvm.contrib.utils import tvm.target.target # The models that should pass this configuration. Maps a short, identifying platform string to # (model, zephyr_board). PLATFORMS = { - "host": ("host", "qemu_x86"), - "host_riscv32": ("host", "qemu_riscv32"), - "host_riscv64": ("host", "qemu_riscv64"), - "mps2_an521": ("mps2_an521", "mps2_an521-qemu"), + "qemu_x86": ("host", "qemu_x86"), + "qemu_riscv32": ("host", "qemu_riscv32"), + "qemu_riscv64": ("host", "qemu_riscv64"), + "mps2_an521": ("mps2_an521", "mps2_an521"), "nrf5340dk": ("nrf5340dk", "nrf5340dk_nrf5340_cpuapp"), "stm32f746xx_disco": ("stm32f746xx", "stm32f746g_disco"), "stm32f746xx_nucleo": ("stm32f746xx", "nucleo_f746zg"), "stm32l4r5zi_nucleo": ("stm32l4r5zi", "nucleo_l4r5zi"), + "zynq_mp_r5": ("zynq_mp_r5", "qemu_cortex_r5"), } def pytest_addoption(parser): parser.addoption( "--microtvm-platforms", - default="host", + default="qemu_x86", choices=PLATFORMS.keys(), help=( "Specify a comma-separated list of test models (i.e. as passed to tvm.target.micro()) " @@ -45,11 +51,6 @@ def pytest_addoption(parser): parser.addoption( "--west-cmd", default="west", help="Path to `west` command for flashing device." ) - parser.addoption( - "--skip-build", - action="store_true", - help="If set true, reuses build from the previous test run. Otherwise, build from the scratch.", - ) parser.addoption( "--tvm-debug", action="store_true", @@ -69,10 +70,27 @@ def west_cmd(request): @pytest.fixture -def skip_build(request): - return request.config.getoption("--skip-build") +def tvm_debug(request): + return request.config.getoption("--tvm-debug") @pytest.fixture -def tvm_debug(request): - return request.config.getoption("--tvm-debug") +def temp_dir(platform): + _, zephyr_board = PLATFORMS[platform] + parent_dir = pathlib.Path(os.path.dirname(__file__)) + filename = os.path.splitext(os.path.basename(__file__))[0] + board_workspace = ( + parent_dir + / f"workspace_{filename}_{zephyr_board}" + / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + ) + board_workspace_base = str(board_workspace) + number = 1 + while board_workspace.exists(): + board_workspace = pathlib.Path(board_workspace_base + f"-{number}") + number += 1 + + if not os.path.exists(board_workspace.parent): + os.makedirs(board_workspace.parent) + + return tvm.contrib.utils.tempdir(board_workspace) diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index 18587acd46ae..5a7e69e3c7f9 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -21,6 +21,7 @@ import glob import logging import os +import pathlib import subprocess import sys import logging @@ -35,8 +36,8 @@ import tvm.micro import tvm.testing import tvm.relay as relay +from tvm.relay.testing import byoc -from tvm.micro.contrib import zephyr from tvm.contrib import utils from tvm.relay.expr_functor import ExprMutator from tvm.relay.op.annotation import compiler_begin, compiler_end @@ -48,89 +49,63 @@ PLATFORMS = conftest.PLATFORMS -def _make_sess_from_op(model, zephyr_board, west_cmd, op_name, sched, arg_bufs, build_config): +def _make_sess_from_op( + temp_dir, model, zephyr_board, west_cmd, op_name, sched, arg_bufs, build_config +): target = tvm.target.target.micro(model) target = tvm.target.Target(target=target, host=target) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): mod = tvm.build(sched, arg_bufs, target=target, name=op_name) - return _make_session(model, target, zephyr_board, west_cmd, mod, build_config) - - -def _make_session(model, target, zephyr_board, west_cmd, mod, build_config): - parent_dir = os.path.dirname(__file__) - filename = os.path.splitext(os.path.basename(__file__))[0] - prev_build = f"{os.path.join(parent_dir, 'archive')}_{filename}_{zephyr_board}_last_build.micro" - workspace_root = os.path.join( - f"{os.path.join(parent_dir, 'workspace')}_{filename}_{zephyr_board}", - datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"), - ) - workspace_parent = os.path.dirname(workspace_root) - if not os.path.exists(workspace_parent): - os.makedirs(workspace_parent) - workspace = tvm.micro.Workspace(debug=True, root=workspace_root) - - test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - tvm_source_dir = os.path.join(test_dir, "..", "..", "..") - runtime_path = os.path.join(tvm_source_dir, "apps", "microtvm", "zephyr", "host_driven") - compiler = zephyr.ZephyrCompiler( - project_dir=runtime_path, - board=zephyr_board, - zephyr_toolchain_variant="zephyr", - west_cmd=west_cmd, + return _make_session(temp_dir, zephyr_board, west_cmd, mod, build_config) + + +TEMPLATE_PROJECT_DIR = ( + pathlib.Path(__file__).parent + / ".." + / ".." + / ".." + / "apps" + / "microtvm" + / "zephyr" + / "template_project" +).resolve() + + +def _make_session(temp_dir, zephyr_board, west_cmd, mod, build_config): + project = tvm.micro.generate_project( + str(TEMPLATE_PROJECT_DIR), + mod, + temp_dir / "project", + { + "project_type": "host_driven", + "west_cmd": west_cmd, + "verbose": bool(build_config.get("debug")), + "zephyr_board": zephyr_board, + }, ) - - opts = tvm.micro.default_options(os.path.join(runtime_path, "crt")) - # TODO(weberlo) verify this is necessary - opts["bin_opts"]["ccflags"] = ["-std=gnu++14"] - opts["lib_opts"]["ccflags"] = ["-std=gnu++14"] - - flasher_kw = {} - if build_config["debug"]: - flasher_kw["debug_rpc_session"] = tvm.rpc.connect("127.0.0.1", 9090) - - session_kw = { - "flasher": compiler.flasher(**flasher_kw), - } - - if not build_config["skip_build"]: - session_kw["binary"] = tvm.micro.build_static_runtime( - # the x86 compiler *expects* you to give the exact same dictionary for both - # lib_opts and bin_opts. so the library compiler is mutating lib_opts and - # the binary compiler is expecting those mutations to be in bin_opts. - # TODO(weberlo) fix this very bizarre behavior - workspace, - compiler, - mod, - opts, - ) - if os.path.exists(prev_build): - os.unlink(prev_build) - session_kw["binary"].archive(prev_build, metadata_only=True) - else: - unarchive_dir = utils.tempdir() - session_kw["binary"] = tvm.micro.MicroBinary.unarchive( - prev_build, unarchive_dir.relpath("binary") - ) - - return tvm.micro.Session(**session_kw) + project.build() + project.flash() + return tvm.micro.Session(project.transport()) -def _make_add_sess(model, zephyr_board, west_cmd, build_config): - A = tvm.te.placeholder((2,), dtype="int8") - B = tvm.te.placeholder((1,), dtype="int8") +def _make_add_sess(temp_dir, model, zephyr_board, west_cmd, build_config, dtype="int8"): + A = tvm.te.placeholder((2,), dtype=dtype) + B = tvm.te.placeholder((1,), dtype=dtype) C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name="C") sched = tvm.te.create_schedule(C.op) - return _make_sess_from_op(model, zephyr_board, west_cmd, "add", sched, [A, B, C], build_config) + return _make_sess_from_op( + temp_dir, model, zephyr_board, west_cmd, "add", sched, [A, B, C], build_config + ) # The same test code can be executed on both the QEMU simulation and on real hardware. @tvm.testing.requires_micro -def test_compile_runtime(platform, west_cmd, skip_build, tvm_debug): +def test_add_uint(temp_dir, platform, west_cmd, tvm_debug): """Test compiling the on-device runtime.""" model, zephyr_board = PLATFORMS[platform] - build_config = {"skip_build": skip_build, "debug": tvm_debug} + build_config = {"debug": tvm_debug} # NOTE: run test in a nested function so cPython will delete arrays before closing the session. def test_basic_add(sess): @@ -145,16 +120,55 @@ def test_basic_add(sess): system_lib.get_function("add")(A_data, B_data, C_data) assert (C_data.numpy() == np.array([6, 7])).all() - with _make_add_sess(model, zephyr_board, west_cmd, build_config) as sess: + with _make_add_sess(temp_dir, model, zephyr_board, west_cmd, build_config) as sess: + test_basic_add(sess) + + +def has_fpu(zephyr_board): + sys.path.insert(0, str(TEMPLATE_PROJECT_DIR)) + try: + import microtvm_api_server + finally: + sys.path.pop(0) + + return microtvm_api_server.Handler._has_fpu(zephyr_board) + + +# The same test code can be executed on both the QEMU simulation and on real hardware. +@tvm.testing.requires_micro +def test_add_float(temp_dir, platform, west_cmd, tvm_debug): + """Test compiling the on-device runtime.""" + model, zephyr_board = PLATFORMS[platform] + if not has_fpu(zephyr_board): + pytest.skip(f"FPU not enabled for {platform}") + + build_config = {"debug": tvm_debug} + + # NOTE: run test in a nested function so cPython will delete arrays before closing the session. + def test_basic_add(sess): + A_data = tvm.nd.array(np.array([2.5, 3.5], dtype="float32"), device=sess.device) + assert (A_data.numpy() == np.array([2.5, 3.5])).all() + B_data = tvm.nd.array(np.array([4.5], dtype="float32"), device=sess.device) + assert (B_data.numpy() == np.array([4.5])).all() + C_data = tvm.nd.array(np.array([0, 0], dtype="float32"), device=sess.device) + assert (C_data.numpy() == np.array([0, 0])).all() + + system_lib = sess.get_system_lib() + system_lib.get_function("add")(A_data, B_data, C_data) + assert (C_data.numpy() == np.array([7, 8])).all() + + with _make_add_sess( + temp_dir, model, zephyr_board, west_cmd, build_config, dtype="float32" + ) as sess: test_basic_add(sess) @tvm.testing.requires_micro -def test_platform_timer(platform, west_cmd, skip_build, tvm_debug): +def test_platform_timer(temp_dir, platform, west_cmd, tvm_debug): """Test compiling the on-device runtime.""" model, zephyr_board = PLATFORMS[platform] - build_config = {"skip_build": skip_build, "debug": tvm_debug} + build_config = {"debug": tvm_debug} # NOTE: run test in a nested function so cPython will delete arrays before closing the session. def test_basic_add(sess): @@ -174,15 +188,15 @@ def test_basic_add(sess): assert result.mean > 0 assert len(result.results) == 3 - with _make_add_sess(model, zephyr_board, west_cmd, build_config) as sess: + with _make_add_sess(temp_dir, model, zephyr_board, west_cmd, build_config) as sess: test_basic_add(sess) @tvm.testing.requires_micro -def test_relay(platform, west_cmd, skip_build, tvm_debug): +def test_relay(temp_dir, platform, west_cmd, tvm_debug): """Testing a simple relay graph""" model, zephyr_board = PLATFORMS[platform] - build_config = {"skip_build": skip_build, "debug": tvm_debug} + build_config = {"debug": tvm_debug} shape = (10,) dtype = "int8" @@ -191,16 +205,17 @@ def test_relay(platform, west_cmd, skip_build, tvm_debug): xx = relay.multiply(x, x) z = relay.add(xx, relay.const(np.ones(shape=shape, dtype=dtype))) func = relay.Function([x], z) + ir_mod = tvm.IRModule.from_expr(func) target = tvm.target.target.micro(model) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - graph, mod, params = tvm.relay.build(func, target=target) + mod = tvm.relay.build(ir_mod, target=target) - with _make_session(model, target, zephyr_board, west_cmd, mod, build_config) as session: + with _make_session(temp_dir, zephyr_board, west_cmd, mod, build_config) as session: graph_mod = tvm.micro.create_local_graph_executor( - graph, session.get_system_lib(), session.device + mod.get_graph_json(), session.get_system_lib(), session.device ) - graph_mod.set_input(**params) + graph_mod.set_input(**mod.get_params()) x_in = np.random.randint(10, size=shape[0], dtype=dtype) graph_mod.run(x=x_in) result = graph_mod.get_output(0).numpy() @@ -209,23 +224,23 @@ def test_relay(platform, west_cmd, skip_build, tvm_debug): @tvm.testing.requires_micro -def test_onnx(platform, west_cmd, skip_build, tvm_debug): +def test_onnx(temp_dir, platform, west_cmd, tvm_debug): """Testing a simple ONNX model.""" model, zephyr_board = PLATFORMS[platform] - build_config = {"skip_build": skip_build, "debug": tvm_debug} + build_config = {"debug": tvm_debug} - # Load test images. - this_dir = os.path.dirname(__file__) - digit_2 = Image.open(f"{this_dir}/testdata/digit-2.jpg").resize((28, 28)) + this_dir = pathlib.Path(os.path.dirname(__file__)) + mnist_testdata = this_dir.parent / "testdata" / "mnist" + digit_2 = Image.open(mnist_testdata / "digit-2.jpg").resize((28, 28)) digit_2 = np.asarray(digit_2).astype("float32") digit_2 = np.expand_dims(digit_2, axis=0) - digit_9 = Image.open(f"{this_dir}/testdata/digit-9.jpg").resize((28, 28)) + digit_9 = Image.open(mnist_testdata / "digit-9.jpg").resize((28, 28)) digit_9 = np.asarray(digit_9).astype("float32") digit_9 = np.expand_dims(digit_9, axis=0) # Load ONNX model and convert to Relay. - onnx_model = onnx.load(f"{this_dir}/testdata/mnist-8.onnx") + onnx_model = onnx.load(mnist_testdata / "mnist-8.onnx") shape = {"Input3": (1, 1, 28, 28)} relay_mod, params = relay.frontend.from_onnx(onnx_model, shape=shape, freeze_params=True) relay_mod = relay.transform.DynamicToStatic()(relay_mod) @@ -239,7 +254,7 @@ def test_onnx(platform, west_cmd, skip_build, tvm_debug): lowered = relay.build(relay_mod, target, params=params) graph = lowered.get_graph_json() - with _make_session(model, target, zephyr_board, west_cmd, lowered.lib, build_config) as session: + with _make_session(temp_dir, zephyr_board, west_cmd, lowered, build_config) as session: graph_mod = tvm.micro.create_local_graph_executor( graph, session.get_system_lib(), session.device ) @@ -257,77 +272,23 @@ def test_onnx(platform, west_cmd, skip_build, tvm_debug): assert np.argmax(result) == 9 -class CcompilerAnnotator(ExprMutator): - """ - This is used to create external functions for ccompiler. - A simple annotator that creates the following program: - | - -- begin -- - | - add - | - subtract - | - multiply - | - -- end -- - | - """ - - def __init__(self): - super(CcompilerAnnotator, self).__init__() - self.in_compiler = 0 - - def visit_call(self, call): - if call.op.name == "add": # Annotate begin at args - if self.in_compiler == 1: - lhs = compiler_begin(super().visit(call.args[0]), "ccompiler") - rhs = compiler_begin(super().visit(call.args[1]), "ccompiler") - op = relay.add(lhs, rhs) - self.in_compiler = 2 - return op - elif call.op.name == "subtract": - if self.in_compiler == 1: - lhs = super().visit(call.args[0]) - rhs = super().visit(call.args[1]) - if isinstance(lhs, relay.expr.Var): - lhs = compiler_begin(lhs, "ccompiler") - if isinstance(rhs, relay.expr.Var): - rhs = compiler_begin(rhs, "ccompiler") - return relay.subtract(lhs, rhs) - elif call.op.name == "multiply": # Annotate end at output - self.in_compiler = 1 - lhs = super().visit(call.args[0]) - rhs = super().visit(call.args[1]) - if isinstance(lhs, relay.expr.Var): - lhs = compiler_begin(lhs, "ccompiler") - if isinstance(rhs, relay.expr.Var): - rhs = compiler_begin(rhs, "ccompiler") - op = relay.multiply(lhs, rhs) - if self.in_compiler == 2: - op = compiler_end(op, "ccompiler") - self.in_compiler = 0 - return op - return super().visit_call(call) - - def check_result( - relay_mod, model, zephyr_board, west_cmd, map_inputs, out_shape, result, build_config + temp_dir, relay_mod, model, zephyr_board, west_cmd, map_inputs, out_shape, result, build_config ): """Helper function to verify results""" TOL = 1e-5 target = tvm.target.target.micro(model) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - graph, mod, params = tvm.relay.build(relay_mod, target=target) + mod = tvm.relay.build(relay_mod, target=target) - with _make_session(model, target, zephyr_board, west_cmd, mod, build_config) as session: + with _make_session(temp_dir, zephyr_board, west_cmd, mod, build_config) as session: rt_mod = tvm.micro.create_local_graph_executor( - graph, session.get_system_lib(), session.device + mod.get_graph_json(), session.get_system_lib(), session.device ) - rt_mod.set_input(**params) + rt_mod.set_input(**mod.get_params()) for name, data in map_inputs.items(): rt_mod.set_input(name, data) - rt_mod.set_input(**params) + rt_mod.set_input(**mod.get_params()) rt_mod.run() out_shapes = out_shape if isinstance(out_shape, list) else [out_shape] @@ -340,10 +301,10 @@ def check_result( @tvm.testing.requires_micro -def test_byoc_microtvm(platform, west_cmd, skip_build, tvm_debug): +def test_byoc_microtvm(temp_dir, platform, west_cmd, tvm_debug): """This is a simple test case to check BYOC capabilities of microTVM""" model, zephyr_board = PLATFORMS[platform] - build_config = {"skip_build": skip_build, "debug": tvm_debug} + build_config = {"debug": tvm_debug} x = relay.var("x", shape=(10, 10)) w0 = relay.var("w0", shape=(10, 10)) w1 = relay.var("w1", shape=(10, 10)) @@ -370,7 +331,7 @@ def test_byoc_microtvm(platform, west_cmd, skip_build, tvm_debug): r = relay.concatenate((q0, q1, q2), axis=0) f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) mod = tvm.IRModule() - ann = CcompilerAnnotator() + ann = byoc.CcompilerAnnotator() mod["main"] = ann.visit(f) mod = tvm.relay.transform.PartitionGraph()(mod) mod = tvm.relay.transform.InferType()(mod) @@ -383,6 +344,7 @@ def test_byoc_microtvm(platform, west_cmd, skip_build, tvm_debug): map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} map_inputs["x"] = x_data check_result( + temp_dir=temp_dir, relay_mod=mod, map_inputs=map_inputs, out_shape=(30, 10), @@ -401,11 +363,13 @@ def test_byoc_microtvm(platform, west_cmd, skip_build, tvm_debug): ) -def _make_add_sess_with_shape(model, zephyr_board, west_cmd, shape, build_config): +def _make_add_sess_with_shape(temp_dir, model, zephyr_board, west_cmd, shape, build_config): A = tvm.te.placeholder(shape, dtype="int8") C = tvm.te.compute(A.shape, lambda i: A[i] + A[i], name="C") sched = tvm.te.create_schedule(C.op) - return _make_sess_from_op(model, zephyr_board, west_cmd, "add", sched, [A, C], build_config) + return _make_sess_from_op( + temp_dir, model, zephyr_board, west_cmd, "add", sched, [A, C], build_config + ) @pytest.mark.parametrize( @@ -417,21 +381,23 @@ def _make_add_sess_with_shape(model, zephyr_board, west_cmd, shape, build_config ], ) @tvm.testing.requires_micro -def test_rpc_large_array(platform, west_cmd, skip_build, tvm_debug, shape): +def test_rpc_large_array(temp_dir, platform, west_cmd, tvm_debug, shape): """Test large RPC array transfer.""" model, zephyr_board = PLATFORMS[platform] - build_config = {"skip_build": skip_build, "debug": tvm_debug} + build_config = {"debug": tvm_debug} # NOTE: run test in a nested function so cPython will delete arrays before closing the session. def test_tensors(sess): a_np = np.random.randint(low=-128, high=127, size=shape, dtype="int8") A_data = tvm.nd.array(a_np, device=sess.device) - assert (A_data.asnumpy() == a_np).all() + assert (A_data.numpy() == a_np).all() C_data = tvm.nd.array(np.zeros(shape, dtype="int8"), device=sess.device) - assert (C_data.asnumpy() == np.zeros(shape)).all() + assert (C_data.numpy() == np.zeros(shape)).all() - with _make_add_sess_with_shape(model, zephyr_board, west_cmd, shape, build_config) as sess: + with _make_add_sess_with_shape( + temp_dir, model, zephyr_board, west_cmd, shape, build_config + ) as sess: test_tensors(sess) diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot.py index ad4776e9c6b3..37aa0f76a852 100644 --- a/tests/micro/zephyr/test_zephyr_aot.py +++ b/tests/micro/zephyr/test_zephyr_aot.py @@ -15,13 +15,14 @@ # specific language governing permissions and limitations # under the License. -import datetime -from hashlib import new +import io import logging import os import sys import logging import pathlib +import tarfile +import tempfile import pytest import numpy as np @@ -29,12 +30,13 @@ import tvm import tvm.rpc import tvm.micro +from tvm.micro.project_api import server import tvm.testing import tvm.relay as relay -from tvm.micro.contrib import zephyr from tvm.contrib import utils from tvm.contrib.download import download_testdata +from tvm.micro.interface_api import generate_c_interface_header import conftest @@ -43,98 +45,77 @@ PLATFORMS = conftest.PLATFORMS -def _build_session_kw(model, target, zephyr_board, west_cmd, mod, runtime_path, build_config): - parent_dir = os.path.dirname(__file__) - filename = os.path.splitext(os.path.basename(__file__))[0] - prev_build = f"{os.path.join(parent_dir, 'archive')}_{filename}_{zephyr_board}_last_build.micro" - workspace_root = os.path.join( - f"{os.path.join(parent_dir, 'workspace')}_{filename}_{zephyr_board}", - datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"), +def _build_project(temp_dir, zephyr_board, west_cmd, mod, build_config, extra_files_tar=None): + template_project_dir = ( + pathlib.Path(__file__).parent + / ".." + / ".." + / ".." + / "apps" + / "microtvm" + / "zephyr" + / "template_project" + ).resolve() + project_dir = temp_dir / "project" + project = tvm.micro.generate_project( + str(template_project_dir), + mod, + project_dir, + { + "extra_files_tar": extra_files_tar, + "project_type": "aot_demo", + "west_cmd": west_cmd, + "verbose": bool(build_config.get("debug")), + "zephyr_board": zephyr_board, + }, ) - workspace_parent = os.path.dirname(workspace_root) - if not os.path.exists(workspace_parent): - os.makedirs(workspace_parent) - workspace = tvm.micro.Workspace(debug=True, root=workspace_root) - - compiler = zephyr.ZephyrCompiler( - project_dir=runtime_path, - board=zephyr_board, - zephyr_toolchain_variant="zephyr", - west_cmd=west_cmd, - env_vars={"ZEPHYR_RUNTIME": "ZEPHYR-AOT"}, - ) - - opts = tvm.micro.default_options(os.path.join(runtime_path, "crt")) - opts["bin_opts"]["include_dirs"].append(os.path.join(runtime_path, "include")) - opts["lib_opts"]["include_dirs"].append(os.path.join(runtime_path, "include")) - - flasher_kw = {} - if build_config["debug"]: - flasher_kw["debug_rpc_session"] = tvm.rpc.connect("127.0.0.1", 9090) - - session_kw = { - "flasher": compiler.flasher(**flasher_kw), - } - - if not build_config["skip_build"]: - session_kw["binary"] = tvm.micro.build_static_runtime( - workspace, - compiler, - mod, - opts, - executor="aot", - extra_libs=[tvm.micro.get_standalone_crt_lib("memory")], - ) - if os.path.exists(prev_build): - os.unlink(prev_build) - session_kw["binary"].archive(prev_build, metadata_only=True) - else: - unarchive_dir = utils.tempdir() - session_kw["binary"] = tvm.micro.MicroBinary.unarchive( - prev_build, unarchive_dir.relpath("binary") - ) - - return session_kw + project.build() + return project, project_dir -def _create_header_file(tensor_name, npy_data, output_path): +def _create_header_file(tensor_name, npy_data, output_path, tar_file): """ This method generates a header file containing the data contained in the numpy array provided. It is used to capture the tensor data (for both inputs and expected outputs). """ - file_path = pathlib.Path(f"{output_path}/" + tensor_name).resolve() - # create header file - raw_path = file_path.with_suffix(".h").resolve() - with open(raw_path, "w") as header_file: - header_file.write("#include \n") - header_file.write("#include \n") - header_file.write("#include \n") - header_file.write(f"const size_t {tensor_name}_len = {npy_data.size};\n") - - if npy_data.dtype == "int8": - header_file.write(f"int8_t {tensor_name}[] =") - elif npy_data.dtype == "int32": - header_file.write(f"int32_t {tensor_name}[] = ") - elif npy_data.dtype == "uint8": - header_file.write(f"uint8_t {tensor_name}[] = ") - elif npy_data.dtype == "float32": - header_file.write(f"float {tensor_name}[] = ") - else: - raise ValueError("Data type not expected.") - - header_file.write("{") - for i in np.ndindex(npy_data.shape): - header_file.write(f"{npy_data[i]}, ") - header_file.write("};\n\n") - - -def _read_line(fd): + header_file = io.StringIO() + header_file.write("#include \n") + header_file.write("#include \n") + header_file.write("#include \n") + header_file.write(f"const size_t {tensor_name}_len = {npy_data.size};\n") + + if npy_data.dtype == "int8": + header_file.write(f"int8_t {tensor_name}[] =") + elif npy_data.dtype == "int32": + header_file.write(f"int32_t {tensor_name}[] = ") + elif npy_data.dtype == "uint8": + header_file.write(f"uint8_t {tensor_name}[] = ") + elif npy_data.dtype == "float32": + header_file.write(f"float {tensor_name}[] = ") + else: + raise ValueError("Data type not expected.") + + header_file.write("{") + for i in np.ndindex(npy_data.shape): + header_file.write(f"{npy_data[i]}, ") + header_file.write("};\n\n") + + header_file_bytes = bytes(header_file.getvalue(), "utf-8") + raw_path = pathlib.Path(output_path) / f"{tensor_name}.h" + ti = tarfile.TarInfo(name=str(raw_path)) + ti.size = len(header_file_bytes) + ti.mode = 0o644 + ti.type = tarfile.REGTYPE + tar_file.addfile(ti, io.BytesIO(header_file_bytes)) + + +def _read_line(fd, timeout_sec: int): data = "" new_line = False while True: if new_line: break - new_data = fd.read(1, timeout_sec=10) + new_data = fd.read(1, timeout_sec=timeout_sec) logging.debug(f"read data: {new_data}") for item in new_data: new_c = chr(item) @@ -145,25 +126,26 @@ def _read_line(fd): return data -def _get_message(fd, expr: str): +def _get_message(fd, expr: str, timeout_sec: int): while True: - data = _read_line(fd) + data = _read_line(fd, timeout_sec) logging.debug(f"new line: {data}") if expr in data: return data @tvm.testing.requires_micro -def test_tflite(platform, west_cmd, skip_build, tvm_debug): +def test_tflite(temp_dir, platform, west_cmd, tvm_debug): """Testing a TFLite model.""" + + if platform not in ["qemu_x86", "mps2_an521", "nrf5340dk", "stm32l4r5zi_nucleo", "zynq_mp_r5"]: + pytest.skip(msg="Model does not fit.") + model, zephyr_board = PLATFORMS[platform] input_shape = (1, 32, 32, 3) output_shape = (1, 10) - build_config = {"skip_build": skip_build, "debug": tvm_debug} + build_config = {"debug": tvm_debug} - this_dir = os.path.dirname(__file__) - tvm_source_dir = os.path.join(this_dir, "..", "..", "..") - runtime_path = os.path.join(tvm_source_dir, "apps", "microtvm", "zephyr", "aot_demo") model_url = "https://github.com/eembc/ulpmark-ml/raw/fc1499c7cc83681a02820d5ddf5d97fe75d4f663/base_models/ic01/ic01_fp32.tflite" model_path = download_testdata(model_url, "ic01_fp32.tflite", module="model") @@ -183,7 +165,9 @@ def test_tflite(platform, west_cmd, skip_build, tvm_debug): tflite_model, shape_dict={"input_1": input_shape}, dtype_dict={"input_1 ": "float32"} ) - target = tvm.target.target.micro(model, options=["-link-params=1", "--executor=aot"]) + target = tvm.target.target.micro( + model, options=["-link-params=1", "--executor=aot", "--unpacked-api=1", "--interface-api=c"] + ) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): lowered = relay.build(relay_mod, target, params=params) @@ -193,20 +177,38 @@ def test_tflite(platform, west_cmd, skip_build, tvm_debug): sample_url, "testdata_image_classification_fp32_8.npy", module="data" ) sample = np.load(sample_path) - model_files_path = os.path.join(runtime_path, "include") - _create_header_file((f"input_data"), sample, model_files_path) - _create_header_file( - "output_data", np.zeros(shape=output_shape, dtype="float32"), model_files_path - ) - session_kw = _build_session_kw( - model, target, zephyr_board, west_cmd, lowered.lib, runtime_path, build_config - ) - transport = session_kw["flasher"].flash(session_kw["binary"]) - transport.open() - transport.write(b"start\n", timeout_sec=5) + with tempfile.NamedTemporaryFile() as tar_temp_file: + with tarfile.open(tar_temp_file.name, "w:gz") as tf: + with tempfile.TemporaryDirectory() as tar_temp_dir: + model_files_path = os.path.join(tar_temp_dir, "include") + os.mkdir(model_files_path) + header_path = generate_c_interface_header( + lowered.libmod_name, ["input_1"], ["output"], model_files_path + ) + tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir)) + + _create_header_file("input_data", sample, "include", tf) + _create_header_file( + "output_data", np.zeros(shape=output_shape, dtype="float32"), "include", tf + ) + + project, _ = _build_project( + temp_dir, + zephyr_board, + west_cmd, + lowered, + build_config, + extra_files_tar=tar_temp_file.name, + ) + + project.flash() + with project.transport() as transport: + timeout_read = 60 + _get_message(transport, "#wakeup", timeout_sec=timeout_read) + transport.write(b"start\n", timeout_sec=5) + result_line = _get_message(transport, "#result", timeout_sec=timeout_read) - result_line = _get_message(transport, "#result") result_line = result_line.strip("\n") result_line = result_line.split(":") result = int(result_line[1]) @@ -216,47 +218,59 @@ def test_tflite(platform, west_cmd, skip_build, tvm_debug): @tvm.testing.requires_micro -def test_qemu_make_fail(platform, west_cmd, skip_build, tvm_debug): - if platform not in ["host", "mps2_an521"]: +def test_qemu_make_fail(temp_dir, platform, west_cmd, tvm_debug): + """Testing QEMU make fail.""" + if platform not in ["qemu_x86", "mps2_an521"]: pytest.skip(msg="Only for QEMU targets.") - """Testing QEMU make fail.""" model, zephyr_board = PLATFORMS[platform] - build_config = {"skip_build": skip_build, "debug": tvm_debug} + build_config = {"debug": tvm_debug} shape = (10,) dtype = "float32" - this_dir = pathlib.Path(__file__).parent - tvm_source_dir = this_dir / ".." / ".." / ".." - runtime_path = tvm_source_dir / "apps" / "microtvm" / "zephyr" / "aot_demo" - # Construct Relay program. x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) xx = relay.multiply(x, x) z = relay.add(xx, relay.const(np.ones(shape=shape, dtype=dtype))) func = relay.Function([x], z) + ir_mod = tvm.IRModule.from_expr(func) target = tvm.target.target.micro(model, options=["-link-params=1", "--executor=aot"]) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - lowered = relay.build(func, target) + lowered = relay.build(ir_mod, target) # Generate input/output header files - model_files_path = os.path.join(runtime_path, "include") - _create_header_file((f"input_data"), np.zeros(shape=shape, dtype=dtype), model_files_path) - _create_header_file("output_data", np.zeros(shape=shape, dtype=dtype), model_files_path) + with tempfile.NamedTemporaryFile() as tar_temp_file: + with tarfile.open(tar_temp_file.name, "w:gz") as tf: + with tempfile.TemporaryDirectory() as tar_temp_dir: + model_files_path = os.path.join(tar_temp_dir, "include") + os.mkdir(model_files_path) + header_path = generate_c_interface_header( + lowered.libmod_name, ["input_1"], ["output"], model_files_path + ) + tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir)) + _create_header_file("input_data", np.zeros(shape=shape, dtype=dtype), "include", tf) + _create_header_file("output_data", np.zeros(shape=shape, dtype=dtype), "include", tf) + + project, project_dir = _build_project( + temp_dir, + zephyr_board, + west_cmd, + lowered, + build_config, + extra_files_tar=tar_temp_file.name, + ) - session_kw = _build_session_kw( - model, target, zephyr_board, west_cmd, lowered.lib, runtime_path, build_config + file_path = ( + pathlib.Path(project_dir) / "build" / "zephyr" / "CMakeFiles" / "run.dir" / "build.make" ) - - file_path = os.path.join(session_kw["binary"].base_dir, "zephyr/CMakeFiles/run.dir/build.make") - assert os.path.isfile(file_path), f"[{file_path}] does not exist." + assert file_path.is_file(), f"[{file_path}] does not exist." # Remove a file to create make failure. os.remove(file_path) - transport = session_kw["flasher"].flash(session_kw["binary"]) - with pytest.raises(RuntimeError) as excinfo: - transport.open() + project.flash() + with pytest.raises(server.JSONRPCError) as excinfo: + project.transport().open() assert "QEMU setup failed" in str(excinfo.value) diff --git a/tests/python/contrib/test_arm_compute_lib/test_dense.py b/tests/python/contrib/test_arm_compute_lib/test_dense.py index e6620a4bc1cb..6bdff0fdb857 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_dense.py +++ b/tests/python/contrib/test_arm_compute_lib/test_dense.py @@ -150,11 +150,16 @@ def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False): if has_bias: bias_dtype = "int32" if dtype == "uint8" else "float32" + bias_shape = ( + [1, weight_shape[0]] + if dtype == "float32" and weight_shape[0] != 1 + else [weight_shape[0]] + ) inputs.append( { "op": "const", "name": "", - "attrs": {"shape": [[[weight_shape[0]]]], "dtype": [[bias_dtype]]}, + "attrs": {"shape": [[bias_shape]], "dtype": [[bias_dtype]]}, } ) diff --git a/tests/python/contrib/test_bnns/test_conv2d_patterns.py b/tests/python/contrib/test_bnns/test_conv2d_patterns.py index b81e74b6d8fa..5fc9e9522fbd 100644 --- a/tests/python/contrib/test_bnns/test_conv2d_patterns.py +++ b/tests/python/contrib/test_bnns/test_conv2d_patterns.py @@ -57,7 +57,7 @@ def test_pattern_conv2d_with_bias_add(): res = relay.nn.bias_add(res, b, axis=axis) mod = partition(res) - bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_0"], "nn.bias_add") + bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_main_0"], "nn.bias_add") assert bias_is_fused if axis == 1 else not bias_is_fused @@ -73,7 +73,7 @@ def test_pattern_conv2d_with_add(): res = relay.add(res, b) mod = partition(res) - bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_0"], "add") + bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_main_0"], "add") assert bias_is_fused == should_be_fused @@ -102,6 +102,6 @@ def test_pattern_conv2d_with_non_cons_bias(): res = relay.nn.bias_add(res, b, axis=1) mod = partition(res) - bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_0"], "nn.bias_add") + bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_main_0"], "nn.bias_add") assert not bias_is_fused diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 069f3f3769b5..839c6f50c070 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -98,6 +98,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1): @tvm.testing.requires_gpu +@requires_cudnn def test_conv2d(): verify_conv2d("float32", "float32", tensor_format=0) verify_conv2d("float16", "float32", tensor_format=1) @@ -171,6 +172,7 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1): @tvm.testing.requires_gpu +@requires_cudnn def test_conv3d(): verify_conv3d("float32", "float32", tensor_format=0) verify_conv3d("float32", "float32", tensor_format=0, groups=2) diff --git a/tests/python/contrib/test_edgetpu_runtime.py b/tests/python/contrib/test_edgetpu_runtime.py index 7e59ab2e3cc6..2bf58106dfdc 100644 --- a/tests/python/contrib/test_edgetpu_runtime.py +++ b/tests/python/contrib/test_edgetpu_runtime.py @@ -51,7 +51,7 @@ def init_interpreter(model_path, target_edgetpu): interpreter = tflite.Interpreter(model_path=model_path) return interpreter - def check_remote(target_edgetpu=False): + def check_remote(server, target_edgetpu=False): tflite_model_path = get_tflite_model_path(target_edgetpu) # inference via tflite interpreter python apis @@ -67,7 +67,6 @@ def check_remote(target_edgetpu=False): tflite_output = interpreter.get_tensor(output_details[0]["index"]) # inference via remote tvm tflite runtime - server = rpc.Server("127.0.0.1") remote = rpc.connect(server.host, server.port) dev = remote.cpu(0) if target_edgetpu: @@ -83,9 +82,9 @@ def check_remote(target_edgetpu=False): np.testing.assert_equal(out.numpy(), tflite_output) # Target CPU on coral board - check_remote() + check_remote(rpc.Server("127.0.0.1")) # Target EdgeTPU on coral board - check_remote(target_edgetpu=True) + check_remote(rpc.Server("127.0.0.1"), target_edgetpu=True) if __name__ == "__main__": diff --git a/tests/python/contrib/test_ethosn/test_networks.py b/tests/python/contrib/test_ethosn/test_networks.py index 6ff8011cf4d7..65d2738447cc 100644 --- a/tests/python/contrib/test_ethosn/test_networks.py +++ b/tests/python/contrib/test_ethosn/test_networks.py @@ -122,13 +122,13 @@ def test_mobilenet_v1(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"5d3cee6ecc488c40ecf533c5cbacc534"} + _compile_hash = {"1fd4ef29a1ea9f3a015cab87c0b8014a"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"896c28b4f06341ea638ead3a593e1aed"} + _compile_hash = {"b879dfbff1f907eaf6129dfd41b44ece"} if tei.get_ethosn_api_version() == 2011: - _compile_hash = {"9298b6c51e2a82f70e91dd11dd6af412"} + _compile_hash = {"9c9f63b30824f5b223cdb27d2f22c857"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"407eb47346c8afea2d15e8f0d1c079f2"} + _compile_hash = {"cd13279061df2319124a7aac81581d81"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", @@ -148,13 +148,13 @@ def test_inception_v3(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"1bc66e83c3de5a9773a719b179c65b1a"} + _compile_hash = {"b90ed315639c6a0e97584c2dbc42a55c"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"551cde850c6ef960d19be4f317fb8e68"} + _compile_hash = {"5693569055695e581a8739194d0301aa"} if tei.get_ethosn_api_version() == 2011: - _compile_hash = {"d44eece5027ff56e5e7fcf014367378d"} + _compile_hash = {"46ccafc840633633aca441645e41b444"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"1ba555b4bc60c428018a0f2de9d90532"} + _compile_hash = {"4a33f397ac3e15c0f9869f7b8286fc2f"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" "models/tflite_11_05_08/inception_v3_quant.tgz", @@ -173,13 +173,13 @@ def test_inception_v4(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"578b8ee279911b49912a77a64f5ff620"} + _compile_hash = {"b36877d2386d9f9c37a11772e3c4072c"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"30f078bd42757e8686eafa1f28d0d352"} + _compile_hash = {"b5046a6f56d78af0b4f51960bf2deeda"} if tei.get_ethosn_api_version() == 2011: - _compile_hash = {"53f126cf654d4cf61ebb23c767f6740b"} + _compile_hash = {"4a1a56393078367dd27915a188d6a6af"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"851665c060cf4719248919d17325ae02"} + _compile_hash = {"905caf389dd6b868aeff6acbca1fecef"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" "models/inception_v4_299_quant_20181026.tgz", @@ -198,13 +198,13 @@ def test_ssd_mobilenet_v1(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"cd335229a2052f30273f127a233bd319", "95dedc29d911cdc6b28207ca08e42470"} + _compile_hash = {"956caf9e7fe5cfd5c042bd17857f7407", "4313033d14328e2aa022b1bd71b27b1c"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"deee52e136327436411fc725624ae2ea", "6526509d3cbee014e38c79e22bb29d7f"} + _compile_hash = {"dc60cc687d892cd2877873094e9dfc0b", "6b3deeec16c24c0dcef23df0db5fb162"} if tei.get_ethosn_api_version() == 2011: - _compile_hash = {"6e8c4586bdd26527c642a4f016f52284", "057c5efb094c79fbe4483b561147f1d2"} + _compile_hash = {"10826406ae724e52f360a06c35ced09d", "9a484d5ecec7acb18c9d6bc6058be031"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"dc687e60a4b6750fe740853f22aeb2dc", "1949d86100004eca41099c8e6fa919ab"} + _compile_hash = {"425b38830f34b6eb448fa77dbfe9ac96", "de49128643cbf1c659a9a63aad1cba62"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" "models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip", diff --git a/tests/python/contrib/test_miopen.py b/tests/python/contrib/test_miopen.py index 27a8ec6df357..81115b6c0238 100644 --- a/tests/python/contrib/test_miopen.py +++ b/tests/python/contrib/test_miopen.py @@ -19,9 +19,17 @@ from tvm import te from tvm.contrib import miopen import numpy as np +import pytest + + +requires_miopen = pytest.mark.skipif( + tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True) is None, + reason="MIOpen is not enabled", +) @tvm.testing.requires_rocm +@requires_miopen def test_conv2d(): in_channel = 3 out_channel = 64 @@ -35,9 +43,6 @@ def test_conv2d(): dilation_w = 1 xshape = [1, in_channel, 128, 128] - if not tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True): - print("skip because miopen is not enabled...") - return wshape = (out_channel, in_channel, filter_h, filter_w) X = te.placeholder(xshape, name="X") @@ -72,5 +77,60 @@ def verify(): verify() +def verify_softmax(shape, axis, dtype="float32", log_softmax=False): + miopen_op = miopen.log_softmax if log_softmax else miopen.softmax + testing_op = ( + tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python + ) + + A = te.placeholder(shape, dtype=dtype, name="A") + B = miopen_op(A, axis) + s = te.create_schedule([B.op]) + + dev = tvm.rocm(0) + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = testing_op(a_np) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + f = tvm.build(s, [A, B], target="rocm --host=llvm", name="softmax") + f(a, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3) + + +def verify_softmax_4d(shape, dtype="float32", log_softmax=False): + miopen_op = miopen.log_softmax if log_softmax else miopen.softmax + testing_op = ( + tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python + ) + + A = te.placeholder(shape, dtype=dtype, name="A") + B = miopen_op(A, axis=1) + s = te.create_schedule([B.op]) + + dev = tvm.rocm(0) + n, c, h, w = shape + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = testing_op(a_np.transpose(0, 2, 3, 1).reshape(h * w, c)) + b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + f = tvm.build(s, [A, B], target="rocm --host=llvm", name="softmax") + f(a, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3) + + +@tvm.testing.requires_rocm +@requires_miopen +def test_softmax(): + verify_softmax((32, 10), -1) + verify_softmax((3, 4), -1) + verify_softmax_4d((1, 16, 256, 256)) + verify_softmax_4d((1, 16, 256, 256)) + + verify_softmax((32, 10), -1, log_softmax=True) + verify_softmax((3, 4), -1, log_softmax=True) + verify_softmax_4d((1, 16, 256, 256), log_softmax=True) + + if __name__ == "__main__": test_conv2d() diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py index 8567f2c814cf..121edc4b8c60 100644 --- a/tests/python/contrib/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -50,8 +50,9 @@ def run_onnx(onnx_model, input_data): def run_relay(func, data_tuple): target = "llvm" dev = tvm.device("llvm", 0) - intrp = relay.create_executor("graph", device=dev, target=target) - relay_res = intrp.evaluate(func)(*data_tuple) + relay_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + *data_tuple + ) result = [] relay_res = relay_res if isinstance(relay_res, list) else [relay_res] @@ -655,6 +656,50 @@ def verify_cast(dshape, dtype): verify_cast(i, o_dtype) +def test_resize(): + """Resize unit test.""" + + def verify_resize(dshape, outsize, method, coord_trans, rounding_method, dtype="float32"): + x = relay.var("x", relay.ty.TensorType(dshape, dtype)) + y = relay.image.resize2d( + x, + outsize, + layout="NCHW", + method=method, + coordinate_transformation_mode=coord_trans, + rounding_method=rounding_method, + ) + func = relay.Function([x], y) + x_data = np.random.uniform(size=dshape).astype(dtype) + verify_results(func, [x_data], "test_resize", rtol=1e-4, atol=1e-4) + + method = ["nearest_neighbor", "linear", "cubic"] + coord_trans = ["half_pixel", "align_corners", "asymmetric"] + rounding_method = ["round", "floor", "ceil"] + + isize = (1, 3, 480, 640) + + # Downsample + osize = (240, 320) + for i in method: + for j in coord_trans: + for k in rounding_method: + if (i == "nearest_neighbor" and j == "align_corners") or ( + i == "cubic" and j in ["half_pixel", "align_corners"] + ): + continue + verify_resize(isize, osize, method=i, coord_trans=j, rounding_method=k) + + # Upsample + osize = (960, 1280) + for i in method: + for j in coord_trans: + for k in rounding_method: + if (i == "nearest_neighbor" and j == "align_corners") or (i == "cubic"): + continue + verify_resize(isize, osize, method=i, coord_trans=j, rounding_method=k) + + if __name__ == "__main__": test_add() test_bias_add() @@ -684,3 +729,4 @@ def verify_cast(dshape, dtype): test_copy() test_round() test_cast() + test_resize() diff --git a/tests/python/contrib/test_onnx_model.py b/tests/python/contrib/test_onnx_model.py index 84cff57d1d94..075085ff8806 100644 --- a/tests/python/contrib/test_onnx_model.py +++ b/tests/python/contrib/test_onnx_model.py @@ -59,9 +59,12 @@ def get_data(in_data_shapes, dtype="float32"): def run_relay(mod, params, in_data): target = "llvm" dev = tvm.device("llvm", 0) - intrp = relay.create_executor("graph", mod, device=dev, target=target) in_data = [tvm.nd.array(value) for value in in_data.values()] - return intrp.evaluate()(*in_data, **params).numpy() + return ( + relay.create_executor("graph", mod, device=dev, target=target) + .evaluate()(*in_data, **params) + .numpy() + ) def _verify_results(mod, params, in_data): diff --git a/tests/python/contrib/test_popen_pool.py b/tests/python/contrib/test_popen_pool.py index 6b5b367293eb..9ebe4c11c118 100644 --- a/tests/python/contrib/test_popen_pool.py +++ b/tests/python/contrib/test_popen_pool.py @@ -18,7 +18,16 @@ import pytest import time from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor -from tvm.testing import identity_after, terminate_self +from tvm.testing import ( + identity_after, + terminate_self, + initializer, + after_initializer, + register_ffi, + call_py_ffi, + call_cpp_ffi, + call_cpp_py_ffi, +) def test_popen_worker(): @@ -66,6 +75,37 @@ def test_popen_pool_executor(): assert val.value == idx +def test_popen_initializer(): + initargs = [1, 2, 3] + proc = PopenWorker(initializer=initializer, initargs=initargs) + proc.send(after_initializer) + test_global_state_1, test_global_state_2, test_global_state_3 = proc.recv() + assert test_global_state_1 == initargs[0] + assert test_global_state_2 == initargs[1] + assert test_global_state_3 == initargs[2] + + +def test_popen_ffi(): + proc = PopenWorker(register_ffi) + + # call python function via ffi + initargs = [0] + proc.send(call_py_ffi, initargs) + assert proc.recv() == initargs[0] + + # call cpp function via ffi + initargs = [1] + proc.send(call_cpp_ffi, initargs) + assert proc.recv() == initargs[0] + + # call python function from cpp function via ffi + initargs = [2] + proc.send(call_cpp_py_ffi, initargs) + assert proc.recv() == initargs[0] + + if __name__ == "__main__": test_popen_worker() test_popen_pool_executor() + test_popen_initializer() + test_popen_ffi() diff --git a/tests/python/contrib/test_random.py b/tests/python/contrib/test_random.py index bd92f2f70ea7..7a52c0dbf1ea 100644 --- a/tests/python/contrib/test_random.py +++ b/tests/python/contrib/test_random.py @@ -120,17 +120,20 @@ def test_rpc(dtype): return np_ones = np.ones((512, 512), dtype=dtype) - server = rpc.Server("127.0.0.1") - remote = rpc.connect(server.host, server.port) - value = tvm.nd.empty((512, 512), dtype, remote.cpu()) - random_fill = remote.get_function("tvm.contrib.random.random_fill") - random_fill(value) - assert np.count_nonzero(value.numpy()) == 512 * 512 + def check_remote(server): + remote = rpc.connect(server.host, server.port) + value = tvm.nd.empty((512, 512), dtype, remote.cpu()) + random_fill = remote.get_function("tvm.contrib.random.random_fill") + random_fill(value) - # make sure arithmentic doesn't overflow too - np_values = value.numpy() - assert np.isfinite(np_values * np_values + np_values).any() + assert np.count_nonzero(value.numpy()) == 512 * 512 + + # make sure arithmentic doesn't overflow too + np_values = value.numpy() + assert np.isfinite(np_values * np_values + np_values).any() + + check_remote(rpc.Server("127.0.0.1")) for dtype in [ "bool", diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 59f1c3aa4d68..f40b3368dc85 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -110,12 +110,16 @@ def run_and_verify_func(config, target="cuda"): with tvm.transform.PassContext( opt_level=3, config={"relay.ext.tensorrt.options": config} ): - exec = relay.create_executor(mode, mod=mod, device=dev, target=target) + func = relay.create_executor( + mode, mod=mod, device=dev, target=target + ).evaluate() else: with tvm.transform.PassContext(opt_level=3): - exec = relay.create_executor(mode, mod=mod, device=dev, target=target) + func = relay.create_executor( + mode, mod=mod, device=dev, target=target + ).evaluate() if not skip_runtime_test(): - result_dict[result_key] = exec.evaluate()(**input_dict, **params) + result_dict[result_key] = func(**input_dict, **params) if not skip_runtime_test(): assert_result_dict_holds(result_dict) @@ -143,12 +147,16 @@ def compile_and_run(mod, params, i_data, mode="vm", use_trt=True): with tvm.transform.PassContext( opt_level=3, config={"relay.ext.tensorrt.options": config} ): - exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda") + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() else: with tvm.transform.PassContext(opt_level=3): - exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda") + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() - res = exec.evaluate()(i_data, **params) if not skip_runtime_test() else None + res = func(i_data, **params) if not skip_runtime_test() else None return res dtype = "float32" @@ -198,16 +206,16 @@ def test_tensorrt_simple(): with tvm.transform.PassContext( opt_level=3, config={"relay.ext.tensorrt.options": config} ): - relay_exec = relay.create_executor( + func = relay.create_executor( mode, mod=mod, device=tvm.cuda(0), target="cuda" - ) + ).evaluate() else: with tvm.transform.PassContext(opt_level=3): - relay_exec = relay.create_executor( + func = relay.create_executor( mode, mod=mod, device=tvm.cuda(0), target="cuda" - ) + ).evaluate() if not skip_runtime_test(): - result_dict[result_key] = relay_exec.evaluate()(x_data, y_data, z_data) + result_dict[result_key] = func(x_data, y_data, z_data) if not skip_runtime_test(): assert_result_dict_holds(result_dict) @@ -247,9 +255,11 @@ def test_tensorrt_not_compatible(): mod, config = tensorrt.partition_for_tensorrt(mod) for mode in ["graph", "vm"]: with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): - exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda") + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() if not skip_runtime_test(): - results = exec.evaluate()(x_data) + results = func(x_data) def test_tensorrt_serialize_graph_executor(): @@ -474,14 +484,25 @@ def get_graph(x_shape=(1, 16), k_shape=(32, 16)): def test_batch_matmul(): - def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64)): + def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb=True): x = relay.var("x", shape=(x_shape), dtype="float32") y = relay.var("y", shape=(y_shape), dtype="float32") - out = relay.nn.batch_matmul(x, y) + out = relay.nn.batch_matmul(x, y, transpose_a=transa, transpose_b=transb) f = relay.Function([x, y], out) return f, {"x": x_shape, "y": y_shape}, [] - run_and_verify_func(get_graph()) + run_and_verify_func( + get_graph(x_shape=(12, 64, 128), y_shape=(12, 128, 64), transa=True, transb=True) + ) + run_and_verify_func( + get_graph(x_shape=(12, 64, 128), y_shape=(12, 64, 128), transa=True, transb=False) + ) + run_and_verify_func( + get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb=True) + ) + run_and_verify_func( + get_graph(x_shape=(12, 128, 64), y_shape=(12, 64, 128), transa=False, transb=False) + ) def test_bias_add(): @@ -730,12 +751,12 @@ def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt): assert are_ops_on_trt(mod, op_list=["reshape"]) == should_offload_to_trt if not skip_runtime_test(): with relay.build_config(opt_level=3): - relay_exec = relay.create_executor( + func = relay.create_executor( "vm", mod=mod, device=tvm.cpu(0), target="llvm" - ) + ).evaluate() for i, x_data in enumerate(x_data_list): - result_arr[i][use_trt] = relay_exec.evaluate()(x_data) + result_arr[i][use_trt] = func(x_data) if not skip_runtime_test(): for i in range(len(x_data_list)): @@ -1233,10 +1254,11 @@ def test_tensorrt_dynamic_batch(): if not skip_runtime_test(): with relay.build_config(opt_level=3): - relay_exec = relay.create_executor("vm", mod=mod, device=tvm.cpu(0), target="llvm") - + func = relay.create_executor( + "vm", mod=mod, device=tvm.cpu(0), target="llvm" + ).evaluate() for i, batch_size in enumerate(batches_to_test): - result_arr[i][use_trt] = relay_exec.evaluate()(x_data[:batch_size, ...]) + result_arr[i][use_trt] = func(x_data[:batch_size, ...]) if not skip_runtime_test(): for i in range(len(batches_to_test)): @@ -1251,33 +1273,33 @@ def test_tensorrt_dynamic_batch_conv(): x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32") k_shape = (16, 32, 3, 3) params = {"kernel": np.random.uniform(-1, 1, k_shape).astype("float32")} - result_arr = [{"cuda": {}, "llvm": {}} for _ in range(len(batches_to_test))] - for use_trt in [True, False]: - x = relay.var("x", shape=x_shape, dtype="float32") - kernel = relay.var("kernel", shape=k_shape, dtype="float32") - out = relay.nn.conv2d(x, kernel, channels=16, kernel_size=(3, 3), groups=1) - f = relay.Function([x, kernel], out) - mod = tvm.IRModule() - mod["main"] = f - if use_trt: - mod, _ = tensorrt.partition_for_tensorrt(mod, params) - + for use_implicit_batch in [True, False]: + result_arr = [{"cuda": {}, "llvm": {}} for _ in range(len(batches_to_test))] + for use_trt in [True, False]: + x = relay.var("x", shape=x_shape, dtype="float32") + kernel = relay.var("kernel", shape=k_shape, dtype="float32") + out = relay.nn.conv2d(x, kernel, channels=16, kernel_size=(3, 3), groups=1) + f = relay.Function([x, kernel], out) + mod = tvm.IRModule() + mod["main"] = f + if use_trt: + mod, config = tensorrt.partition_for_tensorrt( + mod, params, use_implicit_batch=use_implicit_batch + ) + if not skip_runtime_test(): + for target in ["llvm", "cuda"]: + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + func = relay.create_executor( + "vm", mod=mod, device=tvm.device(target), target=target + ).evaluate() + for i, batch_size in enumerate(batches_to_test): + result_arr[i][target][use_trt] = func(x_data[:batch_size, ...], **params) if not skip_runtime_test(): - for target in ["llvm", "cuda"]: - with relay.build_config(opt_level=3): - relay_exec = relay.create_executor( - "vm", mod=mod, device=tvm.cpu(0), target="llvm" - ) - - for i, batch_size in enumerate(batches_to_test): - result_arr[i][target][use_trt] = relay_exec.evaluate()( - x_data[:batch_size, ...], **params - ) - - if not skip_runtime_test(): - for i in range(len(batches_to_test)): - for target in ["llvm", "cuda"]: - assert_result_dict_holds(result_arr[i][target]) + for i in range(len(batches_to_test)): + for target in ["llvm", "cuda"]: + assert_result_dict_holds(result_arr[i][target]) def test_maskrcnn_resnet50() -> None: @@ -1421,9 +1443,11 @@ def test_empty_subgraph(): x_data = np.random.uniform(-1, 1, x_shape).astype("float32") for mode in ["graph", "vm"]: with tvm.transform.PassContext(opt_level=3): - exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda") + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() if not skip_runtime_test(): - results = exec.evaluate()(x_data) + results = func(x_data) if __name__ == "__main__": diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py index 93ab634feb15..6268a6aae615 100644 --- a/tests/python/contrib/test_tflite_runtime.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -128,18 +128,18 @@ def test_remote(): tflite_output = interpreter.get_tensor(output_details[0]["index"]) # inference via remote tvm tflite runtime - server = rpc.Server("127.0.0.1") - remote = rpc.connect(server.host, server.port) - a = remote.upload(tflite_model_path) - - with open(tflite_model_path, "rb") as model_fin: - runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) - runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) - runtime.invoke() - out = runtime.get_output(0) - np.testing.assert_equal(out.numpy(), tflite_output) - - server.terminate() + def check_remote(server): + remote = rpc.connect(server.host, server.port) + a = remote.upload(tflite_model_path) + + with open(tflite_model_path, "rb") as model_fin: + runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) + runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.numpy(), tflite_output) + + check_remote(rpc.Server("127.0.0.1")) if __name__ == "__main__": diff --git a/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py b/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py index 18c57d485d76..2e16792542ca 100644 --- a/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py +++ b/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py @@ -288,8 +288,8 @@ def expected(): func0 = relay.Function( [data0, weight0, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0], bn.astuple() ) - func0 = set_func_attr(func0, "vitis_ai", "tvmgen_default_vitis_ai_0") - gv0 = relay.GlobalVar("tvmgen_default_vitis_ai_0") + func0 = set_func_attr(func0, "vitis_ai", "tvmgen_default_vitis_ai_main_0") + gv0 = relay.GlobalVar("tvmgen_default_vitis_ai_main_0") mod = tvm.IRModule() mod[gv0] = func0 mod = relay.transform.InferType()(mod) diff --git a/tests/python/contrib/test_vitis_ai/test_vitis_ai_runtime_cpu_part.py b/tests/python/contrib/test_vitis_ai/test_vitis_ai_runtime_cpu_part.py index db9552c8eab2..f414d7d71fcc 100644 --- a/tests/python/contrib/test_vitis_ai/test_vitis_ai_runtime_cpu_part.py +++ b/tests/python/contrib/test_vitis_ai/test_vitis_ai_runtime_cpu_part.py @@ -59,10 +59,12 @@ def test_extern_vitis_ai_resnet18(): mod, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=1) ref_mod, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=1) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)) i_data = np.random.uniform(0, 1, ishape).astype(dtype) - ref_res = ref_ex.evaluate()(i_data, **params) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( + i_data, **params + ) + verify_result( mod, {"data": i_data}, diff --git a/tests/python/driver/tvmc/test_mlf.py b/tests/python/driver/tvmc/test_mlf.py index 4669fab916a6..0426f5678153 100644 --- a/tests/python/driver/tvmc/test_mlf.py +++ b/tests/python/driver/tvmc/test_mlf.py @@ -18,6 +18,7 @@ import pytest import os import shlex +import sys import tvm from tvm.driver import tvmc @@ -130,3 +131,7 @@ def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_compile assert tvmc_package.graph is None, ".graph must not be set in the MLF archive for AOT executor." assert tvmc_package.params is not None, ".params must be set in the MLF archive." assert tvmc_package.type == "mlf", ".type must be set to 'mlf' in the MLF format." + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/driver/tvmc/test_model.py b/tests/python/driver/tvmc/test_model.py index f5a28d419cbb..fd2637a85f1f 100644 --- a/tests/python/driver/tvmc/test_model.py +++ b/tests/python/driver/tvmc/test_model.py @@ -21,6 +21,7 @@ from tvm.driver import tvmc from tvm.driver.tvmc.model import TVMCModel, TVMCPackage, TVMCResult +from tvm.runtime.module import BenchmarkResult def test_tvmc_workflow(keras_simple): @@ -35,7 +36,7 @@ def test_tvmc_workflow(keras_simple): assert type(result) is TVMCResult assert path.exists(tuning_records) assert type(result.outputs) is dict - assert type(result.times) is tuple + assert type(result.times) is BenchmarkResult assert "output_0" in result.outputs.keys() diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py index 7acb376baba6..2ce363ab5911 100644 --- a/tests/python/driver/tvmc/test_runner.py +++ b/tests/python/driver/tvmc/test_runner.py @@ -20,6 +20,7 @@ from tvm.driver import tvmc from tvm.driver.tvmc.model import TVMCResult from tvm.driver.tvmc.result_utils import get_top_results +from tvm.runtime.module import BenchmarkResult def test_generate_tensor_data_zeros(): @@ -52,7 +53,7 @@ def test_generate_tensor_data__type_unknown(): def test_format_times__contains_header(): - fake_result = TVMCResult(outputs=None, times=[0.6, 1.2, 0.12, 0.42]) + fake_result = TVMCResult(outputs=None, times=BenchmarkResult([0.6, 1.2, 0.12, 0.42])) sut = fake_result.format_times() assert "std (ms)" in sut @@ -101,5 +102,5 @@ def test_run_tflite_module__with_profile__valid_input( tiger_cat_mobilenet_id in top_5_ids ), "tiger cat is expected in the top-5 for mobilenet v1" assert type(result.outputs) is dict - assert type(result.times) is tuple + assert type(result.times) is BenchmarkResult assert "output_0" in result.outputs.keys() diff --git a/tests/python/driver/tvmc/test_tvmc_common.py b/tests/python/driver/tvmc/test_tvmc_common.py index cb6b82a32937..31fa688ad717 100644 --- a/tests/python/driver/tvmc/test_tvmc_common.py +++ b/tests/python/driver/tvmc/test_tvmc_common.py @@ -177,6 +177,11 @@ def test_shape_parser(): shape_dict = tvmc.common.parse_shape_string(shape_string) # Convert to strings to allow comparison with Any. assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" + # Check that multiple valid gpu inputs are parsed correctly. + shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}" + assert str(shape_dict) == expected # Check that invalid pattern raises expected error. shape_string = "input:[a,10]" @@ -186,6 +191,22 @@ def test_shape_parser(): shape_string = "input:5,10 input2:10,10" with pytest.raises(argparse.ArgumentTypeError): tvmc.common.parse_shape_string(shape_string) + # Check that input with a invalid slash raises error. + shape_string = "gpu_0/data_0:5,10 /:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + # Check that input with a invalid slash raises error. + shape_string = "gpu_0/data_0:5,10 data/:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + # Check that input with a invalid slash raises error. + shape_string = "gpu_0/data_0:5,10 /data:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + # Check that input with invalid slashes raises error. + shape_string = "gpu_0/invalid/data_0:5,10 data_1:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) def test_target_from_cli__error_duplicate(): diff --git a/tests/python/frontend/coreml/test_forward.py b/tests/python/frontend/coreml/test_forward.py index ee9159573ea2..8892018a79e9 100644 --- a/tests/python/frontend/coreml/test_forward.py +++ b/tests/python/frontend/coreml/test_forward.py @@ -30,6 +30,11 @@ import coremltools as cm import model_zoo import tvm.testing +import tempfile +from os import path +import tvm +import tvm.relay as relay +import tensorflow.keras as keras def get_tvm_output( @@ -783,6 +788,44 @@ def test_forward_convolution(): verify_convolution((1, 3, 224, 224), filter=(32, 3, 3, 3), padding="SAME") +def test_can_build_keras_to_coreml_to_relay(): + """Test multiple conversion paths and importing from + a saved file.""" + model = keras.models.Sequential() + model.add( + keras.layers.Conv2D( + filters=6, + kernel_size=(1, 1), + activation="relu", + padding="same", + input_shape=(3, 3, 1), + data_format="channels_first", + ) + ) + + with tempfile.TemporaryDirectory() as tmpdir: + kmodel_fn = path.join(tmpdir, "c1mdl.h5") + model.save(kmodel_fn) + + mdl = cm.convert(kmodel_fn) + model_file = path.join(tmpdir, "c1.mlmodel") + mdl.save(model_file) + + mdl = cm.models.MLModel(model_file) + desc = mdl.get_spec().description + iname = desc.input[0].name + ishape = desc.input[0].type.multiArrayType.shape + shape_dict = {} + for i in mdl.get_spec().description.input: + iname = i.name + ishape = i.type.multiArrayType.shape + shape_dict[iname] = ishape + mod, params = relay.frontend.from_coreml(mdl, shape_dict) + + with tvm.transform.PassContext(opt_level=3): + relay.build(mod, "llvm", params=params) + + if __name__ == "__main__": test_forward_AddLayerParams() test_forward_ConcatLayerParams() @@ -801,3 +844,4 @@ def test_forward_convolution(): test_resnet50_checkonly() test_forward_image_scaler() test_forward_convolution() + test_can_build_keras_to_coreml_to_relay() diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index a6c3d6efec56..44aa93061a62 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -333,8 +333,9 @@ def test_forward_where(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, args, auxs) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(np_cond, np_x, np_y) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + np_cond, np_x, np_y + ) tvm.testing.assert_allclose(op_res.numpy(), mx_out) @@ -352,14 +353,15 @@ def _mx_symbol(F, start, stop, step): return sym def verify(start, stop, step): - ref_res = _mx_symbol(mx.nd, start, stop, step).asnumpy() + ref_res = _mx_symbol(mx.nd, start, stop, step) mx_sym = _mx_symbol(mx.sym, start, stop, step) mod, _ = relay.frontend.from_mxnet(mx_sym, {}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()() - tvm.testing.assert_allclose(op_res.numpy(), ref_res) + op_res = relay.create_executor( + kind, mod=mod, device=dev, target=target + ).evaluate()() + tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify(0, 20, None) verify(0, 20, 2) @@ -416,8 +418,9 @@ def test_forward_broadcast_ops(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np, b_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np, b_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) @@ -451,8 +454,9 @@ def test_forward_elemwise_ops(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np, b_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np, b_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) @@ -500,8 +504,9 @@ def test_forward_unary_ops(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) @@ -530,8 +535,9 @@ def test_forward_scalar_ops(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) for op in ["maximum", "minimum"]: dtype = "float32" @@ -544,8 +550,9 @@ def test_forward_scalar_ops(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) @@ -558,8 +565,9 @@ def verify(shape, axis, begin, end): mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((3, 4), 0, 1, 2) @@ -583,8 +591,9 @@ def verify(x_shape, y_shape, axes): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": x_shape, "y": y_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np, y_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np, y_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((3, 4), (2, 3), None) @@ -617,8 +626,9 @@ def verify(shape, seq_lengths, use_seq_lengths, seq_axis): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(*in_data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + *in_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((3, 4), [1, 2, 3, 1], True, 0) @@ -653,8 +663,9 @@ def test_forward_logistic_regression_output(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) @@ -670,8 +681,9 @@ def verify(a_shape, b_shape, transpose_b=False): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np, b_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np, b_np + ) tvm.testing.assert_allclose( op_res.numpy(), ref_res.asnumpy(), rtol=1e-05, atol=1e-05 ) @@ -689,8 +701,9 @@ def verify(shape): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((1,)) @@ -711,8 +724,9 @@ def verify(shape, axis): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((1, 3, 1), None) @@ -731,8 +745,9 @@ def verify(shape, axis, size): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor( + kind, mod=mod, device=dev, target=target + ).evaluate()(x_np) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((1, 2, 1), 2, 3) @@ -748,8 +763,9 @@ def verify(input_shape, shape): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": input_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((1, 2, 3), (3, 2, 3)) @@ -766,8 +782,9 @@ def verify(input_shape, like_shape): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": input_shape, "y": like_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np, y_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np, y_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((1, 2, 3), (3, 2, 3)) @@ -785,8 +802,9 @@ def test_forward_logical_not(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) @@ -801,8 +819,9 @@ def verify(val, shape, dtype): # Skip testing graph executor because this op will be optimized out # by constant folding. for kind in ["debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()() + op_res = relay.create_executor( + kind, mod=mod, device=dev, target=target + ).evaluate()() tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify(2, (3, 4), "float32") @@ -825,8 +844,9 @@ def verify(data_shape, weight_shape): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": data_shape, "w": weight_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x=x_np, w=w_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x=x_np, w=w_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((2, 2), (4, 5)) @@ -852,8 +872,9 @@ def verify(shape, indices_src, axis, mode="clip"): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np, indices_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np, indices_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((2, 2), [[[1, 0], [0, 1]]], 0) @@ -876,8 +897,9 @@ def verify(xshape, yshape, y_data, error=False): ) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_data, y_data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) @@ -905,8 +927,9 @@ def verify(shape, transform_type, target_shape): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) verify((4, 6), "affine", (16, 32)) @@ -925,8 +948,9 @@ def verify(data_shape, grid_shape): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data, grid) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data, grid + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) verify((4, 4, 16, 32), (4, 2, 8, 8)) @@ -988,8 +1012,9 @@ def verify( for target, dev in tvm.testing.enabled_targets(): # only test graph executor because debug runtime is too slow for kind in ["graph"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(**inputs, **params) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + **inputs, **params + ) if init_states: assert len(op_res) == len(mx_res) for i, val in enumerate(op_res): @@ -1022,11 +1047,11 @@ def verify(xshape, yshape, offset=None): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": xshape, "y": yshape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) + func = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate() if offset is None or offset == (0, 0): - op_res = intrp.evaluate()(x_data, y_data) + op_res = func(x_data, y_data) else: - op_res = intrp.evaluate()(x_data) + op_res = func(x_data) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((1, 3, 40, 40), (1, 3, 20, 20)) @@ -1045,8 +1070,9 @@ def verify(shape, axis, is_ascend, dtype="float32"): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((2, 3, 4), axis=0, is_ascend=False) @@ -1076,8 +1102,9 @@ def verify(shape, k, axis, ret_type, is_ascend=None, dtype="float32"): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np + ) if isinstance(ref_res, list): assert len(op_res) == len(ref_res) for i, t in enumerate(op_res): @@ -1133,11 +1160,11 @@ def verify(shape, use_sequence_length, value, axis, dtype, itype): if use_sequence_length is False and kind == "graph": # Disable the test for 'graph' when it's identity. continue - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) + func = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate() if use_sequence_length: - op_res = intrp.evaluate()(data_np, valid_length_np) + op_res = func(data_np, valid_length_np) else: - op_res = intrp.evaluate()(data_np) + op_res = func(data_np) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((5, 10), True, 0.0, 0, "float32", "float32") @@ -1155,8 +1182,9 @@ def verify(shape): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((3, 4)) @@ -1203,8 +1231,9 @@ def verify(shape, axis=1, fix_gamma=False): # print(mod) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x, gamma, beta, moving_mean, moving_var) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x, gamma, beta, moving_mean, moving_var + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3) verify((2, 3, 4, 5)) @@ -1227,11 +1256,10 @@ def verify(shape, axis=1, epsilon=1e-5): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x, gamma, beta) - tvm.testing.assert_allclose( - op_res.asnumpy(), ref_res.asnumpy(), rtol=2e-5, atol=1e-5 + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x, gamma, beta ) + tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=2e-5, atol=1e-5) verify((2, 3, 4, 5)) verify((32, 64, 80, 64)) @@ -1253,8 +1281,9 @@ def verify(shape, axis=-1): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x, gamma, beta) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x, gamma, beta + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((2, 5)) @@ -1281,8 +1310,9 @@ def verify(shape, num_groups=1): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x, gamma, beta) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x, gamma, beta + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((1, 4, 2), num_groups=4) @@ -1302,8 +1332,9 @@ def verify(indices_shape, depth, on_value, off_value, dtype): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x.astype("float32")) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x.astype("float32") + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((3,), 3, 1, 0, "int32") @@ -1428,8 +1459,9 @@ def verify(data_shape, kernel_size, stride, pad, num_filter, is_depthwise=False) mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x, weight, bias) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x, weight, bias + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3) verify(data_shape=(1, 1, 1024 * 16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) @@ -1509,8 +1541,9 @@ def verify(data_shape, kernel_size, stride, pad, num_filter): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x, weight, bias) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x, weight, bias + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify(data_shape=(1, 1, 1024 * 16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) @@ -1542,8 +1575,9 @@ def verify(a_np, b_np): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["debug", "vm"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np, b_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np, b_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3) verify(np.asarray([1.0], "float32"), np.asarray([2.0], "float32")) @@ -1561,8 +1595,9 @@ def verify(from_dtype, to_dtype): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "vm", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(from_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + from_np + ) assert op_res.dtype == to_dtype, op_res.dtype tvm.testing.assert_allclose(op_res.numpy(), from_np.astype(to_dtype)) @@ -1584,8 +1619,9 @@ def verify(dtypes, cast_narrow, expected_dtype): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "vm", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(*x_nps) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + *x_nps + ) for i, res in enumerate(op_res): assert res.dtype == expected_dtype, res.dtype tvm.testing.assert_allclose(res.numpy(), x_nps[i].astype(expected_dtype)) @@ -1609,8 +1645,9 @@ def verify(x, shape, dtype): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "vm", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) for dtype in ["int32", "int64"]: @@ -1650,8 +1687,9 @@ def verify(shape, blocksize=2): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((1, 18, 3, 3), 3) @@ -1669,8 +1707,9 @@ def verify(shape, blocksize=2): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((1, 1, 9, 9), 3) @@ -1705,8 +1744,9 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data1, data2) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data1, data2 + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify( @@ -1810,8 +1850,9 @@ def verify(data_shape, start=None, step=None, axis=None): mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()() + op_res = relay.create_executor( + kind, mod=mod, device=dev, target=target + ).evaluate()() tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify(data_shape=(3,), start=0.0, step=1.0) @@ -1832,8 +1873,9 @@ def verify(batch, seq_length, num_heads, head_dim): mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) verify(1, 10, 3, 16) @@ -1857,8 +1899,9 @@ def verify(batch, seq_length, num_heads, head_dim): mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape, "weight": weight_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data=data_np, weight=weight_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data=data_np, weight=weight_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) verify(1, 10, 4, 16) @@ -1914,8 +1957,9 @@ def verify( ): target += " -libs=thrust" for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((1, 10, 6)) @@ -1953,8 +1997,9 @@ def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corn mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data, anchors) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data, anchors + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((1, 10, 4), (1, 10, 4)) @@ -1993,11 +2038,11 @@ def verify(data_shape, axis, use_length, length): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) + func = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate() if use_length: - op_res = intrp.evaluate()(x, length) + op_res = func(x, length) else: - op_res = intrp.evaluate()(x) + op_res = func(x) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) @@ -2033,8 +2078,7 @@ def test_forward_npi_pad(data_shape, pad_width, mode, dtype, constant_value, tar ref_res = np.pad(data_np, mode=mode, pad_width=pad_width) mx_sym = mx.sym.np.pad(data.as_np_ndarray(), mode=mode, pad_width=pad_width) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(data_np) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -2052,8 +2096,7 @@ def test_forward_npi_transpose(data_shape, axes, dtype, target, dev, kind): ref_res = mx.np.transpose(mx.np.array(data_np), axes=axes) mx_sym = mx.sym.np.transpose(data.as_np_ndarray(), axes=axes) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(data_np) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2080,8 +2123,9 @@ def test_forward_npi_concatenate(data_shape1, data_shape2, axis, dtype, target, mod, _ = relay.frontend.from_mxnet( mx_sym, shape={"data1": data_shape1, "data2": data_shape2}, dtype=dtype ) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np1, data_np2) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np1, data_np2 + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2108,8 +2152,9 @@ def test_forward_npi_stack(data_shape1, data_shape2, axis, dtype, target, dev, k mod, _ = relay.frontend.from_mxnet( mx_sym, shape={"data1": data_shape1, "data2": data_shape2}, dtype=dtype ) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np1, data_np2) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np1, data_np2 + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2123,8 +2168,7 @@ def test_forward_np_copy(data_shape, dtype, target, dev, kind): ref_res = mx.np.copy(mx.np.array(data_np)) mx_sym = mx.sym.np.copy(data.as_np_ndarray()) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(data_np) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2151,8 +2195,7 @@ def test_forward_npx_reshape(data_shape, out_shape, dtype, target, reverse, dev, ref_res = mx.npx.reshape(mx.np.array(data_np), newshape=out_shape, reverse=reverse) mx_sym = mx.sym.npx.reshape(data.as_np_ndarray(), newshape=out_shape, reverse=reverse) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(data_np) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2186,8 +2229,9 @@ def test_forward_npi_binary(data_shape, dtype, target, dev, kind): mod, _ = relay.frontend.from_mxnet( mx_sym, shape={"lhs": data_shape, "rhs": data_shape}, dtype=dtype ) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np1, data_np2) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np1, data_np2 + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2218,8 +2262,9 @@ def test_forward_npi_binary_scalar(data_shape, dtype, scalar, target, dev, kind) ref_res = ref_op(mx.np.array(data_np1), scalar) mx_sym = mx_op(data1.as_np_ndarray(), scalar) mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"lhs": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np1) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np1 + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2235,8 +2280,7 @@ def test_forward_npi_tanh(data_shape, dtype, target, dev, kind): ref_res = mx.np.tanh(mx.np.array(data_np1)) mx_sym = mx.sym.np.tanh(data1.as_np_ndarray()) mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"data": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np1) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(data_np1) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2267,8 +2311,9 @@ def test_forward_npi_where_rscalar( mod, _ = relay.frontend.from_mxnet( mx_sym, shape={"condition": cond_shape, "x": data_shape}, dtype=dtypeDic ) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(cond_np, data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + cond_np, data_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2296,8 +2341,7 @@ def test_forward_split_v2( data.as_nd_ndarray(), indices_or_sections, axis=axis, squeeze_axis=squeeze_axis ) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(data_np) op_res_ = [] for arr in op_res: op_res_.append(arr.numpy().tolist()) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index c5407697de46..9e0eb1f75217 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import glob import os import re +import glob import numpy as np import pytest @@ -47,7 +49,14 @@ def get_input_data_shape_dict(graph_def, input_data): def get_tvm_output_with_vm( - graph_def, input_data, target, device, opset=None, freeze_params=False, convert_to_static=False + graph_def, + input_data, + target, + dev, + opset=None, + freeze_params=False, + convert_to_static=False, + convert_config=None, ): """Generic function to execute and get tvm output with vm executor""" if not isinstance(input_data, list): @@ -55,14 +64,19 @@ def get_tvm_output_with_vm( _, shape_dict = get_input_data_shape_dict(graph_def, input_data) mod, params = relay.frontend.from_onnx( - graph_def, shape_dict, opset=opset, freeze_params=freeze_params + graph_def, + shape_dict, + opset=opset, + freeze_params=freeze_params, + convert_config=convert_config, ) if convert_to_static: mod = relay.transform.DynamicToStatic()(mod) - ex = relay.create_executor("vm", mod=mod, device=device, target=target) - result = ex.evaluate()(*input_data, **params) + result = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()( + *input_data, **params + ) if isinstance(result, tvm.runtime.NDArray): return result.numpy() return [r.numpy() for r in result] @@ -72,25 +86,25 @@ def get_tvm_output( graph_def, input_data, target, - device, + dev, output_shape=None, output_dtype="float32", opset=None, opt_level=1, + convert_config=None, ): """Generic function to execute and get tvm output""" # TODO: Resolve the issues and remove the following lines - target = "llvm" - device = tvm.cpu(0) - input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data) - mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) + mod, params = relay.frontend.from_onnx( + graph_def, shape_dict, opset=opset, convert_config=convert_config + ) with tvm.transform.PassContext(opt_level=opt_level): graph, lib, params = relay.build(mod, target, params=params) - m = graph_executor.create(graph, lib, device) + m = graph_executor.create(graph, lib, dev) # set inputs if isinstance(input_data, list): for i, e in enumerate(input_names): @@ -137,7 +151,8 @@ def verify_with_ort_with_inputs( model, inputs, out_shape=None, - targets=None, + target=None, + dev=None, use_vm=False, opset=None, freeze_params=False, @@ -147,48 +162,54 @@ def verify_with_ort_with_inputs( atol=1e-5, apply_softmax=False, opt_level=1, + convert_config=None, ): if opset is not None: model.opset_import[0].version = opset ort_out = get_onnxruntime_output(model, inputs) - if targets is None: - targets = [tgt for (tgt, _) in tvm.testing.enabled_targets()] - - for target in targets: - dev = tvm.device(target, 0) - if use_vm: - tvm_out = get_tvm_output_with_vm( - model, - inputs, - target, - dev, - opset=opset, - freeze_params=freeze_params, - convert_to_static=convert_to_static, - ) - else: - tvm_out = get_tvm_output( - model, inputs, target, dev, out_shape, dtype, opset=opset, opt_level=opt_level - ) - if not isinstance(tvm_out, list): - tvm_out = [tvm_out] - if not isinstance(ort_out, list): - ort_out = [ort_out] - for tvm_val, ort_val in zip(tvm_out, ort_out): - if apply_softmax: - ort_val = scipy.special.softmax(ort_val) - tvm_val = scipy.special.softmax(tvm_val) - tvm.testing.assert_allclose(ort_val, tvm_val, rtol=rtol, atol=atol) - assert ort_val.dtype == tvm_val.dtype + if use_vm: + tvm_out = get_tvm_output_with_vm( + model, + inputs, + target, + dev, + opset=opset, + freeze_params=freeze_params, + convert_to_static=convert_to_static, + convert_config=convert_config, + ) + else: + tvm_out = get_tvm_output( + model, + inputs, + target, + dev, + out_shape, + dtype, + opset=opset, + opt_level=opt_level, + convert_config=convert_config, + ) + if not isinstance(tvm_out, list): + tvm_out = [tvm_out] + if not isinstance(ort_out, list): + ort_out = [ort_out] + for tvm_val, ort_val in zip(tvm_out, ort_out): + if apply_softmax: + ort_val = scipy.special.softmax(ort_val) + tvm_val = scipy.special.softmax(tvm_val) + tvm.testing.assert_allclose(ort_val, tvm_val, rtol=rtol, atol=atol) + assert ort_val.dtype == tvm_val.dtype def verify_with_ort( model, input_shapes, out_shape=None, - targets=None, + target=None, + dev=None, use_vm=False, opset=None, freeze_params=False, @@ -202,7 +223,8 @@ def verify_with_ort( model, inputs, out_shape=out_shape, - targets=targets, + target=target, + dev=dev, use_vm=use_vm, opset=opset, freeze_params=freeze_params, @@ -213,6 +235,39 @@ def verify_with_ort( ) +def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev): + from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType + + input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes] + + class RandomDataReader(CalibrationDataReader): + def __init__(self, n=10): + input_dict = dict(zip(input_names, input_shapes)) + self.data = iter( + [ + { + name: np.random.random(shape).astype("float32") + for name, shape in input_dict.items() + } + for _ in range(n) + ] + ) + + def get_next(self): + return next(self.data, None) + + d = tvm.contrib.utils.tempdir() + model_fp32 = os.path.join(d.temp_dir, "model.onnx") + onnx.save_model(onnx_model, model_fp32) + model_quant = os.path.join(d.temp_dir, "model.quant.onnx") + quantized_model = quantize_static(model_fp32, model_quant, RandomDataReader()) + # opt_level=1 will cause error with qnn lowering + model = onnx.load(model_quant) + verify_with_ort_with_inputs( + model, input_arrays, opt_level=2, target=target, dev=dev, use_vm=True + ) + + def make_constant_node(name, data_type, dims, vals): return helper.make_node( "Constant", @@ -228,8 +283,8 @@ def is_version_greater_than(ver): ) -@tvm.testing.uses_gpu -def test_reshape(): +@tvm.testing.parametrize_targets +def test_reshape(target, dev): in_shape = (4, 3, 3, 4) ref_shape = (6, 2, 4, 3) @@ -256,14 +311,13 @@ def test_reshape(): model = helper.make_model(graph, producer_name="reshape_test") - for target, dev in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype("int32") - tvm_out = get_tvm_output(model, x, target, dev, ref_shape, "float32") - tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + x = np.random.uniform(size=in_shape).astype("int32") + tvm_out = get_tvm_output(model, x, target, dev, ref_shape, "float32") + tvm.testing.assert_allclose(ref_shape, tvm_out.shape) -@tvm.testing.uses_gpu -def test_double_reshape(): +@tvm.testing.parametrize_targets +def test_double_reshape(target, dev): in_shape = (4, 3, 3, 4) ref_shape = (6, 2, 4, 3) @@ -292,14 +346,13 @@ def test_double_reshape(): model = helper.make_model(graph, producer_name="reshape_test") - for target, dev in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype("int32") - tvm_out = get_tvm_output(model, x, target, dev, ref_shape, "float32") - tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + x = np.random.uniform(size=in_shape).astype("int32") + tvm_out = get_tvm_output(model, x, target, dev, ref_shape, "float32") + tvm.testing.assert_allclose(ref_shape, tvm_out.shape) -@tvm.testing.uses_gpu -def test_expand(): +@tvm.testing.parametrize_targets +def test_expand(target, dev): def _test_expand(name, data, shape, ref_data, dtype="int32"): shape_array = np.array(shape) if dtype == "int32": @@ -339,9 +392,8 @@ def _test_expand(name, data, shape, ref_data, dtype="int32"): model = helper.make_model(graph, producer_name=name) - for target, dev in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm(model, data, target, dev, freeze_params=True) - tvm.testing.assert_allclose(ref_data, tvm_out) + tvm_out = get_tvm_output_with_vm(model, data, target, dev, freeze_params=True) + tvm.testing.assert_allclose(ref_data, tvm_out) in_shape = (3, 1) shape = (3, 4) @@ -358,51 +410,53 @@ def _test_expand(name, data, shape, ref_data, dtype="int32"): _test_expand("expand_with_dim_changed_test", data, shape, ref_data, "int64") -def verify_depth_to_space(inshape, outshape, mode, blockSize): - node = onnx.helper.make_node("DepthToSpace", inputs=["x"], outputs=["y"], blocksize=blockSize) - - graph = helper.make_graph( - [node], - "depth_to_space_test", - inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))], - ) +@tvm.testing.parametrize_targets +def test_depth_to_space(target, dev): + def verify_depth_to_space(inshape, outshape, mode, blockSize): + node = onnx.helper.make_node( + "DepthToSpace", inputs=["x"], outputs=["y"], blocksize=blockSize + ) - model = helper.make_model(graph, producer_name="depth_to_space_test") + graph = helper.make_graph( + [node], + "depth_to_space_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))], + ) - verify_with_ort(model, [inshape], [outshape]) + model = helper.make_model(graph, producer_name="depth_to_space_test") + verify_with_ort(model, [inshape], [outshape], target, dev) -@tvm.testing.uses_gpu -def test_depth_to_space(): # current onnx.checker use OpSet-1 version of DepthToSpace, which doesn't have a mode argument. - # TO-DO, we can add mode arguement to test CRD mode and DCR mode + # TO-DO, we can add mode argument to test CRD mode and DCR mode # in the future when we update to a newer onnx version. verify_depth_to_space((1, 8, 2, 3), (1, 2, 4, 6), mode="CRD", blockSize=2) -def verify_space_to_depth(inshape, outshape, blockSize): - node = onnx.helper.make_node("SpaceToDepth", inputs=["x"], outputs=["y"], blocksize=blockSize) - - graph = helper.make_graph( - [node], - "space_to_depth_test", - inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))], - ) +@tvm.testing.parametrize_targets +def test_space_to_depth(target, dev): + def verify_space_to_depth(inshape, outshape, blockSize): + node = onnx.helper.make_node( + "SpaceToDepth", inputs=["x"], outputs=["y"], blocksize=blockSize + ) - model = helper.make_model(graph, producer_name="space_to_depth_test") + graph = helper.make_graph( + [node], + "space_to_depth_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))], + ) - verify_with_ort(model, [inshape], [outshape]) + model = helper.make_model(graph, producer_name="space_to_depth_test") + verify_with_ort(model, [inshape], [outshape], target, dev) -@tvm.testing.uses_gpu -def test_space_to_depth(): verify_space_to_depth((1, 1, 4, 6), (1, 4, 2, 3), 2) -@tvm.testing.uses_gpu -def test_shape(): +@tvm.testing.parametrize_targets +def test_shape(target, dev): in_shape = (4, 3, 3, 4) ref_shape = (6, 2, 4, 3) @@ -431,76 +485,72 @@ def test_shape(): model = helper.make_model(graph, producer_name="shape_test") - for target, dev in tvm.testing.enabled_targets(): - x = np.random.uniform(size=in_shape).astype("int32") - tvm_out = get_tvm_output(model, x, target, dev, ref_shape, "int32") - tvm.testing.assert_allclose(ref_shape, tvm_out) + x = np.random.uniform(size=in_shape).astype("int32") + tvm_out = get_tvm_output(model, x, target, dev, ref_shape, "int32") + tvm.testing.assert_allclose(ref_shape, tvm_out) -def _test_power_iteration(x_shape, y_shape): - if isinstance(y_shape, int): - y_shape = [y_shape] +@tvm.testing.parametrize_targets +def test_power(target, dev): + def _test_power_iteration(x_shape, y_shape): + if isinstance(y_shape, int): + y_shape = [y_shape] - x = np.random.uniform(size=x_shape).astype(np.float32) - y = np.random.uniform(size=y_shape).astype(np.float32) + x = np.random.uniform(size=x_shape).astype(np.float32) + y = np.random.uniform(size=y_shape).astype(np.float32) - np_res = np.power(x, y).astype(np.float32) + np_res = np.power(x, y).astype(np.float32) - res = helper.make_node("Pow", ["x", "y"], ["out"]) + res = helper.make_node("Pow", ["x", "y"], ["out"]) - graph = helper.make_graph( - [res], - "power_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), - helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape)), - ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(np_res.shape))], - ) + graph = helper.make_graph( + [res], + "power_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(np_res.shape))], + ) - model = helper.make_model(graph, producer_name="power_test") + model = helper.make_model(graph, producer_name="power_test") - for target, dev in tvm.testing.enabled_targets(): tvm_out = get_tvm_output(model, [x, y], target, dev, np_res.shape) tvm.testing.assert_allclose(np_res, tvm_out, rtol=1e-5, atol=1e-5) - -@tvm.testing.uses_gpu -def test_power(): _test_power_iteration((1, 3), (1)) _test_power_iteration((2, 3), (2, 3)) _test_power_iteration((2, 3), (1, 3)) -def verify_range(start, limit, delta, dtype): - dtype_map = { - "float32": TensorProto.FLOAT, - "int32": TensorProto.INT32, - "int64": TensorProto.INT64, - } - dtype_onnx = dtype_map[dtype] - y = helper.make_node("Range", ["start", "limit", "delta"], ["output"]) - graph = helper.make_graph( - [y], - "range_test", - inputs=[ - helper.make_tensor_value_info("start", dtype_onnx, []), - helper.make_tensor_value_info("limit", dtype_onnx, []), - helper.make_tensor_value_info("delta", dtype_onnx, []), - ], - outputs=[ - helper.make_tensor_value_info( - "output", dtype_onnx, np.arange(start, limit, delta).shape - ) - ], - ) - model = helper.make_model(graph, producer_name="range_test") - inputs = [np.array(x).astype(dtype) for x in [start, limit, delta]] - verify_with_ort_with_inputs(model, inputs, use_vm=True) - +@tvm.testing.parametrize_targets +def test_range(target, dev): + def verify_range(start, limit, delta, dtype): + dtype_map = { + "float32": TensorProto.FLOAT, + "int32": TensorProto.INT32, + "int64": TensorProto.INT64, + } + dtype_onnx = dtype_map[dtype] + y = helper.make_node("Range", ["start", "limit", "delta"], ["output"]) + graph = helper.make_graph( + [y], + "range_test", + inputs=[ + helper.make_tensor_value_info("start", dtype_onnx, []), + helper.make_tensor_value_info("limit", dtype_onnx, []), + helper.make_tensor_value_info("delta", dtype_onnx, []), + ], + outputs=[ + helper.make_tensor_value_info( + "output", dtype_onnx, np.arange(start, limit, delta).shape + ) + ], + ) + model = helper.make_model(graph, producer_name="range_test") + inputs = [np.array(x).astype(dtype) for x in [start, limit, delta]] + verify_with_ort_with_inputs(model, inputs, target=target, dev=dev, use_vm=True) -@tvm.testing.uses_gpu -def test_range(): for t in ["float32", "int32", "int64"]: verify_range(0, 10, 1, t) verify_range(2, 8, 2, t) @@ -508,8 +558,8 @@ def test_range(): verify_range(-2, -7, -1, t) -@tvm.testing.uses_gpu -def test_squeeze(): +@tvm.testing.parametrize_targets +def test_squeeze(target, dev): in_shape = (1, 3, 1, 3, 1, 1) out_shape = (3, 3) y = helper.make_node("Squeeze", ["in"], ["out"], axes=[0, 2, 4, 5]) @@ -523,11 +573,11 @@ def test_squeeze(): model = helper.make_model(graph, producer_name="squeeze_test") x = np.random.uniform(size=in_shape).astype("float32") - verify_with_ort_with_inputs(model, [x], [out_shape], opset=11) + verify_with_ort_with_inputs(model, [x], [out_shape], target=target, dev=dev, opset=11) -@tvm.testing.uses_gpu -def test_flatten(): +@tvm.testing.parametrize_targets +def test_flatten(target, dev): in_shape = (1, 3, 4, 4) axis = 1 @@ -543,11 +593,11 @@ def test_flatten(): ) model = helper.make_model(graph, producer_name="flatten_test") - verify_with_ort(model, [in_shape]) + verify_with_ort(model, [in_shape], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_unsqueeze(): +@tvm.testing.parametrize_targets +def test_unsqueeze(target, dev): in_shape = (3, 3) axis = (0, 3, 4) out_shape = (1, 3, 3, 1, 1) @@ -561,37 +611,36 @@ def test_unsqueeze(): ) model = helper.make_model(graph, producer_name="squeeze_test") - verify_with_ort(model, [in_shape], opset=11) + verify_with_ort(model, [in_shape], target=target, dev=dev, opset=11) -def verify_gather(in_shape, indices, axis, dtype): - x = np.random.uniform(size=in_shape).astype(dtype) - indices = np.array(indices, dtype="int64") - out_np = np.take(x, indices, axis=axis) - - y = helper.make_node("Gather", ["in", "indices"], ["out"], axis=axis) +@tvm.testing.parametrize_targets +def test_gather(target, dev): + def verify_gather(in_shape, indices, axis, dtype): + x = np.random.uniform(size=in_shape).astype(dtype) + indices = np.array(indices, dtype="int64") + out_np = np.take(x, indices, axis=axis) - graph = helper.make_graph( - [y], - "gather_test", - inputs=[ - helper.make_tensor_value_info( - "in", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], list(in_shape) - ), - helper.make_tensor_value_info("indices", TensorProto.INT64, list(indices.shape)), - ], - outputs=[ - helper.make_tensor_value_info( - "out", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], list(out_np.shape) - ) - ], - ) - model = helper.make_model(graph, producer_name="gather_test") - verify_with_ort_with_inputs(model, [x, indices], dtype=dtype) + y = helper.make_node("Gather", ["in", "indices"], ["out"], axis=axis) + graph = helper.make_graph( + [y], + "gather_test", + inputs=[ + helper.make_tensor_value_info( + "in", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], list(in_shape) + ), + helper.make_tensor_value_info("indices", TensorProto.INT64, list(indices.shape)), + ], + outputs=[ + helper.make_tensor_value_info( + "out", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], list(out_np.shape) + ) + ], + ) + model = helper.make_model(graph, producer_name="gather_test") + verify_with_ort_with_inputs(model, [x, indices], target=target, dev=dev, dtype=dtype) -@tvm.testing.uses_gpu -def test_gather(): verify_gather((4,), [1], 0, "int32") verify_gather((1, 4), [0], 0, "int32") verify_gather((4,), [[[1, 0], [0, 1]]], 0, "float32") @@ -600,8 +649,8 @@ def test_gather(): verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, "float32") -@tvm.testing.uses_gpu -def test_dynamic_gather(): +@tvm.testing.parametrize_targets +def test_dynamic_gather(target, dev): dtype = "float32" in_shape = [2, 2] indices = 1 @@ -641,33 +690,30 @@ def test_dynamic_gather(): mod, params = relay.frontend.from_onnx(model) - for target, device in tvm.testing.enabled_targets(): - ex = relay.create_executor("vm", mod=mod, device=device, target=target) - result = ex.evaluate()(x, **params) - tvm.testing.assert_allclose(out_np, result.numpy(), rtol=1e-5, atol=1e-5) - + result = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()(x, **params) + tvm.testing.assert_allclose(out_np, result.numpy(), rtol=1e-5, atol=1e-5) -def verify_gatherelements(in_shape, indices, axis): - x = np.random.uniform(size=in_shape).astype("float32") - indices = np.array(indices, dtype="int32") - y = helper.make_node("GatherElements", ["data", "indices"], ["output"], axis=axis) - graph = helper.make_graph( - [y], - "gather_elements_test", - inputs=[ - helper.make_tensor_value_info("data", TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)), - ], - outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], - ) - model = helper.make_model(graph, producer_name="gather_elements_test") +@tvm.testing.parametrize_targets +def test_gatherelements(target, dev): + def verify_gatherelements(in_shape, indices, axis): + x = np.random.uniform(size=in_shape).astype("float32") + indices = np.array(indices, dtype="int32") - verify_with_ort_with_inputs(model, [x, indices]) + y = helper.make_node("GatherElements", ["data", "indices"], ["output"], axis=axis) + graph = helper.make_graph( + [y], + "gather_elements_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], + ) + model = helper.make_model(graph, producer_name="gather_elements_test") + verify_with_ort_with_inputs(model, [x, indices], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_gatherelements(): verify_gatherelements((4,), [3, 0, 2, 1], 0) verify_gatherelements((2, 2), [[1, 0], [0, 1]], 0) verify_gatherelements((2, 2), [[0, 0], [1, 0]], 1) @@ -682,29 +728,30 @@ def test_gatherelements(): verify_gatherelements((3, 3, 3), indices, 2) -def verify_scatter(in_shape, indices, axis): - x = np.random.uniform(size=in_shape).astype("float32") - indices = np.array(indices, dtype="int32") - updates = np.random.uniform(size=indices.shape).astype("float32") - - y = helper.make_node("ScatterElements", ["data", "indices", "updates"], ["output"], axis=axis) +@tvm.testing.parametrize_targets +def test_scatter(target, dev): + def verify_scatter(in_shape, indices, axis): + x = np.random.uniform(size=in_shape).astype("float32") + indices = np.array(indices, dtype="int32") + updates = np.random.uniform(size=indices.shape).astype("float32") - graph = helper.make_graph( - [y], - "scatter_test", - inputs=[ - helper.make_tensor_value_info("data", TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)), - helper.make_tensor_value_info("updates", TensorProto.FLOAT, list(indices.shape)), - ], - outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], - ) - model = helper.make_model(graph, producer_name="scatter_test") - verify_with_ort_with_inputs(model, [x, indices, updates]) + y = helper.make_node( + "ScatterElements", ["data", "indices", "updates"], ["output"], axis=axis + ) + graph = helper.make_graph( + [y], + "scatter_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)), + helper.make_tensor_value_info("updates", TensorProto.FLOAT, list(indices.shape)), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], + ) + model = helper.make_model(graph, producer_name="scatter_test") + verify_with_ort_with_inputs(model, [x, indices, updates], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_scatter(): verify_scatter((4,), [1], 0) verify_scatter((1, 4), [[0]], 0) verify_scatter((4,), [2, 3], 0) @@ -713,120 +760,130 @@ def test_scatter(): verify_scatter((4, 3, 5, 6), [[[[2, 1, 0, 0]]]], 0) -def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None): - if axes: - y = helper.make_node("Slice", ["in"], ["out"], axes=axes, starts=starts, ends=ends) - else: - y = helper.make_node("Slice", ["in"], ["out"], starts=starts, ends=ends) +@tvm.testing.parametrize_targets +def test_slice(target, dev): + def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None): + if axes: + y = helper.make_node("Slice", ["in"], ["out"], axes=axes, starts=starts, ends=ends) + else: + y = helper.make_node("Slice", ["in"], ["out"], starts=starts, ends=ends) - graph = helper.make_graph( - [y], - "slice_test", - inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], - ) + graph = helper.make_graph( + [y], + "slice_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + ) - model = helper.make_model(graph, producer_name="slice_test") - verify_with_ort_with_inputs(model, [indata], [outdata.shape], opset=1) - - -def _test_slice_iteration_v10(indata, outdata, **attrs): - starts = attrs["starts"] - ends = attrs["ends"] - axes = None if "axes" not in attrs else attrs["axes"] - steps = None if "steps" not in attrs else attrs["steps"] - starts = np.asarray(starts) - ends = np.asarray(ends) - inputs = [ - helper.make_tensor_value_info("data", TensorProto.FLOAT, list(indata.shape)), - helper.make_tensor_value_info("starts", TensorProto.INT64, list(starts.shape)), - helper.make_tensor_value_info("ends", TensorProto.INT64, list(ends.shape)), - ] - initializer = [ - helper.make_tensor("starts", TensorProto.INT64, list(starts.shape), starts), - helper.make_tensor("ends", TensorProto.INT64, list(ends.shape), ends), - ] - nodes = [] + model = helper.make_model(graph, producer_name="slice_test") + verify_with_ort_with_inputs( + model, [indata], [outdata.shape], opset=1, target=target, dev=dev + ) - if "add_noop_to_input_attrs" in attrs: + def _test_slice_iteration_v10(indata, outdata, **attrs): + starts = attrs["starts"] + ends = attrs["ends"] + axes = None if "axes" not in attrs else attrs["axes"] + steps = None if "steps" not in attrs else attrs["steps"] + starts = np.asarray(starts) + ends = np.asarray(ends) + inputs = [ + helper.make_tensor_value_info("data", TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("starts", TensorProto.INT64, list(starts.shape)), + helper.make_tensor_value_info("ends", TensorProto.INT64, list(ends.shape)), + ] + initializer = [ + helper.make_tensor("starts", TensorProto.INT64, list(starts.shape), starts), + helper.make_tensor("ends", TensorProto.INT64, list(ends.shape), ends), + ] + nodes = [] + + if "add_noop_to_input_attrs" in attrs: + + def add_noop_to_input_attr(attr_name, attr): + output_name = attr_name + "_output" + + ref_shape = list(np.array(attr).shape) + ref_shape.insert(0, 1) + ref_shape = tuple(ref_shape) + ref_array = np.array(ref_shape) + ref_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["ref_in_" + attr_name], + value=onnx.helper.make_tensor( + name="const_tensor__1_" + attr_name, + data_type=onnx.TensorProto.INT64, + dims=ref_array.shape, + vals=ref_array.flatten().astype(int), + ), + ) + in_shape = np.array(attr).shape + in_array = np.array(in_shape) + ref_node2 = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["input_shape_" + attr_name], + value=onnx.helper.make_tensor( + name="const_tensor__2_" + attr_name, + data_type=onnx.TensorProto.INT64, + dims=in_array.shape, + vals=in_array.flatten().astype(int), + ), + ) - def add_noop_to_input_attr(attr_name, attr): - output_name = attr_name + "_output" + reshape1_node = helper.make_node( + "Reshape", [attr_name, "ref_in_" + attr_name], ["reshape_" + attr_name] + ) + reshape2_node = helper.make_node( + "Reshape", ["reshape_" + attr_name, "input_shape_" + attr_name], [output_name] + ) + return [ref_node, ref_node2, reshape1_node, reshape2_node] - ref_shape = list(np.array(attr).shape) - ref_shape.insert(0, 1) - ref_shape = tuple(ref_shape) - ref_array = np.array(ref_shape) - ref_node = onnx.helper.make_node( - "Constant", - inputs=[], - outputs=["ref_in_" + attr_name], - value=onnx.helper.make_tensor( - name="const_tensor__1_" + attr_name, - data_type=onnx.TensorProto.INT64, - dims=ref_array.shape, - vals=ref_array.flatten().astype(int), - ), + slice_inputs = [] + for attr_name in ["starts", "ends", "axes", "steps"]: + if attr_name not in attrs: + continue + if "add_noop_to_input_attrs" in attrs and attr_name in attrs["add_noop_to_input_attrs"]: + nodes.extend(add_noop_to_input_attr(attr_name, attrs[attr_name])) + slice_inputs.append(attr_name + "_output") + else: + slice_inputs.append(attr_name) + + if axes: + axes = np.asarray(axes) + inputs.append( + helper.make_tensor_value_info("axes", TensorProto.INT64, list(axes.shape)) ) - in_shape = np.array(attr).shape - in_array = np.array(in_shape) - ref_node2 = onnx.helper.make_node( - "Constant", - inputs=[], - outputs=["input_shape_" + attr_name], - value=onnx.helper.make_tensor( - name="const_tensor__2_" + attr_name, - data_type=onnx.TensorProto.INT64, - dims=in_array.shape, - vals=in_array.flatten().astype(int), - ), + initializer.append( + helper.make_tensor("axes", TensorProto.INT64, list(axes.shape), axes) ) - reshape1_node = helper.make_node( - "Reshape", [attr_name, "ref_in_" + attr_name], ["reshape_" + attr_name] + if steps: + assert axes is not None and len(axes) == len(steps) + steps = np.asarray(steps) + inputs.append( + helper.make_tensor_value_info("steps", TensorProto.INT64, list(axes.shape)) ) - reshape2_node = helper.make_node( - "Reshape", ["reshape_" + attr_name, "input_shape_" + attr_name], [output_name] + initializer.append( + helper.make_tensor("steps", TensorProto.INT64, list(steps.shape), steps) ) - return [ref_node, ref_node2, reshape1_node, reshape2_node] - - slice_inputs = [] - for attr_name in ["starts", "ends", "axes", "steps"]: - if attr_name not in attrs: - continue - if "add_noop_to_input_attrs" in attrs and attr_name in attrs["add_noop_to_input_attrs"]: - nodes.extend(add_noop_to_input_attr(attr_name, attrs[attr_name])) - slice_inputs.append(attr_name + "_output") - else: - slice_inputs.append(attr_name) - - if axes: - axes = np.asarray(axes) - inputs.append(helper.make_tensor_value_info("axes", TensorProto.INT64, list(axes.shape))) - initializer.append(helper.make_tensor("axes", TensorProto.INT64, list(axes.shape), axes)) - - if steps: - assert axes is not None and len(axes) == len(steps) - steps = np.asarray(steps) - inputs.append(helper.make_tensor_value_info("steps", TensorProto.INT64, list(axes.shape))) - initializer.append(helper.make_tensor("steps", TensorProto.INT64, list(steps.shape), steps)) - - y = helper.make_node("Slice", ["data", *slice_inputs], ["out"]) - nodes.append(y) - graph = helper.make_graph( - nodes, - "slice_test", - inputs=inputs, - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], - initializer=initializer, - ) - model = helper.make_model(graph, producer_name="slice_test") - verify_with_ort_with_inputs(model, [indata], opset=10, freeze_params=True, use_vm=True) + y = helper.make_node("Slice", ["data", *slice_inputs], ["out"]) + nodes.append(y) + graph = helper.make_graph( + nodes, + "slice_test", + inputs=inputs, + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + initializer=initializer, + ) + model = helper.make_model(graph, producer_name="slice_test") + verify_with_ort_with_inputs( + model, [indata], opset=10, freeze_params=True, use_vm=True, target=target, dev=dev + ) -@tvm.testing.uses_gpu -def test_slice(): x = np.random.randn(20, 10, 5).astype(np.float32) _test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1)) _test_slice_iteration_v1(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4)) @@ -900,7 +957,9 @@ def test_slice(): ) -def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs, opset=None): +def _test_onnx_op_elementwise( + target, dev, inshape, outfunc, npargs, dtype, opname, kwargs, opset=None +): indata = np.random.uniform(-1, 1, size=inshape).astype(dtype) outdata = outfunc(indata, **npargs) @@ -914,22 +973,26 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs, o ) model = helper.make_model(graph, producer_name=opname + "_test") - verify_with_ort_with_inputs(model, [indata], [outdata.shape], opset=opset, dtype=dtype) + verify_with_ort_with_inputs( + model, [indata], [outdata.shape], opset=opset, dtype=dtype, target=target, dev=dev + ) -@tvm.testing.uses_gpu -def test_floor(): - _test_onnx_op_elementwise((2, 4, 5, 6), np.floor, {}, "float32", "Floor", {}) +@tvm.testing.parametrize_targets +def test_floor(target, dev): + _test_onnx_op_elementwise(target, dev, (2, 4, 5, 6), np.floor, {}, "float32", "Floor", {}) -@tvm.testing.uses_gpu -def test_ceil(): - _test_onnx_op_elementwise((2, 4, 5, 6), np.ceil, {}, "float32", "Ceil", {}) +@tvm.testing.parametrize_targets +def test_ceil(target, dev): + _test_onnx_op_elementwise(target, dev, (2, 4, 5, 6), np.ceil, {}, "float32", "Ceil", {}) -@tvm.testing.uses_gpu -def test_clip(): +@tvm.testing.parametrize_targets +def test_clip(target, dev): _test_onnx_op_elementwise( + target, + dev, (2, 4, 5, 6), np.clip, {"a_min": -1.0, "a_max": 1.0}, @@ -940,6 +1003,8 @@ def test_clip(): ) _test_onnx_op_elementwise( + target, + dev, (2, 4, 5, 6), np.clip, {"a_min": -np.inf, "a_max": 1.0}, @@ -950,6 +1015,8 @@ def test_clip(): ) _test_onnx_op_elementwise( + target, + dev, (2, 4, 5, 6), np.clip, {"a_min": -1.0, "a_max": np.inf}, @@ -960,8 +1027,8 @@ def test_clip(): ) -@tvm.testing.uses_gpu -def test_clip_min_max_as_inputs(): +@tvm.testing.parametrize_targets +def test_clip_min_max_as_inputs(target, dev): input_shape = (2, 4, 5, 6) nodes = [ make_constant_node("min", onnx.TensorProto.FLOAT, (), [0.0]), @@ -977,15 +1044,15 @@ def test_clip_min_max_as_inputs(): ) model = helper.make_model(graph, producer_name="clip_test") - verify_with_ort(model, [input_shape], out_shape=[input_shape]) + verify_with_ort(model, [input_shape], out_shape=[input_shape], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_round(): - _test_onnx_op_elementwise((2, 4, 5, 6), np.round, {}, "float32", "Round", {}) +@tvm.testing.parametrize_targets +def test_round(target, dev): + _test_onnx_op_elementwise(target, dev, (2, 4, 5, 6), np.round, {}, "float32", "Round", {}) -def _test_finite_ops(inshape, outfunc, npargs, dtype, opname, kwargs): +def _test_finite_ops(target, dev, inshape, outfunc, npargs, dtype, opname, kwargs): indata = np.random.choice(a=[np.nan, np.inf, -np.inf, 0.5, 1.0, 0], size=inshape).astype(dtype) outdata = outfunc(indata, **npargs) @@ -999,50 +1066,53 @@ def _test_finite_ops(inshape, outfunc, npargs, dtype, opname, kwargs): ) model = helper.make_model(graph, producer_name=opname + "_test") - verify_with_ort_with_inputs(model, [indata], [outdata.shape], dtype=dtype) - + verify_with_ort_with_inputs( + model, [indata], [outdata.shape], dtype=dtype, target=target, dev=dev + ) -@tvm.testing.uses_gpu -def test_isinf(): - _test_finite_ops((2, 4, 5, 6), np.isinf, {}, "float32", "IsInf", {}) +@tvm.testing.parametrize_targets +def test_isinf(target, dev): + _test_finite_ops(target, dev, (2, 4, 5, 6), np.isinf, {}, "float32", "IsInf", {}) -@tvm.testing.uses_gpu -def test_isnan(): - _test_finite_ops((2, 4, 5, 6), np.isnan, {}, "float32", "IsNaN", {}) +@tvm.testing.parametrize_targets +def test_isnan(target, dev): + _test_finite_ops(target, dev, (2, 4, 5, 6), np.isnan, {}, "float32", "IsNaN", {}) -def verify_gather_nd(in_shape, indices, out_shape, dtype="float32", batch_dims=0, opset=11): - x = np.random.uniform(size=in_shape).astype(dtype) - indices = np.array(indices, dtype="int64") - y = helper.make_node("GatherND", ["in", "indices"], ["out"]) +@tvm.testing.parametrize_targets +def test_gather_nd(target, dev): + def verify_gather_nd(in_shape, indices, out_shape, dtype="float32", batch_dims=0, opset=11): + x = np.random.uniform(size=in_shape).astype(dtype) + indices = np.array(indices, dtype="int64") - if opset >= 12: - batch_dims_attr = helper.make_attribute("batch_dims", batch_dims) - y.attribute.append(batch_dims_attr) + y = helper.make_node("GatherND", ["in", "indices"], ["out"]) - graph = helper.make_graph( - [y], - "gather_test", - inputs=[ - helper.make_tensor_value_info( - "in", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], list(in_shape) - ), - helper.make_tensor_value_info("indices", TensorProto.INT64, list(indices.shape)), - ], - outputs=[ - helper.make_tensor_value_info( - "out", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], list(out_shape) - ) - ], - ) - model = helper.make_model(graph, producer_name="gather_test") - verify_with_ort_with_inputs(model, [x, indices], [out_shape], opset=opset) + if opset >= 12: + batch_dims_attr = helper.make_attribute("batch_dims", batch_dims) + y.attribute.append(batch_dims_attr) + graph = helper.make_graph( + [y], + "gather_test", + inputs=[ + helper.make_tensor_value_info( + "in", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], list(in_shape) + ), + helper.make_tensor_value_info("indices", TensorProto.INT64, list(indices.shape)), + ], + outputs=[ + helper.make_tensor_value_info( + "out", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], list(out_shape) + ) + ], + ) + model = helper.make_model(graph, producer_name="gather_test") + verify_with_ort_with_inputs( + model, [x, indices], [out_shape], opset=opset, target=target, dev=dev + ) -@tvm.testing.uses_gpu -def test_gather_nd(): verify_gather_nd([2, 2], [[0, 0], [1, 1]], [2], "int32") verify_gather_nd([2, 2], [[1], [0]], [2, 2]) verify_gather_nd([2, 2, 2], [[0, 1], [1, 0]], [2, 2]) @@ -1059,8 +1129,8 @@ def test_gather_nd(): ) -@tvm.testing.uses_gpu -def test_onehot(): +@tvm.testing.parametrize_targets +def test_onehot(target, dev): indices_shape = [10] indices_array = np.random.randint(low=0, high=9, size=indices_shape, dtype="int32") depth = 10 @@ -1083,53 +1153,65 @@ def test_onehot(): model = helper.make_model(graph, producer_name="onehot_test") # TODO(jwfromm): Replace test against np with test against onnxrt once we update versions. - for target, dev in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm( - model, [indices_array, np.array([depth]).astype("int32"), values], target, dev - ) - tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) - + tvm_out = get_tvm_output_with_vm( + model, [indices_array, np.array([depth]).astype("int32"), values], target, dev + ) + tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) -def verify_gemm(a_shape, b_shape, c_shape=None, freeze_params=False, dtype="float32"): - out_shape = [a_shape[0], b_shape[1]] - a_array = np.random.uniform(size=a_shape).astype(dtype) - b_array = np.random.uniform(size=b_shape).astype(dtype) - input_names = ["a", "b"] - ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] - input_nodes = [ - helper.make_tensor_value_info("a", ONNX_DTYPE, list(a_shape)), - helper.make_tensor_value_info("b", ONNX_DTYPE, list(b_shape)), - ] - input_values = [a_array, b_array] - if c_shape is not None: - c_array = np.random.uniform(size=c_shape).astype(dtype) - input_names.append("c") - input_nodes.append(helper.make_tensor_value_info("c", ONNX_DTYPE, list(c_shape))) - input_values.append(c_array) - gemm_node = helper.make_node("Gemm", input_names, ["out"]) +@tvm.testing.parametrize_targets +def test_gemm(target, dev): + def verify_gemm(a_shape, b_shape, c_shape=None, freeze_params=False, dtype="float32"): + out_shape = [a_shape[0], b_shape[1]] + a_array = np.random.uniform(size=a_shape).astype(dtype) + b_array = np.random.uniform(size=b_shape).astype(dtype) + input_names = ["a", "b"] + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + input_nodes = [ + helper.make_tensor_value_info("a", ONNX_DTYPE, list(a_shape)), + helper.make_tensor_value_info("b", ONNX_DTYPE, list(b_shape)), + ] + input_values = [a_array, b_array] + if c_shape is not None: + c_array = np.random.uniform(size=c_shape).astype(dtype) + input_names.append("c") + input_nodes.append(helper.make_tensor_value_info("c", ONNX_DTYPE, list(c_shape))) + input_values.append(c_array) - graph = helper.make_graph( - [gemm_node], - "gemm_test", - inputs=input_nodes, - outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, list(out_shape))], - ) + gemm_node = helper.make_node("Gemm", input_names, ["out"]) - model = helper.make_model(graph, producer_name="gemm_test") - verify_with_ort_with_inputs(model, input_values, freeze_params=freeze_params, dtype=dtype) + graph = helper.make_graph( + [gemm_node], + "gemm_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="gemm_test") + atol = 1e-5 + rtol = 1e-5 + if dtype == "float16": + atol = 1e-3 + rtol = 1e-3 + verify_with_ort_with_inputs( + model, + input_values, + freeze_params=freeze_params, + dtype=dtype, + atol=atol, + rtol=rtol, + target=target, + dev=dev, + ) -@tvm.testing.uses_gpu -def test_gemm(): verify_gemm(a_shape=(4, 3), b_shape=(3, 4)) verify_gemm(a_shape=(4, 3), b_shape=(3, 4), c_shape=(4,)) verify_gemm(a_shape=(4, 3), b_shape=(3, 4), c_shape=(4,), freeze_params=True) verify_gemm(a_shape=(4, 3), b_shape=(3, 4), c_shape=(4,), freeze_params=True, dtype="float16") -@tvm.testing.uses_gpu -def test_matmul(): +@tvm.testing.parametrize_targets +def test_matmul(target, dev): a_shape = (4, 3) b_shape = (3, 4) out_shape = [a_shape[0], b_shape[1]] @@ -1150,43 +1232,57 @@ def test_matmul(): ) model = helper.make_model(graph, producer_name="matmul_test") - verify_with_ort_with_inputs(model, [a_array, b_array]) + verify_with_ort_with_inputs(model, [a_array, b_array], target=target, dev=dev) -def verify_batch_matmul(a_shape, b_shape, out_shape, target, dev): - a_array = np.random.uniform(size=a_shape).astype("float32") - b_array = np.random.uniform(size=b_shape).astype("float32") +@tvm.testing.parametrize_targets +def test_batch_matmul(target, dev): + def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None): + a_array = np.random.uniform(size=a_shape).astype("float32") + b_array = np.random.uniform(size=b_shape).astype("float32") - mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) + mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) - graph = helper.make_graph( - [mul_node], - "matmul_test", - inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), - helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), - ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)], - ) - - model = helper.make_model(graph, producer_name="matmul_test") - verify_with_ort_with_inputs(model, [a_array, b_array], use_vm=True, targets=[target]) + graph = helper.make_graph( + [mul_node], + "matmul_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)], + ) + model = helper.make_model(graph, producer_name="matmul_test") + verify_with_ort_with_inputs( + model, + [a_array, b_array], + use_vm=True, + target=target, + dev=dev, + convert_config=convert_config, + ) -# TODO(mbrookhart, electriclilies): Add CUDA as a target once batch matmul is fixed -@tvm.testing.parametrize_targets("llvm") -def test_batch_matmul(target, dev): - verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), (2, 3, 4, 4), target, dev) - verify_batch_matmul((2, 4, 3), (3, 4), (2, 4, 4), target, dev) - verify_batch_matmul((2, 3, 4, 3), (3, 4), (2, 3, 4, 4), target, dev) + verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), (2, 3, 4, 4)) + verify_batch_matmul((2, 4, 3), (3, 4), (2, 4, 4)) + verify_batch_matmul((2, 3, 4, 3), (3, 4), (2, 3, 4, 4)) # Test implicit broadcasting. - verify_batch_matmul((4, 3), (2, 3, 4), (2, 4, 4), target, dev) - verify_batch_matmul((2, 4, 3), (1, 3, 4), (2, 4, 4), target, dev) - verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4), target, dev) + verify_batch_matmul((4, 3), (2, 3, 4), (2, 4, 4)) + verify_batch_matmul((2, 4, 3), (1, 3, 4), (2, 4, 4)) + verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4)) + verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32)) + verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16)) + # Test transb=False + verify_batch_matmul( + (2, 3, 4, 3), + (2, 3, 3, 4), + (2, 3, 4, 4), + convert_config={"use_nt_batch_matmul": False}, + ) def verify_simple_dynamic_model(a_shape, b_shape, target, dev): - def verify_model(ex, a_shape, b_shape): + def verify_model(model, a_shape, b_shape): a_array = np.random.uniform(size=a_shape).astype("float32") b_array = np.random.uniform(size=b_shape).astype("float32") # matmul @@ -1194,7 +1290,7 @@ def verify_model(ex, a_shape, b_shape): # relu out_np[out_np < 0] = 0 - tvm_out = ex.evaluate()(a_array, b_array).numpy() + tvm_out = model(a_array, b_array).numpy() tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) @@ -1221,11 +1317,10 @@ def verify_model(ex, a_shape, b_shape): b_anys = [relay.Any()] * len(b_shape) mod, params = relay.frontend.from_onnx(model, {"a": a_anys, "b": b_anys}) - - ex = relay.create_executor("vm", mod=mod, device=dev, target=target) - verify_model(ex, a_shape, b_shape) - verify_model(ex, [a * 2 for a in a_shape], [b * 2 for b in b_shape]) - verify_model(ex, [a * 3 for a in a_shape], [b * 3 for b in b_shape]) + model = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate() + verify_model(model, a_shape, b_shape) + verify_model(model, [a * 2 for a in a_shape], [b * 2 for b in b_shape]) + verify_model(model, [a * 3 for a in a_shape], [b * 3 for b in b_shape]) # TODO(mbrookhart, electriclilies): Add CUDA as a target once batch matmul is fixed @@ -1236,70 +1331,71 @@ def test_batch_matmul_dynamic_model(target, dev): verify_simple_dynamic_model((2, 3, 4, 3), (3, 4), target, dev) -def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): - in_array = np.random.uniform(size=shape).astype(dtype) - - if alpha == None and beta == None and bias == None: - alpha = 0.0001 - beta = 0.75 - bias = 1.0 - node = onnx.helper.make_node("LRN", inputs=["in"], outputs=["out"], size=nsize) - else: - node = onnx.helper.make_node( - "LRN", inputs=["in"], outputs=["out"], alpha=alpha, beta=beta, bias=bias, size=nsize - ) +@tvm.testing.parametrize_targets +def test_lrn(target, dev): + def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): + in_array = np.random.uniform(size=shape).astype(dtype) - graph = helper.make_graph( - [node], - "lrn_test", - inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))], - ) - model = helper.make_model(graph, producer_name="lrn_test") - verify_with_ort_with_inputs(model, [in_array]) + if alpha == None and beta == None and bias == None: + alpha = 0.0001 + beta = 0.75 + bias = 1.0 + node = onnx.helper.make_node("LRN", inputs=["in"], outputs=["out"], size=nsize) + else: + node = onnx.helper.make_node( + "LRN", inputs=["in"], outputs=["out"], alpha=alpha, beta=beta, bias=bias, size=nsize + ) + graph = helper.make_graph( + [node], + "lrn_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))], + ) + model = helper.make_model(graph, producer_name="lrn_test") + verify_with_ort_with_inputs(model, [in_array], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_lrn(): verify_lrn((5, 5, 5, 5), 3, "float32") verify_lrn((5, 5, 5, 5), 3, "float32", alpha=0.0002, beta=0.5, bias=2.0) -def verify_instance_norm(shape, axis=1): - x = np.random.randn(*shape).astype(np.float32) - gamma = np.random.randn(shape[1]).astype(np.float32) - beta = np.random.randn(shape[1]).astype(np.float32) - epsilon = 1e-5 - - node = onnx.helper.make_node( - "InstanceNormalization", - inputs=["x", "gamma", "beta"], - outputs=["y"], - epsilon=epsilon, - ) - graph = helper.make_graph( - [node], - "instance_norm_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)), - helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)), - helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],)), - ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))], - ) - model = helper.make_model(graph, producer_name="instance_norm_test") - verify_with_ort_with_inputs(model, [x, gamma, beta], out_shape=[shape]) +@tvm.testing.parametrize_targets +def test_instance_norm(target, dev): + def verify_instance_norm(shape, axis=1): + x = np.random.randn(*shape).astype(np.float32) + gamma = np.random.randn(shape[1]).astype(np.float32) + beta = np.random.randn(shape[1]).astype(np.float32) + epsilon = 1e-5 + node = onnx.helper.make_node( + "InstanceNormalization", + inputs=["x", "gamma", "beta"], + outputs=["y"], + epsilon=epsilon, + ) + graph = helper.make_graph( + [node], + "instance_norm_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)), + helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],)), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))], + ) + model = helper.make_model(graph, producer_name="instance_norm_test") + verify_with_ort_with_inputs( + model, [x, gamma, beta], out_shape=[shape], target=target, dev=dev + ) -@tvm.testing.uses_gpu -def test_instance_norm(): verify_instance_norm((2, 3, 4, 5)) verify_instance_norm((32, 64, 80, 64)) verify_instance_norm((8, 6, 5)) verify_instance_norm((8, 7, 6, 5, 4)) -def verify_upsample_nearest(): +@tvm.testing.parametrize_targets +def test_upsample_nearest(target, dev): scale = 2 in_shape = (1, 1, 3, 3) out_shape = (1, 1, 3 * scale, 3 * scale) @@ -1315,10 +1411,11 @@ def verify_upsample_nearest(): ) model = helper.make_model(graph, producer_name="upsample_nearest_test") - verify_with_ort_with_inputs(model, [in_array], [out_shape], opset=7) + verify_with_ort_with_inputs(model, [in_array], [out_shape], opset=7, target=target, dev=dev) -def verify_upsample3d_nearest(): +@tvm.testing.parametrize_targets +def test_upsample3d_nearest(target, dev): scale = 2 in_shape = (1, 1, 3, 3, 3) out_shape = (1, 1, 3 * scale, 3 * scale, 3 * scale) @@ -1337,10 +1434,11 @@ def verify_upsample3d_nearest(): model = helper.make_model(graph, producer_name="upsample_nearest_test") # Upsample is deprecated after opset 9 - verify_with_ort_with_inputs(model, [in_array], [out_shape], opset=7) + verify_with_ort_with_inputs(model, [in_array], [out_shape], opset=7, target=target, dev=dev) -def verify_upsample_bilinear(): +@tvm.testing.parametrize_targets +def test_upsample_bilinear(target, dev): scale = 2 in_shape = (1, 1, 3, 3) out_shape = (1, 1, 3 * scale, 3 * scale) @@ -1356,10 +1454,11 @@ def verify_upsample_bilinear(): ) model = helper.make_model(graph, producer_name="upsample_bilinear_test") - verify_with_ort_with_inputs(model, [in_array], [out_shape], opset=7) + verify_with_ort_with_inputs(model, [in_array], [out_shape], opset=7, target=target, dev=dev) -def verify_upsample3d_trilinear(): +@tvm.testing.parametrize_targets +def test_upsample3d_trilinear(target, dev): scale = 2 in_shape = (1, 1, 3, 3, 3) out_shape = (1, 1, 3 * scale, 3 * scale, 3 * scale) @@ -1397,191 +1496,179 @@ def verify_upsample3d_trilinear(): model = helper.make_model(graph, producer_name="upsample_trilinear_test") # TODO(jwfromm): Trilinear upsampling not supported in 1.0.0 onnxruntime. # Replace topi comparison with verify_with_ort once we update. - for target, dev in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output(model, in_array, target, dev, out_shape, "float32") - tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) - - -@tvm.testing.uses_gpu -def test_upsample(): - verify_upsample_nearest() - verify_upsample_bilinear() - verify_upsample3d_nearest() - verify_upsample3d_trilinear() + tvm_out = get_tvm_output(model, in_array, target, dev, out_shape, "float32") + tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) -def verify_softmax(inshape, axis): - opname = "Softmax" - indata = np.random.uniform(size=inshape).astype(np.float32) - outshape = inshape - y = helper.make_node(opname, ["in"], ["out"]) - if axis is not None: - axis_attr = helper.make_attribute("axis", axis) - y.attribute.append(axis_attr) - - graph = helper.make_graph( - [y], - opname + "_test", - inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outshape))], - ) +@tvm.testing.parametrize_targets +def test_softmax(target, dev): + def verify_softmax(inshape, axis): + opname = "Softmax" + indata = np.random.uniform(size=inshape).astype(np.float32) + outshape = inshape + y = helper.make_node(opname, ["in"], ["out"]) + if axis is not None: + axis_attr = helper.make_attribute("axis", axis) + y.attribute.append(axis_attr) - model = helper.make_model(graph, producer_name=opname + "_test") - verify_with_ort_with_inputs(model, [indata]) + graph = helper.make_graph( + [y], + opname + "_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outshape))], + ) + model = helper.make_model(graph, producer_name=opname + "_test") + verify_with_ort_with_inputs(model, [indata], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_softmax(): verify_softmax((1, 10), None) verify_softmax((1, 10), 1) -def verify_min(input_dim): - dtype = "float32" - - a_np1 = np.random.uniform(size=input_dim).astype(dtype) - a_np2 = np.random.uniform(size=input_dim).astype(dtype) - a_np3 = np.random.uniform(size=input_dim).astype(dtype) +@tvm.testing.parametrize_targets +def test_forward_min(target, dev): + def verify_min(input_dim): + dtype = "float32" - min_node = helper.make_node("Min", ["a_np1", "a_np2", "a_np3"], ["out"]) + a_np1 = np.random.uniform(size=input_dim).astype(dtype) + a_np2 = np.random.uniform(size=input_dim).astype(dtype) + a_np3 = np.random.uniform(size=input_dim).astype(dtype) - graph = helper.make_graph( - [min_node], - "Min_test", - inputs=[ - helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)), - ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_dim))], - ) + min_node = helper.make_node("Min", ["a_np1", "a_np2", "a_np3"], ["out"]) - model = helper.make_model(graph, producer_name="Min_test") - verify_with_ort_with_inputs(model, [a_np1, a_np2, a_np3]) + graph = helper.make_graph( + [min_node], + "Min_test", + inputs=[ + helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_dim))], + ) + model = helper.make_model(graph, producer_name="Min_test") + verify_with_ort_with_inputs(model, [a_np1, a_np2, a_np3], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_forward_min(): verify_min((1, 3, 20, 20)) verify_min((20, 20)) -def verify_max(input_dim): - dtype = "float32" - - a_np1 = np.random.uniform(size=input_dim).astype(dtype) - a_np2 = np.random.uniform(size=input_dim).astype(dtype) - a_np3 = np.random.uniform(size=input_dim).astype(dtype) +@tvm.testing.parametrize_targets +def test_forward_max(target, dev): + def verify_max(input_dim): + dtype = "float32" - max_node = helper.make_node("Max", ["a_np1", "a_np2", "a_np3"], ["out"]) + a_np1 = np.random.uniform(size=input_dim).astype(dtype) + a_np2 = np.random.uniform(size=input_dim).astype(dtype) + a_np3 = np.random.uniform(size=input_dim).astype(dtype) - graph = helper.make_graph( - [max_node], - "Max_test", - inputs=[ - helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)), - ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_dim))], - ) + max_node = helper.make_node("Max", ["a_np1", "a_np2", "a_np3"], ["out"]) - model = helper.make_model(graph, producer_name="Max_test") - verify_with_ort_with_inputs(model, [a_np1, a_np2, a_np3]) + graph = helper.make_graph( + [max_node], + "Max_test", + inputs=[ + helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_dim))], + ) + model = helper.make_model(graph, producer_name="Max_test") + verify_with_ort_with_inputs(model, [a_np1, a_np2, a_np3], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_forward_max(): verify_max((1, 3, 20, 20)) verify_max((20, 20)) -def verify_mean(input_dim): - dtype = "float32" - - a_np1 = np.random.uniform(size=input_dim).astype(dtype) - a_np2 = np.random.uniform(size=input_dim).astype(dtype) - a_np3 = np.random.uniform(size=input_dim).astype(dtype) +@tvm.testing.parametrize_targets +def test_forward_mean(target, dev): + def verify_mean(input_dim): + dtype = "float32" - mean_node = helper.make_node("Mean", ["a_np1", "a_np2", "a_np3"], ["out"]) + a_np1 = np.random.uniform(size=input_dim).astype(dtype) + a_np2 = np.random.uniform(size=input_dim).astype(dtype) + a_np3 = np.random.uniform(size=input_dim).astype(dtype) - graph = helper.make_graph( - [mean_node], - "Mean_test", - inputs=[ - helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)), - helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)), - ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_dim))], - ) + mean_node = helper.make_node("Mean", ["a_np1", "a_np2", "a_np3"], ["out"]) - model = helper.make_model(graph, producer_name="Mean_test") - verify_with_ort_with_inputs(model, [a_np1, a_np2, a_np3]) + graph = helper.make_graph( + [mean_node], + "Mean_test", + inputs=[ + helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_dim))], + ) + model = helper.make_model(graph, producer_name="Mean_test") + verify_with_ort_with_inputs(model, [a_np1, a_np2, a_np3], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_forward_mean(): verify_mean((1, 3, 20, 20)) verify_mean((20, 20)) -def verify_hardsigmoid(input_dim, alpha, beta): - dtype = "float32" - - a_np1 = np.random.uniform(size=input_dim).astype(dtype) +@tvm.testing.parametrize_targets +def test_forward_hardsigmoid(target, dev): + def verify_hardsigmoid(input_dim, alpha, beta): + dtype = "float32" - hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], ["out"], alpha=alpha, beta=beta) + a_np1 = np.random.uniform(size=input_dim).astype(dtype) - graph = helper.make_graph( - [hardsigmoid_node], - "HardSigmoid_test", - inputs=[helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_dim))], - ) + hardsigmoid_node = helper.make_node( + "HardSigmoid", ["a_np1"], ["out"], alpha=alpha, beta=beta + ) - model = helper.make_model(graph, producer_name="HardSigmoid_test") - verify_with_ort_with_inputs(model, [a_np1]) + graph = helper.make_graph( + [hardsigmoid_node], + "HardSigmoid_test", + inputs=[helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_dim))], + ) + model = helper.make_model(graph, producer_name="HardSigmoid_test") + verify_with_ort_with_inputs(model, [a_np1], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_forward_hardsigmoid(): verify_hardsigmoid((1, 3, 20, 20), 0.5, 0.6) verify_hardsigmoid((20, 20), 0.3, 0.4) -def verify_argreduce(input_dim, op_name, axis=None, keepdims=None): - a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32) - out_shape = list(a_np1.shape) - def_axis = axis if axis is not None else 0 - if keepdims == 1 or keepdims == None: - out_shape[def_axis] = 1 - else: - out_shape.pop(def_axis) - - node = onnx.helper.make_node(op_name, inputs=["a_np1"], outputs=["out"]) +# TODO (mbrookhart, electriclilies) Fix argmin on GPU and enable this test +@tvm.testing.known_failing_targets("cuda") +@tvm.testing.parametrize_targets +def test_forward_arg_min_max(target, dev): + def verify_argreduce(input_dim, op_name, axis=None, keepdims=None): + a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32) + out_shape = list(a_np1.shape) + def_axis = axis if axis is not None else 0 + if keepdims == 1 or keepdims == None: + out_shape[def_axis] = 1 + else: + out_shape.pop(def_axis) - if keepdims is not None: - keepdims_attr = helper.make_attribute("keepdims", keepdims) - node.attribute.append(keepdims_attr) - if axis is not None: - axis_attr = helper.make_attribute("axis", axis) - node.attribute.append(axis_attr) + node = onnx.helper.make_node(op_name, inputs=["a_np1"], outputs=["out"]) - graph = helper.make_graph( - [node], - "argreduce_test", - inputs=[helper.make_tensor_value_info("a_np1", TensorProto.INT32, list(a_np1.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.INT64, list(out_shape))], - ) + if keepdims is not None: + keepdims_attr = helper.make_attribute("keepdims", keepdims) + node.attribute.append(keepdims_attr) + if axis is not None: + axis_attr = helper.make_attribute("axis", axis) + node.attribute.append(axis_attr) - model = helper.make_model(graph, producer_name="argreduce_test") - verify_with_ort_with_inputs(model, [a_np1]) + graph = helper.make_graph( + [node], + "argreduce_test", + inputs=[helper.make_tensor_value_info("a_np1", TensorProto.INT32, list(a_np1.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.INT64, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="argreduce_test") + verify_with_ort_with_inputs(model, [a_np1], target=target, dev=dev) -# TODO (mbrookhart, electriclilies) Fix argmin on GPU and enable this test -# @tvm.testing.uses_gpu -def test_forward_arg_min_max(): """Verify argmin and argmax""" verify_argreduce([3, 4, 4], "ArgMin") verify_argreduce([3, 4, 4], "ArgMax") @@ -1595,122 +1682,126 @@ def test_forward_arg_min_max(): verify_argreduce([3, 4, 4], "ArgMax", axis, keepdims) -def verify_constantofshape(input_dim, value, dtype): - fill_node = helper.make_node( - "ConstantOfShape", - ["input"], - ["output"], - value=helper.make_tensor( - "value", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], (1,), (value,) - ), - ) - - inputs = [helper.make_tensor_value_info("input", TensorProto.INT64, [len(input_dim)])] +@tvm.testing.parametrize_targets +def test_constantofshape(target, dev): + def verify_constantofshape(input_dim, value, dtype): + fill_node = helper.make_node( + "ConstantOfShape", + ["input"], + ["output"], + value=helper.make_tensor( + "value", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], (1,), (value,) + ), + ) - graph = helper.make_graph( - [fill_node], - "fill_test", - inputs, - outputs=[ - helper.make_tensor_value_info( - "output", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], input_dim - ) - ], - ) + inputs = [helper.make_tensor_value_info("input", TensorProto.INT64, [len(input_dim)])] - model = helper.make_model(graph, producer_name="fill_test") - input_np = np.array(input_dim).astype("int64") - verify_with_ort_with_inputs(model, [input_np], use_vm=True) + graph = helper.make_graph( + [fill_node], + "fill_test", + inputs, + outputs=[ + helper.make_tensor_value_info( + "output", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], input_dim + ) + ], + ) + model = helper.make_model(graph, producer_name="fill_test") + input_np = np.array(input_dim).astype("int64") + verify_with_ort_with_inputs(model, [input_np], use_vm=True, target=target, dev=dev) -@tvm.testing.uses_gpu -def test_constantofshape(): verify_constantofshape((2, 3, 4, 5), 10, "float32") verify_constantofshape((3, 3), 0, "int32") verify_constantofshape((1, 2, 3), -1, "float32") -def verify_pad(indata, pads, mode="constant", value=0.0): - indata = np.array(indata).astype(np.float32) - # numpy expect result - len_dim = len(pads) // 2 - np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)] - # onnx graph - if mode in ["edge", "reflect"]: - outdata = np.pad(indata, pad_width=np_pads, mode=mode) - node = helper.make_node( - "Pad", - inputs=["input"], - outputs=["output"], - mode=mode, - pads=pads, - ) - else: - outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) - node = helper.make_node( - "Pad", inputs=["input"], outputs=["output"], mode="constant", pads=pads, value=value - ) - graph = helper.make_graph( - [node], - "pad_test", - inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape))], - ) - model = helper.make_model(graph, producer_name="pad_test") - verify_with_ort_with_inputs(model, [indata], [outdata.shape], dtype="float32", opset=2) - - -def verify_pad_v11(indata, pads, mode="constant", value=0.0): - indata = np.array(indata).astype(np.float32) - # numpy expect result - len_dim = len(pads) // 2 - np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)] - pads = np.array(pads) - # onnx graph - if mode in ["edge", "reflect"]: - inputs = [indata] - outdata = np.pad(indata, pad_width=np_pads, mode=mode) - node = helper.make_node("Pad", inputs=["input", "pads"], outputs=["output"], mode=mode) +@tvm.testing.parametrize_targets +def test_pad(target, dev): + def verify_pad(indata, pads, mode="constant", value=0.0): + indata = np.array(indata).astype(np.float32) + # numpy expect result + len_dim = len(pads) // 2 + np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)] + # onnx graph + if mode in ["edge", "reflect"]: + outdata = np.pad(indata, pad_width=np_pads, mode=mode) + node = helper.make_node( + "Pad", + inputs=["input"], + outputs=["output"], + mode=mode, + pads=pads, + ) + else: + outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) + node = helper.make_node( + "Pad", inputs=["input"], outputs=["output"], mode="constant", pads=pads, value=value + ) graph = helper.make_graph( [node], "pad_test", - inputs=[ - helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)), - helper.make_tensor_value_info("pads", TensorProto.INT64, (len(pads),)), - ], - initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads)], + inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape))], outputs=[ helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) ], ) - else: - inputs = [indata] - outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) - node = helper.make_node( - "Pad", inputs=["input", "pads", "constant_value"], outputs=["output"], mode="constant" - ) - graph = helper.make_graph( - [node], - "pad_test", - inputs=[ - helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)), - helper.make_tensor_value_info("pads", TensorProto.INT64, (len(pads),)), - helper.make_tensor_value_info("constant_value", TensorProto.FLOAT, (1,)), - ], - initializer=[ - helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads), - helper.make_tensor("constant_value", TensorProto.FLOAT, (1,), [value]), - ], - outputs=[ - helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) - ], + model = helper.make_model(graph, producer_name="pad_test") + verify_with_ort_with_inputs( + model, [indata], [outdata.shape], dtype="float32", opset=2, target=target, dev=dev ) - model = helper.make_model(graph, producer_name="pad_test") - verify_with_ort_with_inputs(model, inputs, opset=11, use_vm=True) + def verify_pad_v11(indata, pads, mode="constant", value=0.0): + indata = np.array(indata).astype(np.float32) + # numpy expect result + len_dim = len(pads) // 2 + np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)] + pads = np.array(pads) + # onnx graph + if mode in ["edge", "reflect"]: + inputs = [indata] + outdata = np.pad(indata, pad_width=np_pads, mode=mode) + node = helper.make_node("Pad", inputs=["input", "pads"], outputs=["output"], mode=mode) + graph = helper.make_graph( + [node], + "pad_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("pads", TensorProto.INT64, (len(pads),)), + ], + initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads)], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) + ], + ) + else: + inputs = [indata] + outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) + node = helper.make_node( + "Pad", + inputs=["input", "pads", "constant_value"], + outputs=["output"], + mode="constant", + ) + graph = helper.make_graph( + [node], + "pad_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("pads", TensorProto.INT64, (len(pads),)), + helper.make_tensor_value_info("constant_value", TensorProto.FLOAT, (1,)), + ], + initializer=[ + helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads), + helper.make_tensor("constant_value", TensorProto.FLOAT, (1,), [value]), + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) + ], + ) + model = helper.make_model(graph, producer_name="pad_test") + verify_with_ort_with_inputs(model, inputs, opset=11, use_vm=True, target=target, dev=dev) -@tvm.testing.uses_gpu -def test_pad(): verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], "constant", 0.0) verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], "constant", 0.0) verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], "constant", 5.0) @@ -1726,31 +1817,30 @@ def test_pad(): ) -def verify_reduce_func(func, data, axis, keepdims): - inshape = data.shape - outshape = np.sum(data, axis=axis, keepdims=keepdims == 1).shape - - if axis: - node = onnx.helper.make_node( - func, inputs=["x"], outputs=["y"], axes=axis, keepdims=keepdims - ) - else: - node = onnx.helper.make_node(func, inputs=["x"], outputs=["y"], keepdims=keepdims) +@tvm.testing.parametrize_targets +def test_all_reduce_funcs(target, dev): + def verify_reduce_func(func, data, axis, keepdims): + inshape = data.shape + outshape = np.sum(data, axis=axis, keepdims=keepdims == 1).shape - graph = helper.make_graph( - [node], - "reduce_test", - inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))], - ) + if axis: + node = onnx.helper.make_node( + func, inputs=["x"], outputs=["y"], axes=axis, keepdims=keepdims + ) + else: + node = onnx.helper.make_node(func, inputs=["x"], outputs=["y"], keepdims=keepdims) - model = helper.make_model(graph, producer_name="reduce_test") + graph = helper.make_graph( + [node], + "reduce_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))], + ) - verify_with_ort_with_inputs(model, [data], [outshape], opset=11) + model = helper.make_model(graph, producer_name="reduce_test") + verify_with_ort_with_inputs(model, [data], [outshape], opset=11, target=target, dev=dev) -@tvm.testing.uses_gpu -def test_all_reduce_funcs(): funcs = [ "ReduceMax", "ReduceMean", @@ -1791,58 +1881,64 @@ def test_all_reduce_funcs(): ) -def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11): - indata = np.array(indata).astype(np.float32) - outdatas = [np.array(o).astype(np.float32) for o in outdatas] - inputs = [helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape))] - input_names = ["input"] - initializer = [] - - if split: - split_index = range(len(split)) - else: - split_index = range(len(outdatas)) - - if pass_split: - if opset >= 13: - input_names.append("split") - np_split = np.array(split).astype(np.int64) - inputs.append( - helper.make_tensor_value_info("split", TensorProto.INT64, list(np_split.shape)) - ) - indata = [indata, np_split] - initializer.append( - helper.make_tensor("split", TensorProto.INT64, list(np_split.shape), np_split) - ) - node = helper.make_node( - "Split", - inputs=input_names, - outputs=["output_{}".format(i) for i in range(len(split_index))], - axis=axis, - ) +@tvm.testing.parametrize_targets +def test_split(target, dev): + def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11): + indata = np.array(indata).astype(np.float32) + outdatas = [np.array(o).astype(np.float32) for o in outdatas] + inputs = [helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape))] + input_names = ["input"] + initializer = [] - if pass_split and opset < 13: - split_attr = helper.make_attribute("split", split) - node.attribute.append(split_attr) + if split: + split_index = range(len(split)) + else: + split_index = range(len(outdatas)) + + if pass_split: + if opset >= 13: + input_names.append("split") + np_split = np.array(split).astype(np.int64) + inputs.append( + helper.make_tensor_value_info("split", TensorProto.INT64, list(np_split.shape)) + ) + indata = [indata, np_split] + initializer.append( + helper.make_tensor("split", TensorProto.INT64, list(np_split.shape), np_split) + ) + node = helper.make_node( + "Split", + inputs=input_names, + outputs=["output_{}".format(i) for i in range(len(split_index))], + axis=axis, + ) - graph = helper.make_graph( - [node], - "split_test", - inputs=inputs, - initializer=initializer, - outputs=[ - helper.make_tensor_value_info( - "output_{}".format(i), TensorProto.FLOAT, list(outdatas[i].shape) - ) - for i in range(len(split_index)) - ], - ) - model = helper.make_model(graph, producer_name="split_test") - verify_with_ort_with_inputs(model, indata, out_shape=list(range(len(split_index))), opset=opset) + if pass_split and opset < 13: + split_attr = helper.make_attribute("split", split) + node.attribute.append(split_attr) + graph = helper.make_graph( + [node], + "split_test", + inputs=inputs, + initializer=initializer, + outputs=[ + helper.make_tensor_value_info( + "output_{}".format(i), TensorProto.FLOAT, list(outdatas[i].shape) + ) + for i in range(len(split_index)) + ], + ) + model = helper.make_model(graph, producer_name="split_test") + verify_with_ort_with_inputs( + model, + indata, + out_shape=list(range(len(split_index))), + opset=opset, + target=target, + dev=dev, + ) -@tvm.testing.uses_gpu -def test_split(): # 1D verify_split([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [2, 2, 2], 0) verify_split( @@ -1862,8 +1958,8 @@ def test_split(): verify_split([1], [[1]], [1], pass_split=True) -@tvm.testing.uses_gpu -def test_binary_ops(): +@tvm.testing.parametrize_targets +def test_binary_ops(target, dev): in_shape = (1, 2, 3, 3) dtype = "float32" out_shape = in_shape @@ -1884,7 +1980,7 @@ def verify_binary_ops(op, x, y, out_type="float32"): ], ) model = helper.make_model(graph, producer_name="_test") - verify_with_ort_with_inputs(model, [x, y]) + verify_with_ort_with_inputs(model, [x, y], target=target, dev=dev) x = np.random.uniform(size=in_shape).astype(dtype) y = np.random.uniform(size=in_shape).astype(dtype) @@ -1907,8 +2003,8 @@ def verify_binary_ops(op, x, y, out_type="float32"): verify_binary_ops("Equal", x, z, "bool") -@tvm.testing.uses_gpu -def test_unary_ops(): +@tvm.testing.parametrize_targets +def test_unary_ops(target, dev): in_shape = (1, 2, 3, 3) dtype = "float32" out_shape = in_shape @@ -1926,7 +2022,7 @@ def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5, dtype="float32"): outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, list(out_shape))], ) model = helper.make_model(graph, producer_name="_test") - verify_with_ort_with_inputs(model, [x], rtol=rtol, atol=atol) + verify_with_ort_with_inputs(model, [x], rtol=rtol, atol=atol, target=target, dev=dev) x = np.random.uniform(size=in_shape) verify_unary_ops("Neg", x) @@ -1954,32 +2050,41 @@ def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5, dtype="float32"): verify_unary_ops("Softsign", x) -@tvm.testing.uses_gpu -def test_leaky_relu(): +@tvm.testing.parametrize_targets +def test_leaky_relu(target, dev): def leaky_relu_x(x, alpha): return np.where(x >= 0, x, x * alpha) _test_onnx_op_elementwise( - (2, 4, 5, 6), leaky_relu_x, {"alpha": 0.25}, "float32", "LeakyRelu", {"alpha": 0.25} + target, + dev, + (2, 4, 5, 6), + leaky_relu_x, + {"alpha": 0.25}, + "float32", + "LeakyRelu", + {"alpha": 0.25}, ) -@tvm.testing.uses_gpu -def test_elu(): +@tvm.testing.parametrize_targets +def test_elu(target, dev): def elu_x(x, alpha): return np.where(x > 0, x, alpha * (np.exp(x) - 1.0)) _test_onnx_op_elementwise( - (2, 4, 5, 6), elu_x, {"alpha": 0.25}, "float32", "Elu", {"alpha": 0.25} + target, dev, (2, 4, 5, 6), elu_x, {"alpha": 0.25}, "float32", "Elu", {"alpha": 0.25} ) -@tvm.testing.uses_gpu -def test_selu(): +@tvm.testing.parametrize_targets +def test_selu(target, dev): def selu_x(x, alpha, gamma): return gamma * np.where(x > 0, x, alpha * (np.exp(x) - 1.0)) _test_onnx_op_elementwise( + target, + dev, (2, 4, 5, 6), selu_x, {"alpha": 0.25, "gamma": 0.3}, @@ -1989,8 +2094,8 @@ def selu_x(x, alpha, gamma): ) -@tvm.testing.uses_gpu -def test_prelu(): +@tvm.testing.parametrize_targets +def test_prelu(target, dev): def verify_prelu(x_shape, a_shape): node = helper.make_node("PRelu", inputs=["X", "slope"], outputs=["Y"]) @@ -2012,6 +2117,8 @@ def verify_prelu(x_shape, a_shape): out_shape=[list(x_shape)], use_vm=True, convert_to_static=True, + target=target, + dev=dev, ) verify_prelu([3, 4, 5, 6], [1, 4, 1, 1]) @@ -2021,14 +2128,16 @@ def verify_prelu(x_shape, a_shape): verify_prelu([3, 1], [3, 1]) # Test non NCHW workload. -@tvm.testing.uses_gpu -def test_ThresholdedRelu(): +@tvm.testing.parametrize_targets +def test_ThresholdedRelu(target, dev): def ThresholdedRelu_x(x, alpha): out_np = np.clip(x, alpha, np.inf) out_np[out_np == alpha] = 0 return out_np _test_onnx_op_elementwise( + target, + dev, (2, 4, 5, 6), ThresholdedRelu_x, {"alpha": 0.25}, @@ -2038,26 +2147,35 @@ def ThresholdedRelu_x(x, alpha): ) -@tvm.testing.uses_gpu -def test_LogSoftmax(): +@tvm.testing.parametrize_targets +def test_LogSoftmax(target, dev): _test_onnx_op_elementwise( - (1, 4), tvm.topi.testing.log_softmax_python, {}, "float32", "LogSoftmax", {"axis": 1} + target, + dev, + (1, 4), + tvm.topi.testing.log_softmax_python, + {}, + "float32", + "LogSoftmax", + {"axis": 1}, ) -def check_torch_conversion(model, input_size): +def check_torch_conversion(model, input_size, target, dev): dummy_input = torch.randn(*input_size) file_name = "{}.onnx".format(model.__name__) # Set verbose=True for more output torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False) onnx_model = onnx.load(file_name) input_data = np.random.uniform(size=input_size).astype("float32") - verify_with_ort_with_inputs(onnx_model, [input_data], apply_softmax=True) + verify_with_ort_with_inputs( + onnx_model, [input_data], apply_softmax=True, target=target, dev=dev + ) -@tvm.testing.uses_gpu -def test_resnet(): - check_torch_conversion(torchvision.models.resnet18, (1, 3, 224, 224)) +@tvm.testing.parametrize_targets +def test_resnet(target, dev): + check_torch_conversion(torchvision.models.resnet18, (1, 3, 224, 224), target, dev) # check_torch_conversion(torchvision.models.resnet101, (1,3,224,224)) @@ -2075,14 +2193,14 @@ def test_resnet(): # check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224)) -@tvm.testing.uses_gpu -def test_densenet(): - check_torch_conversion(torchvision.models.densenet161, (1, 3, 224, 224)) +@tvm.testing.parametrize_targets +def test_densenet(target, dev): + check_torch_conversion(torchvision.models.densenet161, (1, 3, 224, 224), target, dev) -@tvm.testing.uses_gpu -def test_inception(): - check_torch_conversion(torchvision.models.inception_v3, (1, 3, 224, 224)) +@tvm.testing.parametrize_targets +def test_inception(target, dev): + check_torch_conversion(torchvision.models.inception_v3, (1, 3, 224, 224), target, dev) # TODO(@jroesch): Update Torch + ONNX to support this import. @@ -2094,36 +2212,35 @@ def test_inception(): # check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224)) -@tvm.testing.uses_gpu -def test_sign(): +@tvm.testing.parametrize_targets +def test_sign(target, dev): def Sign_x(x): return np.sign(x) - _test_onnx_op_elementwise((3, 4, 5, 6), Sign_x, {}, "float32", "Sign", {}) - + _test_onnx_op_elementwise(target, dev, (3, 4, 5, 6), Sign_x, {}, "float32", "Sign", {}) -def verify_not(indata, dtype): - x = indata.astype(dtype) - node = helper.make_node( - "Not", - inputs=["in"], - outputs=["out"], - ) +@tvm.testing.parametrize_targets +def test_not(target, dev): + def verify_not(indata, dtype): + x = indata.astype(dtype) - graph = helper.make_graph( - [node], - "not_test", - inputs=[helper.make_tensor_value_info("in", TensorProto.BOOL, list(x.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(x.shape))], - ) + node = helper.make_node( + "Not", + inputs=["in"], + outputs=["out"], + ) - model = helper.make_model(graph, producer_name="not_test") - verify_with_ort_with_inputs(model, [x]) + graph = helper.make_graph( + [node], + "not_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.BOOL, list(x.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(x.shape))], + ) + model = helper.make_model(graph, producer_name="not_test") + verify_with_ort_with_inputs(model, [x], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_not(): # 2d verify_not(indata=(np.random.randn(3, 4) > 0), dtype=bool) # 3d @@ -2132,33 +2249,32 @@ def test_not(): verify_not(indata=(np.random.randn(3, 4, 5, 6) > 0), dtype=bool) -def verify_and(indata, dtype): - x = indata[0].astype(dtype) - y = indata[1].astype(dtype) - outdata = np.logical_and(x, y) +@tvm.testing.parametrize_targets +def test_and(target, dev): + def verify_and(indata, dtype): + x = indata[0].astype(dtype) + y = indata[1].astype(dtype) + outdata = np.logical_and(x, y) - node = helper.make_node( - "And", - inputs=["in1", "in2"], - outputs=["out"], - ) - - graph = helper.make_graph( - [node], - "and_test", - inputs=[ - helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)), - helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape)), - ], - outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))], - ) + node = helper.make_node( + "And", + inputs=["in1", "in2"], + outputs=["out"], + ) - model = helper.make_model(graph, producer_name="and_test") - verify_with_ort_with_inputs(model, [x, y], [outdata.shape]) + graph = helper.make_graph( + [node], + "and_test", + inputs=[ + helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)), + helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))], + ) + model = helper.make_model(graph, producer_name="and_test") + verify_with_ort_with_inputs(model, [x, y], [outdata.shape], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_and(): # 2d x = np.random.randn(3, 4) > 0 y = np.random.randn(3, 4) > 0 @@ -2185,75 +2301,76 @@ def test_and(): verify_and(indata=[x, y], dtype=bool) -def verify_tile_v6(indata, repeats, outdata): - node = helper.make_node("Tile", inputs=["input", "repeats"], outputs=["out"]) - graph = helper.make_graph( - [node], - "tile_test", - inputs=[ - helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)), - helper.make_tensor_value_info("repeats", TensorProto.INT64, list(repeats.shape)), - ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], - ) - - model = helper.make_model(graph, producer_name="tile_test") - verify_with_ort_with_inputs(model, [indata, repeats], use_vm=True, opset=6) +@tvm.testing.parametrize_targets +def test_tile(target, dev): + def verify_tile_v6(indata, repeats, outdata): + node = helper.make_node("Tile", inputs=["input", "repeats"], outputs=["out"]) + graph = helper.make_graph( + [node], + "tile_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("repeats", TensorProto.INT64, list(repeats.shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + ) + model = helper.make_model(graph, producer_name="tile_test") + verify_with_ort_with_inputs( + model, [indata, repeats], use_vm=True, opset=6, target=target, dev=dev + ) -@tvm.testing.uses_gpu -def test_tile(): x = np.random.rand(2, 3, 4, 5).astype(np.float32) repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64) z = np.tile(x, repeats) verify_tile_v6(x, repeats, z) -def verify_erf(indata, outdata): - node = helper.make_node("Erf", inputs=["in"], outputs=["out"]) - graph = helper.make_graph( - [node], - "erf_test", - inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], - ) - model = helper.make_model(graph, producer_name="erf_test") - verify_with_ort_with_inputs(model, [indata], [outdata.shape]) - +@tvm.testing.parametrize_targets +def test_erf(target, dev): + def verify_erf(indata, outdata): + node = helper.make_node("Erf", inputs=["in"], outputs=["out"]) + graph = helper.make_graph( + [node], + "erf_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + ) + model = helper.make_model(graph, producer_name="erf_test") + verify_with_ort_with_inputs(model, [indata], [outdata.shape], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_erf(): x = np.random.rand(2, 3, 4, 6).astype(np.float32) z = scipy.special.erf(x) verify_erf(x, z) -def verify_where(condition, x, y, dtype, outdata, dynamic=False): - node_list = [] - where_inputs = ["condition", "x", "y"] - if dynamic: - shape_node = helper.make_node("Shape", ["x"], ["shape"]) - reshape_node = helper.make_node("Reshape", ["x", "shape"], ["X"]) - where_inputs[1] = "X" - node_list += [shape_node, reshape_node] - node = helper.make_node("Where", inputs=where_inputs, outputs=["out"]) - node_list.append(node) - graph = helper.make_graph( - node_list, - "where_test", - inputs=[ - helper.make_tensor_value_info("condition", TensorProto.BOOL, list(condition.shape)), - helper.make_tensor_value_info("x", dtype, list(x.shape)), - helper.make_tensor_value_info("y", dtype, list(y.shape)), - ], - outputs=[helper.make_tensor_value_info("out", dtype, list(outdata.shape))], - ) - model = helper.make_model(graph, producer_name="where_test") - verify_with_ort_with_inputs(model, [condition, x, y], [outdata.shape], use_vm=True) - +@tvm.testing.parametrize_targets +def test_where(target, dev): + def verify_where(condition, x, y, dtype, outdata, dynamic=False): + node_list = [] + where_inputs = ["condition", "x", "y"] + if dynamic: + shape_node = helper.make_node("Shape", ["x"], ["shape"]) + reshape_node = helper.make_node("Reshape", ["x", "shape"], ["X"]) + where_inputs[1] = "X" + node_list += [shape_node, reshape_node] + node = helper.make_node("Where", inputs=where_inputs, outputs=["out"]) + node_list.append(node) + graph = helper.make_graph( + node_list, + "where_test", + inputs=[ + helper.make_tensor_value_info("condition", TensorProto.BOOL, list(condition.shape)), + helper.make_tensor_value_info("x", dtype, list(x.shape)), + helper.make_tensor_value_info("y", dtype, list(y.shape)), + ], + outputs=[helper.make_tensor_value_info("out", dtype, list(outdata.shape))], + ) + model = helper.make_model(graph, producer_name="where_test") + verify_with_ort_with_inputs( + model, [condition, x, y], [outdata.shape], use_vm=True, target=target, dev=dev + ) -@tvm.testing.uses_gpu -def test_where(): condition = np.array([[1, 0], [1, 1]], dtype=bool) x = np.array([[1, 2], [3, 4]], dtype=np.int64) y = np.array([[9, 8], [7, 6]], dtype=np.int64) @@ -2288,33 +2405,32 @@ def test_where(): verify_where(condition, x, y, TensorProto.FLOAT, outdata, dynamic=True) -def verify_or(indata, dtype): - x = indata[0].astype(dtype) - y = indata[1].astype(dtype) - outdata = np.logical_or(x, y) - - node = helper.make_node( - "Or", - inputs=["in1", "in2"], - outputs=["out"], - ) +@tvm.testing.parametrize_targets +def test_or(target, dev): + def verify_or(indata, dtype): + x = indata[0].astype(dtype) + y = indata[1].astype(dtype) + outdata = np.logical_or(x, y) - graph = helper.make_graph( - [node], - "or_test", - inputs=[ - helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)), - helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape)), - ], - outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))], - ) + node = helper.make_node( + "Or", + inputs=["in1", "in2"], + outputs=["out"], + ) - model = helper.make_model(graph, producer_name="or_test") - verify_with_ort_with_inputs(model, [x, y], [outdata.shape]) + graph = helper.make_graph( + [node], + "or_test", + inputs=[ + helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)), + helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))], + ) + model = helper.make_model(graph, producer_name="or_test") + verify_with_ort_with_inputs(model, [x, y], [outdata.shape], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_or(): # 2d x = np.random.randn(3, 4) > 0 y = np.random.randn(3, 4) > 0 @@ -2341,8 +2457,8 @@ def test_or(): verify_or(indata=[x, y], dtype=bool) -@tvm.testing.uses_gpu -def test_batch_norm(): +@tvm.testing.parametrize_targets +def test_batch_norm(target, dev): def verify_batch_norm(in_shape): batchnorm = onnx.helper.make_node( "BatchNormalization", inputs=["x", "scale", "B", "mean", "var"], outputs=["Y"] @@ -2364,7 +2480,7 @@ def verify_batch_norm(in_shape): model = helper.make_model(graph, producer_name="batchnorm_test") # X, scale, b, mean, var inshapes = [in_shape, in_shape[1], in_shape[1], in_shape[1], in_shape[1]] - verify_with_ort(model, inshapes, out_shape=[in_shape]) + verify_with_ort(model, inshapes, out_shape=[in_shape], target=target, dev=dev) verify_batch_norm([1, 3, 224, 224]) verify_batch_norm([1, 3, 24, 24]) @@ -2373,8 +2489,8 @@ def verify_batch_norm(in_shape): verify_batch_norm([16, 16, 10, 10]) -@tvm.testing.uses_gpu -def test_batch_norm_dynamic_subgraph(): +@tvm.testing.parametrize_targets +def test_batch_norm_dynamic_subgraph(target, dev): def verify_batch_norm_dynamic_subgraph(in_shape, o_shape): batchnorm = onnx.helper.make_node( @@ -2401,81 +2517,88 @@ def verify_batch_norm_dynamic_subgraph(in_shape, o_shape): # X, inp, scale, b, mean, var inshapes = [in_shape, o_shape, in_shape[1], in_shape[1], in_shape[1], in_shape[1]] - verify_with_ort(model, inshapes, out_shape=[in_shape], use_vm=True) + verify_with_ort(model, inshapes, out_shape=[in_shape], use_vm=True, target=target, dev=dev) verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160]) -def verify_conv( - x_shape, - w_shape, - y_shape, - padding, - kernel_shape, - strides, - dilations, - group=1, - auto_pad="NOTSET", - unset_pad=False, -): - if unset_pad: - node = helper.make_node( - "Conv", - inputs=["x", "W"], - outputs=["y"], - kernel_shape=kernel_shape, - # Default values for other attributes: - strides=strides, - dilations=dilations, - group=group, - ) - elif padding is None: - ## autopadding with unset default attributes - kwargs = {} - if not all([s == 1 for s in strides]): - kwargs["strides"] = strides - if not all([d == 1 for d in dilations]): - kwargs["dilations"] = dilations +@tvm.testing.parametrize_targets +def test_conv(target, dev): + def verify_conv( + x_shape, + w_shape, + y_shape, + padding, + kernel_shape, + strides, + dilations, + group=1, + auto_pad="NOTSET", + unset_pad=False, + ): + if unset_pad: + node = helper.make_node( + "Conv", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + group=group, + ) + elif padding is None: + ## autopadding with unset default attributes + kwargs = {} + if not all([s == 1 for s in strides]): + kwargs["strides"] = strides + if not all([d == 1 for d in dilations]): + kwargs["dilations"] = dilations + + node = helper.make_node( + "Conv", + inputs=["x", "W"], + outputs=["y"], + # Default values for other attributes: + auto_pad=auto_pad, + group=group, + **kwargs, + ) + else: + node = helper.make_node( + "Conv", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + group=group, + pads=padding, + ) - node = helper.make_node( - "Conv", - inputs=["x", "W"], - outputs=["y"], - # Default values for other attributes: - auto_pad=auto_pad, - group=group, - **kwargs, - ) - else: - node = helper.make_node( - "Conv", - inputs=["x", "W"], - outputs=["y"], - kernel_shape=kernel_shape, - # Default values for other attributes: - strides=strides, - dilations=dilations, - group=group, - pads=padding, + graph = helper.make_graph( + [node], + "conv_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))], ) - graph = helper.make_graph( - [node], - "conv_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), - helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)), - ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))], - ) - - model = helper.make_model(graph, producer_name="conv_test") - - verify_with_ort(model, [x_shape, w_shape], [y_shape], use_vm=True, convert_to_static=True) + model = helper.make_model(graph, producer_name="conv_test") + verify_with_ort( + model, + [x_shape, w_shape], + [y_shape], + use_vm=True, + convert_to_static=True, + target=target, + dev=dev, + ) -@tvm.testing.uses_gpu -def test_conv(): def repeat(N, D): return tuple([N for _ in range(D)]) @@ -2490,7 +2613,7 @@ def repeat(N, D): repeat(1, D), repeat(1, D), ) - # Convolution with assymetric padding + # Convolution with asymmetric padding verify_conv( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), @@ -2580,83 +2703,82 @@ def repeat(N, D): ) -def verify_convtranspose_with_padding( - x_shape, - w_shape, - y_shape, - padding, - kernel_shape, - strides, - dilations, - auto_pad="NOTSET", - unset_pad=False, - group=1, -): - node = helper.make_node( - "ConvTranspose", - inputs=["x", "W"], - outputs=["y"], - kernel_shape=kernel_shape, - # Default values for other attributes: - strides=strides, - dilations=dilations, - ) - if not unset_pad: - if padding is None: - pad_attr = helper.make_attribute("auto_pad", auto_pad) - else: - pad_attr = helper.make_attribute("pads", padding) - node.attribute.append(pad_attr) - - if group is not None: - group_attr = helper.make_attribute("group", group) - node.attribute.append(group_attr) - - graph = helper.make_graph( - [node], - "convtranspose_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), - helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)), - ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))], - ) +@tvm.testing.parametrize_targets +def test_convtranspose(target, dev): + def verify_convtranspose_with_padding( + x_shape, + w_shape, + padding, + kernel_shape, + strides, + dilations, + auto_pad="NOTSET", + unset_pad=False, + group=1, + ): + node = helper.make_node( + "ConvTranspose", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + ) + if not unset_pad: + if padding is None: + pad_attr = helper.make_attribute("auto_pad", auto_pad) + else: + pad_attr = helper.make_attribute("pads", padding) + node.attribute.append(pad_attr) - model = helper.make_model(graph, producer_name="convtranspose_pad_test") + if group is not None: + group_attr = helper.make_attribute("group", group) + node.attribute.append(group_attr) - verify_with_ort(model, [x_shape, w_shape], [y_shape], use_vm=True, convert_to_static=True) + graph = helper.make_graph( + [node], + "convtranspose_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, ["?"] * len(x_shape))], + ) + model = helper.make_model(graph, producer_name="convtranspose_pad_test") -def verify_convtranspose(x_shape, w_shape, y_shape, p, group=1): - node = onnx.helper.make_node( - "ConvTranspose", - inputs=["x", "W"], - outputs=["y"], - strides=[3, 2], - kernel_shape=[3, 3], - pads=p, - ) + verify_with_ort( + model, [x_shape, w_shape], use_vm=True, convert_to_static=True, target=target, dev=dev + ) - if group is not None: - group_attr = helper.make_attribute("group", group) - node.attribute.append(group_attr) + def verify_convtranspose(x_shape, w_shape, y_shape, p, group=1): + node = onnx.helper.make_node( + "ConvTranspose", + inputs=["x", "W"], + outputs=["y"], + strides=[3, 2], + kernel_shape=[3, 3], + pads=p, + ) - graph = helper.make_graph( - [node], - "verify_convtranspose_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), - helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)), - ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))], - ) + if group is not None: + group_attr = helper.make_attribute("group", group) + node.attribute.append(group_attr) - model = helper.make_model(graph, producer_name="convtranspose_test") - verify_with_ort(model, [x_shape, w_shape], y_shape) + graph = helper.make_graph( + [node], + "verify_convtranspose_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))], + ) + model = helper.make_model(graph, producer_name="convtranspose_test") + verify_with_ort(model, [x_shape, w_shape], y_shape, opset=11, target=target, dev=dev) -@tvm.testing.uses_gpu -def test_convtranspose(): # Convolution Transpose with padding # (1, 1, 3, 3) input tensor # (1, 2, 3, 3) tensor for convolution weights @@ -2669,14 +2791,12 @@ def test_convtranspose(): def repeat(N, D): return tuple([N for _ in range(D)]) - # TODO(mbrookhart): onnxruntime in CI only supports 2D, - # find something else to test 1D and 3D against - for D in [2]: + # Once onnxruntime update is complete + for D in [1, 2, 3]: # Convolution with padding verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(5, D), 2 * repeat(1, D), repeat(3, D), repeat(1, D), @@ -2686,50 +2806,45 @@ def repeat(N, D): verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(7, D), 2 * repeat(0, D), repeat(3, D), repeat(1, D), repeat(1, D), ) - # Convolution with autopadding + # Convolution with unset padding verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(5, D), - None, + 2 * repeat(0, D), repeat(3, D), repeat(1, D), repeat(1, D), - auto_pad="SAME_UPPER", + True, ) - # Convolution with valid autopadding + # Convolution with autopadding verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(7, D), None, repeat(3, D), repeat(1, D), repeat(1, D), - auto_pad="VALID", + auto_pad="SAME_UPPER", ) - # Convolution with unset padding + # Convolution with valid autopadding verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(7, D), - 2 * repeat(0, D), + None, repeat(3, D), repeat(1, D), repeat(1, D), - True, + auto_pad="VALID", ) # Convolution with non uniform stride verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(9, D), None, repeat(3, D), repeat(2, D), @@ -2741,7 +2856,6 @@ def repeat(N, D): # verify_convtranspose_with_padding( # (1, 1) + repeat(5, D), # (1, 1) + repeat(3, D), - # (1, 1) + repeat(5, D), # 2 * repeat(2, D), # repeat(3, D), # repeat(1, D), @@ -2749,8 +2863,8 @@ def repeat(N, D): # ) -@tvm.testing.uses_gpu -def test_unsqueeze_constant(): +@tvm.testing.parametrize_targets +def test_unsqueeze_constant(target, dev): from torch.nn import Linear, Module, Sequential class Flatten(Module): @@ -2770,43 +2884,50 @@ def forward(self, input): relay.frontend.from_onnx(onnx_model, {"0": input_size}) -def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_pad="NOTSET"): - x_np = np.random.uniform(size=x_shape).astype("float32") - - if mode == "max": - node_type = "MaxPool" - elif mode == "average": - node_type = "AveragePool" - else: - raise ValueError("Pool method {} is not supported.".format(mode)) +@tvm.testing.parametrize_targets +def test_pooling(target, dev): + def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_pad="NOTSET"): + x_np = np.random.uniform(size=x_shape).astype("float32") - pool_node = helper.make_node( - node_type, inputs=["x"], outputs=["y"], kernel_shape=kernel_shape, strides=strides - ) + if mode == "max": + node_type = "MaxPool" + elif mode == "average": + node_type = "AveragePool" + else: + raise ValueError("Pool method {} is not supported.".format(mode)) - if pads is None: - pad_attr = helper.make_attribute("auto_pad", auto_pad) - else: - pad_attr = helper.make_attribute("pads", pads) - pool_node.attribute.append(pad_attr) + pool_node = helper.make_node( + node_type, inputs=["x"], outputs=["y"], kernel_shape=kernel_shape, strides=strides + ) - if mode == "max": - storage_attr = helper.make_attribute("storage_order", 0) - pool_node.attribute.append(storage_attr) + if pads is None: + pad_attr = helper.make_attribute("auto_pad", auto_pad) + else: + pad_attr = helper.make_attribute("pads", pads) + pool_node.attribute.append(pad_attr) - graph = helper.make_graph( - [pool_node], - "pooling_test", - inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape))], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))], - ) + if mode == "max": + storage_attr = helper.make_attribute("storage_order", 0) + pool_node.attribute.append(storage_attr) - model = helper.make_model(graph, producer_name="pooling_test") - verify_with_ort(model, [x_shape], [out_shape], use_vm=False, convert_to_static=True) + graph = helper.make_graph( + [pool_node], + "pooling_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="pooling_test") + verify_with_ort( + model, + [x_shape], + [out_shape], + use_vm=False, + convert_to_static=True, + target=target, + dev=dev, + ) -@tvm.testing.uses_gpu -def test_pooling(): for mode in ["max", "average"]: # Pool1D verify_pooling( @@ -2889,31 +3010,38 @@ def test_pooling(): ) -def verify_global_pooling(x_shape, mode): - out_shape = x_shape[:2] + [1] * (len(x_shape) - 2) +@tvm.testing.parametrize_targets +def test_global_pooling(target, dev): + def verify_global_pooling(x_shape, mode): + out_shape = x_shape[:2] + [1] * (len(x_shape) - 2) - if mode == "max": - node_type = "GlobalMaxPool" - elif mode == "average": - node_type = "GlobalAveragePool" - else: - raise ValueError("Pool method {} is not supported.".format(mode)) - - pool_node = helper.make_node(node_type, inputs=["x"], outputs=["y"]) + if mode == "max": + node_type = "GlobalMaxPool" + elif mode == "average": + node_type = "GlobalAveragePool" + else: + raise ValueError("Pool method {} is not supported.".format(mode)) - graph = helper.make_graph( - [pool_node], - "global_pooling_test", - inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape))], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))], - ) + pool_node = helper.make_node(node_type, inputs=["x"], outputs=["y"]) - model = helper.make_model(graph, producer_name="global_pooling_test") - verify_with_ort(model, [x_shape], [out_shape], use_vm=False, convert_to_static=True) + graph = helper.make_graph( + [pool_node], + "global_pooling_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="global_pooling_test") + verify_with_ort( + model, + [x_shape], + [out_shape], + use_vm=False, + convert_to_static=True, + target=target, + dev=dev, + ) -@tvm.testing.uses_gpu -def test_global_pooling(): # Test each pooling mode across all N-D inputs. for mode in ["average", "max"]: # 1D Pooling (NCW) @@ -2927,29 +3055,28 @@ def test_global_pooling(): verify_global_pooling([4, 1, 2, 6, 4], mode) -def verify_mod(x_shape, y_shape, fmod, out_shape, dtype="float32"): - x_np = np.random.uniform(-100.0, 100.0, x_shape).astype(dtype) - y_np = np.random.uniform(-100.0, 100.0, y_shape).astype(dtype) - y_np = np.where(y_np == 0, 1, y_np) # remove 0's to avoid division by zero error - - mod_node = helper.make_node("Mod", inputs=["x", "y"], outputs=["z"], fmod=fmod) +@tvm.testing.parametrize_targets +def test_mod(target, dev): + def verify_mod(x_shape, y_shape, fmod, out_shape, dtype="float32"): + x_np = np.random.uniform(-100.0, 100.0, x_shape).astype(dtype) + y_np = np.random.uniform(-100.0, 100.0, y_shape).astype(dtype) + y_np = np.where(y_np == 0, 1, y_np) # remove 0's to avoid division by zero error - onnx_dtype = TensorProto.FLOAT if dtype == "float32" else TensorProto.INT32 - graph = helper.make_graph( - [mod_node], - "mod_test", - inputs=[ - helper.make_tensor_value_info("x", onnx_dtype, list(x_shape)), - helper.make_tensor_value_info("y", onnx_dtype, list(y_shape)), - ], - outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))], - ) - model = helper.make_model(graph, producer_name="mod_test") - verify_with_ort_with_inputs(model, [x_np, y_np], [out_shape]) + mod_node = helper.make_node("Mod", inputs=["x", "y"], outputs=["z"], fmod=fmod) + onnx_dtype = TensorProto.FLOAT if dtype == "float32" else TensorProto.INT32 + graph = helper.make_graph( + [mod_node], + "mod_test", + inputs=[ + helper.make_tensor_value_info("x", onnx_dtype, list(x_shape)), + helper.make_tensor_value_info("y", onnx_dtype, list(y_shape)), + ], + outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="mod_test") + verify_with_ort_with_inputs(model, [x_np, y_np], [out_shape], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_mod(): # Mod verify_mod( x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, out_shape=(1, 32, 32), dtype="int32" @@ -2978,31 +3105,30 @@ def test_mod(): verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32)) -def verify_xor(x_shape, y_shape): - x_np = np.random.choice(a=[False, True], size=x_shape).astype("bool") - y_np = np.random.choice(a=[False, True], size=y_shape).astype("bool") +@tvm.testing.parametrize_targets +def test_xor(target, dev): + def verify_xor(x_shape, y_shape): + x_np = np.random.choice(a=[False, True], size=x_shape).astype("bool") + y_np = np.random.choice(a=[False, True], size=y_shape).astype("bool") - np_out = np.logical_xor(x_np, y_np) - out_shape = np_out.shape + np_out = np.logical_xor(x_np, y_np) + out_shape = np_out.shape - xor_node = helper.make_node("Xor", inputs=["x", "y"], outputs=["z"]) - - onnx_dtype = TensorProto.BOOL - graph = helper.make_graph( - [xor_node], - "xor_test", - inputs=[ - helper.make_tensor_value_info("x", onnx_dtype, list(x_shape)), - helper.make_tensor_value_info("y", onnx_dtype, list(y_shape)), - ], - outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))], - ) - model = helper.make_model(graph, producer_name="xor_test") - verify_with_ort_with_inputs(model, [x_np, y_np], [out_shape]) + xor_node = helper.make_node("Xor", inputs=["x", "y"], outputs=["z"]) + onnx_dtype = TensorProto.BOOL + graph = helper.make_graph( + [xor_node], + "xor_test", + inputs=[ + helper.make_tensor_value_info("x", onnx_dtype, list(x_shape)), + helper.make_tensor_value_info("y", onnx_dtype, list(y_shape)), + ], + outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="xor_test") + verify_with_ort_with_inputs(model, [x_np, y_np], [out_shape], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_xor(): # XOR verify_xor(x_shape=[1, 32, 32], y_shape=[1, 32, 32]) @@ -3010,36 +3136,35 @@ def test_xor(): verify_xor(x_shape=[1, 32, 32], y_shape=[1, 1, 32]) -def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_shape): - if spatial_scale is None: - pool_node = helper.make_node( - "MaxRoiPool", inputs=["x", "rois"], outputs=["y"], pooled_shape=pooled_shape - ) - else: - pool_node = helper.make_node( - "MaxRoiPool", - inputs=["x", "rois"], - outputs=["y"], - pooled_shape=pooled_shape, - spatial_scale=spatial_scale, - ) - - graph = helper.make_graph( - [pool_node], - "pool_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), - helper.make_tensor_value_info("rois", TensorProto.FLOAT, list(rois_shape)), - ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))], - ) +@tvm.testing.parametrize_targets +def test_max_roi_pool(target, dev): + def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_shape): + if spatial_scale is None: + pool_node = helper.make_node( + "MaxRoiPool", inputs=["x", "rois"], outputs=["y"], pooled_shape=pooled_shape + ) + else: + pool_node = helper.make_node( + "MaxRoiPool", + inputs=["x", "rois"], + outputs=["y"], + pooled_shape=pooled_shape, + spatial_scale=spatial_scale, + ) - model = helper.make_model(graph, producer_name="pool_test") - verify_with_ort(model, [x_shape, rois_shape], [out_shape]) + graph = helper.make_graph( + [pool_node], + "pool_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("rois", TensorProto.FLOAT, list(rois_shape)), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="pool_test") + verify_with_ort(model, [x_shape, rois_shape], [out_shape], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_max_roi_pool(): verify_max_roi_pool( x_shape=[1, 3, 6, 6], rois_shape=[3, 5], @@ -3057,41 +3182,48 @@ def test_max_roi_pool(): ) -def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="NOTSET"): - if pads is None: - pool_node = helper.make_node( - "LpPool", - inputs=["x"], - outputs=["y"], - kernel_shape=kernel_shape, - p=p, - auto_pad=auto_pad, - strides=strides, - ) - else: - pool_node = helper.make_node( - "LpPool", - inputs=["x"], - outputs=["y"], - kernel_shape=kernel_shape, - p=p, - pads=pads, - strides=strides, - ) - - graph = helper.make_graph( - [pool_node], - "lppool_test", - inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape))], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))], - ) +@tvm.testing.parametrize_targets +def test_lppool(target, dev): + def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="NOTSET"): + if pads is None: + pool_node = helper.make_node( + "LpPool", + inputs=["x"], + outputs=["y"], + kernel_shape=kernel_shape, + p=p, + auto_pad=auto_pad, + strides=strides, + ) + else: + pool_node = helper.make_node( + "LpPool", + inputs=["x"], + outputs=["y"], + kernel_shape=kernel_shape, + p=p, + pads=pads, + strides=strides, + ) - model = helper.make_model(graph, producer_name="lppool_test") - verify_with_ort(model, [x_shape], [out_shape], use_vm=True, convert_to_static=True) + graph = helper.make_graph( + [pool_node], + "lppool_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="lppool_test") + verify_with_ort( + model, + [x_shape], + [out_shape], + use_vm=True, + convert_to_static=True, + target=target, + dev=dev, + ) -@tvm.testing.uses_gpu -def test_lppool(): # Pool1D verify_lppool( x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[1], pads=[1, 1], out_shape=[1, 1, 32] @@ -3180,6 +3312,8 @@ def verify_rnn( use_peep=False, linear_before_reset=False, directions=1, + target=None, + dev=None, ): if rnn_type == "LSTM": multiplier = 4 @@ -3299,11 +3433,13 @@ def register(name, shape, proto_type): model = helper.make_model(graph, producer_name="rnn_test") - verify_with_ort_with_inputs(model, input_values, output_shapes, atol=1e-2, rtol=1e-2) + verify_with_ort_with_inputs( + model, input_values, output_shapes, atol=1e-2, rtol=1e-2, target=target, dev=dev + ) -@tvm.testing.uses_gpu -def test_lstm(): +@tvm.testing.parametrize_targets +def test_lstm(target, dev): for directions in [1, 2]: # No bias. verify_rnn( @@ -3314,6 +3450,8 @@ def test_lstm(): use_bias=False, rnn_type="LSTM", directions=directions, + target=target, + dev=dev, ) # large batch. verify_rnn( @@ -3324,6 +3462,8 @@ def test_lstm(): use_bias=True, rnn_type="LSTM", directions=directions, + target=target, + dev=dev, ) # Non power of two. verify_rnn( @@ -3334,6 +3474,8 @@ def test_lstm(): use_bias=True, rnn_type="LSTM", directions=directions, + target=target, + dev=dev, ) # Long sequence. verify_rnn( @@ -3344,6 +3486,8 @@ def test_lstm(): use_bias=True, rnn_type="LSTM", directions=directions, + target=target, + dev=dev, ) # Large hidden. verify_rnn( @@ -3354,6 +3498,8 @@ def test_lstm(): use_bias=True, rnn_type="LSTM", directions=directions, + target=target, + dev=dev, ) # Large input. verify_rnn( @@ -3364,6 +3510,8 @@ def test_lstm(): use_bias=True, rnn_type="LSTM", directions=directions, + target=target, + dev=dev, ) # Different activation testing. @@ -3377,8 +3525,10 @@ def test_lstm(): activations=["HardSigmoid", "Tanh", "Tanh"] * directions, rnn_type="LSTM", directions=directions, + target=target, + dev=dev, ) - # Multiple parameterized activations. + # Multiple parametrized activations. verify_rnn( seq_length=2, batch_size=1, @@ -3390,8 +3540,10 @@ def test_lstm(): betas=[0.3, 0.0, 0.0] * directions, rnn_type="LSTM", directions=directions, + target=target, + dev=dev, ) - # All parameterized with new Affine activation. + # All parametrized with new Affine activation. verify_rnn( seq_length=2, batch_size=1, @@ -3403,6 +3555,8 @@ def test_lstm(): betas=[0.3, 0.1, 0.0] * directions, rnn_type="LSTM", directions=directions, + target=target, + dev=dev, ) # Testing with initial state and peepholes @@ -3415,6 +3569,8 @@ def test_lstm(): use_initial_state=True, rnn_type="LSTM", directions=directions, + target=target, + dev=dev, ) verify_rnn( @@ -3427,11 +3583,13 @@ def test_lstm(): use_peep=True, rnn_type="LSTM", directions=directions, + target=target, + dev=dev, ) -@tvm.testing.uses_gpu -def test_gru(): +@tvm.testing.parametrize_targets +def test_gru(target, dev): for directions in [1, 2]: # No bias. verify_rnn( @@ -3442,6 +3600,8 @@ def test_gru(): use_bias=False, rnn_type="GRU", directions=directions, + target=target, + dev=dev, ) # large batch. verify_rnn( @@ -3453,6 +3613,8 @@ def test_gru(): rnn_type="GRU", linear_before_reset=True, directions=directions, + target=target, + dev=dev, ) # Non power of two. verify_rnn( @@ -3463,6 +3625,8 @@ def test_gru(): use_bias=True, rnn_type="GRU", directions=directions, + target=target, + dev=dev, ) # Long sequence. verify_rnn( @@ -3473,6 +3637,8 @@ def test_gru(): use_bias=True, rnn_type="GRU", directions=directions, + target=target, + dev=dev, ) # Large hidden. verify_rnn( @@ -3483,6 +3649,8 @@ def test_gru(): use_bias=True, rnn_type="GRU", directions=directions, + target=target, + dev=dev, ) # Large input. verify_rnn( @@ -3493,6 +3661,8 @@ def test_gru(): use_bias=True, rnn_type="GRU", directions=directions, + target=target, + dev=dev, ) # Different activation testing. @@ -3506,8 +3676,10 @@ def test_gru(): activations=["HardSigmoid", "Softsign"] * directions, rnn_type="GRU", directions=directions, + target=target, + dev=dev, ) - # Multiple parameterized activations. + # Multiple parametrized activations. verify_rnn( seq_length=2, batch_size=1, @@ -3519,8 +3691,10 @@ def test_gru(): betas=[0.3, 0.0] * directions, rnn_type="GRU", directions=directions, + target=target, + dev=dev, ) - # All parameterized with new Affine activation. + # All parametrized with new Affine activation. verify_rnn( seq_length=2, batch_size=1, @@ -3532,6 +3706,8 @@ def test_gru(): betas=[0.3, 0.1] * directions, rnn_type="GRU", directions=directions, + target=target, + dev=dev, ) # Testing with initial state @@ -3544,11 +3720,13 @@ def test_gru(): use_initial_state=True, rnn_type="GRU", directions=directions, + target=target, + dev=dev, ) -@tvm.testing.uses_gpu -def test_resize(): +@tvm.testing.parametrize_targets +def test_resize(target, dev): def verify(ishape, oshape, scales, mode, coord_trans="asymmetric", alpha=0.5, exclude=False): nodes = [ make_constant_node("roi", onnx.TensorProto.FLOAT, (0,), []), @@ -3583,7 +3761,16 @@ def verify(ishape, oshape, scales, mode, coord_trans="asymmetric", alpha=0.5, ex model = helper.make_model(graph, producer_name="resize_test") - verify_with_ort(model, [ishape], [oshape], use_vm=True, opset=11, freeze_params=True) + verify_with_ort( + model, + [ishape], + [oshape], + use_vm=True, + opset=11, + freeze_params=True, + target=target, + dev=dev, + ) for ndim in [1, 2, 3]: method = "nearest" @@ -3596,17 +3783,14 @@ def verify(ishape, oshape, scales, mode, coord_trans="asymmetric", alpha=0.5, ex verify([1, 16] + [32] * ndim, [], [1, 1] + [0.5] * ndim, method, coord_trans) verify([1, 16] + [32] * ndim, [], [1, 1] + [2] * ndim, method, coord_trans) - if ndim == 2: - ## TODO(mbrookhart): ONNX Runtime in CI only supports 2D linear resize - ## Remove this condition when updating CI - method = "linear" - # upsampling - verify([1, 16] + [32] * ndim, [1, 16] + [64] * ndim, [], method) - # downsampling - verify([1, 16] + [32] * ndim, [1, 16] + [16] * ndim, [], method) - # scales are specified instead of sizes - verify([1, 16] + [32] * ndim, [], [1, 1] + [0.5] * ndim, method) - verify([1, 16] + [32] * ndim, [], [1, 1] + [2] * ndim, method) + method = "linear" + # upsampling + verify([1, 16] + [32] * ndim, [1, 16] + [64] * ndim, [], method) + # downsampling + verify([1, 16] + [32] * ndim, [1, 16] + [16] * ndim, [], method) + # scales are specified instead of sizes + verify([1, 16] + [32] * ndim, [], [1, 1] + [0.5] * ndim, method) + verify([1, 16] + [32] * ndim, [], [1, 1] + [2] * ndim, method) if ndim == 2: # ONNX Runtime only supports cubic interpolation for 2D images @@ -3672,14 +3856,23 @@ def verify_opset_10(ishape, scales, mode): ) model = helper.make_model(graph, producer_name="resize_test") - verify_with_ort(model, [ishape], [oshape], use_vm=True, freeze_params=True, opset=10) + verify_with_ort( + model, + [ishape], + [oshape], + use_vm=True, + freeze_params=True, + opset=10, + target=target, + dev=dev, + ) verify_opset_10([1, 16, 32, 32], [1, 1, 2, 2], "nearest") verify_opset_10([1, 16, 32, 32], [1, 1, 0.5, 0.5], "linear") -@tvm.testing.uses_gpu -def test_nonzero(): +@tvm.testing.parametrize_targets +def test_nonzero(target, dev): def verify_nonzero(indata, outdata, dtype): node = helper.make_node( "NonZero", @@ -3696,7 +3889,9 @@ def verify_nonzero(indata, outdata, dtype): model = helper.make_model(graph, producer_name="nonzero_test") - verify_with_ort_with_inputs(model, [indata], dtype="int64", use_vm=True, opset=9) + verify_with_ort_with_inputs( + model, [indata], dtype="int64", use_vm=True, opset=9, target=target, dev=dev + ) input_data = np.array([[1, 0], [1, 1]], dtype=np.int64) result = np.array((np.nonzero(input_data))) # expected output [[0, 1, 1], [0, 0, 1]] @@ -3707,8 +3902,8 @@ def verify_nonzero(indata, outdata, dtype): verify_nonzero(input_data, result, dtype=np.int64) -@tvm.testing.uses_gpu -def test_topk(): +@tvm.testing.parametrize_targets +def test_topk(target, dev): def verify_topk(input_dims, K, axis=-1): output_dims = list(input_dims) output_dims[axis] = K @@ -3739,7 +3934,9 @@ def verify_topk(input_dims, K, axis=-1): model = helper.make_model(graph, producer_name="topk_test") indata = np.random.uniform(-10, 10, input_dims).astype(np.float32) - verify_with_ort_with_inputs(model, [indata, np.array([K])], use_vm=True) + verify_with_ort_with_inputs( + model, [indata, np.array([K])], use_vm=True, target=target, dev=dev + ) for n in [12, 32]: for shape in [[n], [n, n], [n, n, n]]: @@ -3751,8 +3948,8 @@ def verify_topk(input_dims, K, axis=-1): verify_topk([n, n, n], 5, 2) -@tvm.testing.uses_gpu -def test_roi_align(): +@tvm.testing.parametrize_targets +def test_roi_align(target, dev): def verify_roi_align( input_dims, num_roi, @@ -3799,7 +3996,11 @@ def verify_roi_align( np_batch_indicies = np.random.randint(low=0, high=input_dims[0], size=num_roi) verify_with_ort_with_inputs( - model, [np_data, np_rois, np_batch_indicies], out_shape=[output_dims] + model, + [np_data, np_rois, np_batch_indicies], + out_shape=[output_dims], + target=target, + dev=dev, ) verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=0, spatial_scale=1.0) @@ -3816,8 +4017,8 @@ def verify_roi_align( # ONNX implementation of roi_align with max mode is incorrect, so we don't compare outputs here. -@tvm.testing.uses_gpu -def test_non_max_suppression(): +@tvm.testing.parametrize_targets +def test_non_max_suppression(target, dev): def verify_nms( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_dims ): @@ -3855,7 +4056,7 @@ def verify_nms( model = helper.make_model(graph, producer_name="nms_test") - verify_with_ort_with_inputs(model, inputs, use_vm=True) + verify_with_ort_with_inputs(model, inputs, use_vm=True, target=target, dev=dev) boxes = np.array( [ @@ -3909,270 +4110,314 @@ def verify_nms( ) -def verify_cond_loop(): - y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [1]) - y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [1]) - scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [1]) - cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) - cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) - iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, []) - - y = np.array([-2]).astype(np.float32) - - five_const_node = helper.make_node( - "Constant", - inputs=[], - outputs=["five"], - value=helper.make_tensor( - name="const_tensor_five", data_type=TensorProto.FLOAT, dims=(), vals=[5] - ), - ) - - iter_cast_node = helper.make_node( - "Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT - ) +# @tvm.testing.parametrize_targets +@pytest.mark.skip( + "Test regressed due to not being run in CI" + + " tracked here: https://github.com/apache/tvm/pull/8274" +) +def test_loop(target, dev): + def verify_cond_loop(): + y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [1]) + y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [1]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [1]) + cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) + cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) + iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, []) + + y = np.array([-2]).astype(np.float32) + + five_const_node = helper.make_node( + "Constant", + inputs=[], + outputs=["five"], + value=helper.make_tensor( + name="const_tensor_five", data_type=TensorProto.FLOAT, dims=(), vals=[5] + ), + ) - y_add_node = helper.make_node("Add", inputs=["y_in", "iter_cast"], outputs=["y_out"]) + iter_cast_node = helper.make_node( + "Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT + ) - less_node = helper.make_node("Less", inputs=["y_out", "five"], outputs=["cond_less"]) + y_add_node = helper.make_node("Add", inputs=["y_in", "iter_cast"], outputs=["y_out"]) - squeeze_node = helper.make_node("Squeeze", inputs=["cond_less"], outputs=["cond_squeeze"]) + less_node = helper.make_node("Less", inputs=["y_out", "five"], outputs=["cond_less"]) - cond_cast_node = helper.make_node( - "Cast", inputs=["cond_squeeze"], outputs=["cond_out"], to=onnx.TensorProto.BOOL - ) + squeeze_node = helper.make_node("Squeeze", inputs=["cond_less"], outputs=["cond_squeeze"]) - scan_identity_node = helper.make_node("Identity", inputs=["y_out"], outputs=["scan_out"]) + cond_cast_node = helper.make_node( + "Cast", inputs=["cond_squeeze"], outputs=["cond_out"], to=onnx.TensorProto.BOOL + ) - loop_body = helper.make_graph( - [ - five_const_node, - iter_cast_node, - y_add_node, - less_node, - squeeze_node, - cond_cast_node, - scan_identity_node, - ], - "loop_body", - [iter_count, cond_in, y_in], - [cond_out, y_out, scan_out], - ) + scan_identity_node = helper.make_node("Identity", inputs=["y_out"], outputs=["scan_out"]) - loop_node = helper.make_node( - "Loop", inputs=["trip_count", "cond", "y"], outputs=["res_y", "res_scan"], body=loop_body - ) + loop_body = helper.make_graph( + [ + five_const_node, + iter_cast_node, + y_add_node, + less_node, + squeeze_node, + cond_cast_node, + scan_identity_node, + ], + "loop_body", + [iter_count, cond_in, y_in], + [cond_out, y_out, scan_out], + ) - trip_count = np.array(5).astype(np.int64) - res_y = np.array([13]).astype(np.float32) - cond = np.array(1).astype(bool) - loop_graph = onnx.helper.make_graph( - [loop_node], - "loop_outer", - inputs=[ - onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []), - onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), - onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [1]), - ], - outputs=[ - onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [1]), - onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 1]), - ], - ) - loop_model = onnx.helper.make_model(loop_graph) + loop_node = helper.make_node( + "Loop", + inputs=["trip_count", "cond", "y"], + outputs=["res_y", "res_scan"], + body=loop_body, + ) - # Set a high trip count so that condition trips first. - trip_count = np.array(40).astype(np.int64) - cond = np.array(1).astype(bool) - input_vals = [trip_count, cond, y] - verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True) + trip_count = np.array(5).astype(np.int64) + res_y = np.array([13]).astype(np.float32) + cond = np.array(1).astype(bool) + loop_graph = onnx.helper.make_graph( + [loop_node], + "loop_outer", + inputs=[ + onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []), + onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), + onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [1]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [1]), + onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 1]), + ], + ) + loop_model = onnx.helper.make_model(loop_graph) + # Set a high trip count so that condition trips first. + trip_count = np.array(40).astype(np.int64) + cond = np.array(1).astype(bool) + input_vals = [trip_count, cond, y] + verify_with_ort_with_inputs( + loop_model, + input_vals, + use_vm=True, + freeze_params=True, + opset=11, + target=target, + dev=dev, + ) -def verify_count_loop(): - y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, []) - y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, []) - scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, []) - cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) - cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) - iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, []) + def verify_count_loop(): + y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, []) + y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, []) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, []) + cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) + cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) + iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, []) - y = np.array(-2).astype(np.float32) + y = np.array(-2).astype(np.float32) - iter_cast_node = helper.make_node( - "Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT - ) + iter_cast_node = helper.make_node( + "Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT + ) - y_add_node = helper.make_node("Add", inputs=["y_in", "iter_cast"], outputs=["y_out"]) + y_add_node = helper.make_node("Add", inputs=["y_in", "iter_cast"], outputs=["y_out"]) - identity_node = helper.make_node("Identity", inputs=["cond_in"], outputs=["cond_out"]) + identity_node = helper.make_node("Identity", inputs=["cond_in"], outputs=["cond_out"]) - scan_identity_node = helper.make_node("Identity", inputs=["y_out"], outputs=["scan_out"]) + scan_identity_node = helper.make_node("Identity", inputs=["y_out"], outputs=["scan_out"]) - loop_body = helper.make_graph( - [identity_node, iter_cast_node, y_add_node, scan_identity_node], - "loop_body", - [iter_count, cond_in, y_in], - [cond_out, y_out, scan_out], - ) + loop_body = helper.make_graph( + [identity_node, iter_cast_node, y_add_node, scan_identity_node], + "loop_body", + [iter_count, cond_in, y_in], + [cond_out, y_out, scan_out], + ) - loop_node = helper.make_node( - "Loop", inputs=["trip_count", "cond", "y"], outputs=["res_y", "res_scan"], body=loop_body - ) + loop_node = helper.make_node( + "Loop", + inputs=["trip_count", "cond", "y"], + outputs=["res_y", "res_scan"], + body=loop_body, + ) - trip_count = np.array(5).astype(np.int64) - res_y = np.array([13]).astype(np.float32) - cond = np.array(1).astype(bool) - loop_graph = onnx.helper.make_graph( - [loop_node], - "loop_outer", - inputs=[ - onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []), - onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), - onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, []), - ], - outputs=[ - onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, []), - onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5]), - ], - ) - loop_model = onnx.helper.make_model(loop_graph) + trip_count = np.array(5).astype(np.int64) + res_y = np.array([13]).astype(np.float32) + cond = np.array(1).astype(bool) + loop_graph = onnx.helper.make_graph( + [loop_node], + "loop_outer", + inputs=[ + onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []), + onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), + onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, []), + ], + outputs=[ + onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, []), + onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5]), + ], + ) + loop_model = onnx.helper.make_model(loop_graph) - trip_count = np.array(5).astype(np.int64) - cond = np.array(1).astype(bool) - input_vals = [trip_count, cond, y] - verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True) + trip_count = np.array(5).astype(np.int64) + cond = np.array(1).astype(bool) + input_vals = [trip_count, cond, y] + verify_with_ort_with_inputs( + loop_model, + input_vals, + use_vm=True, + freeze_params=True, + opset=11, + target=target, + dev=dev, + ) + def verify_tensor_loop(shapeless_output=False): + y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [3, 3, 3, 3]) + y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [3, 3, 3, 3]) + scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [3, 3, 3, 3]) + cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) + cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) + iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, []) -def verify_tensor_loop(): - y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [3, 3, 3, 3]) - y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [3, 3, 3, 3]) - scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [3, 3, 3, 3]) - cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) - cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) - iter_count = helper.make_tensor_value_info("iter_count", TensorProto.INT64, []) + y = np.random.normal(size=[3, 3, 3, 3]).astype(np.float32) - y = np.random.normal(size=[3, 3, 3, 3]).astype(np.float32) + iter_cast_node = helper.make_node( + "Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT + ) - iter_cast_node = helper.make_node( - "Cast", inputs=["iter_count"], outputs=["iter_cast"], to=onnx.TensorProto.FLOAT - ) + y_add_node = helper.make_node("Add", inputs=["y_in", "iter_cast"], outputs=["y_out"]) - y_add_node = helper.make_node("Add", inputs=["y_in", "iter_cast"], outputs=["y_out"]) + identity_node = helper.make_node("Identity", inputs=["cond_in"], outputs=["cond_out"]) - identity_node = helper.make_node("Identity", inputs=["cond_in"], outputs=["cond_out"]) + scan_identity_node = helper.make_node("Identity", inputs=["y_out"], outputs=["scan_out"]) - scan_identity_node = helper.make_node("Identity", inputs=["y_out"], outputs=["scan_out"]) + loop_body = helper.make_graph( + [identity_node, iter_cast_node, y_add_node, scan_identity_node], + "loop_body", + [iter_count, cond_in, y_in], + [cond_out, y_out, scan_out], + ) - loop_body = helper.make_graph( - [identity_node, iter_cast_node, y_add_node, scan_identity_node], - "loop_body", - [iter_count, cond_in, y_in], - [cond_out, y_out, scan_out], - ) + loop_node = helper.make_node( + "Loop", + inputs=["trip_count", "cond", "y"], + outputs=["res_y", "res_scan"], + body=loop_body, + ) - loop_node = helper.make_node( - "Loop", inputs=["trip_count", "cond", "y"], outputs=["res_y", "res_scan"], body=loop_body - ) + trip_count = np.array(5).astype(np.int64) + cond = np.array(1).astype(bool) - trip_count = np.array(5).astype(np.int64) - cond = np.array(1).astype(bool) - loop_graph = onnx.helper.make_graph( - [loop_node], - "loop_outer", - inputs=[ - onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []), - onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), - onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]), - ], - outputs=[ - onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]), - onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 3, 3, 3, 3]), - ], - ) - loop_model = onnx.helper.make_model(loop_graph) + # Allow testing of malformed nodes since pytorch likes to create these. + if shapeless_output: + scan_shape = None + else: + scan_shape = [5, 3, 3, 3, 3] - trip_count = np.array(5).astype(np.int64) - cond = np.array(1).astype(bool) - input_vals = [trip_count, cond, y] - verify_with_ort_with_inputs( - loop_model, input_vals, use_vm=True, freeze_params=True, convert_to_static=True - ) + loop_graph = onnx.helper.make_graph( + [loop_node], + "loop_outer", + inputs=[ + onnx.helper.make_tensor_value_info("trip_count", onnx.TensorProto.INT64, []), + onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), + onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]), + ], + outputs=[ + onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]), + onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, scan_shape), + ], + ) + loop_model = onnx.helper.make_model(loop_graph) + trip_count = np.array(5).astype(np.int64) + cond = np.array(1).astype(bool) + input_vals = [trip_count, cond, y] + verify_with_ort_with_inputs( + loop_model, + input_vals, + use_vm=True, + freeze_params=True, + convert_to_static=True, + opset=11, + target=target, + dev=dev, + ) -def test_loop(): # Test a loop that exits once a condition is met. verify_cond_loop() # Test a loop that exits after a fixed number of iterations with scalar outputs. verify_count_loop() # Test a loop that uses an array output. verify_tensor_loop() + # Test a loop that is malformed and has no output shape defined. + verify_tensor_loop(shapeless_output=True) -def verify_if(cond_array, num_outputs): - # Given a bool scalar input cond. - # return constant tensor x if cond is True, otherwise return constant tensor y. +@tvm.testing.parametrize_targets +def test_if(target, dev): + def verify_if(cond_array, num_outputs): + # Given a bool scalar input cond. + # return constant tensor x if cond is True, otherwise return constant tensor y. - def append_constant_nodes(nodes, outputs, expected, name): - outputs.append(onnx.helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, [5])) + def append_constant_nodes(nodes, outputs, expected, name): + outputs.append(onnx.helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, [5])) - expected.append(np.random.randn(5).astype("float32")) + expected.append(np.random.randn(5).astype("float32")) - nodes.append( - onnx.helper.make_node( - "Constant", inputs=[], outputs=[name], value=numpy_helper.from_array(expected[-1]) + nodes.append( + onnx.helper.make_node( + "Constant", + inputs=[], + outputs=[name], + value=numpy_helper.from_array(expected[-1]), + ) ) - ) - if_outputs = [] - graph_outputs = [] + if_outputs = [] + graph_outputs = [] - then_nodes, then_outs, then_expected = [], [], [] - else_nodes, else_outs, else_expected = [], [], [] + then_nodes, then_outs, then_expected = [], [], [] + else_nodes, else_outs, else_expected = [], [], [] - for i in range(num_outputs): - append_constant_nodes(then_nodes, then_outs, then_expected, "then_out{}".format(i)) - append_constant_nodes(else_nodes, else_outs, else_expected, "else_out{}".format(i)) + for i in range(num_outputs): + append_constant_nodes(then_nodes, then_outs, then_expected, "then_out{}".format(i)) + append_constant_nodes(else_nodes, else_outs, else_expected, "else_out{}".format(i)) - if_outputs.append("res{}".format(i)) - graph_outputs.append( - onnx.helper.make_tensor_value_info("res{}".format(i), onnx.TensorProto.FLOAT, [5]), - ) + if_outputs.append("res{}".format(i)) + graph_outputs.append( + onnx.helper.make_tensor_value_info("res{}".format(i), onnx.TensorProto.FLOAT, [5]), + ) - then_body = onnx.helper.make_graph(then_nodes, "then_body", [], then_outs) - else_body = onnx.helper.make_graph(else_nodes, "else_body", [], else_outs) + then_body = onnx.helper.make_graph(then_nodes, "then_body", [], then_outs) + else_body = onnx.helper.make_graph(else_nodes, "else_body", [], else_outs) - if_node = onnx.helper.make_node( - "If", inputs=["cond"], outputs=if_outputs, then_branch=then_body, else_branch=else_body - ) + if_node = onnx.helper.make_node( + "If", inputs=["cond"], outputs=if_outputs, then_branch=then_body, else_branch=else_body + ) - if_graph = onnx.helper.make_graph( - [if_node], - "if_outer", - inputs=[ - onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), - ], - outputs=graph_outputs, - ) + if_graph = onnx.helper.make_graph( + [if_node], + "if_outer", + inputs=[ + onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), + ], + outputs=graph_outputs, + ) - if_model = onnx.helper.make_model(if_graph) - if cond_array: - cond = np.array([1]).astype("bool") - else: - cond = np.array(1).astype("bool") - correct_out = then_expected if cond else else_expected + if_model = onnx.helper.make_model(if_graph) + if cond_array: + cond = np.array([1]).astype("bool") + else: + cond = np.array(1).astype("bool") + correct_out = then_expected if cond else else_expected - # TODO(jwfromm): Onnxruntime 1.0.0 is buggy with If statements. Replace this with - # verify_with_ort once we update versions. - for target, dev in tvm.testing.enabled_targets(): + # TODO(jwfromm): Onnxruntime 1.0.0 is buggy with If statements. Replace this with + # verify_with_ort once we update versions. tvm_out = get_tvm_output_with_vm(if_model, [cond], target, dev, freeze_params=True) if not isinstance(tvm_out, list): tvm_out = [tvm_out] for i in range(len(tvm_out)): tvm.testing.assert_allclose(correct_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) - -@tvm.testing.uses_gpu -def test_if(): # Confirm that if works with cond as an array or scalar. verify_if(cond_array=False, num_outputs=1) verify_if(cond_array=False, num_outputs=2) @@ -4180,8 +4425,8 @@ def test_if(): verify_if(cond_array=True, num_outputs=2) -@tvm.testing.uses_gpu -def test_size(): +@tvm.testing.parametrize_targets +def test_size(target, dev): def verify_size(indata): node = helper.make_node( "Size", @@ -4198,7 +4443,9 @@ def verify_size(indata): model = helper.make_model(graph, producer_name="size_test") - verify_with_ort_with_inputs(model, [indata], dtype="int64", use_vm=True, opset=11) + verify_with_ort_with_inputs( + model, [indata], dtype="int64", use_vm=True, opset=11, target=target, dev=dev + ) input_data = np.array([[1, 0], [1, 1]], dtype=np.int64) verify_size(input_data) @@ -4207,8 +4454,8 @@ def verify_size(indata): verify_size(input_data) -@tvm.testing.uses_gpu -def test_maxunpool(): +@tvm.testing.parametrize_targets +def test_maxunpool(target, dev): def verify_maxunpool(data, indices, kernel_shape, strides, output_shape=None, pads=None): input_names = ["xT", "xI"] input_info = [ @@ -4257,7 +4504,9 @@ def verify_maxunpool(data, indices, kernel_shape, strides, output_shape=None, pa model = helper.make_model(graph, producer_name="size_test") - verify_with_ort_with_inputs(model, input_values, use_vm=True, opset=11) + verify_with_ort_with_inputs( + model, input_values, use_vm=True, opset=11, target=target, dev=dev + ) # Basic test xT = np.array([[[[5, 6], [7, 8]]]], dtype=np.float32) @@ -4275,8 +4524,8 @@ def verify_maxunpool(data, indices, kernel_shape, strides, output_shape=None, pa verify_maxunpool(xT, xI, [2, 2], strides=[2, 2], pads=pads) -@tvm.testing.uses_gpu -def test_softplus(): +@tvm.testing.parametrize_targets +def test_softplus(target, dev): def verify_softplus(indata): node = helper.make_node( "Softplus", @@ -4293,7 +4542,9 @@ def verify_softplus(indata): model = helper.make_model(graph, producer_name="softplus_test") - verify_with_ort_with_inputs(model, [indata], dtype="float32", use_vm=True, opset=11) + verify_with_ort_with_inputs( + model, [indata], dtype="float32", use_vm=True, opset=11, target=target, dev=dev + ) # Simple case with all signs. input_data = np.array([[-1, 0, 1]], dtype=np.float32) @@ -4303,8 +4554,8 @@ def verify_softplus(indata): verify_softplus(input_data) -@tvm.testing.uses_gpu -def test_cumsum(): +@tvm.testing.parametrize_targets +def test_cumsum(target, dev): def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): cumsum_node = onnx.helper.make_node( "CumSum", @@ -4338,7 +4589,9 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): model = helper.make_model(graph, producer_name="cumsum_test") - verify_with_ort_with_inputs(model, [indata], dtype=type, use_vm=True, opset=11) + verify_with_ort_with_inputs( + model, [indata], dtype=type, use_vm=True, opset=11, target=target, dev=dev + ) data = ( np.array( @@ -4380,8 +4633,8 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): verify_cumsum(data, 1, 1, 1, type="int32") -@tvm.testing.uses_gpu -def test_eyelike(): +@tvm.testing.parametrize_targets +def test_eyelike(target, dev): def verify_eyelike(indata): node = helper.make_node( "EyeLike", @@ -4398,121 +4651,333 @@ def verify_eyelike(indata): model = helper.make_model(graph, producer_name="eyelike_test") - verify_with_ort_with_inputs(model, [indata], dtype="float32", opset=9) + verify_with_ort_with_inputs( + model, [indata], dtype="float32", opset=9, target=target, dev=dev + ) input_data = np.zeros((5, 5), dtype=np.float32) verify_eyelike(input_data) """ - The following parameterized tests loads the tests that ONNX ships as + The following parametrized tests loads the tests that ONNX ships as serialized ONNX files, inputs, and outputs. The goal of this test is to ensure the ONNX importer is in line with the ONNX specification. To allow these tests to run in CI before all pass, a number of tests that are not yet supported are skipped. """ -from onnx import numpy_helper - -f = onnx.__file__ -import glob +onnx_test_node_dir = os.path.join(os.path.dirname(onnx.__file__), "backend", "test", "data", "node") -onnx_test_folders = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) +onnx_test_folders = sorted( + dirname + for dirname in os.listdir(onnx_test_node_dir) + if dirname.startswith("test") and os.path.isdir(os.path.join(onnx_test_node_dir, dirname)) +) unsupported_onnx_tests = [ - "test_basic_convinteger/", - "test_cast_DOUBLE_to_FLOAT16/", - "test_cast_FLOAT_to_STRING/", - "test_cast_STRING_to_FLOAT/", - "test_compress_0/", - "test_compress_1/", - "test_compress_default_axis/", - "test_compress_negative_axis/", - "test_convinteger_with_padding/", - "test_convtranspose_dilations/", - "test_convtranspose_output_shape/", - "test_cumsum_1d/", - "test_cumsum_1d_exclusive/", - "test_cumsum_1d_reverse/", - "test_cumsum_1d_reverse_exclusive/", - "test_cumsum_2d_axis_0/", - "test_cumsum_2d_axis_1/", - "test_cumsum_2d_negative_axis/", - "test_det_2d/", - "test_det_nd/", - "test_matmulinteger/", - "test_maxpool_2d_same_lower/", - "test_maxpool_2d_same_upper/", - "test_maxpool_with_argmax_2d_precomputed_pads/", - "test_maxpool_with_argmax_2d_precomputed_strides/", - "test_maxunpool_export_with_output_shape/", - "test_mvn/", - "test_qlinearmatmul_2D/", - "test_qlinearmatmul_3D/", - "test_resize_tf_crop_and_resize/", - ## For these three tests, ONNX 1.6.0 has incorrect graphs, they pass with ONNX 1.7.0 - "test_resize_upsample_sizes_nearest_ceil_half_pixel/", - "test_resize_upsample_sizes_nearest_floor_align_corners/", - "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/", - "test_rnn_seq_length/", - "test_round/", - "test_scan9_sum/", - "test_scan_sum/", - "test_simple_rnn_defaults/", - "test_simple_rnn_with_initial_bias/", - "test_strnormalizer_export_monday_casesensintive_lower/", - "test_strnormalizer_export_monday_casesensintive_nochangecase/", - "test_strnormalizer_export_monday_casesensintive_upper/", - "test_strnormalizer_export_monday_empty_output/", - "test_strnormalizer_export_monday_insensintive_upper_twodim/", - "test_strnormalizer_nostopwords_nochangecase/", - "test_tfidfvectorizer_tf_batch_onlybigrams_skip0/", - "test_tfidfvectorizer_tf_batch_onlybigrams_skip5/", - "test_tfidfvectorizer_tf_batch_uniandbigrams_skip5/", - "test_tfidfvectorizer_tf_only_bigrams_skip0/", - "test_tfidfvectorizer_tf_onlybigrams_levelempty/", - "test_tfidfvectorizer_tf_onlybigrams_skip5/", - "test_tfidfvectorizer_tf_uniandbigrams_skip5/", - "test_unique_sorted_with_axis/", - "test_unique_sorted_with_axis_3d/", - "test_unique_sorted_with_negative_axis/", - "test_upsample_nearest/", + "test_adagrad", + "test_adagrad_multiple", + "test_adam", + "test_adam_multiple", + "test_argmax_default_axis_example_select_last_index", + "test_argmax_default_axis_random_select_last_index", + "test_argmax_keepdims_example_select_last_index", + "test_argmax_keepdims_random_select_last_index", + "test_argmax_negative_axis_keepdims_example_select_last_index", + "test_argmax_negative_axis_keepdims_random_select_last_index", + "test_argmax_no_keepdims_example_select_last_index", + "test_argmax_no_keepdims_random_select_last_index", + "test_argmin_default_axis_example_select_last_index", + "test_argmin_default_axis_random_select_last_index", + "test_argmin_keepdims_example_select_last_index", + "test_argmin_keepdims_random_select_last_index", + "test_argmin_negative_axis_keepdims_example_select_last_index", + "test_argmin_negative_axis_keepdims_random_select_last_index", + "test_argmin_no_keepdims_example_select_last_index", + "test_argmin_no_keepdims_random_select_last_index", + "test_cast_BFLOAT16_to_FLOAT", + "test_cast_DOUBLE_to_FLOAT16", + "test_cast_FLOAT_to_BFLOAT16", + "test_cast_FLOAT_to_STRING", + "test_cast_STRING_to_FLOAT", + "test_compress_0", + "test_compress_1", + "test_compress_default_axis", + "test_compress_negative_axis", + "test_convtranspose_dilations", + "test_convtranspose_output_shape", + "test_cumsum_1d", + "test_cumsum_1d_exclusive", + "test_cumsum_1d_reverse", + "test_cumsum_1d_reverse_exclusive", + "test_cumsum_2d_axis_0", + "test_cumsum_2d_axis_1", + "test_cumsum_2d_negative_axis", + "test_det_2d", + "test_det_nd", + "test_dropout_default", + "test_dropout_default_mask", + "test_dropout_default_mask_ratio", + "test_dropout_default_ratio", + "test_einsum_batch_diagonal", + "test_einsum_batch_matmul", + "test_einsum_inner_prod", + "test_einsum_sum", + "test_einsum_transpose", + "test_greater_equal", + "test_greater_equal_bcast", + "test_hardmax_axis_0", + "test_hardmax_axis_1", + "test_hardmax_default_axis", + "test_if_seq", + "test_less_equal", + "test_less_equal_bcast", + "test_logsoftmax_axis_0_expanded", + "test_logsoftmax_axis_1_expanded", + "test_logsoftmax_axis_2_expanded", + "test_logsoftmax_default_axis_expanded", + "test_logsoftmax_example_1_expanded", + "test_logsoftmax_large_number_expanded", + "test_logsoftmax_negative_axis_expanded", + "test_loop11", + "test_loop13_seq", + "test_matmulinteger", + "test_maxpool_2d_same_lower", + "test_maxpool_2d_same_upper", + "test_maxpool_with_argmax_2d_precomputed_pads", + "test_maxpool_with_argmax_2d_precomputed_strides", + "test_maxunpool_export_with_output_shape", + "test_momentum", + "test_momentum_multiple", + "test_mvn", + "test_nesterov_momentum", + "test_nllloss_NC", + "test_nllloss_NC_expanded", + "test_nllloss_NCd1", + "test_nllloss_NCd1_expanded", + "test_nllloss_NCd1_ii", + "test_nllloss_NCd1_ii_expanded", + "test_nllloss_NCd1_mean_weight_negative_ii", + "test_nllloss_NCd1_mean_weight_negative_ii_expanded", + "test_nllloss_NCd1_weight", + "test_nllloss_NCd1_weight_expanded", + "test_nllloss_NCd1_weight_ii", + "test_nllloss_NCd1_weight_ii_expanded", + "test_nllloss_NCd1d2", + "test_nllloss_NCd1d2_expanded", + "test_nllloss_NCd1d2_no_weight_reduction_mean_ii", + "test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded", + "test_nllloss_NCd1d2_reduction_mean", + "test_nllloss_NCd1d2_reduction_mean_expanded", + "test_nllloss_NCd1d2_reduction_sum", + "test_nllloss_NCd1d2_reduction_sum_expanded", + "test_nllloss_NCd1d2_with_weight", + "test_nllloss_NCd1d2_with_weight_expanded", + "test_nllloss_NCd1d2_with_weight_reduction_mean", + "test_nllloss_NCd1d2_with_weight_reduction_mean_expanded", + "test_nllloss_NCd1d2_with_weight_reduction_sum", + "test_nllloss_NCd1d2_with_weight_reduction_sum_expanded", + "test_nllloss_NCd1d2_with_weight_reduction_sum_ii", + "test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded", + "test_nllloss_NCd1d2d3_none_no_weight_negative_ii", + "test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded", + "test_nllloss_NCd1d2d3_sum_weight_high_ii", + "test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded", + "test_nllloss_NCd1d2d3d4d5_mean_weight", + "test_nllloss_NCd1d2d3d4d5_mean_weight_expanded", + "test_nllloss_NCd1d2d3d4d5_none_no_weight", + "test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded", + "test_pow_types_float", + "test_pow_types_float32_int32", + "test_pow_types_float32_int64", + "test_pow_types_float32_uint32", + "test_pow_types_float32_uint64", + "test_pow_types_int", + "test_pow_types_int32_float32", + "test_pow_types_int32_int32", + "test_pow_types_int64_float32", + "test_pow_types_int64_int64", + "test_qlinearmatmul_2D", + "test_qlinearmatmul_3D", + "test_range_float_type_positive_delta_expanded", + "test_range_int32_type_negative_delta_expanded", + "test_reduce_sum_default_axes_keepdims_example", + "test_reduce_sum_default_axes_keepdims_random", + "test_reduce_sum_do_not_keepdims_example", + "test_reduce_sum_do_not_keepdims_random", + "test_reduce_sum_empty_axes_input_noop_example", + "test_reduce_sum_empty_axes_input_noop_random", + "test_reduce_sum_keepdims_example", + "test_reduce_sum_keepdims_random", + "test_reduce_sum_negative_axes_keepdims_example", + "test_reduce_sum_negative_axes_keepdims_random", + "test_resize_downsample_sizes_cubic", + "test_resize_downsample_sizes_linear_pytorch_half_pixel", + "test_resize_downsample_sizes_nearest", + "test_resize_tf_crop_and_resize", + "test_resize_upsample_sizes_cubic", + "test_resize_upsample_sizes_nearest", + "test_resize_upsample_sizes_nearest_ceil_half_pixel", + "test_resize_upsample_sizes_nearest_floor_align_corners", + "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", + "test_rnn_seq_length", + "test_round", + "test_scan9_sum", + "test_scan_sum", + "test_sce_NCd1_mean_weight_negative_ii", + "test_sce_NCd1_mean_weight_negative_ii_expanded", + "test_sce_NCd1_mean_weight_negative_ii_log_prob", + "test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded", + "test_sce_NCd1d2d3_none_no_weight_negative_ii", + "test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded", + "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob", + "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded", + "test_sce_NCd1d2d3_sum_weight_high_ii", + "test_sce_NCd1d2d3_sum_weight_high_ii_expanded", + "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob", + "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded", + "test_sce_NCd1d2d3d4d5_mean_weight", + "test_sce_NCd1d2d3d4d5_mean_weight_expanded", + "test_sce_NCd1d2d3d4d5_mean_weight_log_prob", + "test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded", + "test_sce_NCd1d2d3d4d5_none_no_weight", + "test_sce_NCd1d2d3d4d5_none_no_weight_expanded", + "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob", + "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded", + "test_sce_mean", + "test_sce_mean_3d", + "test_sce_mean_3d_expanded", + "test_sce_mean_3d_log_prob", + "test_sce_mean_3d_log_prob_expanded", + "test_sce_mean_expanded", + "test_sce_mean_log_prob", + "test_sce_mean_log_prob_expanded", + "test_sce_mean_no_weight_ii", + "test_sce_mean_no_weight_ii_3d", + "test_sce_mean_no_weight_ii_3d_expanded", + "test_sce_mean_no_weight_ii_3d_log_prob", + "test_sce_mean_no_weight_ii_3d_log_prob_expanded", + "test_sce_mean_no_weight_ii_4d", + "test_sce_mean_no_weight_ii_4d_expanded", + "test_sce_mean_no_weight_ii_4d_log_prob", + "test_sce_mean_no_weight_ii_4d_log_prob_expanded", + "test_sce_mean_no_weight_ii_expanded", + "test_sce_mean_no_weight_ii_log_prob", + "test_sce_mean_no_weight_ii_log_prob_expanded", + "test_sce_mean_weight", + "test_sce_mean_weight_expanded", + "test_sce_mean_weight_ii", + "test_sce_mean_weight_ii_3d", + "test_sce_mean_weight_ii_3d_expanded", + "test_sce_mean_weight_ii_3d_log_prob", + "test_sce_mean_weight_ii_3d_log_prob_expanded", + "test_sce_mean_weight_ii_4d", + "test_sce_mean_weight_ii_4d_expanded", + "test_sce_mean_weight_ii_4d_log_prob", + "test_sce_mean_weight_ii_4d_log_prob_expanded", + "test_sce_mean_weight_ii_expanded", + "test_sce_mean_weight_ii_log_prob", + "test_sce_mean_weight_ii_log_prob_expanded", + "test_sce_mean_weight_log_prob", + "test_sce_mean_weight_log_prob_expanded", + "test_sce_none", + "test_sce_none_expanded", + "test_sce_none_log_prob", + "test_sce_none_log_prob_expanded", + "test_sce_none_weights", + "test_sce_none_weights_expanded", + "test_sce_none_weights_log_prob", + "test_sce_none_weights_log_prob_expanded", + "test_sce_sum", + "test_sce_sum_expanded", + "test_sce_sum_log_prob", + "test_sce_sum_log_prob_expanded", + "test_sequence_insert_at_back", + "test_sequence_insert_at_front", + "test_simple_rnn_defaults", + "test_simple_rnn_with_initial_bias", + "test_softmax_axis_0_expanded", + "test_softmax_axis_1_expanded", + "test_softmax_axis_2_expanded", + "test_softmax_default_axis_expanded", + "test_softmax_example_expanded", + "test_softmax_large_number_expanded", + "test_softmax_negative_axis_expanded", + "test_split_variable_parts_1d", + "test_split_variable_parts_2d", + "test_split_variable_parts_default_axis", + "test_split_zero_size_splits", + "test_squeeze", + "test_squeeze_negative_axes", + "test_strnormalizer_export_monday_casesensintive_lower", + "test_strnormalizer_export_monday_casesensintive_nochangecase", + "test_strnormalizer_export_monday_casesensintive_upper", + "test_strnormalizer_export_monday_empty_output", + "test_strnormalizer_export_monday_insensintive_upper_twodim", + "test_strnormalizer_nostopwords_nochangecase", + "test_tfidfvectorizer_tf_batch_onlybigrams_skip0", + "test_tfidfvectorizer_tf_batch_onlybigrams_skip5", + "test_tfidfvectorizer_tf_batch_uniandbigrams_skip5", + "test_tfidfvectorizer_tf_only_bigrams_skip0", + "test_tfidfvectorizer_tf_onlybigrams_levelempty", + "test_tfidfvectorizer_tf_onlybigrams_skip5", + "test_tfidfvectorizer_tf_uniandbigrams_skip5", + "test_training_dropout", + "test_training_dropout_default", + "test_training_dropout_default_mask", + "test_training_dropout_mask", + "test_training_dropout_zero_ratio", + "test_training_dropout_zero_ratio_mask", + "test_unique_sorted_with_axis", + "test_unique_sorted_with_axis_3d", + "test_unique_sorted_with_negative_axis", + "test_unsqueeze_axis_0", + "test_unsqueeze_axis_1", + "test_unsqueeze_axis_2", + "test_unsqueeze_negative_axes", + "test_unsqueeze_three_axes", + "test_unsqueeze_two_axes", + "test_unsqueeze_unsorted_axes", + "test_upsample_nearest", ] -targets = [tgt for (tgt, _) in tvm.testing.enabled_targets()] - target_skips = { "cuda": [ - "test_mod_mixed_sign_float16/", - "test_qlinearconv/", - "test_resize_upsample_sizes_nearest/", + "test_range_float_type_positive_delta_expanded", + "test_range_int32_type_positive_delta_expanded", + "test_mod_mixed_sign_float16", + "test_qlinearconv", + "test_resize_upsample_sizes_nearest", ] } -@pytest.mark.parametrize("target", targets) -@pytest.mark.parametrize("test", onnx_test_folders) -def test_onnx_nodes(test, target): - if target in target_skips: - for failure in target_skips[target]: - if failure in test: - pytest.skip() - break - for failure in unsupported_onnx_tests: - if failure in test: - pytest.skip() - break +@pytest.mark.parametrize("onnx_test", onnx_test_folders) +@tvm.testing.parametrize_targets +def test_onnx_nodes(target, dev, onnx_test): + target_kind = tvm.target.Target(target).kind.name + + if onnx_test in unsupported_onnx_tests: + pytest.skip(f"Onnx test '{onnx_test}' not yet supported by TVM") + + target_specific_skips = target_skips.get(target_kind, []) + if onnx_test in target_specific_skips: + pytest.skip(f"Onnx test '{onnx_test}' not yet supported by TVM on {target_kind} targets") + + test_dir = os.path.join(onnx_test_node_dir, onnx_test) + atol = 1e-5 rtol = 1e-5 - if "roialign" in test: + if "roialign" in test_dir: # for some reason the ONNX test crops the # roialign results to 4 decimal places atol = 1e-4 - onnx_model = onnx.load(test + "/model.onnx") + onnx_model = onnx.load(test_dir + "/model.onnx") inputs = [] outputs = [] - for dataset in glob.glob(test + "/*/"): + for dataset in glob.glob(test_dir + "/*/"): tensors = sorted(glob.glob(dataset + "/*.pb")) for tensor in tensors: new_tensor = onnx.TensorProto() @@ -4525,7 +4990,6 @@ def test_onnx_nodes(test, target): else: raise ImportError(str(tensor) + " not labeled as an import or an output") - dev = tvm.device(target, 0) tvm_val = get_tvm_output_with_vm(onnx_model, inputs, target, dev) if len(outputs) == 1: tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=rtol, atol=atol) @@ -4559,7 +5023,8 @@ def test_wrong_input(): relay.frontend.from_onnx(model, shape=wrong_shape_dict) -def test_aten(): +@tvm.testing.parametrize_targets +def test_aten(target, dev): torch.set_grad_enabled(False) def _convert_to_onnx(model, inputs): @@ -4583,43 +5048,46 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None model = torch.nn.EmbeddingBag(num_embedding, embedding_dim) onnx_model = _convert_to_onnx(model, dummy_data) torch_out = model(dummy_data) - for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output_with_vm( - onnx_model, tvm_inputs, target, ctx, freeze_params=True, convert_to_static=True - ) - tvm.testing.assert_allclose(torch_out.numpy(), tvm_out) + tvm_out = get_tvm_output_with_vm( + onnx_model, + tvm_inputs, + freeze_params=True, + convert_to_static=True, + target=target, + dev=dev, + ) + tvm.testing.assert_allclose(torch_out.numpy(), tvm_out, atol=5e-7) verify_embedding_bag(10, 3, [2, 10]) verify_embedding_bag(32, 2, [3, 3]) -def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis): - node = onnx.helper.make_node( - "ReverseSequence", - inputs=["x", "sequence_lens"], - outputs=["y"], - time_axis=time_axis, - batch_axis=batch_axis, - ) - - graph = helper.make_graph( - [node], - "reverse_sequence_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape)), - helper.make_tensor_value_info( - "sequence_lens", TensorProto.INT64, list(sequence_lens.shape) - ), - ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(x.shape))], - ) +@tvm.testing.parametrize_targets +def test_reverse_sequence(target, dev): + def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis): + node = onnx.helper.make_node( + "ReverseSequence", + inputs=["x", "sequence_lens"], + outputs=["y"], + time_axis=time_axis, + batch_axis=batch_axis, + ) - model = helper.make_model(graph, producer_name="reverse_sequence_test") - verify_with_ort_with_inputs(model, [x, sequence_lens], [x.shape]) + graph = helper.make_graph( + [node], + "reverse_sequence_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape)), + helper.make_tensor_value_info( + "sequence_lens", TensorProto.INT64, list(sequence_lens.shape) + ), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(x.shape))], + ) + model = helper.make_model(graph, producer_name="reverse_sequence_test") + verify_with_ort_with_inputs(model, [x, sequence_lens], [x.shape], target=target, dev=dev) -@tvm.testing.uses_gpu -def test_reverse_sequence(): x = np.array( [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], dtype=np.float32, @@ -4631,95 +5099,96 @@ def test_reverse_sequence(): verify_reverse_sequence(x, sequence_lens, 1, 0) -def verify_qlinearconv( - x_shape, - w_shape, - y_shape, - padding, - kernel_shape, - strides, - dilations, - auto_pad="NOTSET", - bias=False, -): +@tvm.testing.known_failing_targets("cuda") +@tvm.testing.parametrize_targets +def test_qlinearconv(target, dev): + def verify_qlinearconv( + x_shape, + w_shape, + y_shape, + padding, + kernel_shape, + strides, + dilations, + auto_pad="NOTSET", + bias=False, + ): - x_array = np.random.randint(low=0, high=255, size=x_shape).astype("uint8") - w_array = np.random.uniform(low=0, high=255, size=w_shape).astype("uint8") + x_array = np.random.randint(low=0, high=255, size=x_shape).astype("uint8") + w_array = np.random.uniform(low=0, high=255, size=w_shape).astype("uint8") - initializer = [ - helper.make_tensor("x_scale", TensorProto.FLOAT, (), [np.random.rand()]), - helper.make_tensor("x_zero_point", TensorProto.UINT8, (), [np.random.randint(0, 255)]), - helper.make_tensor("w_scale", TensorProto.FLOAT, (), [np.random.rand()]), - helper.make_tensor("w_zero_point", TensorProto.UINT8, (), [np.random.randint(0, 255)]), - helper.make_tensor("y_scale", TensorProto.FLOAT, (), [np.random.rand()]), - helper.make_tensor("y_zero_point", TensorProto.UINT8, (), [np.random.randint(0, 255)]), - ] + initializer = [ + helper.make_tensor("x_scale", TensorProto.FLOAT, (), [np.random.rand()]), + helper.make_tensor("x_zero_point", TensorProto.UINT8, (), [np.random.randint(0, 255)]), + helper.make_tensor("w_scale", TensorProto.FLOAT, (), [np.random.rand()]), + helper.make_tensor("w_zero_point", TensorProto.UINT8, (), [np.random.randint(0, 255)]), + helper.make_tensor("y_scale", TensorProto.FLOAT, (), [np.random.rand()]), + helper.make_tensor("y_zero_point", TensorProto.UINT8, (), [np.random.randint(0, 255)]), + ] - input_nodes = [ - helper.make_tensor_value_info("x", TensorProto.UINT8, list(x_shape)), - helper.make_tensor_value_info("w", TensorProto.UINT8, list(w_shape)), - ] - input_names = [ - "x", - "x_scale", - "x_zero_point", - "w", - "w_scale", - "w_zero_point", - "y_scale", - "y_zero_point", - ] - input_values = [x_array, w_array] - - if bias is True: - b_shape = w_shape[0:1] - b_array = np.random.randint(low=0, high=65536, size=b_shape).astype("int32") - input_nodes.append(helper.make_tensor_value_info("B", TensorProto.INT32, list(b_shape))) - input_names.append("B") - input_values.append(b_array) - - if padding is None: - ## autopadding with unset default attributes - kwargs = {} - if not all([s == 1 for s in strides]): - kwargs["strides"] = strides - if not all([d == 1 for d in dilations]): - kwargs["dilations"] = dilations + input_nodes = [ + helper.make_tensor_value_info("x", TensorProto.UINT8, list(x_shape)), + helper.make_tensor_value_info("w", TensorProto.UINT8, list(w_shape)), + ] + input_names = [ + "x", + "x_scale", + "x_zero_point", + "w", + "w_scale", + "w_zero_point", + "y_scale", + "y_zero_point", + ] + input_values = [x_array, w_array] - node = helper.make_node( - "QLinearConv", - inputs=input_names, - outputs=["y"], - # Default values for other attributes: - auto_pad=auto_pad, - **kwargs, - ) - else: - node = helper.make_node( - "QLinearConv", - inputs=input_names, - outputs=["y"], - kernel_shape=kernel_shape, - # Default values for other attributes: - strides=strides, - dilations=dilations, - # groups=1 - pads=padding, - ) + if bias is True: + b_shape = w_shape[0:1] + b_array = np.random.randint(low=0, high=65536, size=b_shape).astype("int32") + input_nodes.append(helper.make_tensor_value_info("B", TensorProto.INT32, list(b_shape))) + input_names.append("B") + input_values.append(b_array) - graph = helper.make_graph( - [node], - "conv_test", - inputs=input_nodes, - outputs=[helper.make_tensor_value_info("y", TensorProto.UINT8, list(y_shape))], - initializer=initializer, - ) - model = helper.make_model(graph, producer_name="qlinearconv_test") - # opt_level=1 will cause error - verify_with_ort_with_inputs(model, input_values, opt_level=2) + if padding is None: + ## autopadding with unset default attributes + kwargs = {} + if not all([s == 1 for s in strides]): + kwargs["strides"] = strides + if not all([d == 1 for d in dilations]): + kwargs["dilations"] = dilations + + node = helper.make_node( + "QLinearConv", + inputs=input_names, + outputs=["y"], + # Default values for other attributes: + auto_pad=auto_pad, + **kwargs, + ) + else: + node = helper.make_node( + "QLinearConv", + inputs=input_names, + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + # groups=1 + pads=padding, + ) + graph = helper.make_graph( + [node], + "conv_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("y", TensorProto.UINT8, list(y_shape))], + initializer=initializer, + ) + model = helper.make_model(graph, producer_name="qlinearconv_test") + # opt_level=1 will cause error + verify_with_ort_with_inputs(model, input_values, opt_level=2, target=target, dev=dev) -def test_qlinearconv(): def repeat(N, D): return tuple([N for _ in range(D)]) @@ -4749,7 +5218,7 @@ def repeat(N, D): bias=True, ) - # Convolution with assymetric padding + # Convolution with asymmetric padding verify_qlinearconv( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), @@ -4814,62 +5283,278 @@ def repeat(N, D): ) -def verify_qlinearadd(a_shape, b_shape, c_shape): +@tvm.testing.parametrize_targets +def test_qlinearadd(target, dev): + def verify_qlinearadd(a_shape, b_shape, c_shape): - a_array = np.random.random(a_shape).astype("float32") - b_array = np.random.random(b_shape).astype("float32") + a_array = np.random.random(a_shape).astype("float32") + b_array = np.random.random(b_shape).astype("float32") - input_nodes = [ - helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), - helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), - ] - input_names = [ - "a", - "b", - ] - input_values = [a_array, b_array] + input_nodes = [ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ] + input_names = [ + "a", + "b", + ] + input_values = [a_array, b_array] + + node = helper.make_node("Add", ["a", "b"], ["C"]) + graph = helper.make_graph( + [node], + "qlinearadd_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(c_shape))], + ) + model = helper.make_model(graph, producer_name="qlinearadd_test") + quantize_and_verify_with_ort(model, input_names, [a_shape, b_shape], target, dev) + + verify_qlinearadd([4, 2], [4, 2], [4, 2]) + verify_qlinearadd([4, 2], [2], [4, 2]) + verify_qlinearadd([5, 1, 7], [2, 7], [5, 2, 7]) - node = helper.make_node("QLinearAdd", inputs=input_names, outputs=["C"]) - node = helper.make_node("Add", ["a", "b"], ["C"]) - graph = helper.make_graph( - [node], - "qlinearadd_test", - inputs=input_nodes, - outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(c_shape))], +@tvm.testing.parametrize_targets +def test_qlinearmul(target, dev): + def verify_qlinearmul(a_shape, b_shape, c_shape): + + a_array = np.random.random(a_shape).astype("float32") + b_array = np.random.random(b_shape).astype("float32") + + input_nodes = [ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ] + input_names = [ + "a", + "b", + ] + input_values = [a_array, b_array] + + node = helper.make_node("Mul", input_names, ["C"]) + graph = helper.make_graph( + [node], + "qlinearmul_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(c_shape))], + ) + model = helper.make_model(graph, producer_name="qlinearmul_test") + quantize_and_verify_with_ort(model, input_names, [a_shape, b_shape], target, dev) + + verify_qlinearmul([4, 2], [4, 2], [4, 2]) + verify_qlinearmul([4, 2], [2], [4, 2]) + verify_qlinearmul([5, 1, 7], [2, 7], [5, 2, 7]) + + +@tvm.testing.parametrize_targets +def test_random_uniform(target, dev): + def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None): + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + node = helper.make_node( + "RandomUniform", [], ["out"], shape=shape, dtype=ONNX_DTYPE, high=high, low=low + ) + if seed is not None: + seed_attr = helper.make_attribute("seed", seed) + node.attribute.append(seed_attr) + + graph = helper.make_graph( + [node], + "random_uniform_test", + inputs=[], + outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, shape)], + ) + model = helper.make_model(graph, producer_name="random_uniform_test") + return get_tvm_output_with_vm(model, [], target=target, dev=dev) + + # Check that function runs and produces proper shape. + vals = get_random_uniform([10], dtype="float32") + assert list(vals.shape) == [10] + assert vals.dtype == "float32" + + # Test N-D tensor generation. + vals = get_random_uniform([1, 3, 100, 100], dtype="float32") + assert list(vals.shape) == [1, 3, 100, 100] + + # Check that bounds aren't exceeded. + vals = get_random_uniform(shape=[100], high=100, low=-100) + assert list(vals.shape) == [100] + assert all(vals >= -100) and all(vals <= 100) + + # Check that a fixed seed produces the same values when run twice. + vals_1 = get_random_uniform(shape=[10], seed=1) + vals_2 = get_random_uniform(shape=[10], seed=1) + assert all(vals_1 == vals_2) + + # Test against an expected output with a fixed seed. + real = get_random_uniform(shape=[10], seed=5) + expected = np.asarray( + [ + 0.043976, + 0.96656, + 0.292199, + 0.904297, + 0.25167, + 0.521778, + 0.778985, + 0.085463, + 0.939846, + 0.194201, + ] ) - model = helper.make_model(graph, producer_name="qlinearconv_test") - from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType + tvm.testing.assert_allclose(real, expected, rtol=1e-5) + + +@tvm.testing.parametrize_targets +def test_convinteger(target, dev): + def verify_convinteger( + x_shape, + w_shape, + y_shape, + padding, + kernel_shape, + strides, + dilations, + auto_pad="NOTSET", + dtype="uint8", + ): + x_array = np.random.randint(low=0, high=255, size=x_shape).astype(dtype) + w_array = np.random.uniform(low=0, high=255, size=w_shape).astype(dtype) + x_zero_point_array = np.random.randint(0, 255, size=[1]).astype(dtype) + w_zero_point_array = np.random.randint(0, 255, size=[1]).astype(dtype) - class RandomDataReader(CalibrationDataReader): - def __init__(self, n=10): - self.data = iter( - [ - { - "a": np.random.random(a_shape).astype("float32"), - "b": np.random.random(b_shape).astype("float32"), - } - for _ in range(n) - ] + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + input_nodes = [ + helper.make_tensor_value_info("x", ONNX_DTYPE, list(x_shape)), + helper.make_tensor_value_info("w", ONNX_DTYPE, list(w_shape)), + ] + initializer = [ + helper.make_tensor("x_zero_point", ONNX_DTYPE, [], x_zero_point_array), + helper.make_tensor("w_zero_point", ONNX_DTYPE, [], w_zero_point_array), + ] + input_names = ["x", "w", "x_zero_point", "w_zero_point"] + input_values = [x_array, w_array] + + if padding is None: + ## autopadding with unset default attributes + kwargs = {} + if not all([s == 1 for s in strides]): + kwargs["strides"] = strides + if not all([d == 1 for d in dilations]): + kwargs["dilations"] = dilations + + node = helper.make_node( + "ConvInteger", + inputs=input_names, + outputs=["y"], + # Default values for other attributes: + auto_pad=auto_pad, + **kwargs, + ) + else: + node = helper.make_node( + "ConvInteger", + inputs=input_names, + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + # groups=1 + pads=padding, ) - def get_next(self): - return next(self.data, None) + graph = helper.make_graph( + [node], + "convinteger_test", + inputs=input_nodes, + initializer=initializer, + outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, list(y_shape))], + ) + model = helper.make_model(graph, producer_name="convinteger_test") + # opt_level=1 will cause error + verify_with_ort_with_inputs(model, input_values, target=target, dev=dev, opt_level=2) - d = tvm.contrib.utils.tempdir() - model_fp32 = os.path.join(d.temp_dir, "model.onnx") - onnx.save_model(model, model_fp32) - model_quant = os.path.join(d.temp_dir, "model.quant.onnx") - quantized_model = quantize_static(model_fp32, model_quant, RandomDataReader()) - # opt_level=1 will cause error with qnn lowering - model = onnx.load(model_quant) - verify_with_ort_with_inputs(model, input_values, opt_level=2) + def repeat(N, D): + return tuple([N for _ in range(D)]) + # only support 2D ConvInteger because we only support qnn.conv2d for now. + D = 2 -def test_qlinearadd(): - verify_qlinearadd([4, 2], [4, 2], [4, 2]) - verify_qlinearadd([4, 2], [2], [4, 2]) - verify_qlinearadd([5, 1, 7], [2, 7], [5, 2, 7]) + # Convolution with padding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + + # Convolution with asymmetric padding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(4, D), + repeat(0, D) + repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution without padding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + 2 * repeat(0, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution with autopadding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with valid autopadding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="VALID", + ) + # Convolution with non uniform stride + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + None, + repeat(3, D), + repeat(2, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with dilation + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(2, D), + repeat(3, D), + repeat(1, D), + repeat(2, D), + ) if __name__ == "__main__": @@ -4897,7 +5582,8 @@ def test_qlinearadd(): test_scatter() test_lrn() test_instance_norm() - test_upsample() + test_upsample_nearest() + test_upsample_bilinear() test_forward_min() test_forward_max() test_forward_mean() @@ -4955,3 +5641,6 @@ def test_qlinearadd(): test_reverse_sequence() test_eyelike() test_qlinearconv() + test_random_uniform() + test_convinteger() + test_batch_matmul() diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py new file mode 100644 index 000000000000..db07e07f9d83 --- /dev/null +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -0,0 +1,661 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from pathlib import Path +import shutil + +import numpy as np +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relay +from tvm.contrib import graph_executor + +import paddle +import paddle.nn as nn + +PADDLE_TEST_DATA_ROOT_PATH = Path(Path("~").expanduser(), ".tvm_test_data", "paddle") +PADDLE_TEST_DATA_ROOT_PATH.mkdir(parents=True, exist_ok=True) + + +def assert_shapes_match(tru, est): + if tru.shape != est.shape: + msg = "Output shapes {} and {} don't match" + raise AssertionError(msg.format(tru.shape, est.shape)) + + +def get_paddle_model(func, input_spec): + global PADDLE_TEST_DATA_ROOT_PATH + model_path = Path(PADDLE_TEST_DATA_ROOT_PATH, "model") + + paddle.jit.save(func, str(model_path), input_spec=input_spec) + baseline_model = paddle.jit.load(str(model_path)) + + shutil.rmtree(str(PADDLE_TEST_DATA_ROOT_PATH)) + return baseline_model + + +def verify_model(func, input_data, rtol=1e-5, atol=1e-5): + if not (isinstance(input_data, (tuple, list))): + input_data = [input_data] + + input_spec = [] + input_names = [] + input_shape_dict = {} + compiled_input = {} + for idx, data in enumerate(input_data): + input_name = "input{}".format(idx) + input_spec.append( + paddle.static.InputSpec(dtype=data.dtype, shape=data.shape, name=input_name) + ) + input_names.append(input_name) + input_shape_dict[input_name] = data.shape + if isinstance(data, np.ndarray): + compiled_input[input_name] = data + else: + compiled_input[input_name] = data.numpy() + + baseline_model = get_paddle_model(func, input_spec) + baseline_outputs = baseline_model(*[input[:] for input in input_data]) + + # get paddle outputs + if isinstance(baseline_outputs, (tuple, list)): + baseline_outputs = tuple(out.numpy() for out in baseline_outputs) + else: + baseline_outputs = (baseline_outputs.numpy(),) + + mod, params = relay.frontend.from_paddle(baseline_model, input_shape_dict) + parms_num = min(len(input_names), len(mod["main"].params)) + compiled_names = [] + for arg in mod["main"].params[:parms_num]: + assert arg.name_hint in input_names + compiled_names.append(arg.name_hint) + + with tvm.transform.PassContext(opt_level=3): + for target, dev in tvm.testing.enabled_targets(): + lib = relay.build(mod, target=target, params=params) + gmod = graph_executor.GraphModule(lib["default"](dev)) + for name in compiled_names: + gmod.set_input(name, compiled_input[name]) + gmod.run() + + for i, baseline_output in enumerate(baseline_outputs): + compiled_output = gmod.get_output(i).numpy() + + assert_shapes_match(baseline_output, compiled_output) + tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol) + + +@tvm.testing.uses_gpu +def test_forward_add_subtract(): + input_shape = [10] + + @paddle.jit.to_static + def add_subtract(inputs): + return paddle.subtract(paddle.add(inputs, inputs), inputs) + + @paddle.jit.to_static + def add_subtract2(inputs): + return inputs + 1 - 2 + + @paddle.jit.to_static + def add_subtract3(inputs1, inputs2): + ones = paddle.ones([10], dtype="float32") + return inputs1 + ones - inputs2 + + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(add_subtract, input_data) + verify_model(add_subtract2, input_data) + input_data2 = paddle.rand(input_shape, dtype="float32") + verify_model(add_subtract3, [input_data, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_argmax(): + input_shape = [1, 3, 10, 10] + + class ArgMax(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.argmax(inputs) + + class ArgMax1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmax(axis=1) + + class ArgMax2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmax(axis=1, keepdim=False) + + class ArgMax3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return inputs.argmax(axis=2, keepdim=True) + + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(ArgMax(), input_data=input_data) + verify_model(ArgMax1(), input_data=input_data) + verify_model(ArgMax2(), input_data=input_data) + verify_model(ArgMax3(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_assign(): + @paddle.jit.to_static + def assign(inputs): + return paddle.assign(inputs) + + input_shape = [2, 3] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model( + assign, + [ + input_data, + ], + ) + input_data2 = np.random.randint(100, size=input_shape) + verify_model( + assign, + [ + input_data2, + ], + ) + + +@tvm.testing.uses_gpu +def test_forward_batch_norm(): + class BatchNorm1D(nn.Layer): + def __init__(self): + super(BatchNorm1D, self).__init__() + self.batch_norm = nn.BatchNorm1D(2) + + @paddle.jit.to_static + def forward(self, input_data): + return self.batch_norm(input_data) + + class BatchNorm2D(nn.Layer): + def __init__(self): + super(BatchNorm2D, self).__init__() + self.batch_norm = nn.BatchNorm2D(2) + + @paddle.jit.to_static + def forward(self, input_data): + return self.batch_norm(input_data) + + class BatchNorm3D(nn.Layer): + def __init__(self): + super(BatchNorm3D, self).__init__() + self.batch_norm = nn.BatchNorm3D(2) + + @paddle.jit.to_static + def forward(self, input_data): + return self.batch_norm(input_data) + + input_data = paddle.rand((2, 2, 3), dtype="float32") + verify_model(BatchNorm1D(), input_data=input_data) + input_data = paddle.rand((2, 2, 2, 3), dtype="float32") + verify_model(BatchNorm2D(), input_data=input_data) + input_data = paddle.rand((2, 2, 2, 2, 3), dtype="float32") + verify_model(BatchNorm3D(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_cast(): + @paddle.jit.to_static + def cast1(inputs, dtype="uint8"): + return paddle.cast(inputs, dtype) + + @paddle.jit.to_static + def cast2(inputs, dtype="int64"): + return inputs.cast(dtype) + + input_shape = [2, 3] + input_data = paddle.rand(input_shape, dtype="float32") * 100 + verify_model( + cast1, + [ + input_data, + ], + ) + verify_model( + cast2, + [ + input_data, + ], + ) + + +@tvm.testing.uses_gpu +def test_forward_concat_unsqueeze(): + @paddle.jit.to_static + def concat_unsqueeze1(inputs): + return paddle.concat([inputs[:, 0].unsqueeze(1), inputs[:, 1].unsqueeze(1)], axis=1) + + @paddle.jit.to_static + def concat_unsqueeze2(inputs): + a = (inputs[:, :, 0] + 2) * 7 + b = (inputs[:, :, 1] + 3) * 11 + c = (inputs[:, :, 2] + 5) * 13 + return paddle.concat([paddle.unsqueeze(t, axis=2) for t in [a, b, c]], axis=2) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(concat_unsqueeze1, input_data=input_data) + verify_model(concat_unsqueeze2, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_cumsum(): + @paddle.jit.to_static + def cusum1(inputs): + return paddle.cumsum(inputs) + + @paddle.jit.to_static + def cusum2(inputs): + return paddle.cumsum(inputs, axis=0) + + @paddle.jit.to_static + def cusum3(inputs): + return paddle.cumsum(inputs, axis=1) + + input_data = paddle.randint(0, 100, (10, 10), dtype=paddle.int32) + verify_model(cusum1, [input_data]) + verify_model(cusum1, [input_data.astype(paddle.int64)]) + verify_model( + cusum2, + [ + input_data, + ], + ) + verify_model( + cusum3, + [ + input_data, + ], + ) + + +@tvm.testing.uses_gpu +def test_forward_conv(): + conv2d_input_shape = [1, 3, 10, 10] + + class Conv2D1(nn.Layer): + def __init__(self): + super(Conv2D1, self).__init__() + self.conv = nn.Conv2D(3, 6, 7, bias_attr=True) + self.softmax = nn.Softmax() + + @paddle.jit.to_static + def forward(self, inputs): + return self.softmax(self.conv(inputs)) + + class Conv2D2(nn.Layer): + def __init__(self): + super(Conv2D2, self).__init__() + self.conv = nn.Conv2D(3, 6, 7, groups=3, bias_attr=False) + self.softmax = nn.Softmax() + + @paddle.jit.to_static + def forward(self, inputs): + return self.softmax(self.conv(inputs)) + + conv2d_input_data = paddle.rand(conv2d_input_shape, dtype="float32") + verify_model(Conv2D1(), input_data=conv2d_input_data) + verify_model(Conv2D2(), input_data=conv2d_input_data) + + +@tvm.testing.uses_gpu +def test_forward_dropout(): + @paddle.jit.to_static + def dropout(inputs): + return nn.functional.dropout(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(dropout, input_data=input_data[0, 0]) + verify_model(dropout, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_shape_full(): + @paddle.jit.to_static + def full1(inputs): + return paddle.full(paddle.shape(inputs), 3.14) + + @paddle.jit.to_static + def full2(inputs): + return paddle.full(paddle.shape(inputs), 1.0, dtype=inputs.dtype) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(full1, input_data=[input_data]) + verify_model(full2, input_data=[input_data]) + + +@tvm.testing.uses_gpu +def test_forward_ones_like(): + @paddle.jit.to_static + def ones_like1(inputs): + return paddle.ones_like(inputs) + + @paddle.jit.to_static + def ones_like2(inputs): + return paddle.ones_like(inputs, dtype="int32") + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(ones_like1, input_data=input_data) + verify_model(ones_like2, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_gelu(): + @paddle.jit.to_static + def gelu(inputs): + return nn.functional.gelu(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(gelu, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_hard_sigmoid(): + @paddle.jit.to_static + def hard_sigmoid(inputs): + return nn.functional.hardsigmoid(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(hard_sigmoid, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_hard_swish(): + @paddle.jit.to_static + def hard_swish(inputs): + return nn.functional.hardswish(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(hard_swish, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_layer_norm(): + @paddle.jit.to_static + def layer_norm(inputs, weight, bias): + return nn.functional.layer_norm(inputs, inputs.shape[-1], weight=weight, bias=bias) + + class LayerNorm(nn.Layer): + def __init__(self): + super(LayerNorm, self).__init__() + data_shape = [10] + self.layer_norm = nn.LayerNorm(data_shape) + + @paddle.jit.to_static + def forward(self, inputs): + return self.layer_norm(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + weight = paddle.rand([10], dtype="float32") + bias = paddle.rand([10], dtype="float32") + verify_model(layer_norm, input_data=[input_data, weight, bias]) + verify_model(LayerNorm(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_leaky_relu(): + @paddle.jit.to_static + def leaky_relu(inputs): + return nn.functional.leaky_relu(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(leaky_relu, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_look_up(): + @paddle.jit.to_static + def look_up(inputs, weight): + return nn.functional.embedding(inputs, weight) + + class LookUp(nn.Layer): + def __init__(self): + super(LookUp, self).__init__() + self.embedding = paddle.nn.Embedding(10, 4, sparse=True) + + @paddle.jit.to_static + def forward(self, inputs): + return self.embedding(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.randint(0, 10, input_shape, dtype="int32") + weight = paddle.rand([10, 4], dtype="float32") + verify_model(look_up, input_data=[input_data, weight]) + verify_model(LookUp(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_multiply(): + @paddle.jit.to_static + def multiply1(inputs): + return inputs * inputs + + @paddle.jit.to_static + def multiply2(inputs): + return inputs * 1.0 / 2.0 + + @paddle.jit.to_static + def multiply3(inputs, inputs2): + ones = paddle.ones([10], dtype="float32") + return inputs * ones / inputs2 + + input_shape = [10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(multiply1, input_data=input_data) + verify_model(multiply2, input_data=input_data) + input_data2 = paddle.rand(input_shape, dtype="float32") + verify_model(multiply3, input_data=[input_data, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_matmul(): + class MatMul1(nn.Layer): + def forward(self, input1, input2): + return paddle.matmul(input1, input2) + + # matrix x vector + input_data1 = paddle.randn((3, 4), dtype="float32") + input_data2 = paddle.randn((4,), dtype="float32") + verify_model(MatMul1(), input_data=[input_data1, input_data2]) + + # matrix x matrix + input_data1 = paddle.randn((5, 4), dtype="float32") + input_data2 = paddle.randn((4, 5), dtype="float32") + verify_model(MatMul1(), input_data=[input_data1, input_data2]) + + # batched matrix x batched matrix + input_data1 = paddle.randn((10, 3, 4), dtype="float32") + input_data2 = paddle.randn((10, 4, 5), dtype="float32") + verify_model(MatMul1(), input_data=[input_data1, input_data2]) + + # batched matrix x broadcasted matrix + input_data1 = paddle.randn((10, 3, 4), dtype="float32") + input_data2 = paddle.randn((4, 5), dtype="float32") + verify_model(MatMul1(), input_data=[input_data1, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_pool2d(): + @paddle.jit.to_static + def pool2d1(inputs): + return nn.functional.avg_pool2d(inputs, kernel_size=2, stride=2, padding=0) + + @paddle.jit.to_static + def pool2d2(inputs): + return nn.functional.adaptive_avg_pool2d(inputs, output_size=[3, 3]) + + @paddle.jit.to_static + def pool2d3(inputs): + return nn.functional.max_pool2d( + inputs, kernel_size=2, stride=2, padding=0, return_mask=True + ) + + input_data = paddle.uniform(shape=[1, 2, 32, 32], dtype="float32", min=-1, max=1) + verify_model(pool2d1, input_data=input_data) + verify_model(pool2d2, input_data=input_data) + # verify_model(pool2d3, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_relu(): + @paddle.jit.to_static + def relu(inputs): + return nn.functional.relu(inputs) + + input_shape = [10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(relu, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_reshape(): + @paddle.jit.to_static + def reshape1(inputs, x): + new_shape = paddle.shape(x) + return paddle.reshape(inputs, new_shape) + + @paddle.jit.to_static + def reshape2(inputs): + return inputs.reshape([-1]) + + @paddle.jit.to_static + def reshape3(inputs): + data_shape = inputs.shape + return inputs.reshape([data_shape[0] * data_shape[1], data_shape[2]]) + + @paddle.jit.to_static + def reshape4(inputs, x): + new_shape = paddle.shape(x) + return paddle.reshape(inputs, [new_shape[2], 2, -1]) + + input_shape = [2, 1, 10, 1, 10] + input_data = paddle.rand(input_shape, dtype="float32") + input_data2 = paddle.randn([2, 1, 10, 10]) + verify_model(reshape1, input_data=[input_data, input_data2]) + verify_model(reshape2, input_data=input_data) + verify_model(reshape3, input_data=paddle.randn((2, 3, 4))) + verify_model(reshape4, input_data=[input_data, input_data2]) + + +@tvm.testing.uses_gpu +def test_forward_scale(): + @paddle.jit.to_static + def scale1(inputs): + return paddle.scale(inputs, scale=2.0, bias=1.0) + + @paddle.jit.to_static + def scale2(inputs): + return paddle.scale(inputs, scale=3, bias=2.1, act="gelu") + + input_data = paddle.randn(shape=[2, 3], dtype="float32") + verify_model( + scale1, + input_data=[ + input_data, + ], + ) + verify_model(scale2, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_slice(): + @paddle.jit.to_static + def slice1(inputs): + return inputs[:, :, :, :3] + + @paddle.jit.to_static + def slice2(inputs): + return inputs[0, :, :-3, :] + + @paddle.jit.to_static + def slice3(inputs): + return inputs[0::2, 0::2] + inputs[1::2, 1::2] + + @paddle.jit.to_static + def slice4(inputs): + x0 = paddle.to_tensor([2]) - paddle.to_tensor([1]) + x1 = paddle.to_tensor([3]) + paddle.to_tensor([1]) + return inputs[:, x0:, 1:x1, :] + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model( + slice1, + input_data=[ + input_data, + ], + ) + verify_model(slice2, input_data=input_data) + # need op "strided_slice" + # verify_model(slice3, input_data=paddle.randn((4, 4))) + # need op "assign_value" + # verify_model(slice4, input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_tanh(): + @paddle.jit.to_static + def tanh(inputs): + return paddle.tanh(inputs) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(tanh, input_data=input_data) + + +if __name__ == "__main__": + test_forward_add_subtract() + test_forward_argmax() + test_forward_assign() + test_forward_batch_norm() + test_forward_cast() + test_forward_concat_unsqueeze() + test_forward_cumsum() + test_forward_conv() + test_forward_dropout() + test_forward_shape_full() + test_forward_ones_like() + test_forward_gelu() + test_forward_hard_sigmoid() + test_forward_hard_swish() + test_forward_layer_norm() + test_forward_leaky_relu() + test_forward_look_up() + test_forward_multiply() + test_forward_matmul() + test_forward_pool2d() + test_forward_relu() + test_forward_reshape() + test_forward_scale() + test_forward_slice() + test_forward_tanh() diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index f76ea9a5d324..bae7c1b5498c 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -666,7 +666,7 @@ def test_forward_leakyrelu(): def test_forward_elu(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] - input_data = torch.rand(input_shape).float() + input_data = torch.randn(input_shape).float() verify_model(torch.nn.ELU().eval(), input_data=input_data) verify_model(torch.nn.ELU(alpha=0.3).eval(), input_data=input_data) verify_model(torch.nn.ELU(alpha=1.0).eval(), input_data=input_data) @@ -700,6 +700,14 @@ def test_forward_selu(): verify_model(torch.nn.SELU().eval(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_silu(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.SiLU().eval(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_softplus(): torch.set_grad_enabled(False) @@ -1569,8 +1577,10 @@ def forward(self, input, weight): return F.linear(input, weight) input2d = torch.rand([2, 2]).float() + input3d = torch.rand([4, 3, 2]).float() weight1d = torch.rand([2]).float() weight2d = torch.rand([2, 2]).float() + weight3x2 = torch.rand([3, 2]).float() bias1d = torch.rand([2]).float() bias2d = torch.rand([2, 2]).float() # 2D input, 2D weight, 1D bias @@ -1579,9 +1589,12 @@ def forward(self, input, weight): verify_model(Linear(), input_data=[input2d, weight2d, bias2d]) # 2D input, 2D weight, no bias verify_model(LinearNoBias(), input_data=[input2d, weight2d]) + verify_model(LinearNoBias(), input_data=[input2d, weight3x2]) # 2D input, 1D weight, 1D bias is not supported by torch.linear() # 2D input, 1D weight, no bias verify_model(LinearNoBias(), input_data=[input2d, weight1d]) + # 3D input, 2D weight, no bias + verify_model(LinearNoBias(), input_data=[input3d, weight3x2]) # TODO: Add the following cases when matmul(1D, _) is supported by TVM # 1D input, 2D weight, 1D bias # 1D input, 2D weight, no bias @@ -1756,6 +1769,9 @@ def forward(self, x): verify_model(Upsample(size=(64, 64), mode="bilinear", align_corners=True), inp) verify_model(Upsample(scale=2, mode="bilinear", align_corners=True), inp) verify_model(Upsample(size=(50, 50), mode="bilinear", align_corners=True), inp) + verify_model(Upsample(size=(64, 64), mode="bicubic", align_corners=True), inp) + verify_model(Upsample(scale=2, mode="bicubic", align_corners=True), inp) + verify_model(Upsample(size=(50, 50), mode="bicubic", align_corners=True), inp) @tvm.testing.uses_gpu @@ -2224,8 +2240,7 @@ def verify_model_vm(input_model, ishapes, idtype=None, idata=None, targets=["llv print("Running on target", tgt) dev = tvm.device(tgt, 0) - executor = relay.create_executor("vm", mod=mod, device=dev, target=tgt) - evaluator = executor.evaluate() + evaluator = relay.create_executor("vm", mod=mod, device=dev, target=tgt).evaluate() # Inference for name, inp in zip(input_names, input_data): @@ -3981,6 +3996,7 @@ def forward(self, x): test_forward_logsoftmax() test_forward_sigmoid() test_forward_dense() + test_forward_linear() test_forward_avgpool1d() test_forward_avgpool2d() test_forward_avgpool3d() diff --git a/tests/python/frontend/pytorch/test_lstm.py b/tests/python/frontend/pytorch/test_lstm.py index 1aa8bff4076e..25d4563ee64e 100644 --- a/tests/python/frontend/pytorch/test_lstm.py +++ b/tests/python/frontend/pytorch/test_lstm.py @@ -221,9 +221,9 @@ def assert_equal(tvm_result, torch_result): def run_and_compare(mod, params, pt_result, target, device): - executor = relay.create_executor("vm", mod=mod, device=device, target=target) - evaluator = executor.evaluate() - exec_res = evaluator(**params) + exec_res = relay.create_executor("vm", mod=mod, device=device, target=target).evaluate()( + **params + ) def flatten(nested): res = [] diff --git a/tests/python/frontend/pytorch/test_rnns.py b/tests/python/frontend/pytorch/test_rnns.py new file mode 100644 index 000000000000..b5784a6fe1e1 --- /dev/null +++ b/tests/python/frontend/pytorch/test_rnns.py @@ -0,0 +1,430 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +import torch +import onnx +import io +import sys + +from tvm import relay +from tvm.contrib import graph_executor + +from torch import nn + +## LSTM parameters +lstm_feature_size = 16 +lstm_hidden_size = 32 +lstm_projection_size = 20 + +## GRU parameters +gru_feature_size = 8 +gru_hidden_size = 16 + +num_layers = 2 +seqs_length = 2 +batch_size = 2 + + +class RNN_Model(nn.Module): + """ + It is base class for RNN layer classes. + It contains some common fields and methods for child classes. + """ + + def __init__( + self, + ): + super().__init__() + + # model is defined in child class + self.model = None + + def forward(self, input, hidden_init=None): + """ + Computes the output tensor after input inference along RNN layer. + + :param input: batch of data as a tensor of shape (seqs_length, batch_size, feature_size) or (batch_size, seqs_length, feature_size) if self.batch_first = True + :param hidden_init: initial hidden state(s) of the RNN as a tensor(s) of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. + :return: the output tensor of shape (batch_size, hidden_size) + """ + if self.model is None: + raise NotImplementedError("self.model must be defined in subclasses!") + out, _ = self.model(input, hidden_init) + + return out + + def gen_rnd_weights(self): + """ + Generate random weigths for the model + """ + if self.model is None: + raise NotImplementedError("self.model must be defined in subclasses!") + with torch.no_grad(): + for weight_group in self.model.all_weights: + for weight in weight_group: + weight.data = torch.rand(weight.shape) + + def get_dummy_inputs(self): + raise NotImplementedError("subclasses must override get_dummy_inputs()!") + + def get_input_names(self): + raise NotImplementedError("subclasses must override get_input_names()!") + + def get_shape_desc(self, frontend_type): + raise NotImplementedError("subclasses must override get_shape_desc(frontend_type)!") + + def get_tvm_inputs(self, dtype): + raise NotImplementedError("subclasses must override get_tvm_inputs(dtype)!") + + +class GRU_Model(RNN_Model): + def __init__( + self, + seq_len=seqs_length, + batch_size=batch_size, + feature_size=gru_feature_size, + hidden_size=gru_hidden_size, + batch_first=False, + layer_num=1, + bidirectional=False, + use_bias=True, + rnd_weights_init=False, + ): + super().__init__() + + # Shapes + self.shape = [seq_len, batch_size, feature_size] + if batch_first: + self.shape = [batch_size, seq_len, feature_size] + layers_num = 2 * layer_num if bidirectional else layer_num + self.h0_shape = [layers_num, batch_size, hidden_size] + # Dummy inputs + self.dummy_inputs = (torch.rand(self.shape), torch.zeros(self.h0_shape)) + + self.model = nn.GRU( + input_size=feature_size, + hidden_size=hidden_size, + num_layers=layer_num, + bidirectional=bidirectional, + batch_first=batch_first, + bias=use_bias, + ) + + if rnd_weights_init: + self.gen_rnd_weights() + + def gen_rnd_weights(self): + """ + Generate random weigths for the model with biases + For first uni- and bidirectional weights group: + Wi (3*hidden_size, feature_size) + Wh (3*hidden_size, hidden_size) + Bi (3*hidden_size) + Bh (3*hidden_size) + For other weights group: + Wi (3*hidden_size, hidden_size) + Wh (3*hidden_size, hidden_size) + Bi (3*hidden_size) + Bh (3*hidden_size) + For generation of random weigths for the model without biases the Bi and Bh weights are skipped + """ + super().gen_rnd_weights() + + def get_dummy_inputs(self): + return self.dummy_inputs + + def get_input_names(self): + return ["input", "h0"] + + def get_shape_desc(self, frontend_type): + shape_desc = None + if frontend_type == "pt": # PyTorch + shape_desc = [("input", self.shape)] + elif frontend_type == "onnx": # ONNX + shape_desc = { + "input": self.shape, + "h0": self.h0_shape, + } + return shape_desc + + def get_tvm_inputs(self, dtype): + return { + "input": tvm.nd.array(self.dummy_inputs[0].numpy().astype(dtype)), + "h0": tvm.nd.array(self.dummy_inputs[1].numpy().astype(dtype)), + } + + +def check_torch_version_for_proj_in_lstm(): + """ + proj_size parameter is supported in torch.nn.LSTM layer started from 1.8.0 torch version + """ + me = False + + version = torch.__version__ + major, minor, micro = version.split(".") + + if int(major) > 1: + me = True + elif int(major) == 1: + if int(minor) >= 8: + me = True + + return me + + +class LSTM_Model(RNN_Model): + def __init__( + self, + seq_len=seqs_length, + batch_size=batch_size, + feature_size=lstm_feature_size, + hidden_size=lstm_hidden_size, + batch_first=False, + layer_num=1, + bidirectional=False, + proj_size=0, + use_bias=True, + rnd_weights_init=False, + ): + super().__init__() + + # Shapes + self.shape = [seq_len, batch_size, feature_size] + if batch_first: + self.shape = [batch_size, seq_len, feature_size] + layers_num = 2 * layer_num if bidirectional else layer_num + self.h0_shape = [layers_num, batch_size, hidden_size] + if proj_size > 0: + self.h0_shape = [layers_num, batch_size, proj_size] + self.c0_shape = [layers_num, batch_size, hidden_size] + # Dummy inputs + self.dummy_inputs = ( + torch.rand(self.shape), + (torch.zeros(self.h0_shape), torch.zeros(self.c0_shape)), + ) + + if check_torch_version_for_proj_in_lstm(): + self.model = nn.LSTM( + input_size=lstm_feature_size, + hidden_size=lstm_hidden_size, + num_layers=layer_num, + bidirectional=bidirectional, + proj_size=proj_size, + batch_first=batch_first, + bias=use_bias, + ) + else: + if proj_size > 0: + print( + "WARNING: projection is not supported for torch version less than 1.8.0! ", + "LSTM was constructed without projection!", + ) + # sys.exit() + self.model = nn.LSTM( + input_size=lstm_feature_size, + hidden_size=lstm_hidden_size, + num_layers=layer_num, + bidirectional=bidirectional, + batch_first=batch_first, + bias=use_bias, + ) + + if rnd_weights_init: + self.gen_rnd_weights() + + def gen_rnd_weights(self): + """ + Generate random weigths for the model with biases + Without projection: + For first weights group: + Wi (4*lstm_hidden_size, lstm_feature_size) + Wh (4*lstm_hidden_size, lstm_hidden_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + For first bidirectional weights group: + Wi (4*lstm_hidden_size, lstm_feature_size) + Wh (4*lstm_hidden_size, lstm_hidden_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + For other weights group: + Wi (4*lstm_hidden_size, lstm_hidden_size) + Wh (4*lstm_hidden_size, lstm_hidden_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + With projection: + For first weights group: + Wi (4*lstm_hidden_size, lstm_feature_size) + Wh (4*lstm_hidden_size, proj_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + P (proj_size, lstm_hidden_size) + For first bidirectional weights group: + Wi (4*lstm_hidden_size, lstm_feature_size) + Wh (4*lstm_hidden_size, proj_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + P (proj_size, lstm_hidden_size) + For other weights group: + Wi (4*lstm_hidden_size, proj_size * num_directions) + Wh (4*lstm_hidden_size, proj_size) + Bi (4*lstm_hidden_size) + Bh (4*lstm_hidden_size) + P (proj_size, lstm_hidden_size) + For generation of random weigths for the model without biases Bi and Bh are skipped + """ + super().gen_rnd_weights() + + def get_dummy_inputs(self): + return self.dummy_inputs + + def get_input_names(self): + return ["input", "h0", "c0"] + + def get_shape_desc(self, frontend_type): + shape_desc = None + if frontend_type == "pt": # PyTorch + shape_desc = [("input", self.shape)] + elif frontend_type == "onnx": # ONNX + shape_desc = { + "input": self.shape, + "h0": self.h0_shape, + "c0": self.c0_shape, + } + return shape_desc + + def get_tvm_inputs(self, dtype): + return { + "input": tvm.nd.array(self.dummy_inputs[0].numpy().astype(dtype)), + "h0": tvm.nd.array(self.dummy_inputs[1][0].numpy().astype(dtype)), + "c0": tvm.nd.array(self.dummy_inputs[1][1].numpy().astype(dtype)), + } + + +def compare(input, gold_data, rtol=1e-5, atol=1e-5): + tvm.testing.assert_allclose(input, gold_data, rtol=rtol, atol=atol) + + +def check_rnn(rnn_type, rnn_mod, target=tvm.target.Target("llvm -mcpu=core-avx2"), dev=tvm.cpu(0)): + def get_model( + rnn_type, + rnn_mod, + args, + ): + # Fill args + if "b" in rnn_mod: + args["bidirectional"] = True + if "s" in rnn_mod: + args["layer_num"] = num_layers + + if rnn_type == "GRU": + RNN_Model_selector = GRU_Model + elif rnn_type == "LSTM": + RNN_Model_selector = LSTM_Model + if "p" in rnn_mod: + args["proj_size"] = lstm_projection_size + + return RNN_Model_selector(**args) + + def get_onnx_model(model): + onnx_io = io.BytesIO() + with torch.no_grad(): + input_names = model.get_input_names() + inputs = model.get_dummy_inputs() + + # default export (without dynamic input) + torch.onnx.export(model, inputs, onnx_io, input_names=input_names) + + onnx_io.seek(0, 0) + return onnx.load_model(onnx_io) + + model = None + dtype = "float32" + device = torch.device("cpu") + for batch_first in (True, False): + for use_bias in (True, False): + for rnd_weights in [True]: # (True, False): + model_inputs = { + "batch_first": batch_first, + "use_bias": use_bias, + "rnd_weights_init": rnd_weights, + } + model = get_model(rnn_type, rnn_mod, model_inputs) + model.to(device) + model.eval() + + # Get golden output from original model + dummy_inputs = model.get_dummy_inputs() + golden_output = model.forward(dummy_inputs[0].to(device)).detach().cpu().numpy() + + tvm_output = None + for format in ["pt"]: # ["pt", "onnx"]: + shape_desc = model.get_shape_desc(format) + if format == "pt": + # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. + traced_script_module = torch.jit.trace(model, dummy_inputs[0]).eval() + + # Import model to Relay + mod, params = relay.frontend.from_pytorch(traced_script_module, shape_desc) + elif format == "onnx": + try: + onnx_model = get_onnx_model(model) + except: + print( + "WARNING: torch.onnx.export does not support conversion LSTM with projection " + "from pytorch! TODO: waiting for the support and correct test after that." + ) + continue + + # Import model to Relay + mod, params = relay.frontend.from_onnx(onnx_model, shape_desc) + + # Model compilation by tvm + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + + # Inference of the model with given input data + m = graph_executor.GraphModule(lib["default"](dev)) + + # Set inputs + tvm_inputs = model.get_tvm_inputs(dtype) + m.set_input(**tvm_inputs) + # Execute + m.run() + # Get outputs (converted to numpy array) + tvm_output = m.get_output(0).numpy() + + compare(tvm_output, golden_output) + + +@tvm.testing.uses_gpu +def test_rnns(): + for target, dev in tvm.testing.enabled_targets(): + # RNN types: GRU, LSTM + # GRU modifications: unidirectional, stacked, bidirectional, stacked bidirectional + for mod_type in ["uni", "s", "b", "sb"]: + check_rnn("GRU", mod_type, target, dev) + # LSTM modifications: unidirectional, stacked, bidirectional, stacked bidirectional, + # and all these types with projection ("p", "sp", "bp", "sbp") + # The latter are skiped for test acceleration + for mod_type in ["uni", "s", "b", "sb"]: + check_rnn("LSTM", mod_type, target, dev) + + +if __name__ == "__main__": + test_rnns() diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py index c91661db7e36..49dc5170c52f 100644 --- a/tests/python/frontend/tensorflow/test_control_flow.py +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -34,8 +34,7 @@ def check_equal(graph, tf_out, input_map=None): mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) if input_map is not None: params.update(input_map) - ex = relay.create_executor("vm", mod=mod) - relay_out = ex.evaluate()(**params) + relay_out = relay.create_executor("vm", mod=mod).evaluate()(**params) if isinstance(relay_out, nd.NDArray): np.testing.assert_allclose(tf_out, relay_out.numpy()) else: diff --git a/tests/python/frontend/tensorflow/test_debugging.py b/tests/python/frontend/tensorflow/test_debugging.py index 26fe171fb789..0e08840e56ee 100644 --- a/tests/python/frontend/tensorflow/test_debugging.py +++ b/tests/python/frontend/tensorflow/test_debugging.py @@ -28,8 +28,7 @@ def run_relay(graph, shape_dict=None, *vars): mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True), shape=shape_dict) - ex = relay.create_executor("debug", mod=mod) - return ex.evaluate()(*vars) + return relay.create_executor("debug", mod=mod).evaluate()(*vars) def test_assert_true(): diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 583014f657ad..338d219401ee 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -29,6 +29,13 @@ import tensorflow.compat.v1 as tf except ImportError: import tensorflow as tf + +# Only allow TF to run on half the GPU RAM to save the other half +# For TVM +gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5) +sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) +sess.close() + from tensorflow.python.framework import constant_op from tensorflow.python.framework import graph_util from tensorflow.python.ops import nn_ops @@ -117,7 +124,7 @@ def run_tvm_graph( disabled_pass=None, ignore_in_shape=False, serialize=False, - use_dense_op=True, + convert_config=None, ): """Generic function to compile on relay and execute on tvm""" input_data = convert_to_list(input_data) @@ -136,11 +143,10 @@ def run_tvm_graph( layout=layout, shape=shape_dict, outputs=out_names, - use_dense_op=use_dense_op, + convert_config=convert_config, ) dev = tvm.device(target, 0) if mode == "debug": - ex = relay.create_executor(mode, mod=mod, device=tvm.cpu(), target="llvm") inputs = [] for param in mod["main"].params: found = False @@ -152,7 +158,9 @@ def run_tvm_graph( # Interpreter doesn't bind constants, so still need to find in params if not found: inputs.append(tvm.nd.array(params[param.name_hint])) - result = ex.evaluate()(*inputs) + result = relay.create_executor(mode, mod=mod, device=tvm.cpu(), target="llvm").evaluate()( + *inputs + ) return vmobj_to_list(result) elif mode == "vm": with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): @@ -218,7 +226,7 @@ def compare_tf_with_tvm( add_shapes_to_graph_def=True, targets=None, ignore_in_shape=False, - use_dense_op=True, + convert_config=None, ): """Generic function to generate and compare tensorflow and TVM output""" @@ -266,7 +274,7 @@ def name_without_num(name): mode=mode, cuda_layout=cuda_layout, ignore_in_shape=ignore_in_shape, - use_dense_op=use_dense_op, + convert_config=convert_config, ) # since the names from tensorflow and relay runs are not exactly same, # first len(tf_output) will be compared @@ -318,6 +326,8 @@ def _test_pooling(input_shape, **kwargs): if is_gpu_available(): if len(input_shape) == 4: input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] + if isinstance(kwargs["padding"], list): + kwargs["padding"] = [kwargs["padding"][ii] for ii in (0, 3, 1, 2)] kwargs["data_format"] = "NCHW" _test_pooling_iteration(input_shape, **kwargs) @@ -1802,8 +1812,12 @@ def _test_matmul(i, j, k, dtype, outer=None): A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) - compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name, use_dense_op=True) - compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name, use_dense_op=False) + compare_tf_with_tvm( + [A_np, B_np], [A.name, B.name], result.name, convert_config={"use_dense": True} + ) + compare_tf_with_tvm( + [A_np, B_np], [A.name, B.name], result.name, convert_config={"use_dense": False} + ) def test_forward_matmul(): @@ -1821,7 +1835,18 @@ def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) - compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name) + compare_tf_with_tvm( + [A_np, B_np], + [A.name, B.name], + result.name, + convert_config={"use_nt_batch_matmul": True}, + ) + compare_tf_with_tvm( + [A_np, B_np], + [A.name, B.name], + result.name, + convert_config={"use_nt_batch_matmul": False}, + ) def _test_batch_matmul_dynamic( @@ -1834,10 +1859,23 @@ def _test_batch_matmul_dynamic( A_np = np.random.uniform(high=5.0, size=A_np_shape).astype(dtype) B_np = np.random.uniform(high=5.0, size=B_np_shape).astype(dtype) - # for now, in TOPI, only cublas's implementation support dynamic shape + # for now, in TOPI, only llvm & cublas's implementation support dynamic shape # TODO add more backends support in TOPI compare_tf_with_tvm( - [A_np, B_np], [A.name, B.name], result.name, mode="vm", targets=["cuda -libs=cublas"] + [A_np, B_np], + [A.name, B.name], + result.name, + mode="vm", + targets=["llvm", "cuda -libs=cublas"], + convert_config={"use_nt_batch_matmul": True}, + ) + compare_tf_with_tvm( + [A_np, B_np], + [A.name, B.name], + result.name, + mode="vm", + targets=["llvm", "cuda -libs=cublas"], + convert_config={"use_nt_batch_matmul": False}, ) @@ -1856,7 +1894,6 @@ def test_forward_batch_matmul(): _test_batch_matmul((1, 8, 64), (64, 1), "float32", False, False) -@tvm.testing.requires_cuda def test_forward_batch_matmul_dynamic(): _test_batch_matmul_dynamic((None, 5, 4), (None, 4, 5), (3, 5, 4), (3, 4, 5), "int32") _test_batch_matmul_dynamic( @@ -2475,9 +2512,15 @@ def _test_sparse_add(indices, values, A_shape, B_shape, dtype, flip=False): # TODO(ANSHUMAN87): support user input threashold values if flip: - result = tf.sparse.add(B, A_sp, threshold=0) + if package_version.parse(tf.VERSION) < package_version.parse("1.13.0"): + result = tf.sparse.add(B, A_sp, thresh=0) + else: + result = tf.sparse.add(B, A_sp, threshold=0) else: - result = tf.sparse.add(A_sp, B, threshold=0) + if package_version.parse(tf.VERSION) < package_version.parse("1.13.0"): + result = tf.sparse.add(A_sp, B, thresh=0) + else: + result = tf.sparse.add(A_sp, B, threshold=0) B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) @@ -2553,7 +2596,9 @@ def test_forward_stridedslice(): _test_stridedslice([], [0], [0], [1], "float32", new_axis_mask=1) _test_stridedslice([2], [1], [1], [1], "float32", shrink_axis_mask=1) + _test_stridedslice([4], [-1], [0], [1], "float32", shrink_axis_mask=1) _test_stridedslice([2, 1], [0], [1], [1], "float32", shrink_axis_mask=1) + _test_stridedslice([2, 3, 4], [-2], [0], [1], "float32", shrink_axis_mask=8) _test_stridedslice([2, 3, 4], [0], [1], [1], "float32", shrink_axis_mask=8) _test_stridedslice([3, 4, 3], [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32") _test_stridedslice([3, 4, 3], [1, 0], [4, 3], [2, 1], "float32", ellipsis_mask=8) diff --git a/tests/python/frontend/tensorflow/test_no_op.py b/tests/python/frontend/tensorflow/test_no_op.py index 38246ea5e14f..d8bfcee9673b 100644 --- a/tests/python/frontend/tensorflow/test_no_op.py +++ b/tests/python/frontend/tensorflow/test_no_op.py @@ -26,8 +26,7 @@ def run_relay(graph): mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) - ex = relay.create_executor("debug", mod=mod) - return ex.evaluate()(**params) + return relay.create_executor("debug", mod=mod).evaluate()(**params) def test_no_op(): diff --git a/tests/python/frontend/tensorflow2/common.py b/tests/python/frontend/tensorflow2/common.py index 9686909ff31f..4fbdbb07e940 100644 --- a/tests/python/frontend/tensorflow2/common.py +++ b/tests/python/frontend/tensorflow2/common.py @@ -71,7 +71,7 @@ def run_graph_executor(lib, input_, ctx=tvm.cpu(0)): mod = runtime.GraphModule(lib["default"](ctx)) mod.set_input(0, input_) mod.run() - return [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())] + return [mod.get_output(i).numpy() for i in range(mod.get_num_outputs())] def compare_tf_tvm(gdef, input_, output_, runtime="vm", output_tensors=None): diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index b3504ff38328..001ba6de1967 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -354,7 +354,8 @@ def get_input(self): @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) def func(self, x): a, b, c = tf.split(x, 3, axis=1) - return tf.raw_ops.ConcatV2(values=[a, b, c], axis=1) + axis = tf.add(tf.constant(1, dtype="int32"), tf.constant(0, dtype="int32")) + return tf.raw_ops.ConcatV2(values=[a, b, c], axis=axis) run_all(ConcatV2) @@ -447,5 +448,142 @@ def func(self, x): run_model_graph(StatelessWhile2Var, outputs=["Identity:output:0"]) +def test_tensorlist(): + def run_test(elem_shape): + class TensorList(tf.Module): + def get_input(self): + in_tens = np.ones((2, 3), dtype="float32") + in_tens[1, :] = np.zeros((3,), dtype="float32") + return in_tens + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=2, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=0, item=x[0, :]) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=1, item=x[1, :]) + output = tf.raw_ops.TensorListGetItem( + input_handle=tl, index=0, element_shape=elem_shape, element_dtype=dtype + ) + return output + + run_model_graph(TensorList) + run_func_graph(TensorList, runtime="vm") + + run_test((3,)) + run_test((-1,)) + + +def test_tensorlist_stack(): + def run_test(elem_shape): + class TensorListStack(tf.Module): + def get_input(self): + in_tens = np.ones((2, 3), dtype="float32") + in_tens[1] = np.zeros((3,), dtype="float32") + return in_tens + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=2, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListFromTensor(tensor=x, element_shape=elem_shape) + output = tf.raw_ops.TensorListStack( + input_handle=tl, element_shape=elem_shape, element_dtype=dtype + ) + return output + + run_model_graph(TensorListStack) + run_func_graph(TensorListStack, runtime="vm") + + run_test((3,)) + run_test((-1,)) + + +def test_tensorlist_2d(): + def run_test(elem_shape): + class TensorList2D(tf.Module): + def get_input(self): + in_tens = np.ones((2, 3, 4), dtype="float32") + in_tens[1, :, :] = np.zeros((3, 4), dtype="float32") + return in_tens + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=2, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=0, item=x[0, :, :]) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=1, item=x[1, :, :]) + output = tf.raw_ops.TensorListGetItem( + input_handle=tl, index=0, element_shape=elem_shape, element_dtype=dtype + ) + return output + + run_model_graph(TensorList2D) + run_func_graph(TensorList2D, runtime="vm") + + run_test((3, 4)) + run_test((-1, -1)) + + +def test_tensorlist_stack_2d(): + def run_test(elem_shape): + class TensorListStack2D(tf.Module): + def get_input(self): + in_tens = np.ones((2, 3, 4), dtype="float32") + in_tens[1, :, :] = np.zeros((3, 4), dtype="float32") + return in_tens + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=2, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListFromTensor(tensor=x, element_shape=elem_shape) + output = tf.raw_ops.TensorListStack( + input_handle=tl, element_shape=elem_shape, element_dtype=dtype + ) + return output + + run_model_graph(TensorListStack2D) + run_func_graph(TensorListStack2D, runtime="vm") + + run_test((3, 4)) + run_test((-1, -1)) + + +def test_tensorlist_stack_unpack(): + def run_test(elem_shape): + class TensorListStack2D(tf.Module): + def get_input(self): + in_tens = np.ones((1, 3, 4), dtype="float32") + return in_tens + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 3, 4), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=1, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=0, item=x[0, :, :]) + output = tf.raw_ops.TensorListStack( + input_handle=tl, element_shape=elem_shape, element_dtype=dtype, num_elements=1 + ) + output = tf.raw_ops.Unpack(value=output, num=1, axis=0) + return output + + run_model_graph(TensorListStack2D) + run_func_graph(TensorListStack2D, runtime="vm") + + run_test((3, 4)) + run_test((-1, -1)) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/frontend/tensorflow2/test_sequential_models.py b/tests/python/frontend/tensorflow2/test_sequential_models.py index 394a49d0f2e9..1b5a6342f07d 100644 --- a/tests/python/frontend/tensorflow2/test_sequential_models.py +++ b/tests/python/frontend/tensorflow2/test_sequential_models.py @@ -109,5 +109,60 @@ def maxpool_batchnorm_model(input_shape, pool_size=(2, 2)): run_sequential_model(maxpool_batchnorm_model, input_shape=(1, 32, 32, 3)) +def test_tensorlist_stack_model(): + def tensorlist_stack_model(input_shape): + class TensorArrayStackLayer(tf.keras.layers.Layer): + def __init__(self): + super().__init__() + + def call(self, inputs): + inputs = tf.squeeze(inputs) + outputs = tf.TensorArray( + tf.float32, + size=inputs.shape[0], + infer_shape=False, + element_shape=inputs.shape[1:], + ) + outputs = outputs.unstack(inputs) + + return outputs.stack() + + input_shape = (3, 32) + model = tf.keras.Sequential( + [tf.keras.layers.Input(shape=input_shape, batch_size=1), TensorArrayStackLayer()] + ) + return model + + run_sequential_model(tensorlist_stack_model, input_shape=(3, 32)) + + +def test_tensorlist_read_model(): + def tensorlist_read_model(input_shape): + class TensorArrayReadLayer(tf.keras.layers.Layer): + def __init__(self): + super().__init__() + + def call(self, inputs): + inputs = tf.squeeze(inputs) + outputs = tf.TensorArray( + tf.float32, + size=inputs.shape[0], + infer_shape=False, + element_shape=inputs.shape[1:], + ) + for i in range(inputs.shape[0]): + outputs = outputs.write(i, inputs[i, :]) + + return outputs.read(0) + + input_shape = (3, 32) + model = tf.keras.Sequential( + [tf.keras.layers.Input(shape=input_shape, batch_size=1), TensorArrayReadLayer()] + ) + return model + + run_sequential_model(tensorlist_read_model, input_shape=(3, 32)) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 7b5377b3363d..f2941030f0ab 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -189,7 +189,6 @@ def run_tvm_graph( ) if mode in ["debug", "vm"]: - ex = relay.create_executor(mode, mod=mod, device=tvm.cpu(), target="llvm") inputs = [] for param in mod["main"].params: found = False @@ -201,7 +200,9 @@ def run_tvm_graph( # Interpreter doesn't bind constants, so still need to find in params if not found: inputs.append(tvm.nd.array(params[param.name_hint])) - result = ex.evaluate()(*inputs) + result = relay.create_executor(mode, mod=mod, device=tvm.cpu(), target="llvm").evaluate()( + *inputs + ) return vmobj_to_list(result) else: with tvm.transform.PassContext(opt_level=3): @@ -321,7 +322,6 @@ def compare_tflite_with_tvm( out_names=out_names, mode=mode, ) - # WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output # range for the specific operator. While adding test ensure that we aren't getting only clipped values # in output tensors that still pass the assertion. For reference see _test_elemwise_qnn_out_range() @@ -1516,7 +1516,9 @@ def test_forward_reshape(): # ------ -def _test_resize(tf_resize_op, images_data, size_data, align_corners, quantized=False): +def _test_resize( + tf_resize_op, images_data, size_data, align_corners, half_pixel_centers, quantized=False +): """One iteration of Resize""" # Test with tensor and constant with tf.Graph().as_default(): @@ -1529,7 +1531,10 @@ def _test_resize(tf_resize_op, images_data, size_data, align_corners, quantized= ) input_range = {"in": (-3, 2)} out_tensor = tf_resize_op( - images=images_tensor_q, size=size, align_corners=align_corners + images=images_tensor_q, + size=size, + align_corners=align_corners, + half_pixel_centers=half_pixel_centers, ) out_tensor = tf.quantization.fake_quant_with_min_max_args( out_tensor, min=-3, max=2, name="out_tensor" @@ -1544,7 +1549,12 @@ def _test_resize(tf_resize_op, images_data, size_data, align_corners, quantized= input_range=input_range, ) else: - out_tensor = tf_resize_op(images=images_tensor, size=size, align_corners=align_corners) + out_tensor = tf_resize_op( + images=images_tensor, + size=size, + align_corners=align_corners, + half_pixel_centers=half_pixel_centers, + ) compare_tflite_with_tvm([images_data], ["in:0"], [images_tensor], [out_tensor]) @@ -1560,6 +1570,7 @@ def test_all_resize(): images_data_float32, size_data, align_corners=False, + half_pixel_centers=False, quantized=False, ) _test_resize( @@ -1567,13 +1578,32 @@ def test_all_resize(): images_data_float32, size_data, align_corners=True, + half_pixel_centers=False, quantized=False, ) _test_resize( - tf.image.resize_bilinear, images_data_uint8, size_data, align_corners=False, quantized=True + tf.image.resize_bilinear, + images_data_uint8, + size_data, + align_corners=False, + half_pixel_centers=False, + quantized=True, + ) + _test_resize( + tf.image.resize_bilinear, + images_data_uint8, + size_data, + align_corners=True, + half_pixel_centers=False, + quantized=True, ) _test_resize( - tf.image.resize_bilinear, images_data_uint8, size_data, align_corners=True, quantized=True + tf.image.resize_bilinear, + images_data_uint8, + size_data, + align_corners=False, + half_pixel_centers=True, + quantized=True, ) ### RESIZE_NEAREST_NEIGHBOR (was added in v1.13) # According to topi resize.h @@ -1582,7 +1612,11 @@ def test_all_resize(): if "RESIZE_NEAREST_NEIGHBOR" in dir(BuiltinOperator()): _test_resize( - tf.image.resize_nearest_neighbor, images_data_float32, size_data, align_corners=False + tf.image.resize_nearest_neighbor, + images_data_float32, + size_data, + align_corners=False, + half_pixel_centers=False, ) @@ -2583,6 +2617,22 @@ def test_forward_select(): ) +@pytest.mark.parametrize("quant_bits", [2, 4, 8, 16]) +@pytest.mark.parametrize( + "value, min, max", [[-10.11, -6, 6], [-3.55, -6, 6], [0, -6, 6], [3.55, -6, 6], [10.11, -6, 6]] +) +def test_forward_fake_quant(value, min, max, quant_bits): + with tf.Graph().as_default(): + with tf.Session() as sess: + input = tf.placeholder(tf.float32, shape=[1], name="input") + out = tf.quantization.fake_quant_with_min_max_args( + input, min=min, max=max, num_bits=quant_bits, name=None + ) + + in_data = np.float32(value) + compare_tflite_with_tvm([in_data], ["input:0"], [input], [out]) + + # Squeeze # ------- @@ -4412,7 +4462,7 @@ def test_forward_mediapipe_hand_landmark(): # -------------- def test_prevent_tensorflow_dynamic_range(): """ - Should prevent runnung "dynamic range quantization" optimized TFLite graph + Should prevent running "dynamic range quantization" optimized TFLite graph """ data_array = np.random.randint(0, 2, (1, 1024, 1024)).astype(dtype=np.float32) filter_array = np.random.randint(0, 2, (1024, 1024)).astype(dtype=np.float32) diff --git a/tests/python/integration/test_lower.py b/tests/python/integration/test_lower.py new file mode 100644 index 000000000000..3fa4795870d5 --- /dev/null +++ b/tests/python/integration/test_lower.py @@ -0,0 +1,327 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument +"""Test workload for lowering and build""" +import tvm +from tvm import tir +from tvm.script import ty +import tvm.testing +import numpy as np + + +@tvm.script.tir +def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + # match buffer + A = tir.match_buffer(a, [1024, 1024], "float16") + B = tir.match_buffer(b, [1024, 1024], "float16") + C = tir.match_buffer(c, [1024, 1024], "float32") + + # body + for blockIdx_x in tir.thread_binding(0, 16, "blockIdx.x"): + for blockIdx_y in tir.thread_binding(0, 8, "blockIdx.y"): + with tir.block([16, 8]) as [bx, by]: + tir.bind(bx, blockIdx_x) + tir.bind(by, blockIdx_y) + shared_A = tir.alloc_buffer([1024, 1024], "float16", scope="shared") + shared_B = tir.alloc_buffer([1024, 1024], "float16", scope="shared") + wmma_A = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") + wmma_B = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") + wmma_C = tir.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") + for ty in tir.thread_binding(0, 2, "threadIdx.y"): + for tz in tir.thread_binding(0, 2, "threadIdx.z"): + for i, j in tir.grid(2, 4): + with tir.block([64, 64]) as [vi, vj]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.reads([]) + tir.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + C0 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="wmma.accumulator", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_fill_fragment( + C0.data, + 16, + 16, + 16, + i * 4 + j, + tir.float32(0), + dtype="handle", + ) + ) + + for ko in range(0, 32): + # copy data from global to shared + for tx in tir.thread_binding(0, 32, "threadIdx.x"): + for i0, j0 in tir.grid(1, 4): + for j1 in tir.vectorized(0, 4): + with tir.block([1024, 1024]) as [vi, vj]: + tir.bind(vi, bx * 64 + ty * 32 + tx + i0) + tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + shared_A[vi, vj + 8] = A[vi, vj] + + for i0, j0 in tir.grid(2, 4): + for j1 in tir.vectorized(0, 4): + with tir.block([1024, 1024]) as [vi, vj]: + tir.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) + tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + shared_B[vi, vj + 8] = B[vi, vj] + + for ki in range(0, 2): + for i in range(0, 2): + with tir.block([64, 64]) as [vi, vk]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vk, ko * 2 + ki) + tir.reads( + shared_A[ + vi * 16 : vi * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ] + ) + tir.writes( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16] + ) + s0 = tir.var("int32") + s1 = tir.var("int32") + A0 = tir.match_buffer( + shared_A[ + vi * 16 : vi * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ], + (16, 16 + 8), + "float16", + strides=[s0, s1], + scope="shared", + offset_factor=1, + ) + wmma_A0 = tir.match_buffer( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_a", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + wmma_A0.data, + 16, + 16, + 16, + i, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset + 8, + A0.strides[0], + 1, + dtype="handle", + ), + A0.strides[0], + "row_major", + dtype="handle", + ) + ) + for j in range(0, 4): + with tir.block([64, 64]) as [vj, vk]: + tir.bind(vj, by * 8 + tz * 4 + j) + tir.bind(vk, ko * 2 + ki) + tir.reads( + shared_B[ + vj * 16 : vj * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ] + ) + tir.writes( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16] + ) + s0 = tir.var("int32") + s1 = tir.var("int32") + B0 = tir.match_buffer( + shared_B[ + vj * 16 : vj * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ], + (16, 16 + 8), + "float16", + strides=[s0, s1], + scope="shared", + offset_factor=1, + ) + wmma_B0 = tir.match_buffer( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_b", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + wmma_B0.data, + 16, + 16, + 16, + j, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + B0.data, + B0.elem_offset + 8, + B0.strides[0], + 1, + dtype="handle", + ), + B0.strides[0], + "col_major", + dtype="handle", + ) + ) + for i, j in tir.grid(2, 4): + with tir.block([64, 64, tir.reduce_axis(0, 64)]) as [ + vi, + vj, + vk, + ]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.bind(vk, ko * 2 + ki) + tir.reads( + [ + wmma_A[ + vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + ], + wmma_B[ + vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + ], + wmma_C[ + vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16 + ], + ] + ) + tir.writes( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16] + ) + wmma_A1 = tir.match_buffer( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_a", + offset_factor=1, + ) + wmma_B1 = tir.match_buffer( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_b", + offset_factor=1, + ) + wmma_C1 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="wmma.accumulator", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_mma_sync( + wmma_C1.data, + i * 4 + j, + wmma_A1.data, + i, + wmma_B1.data, + j, + wmma_C1.data, + i * 4 + j, + dtype="handle", + ) + ) + for i, j in tir.grid(2, 4): + with tir.block([64, 64]) as [vi, vj]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + s0 = tir.var("int32") + s1 = tir.var("int32") + wmma_C2 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="wmma.accumulator", + offset_factor=1, + ) + C1 = tir.match_buffer( + C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[s0, s1], + offset_factor=1, + ) + tir.evaluate( + tir.tvm_store_matrix_sync( + wmma_C2.data, + 16, + 16, + 16, + i * 4 + j, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float32"), + C1.data, + C1.elem_offset, + C1.strides[0], + 1, + dtype="handle", + ), + C1.strides[0], + "row_major", + dtype="handle", + ) + ) + + +@tvm.testing.requires_cuda +def test_gemm_tensorcore(): + dev = tvm.device("cuda", 0) + a_np = np.random.uniform(size=(1024, 1024)).astype("float16") + b_np = np.random.uniform(size=(1024, 1024)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.T.astype("float32")) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), dev) + f = tvm.build(tensorcore_gemm, target="cuda", name="dense") + f(a, b, c) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + + evaluator = f.time_evaluator(f.entry_name, dev, number=100) + t = evaluator(a, b, c).mean + num_flops = 2 * 1024 * 1024 * 1024 + gflops = num_flops / (t * 1e3) / 1e6 + print("gemm with tensor core: %f ms" % (t * 1e3)) + print("GFLOPS: %f" % gflops) + + +if __name__ == "__main__": + test_gemm_tensorcore() diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 836ff4b22b20..e5ac85b115aa 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -15,28 +15,94 @@ # specific language governing permissions and limitations # under the License. +import datetime +import itertools +import json +import logging import os -import io -import struct -import numpy as np import pathlib +import platform import shutil import subprocess -import tempfile import tarfile -import json +from typing import NamedTuple, Union, Optional, List, Dict +import pytest +import numpy as np import tvm from tvm import relay -from tvm.relay import transform from tvm.contrib import utils, graph_executor from tvm.relay.backend import compile_engine from tvm.relay.backend.utils import mangle_module_name -from tvm.contrib import utils from tvm.micro import export_model_library_format +_LOG = logging.getLogger(__name__) + +AOT_SUCCESS_TOKEN = "AOT_TEST_SUCCESS" +AOT_FAILURE_TOKEN = "AOT_TEST_FAILURE" + + +class AOTTestModel(NamedTuple): + """Class to describe a model under test + + Parameters + ---------- + module: tvm.IRModule + IRModule to generate AOT executor for + inputs: Dict[str, np.array] + Dict of input names to value arrays + outputs: List[np.array] + Ordered list of output value arrays + name: str + Name to use for this model + params: Optional[Dict[str, np.array]] + Dict of parameter names to value arrays + """ + + module: tvm.IRModule + inputs: Dict[str, np.array] + outputs: List[np.array] + name: str = "default" + params: Optional[Dict[str, np.array]] = None + + +class AOTTestRunner(NamedTuple): + """Class to describe a test runner for AOT code + + Parameters + ---------- + makefile: str + Premade Makefile to use from the AOT test folder + prologue: str + Code to prepend to the main function + includes: List[str] + Additional includes required to run the AOT test runner + parameters: Map[str, str] + Additional parameters to pass to the make command + """ + + makefile: str = "default" + prologue: str = "" + includes: List[str] = [] + parameters: Dict[str, str] = {} + + +AOT_DEFAULT_RUNNER = AOTTestRunner() + +# AOT Test Runner using the Arm® Corstone™-300 Reference Systems +# see: https://developer.arm.com/ip-products/subsystem/corstone/corstone-300 +AOT_CORSTONE300_RUNNER = AOTTestRunner( + makefile="corstone300", + prologue=""" + uart_init(); + """, + includes=["uart.h"], + parameters={"NPU_VARIANT": "256"}, +) + + def mangle_name(mod_name, name): mod_name = mangle_module_name(mod_name) return mod_name + "_" + name @@ -82,32 +148,83 @@ def convert_to_list(x): return mod, params -def subprocess_with_stdout_and_log(cmd, cwd, logfile, stdout): +def parametrize_aot_options(test): + """Parametrize over valid option combinations""" + + skip_i386 = pytest.mark.skipif( + platform.machine() == "i686", reason="Reference system unavailable in i386 container" + ) + interface_api = ["packed", "c"] + use_unpacked_api = [True, False] + test_runner = [AOT_DEFAULT_RUNNER, AOT_CORSTONE300_RUNNER] + + all_combinations = itertools.product(interface_api, use_unpacked_api, test_runner) + + # Filter out packed operators with c interface + valid_combinations = filter( + lambda parameters: not (parameters[0] == "c" and not parameters[1]), + all_combinations, + ) + + # Only use reference system for C interface and unpacked API calls + valid_combinations = filter( + lambda parameters: not ( + parameters[2] == AOT_CORSTONE300_RUNNER + and (parameters[0] == "packed" or not parameters[1]) + ), + valid_combinations, + ) + + # Skip reference system tests if running in i386 container + marked_combinations = map( + lambda parameters: pytest.param(*parameters, marks=skip_i386) + if parameters[2] == AOT_CORSTONE300_RUNNER + else parameters, + valid_combinations, + ) + + return pytest.mark.parametrize( + ["interface_api", "use_unpacked_api", "test_runner"], + marked_combinations, + )(test) + + +def subprocess_log_output(cmd, cwd, logfile): """ This method runs a process and logs the output to both a log file and stdout """ - with subprocess.Popen( + _LOG.info("Execute (%s): %s", cwd, cmd) + cmd_base = cmd[0] if isinstance(cmd, (list, tuple)) else cmd.split(" ", 1)[0] + proc = subprocess.Popen( cmd, cwd=cwd, shell=True, bufsize=0, stdout=subprocess.PIPE, stderr=subprocess.STDOUT - ) as proc, open(logfile, "a") as f: + ) + with open(logfile, "ab") as f: + f.write( + bytes( + "\n" + + "-" * 80 + + f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}: Execute ({cwd}): {cmd}\n" + + "-" * 80, + "utf-8", + ) + ) while True: data = proc.stdout.readline() - result = proc.poll() - # process is done if there is no data and the result is valid - if data == b"" and result is not None: - return int(result) - if data: - text = data.decode("ascii", errors="backslashreplace") - f.write(text) - if stdout: - print(text, end="") + _LOG.debug("%s: %s", cmd_base, str(data, "utf-8", "replace").rstrip("\n")) + f.write(data) + # process is done if there is no data and the result is valid + if not data: # EOF + break -def emit_main_network_definition(main_file, mod_name): - main_file.write(f'extern tvm_model_t {mangle_name(mod_name,"network")};\n') + return proc.wait() -def emit_main_prologue(main_file, workspace_bytes): - main_file.write(f"#define WORKSPACE_SIZE ({workspace_bytes})\n") +def emit_main_prologue(main_file, custom_prologue, workspace_bytes): + # Add TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES because of memory alignment. + main_file.write( + f"#define WORKSPACE_SIZE ({workspace_bytes} + TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES)\n" + ) main_file.write("static uint8_t g_aot_memory[WORKSPACE_SIZE];\n") main_file.write("tvm_workspace_t app_workspace;\n") main_file.write( @@ -120,100 +237,190 @@ def emit_main_prologue(main_file, workspace_bytes): return StackMemoryManager_Free(&app_workspace,ptr); } -void TVMPlatformAbort(tvm_crt_error_t code) { } +void TVMPlatformAbort(tvm_crt_error_t code) { } void TVMLogf(const char* msg, ...) { } TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {} int main(){\n - - """ +""" ) + main_file.write(custom_prologue) -def emit_main_data(main_file, input_list, output_list, mod_name): - for i in range(0, len(input_list)): - main_file.write(f'#include "{mangle_name(mod_name,"input_data")}{i}.h"\n') +def emit_main_data(main_file, input_map, output_list, mod_name): + for key in input_map: + main_file.write(f'#include "{mangle_name(mod_name,"input_data")}_{key}.h"\n') for i in range(0, len(output_list)): main_file.write(f'#include "{mangle_name(mod_name,"expected_output_data")}{i}.h"\n') main_file.write(f'#include "{mangle_name(mod_name,"output_data")}{i}.h"\n') -def emit_main_run(main_file, input_list, output_list, mod_name): +def emit_main_data_structs(main_file, input_map, output_list, mod_name): + main_file.write( + f"struct {mangle_name(mod_name, 'inputs')} {mangle_name(mod_name, 'inputs')} = {{" + ) + for key in input_map: + main_file.write(f"\t.{key} = {mangle_name(mod_name, 'input_data')}_{key},\n") + main_file.write("};\n") + + main_file.write( + f"struct {mangle_name(mod_name, 'outputs')} {mangle_name(mod_name, 'outputs')} = {{" + ) + num_outputs = len(output_list) + if num_outputs == 1: + main_file.write(f"\t.output = {mangle_name(mod_name, 'output_data')}0,\n") + else: + for i in range(0, num_outputs): + main_file.write(f"\t.output{i} = {mangle_name(mod_name, 'output_data')}{i},\n") + main_file.write("};\n") + + +def emit_main_data_setup(main_file, input_map, output_list, mod_name): num_outputs = len(output_list) - num_inputs = len(input_list) + num_inputs = len(input_map) main_file.write(f'void* {mangle_name(mod_name,"inputs")}[{num_inputs}] = {{ ') - - for i in range(0, len(input_list)): - main_file.write(f'{mangle_name(mod_name,"input_data")}{i}, ') + for key in input_map: + main_file.write(f'{mangle_name(mod_name,"input_data")}_{key}, ') main_file.write("};\n") main_file.write(f'void* {mangle_name(mod_name,"outputs")}[{num_outputs}] = {{ ') - for i in range(0, len(output_list)): + for i in range(0, num_outputs): main_file.write(f'{mangle_name(mod_name,"output_data")}{i}, ') main_file.write("};\n") + + +def emit_main_c_interface_call(main_file, mod_name): main_file.write( - f'tvm_runtime_run(&{mangle_name(mod_name,"network")}, {mangle_name(mod_name,"inputs")}, {mangle_name(mod_name,"outputs")});' + f'{mangle_name(mod_name,"run")}(&{mangle_name(mod_name,"inputs")}, &{mangle_name(mod_name,"outputs")});\n' ) +def emit_main_fake_packed_values(main_file): + main_file.write( + """ + static DLDevice fake_device = {kDLCPU, 0}; + static int64_t fake_dims = 0; + static int64_t fake_shape = {0}; + """ + ) + + +def emit_main_packed_call(main_file, input_map, output_list, mod_name): + tensors_name = mangle_name(mod_name, "tensors") + values_name = mangle_name(mod_name, "values") + typeids_name = mangle_name(mod_name, "typeids") + + def fake_tensor(source, source_index, packed_index): + main_file.write( + f""" + {tensors_name}[{packed_index}].device = fake_device; + {tensors_name}[{packed_index}].data = {source}[{source_index}]; + {tensors_name}[{packed_index}].shape = &fake_shape; + {tensors_name}[{packed_index}].ndim = fake_dims; + {tensors_name}[{packed_index}].byte_offset = 0; + {tensors_name}[{packed_index}].strides = NULL; + {values_name}[{packed_index}].v_handle = &{tensors_name}[{packed_index}]; + """ + ) + + num_outputs = len(output_list) + num_inputs = len(input_map) + num_tensors = num_inputs + num_outputs + main_file.write( + f""" + DLTensor {tensors_name}[{num_tensors}]; + TVMValue {values_name}[{num_tensors}]; + int32_t {typeids_name}[{num_tensors}]; + """ + ) + + for i in range(0, num_inputs): + fake_tensor(mangle_name(mod_name, "inputs"), i, i) + for i in range(0, num_outputs): + fake_tensor(mangle_name(mod_name, "outputs"), i, i + num_inputs) + + main_file.write( + f'{mangle_name(mod_name, "run")}({values_name}, {typeids_name}, 0, NULL, 0, NULL);\n' + ) + main_file.write("\n") + + def emit_main_compare(main_file, output_list, mod_name): - for i in range(0, len(output_list)): + num_outputs = len(output_list) + actual_data_name = mangle_name(mod_name, "output_data") + expected_data_name = mangle_name(mod_name, "expected_output_data") + + for i in range(0, num_outputs): is_float_dtype = output_list[i].dtype == "float32" - main_file.write(f'for (int i = 0; i<{mangle_name(mod_name,"output_data")}{i}_len; i++){{\n') + main_file.write(f"for (int i = 0; i<{actual_data_name}{i}_len; i++){{\n") if is_float_dtype: main_file.write( - f'if (fabs({mangle_name(mod_name,"output_data")}{i}[i]-{mangle_name(mod_name,"expected_output_data")}{i}[i]) > 0.001f){{printf("ko\\n");return -1;}}\n' + f'if (fabs({actual_data_name}{i}[i]-{expected_data_name}{i}[i]) > 0.001f){{\n\tprintf("{AOT_FAILURE_TOKEN}\\n");\n\treturn -1;}}\n' ) else: main_file.write( - f'if ({mangle_name(mod_name,"output_data")}{i}[i]!={mangle_name(mod_name, "expected_output_data")}{i}[i]){{printf("ko\\n");return -1;}}\n' + f'if ({actual_data_name}{i}[i]!={expected_data_name}{i}[i]){{\n\tprintf("{AOT_FAILURE_TOKEN}\\n");\n\treturn -1;}}\n' ) main_file.write("}\n") def emit_main_init_memory_manager(main_file): main_file.write("StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE);") + main_file.write("\n") def emit_main_epilogue(main_file): - main_file.write('printf("ok\\n");') + main_file.write(f'printf("{AOT_SUCCESS_TOKEN}\\n");') main_file.write("return 0;") main_file.write("}\n") -def emit_main_common_includes(main_file): +def emit_main_common_includes(main_file, custom_includes): main_file.write("#include \n") main_file.write("#include \n") - main_file.write('#include "tvm/runtime/crt/internal/aot_executor/aot_executor.h"\n') + main_file.write('#include "tvm/runtime/c_runtime_api.h"\n') main_file.write('#include "tvm/runtime/crt/stack_allocator.h"\n') + for include in custom_includes: + main_file.write(f'#include "{include}"\n') + +def emit_main_micro_include(main_file, mod_name): + main_file.write(f"#include <{mangle_module_name(mod_name)}.h>\n") -def create_main(test_name, input_list_map, output_list_map, output_path, workspace_bytes): + +def create_main( + test_name, models, output_path, custom_includes, custom_prologue, interface_api, workspace_bytes +): file_path = pathlib.Path(f"{output_path}/" + test_name).resolve() # create header file raw_path = file_path.with_suffix(".c").resolve() with open(raw_path, "w") as main_file: - emit_main_common_includes(main_file) - - for k in input_list_map: - emit_main_network_definition(main_file, k) - - emit_main_prologue(main_file, workspace_bytes) + emit_main_common_includes(main_file, custom_includes) - for k in input_list_map: - emit_main_data(main_file, input_list_map[k], output_list_map[k], k) + if interface_api == "c": + for model in models: + emit_main_micro_include(main_file, model.name) + for model in models: + emit_main_data(main_file, model.inputs, model.outputs, model.name) + emit_main_prologue(main_file, custom_prologue, workspace_bytes) emit_main_init_memory_manager(main_file) - for k in input_list_map: - emit_main_run(main_file, input_list_map[k], output_list_map[k], k) - - for k in input_list_map: - emit_main_compare(main_file, output_list_map[k], k) + if interface_api == "c": + for model in models: + emit_main_data_structs(main_file, model.inputs, model.outputs, model.name) + emit_main_c_interface_call(main_file, model.name) + else: + emit_main_fake_packed_values(main_file) + for model in models: + emit_main_data_setup(main_file, model.inputs, model.outputs, model.name) + emit_main_packed_call(main_file, model.inputs, model.outputs, model.name) + for model in models: + emit_main_compare(main_file, model.outputs, model.name) emit_main_epilogue(main_file) @@ -246,40 +453,40 @@ def create_header_file(tensor_name, npy_data, output_path): header_file.write("};\n\n") -def extract_main_workspace_sizebytes(extract_dir): +def extract_main_workspace_size_bytes(extract_dir): with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) return metadata["memory"]["functions"]["main"][0]["workspace_size_bytes"] def compile_and_run( - mod, - input_list, - output_list, - target_options, - use_calculated_workspaces, - params=None, + models: Union[List[AOTTestModel], AOTTestModel], + runner: AOTTestRunner, + interface_api, + use_unpacked_api, + debug_calculated_workspaces=False, workspace_byte_alignment=8, - mod_name=None, enable_op_fusion=True, ): """ This method verifies the generated source """ - target = f"c -runtime=c --link-params --executor=aot --workspace-byte-alignment={workspace_byte_alignment} {target_options}" + base_target = "c -runtime=c --link-params --executor=aot" + extra_target = f"--workspace-byte-alignment={workspace_byte_alignment} --interface-api={interface_api} --unpacked-api={int(use_unpacked_api)}" + target = f"{base_target} {extra_target}" cflags = f"-DTVM_RUNTIME_ALLOC_ALIGNMENT_BYTES={workspace_byte_alignment} " + if not isinstance(models, list): + models = [models] + # The calculated workspaces will not account for stack allocator tags used for debugging - if not use_calculated_workspaces: + if debug_calculated_workspaces: cflags += "-DTVM_CRT_STACK_ALLOCATOR_ENABLE_LIFO_CHECK " config = {"tir.disable_vectorize": True} if not enable_op_fusion: config["relay.FuseOps.max_depth"] = 1 - with tvm.transform.PassContext(opt_level=3, config=config): - lib = tvm.relay.build(mod, target, target_host=target, params=params, mod_name=mod_name) - tmp_path = utils.tempdir() tmp_dir = tmp_path.temp_dir @@ -287,110 +494,90 @@ def compile_and_run( build_path = os.path.join(base_path, "build") os.makedirs(build_path, exist_ok=True) - tar_file = os.path.join(base_path, "test.tar") - export_model_library_format(lib, tar_file) - t = tarfile.open(tar_file) - t.extractall(base_path) - if use_calculated_workspaces: - workspace_bytes = extract_main_workspace_sizebytes(base_path) - else: - workspace_bytes = 16384 * 1024 - - for i in range(len(input_list)): - create_header_file((f'{mangle_name(mod_name, "input_data")}{i}'), input_list[i], build_path) - - for i in range(len(output_list)): - create_header_file( - (f'{mangle_name(mod_name,"output_data")}{i}'), - np.zeros(output_list[i].shape, output_list[i].dtype), - build_path, - ) - create_header_file( - (f'{mangle_name(mod_name, "expected_output_data")}{i}'), output_list[i], build_path - ) - - create_main( - "test.c", {mod_name: input_list}, {mod_name: output_list}, build_path, workspace_bytes + include_path = os.path.join(base_path, "include") + os.mkdir(include_path) + crt_root = tvm.micro.get_standalone_crt_dir() + shutil.copy2( + os.path.join(crt_root, "template", "crt_config-template.h"), + os.path.join(include_path, "crt_config.h"), ) - # Verify that compiles fine - file_dir = os.path.dirname(os.path.abspath(__file__)) - makefile = os.path.join(file_dir, "aot_test.mk") - make_cmd = ( - f"make CFLAGS='{cflags}' -f {makefile} build_dir=" - + build_path - + f" TVM_ROOT={file_dir}/../../../.." - ) - - compile_log_path = os.path.join(build_path, "test_compile.log") - ret = subprocess_with_stdout_and_log(make_cmd, ".", compile_log_path, False) - assert ret == 0 - - # Verify that runs fine - run_log_path = os.path.join(build_path, "test_run.log") - ret = subprocess_with_stdout_and_log("./aot_test_runner", build_path, run_log_path, False) - assert ret == 0 - - -def compile_and_run_multiple_models( - mod_map, input_list_map, output_list_map, target_options, param_map -): - """ - This method verifies the generated source - """ - target = f"c -runtime=c --link-params --executor=aot {target_options}" - tmp_path = utils.tempdir() - tmp_dir = tmp_path.temp_dir - - base_path = os.path.join(tmp_dir, "test") - build_path = os.path.join(base_path, "build") - os.makedirs(build_path, exist_ok=True) - for mod_name, mod in mod_map.items(): - - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + workspace_bytes = 0 + for model in models: + with tvm.transform.PassContext(opt_level=3, config=config): lib = tvm.relay.build( - mod, target, target_host=target, params=param_map[mod_name], mod_name=mod_name + model.module, + target, + target_host=target, + params=model.params, + mod_name=model.name, ) - tar_file = os.path.join(base_path, "test.tar") + tar_file = os.path.join(base_path, f"{model.name}.tar") export_model_library_format(lib, tar_file) t = tarfile.open(tar_file) t.extractall(base_path) - input_list = input_list_map[mod_name] - output_list = output_list_map[mod_name] + workspace_bytes += extract_main_workspace_size_bytes(base_path) - for i in range(len(input_list_map[mod_name])): + for key in model.inputs: create_header_file( - (f'{mangle_name(mod_name,"input_data")}{i}'), input_list[i], build_path + f'{mangle_name(model.name, "input_data")}_{key}', + model.inputs[key], + include_path, ) - for i in range(len(output_list_map[mod_name])): + for i in range(len(model.outputs)): create_header_file( - (f'{mangle_name(mod_name,"output_data")}{i}'), - np.zeros(output_list[i].shape, output_list[i].dtype), - build_path, + (f'{mangle_name(model.name,"output_data")}{i}'), + np.zeros(model.outputs[i].shape, model.outputs[i].dtype), + include_path, ) create_header_file( - (f'{mangle_name(mod_name,"expected_output_data")}{i}'), output_list[i], build_path + (f'{mangle_name(model.name, "expected_output_data")}{i}'), + model.outputs[i], + include_path, ) - create_main("test.c", input_list_map, output_list_map, build_path, workspace_bytes=16384 * 1024) + create_main( + "test.c", + models, + build_path, + runner.includes, + runner.prologue, + interface_api, + workspace_bytes, + ) # Verify that compiles fine file_dir = os.path.dirname(os.path.abspath(__file__)) - makefile = os.path.join(file_dir, "aot_test.mk") - make_cmd = f"make -f {makefile} build_dir=" + build_path + f" TVM_ROOT={file_dir}/../../../.." + codegen_path = os.path.join(base_path, "codegen") + makefile = os.path.join(file_dir, f"{runner.makefile}.mk") + custom_params = " ".join([f" {param}='{value}'" for param, value in runner.parameters.items()]) + make_command = ( + f"make -f {makefile} build_dir={build_path}" + + f" CFLAGS='{cflags}'" + + f" TVM_ROOT={file_dir}/../../../.." + + f" AOT_TEST_ROOT={file_dir}" + + f" CODEGEN_ROOT={codegen_path}" + + f" STANDALONE_CRT_DIR={tvm.micro.get_standalone_crt_dir()}" + + custom_params + ) compile_log_path = os.path.join(build_path, "test_compile.log") - ret = subprocess_with_stdout_and_log(make_cmd, ".", compile_log_path, False) + compile_command = f"{make_command} aot_test_runner" + ret = subprocess_log_output(compile_command, ".", compile_log_path) assert ret == 0 # Verify that runs fine run_log_path = os.path.join(build_path, "test_run.log") - ret = subprocess_with_stdout_and_log("./aot_test_runner", build_path, run_log_path, False) + run_command = f"{make_command} run" + ret = subprocess_log_output(run_command, build_path, run_log_path) assert ret == 0 + with open(run_log_path) as run_log: + assert AOT_SUCCESS_TOKEN in run_log.read() + def generate_ref_data(mod, input_data, params=None, target="llvm"): """Generate reference data through executing the relay module""" diff --git a/tests/python/relay/aot/corstone300.ld b/tests/python/relay/aot/corstone300.ld new file mode 100644 index 000000000000..4a6b22480d9f --- /dev/null +++ b/tests/python/relay/aot/corstone300.ld @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*------------------ Reference System Memories ------------- + +===================+============+=======+============+============+ + | Memory | Address | Size | CPU Access | NPU Access | + +===================+============+=======+============+============+ + | ITCM | 0x00000000 | 512KB | Yes (RO) | No | + +-------------------+------------+-------+------------+------------+ + | DTCM | 0x20000000 | 512KB | Yes (R/W) | No | + +-------------------+------------+-------+------------+------------+ + | SSE-300 SRAM | 0x21000000 | 2MB | Yes (R/W) | Yes (R/W) | + +-------------------+------------+-------+------------+------------+ + | Data SRAM | 0x01000000 | 2MB | Yes (R/W) | Yes (R/W) | + +-------------------+------------+-------+------------+------------+ + | DDR | 0x60000000 | 32MB | Yes (R/W) | Yes (R/W) | + +-------------------+------------+-------+------------+------------+ */ + +/*---------------------- ITCM Configuration ---------------------------------- + Flash Configuration + Flash Base Address <0x0-0xFFFFFFFF:8> + Flash Size (in Bytes) <0x0-0xFFFFFFFF:8> + + -----------------------------------------------------------------------------*/ +__ROM_BASE = 0x00000000; +__ROM_SIZE = 0x00080000; + +/*--------------------- DTCM RAM Configuration ---------------------------- + RAM Configuration + RAM Base Address <0x0-0xFFFFFFFF:8> + RAM Size (in Bytes) <0x0-0xFFFFFFFF:8> + + -----------------------------------------------------------------------------*/ +__RAM_BASE = 0x20000000; +__RAM_SIZE = 0x00080000; + +/*----------------------- Data SRAM Configuration ------------------------------ + Data SRAM Configuration + DATA_SRAM Base Address <0x0-0xFFFFFFFF:8> + DATA_SRAM Size (in Bytes) <0x0-0xFFFFFFFF:8> + + -----------------------------------------------------------------------------*/ +__DATA_SRAM_BASE = 0x01000000; +__DATA_SRAM_SIZE = 0x00200000; + +/*--------------------- Embedded SRAM Configuration ---------------------------- + SRAM Configuration + SRAM Base Address <0x0-0xFFFFFFFF:8> + SRAM Size (in Bytes) <0x0-0xFFFFFFFF:8> + + -----------------------------------------------------------------------------*/ +__SRAM_BASE = 0x21000000; +__SRAM_SIZE = 0x00200000; + +/*--------------------- Stack / Heap Configuration ---------------------------- + Stack / Heap Configuration + Stack Size (in Bytes) <0x0-0xFFFFFFFF:8> + Heap Size (in Bytes) <0x0-0xFFFFFFFF:8> + + -----------------------------------------------------------------------------*/ +__STACK_SIZE = 0x00008000; +__HEAP_SIZE = 0x00008000; + +/*--------------------- Embedded RAM Configuration ---------------------------- + DDR Configuration + DDR Base Address <0x0-0xFFFFFFFF:8> + DDR Size (in Bytes) <0x0-0xFFFFFFFF:8> + + -----------------------------------------------------------------------------*/ +__DDR_BASE = 0x60000000; +__DDR_SIZE = 0x02000000; + +/* + *-------------------- <<< end of configuration section >>> ------------------- + */ + +MEMORY +{ + ITCM (rx) : ORIGIN = __ROM_BASE, LENGTH = __ROM_SIZE + DTCM (rwx) : ORIGIN = __RAM_BASE, LENGTH = __RAM_SIZE + DATA_SRAM (rwx) : ORIGIN = __DATA_SRAM_BASE, LENGTH = __DATA_SRAM_SIZE + SRAM (rwx) : ORIGIN = __SRAM_BASE, LENGTH = __SRAM_SIZE + DDR (rwx) : ORIGIN = __DDR_BASE, LENGTH = __DDR_SIZE +} + +/* Linker script to place sections and symbol values. Should be used together + * with other linker script that defines memory regions ITCM and RAM. + * It references following symbols, which must be defined in code: + * Reset_Handler : Entry of reset handler + * + * It defines following symbols, which code can use without definition: + * __exidx_start + * __exidx_end + * __copy_table_start__ + * __copy_table_end__ + * __zero_table_start__ + * __zero_table_end__ + * __etext + * __data_start__ + * __preinit_array_start + * __preinit_array_end + * __init_array_start + * __init_array_end + * __fini_array_start + * __fini_array_end + * __data_end__ + * __bss_start__ + * __bss_end__ + * __end__ + * end + * __HeapLimit + * __StackLimit + * __StackTop + * __stack + */ +ENTRY(Reset_Handler) + +SECTIONS +{ + .text : + { + KEEP(*(.vectors)) + *(.text*) + + KEEP(*(.init)) + KEEP(*(.fini)) + + /* .ctors */ + *crtbegin.o(.ctors) + *crtbegin?.o(.ctors) + *(EXCLUDE_FILE(*crtend?.o *crtend.o) .ctors) + *(SORT(.ctors.*)) + *(.ctors) + + /* .dtors */ + *crtbegin.o(.dtors) + *crtbegin?.o(.dtors) + *(EXCLUDE_FILE(*crtend?.o *crtend.o) .dtors) + *(SORT(.dtors.*)) + *(.dtors) + + *(.rodata*) + + KEEP(*(.eh_frame*)) + } > ITCM + + .ARM.extab : + { + *(.ARM.extab* .gnu.linkonce.armextab.*) + } > ITCM + + __exidx_start = .; + .ARM.exidx : + { + *(.ARM.exidx* .gnu.linkonce.armexidx.*) + } > ITCM + __exidx_end = .; + + .copy.table : + { + . = ALIGN(4); + __copy_table_start__ = .; + LONG (__etext) + LONG (__data_start__) + LONG (__data_end__ - __data_start__) + /* Add each additional data section here */ + __copy_table_end__ = .; + } > ITCM + + .zero.table : + { + . = ALIGN(4); + __zero_table_start__ = .; + __zero_table_end__ = .; + } > ITCM + + /** + * Location counter can end up 2byte aligned with narrow Thumb code but + * __etext is assumed by startup code to be the LMA of a section in DTCM + * which must be 4byte aligned + */ + __etext = ALIGN (4); + + .data : AT (__etext) + { + __data_start__ = .; + *(vtable) + *(.data) + *(.data.*) + + . = ALIGN(4); + /* preinit data */ + PROVIDE_HIDDEN (__preinit_array_start = .); + KEEP(*(.preinit_array)) + PROVIDE_HIDDEN (__preinit_array_end = .); + + . = ALIGN(4); + /* init data */ + PROVIDE_HIDDEN (__init_array_start = .); + KEEP(*(SORT(.init_array.*))) + KEEP(*(.init_array)) + PROVIDE_HIDDEN (__init_array_end = .); + + + . = ALIGN(4); + /* finit data */ + PROVIDE_HIDDEN (__fini_array_start = .); + KEEP(*(SORT(.fini_array.*))) + KEEP(*(.fini_array)) + PROVIDE_HIDDEN (__fini_array_end = .); + + KEEP(*(.jcr*)) + . = ALIGN(4); + /* All data end */ + __data_end__ = .; + + } > DTCM + + .sram : + { + . = ALIGN(16); + *(.bss.ethosu_fast_memory); + . = ALIGN(16); + } > SRAM AT > SRAM + + .bss.NoInit : + { + . = ALIGN(16); + *(.bss.NoInit) + . = ALIGN(16); + } > DDR AT > DDR + + .bss : + { + . = ALIGN(4); + __bss_start__ = .; + *(.bss) + *(.bss.*) + *(COMMON) + . = ALIGN(4); + __bss_end__ = .; + } > DTCM AT > DTCM + + .data_sram : + { + . = ALIGN(16); + } > DATA_SRAM + + .heap (COPY) : + { + . = ALIGN(8); + __end__ = .; + PROVIDE(end = .); + . = . + __HEAP_SIZE; + . = ALIGN(8); + __HeapLimit = .; + } > DTCM + + .stack (ORIGIN(DTCM) + LENGTH(DTCM) - __STACK_SIZE) (COPY) : + { + . = ALIGN(8); + __StackLimit = .; + . = . + __STACK_SIZE; + . = ALIGN(8); + __StackTop = .; + } > DTCM + PROVIDE(__stack = __StackTop); + + /* Check if data + stack exceeds DTCM limit */ + ASSERT(__StackLimit >= __bss_end__, "region DTCM overflowed with stack") +} diff --git a/tests/python/relay/aot/corstone300.mk b/tests/python/relay/aot/corstone300.mk new file mode 100644 index 000000000000..3a946f2cd876 --- /dev/null +++ b/tests/python/relay/aot/corstone300.mk @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Makefile to build and run AOT tests against the reference system + +# Setup build environment +build_dir := build +TVM_ROOT=$(shell cd ../../../../..; pwd) +CRT_ROOT ?= ${TVM_ROOT}/build/standalone_crt +ifeq ($(shell ls -lhd $(CRT_ROOT)),) +$(error "CRT not found. Ensure you have built the standalone_crt target and try again") +endif + +ARM_CPU=ARMCM55 +DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core +ETHOSU_PATH=/opt/arm/ethosu +CMSIS_PATH=${ETHOSU_PATH}/cmsis +PLATFORM_PATH=${ETHOSU_PATH}/core_platform/targets/corstone-300 +PKG_COMPILE_OPTS = -g -Wall -O2 -Wno-incompatible-pointer-types -Wno-format -mcpu=cortex-m55 -mthumb -mfloat-abi=hard -std=gnu99 +CC = arm-none-eabi-gcc +AR = arm-none-eabi-ar +RANLIB = arm-none-eabi-ranlib +CC_OPTS = CC=$(CC) AR=$(AR) RANLIB=$(RANLIB) +PKG_CFLAGS = ${PKG_COMPILE_OPTS} \ + ${CFLAGS} \ + -I$(build_dir)/../include \ + -I$(CODEGEN_ROOT)/host/include \ + -I${PLATFORM_PATH} \ + -I${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Include/ \ + -I${CMSIS_PATH}/CMSIS/Core/Include \ + -I${CMSIS_PATH}/CMSIS/NN/Include \ + -I${CMSIS_PATH}/CMSIS/DSP/Include \ + -isystem$(STANDALONE_CRT_DIR)/include \ + +PKG_LDFLAGS = -lm -specs=nosys.specs -static -T ${AOT_TEST_ROOT}/corstone300.ld + +$(ifeq VERBOSE,1) +QUIET ?= +$(else) +QUIET ?= @ +$(endif) + +CRT_SRCS = $(shell find $(CRT_ROOT)) +CODEGEN_SRCS = $(shell find $(abspath $(CODEGEN_ROOT)/host/src/*.c)) +CODEGEN_OBJS = $(subst .c,.o,$(CODEGEN_SRCS)) +CMSIS_STARTUP_SRCS = $(shell find ${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Source/*.c) +CMSIS_NN_SRCS = $(shell find ${CMSIS_PATH}/CMSIS/NN/Source/*/*.c) +UART_SRCS = $(shell find ${PLATFORM_PATH}/*.c) + +aot_test_runner: $(build_dir)/aot_test_runner + +$(build_dir)/stack_allocator.o: $(TVM_ROOT)/src/runtime/crt/memory/stack_allocator.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^ + +$(build_dir)/crt_backend_api.o: $(TVM_ROOT)/src/runtime/crt/common/crt_backend_api.c + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^ + +$(build_dir)/libcodegen.a: $(CODEGEN_SRCS) + $(QUIET)cd $(abspath $(CODEGEN_ROOT)/host/src) && $(CC) -c $(PKG_CFLAGS) $(CODEGEN_SRCS) + $(QUIET)$(AR) -cr $(abspath $(build_dir)/libcodegen.a) $(CODEGEN_OBJS) + $(QUIET)$(RANLIB) $(abspath $(build_dir)/libcodegen.a) + +${build_dir}/libcmsis_startup.a: $(CMSIS_STARTUP_SRCS) + $(QUIET)mkdir -p $(abspath $(build_dir)/libcmsis_startup) + $(QUIET)cd $(abspath $(build_dir)/libcmsis_startup) && $(CC) -c $(PKG_CFLAGS) -D${ARM_CPU} $^ + $(QUIET)$(AR) -cr $(abspath $(build_dir)/libcmsis_startup.a) $(abspath $(build_dir))/libcmsis_startup/*.o + $(QUIET)$(RANLIB) $(abspath $(build_dir)/libcmsis_startup.a) + +${build_dir}/libcmsis_nn.a: $(CMSIS_NN_SRCS) + $(QUIET)mkdir -p $(abspath $(build_dir)/libcmsis_nn) + $(QUIET)cd $(abspath $(build_dir)/libcmsis_nn) && $(CC) -c $(PKG_CFLAGS) -D${ARM_CPU} $^ + $(QUIET)$(AR) -cr $(abspath $(build_dir)/libcmsis_nn.a) $(abspath $(build_dir))/libcmsis_nn/*.o + $(QUIET)$(RANLIB) $(abspath $(build_dir)/libcmsis_nn.a) + +${build_dir}/libuart.a: $(UART_SRCS) + $(QUIET)mkdir -p $(abspath $(build_dir)/libuart) + $(QUIET)cd $(abspath $(build_dir)/libuart) && $(CC) -c $(PKG_CFLAGS) $^ + $(QUIET)$(AR) -cr $(abspath $(build_dir)/libuart.a) $(abspath $(build_dir))/libuart/*.o + $(QUIET)$(RANLIB) $(abspath $(build_dir)/libuart.a) + +$(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/crt_backend_api.o $(build_dir)/stack_allocator.o ${build_dir}/libcmsis_startup.a ${build_dir}/libcmsis_nn.a ${build_dir}/libuart.a $(build_dir)/libcodegen.a + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) $(PKG_CFLAGS) -o $@ -Wl,--whole-archive $^ -Wl,--no-whole-archive $(PKG_LDFLAGS) + +clean: + $(QUIET)rm -rf $(build_dir)/crt + +cleanall: + $(QUIET)rm -rf $(build_dir) + +run: $(build_dir)/aot_test_runner + /opt/arm/FVP_Corstone_SSE-300_Ethos-U55/models/Linux64_GCC-6.4/FVP_Corstone_SSE-300_Ethos-U55 -C cpu0.CFGDTCMSZ=15 \ + -C cpu0.CFGITCMSZ=15 -C mps3_board.uart0.out_file=\"-\" -C mps3_board.uart0.shutdown_tag=\"EXITTHESIM\" \ + -C mps3_board.visualisation.disable-visualisation=1 -C mps3_board.telnetterminal0.start_telnet=0 \ + -C mps3_board.telnetterminal1.start_telnet=0 -C mps3_board.telnetterminal2.start_telnet=0 -C mps3_board.telnetterminal5.start_telnet=0 \ + -C ethosu.num_macs=$(NPU_VARIANT) $(build_dir)/aot_test_runner + +.SUFFIXES: + +.DEFAULT: aot_test_runner + +.PHONY: run \ No newline at end of file diff --git a/tests/python/relay/aot/aot_test.mk b/tests/python/relay/aot/default.mk similarity index 68% rename from tests/python/relay/aot/aot_test.mk rename to tests/python/relay/aot/default.mk index 2426d9fd2963..f5edcb3d6422 100644 --- a/tests/python/relay/aot/aot_test.mk +++ b/tests/python/relay/aot/default.mk @@ -16,25 +16,20 @@ # under the License. # Setup build environment # -AOT_ROOT ?= $(TVM_ROOT)/src/runtime/crt/aot +AOT_ROOT ?= $(CRT_ROOT)/aot ENABLE_TVM_PLATFORM_ABORT_BACKTRACE = 0 DMLC_CORE=$(TVM_ROOT)/3rdparty/dmlc-core -PKG_COMPILE_OPTS = -g +PKG_COMPILE_OPTS = -g CC = gcc AR = ar RANLIB = ranlib CC_OPTS = CC=$(CC) AR=$(AR) RANLIB=$(RANLIB) - PKG_CFLAGS = ${PKG_COMPILE_OPTS} \ - -I$(TVM_ROOT)/src/runtime/crt/include \ - -I$(TVM_ROOT)/src/runtime/crt/host \ - -I$(TVM_ROOT)/include \ - -I$(DMLC_CORE)/include \ - -I$(TVM_ROOT)/3rdparty/dlpack/include \ - -I$(AOT_ROOT)\ - -I$(build_dir) + -I$(build_dir)/../include \ + -I$(CODEGEN_ROOT)/host/include \ + -isystem$(STANDALONE_CRT_DIR)/include $(ifeq VERBOSE,1) QUIET ?= @@ -42,14 +37,12 @@ $(else) QUIET ?= @ $(endif) -CRT_SRCS = $(shell find $(CRT_ROOT)) - aot_test_runner: $(build_dir)/aot_test_runner source_libs= $(wildcard $(build_dir)/../codegen/host/src/*.c) -lib_objs =$(source_libs:.c=.o) +lib_objs =$(source_libs:.c=.o) -$(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/aot_executor.o $(source_libs) $(build_dir)/stack_allocator.o $(build_dir)/crt_backend_api.o +$(build_dir)/aot_test_runner: $(build_dir)/test.c $(source_libs) $(build_dir)/stack_allocator.o $(build_dir)/crt_backend_api.o $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) $(CFLAGS) $(PKG_CFLAGS) -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS) -lm @@ -57,15 +50,11 @@ $(build_dir)/%.o: $(build_dir)/../codegen/host/src/%.c $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) -$(build_dir)/aot_executor.o: $(TVM_ROOT)/src/runtime/crt/aot_executor/aot_executor.c - $(QUIET)mkdir -p $(@D) - $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) - -$(build_dir)/stack_allocator.o: $(TVM_ROOT)/src/runtime/crt/memory/stack_allocator.c +$(build_dir)/stack_allocator.o: $(STANDALONE_CRT_DIR)/src/runtime/crt/memory/stack_allocator.c $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) -$(build_dir)/crt_backend_api.o: $(TVM_ROOT)/src/runtime/crt/common/crt_backend_api.c +$(build_dir)/crt_backend_api.o: $(STANDALONE_CRT_DIR)/src/runtime/crt/common/crt_backend_api.c $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) @@ -73,5 +62,13 @@ clean: $(QUIET)rm -rf $(build_dir)/crt cleanall: $(QUIET)rm -rf $(build_dir) + +run: $(build_dir)/aot_test_runner + $(build_dir)/aot_test_runner + # Don't define implicit rules; they tend to match on logical target names that aren't targets (i.e. bundle_static) .SUFFIXES: + +.DEFAULT: aot_test_runner + +.PHONY: run diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 13cbfa71b6ae..36cffefcd0bb 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -15,37 +15,48 @@ # specific language governing permissions and limitations # under the License. -import os -import io -import struct +from collections import OrderedDict +import sys + import numpy as np -import pathlib -import shutil -import subprocess -import tempfile -import tarfile import pytest import tvm from tvm import relay -from tvm.relay import transform -from tvm.relay.op.contrib import get_pattern_table -from tvm.contrib import utils -from tvm.relay.backend import compile_engine -from tvm.contrib import utils -from tvm.contrib import graph_executor -from tvm.micro import export_model_library_format -from tvm.relay import testing -from tvm.relay.op.annotation import compiler_begin, compiler_end -from tvm.contrib import utils -from tvm.relay.expr_functor import ExprMutator - -from aot_test_utils import * - - -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_conv_with_params(use_calculated_workspaces, target_options): +from tvm.ir.module import IRModule +from tvm.relay import testing, transform +from tvm.relay.testing import byoc +from aot_test_utils import ( + AOTTestModel, + AOT_DEFAULT_RUNNER, + generate_ref_data, + convert_to_relay, + compile_and_run, + parametrize_aot_options, +) + + +def test_error_c_interface_with_packed_api(): + interface_api = "c" + use_unpacked_api = False + test_runner = AOT_DEFAULT_RUNNER + + two = relay.add(relay.const(1), relay.const(1)) + func = relay.Function([], two) + + with pytest.raises(tvm.TVMError, match="Packed interface required for packed operators"): + compile_and_run( + AOTTestModel( + module=IRModule.from_expr(func), inputs={}, outputs=generate_ref_data(func, {}) + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + +@parametrize_aot_options +def test_conv_with_params(interface_api, use_unpacked_api, test_runner): RELAY_MODEL = """ #[version = "0.0.5"] def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), int8]) { @@ -73,13 +84,16 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), inputs = {"data": input_data} output_list = generate_ref_data(mod, inputs, params) - input_list = [input_data] - compile_and_run(mod, input_list, output_list, target_options, use_calculated_workspaces, params) + compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params), + test_runner, + interface_api, + use_unpacked_api, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_add_with_params(use_calculated_workspaces, target_options): +@parametrize_aot_options +def test_add_with_params(interface_api, use_unpacked_api, test_runner): x = relay.var("x", shape=(1, 10)) y = relay.var("y", shape=(1, 10)) z = relay.add(x, y) @@ -92,62 +106,48 @@ def test_add_with_params(use_calculated_workspaces, target_options): inputs = {"y": y_in} output_list = generate_ref_data(func, inputs, params) - input_list = [y_in] compile_and_run( - func, input_list, output_list, target_options, use_calculated_workspaces, params + AOTTestModel( + module=IRModule.from_expr(func), inputs=inputs, outputs=output_list, params=params + ), + test_runner, + interface_api, + use_unpacked_api, ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_conv2d(use_calculated_workspaces, target_options): +@parametrize_aot_options +@pytest.mark.parametrize("groups,weight_shape", [(1, 32), (32, 1)]) +def test_conv2d(interface_api, use_unpacked_api, test_runner, groups, weight_shape): """Test a subgraph with a single conv2d operator.""" + dtype = "float32" + ishape = (1, 32, 14, 14) + wshape = (32, weight_shape, 3, 3) - def conv2d_direct(): - dtype = "float32" - ishape = (1, 32, 14, 14) - w1shape = (32, 32, 3, 3) - - data0 = relay.var("data", shape=ishape, dtype=dtype) - weight0 = relay.var("weight", shape=w1shape, dtype=dtype) - out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1)) - main_f = relay.Function([data0, weight0], out) - mod = tvm.IRModule() - mod["main"] = main_f - mod = transform.InferType()(mod) - - i_data = np.random.uniform(0, 1, ishape).astype(dtype) - w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) - - return mod, {"data": i_data, "weight": w1_data}, (1, 32, 14, 14) - - def group_conv2d(): - dtype = "float32" - ishape = (1, 32, 14, 14) - w2shape = (32, 1, 3, 3) - - data0 = relay.var("data", shape=(ishape), dtype=dtype) - weight0 = relay.var("weight", shape=(w2shape), dtype=dtype) - out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1), groups=32) - main_f = relay.Function([data0, weight0], out) - mod = tvm.IRModule() - mod["main"] = main_f - mod = transform.InferType()(mod) + data0 = relay.var("data", shape=ishape, dtype=dtype) + weight0 = relay.var("weight", shape=wshape, dtype=dtype) + out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1), groups=groups) + main_f = relay.Function([data0, weight0], out) + mod = tvm.IRModule() + mod["main"] = main_f + mod = transform.InferType()(mod) - i_data = np.random.uniform(0, 1, ishape).astype(dtype) - w_data = np.random.uniform(0, 1, w2shape).astype(dtype) + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w1_data = np.random.uniform(0, 1, wshape).astype(dtype) - return mod, {"data": i_data, "weight": w_data}, (1, 32, 14, 14) + inputs = OrderedDict([("data", i_data), ("weight", w1_data)]) - for mod, inputs, out_shape in [conv2d_direct(), group_conv2d()]: - output_list = generate_ref_data(mod, inputs) - input_list = [inputs["data"], inputs["weight"]] - compile_and_run(mod, input_list, output_list, target_options, use_calculated_workspaces) + output_list = generate_ref_data(mod, inputs) + compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_concatenate(use_calculated_workspaces, target_options): +@parametrize_aot_options +def test_concatenate(interface_api, use_unpacked_api, test_runner): dtype = "float32" x = relay.var("x", shape=(10, 5), dtype=dtype) y = relay.var("y", shape=(10, 5), dtype=dtype) @@ -159,16 +159,19 @@ def test_concatenate(use_calculated_workspaces, target_options): x_data = np.random.rand(10, 5).astype(dtype) y_data = np.random.rand(10, 5).astype(dtype) t_data = np.random.uniform(size=()).astype(dtype) - inputs = {"x": x_data, "y": y_data, "z": t_data} + inputs = OrderedDict([("x", x_data), ("y", y_data), ("z", t_data)]) output_list = generate_ref_data(func, inputs) - input_list = [inputs["x"], inputs["y"], inputs["z"]] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + compile_and_run( + AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_nested_tuples(use_calculated_workspaces, target_options): +@parametrize_aot_options +def test_nested_tuples(interface_api, use_unpacked_api, test_runner): x = relay.var("x", shape=(10,)) x1 = x + relay.const(1.0) x2 = x1 + relay.const(1.0) @@ -180,168 +183,141 @@ def test_nested_tuples(use_calculated_workspaces, target_options): x_data = np.random.uniform(size=(10,)).astype(np.float32) inputs = {"x": x_data} output_list = generate_ref_data(func, inputs) - input_list = [x_data] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + + compile_and_run( + AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_tuple_getitem(use_calculated_workspaces, target_options): +@parametrize_aot_options +def test_tuple_getitem(interface_api, use_unpacked_api, test_runner): func = relay.Function([], relay.TupleGetItem(relay.Tuple([relay.const(1), relay.const(2)]), 0)) output_list = generate_ref_data(func, {}) - input_list = [] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + compile_and_run( + AOTTestModel(module=IRModule.from_expr(func), inputs={}, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_id(use_calculated_workspaces, target_options): + +@parametrize_aot_options +def test_id(interface_api, use_unpacked_api, test_runner): x = relay.var("x", "float32") ident = relay.Function([x], x) one = np.array(1.0, "float32") inputs = {"x": one} output_list = generate_ref_data(ident, inputs) - input_list = [one] - compile_and_run(ident, input_list, output_list, target_options, use_calculated_workspaces) + + compile_and_run( + AOTTestModel(module=IRModule.from_expr(ident), inputs=inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_add_const(use_calculated_workspaces, target_options): +@parametrize_aot_options +def test_add_const(interface_api, use_unpacked_api, test_runner): two = relay.add(relay.const(1), relay.const(1)) func = relay.Function([], two) output_list = generate_ref_data(func, {}) - input_list = [] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + + compile_and_run( + AOTTestModel(module=IRModule.from_expr(func), inputs={}, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_mul_param(use_calculated_workspaces, target_options): +@parametrize_aot_options +def test_mul_param(interface_api, use_unpacked_api, test_runner): x = relay.var("x", shape=(10, 10)) y = relay.var("y", shape=(1, 10)) func = relay.Function([x, y], relay.multiply(x, y)) x_data = np.random.rand(10, 10).astype("float32") y_data = np.random.rand(1, 10).astype("float32") - inputs = {"x": x_data, "y": y_data} + + inputs = OrderedDict([("x", x_data), ("y", y_data)]) output_list = generate_ref_data(func, inputs) - input_list = [inputs["x"], inputs["y"]] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + + compile_and_run( + AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_subtract(use_calculated_workspaces, target_options): +@parametrize_aot_options +def test_subtract(interface_api, use_unpacked_api, test_runner): i = relay.var("i", shape=[], dtype="int32") sub = relay.subtract(i, relay.const(1, dtype="int32")) func = relay.Function([i], sub, ret_type=relay.TensorType([], "int32")) i_data = np.array(1, dtype="int32") inputs = {"i": i_data} output_list = generate_ref_data(func, inputs) - input_list = [inputs["i"]] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + compile_and_run( + AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_tuple_output(use_calculated_workspaces, target_options): +@parametrize_aot_options +def test_tuple_output(interface_api, use_unpacked_api, test_runner): x = relay.var("x", shape=(6, 9)) y = relay.split(x, 3).astuple() a = relay.TupleGetItem(y, 0) b = relay.TupleGetItem(y, 1) - c = relay.TupleGetItem(y, 2) out = relay.Tuple([a, b]) func = relay.Function([x], out) x_data = np.random.rand(6, 9).astype("float32") inputs = {"x": x_data} output_list = generate_ref_data(func, inputs) - input_list = [inputs["x"]] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + compile_and_run( + AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + ) @pytest.mark.parametrize( - "use_calculated_workspaces_and_alignment", [(True, 1), (True, 16), (False, 1)] + ["debug_calculated_workspaces", "workspace_byte_alignment"], [(True, 1), (True, 16), (False, 1)] ) -@pytest.mark.parametrize("target_options", ["--unpacked-api"]) -def test_mobilenet(use_calculated_workspaces_and_alignment, target_options): - use_calculated_workspaces = use_calculated_workspaces_and_alignment[0] - workspace_byte_alignment = use_calculated_workspaces_and_alignment[1] +def test_mobilenet(debug_calculated_workspaces, workspace_byte_alignment): + use_unpacked_api = True + interface_api = "c" + test_runner = AOT_DEFAULT_RUNNER mod, params = testing.mobilenet.get_workload(batch_size=1) data_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] data = np.random.uniform(size=data_shape).astype("float32") inputs = {"data": data} output_list = generate_ref_data(mod, inputs, params) - input_list = [inputs["data"]] compile_and_run( - mod, - input_list, - output_list, - target_options, - use_calculated_workspaces, - params, - workspace_byte_alignment, + AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params), + test_runner, + interface_api, + use_unpacked_api, + workspace_byte_alignment=workspace_byte_alignment, + debug_calculated_workspaces=debug_calculated_workspaces, ) -class CcompilerAnnotator(ExprMutator): - """ - This is used to create external functions for ccompiler. - A simple annotator that creates the following program: - | - -- begin -- - | - add - | - subtract - | - multiply - | - -- end -- - | - """ - - def __init__(self): - super(CcompilerAnnotator, self).__init__() - self.in_compiler = 0 - - def visit_call(self, call): - if call.op.name == "add": # Annotate begin at args - if self.in_compiler == 1: - lhs = compiler_begin(super().visit(call.args[0]), "ccompiler") - rhs = compiler_begin(super().visit(call.args[1]), "ccompiler") - op = relay.add(lhs, rhs) - self.in_compiler = 2 - return op - elif call.op.name == "subtract": - if self.in_compiler == 1: - lhs = super().visit(call.args[0]) - rhs = super().visit(call.args[1]) - if isinstance(lhs, relay.expr.Var): - lhs = compiler_begin(lhs, "ccompiler") - if isinstance(rhs, relay.expr.Var): - rhs = compiler_begin(rhs, "ccompiler") - return relay.subtract(lhs, rhs) - elif call.op.name == "multiply": # Annotate end at output - self.in_compiler = 1 - lhs = super().visit(call.args[0]) - rhs = super().visit(call.args[1]) - if isinstance(lhs, relay.expr.Var): - lhs = compiler_begin(lhs, "ccompiler") - if isinstance(rhs, relay.expr.Var): - rhs = compiler_begin(rhs, "ccompiler") - op = relay.multiply(lhs, rhs) - if self.in_compiler == 2: - op = compiler_end(op, "ccompiler") - self.in_compiler = 0 - return op - return super().visit_call(call) - - -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", [""]) -def test_byoc_microtvm(use_calculated_workspaces, target_options): +def test_byoc_microtvm(): """This is a simple test case to check BYOC capabilities of AOT""" + use_unpacked_api = False + interface_api = "packed" + test_runner = AOT_DEFAULT_RUNNER + x = relay.var("x", shape=(10, 10)) w0 = relay.var("w0", shape=(10, 10)) w1 = relay.var("w1", shape=(10, 10)) @@ -368,7 +344,7 @@ def test_byoc_microtvm(use_calculated_workspaces, target_options): r = relay.concatenate((q0, q1, q2), axis=0) f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) mod = tvm.IRModule() - ann = CcompilerAnnotator() + ann = byoc.CcompilerAnnotator() mod["main"] = ann.visit(f) mod = tvm.relay.transform.PartitionGraph("mod_name")(mod) @@ -379,18 +355,20 @@ def test_byoc_microtvm(use_calculated_workspaces, target_options): for _ in range(8): w_data.append(np.random.rand(10, 10).astype("float32")) - map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} - map_inputs["x"] = x_data + map_inputs = OrderedDict([("x", x_data)] + [("w{}".format(i), w_data[i]) for i in range(8)]) output_list = generate_ref_data(mod, map_inputs) input_list = [map_inputs["x"]] input_list.extend([map_inputs["w{}".format(i)] for i in range(8)]) compile_and_run( - mod, input_list, output_list, target_options, use_calculated_workspaces, mod_name="my_mod" + AOTTestModel(name="my_mod", module=mod, inputs=map_inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, ) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_add_name_mangling_with_params(target_options): +@parametrize_aot_options +def test_add_name_mangling_with_params(interface_api, use_unpacked_api, test_runner): x = relay.var("x", shape=(1, 10)) y = relay.var("y", shape=(1, 10)) z = relay.add(x, y) @@ -403,27 +381,22 @@ def test_add_name_mangling_with_params(target_options): inputs = {"y": y_in} output_list = generate_ref_data(func, inputs, params) - input_list = [y_in] compile_and_run( - func, - input_list, - output_list, - target_options, - use_calculated_workspaces=False, - params=params, - mod_name="my_mod", + AOTTestModel(name="my_mod", module=func, inputs=inputs, outputs=output_list, params=params), + test_runner, + interface_api, + use_unpacked_api, ) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_multiple_models(target_options): +@parametrize_aot_options +def test_multiple_models(interface_api, use_unpacked_api, test_runner): # Identity model without params x = relay.var("x", "float32") mod1 = relay.Function([x], x) one = np.array(1.0, "float32") inputs1 = {"x": one} output_list1 = generate_ref_data(mod1, inputs1) - input_list1 = [one] params1 = None # Convolution model @@ -453,15 +426,19 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), params2 = {"weight": weight_data} inputs2 = {"data": input_data} output_list2 = generate_ref_data(mod2, inputs2, params2) - input_list2 = [input_data] - input_list_map = {"mod1": input_list1, "mod2": input_list2} - output_list_map = {"mod1": output_list1, "mod2": output_list2} - mod_map = {"mod1": mod1, "mod2": mod2} - param_map = {"mod1": params1, "mod2": params2} - - compile_and_run_multiple_models( - mod_map, input_list_map, output_list_map, target_options, param_map + compile_and_run( + [ + AOTTestModel( + name="mod1", module=mod1, inputs=inputs1, outputs=output_list1, params=params1 + ), + AOTTestModel( + name="mod2", module=mod2, inputs=inputs2, outputs=output_list2, params=params2 + ), + ], + test_runner, + interface_api, + use_unpacked_api, ) @@ -473,6 +450,10 @@ def test_quant_mobilenet_tfl(): import tvm.relay.testing.tf as tf_testing + interface_api = "packed" + use_unpacked_api = False + test_runner = AOT_DEFAULT_RUNNER + tflite_model_file = tf_testing.get_workload_official( "https://storage.googleapis.com/download.tensorflow.org/" "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", @@ -486,12 +467,16 @@ def test_quant_mobilenet_tfl(): mod, params = convert_to_relay(tflite_model_buf, data, "input") inputs = {"input": data} output_list = generate_ref_data(mod, inputs, params) - input_list = [inputs["input"]] - compile_and_run(mod, input_list, output_list, "--unpacked-api=0", True, params) + compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params), + test_runner, + interface_api, + use_unpacked_api, + ) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_transpose(target_options): +@parametrize_aot_options +def test_transpose(interface_api, use_unpacked_api, test_runner): """Test that non-inpleaceable operations (e.g., transpose) do not happen in-place.""" dtype = "float32" @@ -506,12 +491,17 @@ def test_transpose(target_options): x_data = np.random.rand(10, 5).astype(dtype) y_data = np.random.rand(10, 5).astype(dtype) t_data = np.random.uniform(size=()).astype(dtype) - inputs = {"x": x_data, "y": y_data, "z": t_data} + inputs = {"x": x_data, "y": y_data, "z": t_data} output_list = generate_ref_data(func, inputs) - input_list = [inputs["x"], inputs["y"], inputs["z"]] - compile_and_run(func, input_list, output_list, target_options, True, enable_op_fusion=False) + compile_and_run( + AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + enable_op_fusion=False, + ) if __name__ == "__main__": - pytest.main([__file__]) + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/dyn/test_dynamic_op_level10.py b/tests/python/relay/dyn/test_dynamic_op_level10.py index ad9a0ecd4e59..0f47ce02db49 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level10.py +++ b/tests/python/relay/dyn/test_dynamic_op_level10.py @@ -47,10 +47,9 @@ def verify_more_dynamic_broadcast_to(x_shape, out_shape): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate(func)( - x, np.array(x_shape).astype(shape_type), np.array(out_shape).astype(shape_type) - ) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate( + func + )(x, np.array(x_shape).astype(shape_type), np.array(out_shape).astype(shape_type)) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_more_dynamic_broadcast_to((4, 3), (3, 4, 3)) @@ -73,8 +72,9 @@ def verify_broadcast_to(x_shape, out_shape): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate(func)(x, np.array(out_shape).astype(shape_type)) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate( + func + )(x, np.array(out_shape).astype(shape_type)) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_broadcast_to((1,), (1, 1, 1)) @@ -103,8 +103,9 @@ def test_dyn_broadcast_to(): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type)) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate(func)( + x, np.array(dyn_shape).astype(shape_type) + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -136,8 +137,9 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - out_relay = intrp.evaluate()(indices_np, np.array(depth).astype("int32")) + out_relay = relay.create_executor( + kind, mod=mod, device=dev, target=target + ).evaluate()(indices_np, np.array(depth).astype("int32")) tvm.testing.assert_allclose(out_relay.numpy(), out_np) _verify((3,), 3, 1, 0, -1, "int32") diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index a6ea609be1e2..fd7ab7002806 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -60,8 +60,7 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()( + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32") ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) @@ -127,8 +126,7 @@ def verify_upsampling3d( for target, dev in enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()( + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( x_data, np.array(scale_d).astype("float32"), np.array(scale_h).astype("float32"), diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index 3673f08cf8b2..d2ad5a47f15b 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -31,8 +31,9 @@ def verify_func(func, data, ref_res, target_device=tvm.testing.enabled_targets() for target, dev in target_device: for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(*data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + *data + ) if isinstance(op_res, tvm.runtime.container.ADT): assert len(op_res) == len( ref_res diff --git a/tests/python/relay/dyn/test_dynamic_op_level4.py b/tests/python/relay/dyn/test_dynamic_op_level4.py index f5afbd7588fd..2a4606fcf93f 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level4.py +++ b/tests/python/relay/dyn/test_dynamic_op_level4.py @@ -66,8 +66,9 @@ def verify(dshape, begin, end, strides, slice_mode="end", test_ref=True, dtype=" return for target, dev in tvm.testing.enabled_targets(): mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor("vm", mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(*data) + op_res = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()( + *data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) verify( diff --git a/tests/python/relay/dyn/test_dynamic_op_level5.py b/tests/python/relay/dyn/test_dynamic_op_level5.py index d3459afaab06..c29ea2cd392f 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level5.py +++ b/tests/python/relay/dyn/test_dynamic_op_level5.py @@ -64,8 +64,9 @@ def verify_resize2d(dshape, scale, method, layout): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_data, size) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_data, size + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) for method in ["linear", "nearest_neighbor"]: diff --git a/tests/python/relay/dyn/test_dynamic_op_level6.py b/tests/python/relay/dyn/test_dynamic_op_level6.py index 03823062eab7..530c402b2947 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level6.py +++ b/tests/python/relay/dyn/test_dynamic_op_level6.py @@ -55,8 +55,9 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(np_data, np.array([k]).astype("float32")) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + np_data, np.array([k]).astype("float32") + ) if ret_type == "both": tvm.testing.assert_allclose(op_res[0].numpy(), np_values) tvm.testing.assert_allclose(op_res[1].numpy(), np_indices) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 51f46799e606..8cf31f94378e 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -34,7 +34,14 @@ def count(e): dev = tvm.device("llvm", 0) -intrp = create_executor(mod=prelude.mod, device=dev, target="llvm") + + +def eval(expr): + # CAUTION: These tests re-process the entire prelude for each test expression. + # Hoisting the create_executor won't improve that since preprocessing won't begin + # until the evaluate. + return create_executor(mod=prelude.mod, device=dev, target="llvm").evaluate(expr) + nat, z, s = prelude.mod.get_type("nat") @@ -139,7 +146,7 @@ def get_scalar(tv): # @tvm.testing.uses_gpu def test_nat_value(): assert count(make_nat_value(p, 10)) == 10 - assert count(intrp.evaluate(s(s(z())))) == 2 + assert count(eval(s(s(z())))) == 2 @tvm.testing.uses_gpu @@ -158,14 +165,14 @@ def test_nat_constructor(): @tvm.testing.uses_gpu def test_double(): assert prelude.mod[double].checked_type == relay.FuncType([nat()], nat()) - res = intrp.evaluate(double(s(z()))) + res = eval(double(s(z()))) assert count(res) == 2 @tvm.testing.uses_gpu def test_add(): assert prelude.mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) - res = intrp.evaluate(add(s(z()), s(z()))) + res = eval(add(s(z()), s(z()))) assert count(res) == 2 @@ -187,7 +194,7 @@ def test_hd_tl(): got = [] for i in range(len(expected)): - got.append(count(intrp.evaluate(hd(l)))) + got.append(count(eval(hd(l)))) l = tl(l) assert got == expected @@ -202,7 +209,7 @@ def test_nth(): for i in range(len(expected)): nth = prelude.mod.get_global_var("nth") - item = intrp.evaluate(nth(l, relay.const(i))) + item = eval(nth(l, relay.const(i))) assert get_scalar(item) == i @@ -220,7 +227,7 @@ def test_update(): got = [] for i in range(len(expected)): - got.append(count(intrp.evaluate(nth(l, relay.const(i))))) + got.append(count(eval(nth(l, relay.const(i))))) assert got == expected @@ -231,7 +238,7 @@ def test_length(): assert prelude.mod[length].checked_type == relay.FuncType( [rlist(a)], relay.scalar_type("int32"), [a] ) - res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil()))))) + res = eval(length(cons(z(), cons(z(), cons(z(), nil()))))) assert get_scalar(res) == 3 @@ -245,7 +252,7 @@ def test_map(): x = relay.Var("x") add_one = relay.Function([x], s(x)) - res = intrp.evaluate(map(add_one, cons(z(), cons(z(), nil())))) + res = eval(map(add_one, cons(z(), cons(z(), nil())))) ones = to_list(res) assert len(ones) == 2 assert count(ones[0]) == 1 and count(ones[1]) == 1 @@ -263,7 +270,7 @@ def test_foldl(): x = relay.Var("x") y = relay.Var("y") rev_dup = relay.Function([y, x], cons(x, cons(x, y))) - res = intrp.evaluate( + res = eval( foldl( rev_dup, nil(), @@ -291,7 +298,7 @@ def test_foldr(): x = relay.Var("x") y = relay.Var("y") identity = relay.Function([x, y], cons(x, y)) - res = intrp.evaluate( + res = eval( foldr( identity, nil(), @@ -316,7 +323,7 @@ def test_foldr1(): x = relay.Var("x") y = relay.Var("y") f = relay.Function([x, y], add(x, y)) - res = intrp.evaluate( + res = eval( foldr1( f, cons( @@ -334,7 +341,7 @@ def test_sum(): assert prelude.mod[sum].checked_type == relay.FuncType( [rlist(relay.scalar_type("int32"))], relay.scalar_type("int32") ) - res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil())))) + res = eval(sum(cons(relay.const(1), cons(relay.const(2), nil())))) assert get_scalar(res) == 3 @@ -345,7 +352,7 @@ def test_concat(): l1 = cons(make_nat_expr(prelude, 1), cons(make_nat_expr(prelude, 2), nil())) l2 = cons(make_nat_expr(prelude, 3), cons(make_nat_expr(prelude, 4), nil())) - res = intrp.evaluate(concat(l1, l2)) + res = eval(concat(l1, l2)) catted = to_list(res) assert len(catted) == 4 @@ -379,7 +386,7 @@ def test_filter(): ], ), ) - res = intrp.evaluate( + res = eval( filter( greater_than_one, cons( @@ -416,7 +423,7 @@ def test_zip(): ) l2 = cons(nil(), cons(cons(nil(), nil()), cons(cons(nil(), cons(nil(), nil())), nil()))) - res = intrp.evaluate(zip(l1, l2)) + res = eval(zip(l1, l2)) zipped = to_list(res) assert len(zipped) == 3 assert count(zipped[0][0]) == 1 @@ -428,7 +435,7 @@ def test_zip(): # test truncation l3 = cons(make_nat_expr(prelude, 4), cons(make_nat_expr(prelude, 5), nil())) - shorter_res = intrp.evaluate(zip(l3, l2)) + shorter_res = eval(zip(l3, l2)) truncated = to_list(shorter_res) assert len(truncated) == 2 assert count(truncated[0][0]) == 4 @@ -437,7 +444,7 @@ def test_zip(): assert len(to_list(truncated[1][1])) == 1 l4 = cons(nil(), nil()) - shortest_res = intrp.evaluate(zip(l3, l4)) + shortest_res = eval(zip(l3, l4)) singleton = to_list(shortest_res) assert len(singleton) == 1 assert count(singleton[0][0]) == 4 @@ -449,7 +456,7 @@ def test_rev(): a = relay.TypeVar("a") assert prelude.mod[rev].checked_type == relay.FuncType([rlist(a)], rlist(a), [a]) - res = intrp.evaluate( + res = eval( rev( cons( make_nat_expr(prelude, 1), @@ -488,7 +495,7 @@ def test_unfoldr(): ), ) - res = intrp.evaluate(unfoldr(count_down, make_nat_expr(prelude, 3))) + res = eval(unfoldr(count_down, make_nat_expr(prelude, 3))) unfolded = to_list(res) assert len(unfolded) == 3 @@ -520,7 +527,7 @@ def test_unfoldl(): ), ) - res = intrp.evaluate(unfoldl(count_down, make_nat_expr(prelude, 3))) + res = eval(unfoldl(count_down, make_nat_expr(prelude, 3))) unfolded = to_list(res) assert len(unfolded) == 3 @@ -549,7 +556,7 @@ def test_map_accumr(): make_nat_expr(prelude, 1), cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), ) - res = intrp.evaluate(map_accumr(add_acc_to_each, z(), vals)) + res = eval(map_accumr(add_acc_to_each, z(), vals)) sum = count(res[0]) new_vals = to_list(res[1]) @@ -581,7 +588,7 @@ def test_map_accuml(): make_nat_expr(prelude, 1), cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), ) - res = intrp.evaluate(map_accuml(add_to_acc, z(), vals)) + res = eval(map_accuml(add_to_acc, z(), vals)) sum = count(res[0]) new_vals = to_list(res[1]) @@ -609,7 +616,7 @@ def test_optional_matching(): ), ) - res = intrp.evaluate( + res = eval( foldr( condense, nil(), @@ -636,9 +643,7 @@ def test_tmap(): x = relay.Var("x") add_one = relay.Function([x], s(x)) - res = intrp.evaluate( - tmap(add_one, rose(z(), cons(rose(z(), nil()), cons(rose(z(), nil()), nil())))) - ) + res = eval(tmap(add_one, rose(z(), cons(rose(z(), nil()), cons(rose(z(), nil()), nil()))))) tree_dict = tree_to_dict(res) assert count(tree_dict["member"]) == 1 @@ -657,7 +662,7 @@ def test_size(): root = rose(z(), cons(rose(z(), nil()), cons(rose(z(), nil()), nil()))) t = rose(z(), cons(root, cons(root, cons(root, nil())))) - res = intrp.evaluate(size(t)) + res = eval(size(t)) assert get_scalar(res) == 10 @@ -666,7 +671,7 @@ def test_wildcard_match_solo(): x = relay.Var("x", nat()) copy = relay.Function([x], relay.Match(x, [relay.Clause(relay.PatternWildcard(), x)]), nat()) - res = intrp.evaluate(copy(s(s(s(z()))))) + res = eval(copy(s(s(s(z()))))) assert count(res) == 3 @@ -690,7 +695,7 @@ def test_wildcard_match_order(): nat(), ) - res = intrp.evaluate(return_zero(cons(s(z()), nil()))) + res = eval(return_zero(cons(s(z()), nil()))) # wildcard pattern is evaluated first assert count(res) == 0 @@ -744,7 +749,7 @@ def test_nested_matches(): ) final_list = cons(first_list, cons(second_list, nil())) - res = intrp.evaluate(flatten(final_list)) + res = eval(flatten(final_list)) flat = to_list(res) assert len(flat) == 6 @@ -758,8 +763,8 @@ def test_match_full_var(): v = relay.Var("v") id_func = relay.Function([x], relay.Match(x, [relay.Clause(relay.PatternVar(v), v)])) - res1 = intrp.evaluate(id_func(nil())) - res2 = intrp.evaluate(id_func(cons(z(), cons(z(), nil())))) + res1 = eval(id_func(nil())) + res2 = eval(id_func(cons(z(), cons(z(), nil())))) empty = to_list(res1) assert len(empty) == 0 @@ -794,7 +799,7 @@ def test_nested_pattern_match(): ) get_second = relay.Function([x], match) - res = intrp.evaluate(get_second(cons(s(z()), cons(s(s(z())), nil())))) + res = eval(get_second(cons(s(z()), cons(s(s(z())), nil())))) assert count(res) == 2 @@ -804,14 +809,14 @@ def test_compose(): n = relay.Var("n") inc = relay.Function([n], s(n)) x = relay.Var("x") - res = intrp.evaluate(relay.Call(compose(inc, double), [s(s(z()))])) + res = eval(relay.Call(compose(inc, double), [s(s(z()))])) assert count(res) == 5 @tvm.testing.uses_gpu def test_iterate(): expr = relay.Call(iterate(double, relay.const(2)), [make_nat_expr(prelude, 3)]) - res = intrp.evaluate(relay.Function([], expr)()) + res = eval(relay.Function([], expr)()) assert count(res) == 12 diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index e94b5145ccc2..6430e6aa2116 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -58,8 +58,7 @@ def check_result( continue if kind == "debug" and (only_vm or dev.device_type != tvm.cpu().device_type): continue - ex = relay.create_executor(kind, mod=mod, device=dev, target=tgt) - result = ex.evaluate()(*args) + result = relay.create_executor(kind, mod=mod, device=dev, target=tgt).evaluate()(*args) if isinstance(result, tvm.runtime.container.ADT): result = [r.numpy() for r in result] else: @@ -496,13 +495,24 @@ def verify_any_conv2d( dilation, static_data_shape, ref_out_shape, + data_layout="NCHW", + kernel_layout="OIHW", use_cudnn=False, ): mod = tvm.IRModule() dtype = "float32" data = relay.var("data", shape=data_shape, dtype=dtype) kernel = relay.var("kernel", shape=kernel_shape, dtype=dtype) - y = relay.nn.conv2d(data, kernel, strides, padding, dilation, kernel_size=kernel_shape[2:4]) + y = relay.nn.conv2d( + data, + kernel, + strides, + padding, + dilation, + kernel_size=kernel_shape[2:4] if kernel_layout == "OIHW" else kernel_shape[0:2], + data_layout=data_layout, + kernel_layout=kernel_layout, + ) mod["main"] = relay.Function([data, kernel], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) @@ -545,6 +555,28 @@ def test_any_conv2d(): (1, 64, 224, 224), use_cudnn=True, ) + verify_any_conv2d( + (relay.Any(), 224, 224, 64), + (3, 3, 64, 64), + (1, 1), + (1, 1), + (1, 1), + (1, 224, 224, 64), + (1, 224, 224, 64), + data_layout="NHWC", + kernel_layout="HWIO", + ) + verify_any_conv2d( + (relay.Any(), 224, 224, 64), + (3, 3, 64, 64), + (1, 1), + (1, 1), + (2, 2), + (2, 224, 224, 64), + (2, 222, 222, 64), + data_layout="NHWC", + kernel_layout="HWIO", + ) def verify_any_conv2d_NCHWc( @@ -582,7 +614,6 @@ def verify_any_conv2d_NCHWc( # TODO(@kevinthesun): Support dynamic input height and width. -@tvm.testing.uses_gpu def test_any_conv2d_NCHWc(): verify_any_conv2d_NCHWc( (relay.Any(), 8, 224, 224, 8), @@ -610,6 +641,63 @@ def test_any_conv2d_NCHWc(): ) +def verify_any_conv1d_transpose_ncw( + data_shape, + kernel_shape, + strides, + padding, + dilation, + groups, + static_data_shape, + ref_out_shape, + output_padding, +): + mod = tvm.IRModule() + dtype = "float32" + data = relay.var("data", shape=data_shape, dtype=dtype) + kernel = relay.var("kernel", shape=kernel_shape, dtype=dtype) + y = relay.nn.conv1d_transpose( + data, + kernel, + strides, + padding, + dilation, + groups, + kernel_size=kernel_shape[2:], + output_padding=output_padding, + ) + mod["main"] = relay.Function([data, kernel], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) + check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True) + + +@tvm.testing.uses_gpu +def test_any_conv1d_transpose_ncw(): + verify_any_conv1d_transpose_ncw( + (relay.Any(), 64, 224), + (64, 192, 3), + (1,), + (1,), + (1,), + 1, + (2, 64, 224), + (2, 192, 224), + (0, 0), + ) + verify_any_conv1d_transpose_ncw( + (relay.Any(), 32, 224), + (32, 64, 3), + (2,), + (1,), + (1,), + 1, + (1, 32, 224), + (1, 64, 448), + (1, 1), + ) + + def verify_any_conv2d_transpose_nchw( data_shape, kernel_shape, @@ -762,8 +850,9 @@ def verify_any_split(data_shape, indices_or_sections, axis, static_data_shape, r mod["main"] = relay.Function([data], y.astuple()) data_np = np.random.uniform(size=static_data_shape).astype(dtype) for kind in ["vm"]: - ex = relay.create_executor(kind, mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) + result = relay.create_executor(kind, mod=mod, device=tvm.cpu(), target="llvm").evaluate()( + data_np + ) for ret, ref_ret in zip(result, ref_out_shape): assert ret.numpy().shape == ref_ret, "Shape mismatch: expect %s but got %s." % ( str(ref_ret), @@ -830,6 +919,99 @@ def test_any_dense_dynamic_batch(): verify_any_dense((relay.Any(), 40), (50, 40), 50, (4, 40), (50, 40), (4, 50), use_cublas=True) +def verify_any_batch_matmul( + x_shape, + y_shape, + out_shape, + x_var_shape, + y_var_shape, + dtype="float32", + trans_x=False, + trans_y=True, +): + x = relay.var("x", relay.TensorType(x_var_shape, dtype)) + y = relay.var("y", relay.TensorType(y_var_shape, dtype)) + z = relay.nn.batch_matmul(x, y, transpose_a=trans_x, transpose_b=trans_y) + + func = relay.Function([x, y], z) + x_np = np.random.uniform(size=x_shape).astype(dtype) + y_np = np.random.uniform(size=y_shape).astype(dtype) + z_np = tvm.topi.testing.batch_matmul(x_np, y_np, trans_x=trans_x, trans_y=trans_y) + + for target, dev in tvm.testing.enabled_targets(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + z = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np, y_np + ) + tvm.testing.assert_allclose(z.numpy(), z_np, rtol=1e-5) + + +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu +def test_any_batch_matmul(): + verify_any_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16), (1, 16, 32), (relay.Any(),) * 3) + verify_any_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16), (5, 16, 32), (relay.Any(),) * 3) + verify_any_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20), (5, 16, 32), (relay.Any(),) * 3) + verify_any_batch_matmul( + (30, 16, 32), (30, 20, 32), (30, 16, 20), (30, 16, 32), (relay.Any(),) * 3 + ) + + verify_any_batch_matmul( + (1, 16, 32), (1, 16, 32), (1, 16, 16), (relay.Any(), 16, 32), (relay.Any(), 16, 32) + ) + verify_any_batch_matmul( + (5, 16, 32), (5, 16, 32), (5, 16, 16), (relay.Any(), 16, 32), (relay.Any(), 16, 32) + ) + verify_any_batch_matmul( + (5, 16, 32), (5, 20, 32), (5, 16, 20), (relay.Any(), 16, 32), (relay.Any(), 20, 32) + ) + verify_any_batch_matmul( + (30, 16, 32), (30, 20, 32), (30, 16, 20), (relay.Any(), 16, 32), (relay.Any(), 20, 32) + ) + + verify_any_batch_matmul( + (1, 32, 16), (1, 16, 32), (1, 16, 16), (1, 32, 16), (relay.Any(),) * 3, trans_x=True + ) + verify_any_batch_matmul( + (5, 16, 32), (5, 32, 16), (5, 16, 16), (5, 16, 32), (relay.Any(),) * 3, trans_y=False + ) + verify_any_batch_matmul( + (5, 32, 16), + (5, 32, 20), + (5, 16, 20), + (5, 32, 16), + (relay.Any(),) * 3, + trans_x=True, + trans_y=False, + ) + verify_any_batch_matmul( + (1, 32, 16), + (1, 16, 32), + (1, 16, 16), + (relay.Any(), 32, 16), + (relay.Any(), 16, 32), + trans_x=True, + ) + verify_any_batch_matmul( + (5, 16, 32), + (5, 32, 16), + (5, 16, 16), + (relay.Any(), 16, 32), + (relay.Any(), 32, 16), + trans_y=False, + ) + verify_any_batch_matmul( + (5, 32, 16), + (5, 32, 20), + (5, 16, 20), + (relay.Any(), 32, 16), + (relay.Any(), 32, 20), + trans_x=True, + trans_y=False, + ) + + @tvm.testing.uses_gpu def verify_any_pad(data_shape, pad_width, static_data_shape): mod = tvm.IRModule() @@ -899,6 +1081,72 @@ def test_any_softmax(): verify_any_softmax(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1)) +def verify_any_relu(data_shape, static_data_shape, ref_out_shape): + mod = tvm.IRModule() + dtype = "float32" + data = relay.var("data", shape=data_shape, dtype=dtype) + y = relay.nn.relu(data) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + check_result([data_np], mod, ref_out_shape, assert_shape=True) + + +@tvm.testing.uses_gpu +def test_any_relu(): + verify_any_relu(any_dims(3), (1, 2, 3), (1, 2, 3)) + verify_any_relu(any_dims(4), (13, 11, 3, 1), (13, 11, 3, 1)) + + +def verify_any_prelu(data_shape, alpha, static_data_shape, ref_out_shape): + mod = tvm.IRModule() + dtype = "float32" + data = relay.var("data", shape=data_shape, dtype=dtype) + alpha = relay.const(np.array([alpha]), dtype=dtype) + y = relay.nn.prelu(data, alpha) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + check_result([data_np], mod, ref_out_shape, assert_shape=True) + + +@tvm.testing.uses_gpu +def test_any_prelu(): + verify_any_prelu(any_dims(3), 1, (1, 2, 3), (1, 2, 3)) + verify_any_prelu(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1)) + + +def verify_any_leaky_relu(data_shape, alpha, static_data_shape, ref_out_shape): + mod = tvm.IRModule() + dtype = "float32" + data = relay.var("data", shape=data_shape, dtype=dtype) + y = relay.nn.leaky_relu(data, alpha) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + check_result([data_np], mod, ref_out_shape, assert_shape=True) + + +@tvm.testing.uses_gpu +def test_any_leaky_relu(): + verify_any_leaky_relu(any_dims(3), 0.1, (1, 2, 3), (1, 2, 3)) + verify_any_leaky_relu(any_dims(4), 0.2, (13, 11, 3, 1), (13, 11, 3, 1)) + + +def verify_any_bias_add(data_shape, static_data_shape, ref_out_shape): + mod = tvm.IRModule() + dtype = "float32" + data = relay.var("data", shape=data_shape, dtype=dtype) + bias = relay.const(np.random.randn(1), dtype=dtype) + y = relay.nn.bias_add(data, bias) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + check_result([data_np], mod, ref_out_shape, assert_shape=True) + + +@tvm.testing.uses_gpu +def test_any_bias_add(): + verify_any_bias_add(any_dims(3), (1, 2, 3), (1, 2, 3)) + verify_any_bias_add(any_dims(4), (13, 11, 3, 1), (13, 11, 3, 1)) + + def verify_any_topk(data_shape, kval, np_dshape, dtype, ret_type="indices", const_k=False): mod = tvm.IRModule() data = relay.var("data", shape=data_shape, dtype=dtype) diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py b/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py index 6a9d9a5cf0ad..dd3126a09810 100644 --- a/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py +++ b/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py @@ -179,7 +179,7 @@ def get_output(data, lib): actual_output = get_output(data, lib) expected_output = get_output(data, lib2) - tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4, atol=1e-4) + tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) def test_conv2d(): @@ -188,7 +188,7 @@ def test_conv2d(): def test_conv2d_winograd(): - mod, data, weight = get_relay_conv2d(kh=3, kw=3) + mod, data, weight = get_relay_conv2d(outc=128, kh=3, kw=3) tune_and_check(mod, data, weight) diff --git a/tests/python/relay/test_auto_scheduler_task_extraction.py b/tests/python/relay/test_auto_scheduler_task_extraction.py index 0acd8b0c87d6..a53b68cca885 100644 --- a/tests/python/relay/test_auto_scheduler_task_extraction.py +++ b/tests/python/relay/test_auto_scheduler_task_extraction.py @@ -15,10 +15,13 @@ # specific language governing permissions and limitations # under the License. """Test task extraction for auto-scheduler""" -import pytest +import json +import tempfile +import pytest import tvm.relay.testing import tvm.testing +from tvm import _ffi as _ffi_api from tvm import auto_scheduler, relay @@ -96,52 +99,61 @@ def get_network(name, batch_size=1, layout="NHWC"): @tvm.testing.requires_cuda -def test_task_extraction_cuda(): +@pytest.mark.parametrize( + "params", + [ + ("mlp", "NHWC", 1, 2), + ("resnet-18", "NHWC", 24, 25), + ("resnet-18", "NCHW", 24, 25), + ("mobilenet", "NHWC", 22, 30), + ("mobilenet", "NCHW", 22, 30), + ("resnet3d-18", "NCDHW", 23, 24), + ("resnet3d-18", "NDHWC", 23, 24), + ], +) +def test_task_extraction_cuda(params): target = tvm.target.Target("cuda") + network, layout, expected_task, expected_weights = params - mod, params = get_network("mlp") + mod, params = get_network(network, layout=layout) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) - - assert len(tasks) == 1 - assert sum(task_weights) == 2 - - for layout in ["NHWC", "NCHW"]: - mod, params = get_network("resnet-18", layout=layout) - tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) - - assert len(tasks) == 24 - assert sum(task_weights) == 25 - - mod, params = get_network("mobilenet", layout=layout) - tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) - - assert len(tasks) == 22 - assert sum(task_weights) == 30 - - for layout in ["NCDHW", "NDHWC"]: - mod, params = get_network("resnet3d-18", layout=layout) - tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) - - assert len(tasks) == 23 - assert sum(task_weights) == 24, sum(task_weights) - - -def test_task_extraction(): + for task, weight in zip(tasks, task_weights): + print(task.desc, task.workload_key, weight) + + assert len(tasks) == expected_task + assert sum(task_weights) == expected_weights + + +@pytest.mark.parametrize( + "params", + [ + # Relay FuseOps puts two conv2ds to separate functions and results in two tasks. + ("basic_func", 2, False), + # Relay FuseOps will not break the primitive function and result in one task. + ("fused_func", 1, False), + # The Relay function without complex ops will not form a task by default. + ("simple_func", 0, False), + # Every Relay function becomes a task regardless what ops in its body. + ("simple_func", 1, True), + # The Relay function without any reduce op is considered as a simple task. + ("shape_of_func", 0, False), + ("shape_of_func", 1, True), + # The Relay function with dynamic shape inputs/outputs will not be extracted. + ("dyn_shape_func", 0, False), + # The Conv2D in the Relay function with control flow could still be a task. + # Also, two identical Conv2D should only be one task with weight=2. + ("control_flow_func", 1, False), + # The first function with unsupported op (NMS) will not be extracted. + ("func_w_unsupported_op", 1, True), + ], +) +def test_task_extraction_cpu(params): ishape = (1, 3, 224, 224) w1shape = (32, 3, 3, 3) w2shape = (32, 32, 3, 3) dtype = "float32" target = tvm.target.Target("llvm") - def verify_task_extraction(func, expected_task, include_simple_tasks=False): - mod = tvm.IRModule.from_expr(func) - tasks, task_weights = auto_scheduler.extract_tasks( - mod["main"], None, target, include_simple_tasks=include_simple_tasks - ) - - assert len(tasks) == expected_task - assert len(task_weights) == expected_task - def get_func(): data = relay.var("data", shape=(ishape), dtype=dtype) weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype) @@ -183,13 +195,16 @@ def get_func_with_dynamic_shape(): def get_func_with_control_flow(): data = relay.var("data", shape=(1, 3, 224, 224)) - weight = relay.var("weight", shape=(32, 3, 3, 3)) + weight = relay.var("weight", shape=(3, 3, 3, 3)) eq1 = relay.var("e1", shape=[], dtype="float32") eq2 = relay.var("e2", shape=[], dtype="float32") eq = relay.equal(eq1, eq2) - true_branch = relay.zeros(shape=(1, 32, 222, 222), dtype="float32") - false_branch = relay.nn.conv2d(data, weight, kernel_size=(3, 3), channels=32) + true_branch = relay.zeros(shape=(1, 3, 224, 224), dtype="float32") + false_branch = relay.nn.conv2d(data, weight, kernel_size=(3, 3), channels=3, padding=(1, 1)) + false_branch = relay.nn.conv2d( + false_branch, weight, kernel_size=(3, 3), channels=3, padding=(1, 1) + ) ife = relay.If(eq, true_branch, false_branch) out = relay.erf(ife) return relay.Function([data, weight, eq1, eq2], out) @@ -213,32 +228,67 @@ def get_postproc_func(): out = relay.Call(get_postproc_func(), [nms]) return relay.Function([cls_prob, loc_pred, anchors], out) - # Relay FuseOps puts two conv2ds to separate functions and results in two tasks. - verify_task_extraction(get_func(), 2) + func_map = { + "basic_func": get_func, + "fused_func": get_fused_func, + "simple_func": get_simple_func, + "shape_of_func": get_shape_of_func, + "dyn_shape_func": get_func_with_dynamic_shape, + "control_flow_func": get_func_with_control_flow, + "func_w_unsupported_op": get_func_with_unsupported_op, + } + + def verify_task_extraction(func_name, expected_task, include_simple_tasks=False): + func = func_map[func_name]() + mod = tvm.IRModule.from_expr(func) + tasks, task_weights = auto_scheduler.extract_tasks( + mod["main"], None, target, include_simple_tasks=include_simple_tasks + ) + + assert len(tasks) == expected_task + assert len(task_weights) == expected_task - # By setting the function to primitive, Relay FuseOps will not break it and result in one task. - verify_task_extraction(get_fused_func(), 1) + verify_task_extraction(*params) + + +def test_dump_workload_to_dag_extract_tasks(): + mod, _ = get_network("mobilenet", layout="NHWC") + with tempfile.NamedTemporaryFile() as f: + tasks, _ = auto_scheduler.extract_tasks( + mod["main"], None, "llvm", include_simple_tasks=True, dump_workload_to_dag_log=f.name + ) + expected = {task.workload_key: str(task.compute_dag) for task in tasks} + actual = json.load(f) + assert expected == actual - # The Relay function without complex ops will not form a task by default. - verify_task_extraction(get_simple_func(), 0) - # Every Relay function becomes a task regardless what ops in its body. - verify_task_extraction(get_simple_func(), 1, True) +def test_custom_hash_func_extract_tasks(): + @_ffi_api.register_func("auto_scheduler.compute_dag.hash_func") + def counting_unique_hash(str_dag): + ret = counting_unique_hash.i + counting_unique_hash.i += 1 + return ret - # The Relay function without any reduce op is considered as a simple task. - verify_task_extraction(get_shape_of_func(), 0) - verify_task_extraction(get_shape_of_func(), 1, True) + counting_unique_hash.i = 0 - # The Relay function with dynamic shape inputs/outputs will not be extracted. - verify_task_extraction(get_func_with_dynamic_shape(), 0) + mod, _ = get_network("mobilenet", layout="NHWC") + tasks, _ = auto_scheduler.extract_tasks(mod["main"], None, "llvm", include_simple_tasks=True) - # The Conv2D in the Relay function with control flow could still be a task. - verify_task_extraction(get_func_with_control_flow(), 1) + hash_values = [] + for task in tasks: + # task.workload_key should look like + # [43, [3, 3, 1024, 1], [1024], [3, 3, 1024, 1]] where the first int is the result of the hash + # Extract the hash and keep track of every hash + hash_value = int(task.workload_key[1:].split(",")[0]) + hash_values.append(hash_value) - # Func1 (with NMS) -> Func2 (injective). - verify_task_extraction(get_func_with_unsupported_op(), 1, True) + # All values are unique, and we know the min and max + # This is a sufficient condition to know that hashes in hash_values are an increasing list + # of hashes up to counting_unique_hash.i - 1 + assert len(hash_values) == len(set(hash_values)) + assert min(hash_values) == 0 + assert max(hash_values) == counting_unique_hash.i - 1 if __name__ == "__main__": - test_task_extraction_cuda() - test_task_extraction() + pytest.main([__file__]) diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index e7040f55f631..f1ab58e7bf07 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest +from unittest.mock import patch import tvm import json @@ -22,6 +24,9 @@ from tvm.contrib import graph_executor from tvm.relay.op import add import tvm.testing +from tvm.relay.testing import mlp +from tvm import rpc +from tvm.contrib import utils # @tq, @jr should we put this in testing ns? def check_rts(expr, args, expected_result, mod=None): @@ -40,27 +45,48 @@ def check_rts(expr, args, expected_result, mod=None): expected_result: The expected result of running the expression. """ - intrp = relay.create_executor("debug", mod=mod) - graph = relay.create_executor("graph", mod=mod) - eval_result = intrp.evaluate(expr)(*args) - rts_result = graph.evaluate(expr)(*args) + eval_result = relay.create_executor("debug", mod=mod).evaluate(expr)(*args) + rts_result = relay.create_executor("graph", mod=mod).evaluate(expr)(*args) tvm.testing.assert_allclose(eval_result.numpy(), rts_result.numpy()) tvm.testing.assert_allclose(eval_result.numpy(), expected_result) def test_add_op_scalar(): """ - Program: + test_add_op_scalar: fn (x, y) { return x + y; } """ - x = relay.var("x", shape=()) - y = relay.var("y", shape=()) + x = relay.var("x", shape=()) # Default to float32 + y = relay.var("y", shape=()) # Default to float32 func = relay.Function([x, y], add(x, y)) - x_data = np.array(10.0, dtype="float32") - y_data = np.array(1.0, dtype="float32") - check_rts(func, [x_data, y_data], x_data + y_data) + x_y_data = [ + (np.array(10.0, dtype="float32"), np.array(1.0, dtype="float32")), + (np.float32(10.0), np.float32(1.0)), + (10.0, 1.0), + ] + for (x_data, y_data) in x_y_data: + check_rts(func, [x_data, y_data], x_data + y_data) + + +def test_add_op_scalar_int(): + """ + test_add_op_scalar_int: + fn (x, y) { + return x + y; + } + """ + x = relay.var("x", shape=(), dtype="int32") + y = relay.var("y", shape=(), dtype="int32") + func = relay.Function([x, y], add(x, y)) + x_y_data = [ + (np.array(10.0, dtype="int32"), np.array(1.0, dtype="int32")), + (np.int32(10), np.int32(1)), + (10, 1), + ] + for (x_data, y_data) in x_y_data: + check_rts(func, [x_data, y_data], x_data + y_data) def test_add_op_tensor(): @@ -271,10 +297,9 @@ def test_graph_executor_nested_tuples(): out = relay.Tuple([x, relay.Tuple([y, relay.Tuple([z, w])])]) func = relay.Function([x, y, z, w], out) - exe = relay.create_executor( + f = relay.create_executor( kind="graph", mod=tvm.IRModule.from_expr(func), device=tvm.cpu(0), target="llvm" - ) - f = exe.evaluate() + ).evaluate() data = [np.random.uniform(size=(2, 3)).astype("float32") for _ in "xyzw"] out = f(*data) @@ -287,5 +312,77 @@ def test_graph_executor_nested_tuples(): tvm.testing.assert_allclose(out[1][1][1].numpy(), data[3]) +def test_graph_executor_api(): + dname_0, dname_1 = "data_0", "data_1" + data_0, data_1 = [relay.var(c, shape=(1, 1), dtype="float32") for c in [dname_0, dname_1]] + net = relay.add(data_0, data_1) + func = relay.Function((data_0, data_1), net) + + lib = relay.build(tvm.IRModule.from_expr(func), "llvm") + mod = graph_executor.GraphModule(lib["default"](tvm.cpu(0))) + + assert mod.get_input_index(dname_1) == 1 + assert mod.get_input_index(dname_0) == 0 + assert mod.get_input_index("Invalid") == -1 + + +@tvm.testing.requires_llvm +def test_benchmark(): + mod, params = mlp.get_workload(1) + lib = relay.build(mod, target="llvm", params=params) + exe = graph_executor.create(lib.get_graph_json(), lib.lib, tvm.cpu()) + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32")) + result = exe.benchmark(tvm.cpu(), data=data, func_name="run", repeat=2, number=1) + assert result.mean == result.median + assert result.mean > 0 + assert len(result.results) == 2 + + with patch.object( + tvm.runtime.module.Module, + "time_evaluator", + return_value=lambda: tvm.runtime.module.BenchmarkResult([1, 2, 2, 5]), + ) as method: + result = exe.benchmark(tvm.cpu(), data=data, func_name="run", repeat=2, number=1) + assert result.mean == 2.5 + assert result.median == 2.0 + assert result.max == 5 + assert result.min == 1 + assert result.std == 1.5 + + +@tvm.testing.parametrize_targets("cuda", "llvm") +def test_benchmark_end_to_end(dev, target): + mod, params = mlp.get_workload(1) + lib = relay.build(mod, target=target, params=params) + exe = graph_executor.create(lib.get_graph_json(), lib.lib, dev) + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32")) + result = exe.benchmark(dev, data=data, func_name="run", repeat=2, number=1, end_to_end=True) + assert result.mean > 0 + assert len(result.results) == 2 + + +@tvm.testing.requires_llvm +def test_benchmark_end_to_end_rpc(): + server = rpc.Server("127.0.0.1") + remote = rpc.connect(server.host, server.port) + + mod, params = mlp.get_workload(1) + lib = relay.build(mod, target="llvm", params=params) + + temp = utils.tempdir() + path = temp.relpath("library.so") + lib.export_library(path) + remote.upload(path) + rlib = remote.load_module("library.so") + + dev = remote.cpu() + exe = graph_executor.create(lib.get_graph_json(), rlib, dev) + + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32"), device=dev) + result = exe.benchmark(dev, data=data, func_name="run", repeat=2, number=1, end_to_end=True) + assert result.mean > 0 + assert len(result.results) == 2 + + if __name__ == "__main__": - sys.exit(pytest.main([file] + sys.argv[1:])) + pytest.main([__file__]) diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index d65bcad3364d..af2dcf32c305 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -15,27 +15,26 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm -from tvm import te -import tvm.testing +from tvm import testing from tvm import nd from tvm import relay from tvm.runtime import container from tvm.relay.backend.interpreter import RefValue, ConstructorValue from tvm.relay.scope_builder import ScopeBuilder -from tvm.relay import testing, create_executor def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): # TODO(tqchen) add more types once the schedule register is fixed. for target in ["llvm"]: dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): + if not testing.device_enabled(target): return - intrp = create_executor(mod=mod, device=dev, target=target) - result = intrp.evaluate(expr)(*args) - # use tvm.testing which also set atol - tvm.testing.assert_allclose(result.numpy(), expected_result, rtol=rtol) + func = relay.create_executor(mod=mod, device=dev, target=target).evaluate(expr) + result = func if args is None else func(*args) + # use testing which also set atol + testing.assert_allclose(result.numpy(), expected_result, rtol=rtol) def test_tuple_value(): @@ -146,10 +145,9 @@ def test_ref(): def test_binds(): x = relay.var("x") y = relay.add(x, x) - intrp = create_executor("debug") xx = np.ones((10, 20)) - res = intrp.evaluate(y, binds={x: xx}).numpy() - tvm.testing.assert_allclose(xx + xx, res) + res = relay.create_executor().evaluate(y, binds={x: xx}).numpy() + testing.assert_allclose(xx + xx, res) def test_kwargs_params(): @@ -161,15 +159,13 @@ def test_kwargs_params(): y_data = np.random.rand(1, 10).astype("float32") z_data = np.random.rand(1, 10).astype("float32") params = {"y": y_data, "z": z_data} - intrp = create_executor("debug") - res = intrp.evaluate(f)(x_data, **params) - tvm.testing.assert_allclose(res.numpy(), x_data + y_data + z_data) + res = relay.create_executor().evaluate(f)(x_data, **params) + testing.assert_allclose(res.numpy(), x_data + y_data + z_data) def test_function_taking_adt_ref_tuple(): mod = tvm.IRModule() prelude = relay.prelude.Prelude(mod) - intrp = create_executor("debug", mod) _, cons, nil = prelude.mod.get_type("List") nil_value = ConstructorValue(nil.tag, [], nil) @@ -184,7 +180,7 @@ def test_function_taking_adt_ref_tuple(): [nd.array(np.random.rand(1, 10).astype("float32")) for _ in range(10)] ) - id_func = intrp.evaluate(prelude.id) + id_func = relay.create_executor(mod=mod).evaluate(prelude.id) res_nil = id_func(nil_value) assert res_nil.tag == nil_value.tag @@ -193,17 +189,17 @@ def test_function_taking_adt_ref_tuple(): res_cons = id_func(cons_value) assert res_cons.tag == cons_value.tag assert len(res_cons.fields) == len(cons_value.fields) - tvm.testing.assert_allclose(res_cons.fields[0].numpy(), cons_value.fields[0].numpy()) + testing.assert_allclose(res_cons.fields[0].numpy(), cons_value.fields[0].numpy()) assert isinstance(res_cons.fields[1], ConstructorValue) assert res_cons.fields[1].tag == nil.tag assert len(res_cons.fields[1].fields) == 0 res_ref = id_func(ref_value) - tvm.testing.assert_allclose(res_ref.value.numpy(), ref_value.value.numpy()) + testing.assert_allclose(res_ref.value.numpy(), ref_value.value.numpy()) res_tuple = id_func(tuple_value) for i in range(10): - tvm.testing.assert_allclose(res_tuple[i].numpy(), tuple_value[i].numpy()) + testing.assert_allclose(res_tuple[i].numpy(), tuple_value[i].numpy()) def test_tuple_passing(): @@ -222,28 +218,72 @@ def test_tuple_passing(): dev = tvm.cpu() target = tvm.target.Target("llvm") - exec = relay.create_executor(mod=mod, device=dev, target=target) - f = exec.evaluate(gv) + f = relay.create_executor(mod=mod, device=dev, target=target).evaluate(gv) # First use a Python tuple. out = f((10, 8)) - tvm.testing.assert_allclose(out.numpy(), np.array(10)) + testing.assert_allclose(out.numpy(), np.array(10)) # Second use a tuple value. value_tuple = container.tuple_object([nd.array(np.array(11)), nd.array(np.array(12))]) out = f(value_tuple) - tvm.testing.assert_allclose(out.numpy(), np.array(11)) + testing.assert_allclose(out.numpy(), np.array(11)) + + +def test_dynamic(): + n = 3 + m = 2 + x = relay.Var("x", relay.TensorType([relay.Any(), m], "float32")) + y = relay.Var("y", relay.TensorType([relay.Any(), m], "float32")) + xx = x - relay.expr.const(3.0) + yy = y * relay.expr.const(5.0) + z = relay.op.concatenate([xx, yy], axis=0) + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + x_np = np.random.uniform(size=(n, m)).astype("float32") + y_np = np.random.uniform(size=(n, m)).astype("float32") + expected = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0) + check_eval(None, [x_np, y_np], expected, mod) + + +def test_ref_global_from_expr(): + n = 3 + x = relay.Var("x", relay.TensorType([n], "float32")) + y = relay.Var("y", relay.TensorType([n], "float32")) + mod = tvm.IRModule() + mod["add"] = relay.Function([x, y], relay.add(x, y)) + x_np = np.random.uniform(size=(n,)).astype("float32") + y_np = np.random.uniform(size=(n,)).astype("float32") + expected = np.add(x_np, y_np) + expr = relay.Call(mod.get_global_var("add"), [relay.const(x_np), relay.const(y_np)]) + check_eval(expr, None, expected, mod) + + +def test_keyword_args(): + n = 3 + x = relay.Var("x", relay.TensorType([n], "float32")) + y = relay.Var("y", relay.TensorType([n], "float32")) + z = relay.add(x, y) + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + x_np = np.random.uniform(size=(n,)).astype("float32") + y_np = np.random.uniform(size=(n,)).astype("float32") + expected = np.add(x_np, y_np) + actual = relay.create_executor(mod=mod).evaluate()(y=y_np, x=x_np) + testing.assert_allclose(actual.numpy(), expected) + + +# TODO(mbs): Support? Would help reduce wasted work when we need to prepare +# multiple functions w.r.t. the same module. +@pytest.mark.skip(reason="closures are currently not directly Python callable") +def test_functional_returns(): + n = 3 + x = relay.Var("x", relay.TensorType([n], "float32")) + f = relay.Function([x], x) + t = relay.Tuple([f, f]) + c = np.random.rand(n).astype("float32") + result1, result2 = relay.create_executor().evaluate(t) + testing.assert_allclose(result1(c).numpy(), c) + testing.assert_allclose(result2(c).numpy(), c) if __name__ == "__main__": - test_id() - test_add_const() - test_equal() - test_subtract() - test_simple_loop() - test_loop() - test_binds() - test_kwargs_params() - test_ref() - test_tuple_value() - test_tuple_getitem() - test_function_taking_adt_ref_tuple() - test_tuple_passing() + pytest.main([__file__]) diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 1c721f40d129..74e03f6a9755 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -1727,69 +1727,37 @@ def test_partition_constant_embedding(): assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) +def test_rewrite_once(): + # This class recursively removes the arguments to concat until there is nothing left to concatenate. + class ConcatRewriter(DFPatternCallback): + def __init__(self, rewrite_once): + super().__init__(rewrite_once=rewrite_once) + self.pattern = is_op("concatenate")(None) + + def callback(self, pre, post, node_map): + concat_args = post.args[0] + # Remove the last argument + new_args = [concat_args[i] for i in range(len(concat_args) - 1)] + if new_args: + return relay.op.concatenate(relay.expr.Tuple(new_args), axis=0) + else: + return concat_args + + x = relay.var("x") + y = relay.var("y") + z = relay.var("z") + concat = relay.op.concatenate(relay.expr.Tuple([x, y, z]), axis=0) + + # Let the rewriter run recursively + out = rewrite(ConcatRewriter(False), concat) + expected = relay.expr.Tuple([x]) + assert tvm.ir.structural_equal(out, expected) + + # Run the rewriter once + out = rewrite(ConcatRewriter(True), concat) + expected = relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0) + assert tvm.ir.structural_equal(out, expected) + + if __name__ == "__main__": - test_expr_pattern() - test_var_pattern() - test_constant_pattern() - test_wildcard_pattern() - test_CallPattern() - test_TuplePattern() - test_TupleGetItemPattern() - test_AltPattern() - test_TypePattern() - test_DataTypePattern() - test_ShapePattern() - test_AttrPattern() - test_match_op() - test_no_match_op() - test_match_op_or() - test_match_call_commutive() - test_no_match_call_commutive() - test_match_call() - test_no_match_call() - test_match_option() - test_no_match_option() - test_match_const() - test_match_tuple() - test_no_match_tuple() - test_match_type() - test_no_match_type() - test_match_dtype() - test_no_match_dtype() - test_match_shape() - test_no_match_shape() - test_match_op_attr() - test_no_match_op_attr() - test_match_func_attr() - test_no_match_func_attr() - test_match_call_attr() - test_no_match_call_attr() - test_match_diamond() - test_no_match_diamond() - test_match_fake_diamond() - test_match_dominator() - test_not_match_dominator() - test_rewrite() - test_rewrite_func() - test_nested_rewrite() - test_not_fuse_multi_diamond() - test_fuse_batchnorm() - test_no_fuse_batchnorm() - test_fuse_double_batchnorm() - test_partial_fuse_double_batchnorm() - test_fuse_batchnorm_commutation() - test_quadruple_rewrite_dominator() - test_algebraic_simplify() - test_double_partition() - test_partition_dominator() - test_quadruple_partition_dominator() - test_partition_batchnorm() - test_partition_double_batchnorm() - test_partition_check() - test_partition_check_types() - test_partition_option() - test_match_match() - test_partition_constant_embedding() - test_IfPattern() - test_match_if() - test_no_match_if() + pytest.main([__file__]) diff --git a/tests/python/relay/test_debug.py b/tests/python/relay/test_debug.py index c4ed657701ae..61557867f070 100644 --- a/tests/python/relay/test_debug.py +++ b/tests/python/relay/test_debug.py @@ -23,7 +23,6 @@ def test_debug(): global _test_debug_hit - ex = create_executor() x = var("x", shape=(), dtype="int32") _test_debug_hit = False @@ -32,7 +31,7 @@ def did_exec(x): _test_debug_hit = True prog = debug(x, debug_func=did_exec) - result = ex.evaluate(prog, {x: const(1, "int32")}) + result = create_executor().evaluate(prog, {x: const(1, "int32")}) assert _test_debug_hit assert result.numpy() == 1 @@ -40,7 +39,6 @@ def did_exec(x): def test_debug_with_expr(): global _test_debug_hit _test_debug_hit = False - ex = create_executor() x = var("x", shape=(), dtype="int32") _test_debug_hit = False @@ -49,6 +47,6 @@ def did_exec(x): _test_debug_hit = True prog = debug(x + x * x, debug_func=did_exec) - result = ex.evaluate(prog, {x: const(2, "int32")}) + result = create_executor().evaluate(prog, {x: const(2, "int32")}) assert _test_debug_hit assert result.numpy() == 6 diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 84e2fa305bfe..30db5facc208 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -15,18 +15,15 @@ # specific language governing permissions and limitations # under the License. """Unit tests for graph partitioning.""" + import os import sys +from collections import OrderedDict import numpy as np +import pytest import tvm -from tvm import te -import tvm.relay.testing -import tvm.relay.transform - -from tvm import relay -from tvm import runtime -from tvm.relay import transform +from tvm import relay, runtime from tvm.contrib import utils from tvm.relay.build_module import bind_params_by_name from tvm.relay.op.annotation import compiler_begin, compiler_end @@ -48,37 +45,52 @@ def update_lib(lib): return lib -def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu()): - if sys.platform == "win32": - print("Skip test on Windows for now") - return - - def check_vm_result(): - with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): - exe = relay.vm.compile(mod, target=target) - code, lib = exe.save() - lib = update_lib(lib) - exe = runtime.vm.Executable.load_exec(code, lib) - vm = runtime.vm.VirtualMachine(exe, device) - out = vm.run(**map_inputs) - tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol) - - def check_graph_executor_result(): - with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): - json, lib, _ = relay.build(mod, target=target) - lib = update_lib(lib) - rt_mod = tvm.contrib.graph_executor.create(json, lib, device) - - for name, data in map_inputs.items(): - rt_mod.set_input(name, data) - rt_mod.run() - out = tvm.nd.empty(out_shape, device=device) - out = rt_mod.get_output(0, out) - - tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol) - - check_vm_result() - check_graph_executor_result() +def check_vm_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu()): + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + exe = relay.vm.compile(mod, target=target) + code, lib = exe.save() + lib = update_lib(lib) + exe = runtime.vm.Executable.load_exec(code, lib) + vm = runtime.vm.VirtualMachine(exe, device) + out = vm.run(**map_inputs) + tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol) + + +def check_graph_executor_result( + mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu() +): + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + json, lib, _ = relay.build(mod, target=target) + lib = update_lib(lib) + rt_mod = tvm.contrib.graph_executor.create(json, lib, device) + + for name, data in map_inputs.items(): + rt_mod.set_input(name, data) + rt_mod.run() + out = tvm.nd.empty(out_shape, device=device) + out = rt_mod.get_output(0, out) + + tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol) + + +def check_aot_executor_result( + mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu() +): + if tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON": + pytest.skip("MicroTVM support not enabled. Set USE_MICRO=ON in config.cmake to enable.") + + # Late import to avoid breaking test with USE_MICRO=OFF. + from aot.aot_test_utils import AOTTestModel, AOT_DEFAULT_RUNNER, compile_and_run + + interface_api = "packed" + use_unpacked_api = False + test_runner = AOT_DEFAULT_RUNNER + compile_and_run( + AOTTestModel(module=mod, inputs=map_inputs, outputs=[result]), + test_runner, + interface_api, + use_unpacked_api, + ) def set_external_func_attr(func, compiler, ext_symbol): @@ -88,7 +100,11 @@ def set_external_func_attr(func, compiler, ext_symbol): return func -def test_multi_node_subgraph(): +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") +@pytest.mark.parametrize( + "check_result", [check_vm_result, check_graph_executor_result, check_aot_executor_result] +) +def test_multi_node_subgraph(check_result): x = relay.var("x", shape=(10, 10)) w0 = relay.var("w0", shape=(10, 10)) w1 = relay.var("w1", shape=(10, 10)) @@ -138,8 +154,7 @@ def test_multi_node_subgraph(): for _ in range(8): w_data.append(np.random.rand(10, 10).astype("float32")) - map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} - map_inputs["x"] = x_data + map_inputs = OrderedDict([("x", x_data)] + [("w{}".format(i), w_data[i]) for i in range(8)]) check_result( mod, map_inputs, @@ -155,7 +170,11 @@ def test_multi_node_subgraph(): ) -def test_extern_gcc_single_op(): +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") +@pytest.mark.parametrize( + "check_result", [check_vm_result, check_graph_executor_result, check_aot_executor_result] +) +def test_extern_gcc_single_op(check_result): x = relay.var("x", shape=(8, 8)) y = relay.var("y", shape=(8, 8)) @@ -172,7 +191,11 @@ def test_extern_gcc_single_op(): check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data) -def test_extern_gcc_single_op_int(): +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") +@pytest.mark.parametrize( + "check_result", [check_vm_result, check_graph_executor_result, check_aot_executor_result] +) +def test_extern_gcc_single_op_int(check_result): x = relay.var("x", shape=(8, 8), dtype="int32") y = relay.var("y", shape=(8, 8), dtype="int32") @@ -189,7 +212,11 @@ def test_extern_gcc_single_op_int(): check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data) -def test_extern_gcc(): +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") +@pytest.mark.parametrize( + "check_result", [check_vm_result, check_graph_executor_result, check_aot_executor_result] +) +def test_extern_gcc(check_result): x = relay.var("x", shape=(2, 2)) y = relay.var("y", shape=(2, 2)) @@ -221,9 +248,17 @@ def test_extern_gcc(): x_data = np.random.rand(2, 2).astype("float32") y_data = np.random.rand(2, 2).astype("float32") - check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data)) + inputs = OrderedDict( + [ + ("y", y_data), + ("x", x_data), + ] + ) + + check_result(mod, inputs, (2, 2), (y_data * y_data) - (x_data + x_data)) +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") def test_extern_gcc_consts(): @tvm._ffi.register_func("relay.ext.ccompiler.constant_updater") def constant_updater(expr, symbol): @@ -257,11 +292,13 @@ def constant_updater(expr, symbol): tvm._ffi.registry.remove_global_func("relay.ext.ccompiler.constant_updater") -def test_extern_dnnl(): - if not tvm.get_global_func("relay.ext.dnnl", True): - print("skip because DNNL codegen is not available") - return - +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") +@pytest.mark.skipif( + not tvm.get_global_func("relay.ext.dnnl", True), + reason="skip because DNNL codegen is not available", +) +@pytest.mark.parametrize("check_result", [check_vm_result, check_graph_executor_result]) +def test_extern_dnnl(check_result): dtype = "float32" ishape = (1, 32, 14, 14) w1shape = (32, 1, 3, 3) @@ -290,18 +327,21 @@ def test_extern_dnnl(): i_data = np.random.uniform(0, 1, ishape).astype(dtype) w_data = np.random.uniform(0, 1, w1shape).astype(dtype) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()) - ref_res = ref_ex.evaluate()(i_data, w_data, w_data) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()).evaluate()( + i_data, w_data, w_data + ) check_result( mod, {"data0": i_data, "weight0": w_data}, (1, 32, 14, 14), ref_res.numpy(), tol=1e-5 ) -def test_extern_dnnl_const(): - if not tvm.get_global_func("relay.ext.dnnl", True): - print("skip because DNNL codegen is not available") - return - +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") +@pytest.mark.skipif( + not tvm.get_global_func("relay.ext.dnnl", True), + reason="skip because DNNL codegen is not available", +) +@pytest.mark.parametrize("check_result", [check_vm_result, check_graph_executor_result]) +def test_extern_dnnl_const(check_result): dtype = "float32" ishape = (1, 32, 14, 14) w1shape = (32, 1, 3, 3) @@ -329,8 +369,7 @@ def test_extern_dnnl_const(): i_data = np.random.uniform(0, 1, ishape).astype(dtype) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()) - ref_res = ref_ex.evaluate()(i_data) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()).evaluate()(i_data) check_result(mod, {"data0": i_data}, (1, 32, 14, 14), ref_res.numpy(), tol=1e-5) @@ -349,7 +388,7 @@ def test_load_params_with_constants_in_ext_codegen(): zce = compiler_end(z, "ccompiler") mod["main"] = relay.Function([x, y], zce) mod["main"] = bind_params_by_name(mod["main"], params) - mod = transform.PartitionGraph()(mod) + mod = relay.transform.PartitionGraph()(mod) graph_module = relay.build(mod, target="llvm", params=params) # Params will be stored in metadata module. @@ -360,11 +399,4 @@ def test_load_params_with_constants_in_ext_codegen(): if __name__ == "__main__": - test_multi_node_subgraph() - test_extern_gcc_single_op() - test_extern_gcc_single_op_int() - test_extern_gcc() - test_extern_gcc_consts() - test_extern_dnnl() - test_extern_dnnl_const() - test_load_params_with_constants_in_ext_codegen() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_memory_passes.py b/tests/python/relay/test_memory_passes.py index 7ad72a35a1a0..bed17dbbd830 100644 --- a/tests/python/relay/test_memory_passes.py +++ b/tests/python/relay/test_memory_passes.py @@ -32,13 +32,15 @@ def check_memory_plan(func, check_fn): data = np.random.rand(*sh).astype(param.dtype) args.append(tvm.nd.array(data)) - # Compute without memory planning. + # TODO(mbs): Why does the executor need to be shared? Seems wrong. ex = relay.create_executor("vm", mod) - no_plan_result = ex.evaluate(mod["main"])(*args) + + # Compute without memory planning. + no_plan_result = ex.evaluate()(*args) # Compute with memory planning. with tvm.transform.PassContext(opt_level=1, disabled_pass=["MemoryPlan"]): - plan_result = ex.evaluate(mod["main"])(*args) + plan_result = ex.evaluate()(*args) # Compute Python result. py_res = check_fn(*[arg.numpy() for arg in args]) diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py index 686c0ea556c3..11099ffe50ee 100644 --- a/tests/python/relay/test_op_grad_level1.py +++ b/tests/python/relay/test_op_grad_level1.py @@ -54,8 +54,9 @@ def check_single_op(opfunc, ref, dtype): bwd_func = run_infer_type(gradient(fwd_func)) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - op_res, (op_grad, _) = intrp.evaluate(bwd_func)(data, grad_in) + op_res, (op_grad, _) = relay.create_executor(device=dev, target=target).evaluate( + bwd_func + )(data, grad_in) np.testing.assert_allclose(op_grad.numpy(), ref_grad, rtol=0.01) for opfunc, ref in [ @@ -105,8 +106,9 @@ def check_binary_op(opfunc, ref, dtype): bwd_func = run_infer_type(gradient(fwd_func)) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - op_res, (op_grad0, op_grad1) = intrp.evaluate(bwd_func)(x_data, y_data) + op_res, (op_grad0, op_grad1) = relay.create_executor( + device=dev, target=target + ).evaluate(bwd_func)(x_data, y_data) np.testing.assert_allclose(op_grad0.numpy(), ref_grad0, rtol=0.01) np.testing.assert_allclose(op_grad1.numpy(), ref_grad1, rtol=0.01) diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index e2145f77b366..8d961eb60b18 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -62,10 +62,24 @@ def test_checkpoint(): check_grad(relay.Function(inputs, out_single)) +def verify_batch_matmul_grad(a_shape, b_shape, transpose_a, transpose_b): + tensor_a = relay.var("tensor_a", relay.TensorType(a_shape, "float32")) + tensor_b = relay.var("tensor_b", relay.TensorType(b_shape, "float32")) + check_grad( + relay.Function( + [tensor_a, tensor_b], + relay.op.nn.batch_matmul( + tensor_a, tensor_b, transpose_a=transpose_a, transpose_b=transpose_b + ), + ) + ) + + def test_batch_matmul_grad(): - x = relay.var("x", shape=(2, 3, 5), dtype="float64") - y = relay.var("y", shape=(2, 4, 5), dtype="float64") - check_grad(relay.Function([x, y], relay.op.nn.batch_matmul(x, y))) + verify_batch_matmul_grad((2, 3, 5), (2, 5, 4), False, False) + verify_batch_matmul_grad((2, 3, 5), (2, 4, 5), False, True) + verify_batch_matmul_grad((2, 5, 3), (2, 5, 4), True, False) + verify_batch_matmul_grad((2, 5, 3), (2, 4, 5), True, True) def test_reverse_reshape_grad(): diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index c8a94683eec4..115ed48d5888 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -51,8 +51,9 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode): ) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - op_res, (op_grad,) = intrp.evaluate(bwd_func)(data) + op_res, (op_grad,) = relay.create_executor(device=dev, target=target).evaluate(bwd_func)( + data + ) np.testing.assert_allclose(op_grad.numpy(), ref_grad, rtol=0.01) @@ -100,8 +101,9 @@ def verify_avg_pool2d_grad( ) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - op_res, (op_grad,) = intrp.evaluate(bwd_func)(data) + op_res, (op_grad,) = relay.create_executor(device=dev, target=target).evaluate( + bwd_func + )(data) np.testing.assert_allclose(op_grad.numpy(), ref_grad, rtol=0.01) @@ -156,8 +158,9 @@ def verify_global_avg_pool2d_grad(x_shape): ) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - op_res, (op_grad,) = intrp.evaluate(bwd_func)(data) + op_res, (op_grad,) = relay.create_executor(device=dev, target=target).evaluate(bwd_func)( + data + ) np.testing.assert_allclose(op_grad.numpy(), ref_grad, rtol=0.01) diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index ae3fc2641a25..30d849853d87 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -41,8 +41,9 @@ def test_clip(): bwd_func = run_infer_type(gradient(fwd_func)) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - op_res, (op_grad,) = intrp.evaluate(bwd_func)(data) + op_res, (op_grad,) = relay.create_executor(device=dev, target=target).evaluate( + bwd_func + )(data) np.testing.assert_allclose(op_grad.numpy(), ref_grad, rtol=0.01) @@ -181,8 +182,9 @@ def test_zeros_ones_grad_dynamic(): bwd_func = run_infer_type(gradient(run_infer_type(fwd_func))) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - res, (grad,) = intrp.evaluate(bwd_func)(dyn_shape) + res, (grad,) = relay.create_executor(device=dev, target=target).evaluate(bwd_func)( + dyn_shape + ) tvm.testing.assert_allclose(res.numpy(), op_ref(dyn_shape, dtype="float32")) tvm.testing.assert_allclose(grad.numpy(), np.zeros((rank,), dtype="int32")) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index cbc3e7fbd1e5..97e10eb25a95 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -70,8 +70,9 @@ def check_single_op(opfunc, ref, dtype): and not have_fp16(tvm.cuda(0).compute_version) ): continue - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data + ) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) for opfunc, ref in [ @@ -132,8 +133,9 @@ def check_binary_op(opfunc, ref, dtype): and not have_fp16(tvm.cuda(0).compute_version) ): continue - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data + ) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01, atol=1e-3) for opfunc, ref in [ @@ -163,8 +165,7 @@ def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): continue data = np.random.uniform(size=dshape).astype(dtype) ref_res = data.reshape(oshape) - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) for dtype in ["float16", "float32"]: @@ -196,8 +197,9 @@ def test_bias_add(): and not have_fp16(tvm.cuda(0).compute_version) ): continue - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data + ) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol) @@ -240,8 +242,9 @@ def test_softmax(): x_data = np.random.uniform(size=shape).astype(dtype) ref_res = tvm.topi.testing.softmax_python(x_data) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data + ) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -261,8 +264,9 @@ def test_log_softmax(): x_data = np.random.uniform(size=shape).astype(dtype) ref_res = tvm.topi.testing.log_softmax_python(x_data) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data + ) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -317,11 +321,13 @@ def test_concatenate(): and not have_fp16(tvm.cuda(0).compute_version) ): continue - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data, y_data, t_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data, t_data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=0.01) - op_res2 = intrp2.evaluate(func)(x_data, y_data, t_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + x_data, y_data, t_data + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=0.01) @@ -341,8 +347,7 @@ def test_dropout(): func = relay.Function([], y) for target, dev in tvm.testing.enabled_targets(): for backend in ["debug", "graph"]: - intrp = relay.create_executor("debug", device=dev, target=target) - op_res = intrp.evaluate(func)() + op_res = relay.create_executor("debug", device=dev, target=target).evaluate(func)() tvm.testing.assert_allclose(op_res.numpy(), in_np, rtol=0.01) @@ -461,11 +466,13 @@ def test_matmul(): ref_res = np.dot(x_data.transpose(), w_data) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data, w_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, w_data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data, w_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + x_data, w_data + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -521,11 +528,13 @@ def test_dense(): ref_res = np.dot(x_data, w_data.T) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data, w_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, w_data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data, w_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + x_data, w_data + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 71598e61f694..f796abe5e7d7 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -17,6 +17,7 @@ """ Support level10 operator test cases. """ import numpy as np +import pytest import tvm import tvm.testing import tvm.topi.testing @@ -39,9 +40,10 @@ def test_checkpoint(): inputs = [np.random.uniform() for _ in range(len(xs))] for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - f_res = intrp.evaluate(f)(*inputs) - f_checkpoint_res = intrp.evaluate(f_checkpoint)(*inputs) + f_res = relay.create_executor(kind, device=dev, target=target).evaluate(f)(*inputs) + f_checkpoint_res = relay.create_executor(kind, device=dev, target=target).evaluate( + f_checkpoint + )(*inputs) tvm.testing.assert_allclose(f_res.numpy(), f_checkpoint_res.numpy(), 0, 0) @@ -171,8 +173,7 @@ def test_collapse_sum_like(): ref_res = np.sum(x, 0) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x, y) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x, y) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -191,8 +192,7 @@ def test_collapse_sum_to(): ref_res = np.sum(x, 0) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -211,8 +211,7 @@ def test_broadcast_to(): ref_res = np.broadcast_to(x, shape_like) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -235,8 +234,7 @@ def test_broadcast_to_like(): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x, y) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x, y) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -280,8 +278,9 @@ def verify_slice_like(data, slice_like, axes, output, dtype="float32"): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -314,8 +313,9 @@ def verify_reverse_reshape(shape, newshape, oshape): ref_res = np.reshape(x_data, oshape) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_reverse_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2)) @@ -325,22 +325,21 @@ def verify_reverse_reshape(shape, newshape, oshape): verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12)) -def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): +def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32", trans_x=False, trans_y=True): x = relay.var("x", relay.TensorType(x_shape, dtype)) y = relay.var("y", relay.TensorType(y_shape, dtype)) - z = relay.nn.batch_matmul(x, y) + z = relay.nn.batch_matmul(x, y, transpose_a=trans_x, transpose_b=trans_y) zz = run_infer_type(z) assert zz.checked_type == relay.ty.TensorType(out_shape, dtype) func = relay.Function([x, y], z) x_np = np.random.uniform(size=x_shape).astype(dtype) y_np = np.random.uniform(size=y_shape).astype(dtype) - z_np = tvm.topi.testing.batch_matmul(x_np, y_np) + z_np = tvm.topi.testing.batch_matmul(x_np, y_np, trans_x=trans_x, trans_y=trans_y) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - z = intrp.evaluate(func)(x_np, y_np) + z = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x_np, y_np) tvm.testing.assert_allclose(z.numpy(), z_np, rtol=1e-5) @@ -353,60 +352,13 @@ def test_batch_matmul(): zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((b, m, n), "float32") - verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16)) - verify_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16)) - verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) - verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) - - -def verify_dynamic_batch_matmul( - x_shape, y_shape, out_shape, x_var_shape, y_var_shape, dtype="float32" -): - x = relay.var("x", relay.TensorType(x_var_shape, dtype)) - y = relay.var("y", relay.TensorType(y_var_shape, dtype)) - z = relay.nn.batch_matmul(x, y) - - func = relay.Function([x, y], z) - x_np = np.random.uniform(size=x_shape).astype(dtype) - y_np = np.random.uniform(size=y_shape).astype(dtype) - z_np = tvm.topi.testing.batch_matmul(x_np, y_np) - - for target, dev in tvm.testing.enabled_targets(): - for kind in ["vm", "debug"]: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - z = intrp.evaluate()(x_np, y_np) - tvm.testing.assert_allclose(z.numpy(), z_np, rtol=1e-5) - - -# TODO(mbrookhart): enable once VM supports heterogenous execution -# @tvm.testing.uses_gpu -def test_dynamic_batch_matmul(): - verify_dynamic_batch_matmul( - (1, 16, 32), (1, 16, 32), (1, 16, 16), (1, 16, 32), (relay.Any(),) * 3 - ) - verify_dynamic_batch_matmul( - (5, 16, 32), (5, 16, 32), (5, 16, 16), (5, 16, 32), (relay.Any(),) * 3 - ) - verify_dynamic_batch_matmul( - (5, 16, 32), (5, 20, 32), (5, 16, 20), (5, 16, 32), (relay.Any(),) * 3 - ) - verify_dynamic_batch_matmul( - (30, 16, 32), (30, 20, 32), (30, 16, 20), (30, 16, 32), (relay.Any(),) * 3 - ) - - verify_dynamic_batch_matmul( - (1, 16, 32), (1, 16, 32), (1, 16, 16), (relay.Any(), 16, 32), (relay.Any(), 16, 32) - ) - verify_dynamic_batch_matmul( - (5, 16, 32), (5, 16, 32), (5, 16, 16), (relay.Any(), 16, 32), (relay.Any(), 16, 32) - ) - verify_dynamic_batch_matmul( - (5, 16, 32), (5, 20, 32), (5, 16, 20), (relay.Any(), 16, 32), (relay.Any(), 20, 32) - ) - verify_dynamic_batch_matmul( - (30, 16, 32), (30, 20, 32), (30, 16, 20), (relay.Any(), 16, 32), (relay.Any(), 20, 32) - ) + verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16), trans_x=False, trans_y=True) + verify_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16), trans_x=False, trans_y=True) + verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20), trans_x=False, trans_y=True) + verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20), trans_x=False, trans_y=True) + verify_batch_matmul((1, 32, 16), (1, 16, 32), (1, 16, 16), trans_x=True, trans_y=True) + verify_batch_matmul((5, 16, 32), (5, 32, 16), (5, 16, 16), trans_x=False, trans_y=False) + verify_batch_matmul((5, 32, 16), (5, 32, 20), (5, 16, 20), trans_x=True, trans_y=False) @tvm.testing.uses_gpu @@ -420,8 +372,7 @@ def test_shape_of(): # Because using graph executor, this op will be optimized after # constant folding pass, here we only test with interpreter for kind in ["debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.numpy(), np.array(shape).astype("int32")) @@ -436,8 +387,9 @@ def verify_ndarray_size(shape): ref_res = np.size(x_data) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) verify_ndarray_size((2, 3, 5)) @@ -454,8 +406,9 @@ def verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc): np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, layout) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - relay_out = intrp1.evaluate(func)(np_data) + relay_out = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + np_data + ) tvm.testing.assert_allclose(relay_out.numpy(), np_out, rtol=1e-5, atol=1e-5) @@ -515,8 +468,9 @@ def _verify(data_shape, mask_value, axis, dtype, itype): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - out_relay = intrp.evaluate(func)(data_np, valid_length_np) + out_relay = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data_np, valid_length_np + ) tvm.testing.assert_allclose(out_relay.numpy(), gt_out_np) _verify((5, 10), 0.0, 1, "float32", "int32") @@ -555,8 +509,9 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - out_relay = intrp.evaluate(func)(indices_np) + out_relay = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + indices_np + ) tvm.testing.assert_allclose(out_relay.numpy(), out_np) _verify((3,), 3, 1, 0, -1, "int32") @@ -585,8 +540,9 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - out_relay = intrp.evaluate(func)(input_np, diagonal_np) + out_relay = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + input_np, diagonal_np + ) tvm.testing.assert_allclose(out_relay.numpy(), out_np) _verify((2, 2), (2,), "float32") @@ -626,9 +582,10 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3 ) for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - out_relay = intrp.evaluate(func)(predictions_np, targets_np, weights_np) - tvm.testing.assert_allclose(out_relay.asnumpy(), out_np, rtol=1e-6, atol=1e-6) + out_relay = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + predictions_np, targets_np, weights_np + ) + tvm.testing.assert_allclose(out_relay.numpy(), out_np, rtol=1e-6, atol=1e-6) _verify((10, 5)) _verify((10, 5, 2, 2)) @@ -639,15 +596,4 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3 if __name__ == "__main__": - test_adaptive_pool() - test_collapse_sum_like() - test_broadcast_to() - test_broadcast_to_like() - test_slice_like() - test_reverse_reshape() - test_batch_matmul() - test_shape_of() - test_sequence_mask() - test_one_hot() - test_ndarray_size() - test_matrix_set_diag() + pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index f05c5054415d..87cdc41570d0 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -98,8 +98,9 @@ def run_test_conv1d( if target in except_targets: continue dev = tvm.device(target, 0) - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) # normal conv1d @@ -226,8 +227,9 @@ def run_test_conv2d( if target in except_targets: continue dev = tvm.device(target, 0) - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-4, atol=1e-4) def compile_test_conv2d_arm_cpu( @@ -513,8 +515,9 @@ def run_test_conv3d( continue dev = tvm.device(target, 0) - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) # normal conv3d @@ -578,8 +581,9 @@ def run_test_conv3d( continue dev = tvm.device(target, 0) - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) # normal conv3d @@ -761,8 +765,9 @@ def test_conv3d_transpose_ncdhw_run(): ref_res = tvm.topi.testing.conv3d_transpose_ncdhw_python(data, kernel, 1, 1, 0) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -804,8 +809,9 @@ def test_conv2d_transpose_nchw_run(): ref_res = tvm.topi.testing.conv2d_transpose_nchw_python(data, kernel, 2, 1, (1, 1)) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -840,8 +846,9 @@ def test_conv2d_transpose_nhwc_run(): ) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -862,8 +869,9 @@ def test_conv1d_transpose_ncw_run(): ref_res = tvm.topi.testing.conv1d_transpose_ncw_python(data, kernel, 2, 1, output_padding=(1,)) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -947,8 +955,7 @@ def _test_global_pool2d(opfunc, reffunc): data = np.random.uniform(size=dshape).astype(dtype) ref_res = reffunc(data, axis=(2, 3), keepdims=True) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -980,8 +987,7 @@ def _test_pool2d(opfunc, pool_type, pool_size=2, strides=2, dilation=1, padding= ceil_mode=False, ) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) def _test_pool2d_int(opfunc, reffunc, dtype): @@ -1001,8 +1007,9 @@ def _test_pool2d_int(opfunc, reffunc, dtype): data = np.random.randint(low=-128, high=128, size=dshape) ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) _test_pool2d(relay.nn.max_pool2d, "max") @@ -1039,8 +1046,7 @@ def _test_global_pool1d(opfunc, reffunc): data = np.random.uniform(size=dshape).astype(dtype) ref_res = reffunc(data, axis=(2,), keepdims=True) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -1075,8 +1081,9 @@ def _test_pool1d( ceil_mode=False, ) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) _test_pool1d(relay.nn.max_pool1d, "max") @@ -1135,8 +1142,9 @@ def _test_pool3d( ceil_mode=False, ) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) _test_pool3d(relay.nn.max_pool3d, "max") @@ -1187,8 +1195,7 @@ def test_avg_pool2d_no_count_pad(): data = a_np for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -1222,11 +1229,9 @@ def test_flatten_infer_type(): ref_res = x_data.flatten().reshape(o_shape) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -1296,8 +1301,9 @@ def _test_run(dtype): data = np.random.uniform(size=dshape).astype(dtype) ref_res = _get_numpy_pad(dshape, data, pad) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) _test_run("float32") @@ -1320,8 +1326,9 @@ def _test_run(dtype): ref_res = _get_numpy_pad(dshape, data_arr, pad, pad_value=pad_value_arr) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(kind="graph", device=dev, target=target) - result = intrp.evaluate(f)(data_arr, pad_value_arr) + result = relay.create_executor(kind="graph", device=dev, target=target).evaluate(f)( + data_arr, pad_value_arr + ) tvm.testing.assert_allclose(result.numpy(), ref_res, rtol=1e-5, atol=1e-5) _test_run("float32") @@ -1353,11 +1360,9 @@ def test_lrn(): ref_res = tvm.topi.testing.lrn_python(x_data, size, axis, bias, alpha, beta) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -1383,11 +1388,9 @@ def test_l2_normalize(): ref_res = tvm.topi.testing.l2_normalize_python(x_data, eps, axis) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -1408,8 +1411,7 @@ def test_batch_flatten(): data = np.random.rand(5, 10, 5).astype(t1.dtype) ref_res = batch_flatten(data) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) @@ -1458,8 +1460,7 @@ def get_shape(): "align_corners" if align_corners else "asymmetric", ) for target, dev in tvm.testing.enabled_targets(): - executor = relay.create_executor("graph", device=dev, target=target) - out = executor.evaluate(func)(data) + out = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(out.numpy(), ref, rtol=1e-5, atol=1e-5) @@ -1530,8 +1531,7 @@ def get_shape(): coordinate_transformation_mode, ) for target, dev in tvm.testing.enabled_targets(): - executor = relay.create_executor("graph", device=dev, target=target) - out = executor.evaluate(func)(data) + out = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(out.numpy(), ref, rtol=1e-5, atol=1e-5) @@ -1602,7 +1602,7 @@ def _has_fast_int8_instructions(asm, target): targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"] llvm_version = tvm.target.codegen.llvm_version_major() for target in targets: - if llvm_version >= 8: + if tvm.testing.device_enabled(target) and llvm_version >= 8: dtypes = ("uint8", "int8", "int32") # Sweep the input channels to check int8 robustness # Input channels should be a multiple of 4 internally. @@ -1654,7 +1654,7 @@ def _has_fast_int8_instructions(asm, target): # Check that int8 x int8 goes through legalization so that fast instructions can be picked up. for target in targets: - if llvm_version >= 8: + if tvm.testing.device_enabled(target) and llvm_version >= 8: dtypes = ("int8", "int8", "int32") # Check that both non-divisible oc and ic work asm = _compile( @@ -1676,17 +1676,18 @@ def _has_fast_int8_instructions(asm, target): # Check that a vectorized instruction is generated for older Intel # generations, because we default to NCHWc layout. target = "llvm -mcpu=core-avx2" - fast_int8_dtypes = ("uint8", "int8", "int32") - asm = _compile( - ic=16, - oc=32, - target=target, - data_layout="NCHW", - kernel_layout="OIHW", - dtypes=fast_int8_dtypes, - ) - # Check that vector int mult and add instructions are generated. - assert "vpmulld" in asm and "vpadd" in asm + if tvm.testing.device_enabled(target): + fast_int8_dtypes = ("uint8", "int8", "int32") + asm = _compile( + ic=16, + oc=32, + target=target, + data_layout="NCHW", + kernel_layout="OIHW", + dtypes=fast_int8_dtypes, + ) + # Check that vector int mult and add instructions are generated. + assert "vpmulld" in asm and "vpadd" in asm @tvm.testing.uses_gpu @@ -1797,8 +1798,9 @@ def _test_correlation( ) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data1_np, data2_np) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data1_np, data2_np + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) _test_correlation( diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index fc67f0b90295..41a866a0a034 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -34,8 +34,7 @@ def test_zeros_ones(): y = op(shape=(124, 50), dtype="float64") yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((124, 50), "float64") - intrp = create_executor() - intrp_res = intrp.evaluate(y).numpy() + intrp_res = create_executor().evaluate(y).numpy() np.testing.assert_allclose(intrp_res, ref((124, 50), "float64")) @@ -60,8 +59,7 @@ def test_unary_identity(): if ref is not None: data = np.random.rand(*shape).astype("float32") - intrp = create_executor() - op_res = intrp.evaluate(y, {x: relay.const(data)}) + op_res = create_executor().evaluate(y, {x: relay.const(data)}) ref_res = ref(data) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) @@ -87,8 +85,7 @@ def test_clip(): assert yy.checked_type == relay.TensorType((10, 4), "float32") data = np.random.rand(10, 4).astype("float32") - intrp = create_executor() - op_res = intrp.evaluate(y, {a: relay.const(data)}) + op_res = create_executor().evaluate(y, {a: relay.const(data)}) ref_res = np.clip(data, 1.0, 4.0) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) @@ -105,8 +102,7 @@ def test_fixed_point_multiply(): assert yy.checked_type == relay.TensorType((10, 4), "int32") data = 23 * np.ones((10, 4)).astype("int32") - intrp = create_executor() - op_res = intrp.evaluate(y, {a: relay.const(data)}) + op_res = create_executor().evaluate(y, {a: relay.const(data)}) ref_res = np.ones((10, 4)).astype("int32") np.testing.assert_allclose(op_res.numpy(), ref_res, atol=1) @@ -118,8 +114,7 @@ def test_reinterpret(): assert yy.checked_type == relay.TensorType((1000, 4), "int32") data = np.random.randn(1000, 4).astype("float32") * 1000 - intrp = create_executor() - op_res = intrp.evaluate(y, {a: relay.const(data)}) + op_res = create_executor().evaluate(y, {a: relay.const(data)}) ref_res = data.view("int32") np.testing.assert_equal(op_res.numpy(), ref_res) @@ -155,8 +150,7 @@ def approximate_tanh(x): yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((1000,), "float32") data = np.linspace(-5, 5, 1000).astype("float32") - intrp = create_executor() - op_res = intrp.evaluate(y, {a: relay.const(data)}) + op_res = create_executor().evaluate(y, {a: relay.const(data)}) def reference_sigmoid(x): return np.exp(-np.logaddexp(0, -x)) @@ -167,8 +161,7 @@ def reference_sigmoid(x): yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((1000,), "float32") data = np.linspace(-5, 5, 1000).astype("float32") - intrp = create_executor() - op_res = intrp.evaluate(y, {a: relay.const(data)}) + op_res = create_executor().evaluate(y, {a: relay.const(data)}) def reference_tanh(x): return np.tanh(x) @@ -184,8 +177,7 @@ def verify_squeeze(shape, dtype, axis): np_axis = tuple(axis) if axis is not None else None data = np.random.random_sample(shape).astype(dtype) - intrp = create_executor() - op_res = intrp.evaluate(squeeze, {x: relay.const(data)}) + op_res = create_executor().evaluate(squeeze, {x: relay.const(data)}) ref_res = np.squeeze(data, axis=np_axis) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) @@ -220,8 +212,9 @@ def verify_transpose(dshape, axes): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_transpose((2, 3, 4), (0, 2, 1)) @@ -275,8 +268,9 @@ def verify_reshape(shape, newshape, oshape): ref_res = np.reshape(x_data, oshape) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_reshape((2, 3, 4), (8, 3), (8, 3)) @@ -293,6 +287,7 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((2, 3, 4), (-3, -2), (6, 4)) verify_reshape((2, 3, 4), (-4, 1, 2, -2), (1, 2, 3, 4)) verify_reshape((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4)) + verify_reshape((1,), (), ()) def test_reshape_fail(): @@ -364,8 +359,9 @@ def verify_reshape_like(shape, oshape, shape_like=None, reshape_like_kwargs={}): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_reshape_like((2, 3, 4), (1, 8, 3)) @@ -410,8 +406,9 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, indices_src) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data, indices_src + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_take((4,), [1]) @@ -545,8 +542,9 @@ def verify_full(fill_value, src_shape, dtype): ref_res = np.full(src_shape, fill_value) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(np.array(fill_value, dtype)) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + np.array(fill_value, dtype) + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_full(4, (1, 3, 4, 4), "int32") @@ -584,8 +582,9 @@ def verify_full_like(base, fill_value, dtype): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, np.array(fill_value, dtype)) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data, np.array(fill_value, dtype) + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_full_like((1, 3, 4, 4), 4, "int32") @@ -613,11 +612,9 @@ def test_infer_type_leaky_relu(): ref_res = np.where(x_data > 0, x_data, x_data * 0.1) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -650,11 +647,13 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"): ref_res = (x_data < 0) * (x_data * a_data.reshape(1, 1, 3)) + (x_data >= 0) * x_data for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data, a_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, a_data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data, a_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + x_data, a_data + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -695,8 +694,7 @@ def verify_arange(start, stop, step): func = relay.Function([], x) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)() + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)() tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_arange(None, 20, None) @@ -734,8 +732,9 @@ def verify_meshgrid(lengths, indexing="ij"): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(*input_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + *input_data + ) assert len(op_res) == len(ref_res) for i in range(len(op_res)): tvm.testing.assert_allclose(op_res[i].numpy(), ref_res[i], rtol=1e-5) @@ -760,8 +759,9 @@ def verify_tile(dshape, reps): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_tile((2, 3, 4), (3, 2, 1)) @@ -778,8 +778,7 @@ def verify_repeat(dshape, repeats, axis): ref_res = np.repeat(data, repeats, axis) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_repeat((3,), 2, 0) @@ -803,8 +802,9 @@ def verify_stack(input_expr, relay_args, ref_res, axis): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(*relay_args) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + *relay_args + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) def verify_tup_lit_stack(dshapes, axis): @@ -855,8 +855,9 @@ def verify_reverse(dshape, axis): ref_res = np.flip(x_data, axis) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_reverse((2, 3, 4), 1) @@ -876,8 +877,9 @@ def verify_reverse_sequence(x_data, seq_lengths, batch_axis, seq_axis, ref_res): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32") @@ -970,8 +972,9 @@ def verify_scatter(dshape, ishape, axis=0): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data_np, indices_np, updates_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) def verify_dynamic_scatter(dshape, ishape, axis=0): @@ -991,8 +994,9 @@ def verify_dynamic_scatter(dshape, ishape, axis=0): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np, indices_np, updates_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np, indices_np, updates_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_scatter((10,), (10,), 0) @@ -1244,8 +1248,9 @@ def verify_gather(data, axis, indices, ref_res): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(data, indices) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data, indices + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_gather(data, axis, indices, ref_res) @@ -1271,8 +1276,9 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) @@ -1319,8 +1325,7 @@ def _verify_infiniteness_ops(relay_op, ref_op): ] = np.infty data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.nan - intrp = create_executor() - op_res = intrp.evaluate(y, {x: data}) + op_res = create_executor().evaluate(y, {x: data}) ref_res = ref_op(data) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) @@ -1354,8 +1359,9 @@ def verify_unravel_index(indices, shape, dtype): ref_res = np.unravel_index(x_data, y_data) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) for dtype in ["int64", "int32"]: @@ -1400,13 +1406,11 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ func = relay.Function(args, d) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) + f = relay.create_executor(kind, device=dev, target=target).evaluate(func) if default_value is None: - op_res = intrp.evaluate(func)(sparse_indices_data, sparse_values_data) + op_res = f(sparse_indices_data, sparse_values_data) else: - op_res = intrp.evaluate(func)( - sparse_indices_data, sparse_values_data, default_value_data - ) + op_res = f(sparse_indices_data, sparse_values_data, default_value_data) tvm.testing.assert_allclose(op_res.numpy(), xpected, rtol=1e-5) verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0]) # scalar @@ -1743,8 +1747,9 @@ def verify_func(func, data, ref_res, target_device=tvm.testing.enabled_targets() for target, dev in target_device: for kind in ["vm"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(*data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + *data + ) if isinstance(op_res, tvm.runtime.container.ADT): assert len(op_res) == len( ref_res @@ -1775,8 +1780,9 @@ def verify_adv_index(data_shape, index_shapes): func = relay.Function(inputs, out) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(*np_args) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + *np_args + ) tvm.testing.assert_allclose(op_res.numpy(), np_out, rtol=1e-5) verify_adv_index((10, 5), [(3, 4), (3, 1)]) @@ -1813,8 +1819,7 @@ def assert_relay_scanop( func = relay.Function([inp], out) for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(data_np) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(data_np) tvm.testing.assert_allclose(op_res.numpy(), np_out, rtol=rtol, atol=atol) data = np.array([2, 3, 0]) @@ -1875,8 +1880,9 @@ def verify_scatter_nd( func = relay.Function([data, indices, updates], out) for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data_np, indices_np, updates_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol, atol=atol) def verify_scatter_nd_with_stack( @@ -1884,7 +1890,8 @@ def verify_scatter_nd_with_stack( ): data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) indices_vars = [ - relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np) + relay.var("ind%d" % i, shape=v.shape, dtype=str(v.dtype)) + for i, v in enumerate(indices_np) ] updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype)) @@ -1900,8 +1907,7 @@ def verify_scatter_nd_with_stack( for a in indices_np: fargs.append(a) for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(*fargs) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(*fargs) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol, atol=atol) data = np.zeros((2, 2)).astype("int64") @@ -1926,7 +1932,7 @@ def verify_scatter_nd_with_stack( out[1, :] += updates[0, :] out[0, :] += updates[1, :] out[0, :] += updates[2, :] - verify_scatter_nd(data, indices, updates, out) + verify_scatter_nd(data, indices, updates, out, mode="add") verify_scatter_nd_with_stack(data, indices, updates, out) for mode in ["add", "update"]: @@ -1986,8 +1992,9 @@ def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False): for target, dev in tvm.testing.enabled_targets(): for kind in backends: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - tvm_res = intrp.evaluate()( + tvm_res = relay.create_executor( + kind, mod=mod, device=dev, target=target + ).evaluate()( x_data ) # unique, indices, inverse_indices, num_unique, (counts) np_res = calc_numpy_unique( diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index b59325aea2f9..df77c33658de 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -50,8 +50,9 @@ def check_binary_op(opfunc, ref): func = relay.Function([x, y], z) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) for opfunc, ref in [(relay.power, np.power)]: @@ -88,8 +89,9 @@ def test_cmp_type(): func = relay.Function([x, y], z) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) @@ -113,8 +115,9 @@ def test_binary_int_broadcast_1(): ref_res = ref(x_data, y_data) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) @@ -138,8 +141,9 @@ def test_binary_int_broadcast_2(): ref_res = ref(x_data, y_data) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) @@ -148,8 +152,9 @@ def test_where(): def run(func, inputs, ref_res): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(*inputs) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + *inputs + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) def verify(x_np, y_np, cond_np): @@ -258,11 +263,9 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -352,12 +355,10 @@ def verify_mean_var_std(funcs, shape, axis, keepdims): ref_res = ref_func(x_data, axis=axis, dtype=dtype, keepdims=keepdims) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res1[0].numpy(), ref_mean, rtol=1e-5) tvm.testing.assert_allclose(op_res1[1].numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2[0].numpy(), ref_mean, rtol=1e-5) tvm.testing.assert_allclose(op_res2[1].numpy(), ref_res, rtol=1e-5) @@ -425,8 +426,9 @@ def verify( if not test_ref: return for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) verify((1, 3, 10, 10), [0, 0, 0, 0], [-1, 3, 10, 10], [1], (0, 3, 10, 10), dtype="int64") @@ -503,8 +505,9 @@ def verify( return for target, dev in tvm.testing.enabled_targets(): mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor("vm", mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_data) + op_res = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) verify( @@ -562,8 +565,9 @@ def verify(dshape, begin, end, strides, vshape, test_ref=True): v_data = np.random.uniform(size=vshape).astype("float32") ref_res = tvm.topi.testing.strided_set_python(x_data, v_data, begin, end, strides) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, v_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, v_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) verify((3, 4, 16), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index d93de5419f56..c08b538d22e6 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -62,8 +62,9 @@ def verify_resize(dshape, scale, method, layout, coord_trans): func = relay.Function([x], z) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-4) for method in ["nearest_neighbor", "linear", "cubic"]: @@ -113,8 +114,9 @@ def verify_resize(dshape, scale, method, layout, coord_trans): func = relay.Function([x], z) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-4) for method in ["nearest_neighbor", "linear", "cubic"]: @@ -167,8 +169,7 @@ def verify_resize(dshape, scale, method, layout): func = relay.Function([x], z) for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) for method in ["nearest_neighbor", "linear", "cubic"]: @@ -202,8 +203,9 @@ def verify_crop_and_resize( for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(image_data, boxes, box_indices) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + image_data, boxes, box_indices + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-04) boxes_nhwc = np.array([[0.1, 0.2, 0.8, 0.7], [0.2, 0, 1, 0.6]]).astype("float32") @@ -302,11 +304,9 @@ def verify_multibox_prior( func = relay.Function([x], z) func = run_infer_type(func) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res2 = intrp2.evaluate(func)(data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) sizes = (0.3, 1.5, 0.7) @@ -361,8 +361,7 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): func = relay.Function([x], z.astuple()) func = run_infer_type(func) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("debug", device=dev, target=target) - out = intrp.evaluate(func)(np_data) + out = relay.create_executor("debug", device=dev, target=target).evaluate(func)(np_data) tvm.testing.assert_allclose(out[0].numpy(), np_out1, rtol=1e-3, atol=1e-04) tvm.testing.assert_allclose(out[1].numpy(), np_out2, rtol=1e-3, atol=1e-04) @@ -433,15 +432,21 @@ def verify_nms( func_indices = relay.Function([x0, x1, x2, x3], z_indices) func_indices = run_infer_type(func_indices) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x0_data, x1_data, x2_data, x3_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x0_data, x1_data, x2_data, x3_data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data, x3_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + x0_data, x1_data, x2_data, x3_data + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) - op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data) + op_indices_res1 = relay.create_executor("graph", device=dev, target=target).evaluate( + func_indices + )(x0_data, x1_data, x2_data, x3_data) tvm.testing.assert_allclose(op_indices_res1[0].numpy(), ref_indices_res, rtol=1e-5) - op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data) + op_indices_res2 = relay.create_executor("debug", device=dev, target=target).evaluate( + func_indices + )(x0_data, x1_data, x2_data, x3_data) tvm.testing.assert_allclose(op_indices_res2[0].numpy(), ref_indices_res, rtol=1e-5) np_data = np.array( @@ -624,11 +629,13 @@ def test_default_value(): func = relay.Function([cls_prob, loc_pred, anchors], nms) func = run_infer_type(func) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(np_cls_prob, np_loc_preds, np_anchors) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + np_cls_prob, np_loc_preds, np_anchors + ) tvm.testing.assert_allclose(op_res1.numpy(), expected_np_out, rtol=1e-5) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res2 = intrp2.evaluate(func)(np_cls_prob, np_loc_preds, np_anchors) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + np_cls_prob, np_loc_preds, np_anchors + ) tvm.testing.assert_allclose(op_res2.numpy(), expected_np_out, rtol=1e-5) def test_threshold(): @@ -718,11 +725,13 @@ def verify_roi_align( ) for target, dev in tvm.testing.enabled_targets(): print("test on", target) - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(np_data, np_rois) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + np_data, np_rois + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-4) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res2 = intrp2.evaluate(func)(np_data, np_rois) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + np_data, np_rois + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-4) def verify_roi_align_nchw( @@ -813,11 +822,13 @@ def verify_roi_pool(data_shape, rois_shape, pooled_size, spatial_scale): np_data, np_rois, pooled_size=pooled_size, spatial_scale=spatial_scale ) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(np_data, np_rois) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + np_data, np_rois + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-4) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res2 = intrp2.evaluate(func)(np_data, np_rois) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + np_data, np_rois + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-4) verify_roi_pool((1, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=1.0) @@ -841,11 +852,13 @@ def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs): print("Skip test because %s is not enabled." % target) continue dev = tvm.device(target, 0) - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(np_cls_prob, np_bbox_pred, np_im_info) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + np_cls_prob, np_bbox_pred, np_im_info + ) tvm.testing.assert_allclose(op_res1.numpy(), np_out, rtol=1e-4) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res2 = intrp2.evaluate(func)(np_cls_prob, np_bbox_pred, np_im_info) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + np_cls_prob, np_bbox_pred, np_im_info + ) tvm.testing.assert_allclose(op_res2.numpy(), np_out, rtol=1e-4) attrs = { @@ -935,8 +948,9 @@ def verify_yolo_reorg(shape, stride): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_yolo_reorg((1, 100, 20, 20), 10) @@ -1070,8 +1084,9 @@ def test_run(batch, in_channel, size, out_channel, deformable_groups, groups, la if target == "cuda" and layout == "NHWC": continue # Cannot run NHWC layout on cuda target, only on llvm for kind in ["graph", "debug"]: - intrp1 = relay.create_executor(kind, device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, offset, kernel) + op_res1 = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data, offset, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) test_run(1, 4, 16, 4, 1, 1, "NCHW") @@ -1115,8 +1130,9 @@ def verify_depth_to_space(dshape, block_size, layout, mode): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4) for layout in ["NHWC", "NCHW"]: @@ -1159,8 +1175,9 @@ def verify_space_to_depth(dshape, block_size, layout): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4) for layout in ["NHWC", "NCHW"]: @@ -1215,8 +1232,9 @@ def run_test_dilation2d( for target, dev in tvm.testing.enabled_targets(): if target in except_targets: continue - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(indata, kernel) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + indata, kernel + ) tvm.testing.assert_allclose(op_res.numpy(), out, rtol=1e-5, atol=1e-5) def _convert_data(indata, kernel, out, layout=None): @@ -1317,8 +1335,9 @@ def verify_affine_grid(num_batch, target_shape): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp1 = relay.create_executor(kind, device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data_np) + op_res1 = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data_np + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) verify_affine_grid(1, (16, 32)) @@ -1344,8 +1363,9 @@ def verify_grid_sample(data_shape, grid_shape): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp1 = relay.create_executor(kind, device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data_np, grid_np) + op_res1 = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data_np, grid_np + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8)) @@ -1371,8 +1391,9 @@ def verify_space_to_batch_nd(dshape, block_shape, paddings): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4) verify_space_to_batch_nd([3, 3, 2, 1], [3], [[0, 0]]) @@ -1398,8 +1419,9 @@ def verify_batch_to_space_nd(dshape, block_shape, crops): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4) verify_batch_to_space_nd([4, 1, 1, 3], [2, 2], [[0, 0], [0, 0]]) @@ -1432,8 +1454,9 @@ def verify_all_class_non_max_suppression( for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - selected_indices, num_detections = intrp.evaluate(func)(boxes_np, scores_np) + selected_indices, num_detections = relay.create_executor( + kind, device=dev, target=target + ).evaluate(func)(boxes_np, scores_np) tvm_res = selected_indices.numpy()[: num_detections.numpy()[0]] np.testing.assert_equal(tvm_res, expected_indices) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index 1838233e3a3a..ea640c62dfeb 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -16,24 +16,23 @@ # under the License. """ Support level6 operator test cases. """ +import pytest import numpy as np import tvm -from tvm import te from tvm import relay import tvm.testing @tvm.testing.uses_gpu def test_sort(): - def verify_sort(shape, axis, is_ascend, is_dyn=False): - + def verify_sort(shape, axis, is_ascend, is_dyn=False, in_dtype="float32"): if is_dyn: - x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32")) + x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), in_dtype)) else: - x = relay.var("x", relay.TensorType(shape, "float32")) + x = relay.var("x", relay.TensorType(shape, in_dtype)) z = relay.sort(x, axis=axis, is_ascend=is_ascend) func = relay.Function([x], z) - x_data = np.random.uniform(size=shape).astype("float32") + x_data = np.random.uniform(size=shape).astype(in_dtype) if is_ascend: ref_res = np.sort(x_data, axis=axis) else: @@ -46,8 +45,9 @@ def verify_sort(shape, axis, is_ascend, is_dyn=False): for target, dev in tvm.testing.enabled_targets(): for kind in backends: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) for is_dyn in [False, True]: @@ -56,18 +56,19 @@ def verify_sort(shape, axis, is_ascend, is_dyn=False): verify_sort((3, 5, 6), axis=-1, is_ascend=False, is_dyn=is_dyn) verify_sort((3, 2000, 6), axis=1, is_ascend=False, is_dyn=is_dyn) verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn) + verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn, in_dtype="float16") @tvm.testing.uses_gpu def test_argsort(): - def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): + def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False, in_dtype="float32"): if is_dyn: - x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32")) + x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), in_dtype)) else: - x = relay.var("x", relay.TensorType(shape, "float32")) + x = relay.var("x", relay.TensorType(shape, in_dtype)) z = relay.argsort(x, axis=axis, is_ascend=is_ascend, dtype=dtype) func = relay.Function([x], z) - x_data = np.random.uniform(size=shape).astype("float32") + x_data = np.random.uniform(size=shape).astype(in_dtype) if is_ascend: ref_res = np.argsort(x_data, axis=axis, kind="stable") else: @@ -80,8 +81,9 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): for target, dev in tvm.testing.enabled_targets(): for kind in backends: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.astype(dtype), rtol=1e-5) for is_dyn in [False, True]: @@ -93,39 +95,43 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): verify_argsort((3, 6000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) verify_argsort((1000, 1, 1), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn) verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) + verify_argsort( + (1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn, in_dtype="float16" + ) @tvm.testing.uses_gpu def test_topk(): - def verify_topk(k, axis, ret_type, is_ascend, dtype): + def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"): shape = (20, 100) - x = relay.var("x", relay.TensorType(shape, "float32")) + x = relay.var("x", relay.TensorType(shape, in_dtype)) out = relay.topk(x, k, axis, ret_type, is_ascend, dtype) if isinstance(out, relay.expr.TupleWrapper): out = out.astuple() func = relay.Function([x], out) - np_data = np.random.uniform(size=shape).astype("float32") + np_data = np.random.uniform(size=shape).astype(in_dtype) if is_ascend: - np_indices = np.argsort(np_data, axis=axis) + np_indices = np.argsort(np_data, axis=axis, kind="stable") else: - np_indices = np.argsort(-np_data, axis=axis) + np_indices = np.argsort(-np_data, axis=axis, kind="stable") kk = k if k >= 1 else shape[axis] if axis == 0: np_indices = np_indices[:kk, :] - np_values = np.zeros(np_indices.shape).astype("float32") + np_values = np.zeros(np_indices.shape).astype(in_dtype) for i in range(shape[1]): np_values[:, i] = np_data[np_indices[:, i], i] else: np_indices = np_indices[:, :kk] - np_values = np.zeros(np_indices.shape).astype("float32") + np_values = np.zeros(np_indices.shape).astype(in_dtype) for i in range(shape[0]): np_values[i, :] = np_data[i, np_indices[i, :]] np_indices = np_indices.astype(dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(np_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + np_data + ) if ret_type == "both": tvm.testing.assert_allclose(op_res[0].numpy(), np_values) tvm.testing.assert_allclose(op_res[1].numpy(), np_indices) @@ -140,9 +146,8 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): for ret_type in ["both", "values", "indices"]: verify_topk(k, axis, ret_type, True, "int64") verify_topk(k, axis, ret_type, False, "float32") + verify_topk(k, axis, ret_type, False, "int64", "float16") if __name__ == "__main__": - test_sort() - test_argsort() - test_topk() + pytest.main([__file__]) diff --git a/tests/python/relay/test_op_qnn_add.py b/tests/python/relay/test_op_qnn_add.py index d3a3b8ffca5f..b38ada718cc5 100644 --- a/tests/python/relay/test_op_qnn_add.py +++ b/tests/python/relay/test_op_qnn_add.py @@ -63,8 +63,9 @@ def test_tflite_same_io_qnn_params(): y_data = y_datas[i] golden_output = golden_outputs[i] - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) @@ -111,8 +112,9 @@ def test_tflite_different_io_qnn_params(): y_data = y_datas[i] golden_output = golden_outputs[i] - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) @@ -143,8 +145,9 @@ def test_saturation(): y_data = np.array((255, 255, 128, 0)).reshape((1, 4)) golden_output = np.array((255, 255, 129, 0)).reshape((1, 4)) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) # Same params, different scale @@ -169,8 +172,9 @@ def test_saturation(): y_data = np.array((255, 255, 127, 0)).reshape((1, 4)) golden_output = np.array((255, 129, 65, 0)).reshape((1, 4)) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) # Same io params, different output scale @@ -195,8 +199,9 @@ def test_saturation(): y_data = np.array((255, 255, 127, 0)).reshape((1, 4)) golden_output = np.array((255, 129, 65, 0)).reshape((1, 4)) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) # All params different @@ -221,8 +226,9 @@ def test_saturation(): y_data = np.array((0, 128, 64, 0)).reshape((1, 4)) golden_output = np.array((255, 255, 132, 0)).reshape((1, 4)) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) diff --git a/tests/python/relay/test_op_qnn_batch_matmul.py b/tests/python/relay/test_op_qnn_batch_matmul.py new file mode 100644 index 000000000000..91648aca3dbc --- /dev/null +++ b/tests/python/relay/test_op_qnn_batch_matmul.py @@ -0,0 +1,247 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm import relay +from tvm.contrib import graph_executor +from tvm.relay.testing.temp_op_attr import TempOpAttr + +# We use llvm target for testing functionality. `llvm` points to an older Intel +# generation machine, that legalizes to a simple lowering. Therefore, the +# legalization is overwritten such that it can be skipped and we use the +# QNNCanonicalizeOps lowering for the testing. +def legalize_qnn_batch_matmul(attrs, inputs, types): + return None + + +def make_requantize_params(input_scale, output_scale, output_zero_point, out_dtype): + config = { + "input_scale": input_scale, + "output_scale": output_scale, + "output_zero_point": output_zero_point, + "out_dtype": out_dtype, + } + return config + + +def make_configuration( + quantized_x, + quantized_y, + dtype, + x_shape, + y_shape, + x_zero_point, + y_zero_point, + x_scale, + y_scale, + output, + out_dtype="int32", + requantize=None, +): + config = { + "quantized_x": quantized_x, + "quantized_y": quantized_y, + "dtype": dtype, + "x_shape": x_shape, + "y_shape": y_shape, + "x_zero_point": x_zero_point, + "y_zero_point": y_zero_point, + "x_scale": x_scale, + "y_scale": y_scale, + "output": output, + "out_dtype": out_dtype, + "requantize": requantize, + } + return config + + +def make_int_configuration( + xzero_point_zero=True, yzero_point_zero=True, requantize_output=False, per_channel=False +): + x_shape, y_shape, output_shape = (1, 4, 5), (1, 3, 5), (1, 4, 3) + if xzero_point_zero == True: + x_zero_point = 0 + else: + x_zero_point = -123 + + if yzero_point_zero == True: + y_zero_point = 0 + else: + y_zero_point = -123 + + in_dtype = "int8" + out_dtype = "int32" if not requantize_output else "int8" + quantized_x_np = ( + np.array( + [ + 1, + 3, + 5, + 7, + 9, # sum = 25 + 11, + 13, + 15, + -19, + -21, # sum = -1 + 1, + 3, + 5, + 7, + 9, # sum = 25 + 11, + 13, + -17, + 17, + -21, + ] + ) # sum = 3 + .astype(in_dtype) + .reshape(x_shape) + ) + quantized_y_np = ( + np.array([1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 1, 3, 5, 7, 9]) + .astype(in_dtype) + .reshape(y_shape) + ) + x_scale = 0.5 + y_scale = 0.5 + output_scale = 2.0 + + if requantize_output: + assert xzero_point_zero is True + assert yzero_point_zero is True + output = np.array([20, 51, 20, -26, -27, -26, 20, 51, 20, -14, -10, -14]) + elif xzero_point_zero is False and yzero_point_zero is False: + output = np.array( + [81960, 88360, 81960, 78400, 84540, 78400, 81960, 88360, 81960, 78984, 85164, 78984] + ) + elif xzero_point_zero is True and yzero_point_zero is False: + output = np.array([3240, 3490, 3240, -320, -330, -320, 3240, 3490, 3240, 264, 294, 264]) + elif xzero_point_zero is False and yzero_point_zero is True: + output = np.array([3240, 9640, 3240, 2878, 9018, 2878, 3240, 9640, 3240, 2970, 9150, 2970]) + else: + output = np.array([165, 415, 165, -197, -207, -197, 165, 415, 165, -105, -75, -105]) + + requant_params = ( + make_requantize_params(x_scale * y_scale, output_scale, -1, "int8") + if requantize_output + else None + ) + + output = output.astype(out_dtype).reshape(output_shape) + return make_configuration( + quantized_x=quantized_x_np, + quantized_y=quantized_y_np, + dtype=in_dtype, + x_shape=x_shape, + y_shape=y_shape, + x_zero_point=x_zero_point, + y_zero_point=y_zero_point, + x_scale=x_scale, + y_scale=y_scale, + output=output, + requantize=requant_params, + ) + + +def qnn_batch_matmul_driver(test_configuration): + in_dtype = test_configuration["dtype"] + out_dtype = test_configuration["out_dtype"] + quantized_x_name = "quantized_x" + quantized_y_name = "quantized_y" + expected_out_dtype = test_configuration["out_dtype"] + quantized_x = relay.var(quantized_x_name, shape=test_configuration["x_shape"], dtype=in_dtype) + quantized_y = relay.var(quantized_y_name, shape=test_configuration["y_shape"], dtype=in_dtype) + mod = relay.qnn.op.batch_matmul( + quantized_x, + quantized_y, + relay.const(test_configuration["x_zero_point"], "int32"), + relay.const(test_configuration["y_zero_point"], "int32"), + relay.const(test_configuration["x_scale"], "float32"), + relay.const(test_configuration["y_scale"], "float32"), + ) + if test_configuration["requantize"] is not None: + requantize_config = test_configuration["requantize"] + mod = relay.qnn.op.requantize( + mod, + input_scale=relay.const(requantize_config["input_scale"], "float32"), + input_zero_point=relay.const(0, "int32"), + output_scale=relay.const(requantize_config["output_scale"], "float32"), + output_zero_point=relay.const(requantize_config["output_zero_point"], "int32"), + out_dtype=requantize_config["out_dtype"], + ) + expected_out_dtype = requantize_config["out_dtype"] + + mod = relay.Function(relay.analysis.free_vars(mod), mod) + mod = tvm.IRModule.from_expr(mod) + mod = relay.transform.InferType()(mod) + mod = relay.qnn.transform.CanonicalizeOps()(mod) + with tvm.transform.PassContext(opt_level=2): + graph, lib, params = relay.build(mod, "llvm", params=None) + mod = graph_executor.create(graph, lib, device=tvm.cpu(0)) + mod.set_input(quantized_x_name, test_configuration[quantized_x_name]) + mod.set_input(quantized_y_name, test_configuration[quantized_y_name]) + mod.set_input(**params) + mod.run() + res = mod.get_output(0).numpy() + np.testing.assert_equal(res, test_configuration["output"]) + assert res.dtype == expected_out_dtype + + +def test_qnn_batch_matmul_xzp0_yzp0(): + with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int32_output_params = make_int_configuration(xzero_point_zero=True, yzero_point_zero=True) + qnn_batch_matmul_driver(int32_output_params) + + +def test_qnn_batch_matmul_xzp0(): + with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int32_output_params = make_int_configuration(xzero_point_zero=True, yzero_point_zero=False) + qnn_batch_matmul_driver(int32_output_params) + + +def test_qnn_batch_matmul_yzp0(): + with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int32_output_params = make_int_configuration(xzero_point_zero=False, yzero_point_zero=True) + qnn_batch_matmul_driver(int32_output_params) + + +def test_qnn_batch_matmul(): + with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int32_output_params = make_int_configuration(xzero_point_zero=False, yzero_point_zero=False) + qnn_batch_matmul_driver(int32_output_params) + + +def test_qnn_batch_matmul_with_requantized_output(): + with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int8_requantized_output_params = make_int_configuration(requantize_output=True) + qnn_batch_matmul_driver(int8_requantized_output_params) + + +if __name__ == "__main__": + test_qnn_batch_matmul_xzp0_yzp0() + test_qnn_batch_matmul_xzp0() + test_qnn_batch_matmul_yzp0() + test_qnn_batch_matmul() + test_qnn_batch_matmul_with_requantized_output() diff --git a/tests/python/relay/test_op_qnn_concatenate.py b/tests/python/relay/test_op_qnn_concatenate.py index 12571aad0822..c5f7bf1908ce 100644 --- a/tests/python/relay/test_op_qnn_concatenate.py +++ b/tests/python/relay/test_op_qnn_concatenate.py @@ -51,8 +51,9 @@ def test_same_io_qnn_params(): golden_output = np.concatenate((x_data, y_data), axis=axis) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) @@ -86,8 +87,9 @@ def test_different_io_qnn_params(): golden_output = np.concatenate((x_data - 2, y_data - 3), axis=axis) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) @@ -121,8 +123,9 @@ def test_few_same_io_qnn_params(): golden_output = np.concatenate((x_data + 1, y_data), axis=axis) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) @@ -156,8 +159,9 @@ def test_same_i_qnn_params(): golden_output = np.concatenate((x_data + 1, y_data + 1), axis=axis) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) @@ -183,8 +187,7 @@ def test_call_input(): ) func = relay.Function([x], z) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)(x_data) np.testing.assert_equal(op_res.numpy(), x_data) diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 3a81e6e7b47a..3736350cbfe1 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -49,10 +49,21 @@ def get_ref_func( groups, channels=None, ): + if isinstance(input_zero_point, (int, float)): + input_zero_point = relay.const(input_zero_point, "int32") + if isinstance(kernel_zero_point, (int, float)): + kernel_zero_point = relay.const(kernel_zero_point, "int32") + else: + # Kernel zero point expression requires manual broadcasting for some layouts. + if kernel_layout == "OIHW": + kernel_zero_point = relay.reshape(kernel_zero_point, [-1, 1, 1, 1]) + elif kernel_layout == "HWOI": + kernel_zero_point = relay.reshape(kernel_zero_point, [1, 1, -1, 1]) + casted_data = relay.op.cast(data, "int32") casted_kernel = relay.op.cast(kernel, "int32") - shifted_data = relay.op.subtract(casted_data, relay.const(input_zero_point, "int32")) - shifted_kernel = relay.op.subtract(casted_kernel, relay.const(kernel_zero_point, "int32")) + shifted_data = relay.op.subtract(casted_data, input_zero_point) + shifted_kernel = relay.op.subtract(casted_kernel, kernel_zero_point) func = relay.op.nn.conv2d( shifted_data, shifted_kernel, @@ -88,11 +99,16 @@ def get_qnn_func( channels, groups, ): + if isinstance(input_zero_point, (int, float)): + input_zero_point = relay.const(input_zero_point, "int32") + if isinstance(kernel_zero_point, (int, float)): + kernel_zero_point = relay.const(kernel_zero_point, "int32") + func = relay.qnn.op.conv2d( data, kernel, - input_zero_point=relay.const(input_zero_point, "int32"), - kernel_zero_point=relay.const(kernel_zero_point, "int32"), + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=relay.const(input_scale, "float32"), kernel_scale=relay.const(kernel_scale, "float32"), kernel_size=kernel_size, @@ -419,6 +435,62 @@ def test_both_zero_point(): verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) +def test_dynamic_zero_point(): + with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): + + # uint8 input with non static zero points. + data_shape = (2, 4, 2, 4) + data_dtype = "uint8" + kernel_shape = (3, 4, 2, 2) + kernel_dtype = "uint8" + input_zero_point = relay.op.multiply( + relay.const(2, dtype="int32"), relay.const(2, dtype="int32") + ) + kernel_zero_point = relay.const(np.random.randint(10, size=[3]), "int32") + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + # int8 input + data_shape = (2, 4, 2, 4) + data_dtype = "int8" + kernel_shape = (3, 4, 2, 2) + kernel_dtype = "int8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + def test_layout(): with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): @@ -888,13 +960,17 @@ def test_depthwise_depth_multiplier(): data_dtype = "uint8" kernel_shape = (4, 1, 3, 3) kernel_dtype = "uint8" + input_zero_point = relay.op.multiply( + relay.const(2, dtype="int32"), relay.const(2, dtype="int32") + ) + kernel_zero_point = relay.const(np.random.randint(10, size=[4]), "int32") ref_func, qnn_func = get_funcs( data_shape=data_shape, data_dtype=data_dtype, kernel_shape=kernel_shape, kernel_dtype=kernel_dtype, - input_zero_point=5, - kernel_zero_point=3, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=1.0, kernel_scale=1.0, kernel_size=(3, 3), @@ -905,6 +981,7 @@ def test_depthwise_depth_multiplier(): kernel_layout="OIHW", out_dtype="int32", groups=4, + channels=4, ) verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) @@ -919,8 +996,8 @@ def test_depthwise_depth_multiplier(): data_dtype=data_dtype, kernel_shape=kernel_shape, kernel_dtype=kernel_dtype, - input_zero_point=5, - kernel_zero_point=3, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=1.0, kernel_scale=1.0, kernel_size=(3, 3), @@ -946,8 +1023,8 @@ def test_depthwise_depth_multiplier(): data_dtype=data_dtype, kernel_shape=kernel_shape, kernel_dtype=kernel_dtype, - input_zero_point=5, - kernel_zero_point=3, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=1.0, kernel_scale=1.0, kernel_size=(3, 3), @@ -971,8 +1048,8 @@ def test_depthwise_depth_multiplier(): data_dtype=data_dtype, kernel_shape=kernel_shape, kernel_dtype=kernel_dtype, - input_zero_point=5, - kernel_zero_point=3, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=1.0, kernel_scale=1.0, kernel_size=(3, 3), diff --git a/tests/python/relay/test_op_qnn_mul.py b/tests/python/relay/test_op_qnn_mul.py index c4cd3244c8fe..af84f9778638 100644 --- a/tests/python/relay/test_op_qnn_mul.py +++ b/tests/python/relay/test_op_qnn_mul.py @@ -80,8 +80,9 @@ def test_tflite_same_io_qnn_params(): y_rec = recover(y_data, rhs_scale, rhs_zero_point) golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), np.uint8(golden)) @@ -134,8 +135,9 @@ def test_tflite_different_io_qnn_params(): y_rec = recover(y_data, rhs_scale, rhs_zero_point) golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), np.uint8(golden)) @@ -172,8 +174,9 @@ def test_saturation(): golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), np.uint8(golden)) # Same params, different scale @@ -206,8 +209,9 @@ def test_saturation(): golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), np.uint8(golden)) # All params different @@ -241,8 +245,9 @@ def test_saturation(): golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), np.uint8(golden)) diff --git a/tests/python/relay/test_op_qnn_quantize.py b/tests/python/relay/test_op_qnn_quantize.py index 345e8b815da1..322382ca002c 100644 --- a/tests/python/relay/test_op_qnn_quantize.py +++ b/tests/python/relay/test_op_qnn_quantize.py @@ -88,6 +88,20 @@ def test_float32_to_int8(): ) +def test_scalar_float32_to_int8(): + data = np.array(-63.5).astype("float32") + output = np.array(-128).astype("int8") + quant_args = {"out_zero_point": np.int32(-1), "out_scale": np.float32(0.5)} + quantize_test_driver( + in_dtype="float32", + quant_args=quant_args, + axis=-1, + out_dtype="int8", + in_data=data, + verify_output_data=output, + ) + + def test_channelwise_axis_0(): data = ( np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) @@ -163,6 +177,7 @@ def test_dynamic_quantize(): if __name__ == "__main__": test_float32_to_uint8() test_float32_to_int8() + test_scalar_float32_to_int8() test_channelwise_axis_0() test_channelwise_axis_1() test_dynamic_quantize() diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index ad9805e74929..0f512df25cdf 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -92,6 +92,24 @@ def test_same_scale(): verify(mod, (golden_data, golden_output)) +def test_scalar_same_scale(): + # Have same scales, everything within range + golden_data = np.array(-10).astype("int32") + golden_output = golden_data + + for rounding in roundings: + mod = get_mod( + data_shape=(), + data_dtype="int32", + out_dtype="int8", + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + ) + assert "right_shift" not in mod.astext() + verify(mod, (golden_data, golden_output)) + + def test_downscale(): for rounding in roundings: mod = get_mod( @@ -437,6 +455,7 @@ def test_per_channel_different_scale(): if __name__ == "__main__": test_same_scale() + test_scalar_same_scale() test_downscale() test_upscale() test_non_power_of_two() diff --git a/tests/python/relay/test_op_qnn_simulated_dequantize.py b/tests/python/relay/test_op_qnn_simulated_dequantize.py index 64c70be9d7a7..e15fdf770c64 100644 --- a/tests/python/relay/test_op_qnn_simulated_dequantize.py +++ b/tests/python/relay/test_op_qnn_simulated_dequantize.py @@ -77,8 +77,8 @@ def verify_simulated_dequantize_simple(dtype): ) input_data = relay.var("input_data", shape=data.shape, dtype="float32") scale = relay.var("scale", shape=[]) - zp = relay.var("zp", shape=[]) - dtype = relay.var("dtype", shape=[]) + zp = relay.var("zp", shape=[], dtype="int32") + dtype = relay.var("dtype", shape=[], dtype="int32") vm = build_simulated_dequantize(input_data, scale, zp, dtype) sim_dq_out = vm.invoke("main", input_data=data_fp, scale=scale_np, zp=zp_np, dtype=dtype_np) np.testing.assert_allclose(sim_dq_out.numpy(), dq_out, rtol=1e-5) @@ -109,7 +109,7 @@ def test_dynamic_channels(): input_data = relay.var("input_data", shape=data.shape, dtype="float32") scale = relay.var("scale", shape=[relay.Any()], dtype="float32") zp = relay.var("zp", shape=[relay.Any()], dtype="int32") - dtype = relay.var("dtype", shape=[]) + dtype = relay.var("dtype", shape=[], dtype="int32") vm = build_simulated_dequantize(input_data, scale, zp, dtype, axis=0) sim_dq_out = vm.invoke("main", input_data=data_fp, scale=scale_np, zp=zp_np, dtype=dtype_np) np.testing.assert_allclose(sim_dq_out.numpy(), dq_out, rtol=1e-5) @@ -150,7 +150,7 @@ def test_dynamic_dtype(): input_data = relay.var("input_data", shape=data.shape, dtype="float32") scale = relay.var("scale", shape=[relay.Any()], dtype="float32") zp = relay.var("zp", shape=[relay.Any()], dtype="int32") - dtype = relay.var("dtype", shape=[]) + dtype = relay.var("dtype", shape=[], dtype="int32") vm = build_simulated_dequantize(input_data, scale, zp, dtype) sim_dq_out = vm.invoke("main", input_data=data_fp, scale=scale_np, zp=zp_np, dtype=dtype_np) np.testing.assert_allclose(sim_dq_out.numpy(), dq_out, rtol=1e-5) diff --git a/tests/python/relay/test_op_qnn_simulated_quantize.py b/tests/python/relay/test_op_qnn_simulated_quantize.py index 14014f2e4605..69ce261f6b09 100644 --- a/tests/python/relay/test_op_qnn_simulated_quantize.py +++ b/tests/python/relay/test_op_qnn_simulated_quantize.py @@ -85,8 +85,8 @@ def verify_simulated_quantize_simple(dtype): ) input_data = relay.var("input_data", shape=data.shape, dtype="float32") scale = relay.var("scale", shape=[]) - zp = relay.var("zp", shape=[]) - dtype = relay.var("dtype", shape=[]) + zp = relay.var("zp", shape=[], dtype="int32") + dtype = relay.var("dtype", shape=[], dtype="int32") vm = build_simulated_quantize(input_data, scale, zp, dtype) sim_q_out = vm.invoke("main", input_data=data, scale=scale_np, zp=zp_np, dtype=dtype_np) allclose_with_rounding(sim_q_out.numpy(), q_out) @@ -117,7 +117,7 @@ def test_dynamic_channels(): input_data = relay.var("input_data", shape=data.shape, dtype="float32") scale = relay.var("scale", shape=[relay.Any()], dtype="float32") zp = relay.var("zp", shape=[relay.Any()], dtype="int32") - dtype = relay.var("dtype", shape=[]) + dtype = relay.var("dtype", shape=[], dtype="int32") vm = build_simulated_quantize(input_data, scale, zp, dtype, axis=0) sim_q_out = vm.invoke("main", input_data=data, scale=scale_np, zp=zp_np, dtype=dtype_np) allclose_with_rounding(sim_q_out.numpy(), q_out) @@ -159,7 +159,7 @@ def test_dynamic_dtype(): input_data = relay.var("input_data", shape=data.shape, dtype="float32") scale = relay.var("scale", shape=[relay.Any()], dtype="float32") zp = relay.var("zp", shape=[relay.Any()], dtype="int32") - dtype = relay.var("dtype", shape=[]) + dtype = relay.var("dtype", shape=[], dtype="int32") vm = build_simulated_quantize(input_data, scale, zp, dtype) sim_q_out = vm.invoke("main", input_data=data, scale=scale_np, zp=zp_np, dtype=dtype_np) allclose_with_rounding(sim_q_out.numpy(), q_out) diff --git a/tests/python/relay/test_op_qnn_subtract.py b/tests/python/relay/test_op_qnn_subtract.py index 4f9a36757b81..f7117b559401 100644 --- a/tests/python/relay/test_op_qnn_subtract.py +++ b/tests/python/relay/test_op_qnn_subtract.py @@ -52,8 +52,9 @@ def qnn_subtract_driver(x_datas, y_datas, golden_outputs, scale_and_zp, data_dty x_data = x_datas[i] y_data = y_datas[i] golden_output = golden_outputs[i] - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 3917d81fc2c3..b5702a1542a9 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -712,7 +712,6 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) -@tvm.testing.uses_gpu def test_alter_layout_strided_slice(): """Test rewriting strided_slice during alter_iop_layout""" @@ -759,12 +758,16 @@ def expected(): with relay.build_config(opt_level=3): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug", "vm"]: - ex_before = relay.create_executor(kind, mod=mod_before, device=dev, target=target) - ex_new = relay.create_executor(kind, mod=mod_new, device=dev, target=target) np_data = np.random.uniform(size=(1, 32, 28, 28)).astype("float32") np_weight = np.random.uniform(size=(32, 32, 3, 3)).astype("float32") - result_before = ex_before.evaluate()(np_data, np_weight) - result_new = ex_new.evaluate()(np_data, np_weight) + f_before = relay.create_executor( + kind, mod=mod_before, device=dev, target=target + ).evaluate() + result_before = f_before(np_data, np_weight) + f_new = relay.create_executor( + kind, mod=mod_new, device=dev, target=target + ).evaluate() + result_new = f_new(np_data, np_weight) tvm.testing.assert_allclose( result_before.numpy(), result_new.numpy(), rtol=1e-5, atol=1e-5 ) @@ -1316,7 +1319,9 @@ def expected(): weight = relay.var("weight", shape=(48, 64)) target_layout = "NK16n" weight_transform = relay.layout_transform(weight, "NK", target_layout) - y = relay.nn.contrib_dense_pack(x, weight_transform, units=None, out_dtype="float32") + y = relay.nn.contrib_dense_pack( + x, weight_transform, target_layout, units=None, out_dtype="float32" + ) y = relay.Function(analysis.free_vars(y), y) return y @@ -1354,6 +1359,49 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): assert before.body.attrs.layout == "NCHW" +def test_alter_op_dense_packed_data(): + def before(): + x = relay.var("x", shape=(1, 32, 8, 8)) + weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3)) + conv = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) + pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0]) + squeeze = relay.squeeze(pool, axis=[2, 3]) + dense = relay.nn.dense(squeeze, relay.var("dense_weight", shape=(16, 32))) + return relay.Function(analysis.free_vars(dense), dense) + + def expected(): + x = relay.var("x", shape=(1, 32, 8, 8)) + conv_weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3)) + dense_weight = relay.var("dense_weight", shape=(16, 32)) + conv = relay.nn.contrib_conv2d_nchwc( + relay.layout_transform(x, "NCHW", "NCHW8c"), + relay.layout_transform(conv_weight, "OIHW", "OIHW8i8o"), + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW8c", + kernel_layout="OIHW8i8o", + out_layout="NCHW8c", + ) + pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0], layout="NCHW8c") + squeeze = relay.squeeze(pool, axis=[2, 3]) + dense = relay.nn.contrib_dense_pack( + relay.layout_transform(squeeze, "NC8c", "NC"), + relay.layout_transform(dense_weight, "NK", "NK16n"), + "NK16n", + out_dtype="float32", + ) + return relay.Function(analysis.free_vars(dense), dense) + + with tvm.target.Target("llvm"): + with TempOpAttr( + "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout + ): + a = run_opt_pass(before(), transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b) + + if __name__ == "__main__": test_alter_op() test_alter_return_none() @@ -1378,3 +1426,4 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): test_alter_op_dense() test_alter_layout_strided_slice_axes_nhwc() test_not_inplace_modify() + test_alter_op_dense_packed_data() diff --git a/tests/python/relay/test_pass_annotate_spans_defuse.py b/tests/python/relay/test_pass_annotate_spans_defuse.py new file mode 100644 index 000000000000..def4a1da1b55 --- /dev/null +++ b/tests/python/relay/test_pass_annotate_spans_defuse.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for annotating spans.""" + + +import tvm +import tvm.relay as relay +from tvm.relay import testing +import tvm.testing + + +def test_annotate_spans_compatibility(): + data = relay.var("data", relay.TensorType((1, 3, 64, 64), "float32")) + weight = relay.var("weight") + + bn_gamma = relay.var("bn_gamma") + bn_beta = relay.var("bn_beta") + bn_mmean = relay.var("bn_mean") + bn_mvar = relay.var("bn_var") + + simple_net = relay.nn.conv2d( + data=data, weight=weight, kernel_size=(3, 3), channels=3, padding=(1, 1) + ) + simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] + simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net) + + module, params = testing.create_workload(simple_net) + + # Apply some simple passes to legalize the IR. + with tvm.transform.PassContext(opt_level=0): + module, params = relay.optimize(module, tvm.testing.enabled_targets()[0][0], params) + + seq = tvm.transform.Sequential([relay.transform.AnnotateSpans(), relay.transform.DefuseOps()]) + with tvm.transform.PassContext(opt_level=3): + module = seq(module) + + +if __name__ == "__main__": + test_annotate_spans_compatibility() diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 098fb5c64e82..23ef9d11eb77 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -144,8 +144,9 @@ def test_run(): i_data = np.random.uniform(0, 1, ishape).astype(dtype) w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()) - ref_res = ref_ex.evaluate()(i_data, w1_data) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()).evaluate()( + i_data, w1_data + ) check_result( mod, {"data": i_data, "weight1": w1_data}, (1, 32, 14, 14), ref_res.numpy(), tol=1e-5 @@ -171,8 +172,9 @@ def test_extern_dnnl_mobilenet(): i_data = np.random.uniform(0, 1, ishape).astype(dtype) ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1, dtype="float32") - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)) - ref_res = ref_ex.evaluate()(i_data, **params) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( + i_data, **params + ) check_result(mod, {"data": i_data}, (1, 1000), ref_res.numpy(), tol=1e-5, params=params) diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index 51b9f5f24d1d..030682148a5f 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -185,8 +185,9 @@ def verify_partition(mod, params): params = [gen_rand_tvm(param.type_annotation, 0, 1) for param in partitioned_mod["main"].params] def _eval_mod(mod): - vm = relay.create_executor("vm", device=tvm.cpu(0), target="llvm", mod=mod) - return vm.evaluate()(*params) + return relay.create_executor("vm", device=tvm.cpu(0), target="llvm", mod=mod).evaluate()( + *params + ) partitioned_mod_result = _eval_mod(partitioned_mod) unpartitioned_mod_result = _eval_mod(unpartitioned_mod) diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index fafab3ee3584..2b7e3e9eb3a9 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -21,6 +21,8 @@ from tvm import relay from tvm.relay.op import register_alter_op_layout from tvm.relay import transform, analysis +from tvm.relay.transform.infer_layout_utils import InferCorrectLayoutOutput +from tvm.relay.op import op as reg def run_opt_pass(expr, passes): @@ -1881,6 +1883,48 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_infer_correct_layout(): + test_infer_correct_layout_flag = False + + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + @reg.register_infer_correct_layout("nn.relu", level=11) + def infer_correct_layout_relu(attrs, new_in_layouts, old_in_layouts, old_in_types): + nonlocal test_infer_correct_layout_flag + test_infer_correct_layout_flag = True + ret = tvm.tir.layout("") + if new_in_layouts: + assert len(new_in_layouts) >= 1 + ret = new_in_layouts[0] + else: + for i in range(len(old_in_layouts)): + if old_in_layouts[i]: + ret = old_in_layouts[i] + break + input_layouts = [] + for i in range(len(old_in_layouts)): + input_layouts.append(ret) + return InferCorrectLayoutOutput(input_layouts, [ret], attrs) + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + assert test_infer_correct_layout_flag == True + + if __name__ == "__main__": test_qnn_binary_no_convert_layout() test_no_convert_layout() @@ -1914,3 +1958,4 @@ def expected(): test_conv_strided_slice_axes_convert_layout() test_image_resize_convert_layout() test_conv_image_resize_convert_layout() + test_infer_correct_layout() diff --git a/tests/python/relay/test_pass_defunctionalization.py b/tests/python/relay/test_pass_defunctionalization.py index 57dbb82c2d0d..30f2203be0b5 100644 --- a/tests/python/relay/test_pass_defunctionalization.py +++ b/tests/python/relay/test_pass_defunctionalization.py @@ -124,8 +124,7 @@ def to_adt_list(mod, arr): li = nil() for a in arr: li = cons(relay.const(a), li) - ex = relay.create_executor(mod=mod) - adt = ex.evaluate(li) + adt = relay.create_executor(mod=mod).evaluate(li) mod["main"] = expr return adt @@ -148,11 +147,9 @@ def @main(%l: Tensor[(5, 5), float32]) -> Tensor[(5, 5), float32] { input = np.random.rand(5, 5).astype("float32") - ex = relay.create_executor("debug", mod=mod) - defunc_ex = relay.create_executor("debug", mod=defunc_mod) + out = relay.create_executor("debug", mod=mod).evaluate()(input) - out = ex.evaluate()(input) - defunc_out = defunc_ex.evaluate()(input) + defunc_out = relay.create_executor("debug", mod=defunc_mod).evaluate()(input) np.testing.assert_equal(out.numpy(), defunc_out.numpy()) @@ -182,11 +179,11 @@ def @main(%l: List[float32]) -> List[float32] { input = np.random.rand(10).astype("float32") - ex = relay.create_executor("debug", mod=mod) - defunc_ex = relay.create_executor("debug", mod=defunc_mod) + out = relay.create_executor("debug", mod=mod).evaluate(mod["main"])(to_adt_list(mod, input)) - out = ex.evaluate(mod["main"])(to_adt_list(mod, input)) - defunc_out = defunc_ex.evaluate()(to_adt_list(defunc_mod, input)) + defunc_out = relay.create_executor("debug", mod=defunc_mod).evaluate()( + to_adt_list(defunc_mod, input) + ) np.testing.assert_array_equal(to_list(mod, out), to_list(defunc_mod, defunc_out)) @@ -220,11 +217,11 @@ def @main(%l: List[int32]) -> int32 { input = np.random.randint(1, 100, 10) - ex = relay.create_executor("debug", mod=mod) - defunc_ex = relay.create_executor("debug", mod=defunc_mod) + out = relay.create_executor("debug", mod=mod).evaluate(mod["main"])(to_adt_list(mod, input)) - out = ex.evaluate(mod["main"])(to_adt_list(mod, input)) - defunc_out = defunc_ex.evaluate()(to_adt_list(defunc_mod, input)) + defunc_out = relay.create_executor("debug", mod=defunc_mod).evaluate()( + to_adt_list(defunc_mod, input) + ) tvm.testing.assert_allclose(out.numpy(), defunc_out.numpy()) diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 962b7bebb12b..836d49b3441b 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -40,8 +40,9 @@ def verify_func(func, data, ref_res, rtol=1e-5, atol=1e-7): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(*data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + *data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol, atol=atol) @@ -181,8 +182,9 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): continue for kind in ["graph", "vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func2) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(np_data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + np_data + ) if ret_type == "both": tvm.testing.assert_allclose(op_res[0].numpy(), np_values) tvm.testing.assert_allclose(op_res[1].numpy(), np_indices) diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 3271379cf3ef..2bc2e4e635f0 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -22,6 +22,31 @@ from tvm import relay +def compare_fq_to_int(expr, args, allow_rounding_error=False): + mod = tvm.IRModule.from_expr(expr) + mod = tvm.relay.transform.InferType()(mod) + + mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod_int) + + result = ( + relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + + result_int = ( + relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + + if allow_rounding_error: + assert np.all(np.abs(result - result_int) <= 1) + else: + assert np.array_equal(result, result_int) + + def test_fake_quantize_conv(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") @@ -35,23 +60,48 @@ def test_fake_quantize_conv(): ) op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np, w_np]) + - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np, w_np).asnumpy() +def test_fake_quantize_dense(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[128, 64], dtype="int8") + w = relay.var("w", shape=[256, 64], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.dense( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize(w, relay.const(0.5), zero), + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np, w_np).asnumpy() + x_np = np.random.randint(-128, 127, size=[128, 64], dtype="int8") + w_np = np.random.randint(-128, 127, size=[256, 64], dtype="int8") - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np, w_np]) + + +def test_fake_quantize_batch_matmul(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[1, 128, 64], dtype="int8") + w = relay.var("w", shape=[1, 256, 64], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.batch_matmul( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize(w, relay.const(0.5), zero), + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) + + x_np = np.random.randint(-128, 127, size=[1, 128, 64], dtype="int8") + w_np = np.random.randint(-128, 127, size=[1, 256, 64], dtype="int8") + + compare_fq_to_int(op, [x_np, w_np]) def test_fake_transpose_quantize_conv(): @@ -65,23 +115,10 @@ def test_fake_transpose_quantize_conv(): op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) op = relay.qnn.op.quantize(op, one, zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) - - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np, w_np).asnumpy() - - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np, w_np).asnumpy() - - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np, w_np]) def test_fake_transpose_quantize_conv_bias_add(): @@ -97,24 +134,32 @@ def test_fake_transpose_quantize_conv_bias_add(): op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, one, zero)) op = relay.qnn.op.quantize(op, one, zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np, w_np, bias_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np, w_np, bias_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np, w_np, bias_np).asnumpy() +def test_fake_transpose_quantize_conv_bias_add_mismatch(): + x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + bias = relay.var("bias", shape=[16], dtype="int32") + one = relay.const(1.0) + two = relay.const(2.0) + zero = relay.const(0) + + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + x = relay.transpose(x, [0, 3, 1, 2]) + op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, two, zero)) + op = relay.qnn.op.quantize(op, one, zero) + + x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32") - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np, w_np, bias_np]) def test_fake_quantize_maxpool(): @@ -125,101 +170,121 @@ def test_fake_quantize_maxpool(): op = relay.op.nn.max_pool2d(x, [3, 3]) op = relay.qnn.op.quantize(op, relay.const(2.0), zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_avgpool(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.avg_pool2d(x, [3, 3]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np], True) -def test_fake_quantize_avgpool(): +def test_fake_quantize_reshape(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") zero = relay.const(0) x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) - op = relay.op.nn.avg_pool2d(x, [3, 3]) + op = relay.op.reshape(x, [1, 3, -1]) op = relay.qnn.op.quantize(op, relay.const(2.0), zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_expand_dims(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.expand_dims(x, axis=1) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() +def test_fake_quantize_squeeze(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.squeeze(x, axis=[0]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - assert np.all(np.abs(result - result2) <= 1) + compare_fq_to_int(op, [x_np]) -def test_fake_quantize_reshape(): +def test_fake_quantize_strided_slice(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") zero = relay.const(0) x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) - op = relay.op.reshape(x, [1, 3, -1]) + op = relay.op.strided_slice(x, begin=[0, 0, 0, 0], end=[1, 1, 112, 112]) op = relay.qnn.op.quantize(op, relay.const(2.0), zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_split(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.split(x, axis=3, indices_or_sections=2) + op = relay.qnn.op.quantize(op[0], relay.const(2.0), zero) x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() + op = relay.op.split(x, axis=3, indices_or_sections=[56, 112, 168]) + op = relay.qnn.op.quantize(op[1], relay.const(2.0), zero) - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np]) -def test_fake_quantize_transpose_reshape(): +def test_fake_quantize_batch_flatten(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") zero = relay.const(0) x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) - op = relay.op.transpose(x, [1, 0, 2, 3]) - op = relay.op.reshape(op, [3, -1]) + op = relay.op.nn.batch_flatten(x) op = relay.qnn.op.quantize(op, relay.const(2.0), zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() +def test_fake_quantize_transpose_reshape(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") - assert np.array_equal(result, result2) + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.transpose(x, [1, 0, 2, 3]) + op = relay.op.reshape(op, [3, -1]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np]) def test_fake_quantize_concat(): @@ -234,24 +299,11 @@ def test_fake_quantize_concat(): concat = relay.op.concatenate(inputs, axis=1) out = relay.qnn.op.quantize(concat, relay.const(3.5), zero) - mod = tvm.IRModule.from_expr(out) - mod = tvm.relay.transform.InferType()(mod) - inputs_np = [] for i in range(4): inputs_np.append(np.random.randint(-128, 127, size=[1, 4], dtype="int8")) - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) - - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(*inputs_np).asnumpy() - - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(*inputs_np).asnumpy() - - assert np.array_equal(result, result2) + compare_fq_to_int(out, inputs_np) def test_fake_quantize_clip(): @@ -261,19 +313,67 @@ def test_fake_quantize_clip(): op = relay.op.clip(x, 0, 6) op = relay.qnn.op.quantize(op, relay.const(2.0), relay.const(114), out_dtype="uint8") - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) + + +@pytest.mark.parametrize( + "operator", + [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum, relay.op.maximum], +) +def test_fake_quantize_binary(operator): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + x = relay.qnn.op.dequantize(x, relay.const(0.1), relay.const(0)) + + y = relay.var("y", shape=[1, 3, 224, 224], dtype="int8") + y = relay.qnn.op.dequantize(y, relay.const(0.2), relay.const(0)) + + op = operator(x, y) + if operator == relay.op.multiply: + out_scale = relay.const(20.0) + else: + out_scale = relay.const(0.1) + + op = relay.qnn.op.quantize(op, out_scale, relay.const(0), out_dtype="int8") + + x_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8") + y_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np, y_np]) + + +@pytest.mark.parametrize( + "operator", + [ + relay.op.add, + relay.op.multiply, + relay.op.subtract, + relay.op.subtract, + relay.op.minimum, + relay.op.maximum, + ], +) +def test_fake_quantize_binary_const(operator): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + x = relay.qnn.op.dequantize(x, relay.const(0.1), relay.const(10)) + + y = relay.const(1.0) + + op = operator(x, y) + op = relay.qnn.op.quantize(op, relay.const(0.1), relay.const(10), out_dtype="int8") + + x_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np]) + - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() +def test_fake_quantize_pad(): + x = relay.var("x", shape=[1, 383, 128], dtype="int8") + x = relay.qnn.op.dequantize(x, relay.const(1.0), relay.const(10)) + op = relay.op.nn.pad(x, [[0, 0], [0, 1], [0, 0]], 0.0) + op = relay.qnn.op.quantize(op, relay.const(1.0), relay.const(10), out_dtype="int8") - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() + x_np = np.random.randint(-25, 25, size=[1, 383, 128], dtype="int8") - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np]) diff --git a/tests/python/relay/test_pass_fold_explicit_padding.py b/tests/python/relay/test_pass_fold_explicit_padding.py index 58ba58aa06d3..effebaaf1e8b 100644 --- a/tests/python/relay/test_pass_fold_explicit_padding.py +++ b/tests/python/relay/test_pass_fold_explicit_padding.py @@ -70,12 +70,14 @@ def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout): mod2 = tvm.IRModule.from_expr(zz) with tvm.transform.PassContext(): - ex1 = relay.create_executor("vm", mod=mod1, device=tvm.cpu(), target="llvm") - ex2 = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + func1 = relay.create_executor( + "vm", mod=mod1, device=tvm.cpu(), target="llvm" + ).evaluate() + func2 = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm").evaluate() x_np = np.random.rand(*shape).astype("float32") w_np = np.random.rand(*wshape).astype("float32") - result1 = ex1.evaluate()(x_np, w_np) - result2 = ex2.evaluate()(x_np, w_np) + result1 = func1(x_np, w_np) + result2 = func2(x_np, w_np) tvm.testing.assert_allclose(result1.numpy(), result2.numpy(), rtol=1e-5, atol=1e-5) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 931f453f9a6d..855650f810a5 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -775,9 +775,9 @@ def test_fuse_dynamic_squeeze_slice_take(): take = relay.op.take(strided_slice, take_val, axis=0) mod = tvm.IRModule.from_expr(take) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - - result = ex.evaluate()(*input_data) + result = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm").evaluate()( + *input_data + ) np_result = np.squeeze(input_data[0][:, input_data[1][0], :], axis=0) diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index cd0edf95aba7..126fcf22e823 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -45,9 +45,8 @@ def test_fo_id(): func = run_infer_type(func) back_func = run_infer_type(gradient(func, mode="first_order")) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) - ex = create_executor() x = rand(dtype, *shape) - forward, (grad,) = ex.evaluate(back_func)(x) + forward, (grad,) = create_executor().evaluate(back_func)(x) tvm.testing.assert_allclose(forward.numpy(), x.numpy()) tvm.testing.assert_allclose(grad.numpy(), np.ones_like(x.numpy())) @@ -61,9 +60,8 @@ def test_id(): func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) - ex = create_executor() x = rand(dtype, *shape) - forward, (grad,) = ex.evaluate(back_func)(x) + forward, (grad,) = create_executor().evaluate(back_func)(x) tvm.testing.assert_allclose(forward.numpy(), x.numpy()) tvm.testing.assert_allclose(grad.numpy(), np.ones_like(x.numpy())) @@ -89,9 +87,8 @@ def test_add(): func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) - ex = create_executor() x = rand(dtype, *shape) - forward, (grad,) = ex.evaluate(back_func)(x) + forward, (grad,) = create_executor().evaluate(back_func)(x) tvm.testing.assert_allclose(forward.numpy(), 2 * x.numpy()) tvm.testing.assert_allclose(grad.numpy(), 2 * np.ones_like(x.numpy())) @@ -118,9 +115,8 @@ def test_temp_add(): func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) - ex = create_executor() x = rand(dtype, *shape) - forward, (grad,) = ex.evaluate(back_func)(x) + forward, (grad,) = create_executor().evaluate(back_func)(x) tvm.testing.assert_allclose(forward.numpy(), 4 * x.numpy()) tvm.testing.assert_allclose(grad.numpy(), 4 * np.ones_like(x.numpy())) @@ -134,9 +130,8 @@ def test_sub(): func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) - ex = create_executor() x = rand(dtype, *shape) - forward, (grad,) = ex.evaluate(back_func)(x) + forward, (grad,) = create_executor().evaluate(back_func)(x) tvm.testing.assert_allclose(forward.numpy(), np.zeros_like(x.numpy())) tvm.testing.assert_allclose(grad.numpy(), np.zeros_like(x.numpy())) @@ -163,8 +158,7 @@ def test_broadcast_add(): [relay.TensorType(expected_forward.shape, dtype), relay.TupleType([t1, t2])] ), ) - ex = create_executor() - forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd) + forward, (grad_x, grad_y) = create_executor().evaluate(full_func)(x_nd, y_nd) tvm.testing.assert_allclose(forward.numpy(), expected_forward) tvm.testing.assert_allclose( grad_x.numpy(), np.ones_like(expected_forward).sum(axis=2, keepdims=True) @@ -197,8 +191,7 @@ def test_broadcast_subtract(): [relay.TensorType(expected_forward.shape, dtype), relay.TupleType([t1, t2])] ), ) - ex = create_executor() - forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd) + forward, (grad_x, grad_y) = create_executor().evaluate(full_func)(x_nd, y_nd) tvm.testing.assert_allclose(forward.numpy(), expected_forward) tvm.testing.assert_allclose( grad_x.numpy(), np.ones_like(expected_forward).sum(axis=2, keepdims=True) @@ -247,8 +240,7 @@ def _test_tuple(mode): y_np = y_nd.numpy() z_np = z_nd.numpy() expected_forward = x_np + y_np - z_np - ex = create_executor() - forward, (grad_x, grad_y, grad_z) = ex.evaluate(back_func)(x_nd, y_nd, z_nd) + forward, (grad_x, grad_y, grad_z) = create_executor().evaluate(back_func)(x_nd, y_nd, z_nd) tvm.testing.assert_allclose(forward.numpy(), expected_forward) tvm.testing.assert_allclose(grad_x.numpy(), np.ones_like(grad_x.numpy())) tvm.testing.assert_allclose(grad_y.numpy(), np.ones_like(grad_y.numpy())) @@ -271,8 +263,7 @@ def _test_tuple_argument(mode): xs = [rand(dtype, *shape) for _ in range(fields)] xs_np = np.array([x.numpy() for x in xs]) expected_forward = np.sum(xs_np, axis=0) - ex = create_executor() - forward, grad = ex.evaluate(back_func)(tuple(xs)) + forward, grad = create_executor().evaluate(back_func)(tuple(xs)) tvm.testing.assert_allclose(forward.numpy(), expected_forward) for field in grad[0]: tvm.testing.assert_allclose(field.numpy(), np.ones_like(field.numpy())) @@ -315,8 +306,7 @@ def test_pow(): back_func = m["main"] assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) i_nd = rand(dtype, *shape) - ex = create_executor(mod=mod) - forward, (grad_i,) = ex.evaluate(back_func)(i_nd) + forward, (grad_i,) = create_executor(mod=mod).evaluate(back_func)(i_nd) tvm.testing.assert_allclose(forward.numpy(), 8 * i_nd.numpy()) tvm.testing.assert_allclose(grad_i.numpy(), 8 * np.ones_like(grad_i.numpy())) @@ -336,8 +326,7 @@ def test_ref(): back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) x_nd = rand(dtype, *shape) - ex = create_executor() - forward, (grad_x,) = ex.evaluate(back_func)(x_nd) + forward, (grad_x,) = create_executor().evaluate(back_func)(x_nd) tvm.testing.assert_allclose(forward.numpy(), 2 * x_nd.numpy()) tvm.testing.assert_allclose(grad_x.numpy(), 2 * np.ones_like(grad_x.numpy())) @@ -358,8 +347,7 @@ def test_square_second_order(): back_back_func = run_infer_type(gradient(back_func_adjusted)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) x_nd = rand(dtype, *shape) - ex = create_executor() - forward, (grad_x,) = ex.evaluate(back_back_func)(x_nd) + forward, (grad_x,) = create_executor().evaluate(back_back_func)(x_nd) tvm.testing.assert_allclose(forward.numpy(), 2 * x_nd.numpy()) tvm.testing.assert_allclose(grad_x.numpy(), 2 * np.ones_like(grad_x.numpy())) @@ -390,9 +378,8 @@ def test_grad_tuple(): assert back_func.checked_type == relay.FuncType( [t], relay.TupleType([relay.TupleType([t, t]), relay.TupleType([t])]) ) - ex = create_executor() x = rand(dtype, *shape) - (forward_four, forward_two), (grad,) = ex.evaluate(back_func)(x) + (forward_four, forward_two), (grad,) = create_executor().evaluate(back_func)(x) tvm.testing.assert_allclose(forward_four.numpy(), 4 * x.numpy()) tvm.testing.assert_allclose(forward_two.numpy(), 2 * x.numpy()) tvm.testing.assert_allclose(grad.numpy(), 4 * np.ones_like(x.numpy())) @@ -463,9 +450,8 @@ def test_global_function(): m = tvm.relay.transform.InferType()(m) back_func = m[g] assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) - ex = create_executor(mod=m) x = rand(dtype, *shape) - forward, (grad,) = ex.evaluate(back_func)(x) + forward, (grad,) = create_executor(mod=m).evaluate(back_func)(x) tvm.testing.assert_allclose(forward.numpy(), 4 * x.numpy()) tvm.testing.assert_allclose(grad.numpy(), 4 * np.ones_like(x.numpy())) diff --git a/tests/python/relay/test_pass_lazy_gradient_init.py b/tests/python/relay/test_pass_lazy_gradient_init.py index f37856669306..a0af2205a5d0 100644 --- a/tests/python/relay/test_pass_lazy_gradient_init.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -65,9 +65,8 @@ def test_add(): assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy() + x.numpy()) @@ -92,9 +91,8 @@ def test_add_tuple(): assert mod["main"].checked_type == relay.FuncType([t], tensor_type) - ex = create_executor(mod=mod) x = (rand(dtype, *shape), rand(dtype, *shape)) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x[0].numpy() + x[1].numpy()) @@ -117,9 +115,8 @@ def test_mult(): assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy() * x.numpy()) @@ -143,9 +140,8 @@ def test_ret_tuple(): assert mod["main"].checked_type == relay.FuncType([t], relay.TupleType([t, t])) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(func)(x) + y = create_executor(mod=mod).evaluate(func)(x) assert_allclose(y[0].numpy(), x.numpy()) assert_allclose(y[1].numpy(), x.numpy() * 2.0) @@ -177,8 +173,7 @@ def test_add_broadcast(): expected_forward_type = relay.TensorType(expected_forward.shape, dtype) assert mod["main"].checked_type == relay.FuncType([t1, t2], expected_forward_type) - ex = create_executor(mod=mod) - forward = ex.evaluate(func)(x1_np, x2_np) + forward = create_executor(mod=mod).evaluate(func)(x1_np, x2_np) assert_allclose(forward.numpy(), expected_forward) @@ -208,9 +203,8 @@ def test_reverse_ad_identity(): [t], relay.TupleType([t, relay.TupleType([t])]) ) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - (forward), (grad,) = ex.evaluate(back_func)(x) + (forward), (grad,) = create_executor(mod=mod).evaluate(back_func)(x) assert_allclose(forward.numpy(), x.numpy()) assert_allclose(grad.numpy(), np.ones_like(x.numpy())) @@ -240,10 +234,9 @@ def test_multivar_reverse_ad(): [t, t], relay.TupleType([t, relay.TupleType([t, t])]) ) - ex = create_executor(mod=mod) x = rand(dtype, *shape) y = rand(dtype, *shape) - (forward), (grad_x, grad_y,) = ex.evaluate( + (forward), (grad_x, grad_y,) = create_executor(mod=mod).evaluate( back_func )(x, y) assert_allclose(forward.numpy(), x.numpy() * y.numpy()) @@ -305,10 +298,9 @@ def test_after_partial_eval(): [t, t], relay.TupleType([t, relay.TupleType([t, t])]) ) - ex = create_executor(mod=mod) x = rand(dtype, *shape) y = rand(dtype, *shape) - (forward), (grad_x, grad_y,) = ex.evaluate( + (forward), (grad_x, grad_y,) = create_executor(mod=mod).evaluate( back_func )(x, y) assert_allclose(forward.numpy(), x.numpy() * y.numpy()) @@ -343,10 +335,9 @@ def test_before_partial_eval(): [t, t], relay.TupleType([t, relay.TupleType([t, t])]) ) - ex = create_executor(mod=mod) x = rand(dtype, *shape) y = rand(dtype, *shape) - (forward), (grad_x, grad_y,) = ex.evaluate( + (forward), (grad_x, grad_y,) = create_executor(mod=mod).evaluate( back_func )(x, y) assert_allclose(forward.numpy(), x.numpy() * y.numpy()) @@ -372,9 +363,8 @@ def test_zeros(): assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy()) @@ -396,9 +386,8 @@ def test_ones(): assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy() + np.ones_like(x.numpy())) @@ -420,9 +409,8 @@ def test_zeros_like(): assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy()) @@ -444,9 +432,8 @@ def test_ones_like(): assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy() + np.ones_like(x.numpy())) diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index fb1094becb21..c7926f7a3d79 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -180,11 +180,13 @@ def test_pass_run(): y_nd = get_rand(shape, dtype) ref_res = x_nd.numpy() + y_nd.numpy() for target, dev in tvm.testing.enabled_targets(): - exe1 = relay.create_executor("graph", device=dev, target=target) - exe2 = relay.create_executor("debug", device=dev, target=target) - res1 = exe1.evaluate(new_add)(x_nd, y_nd) + res1 = relay.create_executor("graph", device=dev, target=target).evaluate(new_add)( + x_nd, y_nd + ) tvm.testing.assert_allclose(res1.numpy(), ref_res, rtol=1e-5) - res2 = exe2.evaluate(new_add)(x_nd, y_nd) + res2 = relay.create_executor("debug", device=dev, target=target).evaluate(new_add)( + x_nd, y_nd + ) tvm.testing.assert_allclose(res2.numpy(), ref_res, rtol=1e-5) test_pass_registration() @@ -277,11 +279,9 @@ def test_pass_run(): x_nd = get_rand(shape, dtype) ref_res = np.log(x_nd.numpy() * 2) for target, dev in tvm.testing.enabled_targets(): - exe1 = relay.create_executor("graph", device=dev, target=target) - exe2 = relay.create_executor("debug", device=dev, target=target) - res1 = exe1.evaluate(new_log)(x_nd) + res1 = relay.create_executor("graph", device=dev, target=target).evaluate(new_log)(x_nd) tvm.testing.assert_allclose(res1.numpy(), ref_res, rtol=1e-5) - res2 = exe2.evaluate(new_log)(x_nd) + res2 = relay.create_executor("debug", device=dev, target=target).evaluate(new_log)(x_nd) tvm.testing.assert_allclose(res2.numpy(), ref_res, rtol=1e-5) test_pass_registration() @@ -439,22 +439,22 @@ def test_multiple_passes(): y_nd = get_rand(shape, dtype) ref_res = np.subtract(x_nd.numpy() * 2, y_nd.numpy() * 2) for target, dev in tvm.testing.enabled_targets(): - exe1 = relay.create_executor("graph", device=dev, target=target) - exe2 = relay.create_executor("debug", device=dev, target=target) - res1 = exe1.evaluate(new_sub)(x_nd, y_nd) + res1 = relay.create_executor("graph", device=dev, target=target).evaluate(new_sub)( + x_nd, y_nd + ) tvm.testing.assert_allclose(res1.numpy(), ref_res, rtol=1e-5) - res2 = exe2.evaluate(new_sub)(x_nd, y_nd) + res2 = relay.create_executor("debug", device=dev, target=target).evaluate(new_sub)( + x_nd, y_nd + ) tvm.testing.assert_allclose(res2.numpy(), ref_res, rtol=1e-5) # Execute the updated abs function. x_nd = get_rand((5, 10), dtype) ref_res = np.abs(x_nd.numpy() * 2) for target, dev in tvm.testing.enabled_targets(): - exe1 = relay.create_executor("graph", device=dev, target=target) - exe2 = relay.create_executor("debug", device=dev, target=target) - res1 = exe1.evaluate(new_abs)(x_nd) + res1 = relay.create_executor("graph", device=dev, target=target).evaluate(new_abs)(x_nd) tvm.testing.assert_allclose(res1.numpy(), ref_res, rtol=1e-5) - res2 = exe2.evaluate(new_abs)(x_nd) + res2 = relay.create_executor("debug", device=dev, target=target).evaluate(new_abs)(x_nd) tvm.testing.assert_allclose(res2.numpy(), ref_res, rtol=1e-5) test_pass_registration() @@ -507,6 +507,34 @@ def expected(): assert tvm.ir.structural_equal(zz, zexpected) +def test_nested_sequential_with_scoping(): + def before(): + x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") + w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") + y = relay.nn.conv2d(x, w, padding=(1, 1)) + y = relay.reshape(y, newshape=(1, 16, -1)) + y = relay.reshape(y, newshape=(4, 8, -1, 16)) + y = relay.reverse_reshape(y, newshape=(32, 0, -1)) + return tvm.IRModule.from_expr(y) + + def expected(): + x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") + w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") + y = relay.nn.conv2d(x, w, padding=(1, 1)) + y = relay.reshape(y, newshape=(32, 16, 16)) + return tvm.IRModule.from_expr(y) + + z = before() + passes = [ + tvm.transform.Sequential([relay.transform.SimplifyExpr()]), + ] + with tvm.transform.PassContext(opt_level=1): + zz = tvm.transform.Sequential(passes)(z) + + expected = relay.transform.InferType()(expected()) + assert tvm.ir.structural_equal(zz, expected) + + def test_print_ir(capfd): shape = (1, 2, 3) tp = relay.TensorType(shape, "float32") diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index 129ac047cd89..ce36abd83c40 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -31,9 +31,7 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07): dev = tvm.device("llvm", 0) - intrp = create_executor(mod=mod, device=dev, target="llvm") - - result = intrp.evaluate(expr) + result = create_executor(mod=mod, device=dev, target="llvm").evaluate(expr) np.testing.assert_allclose(result.numpy(), expected_result, rtol=rtol) @@ -144,9 +142,8 @@ def test_if_ref(): body = Let(eff, body, RefRead(r)) f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body))) pe_f = tipe(f) - ex = create_executor() - f_res = ex.evaluate(f)(const(True)) - pe_f_res = ex.evaluate(pe_f)(const(True)) + f_res = create_executor().evaluate(f)(const(True)) + pe_f_res = create_executor().evaluate(pe_f)(const(True)) np.testing.assert_allclose(f_res.numpy(), 2 * np.ones_like(f_res.numpy())) np.testing.assert_allclose(pe_f_res.numpy(), 2 * np.ones_like(pe_f_res.numpy())) @@ -168,9 +165,8 @@ def test_function_invalidate(): body = Let(r, RefCreate(const(0)), body) f = Function([d], body) pe_f = tipe(f) - ex = create_executor() - f_res = ex.evaluate(f)(const(True)) - pe_f_res = ex.evaluate(pe_f)(const(True)) + f_res = create_executor().evaluate(f)(const(True)) + pe_f_res = create_executor().evaluate(pe_f)(const(True)) np.testing.assert_allclose(f_res.numpy(), np.ones_like(f_res.numpy())) np.testing.assert_allclose(pe_f_res.numpy(), np.ones_like(pe_f_res.numpy())) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 98d7161ae36c..93cd6f791765 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -27,6 +27,7 @@ from tvm import relay from tvm import runtime from tvm.relay import transform +from tvm.relay.testing import byoc from tvm.contrib import utils from tvm.relay.backend import compile_engine from tvm.relay.expr_functor import ExprMutator @@ -63,59 +64,6 @@ def visit_call(self, call): return Annotator().visit(func) -class CcompilerAnnotator(ExprMutator): - """ - A simple annotator that creates the following program: - | - -- begin -- - | - add - | - subtract - | - multiply - | - -- end -- - | - """ - - def __init__(self): - super(CcompilerAnnotator, self).__init__() - self.in_compiler = 0 - - def visit_call(self, call): - if call.op.name == "add": # Annotate begin at args - if self.in_compiler == 1: - lhs = compiler_begin(super().visit(call.args[0]), "ccompiler") - rhs = compiler_begin(super().visit(call.args[1]), "ccompiler") - op = relay.add(lhs, rhs) - self.in_compiler = 2 - return op - elif call.op.name == "subtract": - if self.in_compiler == 1: - lhs = super().visit(call.args[0]) - rhs = super().visit(call.args[1]) - if isinstance(lhs, relay.expr.Var): - lhs = compiler_begin(lhs, "ccompiler") - if isinstance(rhs, relay.expr.Var): - rhs = compiler_begin(rhs, "ccompiler") - return relay.subtract(lhs, rhs) - elif call.op.name == "multiply": # Annotate end at output - self.in_compiler = 1 - lhs = super().visit(call.args[0]) - rhs = super().visit(call.args[1]) - if isinstance(lhs, relay.expr.Var): - lhs = compiler_begin(lhs, "ccompiler") - if isinstance(rhs, relay.expr.Var): - rhs = compiler_begin(rhs, "ccompiler") - op = relay.multiply(lhs, rhs) - if self.in_compiler == 2: - op = compiler_end(op, "ccompiler") - self.in_compiler = 0 - return op - return super().visit_call(call) - - class WholeGraphAnnotator(ExprMutator): """ An annotator that creates a compiler for an entire graph. @@ -261,7 +209,7 @@ def test_multi_node_compiler(): r = relay.concatenate((q0, q1, q2), axis=0) f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) mod = tvm.IRModule() - ann = CcompilerAnnotator() + ann = byoc.CcompilerAnnotator() mod["main"] = ann.visit(f) mod = transform.PartitionGraph()(mod) mod = transform.InferType()(mod) @@ -339,8 +287,57 @@ def expected(): add = x0 + y0 # Function that uses C compiler func = relay.Function([x0, y0], add) - func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_0") - glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_0") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0") + mod[glb_0] = func + add_call = relay.Call(glb_0, [x, y]) + # Function that uses default compiler. Ops are fused in this function. + p0 = relay.var("p0", shape=(8, 8)) + log = relay.log(p0) + exp = relay.exp(p0) + concat = relay.concatenate([log, exp], axis=0) + fused_func = relay.Function([p0], concat) + fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + fused_call = relay.Call(fused_func, [add_call]) + main = relay.Function([x, y], fused_call) + mod["main"] = main + mod = transform.InferType()(mod) + return mod + + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + add = x + y + log = relay.log(add) + exp = relay.exp(add) + concat = relay.concatenate([log, exp], axis=0) + f = relay.Function([x, y], concat) + mod = tvm.IRModule() + mod["main"] = f + mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod) + mod = transform.PartitionGraph()(mod) + fused_mod = transform.FuseOps(2)(mod) + expected_mod = expected() + assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True) + + x_data = np.random.rand(8, 8).astype("float32") + y_data = np.random.rand(8, 8).astype("float32") + np_add = x_data + y_data + res = np.concatenate([np.log(np_add), np.exp(np_add)]) + check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res) + + +def test_extern_ccompiler_multiple_functions(): + def expected(): + mod = tvm.IRModule() + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + x0 = relay.var("x0", shape=(8, 8)) + y0 = relay.var("y0", shape=(8, 8)) + add = x0 + y0 + # Function that uses C compiler + func = relay.Function([x0, y0], add) + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0") mod[glb_0] = func add_call = relay.Call(glb_0, [x, y]) # Function that uses default compiler. Ops are fused in this function. @@ -353,6 +350,28 @@ def expected(): fused_call = relay.Call(fused_func, [add_call]) main = relay.Function([x, y], fused_call) mod["main"] = main + # define the second one + a = relay.var("a", shape=(16, 16)) + b = relay.var("b", shape=(16, 16)) + a0 = relay.var("a0", shape=(16, 16)) + b0 = relay.var("b0", shape=(16, 16)) + add = a0 + b0 + # Function that uses C compiler + func = relay.Function([a0, b0], add) + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_subfunction_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_subfunction_0") + mod[glb_0] = func + add_call = relay.Call(glb_0, [a, b]) + # Function that uses default compiler. Ops are fused in this function. + p0 = relay.var("p0", shape=(16, 16)) + log = relay.log(p0) + exp = relay.exp(p0) + concat = relay.concatenate([log, exp], axis=0) + fused_func = relay.Function([p0], concat) + fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + fused_call = relay.Call(fused_func, [add_call]) + sunfunction = relay.Function([a, b], fused_call) + mod["subfunction"] = sunfunction mod = transform.InferType()(mod) return mod @@ -365,6 +384,15 @@ def expected(): f = relay.Function([x, y], concat) mod = tvm.IRModule() mod["main"] = f + # define second function + a = relay.var("a", shape=(16, 16)) + b = relay.var("b", shape=(16, 16)) + add = a + b + log = relay.log(add) + exp = relay.exp(add) + concat = relay.concatenate([log, exp], axis=0) + f2 = relay.Function([a, b], concat) + mod["subfunction"] = f2 mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod) mod = transform.PartitionGraph()(mod) @@ -416,8 +444,8 @@ def expected(): out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) func = relay.Function([data0, input0], out) - func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_0") - glb_var = relay.GlobalVar("tvmgen_default_dnnl_0") + func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_main_0") + glb_var = relay.GlobalVar("tvmgen_default_dnnl_main_0") mod = tvm.IRModule() mod[glb_var] = func mod = transform.InferType()(mod) @@ -456,8 +484,9 @@ def get_func(): i_data = np.random.uniform(0, 1, ishape).astype(dtype) w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()) - ref_res = ref_ex.evaluate()(i_data, w1_data) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()).evaluate()( + i_data, w1_data + ) check_result( mod, {"data": i_data, "weight1": w1_data}, (1, 32, 14, 14), ref_res.numpy(), tol=1e-5 ) @@ -476,8 +505,9 @@ def test_extern_dnnl_mobilenet(): mod = transform.PartitionGraph()(mod) i_data = np.random.uniform(0, 1, ishape).astype(dtype) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)) - ref_res = ref_ex.evaluate()(i_data, **params) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( + i_data, **params + ) compile_engine.get().clear() check_result(mod, {"data": i_data}, (1, 1000), ref_res.numpy(), tol=1e-5, params=params) @@ -532,8 +562,8 @@ def expected(): bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple()) - func0 = set_func_attr(func0, "test_compiler", "tvmgen_default_test_compiler_2") - gv0 = relay.GlobalVar("tvmgen_default_test_compiler_2") + func0 = set_func_attr(func0, "test_compiler", "tvmgen_default_test_compiler_main_2") + gv0 = relay.GlobalVar("tvmgen_default_test_compiler_main_2") mod[gv0] = func0 mod = transform.InferType()(mod) @@ -544,8 +574,8 @@ def expected(): data=data1, weight=weight1, kernel_size=(3, 3), channels=16, padding=(1, 1) ) func1 = relay.Function([data1, weight1], conv) - func1 = set_func_attr(func1, "test_compiler", "tvmgen_default_test_compiler_0") - gv1 = relay.GlobalVar("tvmgen_default_test_compiler_0") + func1 = set_func_attr(func1, "test_compiler", "tvmgen_default_test_compiler_main_0") + gv1 = relay.GlobalVar("tvmgen_default_test_compiler_main_0") mod[gv1] = func1 mod = transform.InferType()(mod) @@ -613,7 +643,7 @@ def expected(): bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple()) - func0 = set_func_attr(func0, "test_compiler", "tvmgen_default_test_compiler_0") + func0 = set_func_attr(func0, "test_compiler", "tvmgen_default_test_compiler_main_0") # main function data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32")) @@ -643,8 +673,8 @@ def expected(): add = x0 + y0 # Function that uses C compiler func = relay.Function([y0], add) - func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_0") - glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_0") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0") mod[glb_0] = func mod = relay.transform.InferType()(mod) add_call = relay.Call(glb_0, [y]) @@ -733,8 +763,8 @@ def expected(): tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2])) func0 = relay.Function([data, weight, bn_gamma, bn_beta, bn_mean, bn_var], tuple_o) - func0 = set_func_attr(func0, "test_target", "tvmgen_default_test_target_0") - gv0 = relay.GlobalVar("tvmgen_default_test_target_0") + func0 = set_func_attr(func0, "test_target", "tvmgen_default_test_target_main_0") + gv0 = relay.GlobalVar("tvmgen_default_test_target_main_0") mod[gv0] = func0 mod = relay.transform.InferType()(mod) @@ -796,8 +826,8 @@ def expected(): f1_O_2 = relay.nn.relu(f1_O_1) f1_out = relay.Tuple((f1_O_2, f1_O_1)) func1 = relay.Function([f1_cb1], f1_out) - func1 = set_func_attr(func1, "test_target", "tvmgen_default_test_target_0") - gv1 = relay.GlobalVar("tvmgen_default_test_target_0") + func1 = set_func_attr(func1, "test_target", "tvmgen_default_test_target_main_0") + gv1 = relay.GlobalVar("tvmgen_default_test_target_main_0") mod[gv1] = func1 mod = relay.transform.InferType()(mod) @@ -806,8 +836,8 @@ def expected(): f2_cb4 = relay.var("test_target_1_i1", shape=(10, 10)) f2_O_3 = relay.add(f2_cb3, f2_cb4) func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3) - func0 = set_func_attr(func0, "test_target", "tvmgen_default_test_target_1") - gv0 = relay.GlobalVar("tvmgen_default_test_target_1") + func0 = set_func_attr(func0, "test_target", "tvmgen_default_test_target_main_1") + gv0 = relay.GlobalVar("tvmgen_default_test_target_main_1") mod[gv0] = func0 mod = relay.transform.InferType()(mod) @@ -917,8 +947,9 @@ def test_partition_mobilenet(): def test_exec(mod, params, ref_mod, ref_params, out_shape): ishape = (1, 3, 224, 224) i_data = np.random.randn(*ishape).astype(np.float32) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)) - ref_res = ref_ex.evaluate()(i_data, **ref_params) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( + i_data, **ref_params + ) compile_engine.get().clear() mod = get_partitoned_mod(mod, params, dnnl_patterns) @@ -955,8 +986,8 @@ def expected_same_output_region(): mul = log * sub # The partitioned graph contains log, subtract, and multiply func = relay.Function([x0, y0], mul) - func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_0") - glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_0") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0") mod[glb_0] = func mod = transform.InferType()(mod) @@ -977,8 +1008,8 @@ def expected_different_output_region(): i0 = relay.var("i0", shape=(8, 8)) log = relay.log(i0) func = relay.Function([i0], log) - func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_0") - glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_0") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0") mod[glb_0] = func mod = transform.InferType()(mod) @@ -987,8 +1018,8 @@ def expected_different_output_region(): y0 = relay.var("y0", shape=(8, 8)) sub = x0 - y0 func = relay.Function([x0, y0], sub) - func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_1") - glb_1 = relay.GlobalVar("tvmgen_default_ccompiler_1") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_1") + glb_1 = relay.GlobalVar("tvmgen_default_ccompiler_main_1") mod[glb_1] = func mod = transform.InferType()(mod) @@ -1063,8 +1094,8 @@ def expected(): func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", target) - func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_0") - gv0 = relay.GlobalVar("tvmgen_default_" + target + "_0") + func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_main_0") + gv0 = relay.GlobalVar("tvmgen_default_" + target + "_main_0") mod[gv0] = func0 mod = transform.InferType()(mod) @@ -1140,8 +1171,8 @@ def expected(): func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", target) - func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_0") - gv0 = relay.GlobalVar("tvmgen_default_" + target + "_0") + func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_main_0") + gv0 = relay.GlobalVar("tvmgen_default_" + target + "_main_0") mod[gv0] = func0 mod = transform.InferType()(mod) @@ -1216,7 +1247,7 @@ def create_graph(): partitioned = seq(create_graph()) - concat = partitioned["tvmgen_default_const_tuples_0"].body + concat = partitioned["tvmgen_default_const_tuples_main_0"].body assert type(concat.args[1]) == relay.Tuple assert type(concat.args[2]) == relay.Tuple assert type(concat.args[3]) == relay.Constant @@ -1266,8 +1297,8 @@ def expected(): func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", target) - func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_0") - gv0 = relay.GlobalVar("tvmgen_default_" + target + "_0") + func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_main_0") + gv0 = relay.GlobalVar("tvmgen_default_" + target + "_main_0") mod[gv0] = func0 mod = transform.InferType()(mod) @@ -1349,9 +1380,9 @@ def Optimize(mod): mod = transform.PartitionGraph()(mod) try: - t0 = mod["tvmgen_default_test_target_0"] + t0 = mod["tvmgen_default_test_target_main_0"] except: - raise KeyError("test_target_0 not found") + raise KeyError("test_target_main_0 not found") assert isinstance(t0.body, relay.Constant) expected = np.empty([2, 2]) @@ -1359,10 +1390,39 @@ def Optimize(mod): tvm.testing.assert_allclose(t0.body.data.numpy(), expected, rtol=1e-5, atol=1e-5) +def test_preserve_type_import(): + """Test to make sure type definition and imports are preserved during the BYOC pipeline.""" + from tvm.relay.prelude import Prelude, StaticTensorArrayOps + + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + tensor_array = p.get_global_var_static("tensor_array", dtype, shape) + tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + write = p.get_global_var_static("tensor_array_write", dtype, shape) + gather = p.get_global_var_static("tensor_array_gather", dtype, shape) + v = relay.var("v") + indice = relay.var("indice") + init_tensor_array = tensor_array(relay.const(3)) + tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v)) + tensor_array2 = write(tensor_array1, relay.const(1), tensor(v)) + tensor_array3 = write(tensor_array2, relay.const(2), tensor(v)) + out = gather(tensor_array3, indice) + mod["main"] = relay.Function([v, indice], out) + mod = transform.RemoveUnusedFunctions()(mod) + mod = transform.PartitionGraph()(mod) + + run("float32", [2, 3]) + + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op() test_extern_ccompiler_default_ops() + test_extern_ccompiler_multiple_functions() test_extern_ccompiler() test_extern_dnnl() test_extern_dnnl_mobilenet() @@ -1379,3 +1439,4 @@ def Optimize(mod): test_flatten_tuple_output() test_tuple_output_exec() test_extern_opt() + test_static_tensor_array_gather_partition() diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index 61e5b8ea9407..cd2e5d2fd249 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -37,9 +37,7 @@ def run_opt_pass(expr, passes): def check_eval(expr, expected_result, mod=None, rtol=1e-07): dev = tvm.device("llvm", 0) - intrp = create_executor(mod=mod, device=dev, target="llvm") - - result = intrp.evaluate(expr) + result = create_executor(mod=mod, device=dev, target="llvm").evaluate(expr) np.testing.assert_allclose(result.numpy(), expected_result, rtol=rtol) @@ -151,6 +149,7 @@ def test_nat_add(): add = p.mod.get_global_var("nat_add") dev = tvm.device("llvm", 0) intrp = create_executor(mod=mod, device=dev, target="llvm") + # CAUTION: Following calls to intrp.evaluate(...) will re-prepare the prelude. assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 expr = add(s(z()), s(z())) diff --git a/tests/python/relay/test_pass_to_basic_block_normal_form.py b/tests/python/relay/test_pass_to_basic_block_normal_form.py index d345d465c53e..d04afe15b5bb 100644 --- a/tests/python/relay/test_pass_to_basic_block_normal_form.py +++ b/tests/python/relay/test_pass_to_basic_block_normal_form.py @@ -22,7 +22,7 @@ from tvm.relay.analysis import detect_feature from tvm.relay import op, create_executor, transform from tvm.relay.prelude import Prelude -from tvm.relay.testing import count +from tvm.relay.testing import count, create_workload from tvm.relay.analysis import Feature from tvm.relay.analysis import check_basic_block_normal_form @@ -39,9 +39,7 @@ def run_opt_pass(expr, passes): def check_eval(expr, expected_result, mod=None, rtol=1e-07): dev = tvm.device("llvm", 0) - intrp = create_executor(mod=mod, device=dev, target="llvm") - - result = intrp.evaluate(expr) + result = create_executor(mod=mod, device=dev, target="llvm").evaluate(expr) np.testing.assert_allclose(result.numpy(), expected_result, rtol=rtol) @@ -267,16 +265,20 @@ def test_nat_add(): nat, z, s = p.mod.get_type("nat") add = p.mod.get_global_var("nat_add") dev = tvm.device("llvm", 0) - intrp = create_executor(mod=mod, device=dev, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) - assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 + assert ( + count(p, create_executor(mod=mod, device=dev, target="llvm").evaluate(add(s(z()), s(z())))) + == 2 + ) expr = add(s(z()), s(z())) f = relay.GlobalVar("f") mod[f] = relay.Function([], expr) mod = transform.InferType()(mod) mod = transform.ToBasicBlockNormalForm()(mod) opt_expr = mod["f"] - assert count(p, intrp.evaluate(opt_expr.body)) == 2 + assert ( + count(p, create_executor(mod=mod, device=dev, target="llvm").evaluate(opt_expr.body)) == 2 + ) assert not Feature.fLet in detect_feature(mod[add]) check_basic_block_normal_form(opt_expr) @@ -489,5 +491,27 @@ def test_higher_order_nested(): check_basic_block_normal_form(bblock) +def test_immutability(): + simple_net = relay.nn.conv2d( + data=relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")), + weight=relay.var("weight"), + kernel_size=(5, 5), + channels=3, + padding=(1, 1), + ) + simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net) + mod, _ = create_workload(simple_net) + + old_mod = mod + + with tvm.transform.PassContext(opt_level=4): + with tvm.target.Target("llvm"): + seq = tvm.transform.Sequential(passes=[transform.ToBasicBlockNormalForm()], opt_level=4) + new_mod = seq(mod) + + assert old_mod.astext() == mod.astext() + assert old_mod.astext() != new_mod.astext() + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index 0cde1d9ae492..4825cc29e6e4 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -58,9 +58,8 @@ def test_recursion(): mod["main"] = to_cps(mod["main"], mod=mod) mod = relay.transform.InferType()(mod) mod["main"] = un_cps(mod["main"]) - ex = create_executor(mod=mod) i_nd = rand(dtype, *shape) - forward = ex.evaluate()(i_nd) + forward = create_executor(mod=mod).evaluate()(i_nd) tvm.testing.assert_allclose(forward.numpy(), 8 * i_nd.numpy()) diff --git a/tests/python/relay/test_pass_to_graph_normal_form.py b/tests/python/relay/test_pass_to_graph_normal_form.py index 4f5084d83f9c..6a8c99d076e4 100644 --- a/tests/python/relay/test_pass_to_graph_normal_form.py +++ b/tests/python/relay/test_pass_to_graph_normal_form.py @@ -34,9 +34,7 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): mod = tvm.IRModule() dev = tvm.device("llvm", 0) - intrp = create_executor(mod=mod, device=dev, target="llvm") - - result = intrp.evaluate(expr)(*args) + result = create_executor(mod=mod, device=dev, target="llvm").evaluate(expr)(*args) np.testing.assert_allclose(result.numpy(), expected_result, rtol=rtol) diff --git a/tests/python/relay/test_prng.py b/tests/python/relay/test_prng.py index f79c79329b67..79ed014c5503 100644 --- a/tests/python/relay/test_prng.py +++ b/tests/python/relay/test_prng.py @@ -92,7 +92,7 @@ def test_threefry_sequential_generate_remaining(target, dev): ).evaluate()() assert ( - out1.asnumpy()[-3:] != out2.asnumpy()[-3:] + out1.numpy()[-3:] != out2.numpy()[-3:] ).any(), "Sequential generates should not have the same output" diff --git a/tests/python/relay/test_sparse_conv2d_convert.py b/tests/python/relay/test_sparse_conv2d_convert.py index 0af78fc033ac..045462475ee1 100644 --- a/tests/python/relay/test_sparse_conv2d_convert.py +++ b/tests/python/relay/test_sparse_conv2d_convert.py @@ -25,6 +25,7 @@ from tvm.ir import IRModule from tvm import relay from tvm.topi.sparse.utils import random_bsr_matrix +from tvm.relay.build_module import bind_params_by_name def run_func(func, params, x): @@ -100,6 +101,68 @@ def test_bsr_sparse_conv2d_nhwc(): np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) +def test_bsr_sparse_conv2d_3x3_nchw(): + data = relay.var("data", shape=(1, 64, 32, 32), dtype="float32") + x = relay.nn.relu(data) + w = relay.var("weight", shape=(128, 64, 3, 3), dtype="float32") + y = relay.nn.conv2d( + x, w, channels=128, kernel_size=3, padding=1, data_layout="NCHW", kernel_layout="OIHW" + ) + z = relay.nn.relu(y) + func = relay.Function(relay.analysis.free_vars(z), z) + + params = { + "weight": tvm.nd.array( + np.array(random_bsr_matrix(128, 64 * 9, 16, 1, 0.1, "float32").todense()).reshape( + 128, 64, 3, 3 + ) + ) + } + + x_np = np.random.randn(1, 64, 32, 32).astype("float32") + # dense output + dense_output = run_func(func, params, x_np) + # sparse + func = bind_params_by_name(func, params) + sparse_func, params = relay.data_dep_optimization.bsr_conv2d.convert2( + func, {}, (16, 1), 0.2, "NCHW", 3 + ) + sparse_output = run_func(sparse_func, params, x_np) + np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) + + +def test_bsr_sparse_conv2d_3x3_nhwc(): + data = relay.var("data", shape=(1, 32, 32, 64), dtype="float32") + x = relay.nn.relu(data) + w = relay.var("weight", shape=(3, 3, 64, 128), dtype="float32") + y = relay.nn.conv2d( + x, w, channels=128, kernel_size=3, padding=1, data_layout="NHWC", kernel_layout="HWIO" + ) + z = relay.nn.relu(y) + func = relay.Function(relay.analysis.free_vars(z), z) + + params = { + "weight": tvm.nd.array( + np.array(random_bsr_matrix(128, 64 * 9, 16, 1, 0.1, "float32").todense()).T.reshape( + 3, 3, 64, 128 + ) + ) + } + + x_np = np.random.randn(1, 32, 32, 64).astype("float32") + # dense output + dense_output = run_func(func, params, x_np) + # sparse + func = bind_params_by_name(func, params) + sparse_func, params = relay.data_dep_optimization.bsr_conv2d.convert2( + func, {}, (16, 1), 0.2, "NHWC", 3 + ) + sparse_output = run_func(sparse_func, params, x_np) + np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": test_bsr_sparse_conv2d_nhwc() test_bsr_sparse_conv2d_nchw() + test_bsr_sparse_conv2d_3x3_nhwc() + test_bsr_sparse_conv2d_3x3_nchw() diff --git a/tests/python/relay/test_tensor_array.py b/tests/python/relay/test_tensor_array.py index e93831bef95f..21043abb3c84 100644 --- a/tests/python/relay/test_tensor_array.py +++ b/tests/python/relay/test_tensor_array.py @@ -63,8 +63,9 @@ def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", rtol=1e-5): for target, dev in [("llvm", tvm.cpu(0))]: # testing.enabled_targets(): if kind == "debug" and dev.device_type != tvm.cpu().device_type: continue - ex = relay.create_executor(kind, mod=ta_mod, device=dev, target=target) - result = ex.evaluate()(*args) + result = relay.create_executor(kind, mod=ta_mod, device=dev, target=target).evaluate()( + *args + ) got = vmobj_to_list(ta_mod, result, dtype) tvm.testing.assert_allclose(ref_res, got, rtol=rtol, atol=rtol) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 7a3fbfafc089..472f98715ec5 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -27,13 +27,12 @@ def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List: dev = tvm.device("llvm", 0) - intrp = relay.create_executor("debug", mod, device=dev, target="llvm") - result = intrp.evaluate()(**mod_params) + result = relay.create_executor("debug", mod, device=dev, target="llvm").evaluate()(**mod_params) if isinstance(result, tvm.runtime.container.ADT): - result = [r.asnumpy() for r in result] + result = [r.numpy() for r in result] return result else: - return [result.asnumpy()] + return [result.numpy()] def verify_mixed_precision_output_close( @@ -222,12 +221,36 @@ def test_do_not_convert_softmax(): b = relay.nn.softmax(a) mod = tvm.IRModule.from_expr(b) mod = tvm.relay.transform.InferType()(mod) + out_mod = ToMixedPrecision("float16")(mod) + orig_mod = tvm.relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(orig_mod, out_mod) - mod_params = { - "a": np.random.uniform(-1, 1, size=shape).astype("float32"), - } - output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0) - assert tvm.ir.structural_equal(mod, output_mod) + +def test_do_not_convert_arange(): + """Arange is a red listed operation and therefore should never be fp16.""" + dtype = "float32" + arange = relay.arange(relay.const(1, dtype), relay.const(128, dtype)) + mod = tvm.IRModule.from_expr(arange) + out_mod = ToMixedPrecision("float16")(mod) + orig_mod = tvm.relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(orig_mod, out_mod) + + +def test_do_not_convert_summation(): + """Ops that could involve a large summation are not allowed in fp16.""" + shape = [1, 3, 16, 16] + a = relay.var("a", shape=shape) + ops = [ + relay.sum, + relay.mean, + relay.nn.global_avg_pool2d, + lambda inp: relay.nn.adaptive_avg_pool2d(inp, (1, 1)), + ] + for op in ops: + mod = tvm.IRModule.from_expr(op(a)) + out_mod = ToMixedPrecision("float16")(mod) + orig_mod = tvm.relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(orig_mod, out_mod) def test_green_gray_propagates_simple(): @@ -363,7 +386,7 @@ def test_let_statement_simple(): "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"), "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"), } - output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.05, rtol=0.15) # Construct expected structure var1 = relay.var("var1", shape=[1, 20], dtype="float16") diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 151e5ecc160b..4c5b98514724 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -17,6 +17,7 @@ import numpy as np import pytest import time +from unittest.mock import patch import tvm from tvm import runtime @@ -30,6 +31,7 @@ from tvm import rpc import tvm.testing from tvm.relay.transform import InferType +from tvm.relay.testing import mlp def check_result(args, expected_result, mod=None): @@ -46,8 +48,9 @@ def check_result(args, expected_result, mod=None): The expected result of running the expression. """ for target, dev in tvm.testing.enabled_targets(): - vm = relay.create_executor("vm", device=dev, target=target, mod=mod) - rts_result = vm.evaluate()(*args) + rts_result = relay.create_executor("vm", device=dev, target=target, mod=mod).evaluate()( + *args + ) tvm.testing.assert_allclose(expected_result, rts_result.numpy()) @@ -182,8 +185,8 @@ def test_multiple_ifs(): fn = relay.Function([b], out) mod["main"] = fn dev = tvm.runtime.device("llvm", 0) - vm = relay.create_executor(device=dev, mod=mod, kind="vm") - res = vmobj_to_list(vm.evaluate()(False)) + func = relay.create_executor(device=dev, mod=mod, kind="vm").evaluate() + res = vmobj_to_list(func(False)) assert res == [1, 0] @@ -647,13 +650,39 @@ def test_add_op_scalar(): } """ mod = tvm.IRModule() - x = relay.var("x", shape=()) - y = relay.var("y", shape=()) + x = relay.var("x", shape=()) # Default to float32 + y = relay.var("y", shape=()) # Default to float32 func = relay.Function([x, y], relay.op.add(x, y)) - x_data = np.array(10.0, dtype="float32") - y_data = np.array(1.0, dtype="float32") - mod["main"] = func - check_result([x_data, y_data], x_data + y_data, mod=mod) + x_y_data = [ + (np.array(10.0, dtype="float32"), np.array(1.0, dtype="float32")), + (np.float32(10.0), np.float32(1.0)), + (10.0, 1.0), + ] + for (x_data, y_data) in x_y_data: + mod["main"] = func + check_result([x_data, y_data], x_data + y_data, mod=mod) + + +@tvm.testing.uses_gpu +def test_add_op_scalar_int(): + """ + test_add_op_scalar_int: + fn (x, y) { + return x + y; + } + """ + mod = tvm.IRModule() + x = relay.var("x", shape=(), dtype="int32") + y = relay.var("y", shape=(), dtype="int32") + func = relay.Function([x, y], relay.op.add(x, y)) + x_y_data = [ + (np.array(10.0, dtype="int32"), np.array(1.0, dtype="int32")), + (np.int32(10), np.int32(1)), + (10, 1), + ] + for (x_data, y_data) in x_y_data: + mod["main"] = func + check_result([x_data, y_data], x_data + y_data, mod=mod) @tvm.testing.uses_gpu @@ -853,29 +882,25 @@ def test_vm_rpc(): # Use local rpc server for testing. # Server must use popen so it doesn't inherit the current process state. It # will crash otherwise. - server = rpc.Server("localhost", port=9120) - remote = rpc.connect(server.host, server.port, session_timeout=10) - - # Upload the serialized Executable. - remote.upload(path) - # Get a handle to remote Executable. - rexec = remote.load_module("vm_library.so") + def check_remote(server): + remote = rpc.connect(server.host, server.port, session_timeout=10) - ctx = remote.cpu() - # Build a VM out of the executable and context. - vm_factory = runtime.vm.VirtualMachine(rexec, ctx) - np_input = np.random.uniform(size=(10, 1)).astype("float32") - input_tensor = tvm.nd.array(np_input, ctx) - # Invoke its "main" function. - out = vm_factory.invoke("main", input_tensor) - # Check the result. - np.testing.assert_allclose(out.numpy(), np_input + np_input) + # Upload the serialized Executable. + remote.upload(path) + # Get a handle to remote Executable. + rexec = remote.load_module("vm_library.so") - # delete tensors before the server shuts down so we don't throw errors. - del input_tensor - del out + ctx = remote.cpu() + # Build a VM out of the executable and context. + vm_factory = runtime.vm.VirtualMachine(rexec, ctx) + np_input = np.random.uniform(size=(10, 1)).astype("float32") + input_tensor = tvm.nd.array(np_input, ctx) + # Invoke its "main" function. + out = vm_factory.invoke("main", input_tensor) + # Check the result. + np.testing.assert_allclose(out.numpy(), np_input + np_input) - server.terminate() + check_remote(rpc.Server("127.0.0.1")) def test_get_output_single(): @@ -915,5 +940,78 @@ def test_get_output_multiple(): np.testing.assert_allclose(outputs[1].numpy(), inp) +def test_get_input_index(): + target = tvm.target.Target("llvm") + + # Build a IRModule. + data_0, data_1 = ["d1", "d2"] + x, y = [relay.var(c, shape=(10,)) for c in [data_0, data_1]] + f = relay.Function([x, y], x + y) + mod = IRModule.from_expr(f) + + # Compile to VMExecutable. + vm_exec = vm.compile(mod, target=target) + vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu()) + assert vm_factory.get_input_index(data_1) == 1 + assert vm_factory.get_input_index(data_0) == 0 + assert vm_factory.get_input_index("invalid") == -1 + + +@tvm.testing.requires_llvm +def test_benchmark(): + mod, params = mlp.get_workload(1) + lib = vm.compile(mod, target="llvm", params=params) + exe = runtime.vm.VirtualMachine(lib, tvm.cpu()) + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32")) + result = exe.benchmark(tvm.cpu(), data, func_name="main", repeat=2, number=1) + assert result.mean == result.median + assert result.mean > 0 + assert len(result.results) == 2 + + with patch.object( + tvm.runtime.module.Module, + "time_evaluator", + return_value=lambda x: tvm.runtime.module.BenchmarkResult([1, 2, 2, 5]), + ) as method: + result = exe.benchmark(tvm.cpu(), data, func_name="main", repeat=2, number=1) + assert result.mean == 2.5 + assert result.median == 2.0 + assert result.max == 5 + assert result.min == 1 + assert result.std == 1.5 + + +@tvm.testing.parametrize_targets("cuda", "llvm") +def test_benchmark_end_to_end(dev, target): + mod, params = mlp.get_workload(1) + lib = vm.compile(mod, target=target, params=params) + exe = runtime.vm.VirtualMachine(lib, dev) + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32"), device=dev) + result = exe.benchmark(dev, data, func_name="main", repeat=2, number=1, end_to_end=True) + assert result.mean > 0 + + +@tvm.testing.requires_llvm +def test_benchmark_end_to_end_rpc(): + server = rpc.Server("127.0.0.1") + remote = rpc.connect(server.host, server.port) + + mod, params = mlp.get_workload(1) + lib = vm.compile(mod, target="llvm", params=params) + + temp = utils.tempdir() + path = temp.relpath("vm_library.so") + lib.mod.export_library(path) + remote.upload(path) + rlib = remote.load_module("vm_library.so") + + exe = runtime.vm.VirtualMachine(rlib, remote.cpu()) + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32"), device=remote.cpu()) + result = exe.benchmark( + remote.cpu(), data=data, func_name="main", repeat=2, number=1, end_to_end=True + ) + assert result.mean > 0 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index ef7d9111b84c..f579f74a24ac 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -54,8 +54,7 @@ def get_serialized_output(mod, *data, params=None, target="llvm", device=tvm.cpu def run_network(mod, params, dtype="float32"): def get_vm_output(mod, data, params, target, device, dtype="float32"): - ex = relay.create_executor("vm", mod=mod, device=device) - result = ex.evaluate()(data, **params) + result = relay.create_executor("vm", mod=mod, device=device).evaluate()(data, **params) return result.numpy().astype(dtype) data_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] diff --git a/tests/python/topi/python/test_topi_batch_matmul.py b/tests/python/topi/python/test_topi_batch_matmul.py index 8c8ad37287dc..9bd9dd286b1a 100644 --- a/tests/python/topi/python/test_topi_batch_matmul.py +++ b/tests/python/topi/python/test_topi_batch_matmul.py @@ -85,7 +85,8 @@ def check_device(target, dev): tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) for target, dev in tvm.testing.enabled_targets(): - if dynamic and (target == "cuda" or target == "nvptx"): + target_kind = tvm.target.Target(target).kind.name + if dynamic and target_kind in ["cuda", "nvptx", "vulkan", "opencl"]: print("Dynamic batch matmul test is skippped on %s" % target) continue diff --git a/tests/python/topi/python/test_topi_conv2d_nchw.py b/tests/python/topi/python/test_topi_conv2d_nchw.py index 8dbe94b45a2f..96a7ff9b926c 100644 --- a/tests/python/topi/python/test_topi_conv2d_nchw.py +++ b/tests/python/topi/python/test_topi_conv2d_nchw.py @@ -16,13 +16,15 @@ # under the License. """Example code to do convolution.""" +import sys + +import pytest import numpy as np + import tvm -from tvm import te -from tvm import autotvm -from tvm import topi +from tvm import autotvm, te, topi import tvm.topi.testing -from tvm.contrib.pickle_memoize import memoize +from tvm.contrib import cudnn from tvm.topi.nn.utils import get_pad_tuple from tvm.topi.utils import get_const_tuple from tvm.topi.nn.conv2d import _get_workload @@ -30,238 +32,323 @@ import tvm.testing +dtype = tvm.testing.parameter("float16", "float32") +random_seed = tvm.testing.parameter(0) -def verify_conv2d_nchw( - batch, - in_channel, - in_size, - num_filter, - kernel, - stride, - padding, - dilation=1, - add_bias=False, - add_relu=False, - use_cudnn=False, -): - pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) - padding_sum = pad_top + pad_left + pad_bottom + pad_right - print( - "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" - % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation) - ) +@tvm.testing.fixture +def input_shape(batch, in_channel, in_size): + return (batch, in_channel, in_size, in_size) - in_height = in_width = in_size - - A = te.placeholder((batch, in_channel, in_height, in_width), name="A") - W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W") - bias = te.placeholder((num_filter, 1, 1), name="bias") - - a_shape = get_const_tuple(A.shape) - w_shape = get_const_tuple(W.shape) - bias_shape = get_const_tuple(bias.shape) - dtype = A.dtype - - @memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = np.random.uniform(size=bias_shape).astype(dtype) - dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) - if add_bias: - c_np += b_np - if add_relu: - c_np = np.maximum(c_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - def verify_workload_padding(): - _, _, out_height, out_width = get_const_tuple(c_np.shape) - wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype) - - # check if tile_ow candidates are the factors of the right output weight. - cfg = autotvm.get_config() - _fallback_schedule(cfg, wkl) - ow_tile = np.prod(cfg["tile_ow"].size) - tvm.testing.assert_allclose(ow_tile, out_width) +@tvm.testing.fixture +def weight_shape(num_filter, in_channel, kernel): + return (num_filter, in_channel, kernel, kernel) - def check_target(target): - dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - print("Running on target: %s" % target) - if "cudnn" in target: - fcompute, fschedule = topi.cuda.conv2d_cudnn, topi.cuda.schedule_conv2d_cudnn - else: - fcompute, fschedule = tvm.topi.testing.get_conv2d_nchw_implement(target) +@tvm.testing.fixture +def bias_shape(num_filter): + return (num_filter, 1, 1) - with tvm.target.Target(target): - if "cudnn" in target: - C = fcompute( - A, W, (stride, stride), padding, (dilation, dilation), 1, "NCHW", dtype - ) + +@tvm.testing.fixture(cache_return_value=True) +def ref_data( + random_seed, + input_shape, + weight_shape, + bias_shape, + dtype, + stride, + padding, + dilation, + add_bias, + apply_relu, +): + np.random.seed(random_seed) + + # scipy.signal.convolve2d does not support float16 data types, and + # the python fallback is too slow for general use. Computing + # ref_data in float32 will have fewer rounding errors than the TVM + # float16 compute, but those vary based on schedule anyways. + conv_dtype = "float32" if dtype == "float16" else dtype + + a_np = np.random.uniform(size=input_shape).astype(dtype) + w_np = np.random.uniform(size=weight_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + c_np = tvm.topi.testing.conv2d_nchw_python( + a_np.astype(conv_dtype), dw_np.astype(conv_dtype), stride, padding + ).astype(dtype) + + if add_bias: + c_np = c_np + b_np + if apply_relu: + c_np = np.maximum(c_np, 0) + return a_np, w_np, b_np, c_np + + +class BaseConv2DTests: + add_bias = tvm.testing.parameter(False) + apply_relu = tvm.testing.parameter(False) + dilation = tvm.testing.parameter(1) + batch = tvm.testing.parameter(1) + + def test_conv2d_nchw( + self, + target, + dev, + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dtype, + ref_data, + dilation, + add_bias, + apply_relu, + ): + target = tvm.target.Target(target) + is_cudnn_target = target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []) + + if target.kind.name == "vulkan" and dtype == "float16": + if not target.attrs.get("supports_float16", False) or not target.attrs.get( + "supports_16bit_buffer", False + ): + pytest.xfail("Vulkan device does not support float16") + + if ( + target.kind.name == "cuda" + and dtype == "float16" + and not tvm.contrib.nvcc.have_fp16(dev.compute_version) + ): + pytest.xfail("CUDA float16 intrinsics not available") + + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + + has_asymmetric_padding = (pad_top != pad_bottom) or (pad_left != pad_right) + if is_cudnn_target and has_asymmetric_padding: + pytest.xfail("CuDNN does not support asymmetric padding") + + a_np, w_np, b_np, c_np = ref_data + + A = te.placeholder(a_np.shape, name="A", dtype=dtype) + W = te.placeholder(w_np.shape, name="W", dtype=dtype) + bias = te.placeholder(b_np.shape, name="bias", dtype=dtype) + + if "int" in dtype: + tol = {"atol": 0, "rtol": 0} + elif dtype == "float32": + tol = {"rtol": 1e-4, "atol": 2e-4} + elif dtype == "float16": + # A summation in float16 with a single accumulator very + # quickly runs into large rounding errors. At some point, + # this tolerance should be schedule-dependent for to avoid + # false negatives. + num_values_summed = in_channel * kernel * kernel + gap_size = np.nextafter(c_np.max(), np.inf, dtype=c_np.dtype) - c_np.max() + tol = {"rtol": 1e-3, "atol": num_values_summed * gap_size / 2} + + with autotvm.tophub.context(target): # load tophub pre-tuned parameters + if is_cudnn_target: + fcompute, fschedule = topi.cuda.conv2d_cudnn, topi.cuda.schedule_conv2d_cudnn else: - C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype) - if add_bias: - C = topi.add(C, bias) - if add_relu: - C = topi.nn.relu(C) - s = fschedule([C]) - - if "llvm" in target: - verify_workload_padding() - - a = tvm.nd.array(a_np, dev) - w = tvm.nd.array(w_np, dev) - b = tvm.nd.array(b_np, dev) - - c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) - if add_bias: + fcompute, fschedule = tvm.topi.testing.get_conv2d_nchw_implement(target) + + with target: + if is_cudnn_target: + C = fcompute( + A, W, (stride, stride), padding, (dilation, dilation), 1, "NCHW", dtype + ) + else: + C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype) + if add_bias: + C = topi.add(C, bias) + if apply_relu: + C = topi.nn.relu(C) + s = fschedule([C]) + + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(b_np, dev) + + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) func = tvm.build( s, [A, W, bias, C], target, - name="relu_%d_%d_%d_%d_%d_%d_%d_%d" - % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), + name="conv2d_{}_{}_{}_{}_{}_{}_{}_{}_{}".format( + dtype, + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding_sum, + dilation, + ), ) func(a, w, b, c) - else: - func = tvm.build( - s, - [A, W, C], - target, - name="relu_%d_%d_%d_%d_%d_%d_%d_%d" - % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), - ) - func(a, w, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4) + tvm.testing.assert_allclose(c.numpy(), c_np, **tol) + + @tvm.testing.parametrize_targets("llvm") + def test_workload_padding( + self, + target, + input_shape, + weight_shape, + stride, + padding, + dilation, + dtype, + ref_data, + ): + a_np, w_np, b_np, c_np = ref_data + _, _, out_height, out_width = c_np.shape + + A = te.placeholder(input_shape, name="A", dtype=dtype) + W = te.placeholder(weight_shape, name="W", dtype=dtype) - for target, dev in tvm.testing.enabled_targets(): - with autotvm.tophub.context(target): # load tophub pre-tuned parameters - check_target(target) - - if use_cudnn: - check_target("cuda -model=unknown -libs=cudnn") - if ("opencl", tvm.device("opencl")) in tvm.testing.enabled_targets(): - check_target("opencl -device=intel_graphics") - - -@tvm.testing.uses_gpu -def test_conv2d_nchw(): - # ResNet18 workloads - verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0) - verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1) - verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0) - verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) - verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1) - verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0) - verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) - verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) - verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) - verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) - - # bias, relu - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_relu=True) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True) - - # dilation = 2 - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, dilation=2) - - # batch size - verify_conv2d_nchw(4, 64, 56, 64, 3, 1, 1) - verify_conv2d_nchw(9, 64, 56, 64, 3, 1, 1) - - # weird workloads - verify_conv2d_nchw(2, 2, 2, 2, 2, 2, 2) - verify_conv2d_nchw(3, 3, 3, 3, 3, 3, 3) - verify_conv2d_nchw(4, 4, 4, 4, 4, 4, 4) - verify_conv2d_nchw(5, 5, 5, 5, 5, 5, 5) - verify_conv2d_nchw(6, 6, 6, 6, 6, 6, 6) - - # disable these tests due to some bugs of llvm with nvptx - # verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=1) - # verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=2) - # verify_conv2d_nchw(2, 13, 71, 59, 3, 1, 1) - - # inception v3 workloads - verify_conv2d_nchw(1, 3, 299, 32, 3, 2, 0) - verify_conv2d_nchw(1, 32, 149, 32, 3, 1, 0) - verify_conv2d_nchw(1, 32, 147, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 73, 80, 1, 1, 0) - verify_conv2d_nchw(1, 80, 73, 192, 3, 1, 0) - verify_conv2d_nchw(1, 192, 35, 64, 1, 1, 0) - verify_conv2d_nchw(1, 192, 35, 48, 1, 1, 0) - verify_conv2d_nchw(1, 48, 35, 64, 5, 1, 2) - verify_conv2d_nchw(1, 64, 35, 96, 3, 1, 1) - verify_conv2d_nchw(1, 96, 35, 96, 3, 1, 1) - verify_conv2d_nchw(1, 192, 35, 32, 1, 1, 0) - verify_conv2d_nchw(1, 256, 35, 64, 1, 1, 0) - verify_conv2d_nchw(1, 256, 35, 48, 1, 1, 0) - verify_conv2d_nchw(1, 288, 35, 64, 1, 1, 0) - verify_conv2d_nchw(1, 288, 35, 48, 1, 1, 0) - verify_conv2d_nchw(1, 288, 35, 384, 3, 2, 0) - verify_conv2d_nchw(1, 96, 35, 96, 3, 2, 0) - verify_conv2d_nchw(1, 768, 17, 192, 1, 1, 0) - verify_conv2d_nchw(1, 768, 17, 128, 1, 1, 0) - verify_conv2d_nchw(1, 128, 17, 128, 1, 1, 0) - verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3) - verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3) - verify_conv2d_nchw(1, 128, 17, 192, 1, 1, 0) - verify_conv2d_nchw(1, 768, 17, 160, 1, 1, 0) - # disable these tests due to some bugs of llvm with nvptx - # verify_conv2d_nchw(1, 160, 17, 160, 1, 1, 0) - verify_conv2d_nchw(1, 160, 17, 192, 7, 1, 3) - verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3) - verify_conv2d_nchw(1, 160, 17, 192, 1, 1, 0) - verify_conv2d_nchw(1, 192, 17, 192, 1, 1, 0) - verify_conv2d_nchw(1, 192, 17, 192, 7, 1, 3) - verify_conv2d_nchw(1, 192, 17, 320, 3, 2, 0) - verify_conv2d_nchw(1, 192, 17, 192, 3, 2, 0) - verify_conv2d_nchw(1, 1280, 8, 320, 1, 1, 0) - verify_conv2d_nchw(1, 1280, 8, 384, 1, 1, 0) - verify_conv2d_nchw(1, 384, 8, 384, 1, 1, 0) - verify_conv2d_nchw(1, 384, 8, 384, 3, 1, 1) - verify_conv2d_nchw(1, 1280, 8, 448, 1, 1, 0) - verify_conv2d_nchw(1, 448, 8, 384, 3, 1, 1) - verify_conv2d_nchw(1, 1280, 8, 192, 1, 1, 0) - verify_conv2d_nchw(1, 2048, 8, 320, 1, 1, 0) - verify_conv2d_nchw(1, 2048, 8, 384, 1, 1, 0) - verify_conv2d_nchw(1, 2048, 8, 448, 1, 1, 0) - verify_conv2d_nchw(1, 2048, 8, 192, 1, 1, 0) - verify_conv2d_nchw(1, 1024, 19, 84, 3, 1, 1) - verify_conv2d_nchw(1, 2048, 10, 126, 3, 1, 1) - verify_conv2d_nchw(1, 512, 5, 126, 3, 1, 1) - verify_conv2d_nchw(1, 256, 3, 126, 3, 1, 1) - - # Asymmetric padding - verify_conv2d_nchw(1, 3, 35, 64, 7, 2, (0, 0, 1, 1)) - verify_conv2d_nchw(1, 64, 8, 128, 3, 1, (3, 3, 2, 2)) - verify_conv2d_nchw(1, 64, 8, 64, 1, 1, (1, 2, 2, 1)) - verify_conv2d_nchw(1, 64, 17, 192, 1, 1, (1, 2)) - verify_conv2d_nchw(1, 64, 8, 64, 3, 1, (3, 1)) - verify_conv2d_nchw(1, 128, 8, 384, 3, 1, (0, 2)) - verify_conv2d_nchw(1, 64, 35, 64, 3, 1, (1, 2), use_cudnn=True) - verify_conv2d_nchw(1, 64, 8, 64, 1, 1, "VALID") - verify_conv2d_nchw(1, 388, 8, 64, 3, 1, "VALID") - verify_conv2d_nchw(1, 64, 10, 48, 3, 1, "VALID", use_cudnn=True) - verify_conv2d_nchw(1, 512, 19, 64, 1, 1, "SAME") - verify_conv2d_nchw(1, 64, 5, 32, 2, 1, "SAME") - verify_conv2d_nchw(1, 64, 8, 64, 3, 1, "SAME", use_cudnn=True) - verify_conv2d_nchw(1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True) - verify_conv2d_nchw(1, 64, 8, 64, 5, 2, (1, 3), add_bias=True) - verify_conv2d_nchw(1, 64, 8, 64, 3, 1, "VALID", add_bias=True, add_relu=True) - verify_conv2d_nchw(1, 64, 8, 64, 24, 1, "SAME", add_bias=True, add_relu=True) - verify_conv2d_nchw(1, 32, 35, 64, 7, 2, (0, 0, 2, 2)) + with tvm.target.Target(target): + wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype) + + # check if tile_ow candidates are the factors of the right output weight. + cfg = autotvm.get_config() + _fallback_schedule(cfg, wkl) + ow_tile = np.prod(cfg["tile_ow"].size) + + tvm.testing.assert_allclose(ow_tile, out_width) + + +class TestResNet18Workloads(BaseConv2DTests): + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (3, 224, 64, 7, 2, 3), + (64, 56, 64, 3, 1, 1), + (64, 56, 64, 1, 1, 0), + (64, 56, 128, 3, 2, 1), + (64, 56, 128, 1, 2, 0), + (128, 28, 128, 3, 1, 1), + (128, 28, 256, 3, 2, 1), + (128, 28, 256, 1, 2, 0), + (256, 14, 256, 3, 1, 1), + (256, 14, 512, 3, 2, 1), + (256, 14, 512, 1, 2, 0), + (512, 7, 512, 3, 1, 1), + ) + + +class TestInceptionV3Workloads(BaseConv2DTests): + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (3, 299, 32, 3, 2, 0), + (32, 149, 32, 3, 1, 0), + (32, 147, 64, 3, 1, 1), + (64, 73, 80, 1, 1, 0), + (80, 73, 192, 3, 1, 0), + (192, 35, 64, 1, 1, 0), + (192, 35, 48, 1, 1, 0), + (48, 35, 64, 5, 1, 2), + (64, 35, 96, 3, 1, 1), + (96, 35, 96, 3, 1, 1), + (192, 35, 32, 1, 1, 0), + (256, 35, 64, 1, 1, 0), + (256, 35, 48, 1, 1, 0), + (288, 35, 64, 1, 1, 0), + (288, 35, 48, 1, 1, 0), + (288, 35, 384, 3, 2, 0), + (96, 35, 96, 3, 2, 0), + (768, 17, 192, 1, 1, 0), + (768, 17, 128, 1, 1, 0), + (128, 17, 128, 1, 1, 0), + (128, 17, 192, 7, 1, 3), + (128, 17, 128, 7, 1, 3), + (128, 17, 192, 1, 1, 0), + (768, 17, 160, 1, 1, 0), + # disable these tests due to some bugs of llvm with nvptx + # (160, 17, 160, 1, 1, 0), + (160, 17, 192, 7, 1, 3), + (160, 17, 160, 7, 1, 3), + (160, 17, 192, 1, 1, 0), + (192, 17, 192, 1, 1, 0), + (192, 17, 192, 7, 1, 3), + (192, 17, 320, 3, 2, 0), + (192, 17, 192, 3, 2, 0), + (1280, 8, 320, 1, 1, 0), + (1280, 8, 384, 1, 1, 0), + (384, 8, 384, 1, 1, 0), + (384, 8, 384, 3, 1, 1), + (1280, 8, 448, 1, 1, 0), + (448, 8, 384, 3, 1, 1), + (1280, 8, 192, 1, 1, 0), + (2048, 8, 320, 1, 1, 0), + (2048, 8, 384, 1, 1, 0), + (2048, 8, 448, 1, 1, 0), + (2048, 8, 192, 1, 1, 0), + (1024, 19, 84, 3, 1, 1), + (2048, 10, 126, 3, 1, 1), + (512, 5, 126, 3, 1, 1), + (256, 3, 126, 3, 1, 1), + ) + + +class TestWeirdWorkloads(BaseConv2DTests): + batch, in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (2, 2, 2, 2, 2, 2, 2), + (3, 3, 3, 3, 3, 3, 3), + (4, 4, 4, 4, 4, 4, 4), + (5, 5, 5, 5, 5, 5, 5), + (6, 6, 6, 6, 6, 6, 6), + # disable these tests due to some bugs of llvm with nvptx + # (1, 1, 1, 1, 1, 1, 1), + # (2, 13, 71, 59, 3, 1, 1), + ) + + +class TestAsymmetricPadding(BaseConv2DTests): + dilation = tvm.testing.parameter(1, 2) + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (3, 35, 64, 7, 2, (0, 0, 1, 1)), + (64, 8, 128, 3, 1, (3, 3, 2, 2)), + (64, 8, 64, 1, 1, (1, 2, 2, 1)), + (64, 17, 192, 1, 1, (1, 2)), + (64, 8, 64, 3, 1, (3, 1)), + (128, 8, 384, 3, 1, (0, 2)), + (64, 35, 64, 3, 1, (1, 2)), + (64, 8, 64, 1, 1, "VALID"), + (388, 8, 64, 3, 1, "VALID"), + (64, 10, 48, 3, 1, "VALID"), + (512, 19, 64, 1, 1, "SAME"), + (64, 5, 32, 2, 1, "SAME"), + (64, 8, 64, 3, 1, "SAME"), + (64, 8, 64, 3, 1, (1, 2, 2, 1)), + (64, 8, 64, 5, 2, (1, 3)), + (64, 8, 64, 3, 1, "VALID"), + (64, 8, 64, 24, 1, "SAME"), + (32, 35, 64, 7, 2, (0, 0, 2, 2)), + ) + + +class TestBatchSize(BaseConv2DTests): + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (64, 56, 64, 3, 1, 1), + ) + batch = tvm.testing.parameter(1, 4, 9) + + +class TestBiasRelu(BaseConv2DTests): + apply_relu = tvm.testing.parameter(True, False, ids=["relu", "no_relu"]) + add_bias = tvm.testing.parameter(True, False, ids=["bias", "no_bias"]) + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (64, 56, 64, 3, 1, 1), + (64, 8, 64, 3, 1, (1, 2, 2, 1)), + (64, 8, 64, 5, 2, (1, 3)), + (64, 8, 64, 3, 1, "VALID"), + (64, 8, 64, 24, 1, "SAME"), + ) if __name__ == "__main__": - test_conv2d_nchw() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_conv3d_ncdhw.py b/tests/python/topi/python/test_topi_conv3d_ncdhw.py index 056ef7fc146a..c45aaa188834 100644 --- a/tests/python/topi/python/test_topi_conv3d_ncdhw.py +++ b/tests/python/topi/python/test_topi_conv3d_ncdhw.py @@ -116,7 +116,7 @@ def check_target(target, dev): % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-6) for target, dev in tvm.testing.enabled_targets(): with autotvm.tophub.context(target): # load tophub pre-tuned parameters diff --git a/tests/python/topi/python/test_topi_conv3d_winograd.py b/tests/python/topi/python/test_topi_conv3d_winograd.py index 0b9d579287c3..54dd72a2f544 100644 --- a/tests/python/topi/python/test_topi_conv3d_winograd.py +++ b/tests/python/topi/python/test_topi_conv3d_winograd.py @@ -138,7 +138,7 @@ def check_device(device): ), ) func(a, w, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-6) for device in ["cuda"]: with autotvm.tophub.context(device): # load tophub pre-tuned parameters diff --git a/tests/python/topi/python/test_topi_dense.py b/tests/python/topi/python/test_topi_dense.py index 964c1621fa47..8f58415da329 100644 --- a/tests/python/topi/python/test_topi_dense.py +++ b/tests/python/topi/python/test_topi_dense.py @@ -28,11 +28,14 @@ from common import Int8Fallback +random_seed = tvm.testing.parameter(0) + use_bias = tvm.testing.parameter(True, False) batch_size = tvm.testing.parameter(1, 2, 128) in_dim, out_dim = tvm.testing.parameters((1024, 1000)) in_dtype, out_dtype = tvm.testing.parameters( ("float32", "float32"), + ("float16", "float16"), ("int8", "int32"), ) @@ -55,7 +58,9 @@ @tvm.testing.fixture(cache_return_value=True) -def dense_ref_data(batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype): +def dense_ref_data(random_seed, batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype): + np.random.seed(random_seed) + if "float" in in_dtype: a_np = np.random.uniform(size=(batch_size, in_dim)).astype(in_dtype) b_np = np.random.uniform(size=(out_dim, in_dim)).astype(in_dtype) @@ -90,19 +95,24 @@ def test_dense( ): target = tvm.target.Target(target) - if ( - in_dtype == "int8" - and target.kind.name == "cuda" - and not tvm.contrib.nvcc.have_int8(dev.compute_version) - ): - pytest.xfail("CUDA int8 intrinsics not available") - - if ( - in_dtype == "int8" - and target.kind.name == "vulkan" - and not target.attrs.get("supports_int8", False) - ): - pytest.xfail("Vulkan int8 driver support not available") + if target.kind.name == "cuda": + if in_dtype == "int8" and not tvm.contrib.nvcc.have_int8(dev.compute_version): + pytest.xfail("CUDA int8 intrinsics not available") + + if in_dtype == "float16" and not tvm.contrib.nvcc.have_fp16(dev.compute_version): + pytest.xfail("CUDA float16 intrinsics not available") + + if target.kind.name == "vulkan": + if in_dtype == "int8" and ( + not target.attrs.get("supports_int8", False) + or not target.attrs.get("supports_8bit_buffer", False) + ): + pytest.xfail("Vulkan int8 driver support not available") + if in_dtype == "float16" and ( + not target.attrs.get("supports_float16", False) + or not target.attrs.get("supports_16bit_buffer", False) + ): + pytest.xfail("Vulkan float16 driver support not available") if ( target.kind.name not in ["llvm", "c"] @@ -110,6 +120,13 @@ def test_dense( ): pytest.xfail("No implementation for tvm.topi.testing.dispatch to find") + if "int" in in_dtype: + tol = {"atol": 0, "rtol": 0} + elif in_dtype == "float32": + tol = {"rtol": 1e-5, "atol": 1e-5} + elif in_dtype == "float16": + tol = {"rtol": 5e-2, "atol": 1e-5} + A = te.placeholder((batch_size, in_dim), name="A", dtype=in_dtype) B = te.placeholder((out_dim, in_dim), name="B", dtype=in_dtype) C = te.placeholder((out_dim,), name="C", dtype=out_dtype) @@ -131,11 +148,10 @@ def test_dense( d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=out_dtype), dev) f = tvm.build(s, [A, B, C, D], target, name="dense") f(a, b, c, d) - tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-5) + tvm.testing.assert_allclose(d.numpy(), d_np, **tol) @pytest.mark.parametrize("target,in_dtype,out_dtype", [("cuda", "int8", "int32")]) -@tvm.testing.requires_gpu def test_dense_cuda_int8( target, dev, diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index 76093c51b4c8..27601cd32b89 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -14,561 +14,424 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import sys + +import numpy as np +import pytest + import tvm -from tvm import te -from tvm import autotvm -from tvm import topi +import tvm.testing import tvm.topi.testing -import numpy as np + +from tvm import autotvm, te, topi from tvm.topi.utils import get_const_tuple from tvm.topi.nn.utils import get_pad_tuple from tvm.contrib.pickle_memoize import memoize from tvm.topi.nn.depthwise_conv2d import _get_workload from tvm.topi.x86.depthwise_conv2d import _fallback_schedule -import tvm.testing -_depthwise_conv2d_nchw_implement = { - "generic": [(topi.nn.depthwise_conv2d_nchw, topi.generic.schedule_depthwise_conv2d_nchw)], - "arm_cpu": [ - (topi.arm_cpu.depthwise_conv2d_nchw, topi.arm_cpu.schedule_depthwise_conv2d_nchw), - ( - topi.arm_cpu.depthwise_conv2d_nchw_spatial_pack, - topi.arm_cpu.schedule_depthwise_conv2d_nchw_spatial_pack, - ), - ], - "gpu": [(topi.cuda.depthwise_conv2d_nchw, topi.cuda.schedule_depthwise_conv2d_nchw)], - "mali": [(topi.mali.depthwise_conv2d_nchw, topi.mali.schedule_depthwise_conv2d_nchw)], - "bifrost": [(topi.nn.depthwise_conv2d_nchw, topi.bifrost.schedule_depthwise_conv2d_nchw)], - "intel_graphics": [ - ( - topi.intel_graphics.depthwise_conv2d_nchw, - topi.intel_graphics.schedule_depthwise_conv2d_nchw, - ) - ], +_depthwise_conv2d_implement = { + "NCHW": { + "generic": [(topi.nn.depthwise_conv2d_nchw, topi.generic.schedule_depthwise_conv2d_nchw)], + "arm_cpu": [ + (topi.arm_cpu.depthwise_conv2d_nchw, topi.arm_cpu.schedule_depthwise_conv2d_nchw), + ( + topi.arm_cpu.depthwise_conv2d_nchw_spatial_pack, + topi.arm_cpu.schedule_depthwise_conv2d_nchw_spatial_pack, + ), + ], + "gpu": [(topi.cuda.depthwise_conv2d_nchw, topi.cuda.schedule_depthwise_conv2d_nchw)], + "mali": [(topi.mali.depthwise_conv2d_nchw, topi.mali.schedule_depthwise_conv2d_nchw)], + "bifrost": [(topi.nn.depthwise_conv2d_nchw, topi.bifrost.schedule_depthwise_conv2d_nchw)], + "intel_graphics": [ + ( + topi.intel_graphics.depthwise_conv2d_nchw, + topi.intel_graphics.schedule_depthwise_conv2d_nchw, + ) + ], + }, + "NHWC": { + "generic": [(topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc)], + "arm_cpu": [ + ( + topi.arm_cpu.compute_depthwise_conv2d_nhwc, + topi.arm_cpu.schedule_depthwise_conv2d_nhwc, + ) + ], + "gpu": [(topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc)], + "mali": [(topi.mali.depthwise_conv2d_nhwc, topi.mali.schedule_depthwise_conv2d_nhwc)], + "bifrost": [(topi.mali.depthwise_conv2d_nhwc, topi.mali.schedule_depthwise_conv2d_nhwc)], + }, + "NCHWc": { + "generic": [(topi.x86.depthwise_conv2d_NCHWc, topi.x86.schedule_depthwise_conv2d_NCHWc)], + }, } -_depthwise_conv2d_nhwc_implement = { - "generic": (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc), - "arm_cpu": ( - topi.arm_cpu.compute_depthwise_conv2d_nhwc, - topi.arm_cpu.schedule_depthwise_conv2d_nhwc, - ), - "gpu": (topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc), -} +random_seed = tvm.testing.parameter(0) +in_dtype, out_dtype = tvm.testing.parameters( + ("float32", "float32"), + ("float16", "float16"), +) -def compile_depthwise_NHWC_int8_arm( - batch, - in_channel, - in_size, - kernel, - depth_multiplier, - stride, - padding, - add_bias=False, - dilation=1, -): - pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) - padding_sum = pad_top + pad_left + pad_bottom + pad_right - - in_height = in_width = in_size - A = te.placeholder((batch, in_height, in_width, in_channel), name="A", dtype="int16") - W = te.placeholder((kernel, kernel, in_channel, depth_multiplier), name="W", dtype="int16") - bias = te.placeholder((in_channel * depth_multiplier,), name="bias", dtype="int32") - dtype = "int32" - - target = "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu" - compute = topi.arm_cpu.compute_depthwise_conv2d_nhwc - schedule = topi.arm_cpu.schedule_depthwise_conv2d_nhwc - - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - - print("Compiling on arm AArch64 target: %s" % target) - with tvm.target.Target(target): - assert topi.arm_cpu.arm_utils.is_aarch64_arm(), "AArch64 target not recognized" - - C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype) - if add_bias: - C += bias - ins_outs = [A, W, bias, C] - else: - ins_outs = [A, W, C] - - s = schedule([C]) - - func = tvm.build( - s, - ins_outs, - target, - name="depthwise_conv2d", - ) - - -def depthwise_conv2d_with_workload_nchw( - target, - dev, - batch, - in_channel, - in_height, - channel_multiplier, - filter_height, + +@tvm.testing.fixture +def input_shape(layout, batch, in_channel, in_size, filter_shape): + if layout == "NCHW": + return (batch, in_channel, in_size, in_size) + elif layout == "NHWC": + return (batch, in_size, in_size, in_channel) + elif layout == "NCHWc": + oc_block = filter_shape[-1] + ic_block = next(bn for bn in range(oc_block, 0, -1) if in_channel % bn == 0) + return (batch, in_channel // ic_block, in_size, in_size, ic_block) + + +@tvm.testing.fixture +def filter_shape(layout, in_channel, channel_multiplier, kernel): + filter_channel = in_channel + if layout == "NCHW": + return (filter_channel, channel_multiplier, kernel, kernel) + elif layout == "NHWC": + return (kernel, kernel, filter_channel, channel_multiplier) + elif layout == "NCHWc": + out_channel = in_channel * channel_multiplier + # For testing the functionality, we choose an arbitrary block + # size that can divide out_channel, regardless of the + # performance. + oc_block = next(bn for bn in range(16, 0, -1) if out_channel % bn == 0) + return (out_channel // oc_block, 1, kernel, kernel, 1, oc_block) + + +@tvm.testing.fixture +def scale_shape(layout, in_channel, channel_multiplier, filter_shape): + out_channel = in_channel * channel_multiplier + + if layout in ("NCHW", "NHWC"): + return (out_channel,) + + if layout == "NCHWc": + oc_block = filter_shape[-1] + return (out_channel // oc_block, oc_block) + + raise ValueError("Unknown layout {}".format(layout)) + + +@tvm.testing.fixture +def shift_shape(scale_shape): + return scale_shape + + +@tvm.testing.fixture(cache_return_value=True) +def ref_data( + random_seed, + in_dtype, + out_dtype, + layout, + input_shape, + filter_shape, + dilation, stride, padding, - dilation=1, + scale_shape, + shift_shape, + use_scale_shift, + apply_relu, ): - in_width = in_height - filter_channel = in_channel - filter_width = filter_height - stride_h = stride_w = stride - - if dilation == 1: - # here we transform the padding argument from 'str' to 'tuple' , - # because we need this to match the "workload" tuple to the records in TopHub - padt, padl, padb, padr = get_pad_tuple(padding, (filter_height, filter_width)) - padding_args = (padt, padl, padb, padr) - else: - padding_args = padding - - # placeholder - Input = te.placeholder((batch, in_channel, in_height, in_width), name="Input") - Filter = te.placeholder( - (filter_channel, channel_multiplier, filter_height, filter_width), name="Filter" + np.random.seed(random_seed) + + # scipy.signal.convolve2d does not support float16 data types, and + # the python fallback is too slow for general use. Computing + # ref_data in float32 will have fewer rounding errors than the TVM + # float16 compute, but those vary based on schedule anyways. + conv_dtype = "float32" if in_dtype == "float16" else in_dtype + + input_np = np.random.uniform(size=input_shape).astype(in_dtype) + filter_np = np.random.uniform(size=filter_shape).astype(in_dtype) + scale_np = np.random.uniform(size=scale_shape).astype(out_dtype) + shift_np = np.random.uniform(size=shift_shape).astype(out_dtype) + if layout == "NCHW": + np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nchw + dilation = (1, 1, dilation, dilation) + reshape = (1, -1, 1, 1) + elif layout == "NHWC": + np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nhwc + dilation = (dilation, dilation, 1, 1) + reshape = (1, 1, 1, -1) + elif layout == "NCHWc": + np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nchwc + dilation = (1, 1, dilation, dilation, 1, 1) + reshape = (1, scale_shape[0], 1, 1, scale_shape[1]) + + dilated_filter_np = tvm.topi.testing.dilate_python(filter_np, dilation) + output_np = np_depthwise_conv2d( + input_np.astype(conv_dtype), dilated_filter_np.astype(conv_dtype), stride, padding + ).astype(out_dtype) + + if use_scale_shift: + output_np = output_np * scale_np.reshape(reshape) + shift_np.reshape(reshape) + if apply_relu: + output_np = np.maximum(output_np, 0) + + return ( + input_np, + filter_np, + scale_np, + shift_np, + output_np, ) - Scale = te.placeholder((in_channel * channel_multiplier,), name="Scale") - Shift = te.placeholder((in_channel * channel_multiplier,), name="Shift") - dtype = "float32" - with autotvm.tophub.context(target): # load tophub pre-tuned parameters - impl_list = tvm.topi.testing.dispatch(target, _depthwise_conv2d_nchw_implement)[:] - if target == "llvm" and channel_multiplier == 1 and dilation == 1: - impl_list.append( - (topi.x86.depthwise_conv2d_nchw, topi.x86.schedule_depthwise_conv2d_nchw) - ) +class BaseDepthwiseConv2D: + """Provides the test_conv2d test function, to be used by other test classes. - for fcompute, fschedule in impl_list: - with tvm.target.Target(target): - # declare - DepthwiseConv2d = fcompute( - Input, Filter, (stride_h, stride_w), padding_args, dilation, dtype - ) - ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) - Relu = topi.nn.relu(ScaleShift) - # schedule - s1 = fschedule(DepthwiseConv2d) - s2 = fschedule(ScaleShift) - s3 = fschedule(Relu) - # build the kernels - f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], target) - f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], target) - f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], target) - - # Prepare pod type for test data closure - input_shape = get_const_tuple(Input.shape) - filter_shape = get_const_tuple(Filter.shape) - scale_shape = get_const_tuple(Scale.shape) - shift_shape = get_const_tuple(Shift.shape) - scale_shift_shape = get_const_tuple(ScaleShift.shape) - - # Use memoize, pickle the test data for next time use. - @memoize("topi.tests.test_topi_depthwise_conv2d.nchw") - def get_ref_data(): - input_np = np.random.uniform(size=input_shape).astype(dtype) - filter_np = np.random.uniform(size=filter_shape).astype(dtype) - dilated_filter_np = tvm.topi.testing.dilate_python( - filter_np, (1, 1, dilation, dilation) - ) - scale_np = np.random.uniform(size=scale_shape).astype(dtype) - shift_np = np.random.uniform(size=shift_shape).astype(dtype) - # correctness with scipy - depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nchw( - input_np, dilated_filter_np, stride, padding - ) - scale_shift_scipy = np.zeros(shape=scale_shift_shape) - for c in range(in_channel * channel_multiplier): - scale_shift_scipy[:, c, :, :] = ( - depthwise_conv2d_scipy[:, c, :, :] * scale_np[c] + shift_np[c] - ) - relu_scipy = np.maximum(scale_shift_scipy, 0) - return ( - input_np, - filter_np, - scale_np, - shift_np, - depthwise_conv2d_scipy, - scale_shift_scipy, - relu_scipy, - ) + Test parameter sets are split out into different classes for + readability (e.g. used for mobilenet), and for restrictions + (e.g. implemented only for llvm). + """ - # Get the test data - ( - input_np, - filter_np, - scale_np, - shift_np, - depthwise_conv2d_scipy, - scale_shift_scipy, - relu_scipy, - ) = get_ref_data() - - def verify_workload_padding(): - _, _, out_height, out_width = get_const_tuple(depthwise_conv2d_scipy.shape) - wkl = _get_workload( - Input, Filter, (stride_h, stride_w), padding_args, dilation, dtype - ) + layout = tvm.testing.parameter("NCHW", "NHWC") - # check if tile_ow candidates are the factors of the right output weight. - with tvm.target.Target(target): - cfg = autotvm.get_config() - _fallback_schedule(cfg, wkl) - ow_tile = np.prod(cfg["tile_ow"].size) - - tvm.testing.assert_allclose(ow_tile, out_width) - - if "llvm" in target: - verify_workload_padding() - - input_tvm = tvm.nd.array(input_np, dev) - filter_tvm = tvm.nd.array(filter_np, dev) - scale_tvm = tvm.nd.array(scale_np, dev) - shift_tvm = tvm.nd.array(shift_np, dev) - depthwise_conv2d_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), - dev, - ) - scale_shift_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), dev - ) - relu_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), dev - ) - # launch kernel 1 (depthwise_conv2d) - timer_1 = f1.time_evaluator(f1.entry_name, dev, number=1) - tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean - # launch kernel 2 (depthwise_conv2d + scale_shift) - timer_2 = f2.time_evaluator(f2.entry_name, dev, number=1) - tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean - # launch kernel 3 (depthwise_conv2d + scale_shift + relu) - timer_3 = f3.time_evaluator(f3.entry_name, dev, number=1) - tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean - tvm.testing.assert_allclose( - depthwise_conv2d_tvm.numpy(), depthwise_conv2d_scipy, rtol=1e-5 - ) - tvm.testing.assert_allclose(scale_shift_tvm.numpy(), scale_shift_scipy, rtol=1e-5) - tvm.testing.assert_allclose(relu_tvm.numpy(), relu_scipy, rtol=1e-5) - - -def depthwise_conv2d_with_workload_nhwc( - target, - dev, - batch, - in_channel, - in_height, - channel_multiplier, - filter_height, - stride_h, - padding, - dilation=1, -): - in_width = in_height - filter_channel = in_channel - filter_width = filter_height - stride_w = stride_h - - if dilation == 1: - # here we transform the padding argument from 'str' to 'tuple' , - # because we need this to match the "workload" tuple to the records in TopHub - pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width)) - padding_args = (pad_h, pad_w) - else: - padding_args = padding - - # placeholder - Input = te.placeholder((batch, in_height, in_width, in_channel), name="Input") - Filter = te.placeholder( - (filter_height, filter_width, filter_channel, channel_multiplier), name="Filter" + (batch, in_channel, in_size, channel_multiplier, kernel, stride) = tvm.testing.parameters( + (1, 728, 32, 1, 3, 1), + (4, 256, 64, 2, 5, 2), ) - Scale = te.placeholder((in_channel * channel_multiplier,), name="Scale") - Shift = te.placeholder((in_channel * channel_multiplier,), name="Shift") - - dtype = "float32" - - with autotvm.tophub.context(target): # load tophub pre-tuned parameters - fcompute, fschedule = tvm.topi.testing.dispatch(target, _depthwise_conv2d_nhwc_implement) - with tvm.target.Target(target): - # declare - DepthwiseConv2d = fcompute( - Input, Filter, (stride_h, stride_w), padding_args, dilation, dtype - ) - ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift) - Relu = topi.nn.relu(ScaleShift) - # schedule - s1 = fschedule(DepthwiseConv2d) - s2 = fschedule(ScaleShift) - s3 = fschedule(Relu) - # build the kernels - f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], target) - f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], target) - f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], target) - - # Prepare pod type for test data closure - input_shape = get_const_tuple(Input.shape) - filter_shape = get_const_tuple(Filter.shape) - scale_shape = get_const_tuple(Scale.shape) - shift_shape = get_const_tuple(Shift.shape) - scale_shift_shape = get_const_tuple(ScaleShift.shape) - - # Use memoize, pickle the test data for next time use. - @memoize("topi.tests.test_topi_depthwise_conv2d.nhwc.v2") - def get_ref_data(): - input_np = np.random.uniform(size=input_shape).astype(dtype) - filter_np = np.random.uniform(size=filter_shape).astype(dtype) - dilated_filter_np = tvm.topi.testing.dilate_python( - filter_np, (dilation, dilation, 1, 1) - ) - scale_np = np.random.uniform(size=scale_shape).astype(dtype) - shift_np = np.random.uniform(size=shift_shape).astype(dtype) - # correctness with scipy - depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nhwc( - input_np, dilated_filter_np, stride=[stride_h, stride_w], padding=padding - ) - scale_shift_scipy = np.zeros(shape=scale_shift_shape) - for c in range(in_channel * channel_multiplier): - scale_shift_scipy[:, :, :, c] = ( - depthwise_conv2d_scipy[:, :, :, c] * scale_np[c] + shift_np[c] - ) - relu_scipy = np.maximum(scale_shift_scipy, 0) - return ( - input_np, - filter_np, - scale_np, - shift_np, - depthwise_conv2d_scipy, - scale_shift_scipy, - relu_scipy, + padding = tvm.testing.parameter("SAME", "VALID") + dilation = tvm.testing.parameter(1, 2) + + use_scale_shift = tvm.testing.parameter(True, False, ids=["with_scale_shift", "no_scale_shift"]) + apply_relu = tvm.testing.parameter(True, False, ids=["with_relu", "no_relu"]) + + run_after_compile = True + + def test_conv2d( + self, + target, + dev, + in_dtype, + out_dtype, + layout, + input_shape, + filter_shape, + scale_shape, + shift_shape, + use_scale_shift, + apply_relu, + batch, + in_channel, + channel_multiplier, + kernel, + stride, + padding, + dilation, + ref_data, + ): + target = tvm.target.Target(target) + if ( + target.kind.name == "cuda" + and in_dtype == "float16" + and not tvm.contrib.nvcc.have_fp16(dev.compute_version) + ): + pytest.xfail("CUDA float16 intrinsics not available") + + if ( + target.kind.name == "vulkan" + and in_dtype == "float16" + and ( + not target.attrs.get("supports_float16", False) + or not target.attrs.get("supports_16bit_buffer", False) ) - - # Get the test data - ( - input_np, - filter_np, - scale_np, - shift_np, - depthwise_conv2d_scipy, - scale_shift_scipy, - relu_scipy, - ) = get_ref_data() - - # prepare data - input_tvm = tvm.nd.array(input_np, dev) - filter_tvm = tvm.nd.array(filter_np, dev) - scale_tvm = tvm.nd.array(scale_np, dev) - shift_tvm = tvm.nd.array(shift_np, dev) - depthwise_conv2d_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), dev - ) - scale_shift_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), dev - ) - relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), dev) - # launch kernel 1 (depthwise_conv2d) - timer_1 = f1.time_evaluator(f1.entry_name, dev, number=1) - tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean - # launch kernel 2 (depthwise_conv2d + scale_shift) - timer_2 = f2.time_evaluator(f2.entry_name, dev, number=1) - tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean - # launch kernel 3 (depthwise_conv2d + scale_shift + relu) - timer_3 = f3.time_evaluator(f3.entry_name, dev, number=1) - tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean - relu_scipy = np.maximum(scale_shift_scipy, 0) - tvm.testing.assert_allclose(depthwise_conv2d_tvm.numpy(), depthwise_conv2d_scipy, rtol=1e-5) - tvm.testing.assert_allclose(scale_shift_tvm.numpy(), scale_shift_scipy, rtol=1e-5) - tvm.testing.assert_allclose(relu_tvm.numpy(), relu_scipy, rtol=1e-5) - - -def _transform_data(data, bn): - # NCHW -> NCHW[x]c - batch_size, channel, height, width = data.shape - data = np.reshape(data, (batch_size, channel // bn, bn, height, width)) - data = np.transpose(data, (0, 1, 3, 4, 2)) - return data - - -def _transform_kernel(kernel, bn): - # channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block - channel, channel_multiplier, kh, kw = kernel.shape - out_channel = channel * channel_multiplier - kernel = np.reshape(kernel, (out_channel // bn, bn, kh, kw)) - kernel = np.transpose(kernel, (0, 2, 3, 1)) - out_channel_chunk, kh, kw, out_channel_block = kernel.shape - return kernel.reshape(out_channel_chunk, 1, kh, kw, 1, out_channel_block) - - -def depthwise_conv2d_with_workload_NCHWc( - target, - dev, - batch, - in_channel, - in_height, - channel_multiplier, - filter_height, - stride, - padding, - dilation=1, -): - in_width = in_height - filter_channel = in_channel - filter_width = filter_height - stride_h = stride_w = stride - - assert ( - channel_multiplier == 1 - ), "depthwise_conv2d_NCHWc currently does not support channel multiplier > 1." - pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width)) - padding_args = (pad_h, pad_w) - - out_channel = filter_channel * channel_multiplier - # for testing functionality, - # we choose arbitrary block size that can divide the channel, - # regardless of the performance. - oc_block = 1 - for bn in range(16, 0, -1): - if out_channel % bn == 0: - oc_block = bn - break - - ic_block = 1 - for bn in range(oc_block, 0, -1): - if in_channel % bn == 0: - ic_block = bn - break - - # placeholder - Input = te.placeholder( - (batch, in_channel // ic_block, in_height, in_width, ic_block), name="Input" - ) - Filter = te.placeholder( - (out_channel // oc_block, 1, filter_height, filter_width, 1, oc_block), name="Filter" - ) - in_layout = "NCHW%dc" % ic_block - out_layout = "NCHW%dc" % oc_block - dtype = "float32" - - with autotvm.tophub.context(target): # load tophub pre-tuned parameters - dev = tvm.device(target, 0) - with tvm.target.Target(target): - # declare - DepthwiseConv2d = topi.x86.depthwise_conv2d_NCHWc( + ): + pytest.xfail("Vulkan float16 driver support not available") + + # Transform the padding argument from 'str' to 'tuple' to + # match the "workload" tuple in TopHub. Which padding_args to + # use for each layout chosen to reproduce previous behavior. + if dilation == 1: + padding_args = get_pad_tuple(padding, (kernel, kernel)) + padding_args_i = [0, 1, 2, 3] if layout == "NCHW" else [0, 1] + padding_args = [padding_args[i] for i in padding_args_i] + else: + padding_args = padding + + # placeholder + Input = te.placeholder(input_shape, name="Input", dtype=in_dtype) + Filter = te.placeholder(filter_shape, name="Filter", dtype=in_dtype) + Scale = te.placeholder(scale_shape, name="Scale", dtype=out_dtype) + Shift = te.placeholder(shift_shape, name="Shift", dtype=out_dtype) + + if layout == "NCHW": + topi_scale_shift = topi.nn.scale_shift_nchw + fcompute_args = (Input, Filter, stride, padding_args, dilation, out_dtype) + + elif layout == "NHWC": + topi_scale_shift = topi.nn.scale_shift_nhwc + fcompute_args = (Input, Filter, stride, padding_args, dilation, out_dtype) + + elif layout == "NCHWc": + topi_scale_shift = topi.nn.scale_shift_nchwc + in_layout = "NCHW{}c".format(input_shape[-1]) + out_layout = "NCHW{}c".format(filter_shape[-1]) + fcompute_args = ( Input, Filter, - (stride_h, stride_w), + stride, padding, - (dilation, dilation), + dilation, in_layout, out_layout, - dtype, - ) - # TODO: add scale_shift implement for NCHWc and add test here - Relu = topi.nn.relu(DepthwiseConv2d) - # schedule - s1 = topi.x86.schedule_depthwise_conv2d_NCHWc(DepthwiseConv2d) - s2 = topi.x86.schedule_depthwise_conv2d_NCHWc(Relu) - # build the kernels - f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], target) - f2 = tvm.build(s2, [Input, Filter, Relu], target) - - # Prepare pod type for test data closure - input_shape = (batch, in_channel, in_height, in_width) - filter_shape = (filter_channel, channel_multiplier, filter_height, filter_width) - - # Use memoize, pickle the test data for next time use. - @memoize("topi.tests.test_topi_depthwise_conv2d.NCHWc") - def get_ref_data(): - input_np = np.random.uniform(size=input_shape).astype(dtype) - filter_np = np.random.uniform(size=filter_shape).astype(dtype) - # correctness with scipy - dw_np = tvm.topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation)).astype( - dtype - ) - depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nchw( - input_np, dw_np, stride, padding - ) - relu_scipy = np.maximum(depthwise_conv2d_scipy, 0) - return ( - _transform_data(input_np, ic_block), - _transform_kernel(filter_np, oc_block), - _transform_data(depthwise_conv2d_scipy, oc_block), - _transform_data(relu_scipy, oc_block), + out_dtype, ) - # Get the test data - (input_np, filter_np, depthwise_conv2d_scipy, relu_scipy) = get_ref_data() - - input_tvm = tvm.nd.array(input_np, dev) - filter_tvm = tvm.nd.array(filter_np, dev) - - depthwise_conv2d_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), dev - ) - relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), dev) - # launch kernel 1 (depthwise_conv2d) - f1(input_tvm, filter_tvm, depthwise_conv2d_tvm) - # launch kernel 2 (depthwise_conv2d + relu) - f2(input_tvm, filter_tvm, relu_tvm) - tvm.testing.assert_allclose(depthwise_conv2d_tvm.numpy(), depthwise_conv2d_scipy, rtol=1e-5) - tvm.testing.assert_allclose(relu_tvm.numpy(), relu_scipy, rtol=1e-5) - - -@tvm.testing.parametrize_targets -def test_depthwise_conv2d_nchw(target, dev): - # mobilenet workloads - depthwise_conv2d_with_workload_nchw(target, dev, 1, 32, 112, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 64, 112, 1, 3, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 128, 56, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 128, 56, 1, 3, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 256, 28, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 256, 28, 1, 3, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 512, 14, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 512, 14, 1, 3, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 1024, 7, 1, 3, 1, "SAME") - - depthwise_conv2d_with_workload_nchw(target, dev, 1, 728, 32, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 4, 256, 64, 2, 5, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 728, 32, 1, 3, 1, "VALID") - depthwise_conv2d_with_workload_nchw(target, dev, 4, 256, 64, 2, 5, 2, "VALID") - # dilation = 2 - depthwise_conv2d_with_workload_nchw(target, dev, 1, 728, 64, 1, 3, 1, "SAME", dilation=2) - - -@tvm.testing.parametrize_targets -def test_depthwise_conv2d_nhwc(target, dev): - depthwise_conv2d_with_workload_nhwc(target, dev, 1, 728, 32, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nhwc(target, dev, 4, 256, 64, 2, 5, 2, "SAME") - depthwise_conv2d_with_workload_nhwc(target, dev, 1, 728, 32, 1, 3, 1, "VALID") - depthwise_conv2d_with_workload_nhwc(target, dev, 4, 256, 64, 2, 5, 2, "VALID") - - # dilation = 2 - # disabled because it uses too large shared memory on cuda - # depthwise_conv2d_with_workload_nhwc(target, dev, 1, 728, 64, 1, 3, 1, "SAME", dilation=2) - - -# test llvm only for now since depthwise_conv2d_NCHWc implement is missing in other backend. + with autotvm.tophub.context(target): # load tophub pre-tuned parameters + impl_list = tvm.topi.testing.dispatch(target, _depthwise_conv2d_implement[layout])[:] + if target == "llvm" and layout == "NCHW" and channel_multiplier == 1 and dilation == 1: + impl_list.append( + (topi.x86.depthwise_conv2d_nchw, topi.x86.schedule_depthwise_conv2d_nchw) + ) + + for fcompute, fschedule in impl_list: + with tvm.target.Target(target): + # Declare, build schedule + C = fcompute(*fcompute_args) + if use_scale_shift: + C = topi_scale_shift(C, Scale, Shift) + if apply_relu: + C = topi.nn.relu(C) + + s = fschedule(C) + + # Build and run + f = tvm.build(s, [Input, Filter, Scale, Shift, C], target) + + if self.run_after_compile: + input_np, filter_np, scale_np, shift_np, output_np = ref_data + if "int" in out_dtype: + tol = {"atol": 0, "rtol": 0} + elif out_dtype == "float32": + tol = {"rtol": 1e-4, "atol": 1e-5} + elif out_dtype == "float16": + # A summation in float16 with a single accumulator very + # quickly runs into large rounding errors. At some point, + # this tolerance should be schedule-dependent for to avoid + # false negatives. + num_values_summed = kernel * kernel + gap_size = ( + np.nextafter(output_np.max(), np.inf, dtype=output_np.dtype) + - output_np.max() + ) + tol = {"rtol": 1e-3, "atol": num_values_summed * gap_size / 2} + + input_tvm = tvm.nd.array(input_np, dev) + filter_tvm = tvm.nd.array(filter_np, dev) + scale_tvm = tvm.nd.array(scale_np, dev) + shift_tvm = tvm.nd.array(shift_np, dev) + output_tvm = tvm.nd.array( + np.zeros(shape=get_const_tuple(C.shape), dtype=C.dtype), + dev, + ) + + f(input_tvm, filter_tvm, scale_tvm, shift_tvm, output_tvm) + tvm.testing.assert_allclose(output_np, output_tvm.numpy(), **tol) + + +class TestDepthwiseConv2D(BaseDepthwiseConv2D): + """Test variety of parameters, defined in BaseDepthwiseConv2D. Also + has llvm-specific tests for workload padding.""" + + @tvm.testing.parametrize_targets("llvm") + def test_workload_padding( + self, + out_dtype, + layout, + input_shape, + filter_shape, + target, + ref_data, + stride, + padding, + dilation, + ): + input_np, filter_np, scale_np, shift_np, output_np = ref_data + if layout == "NCHW": + _, _, out_height, out_width = output_np.shape + elif layout == "NHWC": + _, out_height, out_width, _ = output_np.shape + elif layout == "NCHWc": + _, _, out_height, out_width, _ = output_np.shape + + Input = te.placeholder(input_shape, name="Input") + Filter = te.placeholder(filter_shape, name="Filter") + wkl = _get_workload(Input, Filter, (stride, stride), padding, dilation, out_dtype, layout) + + # check if tile_ow candidates are the factors of the right output weight. + with tvm.target.Target(target): + cfg = autotvm.get_config() + _fallback_schedule(cfg, wkl) + ow_tile = np.prod(cfg["tile_ow"].size) + + tvm.testing.assert_allclose(ow_tile, out_width) + + +class TestDepthwiseConv2D_MobilenetWorkloads(BaseDepthwiseConv2D): + """Extra tests to verify functionality for workloads used by mobilenet.""" + + layout = tvm.testing.parameter("NCHW") + + batch = tvm.testing.parameter(1) + channel_multiplier = tvm.testing.parameter(1) + kernel = tvm.testing.parameter(3) + padding = tvm.testing.parameter("SAME") + dilation = tvm.testing.parameter(1) + + in_channel, in_size, stride = tvm.testing.parameters( + (32, 112, 1), + (64, 112, 2), + (128, 56, 1), + (128, 56, 2), + (256, 28, 1), + (256, 28, 2), + (512, 14, 1), + (512, 14, 2), + (1024, 7, 1), + ) + + @tvm.testing.parametrize_targets("llvm") -def test_depthwise_conv2d_nchwc(target, dev): - # NCHW[x]c - depthwise_conv2d_with_workload_NCHWc(target, dev, 1, 728, 32, 1, 3, 1, "SAME", dilation=2) - depthwise_conv2d_with_workload_NCHWc(target, dev, 1, 728, 32, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_NCHWc(target, dev, 1, 728, 32, 1, 3, 1, "VALID") +class TestDepthwiseConv2D_NCHWc(BaseDepthwiseConv2D): + """Tests specific to NCHWc layouts. + + Once the implementation supports channel_multiplier>1 and GPU + devices, this class can be merged into TestDepthwiseConv2D. + """ + + # depthwise_conv2d_NCHWc currently does not support channel multiplier > 1 + layout = tvm.testing.parameter("NCHWc") + (batch, in_channel, in_size, channel_multiplier, kernel, stride) = tvm.testing.parameters( + (1, 728, 32, 1, 3, 1), + ) + + +@tvm.testing.parametrize_targets("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu") +class TestDepthwiseConv2DArmCompile(BaseDepthwiseConv2D): + """Compile-only tests for cross-compiling to ARM.""" + layout = tvm.testing.parameter("NHWC", "NCHW") + batch = tvm.testing.parameter(1) + dilation = tvm.testing.parameter(1) + in_dtype, out_dtype = tvm.testing.parameters(("int16", "int32")) + in_channel = tvm.testing.parameter(728) + in_size = tvm.testing.parameter(32) + kernel = tvm.testing.parameter(1) + channel_multiplier = tvm.testing.parameter(1, 3) + stride = tvm.testing.parameter(1) + padding = tvm.testing.parameter("SAME") + use_scale_shift = tvm.testing.parameter(True, False, ids=["with_scale_shift", "no_scale_shift"]) -def test_depthwise_conv2d_arm(): - # Test compilation on arm targets - compile_depthwise_NHWC_int8_arm(1, 728, 32, 1, 3, 1, "SAME") - compile_depthwise_NHWC_int8_arm(1, 728, 32, 1, 1, 1, "SAME", True) + run_after_compile = False if __name__ == "__main__": - test_depthwise_conv2d() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_loss.py b/tests/python/topi/python/test_topi_loss.py index c2db8bebc0f5..bb7655b192f5 100644 --- a/tests/python/topi/python/test_topi_loss.py +++ b/tests/python/topi/python/test_topi_loss.py @@ -52,7 +52,7 @@ def verify_nll_loss( weights_nd = tvm.nd.array(weights_npy, dev) out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(nll_loss_result.dtype), dev) fn(predictions_nd, targets_nd, weights_nd, out_nd) - out_topi = out_nd.asnumpy() + out_topi = out_nd.numpy() tvm.testing.assert_allclose(out_topi, out_npy, rtol=1e-4, atol=1e-5) diff --git a/tests/python/topi/python/test_topi_pooling.py b/tests/python/topi/python/test_topi_pooling.py index 57877e3d202c..b7f11de1391f 100644 --- a/tests/python/topi/python/test_topi_pooling.py +++ b/tests/python/topi/python/test_topi_pooling.py @@ -315,6 +315,7 @@ def verify_poolnd( pool_type, count_include_pad, ceil_mode, + layout=layout, ) np.testing.assert_equal(tuple(output_shape), tuple(ref_np.shape)) @@ -355,7 +356,7 @@ def verify_pool3d( padding, pool_type, ceil_mode, - layout="NCDHW", + layout=layout, count_include_pad=count_include_pad, ) @@ -363,18 +364,106 @@ def verify_pool3d( @tvm.testing.uses_gpu def test_pool3d(): """test cases of pool3d""" + verify_pool3d( + [1, 16, 32, 32, 32], [2, 2, 2], [2, 2, 2], [1, 1, 1], [0, 0, 0, 0, 0, 0], "avg", False, True + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [1, 1, 1], [1, 1, 2, 2, 2, 1], "avg", False, True + ) + verify_pool3d( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [1, 1, 2, 2, 2, 1], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], + [4, 4, 4], + [4, 4, 4], + [1, 1, 1], + [3, 3, 3, 3, 3, 3], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], + [4, 4, 4], + [4, 4, 4], + [1, 1, 1], + [0, 0, 0, 0, 0, 0], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 32, 32, 32], [2, 2, 2], [2, 2, 2], [1, 1, 1], [0, 0, 0, 0, 0, 0], "max", False + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [1, 1, 1], [2, 2, 1, 1, 1, 2], "max", False + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [1, 1, 1], [2, 2, 1, 1, 1, 2], "max", True + ) + + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [1, 1, 1], [2, 1, 0, 5, 4, 3], "avg", False, True + ) + verify_pool3d( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [0, 5, 4, 3, 2, 1], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [1, 1, 1], [1, 0, 5, 4, 3, 2], "max", False + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [1, 1, 1], [3, 2, 1, 0, 5, 4], "max", True + ) + + # Test non-1 dilation + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [3, 3, 3], [2, 1, 0, 5, 4, 3], "avg", False, True + ) verify_pool3d( [1, 16, 32, 32, 32], [2, 2, 2], [2, 2, 2], + [2, 2, 2], + [0, 5, 4, 3, 2, 1], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [2, 1, 3], [1, 0, 5, 4, 3, 2], "max", False + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [2, 2, 3], [3, 2, 1, 0, 5, 4], "max", True + ) + # Test channel last layouts + verify_pool3d( + [1, 32, 32, 32, 16], + [2, 2, 2], + [2, 2, 2], [1, 1, 1], [0, 0, 0, 0, 0, 0], "avg", False, True, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [3, 3, 3], [3, 3, 3], [1, 1, 1], @@ -382,9 +471,10 @@ def test_pool3d(): "avg", False, True, + layout="NDHWC", ) verify_pool3d( - [1, 16, 32, 32, 32], + [1, 32, 32, 32, 16], [2, 2, 2], [2, 2, 2], [1, 1, 1], @@ -392,9 +482,10 @@ def test_pool3d(): "avg", False, False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [4, 4, 4], [4, 4, 4], [1, 1, 1], @@ -402,9 +493,10 @@ def test_pool3d(): "avg", False, False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [4, 4, 4], [4, 4, 4], [1, 1, 1], @@ -412,37 +504,41 @@ def test_pool3d(): "avg", False, False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 32, 32, 32], + [1, 32, 32, 32, 16], [2, 2, 2], [2, 2, 2], [1, 1, 1], [0, 0, 0, 0, 0, 0], "max", False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [3, 3, 3], [3, 3, 3], [1, 1, 1], [2, 2, 1, 1, 1, 2], "max", False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [3, 3, 3], [3, 3, 3], [1, 1, 1], [2, 2, 1, 1, 1, 2], "max", True, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [3, 3, 3], [3, 3, 3], [1, 1, 1], @@ -450,9 +546,10 @@ def test_pool3d(): "avg", False, True, + layout="NDHWC", ) verify_pool3d( - [1, 16, 32, 32, 32], + [1, 32, 32, 32, 16], [2, 2, 2], [2, 2, 2], [1, 1, 1], @@ -460,36 +557,32 @@ def test_pool3d(): "avg", False, False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [3, 3, 3], [3, 3, 3], [1, 1, 1], [1, 0, 5, 4, 3, 2], "max", False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [3, 3, 3], [3, 3, 3], [1, 1, 1], [3, 2, 1, 0, 5, 4], "max", True, + layout="NDHWC", ) # Test non-1 dilation verify_pool3d( - [1, 16, 31, 31, 31], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [2, 1, 0, 5, 4, 3], - "avg", - False, - True, + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [3, 3, 3], [2, 1, 0, 5, 4, 3], "avg", False, True ) verify_pool3d( [1, 16, 32, 32, 32], @@ -502,27 +595,23 @@ def test_pool3d(): False, ) verify_pool3d( - [1, 16, 31, 31, 31], - [3, 3, 3], - [3, 3, 3], - [2, 1, 3], - [1, 0, 5, 4, 3, 2], - "max", - False, + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [2, 1, 3], [1, 0, 5, 4, 3, 2], "max", False ) verify_pool3d( - [1, 16, 31, 31, 31], - [3, 3, 3], - [3, 3, 3], - [2, 2, 3], - [3, 2, 1, 0, 5, 4], - "max", - True, + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [2, 2, 3], [3, 2, 1, 0, 5, 4], "max", True ) def verify_pool2d( - input_shape, kernel, stride, dilation, padding, pool_type, ceil_mode, count_include_pad=True + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad=True, + layout="NCHW", ): verify_poolnd( 2, @@ -533,7 +622,7 @@ def verify_pool2d( padding, pool_type, ceil_mode, - layout="NCHW", + layout=layout, count_include_pad=count_include_pad, ) @@ -541,162 +630,69 @@ def verify_pool2d( @tvm.testing.uses_gpu def test_pool2d(): """test cases of pool""" + verify_pool2d([1, 16, 32, 32], [2, 2], [2, 2], [1, 1], [0, 0, 0, 0], "avg", False, True) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [1, 2, 1, 2], "avg", False, True) + verify_pool2d([1, 16, 32, 32], [2, 2], [2, 2], [1, 1], [1, 2, 1, 2], "avg", False, False) + verify_pool2d([1, 16, 31, 31], [4, 4], [4, 4], [1, 1], [3, 3, 3, 3], "avg", False, False) + verify_pool2d([1, 16, 31, 31], [4, 4], [4, 4], [1, 1], [0, 0, 0, 0], "avg", False, False) + verify_pool2d([1, 16, 32, 32], [2, 3], [2, 2], [1, 1], [0, 0, 0, 0], "max", False) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", False) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", True) + + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [2, 1, 0, 3], "avg", False, True) + verify_pool2d([1, 16, 32, 32], [2, 3], [2, 2], [1, 1], [0, 3, 2, 1], "avg", False, False) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [1, 0, 3, 2], "max", False) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [3, 2, 1, 0], "max", True) + + # Test non-1 dilations + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [2, 1], [2, 1, 0, 3], "avg", False, True) + verify_pool2d([1, 16, 32, 32], [2, 3], [2, 2], [2, 3], [0, 3, 2, 1], "avg", False, False) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [3, 3], [1, 0, 3, 2], "max", False) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [2, 2], [3, 2, 1, 0], "max", True) + # Test channel last verify_pool2d( - [1, 16, 32, 32], - [2, 2], - [2, 2], - [1, 1], - [0, 0, 0, 0], - "avg", - False, - True, - ) - verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [1, 1], - [1, 2, 1, 2], - "avg", - False, - True, + [1, 32, 32, 16], [2, 2], [2, 2], [1, 1], [0, 0, 0, 0], "avg", False, True, layout="NHWC" ) verify_pool2d( - [1, 16, 32, 32], - [2, 2], - [2, 2], - [1, 1], - [1, 2, 1, 2], - "avg", - False, - False, + [1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [1, 2, 1, 2], "avg", False, True, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [4, 4], - [4, 4], - [1, 1], - [3, 3, 3, 3], - "avg", - False, - False, + [1, 32, 32, 16], [2, 2], [2, 2], [1, 1], [1, 2, 1, 2], "avg", False, False, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [4, 4], - [4, 4], - [1, 1], - [0, 0, 0, 0], - "avg", - False, - False, + [1, 31, 31, 16], [4, 4], [4, 4], [1, 1], [3, 3, 3, 3], "avg", False, False, layout="NHWC" ) verify_pool2d( - [1, 16, 32, 32], - [2, 3], - [2, 2], - [1, 1], - [0, 0, 0, 0], - "max", - False, + [1, 31, 31, 16], [4, 4], [4, 4], [1, 1], [0, 0, 0, 0], "avg", False, False, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [1, 1], - [2, 1, 2, 1], - "max", - False, + [1, 32, 32, 16], [2, 3], [2, 2], [1, 1], [0, 0, 0, 0], "max", False, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [1, 1], - [2, 1, 2, 1], - "max", - True, + [1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", False, layout="NHWC" ) + verify_pool2d([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", True, layout="NHWC") verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [1, 1], - [2, 1, 0, 3], - "avg", - False, - True, - ) - verify_pool2d( - [1, 16, 32, 32], - [2, 3], - [2, 2], - [1, 1], - [0, 3, 2, 1], - "avg", - False, - False, - ) - verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [1, 1], - [1, 0, 3, 2], - "max", - False, + [1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [2, 1, 0, 3], "avg", False, True, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [1, 1], - [3, 2, 1, 0], - "max", - True, + [1, 32, 32, 16], [2, 3], [2, 2], [1, 1], [0, 3, 2, 1], "avg", False, False, layout="NHWC" ) - - # Test non-1 dilations verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [2, 1], - [2, 1, 0, 3], - "avg", - False, - True, + [1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [1, 0, 3, 2], "max", False, layout="NHWC" ) + verify_pool2d([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [3, 2, 1, 0], "max", True, layout="NHWC") verify_pool2d( - [1, 16, 32, 32], - [2, 3], - [2, 2], - [2, 3], - [0, 3, 2, 1], - "avg", - False, - False, + [1, 31, 31, 16], [3, 3], [3, 3], [2, 1], [2, 1, 0, 3], "avg", False, True, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [3, 3], - [1, 0, 3, 2], - "max", - False, + [1, 32, 32, 16], [2, 3], [2, 2], [2, 3], [0, 3, 2, 1], "avg", False, False, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [2, 2], - [3, 2, 1, 0], - "max", - True, + [1, 31, 31, 16], [3, 3], [3, 3], [3, 3], [1, 0, 3, 2], "max", False, layout="NHWC" ) + verify_pool2d([1, 31, 31, 16], [3, 3], [3, 3], [2, 2], [3, 2, 1, 0], "max", True, layout="NHWC") def verify_pool1d( @@ -719,7 +715,7 @@ def verify_pool1d( padding, pool_type, ceil_mode, - layout="NCW", + layout=layout, count_include_pad=count_include_pad, ) @@ -727,162 +723,43 @@ def verify_pool1d( @tvm.testing.uses_gpu def test_pool1d(): """test cases of pool1d""" - verify_pool1d( - [1, 16, 32], - [2], - [2], - [1], - [0, 0], - "avg", - False, - True, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [1], - [1, 2], - "avg", - False, - True, - ) - verify_pool1d( - [1, 16, 32], - [2], - [2], - [1], - [1, 2], - "avg", - False, - False, - ) - verify_pool1d( - [1, 16, 31], - [4], - [4], - [1], - [3, 3], - "avg", - False, - False, - ) - verify_pool1d( - [1, 16, 31], - [4], - [4], - [1], - [0, 0], - "avg", - False, - False, - ) - verify_pool1d( - [1, 16, 32], - [2], - [2], - [1], - [0, 0], - "max", - False, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [1], - [2, 1], - "max", - False, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [1], - [2, 1], - "max", - True, - ) - - verify_pool1d( - [1, 16, 31], - [3], - [3], - [1], - [2, 5], - "avg", - False, - True, - ) - verify_pool1d( - [1, 16, 32], - [2], - [2], - [1], - [0, 3], - "avg", - False, - False, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [1], - [1, 4], - "max", - False, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [1], - [3, 0], - "max", - True, - ) + verify_pool1d([1, 16, 32], [2], [2], [1], [0, 0], "avg", False, True) + verify_pool1d([1, 16, 31], [3], [3], [1], [1, 2], "avg", False, True) + verify_pool1d([1, 16, 32], [2], [2], [1], [1, 2], "avg", False, False) + verify_pool1d([1, 16, 31], [4], [4], [1], [3, 3], "avg", False, False) + verify_pool1d([1, 16, 31], [4], [4], [1], [0, 0], "avg", False, False) + verify_pool1d([1, 16, 32], [2], [2], [1], [0, 0], "max", False) + verify_pool1d([1, 16, 31], [3], [3], [1], [2, 1], "max", False) + verify_pool1d([1, 16, 31], [3], [3], [1], [2, 1], "max", True) + + verify_pool1d([1, 16, 31], [3], [3], [1], [2, 5], "avg", False, True) + verify_pool1d([1, 16, 32], [2], [2], [1], [0, 3], "avg", False, False) + verify_pool1d([1, 16, 31], [3], [3], [1], [1, 4], "max", False) + verify_pool1d([1, 16, 31], [3], [3], [1], [3, 0], "max", True) # Test non-1 dilations - verify_pool1d( - [1, 16, 31], - [3], - [3], - [2], - [2, 5], - "avg", - False, - True, - ) - verify_pool1d( - [1, 16, 32], - [2], - [2], - [3], - [0, 3], - "avg", - False, - False, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [2], - [1, 4], - "max", - False, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [3], - [3, 0], - "max", - True, - ) + verify_pool1d([1, 16, 31], [3], [3], [2], [2, 5], "avg", False, True) + verify_pool1d([1, 16, 32], [2], [2], [3], [0, 3], "avg", False, False) + verify_pool1d([1, 16, 31], [3], [3], [2], [1, 4], "max", False) + verify_pool1d([1, 16, 31], [3], [3], [3], [3, 0], "max", True) + # Test Channel last + verify_pool1d([1, 32, 16], [2], [2], [1], [0, 0], "avg", False, True, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [1], [1, 2], "avg", False, True, layout="NWC") + verify_pool1d([1, 32, 16], [2], [2], [1], [1, 2], "avg", False, False, layout="NWC") + verify_pool1d([1, 31, 16], [4], [4], [1], [3, 3], "avg", False, False, layout="NWC") + verify_pool1d([1, 31, 16], [4], [4], [1], [0, 0], "avg", False, False, layout="NWC") + verify_pool1d([1, 32, 16], [2], [2], [1], [0, 0], "max", False, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [1], [2, 1], "max", False, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [1], [2, 1], "max", True, layout="NWC") + + verify_pool1d([1, 31, 16], [3], [3], [1], [2, 5], "avg", False, True, layout="NWC") + verify_pool1d([1, 31, 16], [2], [2], [1], [0, 3], "avg", False, False, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [1], [1, 4], "max", False, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [1], [3, 0], "max", True, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [2], [2, 5], "avg", False, True, layout="NWC") + verify_pool1d([1, 32, 16], [2], [2], [3], [0, 3], "avg", False, False, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [2], [1, 4], "max", False, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [3], [3, 0], "max", True, layout="NWC") if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_prng.py b/tests/python/topi/python/test_topi_prng.py index b9ac51419772..60ef7b3b234c 100644 --- a/tests/python/topi/python/test_topi_prng.py +++ b/tests/python/topi/python/test_topi_prng.py @@ -57,7 +57,7 @@ def uniform(target, dev, gen, low, high, size, dtype): out_gen = tvm.nd.array(np.zeros(gen.shape, dtype="uint64")) rands = tvm.nd.array(np.zeros(size, dtype=dtype)) f(tvm.nd.array(gen), tvm.nd.array(low), tvm.nd.array(high), out_gen, rands) - return out_gen.asnumpy(), rands.asnumpy() + return out_gen.numpy(), rands.asnumpy() @tvm.testing.parametrize_targets @@ -143,7 +143,7 @@ def test_threefry_wrapping(target, dev): @tvm.testing.parametrize_targets def test_uniform(target, dev): - gen = tvm.relay.random.threefry_key(0).data.asnumpy() + gen = tvm.relay.random.threefry_key(0).data.numpy() m = 1024 n = 1024 dtypes = ["float32", "float64"] diff --git a/tests/python/topi/python/test_topi_relu.py b/tests/python/topi/python/test_topi_relu.py index 83007e16f81d..509d09781fa8 100644 --- a/tests/python/topi/python/test_topi_relu.py +++ b/tests/python/topi/python/test_topi_relu.py @@ -54,7 +54,7 @@ def test_relu(target, dev, m, n, dtype): b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) foo = tvm.build(s, [A, B], target, name="relu") foo(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) size, alpha = tvm.testing.parameters((100, 0.1)) diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py index 003b89f7122a..11006576fea3 100644 --- a/tests/python/topi/python/test_topi_sparse.py +++ b/tests/python/topi/python/test_topi_sparse.py @@ -35,21 +35,20 @@ } -def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True): +def verify_dynamic_csrmv(batch, in_dim, out_dim, dtype, use_bias=True): nr, nc, n = te.var("nr"), te.var("nc"), te.var("n") - dtype = "float32" A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name="A") - B = te.placeholder((in_dim, 1), name="B") - C = te.placeholder((nr,), name="C") + B = te.placeholder((in_dim, 1), dtype=dtype, name="B") + C = te.placeholder((nr,), dtype=dtype, name="C") D = topi.sparse.csrmv(A, B, C if use_bias else None) s = te.create_schedule(D.op) dtype = A.dtype # get the test data def get_ref_data(): - a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype) - 0.5, 0.0) - b_np = np.random.uniform(size=(in_dim, 1)).astype(dtype) - 0.5 - c_np = np.random.uniform(size=(batch,)).astype(dtype) + a_np = np.random.uniform(size=(batch, in_dim), high=100).astype(dtype) + b_np = np.random.uniform(size=(in_dim, 1), high=100).astype(dtype) + c_np = np.random.uniform(size=(batch,), high=100).astype(dtype) if use_bias: d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1)) else: @@ -81,21 +80,20 @@ def check_device(device): check_device(device) -def verify_dynamic_csrmm(batch, in_dim, out_dim, use_bias=True): +def verify_dynamic_csrmm(batch, in_dim, out_dim, dtype, use_bias=True): nr, nc, n = te.var("nr"), te.var("nc"), te.var("n") - dtype = "float32" A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name="A") - B = te.placeholder((in_dim, out_dim), name="B") - C = te.placeholder((nr,), name="C") + B = te.placeholder((in_dim, out_dim), dtype=dtype, name="B") + C = te.placeholder((nr,), dtype=dtype, name="C") D = topi.sparse.csrmm(A, B, C if use_bias else None) s = te.create_schedule(D.op) dtype = A.dtype # get the test data def get_ref_data(): - a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype) - 0.5, 0.0) - b_np = np.random.uniform(size=(in_dim, out_dim)).astype(dtype) - 0.5 - c_np = np.random.uniform(size=(batch,)).astype(dtype) + a_np = np.random.uniform(size=(batch, in_dim), high=100).astype(dtype) + b_np = np.random.uniform(size=(in_dim, out_dim), high=100).astype(dtype) + c_np = np.random.uniform(size=(batch,), high=100).astype(dtype) if use_bias: d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1)) else: @@ -212,14 +210,15 @@ def check_device(device): def test_csrmv(): - verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=False) - verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=True) + verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, dtype="float32", use_bias=False) + verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, dtype="float64", use_bias=True) + verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, dtype="int32", use_bias=True) def test_csrmm(): M, K, N = 5, 7, 2 - verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=False) - verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=True) + verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, dtype="int64", use_bias=False) + verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, dtype="float64", use_bias=True) def test_dense_si(): diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index ddde2e20e754..42d2463b8952 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -879,24 +879,24 @@ def test_transpose_unfused_schedule(target, dev): shape = (100, tvm.target.Target(target).thread_warp_size + 3) x = relay.var("x", relay.TensorType(shape, "float32")) f = relay.transpose(x) - ex = relay.create_executor( - kind="graph", mod=tvm.IRModule.from_expr(relay.Function([x], f)), device=dev, target=target - ) r = np.random.rand(*shape) - tvm.testing.assert_allclose(ex.evaluate()(r).asnumpy(), np.transpose(r)) + func = relay.create_executor( + kind="graph", mod=tvm.IRModule.from_expr(relay.Function([x], f)), device=dev, target=target + ).evaluate() + tvm.testing.assert_allclose(func(r).numpy(), np.transpose(r)) # We want to make sure schedule does not fire here, but there is no way of # inspecting which schedules were used. x = relay.var("x", relay.TensorType(shape, "float32")) y = relay.var("y", relay.TensorType(shape, "float32")) f = relay.transpose(x + y) - ex = relay.create_executor( + func = relay.create_executor( kind="graph", mod=tvm.IRModule.from_expr(relay.Function([x, y], f)), device=dev, target=target, - ) - tvm.testing.assert_allclose(ex.evaluate()(r, r).asnumpy(), np.transpose(r + r)) + ).evaluate() + tvm.testing.assert_allclose(func(r, r).numpy(), np.transpose(r + r)) @tvm.testing.uses_gpu diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index b34acb9ae359..c307034c04c9 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -285,6 +285,15 @@ def test_predicate(): ) assert len(res) == 0 + # zero iter + xo = tvm.tir.Var("xo", "int32"), 1 + xi = tvm.tir.Var("xi", "int32"), 129 + y = tvm.tir.Var("y", "int32"), 128 + + res = tvm.arith.detect_iter_map( + [xo[0] * 129 + xi[0], y[0]], var_dom([xo, xi, y]), xo[0] * 129 + xi[0] < 128 + ) + def convert_division(divisions): if divisions is None or len(divisions) == 0: diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index c3afa6c65627..641eed51d5cf 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -101,15 +101,16 @@ def test_vector_simplify(): ck.verify( fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), - ) + ) # Example negative case: x = 15; [60, 61, 62, 63, 64] / 64 = [0, 0, 0, 0, 1] ck.verify( fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), - ) + ) # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [0, 1, 1, 1] ck.verify( fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), - ) + ) # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [0, 1, 1, 1] + # floor mod ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2")) ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(flm(x, 2), 4)) @@ -136,16 +137,21 @@ def test_vector_simplify(): flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 8, 64), 2, 4) ) ck.verify( - flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), tvm.tir.Ramp(flm(x * 4, 64), 1, 5) - ) + flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), + flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), + ) # Example negative case: x = 15; [60, 61, 62, 63, 64] % 64 = [60, 61, 62, 63, 0] ck.verify( flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), - tvm.tir.Ramp(flm(x * 4 + 3, 64), 1, 4), - ) + flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), + ) # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [63, 0, 1, 2] + ck.verify( + flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)), + flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)), + ) # Example negative case: x = 9; [18, 19, 20, ..., 25] % 20 = [18, 19, 0, 1, ..., 5] ck.verify( flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), - ) + ) # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [63, 6, 13, 20] # Min/Max rules vx = te.var("vx", dtype="int32x2") @@ -275,6 +281,7 @@ def test_add_index_simplify(): def test_sub_index_simplify(): ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") + a, b = tvm.tir.Any(), tvm.tir.Any() ck.verify(x + y - y, x) ck.verify(x + y - x, y) @@ -293,6 +300,8 @@ def test_sub_index_simplify(): # mul co-efficient foldng ck.verify(x - x, 0) + ck.verify(a - a, 0) + ck.verify(a - b, a - b) ck.verify(x * y - x, x * (y + (-1))) ck.verify(x * y - 10 * x, x * (y + (-10))) ck.verify(y * x - x * z, x * (y - z)) diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py index b303ef56c1d2..81ee5cabbfbc 100644 --- a/tests/python/unittest/test_auto_scheduler_compute_dag.py +++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py @@ -23,7 +23,7 @@ from tvm import topi from tvm import auto_scheduler, te -from test_auto_scheduler_common import ( +from tvm.testing.auto_scheduler import ( get_tiled_matmul, invalid_compute_definition, matmul_auto_scheduler_test, @@ -62,6 +62,11 @@ def test_estimate_flop(): dag = auto_scheduler.ComputeDAG([A, B, F]) assert abs(dag.flop_ct - (2 * N ** 3 + 1234)) < 0.5 + A = te.placeholder((N, N), dtype="float32", name="A") + F = te.compute((N, N), lambda i, j: te.if_then_else(A[i, j] > 0, A[i, j], 0)) + dag = auto_scheduler.ComputeDAG([A, F]) + assert abs(dag.flop_ct - N ** 2) < 0.5 + def test_stage_order(): """Test if the stage order is preserved when recovering a DAG.""" diff --git a/tests/python/unittest/test_auto_scheduler_cost_model.py b/tests/python/unittest/test_auto_scheduler_cost_model.py index 0b34615583db..50e3ceb6f5fa 100644 --- a/tests/python/unittest/test_auto_scheduler_cost_model.py +++ b/tests/python/unittest/test_auto_scheduler_cost_model.py @@ -24,7 +24,7 @@ import tvm from tvm import auto_scheduler -from test_auto_scheduler_common import matmul_auto_scheduler_test +from tvm.testing.auto_scheduler import matmul_auto_scheduler_test def get_sample_records(number): diff --git a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py index e28219d0979f..b5c99c0f05fd 100644 --- a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py +++ b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py @@ -18,7 +18,7 @@ import tvm import pytest -from test_auto_scheduler_common import matmul_auto_scheduler_test +from tvm.testing.auto_scheduler import matmul_auto_scheduler_test from tvm import auto_scheduler, te from tvm.auto_scheduler.cost_model.cost_model import PythonBasedModel diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py index 82cfb1d6508b..96090e328328 100644 --- a/tests/python/unittest/test_auto_scheduler_feature.py +++ b/tests/python/unittest/test_auto_scheduler_feature.py @@ -23,7 +23,7 @@ import tvm from tvm import te, auto_scheduler -from test_auto_scheduler_common import matmul_auto_scheduler_test +from tvm.testing.auto_scheduler import matmul_auto_scheduler_test def fequal(a, b): diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index c9291965613b..39673fad2495 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -26,7 +26,7 @@ from tvm import topi from tvm import auto_scheduler, te -from test_auto_scheduler_common import get_tiled_matmul, matmul_auto_scheduler_test +from tvm.testing.auto_scheduler import get_tiled_matmul, matmul_auto_scheduler_test def test_apply_steps_with_layout_rewrite(): diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 44ed1fc42562..0965ed9efbac 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -23,7 +23,7 @@ from tvm import auto_scheduler, te from tvm import topi -from test_auto_scheduler_common import ( +from tvm.testing.auto_scheduler import ( matmul_auto_scheduler_test, conv2d_nchw_bn_relu_auto_scheduler_test, ) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index d82cfd447a40..9eae3dd33672 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -26,7 +26,7 @@ import tempfile import tvm.testing import pickle -from test_auto_scheduler_common import matmul_auto_scheduler_test +from tvm.testing.auto_scheduler import matmul_auto_scheduler_test from tvm.auto_scheduler import workload_registry @@ -293,6 +293,15 @@ def test_dag_measure_local_builder_runner(): assert mress[0].error_no == 0 +def test_workload_serialization(): + key = tvm.auto_scheduler.utils.get_func_name(matmul_auto_scheduler_test) + transfer_data = workload_registry.serialize_workload_registry_entry(key) + f_data = pickle.dumps(transfer_data) + f_new = pickle.loads(f_data) + del workload_registry.WORKLOAD_FUNC_REGISTRY[key] + workload_registry.deserialize_workload_registry_entry(f_new) + + def test_measure_local_builder_rpc_runner(): if not tvm.testing.device_enabled("llvm"): return @@ -423,6 +432,7 @@ def foo(): test_workload_dis_factor() test_measure_local_builder_runner() test_dag_measure_local_builder_runner() + test_workload_serialization() test_measure_local_builder_rpc_runner() test_measure_target_host() test_measure_special_inputs_map_by_name_local_runner() diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index d114ce4f9d16..a9f6596a8548 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -27,7 +27,7 @@ from tvm import auto_scheduler from tvm.auto_scheduler.utils import get_const_tuple -from test_auto_scheduler_common import ( +from tvm.testing.auto_scheduler import ( matmul_auto_scheduler_test, zero_rank_compute_auto_scheduler_test, zero_rank_reduce_auto_scheduler_test, diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py index cd47f1e468ff..f23b02c24298 100644 --- a/tests/python/unittest/test_auto_scheduler_search_task.py +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -24,7 +24,7 @@ import tvm.testing from tvm import auto_scheduler from tvm.auto_scheduler.utils import get_const_tuple -from test_auto_scheduler_common import ( +from tvm.testing.auto_scheduler import ( matmul_auto_scheduler_test, zero_rank_compute_auto_scheduler_test, zero_rank_reduce_auto_scheduler_test, diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py index 4092ae0b0500..6d2f870ca14d 100644 --- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py +++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py @@ -27,7 +27,7 @@ from tvm.auto_scheduler import _ffi_api from tvm.auto_scheduler.loop_state import Stage -from test_auto_scheduler_common import ( +from tvm.testing.auto_scheduler import ( matmul_auto_scheduler_test, double_matmul_auto_scheduler_test, conv2d_nchw_bn_relu_auto_scheduler_test, diff --git a/tests/python/unittest/test_auto_scheduler_task_scheduler.py b/tests/python/unittest/test_auto_scheduler_task_scheduler.py index bbe29b1ba4f9..a3f356929dd1 100644 --- a/tests/python/unittest/test_auto_scheduler_task_scheduler.py +++ b/tests/python/unittest/test_auto_scheduler_task_scheduler.py @@ -25,7 +25,7 @@ import tvm.testing from tvm import auto_scheduler -from test_auto_scheduler_common import matmul_auto_scheduler_test +from tvm.testing.auto_scheduler import matmul_auto_scheduler_test @tvm.testing.requires_llvm diff --git a/tests/python/unittest/test_autotvm_graph_tuner_core.py b/tests/python/unittest/test_autotvm_graph_tuner_core.py index 3d7d304f13d5..bcc43648de22 100644 --- a/tests/python/unittest/test_autotvm_graph_tuner_core.py +++ b/tests/python/unittest/test_autotvm_graph_tuner_core.py @@ -188,6 +188,49 @@ def test_graph_tuner_layout_transform(): ) +def test_graph_tuner_layout_transform_runner(): + log_file = "%s/test_tuner.log" % (os.getcwd()) + target = "llvm" + dshape = (1, 3, 8, 8) + dtype = "float32" + layout = "NCHW" + conv2d = relay.op.get("nn.conv2d") + target_ops = [conv2d] + + g, records, ltf_records, ltf_keys, _ = _create_data(target, dshape, dtype, layout) + executor = DPTuner(g, {"data": dshape}, records, target_ops, target=target, log_file=log_file) + runner = autotvm.LocalRunner(number=100, repeat=1, timeout=10) + executor.benchmark_layout_transform( + layout_records=ltf_records, infer_layout=True, runner=runner + ) + out = executor._layout_transform_perf_records + + num_flops = 0 + total_time = 0 + for record in ltf_records: + ltf_wkl = record[0].task.workload + input_shape = ltf_wkl[1][1] + flops = np.prod(input_shape) + num_flops += flops + total_time += record[1].costs[0] + avg_time = total_time / num_flops + + for ltf_workload in out: + input_shape = ltf_workload[1][1] + flops = 1 + for i in input_shape: + flops *= i + expected_time = flops * avg_time + out_time = out[ltf_workload][1].costs[0] + assert ( + expected_time == out_time + ), "Inferred layout transformation time mismatch for %s: " "expecting %f but got %f" % ( + str(ltf_workload), + expected_time, + out_time, + ) + + def test_DPTuner_run(): log_file = "%s/test_tuner.log" % (os.getcwd()) target = "llvm" diff --git a/tests/python/unittest/test_autotvm_measure.py b/tests/python/unittest/test_autotvm_measure.py index 9db9f18fa377..a89c69c37d64 100644 --- a/tests/python/unittest/test_autotvm_measure.py +++ b/tests/python/unittest/test_autotvm_measure.py @@ -26,6 +26,8 @@ from test_autotvm_common import DummyRunner, bad_matmul, get_sample_task from tvm import autotvm from tvm.autotvm.measure.measure import MeasureErrorNo, MeasureResult +from tvm.autotvm import measure +from inspect import Signature def test_task_tuner_without_measurement(): @@ -60,8 +62,30 @@ def test_task_tuner_without_measurement_spawn(): p.join() +def test_task_runner_with_ref_input(): + """test runner ref_input without measurement""" + refinp = [np.random.rand(128, 128) for i in range(3)] + runner = measure.LocalRunner() + runner.ref_input = refinp + + class DummyExecutor(measure.executor.Executor): + def __init__(self): + self.ran_dummy_executor = False + + def submit(self, func, *args, **kwargs): + self.ran_dummy_executor = True + sig = Signature.from_callable(func) + assert sig.bind(*args, **kwargs).arguments["ref_input"] == refinp + return measure.local_executor.LocalFutureNoFork(None) + + runner.executor = DummyExecutor() + runner.run([None], [None]) + assert runner.executor.ran_dummy_executor + + if __name__ == "__main__": logging.basicConfig(level=logging.INFO) test_task_tuner_without_measurement() test_task_tuner_without_measurement_spawn() + test_task_runner_with_ref_input() diff --git a/tests/python/unittest/test_autotvm_space.py b/tests/python/unittest/test_autotvm_space.py index 2d40371b3ef6..d56ca9e07214 100644 --- a/tests/python/unittest/test_autotvm_space.py +++ b/tests/python/unittest/test_autotvm_space.py @@ -84,6 +84,8 @@ def count4(n): cfg = FallbackConfigEntity() cfg.define_split("tile_n", cfg.axis(128), num_outputs=3) cfg.fallback_split("tile_n", [-1, 8, 4]) + # verify if define_split override previously manualy defined split params + cfg.define_split("tile_n", cfg.axis(128), num_outputs=3) assert cfg["tile_n"].size == [4, 8, 4] cfg = FallbackConfigEntity() diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 3ba508a40a77..586e9fbfb91e 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -19,7 +19,9 @@ import copy import glob import os +import pathlib import pytest +import shutil pytest.importorskip("pty") import sys @@ -43,46 +45,36 @@ TARGET = tvm.target.target.micro("host") -def _make_sess_from_op(workspace, op_name, sched, arg_bufs): +def _make_sess_from_op(temp_dir, op_name, sched, arg_bufs): with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): mod = tvm.build(sched, arg_bufs, Target(TARGET, TARGET), name=op_name) - return _make_session(workspace, mod) + return _make_session(temp_dir, mod) -def _make_session(workspace, mod): - compiler = tvm.micro.DefaultCompiler(target=TARGET) - opts = tvm.micro.default_options( - os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host") +def _make_session(temp_dir, mod): + template_project_dir = os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host") + project = tvm.micro.generate_project( + template_project_dir, mod, temp_dir / "project", {"verbose": 1} ) - micro_binary = tvm.micro.build_static_runtime( - workspace, - compiler, - mod, - opts, - extra_libs=[tvm.micro.get_standalone_crt_lib("memory")], - ) - - flasher_kw = { - "debug": DEBUG, - } - flasher = compiler.flasher(**flasher_kw) - return tvm.micro.Session(binary=micro_binary, flasher=flasher) + project.build() + project.flash() + return tvm.micro.Session(project.transport()) -def _make_add_sess(workspace): +def _make_add_sess(temp_dir): A = tvm.te.placeholder((2,), dtype="int8") B = tvm.te.placeholder((1,), dtype="int8") C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name="C") sched = tvm.te.create_schedule(C.op) - return _make_sess_from_op(workspace, "add", sched, [A, B, C]) + return _make_sess_from_op(temp_dir, "add", sched, [A, B, C]) -def _make_ident_sess(workspace): +def _make_ident_sess(temp_dir): A = tvm.te.placeholder((2,), dtype="int8") B = tvm.te.compute(A.shape, lambda i: A[i], name="B") sched = tvm.te.create_schedule(B.op) - return _make_sess_from_op(workspace, "ident", sched, [A, B]) + return _make_sess_from_op(temp_dir, "ident", sched, [A, B]) @tvm.testing.requires_micro @@ -90,9 +82,9 @@ def test_compile_runtime(): """Test compiling the on-device runtime.""" import tvm.micro - workspace = tvm.micro.Workspace() + temp_dir = tvm.contrib.utils.tempdir() - with _make_add_sess(workspace) as sess: + with _make_add_sess(temp_dir) as sess: A_data = tvm.nd.array(np.array([2, 3], dtype="int8"), device=sess.device) assert (A_data.numpy() == np.array([2, 3])).all() B_data = tvm.nd.array(np.array([4], dtype="int8"), device=sess.device) @@ -128,9 +120,9 @@ def test_reset(): import tvm.micro from tvm.micro import transport - workspace = tvm.micro.Workspace() + temp_dir = tvm.contrib.utils.tempdir() - with _make_add_sess(workspace) as sess: + with _make_add_sess(temp_dir) as sess: try: sess._rpc.get_function("tvm.testing.reset_server")() assert False, "expected to raise SessionTerminatedError; did not raise" @@ -141,9 +133,11 @@ def test_reset(): @tvm.testing.requires_micro def test_graph_executor(): """Test use of the graph executor with microTVM.""" - import tvm.micro - workspace = tvm.micro.Workspace(debug=True) + ws_root = pathlib.Path(os.path.dirname(__file__) + "/micro-workspace") + if ws_root.exists(): + shutil.rmtree(ws_root) + temp_dir = tvm.contrib.utils.tempdir(ws_root.resolve()) relay_mod = tvm.parser.fromtext( """ #[version = "0.0.5"] @@ -156,7 +150,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) { with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): factory = tvm.relay.build(relay_mod, target=TARGET) - with _make_session(workspace, factory.get_lib()) as sess: + with _make_session(temp_dir, factory) as sess: graph_mod = tvm.micro.create_local_graph_executor( factory.get_graph_json(), sess.get_system_lib(), sess.device ) @@ -176,9 +170,9 @@ def test_std_math_functions(): """Verify that standard math functions can be used.""" import tvm.micro - workspace = tvm.micro.Workspace() + temp_dir = tvm.contrib.utils.tempdir() - with _make_add_sess(workspace) as sess: + with _make_add_sess(temp_dir) as sess: A_data = tvm.nd.array(np.array([2, 3], dtype="int8"), device=sess.device) assert (A_data.numpy() == np.array([2, 3])).all() B_data = tvm.nd.array(np.array([4], dtype="int8"), device=sess.device) @@ -189,12 +183,12 @@ def test_std_math_functions(): system_lib = sess.get_system_lib() system_lib.get_function("add")(A_data, B_data, C_data) - workspace = tvm.micro.Workspace() + temp_dir = tvm.contrib.utils.tempdir() A = tvm.te.placeholder((2,), dtype="float32", name="A") B = tvm.te.compute(A.shape, lambda i: tvm.te.exp(A[i]), name="B") s = tvm.te.create_schedule(B.op) - with _make_sess_from_op(workspace, "myexpf", s, [A, B]) as sess: + with _make_sess_from_op(temp_dir, "myexpf", s, [A, B]) as sess: A_data = tvm.nd.array(np.array([2.0, 3.0], dtype="float32"), device=sess.device) B_data = tvm.nd.array(np.array([2.0, 3.0], dtype="float32"), device=sess.device) lib = sess.get_system_lib() @@ -208,12 +202,12 @@ def test_platform_timer(): """Verify the platform timer can be used to time remote functions.""" import tvm.micro - workspace = tvm.micro.Workspace() + temp_dir = tvm.contrib.utils.tempdir() A = tvm.te.placeholder((2,), dtype="float32", name="A") B = tvm.te.compute(A.shape, lambda i: tvm.te.exp(A[i]), name="B") s = tvm.te.create_schedule(B.op) - with _make_sess_from_op(workspace, "myexpf", s, [A, B]) as sess: + with _make_sess_from_op(temp_dir, "myexpf", s, [A, B]) as sess: A_data = tvm.nd.array(np.array([2.0, 3.0], dtype="float32"), device=sess.device) B_data = tvm.nd.array(np.array([2.0, 3.0], dtype="float32"), device=sess.device) lib = sess.get_system_lib() @@ -226,5 +220,4 @@ def test_platform_timer(): if __name__ == "__main__": - test_graph_executor() -# sys.exit(pytest.main([__file__] + sys.argv[1:])) + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_custom_datatypes.py b/tests/python/unittest/test_custom_datatypes.py index 5f962ef7f74f..b135973718bc 100644 --- a/tests/python/unittest/test_custom_datatypes.py +++ b/tests/python/unittest/test_custom_datatypes.py @@ -90,17 +90,17 @@ def change_dtype(src, dst, module, params): def compare(module, input, src_dtype, dst_dtype, rtol, atol, params={}, target="llvm"): module = relay.transform.InferType()(module) module = relay.transform.SimplifyInference()(module) - ex = relay.create_executor("graph", mod=module) - correct = ex.evaluate()(*input, **params) + correct = relay.create_executor("graph", mod=module).evaluate()(*input, **params) module, converted_params = change_dtype(src_dtype, dst_dtype, module, params) - ex = relay.create_executor("graph", mod=module, target=target) # converts all inputs to dst_dtype x_converted = [convert_ndarray(dst_dtype, arr) for arr in input] # Vectorization is not implemented with custom datatypes with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - maybe_correct = ex.evaluate()(*x_converted, **converted_params) + maybe_correct = relay.create_executor("graph", mod=module, target=target).evaluate()( + *x_converted, **converted_params + ) # currently this only works for comparing single output maybe_correct_converted = convert_ndarray(src_dtype, maybe_correct) np.testing.assert_allclose( diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py index 51799bba61fd..4f9cfffd0640 100644 --- a/tests/python/unittest/test_link_params.py +++ b/tests/python/unittest/test_link_params.py @@ -348,7 +348,7 @@ def _run_unlinked(lib_mod): @tvm.testing.requires_micro def test_crt_link_params(): - import tvm.micro + from tvm import micro for dtype in LINKABLE_DTYPES: mod, param_init = _make_mod_and_params(dtype) @@ -356,34 +356,21 @@ def test_crt_link_params(): main_func = mod["main"] target = "c --system-lib --runtime=c --link-params" with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - graph_json, lib, params = tvm.relay.build(mod, target, params=param_init) - assert set(params.keys()) == {"p0", "p1"} # NOTE: op folded + factory = tvm.relay.build(mod, target, params=param_init) + assert set(factory.get_params().keys()) == {"p0", "p1"} # NOTE: op folded - workspace = tvm.micro.Workspace() - compiler = tvm.micro.DefaultCompiler(target=target) - opts = tvm.micro.default_options( - os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host") + temp_dir = tvm.contrib.utils.tempdir() + template_project_dir = os.path.join( + tvm.micro.get_standalone_crt_dir(), "template", "host" ) - opts["bin_opts"]["ldflags"].append("-DTVM_HOST_USE_GRAPH_EXECUTOR_MODULE") - - micro_binary = tvm.micro.build_static_runtime( - workspace, - compiler, - lib, - compiler_options=opts, - extra_libs=[ - tvm.micro.get_standalone_crt_lib(m) - for m in ("memory", "graph_executor_module", "graph_executor") - ], + project = tvm.micro.generate_project( + template_project_dir, factory, temp_dir / "project", {"verbose": 1} ) - - flasher_kw = { - "debug": False, - } - flasher = compiler.flasher(**flasher_kw) - with tvm.micro.Session(binary=micro_binary, flasher=flasher) as sess: + project.build() + project.flash() + with tvm.micro.Session(project.transport()) as sess: graph_rt = tvm.micro.session.create_local_graph_executor( - graph_json, sess.get_system_lib(), sess.device + factory.get_graph_json(), sess.get_system_lib(), sess.device ) # NOTE: not setting params here. diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index 4505a7bed244..e5528a8c4756 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -50,6 +50,25 @@ def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: @tvm.script.tir class LoweredModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + # function attr dict + tir.func_attr( + {"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True} + ) + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + # body + for x, y in tir.grid(128, 128): + C.data[x * 128 + y] = 0.0 + for k in tir.serial(0, 128): + C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y) + tir.load( + "float32", A.data, x * 128 + k + ) * tir.load("float32", B.data, y * 128 + k) + + +@tvm.script.tir +class LoweredTIRModule: def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # function attr dict tir.func_attr({"global_symbol": "main", "tir.noalias": True}) @@ -83,7 +102,7 @@ def test_lower_build_te_schedule(): def test_lower_build_tir_func(): # check lowering ir_mod = tvm.lower(matmul) - tvm.ir.assert_structural_equal(ir_mod, LoweredModule()) + tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule()) # check building mod = tvm.build(matmul, target="llvm") _check_module_with_numpy(mod) @@ -95,7 +114,7 @@ def test_lower_build_tir_module(): ir_mod = IRModule({"main": func}) # check lowering lowered_mod = tvm.lower(ir_mod) - tvm.ir.assert_structural_equal(lowered_mod, LoweredModule()) + tvm.ir.assert_structural_equal(lowered_mod, LoweredTIRModule()) # check building mod = tvm.build(ir_mod, target="llvm") _check_module_with_numpy(mod) @@ -103,8 +122,8 @@ def test_lower_build_tir_module(): def test_lower_build_lowered_module(): # check lowering - ir_mod = tvm.lower(LoweredModule()) - tvm.ir.assert_structural_equal(ir_mod, LoweredModule()) + ir_mod = tvm.lower(LoweredTIRModule()) + tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule()) # check building mod = tvm.build(ir_mod, target="llvm") _check_module_with_numpy(mod) diff --git a/tests/python/unittest/test_micro_artifact.py b/tests/python/unittest/test_micro_artifact.py deleted file mode 100644 index fc180200720d..000000000000 --- a/tests/python/unittest/test_micro_artifact.py +++ /dev/null @@ -1,149 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Unit tests for the artifact module.""" - -import pytest -import json -import os -import shutil -import tvm - -from tvm.contrib import utils - -pytest.importorskip("tvm.micro") -from tvm.micro import artifact - -FILE_LIST = ["label1", "label2", "label12", "unlabelled"] - - -TEST_METADATA = {"foo": "bar"} - - -TEST_LABELS = {"label1": ["label1", "label12"], "label2": ["label2", "label12"]} - - -def build_artifact(artifact_path, immobile=False): - os.mkdir(artifact_path) - - for f in FILE_LIST: - with open(os.path.join(artifact_path, f), "w") as lib_f: - lib_f.write(f"{f}\n") - - sub_dir = os.path.join(artifact_path, "sub_dir") - os.mkdir(sub_dir) - os.symlink("label1", os.path.join(artifact_path, "rel_symlink")) - os.symlink("label2", os.path.join(artifact_path, "abs_symlink"), "label2") - os.symlink( - os.path.join(artifact_path, "sub_dir"), os.path.join(artifact_path, "abs_dir_symlink") - ) - - from tvm.micro import artifact - - art = artifact.Artifact(artifact_path, TEST_LABELS, TEST_METADATA, immobile=immobile) - - return art - - -@tvm.testing.requires_micro -def test_basic_functionality(): - temp_dir = utils.tempdir() - artifact_path = temp_dir.relpath("foo") - art = build_artifact(artifact_path) - - assert art.abspath("bar") == os.path.join(artifact_path, "bar") - - for label, paths in TEST_LABELS.items(): - assert art.label(label) == paths - assert art.label_abspath(label) == [os.path.join(artifact_path, p) for p in paths] - - -@tvm.testing.requires_micro -def test_archive(): - from tvm.micro import artifact - - temp_dir = utils.tempdir() - art = build_artifact(temp_dir.relpath("foo")) - - # Create archive - archive_path = art.archive(temp_dir.temp_dir) - assert archive_path == temp_dir.relpath("foo.tar") - - # Inspect created archive - unpack_dir = temp_dir.relpath("unpack") - os.mkdir(unpack_dir) - shutil.unpack_archive(archive_path, unpack_dir) - - for path in FILE_LIST: - with open(os.path.join(unpack_dir, "foo", path)) as f: - assert f.read() == f"{path}\n" - - with open(os.path.join(unpack_dir, "foo", "metadata.json")) as metadata_f: - metadata = json.load(metadata_f) - - assert metadata["version"] == 2 - assert metadata["labelled_files"] == TEST_LABELS - assert metadata["metadata"] == TEST_METADATA - - # Unarchive and verify basic functionality - unarchive_base_dir = temp_dir.relpath("unarchive") - unarch = artifact.Artifact.unarchive(archive_path, unarchive_base_dir) - - assert unarch.metadata == TEST_METADATA - assert unarch.labelled_files == TEST_LABELS - for f in FILE_LIST: - assert os.path.exists(os.path.join(unarchive_base_dir, f)) - - -@tvm.testing.requires_micro -def test_metadata_only(): - from tvm.micro import artifact - - temp_dir = utils.tempdir() - base_dir = temp_dir.relpath("foo") - art = build_artifact(base_dir) - - artifact_path = art.archive(temp_dir.relpath("foo.artifact"), metadata_only=True) - unarch_base_dir = temp_dir.relpath("bar") - unarch = artifact.Artifact.unarchive(artifact_path, unarch_base_dir) - assert unarch.base_dir == base_dir - - for p in unarch.label_abspath("label1") + unarch.label_abspath("label2"): - assert os.path.exists(p) - - os.unlink(art.abspath("label1")) - with open(art.abspath("label2"), "w+") as f: - f.write("changed line\n") - - try: - artifact.Artifact.unarchive(artifact_path, os.path.join(temp_dir.temp_dir, "bar2")) - assert False, "unarchive should raise error" - except artifact.ArchiveModifiedError as err: - assert str(err) == ( - "Files in metadata-only archive have been modified:\n" - " * label1: original file not found\n" - " * label2: sha256 mismatch: expected " - "6aa3c5668c8794c791400e19ecd7123949ded1616eafb0395acdd2d896354e83, got " - "ed87db21670a81819d65eccde87c5ae0243b2b61783bf77e9b27993be9a3eca0" - ) - - -if __name__ == "__main__": - test_basic_functionality() - test_archive() - test_metadata_only() - # TODO: tests for dir symlinks, symlinks out of bounds, loading malformed artifact tars. diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 246c0336a001..92c1174e728c 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -27,6 +27,7 @@ import tvm import tvm.relay from tvm.relay.backend import executor_factory +from tvm.relay.testing import byoc import tvm.runtime.module import tvm.testing from tvm.contrib import utils @@ -56,7 +57,7 @@ def test_export_operator_model_library_format(): with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 4 + assert metadata["version"] == 5 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" @@ -89,7 +90,7 @@ def test_export_operator_model_library_format(): def validate_graph_json(extract_dir, factory): - with open(os.path.join(extract_dir, "runtime-config", "graph", "graph.json")) as graph_f: + with open(os.path.join(extract_dir, "executor-config", "graph", "graph.json")) as graph_f: graph_json = graph_f.read() assert graph_json == factory.graph_json @@ -102,14 +103,20 @@ def validate_graph_json(extract_dir, factory): @tvm.testing.requires_micro @pytest.mark.parametrize( - "target", + "executor,target,should_generate_interface", [ - ("graph", tvm.target.target.micro("host")), - ("aot", tvm.target.target.micro("host", options="-executor=aot")), + ("graph", tvm.target.target.micro("host"), False), + ("aot", tvm.target.target.micro("host", options="-executor=aot"), False), + ( + "aot", + tvm.target.target.micro( + "host", options="-executor=aot --unpacked-api=1 --interface-api=c" + ), + True, + ), ], ) -def test_export_model_library_format_c(target): - executor, _target = target +def test_export_model_library_format_c(executor, target, should_generate_interface): with utils.TempDirectory.set_keep_for_debug(True): with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): relay_mod = tvm.parser.fromtext( @@ -122,8 +129,8 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ ) factory = tvm.relay.build( relay_mod, - _target, - target_host=_target, + target, + target_host=target, mod_name="add", params={"c": numpy.array([[2.0, 4.0]], dtype="float32")}, ) @@ -141,13 +148,13 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 4 + assert metadata["version"] == 5 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" ) assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) - assert metadata["target"] == {"1": str(_target)} + assert metadata["target"] == {"1": str(target)} if executor == "graph": assert metadata["memory"]["sids"] == [ {"storage_id": 0, "size_bytes": 2, "input_binding": "a"}, @@ -173,6 +180,9 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "add_lib0.c")) assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "add_lib1.c")) + assert should_generate_interface == os.path.exists( + os.path.join(extract_dir, "codegen", "host", "include", "tvmgen_add.h") + ) if executor == "graph": validate_graph_json(extract_dir, factory) @@ -221,7 +231,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 4 + assert metadata["version"] == 5 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" @@ -265,13 +275,9 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ @tvm.testing.requires_micro @pytest.mark.parametrize( "target", - [ - ("graph", tvm.target.target.micro("host")), - ("aot", tvm.target.target.micro("host", options="-executor=aot")), - ], + [tvm.target.target.micro("host"), tvm.target.target.micro("host", options="-executor=aot")], ) def test_export_model_library_format_workspace(target): - executor, _target = target with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): relay_mod = tvm.parser.fromtext( """ @@ -285,7 +291,7 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 } """ ) - factory = tvm.relay.build(relay_mod, _target, target_host=_target, mod_name="qnn_conv2d") + factory = tvm.relay.build(relay_mod, target, target_host=target, mod_name="qnn_conv2d") temp_dir = utils.tempdir() mlf_tar_path = temp_dir.relpath("lib.tar") @@ -300,13 +306,13 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 4 + assert metadata["version"] == 5 assert metadata["model_name"] == "qnn_conv2d" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" ) assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) - assert metadata["target"] == {"1": str(_target)} + assert metadata["target"] == {"1": str(target)} assert metadata["memory"]["functions"]["main"] == [ { "constants_size_bytes": 0, @@ -327,9 +333,6 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 @tvm.testing.requires_micro def test_export_non_dso_exportable(): module = tvm.support.FrontendTestModule() - factory = executor_factory.GraphExecutorFactoryModule( - None, tvm.target.target.micro("host"), '"graph_json"', module, "test_module", {}, {} - ) temp_dir = utils.tempdir() import tvm.micro as micro @@ -343,5 +346,69 @@ def test_export_non_dso_exportable(): ) +@tvm.testing.requires_micro +def test_export_byoc_c_module(): + """Test BYOC flow when it produces DSO-exportable modules. + + NOTE the general BYOC flow is not fully supported by Model Library Format right now. + """ + x = tvm.relay.var("x", shape=(10, 10)) + w0 = tvm.relay.var("w0", shape=(10, 10)) + w1 = tvm.relay.var("w1", shape=(10, 10)) + w2 = tvm.relay.var("w2", shape=(10, 10)) + w3 = tvm.relay.var("w3", shape=(10, 10)) + w4 = tvm.relay.var("w4", shape=(10, 10)) + w5 = tvm.relay.var("w5", shape=(10, 10)) + w6 = tvm.relay.var("w6", shape=(10, 10)) + w7 = tvm.relay.var("w7", shape=(10, 10)) + + # C compiler + z0 = tvm.relay.add(x, w0) + p0 = tvm.relay.subtract(z0, w1) + q0 = tvm.relay.multiply(p0, w2) + + z1 = tvm.relay.add(x, w3) + p1 = tvm.relay.subtract(z1, w4) + q1 = tvm.relay.multiply(p1, w5) + + # Other parts on TVM + z2 = tvm.relay.add(x, w6) + q2 = tvm.relay.subtract(z2, w7) + + r = tvm.relay.concatenate((q0, q1, q2), axis=0) + f = tvm.relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) + mod = tvm.IRModule() + ann = byoc.CcompilerAnnotator() + mod["main"] = ann.visit(f) + mod = tvm.relay.transform.PartitionGraph("mod_name")(mod) + mod = tvm.relay.transform.InferType()(mod) + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + factory = tvm.relay.build(mod, tvm.target.target.micro("host")) + + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + + from tvm import micro + + micro.export_model_library_format(factory, mlf_tar_path) + + with tarfile.open(mlf_tar_path, "r:*") as tf: + tar_members = [ti.name for ti in tf.getmembers()] + print("tar members", tar_members) + assert "./metadata.json" in tar_members + with tf.extractfile("./metadata.json") as f: + metadata = json.load(f) + main_md = metadata["memory"]["functions"]["main"] + assert main_md == [ + { + "constants_size_bytes": 0, + "device": 1, + "io_size_bytes": 4800, + "workspace_size_bytes": 800, + } + ] + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_micro_project_api.py b/tests/python/unittest/test_micro_project_api.py new file mode 100644 index 000000000000..b5e2a57c122c --- /dev/null +++ b/tests/python/unittest/test_micro_project_api.py @@ -0,0 +1,424 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import collections +import io +import json +import sys +import unittest +from unittest import mock + +import pytest + +import tvm +from tvm.micro import project_api + + +class BaseTestHandler(project_api.server.ProjectAPIHandler): + + DEFAULT_TEST_SERVER_INFO = project_api.server.ServerInfo( + platform_name="platform_name", + is_template=True, + model_library_format_path="./model-library-format-path.sh", + project_options=[ + project_api.server.ProjectOption(name="foo", help="Option foo"), + project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"), + ], + ) + + def server_info_query(self, tvm_version): + return self.DEFAULT_TEST_SERVER_INFO + + def generate_project(self, model_library_format_path, crt_path, project_path, options): + assert False, "generate_project is not implemented for this test" + + def build(self, options): + assert False, "build is not implemented for this test" + + def flash(self, options): + assert False, "flash is not implemented for this test" + + def open_transport(self, options): + assert False, "open_transport is not implemented for this test" + + def close_transport(self, options): + assert False, "open_transport is not implemented for this test" + + def read_transport(self, n, timeout_sec): + assert False, "read_transport is not implemented for this test" + + def write_transport(self, data, timeout_sec): + assert False, "write_transport is not implemented for this test" + + +class Transport: + def readable(self): + return True + + def writable(self): + return True + + def seekable(self): + return False + + closed = False + + def __init__(self): + self.data = bytearray() + self.rpos = 0 + + self.items = [] + + def read(self, size=-1): + to_read = len(self.data) - self.rpos + if size != -1: + to_read = min(size, to_read) + + rpos = self.rpos + self.rpos += to_read + return self.data[rpos : self.rpos] + + def write(self, data): + self.data.extend(data) + + +class ClientServerFixture: + def __init__(self, handler): + self.handler = handler + self.client_to_server = Transport() + self.server_to_client = Transport() + + self.server = project_api.server.ProjectAPIServer( + self.client_to_server, self.server_to_client, handler + ) + self.client = project_api.client.ProjectAPIClient( + self.server_to_client, + self.client_to_server, + testonly_did_write_request=self._process_server_request, + ) + + self.expect_failure = False + + def _process_server_request(self): + assert self.server.serve_one_request() == ( + not self.expect_failure + ), "Server failed to process request" + + +def test_server_info_query(): + fixture = ClientServerFixture(BaseTestHandler()) + + # Examine reply explicitly because these are the defaults for all derivative test cases. + reply = fixture.client.server_info_query(tvm.__version__) + assert reply["protocol_version"] == 1 + assert reply["platform_name"] == "platform_name" + assert reply["is_template"] == True + assert reply["model_library_format_path"] == "./model-library-format-path.sh" + assert reply["project_options"] == [ + {"name": "foo", "choices": None, "help": "Option foo"}, + {"name": "bar", "choices": ["qux"], "help": "Option bar"}, + ] + + +def test_server_info_query_wrong_tvm_version(): + def server_info_query(tvm_version): + raise project_api.server.UnsupportedTVMVersionError() + + with mock.patch.object(BaseTestHandler, "server_info_query", side_effect=server_info_query): + fixture = ClientServerFixture(BaseTestHandler()) + with pytest.raises(project_api.server.UnsupportedTVMVersionError) as exc_info: + fixture.client.server_info_query(tvm.__version__) + + assert "UnsupportedTVMVersionError" in str(exc_info.value) + + +def test_server_info_query_wrong_protocol_version(): + ServerInfoProtocol = collections.namedtuple( + "ServerInfoProtocol", list(project_api.server.ServerInfo._fields) + ["protocol_version"] + ) + + def server_info_query(tvm_version): + return ServerInfoProtocol( + protocol_version=0, **BaseTestHandler.DEFAULT_TEST_SERVER_INFO._asdict() + ) + + with mock.patch.object(BaseTestHandler, "server_info_query", side_effect=server_info_query): + fixture = ClientServerFixture(BaseTestHandler()) + with pytest.raises(project_api.client.UnsupportedProtocolVersionError) as exc_info: + fixture.client.server_info_query(tvm.__version__) + + assert "microTVM API Server supports protocol version 0; want 1" in str(exc_info.value) + + +def test_base_test_handler(): + """All methods should raise AssertionError on BaseTestHandler.""" + fixture = ClientServerFixture(BaseTestHandler()) + + for method in dir(fixture.handler): + if method.startswith("_") or not callable(method) or method == "server_info_query": + continue + + with self.assertThrows(AssertionError) as exc_info: + getattr(fixture.client, method)() + + assert (exc_info.exception) == f"{method} is not implemented for this test" + + +def test_build(): + with mock.patch.object(BaseTestHandler, "build", return_value=None) as patch: + fixture = ClientServerFixture(BaseTestHandler()) + fixture.client.build(options={"bar": "baz"}) + + fixture.handler.build.assert_called_once_with(options={"bar": "baz"}) + + +def test_flash(): + with mock.patch.object(BaseTestHandler, "flash", return_value=None) as patch: + fixture = ClientServerFixture(BaseTestHandler()) + fixture.client.flash(options={"bar": "baz"}) + fixture.handler.flash.assert_called_once_with(options={"bar": "baz"}) + + +def test_open_transport(): + timeouts = project_api.server.TransportTimeouts( + session_start_retry_timeout_sec=1.0, + session_start_timeout_sec=2.0, + session_established_timeout_sec=3.0, + ) + + with mock.patch.object(BaseTestHandler, "open_transport", return_value=timeouts) as patch: + fixture = ClientServerFixture(BaseTestHandler()) + assert fixture.client.open_transport(options={"bar": "baz"}) == { + "timeouts": dict(timeouts._asdict()) + } + fixture.handler.open_transport.assert_called_once_with({"bar": "baz"}) + + +def test_close_transport(): + with mock.patch.object(BaseTestHandler, "close_transport", return_value=None) as patch: + fixture = ClientServerFixture(BaseTestHandler()) + fixture.client.close_transport() + fixture.handler.close_transport.assert_called_once_with() + + +def test_read_transport(): + with mock.patch.object(BaseTestHandler, "read_transport", return_value=b"foo\x1b") as patch: + fixture = ClientServerFixture(BaseTestHandler()) + assert fixture.client.read_transport(128, timeout_sec=5.0) == {"data": b"foo\x1b"} + + fixture.handler.read_transport.assert_called_with(128, 5.0) + + fixture.handler.read_transport.side_effect = project_api.server.IoTimeoutError + with pytest.raises(project_api.server.IoTimeoutError) as exc_info: + fixture.client.read_transport(256, timeout_sec=10.0) + + fixture.handler.read_transport.assert_called_with(256, 10.0) + + fixture.handler.read_transport.side_effect = project_api.server.TransportClosedError + with pytest.raises(project_api.server.TransportClosedError) as exc_info: + fixture.client.read_transport(512, timeout_sec=15.0) + + fixture.handler.read_transport.assert_called_with(512, 15.0) + + assert fixture.handler.read_transport.call_count == 3 + + +def test_write_transport(): + with mock.patch.object(BaseTestHandler, "write_transport", return_value=None) as patch: + fixture = ClientServerFixture(BaseTestHandler()) + assert fixture.client.write_transport(b"foo", timeout_sec=5.0) is None + fixture.handler.write_transport.assert_called_with(b"foo", 5.0) + + fixture.handler.write_transport.side_effect = project_api.server.IoTimeoutError + with pytest.raises(project_api.server.IoTimeoutError) as exc_info: + fixture.client.write_transport(b"bar", timeout_sec=10.0) + + fixture.handler.write_transport.assert_called_with(b"bar", 10.0) + + fixture.handler.write_transport.side_effect = project_api.server.TransportClosedError + with pytest.raises(project_api.server.TransportClosedError) as exc_info: + fixture.client.write_transport(b"baz", timeout_sec=15.0) + + fixture.handler.write_transport.assert_called_with(b"baz", 15.0) + + assert fixture.handler.write_transport.call_count == 3 + + +class ProjectAPITestError(Exception): + """An error raised in test.""" + + +def test_method_raises_error(): + with mock.patch.object( + BaseTestHandler, "close_transport", side_effect=ProjectAPITestError + ) as patch: + fixture = ClientServerFixture(BaseTestHandler()) + with pytest.raises(project_api.server.ServerError) as exc_info: + fixture.client.close_transport() + + fixture.handler.close_transport.assert_called_once_with() + assert "ProjectAPITestError" in str(exc_info.value) + + +def test_method_not_found(): + fixture = ClientServerFixture(BaseTestHandler()) + + with pytest.raises(project_api.server.JSONRPCError) as exc_info: + fixture.client._request_reply("invalid_method", {"bar": None}) + + assert exc_info.value.code == project_api.server.ErrorCode.METHOD_NOT_FOUND + + +def test_extra_param(): + fixture = ClientServerFixture(BaseTestHandler()) + + # test one with has_preprocssing and one without + assert hasattr(fixture.server, "_dispatch_build") == False + with pytest.raises(project_api.server.JSONRPCError) as exc_info: + fixture.client._request_reply("build", {"invalid_param_name": None, "options": {}}) + + assert exc_info.value.code == project_api.server.ErrorCode.INVALID_PARAMS + assert "build: extra parameters: invalid_param_name" in str(exc_info.value) + + assert hasattr(fixture.server, "_dispatch_open_transport") == True + with pytest.raises(project_api.server.JSONRPCError) as exc_info: + fixture.client._request_reply("open_transport", {"invalid_param_name": None, "options": {}}) + + assert exc_info.value.code == project_api.server.ErrorCode.INVALID_PARAMS + assert "open_transport: extra parameters: invalid_param_name" in str(exc_info.value) + + +def test_missing_param(): + fixture = ClientServerFixture(BaseTestHandler()) + + # test one with has_preprocssing and one without + assert hasattr(fixture.server, "_dispatch_build") == False + with pytest.raises(project_api.server.JSONRPCError) as exc_info: + fixture.client._request_reply("build", {}) + + assert exc_info.value.code == project_api.server.ErrorCode.INVALID_PARAMS + assert "build: parameter options not given" in str(exc_info.value) + + assert hasattr(fixture.server, "_dispatch_open_transport") == True + with pytest.raises(project_api.server.JSONRPCError) as exc_info: + fixture.client._request_reply("open_transport", {}) + + assert exc_info.value.code == project_api.server.ErrorCode.INVALID_PARAMS + assert "open_transport: parameter options not given" in str(exc_info.value) + + +def test_incorrect_param_type(): + fixture = ClientServerFixture(BaseTestHandler()) + + # The error message given at the JSON-RPC server level doesn't make sense when preprocessing is + # used. Only test without preprocessing here. + assert hasattr(fixture.server, "_dispatch_build") == False + with pytest.raises(project_api.server.JSONRPCError) as exc_info: + fixture.client._request_reply("build", {"options": None}) + + assert exc_info.value.code == project_api.server.ErrorCode.INVALID_PARAMS + assert "build: parameter options: want , got " in str( + exc_info.value + ) + + +def test_invalid_request(): + fixture = ClientServerFixture(BaseTestHandler()) + + # Invalid JSON does not get a reply. + fixture.client_to_server.write(b"foobar\n") + assert fixture.server.serve_one_request() == False + assert fixture.server_to_client.read() == b"" + + # EOF causes a clean return + assert fixture.server.serve_one_request() == False + assert fixture.server_to_client.read() == b"" + + def _request_reply(request): + fixture.client_to_server.write(request + b"\n") + assert fixture.server.serve_one_request() == False + return json.loads(fixture.server_to_client.read()) + + # Parseable JSON with the wrong schema gets a reply. + assert _request_reply(b"1") == { + "error": { + "code": project_api.server.ErrorCode.INVALID_REQUEST, + "data": None, + "message": "request: want dict; got 1", + }, + "id": None, + "jsonrpc": "2.0", + } + + # Incorrect JSON-RPC spec version. + assert _request_reply(b'{"jsonrpc": 1.0}') == { + "error": { + "code": project_api.server.ErrorCode.INVALID_REQUEST, + "data": None, + "message": 'request["jsonrpc"]: want "2.0"; got 1.0', + }, + "id": None, + "jsonrpc": "2.0", + } + + # Method not a str + assert _request_reply(b'{"jsonrpc": "2.0", "method": 123}') == { + "error": { + "code": project_api.server.ErrorCode.INVALID_REQUEST, + "data": None, + "message": 'request["method"]: want str; got 123', + }, + "id": None, + "jsonrpc": "2.0", + } + + # Method name has invalid characters + assert _request_reply(b'{"jsonrpc": "2.0", "method": "bar!"}') == { + "error": { + "code": project_api.server.ErrorCode.INVALID_REQUEST, + "data": None, + "message": "request[\"method\"]: should match regex ^[a-zA-Z0-9_]+$; got 'bar!'", + }, + "id": None, + "jsonrpc": "2.0", + } + + # params not a dict + assert _request_reply(b'{"jsonrpc": "2.0", "method": "bar", "params": 123}') == { + "error": { + "code": project_api.server.ErrorCode.INVALID_REQUEST, + "data": None, + "message": "request[\"params\"]: want dict; got ", + }, + "id": None, + "jsonrpc": "2.0", + } + + # id not valid + assert _request_reply(b'{"jsonrpc": "2.0", "method": "bar", "params": {}, "id": {}}') == { + "error": { + "code": project_api.server.ErrorCode.INVALID_REQUEST, + "data": None, + "message": 'request["id"]: want str, number, null; got {}', + }, + "id": None, + "jsonrpc": "2.0", + } + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_micro_transport.py b/tests/python/unittest/test_micro_transport.py index b0f99681af2e..a188e612763f 100644 --- a/tests/python/unittest/test_micro_transport.py +++ b/tests/python/unittest/test_micro_transport.py @@ -132,25 +132,25 @@ def test_transport_logger(self): transport.to_return = 3 transport_logger.write(b"data", 3.0) assert test_log.records[-1].getMessage() == ( - "foo: write { 3.00s} <- [ 3 B]: 64 61 74 " - " dat" + "foo: write { 3.00s} <- [ 4 B]: 64 61 74 61" + " data" ) # Normal log, multi-line data written. transport.to_return = 20 transport_logger.write(b"data" * 6, 3.0) assert test_log.records[-1].getMessage() == ( - "foo: write { 3.00s} <- [ 20 B]:\n" + "foo: write { 3.00s} <- [ 24 B]:\n" "0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n" - "0010 64 61 74 61 data" + "0010 64 61 74 61 64 61 74 61 datadata" ) # Lack of timeout prints. transport.to_return = 3 transport_logger.write(b"data", None) assert test_log.records[-1].getMessage() == ( - "foo: write { None } <- [ 3 B]: 64 61 74 " - " dat" + "foo: write { None } <- [ 4 B]: 64 61 74 61" + " data" ) # IoTimeoutError includes the timeout value. diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py index 781fd7f93886..4c72f2c6083b 100644 --- a/tests/python/unittest/test_runtime_container.py +++ b/tests/python/unittest/test_runtime_container.py @@ -48,8 +48,7 @@ def test_tuple_object(): fn = relay.Function([x], relay.expr.TupleGetItem(x, 0)) mod = tvm.IRModule.from_expr(fn) - exe = relay.create_executor(kind="vm", mod=mod, device=nd.cpu(), target="llvm") - f = exe.evaluate() + f = relay.create_executor(kind="vm", mod=mod, device=nd.cpu(), target="llvm").evaluate() value_tuple = _container.tuple_object([nd.array(np.array(11)), nd.array(np.array(12))]) # pass an ADT object to evaluate out = f(value_tuple) diff --git a/tests/python/unittest/test_runtime_graph.py b/tests/python/unittest/test_runtime_graph.py index 1259e77afbf8..458952fb5641 100644 --- a/tests/python/unittest/test_runtime_graph.py +++ b/tests/python/unittest/test_runtime_graph.py @@ -65,9 +65,8 @@ def check_verify(): out = mod.get_output(0, tvm.nd.empty((n,))) np.testing.assert_equal(out.numpy(), a + 1) - def check_remote(): + def check_remote(server): mlib = tvm.build(s, [A, B], "llvm", name="myadd") - server = rpc.Server("127.0.0.1") remote = rpc.connect(server.host, server.port) temp = utils.tempdir() dev = remote.cpu(0) @@ -115,7 +114,7 @@ def check_sharing(): del mod check_verify() - check_remote() + check_remote(rpc.Server("127.0.0.1")) check_sharing() diff --git a/tests/python/unittest/test_runtime_graph_debug.py b/tests/python/unittest/test_runtime_graph_debug.py index 192e0dad702f..cadc8ae6a4c0 100644 --- a/tests/python/unittest/test_runtime_graph_debug.py +++ b/tests/python/unittest/test_runtime_graph_debug.py @@ -32,6 +32,7 @@ @tvm.testing.requires_llvm +@tvm.testing.requires_rpc def test_graph_simple(): n = 4 A = te.placeholder((n,), name="A") @@ -160,9 +161,8 @@ def split_debug_line(i): # verify dump root delete after cleanup assert not os.path.exists(directory) - def check_remote(): + def check_remote(server): mlib = tvm.build(s, [A, B], "llvm", name="myadd") - server = rpc.Server("127.0.0.1") remote = rpc.connect(server.host, server.port) temp = utils.tempdir() dev = remote.cpu(0) @@ -182,7 +182,7 @@ def check_remote(): np.testing.assert_equal(out.numpy(), a + 1) check_verify() - check_remote() + check_remote(rpc.Server("127.0.0.1")) if __name__ == "__main__": diff --git a/tests/python/unittest/test_runtime_measure.py b/tests/python/unittest/test_runtime_measure.py index 0d02f910a44c..8955b03241a2 100644 --- a/tests/python/unittest/test_runtime_measure.py +++ b/tests/python/unittest/test_runtime_measure.py @@ -20,6 +20,7 @@ import tvm from tvm import te from tvm.contrib.utils import tempdir +from tvm.runtime.module import BenchmarkResult def test_min_repeat_ms(): @@ -56,5 +57,15 @@ def my_debug(filename): assert ct > 10 + 2 +def test_benchmark_result(): + r = BenchmarkResult([1, 2, 2, 5]) + assert r.mean == 2.5 + assert r.median == 2.0 + assert r.min == 1 + assert r.max == 5 + assert r.std == 1.5 + + if __name__ == "__main__": test_min_repeat_ms() + test_benchmark_result() diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 9bb05dfed65f..e984979ac14f 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -275,29 +275,31 @@ def verify_rpc_gpu_export(obj_format): from tvm import rpc - server = rpc.Server("127.0.0.1", port=9094) - remote = rpc.connect(server.host, server.port) - remote.upload(path_lib) - loaded_lib = remote.load_module(path_lib) - data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") - dev = remote.cuda() - - # raw api - gmod = loaded_lib["default"](dev) - set_input = gmod["set_input"] - run = gmod["run"] - get_output = gmod["get_output"] - set_input("data", tvm.nd.array(data, device=dev)) - run() - out = get_output(0).numpy() - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - - # graph executor wrapper - gmod = graph_executor.GraphModule(loaded_lib["default"](dev)) - gmod.set_input("data", data) - gmod.run() - out = gmod.get_output(0).numpy() - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + def check_remote(server): + remote = rpc.connect(server.host, server.port) + remote.upload(path_lib) + loaded_lib = remote.load_module(path_lib) + data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32") + dev = remote.cuda() + + # raw api + gmod = loaded_lib["default"](dev) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data, device=dev)) + run() + out = get_output(0).numpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph executor wrapper + gmod = graph_executor.GraphModule(loaded_lib["default"](dev)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).numpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + check_remote(rpc.Server("127.0.0.1")) for obj_format in [".so", ".tar"]: verify_cpu_export(obj_format) diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index ee8032550b39..8306f2f67fa1 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -18,6 +18,8 @@ import pytest from io import StringIO import csv +import os +import json import tvm.testing from tvm.runtime import profiler_vm @@ -26,6 +28,24 @@ from tvm.contrib.debugger import debug_executor +def read_csv(report): + f = StringIO(report.csv()) + headers = [] + rows = [] + reader = csv.reader(f, delimiter=",") + # force parsing + in_header = True + for row in reader: + if in_header: + headers = row + in_header = False + rows = [[] for x in headers] + else: + for i in range(len(row)): + rows[i].append(row[i]) + return dict(zip(headers, rows)) + + @pytest.mark.skipif(not profiler_vm.enabled(), reason="VM Profiler not enabled") @tvm.testing.parametrize_targets def test_vm(target, dev): @@ -39,14 +59,9 @@ def test_vm(target, dev): assert "fused_nn_softmax" in str(report) assert "Total" in str(report) - f = StringIO(report.csv()) - reader = csv.reader(f, delimiter=",") - # force parsing - in_header = True - for row in reader: - if in_header: - assert "Hash" in row - in_header = False + csv = read_csv(report) + assert "Hash" in csv.keys() + assert all([float(x) > 0 for x in csv["Duration (us)"]]) @tvm.testing.parametrize_targets @@ -61,3 +76,60 @@ def test_graph_executor(target, dev): assert "fused_nn_softmax" in str(report) assert "Total" in str(report) assert "Hash" in str(report) + + +@tvm.testing.parametrize_targets("cuda", "llvm") +@pytest.mark.skipif( + tvm.get_global_func("runtime.profiling.PAPIMetricCollector", allow_missing=True) is None, + reason="PAPI profiling not enabled", +) +def test_papi(target, dev): + target = tvm.target.Target(target) + if str(target.kind) == "llvm": + metric = "PAPI_FP_OPS" + elif str(target.kind) == "cuda": + metric = "cuda:::event:shared_load:device=0" + else: + pytest.skip(f"Target {target.kind} not supported by this test") + mod, params = mlp.get_workload(1) + + exe = relay.vm.compile(mod, target, params=params) + vm = profiler_vm.VirtualMachineProfiler(exe, dev) + + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32"), device=dev) + report = vm.profile( + [data], + func_name="main", + collectors=[tvm.runtime.profiling.PAPIMetricCollector({dev: [metric]})], + ) + print(report) + assert metric in str(report) + + csv = read_csv(report) + assert metric in csv.keys() + assert any([float(x) > 0 for x in csv[metric]]) + + +@tvm.testing.requires_llvm +def test_json(): + mod, params = mlp.get_workload(1) + + exe = relay.vm.compile(mod, "llvm", params=params) + vm = profiler_vm.VirtualMachineProfiler(exe, tvm.cpu()) + + data = np.random.rand(1, 1, 28, 28).astype("float32") + report = vm.profile(data, func_name="main") + parsed = json.loads(report.json()) + assert "device_metrics" in parsed + assert "calls" in parsed + assert "Duration (us)" in parsed["calls"][0] + assert "microseconds" in parsed["calls"][0]["Duration (us)"] + assert len(parsed["calls"]) > 0 + for call in parsed["calls"]: + assert isinstance(call["Name"], str) + assert isinstance(call["Count"]["count"], int) + assert isinstance(call["Duration (us)"]["microseconds"], float) + + +if __name__ == "__main__": + test_papi("llvm", tvm.cpu()) diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index f90c9548ec02..22aea8d1fcea 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -53,6 +53,9 @@ ), ) +# NOTE: When writing tests, wrap remote related checking in a sub-function +# to ensure all the remote resources destructs before the server terminates + @tvm.testing.requires_rpc def test_bigendian_rpc(): @@ -90,38 +93,49 @@ def verify_rpc(remote, target, shape, dtype): def test_rpc_simple(): server = rpc.Server(key="x1") client = rpc.connect("127.0.0.1", server.port, key="x1") - f1 = client.get_function("rpc.test.addone") - assert f1(10) == 11 - f3 = client.get_function("rpc.test.except") - with pytest.raises(tvm._ffi.base.TVMError): - f3("abc") + def check_remote(): + f1 = client.get_function("rpc.test.addone") + assert f1(10) == 11 + f3 = client.get_function("rpc.test.except") + + with pytest.raises(tvm._ffi.base.TVMError): + f3("abc") + + f2 = client.get_function("rpc.test.strcat") + assert f2("abc", 11) == "abc:11" - f2 = client.get_function("rpc.test.strcat") - assert f2("abc", 11) == "abc:11" + check_remote() @tvm.testing.requires_rpc def test_rpc_runtime_string(): server = rpc.Server(key="x1") client = rpc.connect("127.0.0.1", server.port, key="x1") - func = client.get_function("rpc.test.runtime_str_concat") - x = tvm.runtime.container.String("abc") - y = tvm.runtime.container.String("def") - assert str(func(x, y)) == "abcdef" + + def check_remote(): + func = client.get_function("rpc.test.runtime_str_concat") + x = tvm.runtime.container.String("abc") + y = tvm.runtime.container.String("def") + assert str(func(x, y)) == "abcdef" + + check_remote() @tvm.testing.requires_rpc def test_rpc_array(): - x = np.ones((3, 4)) - server = rpc.Server() remote = rpc.connect("127.0.0.1", server.port) - r_cpu = tvm.nd.array(x, remote.cpu(0)) - assert str(r_cpu.device).startswith("remote") - np.testing.assert_equal(r_cpu.numpy(), x) - fremote = remote.get_function("rpc.test.remote_array_func") - fremote(r_cpu) + + def check_remote(): + x = np.ones((3, 4)) + r_cpu = tvm.nd.array(x, remote.cpu(0)) + assert str(r_cpu.device).startswith("remote") + np.testing.assert_equal(r_cpu.numpy(), x) + fremote = remote.get_function("rpc.test.remote_array_func") + fremote(r_cpu) + + check_remote() @tvm.testing.requires_rpc @@ -129,13 +143,17 @@ def test_rpc_large_array(): # testcase of large array creation server = rpc.Server() remote = rpc.connect("127.0.0.1", server.port) - dev = remote.cpu(0) - a_np = np.ones((5041, 720)).astype("float32") - b_np = np.ones((720, 192)).astype("float32") - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - np.testing.assert_equal(a.numpy(), a_np) - np.testing.assert_equal(b.numpy(), b_np) + + def check_remote(): + dev = remote.cpu(0) + a_np = np.ones((5041, 720)).astype("float32") + b_np = np.ones((720, 192)).astype("float32") + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + np.testing.assert_equal(a.numpy(), a_np) + np.testing.assert_equal(b.numpy(), b_np) + + check_remote() @tvm.testing.requires_rpc @@ -186,10 +204,14 @@ def check_minrpc(): def test_rpc_file_exchange(): server = rpc.Server() remote = rpc.connect("127.0.0.1", server.port) - blob = bytearray(np.random.randint(0, 10, size=(10))) - remote.upload(blob, "dat.bin") - rev = remote.download("dat.bin") - assert rev == blob + + def check_remote(): + blob = bytearray(np.random.randint(0, 10, size=(10))) + remote.upload(blob, "dat.bin") + rev = remote.download("dat.bin") + assert rev == blob + + check_remote() @tvm.testing.requires_rpc @@ -321,9 +343,13 @@ def check_remote_link_cl(remote): def test_rpc_return_func(): server = rpc.Server(key="x1") client = rpc.connect("127.0.0.1", server.port, key="x1") - f1 = client.get_function("rpc.test.add_to_lhs") - fadd = f1(10) - assert fadd(12) == 22 + + def check_remote(): + f1 = client.get_function("rpc.test.add_to_lhs") + fadd = f1(10) + assert fadd(12) == 22 + + check_remote() @tvm.testing.requires_rpc @@ -386,14 +412,18 @@ def run_arr_test(): @tvm.testing.requires_rpc def test_local_func(): client = rpc.LocalSession() - f1 = client.get_function("rpc.test.add_to_lhs") - fadd = f1(10) - assert fadd(12) == 22 - - blob = bytearray(np.random.randint(0, 10, size=(10))) - client.upload(blob, "dat.bin") - rev = client.download("dat.bin") - assert rev == blob + + def check_remote(): + f1 = client.get_function("rpc.test.add_to_lhs") + fadd = f1(10) + assert fadd(12) == 22 + + blob = bytearray(np.random.randint(0, 10, size=(10))) + client.upload(blob, "dat.bin") + rev = client.download("dat.bin") + assert rev == blob + + check_remote() @tvm.testing.requires_rpc diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index e3422ae45945..75b61d281840 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -29,7 +29,9 @@ def test_basic(dev, target): return exe = relay.vm.compile(mod, target, params=params) - vm = profiler_vm.VirtualMachineProfiler(exe, dev) + code, lib = exe.save() + des_exe = tvm.runtime.vm.Executable.load_exec(code, lib) + vm = profiler_vm.VirtualMachineProfiler(des_exe, dev) data = np.random.rand(1, 1, 28, 28).astype("float32") res = vm.profile(tvm.nd.array(data), func_name="main") diff --git a/tests/python/unittest/test_target_codegen_device.py b/tests/python/unittest/test_target_codegen_device.py index 99b504219c14..b4181fb7b014 100644 --- a/tests/python/unittest/test_target_codegen_device.py +++ b/tests/python/unittest/test_target_codegen_device.py @@ -45,7 +45,7 @@ def check_target(device): assert a.numpy()[0] == value + 3 check_target("cuda") - check_target("vulkan") + check_target("vulkan -from_device=0") @tvm.testing.requires_gpu diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index 98340f0e6ac5..56392ec8cccc 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -17,6 +17,7 @@ import tvm from tvm import te import tvm.testing +import re target = "opencl" @@ -120,6 +121,25 @@ def check_max(dev, n, dtype): check_max(dev, 1, "float64") +def test_opencl_erf(): + def check_erf(dev, n, dtype): + A = te.placeholder((n,), name="A", dtype=dtype) + C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C") + s = te.create_schedule(C.op) + s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, C], target) + source_str = fun.imported_modules[0].get_source() + matches = re.findall("erf", source_str) + error_matches = re.findall("erff", source_str) + assert len(matches) == 1 and len(error_matches) == 0 + + dev = tvm.device(target, 0) + + check_erf(dev, 1, "float32") + check_erf(dev, 1, "float64") + + if __name__ == "__main__": test_opencl_ternary_expression() test_opencl_inf_nan() + test_opencl_erf() diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 0551fcd54855..1edc5d311759 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -220,8 +220,7 @@ def do_copy(A, B, n): def check_mod(target, dev, mod, x_np, res_np): - ex = relay.create_executor("vm", mod=mod, device=dev, target=target) - res = ex.evaluate()(x_np).numpy() + res = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()(x_np).numpy() tvm.testing.assert_allclose(res, res_np, atol=1e-5) @@ -433,5 +432,129 @@ def do_compute(A, B, n): tvm.testing.assert_allclose(b.numpy(), a_np) +class TestVectorizedIndices: + load_type, store_type = tvm.testing.parameters( + # Load N values, write to N locations. + # Vectorized copy. + ("ramp", "ramp"), + # Load 1 value, write to N locations. + # Scalar load, vectorized store. + # + # Most TVM operations (e.g. schedule[tensor].vectorize(axis)) have + # the broadcast outside of the index, but it is semantically okay + # for the broadcast to be inside the index, and it shows up with + # some optimizations. + ("broadcast", "ramp"), + # Load 1 values, write to 1 location. + # Broadcasting on both sides should be equivalent to a scalar copy. + ("broadcast", "broadcast"), + # Loads N values, write to 1 location. + # Disabled as it would have unclear semantics. + # ("ramp","broadcoast"), + ) + indirect_indices = tvm.testing.parameter(True, False, ids=["reorder", "no_reorder"]) + + @tvm.testing.fixture + def ref_data(self, load_type, store_type, indirect_indices): + n = 4 + + index_map = { + "ramp": np.arange(n), + "broadcast": np.zeros(n, dtype="int32"), + } + + a_np = np.random.randint(np.iinfo("int32").max, size=n).astype("int32") + b_np = np.zeros(shape=n, dtype=a_np.dtype) + reorder_np = np.arange(n, dtype="int32")[::-1] + + load_index = index_map[load_type] + store_index = index_map[store_type] + + if indirect_indices: + load_index = reorder_np[load_index] + + b_np[store_index] = a_np[load_index] + + return a_np, reorder_np, b_np + + @tvm.testing.fixture + def mod(self, target, load_type, store_type, indirect_indices): + target = tvm.target.Target(target) + + n = 4 + dtype = "int32" + A = te.placeholder((n,), dtype=dtype, name="A") + R = te.placeholder((n,), dtype=dtype, name="R") + + def do_compute(ins, outs): + ib = tvm.tir.ir_builder.create() + A, R = map(ib.buffer_ptr, ins) + B = ib.buffer_ptr(outs[0]) + + if "gpu" in target.keys: + ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0) + + index_map = { + "ramp": tvm.tir.Ramp(0, 1, 4), + "broadcast": tvm.tir.Broadcast(0, 4), + } + + load_index = index_map[load_type] + store_index = index_map[store_type] + + if indirect_indices: + load_index = tvm.tir.expr.Load("int32x4", R, load_index) + + transfer = tvm.tir.expr.Load("int32x4", A, load_index) + ib.emit(tvm.tir.stmt.Store(B, transfer, store_index)) + + return ib.get() + + B = te.extern(A.shape, [A, R], do_compute, dtype="int32") + s = te.create_schedule(B.op) + + return tvm.lower(s, [A, R, B]) + + def test_ramp_broadcast_index(self, target, dev, mod, ref_data): + f = tvm.build(mod, target=target) + + a_np, reorder_np, b_np = ref_data + a = tvm.nd.array(a_np, dev) + r = tvm.nd.array(reorder_np, dev) + b = tvm.nd.array(np.zeros(shape=b_np.shape, dtype="int32"), dev) + f(a, r, b) + tvm.testing.assert_allclose(b.numpy(), b_np) + + +@tvm.testing.parametrize_targets("vulkan -max_shared_memory_per_block=16384") +def test_shared_mem_alloc(target, dev): + alloc_nbytes = 16384 * 2 + + def do_compute(ins, outs): + ib = tvm.tir.ir_builder.create() + out = ib.buffer_ptr(outs[0]) + + ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0) + + array = ib.allocate("int32", (alloc_nbytes,), name="array", scope="shared") + array[0] = 0 + out[0] = array[0] + + return ib.get() + + Out = te.extern( + shape=(1,), + inputs=[], + fcompute=do_compute, + dtype="int32", + ) + s = te.create_schedule(Out.op) + + # Codegen should raise error when allocating more memory than the + # target supports. + with pytest.raises(tvm.TVMError): + tvm.build(s, [Out], target) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index bb3aa9e86267..0a0ad49a7767 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -76,6 +76,19 @@ def test_target_string_parse(): assert tvm.target.arm_cpu().device_name == "arm_cpu" +def test_target_string_with_spaces(): + target = tvm.target.Target( + "vulkan -device_name='Name of GPU with spaces' -device_type=discrete" + ) + assert target.attrs["device_name"] == "Name of GPU with spaces" + assert target.attrs["device_type"] == "discrete" + + target = tvm.target.Target(str(target)) + + assert target.attrs["device_name"] == "Name of GPU with spaces" + assert target.attrs["device_type"] == "discrete" + + def test_target_create(): targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"), vta(), bifrost()] for tgt in targets: diff --git a/tests/python/unittest/test_target_texture_codegen_opencl.py b/tests/python/unittest/test_target_texture_codegen_opencl.py new file mode 100644 index 000000000000..03944c85ade5 --- /dev/null +++ b/tests/python/unittest/test_target_texture_codegen_opencl.py @@ -0,0 +1,1400 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys + +import numpy as np +import pytest + +import tvm +from tvm import autotvm +from tvm import te +from tvm.topi import testing +from tvm.topi.utils import get_const_tuple, simplify +from tvm.topi import nn + + +def compute_plus_one_rank3(shape): + X = te.placeholder(shape, name="X", dtype="float32") + Y = te.compute(shape, lambda i, j, k: X[i, j, k] + 1, name="Compute_Y") + return X, Y + + +def schedule_plus_one_rank3(X, Y): + s = te.create_schedule(Y.op) + # Xt = s.cache_read(X, "texture", [Y]) + # Xt = s.cache_read(X, "global", [Y]) + Xt = s.cache_read(X, "global.texture", [Y]) + + # copy to texture stage + x, y, c = s[Xt].op.axis + s[Xt].bind(x, te.thread_axis("blockIdx.x")) + s[Xt].bind(y, te.thread_axis("threadIdx.x")) + s[Xt].vectorize(c) + + # the compute stage + x, y, c = s[Y].op.axis + xo, yo, xi, yi = s[Y].tile(x, y, 4, 4) + s[Y].bind(xo, te.thread_axis("blockIdx.x")) + s[Y].bind(yo, te.thread_axis("threadIdx.x")) + s[Y].vectorize(c) + return s + + +def compute_plus_one_rank5(shape): + X = te.placeholder(shape, name="X", dtype="float32") + Y = te.compute(shape, lambda i, j, k, l, m: X[i, j, k, l, m] + 1, name="Compute_Y") + return X, Y + + +def schedule_plus_one_rank5(X, Y): + s = te.create_schedule(Y.op) + Xt = s.cache_read(X, "global.texture", [Y]) + + # copy to texture stage + a, b, c, d, e = s[Xt].op.axis + abc = s[Xt].fuse(a, b, c) + s[Xt].bind(abc, te.thread_axis("blockIdx.x")) + s[Xt].bind(d, te.thread_axis("threadIdx.x")) + s[Xt].vectorize(e) + + # the compute stage + a, b, c, d, e = s[Y].op.axis + abc = s[Y].fuse(a, b, c) + xo, yo, xi, yi = s[Y].tile(abc, d, 4, 4) + s[Y].bind(xo, te.thread_axis("blockIdx.x")) + s[Y].bind(yo, te.thread_axis("threadIdx.x")) + s[Y].vectorize(e) + return s + + +def compute_matmul(shape): + A = te.placeholder(shape, name="A", dtype="float32") + B = te.placeholder(shape, name="B", dtype="float32") + k = te.reduce_axis((0, shape[1]), name="k") + C = te.compute( + (shape[0] * shape[2], shape[0] * shape[2]), + lambda i, j: te.sum( + A[i // shape[2], k, i % shape[2]].astype("float32") + * B[j // shape[2], k, j % shape[2]].astype("float32"), + axis=[k], + ), + name="Compute_MatMul", + ) + return A, B, C + + +def schedule_matmul(A, B, C, local=False): + s = te.create_schedule(C.op) + At = s.cache_read(A, "global.texture", [C]) + Bt = s.cache_read(B, "global.texture", [C]) + if local: + Al = s.cache_read(At, "local", [C]) + Bl = s.cache_read(Bt, "local", [C]) + Cl = s.cache_write(C, "local") + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + + def copy_to_texture(stage): + _io, _k, _ii = s[stage].op.axis + s[stage].vectorize(_ii) + s[stage].bind(_io, bx) + s[stage].bind(_k, tx) + + copy_to_texture(At) + copy_to_texture(Bt) + + # copy to global stage + _i, _j = s[C].op.axis + xo, yo, xi, yi = s[C].tile(_i, _j, 4, 4) + s[C].unroll(xi) + s[C].vectorize(yi) + s[C].bind(xo, te.thread_axis("blockIdx.x")) + s[C].bind(yo, te.thread_axis("threadIdx.x")) + + # the compute stage + s[Cl].compute_at(s[C], yo) + (_k,) = Cl.op.reduce_axis + _x, _y = s[Cl].op.axis + s[Cl].reorder(_k, _x, _y) + s[Cl].unroll(_x) + s[Cl].vectorize(_y) + + if local: + s[Al].compute_at(s[Cl], _k) + s[Al].vectorize(s[Al].op.axis[-1]) + s[Bl].compute_at(s[Cl], _k) + s[Bl].vectorize(s[Bl].op.axis[-1]) + + return s + + +def compute_matmul_inner(shape): + A = te.placeholder(shape, name="A", dtype="float32") + B = te.placeholder(shape, name="B", dtype="float32") + k = te.reduce_axis((0, shape[1] * shape[2]), name="k") + # (M, K) x (N, K) + # (32, 256) x (32, 256) + # (32, 64, 4) x (32, 64, 4) + C = te.compute( + (shape[0], shape[0]), + lambda i, j: te.sum( + A[i, k // shape[2], k % shape[2]].astype("float32") + * B[j, k // shape[2], k % shape[2]].astype("float32"), + axis=[k], + ), + name="Compute_MatMul", + ) + return A, B, C + + +def schedule_matmul_inner(A, B, C, local=False): + s = te.create_schedule(C.op) + At = s.cache_read(A, "global.texture", [C]) + Bt = s.cache_read(B, "global.texture", [C]) + if local: + Al = s.cache_read(At, "local", [C]) + Bl = s.cache_read(Bt, "local", [C]) + Cl = s.cache_write(C, "local") + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + + def copy_to_texture(stage): + _i, _ko, _ki = s[stage].op.axis + s[stage].vectorize(_ki) + s[stage].bind(_i, bx) + s[stage].bind(_ko, tx) + + copy_to_texture(At) + copy_to_texture(Bt) + + # copy to global stage + _i, _j = s[C].op.axis + xo, yo, xi, yi = s[C].tile(_i, _j, 4, 4) + s[C].unroll(xi) + s[C].vectorize(yi) + s[C].bind(xo, te.thread_axis("blockIdx.x")) + s[C].bind(yo, te.thread_axis("threadIdx.x")) + + # the compute stage + s[Cl].compute_at(s[C], yo) + (_k,) = Cl.op.reduce_axis + _x, _y = s[Cl].op.axis + s[Cl].reorder(_x, _y, _k) + s[Cl].unroll(_x) + # TODO(csullivan): consider whether the below error is worth resolving + # s[Cl].vectorize(_y) # error + + if local: + s[Al].compute_at(s[Cl], _x) + s[Al].vectorize(s[Al].op.axis[-1]) + s[Bl].compute_at(s[Cl], _x) + s[Bl].vectorize(s[Bl].op.axis[-1]) + + return s + + +def compute_matmul_vector_accumulator(shapeA, shapeB): + # A x B + # (K/4, M, K%4) x (K, N/4, N%4) = (M, N) + # (32, 64, 4) x (128, 16, 4) = (64, 64) + A = te.placeholder(shapeA, name="A", dtype="float32") + B = te.placeholder(shapeB, name="B", dtype="float32") + k = te.reduce_axis((0, shapeB[0]), name="k") + C = te.compute( + (shapeA[1], shapeB[1] * shapeB[2]), + lambda i, j: te.sum( + A[k // shapeA[-1], i, k % shapeA[-1]].astype("float32") + * B[k, j // shapeB[-1], j % shapeB[-1]].astype("float32"), + axis=[k], + ), + name="Compute_MatMul", + ) + return A, B, C + + +def schedule_matmul_vector_accumulator(A, B, C, local=False): + s = te.create_schedule(C.op) + At = s.cache_read(A, "global.texture", [C]) + Bt = s.cache_read(B, "global.texture", [C]) + if local: + Al = s.cache_read(At, "local", [C]) + Bl = s.cache_read(Bt, "local", [C]) + Cl = s.cache_write(C, "local") + + def copy_to_texture(stage): + _y, _x, _v = s[stage].op.axis + # TODO(csullivan): removing this vectorize results in numerical errors, autovectorize + s[stage].vectorize(_v) + s[stage].bind(_y, te.thread_axis("blockIdx.x")) + s[stage].bind(_x, te.thread_axis("threadIdx.x")) + + copy_to_texture(At) + copy_to_texture(Bt) + + # copy to global stage + _i, _j = s[C].op.axis + xo, yo, xi, yi = s[C].tile(_i, _j, 4, 4) + s[C].unroll(xi) + s[C].vectorize(yi) + s[C].bind(xo, te.thread_axis("blockIdx.x")) + s[C].bind(yo, te.thread_axis("threadIdx.x")) + + # the compute stage + s[Cl].compute_at(s[C], yo) + (_k,) = Cl.op.reduce_axis + _a, _b = s[Cl].op.axis + _ko, _ki = s[Cl].split(_k, factor=4) + s[Cl].reorder(_ko, _a, _ki, _b) + s[Cl].unroll(_ki) + s[Cl].unroll(_a) + s[Cl].vectorize(_b) + + if local: + s[Al].compute_at(s[Cl], _a) + _aa, _ka, _ba = s[Al].op.axis + # TODO(csullivan)[BEFORE PR]: removing this vectorize command causes a crash. This needs to be autovectorized. + s[Al].vectorize(_ba) + s[Bl].compute_at(s[Cl], _ko) + _ab, _kb, _bb = s[Bl].op.axis + s[Bl].vectorize(_bb) + s[Bl].unroll(_ab) + + return s + + +def compute_conv2d_1x1_NCHWc_RSCKk(input_shape, filter_shape): + # conv2d( [N, C, H, W, c] , [1, 1, C, K, k] + data = te.placeholder(input_shape, name="data", dtype="float32") + filt = te.placeholder(filter_shape, name="filter", dtype="float32") + c = te.reduce_axis((0, input_shape[1]), name="C") + c4 = te.reduce_axis((0, input_shape[-1]), name="c4") + kh = te.reduce_axis((0, filter_shape[0]), name="kh") + kw = te.reduce_axis((0, filter_shape[1]), name="kw") + conv = te.compute( + (input_shape[0], filter_shape[-2], input_shape[2], input_shape[3], filter_shape[-1]), + lambda n, ko, i, j, ki: te.sum( + data[n, c, i, j, c4].astype("float32") + * filt[kh, kw, c * input_shape[-1] + c4, ko, ki].astype("float32"), + axis=[kh, kw, c, c4], + ), + # name="Compute_conv2d_1x1_NCHWc_RSCKk", + name="conv2d_1x1", + ) + return data, filt, conv + + +def schedule_conv2d_1x1_NCHWc_RSCKk(data, filt, conv): + # inputs: (1, 128//4, 56, 56, 4), (1, 1, 128, 128//4, 4) + # outputs: + s = te.create_schedule(conv.op) + A, B, C = data, filt, conv + At = s.cache_read(A, "global.texture", [C]) + Bt = s.cache_read(B, "global.texture", [C]) + Al = s.cache_read(At, "local", [C]) + Bl = s.cache_read(Bt, "local", [C]) + Cl = s.cache_write(C, "local") + + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + + copy_to_texture(At) + copy_to_texture(Bt) + + _n, _ko, _h, _w, _ki = s[C].op.axis + s[C].vectorize(_ki) + s[C].bind(_n, te.thread_axis("blockIdx.x")) + s[C].bind(_ko, te.thread_axis("threadIdx.x")) + + s[Cl].compute_at(s[C], _w) + _nl, _kol, _hl, _wl, _kil = s[Cl].op.axis + _khl, _kwl, _cl, _cl4 = s[Cl].op.reduce_axis + _clo, _cli = s[Cl].split(_cl, factor=4) + s[Cl].reorder(_clo, _cli, _cl4, _kil) + s[Cl].unroll(_cli) + s[Cl].unroll(_cl4) + s[Cl].vectorize(_kil) + + s[Al].compute_at(s[Cl], _cli) + s[Al].vectorize(s[Al].op.axis[-1]) + s[Bl].compute_at(s[Cl], _kwl) + s[Bl].vectorize(s[Bl].op.axis[-1]) + + return s + + +def compute_conv2d_1x1_WCHNc_CRSKk(input_shape, filter_shape): + # input_shape = [W, C, H, N, c] -> [W, C, H*N, c] + # filter_shape = [C, R, S, K, k] -> [C, R*S*K, k] + # output_shape: [WK, HN, k] -> [W, K, H, N, k] + data = te.placeholder(input_shape, name="data", dtype="float32") + filt = te.placeholder(filter_shape, name="filter", dtype="float32") + + packed_data = te.compute( + (input_shape[0], input_shape[1], input_shape[2] * input_shape[3], input_shape[4]), + lambda i, j, k, l: data[i, j, k // input_shape[3], k % input_shape[3], l], + name="packed_data", + ) + + # Logical transformation of Nd -> 3d tensor + # CRSKk -> C|RSK|k + # r = rsk // SK + # sk = rsk % SK + # s = sk // K == (rsk % SK) // K == (rsk // K) % S + # k = sk % K == (rsk % SK) % K == rsk % K + packed_filter = te.compute( + (filter_shape[0], filter_shape[1] * filter_shape[2] * filter_shape[3], filter_shape[4]), + lambda i, j, k: filt[ + i, + j // (filter_shape[3] * filter_shape[2]), + (j // filter_shape[3]) % filter_shape[2], + j % filter_shape[3], + k, + ], + name="packed_filter", + ) + + c = te.reduce_axis((0, input_shape[1]), name="C") + c4 = te.reduce_axis((0, input_shape[-1]), name="c4") + r = te.reduce_axis((0, filter_shape[1]), name="r") + s = te.reduce_axis((0, filter_shape[2]), name="s") + + conv = te.compute( + (input_shape[0], filter_shape[3], input_shape[2], input_shape[3], filter_shape[4]), + lambda w, ko, h, n, ki: te.sum( + packed_data[w, c, h * input_shape[3] + n, c4].astype("float32") + * packed_filter[ + c * input_shape[-1] + c4, ((r * filter_shape[2]) + s) * filter_shape[3] + ko, ki + ].astype("float32"), + axis=[r, s, c, c4], + ), + name="conv2d_1x1", + ) + return data, filt, packed_data, packed_filter, conv + + +def schedule_conv2d_1x1_WCHNc_CRSKk(data, filt, packed_data, packed_filter, conv): + # data: [W, C, H*N, c] + # filter: [C, R*S*K, k] + # output: [W, K, H, N, k] + + # conv2d( [N, C, H, W, c] , [1, 1, C, K, k] + # inputs: (1, 128//4, 56, 56, 4), (1, 1, 128, 128//4, 4) + + # data: (56, 128//4, 56*1, 4) = (56, 32, 56, 4) + # filt: (128, 1*1*128//4, 4) = (128, 32, 4) + # conv: (56, 32, 56, 1, 4) + + s = te.create_schedule(conv.op) + cfg = autotvm.get_config() + + s[packed_data].compute_inline() + s[packed_filter].compute_inline() + A, B, C = packed_data, packed_filter, conv + At = s.cache_read(A, "global.texture", [C]) + Bt = s.cache_read(B, "global.texture", [C]) + Al = s.cache_read(At, "local", [C]) + Bl = s.cache_read(Bt, "local", [C]) + Cl = s.cache_write(C, "local") + + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + + copy_to_texture(At) + copy_to_texture(Bt) + + _w, _ko, _h, _n, _ki = s[C].op.axis + kernel_scope, _n = s[C].split(_n, nparts=1) + + cfg.define_split("tile_f", _ko, num_outputs=4) + cfg.define_split("tile_w", _w, num_outputs=4) + cfg.define_split("tile_h", _h, num_outputs=4) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + bk, vk, tk, ki = cfg["tile_f"].apply(s, C, _ko) + bw, vw, tw, wi = cfg["tile_w"].apply(s, C, _w) + bh, vh, th, hi = cfg["tile_h"].apply(s, C, _h) + s[C].reorder(bh, _n, vh, th, hi) + bhn = s[C].fuse(bh, _n) + + s[C].bind(bk, te.thread_axis("blockIdx.z")) + s[C].bind(bhn, te.thread_axis("blockIdx.y")) + s[C].bind(bw, te.thread_axis("blockIdx.x")) + s[C].bind(vk, te.thread_axis("vthread")) + s[C].bind(vh, te.thread_axis("vthread")) + s[C].bind(vw, te.thread_axis("vthread")) + s[C].bind(tk, te.thread_axis("threadIdx.z")) + s[C].bind(th, te.thread_axis("threadIdx.y")) + s[C].bind(tw, te.thread_axis("threadIdx.x")) + s[C].reorder(bw, bk, bhn, vw, vk, vh, tw, tk, th, ki, hi, wi, _ki) + s[C].vectorize(_ki) + + # TODO(csullivan): Try uneven workgroup split + # _wo, _wi = s[C].split(_w, factor=4) + # #_hno, _hni = s[C].split(_hn, factor=8) + # #s[C].reorder(_wo, _wi, _ko, _hno, _hni, _ki) + # s[C].reorder(_wo, _ko, _hn, _ki, _wi) + # s[C].unroll(_wi) + + # # mace: + # # const int out_ch_blk = get_global_id(0); + # # const int out_w_blk = get_global_id(1); + # # const int out_hb = get_global_id(2); + + # bx = te.thread_axis("blockIdx.x") + # by = te.thread_axis("blockIdx.y") + # bz = te.thread_axis("blockIdx.z") + # s[C].bind(_ko, bx) + # s[C].bind(_wo, by) + # s[C].bind(_hn, bz) + + # s[Cl].compute_at(s[C], _hn) + s[Cl].compute_at(s[C], th) + + _wl, _kol, _hl, _nl, _kil = s[Cl].op.axis + _khl, _kwl, _cl, _cl4 = s[Cl].op.reduce_axis + + cfg.define_split("tile_c", _cl, num_outputs=2) + cfg.define_split("tile_kh", _khl, num_outputs=2) + cfg.define_split("tile_kw", _kwl, num_outputs=2) + + _clo, _cli = cfg["tile_c"].apply(s, Cl, _cl) + _khlo, _khli = cfg["tile_kh"].apply(s, Cl, _khl) + _kwlo, _kwli = cfg["tile_kw"].apply(s, Cl, _kwl) + # s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x) + s[Cl].reorder(_clo, _khlo, _kwlo, _cli, _cl4, _khli, _kwli, _kol, _hl, _nl, _kil, _wl) + # s[Cl].reorder(_clo, _khlo, _kwlo, _cli, _cl4, _khli, _kwli) + # s[Cl].reorder(_cl, _cl4, _kil, _wl) + s[Cl].unroll(_cl4) + s[Cl].unroll(_wl) + s[Cl].vectorize(_kil) + + _wla, _cla, _hnla, _cl4a = s[Al].op.axis + s[Al].compute_at(s[Cl], _cli) + s[Al].vectorize(_cl4a) + s[Al].unroll(_wla) + + _clb, _rskolb, _kilb = s[Bl].op.axis + s[Bl].compute_at(s[Cl], _cli) + s[Bl].vectorize(_kilb) + s[Bl].unroll(_clb) + + s[C].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + + WO, K, HO, N, K4 = get_const_tuple(C.shape) + RSC, _, _ = get_const_tuple(B.shape) + cfg.add_flop(2 * N * K * K4 * HO * WO * RSC) + + return s + + +def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None): + """Convolution operator in NCHWc layout. """ + + if out_dtype is None: + out_dtype = Input.dtype + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_channel_chunk, in_height, in_width, in_channel_block = Input.shape + num_filter_chunk, channel, kernel_h, kernel_w, num_filter_block = Filter.shape + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + # compute graph + pad_before = [0, 0, pad_top, pad_left, 0] + pad_after = [0, 0, pad_down, pad_right, 0] + temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") + + rcc = te.reduce_axis((0, in_channel_chunk), name="rc") + rcb = te.reduce_axis((0, in_channel_block), name="rc") + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + + # NCHWc x KCRSk + # texture: NCH|W|c + # texture: K|CRS|k + # c = crs//RS + # rs = crs % RS + # r = rs // W == (crs // S) % R + # s = rs % W == crs % S + Filter = te.compute( + (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), + lambda ffc, crs, ffb: Filter[ + ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb + ], + name="packed_filter", + ) + return te.compute( + (batch, num_filter_chunk, out_height, out_width, num_filter_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + temp[ + nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb + ].astype(out_dtype) + * Filter[ + ffc, ((rcc * in_channel_block + rcb) * kernel_h + ry) * kernel_w + rx, ffb + ].astype(out_dtype), + axis=[rcc, rcb, ry, rx], + ), + tag="conv2d_nchwc_kcrsk_texture", + ) + + +def schedule_conv2d_NCHWc_KCRSk(cfg, s, conv): + """schedule optimized for batch size = 1""" + + ##### space definition begin ##### + n, fc, y, x, fb = s[conv].op.axis + rcc, rcb, ry, rx = s[conv].op.reduce_axis + cfg.define_split("tile_fc", fc, num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rcc", rcc, num_outputs=2) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + pad_data, flattened_kernel = s[conv].op.input_tensors + kernel = s[flattened_kernel].op.input_tensors[0] + s[flattened_kernel].compute_inline() + + s[pad_data].compute_inline() + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + kernel = flattened_kernel + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, "local") + else: + output = s.outputs[0].output(0) + s[conv].set_scope("local") + OL = conv + + # create cache stage + AT = s.cache_read(pad_data, "global.texture", [OL]) + WT = s.cache_read(kernel, "global.texture", [OL]) + + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + + copy_to_texture(AT) + copy_to_texture(WT) + + AA = s.cache_read(AT, "shared", [OL]) + WW = s.cache_read(WT, "shared", [OL]) + + # tile and bind spatial axes + n, fc, y, x, fb = s[output].op.axis + + kernel_scope, n = s[output].split(n, nparts=1) + + bf, vf, tf, fi = cfg["tile_fc"].apply(s, output, fc) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + bf = s[output].fuse(n, bf) + s[output].bind(bf, te.thread_axis("blockIdx.z")) + s[output].bind(by, te.thread_axis("blockIdx.y")) + s[output].bind(bx, te.thread_axis("blockIdx.x")) + s[output].bind(vf, te.thread_axis("vthread")) + s[output].bind(vy, te.thread_axis("vthread")) + s[output].bind(vx, te.thread_axis("vthread")) + s[output].bind(tf, te.thread_axis("threadIdx.z")) + s[output].bind(ty, te.thread_axis("threadIdx.y")) + s[output].bind(tx, te.thread_axis("threadIdx.x")) + s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb) + s[output].vectorize(fb) + s[OL].compute_at(s[output], tx) + + # tile reduction axes + n, fc, y, x, fb = s[OL].op.axis + + rcc, rcb, ry, rx = s[OL].op.reduce_axis + rco, rci = cfg["tile_rcc"].apply(s, OL, rcc) + ryo, ryi = cfg["tile_ry"].apply(s, OL, ry) + rxo, rxi = cfg["tile_rx"].apply(s, OL, rx) + + # TODO(csullivan): check position of rcb + s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb) + s[OL].vectorize(fb) + s[OL].unroll(rcb) + + s[AA].compute_at(s[OL], rxo) + s[WW].compute_at(s[OL], rxo) + # cooperative fetching + for load in [AA, WW]: + if load == WW: + n, fyx, v = s[load].op.axis + fused = s[load].fuse(n, fyx) + else: + n, f, y, x, v = s[load].op.axis + fused = s[load].fuse(n, f, y, x) + tz, fused = s[load].split(fused, nparts=cfg["tile_fc"].size[2]) + ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) + tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + s[load].vectorize(v) + + # unroll + s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + + N, OCC, OH, OW, OCB = get_const_tuple(output.shape) + _, ICKHKW, _ = get_const_tuple(kernel.shape) + + if isinstance(N, int): + cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) + + +def compute_conv2d_NCHWc_KCRSk_acc32(Input, Filter, stride, padding, dilation, out_dtype=None): + """Convolution operator in NCHWc layout. """ + + if out_dtype is None: + out_dtype = Input.dtype + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_channel_chunk, in_height, in_width, in_channel_block = Input.shape + num_filter_chunk, channel, kernel_h, kernel_w, num_filter_block = Filter.shape + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + # compute graph + pad_before = [0, 0, pad_top, pad_left, 0] + pad_after = [0, 0, pad_down, pad_right, 0] + temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") + + rcc = te.reduce_axis((0, in_channel_chunk), name="rc") + rcb = te.reduce_axis((0, in_channel_block), name="rc") + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + + # NCHWc x KCRSk + # texture: NCH|W|c + # texture: K|CRS|k + # c = crs//RS + # rs = crs % RS + # r = rs // W == (crs // S) % R + # s = rs % W == crs % S + Filter = te.compute( + (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), + lambda ffc, crs, ffb: Filter[ + ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb + ], + name="packed_filter", + ) + conv = te.compute( + (batch, num_filter_chunk, out_height, out_width, num_filter_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + ( + temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] + * Filter[ffc, ((rcc * in_channel_block + rcb) * kernel_h + ry) * kernel_w + rx, ffb] + ).astype(out_dtype), + axis=[rcc, rcb, ry, rx], + ), + tag="conv2d_nchwc_kcrsk_texture", + ) + output = te.compute(conv.shape, lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype("float32")) + return output + + +def schedule_conv2d_NCHWc_KCRSk_acc32(cfg, s, output): + """schedule optimized for batch size = 1""" + + conv = output.op.input_tensors[0] + + ##### space definition begin ##### + n, fc, y, x, fb = s[conv].op.axis + rcc, rcb, ry, rx = s[conv].op.reduce_axis + cfg.define_split("tile_fc", fc, num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rcc", rcc, num_outputs=2) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + pad_data, flattened_kernel = s[conv].op.input_tensors + kernel = s[flattened_kernel].op.input_tensors[0] + s[flattened_kernel].compute_inline() + + s[pad_data].compute_inline() + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + kernel = flattened_kernel + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, "local") + else: + output = s.outputs[0].output(0) + s[conv].set_scope("local") + OL = conv + + # create cache stage + AT = s.cache_read(pad_data, "global.texture", [OL]) + WT = s.cache_read(kernel, "global.texture", [OL]) + + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + + copy_to_texture(AT) + copy_to_texture(WT) + + AA = s.cache_read(AT, "shared", [OL]) + WW = s.cache_read(WT, "shared", [OL]) + + # tile and bind spatial axes + n, fc, y, x, fb = s[output].op.axis + + kernel_scope, n = s[output].split(n, nparts=1) + + bf, vf, tf, fi = cfg["tile_fc"].apply(s, output, fc) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + bf = s[output].fuse(n, bf) + s[output].bind(bf, te.thread_axis("blockIdx.z")) + s[output].bind(by, te.thread_axis("blockIdx.y")) + s[output].bind(bx, te.thread_axis("blockIdx.x")) + s[output].bind(vf, te.thread_axis("vthread")) + s[output].bind(vy, te.thread_axis("vthread")) + s[output].bind(vx, te.thread_axis("vthread")) + s[output].bind(tf, te.thread_axis("threadIdx.z")) + s[output].bind(ty, te.thread_axis("threadIdx.y")) + s[output].bind(tx, te.thread_axis("threadIdx.x")) + s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb) + s[output].vectorize(fb) + + s[OL].compute_at(s[output], tx) + + # tile reduction axes + n, fc, y, x, fb = s[OL].op.axis + + rcc, rcb, ry, rx = s[OL].op.reduce_axis + rco, rci = cfg["tile_rcc"].apply(s, OL, rcc) + ryo, ryi = cfg["tile_ry"].apply(s, OL, ry) + rxo, rxi = cfg["tile_rx"].apply(s, OL, rx) + + # TODO(csullivan): check position of rcb + s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb) + s[OL].vectorize(fb) + s[OL].unroll(rcb) + + s[AA].compute_at(s[OL], rxo) + s[WW].compute_at(s[OL], rxo) + # cooperative fetching + for load in [AA, WW]: + if load == WW: + n, fyx, v = s[load].op.axis + fused = s[load].fuse(n, fyx) + else: + n, f, y, x, v = s[load].op.axis + fused = s[load].fuse(n, f, y, x) + tz, fused = s[load].split(fused, nparts=cfg["tile_fc"].size[2]) + ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) + tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + s[load].vectorize(v) + + # unroll + s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + + N, OCC, OH, OW, OCB = get_const_tuple(output.shape) + _, ICKHKW, _ = get_const_tuple(kernel.shape) + + if isinstance(N, int): + cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) + + +def compute_depthwise_conv2d_NCHWc_KCRSk_acc32( + Input, Filter, stride, padding, dilation, out_dtype=None +): + """Depthwise convolution operator in NCHWc layout. """ + if out_dtype is None: + out_dtype = Input.dtype + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, channel_chunk, in_height, in_width, channel_block = Input.shape + _, channel_multiplier, kernel_h, kernel_w, _ = Filter.shape + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + out_channel_chunk = simplify(channel_chunk * channel_multiplier) + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + # compute graph + pad_before = [0, 0, pad_top, pad_left, 0] + pad_after = [0, 0, pad_down, pad_right, 0] + temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") + + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + + # NCHWc x CMRSc = [N,(C//4)M,OH,OW, 4c] + # NCHWc x CMRS + # texture: NCH|W|c + # texture: C|MRS|c + # output: N + # m = mrs//RS + # rs = mrs % RS + # r = rs // W == (mrs // S) % R + # s = rs % W == mrs % S + Filter = te.compute( + (channel_chunk, channel_multiplier * kernel_h * kernel_w, channel_block), + lambda ffc, mrs, ffb: Filter[ + ffc, mrs // (kernel_h * kernel_w), (mrs // kernel_w) % kernel_h, mrs % kernel_w, ffb + ], + name="packed_filter", + ) + + conv = te.compute( + (batch, out_channel_chunk, out_height, out_width, channel_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + ( + temp[ + nn, + ffc // channel_multiplier, + yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, + ffb, + ] + * Filter[ + ffc // channel_multiplier, + ((ffc % channel_multiplier) * kernel_h + ry) * kernel_w + rx, + ffb, + ] + ).astype(out_dtype), + axis=[ry, rx], + ), + tag="depthwise_conv2d_nchwc_kcrsk_texture", + ) + return te.compute( + conv.shape, lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, ffb].astype("float32") + ) + + +def schedule_depthwise_conv2d_NCHWc_KCRSk_acc32(cfg, s, output): + """schedule optimized for batch size = 1""" + + conv = output.op.input_tensors[0] + + ##### space definition begin ##### + n, fc, y, x, fb = s[conv].op.axis + ry, rx = s[conv].op.reduce_axis + cfg.define_split("tile_fc", fc, num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + pad_data, flattened_kernel = s[conv].op.input_tensors + kernel = s[flattened_kernel].op.input_tensors[0] + s[flattened_kernel].compute_inline() + + s[pad_data].compute_inline() + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + kernel = flattened_kernel + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, "local") + else: + output = s.outputs[0].output(0) + s[conv].set_scope("local") + OL = conv + + # create cache stage + AT = s.cache_read(pad_data, "global.texture", [OL]) + WT = s.cache_read(kernel, "global.texture", [OL]) + + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + + copy_to_texture(AT) + copy_to_texture(WT) + + AA = s.cache_read(AT, "shared", [OL]) + WW = s.cache_read(WT, "shared", [OL]) + + # tile and bind spatial axes + n, fc, y, x, fb = s[output].op.axis + + kernel_scope, n = s[output].split(n, nparts=1) + + bf, vf, tf, fi = cfg["tile_fc"].apply(s, output, fc) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + bf = s[output].fuse(n, bf) + s[output].bind(bf, te.thread_axis("blockIdx.z")) + s[output].bind(by, te.thread_axis("blockIdx.y")) + s[output].bind(bx, te.thread_axis("blockIdx.x")) + s[output].bind(vf, te.thread_axis("vthread")) + s[output].bind(vy, te.thread_axis("vthread")) + s[output].bind(vx, te.thread_axis("vthread")) + s[output].bind(tf, te.thread_axis("threadIdx.z")) + s[output].bind(ty, te.thread_axis("threadIdx.y")) + s[output].bind(tx, te.thread_axis("threadIdx.x")) + s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb) + s[output].vectorize(fb) + + s[OL].compute_at(s[output], tx) + + # tile reduction axes + n, fc, y, x, fb = s[OL].op.axis + + ry, rx = s[OL].op.reduce_axis + ryo, ryi = cfg["tile_ry"].apply(s, OL, ry) + rxo, rxi = cfg["tile_rx"].apply(s, OL, rx) + + s[OL].reorder(ryo, rxo, ryi, rxi, n, fc, y, x, fb) + s[OL].vectorize(fb) + # s[OL].unroll() + + s[AA].compute_at(s[OL], rxo) + s[WW].compute_at(s[OL], rxo) + # cooperative fetching + for load in [AA, WW]: + if load == WW: + n, fyx, v = s[load].op.axis + fused = s[load].fuse(n, fyx) + else: + n, f, y, x, v = s[load].op.axis + fused = s[load].fuse(n, f, y, x) + tz, fused = s[load].split(fused, nparts=cfg["tile_fc"].size[2]) + ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) + tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + s[load].vectorize(v) + + # unroll + s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + + N, OCC, OH, OW, OCB = get_const_tuple(output.shape) + ICC, MKHKW, ICB = get_const_tuple(kernel.shape) + M = (OCC * OCB) // (ICC * ICB) + KHKW = MKHKW // M + + if isinstance(N, int): + cfg.add_flop(2 * N * OH * OW * OCC * OCB * KHKW) + + +def scheduler(compute, schedule, *args, **kwargs): + placeholders = compute(*args) + s = schedule(*placeholders, **kwargs) + return s, placeholders + + +def conv2d_1x1_NCHWc_RSCKk(input_shape, filter_shape): + placeholders = compute_conv2d_1x1_NCHWc_RSCKk(input_shape, filter_shape) + s = schedule_conv2d_1x1_NCHWc_RSCKk(*placeholders) + return s, placeholders + + +def conv2d_1x1_WCHNc_CRSKk(input_shape, filter_shape): + placeholders = compute_conv2d_1x1_WCHNc_CRSKk(input_shape, filter_shape) + s = schedule_conv2d_1x1_WCHNc_CRSKk(*placeholders) + return s, (placeholders[0], placeholders[1], placeholders[-1]) + + +def conv2d_NCHWc_KCRSk(input_shape, filter_shape): + data = te.placeholder(input_shape, name="data", dtype="float32") + filt = te.placeholder(filter_shape, name="filter", dtype="float32") + conv = compute_conv2d_NCHWc_KCRSk(data, filt, [1, 1], [0, 0], [1, 1], "float32") + cfg = autotvm.get_config() + s = te.create_schedule([x.op for x in [conv]]) + schedule_conv2d_NCHWc_KCRSk(cfg, s, conv) + return s, (data, filt, conv) + + +def conv2d_NCHWc_KCRSk_fp32_acc(input_shape, filter_shape): + data = te.placeholder(input_shape, name="data", dtype="float32") + filt = te.placeholder(filter_shape, name="filter", dtype="float32") + output = compute_conv2d_NCHWc_KCRSk_acc32(data, filt, [1, 1], [0, 0], [1, 1], "float32") + cfg = autotvm.get_config() + s = te.create_schedule([x.op for x in [output]]) + schedule_conv2d_NCHWc_KCRSk_acc32(cfg, s, output) + return s, (data, filt, output) + + +def depthwise_conv2d_NCHWc_KCRSk_acc32(input_shape, filter_shape): + data = te.placeholder(input_shape, name="data", dtype="float32") + filt = te.placeholder(filter_shape, name="filter", dtype="float32") + output = compute_depthwise_conv2d_NCHWc_KCRSk_acc32( + data, filt, [1, 1], [0, 0], [1, 1], "float32" + ) + cfg = autotvm.get_config() + s = te.create_schedule([x.op for x in [output]]) + schedule_depthwise_conv2d_NCHWc_KCRSk_acc32(cfg, s, output) + return s, (data, filt, output) + + +def ref_convolution(data, kernel, stride, pad): + import mxnet as mx + + groups = 1 + kernel_size = (kernel.shape[2], kernel.shape[3]) + num_filter = kernel.shape[0] + ref_res = mx.nd.Convolution( + data=mx.nd.array(data), + weight=mx.nd.array(kernel), + bias=None, + no_bias=True, + kernel=kernel_size, + stride=stride, + pad=pad, + num_filter=num_filter, + num_group=groups, + ) + return ref_res.asnumpy() + + +def ref_depthwise_convolution(data, kernel, stride, pad): + import mxnet as mx + + groups = kernel.shape[0] + kernel_size = (kernel.shape[2], kernel.shape[3]) + num_filter = kernel.shape[0] + multiplier = kernel.shape[1] + ref_res = mx.nd.Convolution( + data=mx.nd.array(data), + weight=mx.nd.array(kernel), + bias=None, + no_bias=True, + kernel=kernel_size, + stride=stride, + pad=pad, + num_filter=num_filter, + num_group=groups, + ) + return ref_res.asnumpy() + + +def validate(workload, target, dev, input_shapes, *args, **kwargs): + s, placeholders = workload(*input_shapes, *args, **kwargs) + func = tvm.driver.build(s, [*placeholders], target=target, name="TestFunction") + + args_tvm = [] + args_np = [] + for var in placeholders[:-1]: + var_np = np.random.uniform(size=[i.value for i in var.shape]).astype(var.dtype) + args_np.append(var_np) + args_tvm.append(tvm.nd.array(var_np, dev)) + args_tvm.append( + tvm.nd.array( + np.zeros([i.value for i in placeholders[-1].shape], dtype=placeholders[-1].dtype), dev + ) + ) + func(*args_tvm) + + if "plus_one" in workload.__name__: + np_result = args_np[0] + 1.0 + elif "matmul" in workload.__name__: + if "inner" in workload.__name__: + np_result = np.matmul( + args_np[0].reshape(32, 256), args_np[1].reshape(32, 256).transpose(1, 0) + ) + elif "accum" in workload.__name__: + np_result = np.matmul( + args_np[0].transpose((1, 0, 2)).reshape(64, 128), args_np[1].reshape(128, 64) + ) + else: + np_result = np.matmul( + args_np[0].transpose((0, 2, 1)).reshape(128, 64), + args_np[1].transpose(1, 0, 2).reshape(64, 128), + ) + elif "conv2d_1x1_NCHWc_RSCKk" in workload.__name__: + vec_length = args_np[1].shape[-1] + # nchwc -> nchw + args_np[0] = ( + args_np[0] + .transpose((0, 1, 4, 2, 3)) + .reshape( + args_np[0].shape[0], + args_np[0].shape[1] * args_np[0].shape[-1], + args_np[0].shape[2], + args_np[0].shape[3], + ) + ) + # rsckk -> rsck -> kcrs + args_np[1] = ( + args_np[1] + .reshape( + args_np[1].shape[0], + args_np[1].shape[1], + args_np[1].shape[2], + args_np[1].shape[3] * args_np[1].shape[4], + ) + .transpose((3, 2, 0, 1)) + ) + np_result = testing.conv2d_nchw_python(args_np[0], args_np[1], 1, 0) + # nkhw -> nkhwk + np_result = np_result.reshape( + np_result.shape[0], + np_result.shape[1] // vec_length, + vec_length, + np_result.shape[2], + np_result.shape[3], + ).transpose(0, 1, 3, 4, 2) + elif "conv2d_1x1_WCHNc_CRSKk" in workload.__name__: + vec_length = args_np[1].shape[-1] + # wchnc -> nchw + args_np[0] = ( + args_np[0] + .transpose((3, 1, 4, 2, 0)) + .reshape( + args_np[0].shape[3], + args_np[0].shape[1] * args_np[0].shape[-1], + args_np[0].shape[2], + args_np[0].shape[0], + ) + ) + # crskk -> crsk -> kcrs + args_np[1] = ( + args_np[1] + .reshape( + args_np[1].shape[0], + args_np[1].shape[1], + args_np[1].shape[2], + args_np[1].shape[3] * args_np[1].shape[4], + ) + .transpose((3, 0, 1, 2)) + ) + np_result = testing.conv2d_nchw_python(args_np[0], args_np[1], 1, 0) + # nkhw -> nkkhw -> wkhnk + np_result = np_result.reshape( + np_result.shape[0], + np_result.shape[1] // vec_length, + vec_length, + np_result.shape[2], + np_result.shape[3], + ).transpose(4, 1, 3, 0, 2) + elif "NCHW_KCRS" in workload.__name__: + np_result = testing.conv2d_nchw_python(args_np[0], args_np[1], 1, 0) + elif "NCHWc_KCRSk" in workload.__name__: + vec_length = args_np[1].shape[-1] + # nchwc -> nchw + args_np[0] = ( + args_np[0] + .transpose((0, 1, 4, 2, 3)) + .reshape( + args_np[0].shape[0], + args_np[0].shape[1] * args_np[0].shape[-1], + args_np[0].shape[2], + args_np[0].shape[3], + ) + ) + # kcrsk/cmrsc -> kcrs/cmrs + args_np[1] = ( + args_np[1] + .transpose((0, 4, 1, 2, 3)) + .reshape( + args_np[1].shape[0] * args_np[1].shape[4], + args_np[1].shape[1], + args_np[1].shape[2], + args_np[1].shape[3], + ) + ) + if "depthwise" in workload.__name__: + # np_result = testing.depthwise_conv2d_python_nchw(args_np[0], args_np[1], 1, "VALID") + np_result = ref_depthwise_convolution(args_np[0], args_np[1], [], []) + else: + # np_result = testing.conv2d_nchw_python(args_np[0], args_np[1], 1, 0) + np_result = ref_convolution(args_np[0], args_np[1], [], []) + # nkhw -> nkhwk + np_result = np_result.reshape( + np_result.shape[0], + np_result.shape[1] // vec_length, + vec_length, + np_result.shape[2], + np_result.shape[3], + ).transpose(0, 1, 3, 4, 2) + np.testing.assert_allclose(args_tvm[-1].asnumpy(), np_result, rtol=1e-2, atol=1e-2) + + +class BaseSingleShapeValidator: + @tvm.testing.parametrize_targets("opencl") + def test_unary(self, test_func, input_shape, target, dev): + validate(test_func, target, dev, [input_shape]) + + +class TestPlusOneRank3(BaseSingleShapeValidator): + input_shape = tvm.testing.parameter((32, 32, 4)) + + def plus_one(input_shape): + return scheduler(compute_plus_one_rank3, schedule_plus_one_rank3, input_shape) + + test_func = tvm.testing.parameter(plus_one) + + +class TestPlusOneRank5(BaseSingleShapeValidator): + input_shape = tvm.testing.parameter((32, 2, 4, 4, 4)) + + def plus_one(input_shape): + return scheduler(compute_plus_one_rank5, schedule_plus_one_rank5, input_shape) + + test_func = tvm.testing.parameter(plus_one) + + +class TestMatmul: + input_shape = tvm.testing.parameter((32, 64, 4)) + local = tvm.testing.parameter(False, True) + + def matmul(input_shape, local): + return scheduler(compute_matmul, schedule_matmul, input_shape, local=local) + + def matmul_inner(input_shape, local): + return scheduler(compute_matmul_inner, schedule_matmul_inner, input_shape, local=local) + + test_func = tvm.testing.parameter(matmul, matmul_inner) + + @tvm.testing.parametrize_targets("opencl") + def test_matmul(self, test_func, input_shape, local, target, dev): + validate(test_func, target, dev, [input_shape], local=local) + + +class TestMatmulVectorAccumulator: + shapeA = tvm.testing.parameter((32, 64, 4)) + shapeB = tvm.testing.parameter((128, 16, 4)) + local = tvm.testing.parameter(False, True) + + def matmul_vector_accumulator(shapeA, shapeB, local): + return scheduler( + compute_matmul_vector_accumulator, + schedule_matmul_vector_accumulator, + shapeA, + shapeB, + local=local, + ) + + test_func = tvm.testing.parameter(matmul_vector_accumulator) + + @tvm.testing.parametrize_targets("opencl") + def test_matmul_vec_acc(self, test_func, shapeA, shapeB, local, target, dev): + validate(test_func, target, dev, [shapeA, shapeB], local=local) + + +class BaseConv2DValidator: + @tvm.testing.parametrize_targets("opencl") + def test_conv2d(self, test_func, input_shapes, target, dev): + validate(test_func, target, dev, input_shapes) + + +class TestConv2dNCHWcRSCKk(BaseConv2DValidator): + input_shapes = tvm.testing.parameter([(1, 32, 56, 56, 4), (1, 1, 128, 32, 4)]) + test_func = tvm.testing.parameter(conv2d_1x1_NCHWc_RSCKk) + + +class TestConv2dWCHNcCRSKk(BaseConv2DValidator): + input_shapes = tvm.testing.parameter([(56, 32, 56, 1, 4), (128, 1, 1, 32, 4)]) + test_func = tvm.testing.parameter(conv2d_1x1_WCHNc_CRSKk) + + +class TestConv2dNCHWcKCRSk(BaseConv2DValidator): + input_shapes = tvm.testing.parameter( + [(1, 32, 56, 56, 4), (32, 128, 1, 1, 4)], [(1, 32, 112, 112, 4), (32, 128, 3, 3, 4)] + ) + test_func = tvm.testing.parameter(conv2d_NCHWc_KCRSk, conv2d_NCHWc_KCRSk_fp32_acc) + + +class TestDepthwiseConv2dNCHWcKCRSk(BaseConv2DValidator): + input_shapes = tvm.testing.parameter([(1, 24, 257, 257, 4), (24, 1, 3, 3, 4)]) + test_func = tvm.testing.parameter(depthwise_conv2d_NCHWc_KCRSk_acc32) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 52dc4ccd9fef..2fdafe08e60f 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -27,7 +27,7 @@ def test_unique_name(): B = te.compute((16, 16), lambda x, y: A[x, y] * 2, name="main") C = te.compute((16, 16), lambda x, y: B[x, y] + 1, name="main") func = te.create_prim_func([A, C]) - s = tir.Schedule(func, debug_mode=True) + s = tir.Schedule(func, debug_mask="all") assert isinstance(s.get_sref(s.get_block("main")), tir.schedule.StmtSRef) assert isinstance(s.get_sref(s.get_block("main_1")), tir.schedule.StmtSRef) @@ -36,7 +36,7 @@ def _check_workload(te_workload, tir_workload): func = te.create_prim_func(te_workload()) tvm.ir.assert_structural_equal(func, tir_workload) # make sure that we can create schedule from the func - s = tir.Schedule(func, debug_mode=True) + s = tir.Schedule(func, debug_mask="all") assert s diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index 30b96546f991..e9626e7f31b4 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -189,9 +189,7 @@ def fanout(n, a): assert ir.min.value == 0 assert tvm.ir.structural_equal(ir.extent, n - 3) # Check loopbody - ibody = ir.body - assert isinstance(ibody, tvm.tir.AttrStmt) - abody = ibody.body + abody = ir.body assert isinstance(abody, tvm.tir.ProducerRealize) assert abody.bounds[0].min.value == 0 assert abody.bounds[0].extent.value == 1 diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index 60a324727f81..8b504df120e0 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -218,7 +218,7 @@ def test_rfactor(): assert set(BF.op.body[0].axis) == set([k2]) assert s[B].op.body[0].axis[0].dom.extent == n assert len(s[B].all_iter_vars) == 2 - # schedule with splot + # schedule with split s = te.create_schedule(B.op) ko, ki = s[B].split(k1, factor=4) xo, xi = s[B].split(B.op.axis[0], factor=8) diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index e2c2f7f7e0e5..ae5e7051bfba 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -379,8 +379,8 @@ def intrin_func(ins, outs): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) # The loop that we tried to tensorize still exists in the code # That means tensorize didn't work as expected - assert isinstance(stmt.body.body, tvm.tir.For) - assert stmt.body.body.loop_var.name == C.op.axis[0].var.name + assert isinstance(stmt.body, tvm.tir.For) + assert stmt.body.loop_var.name == C.op.axis[0].var.name if __name__ == "__main__": diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index ed4a21397885..2931925965b7 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -309,7 +309,7 @@ def get_B1_realize(x): ret = [] tvm.tir.stmt_functor.post_order_visit(stmt, get_B1_realize) - assert stmt.node == C.op and len(ret) == 1 + assert stmt.producer == C and len(ret) == 1 def test_tensor_inputs(): diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 36fd80fd07de..8c2b2710f1ba 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -70,6 +70,22 @@ def lca_is_func_root(a: ty.handle) -> None: A.data[0] = 1.0 +@tvm.script.tir +def match_buffer_func(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.match_buffer(b, (128, 128), "float32") + with tir.block([8, 8], "block") as [vi, vj]: + tir.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + tir.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + B0 = tir.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = tir.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) + with tir.block([16, 16], "AAA") as [i, j]: + AA = tir.match_buffer(A[i, j], ()) + AA[()] = 1.0 + tir.evaluate(B0.data) + tir.evaluate(B1.data) + + def test_buffer_load_store(): func = buffer_load_store_func A, B = [func.buffer_map[x] for x in func.params] @@ -115,7 +131,24 @@ def test_lca_func_root(): assert lca[A] is None +def test_match_buffer(): + func = match_buffer_func + A, B = [func.buffer_map[x] for x in func.params] + lca = tir.analysis.detect_buffer_access_lca(func) + + root_block = func.body.block + block = root_block.body.body.body.block + block_inner = block.body[0].body.body.block + + # LCA of Buffer C is the inner block + assert lca[A] == block_inner + + # LCA of Buffer C is the main block + assert lca[B] == block + + if __name__ == "__main__": test_buffer_load_store() test_opaque_access() test_lca_func_root() + test_match_buffer() diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 7e4d7d87c1e1..9c95b9819e6f 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -39,6 +39,48 @@ def func() -> None: tir.evaluate(D.data) +@tvm.script.tir +def match_buffer_func() -> None: + with tir.block([], "root"): + A = tir.alloc_buffer((128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + tir.reads([]) + tir.writes([]) + # Need add read/write region manually to avoid triggering block access region detector + with tir.block([8, 8], "block") as [vi, vj]: + tir.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + tir.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + AA = tir.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) + B0 = tir.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = tir.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) + with tir.block([16, 16], "AAA") as [i, j]: + tir.reads([]) + tir.writes(AA[i, j]) + AAA = tir.match_buffer(AA[i, j], ()) + AAA[()] = 1.0 + tir.evaluate(B0.data) + tir.evaluate(B1.data) + + +@tvm.script.tir +def opaque_block_func() -> None: + with tir.block([], "root"): + A = tir.alloc_buffer((16, 16), "float32") + B = tir.alloc_buffer((16, 16), "float32") + tir.reads([]) + tir.writes([]) + # Need add read/write region manually to avoid triggering block access region detector + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes([B[i, 0:16]]) + for j in range(0, 16): + with tir.block([]): + tir.reads(A[i, j]) + tir.writes(B[i, j]) + B[i, j] = A[i, j] + 1.0 + + def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -53,5 +95,50 @@ def test_block_access_region_detector(): ) +def test_opaque_block(): + alloc_buffers = opaque_block_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + block0 = opaque_block_func.body.block.body.body.block + ret = tir.analysis.get_block_access_region(block0, buffer_var_map) + tvm.ir.assert_structural_equal(block0.reads, ret[0]) + tvm.ir.assert_structural_equal(block0.writes, ret[1]) + + block1 = block0.body.body.block + ret = tir.analysis.get_block_access_region(block1, buffer_var_map) + tvm.ir.assert_structural_equal(block1.reads, ret[0]) + tvm.ir.assert_structural_equal(block1.writes, ret[1]) + + +def test_match_buffer(): + root_block = match_buffer_func.body.block + block = root_block.body.body.body.block + block_inner = block.body[0].body.body.block + alloc_buffers = match_buffer_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + # Check block + ret = tir.analysis.get_block_access_region(block, buffer_var_map) + tvm.ir.assert_structural_equal(block.writes, ret[1]) + # B is opaque access + tvm.ir.assert_structural_equal(block.reads, ret[2]) + + # Check inner block AAA without updating buffer_var_map + ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) + # Since AA is not in the buffer_var_map, region of AA will not be collected. + tvm.ir.assert_structural_equal([], ret[1]) + + # Check inner block AAA + for match_buffer in block.match_buffers: + target_buffer = match_buffer.buffer + buffer_var_map[target_buffer.data] = target_buffer + + ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) + tvm.ir.assert_structural_equal(block_inner.reads, ret[0]) + tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) + + if __name__ == "__main__": test_block_access_region_detector() + test_opaque_block() + test_match_buffer() diff --git a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py index 7a23c8aa3359..9e9563a66a5d 100644 --- a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py +++ b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py @@ -17,7 +17,9 @@ """Test gpu code verifier""" import tvm from tvm import te +from tvm import topi import tvm.testing +import tvm.topi.testing def get_verify_pass(valid, **kwargs): @@ -373,6 +375,33 @@ def test_vthread(): assert not valid[0] +@tvm.testing.requires_gpu +def test_redundant_kernels(): + dtype = "float32" + A = te.placeholder(shape=(1,), name="A", dtype=dtype) + B = te.placeholder(shape=(1,), name="B", dtype=dtype) + C = te.placeholder(shape=(1,), name="C", dtype=dtype) + D = topi.less(A, C) + E = topi.less(B, C) + F = topi.logical_or(D, E) + G = topi.identity(F) + + for target in ["opencl", "cuda"]: + if not tvm.testing.device_enabled(target): + continue + print("Running on target: %s" % target) + valid = [None] + + with tvm.target.Target(target): + s = tvm.topi.testing.get_reduce_schedule(target)(G) + + with tvm.transform.PassContext( + config={"tir.add_lower_pass": [(2, get_verify_pass(valid, max_kernels=1))]} + ): + tvm.build(s, [A, B, C, G], target) + assert valid[0] + + if __name__ == "__main__": test_local_memory() test_shared_memory() @@ -381,3 +410,4 @@ def test_vthread(): test_wrong_bind() test_vectorize() test_vthread() + test_redundant_kernels() diff --git a/tests/python/unittest/test_tir_base.py b/tests/python/unittest/test_tir_base.py index 6e081a179059..66f3ef9e599f 100644 --- a/tests/python/unittest/test_tir_base.py +++ b/tests/python/unittest/test_tir_base.py @@ -15,8 +15,12 @@ # specific language governing permissions and limitations # under the License. import tvm +import pytest from tvm import tir +from tvm._ffi.base import TVMError from tvm.ir.transform import PassContext +import itertools +import pytest def build_tir_func(func): @@ -30,15 +34,69 @@ def build_tir_func(func): def test_scalar_add(): - a = tir.Var("a", "float32") - b = tir.Var("b", "float32") - c = a + b - c = tir.ret(c) - c = tir.Evaluate(c) - func = tir.PrimFunc([a, b], c) + # All these types should be interchangeable with each other + # E.g. float16 + float32 upconverts the float16 --> float32 + # Meanwhile if an int or float or together the int will be + # cast to the float type. + lhs_types = ["float32", "float16", "int32", "int64"] + rhs_types = ["float32", "float16"] + for lhs_type, rhs_type in itertools.product(lhs_types, rhs_types): + # Input vars should be float32, we will cast to test for upcasting between them + lhs_input = tir.Var("lhs", "float32") + rhs_input = tir.Var("rhs", "float32") + lhs = tir.Cast(lhs_type, lhs_input) + rhs = tir.Cast(rhs_type, rhs_input) + output = lhs + rhs + output = tir.ret(output) + output = tir.Evaluate(output) + func = tir.PrimFunc([lhs_input, rhs_input], output) + func = build_tir_func(func) + out = func(1.0, 2.0) + assert out == 3.0 + + +def assignment_helper(store_dtype, value_dtype): + store = tir.Var("store", dtype=store_dtype) + value = tir.Var("value", dtype=value_dtype) + tir.Let(store, value, body=store) + + +def test_fail_implicit_downcasts_same_type(): + # These lists should be sorted + bits = [8, 16, 32, 64] + for type in ["float", "int", "uint"]: + for i in range(len(bits) - 1): + with pytest.raises(TVMError): + assignment_helper( + store_dtype=f"{type}{bits[i]}", value_dtype=f"{type}{bits[i + 1]}" + ) + + +def test_cast_between_types(): + # We should only be able to assign values with the same types + bits = [16, 32] + types = ["float", "int", "uint"] + for store_type, store_bits, value_type, value_bits in itertools.product( + types, bits, types, bits + ): + store_dtype = f"{store_type}{store_bits}" + value_dtype = f"{value_type}{value_bits}" + if store_dtype == value_dtype: + assignment_helper(store_dtype, value_dtype) + else: + # TODO: we might want to allow casts between uint and int types + with pytest.raises(TVMError): + assignment_helper(store_dtype, value_dtype) + + +def test_ret_const(): + a = tir.const(0) + b = tir.ret(a) + b = tir.Evaluate(b) + func = tir.PrimFunc([], b) func = build_tir_func(func) - out = func(1.0, 2.0) - assert out == 3.0 + out = func() + assert out == 0 def test_control_flow_jump(): @@ -55,6 +113,13 @@ def test_control_flow_jump(): assert out == 1.0 +def test_exception(): + with pytest.raises(tvm.TVMError): + x = tir.Var(name=1, dtype="int") + + if __name__ == "__main__": test_scalar_add() + test_ret_const() test_control_flow_jump() + test_exception() diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 83377e443764..42f9c34133df 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -131,6 +131,23 @@ def assert_simplified_equal(index_simplified, index_direct): ) assert_simplified_equal(index_simplified, index_direct) + # Test Case5 + B = tvm.tir.decl_buffer((1, 14, 14, 1024)) + i = te.size_var("i") + j = te.size_var("j") + k = te.size_var("k") + + index_simplified = B.vload( + ( + idxd(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), + idxm(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), + idxm(idxd((i * 50176 + j * 28672 + k), 1024), 14), + idxm((i * 50176 + j * 28672 + k), 1024), + ) + ) + index_direct = B.vload((0, 0, 0, (i * 50176 + j * 28672 + k))) + assert_simplified_equal(index_simplified, index_direct) + @tvm.testing.requires_llvm def test_buffer_broadcast(): diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 355d3abed559..5b123e883849 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -18,6 +18,7 @@ from tvm import te import numpy as np import tvm.testing +from tvm.topi.math import cast def test_for(): @@ -30,8 +31,6 @@ def test_for(): A[j] = A[j] + 2 body = ib.get() - assert isinstance(body, tvm.tir.AttrStmt) - body = body.body assert isinstance(body, tvm.tir.Allocate) body = body.body assert isinstance(body, tvm.tir.For) @@ -497,6 +496,62 @@ def check_target(target, ir): check_target("vulkan", searchsorted_ir_gpu) +@tvm.testing.requires_gpu +def test_dyn_shared(): + n = te.size_var("n") + dtype = "float32" + A = te.placeholder((n,), name="A") + + def test_device_ir(A, B): + n = A.shape[0] + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", n) + + temp = ib.allocate(dtype, (n,), scope="shared.dyn") # n is symbolic size + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + + temp[tx] = Aptr[tx] + depth = tvm.tir.log2(cast(n, "float32")) + + with ib.for_range(0, depth) as i: + ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + d = n >> (i + 1) + with ib.if_scope(tx < d): + temp[tx] += temp[tx + d] + + Bptr[0] = temp[0] + return ib.get() + + B = te.extern( + (1,), + [A], + lambda ins, outs: test_device_ir(ins[0], outs[0]), + name="reduce", + dtype=dtype, + ) + s = te.create_schedule(B.op) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + freduce = tvm.build(s, [A, B], target) + dev = tvm.device(target, 0) + + for n in [512, 1024]: + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.zeros(1, dtype=B.dtype), dev) + freduce(a, b) + tvm.testing.assert_allclose(b.numpy()[0], np.sum(a.numpy()), 1e-4, 1e-4) + + for target in ["cuda", "nvptx"]: + check_target(target) + + if __name__ == "__main__": test_prefetch() test_if() @@ -507,3 +562,4 @@ def check_target(target, ir): test_while_collatz() test_while_mandel() test_while_binary_search() + test_dyn_shared() diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py new file mode 100644 index 000000000000..78a8c5117849 --- /dev/null +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -0,0 +1,455 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +from tvm import tir +from tvm.script import ty + + +def _check(original, transformed): + mod = tvm.IRModule.from_expr(original) + mod = tvm.tir.transform.LowerMatchBuffer()(mod) + mod = tvm.tir.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed) + + +def _check_fail(original): + mod = tvm.IRModule.from_expr(original) + with pytest.raises(tvm.TVMError): + mod = tvm.tir.transform.LowerMatchBuffer()(mod) + + +@tvm.script.tir +def buffer_load_store(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i, j, k in tir.grid(4, 16, 8): + with tir.block([]): + tir.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) + tir.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) + sub_A = tir.match_buffer( + A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2], (4, 1, 2), offset_factor=1 + ) + sub_C = tir.match_buffer( + C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2], (4, 2), offset_factor=1 + ) + for ii, kk in tir.grid(4, 2): + sub_A[ii, 0, kk] += sub_C[ii, kk] + + +@tvm.script.tir +def transformed_buffer_load_store(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i, j, k in tir.grid(4, 16, 8): + with tir.block([]): + tir.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) + tir.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) + for ii, kk in tir.grid(4, 2): + A[i * 4 + ii, j, k * 2 + kk] += C[i * 4 + ii, k * 2 + kk] + + +@tvm.ir.register_op_attr("tir.intrin_test", "") +def intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1): + return 0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (32, 64, 128)) + B = tir.match_buffer(b, (64, 64, 64)) + for i, j, k in tir.grid(2, 64, 8): + with tir.block([]): + tir.reads([]) + tir.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) + sub_A = tir.match_buffer( + A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16], + (16, 1, 16), + strides=[8192, 128, 1], + offset_factor=1, + ) + tir.evaluate( + tir.intrin_test( + sub_A.data, + sub_A.elem_offset, + sub_A.strides[0], + sub_A.strides[1], + sub_A.shape[0], + sub_A.shape[1], + dtype="handle", + ) + ) + for i, j, k in tir.grid(64, 2, 8): + with tir.block([]): + Bs_0 = tir.var("int32") + Bs_1 = tir.var("int32") + tir.reads([]) + tir.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) + sub_B = tir.match_buffer( + B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8], + (32, 8), + strides=[Bs_0, Bs_1], + offset_factor=1, + ) + tir.evaluate( + tir.intrin_test( + sub_B.data, + sub_B.elem_offset, + sub_B.strides[0], + sub_B.strides[1], + sub_B.shape[0], + sub_B.shape[1], + dtype="handle", + ) + ) + + +@tvm.script.tir +def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (32, 64, 128)) + B = tir.match_buffer(b, (64, 64, 64)) + for i, j, k in tir.grid(2, 64, 8): + with tir.block([]): + tir.reads([]) + tir.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) + tir.evaluate( + tir.intrin_test( + A.data, + i * 131072 + j * 128 + k * 16, + 8192, + 128, + 16, + 1, + dtype="handle", + ) + ) + for i, j, k in tir.grid(64, 2, 8): + with tir.block([]): + tir.reads([]) + tir.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) + tir.evaluate( + tir.intrin_test( + B.data, + i * 4096 + j * 2048 + k * 8, + 64, + 1, + 32, + 8, + dtype="handle", + ) + ) + + +@tvm.script.tir +def recursive_match(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (64, 64, 64)) + B = tir.match_buffer(b, (64, 64, 64)) + for i, j, k in tir.grid(64, 4, 4): + with tir.block([]): + tir.reads([]) + tir.writes( + [ + A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + ] + ) + As_0 = tir.var("int32") + As_1 = tir.var("int32") + sub_A = tir.match_buffer( + A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + (16, 16), + strides=[As_0, As_1], + offset_factor=1, + ) + sub_B = tir.match_buffer( + B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + (16, 16), + offset_factor=1, + ) + for jj, kk in tir.grid(4, 4): + with tir.block([]): + tir.reads([]) + tir.writes( + [ + sub_A[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], + sub_B[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], + ] + ) + Ass_0 = tir.var("int32") + Ass_1 = tir.var("int32") + sub_sub_A = tir.match_buffer( + sub_A[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], + (4, 4), + strides=[Ass_0, Ass_1], + offset_factor=1, + ) + sub_sub_B = tir.match_buffer( + sub_B[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], + (4, 4), + offset_factor=1, + ) + tir.evaluate( + tir.intrin_test( + sub_sub_A.data, + sub_sub_A.elem_offset, + sub_sub_A.strides[0], + sub_sub_A.strides[1], + sub_sub_A.shape[0], + sub_sub_A.shape[1], + dtype="handle", + ) + ) + for jjj, kkk in tir.grid(4, 4): + sub_sub_B[jjj, kkk] = 1 + + +@tvm.script.tir +def transformed_recursive_match(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (64, 64, 64)) + B = tir.match_buffer(b, (64, 64, 64)) + for i, j, k in tir.grid(64, 4, 4): + with tir.block([]): + tir.reads([]) + tir.writes( + [ + A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + ] + ) + for jj, kk in tir.grid(4, 4): + with tir.block([]): + tir.reads([]) + tir.writes( + [ + A[ + i, + j * 16 + jj * 4 : j * 16 + jj * 4 + 4, + k * 16 + kk * 4 : k * 16 + kk * 4 + 4, + ], + B[ + i, + j * 16 + jj * 4 : j * 16 + jj * 4 + 4, + k * 16 + kk * 4 : k * 16 + kk * 4 + 4, + ], + ] + ) + tir.evaluate( + tir.intrin_test( + A.data, + i * 4096 + j * 1024 + jj * 256 + k * 16 + kk * 4, + 64, + 1, + 4, + 4, + dtype="handle", + ) + ) + for jjj, kkk in tir.grid(4, 4): + B[i, j * 16 + jj * 4 + jjj, k * 16 + kk * 4 + kkk] = 1 + + +@tvm.script.tir +def symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.int32) -> None: + A = tir.match_buffer(a, (n * m, m)) + B = tir.match_buffer(b, (n * 2, m * 4)) + for i in range(0, n): + with tir.block([]): + tir.reads([]) + tir.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) + Bs_0 = tir.var("int32") + Bs_1 = tir.var("int32") + sub_A = tir.match_buffer(A[i * m : i * m + m, 0:m], (m, m), offset_factor=1) + sub_B = tir.match_buffer( + B[i * n : i * n + 2, 0 : m * 4], (2, m * 4), strides=[Bs_0, Bs_1], offset_factor=1 + ) + for ii, jj in tir.grid(m, m): + sub_A[ii, jj] = 1 + for j in range(0, 4): + tir.evaluate( + tir.intrin_test( + sub_B.data, + sub_B.elem_offset, + sub_B.strides[0], + sub_B.strides[1], + sub_B.shape[0], + sub_B.shape[1], + dtype="handle", + ) + ) + + +@tvm.script.tir +def transformed_symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.int32) -> None: + A = tir.match_buffer(a, (n * m, m)) + B = tir.match_buffer(b, (n * 2, m * 4)) + for i in range(0, n): + with tir.block([]): + tir.reads([]) + tir.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) + for ii, jj in tir.grid(m, m): + A[i * m + ii, jj] = 1 + for j in range(0, 4): + tir.evaluate( + tir.intrin_test( + B.data, + i * n * (m * 4), + m * 4, + 1, + 2, + m * 4, + dtype="handle", + ) + ) + + +@tvm.script.tir +def rank0_buffer(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + B = tir.match_buffer(b, (8, 8)) + for i, j in tir.grid(8, 8): + with tir.block([]): + tir.reads([]) + tir.writes([A[i, j], B[i, j]]) + sub_A = tir.match_buffer(A[i, j], (), offset_factor=1) + sub_B = tir.match_buffer(B[i, j], (), offset_factor=1) + sub_A[()] = 1 + tir.evaluate( + tir.intrin_test( + sub_B.data, + sub_B.elem_offset, + 0, + 0, + 0, + 0, + dtype="handle", + ) + ) + + +@tvm.script.tir +def transformed_rank0_buffer(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + B = tir.match_buffer(b, (8, 8)) + for i, j in tir.grid(8, 8): + with tir.block([]): + tir.reads([]) + tir.writes([A[i, j], B[i, j]]) + A[i, j] = 1 + tir.evaluate( + tir.intrin_test( + B.data, + i * 8 + j, + 0, + 0, + 0, + 0, + dtype="handle", + ) + ) + + +@tvm.script.tir +def fail_match_load(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 8): + with tir.block([]): + tir.reads(A[i, j]) + tir.writes([]) + sub_A = tir.match_buffer(A[i, j], ()) + tir.evaluate(tir.load("float32", sub_A.data, 0)) + + +@tvm.script.tir +def fail_match_store(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 8): + with tir.block([]): + tir.reads([]) + tir.writes(A[i, j]) + sub_A = tir.match_buffer(A[i, j], ()) + sub_A.data[0] = 1 + + +@tvm.script.tir +def fail_buffer_bind(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 2): + with tir.block([]): + stride = tir.var("int32") + sub_A = tir.match_buffer( + A[i, j * 4 : j * 4 + 4], (1, 4), strides=[stride, stride], offset_factor=1 + ) + for jj in range(0, 4): + sub_A[i, j * 4 + jj] = 1 + + +@tvm.script.tir +def fail_match_func_param(a: ty.handle, m: ty.handle, n: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 2): + with tir.block([]): + sub_A = tir.match_buffer( + A[i, j * 4 : j * 4 + 4], (1, 4), strides=[m, n], offset_factor=1 + ) + for jj in range(0, 4): + sub_A[i, j * 4 + jj] = 1 + + +def test_buffer_load_store(): + _check(buffer_load_store, transformed_buffer_load_store) + + +def test_opaque_access(): + _check(opaque_access, transformed_opaque_access) + + +def test_recursive_match(): + _check(recursive_match, transformed_recursive_match) + + +def test_symbolic_match(): + _check(symbolic_match, transformed_symbolic_match) + + +def test_rank0_buffer(): + _check(rank0_buffer, transformed_rank0_buffer) + + +def test_fail_load_store(): + _check_fail(fail_match_load) + _check_fail(fail_match_store) + + +def test_fail_buffer_bind(): + _check_fail(fail_buffer_bind) + + +def test_fail_match_func_param(): + _check_fail(fail_match_func_param) + + +if __name__ == "__main__": + test_buffer_load_store() + test_opaque_access() + test_recursive_match() + test_symbolic_match() + test_rank0_buffer() + test_fail_load_store() + test_fail_buffer_bind() + test_fail_match_func_param() diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 07a82ba9936c..dbae0b6fa516 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -398,7 +398,7 @@ def test_block_blockrealize(): ) ] writes = [tvm.tir.BufferRegion(A, [tvm.ir.Range.from_min_extent(vx_var, 1)])] - match_buffer_region = tvm.tir.MatchBufferRegion( + block_match_buffer = tvm.tir.MatchBufferRegion( match_buffer, tvm.tir.BufferRegion(B, [tvm.ir.Range(0, 16), tvm.ir.Range(0, 16)]) ) @@ -410,7 +410,7 @@ def test_block_blockrealize(): body, init=init_body, alloc_buffers=[alloc_buffer], - match_buffers=[match_buffer_region], + match_buffers=[block_match_buffer], annotations={"attr_key": "attr_value"}, ) @@ -462,7 +462,7 @@ def test_block_blockrealize(): assert output.find("reads") != -1 assert output.find("writes") != -1 assert output.find("alloc_buffer") != -1 - assert output.find("match_buffer_region") != -1 + assert output.find("match_buffer") != -1 assert output.find("attr") != -1 assert output.find("with init()") != -1 @@ -471,7 +471,6 @@ def test_block_blockrealize(): test_intimm_cond() test_buffer_load_store() test_vars() - test_scoped_storage_var() test_prim_func() test_cast() test_attr() diff --git a/tests/python/unittest/test_tir_ops.py b/tests/python/unittest/test_tir_ops.py index f1f8cf70d0c9..9725650eadae 100644 --- a/tests/python/unittest/test_tir_ops.py +++ b/tests/python/unittest/test_tir_ops.py @@ -119,7 +119,8 @@ def verify_general_dtype_support(f, is_conditional=False): [("bool", "int32"), "int32"], [("int32", "float32"), "float32"], [("int32", "int64"), "int64"], - [("uint32", "int32"), "int32"], + [("uint32", "int8"), "uint32"], + [("uint32", "int32"), "uint32"], ] for (lhs_dtype, rhs_dtype), out_dtype in rules: lhs = te.var("lhs", dtype=lhs_dtype) @@ -146,13 +147,22 @@ def verify_callop_float_only(f): rhs = te.var("rhs", dtype=rhs_dtype) if "float" not in lhs_dtype and "float" not in rhs_dtype: check_throws(lambda: f(lhs, rhs)) - elif "float" in lhs_dtype and "float" in rhs_dtype and lhs_dtype != rhs_dtype: - check_throws(lambda: f(lhs, rhs)) elif "float" in lhs_dtype: out = f(lhs, rhs) - assert out.dtype == lhs_dtype - assert out.args[0].dtype == lhs_dtype - assert out.args[1].dtype == lhs_dtype + + # Upcasting for floating point types + dtypes = [lhs_dtype, rhs_dtype] + if "float64" in dtypes: + target_dtype = "float64" + elif "float32" in dtypes: + target_dtype = "float32" + else: + target_dtype = "int32" + assert out.dtype == target_dtype + + # Final inputs are the right type + assert out.args[0].dtype == target_dtype + assert out.args[1].dtype == target_dtype else: out = f(lhs, rhs) assert out.dtype == rhs_dtype @@ -165,14 +175,18 @@ def verify_callop_float_only(f): verify_general_dtype_support(lambda a, b: a <= b, is_conditional=True) verify_callop_float_only(lambda a, b: te.power(a, b)) + # verify bool & int32 constant folding + assert tvm.tir.const(1) == tvm.tir.const(True) + assert tvm.tir.const(2) != tvm.tir.const(True) + def test_if_then_else(): cases = [ [(te.var("cond", dtype="bool"), "bool", "int32"), "int32"], [(True, "int32", "float32"), "float32"], [(False, "int32", "int64"), "int64"], - [(te.var("cond", dtype="bool"), "uint32", "int32"), "int32"], - [(te.var("cond", dtype="int32"), "uint32", "int32"), "int32"], + [(te.var("cond", dtype="bool"), "uint32", "int32"), "uint32"], + [(te.var("cond", dtype="int32"), "uint32", "int32"), "uint32"], ] for (cond, lhs_dtype, rhs_dtype), out_dtype in cases: lhs = te.var("lhs", dtype=lhs_dtype) diff --git a/tests/python/unittest/test_tir_schedule_block_scope.py b/tests/python/unittest/test_tir_schedule_block_scope.py index 4a914f5063f8..f66dca30d998 100644 --- a/tests/python/unittest/test_tir_schedule_block_scope.py +++ b/tests/python/unittest/test_tir_schedule_block_scope.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest import tvm from tvm import tir from tvm.script import ty @@ -81,7 +84,7 @@ def f_visit(node): def test_elementwise_dependency(): - s = tir.ScheduleState(elementwise, debug_mode=True) + s = tir.ScheduleState(elementwise, debug_mask="all") root = _get_block(s, "root") block_b = _get_block(s, "B") block_c = _get_block(s, "C") @@ -98,7 +101,7 @@ def test_elementwise_dependency(): def test_matmul_dependency(): - s = tir.ScheduleState(matmul, debug_mode=True) + s = tir.ScheduleState(matmul, debug_mask="all") root = _get_block(s, "root") init = _get_block(s, "init") update = _get_block(s, "update") @@ -123,7 +126,7 @@ def test_matmul_dependency(): def test_war_dependency(): - s = tir.ScheduleState(war_dependency, debug_mode=True) + s = tir.ScheduleState(war_dependency, debug_mask="all") root = _get_block(s, "root") block_c = _get_block(s, "C") block_b = _get_block(s, "B") @@ -140,6 +143,4 @@ def test_war_dependency(): if __name__ == "__main__": - test_elementwise_dependency() - test_matmul_dependency() - test_war_dependency() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index c34ec8d610d6..ea322920b846 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -15,10 +15,13 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring +import sys + import pytest import tvm from tvm import tir from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable @@ -171,7 +174,7 @@ def buffer_matched(a: ty.handle, c: ty.handle) -> None: with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 with tir.block([128, 128], "C") as [vi, vj]: - Bb = tir.match_buffer_region(B[vi : vi + 1, vj]) + Bb = tir.match_buffer(B[vi : vi + 1, vj], (1, 1)) C[vi, vj] = Bb[0, 0] + 1.0 @@ -221,34 +224,37 @@ def elementwise_multi_loads_inlined(a: ty.handle, c: ty.handle) -> None: def test_compute_inline_elementwise(): - sch = tir.Schedule(elementwise, debug_mode=True) + sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" + verify_trace_roundtrip(sch=sch, mod=elementwise) def test_compute_inline_under_loop(): - sch = tir.Schedule(elementwise_under_loop, debug_mode=True) + sch = tir.Schedule(elementwise_under_loop, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" + verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop) def test_compute_inline_as_dce(): - sch = tir.Schedule(elementwise_standalone, debug_mode=True) + sch = tir.Schedule(elementwise_standalone, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_standalone_dce, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" + verify_trace_roundtrip(sch=sch, mod=elementwise_standalone) def test_compute_inline_multi_consumer(): - sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mode=True) + sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") block_d = sch.get_block("D") @@ -256,118 +262,108 @@ def test_compute_inline_multi_consumer(): tvm.ir.assert_structural_equal(elementwise_multi_consumer_inlined, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" assert sch.get(block_d).name_hint == "D" + verify_trace_roundtrip(sch=sch, mod=elementwise_multi_producer_consumer) def test_compute_inline_fail_multi_writer(): - sch = tir.Schedule(fail_multi_reader_writer, debug_mode=True, error_render_level="detail") + sch = tir.Schedule(fail_multi_reader_writer, debug_mask="all") block_b = sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) def test_reverse_compute_inline_elementwise(): - sch = tir.Schedule(elementwise, debug_mode=True) + sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_b).name_hint == "B" + verify_trace_roundtrip(sch=sch, mod=elementwise) def test_reverse_compute_inline_under_loop(): - sch = tir.Schedule(elementwise_under_loop, debug_mode=True) + sch = tir.Schedule(elementwise_under_loop, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_b).name_hint == "B" + verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop) def test_reverse_compute_inline_fail_as_dce(): - sch = tir.Schedule(elementwise_standalone, debug_mode=True) + sch = tir.Schedule(elementwise_standalone, debug_mask="all") block_b = sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_b) def test_reverse_compute_inline_fail_multi_producer(): - sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mode=True) + sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mask="all") block_d = sch.get_block("D") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_d) def test_reverse_compute_inline_fail_multi_reader(): - sch = tir.Schedule(fail_multi_reader_writer, debug_mode=True) + sch = tir.Schedule(fail_multi_reader_writer, debug_mask="all") block_c = sch.get_block("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c) def test_reverse_compute_multi_reverse_loads(): - sch = tir.Schedule(elementwise_multi_reverse_loads, debug_mode=True) + sch = tir.Schedule(elementwise_multi_reverse_loads, debug_mask="all") block_c = sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_multi_reverse_loads_inlined, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_loads) def test_reverse_compute_fail_multi_reverse_loads(): - sch = tir.Schedule(elementwise_multi_loads, debug_mode=True) + sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") block_c = sch.get_block("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c) def test_opaque_access_load(): - sch = tir.Schedule(opaque_access_load, debug_mode=True) + sch = tir.Schedule(opaque_access_load, debug_mask="all") block_b = sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) def test_opaque_access_store(): - sch = tir.Schedule(opaque_access_store, debug_mode=True) + sch = tir.Schedule(opaque_access_store, debug_mask="all") block_b = sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) def test_buffer_matched(): - sch = tir.Schedule(buffer_matched, debug_mode=True) + sch = tir.Schedule(buffer_matched, debug_mask="all") block_b = sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) def test_compute_inline_predicate(): - sch = tir.Schedule(elementwise_predicate, debug_mode=True) + sch = tir.Schedule(elementwise_predicate, debug_mask="all") block_b = sch.get_block("B") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_predicate_inlined, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_predicate) def test_compute_inline_multi_loads(): - sch = tir.Schedule(elementwise_multi_loads, debug_mode=True) + sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") block_b = sch.get_block("B") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_multi_loads_inlined, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_multi_loads) if __name__ == "__main__": - test_compute_inline_elementwise() - test_compute_inline_under_loop() - test_compute_inline_as_dce() - test_compute_inline_multi_consumer() - test_compute_inline_fail_multi_writer() - test_reverse_compute_inline_elementwise() - test_reverse_compute_inline_under_loop() - test_reverse_compute_inline_fail_as_dce() - test_reverse_compute_inline_fail_multi_producer() - test_reverse_compute_inline_fail_multi_reader() - test_reverse_compute_multi_reverse_loads() - test_reverse_compute_fail_multi_reverse_loads() - test_opaque_access_load() - test_opaque_access_store() - test_buffer_matched() - test_compute_inline_predicate() - test_compute_inline_multi_loads() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_error.py b/tests/python/unittest/test_tir_schedule_error.py index 1fa658feabe3..6fcd0dc2aedc 100644 --- a/tests/python/unittest/test_tir_schedule_error.py +++ b/tests/python/unittest/test_tir_schedule_error.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring +import sys + import pytest import tvm from tvm import tir from tvm.script import ty - # pylint: disable=no-member,invalid-name,unused-variable @@ -41,7 +42,7 @@ def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: def test_tir_schedule_error_detail(): - sch = tir.Schedule(matmul, debug_mode=True, error_render_level="detail") + sch = tir.Schedule(matmul, debug_mask="all", error_render_level="detail") with pytest.raises(tir.ScheduleError) as excinfo: sch.get_block("wrong_name") (msg,) = excinfo.value.args @@ -49,7 +50,7 @@ def test_tir_schedule_error_detail(): def test_tir_schedule_error_fast(): - sch = tir.Schedule(matmul, debug_mode=True, error_render_level="fast") + sch = tir.Schedule(matmul, debug_mask="all", error_render_level="fast") with pytest.raises(tir.ScheduleError) as excinfo: sch.get_block("wrong_name") (msg,) = excinfo.value.args @@ -57,7 +58,7 @@ def test_tir_schedule_error_fast(): def test_tir_schedule_error_none(): - sch = tir.Schedule(matmul, debug_mode=True, error_render_level="none") + sch = tir.Schedule(matmul, debug_mask="all", error_render_level="none") with pytest.raises(tir.ScheduleError) as excinfo: sch.get_block("wrong_name") (msg,) = excinfo.value.args @@ -65,6 +66,4 @@ def test_tir_schedule_error_none(): if __name__ == "__main__": - test_tir_schedule_error_detail() - test_tir_schedule_error_fast() - test_tir_schedule_error_none() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py new file mode 100644 index 000000000000..5649a06bd3b8 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -0,0 +1,365 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def element_wise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_parallelized(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i0 in tir.parallel(0, 128): + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, i1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_i_bound(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i0 in tir.thread_binding(0, 128, thread="threadIdx.x"): + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, i1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_compute_at_split(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + for i in tir.serial(0, 128): + for j0 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j0) + B[vi, vj] = A[vi, vj] * 2.0 + for j1o, j1i in tir.grid(32, 4): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j1o * 4 + j1i) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def element_wise_compute_at_split_vectorized(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + for i in tir.serial(0, 128): + for j0 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j0) + B[vi, vj] = A[vi, vj] * 2.0 + for j1o in tir.serial(0, 32): + for j1i in tir.vectorized(0, 4): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j1o * 4 + j1i) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def element_wise_split_predicate(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + for i, j_0, j_1 in tir.grid(128, 13, 10): + with tir.block([128, 128], "B") as [vi, vj]: + tir.where(j_0 * 10 + j_1 < 128) + tir.bind(vi, i) + tir.bind(vj, j_0 * 10 + j_1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_split_predicate_parallelized(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + for i in tir.serial(0, 128): + for j_0 in tir.parallel(0, 13): + for j_1 in tir.serial(0, 10): + with tir.block([128, 128], "B") as [vi, vj]: + tir.where(j_0 * 10 + j_1 < 128) + tir.bind(vi, i) + tir.bind(vj, j_0 * 10 + j_1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_split_predicate_vectorized(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + for i in tir.vectorized(0, 128): + for j_0, j_1 in tir.grid(13, 10): + with tir.block([128, 128], "B") as [vi, vj]: + tir.where(j_0 * 10 + j_1 < 128) + tir.bind(vi, i) + tir.bind(vj, j_0 * 10 + j_1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_compute_at_split_j0_j1o_bound(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + for i in tir.serial(0, 128): + for j0 in tir.thread_binding(0, 128, thread="threadIdx.x"): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j0) + B[vi, vj] = A[vi, vj] * 2.0 + for j1o in tir.thread_binding(0, 32, thread="threadIdx.x"): + for j1i in tir.serial(0, 4): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j1o * 4 + j1i) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def rowsum(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_unrolled(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + for i0 in tir.unroll(0, 128): + for i1 in tir.serial(0, 128): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i0) + tir.bind(vk, i1) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + for i, k in tir.grid(128, 16): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i) + tir.bind(vk, tir.floordiv(k * k, 2)) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_compact_data_flow(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vk] = 0.0 + B[vk] = B[vk] + A[vi, vk] + + +@tvm.script.tir +def rowsum_cross_thread_reduction(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + for i0 in tir.serial(0, 128): + for i1 in tir.thread_binding(0, 128, thread="threadIdx.x"): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i0) + tir.bind(vk, i1) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def opaque_block(a: ty.handle) -> None: + A = tir.match_buffer(a, (16,)) + for i in tir.serial(0, 15): + with tir.block([], "opaque"): + A[i + 1] = A[i + 1] + A[i] + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_parallel(): + s = tir.Schedule(element_wise, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.parallel(i) + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_parallelized) + verify_trace_roundtrip(s, mod=element_wise) + + +def test_parallel_predicate(): + s = tir.Schedule(element_wise_split_predicate, debug_mask="all") + _, j, _ = s.get_loops(s.get_block("B")) + s.parallel(j) + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_split_predicate_parallelized) + verify_trace_roundtrip(s, mod=element_wise_split_predicate) + + +def test_parallel_reduction_block_iter(): + s = tir.Schedule(matmul, debug_mask="all") + _, _, k = s.get_loops(s.get_block("C")) + with pytest.raises(tvm.tir.ScheduleError): + s.parallel(k) + + +def test_parallel_not_quasi_affine(): + s = tir.Schedule(rowsum_not_quasi_affine, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.parallel(i) + + +def test_parallel_not_compact_data_flow(): + s = tir.Schedule(rowsum_not_compact_data_flow, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.parallel(i) + + +def test_vectorize(): + s = tir.Schedule(element_wise_compute_at_split, debug_mask="all") + _, _, j1i = s.get_loops(s.get_block("C")) + s.vectorize(j1i) + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_compute_at_split_vectorized) + verify_trace_roundtrip(s, mod=element_wise_compute_at_split) + + +def test_vectorize_predicate(): + s = tir.Schedule(element_wise_split_predicate, debug_mask="all") + i, _, _ = s.get_loops(s.get_block("B")) + s.vectorize(i) + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_split_predicate_vectorized) + verify_trace_roundtrip(s, mod=element_wise_split_predicate) + + +def test_vectorize_opaque_block(): + s = tir.Schedule(opaque_block, debug_mask="all") + (i,) = s.get_loops(s.get_block("opaque")) + with pytest.raises(tvm.tir.ScheduleError): + s.vectorize(i) + + +def test_unroll(): + s = tir.Schedule(rowsum, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.unroll(i) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_unrolled) + verify_trace_roundtrip(s, mod=rowsum) + + +def test_unroll_after_bind(): + s = tir.Schedule(rowsum, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.bind(i, "blockIdx.x") + s.unroll(i) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_unrolled) + verify_trace_roundtrip(s, mod=rowsum) + + +def test_bind1(): + s = tir.Schedule(element_wise, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.bind(i, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_i_bound) + verify_trace_roundtrip(s, mod=element_wise) + + +def test_bind2(): + s = tir.Schedule(element_wise_compute_at_split, debug_mask="all") + _, j0 = s.get_loops(s.get_block("B")) + _, j1o, _ = s.get_loops(s.get_block("C")) + s.bind(j0, "threadIdx.x") + s.bind(j1o, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_compute_at_split_j0_j1o_bound) + verify_trace_roundtrip(s, mod=element_wise_compute_at_split) + + +def test_bind_cross_thread_reduction(): + s = tir.Schedule(rowsum, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + s.bind(k, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_cross_thread_reduction) + verify_trace_roundtrip(s, mod=rowsum) + + +def test_bind_not_cross_thread_reduction(): + s = tir.Schedule(rowsum, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.bind(k, "blockIdx.x") + + +def test_bind_after_bind(): + s = tir.Schedule(element_wise, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.bind(i, "blockIdx.x") + s.bind(i, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_i_bound) + verify_trace_roundtrip(s, mod=element_wise) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_instruction.py b/tests/python/unittest/test_tir_schedule_instruction.py new file mode 100644 index 000000000000..9e6f447dd3e6 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_instruction.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +# mypy: ignore-errors +import sys + +import pytest +from tvm.tir.schedule import BlockRV, Instruction, InstructionKind, LoopRV + + +def test_inst_kind_get(): + kind = InstructionKind.get("EnterPostproc") + assert not kind.is_pure + assert kind.name == "EnterPostproc" + + +def test_inst_construct_1(): + block = BlockRV() + loop0 = LoopRV() + loop1 = LoopRV() + inst = Instruction( + kind=InstructionKind.get("GetLoops"), + inputs=[block], + attrs=[], + outputs=[loop0, loop1], + ) + assert str(inst) == "_, _ = sch.get_loops(block=_)" + assert len(inst.inputs) == 1 + assert len(inst.attrs) == 0 + assert len(inst.outputs) == 2 + assert inst.kind.same_as(InstructionKind.get("GetLoops")) + assert inst.inputs[0].same_as(block) + assert inst.outputs[0].same_as(loop0) + assert inst.outputs[1].same_as(loop1) + + +def test_inst_construct_2(): + block = BlockRV() + inst = Instruction( + kind=InstructionKind.get("ComputeInline"), + inputs=[block], + attrs=[], + outputs=[], + ) + assert str(inst) == "sch.compute_inline(block=_)" + assert len(inst.inputs) == 1 + assert len(inst.attrs) == 0 + assert len(inst.outputs) == 0 + assert inst.kind.same_as(InstructionKind.get("ComputeInline")) + assert inst.inputs[0].same_as(block) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py new file mode 100644 index 000000000000..067952899c0a --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -0,0 +1,632 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + + +@tvm.script.tir +def transformed_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(128, 128, 4, 8, 4): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + tir.bind(vi, i0) + tir.bind(vj, i1) + tir.bind(vk, (((i2_outer * 32) + (i2_inner_outer * 4)) + i2_inner_inner)) + tir.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) + tir.writes([C[vi, vj]]) + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) + + +@tvm.script.tir +def matmul_rfactor(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + C_rf = tir.alloc_buffer([4, 128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(128, 128, 4, 8, 4): + with tir.block( + [4, 128, 128, tir.reduce_axis(0, 4), tir.reduce_axis(0, 8)], "update_rf" + ) as [vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer]: + tir.bind(vi2_inner_inner, i2_inner_inner) + tir.bind(vi, i0) + tir.bind(vj, i1) + tir.bind(vi2_outer, i2_outer) + tir.bind(vi2_inner_outer, i2_inner_outer) + with tir.init(): + C_rf[vi2_inner_inner, vi, vj] = 0.0 + C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + ( + A[vi, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] + * B[vj, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] + ) + + for i0_1, i1_1, i2_inner_inner_1 in tir.grid(128, 128, 4): + with tir.block([tir.reduce_axis(0, 4), 128, 128], "update") as [ + vi2_inner_inner_1, + vi_1, + vj_1, + ]: + tir.bind(vi2_inner_inner_1, i2_inner_inner_1) + tir.bind(vi_1, i0_1) + tir.bind(vj_1, i1_1) + with tir.init(): + C[vi_1, vj_1] = 0.0 + C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1] + + +@tvm.script.tir +def matmul_not_stage_pipeline(a: ty.handle, b: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [256, 256]) + B = tir.match_buffer(b, [256, 256]) + D = tir.match_buffer(d, [256, 256]) + C = tir.alloc_buffer([256, 256]) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + with tir.block([256, 256], "D") as [vi, vj]: + D[vi, vj] = C[vi, vj] + + +@tvm.script.tir +def matmul_not_same_buffer_access(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj] + + +@tvm.script.tir +def matmul_loop_multiple_children(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + D = tir.match_buffer(d, [128, 128]) + + for k, i, j in tir.grid(128, 128, 128): + with tir.block([tir.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]: + tir.bind(ck, k) + tir.bind(ci, i) + tir.bind(cj, j) + with tir.init(): + C[ci, cj] = 0.0 + C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] + with tir.block([tir.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]: + tir.bind(dk, k) + tir.bind(di, i) + tir.bind(dj, j) + with tir.init(): + D[di, dj] = 0.0 + D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj] + + +@tvm.script.tir +def square_sum(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + C = tir.match_buffer(c, [16]) + + with tir.block([16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]: + with tir.init(): + C[b] = 0.0 + C[b] = C[b] + A[b, i, j] * A[b, i, j] + + +@tvm.script.tir +def square_sum_rfactor(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + C = tir.match_buffer(c, [16]) + C_rf = tir.alloc_buffer([16, 256]) + + for i0, i1, i2 in tir.grid(16, 256, 256): + with tir.block([256, 16, tir.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: + tir.bind(vi2, i2) + tir.bind(b, i0) + tir.bind(i, i1) + with tir.init(): + C_rf[b, vi2] = 0.0 + C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) + + for i0_1, i2_1 in tir.grid(16, 256): + with tir.block([tir.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]: + tir.bind(vi2_1, i2_1) + tir.bind(b_1, i0_1) + with tir.init(): + C[b_1] = 0.0 + C[b_1] = C[b_1] + C_rf[b_1, vi2_1] + + +@tvm.script.tir +def transformed_square_sum_square_root(a: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + D = tir.match_buffer(d, [16]) + C = tir.alloc_buffer([16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): + with tir.block([16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]: + tir.bind(b, i0) + tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) + tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) + tir.reads([C[b], A[b, i, j]]) + tir.writes([C[b]]) + with tir.init(): + C[b] = 0.0 + C[b] = C[b] + (A[b, i, j] * A[b, i, j]) + for i0_1 in tir.serial(0, 16): + with tir.block([16], "D") as [b_1]: + tir.bind(b_1, i0_1) + tir.reads([C[b_1]]) + tir.writes([D[b_1]]) + D[b_1] = tir.sqrt(C[b_1], dtype="float32") + + +@tvm.script.tir +def square_sum_square_root_rfactor(a: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + D = tir.match_buffer(d, [16]) + C = tir.alloc_buffer([16]) + C_rf = tir.alloc_buffer([1, 16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): + with tir.block([1, 16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C_rf") as [ + vi1_i2_fused_inner, + b, + i, + j, + ]: + tir.bind(vi1_i2_fused_inner, i1_i2_fused_inner) + tir.bind(b, i0) + tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) + tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) + with tir.init(): + C_rf[vi1_i2_fused_inner, b] = 0.0 + C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) + + for i0_1, i1_i2_fused_inner_1 in tir.grid(16, 1): + with tir.block([tir.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: + tir.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) + tir.bind(b_1, i0_1) + with tir.init(): + C[b_1] = 0.0 + C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] + + for i0_2 in tir.serial(0, 16): + with tir.block([16], "D") as [b_2]: + tir.bind(b_2, i0_2) + D[b_2] = tir.sqrt(C[b_2], dtype="float32") + + +@tvm.script.tir +def element_wise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def rowsum(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + for i, k in tir.grid(128, 16): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i) + tir.bind(vk, tir.floordiv(k * k, 2)) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_dominant(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi, vk] = 0.0 + B[vi, vk] = B[vi, vk] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_serial(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + for i in tir.serial(0, 128): + for k in tir.parallel(0, 128): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i) + tir.bind(vk, k) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_wrong_reduce_pattern1(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi] = 1.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_wrong_reduce_pattern2(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] - A[vi, vk] + + +@tvm.script.tir +def rowsum_transformed(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + for io, ii_ko_fused, ki in tir.grid(32, 128, 4): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, io * 4 + tir.floordiv(ii_ko_fused, 32)) + tir.bind(vk, tir.floormod(ii_ko_fused, 32) * 4 + ki) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_zero_dim(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128]) + B = tir.match_buffer(b, []) + + with tir.block([tir.reduce_axis(0, 128)], "B") as [k]: + with tir.init(): + B[()] = 0.0 + B[()] = B[()] + A[k] + + +@tvm.script.tir +def rowsum_zero_dim_rfactor(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128]) + B = tir.match_buffer(b, []) + B_rf = tir.alloc_buffer([128]) + + with tir.block([128], "B_rf") as [vi0]: + with tir.init(): + B_rf[vi0] = 0.0 + B_rf[vi0] = B_rf[vi0] + A[vi0] + + with tir.block([tir.reduce_axis(0, 128)], "B") as [vi0_1]: + with tir.init(): + B[()] = 0.0 + B[()] = B[()] + B_rf[vi0_1] + + +@tvm.script.tir +def multiple_reduction_blocks(a: ty.handle, f: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16, 16)) + C = tir.alloc_buffer((16, 16)) + D = tir.alloc_buffer((16, 16)) + E = tir.alloc_buffer((16, 16)) + F = tir.match_buffer(f, (16, 16)) + + for i in tir.serial(0, 16): + for j1 in tir.serial(0, 16): + for k1o, k1i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "C") as [ci, cj, ck]: + tir.bind(ci, i) + tir.bind(cj, j1) + tir.bind(ck, k1o * 4 + k1i) + with tir.init(): + C[ci, cj] = 0.0 + C[ci, cj] = C[ci, cj] + A[ci, cj, ck] + for k2o, k2i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: + tir.bind(di, i) + tir.bind(dj, j1) + tir.bind(dk, k2o * 4 + k2i) + with tir.init(): + D[di, dj] = 0.0 + D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] + for j2 in tir.serial(0, 16): + for k3o, k3i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: + tir.bind(ei, i) + tir.bind(ej, j2) + tir.bind(ek, k3o * 4 + k3i) + with tir.init(): + E[ei, ej] = 0.0 + E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] + for k4o, k4i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: + tir.bind(fi, i) + tir.bind(fj, j2) + tir.bind(fk, k4o * 4 + k4i) + with tir.init(): + F[fi, fj] = 0.0 + F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj] + + +@tvm.script.tir +def multiple_reduction_blocks_rfactor(a: ty.handle, f: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16, 16]) + C = tir.alloc_buffer([16, 16]) + D = tir.alloc_buffer([16, 16]) + E = tir.alloc_buffer([16, 16]) + F = tir.match_buffer(f, [16, 16]) + C_rf = tir.alloc_buffer([16, 16, 4]) + + for i, j1, k1o, k1i in tir.grid(16, 16, 4, 4): + with tir.block([4, 16, 16, tir.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: + tir.bind(vk1o, k1o) + tir.bind(ci, i) + tir.bind(cj, j1) + tir.bind(vk1i, k1i) + with tir.init(): + C_rf[ci, cj, vk1o] = 0.0 + C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] + for i_1 in tir.serial(0, 16): + for j1_1 in tir.serial(0, 16): + for k1o_1 in tir.serial(0, 4): + with tir.block([tir.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: + tir.bind(vk1o_1, k1o_1) + tir.bind(ci_1, i_1) + tir.bind(cj_1, j1_1) + with tir.init(): + C[ci_1, cj_1] = 0.0 + C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] + for k2o, k2i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: + tir.bind(di, i_1) + tir.bind(dj, j1_1) + tir.bind(dk, (k2o * 4) + k2i) + with tir.init(): + D[di, dj] = 0.0 + D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] + for j2 in tir.serial(0, 16): + for k3o, k3i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: + tir.bind(ei, i_1) + tir.bind(ej, j2) + tir.bind(ek, (k3o * 4) + k3i) + with tir.init(): + E[ei, ej] = 0.0 + E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] + for k4o, k4i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: + tir.bind(fi, i_1) + tir.bind(fj, j2) + tir.bind(fk, (k4o * 4) + k4i) + with tir.init(): + F[fi, fj] = 0.0 + F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj] + + +# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + + +def test_reduction_rfactor_matmul(): + s = tir.Schedule(transformed_matmul, debug_mask="all") + update = s.get_block("update") + _, _, _, _, kii = s.get_loops(update) + rf_block = s.rfactor(kii, 0) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) + assert s.get(update).same_as(s.get(s.get_block("update"))) + verify_trace_roundtrip(s, mod=transformed_matmul) + + +def test_reduction_rfactor_square_sum(): + s = tir.Schedule(square_sum, debug_mask="all") + C = s.get_block("C") + _, _, j = s.get_loops(C) + rf_block = s.rfactor(j, 1) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + assert s.get(C).same_as(s.get(s.get_block("C"))) + verify_trace_roundtrip(s, mod=square_sum) + + +def test_reduction_rfactor_square_sum_square_root(): + s = tir.Schedule(transformed_square_sum_square_root, debug_mask="all") + C = s.get_block("C") + _, _, f_i = s.get_loops(C) + rf_block = s.rfactor(f_i, 0) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + assert s.get(C).same_as(s.get(s.get_block("C"))) + verify_trace_roundtrip(s, mod=transformed_square_sum_square_root) + + +def test_reduction_rfactor_loop_multiple_children(): + s = tir.Schedule(matmul_loop_multiple_children, debug_mask="all") + k, _, _ = s.get_loops(s.get_block("C")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_stage_pipeline(): + s = tir.Schedule(matmul_not_stage_pipeline, debug_mask="all") + _, _, k = s.get_loops(s.get_block("C")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_reduction_block1(): + s = tir.Schedule(element_wise, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(i, 0) + + +def test_reduction_rfactor_not_reduction_block2(): + s = tir.Schedule(rowsum_not_quasi_affine, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_reduction_block3(): + s = tir.Schedule(rowsum_not_dominant, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_serial_loop(): + s = tir.Schedule(rowsum_not_serial, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_same_buffer_access(): + s = tir.Schedule(matmul_not_same_buffer_access, debug_mask="all") + _, _, k = s.get_loops(s.get_block("C")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_factor_axis_range_fail(): + s = tir.Schedule(transformed_matmul, debug_mask="all") + _, _, _, _, kii = s.get_loops(s.get_block("update")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(kii, 3) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(kii, -4) + + +def test_reduction_rfactor_factor_axis_range(): + s = tir.Schedule(transformed_matmul, debug_mask="all") + update = s.get_block("update") + _, _, _, _, kii = s.get_loops(update) + rf_block = s.rfactor(kii, -3) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) + assert s.get(update).same_as(s.get(s.get_block("update"))) + verify_trace_roundtrip(s, mod=transformed_matmul) + + +def test_reduction_rfactor_wrong_reduce_pattern1(): + s = tir.Schedule(rowsum_wrong_reduce_pattern1, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_wrong_reduce_pattern2(): + s = tir.Schedule(rowsum_wrong_reduce_pattern2, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_wrong_loops1(): + s = tir.Schedule(rowsum, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(i, 0) + + +def test_reduction_rfactor_wrong_loops2(): + s = tir.Schedule(rowsum_transformed, debug_mask="all") + _, _, k_i = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k_i, 0) + + +def test_reduction_rfactor_zero_dim(): + s = tir.Schedule(rowsum_zero_dim, debug_mask="all") + B = s.get_block("B") + (k,) = s.get_loops(B) + rf_block = s.rfactor(k, 0) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_zero_dim_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("B_rf"))) + assert s.get(B).same_as(s.get(s.get_block("B"))) + verify_trace_roundtrip(s, mod=rowsum_zero_dim) + + +def test_reduction_rfactor_outermost_loop_multiple_children_fail(): # pylint: disable=invalid-name + s = tir.Schedule(multiple_reduction_blocks, debug_mask="all") + _, _, k2o, k2i = s.get_loops(s.get_block("D")) + _, _, k3o, k3i = s.get_loops(s.get_block("E")) + _, _, k4o, k4i = s.get_loops(s.get_block("F")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k2o, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k2i, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k3o, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k3i, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k4o, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k4i, 0) + + +def test_reduction_rfactor_outermost_loop_multiple_children(): # pylint: disable=invalid-name + s = tir.Schedule(multiple_reduction_blocks, debug_mask="all") + C = s.get_block("C") + _, _, k1o, _ = s.get_loops(C) + rf_block = s.rfactor(k1o, 2) + tvm.ir.assert_structural_equal(s.mod["main"], multiple_reduction_blocks_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + assert s.get(C).same_as(s.get(s.get_block("C"))) + verify_trace_roundtrip(s, mod=multiple_reduction_blocks) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py new file mode 100644 index 000000000000..091a77df2030 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -0,0 +1,294 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_not_affine(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in tir.grid(128, 128, 128, 8): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.bind(vl, l * 16) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_dependent_loop(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for i in tir.serial(0, 128): + for j, k, l in tir.grid(128, i, 128): + with tir.block([128, 128, i, 128], "B") as [vi, vj, vk, vl]: + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_predicate(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in tir.grid(128, 128, 128, 128): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.where(i * 2097152 + j * 16384 + k * 128 + l < 100) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_non_single_branch(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + C = tir.alloc_buffer((128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "C") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_loops_not_same_scope(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "A") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + for k in tir.serial(0, 128): + with tir.block([128], "B") as [vk]: + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_wrong_block_var_type(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j, k in tir.grid(128, 128, 128): + with tir.block([128, 128, tir.scan_axis(0, 128)], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_reordered(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for l, j, k, i in tir.grid(128, 128, 128, 128): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.bind(vl, l) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_reordered2(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for k, j, i, l in tir.grid(128, 128, 128, 128): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.bind(vl, l) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_reordered_with_predicate(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for l, j, k, i in tir.grid(128, 128, 128, 128): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.where(i * 2097152 + j * 16384 + k * 128 + l < 100) + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.bind(vl, l) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16], "float32") + B = tir.match_buffer(b, [16, 16], "float32") + with tir.block([16, 16], "A") as [vi, vj]: + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, vi * 16 + vj, 1) + with tir.block([16, 16], "B") as [vi, vj]: + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +@tvm.script.tir +def opaque_access_reorder(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16], "float32") + B = tir.match_buffer(b, [16, 16], "float32") + for j, i in tir.grid(16, 16): + with tir.block([16, 16], "A") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, vi * 16 + vj, 1) + for j, i in tir.grid(16, 16): + with tir.block([16, 16], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_reorder(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + sch.reorder(l, i) + tvm.ir.assert_structural_equal(elementwise_reordered, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_reorder2(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + sch.reorder(k, i, l) + tvm.ir.assert_structural_equal(elementwise_reordered2, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_reorder_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mask="all") + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + sch.reorder(j, i) + block_b = sch.get_block("B") + i, j = sch.get_loops(block_b) + sch.reorder(j, i) + tvm.ir.assert_structural_equal(opaque_access_reorder, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=opaque_access) + + +def test_reorder_with_predicate(): + sch = tir.Schedule(elementwise_predicate, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + sch.reorder(l, i) + tvm.ir.assert_structural_equal(elementwise_reordered_with_predicate, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_predicate) + + +def test_reorder_fail_with_multi_appearance_loops(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k, i, i) + + +def test_reorder_fail_with_non_single_branch_loop(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k, i) + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + block_c = sch.get_block("C") + i, j, k1 = sch.get_loops(block_b) + _, _, k2 = sch.get_loops(block_c) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k1, i, k2) + + +def test_reorder_fail_with_loops_not_under_same_scope(): + sch = tir.Schedule(elementwise_with_loops_not_same_scope, debug_mask="all") + block_b = sch.get_block("B") + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + k = sch.get_loops(block_b)[0] + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k, i) + + +def test_reorder_fail_with_wrong_block_var_type(): + sch = tir.Schedule(elementwise_with_wrong_block_var_type, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k, i) + + +def test_reorder_fail_with_dependent_loops(): + sch = tir.Schedule(elementwise_dependent_loop, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(l, i) + + +def test_reorder_fail_not_affine_bindings(): + sch = tir.Schedule(elementwise_not_affine, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(l, i) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py new file mode 100644 index 000000000000..2bfd68663c99 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +from collections import defaultdict + +import pytest +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule import Trace + + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_sample_categorical(): + """Test sample categprical sampling function""" + n = 1000 + sch = tir.Schedule(elementwise, seed=42, debug_mask="all") + counter = defaultdict(int) + candidates = [5, 2, 7, 1] + probs = [0.15, 0.55, 0.05, 0.25] + for _ in range(n): + v = sch.get(sch.sample_categorical(candidates, probs)) + counter[v] += 1 + for i, prob in enumerate(probs): + assert (prob - 0.07) * n <= counter[candidates[i]] <= (prob + 0.07) * n + verify_trace_roundtrip(sch, mod=elementwise) + + +def test_sample_categorical_copy(): + """Check the random variable sampling results after schedule copy""" + n = 100 + sch = tir.Schedule(elementwise, seed=42, debug_mask="all") + candidates = [1, 2, 3, 4] + probs = [0.1, 0.2, 0.3, 0.4] + rv_decisions = [] + for _ in range(n): + rv = sch.sample_categorical(candidates, probs) # pylint: disable=invalid-name + rv_decisions.append((rv, sch.get(rv))) + sch_copy = sch.copy() + for rv, decision in rv_decisions: # pylint: disable=invalid-name + decision_copy = sch_copy.get(rv) + assert int(decision) == int(decision_copy) + + +def test_sample_categorical_serialize(): + """Check the random variable sampling results after schedule serialization""" + n = 100 + sch = tir.Schedule(elementwise, seed=42, debug_mask="all") + candidates = [5, 6, 7, 8] + probs = [0.23, 0.19, 0.37, 0.21] + decisions = [] + for _ in range(n): + rv = sch.get(sch.sample_categorical(candidates, probs)) # pylint: disable=invalid-name + decisions.append(rv) + new_sch = verify_trace_roundtrip(sch, mod=elementwise) + for i, new_inst in enumerate(new_sch.trace.insts): + assert decisions[i] == candidates[new_sch.trace.decisions[new_inst].value] + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py new file mode 100644 index 000000000000..2284f9d996b1 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -0,0 +1,466 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i, j, k in tir.grid(128, 128, n): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic_fused(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i_j_k_fused in tir.serial(0, (n * 16384)): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(i_j_k_fused, (n * 128))) + tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, n), 128)) + tir.bind(vk, tir.floormod(i_j_k_fused, n)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic_split(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i, j, k0, k1 in tir.grid(128, 128, 10, tir.floordiv((n + 9), 10)): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + tir.where((((k0 * tir.floordiv((n + 9), 10)) + k1) < n)) + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, ((k0 * tir.floordiv((n + 9), 10)) + k1)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_seq(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + C = tir.alloc_buffer((128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "C") as [vi, vj, vk]: + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_anno(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(0, 128, annotations={"useless_annotation": True}): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_thread_binding(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.thread_binding(0, 128, thread="threadIdx.x"): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_starting_point(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(10, 128): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j, k in tir.grid(128, 128, 128): + with tir.block([], "opaque"): + tir.reads([A[i, j, k]]) + tir.writes([B[i, j, k]]) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_fused(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for fused in tir.serial(0, 2097152): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(fused, 16384)) + tir.bind(vj, tir.floormod(tir.floordiv(fused, 128), 128)) + tir.bind(vk, tir.floormod(fused, 128)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_case0(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128, 128]) + B = tir.match_buffer(b, [128, 128, 128]) + for i1, i2, i3, j1, j2, k1, k2 in tir.grid(2, 1, 64, 4, 32, 16, 8): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, ((i1 * 64) + i3)) + tir.bind(vj, ((j1 * 32) + j2)) + tir.bind(vk, ((k1 * 8) + k2)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_case1(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128, 128]) + B = tir.match_buffer(b, [128, 128, 128]) + for i1, i2, i3, j1, j2, j3, k1, k2, k3 in tir.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i1 * 64 + i3) + tir.bind(vj, j1 * 64 + j3) + tir.bind(vk, k1 * 64 + k3) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_with_predicate(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + for i0, i1, i2, j0, j1, k0, k1 in tir.grid(1000, 2, 3, 1, 129, 3, 43): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.where( + ( + ((((((i0 * 2) + i1) * 3) + i2) < 128) and (((j0 * 129) + j1) < 128)) + and (((k0 * 43) + k1) < 128) + ) + ) + tir.bind(vi, (((i0 * 6) + (i1 * 3)) + i2)) + tir.bind(vj, j1) + tir.bind(vk, ((k0 * 43) + k1)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_fuse_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + for i_j_k_fused in tir.serial(0, 2097152): + with tir.block([], "opaque"): + tir.reads( + [ + A[ + tir.floormod(tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), + tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), + tir.floormod(i_j_k_fused, 128), + ] + ] + ) + tir.writes( + [ + B[ + tir.floormod(tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), + tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), + tir.floormod(i_j_k_fused, 128), + ] + ] + ) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(i_j_k_fused, 16384)) + tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, 128), 128)) + tir.bind(vk, tir.floormod(i_j_k_fused, 128)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + + for i0, i1, j, k in tir.grid(8, 16, 128, 128): + with tir.block([], "opaque"): + tir.reads([A[i0 * 16 + i1, j, k]]) + tir.writes([B[i0 * 16 + i1, j, k]]) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i0 * 16 + i1) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16], "float32") + B = tir.match_buffer(b, [16, 16], "float32") + with tir.block([16, 16], "A") as [vi, vj]: + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, vi * 16 + vj, 1) + with tir.block([16, 16], "B") as [vi, vj]: + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +@tvm.script.tir +def opaque_access_fused(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16]) + B = tir.match_buffer(b, [16, 16]) + for i_j_fused in tir.serial(0, 256): + with tir.block([16, 16], "A") as [vi, vj]: + tir.bind(vi, tir.floordiv(i_j_fused, 16)) + tir.bind(vj, tir.floormod(i_j_fused, 16)) + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, ((vi * 16) + vj), 1, 1) + for i_j_fused in tir.serial(0, 256): + with tir.block([16, 16], "B") as [vi, vj]: + tir.bind(vi, tir.floordiv(i_j_fused, 16)) + tir.bind(vj, tir.floormod(i_j_fused, 16)) + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate( + tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle") + ) + + +@tvm.script.tir +def opaque_access_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + B = tir.match_buffer(b, (16, 16)) + for i, j0, j1 in tir.grid(16, 4, 4): + with tir.block([16, 16], "A") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, ((j0 * 4) + j1)) + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, ((vi * 16) + vj), 1, 1) + for i, j0, j1 in tir.grid(16, 4, 4): + with tir.block([16, 16], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, ((j0 * 4) + j1)) + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate( + tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle") + ) + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_fuse(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.fuse(i, j, k) + tvm.ir.assert_structural_equal(elementwise_fused, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_split(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[2, 1, 64]) + sch.split(j, factors=[4, 32]) + sch.split(k, factors=[16, 8]) + tvm.ir.assert_structural_equal(elementwise_split_case0, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_split_with_inferred_factor(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[None, 1, 64]) + sch.split(j, factors=[2, None, 64]) + sch.split(k, factors=[2, 1, None]) + tvm.ir.assert_structural_equal(elementwise_split_case1, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_split_with_predicate(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[1000, 2, 3]) + sch.split(j, factors=[None, 129]) + sch.split(k, factors=[3, None]) + tvm.ir.assert_structural_equal(elementwise_split_with_predicate, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_fuse_fail_not_only_child(): + sch = tir.Schedule(elementwise_with_seq, debug_mask="all") + block_b = sch.get_block("B") + _, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + + +def test_fuse_split_fail_with_annotation(): + sch = tir.Schedule(elementwise_with_anno, debug_mask="all") + block_b = sch.get_block("B") + _, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_split_fail_not_start_with_zero(): + sch = tir.Schedule(elementwise_with_anno, debug_mask="all") + block_b = sch.get_block("B") + _, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_with_opaque_block(): + sch = tir.Schedule(elementwise_with_opaque_block, debug_mask="all") + block_opaque = sch.get_block("opaque") + i, j, k = sch.get_loops(block_opaque) + sch.fuse(i, j, k) + tvm.ir.assert_structural_equal(elementwise_fuse_with_opaque_block, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_with_opaque_block) + + +def test_fuse_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mask="all") + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + sch.fuse(i, j) + block_b = sch.get_block("B") + i, j = sch.get_loops(block_b) + sch.fuse(i, j) + tvm.ir.assert_structural_equal(opaque_access_fused, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=opaque_access) + + +def test_split_with_opaque_block(): + sch = tir.Schedule(elementwise_with_opaque_block, debug_mask="all") + block_opaque = sch.get_block("opaque") + i, _, _ = sch.get_loops(block_opaque) + sch.split(i, factors=[None, 16]) + tvm.ir.assert_structural_equal(elementwise_split_with_opaque_block, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_with_opaque_block) + + +def test_split_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mask="all") + block_a = sch.get_block("A") + _, j = sch.get_loops(block_a) + sch.split(j, factors=[None, 4]) + block_b = sch.get_block("B") + _, j = sch.get_loops(block_b) + sch.split(j, factors=[None, 4]) + tvm.ir.assert_structural_equal(opaque_access_split, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=opaque_access) + + +def test_fuse_split_fail_with_thread_binding(): + sch = tir.Schedule(elementwise_with_thread_binding, debug_mask="all") + block_b = sch.get_block("B") + _, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_symbolic(): + sch = tir.Schedule(elementwise_symbolic, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.fuse(i, j, k) + tvm.ir.assert_structural_equal(elementwise_symbolic_fused, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_symbolic) + + +def test_split_symbolic(): + sch = tir.Schedule(elementwise_symbolic, debug_mask="all") + block_b = sch.get_block("B") + _, _, k = sch.get_loops(block_b) + sch.split(k, factors=[10, None]) + tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_symbolic) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_state.py b/tests/python/unittest/test_tir_schedule_state.py index 34041120f252..856d6a5c17eb 100644 --- a/tests/python/unittest/test_tir_schedule_state.py +++ b/tests/python/unittest/test_tir_schedule_state.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring - import gc +import sys +import pytest import tvm from tvm import tir from tvm.ir import IRModule @@ -77,7 +78,7 @@ def block_in_opaque_block(a: ty.handle, b: ty.handle) -> None: def replace_ir_builder(deep_copy=False, realize=False): new_func = tvm.script.from_source(tvm.script.asscript(elementwise)) - s = tir.ScheduleState(new_func, debug_mode=True) + s = tir.ScheduleState(new_func, debug_mask="all") target = tvm.tir.Block( iter_vars=[], reads=[], @@ -105,7 +106,7 @@ def replace_ir_builder_module(deep_copy=False, realize=False): new_func = tvm.script.from_source(tvm.script.asscript(elementwise)) other_func = tvm.script.from_source(tvm.script.asscript(elementwise)) mod = IRModule(functions={"main": new_func, "other": other_func}) - s = tir.ScheduleState(mod, debug_mode=True) + s = tir.ScheduleState(mod, debug_mask="all") target = tvm.tir.Block( iter_vars=[], reads=[], @@ -131,7 +132,7 @@ def replace_ir_builder_module(deep_copy=False, realize=False): def replace_ir_builder_with_opaque(): func = tvm.script.from_source(tvm.script.asscript(block_in_opaque_block)) - s = tir.ScheduleState(func, debug_mode=True) + s = tir.ScheduleState(func, debug_mask="all") gc.collect() return s @@ -291,7 +292,7 @@ def test_replace_root_copy3(): def test_replace_block_remap(): func = elementwise - s = tir.ScheduleState(func, debug_mode=True) + s = tir.ScheduleState(func, debug_mask="all") # The target stmt target = matmul.body.block.body.body.body[0].block sref = s.get_sref(s.mod["main"].body.block.body[0].body.body.block) @@ -338,16 +339,4 @@ def test_replace_ir_module(): if __name__ == "__main__": - test_replace_direct_write0() - test_replace_direct_write1() - test_replace_copy() - test_replace_partial_copy0() - test_replace_partial_copy1() - test_replace_root_write() - test_replace_root_copy0() - test_replace_root_copy1() - test_replace_root_copy2() - test_replace_root_copy3() - test_replace_block_remap() - test_replace_block_in_opaque_block() - test_replace_ir_module() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index a320812b339f..075b6cd689c4 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring +import sys +import pytest import tvm from tvm import tir from tvm.script import ty @@ -280,7 +282,7 @@ def f_visit(node): def test_elementwise(): - s = tir.ScheduleState(elementwise, debug_mode=True) + s = tir.ScheduleState(elementwise, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, @@ -301,7 +303,7 @@ def test_elementwise(): def test_matmul(): - s = tir.ScheduleState(matmul, debug_mode=True) + s = tir.ScheduleState(matmul, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "init")) == CachedFlags( affine_binding=True, @@ -322,7 +324,7 @@ def test_matmul(): def test_block_in_opaque_block(): - s = tir.ScheduleState(block_in_opaque_block, debug_mode=True) + s = tir.ScheduleState(block_in_opaque_block, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, @@ -353,7 +355,7 @@ def test_block_in_opaque_block(): def test_write_after_read(): - s = tir.ScheduleState(write_after_read, debug_mode=True) + s = tir.ScheduleState(write_after_read, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, @@ -374,7 +376,7 @@ def test_write_after_read(): def test_loop_carried_dependency(): - s = tir.ScheduleState(loop_carried_dependency, debug_mode=True) + s = tir.ScheduleState(loop_carried_dependency, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, @@ -395,7 +397,7 @@ def test_loop_carried_dependency(): def test_concatenate_multi_producer_covered(): # pylint: disable=invalid-name - s = tir.ScheduleState(concatenate_multi_producer, debug_mode=True) + s = tir.ScheduleState(concatenate_multi_producer, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( affine_binding=True, @@ -421,7 +423,7 @@ def test_concatenate_multi_producer_covered(): # pylint: disable=invalid-name def test_concatenate_multi_producer_uncovered(): # pylint: disable=invalid-name - s = tir.ScheduleState(concatenate_multi_producer_uncovered, debug_mode=True) + s = tir.ScheduleState(concatenate_multi_producer_uncovered, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( affine_binding=True, @@ -447,7 +449,7 @@ def test_concatenate_multi_producer_uncovered(): # pylint: disable=invalid-name def test_lca_at_loop(): - s = tir.ScheduleState(lca_at_loop, debug_mode=True) + s = tir.ScheduleState(lca_at_loop, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, @@ -468,7 +470,7 @@ def test_lca_at_loop(): def test_multi_producer_consumer(): - s = tir.ScheduleState(multi_producer_consumer, debug_mode=True) + s = tir.ScheduleState(multi_producer_consumer, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( affine_binding=True, @@ -494,7 +496,7 @@ def test_multi_producer_consumer(): def test_elementwise_affine_producer(): - s = tir.ScheduleState(elementwise_affine_producer, debug_mode=True) + s = tir.ScheduleState(elementwise_affine_producer, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, @@ -515,7 +517,7 @@ def test_elementwise_affine_producer(): def test_subblock(): - s = tir.ScheduleState(elementwise_subblock, debug_mode=True) + s = tir.ScheduleState(elementwise_subblock, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, @@ -541,7 +543,7 @@ def test_subblock(): def test_subblock_uncovered(): - s = tir.ScheduleState(elementwise_subblock_uncovered, debug_mode=True) + s = tir.ScheduleState(elementwise_subblock_uncovered, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, @@ -567,7 +569,7 @@ def test_subblock_uncovered(): def test_thread_binding(): - s = tir.ScheduleState(bound_to_thread, debug_mode=True) + s = tir.ScheduleState(bound_to_thread, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, @@ -588,7 +590,7 @@ def test_thread_binding(): def test_equal_ranked_threads(): - s = tir.ScheduleState(equal_ranked_threads, debug_mode=True) + s = tir.ScheduleState(equal_ranked_threads, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, @@ -609,7 +611,7 @@ def test_equal_ranked_threads(): def test_warp_memory(): - s = tir.ScheduleState(warp_memory, debug_mode=True) + s = tir.ScheduleState(warp_memory, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, @@ -630,7 +632,7 @@ def test_warp_memory(): def test_warp_memory_negative(): - s = tir.ScheduleState(warp_memory_negative, debug_mode=True) + s = tir.ScheduleState(warp_memory_negative, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, @@ -651,19 +653,4 @@ def test_warp_memory_negative(): if __name__ == "__main__": - test_elementwise() - test_matmul() - test_block_in_opaque_block() - test_write_after_read() - test_loop_carried_dependency() - test_concatenate_multi_producer_covered() - test_concatenate_multi_producer_uncovered() - test_lca_at_loop() - test_multi_producer_consumer() - test_elementwise_affine_producer() - test_subblock() - test_subblock_uncovered() - test_thread_binding() - test_equal_ranked_threads() - test_warp_memory() - test_warp_memory_negative() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py b/tests/python/unittest/test_tir_schedule_storage_align.py new file mode 100644 index 000000000000..a0a069347f95 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_storage_align.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name + +@tvm.script.tir +def element_wise(a: ty.handle, c: ty.handle) -> None: + C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with tir.block([], "root"): + tir.reads([]) + tir.writes([]) + B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in tir.serial(0, 128): + for ax1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, ax1) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + B[vi, vj] = (A[vi, vj]*tir.float32(2)) + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi_1, vj_1]: + tir.bind(vi_1, i0) + tir.bind(vj_1, i1) + tir.reads([B[vi_1, vj_1]]) + tir.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) + + +@tvm.script.tir +def element_wise_storage_align(a: ty.handle, c: ty.handle) -> None: + C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with tir.block([], "root"): + tir.reads([]) + tir.writes([]) + B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in tir.serial(0, 128): + for ax1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, ax1) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + tir.block_attr({"buffer_dim_align":[[0, 0, 128, 127]]}) + B[vi, vj] = (A[vi, vj]*tir.float32(2)) + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi_1, vj_1]: + tir.bind(vi_1, i0) + tir.bind(vj_1, i1) + tir.reads([B[vi_1, vj_1]]) + tir.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) + + +@tvm.script.tir +def element_wise_invalid_annotation(a: ty.handle, c: ty.handle) -> None: + C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with tir.block([], "root"): + tir.reads([]) + tir.writes([]) + B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in tir.serial(0, 128): + for ax1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.block_attr({"buffer_dim_align": [0]}) + tir.bind(vi, i0) + tir.bind(vj, ax1) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + B[vi, vj] = (A[vi, vj]*tir.float32(2)) + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi_1, vj_1]: + tir.bind(vi_1, i0) + tir.bind(vj_1, i1) + tir.reads([B[vi_1, vj_1]]) + tir.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) + + +def test_storage_align(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + s.storage_align(B, 0, axis=0, factor=128, offset=127) + tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_storage_align_update(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + s.storage_align(B, 0, axis=0, factor=128, offset=0) + s.storage_align(B, 0, axis=0, factor=128, offset=127) + tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_storage_align_invalid_factor1(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 0, axis=0, factor=0, offset=127) + + +def test_storage_align_invalid_factor2(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 0, axis=0, factor=-1, offset=127) + + +def test_storage_align_invalid_buffer(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + C = s.get_block("C") + with pytest.raises(tir.ScheduleError): + s.storage_align(C, 0, axis=0, factor=128, offset=127) + + +def test_storage_align_invalid_buffer_index(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 2, axis=0, factor=128, offset=127) + + +def test_storage_align_invalid_axis(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 0, axis=2, factor=128, offset=127) + + +def test_storage_align_invalid_annotation(): + func = element_wise_invalid_annotation + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 0, axis=2, factor=128, offset=127) + + +if __name__ == "__main__": + test_storage_align() + test_storage_align_update() + test_storage_align_invalid_factor1() + test_storage_align_invalid_factor2() + test_storage_align_invalid_buffer() + test_storage_align_invalid_buffer_index() + test_storage_align_invalid_axis() + test_storage_align_invalid_annotation() diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py new file mode 100644 index 000000000000..da7b096ade17 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +# mypy: ignore-errors +import sys + +import pytest +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule import BlockRV, Instruction, InstructionKind, LoopRV, Trace + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_inlined(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def _make_get_block(name, output): + return Instruction( + kind=InstructionKind.get("GetBlock"), + inputs=[], + attrs=[name, "main"], + outputs=[output], + ) + + +def _make_get_loops(input, outputs): # pylint: disable=redefined-builtin + return Instruction( + kind=InstructionKind.get("GetLoops"), + inputs=[input], + attrs=[], + outputs=outputs, + ) + + +def _make_compute_inline(input): # pylint: disable=redefined-builtin + return Instruction( + kind=InstructionKind.get("ComputeInline"), + inputs=[input], + attrs=[], + outputs=[], + ) + + +def _make_enter_postproc(): + return Instruction( + kind=InstructionKind.get("EnterPostproc"), + inputs=[], + attrs=[], + outputs=[], + ) + + +def _make_trace_1(b0, l1, l2): # pylint: disable=invalid-name + return Trace( + insts=[ + _make_get_block(name="block", output=b0), + _make_get_loops(input=b0, outputs=[l1, l2]), + ], + decisions={}, + ) + + +def _make_trace_2(b0): # pylint: disable=invalid-name + return Trace( + insts=[ + _make_get_block(name="B", output=b0), + _make_compute_inline(input=b0), + ], + decisions={}, + ) + + +def _make_trace_3(b0, b1, add_postproc): # pylint: disable=invalid-name + if add_postproc: + insts = [ + _make_get_block(name="B", output=b0), + _make_compute_inline(input=b0), + _make_get_block(name="C", output=b1), + _make_enter_postproc(), + _make_compute_inline(input=b1), + ] + else: + insts = [ + _make_get_block(name="B", output=b0), + _make_compute_inline(input=b0), + _make_get_block(name="C", output=b1), + ] + return Trace(insts=insts, decisions={}) + + +def test_trace_construct_1(): + trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="block", func_name="main")', + "l1, l2 = sch.get_loops(block=b0)", + ) + ) + assert len(trace.insts) == 2 + assert len(trace.decisions) == 0 + + +def test_trace_construct_get_decision_1(): + trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + assert trace.get_decision(trace.insts[0]) is None + assert trace.get_decision(trace.insts[1]) is None + + +def test_trace_construct_append_1(): + trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + trace.append(inst=_make_get_block("block2", BlockRV())) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="block", func_name="main")', + "l1, l2 = sch.get_loops(block=b0)", + 'b3 = sch.get_block(name="block2", func_name="main")', + ) + ) + + +def test_trace_construct_pop_1(): + trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + last_inst = trace.insts[-1] + assert trace.pop().same_as(last_inst) + assert str(trace) == 'b0 = sch.get_block(name="block", func_name="main")' + + +def test_trace_construct_pop_2(): + trace = Trace([], {}) + assert str(trace) == "" + assert trace.pop() is None + assert str(trace) == "" + + +def test_trace_apply_to_schedule(): + trace = _make_trace_2(BlockRV()) + sch = tir.Schedule(elementwise, debug_mask="all") + trace.apply_to_schedule(sch, remove_postproc=False, decision_provider=None) + tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + + +def test_trace_as_json_1(): + trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + obj = trace.as_json() + assert obj == [ + [ + ["GetBlock", [], ["block", "main"], ["b0"]], + ["GetLoops", ["b0"], [], ["l1", "l2"]], + ], + [], + ] + + +def test_trace_simplified_1(): + trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="B", func_name="main")', + "sch.compute_inline(block=b0)", + 'b1 = sch.get_block(name="C", func_name="main")', + "sch.enter_postproc()", + "sch.compute_inline(block=b1)", + ) + ) + trace = trace.simplified(remove_postproc=True) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="B", func_name="main")', + "sch.compute_inline(block=b0)", + ) + ) + + +def test_trace_simplified_2(): + trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="B", func_name="main")', + "sch.compute_inline(block=b0)", + 'b1 = sch.get_block(name="C", func_name="main")', + "sch.enter_postproc()", + "sch.compute_inline(block=b1)", + ) + ) + trace = trace.simplified(remove_postproc=False) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="B", func_name="main")', + "sch.compute_inline(block=b0)", + 'b1 = sch.get_block(name="C", func_name="main")', + "sch.enter_postproc()", + "sch.compute_inline(block=b1)", + ) + ) + + +def test_apply_json_to_schedule_1(): + trace = _make_trace_2(BlockRV()) + json_obj = trace.as_json() + sch = tir.Schedule(elementwise, debug_mask="all") + Trace.apply_json_to_schedule(json_obj, sch) + tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index af89ca252738..dcaeaaad6164 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -15,12 +15,15 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring +import sys + import pytest import tvm + from tvm import tir from tvm.ir import IRModule from tvm.script import ty - +from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable @@ -46,8 +49,8 @@ def test_tir_schedule_creation(): # - Schedule.__init__ for PrimFunc and IRModule # - Schedule.mod # - Schedule.state - sch_1 = tir.Schedule(matmul, debug_mode=True) - sch_2 = tir.Schedule(IRModule({"main": matmul}), debug_mode=True) + sch_1 = tir.Schedule(matmul, debug_mask="all") + sch_2 = tir.Schedule(IRModule({"main": matmul}), debug_mask="all") assert sch_1.mod["main"].same_as(sch_2.mod["main"]) assert sch_1.state.mod["main"].same_as(sch_2.state.mod["main"]) @@ -57,7 +60,7 @@ def test_tir_schedule_get_block(): # - Schedule.get_block # - Schedule.get_sref # - Schedule.get - sch = tir.Schedule(matmul, debug_mode=True) + sch = tir.Schedule(matmul, debug_mask="all") block_rv = sch.get_block(name="update") block_sref = sch.get_sref(block_rv) block = sch.get(block_rv) @@ -71,7 +74,7 @@ def test_tir_schedule_get_loops(): # Tests: # - Schedule.get_loops # - Schedule.get - sch = tir.Schedule(matmul, debug_mode=True) + sch = tir.Schedule(matmul, debug_mask="all") block_rv = sch.get_block(name="update") i, j, k = sch.get_loops(block_rv) assert sch.get(i).loop_var.name == "i" @@ -79,10 +82,10 @@ def test_tir_schedule_get_loops(): assert sch.get(k).loop_var.name == "k" -def test_tir_schedule_copy(): +def test_tir_schedule_copy_1(): # Tests: # - Schedule.copy - sch_1 = tir.Schedule(matmul, debug_mode=True) + sch_1 = tir.Schedule(matmul, debug_mask="all") block_rv = sch_1.get_block(name="update") i, j, k = sch_1.get_loops(block_rv) assert sch_1.get(i).loop_var.name == "i" @@ -96,10 +99,40 @@ def test_tir_schedule_copy(): assert sch_2.get(k).loop_var.name == "k" +def test_tir_schedule_copy_2(): + sch = tir.Schedule(mod=matmul, debug_mask="all") + i, j, k = sch.get_loops(sch.get_block("update")) + sch_copy = sch.copy() + assert not sch.get_sref(i).same_as(sch_copy.get_sref(i)) + assert not sch.get_sref(j).same_as(sch_copy.get_sref(j)) + assert not sch.get_sref(k).same_as(sch_copy.get_sref(k)) + assert sch.get_sref(i).stmt.same_as(sch_copy.get_sref(i).stmt) + assert sch.get_sref(j).stmt.same_as(sch_copy.get_sref(j).stmt) + assert sch.get_sref(k).stmt.same_as(sch_copy.get_sref(k).stmt) + i_0, i_1 = sch.split(i, factors=[None, 64]) + j_0, j_1 = sch_copy.split(j, factors=[None, 32]) + + assert sch.get_sref(i_0).stmt.extent == 2 + assert sch.get_sref(i_1).stmt.extent == 64 + with pytest.raises(IndexError): + sch_copy.get_sref(i_0) + with pytest.raises(IndexError): + sch_copy.get_sref(i_1) + + with pytest.raises(IndexError): + sch.get_sref(j_0) + with pytest.raises(IndexError): + sch.get_sref(j_1) + assert sch_copy.get_sref(j_0).stmt.extent == 4 + assert sch_copy.get_sref(j_1).stmt.extent == 32 + verify_trace_roundtrip(sch, mod=matmul) + verify_trace_roundtrip(sch_copy, mod=matmul) + + def test_tir_schedule_remove_rv(): # Tests: # - Schedule.remove_rv - sch = tir.Schedule(matmul, debug_mode=True) + sch = tir.Schedule(matmul, debug_mask="all") block_rv = sch.get_block(name="update") assert sch.get(block_rv).name_hint == "update" sch.remove_rv(block_rv) @@ -108,8 +141,4 @@ def test_tir_schedule_remove_rv(): if __name__ == "__main__": - test_tir_schedule_creation() - test_tir_schedule_get_block() - test_tir_schedule_get_loops() - test_tir_schedule_copy() - test_tir_schedule_remove_rv() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_specialize.py b/tests/python/unittest/test_tir_specialize.py index 2e9f1110732a..d6cfadaf1fbc 100644 --- a/tests/python/unittest/test_tir_specialize.py +++ b/tests/python/unittest/test_tir_specialize.py @@ -145,6 +145,24 @@ def mem_copy_m_n_p_n(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty B[vi, vj] = A[vi, vj] +@tvm.script.tir +def param_in_arith_exprs(a: ty.handle, b: ty.handle) -> None: + n = tir.var("int32") + A = tir.match_buffer(a, [n // 8, 8], "int32") + B = tir.match_buffer(b, [n], "int32") + with tir.block([n - 1], "") as [vi]: + B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 + + +@tvm.script.tir +def param_in_arith_exprs_n_16(a: ty.handle, b: ty.handle) -> None: + n = tir.var("int32") + A = tir.match_buffer(a, [2, 8], "int32") + B = tir.match_buffer(b, [16], "int32") + with tir.block([15], "") as [vi]: + B[vi] = A[vi // 8, vi % 8] + 714 + + def test_specialize_nothing(): func = matmul.specialize({}) assert func.same_as(matmul) # Pointer the same @@ -191,9 +209,16 @@ def test_specialize_recursive_load(): pass +def test_specialize_with_const_folding(): + b = param_in_arith_exprs.params[1] + func = param_in_arith_exprs.specialize({b: tir.decl_buffer([16])}) + tvm.ir.assert_structural_equal(func, param_in_arith_exprs_n_16) + + if __name__ == "__main__": test_specialize_nothing() test_specialize_matmul() test_specialize_elemwise() test_specialize_mem_copy() test_specialize_recursive_load() + test_specialize_with_const_folding() diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py index 1e2f8460e586..d25780a01f79 100644 --- a/tests/python/unittest/test_tir_structural_equal_hash.py +++ b/tests/python/unittest/test_tir_structural_equal_hash.py @@ -165,6 +165,24 @@ def func2(): assert consistent_equal(func2(), func2()) +def test_buffer_storage_scope(): + x = te.var("x", dtype="handle") + + buffer_local_0 = tvm.tir.decl_buffer((10, 10), "float32", scope="local") + buffer_local_1 = tvm.tir.decl_buffer((10, 10), "float32", scope="local") + buffer_global = tvm.tir.decl_buffer((10, 10), "float32", scope="global") + buffer_empty = tvm.tir.decl_buffer((10, 10), "float32", scope="") + + func0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_local_0}) + func1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_local_1}) + func2 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_global}) + func3 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_empty}) + + assert consistent_equal(func0, func1) + assert consistent_equal(func2, func3) + assert not consistent_equal(func0, func2) + + def test_buffer_load_store(): b = tvm.tir.decl_buffer((10, 10), "float32") x = tvm.tir.BufferLoad(b, [0, 1]) @@ -188,4 +206,5 @@ def test_buffer_load_store(): test_array() test_env_func() test_stmt() + test_buffer_storage_scope() test_buffer_load_store() diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 7c06b5ef5ca1..cefdb5fd8c6a 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir +from tvm import tir, te from tvm.script import ty @@ -293,6 +293,96 @@ def compacted_complex_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None: C[i, j] = B[0, j] +@tvm.script.tir +def match_buffer_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + C0 = tir.match_buffer(C[i, 0:16], (16)) + B = tir.alloc_buffer((16, 16)) + with tir.block([]): + B0 = tir.match_buffer(B[i, 0:16], (16)) + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + B1 = tir.match_buffer(B0[j], ()) + B1[()] = A1[()] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + C1 = tir.match_buffer(C0[j], ()) + B2 = tir.match_buffer(B[i, j], ()) + C1[()] = B2[()] * 2.0 + + +@tvm.script.tir +def compacted_match_buffer_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + C0 = tir.match_buffer(C[i, 0:16], (16)) + B = tir.alloc_buffer((1, 16)) + with tir.block([]): + B0 = tir.match_buffer(B[0, 0:16], (16)) + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + B1 = tir.match_buffer(B0[j], ()) + B1[()] = A1[()] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + C1 = tir.match_buffer(C0[j], ()) + B2 = tir.match_buffer(B[0, j], ()) + C1[()] = B2[()] * 2.0 + + +@tvm.script.tir +def storage_align_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes(C[i, 0:16]) + B = tir.alloc_buffer((16, 16), "float32") + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(A[i, j]) + tir.writes(B[i, j]) + tir.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) + B[i, j] = A[i, j] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(B[i, j]) + tir.writes(C[i, j]) + C[i, j] = B[i, j] * 2.0 + + +@tvm.script.tir +def compacted_storage_align_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes(C[i, 0:16]) + B = tir.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32") + for j in range(0, 16): + with tir.block() as []: + tir.reads(A[i, j]) + tir.writes(B[0, j]) + tir.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) + B[0, j] = A[i, j] + 1.0 + for j in range(0, 16): + with tir.block() as []: + tir.reads(B[0, j]) + tir.writes(C[i, j]) + C[i, j] = B[0, j] * 2.0 + + def test_elementwise(): _check(elementwise_func, compacted_elementwise_func) @@ -321,6 +411,23 @@ def test_complex(): _check(complex_func, compacted_complex_func) +def test_match_buffer(): + _check(match_buffer_func, compacted_match_buffer_func) + + +def test_lower_te(): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + mod = tvm.tir.transform.CompactBufferAllocation()(orig_mod) + tvm.ir.assert_structural_equal(mod, orig_mod) # CompactBufferAllocation should do nothing on TE + + +def test_storage_align(): + _check(storage_align_func, compacted_storage_align_func) + + if __name__ == "__main__": test_elementwise() test_unschedulable_block() @@ -329,3 +436,6 @@ def test_complex(): test_warp_mem() test_symbolic() test_complex() + test_match_buffer() + test_storage_align() + test_lower_te() diff --git a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py index 38fe1c967456..cfdcc1a65911 100644 --- a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir +from tvm import tir, te from tvm.script import ty @@ -73,5 +73,15 @@ def test_elementwise(): _check(elementwise_func, substituted_elementwise_func) +def test_lower_te(): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + mod = tvm.tir.transform.ConvertBlocksToOpaque()(orig_mod) + tvm.ir.assert_structural_equal(mod, orig_mod) # ConvertBlocksToOpaque should do nothing on TE + + if __name__ == "__main__": test_elementwise() + test_lower_te() diff --git a/tests/python/unittest/test_tir_transform_coproc_sync.py b/tests/python/unittest/test_tir_transform_coproc_sync.py index 2d45118f39f2..7dacd8e046cc 100644 --- a/tests/python/unittest/test_tir_transform_coproc_sync.py +++ b/tests/python/unittest/test_tir_transform_coproc_sync.py @@ -51,7 +51,7 @@ def meminfo_cache(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body - body = stmt.body.body.body + body = stmt.body.body blist = tvm.tir.stmt_list(body) assert blist[1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_read_barrier")) @@ -112,7 +112,7 @@ def __check_list(tvm_array, py_list): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body - slist = tvm.tir.stmt_list(stmt[0].body.body) + slist = tvm.tir.stmt_list(stmt[0].body) push_st = slist[2] slist = tvm.tir.stmt_list(slist[-1]) pop_st = slist[0].body[0] diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index c997748649cd..c51b5319e85f 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir +from tvm import tir, te from tvm.script import ty @@ -35,7 +35,7 @@ def compacted_elementwise_func(a: ty.handle, c: ty.handle) -> None: with tir.block([]): tir.reads(A[i, 0:16]) tir.writes(C[i, 0:16]) - B = tir.alloc_buffer([1, 16], "float32") + B = tir.alloc_buffer([1, 16], "float32", scope="global") for j in range(0, 16): with tir.block() as []: tir.reads(A[i, j]) @@ -111,7 +111,7 @@ def compacted_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32, m: ty.int32 with tir.block([]): tir.reads(A[i, m]) tir.writes(C[i, m]) - B = tir.alloc_buffer((m,), "float32") + B = tir.alloc_buffer((m,), "float32", scope="global") for j in range(0, m): with tir.block([]) as []: tir.reads(A[i, j]) @@ -190,8 +190,8 @@ def compacted_multi_alloc_func(a: ty.handle, d: ty.handle) -> None: with tir.block([]) as []: tir.reads(A[i]) tir.writes(D[i]) - B = tir.alloc_buffer((32,)) - C = tir.alloc_buffer((32,)) + B = tir.alloc_buffer((32,), scope="global") + C = tir.alloc_buffer((32,), scope="global") B[i] = A[i] + 1.0 C[i] = A[i] + B[i] D[i] = C[i] * 2.0 @@ -234,6 +234,15 @@ def test_multi_alloc(): _check(compacted_multi_alloc_func, flattened_multi_alloc_func) +def test_lower_te(): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + mod = tvm.tir.transform.FlattenBuffer()(orig_mod) + tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do nothing on TE + + if __name__ == "__main__": test_elementwise() test_gpu_workload() @@ -241,3 +250,4 @@ def test_multi_alloc(): test_predicate() test_unit_loops() test_multi_alloc() + test_lower_te() diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index 252a187dbdc5..b111e2be75c7 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -636,7 +636,7 @@ def test_hoisting_block_scope_4(): def test_hoisting_block_scope_5(): ib = tvm.tir.ir_builder.create() - data = ib.pointer("float32", name="data") + data = ib.pointer("float32", name="data", scope="global") l = te.var("l") m = te.var("m") n = te.var("n") diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py index ceb32c484c6d..9b37bcaaacbc 100644 --- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -47,8 +47,8 @@ def test_double_buffer(): mod = opt(mod) stmt = mod["db"].body - assert isinstance(stmt.body.body, tvm.tir.Allocate) - assert stmt.body.body.extents[0].value == 2 + assert isinstance(stmt.body, tvm.tir.Allocate) + assert stmt.body.extents[0].value == 2 f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 3e7a5a0cb300..673267a9b1fa 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -49,13 +49,13 @@ def get_vthread(name): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("vthread"))) - )["main"].body + )["main"] assert stmt.body.body.extents[0].value == 2 stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"].body + )["main"] assert len(stmt.body.body.extents) == 3 @@ -94,11 +94,11 @@ def get_vthread(name): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"].body + )["main"] assert stmt.body.body.extents[0].value == 2 - assert stmt.body.body.body.body.body.body.extents[0].value == 2 - assert len(stmt.body.body.body.body.body.body.extents) == 3 + assert stmt.body.body.body.body.extents[0].value == 2 + assert len(stmt.body.body.body.body.extents) == 3 def test_vthread_if_then_else(): @@ -119,7 +119,7 @@ def test_vthread_if_then_else(): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) - )["main"].body + )["main"] assert stmt.body.body.body[0].else_case != None assert stmt.body.body.body[1].else_case == None diff --git a/tests/python/unittest/test_tir_transform_lift_attr_scope.py b/tests/python/unittest/test_tir_transform_lift_attr_scope.py index 12ad16dfe092..65e317dfbcb8 100644 --- a/tests/python/unittest/test_tir_transform_lift_attr_scope.py +++ b/tests/python/unittest/test_tir_transform_lift_attr_scope.py @@ -38,7 +38,7 @@ def test_coproc_lift(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] assert body.body.body.node == cp @@ -58,7 +58,7 @@ def test_coproc_lift(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] assert body.body.body.body[1].node == cp assert len(body.body.body.body) == 2 diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 9e8848083908..c632f744bb81 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -40,7 +40,7 @@ def test_basic(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) mod = tvm.tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tir.transform.Simplify()(mod)["main"] assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) assert any(collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse))) @@ -156,7 +156,7 @@ def test_thread_axis(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) mod = tvm.tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tir.transform.Simplify()(mod)["main"] assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) @@ -178,7 +178,7 @@ def test_vectorize(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].vectorize(x) - stmt = tvm.lower(s, [A, B], name="main")["main"].body + stmt = tvm.lower(s, [A, B], name="main")["main"] body = stmt.body.body.body.body assert x.var.name not in str(body.condition) assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))) @@ -229,7 +229,7 @@ def test_thread_axis2(): _, x = s[C].split(x, factor=m) s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) - stmt = tvm.lower(s, [A, B], name="main")["main"].body + stmt = tvm.lower(s, [A, B], name="main")["main"] for_body = stmt.body.body.body.body[0] assert "threadIdx" not in str(for_body.extent) diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py index 3fb8331d39fc..1f8a4adf7054 100644 --- a/tests/python/unittest/test_tir_transform_lower_init_block.py +++ b/tests/python/unittest/test_tir_transform_lower_init_block.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir +from tvm import tir, te from tvm.script import ty +# pylint: disable=no-self-argument + @tvm.script.tir class WithInit: @@ -43,11 +45,56 @@ def main(a: ty.handle, b: ty.handle) -> None: B[i] += A[i, j, k] +@tvm.script.tir +class InitWithMatchBuffer: + def main(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [64, 64, 64]) + B = tir.match_buffer(b, [64]) + + with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: + BB = tir.match_buffer(B[i], ()) + AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64)) + with tir.init(): + BB[()] = tir.float32(0) + BB[()] += AA[j, k] + + +@tvm.script.tir +class BranchWithMatchBuffer: + def main(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [64, 64, 64]) + B = tir.match_buffer(b, [64]) + + with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: + BB = tir.match_buffer(B[i], ()) + AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64)) + if (j == 0) and (k == 32): + BB[()] = tir.float32(0) + BB[()] += AA[j, k] + + def test_lower_reduction(): origin_mod = WithInit() mod = tvm.tir.transform.LowerInitBlock()(origin_mod) tvm.ir.assert_structural_equal(mod, WithBranch(), True) +def test_lower_match_buffer(): + origin_mod = InitWithMatchBuffer() + mod = tvm.tir.transform.LowerInitBlock()(origin_mod) + tvm.ir.assert_structural_equal(mod, BranchWithMatchBuffer(), True) + + +def test_lower_te(): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + mod = tvm.tir.transform.LowerInitBlock()(orig_mod) + tvm.ir.assert_structural_equal(mod, orig_mod) # LowerInitBlock should do nothing on TE + + if __name__ == "__main__": test_lower_reduction() + test_lower_match_buffer() + test_lower_te() diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index ef474c15cfbb..675a7feb3b1f 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -47,8 +47,9 @@ def test_lower_warp_memory_local_scope(): fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] mod = tvm.IRModule.from_expr(fdevice) fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] - assert fdevice.body.body.value.value == "local" - assert fdevice.body.body.body.extents[0].value == 2 + allocate = fdevice.body.body + assert allocate.buffer_var.type_annotation.storage_scope == "local" + assert fdevice.body.body.extents[0].value == 2 @tvm.testing.requires_cuda @@ -72,8 +73,8 @@ def test_lower_warp_memory_correct_indices(): bounds = tvm.te.schedule.InferBound(s) ir = tvm.te.schedule.ScheduleOps(s, bounds) - inner_func = ir.body.body.body.body - store_A_warp = inner_func.body.seq[0].body.body + inner_func = ir.body.body.body + store_A_warp = inner_func.seq[0].body.body indices = list(store_A_warp.indices) # A.warp is actually many buffers, one for each warp, although they are all called A.warp @@ -282,6 +283,33 @@ def check(device, m): check(device, m=65) +@tvm.testing.requires_cuda +def test_lower_warp_memory_same_thread(): + m = n = 128 + A = te.placeholder((m, n), name="A") + k = te.reduce_axis((0, n), name="k") + B = te.compute((m,), lambda i: te.sum(A[i, k], axis=[k])) + + s = te.create_schedule(B.op) + BB = s.cache_write(B, "warp") + tx = te.thread_axis("threadIdx.x") + xo, xi = s[B].split(B.op.axis[0], factor=32) + s[B].bind(xi, tx) + s[B].bind(xo, te.thread_axis("blockIdx.x")) + s[BB].compute_at(s[B], xo) + xo, xi = s[BB].split(s[BB].op.axis[0], factor=32) + s[BB].bind(xi, tx) + + cuda_target = tvm.target.Target("cuda") + assert cuda_target.thread_warp_size == 32 + mod = tvm.lower(s, [A, B], name="f") + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) + fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] + mod = tvm.IRModule.from_expr(fdevice) + fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] + assert "tvm_warp_shuffle" not in fdevice.astext() + + if __name__ == "__main__": test_lower_warp_memory_local_scope() test_lower_warp_memory_correct_indices() @@ -289,3 +317,4 @@ def check(device, m): test_lower_warp_memory_cuda_half_a_warp() test_lower_warp_memory_cuda_2_buffers() test_lower_warp_memory_roundup() + test_lower_warp_memory_same_thread() diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py new file mode 100644 index 000000000000..9c511f1de6b9 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -0,0 +1,259 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +import numpy as np +import tvm.testing +from tvm.topi.math import cast + + +def run_passes(sch, args): + bounds = tvm.te.schedule.InferBound(sch) + assert isinstance(bounds, tvm.container.Map) + stmt = tvm.te.schedule.ScheduleOps(sch, bounds) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) + mod = tvm.IRModule.from_expr(func) + return tvm.transform.Sequential( + [ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.Simplify(), + tvm.tir.transform.VectorizeLoop(), + tvm.tir.transform.StorageRewrite(), + tvm.tir.transform.MergeDynamicSharedMemoryAllocations(), + ] + )(mod) + + +def verify_single_allocation(stmt, alloc_size=None): + num_alloc = [0] + alloc_extents = [] + + def verify(n): + if ( + isinstance(n, tvm.tir.Allocate) + and n.buffer_var.type_annotation.storage_scope == "shared.dyn" + ): + num_alloc[0] += 1 + alloc_extents.append(n.extents[0]) + + tvm.tir.stmt_functor.post_order_visit(stmt, verify) + assert num_alloc[0] == 1 + + if alloc_size: + assert alloc_extents[0] == alloc_size + + +@tvm.testing.requires_gpu +def test_matmul_dyn_shared(): + n = 1024 + block = 16 + A = te.placeholder((n, n), name="A", dtype="float16") + B = te.placeholder((n, n), name="B", dtype="float16") + + def syncthread(): + return tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])) + + def test_matmul_ir(A, B, C): + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ty = te.thread_axis("threadIdx.y") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", block) + ib.scope_attr(ty, "thread_extent", block) + ib.scope_attr(bx, "thread_extent", n // block) + ib.scope_attr(by, "thread_extent", n // block) + + A_sh = ib.allocate(A.dtype, (block, block), scope="shared.dyn", name="A_sh") # fp16 + B_sh = ib.allocate(B.dtype, (block, block), scope="shared.dyn", name="B_sh") # fp16 + # Create a dynamic shared memory for the accumulation. + # This is for testing merging dynamic shared memory alloctions with different data type. + # In practice, there is no need to allocate a shared memory for C. + C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32 + + A_ptr = ib.buffer_ptr(A) + B_ptr = ib.buffer_ptr(B) + C_ptr = ib.buffer_ptr(C) + + C_sh[ty, tx] = 0.0 + + with ib.for_range(0, n // block, name="i") as i: + A_sh[ty, tx] = A_ptr[by * block + ty, i * block + tx] + B_sh[ty, tx] = B_ptr[i * block + ty, bx * block + tx] + ib.emit(syncthread()) + + with ib.for_range(0, block, name="k") as k: + C_sh[ty, tx] += cast(A_sh[ty, k] * B_sh[k, tx], "float32") + + ib.emit(syncthread()) + + C_ptr[by * block + ty, bx * block + tx] = C_sh[ty, tx] + + return ib.get() + + C = te.extern( + A.shape, + [A, B], + lambda ins, outs: test_matmul_ir(ins[0], ins[1], outs[0]), + name="matmul", + dtype="float32", + ) + s = te.create_schedule(C.op) + mod = run_passes(s, [A, B, C]) + expected_alloc_size = block * block * 3 * 4 + verify_single_allocation(mod["main"].body, expected_alloc_size) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fmatmul = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + size = (n, n) + a_np = np.random.uniform(size=size).astype(A.dtype) + b_np = np.random.uniform(size=size).astype(B.dtype) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(size, dtype=C.dtype), dev) + fmatmul(a, b, c) + np_ref = np.dot(a_np.astype("float32"), b_np.astype("float32")) + tvm.testing.assert_allclose(c.numpy(), np_ref, 1e-4, 1e-4) + + for target in ["cuda", "nvptx"]: + check_target(target) + + +@tvm.testing.requires_gpu +def test_dyn_shared_vectorized_store(): + """Test vectorized store into dynamic shared memory""" + n = te.size_var("n") + A = te.placeholder((n,), name="A", dtype="float16") + B = te.placeholder((n,), name="B", dtype="float32") + + def test_device_ir(A, B, C): + n = A.shape[0] + ib = tvm.tir.ir_builder.create() + + values_per_thread = 4 + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", tvm.tir.indexdiv(n, values_per_thread)) + + A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn") # fp16 + B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn") # fp32 + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + with ib.for_range(0, values_per_thread, kind="vectorize") as i: + A_sh[tx * values_per_thread + i] = Aptr[tx * values_per_thread + i] + B_sh[tx * values_per_thread + i] = Bptr[tx * values_per_thread + i] + + with ib.for_range(0, values_per_thread) as i: + Cptr[tx * values_per_thread + i] = ( + cast(A_sh[tx * values_per_thread + i], "float32") + B_sh[tx * values_per_thread + i] + ) + + return ib.get() + + C = te.extern( + (n,), + [A, B], + lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]), + name="vadd", + dtype="float32", + ) + s = te.create_schedule(C.op) + + mod = run_passes(s, [A, B, C]) + verify_single_allocation(mod["main"].body) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fadd = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + for n in [512, 1024]: + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.nd.array(np.zeros((n,), dtype=C.dtype), dev) + fadd(a, b, c) + tvm.testing.assert_allclose( + c.numpy(), a.numpy().astype("float32") + b.numpy(), 1e-4, 1e-4 + ) + + for target in ["cuda", "nvptx"]: + check_target(target) + + +@tvm.testing.requires_gpu +def test_dyn_shared_reuse_and_merge(): + n = 64 + A = te.placeholder((n,), name="A", dtype="float32") + B = te.placeholder((n,), name="B", dtype="float32") + C = te.placeholder((te.size_var("n_dyn"),), name="C", dtype="float32") + + def test_device_ir(A, B, C, D): + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", n) + + A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn", name="A_sh") + B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn", name="B_sh") + C_sh = ib.allocate(C.dtype, (C.shape[0],), scope="shared.dyn", name="C_sh") + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + Dptr = ib.buffer_ptr(D) + + A_sh[tx] = Aptr[tx] + Dptr[tx] = A_sh[tx] + + B_sh[tx] = Bptr[tx] + Dptr[tx] += B_sh[tx] + + C_sh[tx] = Cptr[tx] # C cannot reuse other buffers since it size is dynamic + Dptr[tx] += C_sh[tx] + + return ib.get() + + D = te.extern( + (n,), + [A, B, C], + lambda ins, outs: test_device_ir(ins[0], ins[1], ins[2], outs[0]), + name="vadd", + dtype="float32", + ) + s = te.create_schedule(D.op) + + mod = run_passes(s, [A, B, C, D]) + # merged allocation + # allocate(buf_dyn_shmem: Pointer(shared.dyn uint8), uint8, [((n_dyn*4) + 256)]); + verify_single_allocation(mod["main"].body) + + +if __name__ == "__main__": + test_matmul_dyn_shared() + test_dyn_shared_vectorized_store() + test_dyn_shared_reuse_and_merge() diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index d42c5e1f8626..07140ab458e6 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir +from tvm import tir, te from tvm.script import ty @@ -115,6 +115,85 @@ def transformed_func() -> None: ) +@tvm.script.tir +def match_buffer_func() -> None: + C = tir.alloc_buffer((128, 128)) + with tir.block([128]) as [vi]: + C0 = tir.match_buffer(C[vi, 0:128], (128)) + with tir.block([128]) as [jj]: + C1 = tir.match_buffer(C0[jj], ()) + C1[()] = 0 + + +@tvm.script.tir +def transformed_match_buffer_func() -> None: + for i in range(0, 128): + with tir.block([128]) as [vi]: + tir.bind(vi, i) + C = tir.alloc_buffer((128, 128)) + C0 = tir.match_buffer(C[vi, 0:128], (128)) + with tir.block([128]) as [jj]: + C1 = tir.match_buffer(C0[jj], ()) + C1[()] = 0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [1024]) + B = tir.match_buffer(b, [1024]) + A_cache = tir.alloc_buffer([1024]) + for i in tir.serial(0, 8): + with tir.block([8]) as [vi]: + with tir.block([8]) as [v]: + tir.bind(v, vi) + tir.reads([A[(v * 128) : ((v * 128) + 128)]]) + tir.writes([A_cache[(v * 128) : ((v * 128) + 128)]]) + tir.evaluate( + tir.call_extern( + "test", + A_cache.data, + (v * 128), + 128, + A.data, + (v * 128), + 128, + dtype="float32", + ) + ) + for j in tir.serial(0, 128): + with tir.block([1024]) as [v]: + tir.bind(v, ((vi * 128) + j)) + tir.reads([A_cache[v]]) + tir.writes([B[v]]) + B[v] = A_cache[v] + + +@tvm.script.tir +def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [1024]) + B = tir.match_buffer(b, [1024]) + for i in tir.serial(0, 8): + with tir.block([8]) as [vi]: + tir.reads(A[vi * 128 : vi * 128 + 128]) + tir.writes(B[vi * 128 : vi * 128 + 128]) + A_cache = tir.alloc_buffer([1024]) + with tir.block([8]) as [v]: + tir.bind(v, vi) + tir.reads([A[v * 128 : v * 128 + 128]]) + tir.writes([A_cache[v * 128 : v * 128 + 128]]) + tir.evaluate( + tir.call_extern( + "test", A_cache.data, v * 128, 128, A.data, v * 128, 128, dtype="float32" + ) + ) + for j in tir.serial(0, 128): + with tir.block([1024]) as [v]: + tir.bind(v, ((vi * 128) + j)) + tir.reads([A_cache[v]]) + tir.writes([B[v]]) + B[v] = A_cache[v] + + def test_elementwise(): _check(element_func, transformed_element_func) @@ -123,6 +202,28 @@ def test_locate_buffer_allocation(): _check(original_func, transformed_func) +def test_match_buffer_allocation(): + _check(match_buffer_func, transformed_match_buffer_func) + + +def test_opaque_access(): + _check(opaque_access, transformed_opaque_access) + + +def test_lower_te(): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + mod = tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()(orig_mod) + tvm.ir.assert_structural_equal( + mod, orig_mod + ) # PlanAndUpdateBufferAllocationLocation should do nothing on TE + + if __name__ == "__main__": test_elementwise() test_locate_buffer_allocation() + test_match_buffer_allocation() + test_opaque_access() + test_lower_te() diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 2d1fea01aa32..b57fa6c417b2 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -16,6 +16,8 @@ # under the License. import tvm from tvm import te +from tvm.script import ty +from tvm.relay import GlobalVar def test_flatten2(): @@ -79,7 +81,7 @@ def test_flatten_storage_align(): )(mod) stmt = mod["main"].body - assert stmt.body.extents[0].value == 17 * 8 + assert stmt.extents[0].value == 17 * 8 def test_flatten_double_buffer(): @@ -102,7 +104,9 @@ def test_flatten_double_buffer(): stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt)) + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A, C], stmt).with_attr("from_legacy_te_schedule", True) + ) with tvm.transform.PassContext(config={"tir.InjectDoubleBuffer": {"split_loop": 2}}): mod = tvm.transform.Sequential( @@ -114,8 +118,8 @@ def test_flatten_double_buffer(): )(mod) stmt = mod["main"].body - assert isinstance(stmt.body.body, tvm.tir.Allocate) - assert stmt.body.body.extents[0].value == 2 + assert isinstance(stmt.body, tvm.tir.Allocate) + assert stmt.body.extents[0].value == 2 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db")) f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] @@ -130,8 +134,24 @@ def count_sync(op): assert count[0] == 4 +@tvm.script.tir +def tir_func(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [2, 2]) + B = tir.match_buffer(a, [2, 2]) + A[0, 1] = B[1, 1] + + +def test_flatten_tir(): + orig_mod = tvm.IRModule({GlobalVar("main"): tir_func}) + mod = tvm.tir.transform.StorageFlatten(64)(orig_mod) + tvm.ir.assert_structural_equal( + orig_mod, mod + ) # StorageFlatten should do nothing to TIR functions + + if __name__ == "__main__": test_flatten2() test_flatten_storage_align() test_flatten_double_buffer() test_flatten_prefetch() + test_flatten_tir() diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 70e77ff69fea..9e738b136b17 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -298,9 +298,9 @@ def test_storage_share_gpu(): alloc_stats = {"global": 0, "shared": 0} def verify(n): - if isinstance(n, tvm.tir.AttrStmt): - if n.attr_key == "storage_scope": - alloc_stats[n.value.value] += 1 + if isinstance(n, tvm.tir.Allocate): + scope = n.buffer_var.type_annotation.storage_scope + alloc_stats[scope] += 1 tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert alloc_stats["global"] == 2 @@ -317,7 +317,7 @@ def test_parallel_alloc(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + body = tvm.tir.transform.StorageRewrite()(mod)["main"] assert isinstance(body.body.body, tvm.tir.Allocate) @@ -334,7 +334,7 @@ def test_parallel_alloc(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + body = tvm.tir.transform.StorageRewrite()(mod)["main"] assert isinstance(body.body.body.body.body, tvm.tir.Allocate) @@ -356,7 +356,6 @@ def get_mod(kind="serial"): mod = get_mod(kind="parallel") # parallel (i, 0, n) { - # // attr [j] storage_scope = "global" # allocate j[int32 * 1] # j[0] = 0 # while((j[0] < 10)){ @@ -366,11 +365,9 @@ def get_mod(kind="serial"): # j[0] = (j[0] + (j[0] + 1)) # } # } - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + body = tvm.tir.transform.StorageRewrite()(mod)["main"] # parallel (i, 0, n) { - # // attr [j] storage_scope = "global" # allocate j[int32 * 1] - # // attr [A] storage_scope = "global" # allocate A[float32 * n] # j[0] = 0 # while((j[0] < 10)){ @@ -379,11 +376,10 @@ def get_mod(kind="serial"): # } # } assert isinstance(body.body.body, tvm.tir.Allocate) # j - assert isinstance(body.body.body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body.body, tvm.tir.Allocate) # A mod = get_mod(kind="serial") # for (i, 0, n) { - # // attr [j] storage_scope = "global" # allocate j[int32 * 1] # j[0] = 0 # while((j[0] < 10)){ @@ -393,10 +389,8 @@ def get_mod(kind="serial"): # j[0] = (j[0] + (j[0] + 1)) # } # } - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body - # // attr [j] storage_scope = "global" + body = tvm.tir.transform.StorageRewrite()(mod)["main"] # allocate j[int32 * 1] - # // attr [A] storage_scope = "global" # allocate A[float32 * n] # for (i, 0, n) { # j[0] = 0 @@ -406,7 +400,7 @@ def get_mod(kind="serial"): # } # } assert isinstance(body.body, tvm.tir.Allocate) # j - assert isinstance(body.body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body, tvm.tir.Allocate) # A def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 030c01713927..ffdf4b5916c4 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -19,6 +19,21 @@ import tvm.testing +def run_passes(inputs, stmt): + func = tvm.te.schedule.SchedulePostProcToPrimFunc(inputs, stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + + cuda_target = tvm.target.Target("cuda") + + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) + )(mod) + + mod = tvm.tir.transform.SplitHostDevice()(mod) + return tvm.tir.transform.ThreadSync("shared")(mod) + + @tvm.testing.requires_cuda def test_thread_storage_sync(): m = te.size_var("m") @@ -38,23 +53,46 @@ def test_thread_storage_sync(): assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) - mod = tvm.IRModule.from_expr(func) - mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) + mod = run_passes([A, A2], stmt) + f = mod["test_kernel0"] + body_list = tvm.tir.stmt_list(f.body.body.body) + assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")) - cuda_target = tvm.target.Target("cuda") - mod = tvm.tir.transform.Apply( - lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) - )(mod._move()) +@tvm.testing.requires_cuda +def test_sync_else_branch(): + def ir(A, B): + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) - fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] - mod = tvm.IRModule.from_expr(fdevice) - cuda_target = tvm.target.Target("cuda") - f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"] - body_list = tvm.tir.stmt_list(f.body.body.body.body) - assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", 1) + + local = ib.allocate(A.dtype, (8,), name="buf_local", scope="local") + shared = ib.allocate(A.dtype, (8,), name="buf_shared", scope="shared") + + with ib.for_range(0, 8) as i: + with ib.if_scope(Aptr[i] < 0): + local[i] = Aptr[i] + with ib.else_scope(): + shared[i] = Aptr[i] + + with ib.for_range(0, 8) as i: + with ib.if_scope(Aptr[i] < 0): + Bptr[i] = local[i] + with ib.else_scope(): + Bptr[i] = shared[i] + + return ib.get() + + A = tvm.tir.decl_buffer((8,), "float32") + B = tvm.tir.decl_buffer((8,), "float32") + stmt = ir(A, B) + mod = run_passes([A, B], stmt) + assert "@tir.tvm_storage_sync" in str(mod) if __name__ == "__main__": test_thread_storage_sync() + test_sync_else_branch() diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py b/tests/python/unittest/test_tir_transform_unify_thread_binding.py new file mode 100644 index 000000000000..8e0b6dc804aa --- /dev/null +++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py @@ -0,0 +1,227 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +from tvm import tir, te +from tvm.script import ty + + +def _check(original, transformed): + mod = tvm.IRModule.from_expr(original) + mod = tvm.tir.transform.UnifyThreadBinding()(mod) + mod = tvm.tir.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed, True) + + +def _check_fail(original): + mod = tvm.IRModule.from_expr(original) + with pytest.raises(ValueError): + tvm.tir.transform.UnifyThreadBinding()(mod) + + +@tvm.script.tir +def element_wise_thread_x(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + j1_0 = tir.env_thread("threadIdx.x") + j0_0 = tir.env_thread("threadIdx.x") + i = tir.env_thread("blockIdx.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + tir.launch_thread(i, 128) + with tir.launch_thread(j0_0, 4): + for j0_1 in tir.serial(0, 32): + tir.store( + B.data, + i * 128 + j0_0 * 32 + j0_1, + tir.load("float32", A.data, i * 128 + j0_0 * 32 + j0_1) * 2.0, + True, + ) + tir.launch_thread(j1_0, 4) + for j1_1 in tir.serial(0, 32): + tir.store( + C.data, + i * 128 + j1_0 * 32 + j1_1, + tir.load("float32", A.data, i * 128 + j1_0 * 32 + j1_1) + 1.0, + True, + ) + + +@tvm.script.tir +def unified_element_wise_thread_x(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + thread_x = tir.env_thread("threadIdx.x") + block_x = tir.env_thread("blockIdx.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + tir.launch_thread(block_x, 128) + with tir.launch_thread(thread_x, 4): + for j0_1 in tir.serial(0, 32): + tir.store( + B.data, + block_x * 128 + thread_x * 32 + j0_1, + tir.load("float32", A.data, block_x * 128 + thread_x * 32 + j0_1) * 2.0, + True, + ) + tir.launch_thread(thread_x, 4) + for j1_1 in tir.serial(0, 32): + tir.store( + C.data, + block_x * 128 + thread_x * 32 + j1_1, + tir.load("float32", A.data, block_x * 128 + thread_x * 32 + j1_1) + 1.0, + True, + ) + + +@tvm.script.tir +def element_wise_vthread_x(a: ty.handle, b: ty.handle) -> None: + i_0 = tir.env_thread("vthread.x") + i_1 = tir.env_thread("threadIdx.x") + j_0 = tir.env_thread("vthread.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + tir.launch_thread(i_0, 2) + tir.launch_thread(i_1, 64) + tir.launch_thread(j_0, 2) + for j_1 in tir.serial(0, 64): + tir.store( + B.data, + i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1, + tir.load("float32", A.data, i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1) * 2.0, + True, + ) + + +@tvm.script.tir +def unified_element_wise_vthread_x(a: ty.handle, b: ty.handle) -> None: + vthread_x = tir.env_thread("vthread.x") + thread_x = tir.env_thread("threadIdx.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + tir.launch_thread(vthread_x, 2) + tir.launch_thread(thread_x, 64) + tir.launch_thread(vthread_x, 2) + for j_1 in tir.serial(0, 64): + tir.store( + B.data, + vthread_x * 8256 + thread_x * 128 + j_1, + tir.load("float32", A.data, vthread_x * 8256 + thread_x * 128 + j_1) * 2.0, + True, + ) + + +@tvm.script.tir +def element_wise_two_thread_x_in_same_kernel_not_equal( + a: ty.handle, b: ty.handle, c: ty.handle +) -> None: + i = tir.env_thread("blockIdx.x") + j0 = tir.env_thread("threadIdx.x") + j1 = tir.env_thread("threadIdx.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 64]) + tir.launch_thread(i, 128) + with tir.launch_thread(j0, 128): + tir.store(B.data, i * 64 + j0, tir.load("float32", A.data, i * 128 + j0) * 2.0, True) + tir.launch_thread(j1, 64) + tir.store(C.data, i * 64 + j1, tir.load("float32", A.data, i * 128 + j1) + 1.0, True) + + +@tvm.script.tir +def element_wise_kernels_with_different_size( + a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle +) -> None: + i0 = tir.env_thread("blockIdx.x") + j0 = tir.env_thread("threadIdx.x") + i1 = tir.env_thread("blockIdx.x") + j1 = tir.env_thread("threadIdx.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [256, 256]) + D = tir.match_buffer(d, [256, 256]) + with tir.launch_thread(i0, 128): + tir.launch_thread(j0, 128) + tir.store(B.data, i0 * 128 + j0, tir.load("float32", A.data, i0 * 128 + j0) * 2.0, True) + tir.launch_thread(i1, 256) + tir.launch_thread(j1, 256) + tir.store(D.data, i1 * 256 + j1, tir.load("float32", C.data, i1 * 256 + j1) + 1.0, True) + + +@tvm.script.tir +def unified_element_wise_kernels_with_different_size( + a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle +) -> None: + block_x = tir.env_thread("blockIdx.x") + thread_x = tir.env_thread("threadIdx.x") + block_x_1 = tir.env_thread("blockIdx.x") + thread_x_1 = tir.env_thread("threadIdx.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [256, 256]) + D = tir.match_buffer(d, [256, 256]) + with tir.launch_thread(block_x, 128): + tir.launch_thread(thread_x, 128) + tir.store( + B.data, + block_x * 128 + thread_x, + tir.load("float32", A.data, block_x * 128 + thread_x) * 2.0, + True, + ) + tir.launch_thread(block_x_1, 256) + tir.launch_thread(thread_x_1, 256) + tir.store( + D.data, + block_x_1 * 256 + thread_x_1, + tir.load("float32", C.data, block_x_1 * 256 + thread_x_1) + 1.0, + True, + ) + + +def test_thread_x(): + _check(element_wise_thread_x, unified_element_wise_thread_x) + + +def test_vthread_x(): + _check(element_wise_vthread_x, unified_element_wise_vthread_x) + + +def test_two_thread_x_in_same_kernel_not_equal(): + _check_fail(element_wise_two_thread_x_in_same_kernel_not_equal) + + +def test_kernels_with_different_size(): + _check( + element_wise_kernels_with_different_size, unified_element_wise_kernels_with_different_size + ) + + +def test_lower_te(): + a = te.placeholder((32, 2, 2)) + b = te.compute((32, 2, 2), lambda i, j, k: a[i, j, k] * 2.0) + s = te.create_schedule(b.op) + s[b].bind(b.op.axis[1], te.thread_axis("threadIdx.x")) + s[b].bind(b.op.axis[2], te.thread_axis("threadIdx.x")) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [a, b]) + mod = tvm.tir.transform.UnifyThreadBinding()(orig_mod) + tvm.ir.assert_structural_equal(mod, orig_mod) # UnifyThreadBinding should do nothing on TE + + +if __name__ == "__main__": + test_thread_x() + test_vthread_x() + test_two_thread_x_in_same_kernel_not_equal() + test_kernels_with_different_size() + test_lower_te() diff --git a/tests/python/unittest/test_tvm_testing_features.py b/tests/python/unittest/test_tvm_testing_features.py index 07b8c652bf1f..4c9c5d91901a 100644 --- a/tests/python/unittest/test_tvm_testing_features.py +++ b/tests/python/unittest/test_tvm_testing_features.py @@ -148,6 +148,22 @@ def test_cached_count(self): assert self.cached_calls == len(self.param1_vals) +class TestCachedFixtureIsCopy: + param = tvm.testing.parameter(1, 2, 3, 4) + + @tvm.testing.fixture(cache_return_value=True) + def cached_mutable_fixture(self): + return {"val": 0} + + def test_modifies_fixture(self, param, cached_mutable_fixture): + assert cached_mutable_fixture["val"] == 0 + + # The tests should receive a copy of the fixture value. If + # the test receives the original and not a copy, then this + # will cause the next parametrization to fail. + cached_mutable_fixture["val"] = param + + class TestBrokenFixture: # Tests that use a fixture that throws an exception fail, and are # marked as setup failures. The tests themselves are never run. @@ -180,5 +196,79 @@ def test_num_uses_cached(self): assert self.num_uses_broken_cached_fixture == 0 +class TestAutomaticMarks: + @staticmethod + def check_marks(request, target): + parameter = tvm.testing.plugin._pytest_target_params([target])[0] + required_marks = [decorator.mark for decorator in parameter.marks] + applied_marks = list(request.node.iter_markers()) + + for required_mark in required_marks: + assert required_mark in applied_marks + + def test_automatic_fixture(self, request, target): + self.check_marks(request, target) + + @tvm.testing.parametrize_targets + def test_bare_parametrize(self, request, target): + self.check_marks(request, target) + + @tvm.testing.parametrize_targets("llvm", "cuda", "vulkan") + def test_explicit_parametrize(self, request, target): + self.check_marks(request, target) + + @pytest.mark.parametrize("target", ["llvm", "cuda", "vulkan"]) + def test_pytest_mark(self, request, target): + self.check_marks(request, target) + + @pytest.mark.parametrize("target,other_param", [("llvm", 0), ("cuda", 1), ("vulkan", 2)]) + def test_pytest_mark_covariant(self, request, target, other_param): + self.check_marks(request, target) + + +@pytest.mark.skipif( + bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0"))), + reason="Cannot test cache behavior while caching is disabled", +) +class TestCacheableTypes: + class EmptyClass: + pass + + @tvm.testing.fixture(cache_return_value=True) + def uncacheable_fixture(self): + return self.EmptyClass() + + def test_uses_uncacheable(self, request): + # Normally the num_tests_use_this_fixture would be set before + # anything runs. For this test case only, because we are + # delaying the use of the fixture, we need to manually + # increment it. + self.uncacheable_fixture.num_tests_use_this_fixture[0] += 1 + with pytest.raises(TypeError): + request.getfixturevalue("uncacheable_fixture") + + class ImplementsReduce: + def __reduce__(self): + return super().__reduce__() + + @tvm.testing.fixture(cache_return_value=True) + def fixture_with_reduce(self): + return self.ImplementsReduce() + + def test_uses_reduce(self, fixture_with_reduce): + pass + + class ImplementsDeepcopy: + def __deepcopy__(self, memo): + return type(self)() + + @tvm.testing.fixture(cache_return_value=True) + def fixture_with_deepcopy(self): + return self.ImplementsDeepcopy() + + def test_uses_deepcopy(self, fixture_with_deepcopy): + pass + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index a4d2dec0cce9..4798e9e09865 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -177,19 +177,6 @@ def test_complete_part_region(): _check_elementwise(func_with_part_access_region) -def test_complete_opaque_block_error(): - def render(e): - pass - - override_renderer(render) - - try: - from_source(func_with_opaque_block) - except tvm.error.DiagnosticError: - return - assert False - - @tvm.script.tir def func_with_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: data_buf = tir.match_buffer(data, (16, 16), "float32") @@ -255,10 +242,46 @@ def test_complete_buffer_indices(): tvm.ir.assert_structural_equal(new_func, expected_recursive_bufferslice_indices) +@tvm.script.tir +def match_buffer_func(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + with tir.block([]): + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + A1[()] = 1.0 + + +@tvm.script.tir +def expected_match_buffer_func(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + for i in range(0, 16): + with tir.block([]): + tir.reads([]) + tir.writes(A[i, 0:16]) + A0 = tir.match_buffer(A[i, 0:16], (16)) + with tir.block([]): + tir.reads([]) + tir.writes(A0[0:16]) + for j in range(0, 16): + with tir.block([]) as []: + tir.reads([]) + tir.writes(A0[j]) + A1 = tir.match_buffer(A0[j], ()) + A1[()] = 1.0 + + +def test_complete_match_buffer(): + tvm.ir.assert_structural_equal(match_buffer_func, expected_match_buffer_func) + + if __name__ == "__main__": test_complete_matmul() test_complete_matmul_original() test_complete_with_root() - test_complete_opaque_block_error() test_complete_part_region() test_complete_buffer_indices() + test_complete_match_buffer() diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index a72b13e38829..7aeceeccfa89 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -202,7 +202,7 @@ def test_inconsistent_grid(): def invalid_match_buffer_region() -> None: with tir.block([16, 16]) as [vi, vj]: - A = tir.match_buffer_region(vi) # error + A = tir.match_buffer(vi) # error tir.evaluate(1.0) @@ -363,6 +363,23 @@ def test_tvm_exception_catch(): check_error(intrin_except_assign, 3) +def buffer_shape_mismatch(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 2): + with tir.block([]): + tir.reads([]) + tir.writes([A[i, j * 4 : j * 4 + 4]]) + sub_A = tir.match_buffer( + A[i, j * 4 : j * 4 + 4], (5) + ) # error: shape mismatched between 4 and 5 + for jj in range(0, 4): + sub_A[i, j * 4 + jj] = 1 + + +def test_match_buffer_shape_mismatch(): + check_error(buffer_shape_mismatch, 7) + + def check_error(module, rel_lineno): # Override the default renderer to accumulate errors _, start_line = inspect.getsourcelines(module) @@ -414,3 +431,4 @@ def render(e): test_error_index_with_stop_slice() test_mismatch_args() test_tvm_exception_catch() + test_match_buffer_shape_mismatch() diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 164949552859..0566ff5044d9 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -277,8 +277,8 @@ def mmult( } ) # var definition - C_global = tir.var("handle") - packedB = tir.var("handle") + C_global = tir.buffer_var("float32", "global") + packedB = tir.buffer_var("float32", "global") # body assert num_args == 3, "mmult: num_args should be 3" arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle") @@ -2820,6 +2820,43 @@ def test_for_thread_binding(): assert rt_func.body.body.thread_binding.thread_tag == "threadIdx.y" +@tvm.script.tir +def match_buffer_region(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16, 16), "float32") + B = tir.match_buffer(b, (1), "float32") + + with tir.block([16, 4]) as [vi, vj]: + C = tir.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) + with tir.block([4]) as [vii]: + D = tir.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) + for i, j in tir.grid(4, 4): + B[0] += D[i, 0, j] + + +def test_match_buffer_region(): + func = match_buffer_region + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + assert isinstance(rt_func.body, tir.stmt.BlockRealize) + root = rt_func.body.block + + assert isinstance(root.body, tir.stmt.For) + assert isinstance(root.body.body, tir.stmt.For) + assert isinstance(root.body.body.body, tir.stmt.BlockRealize) + outer_block = root.body.body.body.block + assert len(outer_block.match_buffers) == 1 + buffer_C = outer_block.match_buffers[0].buffer + tvm.ir.assert_structural_equal(buffer_C.shape, [16, 1, 4]) + + assert isinstance(outer_block.body, tir.stmt.For) + assert isinstance(outer_block.body.body, tir.stmt.BlockRealize) + inner_block = outer_block.body.body.block + assert len(inner_block.match_buffers) == 1 + buffer_D = inner_block.match_buffers[0].buffer + tvm.ir.assert_structural_equal(buffer_D.shape, [4, 1, 4]) + + @tvm.script.tir def block_elements(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (16, 16), "float32") @@ -2832,10 +2869,10 @@ def block_elements(a: ty.handle, b: ty.handle) -> None: tir.writes(B[0, 0]) tir.block_attr({"attr_key": "attr_value"}) C = tir.alloc_buffer((4, 4), dtype="float32") - D = tir.match_buffer_region(A[0:4, 0]) + D = tir.match_buffer(A[0:4, 0], (4, 1)) with tir.init(): B[0, 0] = tir.float32(0) - B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2, 0] + B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2] def test_block_elements(): @@ -2946,6 +2983,34 @@ def test_minmax(): tvm.ir.assert_structural_equal(func, rt_func) +@tvm.script.tir +def abs(a: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + + with tir.block([128, 128], "A") as [vi, vj]: + A[vi, vj] = tir.abs(A[vi, vj]) + + +def test_abs(): + func = abs + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + +@tvm.script.tir +def constant_folding(a: ty.handle) -> None: + A = tir.match_buffer(a, (), "float32") + A[()] = tir.min(2.2, 5.2) + A[()] = tir.max(tir.float32(2.2), tir.float32(tir.float32(5.2))) + A[()] = tir.min(2.2, 5.0) + + +def test_script_printer(): + func = constant_folding + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + if __name__ == "__main__": test_opt_gemm_normalize() test_opt_gemm_mod_host() @@ -2960,5 +3025,8 @@ def test_minmax(): test_element_wise() test_predicate() test_for_thread_binding() + test_match_buffer_region() test_block_elements() test_opaque_block() + test_abs() + test_script_printer() diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh index e534e8bc114b..753d17d8afe5 100755 --- a/tests/scripts/task_ci_setup.sh +++ b/tests/scripts/task_ci_setup.sh @@ -36,3 +36,6 @@ python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.3.0 # Jenkinsfile. We expect config.cmake to be present from pack_lib(). # TODO(areusch): Make pack_lib() pack all the data dependencies of TVM. (cd build && cmake .. && make standalone_crt) + +# Ensure no stale pytest-results remain from a previous test run. +(cd build && rm -rf pytest-results) diff --git a/tests/scripts/task_config_build_arm.sh b/tests/scripts/task_config_build_arm.sh index cae28467830f..47fa243e8d38 100755 --- a/tests/scripts/task_config_build_arm.sh +++ b/tests/scripts/task_config_build_arm.sh @@ -31,7 +31,7 @@ echo set\(USE_PROFILER ON\) >> config.cmake echo set\(USE_LLVM llvm-config-8\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake -echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_ARM_COMPUTE_LIB ON\) >> config.cmake echo set\(USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR "/opt/acl"\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 2af91d7c6b8e..167e5becd4a7 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -46,3 +46,4 @@ echo set\(USE_ETHOSN_HW OFF\) >> config.cmake echo set\(USE_VITIS_AI ON\) >> config.cmake echo set\(USE_VERILATOR ON\) >> config.cmake echo set\(USE_LIBBACKTRACE ON\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index 609325c9962b..5f86476c64c7 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -38,10 +38,10 @@ echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake echo set\(USE_STACKVM_RUNTIME ON\) >> config.cmake echo set\(USE_PROFILER ON\) >> config.cmake echo set\(USE_ANTLR ON\) >> config.cmake -echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_BLAS openblas\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(USE_TENSORRT_CODEGEN ON\) >> config.cmake echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu_vulkan.sh b/tests/scripts/task_config_build_gpu_vulkan.sh index 17d11397718d..a5a26a1db0fb 100755 --- a/tests/scripts/task_config_build_gpu_vulkan.sh +++ b/tests/scripts/task_config_build_gpu_vulkan.sh @@ -30,3 +30,4 @@ echo set\(USE_MICRO ON\) >> config.cmake echo set\(USE_PROFILER ON\) >> config.cmake echo set\(USE_LIBBACKTRACE OFF\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_i386.sh b/tests/scripts/task_config_build_i386.sh index 05acbb022124..298259682972 100755 --- a/tests/scripts/task_config_build_i386.sh +++ b/tests/scripts/task_config_build_i386.sh @@ -31,6 +31,8 @@ echo set\(USE_PROFILER ON\) >> config.cmake echo set\(USE_LLVM llvm-config-4.0\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake -echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake +echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VERILATOR ON\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake + diff --git a/tests/scripts/task_config_build_qemu.sh b/tests/scripts/task_config_build_qemu.sh index 086ca8034dc9..d821da0eb262 100755 --- a/tests/scripts/task_config_build_qemu.sh +++ b/tests/scripts/task_config_build_qemu.sh @@ -29,3 +29,4 @@ echo set\(USE_LLVM llvm-config-10\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_wasm.sh b/tests/scripts/task_config_build_wasm.sh index 78dc7550028b..9a1edbccc1fc 100755 --- a/tests/scripts/task_config_build_wasm.sh +++ b/tests/scripts/task_config_build_wasm.sh @@ -32,5 +32,5 @@ echo set\(USE_ANTLR ON\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake -echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_cpp_unittest.sh b/tests/scripts/task_cpp_unittest.sh index 62c563b2610d..3df7b580d79d 100755 --- a/tests/scripts/task_cpp_unittest.sh +++ b/tests/scripts/task_cpp_unittest.sh @@ -30,9 +30,7 @@ export VTA_HW_PATH=`pwd`/3rdparty/vta-hw export TVM_BIND_THREADS=0 export OMP_NUM_THREADS=1 -# Remove existing testcases -rm -f build/*_test - +# Build cpptest suite make cpptest -j2 # "make crttest" requires USE_MICRO to be enabled, which is not always the case. @@ -40,9 +38,7 @@ if grep crttest build/Makefile > /dev/null; then make crttest # NOTE: don't parallelize, due to issue with build deps. fi -for test in build/*_test; do - ./$test -done +cd build && ctest --gtest_death_test_style=threadsafe && cd .. # Test MISRA-C runtime cd apps/bundle_deploy diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh index 023ab2a31bd8..2889c3a94f11 100755 --- a/tests/scripts/task_lint.sh +++ b/tests/scripts/task_lint.sh @@ -39,11 +39,15 @@ tests/lint/cpplint.sh echo "clang-format check..." tests/lint/clang_format.sh +echo "Rust check..." +tests/lint/rust_format.sh + echo "black check..." tests/lint/python_format.sh echo "Linting the Python code..." tests/lint/pylint.sh +tests/lint/flake8.sh echo "Lintinf the JNI code..." tests/lint/jnilint.sh diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index 71c03888fddf..b05acb090c2f 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -19,3 +19,9 @@ set -o pipefail echo "Checking MyPy Type defs in the schedule package." mypy --check-untyped-defs python/tvm/tir/schedule + +echo "Checking MyPy Type defs in the analysis package." +mypy --check-untyped-defs python/tvm/tir/analysis/ + +echo "Checking MyPy Type defs in the transofrm package." +mypy --check-untyped-defs python/tvm/tir/transform/ diff --git a/tests/scripts/task_python_arm_compute_library.sh b/tests/scripts/task_python_arm_compute_library.sh index 75380624572d..7df894d93399 100755 --- a/tests/scripts/task_python_arm_compute_library.sh +++ b/tests/scripts/task_python_arm_compute_library.sh @@ -27,5 +27,4 @@ source tests/scripts/setup-pytest-env.sh find . -type f -path "*.pyc" | xargs rm -f make cython3 -echo "Temporarily suspended while we understand flakiness with #8117" -#run_pytest ctypes python-arm_compute_lib tests/python/contrib/test_arm_compute_lib +run_pytest ctypes python-arm_compute_lib tests/python/contrib/test_arm_compute_lib diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index fb388a6b7edd..a2f6d706a163 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -35,10 +35,7 @@ echo "Running relay MXNet frontend test..." run_pytest cython python-frontend-mxnet tests/python/frontend/mxnet echo "Running relay ONNX frontend test..." -# Enable tvm.testing decorators in the ONNX importer test (not enabling in the other tests because we -# they do not consistently use the decorators to indicate that tests should run on GPU) -# In the future, we should enable tvm.testing decorators for all the test files. -PYTEST_ADDOPTS="-m gpu $PYTEST_ADDOPTS" run_pytest cython python-frontend-onnx tests/python/frontend/onnx +run_pytest cython python-frontend-onnx tests/python/frontend/onnx echo "Running relay CoreML frontend test..." run_pytest cython python-frontend-coreml tests/python/frontend/coreml @@ -54,3 +51,6 @@ run_pytest cython python-frontend-darknet tests/python/frontend/darknet echo "Running relay PyTorch frontend test..." run_pytest cython python-frontend-pytorch tests/python/frontend/pytorch + +echo "Running relay PaddlePaddle frontend test..." +run_pytest cython python-frontend-paddlepaddle tests/python/frontend/paddlepaddle diff --git a/tests/scripts/task_python_microtvm.sh b/tests/scripts/task_python_microtvm.sh index aa49d90eaa43..7b7758c3da24 100755 --- a/tests/scripts/task_python_microtvm.sh +++ b/tests/scripts/task_python_microtvm.sh @@ -23,5 +23,9 @@ set -x # NOTE(areusch): Adding to diagnose flaky timeouts source tests/scripts/setup-pytest-env.sh make cython3 -run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --microtvm-platforms=host +run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --microtvm-platforms=qemu_x86 run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --microtvm-platforms=mps2_an521 + +run_pytest ctypes python-microtvm-arduino apps/microtvm/arduino/template_project/tests +run_pytest ctypes python-microtvm-arduino-nano33ble tests/micro/arduino --test-build-only --microtvm-platforms=nano33ble +run_pytest ctypes python-microtvm-arduino-due tests/micro/arduino --test-build-only --microtvm-platforms=due diff --git a/tests/scripts/task_python_vta_tsim.sh b/tests/scripts/task_python_vta_tsim.sh index 3a6a35e5a06f..4c21f46c5f81 100755 --- a/tests/scripts/task_python_vta_tsim.sh +++ b/tests/scripts/task_python_vta_tsim.sh @@ -27,6 +27,9 @@ export VTA_HW_PATH=`pwd`/3rdparty/vta-hw export TVM_BIND_THREADS=0 export OMP_NUM_THREADS=1 +# temporary skip tsim test, enable later +exit 0 + # cleanup pycache find . -type f -path "*.pyc" | xargs rm -f diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index 9ddd1f2b5a4b..5cc1dc0503f7 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -28,19 +28,21 @@ echo "Using PYTHONPATH=$PYTHONPATH" export RUST_DIR="$TVM_HOME/rust" echo "Using RUST_DIR=$RUST_DIR" - export LLVM_CONFIG_DEFAULT=`which llvm-config-10` export LLVM_CONFIG_PATH="${LLVM_CONFIG_PATH:-$LLVM_CONFIG_DEFAULT}" echo "Using LLVM_CONFIG_PATH=$LLVM_CONFIG_PATH" +TVM_RUSTC_VERSION=`rustc --version` +echo "Using TVM_RUSTC_VERSION=$TVM_RUSTC_VERSION" + +TVM_CARGO_VERSION=`cargo --version` +echo "Using TVM_CARGO_VERSION=$TVM_CARGO_VERSION" + # to avoid CI CPU thread throttling. export TVM_BIND_THREADS=0 export OMP_NUM_THREADS=1 -cd $RUST_DIR -cargo fmt -- --check - # First we test tvm-sys the core Rust bindings. cd $RUST_DIR/tvm-sys # First we test w/o the bindings feature on. diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index 5b0931405212..1619a55dc7e9 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -349,11 +349,7 @@ def tune_and_evaluate(): # Evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500) - prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, repeat=3, min_repeat_ms=500)) # We do not run the tuning in our webpage server since the server doesn't have a Raspberry Pi, diff --git a/tutorials/auto_scheduler/tune_network_cuda.py b/tutorials/auto_scheduler/tune_network_cuda.py index 7b5619c671be..08c15264e3c1 100644 --- a/tutorials/auto_scheduler/tune_network_cuda.py +++ b/tutorials/auto_scheduler/tune_network_cuda.py @@ -288,9 +288,7 @@ def run_tuning(): # Evaluate print("Evaluate inference time cost...") -ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500) -prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond -print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) +print(module.benchmark(dev, repeat=3, min_repeat_ms=500)) ################################################################# diff --git a/tutorials/auto_scheduler/tune_network_mali.py b/tutorials/auto_scheduler/tune_network_mali.py index 8275f96806b8..2d1e51520952 100644 --- a/tutorials/auto_scheduler/tune_network_mali.py +++ b/tutorials/auto_scheduler/tune_network_mali.py @@ -264,11 +264,7 @@ def tune_and_evaluate(): # Evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500) - prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, repeat=3, min_repeat_ms=500)) # We do not run the tuning in our webpage server since server doesn't have mali gpu. diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 76068fa79605..6cb8d6f14cb9 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -322,9 +322,7 @@ def run_tuning(): # Evaluate print("Evaluate inference time cost...") -ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500) -prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond -print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) +print(module.benchmark(dev, repeat=3, min_repeat_ms=500)) ################################################################# diff --git a/tutorials/autotvm/tune_relay_arm.py b/tutorials/autotvm/tune_relay_arm.py index debf8b8ecf60..f072c5ddac93 100644 --- a/tutorials/autotvm/tune_relay_arm.py +++ b/tutorials/autotvm/tune_relay_arm.py @@ -359,12 +359,7 @@ def tune_and_evaluate(tuning_opt): # evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, number=1, repeat=10) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" - % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, number=1, repeat=10)) # We do not run the tuning in our webpage server since it takes too long. diff --git a/tutorials/autotvm/tune_relay_cuda.py b/tutorials/autotvm/tune_relay_cuda.py index 65991cc83454..b2af2e13f4fe 100644 --- a/tutorials/autotvm/tune_relay_cuda.py +++ b/tutorials/autotvm/tune_relay_cuda.py @@ -244,12 +244,7 @@ def tune_and_evaluate(tuning_opt): # evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, number=1, repeat=600) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" - % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, number=1, repeat=600)) # We do not run the tuning in our webpage server since it takes too long. diff --git a/tutorials/autotvm/tune_relay_mobile_gpu.py b/tutorials/autotvm/tune_relay_mobile_gpu.py index 790c2ff2c2b9..d3f4ec62fafc 100644 --- a/tutorials/autotvm/tune_relay_mobile_gpu.py +++ b/tutorials/autotvm/tune_relay_mobile_gpu.py @@ -352,12 +352,7 @@ def tune_and_evaluate(tuning_opt): # evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, number=1, repeat=30) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" - % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, number=1, repeat=30)) # We do not run the tuning in our webpage server since it takes too long. diff --git a/tutorials/autotvm/tune_relay_x86.py b/tutorials/autotvm/tune_relay_x86.py index dd5d4057c211..771220bb3314 100644 --- a/tutorials/autotvm/tune_relay_x86.py +++ b/tutorials/autotvm/tune_relay_x86.py @@ -194,6 +194,18 @@ def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True): # Finally, we launch tuning jobs and evaluate the end-to-end performance. +def evaluate_performance(lib, data_shape): + # upload parameters to device + dev = tvm.cpu() + data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype)) + module = runtime.GraphModule(lib["default"](dev)) + module.set_input(input_name, data_tvm) + + # evaluate + print("Evaluate inference time cost...") + print(module.benchmark(dev, number=100, repeat=3)) + + def tune_and_evaluate(tuning_opt): # extract workloads from relay program print("Extract tasks...") @@ -206,26 +218,28 @@ def tune_and_evaluate(tuning_opt): tune_kernels(tasks, **tuning_opt) tune_graph(mod["main"], data_shape, log_file, graph_opt_sch_file) + # compile kernels in default mode + print("Evaluation of the network compiled in 'default' mode without auto tune:") + with tvm.transform.PassContext(opt_level=3): + print("Compile...") + lib = relay.build(mod, target=target, params=params) + evaluate_performance(lib, data_shape) + + # compile kernels in kernel tuned only mode + print("\nEvaluation of the network been tuned on kernel level:") + with autotvm.apply_history_best(log_file): + print("Compile...") + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + evaluate_performance(lib, data_shape) + # compile kernels with graph-level best records + print("\nEvaluation of the network been tuned on graph level:") with autotvm.apply_graph_best(graph_opt_sch_file): print("Compile...") with tvm.transform.PassContext(opt_level=3): lib = relay.build_module.build(mod, target=target, params=params) - - # upload parameters to device - dev = tvm.cpu() - data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype)) - module = runtime.GraphModule(lib["default"](dev)) - module.set_input(input_name, data_tvm) - - # evaluate - print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, number=100, repeat=3) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" - % (np.mean(prof_res), np.std(prof_res)) - ) + evaluate_performance(lib, data_shape) # We do not run the tuning in our webpage server since it takes too long. @@ -256,6 +270,29 @@ def tune_and_evaluate(tuning_opt): # [Task 10/12] Current/Best: 182.33/1734.45 GFLOPS | Progress: (360/360) | 1755.06 s Done. # [Task 11/12] Current/Best: 372.18/1745.15 GFLOPS | Progress: (360/360) | 1684.50 s Done. # [Task 12/12] Current/Best: 215.34/2271.11 GFLOPS | Progress: (400/400) | 2128.74 s Done. +# INFO Start to benchmark layout transformation... +# INFO Benchmarking layout transformation successful. +# INFO Start to run dynamic programming algorithm... +# INFO Start forward pass... +# INFO Finished forward pass. +# INFO Start backward pass... +# INFO Finished backward pass... +# INFO Finished DPExecutor run. +# INFO Writing optimal schedules to resnet-18_graph_opt.log successfully. +# +# Evaluation of the network compiled in 'default' mode without auto tune: +# Compile... +# Evaluate inference time cost... +# Mean inference time (std dev): 4.5 ms (0.03 ms) +# +# Evaluation of the network been tuned on kernel level: +# Compile... +# Evaluate inference time cost... +# Mean inference time (std dev): 3.2 ms (0.03 ms) +# +# Evaluation of the network been tuned on graph level: # Compile... +# Config for target=llvm -keys=cpu -link-params=0, workload=('dense_nopack.x86', ('TENSOR', (1, 512), 'float32'), ('TENSOR', (1000, 512), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression. +# Config for target=llvm -keys=cpu -link-params=0, workload=('dense_pack.x86', ('TENSOR', (1, 512), 'float32'), ('TENSOR', (1000, 512), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression. # Evaluate inference time cost... # Mean inference time (std dev): 3.16 ms (0.03 ms) diff --git a/tutorials/dev/bring_your_own_datatypes.py b/tutorials/dev/bring_your_own_datatypes.py index 06d96e14d28c..1cf556ddd056 100644 --- a/tutorials/dev/bring_your_own_datatypes.py +++ b/tutorials/dev/bring_your_own_datatypes.py @@ -82,9 +82,7 @@ ###################################################################### # Finally, we're ready to run the program: -ex = relay.create_executor(mod=module) - -z_output = ex.evaluate()(x_input, y_input) +z_output = relay.create_executor(mod=module).evaluate()(x_input, y_input) print("z: {}".format(z_output)) ###################################################################### @@ -135,8 +133,7 @@ # Now that we can express our program without errors, let's try running it! try: with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - ex = relay.create_executor("graph", mod=module) - z_output_myfloat = ex.evaluate()(x_input, y_input) + z_output_myfloat = relay.create_executor("graph", mod=module).evaluate()(x_input, y_input) print("z: {}".format(y_myfloat)) except tvm.TVMError as e: # Print last line of error @@ -181,8 +178,7 @@ # We can now re-try running the program: try: with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - ex = relay.create_executor("graph", mod=module) - z_output_myfloat = ex.evaluate()(x_input, y_input) + z_output_myfloat = relay.create_executor("graph", mod=module).evaluate()(x_input, y_input) print("z: {}".format(z_output_myfloat)) except tvm.TVMError as e: # Print last line of error @@ -211,8 +207,7 @@ # Now, we can run our program without errors. with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - compiled = ex.evaluate(program) - z_output_myfloat = compiled(x_input, y_input) + z_output_myfloat = relay.create_executor(mod=module).evaluate()(x_input, y_input) print("z: {}".format(z_output_myfloat)) print("x:\t\t{}".format(x_input)) @@ -262,9 +257,9 @@ def get_cat_image(): ###################################################################### # It's easy to execute MobileNet with native TVM: -ex = tvm.relay.create_executor("graph", mod=module) +ex = tvm.relay.create_executor("graph", mod=module, params=params) input = get_cat_image() -result = ex.evaluate()(input, **params).numpy() +result = ex.evaluate()(input).numpy() # print first 10 elements print(result.flatten()[:10]) @@ -311,7 +306,9 @@ def convert_ndarray(dst_dtype, array): try: # Vectorization is not implemented with custom datatypes. with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - result_myfloat = ex.evaluate(expr)(input, **params) + result_myfloat = tvm.relay.create_executor("graph", mod=module).evaluate(expr)( + input, **params + ) except tvm.TVMError as e: print(str(e).split("\n")[-1]) @@ -401,7 +398,7 @@ def convert_ndarray(dst_dtype, array): # Vectorization is not implemented with custom datatypes. with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - result_myfloat = ex.evaluate(expr)(input, **params) + result_myfloat = relay.create_executor(mod=module).evaluate(expr)(input, **params) result_myfloat = convert_ndarray(src_dtype, result_myfloat).numpy() # print first 10 elements print(result_myfloat.flatten()[:10]) diff --git a/tutorials/frontend/deploy_model_on_android.py b/tutorials/frontend/deploy_model_on_android.py index f435befb8250..c7b610d5d503 100644 --- a/tutorials/frontend/deploy_model_on_android.py +++ b/tutorials/frontend/deploy_model_on_android.py @@ -332,9 +332,7 @@ def transform_image(image): print("TVM prediction top-1: {}".format(synset[top1])) print("Evaluate inference time cost...") -ftimer = module.module.time_evaluator("run", dev, number=1, repeat=10) -prof_res = np.array(ftimer().results) * 1000 # convert to millisecond -print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) +print(module.benchmark(dev, number=1, repeat=10)) ###################################################################### # Sample Output diff --git a/tutorials/frontend/deploy_prequantized.py b/tutorials/frontend/deploy_prequantized.py index a59655222278..11a9e3e3eee8 100644 --- a/tutorials/frontend/deploy_prequantized.py +++ b/tutorials/frontend/deploy_prequantized.py @@ -199,9 +199,7 @@ def quantize_model(model, inp): # Here we give an example of how to measure performance of TVM compiled models. n_repeat = 100 # should be bigger to make the measurement more accurate dev = tvm.cpu(0) -ftimer = rt_mod.module.time_evaluator("run", dev, number=1, repeat=n_repeat) -prof_res = np.array(ftimer().results) * 1e3 -print("Elapsed average ms:", np.mean(prof_res)) +print(rt_mod.benchmark(dev, number=1, repeat=n_repeat)) ###################################################################### # .. note:: diff --git a/tutorials/frontend/deploy_prequantized_tflite.py b/tutorials/frontend/deploy_prequantized_tflite.py index e3934e9b250f..7bbb06bdf801 100644 --- a/tutorials/frontend/deploy_prequantized_tflite.py +++ b/tutorials/frontend/deploy_prequantized_tflite.py @@ -232,9 +232,7 @@ def run_tvm(lib): # Here we give an example of how to measure performance of TVM compiled models. n_repeat = 100 # should be bigger to make the measurement more accurate dev = tvm.cpu(0) -ftimer = rt_mod.module.time_evaluator("run", dev, number=1, repeat=n_repeat) -prof_res = np.array(ftimer().results) * 1e3 -print("Elapsed average ms:", np.mean(prof_res)) +print(rt_mod.benchmark(dev, number=1, repeat=n_repeat)) ###################################################################### # .. note:: diff --git a/tutorials/frontend/deploy_quantized.py b/tutorials/frontend/deploy_quantized.py index b2210b8ab69b..2d9275796eb5 100644 --- a/tutorials/frontend/deploy_quantized.py +++ b/tutorials/frontend/deploy_quantized.py @@ -146,11 +146,11 @@ def quantize(mod, params, data_aware): # ------------- # We create a Relay VM to build and execute the model. def run_inference(mod): - executor = relay.create_executor("vm", mod, dev, target) + model = relay.create_executor("vm", mod, dev, target).evaluate() val_data, batch_fn = get_val_data() for i, batch in enumerate(val_data): data, label = batch_fn(batch) - prediction = executor.evaluate()(data) + prediction = model(data) if i > 10: # only run inference on a few samples in this tutorial break diff --git a/tutorials/frontend/deploy_sparse.py b/tutorials/frontend/deploy_sparse.py index d3375c40fe72..768a697f45cf 100644 --- a/tutorials/frontend/deploy_sparse.py +++ b/tutorials/frontend/deploy_sparse.py @@ -90,6 +90,20 @@ import scipy.sparse as sp +# Ask tensorflow to limit its GPU memory to what's actually needed +# instead of gobbling everything that's available. +# https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth +# This way this tutorial is a little more friendly to sphinx-gallery. +gpus = tf.config.list_physical_devices("GPU") +if gpus: + try: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + print("tensorflow will use experimental.set_memory_growth(True)") + except RuntimeError as e: + print("experimental.set_memory_growth option is not available: {}".format(e)) + + ############################################################################### # Configure Settings # ------------------ @@ -219,12 +233,7 @@ def run_relay_graph(mod, params, shape_dict, target, dev): m.run() tvm_output = m.get_output(0) - ftimer = m.module.time_evaluator("run", dev, repeat=5, number=5) - prof_res = np.array(ftimer().results) * 1000 - print( - "%-20s %-19s (%s)" - % ("Runtime:", "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)) - ) + print(m.benchmark(dev, repeat=5, number=5)) return tvm_output diff --git a/tutorials/frontend/from_keras.py b/tutorials/frontend/from_keras.py index 1c48aff799d4..182e769b35b1 100644 --- a/tutorials/frontend/from_keras.py +++ b/tutorials/frontend/from_keras.py @@ -97,14 +97,20 @@ # compile the model target = "cuda" dev = tvm.cuda(0) -with tvm.transform.PassContext(opt_level=3): - executor = relay.build_module.create_executor("graph", mod, dev, target) + +# TODO(mbs): opt_level=3 causes nn.contrib_conv2d_winograd_weight_transform +# to end up in the module which fails memory validation on cuda most likely +# due to a latent bug. Note that the pass context only has an effect within +# evaluate() and is not captured by create_executor(). +with tvm.transform.PassContext(opt_level=0): + model = relay.build_module.create_executor("graph", mod, dev, target, params).evaluate() + ###################################################################### # Execute on TVM # --------------- dtype = "float32" -tvm_out = executor.evaluate()(tvm.nd.array(data.astype(dtype)), **params) +tvm_out = model(tvm.nd.array(data.astype(dtype))) top1_tvm = np.argmax(tvm_out.numpy()[0]) ##################################################################### diff --git a/tutorials/frontend/from_mxnet.py b/tutorials/frontend/from_mxnet.py index 0ce610a2cdd6..027e9e6eb757 100644 --- a/tutorials/frontend/from_mxnet.py +++ b/tutorials/frontend/from_mxnet.py @@ -32,7 +32,7 @@ pip install mxnet --user -or please refer to offical installation guide. +or please refer to official installation guide. https://mxnet.apache.org/versions/master/install/index.html """ # some standard imports diff --git a/tutorials/frontend/from_onnx.py b/tutorials/frontend/from_onnx.py index 4eba297935f0..fd51d7a76992 100644 --- a/tutorials/frontend/from_onnx.py +++ b/tutorials/frontend/from_onnx.py @@ -29,7 +29,7 @@ pip install onnx --user -or please refer to offical site. +or please refer to official site. https://github.com/onnx/onnx """ import onnx @@ -92,13 +92,15 @@ mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) with tvm.transform.PassContext(opt_level=1): - intrp = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target) + executor = relay.build_module.create_executor( + "graph", mod, tvm.cpu(0), target, params + ).evaluate() ###################################################################### # Execute on TVM # --------------------------------------------- dtype = "float32" -tvm_output = intrp.evaluate()(tvm.nd.array(x.astype(dtype)), **params).numpy() +tvm_output = executor(tvm.nd.array(x.astype(dtype))).numpy() ###################################################################### # Display results @@ -122,7 +124,7 @@ # Notes # --------------------------------------------- # By default, ONNX defines models in terms of dynamic shapes. The ONNX importer -# retains that dynamism upon import, and the compiler attemps to convert the model +# retains that dynamism upon import, and the compiler attempts to convert the model # into a static shapes at compile time. If this fails, there may still be dynamic # operations in the model. Not all TVM kernels currently support dynamic shapes, # please file an issue on discuss.tvm.apache.org if you hit an error with dynamic kernels. diff --git a/tutorials/frontend/from_tensorflow.py b/tutorials/frontend/from_tensorflow.py index fc87c07fb569..4563e245c0cf 100644 --- a/tutorials/frontend/from_tensorflow.py +++ b/tutorials/frontend/from_tensorflow.py @@ -36,6 +36,21 @@ # Tensorflow imports import tensorflow as tf + +# Ask tensorflow to limit its GPU memory to what's actually needed +# instead of gobbling everything that's available. +# https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth +# This way this tutorial is a little more friendly to sphinx-gallery. +gpus = tf.config.list_physical_devices("GPU") +if gpus: + try: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + print("tensorflow will use experimental.set_memory_growth(True)") + except RuntimeError as e: + print("experimental.set_memory_growth option is not available: {}".format(e)) + + try: tf_compat_v1 = tf.compat.v1 except ImportError: diff --git a/tutorials/micro/micro_reference_vm.py b/tutorials/micro/micro_reference_vm.py index 93395a44c8ae..bb262893eb6b 100644 --- a/tutorials/micro/micro_reference_vm.py +++ b/tutorials/micro/micro_reference_vm.py @@ -151,7 +151,6 @@ .. code-block:: bash $ cd /Users/yourusername/path/to/tvm - $ sudo ./docker/install/ubuntu_install_qemu.sh $ cd apps/microtvm/reference-vm/zephyr/ $ poetry run pytest ../../../../tests/micro/qemu/test_zephyr.py --microtvm-platforms=host diff --git a/tutorials/micro/micro_tflite.py b/tutorials/micro/micro_tflite.py index 5e517bf062ef..5a39be08e108 100644 --- a/tutorials/micro/micro_tflite.py +++ b/tutorials/micro/micro_tflite.py @@ -208,52 +208,92 @@ with tvm.transform.PassContext( opt_level=3, config={"tir.disable_vectorize": True}, disabled_pass=["FuseOps", "AlterOpLayout"] ): - graph, c_mod, c_params = relay.build(mod, target=TARGET, params=params) + module = relay.build(mod, target=TARGET, params=params) -# Compiling for a host simulated device -# ------------------------------------- +# Inspecting the compilation output +# --------------------------------- # -# First, compile a static microTVM runtime for the targeted device. In this case, the host simulated -# device is used. -compiler = tvm.micro.DefaultCompiler(target=TARGET) -opts = tvm.micro.default_options( - os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host") +# The compilation process has produced some C code implementing the operators in this graph. We +# can inspect it by printing the CSourceModule contents (for the purposes of this tutorial, let's +# just print the first 10 lines): + +c_source_module = module.get_lib().imported_modules[0] +assert c_source_module.type_key == "c", "tutorial is broken" + +c_source_code = c_source_module.get_source() +first_few_lines = c_source_code.split("\n")[:10] +assert any( + l.startswith("TVM_DLL int32_t tvmgen_default_") for l in first_few_lines +), f"tutorial is broken: {first_few_lines!r}" +print("\n".join(first_few_lines)) + + +# Compiling the generated code +# ---------------------------- +# +# Now we need to incorporate the generated C code into a project that allows us to run inference on the +# device. The simplest way to do this is to integrate it yourself, using microTVM's standard output format +# (:doc:`Model Library Format` `). This is a tarball with a standard layout: + +# Get a temporary path where we can store the tarball (since this is running as a tutorial). +import tempfile + +fd, model_library_format_tar_path = tempfile.mkstemp() +os.close(fd) +os.unlink(model_library_format_tar_path) +tvm.micro.export_model_library_format(module, model_library_format_tar_path) + +import tarfile + +with tarfile.open(model_library_format_tar_path, "r:*") as tar_f: + print("\n".join(f" - {m.name}" for m in tar_f.getmembers())) + +# Cleanup for tutorial: +os.unlink(model_library_format_tar_path) + + +# TVM also provides a standard way for embedded platforms to automatically generate a standalone +# project, compile and flash it to a target, and communicate with it using the standard TVM RPC +# protocol. The Model Library Format serves as the model input to this process. When embedded +# platforms provide such an integration, they can be used directly by TVM for both host-driven +# inference and autotuning . This integration is provided by the +# `microTVM Project API` _, +# +# Embedded platforms need to provide a Template Project containing a microTVM API Server (typically, +# this lives in a file ``microtvm_api_server.py`` in the root directory). Let's use the example ``host`` +# project in this tutorial, which simulates the device using a POSIX subprocess and pipes: + +import subprocess +import pathlib + +repo_root = pathlib.Path( + subprocess.check_output(["git", "rev-parse", "--show-toplevel"], encoding="utf-8").strip() ) +template_project_path = repo_root / "src" / "runtime" / "crt" / "host" +project_options = {} # You can use options to provide platform-specific options through TVM. # Compiling for physical hardware (or an emulated board, like the mps_an521) # -------------------------------------------------------------------------- -# For physical hardware, comment out the previous section selecting TARGET and BOARD and use this -# compiler definition instead of the one above. -# -# import subprocess -# from tvm.micro.contrib import zephyr -# -# repo_root = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], encoding='utf-8').strip() -# project_dir = os.path.join(repo_root, "apps", "microtvm", "zephyr", "host_driven") -# compiler = zephyr.ZephyrCompiler( -# project_dir=project_dir, -# board=BOARD, -# zephyr_toolchain_variant="zephyr", -# ) -# -# opts = tvm.micro.default_options(f"{project_dir}/crt") -# -# -# # Enable printing memory usage statistics for the runtime image generated by Zephyr -# logging.basicConfig(level="INFO") - -workspace = tvm.micro.Workspace() -micro_binary = tvm.micro.build_static_runtime( - workspace, - compiler, - c_mod, - opts, - # Use the microTVM memory manager. If, in your main.cc, you change TVMPlatformMemoryAllocate and - # TVMPlatformMemoryFree to use e.g. malloc() and free(), you can omit this extra library. - extra_libs=[tvm.micro.get_standalone_crt_lib("memory")], +# For physical hardware, you can try out the Zephyr platform by using a different template project +# and options: +# +# template_project_path = repo_root / "apps" / "microtvm" / "zephyr" / "template_project" +# project_options = {"project_type": "host_driven", zephyr_board": "nucleo_f746zg"}} + +# Create a temporary directory +import tvm.contrib.utils + +temp_dir = tvm.contrib.utils.tempdir() +generated_project_dir = temp_dir / "generated-project" +generated_project = tvm.micro.generate_project( + template_project_path, module, generated_project_dir, project_options ) +# Build and flash the project +generated_project.build() +generated_project.flash() + ###################################################################### # Next, establish a session with the simulated device and run the @@ -261,14 +301,13 @@ # microcontroller, but in this tutorial, it simply launches a subprocess # to stand in for an attached microcontroller. -flasher = compiler.flasher() -with tvm.micro.Session(binary=micro_binary, flasher=flasher) as session: +with tvm.micro.Session(transport_context_manager=generated_project.transport()) as session: graph_mod = tvm.micro.create_local_graph_executor( - graph, session.get_system_lib(), session.device + module.get_graph_json(), session.get_system_lib(), session.device ) # Set the model parameters using the lowered parameters produced by `relay.build`. - graph_mod.set_input(**c_params) + graph_mod.set_input(**module.get_params()) # The model consumes a single float32 value and returns a predicted sine value. To pass the # input value we construct a tvm.nd.array object with a single contrived number as input. For diff --git a/vta/python/vta/__init__.py b/vta/python/vta/__init__.py index 5fce76808c45..af840c9c55f3 100644 --- a/vta/python/vta/__init__.py +++ b/vta/python/vta/__init__.py @@ -21,6 +21,7 @@ configure the hardware environment and access remote device through RPC. """ import sys +import tvm._ffi.base from .autotvm import module_loader from .bitstream import get_bitstream_path, download_bitstream @@ -29,8 +30,9 @@ __version__ = "0.1.0" + # do not from tvm import topi when running vta.exec.rpc_server -# to maintain minimum dependency on the board -if sys.argv[0] not in ("-c", "-m"): +# in lib tvm runtime only mode +if not tvm._ffi.base._RUNTIME_ONLY: from . import top from .build_module import build_config, lower, build diff --git a/vta/python/vta/build_module.py b/vta/python/vta/build_module.py index 3b62edd1a978..8ced8e5ce494 100644 --- a/vta/python/vta/build_module.py +++ b/vta/python/vta/build_module.py @@ -17,8 +17,9 @@ # pylint: disable=unused-argument, invalid-name """VTA specific buildin for runtime.""" import tvm +from tvm.ir import register_intrin_lowering from . import transform -from .environment import get_env +from .environment import get_env, Environment def EarlyRewrite(): @@ -134,3 +135,65 @@ def build(*args, **kwargs): tvm.ir.register_op_attr("tir.vta.command_handle", "TGlobalSymbol", "VTATLSCommandHandle") tvm.ir.register_op_attr("tir.vta.command_handle", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque) + +# The memory information for the compiler +@tvm.register_func("tvm.info.mem.%s" % Environment.inp_scope) +def mem_info_inp_buffer(): + spec = get_env() + return tvm.ir.make_node( + "MemoryInfo", + unit_bits=spec.INP_ELEM_BITS, + max_simd_bits=spec.INP_ELEM_BITS, + max_num_bits=spec.INP_BUFF_SIZE * 8, + head_address=None, + ) + + +@tvm.register_func("tvm.info.mem.%s" % Environment.wgt_scope) +def mem_info_wgt_buffer(): + spec = get_env() + return tvm.ir.make_node( + "MemoryInfo", + unit_bits=spec.WGT_ELEM_BITS, + max_simd_bits=spec.WGT_ELEM_BITS, + max_num_bits=spec.WGT_BUFF_SIZE * 8, + head_address=None, + ) + + +@tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope) +def mem_info_acc_buffer(): + spec = get_env() + return tvm.ir.make_node( + "MemoryInfo", + unit_bits=spec.ACC_ELEM_BITS, + max_simd_bits=spec.ACC_ELEM_BITS, + max_num_bits=spec.ACC_BUFF_SIZE * 8, + head_address=None, + ) + + +# TVM Op related registration +@register_intrin_lowering("tir.vta.coproc_sync", "default") +def coproc_sync(op): + _ = op + return tvm.tir.call_extern( + "int32", + "VTASynchronize", + get_env().dev.command_handle, + tvm.runtime.const(1 << 31, dtype="uint32"), + ) + + +@register_intrin_lowering("tir.vta.coproc_dep_push", "default") +def coproc_dep_push(op): + return tvm.tir.call_extern( + "int32", "VTADepPush", get_env().dev.command_handle, op.args[0], op.args[1] + ) + + +@register_intrin_lowering("tir.vta.coproc_dep_pop", "default") +def coproc_dep_pop(op): + return tvm.tir.call_extern( + "int32", "VTADepPop", get_env().dev.command_handle, op.args[0], op.args[1] + ) diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index 9181a44fa523..087c7e852cf6 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -23,7 +23,6 @@ import copy import tvm from tvm import te -from tvm.ir import register_intrin_lowering from . import intrin @@ -255,69 +254,6 @@ def get_env(): return Environment.current -# The memory information for the compiler -@tvm.register_func("tvm.info.mem.%s" % Environment.inp_scope) -def mem_info_inp_buffer(): - spec = get_env() - return tvm.ir.make_node( - "MemoryInfo", - unit_bits=spec.INP_ELEM_BITS, - max_simd_bits=spec.INP_ELEM_BITS, - max_num_bits=spec.INP_BUFF_SIZE * 8, - head_address=None, - ) - - -@tvm.register_func("tvm.info.mem.%s" % Environment.wgt_scope) -def mem_info_wgt_buffer(): - spec = get_env() - return tvm.ir.make_node( - "MemoryInfo", - unit_bits=spec.WGT_ELEM_BITS, - max_simd_bits=spec.WGT_ELEM_BITS, - max_num_bits=spec.WGT_BUFF_SIZE * 8, - head_address=None, - ) - - -@tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope) -def mem_info_acc_buffer(): - spec = get_env() - return tvm.ir.make_node( - "MemoryInfo", - unit_bits=spec.ACC_ELEM_BITS, - max_simd_bits=spec.ACC_ELEM_BITS, - max_num_bits=spec.ACC_BUFF_SIZE * 8, - head_address=None, - ) - - -# TVM Op related registration -@register_intrin_lowering("tir.vta.coproc_sync", "default") -def coproc_sync(op): - _ = op - return tvm.tir.call_extern( - "int32", - "VTASynchronize", - get_env().dev.command_handle, - tvm.runtime.const(1 << 31, dtype="uint32"), - ) - - -@register_intrin_lowering("tir.vta.coproc_dep_push", "default") -def coproc_dep_push(op): - return tvm.tir.call_extern( - "int32", "VTADepPush", get_env().dev.command_handle, op.args[0], op.args[1] - ) - - -@register_intrin_lowering("tir.vta.coproc_dep_pop", "default") -def coproc_dep_pop(op): - return tvm.tir.call_extern( - "int32", "VTADepPop", get_env().dev.command_handle, op.args[0], op.args[1] - ) - - def _init_env(): """Initialize the default global env""" config_path = os.path.join(get_vta_hw_path(), "config/vta_config.json") diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py index b7a9c79392d2..dcf564dd0314 100644 --- a/vta/python/vta/exec/rpc_server.py +++ b/vta/python/vta/exec/rpc_server.py @@ -34,7 +34,6 @@ from ..libinfo import find_libvta -@tvm.register_func("tvm.rpc.server.start", override=True) def server_start(): """VTA RPC server extension.""" # pylint: disable=unused-variable @@ -148,8 +147,21 @@ def main(): else: tracker_addr = None + # register the initialization callback + def server_init_callback(): + # pylint: disable=redefined-outer-name, reimported, import-outside-toplevel, import-self + import tvm + import vta.exec.rpc_server + + tvm.register_func("tvm.rpc.server.start", vta.exec.rpc_server.server_start, override=True) + server = rpc.Server( - args.host, args.port, args.port_end, key=args.key, tracker_addr=tracker_addr + args.host, + args.port, + args.port_end, + key=args.key, + tracker_addr=tracker_addr, + server_init_callback=server_init_callback, ) server.proc.join() diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index a982b88b75e8..f15e4922b4a8 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -56,13 +56,24 @@ def _pack_batch_channel(data, dshape, bfactor, cfactor): return data -def _unpack_batch_channel(data, old_shape): +def _unpack_batch_channel(data, old_shape, unpack_transpose=False): """Unpack the data channel dimension.""" - data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3)) + if unpack_transpose: + data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3)) data = op.reshape(data, newshape=old_shape) return data +def _channel_const_match(channel_length, cfactor_out): + """Round the chanel const variant if the value not divisible by cfactor_out""" + diff = int(channel_length) % cfactor_out + if diff != 0: + diff = cfactor_out - diff + channel_length = channel_length + diff + + return diff, channel_length + + def _const_shape_match(data, dshape, cfactor_out): """Pad the constant if the shape[0] not divisible by cfactor_out.""" assert len(dshape) == 3 @@ -299,6 +310,7 @@ def __init__(self, bfactor, cfactor, weight_bits): self.upsampling = op.op.get("nn.upsampling") self.reshape = op.op.get("reshape") self.number_of_conv2d = 0 + self.unpack_transpose = True super().__init__() def visit_call(self, call): @@ -319,7 +331,7 @@ def visit_call(self, call): self.start_pack = False data = args[0] data_shape = _get_tensor_shape(call.args[0]) - return _unpack_batch_channel(data, data_shape) + return _unpack_batch_channel(data, data_shape, self.unpack_transpose) if self.start_pack: # Operator cases if call.op == self.conv2d and odtype == "int32": @@ -429,12 +441,12 @@ def visit_call(self, call): if len(pad_width) == 6: pass elif len(pad_width) == 4: - (data,) = args + (data, pad_value) = args new_pad_width = [] new_pad_width.extend(pad_width) for _ in range(2): new_pad_width.append([0, 0]) - return op.nn.pad(data, pad_value=call.attrs.pad_value, pad_width=new_pad_width) + return op.nn.pad(data, pad_value=pad_value, pad_width=new_pad_width) elif call.op == self.upsampling: (data,) = args scale_h = call.attrs.scale_h @@ -445,8 +457,17 @@ def visit_call(self, call): return op.nn.upsampling(data, scale_h, scale_w, data_layout, method, align_corners) elif call.op == self.reshape and len(input_types[0].shape) == 4: (data,) = args + self.unpack_transpose = False data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3)) - return op.reshape(data, [int(x) for x in input_types[0].shape]) + new_shape = [int(x) for x in input_types[0].shape] + # Check if the reshape match with such shape after pad + pad, new_shape[1] = _channel_const_match(new_shape[1], self.cfactor) + data = op.reshape(data, new_shape) + # remove pad data + if pad != 0: + new_pad_width = [[0, 0], [0, -pad], [0, 0], [0, 0]] + data = op.nn.pad(data, pad_width=new_pad_width) + return data return relay.Call(self.visit(call.op), args, call.attrs) diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 7c7d02b40fbb..383841f19e34 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -495,21 +495,21 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value): # FIXME: pad_value is ignored... env = get_env() _ = pad_value - if dst.scope == "global": + if dst.scope() == "global": # Store if pad_before or pad_after: raise RuntimeError("Do not support copy into DRAM with pad") - if src.scope == env.acc_scope: + if src.scope() == env.acc_scope: elem_width = env.OUT_WIDTH elem_bytes = env.OUT_ELEM_BYTES mem_type = env.dev.MEM_ID_OUT data_type = "int%d" % env.OUT_WIDTH task_qid = env.dev.QID_STORE_OUT else: - raise RuntimeError("Do not support copy %s->dram" % (src.scope)) + raise RuntimeError("Do not support copy %s->dram" % (src.scope())) _check_compact(src) x_size, y_size, x_stride, offset = _get_2d_pattern( - dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True + dst, elem_width, elem_bytes, data_type, src.scope(), allow_fold=True ) irb = tvm.tir.ir_builder.create() irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid)) @@ -528,27 +528,27 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value): ) ) return irb.get() - elif src.scope == "global": - if dst.scope == env.acc_scope: + elif src.scope() == "global": + if dst.scope() == env.acc_scope: elem_width = env.ACC_WIDTH elem_bytes = env.ACC_ELEM_BYTES mem_type = env.dev.MEM_ID_ACC data_type = "int%d" % env.ACC_WIDTH task_qid = env.dev.QID_LOAD_OUT - elif dst.scope == env.inp_scope: + elif dst.scope() == env.inp_scope: elem_width = env.INP_WIDTH elem_bytes = env.INP_ELEM_BYTES mem_type = env.dev.MEM_ID_INP data_type = "int%d" % env.INP_WIDTH task_qid = env.dev.QID_LOAD_INP - elif dst.scope == env.wgt_scope: + elif dst.scope() == env.wgt_scope: elem_width = env.WGT_WIDTH elem_bytes = env.WGT_ELEM_BYTES mem_type = env.dev.MEM_ID_WGT data_type = "int%d" % env.WGT_WIDTH task_qid = env.dev.QID_LOAD_WGT else: - raise RuntimeError("Do not support copy dram->%s" % (dst.scope)) + raise RuntimeError("Do not support copy dram->%s" % (dst.scope())) # collect pad statistics if pad_before: assert pad_after @@ -586,7 +586,7 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value): _check_compact(dst) x_size, y_size, x_stride, offset = _get_2d_pattern( - src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold + src, elem_width, elem_bytes, data_type, dst.scope(), allow_fold=allow_fold ) if data_type != src.dtype: @@ -617,7 +617,7 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value): return irb.get() else: - raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope)) + raise RuntimeError("Do not support copy %s->%s" % (src.scope(), dst.scope())) return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy) diff --git a/vta/tutorials/autotvm/tune_alu_vta.py b/vta/tutorials/autotvm/tune_alu_vta.py index 7b1fd411be57..3d38bbbbbd3d 100644 --- a/vta/tutorials/autotvm/tune_alu_vta.py +++ b/vta/tutorials/autotvm/tune_alu_vta.py @@ -28,7 +28,7 @@ import tvm from tvm import te from tvm import rpc, autotvm, relay -from tvm.contrib import graph_runtime, download +from tvm.contrib import download from tvm.autotvm.measure.measure_methods import request_remote from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner from tvm.autotvm import record diff --git a/vta/tutorials/frontend/deploy_classification.py b/vta/tutorials/frontend/deploy_classification.py index b2f909b9710a..139e30333f1e 100644 --- a/vta/tutorials/frontend/deploy_classification.py +++ b/vta/tutorials/frontend/deploy_classification.py @@ -52,7 +52,7 @@ import tvm from tvm import te from tvm import rpc, autotvm, relay -from tvm.contrib import graph_executor, utils, download, graph_runtime +from tvm.contrib import graph_executor, utils, download from tvm.contrib.debugger import debug_executor from tvm.relay import transform @@ -223,10 +223,10 @@ if env.TARGET == "intelfocl" or env.TARGET == "sim": ctxes = [remote.ext_dev(0), remote.cpu(0)] - m = graph_runtime.create(graph, lib, ctxes) + m = graph_executor.create(graph, lib, ctxes) else: # Graph runtime - m = graph_runtime.create(graph, lib, ctx) + m = graph_executor.create(graph, lib, ctx) ###################################################################### # Perform image classification inference diff --git a/vta/tutorials/frontend/legacy/deploy_detection.py b/vta/tutorials/frontend/deploy_detection.py similarity index 99% rename from vta/tutorials/frontend/legacy/deploy_detection.py rename to vta/tutorials/frontend/deploy_detection.py index 1d78786848e7..771801851a48 100644 --- a/vta/tutorials/frontend/legacy/deploy_detection.py +++ b/vta/tutorials/frontend/deploy_detection.py @@ -177,9 +177,9 @@ # Get execution context from remote ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) -#################################### +##################################### # Build the inference graph executor. -# ---------------------------------- +# ----------------------------------- # Using Darknet library load downloaded vision model and compile with Relay. # The compilation steps are: # @@ -191,7 +191,6 @@ # 5. Perform relay build to object file. # 6. Load the object file onto remote (FPGA device). # 7. Generate graph executor, `m`. -# # Load pre-configured AutoTVM schedules with autotvm.tophub.context(target): @@ -212,7 +211,7 @@ # Note: We set opt_level to 3 in order to fold batch norm with tvm.transform.PassContext(opt_level=3): with relay.quantize.qconfig( - global_scale=33.0, + global_scale=23.0, skip_conv_layers=[0], store_lowbit_output=True, round_for_shift=True, diff --git a/web/Makefile b/web/Makefile index 8c4dbc20dadc..34a1b8172484 100644 --- a/web/Makefile +++ b/web/Makefile @@ -18,7 +18,7 @@ TVM_ROOT=$(shell cd ..; pwd) INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\ - -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include + -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include -I$(TVM_ROOT)/3rdparty/compiler-rt .PHONY: clean all rmtypedep preparetest diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index f12837f421f8..226797eb7d19 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -39,7 +39,7 @@ export async function detectGPUDevice(): Promise { interface FunctionInfo { name: string; arg_types: Array; - thread_axis_tags: Array; + launch_param_tags: Array; } /** @@ -114,8 +114,8 @@ export class WebGPUContext { const dispatchToDim: Array = []; - for (let i = 0; i < finfo.thread_axis_tags.length; ++i) { - const tag: string = finfo.thread_axis_tags[i]; + for (let i = 0; i < finfo.launch_param_tags.length; ++i) { + const tag: string = finfo.launch_param_tags[i]; if (tag.startsWith("blockIdx.")) { const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); assert(target >= 0 && target < 3); diff --git a/web/tests/python/webgpu_rpc_test.py b/web/tests/python/webgpu_rpc_test.py index e26937038fa2..c326f9fc1a16 100644 --- a/web/tests/python/webgpu_rpc_test.py +++ b/web/tests/python/webgpu_rpc_test.py @@ -70,11 +70,11 @@ def check(remote): a = tvm.nd.array(adata, dev) b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev) - np.testing.assert_equal(a.asnumpy(), adata) + np.testing.assert_equal(a.numpy(), adata) f1 = remote.system_lib() addone = f1.get_function("addone") addone(a, b) - np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + np.testing.assert_equal(b.numpy(), a.numpy() + 1) print("Test pass..") check(remote) diff --git a/web/tests/python/websock_rpc_test.py b/web/tests/python/websock_rpc_test.py index 84a0c9f134d5..ee94e40a678c 100644 --- a/web/tests/python/websock_rpc_test.py +++ b/web/tests/python/websock_rpc_test.py @@ -82,7 +82,7 @@ def check(remote): time_f(a, b) cost = time_f(a, b).mean print("%g secs/op" % cost) - np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + np.testing.assert_equal(b.numpy(), a.numpy() + 1) check(remote)